fix(GODT-2822): Sync Cache

When the sync fail, store the previously downloaded data in memory so
that on next retry we don't have to re-download everything.
This commit is contained in:
Leander Beernaert
2023-07-26 17:56:57 +02:00
committed by Jakub
parent 5136919c36
commit f1cf4ee194
6 changed files with 275 additions and 39 deletions

View File

@ -89,6 +89,8 @@ func (user *User) handleRefreshEvent(ctx context.Context, refresh proton.Refresh
// Re-sync messages after the user, address and label refresh.
defer user.goSync()
user.syncCache.Clear()
return user.syncUserAddressesLabelsAndClearSync(ctx, false)
}

View File

@ -373,13 +373,15 @@ func (user *User) syncMessages(
)
type flushUpdate struct {
messageID string
err error
batchLen int
batchMessageID string
messages []proton.FullMessage
err error
batchLen int
}
type builtMessageBatch struct {
batch []*buildRes
batch []*buildRes
messages []proton.FullMessage
}
downloadCh := make(chan downloadRequest)
@ -455,7 +457,7 @@ func (user *User) syncMessages(
}, logging.Labels{"sync-stage": "meta-data"})
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, downloadCh, syncLimits)
buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, user.syncCache, downloadCh, syncLimits)
// Goroutine which builds messages after they have been downloaded
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
@ -501,7 +503,7 @@ func (user *User) syncMessages(
}
select {
case flushCh <- builtMessageBatch{result}:
case flushCh <- builtMessageBatch{batch: result, messages: buildBatch.batch}:
case <-ctx.Done():
return
@ -580,9 +582,11 @@ func (user *User) syncMessages(
select {
case flushUpdateCh <- flushUpdate{
messageID: downloadBatch.batch[0].messageID,
err: nil,
batchLen: len(downloadBatch.batch),
batchMessageID: downloadBatch.batch[0].messageID,
messages: downloadBatch.messages,
err: nil,
batchLen: len(downloadBatch.batch),
}:
case <-ctx.Done():
return
@ -595,14 +599,29 @@ func (user *User) syncMessages(
return flushUpdate.err
}
if err := vault.SetLastMessageID(flushUpdate.messageID); err != nil {
if err := vault.SetLastMessageID(flushUpdate.batchMessageID); err != nil {
return fmt.Errorf("failed to set last synced message ID: %w", err)
}
for _, m := range flushUpdate.messages {
user.syncCache.DeleteMessages(m.ID)
if m.NumAttachments != 0 {
user.syncCache.DeleteAttachments(xslices.Map(m.Attachments, func(a proton.Attachment) string {
return a.ID
})...)
}
}
syncReporter.add(flushUpdate.batchLen)
}
return <-errorCh
err := <-errorCh
if err != nil {
user.syncCache.Clear()
}
return err
}
func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *imap.MailboxCreated {

View File

@ -63,7 +63,14 @@ type downloadResult struct {
err error
}
func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, downloadCh <-chan downloadRequest, syncLimits syncLimits) (<-chan downloadedMessageBatch, <-chan error) {
func startSyncDownloader(
ctx context.Context,
panicHandler async.PanicHandler,
downloader MessageDownloader,
cache *SyncDownloadCache,
downloadCh <-chan downloadRequest,
syncLimits syncLimits,
) (<-chan downloadedMessageBatch, <-chan error) {
buildCh := make(chan downloadedMessageBatch)
errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
@ -75,7 +82,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d
logrus.Debugf("sync downloader exit")
}()
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, syncLimits.MaxParallelDownloads)
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, cache, syncLimits.MaxParallelDownloads)
defer attachmentDownloader.close()
for request := range downloadCh {
@ -85,7 +92,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d
return
}
result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, syncLimits.MaxParallelDownloads)
result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, cache, syncLimits.MaxParallelDownloads)
if err != nil {
errorCh <- err
return
@ -96,7 +103,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d
return
}
batch, err := downloadMessagesStage2(ctx, result, downloader, SyncRetryCooldown)
batch, err := downloadMessagesStage2(ctx, result, downloader, cache, SyncRetryCooldown)
if err != nil {
errorCh <- err
return
@ -132,7 +139,7 @@ type attachmentDownloader struct {
cancel context.CancelFunc
}
func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <-chan attachmentJob) {
func attachmentWorker(ctx context.Context, downloader MessageDownloader, cache *SyncDownloadCache, work <-chan attachmentJob) {
for {
select {
case <-ctx.Done():
@ -141,26 +148,45 @@ func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <-
if !ok {
return
}
var b bytes.Buffer
b.Grow(int(job.size))
err := downloader.GetAttachmentInto(ctx, job.id, &b)
var result attachmentResult
if data, ok := cache.GetAttachment(job.id); ok {
result.attachment = data
result.err = nil
} else {
var b bytes.Buffer
b.Grow(int(job.size))
err := downloader.GetAttachmentInto(ctx, job.id, &b)
result.attachment = b.Bytes()
result.err = err
if err == nil {
cache.StoreAttachment(job.id, result.attachment)
}
}
select {
case <-ctx.Done():
close(job.result)
return
case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
case job.result <- result:
close(job.result)
}
}
}
}
func newAttachmentDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, workerCount int) *attachmentDownloader {
func newAttachmentDownloader(
ctx context.Context,
panicHandler async.PanicHandler,
downloader MessageDownloader,
cache *SyncDownloadCache,
workerCount int,
) *attachmentDownloader {
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
ctx, cancel := context.WithCancel(ctx)
for i := 0; i < workerCount; i++ {
workerCh = make(chan attachmentJob)
async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, workerCh) }, logging.Labels{
async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, cache, workerCh) }, logging.Labels{
"sync": fmt.Sprintf("att-downloader %v", i),
})
}
@ -209,6 +235,7 @@ func downloadMessageStage1(
request downloadRequest,
downloader MessageDownloader,
attachmentDownloader *attachmentDownloader,
cache *SyncDownloadCache,
parallelDownloads int,
) ([]downloadResult, error) {
// 1st attempt download everything in parallel
@ -217,21 +244,28 @@ func downloadMessageStage1(
result := downloadResult{ID: id}
msg, err := downloader.GetMessage(ctx, id)
if err != nil {
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
result.err = err
return result, nil
v, ok := cache.GetMessage(id)
if !ok {
msg, err := downloader.GetMessage(ctx, id)
if err != nil {
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
result.err = err
return result, nil
}
cache.StoreMessage(msg)
result.Message.Message = msg
} else {
result.Message.Message = v
}
result.Message.Message = msg
result.State = downloadStateHasMessage
attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
attachments, err := attachmentDownloader.getAttachments(ctx, result.Message.Attachments)
result.Message.AttData = attachments
if err != nil {
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments")
logrus.WithError(err).WithField("msgID", id).Error("Failed to download message attachments")
result.err = err
return result, nil
}
@ -242,7 +276,13 @@ func downloadMessageStage1(
})
}
func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloader MessageDownloader, coolDown time.Duration) ([]proton.FullMessage, error) {
func downloadMessagesStage2(
ctx context.Context,
state []downloadResult,
downloader MessageDownloader,
cache *SyncDownloadCache,
coolDown time.Duration,
) ([]proton.FullMessage, error) {
logrus.Debug("Entering download stage 2")
var retryList []int
var shouldWaitBeforeRetry bool
@ -289,6 +329,7 @@ func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloa
return nil, err
}
cache.StoreMessage(message)
st.Message.Message = message
st.State = downloadStateHasMessage
}
@ -314,6 +355,7 @@ func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloa
}
st.Message.AttData[i] = buffer.Bytes()
cache.StoreAttachment(st.Message.Attachments[i].ID, st.Message.AttData[i])
}
}

