Files
proton-bridge/internal/user/sync.go
2022-11-16 12:26:09 +01:00

248 lines
6.8 KiB
Go

package user
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
)
const (
maxUpdateSize = 1 << 25
maxBatchSize = 1 << 8
)
func (user *User) sync(ctx context.Context) error {
return user.withAddrKRs(func(addrKRs map[string]*crypto.KeyRing) error {
logrus.Info("Beginning sync")
if !user.vault.SyncStatus().HasLabels {
logrus.Info("Syncing labels")
if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error {
return syncLabels(ctx, user.client, xslices.Unique(updateCh)...)
}); err != nil {
return fmt.Errorf("failed to sync labels: %w", err)
}
if err := user.vault.SetHasLabels(true); err != nil {
return fmt.Errorf("failed to set has labels: %w", err)
}
} else {
logrus.Info("Labels are already synced, skipping")
}
if !user.vault.SyncStatus().HasMessages {
logrus.Info("Syncing labels")
if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error {
return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh)
}); err != nil {
return fmt.Errorf("failed to sync messages: %w", err)
}
if err := user.vault.SetHasMessages(true); err != nil {
return fmt.Errorf("failed to set has messages: %w", err)
}
} else {
logrus.Info("Messages are already synced, skipping")
}
return nil
})
}
func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error {
// Sync the system folders.
system, err := client.GetLabels(ctx, liteapi.LabelTypeSystem)
if err != nil {
return fmt.Errorf("failed to get system labels: %w", err)
}
for _, label := range xslices.Filter(system, func(label liteapi.Label) bool { return wantLabelID(label.ID) }) {
for _, updateCh := range updateCh {
updateCh.Enqueue(newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name))
}
}
// Create Folders/Labels mailboxes with a random ID and with the \Noselect attribute.
for _, prefix := range []string{folderPrefix, labelPrefix} {
for _, updateCh := range updateCh {
updateCh.Enqueue(newPlaceHolderMailboxCreatedUpdate(prefix))
}
}
// Sync the API folders.
folders, err := client.GetLabels(ctx, liteapi.LabelTypeFolder)
if err != nil {
return fmt.Errorf("failed to get folders: %w", err)
}
for _, folder := range folders {
for _, updateCh := range updateCh {
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(folder.ID), getMailboxName(folder)))
}
}
// Sync the API labels.
labels, err := client.GetLabels(ctx, liteapi.LabelTypeLabel)
if err != nil {
return fmt.Errorf("failed to get labels: %w", err)
}
for _, label := range labels {
for _, updateCh := range updateCh {
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(label.ID), getMailboxName(label)))
}
}
// Wait for all label updates to be applied.
for _, updateCh := range updateCh {
update := imap.NewNoop()
defer update.WaitContext(ctx)
updateCh.Enqueue(update)
}
return nil
}
func syncMessages(
ctx context.Context,
userID string,
client *liteapi.Client,
vault *vault.User,
addrKRs map[string]*crypto.KeyRing,
updateCh map[string]*queue.QueuedChannel[imap.Update],
eventCh *queue.QueuedChannel[events.Event],
) error {
// Determine which messages to sync.
metadata, err := client.GetAllMessageMetadata(ctx, nil)
if err != nil {
return fmt.Errorf("get all message metadata: %w", err)
}
// Get the message IDs to sync.
messageIDs := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string {
return metadata.ID
})
// If possible, begin syncing from one beyond the last synced message.
if idx := xslices.Index(messageIDs, vault.SyncStatus().LastMessageID); idx >= 0 {
messageIDs = messageIDs[idx+1:]
}
// Fetch and build each message.
buildCh := stream.Map(
client.GetFullMessages(ctx, messageIDs...),
func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) {
return buildRFC822(ctx, full, addrKRs[full.AddressID])
},
)
defer buildCh.Close()
// Create the flushers, one per update channel.
flushers := make(map[string]*flusher)
for addrID, updateCh := range updateCh {
flusher := newFlusher(updateCh, maxUpdateSize)
defer flusher.flush(ctx, true)
flushers[addrID] = flusher
}
// Create a reporter to report sync progress updates.
reporter := newReporter(userID, eventCh, len(messageIDs), time.Second)
defer reporter.done()
// Send each update to the appropriate flusher.
return forEach(ctx, stream.Chunk(buildCh, maxBatchSize), func(batch []*buildRes) error {
for _, res := range batch {
flushers[res.addressID].push(ctx, res.update)
}
for _, flusher := range flushers {
flusher.flush(ctx, true)
}
if err := vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil {
return fmt.Errorf("failed to set last synced message ID: %w", err)
}
reporter.add(len(batch))
return nil
})
}
func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated {
if strings.EqualFold(labelName, imap.Inbox) {
labelName = imap.Inbox
}
return imap.NewMailboxCreated(imap.Mailbox{
ID: labelID,
Name: []string{labelName},
Flags: defaultFlags,
PermanentFlags: defaultPermanentFlags,
Attributes: imap.NewFlagSet(imap.AttrNoInferiors),
})
}
func newPlaceHolderMailboxCreatedUpdate(labelName string) *imap.MailboxCreated {
return imap.NewMailboxCreated(imap.Mailbox{
ID: imap.LabelID(uuid.NewString()),
Name: []string{labelName},
Flags: defaultFlags,
PermanentFlags: defaultPermanentFlags,
Attributes: imap.NewFlagSet(imap.AttrNoSelect),
})
}
func newMailboxCreatedUpdate(labelID imap.LabelID, labelName []string) *imap.MailboxCreated {
return imap.NewMailboxCreated(imap.Mailbox{
ID: labelID,
Name: labelName,
Flags: defaultFlags,
PermanentFlags: defaultPermanentFlags,
Attributes: imap.NewFlagSet(),
})
}
func wantLabelID(labelID string) bool {
switch labelID {
case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel:
return false
default:
return true
}
}
func forEach[T any](ctx context.Context, streamer stream.Stream[T], fn func(T) error) error {
for {
res, err := streamer.Next(ctx)
if errors.Is(err, stream.End) {
return nil
} else if err != nil {
return fmt.Errorf("failed to get next stream item: %w", err)
}
if err := fn(res); err != nil {
return fmt.Errorf("failed to process stream item: %w", err)
}
}
}