mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-18 16:17:03 +00:00
fix(GODT-2822): Sync Cache
When the sync fail, store the previously downloaded data in memory so that on next retry we don't have to re-download everything.
This commit is contained in:
@ -89,6 +89,8 @@ func (user *User) handleRefreshEvent(ctx context.Context, refresh proton.Refresh
|
||||
// Re-sync messages after the user, address and label refresh.
|
||||
defer user.goSync()
|
||||
|
||||
user.syncCache.Clear()
|
||||
|
||||
return user.syncUserAddressesLabelsAndClearSync(ctx, false)
|
||||
}
|
||||
|
||||
|
||||
@ -373,13 +373,15 @@ func (user *User) syncMessages(
|
||||
)
|
||||
|
||||
type flushUpdate struct {
|
||||
messageID string
|
||||
err error
|
||||
batchLen int
|
||||
batchMessageID string
|
||||
messages []proton.FullMessage
|
||||
err error
|
||||
batchLen int
|
||||
}
|
||||
|
||||
type builtMessageBatch struct {
|
||||
batch []*buildRes
|
||||
batch []*buildRes
|
||||
messages []proton.FullMessage
|
||||
}
|
||||
|
||||
downloadCh := make(chan downloadRequest)
|
||||
@ -455,7 +457,7 @@ func (user *User) syncMessages(
|
||||
}, logging.Labels{"sync-stage": "meta-data"})
|
||||
|
||||
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
||||
buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, downloadCh, syncLimits)
|
||||
buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, user.syncCache, downloadCh, syncLimits)
|
||||
|
||||
// Goroutine which builds messages after they have been downloaded
|
||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
@ -501,7 +503,7 @@ func (user *User) syncMessages(
|
||||
}
|
||||
|
||||
select {
|
||||
case flushCh <- builtMessageBatch{result}:
|
||||
case flushCh <- builtMessageBatch{batch: result, messages: buildBatch.batch}:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@ -580,9 +582,11 @@ func (user *User) syncMessages(
|
||||
|
||||
select {
|
||||
case flushUpdateCh <- flushUpdate{
|
||||
messageID: downloadBatch.batch[0].messageID,
|
||||
err: nil,
|
||||
batchLen: len(downloadBatch.batch),
|
||||
batchMessageID: downloadBatch.batch[0].messageID,
|
||||
messages: downloadBatch.messages,
|
||||
|
||||
err: nil,
|
||||
batchLen: len(downloadBatch.batch),
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@ -595,14 +599,29 @@ func (user *User) syncMessages(
|
||||
return flushUpdate.err
|
||||
}
|
||||
|
||||
if err := vault.SetLastMessageID(flushUpdate.messageID); err != nil {
|
||||
if err := vault.SetLastMessageID(flushUpdate.batchMessageID); err != nil {
|
||||
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
||||
}
|
||||
|
||||
for _, m := range flushUpdate.messages {
|
||||
user.syncCache.DeleteMessages(m.ID)
|
||||
if m.NumAttachments != 0 {
|
||||
user.syncCache.DeleteAttachments(xslices.Map(m.Attachments, func(a proton.Attachment) string {
|
||||
return a.ID
|
||||
})...)
|
||||
}
|
||||
}
|
||||
|
||||
syncReporter.add(flushUpdate.batchLen)
|
||||
}
|
||||
|
||||
return <-errorCh
|
||||
err := <-errorCh
|
||||
|
||||
if err != nil {
|
||||
user.syncCache.Clear()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *imap.MailboxCreated {
|
||||
|
||||
@ -63,7 +63,14 @@ type downloadResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, downloadCh <-chan downloadRequest, syncLimits syncLimits) (<-chan downloadedMessageBatch, <-chan error) {
|
||||
func startSyncDownloader(
|
||||
ctx context.Context,
|
||||
panicHandler async.PanicHandler,
|
||||
downloader MessageDownloader,
|
||||
cache *SyncDownloadCache,
|
||||
downloadCh <-chan downloadRequest,
|
||||
syncLimits syncLimits,
|
||||
) (<-chan downloadedMessageBatch, <-chan error) {
|
||||
buildCh := make(chan downloadedMessageBatch)
|
||||
errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
|
||||
|
||||
@ -75,7 +82,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d
|
||||
logrus.Debugf("sync downloader exit")
|
||||
}()
|
||||
|
||||
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, syncLimits.MaxParallelDownloads)
|
||||
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, cache, syncLimits.MaxParallelDownloads)
|
||||
defer attachmentDownloader.close()
|
||||
|
||||
for request := range downloadCh {
|
||||
@ -85,7 +92,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d
|
||||
return
|
||||
}
|
||||
|
||||
result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, syncLimits.MaxParallelDownloads)
|
||||
result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, cache, syncLimits.MaxParallelDownloads)
|
||||
if err != nil {
|
||||
errorCh <- err
|
||||
return
|
||||
@ -96,7 +103,7 @@ func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, d
|
||||
return
|
||||
}
|
||||
|
||||
batch, err := downloadMessagesStage2(ctx, result, downloader, SyncRetryCooldown)
|
||||
batch, err := downloadMessagesStage2(ctx, result, downloader, cache, SyncRetryCooldown)
|
||||
if err != nil {
|
||||
errorCh <- err
|
||||
return
|
||||
@ -132,7 +139,7 @@ type attachmentDownloader struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <-chan attachmentJob) {
|
||||
func attachmentWorker(ctx context.Context, downloader MessageDownloader, cache *SyncDownloadCache, work <-chan attachmentJob) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@ -141,26 +148,45 @@ func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <-
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var b bytes.Buffer
|
||||
b.Grow(int(job.size))
|
||||
err := downloader.GetAttachmentInto(ctx, job.id, &b)
|
||||
|
||||
var result attachmentResult
|
||||
if data, ok := cache.GetAttachment(job.id); ok {
|
||||
result.attachment = data
|
||||
result.err = nil
|
||||
} else {
|
||||
var b bytes.Buffer
|
||||
b.Grow(int(job.size))
|
||||
err := downloader.GetAttachmentInto(ctx, job.id, &b)
|
||||
result.attachment = b.Bytes()
|
||||
result.err = err
|
||||
if err == nil {
|
||||
cache.StoreAttachment(job.id, result.attachment)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
close(job.result)
|
||||
return
|
||||
case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
|
||||
case job.result <- result:
|
||||
close(job.result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newAttachmentDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, workerCount int) *attachmentDownloader {
|
||||
func newAttachmentDownloader(
|
||||
ctx context.Context,
|
||||
panicHandler async.PanicHandler,
|
||||
downloader MessageDownloader,
|
||||
cache *SyncDownloadCache,
|
||||
workerCount int,
|
||||
) *attachmentDownloader {
|
||||
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
for i := 0; i < workerCount; i++ {
|
||||
workerCh = make(chan attachmentJob)
|
||||
async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, workerCh) }, logging.Labels{
|
||||
async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, cache, workerCh) }, logging.Labels{
|
||||
"sync": fmt.Sprintf("att-downloader %v", i),
|
||||
})
|
||||
}
|
||||
@ -209,6 +235,7 @@ func downloadMessageStage1(
|
||||
request downloadRequest,
|
||||
downloader MessageDownloader,
|
||||
attachmentDownloader *attachmentDownloader,
|
||||
cache *SyncDownloadCache,
|
||||
parallelDownloads int,
|
||||
) ([]downloadResult, error) {
|
||||
// 1st attempt download everything in parallel
|
||||
@ -217,21 +244,28 @@ func downloadMessageStage1(
|
||||
|
||||
result := downloadResult{ID: id}
|
||||
|
||||
msg, err := downloader.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
|
||||
result.err = err
|
||||
return result, nil
|
||||
v, ok := cache.GetMessage(id)
|
||||
if !ok {
|
||||
msg, err := downloader.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
|
||||
result.err = err
|
||||
return result, nil
|
||||
}
|
||||
|
||||
cache.StoreMessage(msg)
|
||||
result.Message.Message = msg
|
||||
} else {
|
||||
result.Message.Message = v
|
||||
}
|
||||
|
||||
result.Message.Message = msg
|
||||
result.State = downloadStateHasMessage
|
||||
|
||||
attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
|
||||
attachments, err := attachmentDownloader.getAttachments(ctx, result.Message.Attachments)
|
||||
result.Message.AttData = attachments
|
||||
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments")
|
||||
logrus.WithError(err).WithField("msgID", id).Error("Failed to download message attachments")
|
||||
result.err = err
|
||||
return result, nil
|
||||
}
|
||||
@ -242,7 +276,13 @@ func downloadMessageStage1(
|
||||
})
|
||||
}
|
||||
|
||||
func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloader MessageDownloader, coolDown time.Duration) ([]proton.FullMessage, error) {
|
||||
func downloadMessagesStage2(
|
||||
ctx context.Context,
|
||||
state []downloadResult,
|
||||
downloader MessageDownloader,
|
||||
cache *SyncDownloadCache,
|
||||
coolDown time.Duration,
|
||||
) ([]proton.FullMessage, error) {
|
||||
logrus.Debug("Entering download stage 2")
|
||||
var retryList []int
|
||||
var shouldWaitBeforeRetry bool
|
||||
@ -289,6 +329,7 @@ func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloa
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache.StoreMessage(message)
|
||||
st.Message.Message = message
|
||||
st.State = downloadStateHasMessage
|
||||
}
|
||||
@ -314,6 +355,7 @@ func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloa
|
||||
}
|
||||
|
||||
st.Message.AttData[i] = buffer.Bytes()
|
||||
cache.StoreAttachment(st.Message.Attachments[i].ID, st.Message.AttData[i])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -89,10 +89,11 @@ func TestSyncDownloader_Stage1_429(t *testing.T) {
|
||||
return err
|
||||
})
|
||||
|
||||
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, 1)
|
||||
cache := newSyncDownloadCache()
|
||||
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, cache, 1)
|
||||
defer attachmentDownloader.close()
|
||||
|
||||
result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, 1)
|
||||
result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, cache, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(result))
|
||||
// Check message 1
|
||||
@ -115,12 +116,21 @@ func TestSyncDownloader_Stage1_429(t *testing.T) {
|
||||
require.Nil(t, result[2].Message.AttData[0])
|
||||
require.NotEqual(t, attachmentData, result[2].Message.AttData[1])
|
||||
require.NotNil(t, result[2].err)
|
||||
|
||||
_, ok := cache.GetMessage("MsgID1")
|
||||
require.True(t, ok)
|
||||
_, ok = cache.GetMessage("MsgID3")
|
||||
require.True(t, ok)
|
||||
att, ok := cache.GetAttachment("Attachment1_1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, attachmentData, string(att))
|
||||
}
|
||||
|
||||
func TestSyncDownloader_Stage2_Everything200(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||
ctx := context.Background()
|
||||
cache := newSyncDownloadCache()
|
||||
|
||||
downloadResult := []downloadResult{
|
||||
{
|
||||
@ -133,7 +143,7 @@ func TestSyncDownloader_Stage2_Everything200(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||
result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(result))
|
||||
}
|
||||
@ -142,6 +152,7 @@ func TestSyncDownloader_Stage2_Not429(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||
ctx := context.Background()
|
||||
cache := newSyncDownloadCache()
|
||||
|
||||
msgErr := fmt.Errorf("something not 429")
|
||||
downloadResult := []downloadResult{
|
||||
@ -160,7 +171,7 @@ func TestSyncDownloader_Stage2_Not429(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, msgErr, err)
|
||||
}
|
||||
@ -169,6 +180,7 @@ func TestSyncDownloader_Stage2_API500(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||
ctx := context.Background()
|
||||
cache := newSyncDownloadCache()
|
||||
|
||||
msgErr := &proton.APIError{Status: 500}
|
||||
downloadResult := []downloadResult{
|
||||
@ -183,7 +195,7 @@ func TestSyncDownloader_Stage2_API500(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, msgErr, err)
|
||||
}
|
||||
@ -192,6 +204,7 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||
ctx := context.Background()
|
||||
cache := newSyncDownloadCache()
|
||||
|
||||
const attachmentData1 = "attachment data 1"
|
||||
const attachmentData2 = "attachment data 2"
|
||||
@ -290,7 +303,7 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(messages))
|
||||
|
||||
@ -304,12 +317,28 @@ func TestSyncDownloader_Stage2_Some429(t *testing.T) {
|
||||
require.Equal(t, attachmentData1, string(messages[1].AttData[0]))
|
||||
require.Equal(t, attachmentData2, string(messages[1].AttData[1]))
|
||||
require.Empty(t, messages[2].AttData)
|
||||
|
||||
_, ok := cache.GetMessage("Msg3")
|
||||
require.True(t, ok)
|
||||
|
||||
att3, ok := cache.GetAttachment("A3")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, attachmentData3, string(att3))
|
||||
|
||||
att1, ok := cache.GetAttachment("A1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, attachmentData1, string(att1))
|
||||
|
||||
att2, ok := cache.GetAttachment("A2")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, attachmentData2, string(att2))
|
||||
}
|
||||
|
||||
func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||
ctx := context.Background()
|
||||
cache := newSyncDownloadCache()
|
||||
|
||||
err429 := &proton.APIError{Status: 429}
|
||||
err500 := &proton.APIError{Status: 500}
|
||||
@ -352,7 +381,7 @@ func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) {
|
||||
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{}, err500)
|
||||
}
|
||||
|
||||
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||
require.Error(t, err)
|
||||
require.Empty(t, 0, messages)
|
||||
}
|
||||
@ -361,6 +390,7 @@ func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||
ctx := context.Background()
|
||||
cache := newSyncDownloadCache()
|
||||
|
||||
err429 := &proton.APIError{Status: 429}
|
||||
err500 := &proton.APIError{Status: 500}
|
||||
@ -394,7 +424,50 @@ func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) {
|
||||
// 500 for second attachment
|
||||
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A4"), gomock.Any()).Times(1).Return(err500)
|
||||
|
||||
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, cache, time.Millisecond)
|
||||
require.Error(t, err)
|
||||
require.Empty(t, 0, messages)
|
||||
}
|
||||
|
||||
func TestSyncDownloader_Stage1_DoNotDownloadIfAlreadyInCache(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||
panicHandler := &async.NoopPanicHandler{}
|
||||
ctx := context.Background()
|
||||
|
||||
requests := downloadRequest{
|
||||
ids: []string{"Msg1", "Msg3"},
|
||||
expectedSize: 0,
|
||||
err: nil,
|
||||
}
|
||||
|
||||
cache := newSyncDownloadCache()
|
||||
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, cache, 1)
|
||||
defer attachmentDownloader.close()
|
||||
|
||||
const attachmentData = "attachment data"
|
||||
|
||||
cache.StoreMessage(proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg1", NumAttachments: 1}, Attachments: []proton.Attachment{{ID: "A1"}}})
|
||||
cache.StoreMessage(proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3", NumAttachments: 2}, Attachments: []proton.Attachment{{ID: "A2"}}})
|
||||
|
||||
cache.StoreAttachment("A1", []byte(attachmentData))
|
||||
cache.StoreAttachment("A2", []byte(attachmentData))
|
||||
|
||||
result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, cache, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(result))
|
||||
|
||||
require.Equal(t, result[0].State, downloadStateFinished)
|
||||
require.Equal(t, result[0].Message.ID, "Msg1")
|
||||
require.NotEmpty(t, result[0].Message.AttData)
|
||||
require.NotEqual(t, attachmentData, result[0].Message.AttData[0])
|
||||
require.NotNil(t, result[0].Message.AttData[0])
|
||||
require.Nil(t, result[0].err)
|
||||
|
||||
require.Equal(t, result[1].State, downloadStateFinished)
|
||||
require.Equal(t, result[1].Message.ID, "Msg3")
|
||||
require.NotEmpty(t, result[1].Message.AttData)
|
||||
require.NotEqual(t, attachmentData, result[1].Message.AttData[0])
|
||||
require.NotNil(t, result[1].Message.AttData[0])
|
||||
require.Nil(t, result[1].err)
|
||||
}
|
||||
|
||||
98
internal/user/sync_message_cache.go
Normal file
98
internal/user/sync_message_cache.go
Normal file
@ -0,0 +1,98 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
)
|
||||
|
||||
type SyncDownloadCache struct {
|
||||
messageLock sync.RWMutex
|
||||
messages map[string]proton.Message
|
||||
attachmentLock sync.RWMutex
|
||||
attachments map[string][]byte
|
||||
}
|
||||
|
||||
func newSyncDownloadCache() *SyncDownloadCache {
|
||||
return &SyncDownloadCache{
|
||||
messages: make(map[string]proton.Message, 64),
|
||||
attachments: make(map[string][]byte, 64),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SyncDownloadCache) StoreMessage(message proton.Message) {
|
||||
s.messageLock.Lock()
|
||||
defer s.messageLock.Unlock()
|
||||
|
||||
s.messages[message.ID] = message
|
||||
}
|
||||
|
||||
func (s *SyncDownloadCache) StoreAttachment(id string, data []byte) {
|
||||
s.attachmentLock.Lock()
|
||||
defer s.attachmentLock.Unlock()
|
||||
|
||||
s.attachments[id] = data
|
||||
}
|
||||
|
||||
func (s *SyncDownloadCache) DeleteMessages(id ...string) {
|
||||
s.messageLock.Lock()
|
||||
defer s.messageLock.Unlock()
|
||||
|
||||
for _, id := range id {
|
||||
delete(s.messages, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SyncDownloadCache) DeleteAttachments(id ...string) {
|
||||
s.attachmentLock.Lock()
|
||||
defer s.attachmentLock.Unlock()
|
||||
|
||||
for _, id := range id {
|
||||
delete(s.attachments, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SyncDownloadCache) GetMessage(id string) (proton.Message, bool) {
|
||||
s.messageLock.RLock()
|
||||
defer s.messageLock.RUnlock()
|
||||
|
||||
v, ok := s.messages[id]
|
||||
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (s *SyncDownloadCache) GetAttachment(id string) ([]byte, bool) {
|
||||
s.attachmentLock.RLock()
|
||||
defer s.attachmentLock.RUnlock()
|
||||
|
||||
v, ok := s.attachments[id]
|
||||
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (s *SyncDownloadCache) Clear() {
|
||||
s.messageLock.Lock()
|
||||
s.messages = make(map[string]proton.Message, 64)
|
||||
s.messageLock.Unlock()
|
||||
|
||||
s.attachmentLock.Lock()
|
||||
s.attachments = make(map[string][]byte, 64)
|
||||
s.attachmentLock.Unlock()
|
||||
}
|
||||
@ -93,6 +93,7 @@ type User struct {
|
||||
showAllMail uint32
|
||||
|
||||
maxSyncMemory uint64
|
||||
syncCache *SyncDownloadCache
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
|
||||
@ -171,6 +172,7 @@ func New(
|
||||
showAllMail: b32(showAllMail),
|
||||
|
||||
maxSyncMemory: maxSyncMemory,
|
||||
syncCache: newSyncDownloadCache(),
|
||||
|
||||
panicHandler: crashHandler,
|
||||
|
||||
|
||||
Reference in New Issue
Block a user