From f1cf4ee19434af54d5d20460ee108d482545fec0 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Wed, 26 Jul 2023 17:56:57 +0200 Subject: [PATCH] fix(GODT-2822): Sync Cache When the sync fail, store the previously downloaded data in memory so that on next retry we don't have to re-download everything. --- internal/user/events.go | 2 + internal/user/sync.go | 41 ++++++++--- internal/user/sync_downloader.go | 82 ++++++++++++++++------ internal/user/sync_downloader_test.go | 89 +++++++++++++++++++++--- internal/user/sync_message_cache.go | 98 +++++++++++++++++++++++++++ internal/user/user.go | 2 + 6 files changed, 275 insertions(+), 39 deletions(-) create mode 100644 internal/user/sync_message_cache.go diff --git a/internal/user/events.go b/internal/user/events.go index 6876aef6..ce79758f 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -89,6 +89,8 @@ func (user *User) handleRefreshEvent(ctx context.Context, refresh proton.Refresh // Re-sync messages after the user, address and label refresh. defer user.goSync() + user.syncCache.Clear() + return user.syncUserAddressesLabelsAndClearSync(ctx, false) } diff --git a/internal/user/sync.go b/internal/user/sync.go index 84cce7c6..25cded53 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -373,13 +373,15 @@ func (user *User) syncMessages( ) type flushUpdate struct { - messageID string - err error - batchLen int + batchMessageID string + messages []proton.FullMessage + err error + batchLen int } type builtMessageBatch struct { - batch []*buildRes + batch []*buildRes + messages []proton.FullMessage } downloadCh := make(chan downloadRequest) @@ -455,7 +457,7 @@ func (user *User) syncMessages( }, logging.Labels{"sync-stage": "meta-data"}) // Goroutine in charge of downloading and building messages in maxBatchSize batches. - buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, downloadCh, syncLimits) + buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, user.syncCache, downloadCh, syncLimits) // Goroutine which builds messages after they have been downloaded async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { @@ -501,7 +503,7 @@ func (user *User) syncMessages( } select { - case flushCh <- builtMessageBatch{result}: + case flushCh <- builtMessageBatch{batch: result, messages: buildBatch.batch}: case <-ctx.Done(): return @@ -580,9 +582,11 @@ func (user *User) syncMessages( select { case flushUpdateCh <- flushUpdate{ - messageID: downloadBatch.batch[0].messageID, - err: nil, - batchLen: len(downloadBatch.batch), + batchMessageID: downloadBatch.batch[0].messageID, + messages: downloadBatch.messages, + + err: nil, + batchLen: len(downloadBatch.batch), }: case <-ctx.Done(): return @@ -595,14 +599,29 @@ func (user *User) syncMessages( return flushUpdate.err } - if err := vault.SetLastMessageID(flushUpdate.messageID); err != nil { + if err := vault.SetLastMessageID(flushUpdate.batchMessageID); err != nil { return fmt.Errorf("failed to set last synced message ID: %w", err) } + for _, m := range flushUpdate.messages { + user.syncCache.DeleteMessages(m.ID) + if m.NumAttachments != 0 { + user.syncCache.DeleteAttachments(xslices.Map(m.Attachments, func(a proton.Attachment) string { + return a.ID + })...) + } + } + syncReporter.add(flushUpdate.batchLen) } - return <-errorCh + err := <-errorCh + + if err != nil { + user.syncCache.Clear() + } + + return err } func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *imap.MailboxCreated { diff --git a/internal/user/sync_downloader.go b/internal/user/sync_downloader.go index 4d61b2f0..e84450bb 100644 --- a/internal/user/sync_downloader.go +++ b/internal/user/sync_downloader.go @@ -63,7 +63,14 @@ type downloadResult struct { err error } -func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, downloadCh <-chan downloadRequest, syncLimits syncLimits) (<-chan downloadedMessageBatch, <-chan error) { +func startSyncDownloader( + ctx context.Context, + panicHandler async.PanicHandler, + downloader MessageDownloader, + cache *SyncDownloadCache, + downloadCh <-chan downloadRequest, + syncLimits syncLimits, +) (<-chan downloadedMessageBatch, <-chan error) { buildCh := make(chan downloadedMessageBatch) errorCh := make(chan error, syncLimits.MaxParallelDownloads*4) @@ -75,7 +82,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d logrus.Debugf("sync downloader exit") }() - attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, syncLimits.MaxParallelDownloads) + attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, cache, syncLimits.MaxParallelDownloads) defer attachmentDownloader.close() for request := range downloadCh { @@ -85,7 +92,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d return } - result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, syncLimits.MaxParallelDownloads) + result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, cache, syncLimits.MaxParallelDownloads) if err != nil { errorCh <- err return @@ -96,7 +103,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d return } - batch, err := downloadMessagesStage2(ctx, result, downloader, SyncRetryCooldown) + batch, err := downloadMessagesStage2(ctx, result, downloader, cache, SyncRetryCooldown) if err != nil { errorCh <- err return @@ -132,7 +139,7 @@ type attachmentDownloader struct { cancel context.CancelFunc } -func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <-chan attachmentJob) { +func attachmentWorker(ctx context.Context, downloader MessageDownloader, cache *SyncDownloadCache, work <-chan attachmentJob) { for { select { case <-ctx.Done(): @@ -141,26 +148,45 @@ func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <- if !ok { return } - var b bytes.Buffer - b.Grow(int(job.size)) - err := downloader.GetAttachmentInto(ctx, job.id, &b) + + var result attachmentResult + if data, ok := cache.GetAttachment(job.id); ok { + result.attachment = data + result.err = nil + } else { + var b bytes.Buffer + b.Grow(int(job.size)) + err := downloader.GetAttachmentInto(ctx, job.id, &b) + result.attachment = b.Bytes() + result.err = err + if err == nil { + cache.StoreAttachment(job.id, result.attachment) + } + } + select { case <-ctx.Done(): close(job.result) return - case job.result <- attachmentResult{attachment: b.Bytes(), err: err}: + case job.result <- result: close(job.result) } } } } -func newAttachmentDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, workerCount int) *attachmentDownloader { +func newAttachmentDownloader( + ctx context.Context, + panicHandler async.PanicHandler, + downloader MessageDownloader, + cache *SyncDownloadCache, + workerCount int, +) *attachmentDownloader { workerCh := make(chan attachmentJob, (workerCount+2)*workerCount) ctx, cancel := context.WithCancel(ctx) for i := 0; i < workerCount; i++ { workerCh = make(chan attachmentJob) - async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, workerCh) }, logging.Labels{ + async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, cache, workerCh) }, logging.Labels{ "sync": fmt.Sprintf("att-downloader %v", i), }) } @@ -209,6 +235,7 @@ func downloadMessageStage1( request downloadRequest, downloader MessageDownloader, attachmentDownloader *attachmentDownloader, + cache *SyncDownloadCache, parallelDownloads int, ) ([]downloadResult, error) { // 1st attempt download everything in parallel @@ -217,21 +244,28 @@ func downloadMessageStage1( result := downloadResult{ID: id} - msg, err := downloader.GetMessage(ctx, id) - if err != nil { - logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message") - result.err = err - return result, nil + v, ok := cache.GetMessage(id) + if !ok { + msg, err := downloader.GetMessage(ctx, id) + if err != nil { + logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message") + result.err = err + return result, nil + } + + cache.StoreMessage(msg) + result.Message.Message = msg + } else { + result.Message.Message = v } - result.Message.Message = msg result.State = downloadStateHasMessage - attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments) + attachments, err := attachmentDownloader.getAttachments(ctx, result.Message.Attachments) result.Message.AttData = attachments if err != nil { - logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments") + logrus.WithError(err).WithField("msgID", id).Error("Failed to download message attachments") result.err = err return result, nil } @@ -242,7 +276,13 @@ func downloadMessageStage1( }) } -func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloader MessageDownloader, coolDown time.Duration) ([]proton.FullMessage, error) { +func downloadMessagesStage2( + ctx context.Context, + state []downloadResult, + downloader MessageDownloader, + cache *SyncDownloadCache, + coolDown time.Duration, +) ([]proton.FullMessage, error) { logrus.Debug("Entering download stage 2") var retryList []int var shouldWaitBeforeRetry bool @@ -289,6 +329,7 @@ func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloa return nil, err } + cache.StoreMessage(message) st.Message.Message = message st.State = downloadStateHasMessage } @@ -314,6 +355,7 @@ func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloa } st.Message.AttData[i] = buffer.Bytes() + cache.StoreAttachment(st.Message.Attachments[i].ID, st.Message.AttData[i]) } } diff --git a/internal/user/sync_downloader_test.go b/internal/user/sync_downloader_test.go index b8edfdd9..c03d6061 100644 --- a/internal/user/sync_downloader_test.go +++ b/internal/user/sync_downloader_test.go @@ -89,10 +89,11 @@ func TestSyncDownloader_Stage1_429(t *testing.T) { return err }) - attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, 1) + cache := newSyncDownloadCache() + attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, cache, 1) defer attachmentDownloader.close() - result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, 1) + result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, cache, 1) require.NoError(t, err) require.Equal(t, 3, len(result)) // Check message 1 @@ -115,12 +116,21 @@ func TestSyncDownloader_Stage1_429(t *testing.T) { require.Nil(t, result[2].Message.AttData[0]) require.NotEqual(t, attachmentData, result[2].Message.AttData[1]) require.NotNil(t, result[2].err) + + _, ok := cache.GetMessage("MsgID1") + require.True(t, ok) + _, ok = cache.GetMessage("MsgID3") + require.True(t, ok) + att, ok := cache.GetAttachment("Attachment1_1") + require.True(t, ok) + require.Equal(t, attachmentData, string(att)) } func TestSyncDownloader_Stage2_Everything200(t *testing.T) { mockCtrl := gomock.NewController(t) messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) ctx := context.Background() + cache := newSyncDownloadCache() downloadResult := []downloadResult{ { @@ -133,7 +143,7 @@ func TestSyncDownloader_Stage2_Everything200(t *testing.T) { }, } - result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond) require.NoError(t, err) require.Equal(t, 2, len(result)) } @@ -142,6 +152,7 @@ func TestSyncDownloader_Stage2_Not429(t *testing.T) { mockCtrl := gomock.NewController(t) messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) ctx := context.Background() + cache := newSyncDownloadCache() msgErr := fmt.Errorf("something not 429") downloadResult := []downloadResult{ @@ -160,7 +171,7 @@ func TestSyncDownloader_Stage2_Not429(t *testing.T) { }, } - _, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + _, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond) require.Error(t, err) require.Equal(t, msgErr, err) } @@ -169,6 +180,7 @@ func TestSyncDownloader_Stage2_API500(t *testing.T) { mockCtrl := gomock.NewController(t) messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) ctx := context.Background() + cache := newSyncDownloadCache() msgErr := &proton.APIError{Status: 500} downloadResult := []downloadResult{ @@ -183,7 +195,7 @@ func TestSyncDownloader_Stage2_API500(t *testing.T) { }, } - _, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + _, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond) require.Error(t, err) require.Equal(t, msgErr, err) } @@ -192,6 +204,7 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) { mockCtrl := gomock.NewController(t) messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) ctx := context.Background() + cache := newSyncDownloadCache() const attachmentData1 = "attachment data 1" const attachmentData2 = "attachment data 2" @@ -290,7 +303,7 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) { }) } - messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond) require.NoError(t, err) require.Equal(t, 3, len(messages)) @@ -304,12 +317,28 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) { require.Equal(t, attachmentData1, string(messages[1].AttData[0])) require.Equal(t, attachmentData2, string(messages[1].AttData[1])) require.Empty(t, messages[2].AttData) + + _, ok := cache.GetMessage("Msg3") + require.True(t, ok) + + att3, ok := cache.GetAttachment("A3") + require.True(t, ok) + require.Equal(t, attachmentData3, string(att3)) + + att1, ok := cache.GetAttachment("A1") + require.True(t, ok) + require.Equal(t, attachmentData1, string(att1)) + + att2, ok := cache.GetAttachment("A2") + require.True(t, ok) + require.Equal(t, attachmentData2, string(att2)) } func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) { mockCtrl := gomock.NewController(t) messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) ctx := context.Background() + cache := newSyncDownloadCache() err429 := &proton.APIError{Status: 429} err500 := &proton.APIError{Status: 500} @@ -352,7 +381,7 @@ func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) { messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{}, err500) } - messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond) require.Error(t, err) require.Empty(t, 0, messages) } @@ -361,6 +390,7 @@ func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) { mockCtrl := gomock.NewController(t) messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) ctx := context.Background() + cache := newSyncDownloadCache() err429 := &proton.APIError{Status: 429} err500 := &proton.APIError{Status: 500} @@ -394,7 +424,50 @@ func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) { // 500 for second attachment messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A4"), gomock.Any()).Times(1).Return(err500) - messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond) require.Error(t, err) require.Empty(t, 0, messages) } + +func TestSyncDownloader_Stage1_DoNotDownloadIfAlreadyInCache(t *testing.T) { + mockCtrl := gomock.NewController(t) + messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) + panicHandler := &async.NoopPanicHandler{} + ctx := context.Background() + + requests := downloadRequest{ + ids: []string{"Msg1", "Msg3"}, + expectedSize: 0, + err: nil, + } + + cache := newSyncDownloadCache() + attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, cache, 1) + defer attachmentDownloader.close() + + const attachmentData = "attachment data" + + cache.StoreMessage(proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg1", NumAttachments: 1}, Attachments: []proton.Attachment{{ID: "A1"}}}) + cache.StoreMessage(proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3", NumAttachments: 2}, Attachments: []proton.Attachment{{ID: "A2"}}}) + + cache.StoreAttachment("A1", []byte(attachmentData)) + cache.StoreAttachment("A2", []byte(attachmentData)) + + result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, cache, 1) + require.NoError(t, err) + require.Equal(t, 2, len(result)) + + require.Equal(t, result[0].State, downloadStateFinished) + require.Equal(t, result[0].Message.ID, "Msg1") + require.NotEmpty(t, result[0].Message.AttData) + require.NotEqual(t, attachmentData, result[0].Message.AttData[0]) + require.NotNil(t, result[0].Message.AttData[0]) + require.Nil(t, result[0].err) + + require.Equal(t, result[1].State, downloadStateFinished) + require.Equal(t, result[1].Message.ID, "Msg3") + require.NotEmpty(t, result[1].Message.AttData) + require.NotEqual(t, attachmentData, result[1].Message.AttData[0]) + require.NotNil(t, result[1].Message.AttData[0]) + require.Nil(t, result[1].err) +} diff --git a/internal/user/sync_message_cache.go b/internal/user/sync_message_cache.go new file mode 100644 index 00000000..6cec6e10 --- /dev/null +++ b/internal/user/sync_message_cache.go @@ -0,0 +1,98 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package user + +import ( + "sync" + + "github.com/ProtonMail/go-proton-api" +) + +type SyncDownloadCache struct { + messageLock sync.RWMutex + messages map[string]proton.Message + attachmentLock sync.RWMutex + attachments map[string][]byte +} + +func newSyncDownloadCache() *SyncDownloadCache { + return &SyncDownloadCache{ + messages: make(map[string]proton.Message, 64), + attachments: make(map[string][]byte, 64), + } +} + +func (s *SyncDownloadCache) StoreMessage(message proton.Message) { + s.messageLock.Lock() + defer s.messageLock.Unlock() + + s.messages[message.ID] = message +} + +func (s *SyncDownloadCache) StoreAttachment(id string, data []byte) { + s.attachmentLock.Lock() + defer s.attachmentLock.Unlock() + + s.attachments[id] = data +} + +func (s *SyncDownloadCache) DeleteMessages(id ...string) { + s.messageLock.Lock() + defer s.messageLock.Unlock() + + for _, id := range id { + delete(s.messages, id) + } +} + +func (s *SyncDownloadCache) DeleteAttachments(id ...string) { + s.attachmentLock.Lock() + defer s.attachmentLock.Unlock() + + for _, id := range id { + delete(s.attachments, id) + } +} + +func (s *SyncDownloadCache) GetMessage(id string) (proton.Message, bool) { + s.messageLock.RLock() + defer s.messageLock.RUnlock() + + v, ok := s.messages[id] + + return v, ok +} + +func (s *SyncDownloadCache) GetAttachment(id string) ([]byte, bool) { + s.attachmentLock.RLock() + defer s.attachmentLock.RUnlock() + + v, ok := s.attachments[id] + + return v, ok +} + +func (s *SyncDownloadCache) Clear() { + s.messageLock.Lock() + s.messages = make(map[string]proton.Message, 64) + s.messageLock.Unlock() + + s.attachmentLock.Lock() + s.attachments = make(map[string][]byte, 64) + s.attachmentLock.Unlock() +} diff --git a/internal/user/user.go b/internal/user/user.go index 49d29ab7..0399fba1 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -93,6 +93,7 @@ type User struct { showAllMail uint32 maxSyncMemory uint64 + syncCache *SyncDownloadCache panicHandler async.PanicHandler @@ -171,6 +172,7 @@ func New( showAllMail: b32(showAllMail), maxSyncMemory: maxSyncMemory, + syncCache: newSyncDownloadCache(), panicHandler: crashHandler,