View File

@ -89,10 +89,11 @@ func TestSyncDownloader_Stage1_429(t *testing.T) {
return err
})
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, 1)
cache := newSyncDownloadCache()
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, cache, 1)
defer attachmentDownloader.close()
result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, 1)
result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, cache, 1)
require.NoError(t, err)
require.Equal(t, 3, len(result))
// Check message 1
@ -115,12 +116,21 @@ func TestSyncDownloader_Stage1_429(t *testing.T) {
require.Nil(t, result[2].Message.AttData[0])
require.NotEqual(t, attachmentData, result[2].Message.AttData[1])
require.NotNil(t, result[2].err)
_, ok := cache.GetMessage("MsgID1")
require.True(t, ok)
_, ok = cache.GetMessage("MsgID3")
require.True(t, ok)
att, ok := cache.GetAttachment("Attachment1_1")
require.True(t, ok)
require.Equal(t, attachmentData, string(att))
}
func TestSyncDownloader_Stage2_Everything200(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
downloadResult := []downloadResult{
{
@ -133,7 +143,7 @@ func TestSyncDownloader_Stage2_Everything200(t *testing.T) {
},
}
result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
require.NoError(t, err)
require.Equal(t, 2, len(result))
}
@ -142,6 +152,7 @@ func TestSyncDownloader_Stage2_Not429(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
msgErr := fmt.Errorf("something not 429")
downloadResult := []downloadResult{
@ -160,7 +171,7 @@ func TestSyncDownloader_Stage2_Not429(t *testing.T) {
},
}
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
require.Error(t, err)
require.Equal(t, msgErr, err)
}
@ -169,6 +180,7 @@ func TestSyncDownloader_Stage2_API500(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
msgErr := &proton.APIError{Status: 500}
downloadResult := []downloadResult{
@ -183,7 +195,7 @@ func TestSyncDownloader_Stage2_API500(t *testing.T) {
},
}
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
require.Error(t, err)
require.Equal(t, msgErr, err)
}
@ -192,6 +204,7 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
const attachmentData1 = "attachment data 1"
const attachmentData2 = "attachment data 2"
@ -290,7 +303,7 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) {
})
}
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
require.NoError(t, err)
require.Equal(t, 3, len(messages))
@ -304,12 +317,28 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) {
require.Equal(t, attachmentData1, string(messages[1].AttData[0]))
require.Equal(t, attachmentData2, string(messages[1].AttData[1]))
require.Empty(t, messages[2].AttData)
_, ok := cache.GetMessage("Msg3")
require.True(t, ok)
att3, ok := cache.GetAttachment("A3")
require.True(t, ok)
require.Equal(t, attachmentData3, string(att3))
att1, ok := cache.GetAttachment("A1")
require.True(t, ok)
require.Equal(t, attachmentData1, string(att1))
att2, ok := cache.GetAttachment("A2")
require.True(t, ok)
require.Equal(t, attachmentData2, string(att2))
}
func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
err429 := &proton.APIError{Status: 429}
err500 := &proton.APIError{Status: 500}
@ -352,7 +381,7 @@ func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) {
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{}, err500)
}
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
require.Error(t, err)
require.Empty(t, 0, messages)
}
@ -361,6 +390,7 @@ func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
err429 := &proton.APIError{Status: 429}
err500 := &proton.APIError{Status: 500}
@ -394,7 +424,50 @@ func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) {
// 500 for second attachment
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A4"), gomock.Any()).Times(1).Return(err500)
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
require.Error(t, err)
require.Empty(t, 0, messages)
}
func TestSyncDownloader_Stage1_DoNotDownloadIfAlreadyInCache(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
panicHandler := &async.NoopPanicHandler{}
ctx := context.Background()
requests := downloadRequest{
ids: []string{"Msg1", "Msg3"},
expectedSize: 0,
err: nil,
}
cache := newSyncDownloadCache()
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, cache, 1)
defer attachmentDownloader.close()
const attachmentData = "attachment data"
cache.StoreMessage(proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg1", NumAttachments: 1}, Attachments: []proton.Attachment{{ID: "A1"}}})
cache.StoreMessage(proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3", NumAttachments: 2}, Attachments: []proton.Attachment{{ID: "A2"}}})
cache.StoreAttachment("A1", []byte(attachmentData))
cache.StoreAttachment("A2", []byte(attachmentData))
result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, cache, 1)
require.NoError(t, err)
require.Equal(t, 2, len(result))
require.Equal(t, result[0].State, downloadStateFinished)
require.Equal(t, result[0].Message.ID, "Msg1")
require.NotEmpty(t, result[0].Message.AttData)
require.NotEqual(t, attachmentData, result[0].Message.AttData[0])
require.NotNil(t, result[0].Message.AttData[0])
require.Nil(t, result[0].err)
require.Equal(t, result[1].State, downloadStateFinished)
require.Equal(t, result[1].Message.ID, "Msg3")
require.NotEmpty(t, result[1].Message.AttData)
require.NotEqual(t, attachmentData, result[1].Message.AttData[0])
require.NotNil(t, result[1].Message.AttData[0])
require.Nil(t, result[1].err)
}

