mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-18 16:17:03 +00:00
fix(GODT-2822): Sync Cache
When the sync fail, store the previously downloaded data in memory so that on next retry we don't have to re-download everything.
This commit is contained in:
@ -89,6 +89,8 @@ func (user *User) handleRefreshEvent(ctx context.Context, refresh proton.Refresh
|
|||||||
// Re-sync messages after the user, address and label refresh.
|
// Re-sync messages after the user, address and label refresh.
|
||||||
defer user.goSync()
|
defer user.goSync()
|
||||||
|
|
||||||
|
user.syncCache.Clear()
|
||||||
|
|
||||||
return user.syncUserAddressesLabelsAndClearSync(ctx, false)
|
return user.syncUserAddressesLabelsAndClearSync(ctx, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -373,13 +373,15 @@ func (user *User) syncMessages(
|
|||||||
)
|
)
|
||||||
|
|
||||||
type flushUpdate struct {
|
type flushUpdate struct {
|
||||||
messageID string
|
batchMessageID string
|
||||||
|
messages []proton.FullMessage
|
||||||
err error
|
err error
|
||||||
batchLen int
|
batchLen int
|
||||||
}
|
}
|
||||||
|
|
||||||
type builtMessageBatch struct {
|
type builtMessageBatch struct {
|
||||||
batch []*buildRes
|
batch []*buildRes
|
||||||
|
messages []proton.FullMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
downloadCh := make(chan downloadRequest)
|
downloadCh := make(chan downloadRequest)
|
||||||
@ -455,7 +457,7 @@ func (user *User) syncMessages(
|
|||||||
}, logging.Labels{"sync-stage": "meta-data"})
|
}, logging.Labels{"sync-stage": "meta-data"})
|
||||||
|
|
||||||
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
||||||
buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, downloadCh, syncLimits)
|
buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, user.syncCache, downloadCh, syncLimits)
|
||||||
|
|
||||||
// Goroutine which builds messages after they have been downloaded
|
// Goroutine which builds messages after they have been downloaded
|
||||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||||
@ -501,7 +503,7 @@ func (user *User) syncMessages(
|
|||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case flushCh <- builtMessageBatch{result}:
|
case flushCh <- builtMessageBatch{batch: result, messages: buildBatch.batch}:
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
@ -580,7 +582,9 @@ func (user *User) syncMessages(
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case flushUpdateCh <- flushUpdate{
|
case flushUpdateCh <- flushUpdate{
|
||||||
messageID: downloadBatch.batch[0].messageID,
|
batchMessageID: downloadBatch.batch[0].messageID,
|
||||||
|
messages: downloadBatch.messages,
|
||||||
|
|
||||||
err: nil,
|
err: nil,
|
||||||
batchLen: len(downloadBatch.batch),
|
batchLen: len(downloadBatch.batch),
|
||||||
}:
|
}:
|
||||||
@ -595,14 +599,29 @@ func (user *User) syncMessages(
|
|||||||
return flushUpdate.err
|
return flushUpdate.err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := vault.SetLastMessageID(flushUpdate.messageID); err != nil {
|
if err := vault.SetLastMessageID(flushUpdate.batchMessageID); err != nil {
|
||||||
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, m := range flushUpdate.messages {
|
||||||
|
user.syncCache.DeleteMessages(m.ID)
|
||||||
|
if m.NumAttachments != 0 {
|
||||||
|
user.syncCache.DeleteAttachments(xslices.Map(m.Attachments, func(a proton.Attachment) string {
|
||||||
|
return a.ID
|
||||||
|
})...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
syncReporter.add(flushUpdate.batchLen)
|
syncReporter.add(flushUpdate.batchLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
return <-errorCh
|
err := <-errorCh
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
user.syncCache.Clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *imap.MailboxCreated {
|
func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *imap.MailboxCreated {
|
||||||
|
|||||||
@ -63,7 +63,14 @@ type downloadResult struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, downloadCh <-chan downloadRequest, syncLimits syncLimits) (<-chan downloadedMessageBatch, <-chan error) {
|
func startSyncDownloader(
|
||||||
|
ctx context.Context,
|
||||||
|
panicHandler async.PanicHandler,
|
||||||
|
downloader MessageDownloader,
|
||||||
|
cache *SyncDownloadCache,
|
||||||
|
downloadCh <-chan downloadRequest,
|
||||||
|
syncLimits syncLimits,
|
||||||
|
) (<-chan downloadedMessageBatch, <-chan error) {
|
||||||
buildCh := make(chan downloadedMessageBatch)
|
buildCh := make(chan downloadedMessageBatch)
|
||||||
errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
|
errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
|
||||||
|
|
||||||
@ -75,7 +82,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d
|
|||||||
logrus.Debugf("sync downloader exit")
|
logrus.Debugf("sync downloader exit")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, syncLimits.MaxParallelDownloads)
|
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, cache, syncLimits.MaxParallelDownloads)
|
||||||
defer attachmentDownloader.close()
|
defer attachmentDownloader.close()
|
||||||
|
|
||||||
for request := range downloadCh {
|
for request := range downloadCh {
|
||||||
@ -85,7 +92,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, syncLimits.MaxParallelDownloads)
|
result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, cache, syncLimits.MaxParallelDownloads)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorCh <- err
|
errorCh <- err
|
||||||
return
|
return
|
||||||
@ -96,7 +103,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
batch, err := downloadMessagesStage2(ctx, result, downloader, SyncRetryCooldown)
|
batch, err := downloadMessagesStage2(ctx, result, downloader, cache, SyncRetryCooldown)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorCh <- err
|
errorCh <- err
|
||||||
return
|
return
|
||||||
@ -132,7 +139,7 @@ type attachmentDownloader struct {
|
|||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <-chan attachmentJob) {
|
func attachmentWorker(ctx context.Context, downloader MessageDownloader, cache *SyncDownloadCache, work <-chan attachmentJob) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@ -141,26 +148,45 @@ func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <-
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var result attachmentResult
|
||||||
|
if data, ok := cache.GetAttachment(job.id); ok {
|
||||||
|
result.attachment = data
|
||||||
|
result.err = nil
|
||||||
|
} else {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
b.Grow(int(job.size))
|
b.Grow(int(job.size))
|
||||||
err := downloader.GetAttachmentInto(ctx, job.id, &b)
|
err := downloader.GetAttachmentInto(ctx, job.id, &b)
|
||||||
|
result.attachment = b.Bytes()
|
||||||
|
result.err = err
|
||||||
|
if err == nil {
|
||||||
|
cache.StoreAttachment(job.id, result.attachment)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
close(job.result)
|
close(job.result)
|
||||||
return
|
return
|
||||||
case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
|
case job.result <- result:
|
||||||
close(job.result)
|
close(job.result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAttachmentDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, workerCount int) *attachmentDownloader {
|
func newAttachmentDownloader(
|
||||||
|
ctx context.Context,
|
||||||
|
panicHandler async.PanicHandler,
|
||||||
|
downloader MessageDownloader,
|
||||||
|
cache *SyncDownloadCache,
|
||||||
|
workerCount int,
|
||||||
|
) *attachmentDownloader {
|
||||||
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
|
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
for i := 0; i < workerCount; i++ {
|
for i := 0; i < workerCount; i++ {
|
||||||
workerCh = make(chan attachmentJob)
|
workerCh = make(chan attachmentJob)
|
||||||
async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, workerCh) }, logging.Labels{
|
async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, cache, workerCh) }, logging.Labels{
|
||||||
"sync": fmt.Sprintf("att-downloader %v", i),
|
"sync": fmt.Sprintf("att-downloader %v", i),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -209,6 +235,7 @@ func downloadMessageStage1(
|
|||||||
request downloadRequest,
|
request downloadRequest,
|
||||||
downloader MessageDownloader,
|
downloader MessageDownloader,
|
||||||
attachmentDownloader *attachmentDownloader,
|
attachmentDownloader *attachmentDownloader,
|
||||||
|
cache *SyncDownloadCache,
|
||||||
parallelDownloads int,
|
parallelDownloads int,
|
||||||
) ([]downloadResult, error) {
|
) ([]downloadResult, error) {
|
||||||
// 1st attempt download everything in parallel
|
// 1st attempt download everything in parallel
|
||||||
@ -217,6 +244,8 @@ func downloadMessageStage1(
|
|||||||
|
|
||||||
result := downloadResult{ID: id}
|
result := downloadResult{ID: id}
|
||||||
|
|
||||||
|
v, ok := cache.GetMessage(id)
|
||||||
|
if !ok {
|
||||||
msg, err := downloader.GetMessage(ctx, id)
|
msg, err := downloader.GetMessage(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
|
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
|
||||||
@ -224,14 +253,19 @@ func downloadMessageStage1(
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cache.StoreMessage(msg)
|
||||||
result.Message.Message = msg
|
result.Message.Message = msg
|
||||||
|
} else {
|
||||||
|
result.Message.Message = v
|
||||||
|
}
|
||||||
|
|
||||||
result.State = downloadStateHasMessage
|
result.State = downloadStateHasMessage
|
||||||
|
|
||||||
attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
|
attachments, err := attachmentDownloader.getAttachments(ctx, result.Message.Attachments)
|
||||||
result.Message.AttData = attachments
|
result.Message.AttData = attachments
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments")
|
logrus.WithError(err).WithField("msgID", id).Error("Failed to download message attachments")
|
||||||
result.err = err
|
result.err = err
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
@ -242,7 +276,13 @@ func downloadMessageStage1(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloader MessageDownloader, coolDown time.Duration) ([]proton.FullMessage, error) {
|
func downloadMessagesStage2(
|
||||||
|
ctx context.Context,
|
||||||
|
state []downloadResult,
|
||||||
|
downloader MessageDownloader,
|
||||||
|
cache *SyncDownloadCache,
|
||||||
|
coolDown time.Duration,
|
||||||
|
) ([]proton.FullMessage, error) {
|
||||||
logrus.Debug("Entering download stage 2")
|
logrus.Debug("Entering download stage 2")
|
||||||
var retryList []int
|
var retryList []int
|
||||||
var shouldWaitBeforeRetry bool
|
var shouldWaitBeforeRetry bool
|
||||||
@ -289,6 +329,7 @@ func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloa
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cache.StoreMessage(message)
|
||||||
st.Message.Message = message
|
st.Message.Message = message
|
||||||
st.State = downloadStateHasMessage
|
st.State = downloadStateHasMessage
|
||||||
}
|
}
|
||||||
@ -314,6 +355,7 @@ func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloa
|
|||||||
}
|
}
|
||||||
|
|
||||||
st.Message.AttData[i] = buffer.Bytes()
|
st.Message.AttData[i] = buffer.Bytes()
|
||||||
|
cache.StoreAttachment(st.Message.Attachments[i].ID, st.Message.AttData[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -89,10 +89,11 @@ func TestSyncDownloader_Stage1_429(t *testing.T) {
|
|||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, 1)
|
cache := newSyncDownloadCache()
|
||||||
|
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, cache, 1)
|
||||||
defer attachmentDownloader.close()
|
defer attachmentDownloader.close()
|
||||||
|
|
||||||
result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, 1)
|
result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, cache, 1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 3, len(result))
|
require.Equal(t, 3, len(result))
|
||||||
// Check message 1
|
// Check message 1
|
||||||
@ -115,12 +116,21 @@ func TestSyncDownloader_Stage1_429(t *testing.T) {
|
|||||||
require.Nil(t, result[2].Message.AttData[0])
|
require.Nil(t, result[2].Message.AttData[0])
|
||||||
require.NotEqual(t, attachmentData, result[2].Message.AttData[1])
|
require.NotEqual(t, attachmentData, result[2].Message.AttData[1])
|
||||||
require.NotNil(t, result[2].err)
|
require.NotNil(t, result[2].err)
|
||||||
|
|
||||||
|
_, ok := cache.GetMessage("MsgID1")
|
||||||
|
require.True(t, ok)
|
||||||
|
_, ok = cache.GetMessage("MsgID3")
|
||||||
|
require.True(t, ok)
|
||||||
|
att, ok := cache.GetAttachment("Attachment1_1")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, attachmentData, string(att))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSyncDownloader_Stage2_Everything200(t *testing.T) {
|
func TestSyncDownloader_Stage2_Everything200(t *testing.T) {
|
||||||
mockCtrl := gomock.NewController(t)
|
mockCtrl := gomock.NewController(t)
|
||||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
cache := newSyncDownloadCache()
|
||||||
|
|
||||||
downloadResult := []downloadResult{
|
downloadResult := []downloadResult{
|
||||||
{
|
{
|
||||||
@ -133,7 +143,7 @@ func TestSyncDownloader_Stage2_Everything200(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 2, len(result))
|
require.Equal(t, 2, len(result))
|
||||||
}
|
}
|
||||||
@ -142,6 +152,7 @@ func TestSyncDownloader_Stage2_Not429(t *testing.T) {
|
|||||||
mockCtrl := gomock.NewController(t)
|
mockCtrl := gomock.NewController(t)
|
||||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
cache := newSyncDownloadCache()
|
||||||
|
|
||||||
msgErr := fmt.Errorf("something not 429")
|
msgErr := fmt.Errorf("something not 429")
|
||||||
downloadResult := []downloadResult{
|
downloadResult := []downloadResult{
|
||||||
@ -160,7 +171,7 @@ func TestSyncDownloader_Stage2_Not429(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, msgErr, err)
|
require.Equal(t, msgErr, err)
|
||||||
}
|
}
|
||||||
@ -169,6 +180,7 @@ func TestSyncDownloader_Stage2_API500(t *testing.T) {
|
|||||||
mockCtrl := gomock.NewController(t)
|
mockCtrl := gomock.NewController(t)
|
||||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
cache := newSyncDownloadCache()
|
||||||
|
|
||||||
msgErr := &proton.APIError{Status: 500}
|
msgErr := &proton.APIError{Status: 500}
|
||||||
downloadResult := []downloadResult{
|
downloadResult := []downloadResult{
|
||||||
@ -183,7 +195,7 @@ func TestSyncDownloader_Stage2_API500(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, msgErr, err)
|
require.Equal(t, msgErr, err)
|
||||||
}
|
}
|
||||||
@ -192,6 +204,7 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) {
|
|||||||
mockCtrl := gomock.NewController(t)
|
mockCtrl := gomock.NewController(t)
|
||||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
cache := newSyncDownloadCache()
|
||||||
|
|
||||||
const attachmentData1 = "attachment data 1"
|
const attachmentData1 = "attachment data 1"
|
||||||
const attachmentData2 = "attachment data 2"
|
const attachmentData2 = "attachment data 2"
|
||||||
@ -290,7 +303,7 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 3, len(messages))
|
require.Equal(t, 3, len(messages))
|
||||||
|
|
||||||
@ -304,12 +317,28 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) {
|
|||||||
require.Equal(t, attachmentData1, string(messages[1].AttData[0]))
|
require.Equal(t, attachmentData1, string(messages[1].AttData[0]))
|
||||||
require.Equal(t, attachmentData2, string(messages[1].AttData[1]))
|
require.Equal(t, attachmentData2, string(messages[1].AttData[1]))
|
||||||
require.Empty(t, messages[2].AttData)
|
require.Empty(t, messages[2].AttData)
|
||||||
|
|
||||||
|
_, ok := cache.GetMessage("Msg3")
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
att3, ok := cache.GetAttachment("A3")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, attachmentData3, string(att3))
|
||||||
|
|
||||||
|
att1, ok := cache.GetAttachment("A1")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, attachmentData1, string(att1))
|
||||||
|
|
||||||
|
att2, ok := cache.GetAttachment("A2")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, attachmentData2, string(att2))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) {
|
func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) {
|
||||||
mockCtrl := gomock.NewController(t)
|
mockCtrl := gomock.NewController(t)
|
||||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
cache := newSyncDownloadCache()
|
||||||
|
|
||||||
err429 := &proton.APIError{Status: 429}
|
err429 := &proton.APIError{Status: 429}
|
||||||
err500 := &proton.APIError{Status: 500}
|
err500 := &proton.APIError{Status: 500}
|
||||||
@ -352,7 +381,7 @@ func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) {
|
|||||||
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{}, err500)
|
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{}, err500)
|
||||||
}
|
}
|
||||||
|
|
||||||
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Empty(t, 0, messages)
|
require.Empty(t, 0, messages)
|
||||||
}
|
}
|
||||||
@ -361,6 +390,7 @@ func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) {
|
|||||||
mockCtrl := gomock.NewController(t)
|
mockCtrl := gomock.NewController(t)
|
||||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
cache := newSyncDownloadCache()
|
||||||
|
|
||||||
err429 := &proton.APIError{Status: 429}
|
err429 := &proton.APIError{Status: 429}
|
||||||
err500 := &proton.APIError{Status: 500}
|
err500 := &proton.APIError{Status: 500}
|
||||||
@ -394,7 +424,50 @@ func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) {
|
|||||||
// 500 for second attachment
|
// 500 for second attachment
|
||||||
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A4"), gomock.Any()).Times(1).Return(err500)
|
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A4"), gomock.Any()).Times(1).Return(err500)
|
||||||
|
|
||||||
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Empty(t, 0, messages)
|
require.Empty(t, 0, messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSyncDownloader_Stage1_DoNotDownloadIfAlreadyInCache(t *testing.T) {
|
||||||
|
mockCtrl := gomock.NewController(t)
|
||||||
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
|
panicHandler := &async.NoopPanicHandler{}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
requests := downloadRequest{
|
||||||
|
ids: []string{"Msg1", "Msg3"},
|
||||||
|
expectedSize: 0,
|
||||||
|
err: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := newSyncDownloadCache()
|
||||||
|
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, cache, 1)
|
||||||
|
defer attachmentDownloader.close()
|
||||||
|
|
||||||
|
const attachmentData = "attachment data"
|
||||||
|
|
||||||
|
cache.StoreMessage(proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg1", NumAttachments: 1}, Attachments: []proton.Attachment{{ID: "A1"}}})
|
||||||
|
cache.StoreMessage(proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3", NumAttachments: 2}, Attachments: []proton.Attachment{{ID: "A2"}}})
|
||||||
|
|
||||||
|
cache.StoreAttachment("A1", []byte(attachmentData))
|
||||||
|
cache.StoreAttachment("A2", []byte(attachmentData))
|
||||||
|
|
||||||
|
result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, cache, 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 2, len(result))
|
||||||
|
|
||||||
|
require.Equal(t, result[0].State, downloadStateFinished)
|
||||||
|
require.Equal(t, result[0].Message.ID, "Msg1")
|
||||||
|
require.NotEmpty(t, result[0].Message.AttData)
|
||||||
|
require.NotEqual(t, attachmentData, result[0].Message.AttData[0])
|
||||||
|
require.NotNil(t, result[0].Message.AttData[0])
|
||||||
|
require.Nil(t, result[0].err)
|
||||||
|
|
||||||
|
require.Equal(t, result[1].State, downloadStateFinished)
|
||||||
|
require.Equal(t, result[1].Message.ID, "Msg3")
|
||||||
|
require.NotEmpty(t, result[1].Message.AttData)
|
||||||
|
require.NotEqual(t, attachmentData, result[1].Message.AttData[0])
|
||||||
|
require.NotNil(t, result[1].Message.AttData[0])
|
||||||
|
require.Nil(t, result[1].err)
|
||||||
|
}
|
||||||
|
|||||||
98
internal/user/sync_message_cache.go
Normal file
98
internal/user/sync_message_cache.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
// Copyright (c) 2023 Proton AG
|
||||||
|
//
|
||||||
|
// This file is part of Proton Mail Bridge.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/go-proton-api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SyncDownloadCache struct {
|
||||||
|
messageLock sync.RWMutex
|
||||||
|
messages map[string]proton.Message
|
||||||
|
attachmentLock sync.RWMutex
|
||||||
|
attachments map[string][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSyncDownloadCache() *SyncDownloadCache {
|
||||||
|
return &SyncDownloadCache{
|
||||||
|
messages: make(map[string]proton.Message, 64),
|
||||||
|
attachments: make(map[string][]byte, 64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SyncDownloadCache) StoreMessage(message proton.Message) {
|
||||||
|
s.messageLock.Lock()
|
||||||
|
defer s.messageLock.Unlock()
|
||||||
|
|
||||||
|
s.messages[message.ID] = message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SyncDownloadCache) StoreAttachment(id string, data []byte) {
|
||||||
|
s.attachmentLock.Lock()
|
||||||
|
defer s.attachmentLock.Unlock()
|
||||||
|
|
||||||
|
s.attachments[id] = data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SyncDownloadCache) DeleteMessages(id ...string) {
|
||||||
|
s.messageLock.Lock()
|
||||||
|
defer s.messageLock.Unlock()
|
||||||
|
|
||||||
|
for _, id := range id {
|
||||||
|
delete(s.messages, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SyncDownloadCache) DeleteAttachments(id ...string) {
|
||||||
|
s.attachmentLock.Lock()
|
||||||
|
defer s.attachmentLock.Unlock()
|
||||||
|
|
||||||
|
for _, id := range id {
|
||||||
|
delete(s.attachments, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SyncDownloadCache) GetMessage(id string) (proton.Message, bool) {
|
||||||
|
s.messageLock.RLock()
|
||||||
|
defer s.messageLock.RUnlock()
|
||||||
|
|
||||||
|
v, ok := s.messages[id]
|
||||||
|
|
||||||
|
return v, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SyncDownloadCache) GetAttachment(id string) ([]byte, bool) {
|
||||||
|
s.attachmentLock.RLock()
|
||||||
|
defer s.attachmentLock.RUnlock()
|
||||||
|
|
||||||
|
v, ok := s.attachments[id]
|
||||||
|
|
||||||
|
return v, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SyncDownloadCache) Clear() {
|
||||||
|
s.messageLock.Lock()
|
||||||
|
s.messages = make(map[string]proton.Message, 64)
|
||||||
|
s.messageLock.Unlock()
|
||||||
|
|
||||||
|
s.attachmentLock.Lock()
|
||||||
|
s.attachments = make(map[string][]byte, 64)
|
||||||
|
s.attachmentLock.Unlock()
|
||||||
|
}
|
||||||
@ -93,6 +93,7 @@ type User struct {
|
|||||||
showAllMail uint32
|
showAllMail uint32
|
||||||
|
|
||||||
maxSyncMemory uint64
|
maxSyncMemory uint64
|
||||||
|
syncCache *SyncDownloadCache
|
||||||
|
|
||||||
panicHandler async.PanicHandler
|
panicHandler async.PanicHandler
|
||||||
|
|
||||||
@ -171,6 +172,7 @@ func New(
|
|||||||
showAllMail: b32(showAllMail),
|
showAllMail: b32(showAllMail),
|
||||||
|
|
||||||
maxSyncMemory: maxSyncMemory,
|
maxSyncMemory: maxSyncMemory,
|
||||||
|
syncCache: newSyncDownloadCache(),
|
||||||
|
|
||||||
panicHandler: crashHandler,
|
panicHandler: crashHandler,
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user