diff --git a/internal/user/sync.go b/internal/user/sync.go index c0ef6403..cac90f7c 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -19,7 +19,6 @@ package user import ( "context" - "errors" "fmt" "strings" "time" @@ -30,7 +29,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/vault" - "github.com/bradenaw/juniper/stream" + "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" "gitlab.protontech.ch/go/liteapi" @@ -39,7 +38,7 @@ import ( const ( maxUpdateSize = 1 << 27 // 128 MiB - maxBatchSize = 1 << 7 // 128 + maxBatchSize = 1 << 8 // 256 ) // doSync begins syncing the users data. @@ -188,28 +187,22 @@ func syncMessages( //nolint:funlen addrKRs map[string]*crypto.KeyRing, updateCh map[string]*queue.QueuedChannel[imap.Update], eventCh *queue.QueuedChannel[events.Event], - syncWorkers, syncBuffer int, + syncWorkers, _ int, ) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + // Determine which messages to sync. messageIDs, err := client.GetMessageIDs(ctx, vault.SyncStatus().LastMessageID) if err != nil { return fmt.Errorf("failed to get message IDs to sync: %w", err) } - // Fetch and build each message. - buildCh := stream.Map( - client.GetFullMessages(ctx, syncWorkers, syncBuffer, messageIDs...), - func(_ context.Context, full liteapi.FullMessage) (*buildRes, error) { - return buildRFC822(apiLabels, full, addrKRs[full.AddressID]) - }, - ) - // Create the flushers, one per update channel. - flushers := make(map[string]*flusher) + flushers := make(map[string]*flusher, len(updateCh)) for addrID, updateCh := range updateCh { flusher := newFlusher(updateCh, maxUpdateSize) - defer flusher.flush(ctx, true) flushers[addrID] = flusher } @@ -218,24 +211,101 @@ func syncMessages( //nolint:funlen reporter := newReporter(userID, eventCh, len(messageIDs), time.Second) defer reporter.done() - // Send each update to the appropriate flusher. - return forEach(ctx, stream.Chunk(buildCh, maxBatchSize), func(batch []*buildRes) error { - for _, res := range batch { - flushers[res.addressID].push(ctx, res.update) + type flushUpdate struct { + messageID string + noOps []*imap.Noop + batchLen int + } + + // The higher this value, the longer we can continue our download iteration before being blocked on channel writes + // to the update flushing goroutine. + flushCh := make(chan []*buildRes, 2) + + // Allow up to 4 batched wait requests. + flushUpdateCh := make(chan flushUpdate, 4) + + errorCh := make(chan error, syncWorkers*2) + + // Goroutine in charge of downloading and building messages in maxBatchSize batches. + go func() { + defer close(flushCh) + defer close(errorCh) + + for _, batch := range xslices.Chunk(messageIDs, maxBatchSize) { + if ctx.Err() != nil { + errorCh <- ctx.Err() + return + } + + result, err := parallel.MapContext(ctx, int(float32(syncWorkers)*1.5), batch, func(ctx context.Context, id string) (*buildRes, error) { + msg, err := client.GetFullMessage(ctx, id) + if err != nil { + return nil, err + } + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID]) + }) + + if err != nil { + errorCh <- err + return + } + + if ctx.Err() != nil { + errorCh <- ctx.Err() + return + } + + flushCh <- result + } + }() + + // Goroutine in charge of converting the messages into updates and building a waitable structure for progress + // tracking. + go func() { + defer close(flushUpdateCh) + for batch := range flushCh { + for _, res := range batch { + flushers[res.addressID].push(res.update) + } + + for _, flusher := range flushers { + flusher.flush() + } + + noopUpdates := make([]*imap.Noop, len(updateCh)) + index := 0 + for _, updateCh := range updateCh { + noopUpdates[index] = imap.NewNoop() + updateCh.Enqueue(noopUpdates[index]) + index++ + } + + flushUpdateCh <- flushUpdate{ + messageID: batch[len(batch)-1].messageID, + noOps: noopUpdates, + batchLen: len(batch), + } + } + }() + + for flushUpdate := range flushUpdateCh { + for _, up := range flushUpdate.noOps { + up.WaitContext(ctx) } - for _, flusher := range flushers { - flusher.flush(ctx, true) - } - - if err := vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil { + if err := vault.SetLastMessageID(flushUpdate.messageID); err != nil { return fmt.Errorf("failed to set last synced message ID: %w", err) } - reporter.add(len(batch)) + reporter.add(flushUpdate.batchLen) + } - return nil - }) + return <-errorCh } func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *imap.MailboxCreated { @@ -338,20 +408,3 @@ func wantLabels(apiLabels map[string]liteapi.Label, labelIDs []string) []string return wantLabel(apiLabels[labelID]) }) } - -func forEach[T any](ctx context.Context, streamer stream.Stream[T], fn func(T) error) error { - defer streamer.Close() - - for { - res, err := streamer.Next(ctx) - if errors.Is(err, stream.End) { - return nil - } else if err != nil { - return fmt.Errorf("failed to get next stream item: %w", err) - } - - if err := fn(res); err != nil { - return fmt.Errorf("failed to process stream item: %w", err) - } - } -} diff --git a/internal/user/sync_flusher.go b/internal/user/sync_flusher.go index d4b9e8e5..4f7877cb 100644 --- a/internal/user/sync_flusher.go +++ b/internal/user/sync_flusher.go @@ -18,9 +18,6 @@ package user import ( - "context" - "sync" - "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/queue" ) @@ -31,8 +28,6 @@ type flusher struct { maxUpdateSize int curChunkSize int - - pushLock sync.Mutex } func newFlusher(updateCh *queue.QueuedChannel[imap.Update], maxUpdateSize int) *flusher { @@ -42,28 +37,18 @@ func newFlusher(updateCh *queue.QueuedChannel[imap.Update], maxUpdateSize int) * } } -func (f *flusher) push(ctx context.Context, update *imap.MessageCreated) { - f.pushLock.Lock() - defer f.pushLock.Unlock() - +func (f *flusher) push(update *imap.MessageCreated) { f.updates = append(f.updates, update) if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxUpdateSize { - f.flush(ctx, false) + f.flush() } } -func (f *flusher) flush(ctx context.Context, wait bool) { +func (f *flusher) flush() { if len(f.updates) > 0 { f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...)) f.updates = nil f.curChunkSize = 0 } - - if wait { - update := imap.NewNoop() - defer update.WaitContext(ctx) - - f.updateCh.Enqueue(update) - } }