From 5136919c369affdf60c792e5681d4f060e0ee269 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Wed, 26 Jul 2023 14:54:58 +0200 Subject: [PATCH] fix(GODT-2822): Handle 429 during message download When we run into 429 during a message download, do not cancel the whole batch and switch to a sequential downloader to avoid API overload. --- Makefile | 1 + internal/user/mocks/mocks.go | 66 +++++ internal/user/sync.go | 161 +---------- internal/user/sync_downloader.go | 339 ++++++++++++++++++++++ internal/user/sync_downloader_test.go | 400 ++++++++++++++++++++++++++ 5 files changed, 807 insertions(+), 160 deletions(-) create mode 100644 internal/user/mocks/mocks.go create mode 100644 internal/user/sync_downloader.go create mode 100644 internal/user/sync_downloader_test.go diff --git a/Makefile b/Makefile index c4156f84..fd200737 100644 --- a/Makefile +++ b/Makefile @@ -274,6 +274,7 @@ mocks: mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/telemetry HeartbeatManager > internal/telemetry/mocks/mocks.go cp internal/telemetry/mocks/mocks.go internal/bridge/mocks/telemetry_mocks.go + mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/user MessageDownloader > internal/user/mocks/mocks.go lint: gofiles lint-golang lint-license lint-dependencies lint-changelog diff --git a/internal/user/mocks/mocks.go b/internal/user/mocks/mocks.go new file mode 100644 index 00000000..0d72426f --- /dev/null +++ b/internal/user/mocks/mocks.go @@ -0,0 +1,66 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ProtonMail/proton-bridge/v3/internal/user (interfaces: MessageDownloader) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + io "io" + reflect "reflect" + + proton "github.com/ProtonMail/go-proton-api" + gomock "github.com/golang/mock/gomock" +) + +// MockMessageDownloader is a mock of MessageDownloader interface. +type MockMessageDownloader struct { + ctrl *gomock.Controller + recorder *MockMessageDownloaderMockRecorder +} + +// MockMessageDownloaderMockRecorder is the mock recorder for MockMessageDownloader. +type MockMessageDownloaderMockRecorder struct { + mock *MockMessageDownloader +} + +// NewMockMessageDownloader creates a new mock instance. +func NewMockMessageDownloader(ctrl *gomock.Controller) *MockMessageDownloader { + mock := &MockMessageDownloader{ctrl: ctrl} + mock.recorder = &MockMessageDownloaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMessageDownloader) EXPECT() *MockMessageDownloaderMockRecorder { + return m.recorder +} + +// GetAttachmentInto mocks base method. +func (m *MockMessageDownloader) GetAttachmentInto(arg0 context.Context, arg1 string, arg2 io.ReaderFrom) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAttachmentInto", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// GetAttachmentInto indicates an expected call of GetAttachmentInto. +func (mr *MockMessageDownloaderMockRecorder) GetAttachmentInto(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachmentInto", reflect.TypeOf((*MockMessageDownloader)(nil).GetAttachmentInto), arg0, arg1, arg2) +} + +// GetMessage mocks base method. +func (m *MockMessageDownloader) GetMessage(arg0 context.Context, arg1 string) (proton.Message, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMessage", arg0, arg1) + ret0, _ := ret[0].(proton.Message) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMessage indicates an expected call of GetMessage. +func (mr *MockMessageDownloaderMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockMessageDownloader)(nil).GetMessage), arg0, arg1) +} diff --git a/internal/user/sync.go b/internal/user/sync.go index 9fb59342..84cce7c6 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -378,32 +378,18 @@ func (user *User) syncMessages( batchLen int } - type downloadRequest struct { - ids []string - expectedSize uint64 - err error - } - - type downloadedMessageBatch struct { - batch []proton.FullMessage - } - type builtMessageBatch struct { batch []*buildRes } downloadCh := make(chan downloadRequest) - buildCh := make(chan downloadedMessageBatch) - // The higher this value, the longer we can continue our download iteration before being blocked on channel writes // to the update flushing goroutine. flushCh := make(chan builtMessageBatch) flushUpdateCh := make(chan flushUpdate) - errorCh := make(chan error, syncLimits.MaxParallelDownloads*4) - // Go routine in charge of downloading message metadata async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { defer close(downloadCh) @@ -469,65 +455,7 @@ func (user *User) syncMessages( }, logging.Labels{"sync-stage": "meta-data"}) // Goroutine in charge of downloading and building messages in maxBatchSize batches. - async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { - defer close(buildCh) - defer close(errorCh) - defer func() { - logrus.Debugf("sync downloader exit") - }() - - attachmentDownloader := user.newAttachmentDownloader(ctx, client, syncLimits.MaxParallelDownloads) - defer attachmentDownloader.close() - - for request := range downloadCh { - logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize)) - if request.err != nil { - errorCh <- request.err - return - } - - if ctx.Err() != nil { - errorCh <- ctx.Err() - return - } - - result, err := parallel.MapContext(ctx, syncLimits.MaxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) { - defer async.HandlePanic(user.panicHandler) - - var result proton.FullMessage - - msg, err := client.GetMessage(ctx, id) - if err != nil { - logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message") - return proton.FullMessage{}, err - } - - attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments) - if err != nil { - logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments") - return proton.FullMessage{}, err - } - - result.Message = msg - result.AttData = attachments - - return result, nil - }) - if err != nil { - errorCh <- err - return - } - - select { - case buildCh <- downloadedMessageBatch{ - batch: result, - }: - - case <-ctx.Done(): - return - } - } - }, logging.Labels{"sync-stage": "download"}) + buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, downloadCh, syncLimits) // Goroutine which builds messages after they have been downloaded async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { @@ -793,93 +721,6 @@ func wantLabels(apiLabels map[string]proton.Label, labelIDs []string) []string { }) } -type attachmentResult struct { - attachment []byte - err error -} - -type attachmentJob struct { - id string - size int64 - result chan attachmentResult -} - -type attachmentDownloader struct { - workerCh chan attachmentJob - cancel context.CancelFunc -} - -func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan attachmentJob) { - for { - select { - case <-ctx.Done(): - return - case job, ok := <-work: - if !ok { - return - } - var b bytes.Buffer - b.Grow(int(job.size)) - err := client.GetAttachmentInto(ctx, job.id, &b) - select { - case <-ctx.Done(): - close(job.result) - return - case job.result <- attachmentResult{attachment: b.Bytes(), err: err}: - close(job.result) - } - } - } -} - -func (user *User) newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader { - workerCh := make(chan attachmentJob, (workerCount+2)*workerCount) - ctx, cancel := context.WithCancel(ctx) - for i := 0; i < workerCount; i++ { - workerCh = make(chan attachmentJob) - async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{ - "sync": fmt.Sprintf("att-downloader %v", i), - }) - } - - return &attachmentDownloader{ - workerCh: workerCh, - cancel: cancel, - } -} - -func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) { - resultChs := make([]chan attachmentResult, len(attachments)) - for i, id := range attachments { - resultChs[i] = make(chan attachmentResult, 1) - select { - case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}: - case <-ctx.Done(): - return nil, ctx.Err() - } - } - - result := make([][]byte, len(attachments)) - var err error - for i := 0; i < len(attachments); i++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case r := <-resultChs[i]: - if r.err != nil { - err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err) - } - result[i] = r.attachment - } - } - - return result, err -} - -func (a *attachmentDownloader) close() { - a.cancel() -} - func chunkSyncBuilderBatch(batch []proton.FullMessage, maxMemory uint64) [][]proton.FullMessage { var expectedMemUsage uint64 var chunks [][]proton.FullMessage diff --git a/internal/user/sync_downloader.go b/internal/user/sync_downloader.go new file mode 100644 index 00000000..4d61b2f0 --- /dev/null +++ b/internal/user/sync_downloader.go @@ -0,0 +1,339 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package user + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "time" + + "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/logging" + "github.com/ProtonMail/go-proton-api" + "github.com/bradenaw/juniper/parallel" + "github.com/bradenaw/juniper/xslices" + "github.com/sirupsen/logrus" +) + +type downloadRequest struct { + ids []string + expectedSize uint64 + err error +} + +type downloadedMessageBatch struct { + batch []proton.FullMessage +} + +type MessageDownloader interface { + GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error + GetMessage(ctx context.Context, messageID string) (proton.Message, error) +} + +type downloadState int + +const ( + downloadStateZero downloadState = iota + downloadStateHasMessage + downloadStateFinished +) + +type downloadResult struct { + ID string + State downloadState + Message proton.FullMessage + err error +} + +func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, downloadCh <-chan downloadRequest, syncLimits syncLimits) (<-chan downloadedMessageBatch, <-chan error) { + buildCh := make(chan downloadedMessageBatch) + errorCh := make(chan error, syncLimits.MaxParallelDownloads*4) + + // Goroutine in charge of downloading and building messages in maxBatchSize batches. + async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { + defer close(buildCh) + defer close(errorCh) + defer func() { + logrus.Debugf("sync downloader exit") + }() + + attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, syncLimits.MaxParallelDownloads) + defer attachmentDownloader.close() + + for request := range downloadCh { + logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize)) + if request.err != nil { + errorCh <- request.err + return + } + + result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, syncLimits.MaxParallelDownloads) + if err != nil { + errorCh <- err + return + } + + if ctx.Err() != nil { + errorCh <- ctx.Err() + return + } + + batch, err := downloadMessagesStage2(ctx, result, downloader, SyncRetryCooldown) + if err != nil { + errorCh <- err + return + } + + select { + case buildCh <- downloadedMessageBatch{ + batch: batch, + }: + + case <-ctx.Done(): + return + } + } + }, logging.Labels{"sync-stage": "download"}) + + return buildCh, errorCh +} + +type attachmentResult struct { + attachment []byte + err error +} + +type attachmentJob struct { + id string + size int64 + result chan attachmentResult +} + +type attachmentDownloader struct { + workerCh chan attachmentJob + cancel context.CancelFunc +} + +func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <-chan attachmentJob) { + for { + select { + case <-ctx.Done(): + return + case job, ok := <-work: + if !ok { + return + } + var b bytes.Buffer + b.Grow(int(job.size)) + err := downloader.GetAttachmentInto(ctx, job.id, &b) + select { + case <-ctx.Done(): + close(job.result) + return + case job.result <- attachmentResult{attachment: b.Bytes(), err: err}: + close(job.result) + } + } + } +} + +func newAttachmentDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, workerCount int) *attachmentDownloader { + workerCh := make(chan attachmentJob, (workerCount+2)*workerCount) + ctx, cancel := context.WithCancel(ctx) + for i := 0; i < workerCount; i++ { + workerCh = make(chan attachmentJob) + async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, workerCh) }, logging.Labels{ + "sync": fmt.Sprintf("att-downloader %v", i), + }) + } + + return &attachmentDownloader{ + workerCh: workerCh, + cancel: cancel, + } +} + +func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) { + resultChs := make([]chan attachmentResult, len(attachments)) + for i, id := range attachments { + resultChs[i] = make(chan attachmentResult, 1) + select { + case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}: + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + result := make([][]byte, len(attachments)) + var err error + for i := 0; i < len(attachments); i++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case r := <-resultChs[i]: + if r.err != nil { + err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err) + } + result[i] = r.attachment + } + } + + return result, err +} + +func (a *attachmentDownloader) close() { + a.cancel() +} + +func downloadMessageStage1( + ctx context.Context, + panicHandler async.PanicHandler, + request downloadRequest, + downloader MessageDownloader, + attachmentDownloader *attachmentDownloader, + parallelDownloads int, +) ([]downloadResult, error) { + // 1st attempt download everything in parallel + return parallel.MapContext(ctx, parallelDownloads, request.ids, func(ctx context.Context, id string) (downloadResult, error) { + defer async.HandlePanic(panicHandler) + + result := downloadResult{ID: id} + + msg, err := downloader.GetMessage(ctx, id) + if err != nil { + logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message") + result.err = err + return result, nil + } + + result.Message.Message = msg + result.State = downloadStateHasMessage + + attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments) + result.Message.AttData = attachments + + if err != nil { + logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments") + result.err = err + return result, nil + } + + result.State = downloadStateFinished + + return result, nil + }) +} + +func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloader MessageDownloader, coolDown time.Duration) ([]proton.FullMessage, error) { + logrus.Debug("Entering download stage 2") + var retryList []int + var shouldWaitBeforeRetry bool + + for { + if shouldWaitBeforeRetry { + time.Sleep(coolDown) + } + + retryList = nil + shouldWaitBeforeRetry = false + + for index, s := range state { + if s.State == downloadStateFinished { + continue + } + + if s.err != nil { + if is429Error(s.err) { + logrus.WithField("msg-id", s.ID).Debug("Message download failed due to 429, retrying") + retryList = append(retryList, index) + continue + } + return nil, s.err + } + } + + if len(retryList) == 0 { + break + } + + for _, i := range retryList { + st := &state[i] + if st.State == downloadStateZero { + message, err := downloader.GetMessage(ctx, st.ID) + if err != nil { + logrus.WithField("msg-id", st.ID).WithError(err).Error("failed to download message (429)") + if is429Error(err) { + st.err = err + shouldWaitBeforeRetry = true + continue + } + + return nil, err + } + + st.Message.Message = message + st.State = downloadStateHasMessage + } + + if st.Message.AttData == nil && st.Message.NumAttachments != 0 { + st.Message.AttData = make([][]byte, st.Message.NumAttachments) + } + + hasAllAttachments := true + for i := 0; i < st.Message.NumAttachments; i++ { + if st.Message.AttData[i] == nil { + buffer := bytes.Buffer{} + if err := downloader.GetAttachmentInto(ctx, st.Message.Attachments[i].ID, &buffer); err != nil { + logrus.WithField("msg-id", st.ID).WithError(err).Errorf("failed to download attachment %v/%v (429)", i+1, len(st.Message.Attachments)) + if is429Error(err) { + st.err = err + shouldWaitBeforeRetry = true + hasAllAttachments = false + continue + } + + return nil, err + } + + st.Message.AttData[i] = buffer.Bytes() + } + } + + if hasAllAttachments { + st.State = downloadStateFinished + } + } + } + + logrus.Debug("All message downloaded successfully") + return xslices.Map(state, func(s downloadResult) proton.FullMessage { + return s.Message + }), nil +} + +func is429Error(err error) bool { + var apiErr *proton.APIError + if errors.As(err, &apiErr) { + return apiErr.Status == 429 + } + + return false +} diff --git a/internal/user/sync_downloader_test.go b/internal/user/sync_downloader_test.go new file mode 100644 index 00000000..b8edfdd9 --- /dev/null +++ b/internal/user/sync_downloader_test.go @@ -0,0 +1,400 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package user + +import ( + "context" + "fmt" + "io" + "strings" + "testing" + "time" + + "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal/user/mocks" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +func TestSyncDownloader_Stage1_429(t *testing.T) { + // Check 429 is correctly caught and download state recorded correctly + // Message 1: All ok + // Message 2: Message failed + // Message 3: One attachment failed. + mockCtrl := gomock.NewController(t) + messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) + panicHandler := &async.NoopPanicHandler{} + ctx := context.Background() + + requests := downloadRequest{ + ids: []string{"Msg1", "Msg2", "Msg3"}, + expectedSize: 0, + err: nil, + } + + messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg1")).Times(1).Return(proton.Message{ + MessageMetadata: proton.MessageMetadata{ + ID: "MsgID1", + NumAttachments: 1, + }, + Attachments: []proton.Attachment{ + { + ID: "Attachment1_1", + }, + }, + }, nil) + + messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg2")).Times(1).Return(proton.Message{}, &proton.APIError{Status: 429}) + messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{ + MessageMetadata: proton.MessageMetadata{ + ID: "MsgID3", + NumAttachments: 2, + }, + Attachments: []proton.Attachment{ + { + ID: "Attachment3_1", + }, + { + ID: "Attachment3_2", + }, + }, + }, nil) + + const attachmentData = "attachment data" + + messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment1_1"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error { + _, err := r.ReadFrom(strings.NewReader(attachmentData)) + return err + }) + + messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment3_1"), gomock.Any()).Times(1).Return(&proton.APIError{Status: 429}) + messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment3_2"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error { + _, err := r.ReadFrom(strings.NewReader(attachmentData)) + return err + }) + + attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, 1) + defer attachmentDownloader.close() + + result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, 1) + require.NoError(t, err) + require.Equal(t, 3, len(result)) + // Check message 1 + require.Equal(t, result[0].State, downloadStateFinished) + require.Equal(t, result[0].Message.ID, "MsgID1") + require.NotEmpty(t, result[0].Message.AttData) + require.NotEqual(t, attachmentData, result[0].Message.AttData[0]) + require.NotNil(t, result[0].Message.AttData[0]) + require.Nil(t, result[0].err) + + // Check message 2 + require.Equal(t, result[1].State, downloadStateZero) + require.Empty(t, result[1].Message.ID) + require.NotNil(t, result[1].err) + + require.Equal(t, result[2].State, downloadStateHasMessage) + require.Equal(t, result[2].Message.ID, "MsgID3") + require.Equal(t, 2, len(result[2].Message.AttData)) + require.NotNil(t, result[2].err) + require.Nil(t, result[2].Message.AttData[0]) + require.NotEqual(t, attachmentData, result[2].Message.AttData[1]) + require.NotNil(t, result[2].err) +} + +func TestSyncDownloader_Stage2_Everything200(t *testing.T) { + mockCtrl := gomock.NewController(t) + messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) + ctx := context.Background() + + downloadResult := []downloadResult{ + { + ID: "Msg1", + State: downloadStateFinished, + }, + { + ID: "Msg2", + State: downloadStateFinished, + }, + } + + result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + require.NoError(t, err) + require.Equal(t, 2, len(result)) +} + +func TestSyncDownloader_Stage2_Not429(t *testing.T) { + mockCtrl := gomock.NewController(t) + messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) + ctx := context.Background() + + msgErr := fmt.Errorf("something not 429") + downloadResult := []downloadResult{ + { + ID: "Msg1", + State: downloadStateFinished, + }, + { + ID: "Msg2", + State: downloadStateHasMessage, + err: msgErr, + }, + { + ID: "Msg3", + State: downloadStateFinished, + }, + } + + _, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + require.Error(t, err) + require.Equal(t, msgErr, err) +} + +func TestSyncDownloader_Stage2_API500(t *testing.T) { + mockCtrl := gomock.NewController(t) + messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) + ctx := context.Background() + + msgErr := &proton.APIError{Status: 500} + downloadResult := []downloadResult{ + { + ID: "Msg2", + State: downloadStateHasMessage, + err: msgErr, + }, + { + ID: "Msg3", + State: downloadStateFinished, + }, + } + + _, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + require.Error(t, err) + require.Equal(t, msgErr, err) +} + +func TestSyncDownloader_Stage2_Some429(t *testing.T) { + mockCtrl := gomock.NewController(t) + messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) + ctx := context.Background() + + const attachmentData1 = "attachment data 1" + const attachmentData2 = "attachment data 2" + const attachmentData3 = "attachment data 3" + const attachmentData4 = "attachment data 4" + + err429 := &proton.APIError{Status: 429} + downloadResult := []downloadResult{ + { + // Full message , but missing 1 of 2 attachments + ID: "Msg1", + Message: proton.FullMessage{ + Message: proton.Message{ + MessageMetadata: proton.MessageMetadata{ + ID: "Msg1", + NumAttachments: 2, + }, + Attachments: []proton.Attachment{ + { + ID: "A3", + }, + { + ID: "A4", + }, + }, + }, + AttData: [][]byte{ + nil, + []byte(attachmentData4), + }, + }, + State: downloadStateHasMessage, + err: err429, + }, + { + // Full message, but missing all attachments + ID: "Msg2", + Message: proton.FullMessage{ + Message: proton.Message{ + MessageMetadata: proton.MessageMetadata{ + ID: "Msg2", + NumAttachments: 2, + }, + Attachments: []proton.Attachment{ + { + ID: "A1", + }, + { + ID: "A2", + }, + }, + }, + AttData: nil, + }, + State: downloadStateHasMessage, + err: err429, + }, + { + // Missing everything + ID: "Msg3", + State: downloadStateZero, + Message: proton.FullMessage{ + Message: proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3"}}, + }, + err: err429, + }, + } + + { + // Simulate 2 failures for message 3 body. + firstCall := messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(2).Return(proton.Message{}, err429) + messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).After(firstCall).Times(1).Return(proton.Message{ + MessageMetadata: proton.MessageMetadata{ + ID: "Msg3", + }, + }, nil) + } + + { + // Simulate failures for message 2 attachments. + firstCall := messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A1"), gomock.Any()).Times(2).Return(err429) + messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A1"), gomock.Any()).After(firstCall).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error { + _, err := r.ReadFrom(strings.NewReader(attachmentData1)) + return err + }) + messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A2"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error { + _, err := r.ReadFrom(strings.NewReader(attachmentData2)) + return err + }) + } + + { + messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A3"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error { + _, err := r.ReadFrom(strings.NewReader(attachmentData3)) + return err + }) + } + + messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + require.NoError(t, err) + require.Equal(t, 3, len(messages)) + + require.Equal(t, messages[0].Message.ID, "Msg1") + require.Equal(t, messages[1].Message.ID, "Msg2") + require.Equal(t, messages[2].Message.ID, "Msg3") + + // check attachments + require.Equal(t, attachmentData3, string(messages[0].AttData[0])) + require.Equal(t, attachmentData4, string(messages[0].AttData[1])) + require.Equal(t, attachmentData1, string(messages[1].AttData[0])) + require.Equal(t, attachmentData2, string(messages[1].AttData[1])) + require.Empty(t, messages[2].AttData) +} + +func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) { + mockCtrl := gomock.NewController(t) + messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) + ctx := context.Background() + + err429 := &proton.APIError{Status: 429} + err500 := &proton.APIError{Status: 500} + downloadResult := []downloadResult{ + { + // Missing everything + ID: "Msg3", + State: downloadStateZero, + Message: proton.FullMessage{ + Message: proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3"}}, + }, + err: err429, + }, + { + // Full message , but missing 1 of 2 attachments + ID: "Msg1", + Message: proton.FullMessage{ + Message: proton.Message{ + MessageMetadata: proton.MessageMetadata{ + ID: "Msg1", + NumAttachments: 2, + }, + Attachments: []proton.Attachment{ + { + ID: "A3", + }, + { + ID: "A4", + }, + }, + }, + }, + State: downloadStateHasMessage, + err: err429, + }, + } + + { + // Simulate 2 failures for message 3 body, + messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{}, err500) + } + + messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + require.Error(t, err) + require.Empty(t, 0, messages) +} + +func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) { + mockCtrl := gomock.NewController(t) + messageDownloader := mocks.NewMockMessageDownloader(mockCtrl) + ctx := context.Background() + + err429 := &proton.APIError{Status: 429} + err500 := &proton.APIError{Status: 500} + downloadResult := []downloadResult{ + { + // Full message , but missing 1 of 2 attachments + ID: "Msg1", + Message: proton.FullMessage{ + Message: proton.Message{ + MessageMetadata: proton.MessageMetadata{ + ID: "Msg1", + NumAttachments: 2, + }, + Attachments: []proton.Attachment{ + { + ID: "A3", + }, + { + ID: "A4", + }, + }, + }, + }, + State: downloadStateHasMessage, + err: err429, + }, + } + + // 429 for first attachment + messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A3"), gomock.Any()).Times(1).Return(err429) + // 500 for second attachment + messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A4"), gomock.Any()).Times(1).Return(err500) + + messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond) + require.Error(t, err) + require.Empty(t, 0, messages) +}