View File

@ -0,0 +1,98 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package user
import (
"sync"
"github.com/ProtonMail/go-proton-api"
)
type SyncDownloadCache struct {
messageLock sync.RWMutex
messages map[string]proton.Message
attachmentLock sync.RWMutex
attachments map[string][]byte
}
func newSyncDownloadCache() *SyncDownloadCache {
return &SyncDownloadCache{
messages: make(map[string]proton.Message, 64),
attachments: make(map[string][]byte, 64),
}
}
func (s *SyncDownloadCache) StoreMessage(message proton.Message) {
s.messageLock.Lock()
defer s.messageLock.Unlock()
s.messages[message.ID] = message
}
func (s *SyncDownloadCache) StoreAttachment(id string, data []byte) {
s.attachmentLock.Lock()
defer s.attachmentLock.Unlock()
s.attachments[id] = data
}
func (s *SyncDownloadCache) DeleteMessages(id ...string) {
s.messageLock.Lock()
defer s.messageLock.Unlock()
for _, id := range id {
delete(s.messages, id)
}
}
func (s *SyncDownloadCache) DeleteAttachments(id ...string) {
s.attachmentLock.Lock()
defer s.attachmentLock.Unlock()
for _, id := range id {
delete(s.attachments, id)
}
}
func (s *SyncDownloadCache) GetMessage(id string) (proton.Message, bool) {
s.messageLock.RLock()
defer s.messageLock.RUnlock()
v, ok := s.messages[id]
return v, ok
}
func (s *SyncDownloadCache) GetAttachment(id string) ([]byte, bool) {
s.attachmentLock.RLock()
defer s.attachmentLock.RUnlock()
v, ok := s.attachments[id]
return v, ok
}
func (s *SyncDownloadCache) Clear() {
s.messageLock.Lock()
s.messages = make(map[string]proton.Message, 64)
s.messageLock.Unlock()
s.attachmentLock.Lock()
s.attachments = make(map[string][]byte, 64)
s.attachmentLock.Unlock()
}

View File

@ -93,6 +93,7 @@ type User struct {
showAllMail uint32
maxSyncMemory uint64
syncCache *SyncDownloadCache
panicHandler async.PanicHandler
@ -171,6 +172,7 @@ func New(
showAllMail: b32(showAllMail),
maxSyncMemory: maxSyncMemory,
syncCache: newSyncDownloadCache(),
panicHandler: crashHandler,