diff --git a/Makefile b/Makefile
index c4156f84..fd200737 100644
--- a/Makefile
+++ b/Makefile
@@ -274,6 +274,7 @@ mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/telemetry HeartbeatManager > internal/telemetry/mocks/mocks.go
cp internal/telemetry/mocks/mocks.go internal/bridge/mocks/telemetry_mocks.go
+ mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/user MessageDownloader > internal/user/mocks/mocks.go
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog
diff --git a/internal/user/mocks/mocks.go b/internal/user/mocks/mocks.go
new file mode 100644
index 00000000..0d72426f
--- /dev/null
+++ b/internal/user/mocks/mocks.go
@@ -0,0 +1,66 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/ProtonMail/proton-bridge/v3/internal/user (interfaces: MessageDownloader)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+ context "context"
+ io "io"
+ reflect "reflect"
+
+ proton "github.com/ProtonMail/go-proton-api"
+ gomock "github.com/golang/mock/gomock"
+)
+
+// MockMessageDownloader is a mock of MessageDownloader interface.
+type MockMessageDownloader struct {
+ ctrl *gomock.Controller
+ recorder *MockMessageDownloaderMockRecorder
+}
+
+// MockMessageDownloaderMockRecorder is the mock recorder for MockMessageDownloader.
+type MockMessageDownloaderMockRecorder struct {
+ mock *MockMessageDownloader
+}
+
+// NewMockMessageDownloader creates a new mock instance.
+func NewMockMessageDownloader(ctrl *gomock.Controller) *MockMessageDownloader {
+ mock := &MockMessageDownloader{ctrl: ctrl}
+ mock.recorder = &MockMessageDownloaderMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockMessageDownloader) EXPECT() *MockMessageDownloaderMockRecorder {
+ return m.recorder
+}
+
+// GetAttachmentInto mocks base method.
+func (m *MockMessageDownloader) GetAttachmentInto(arg0 context.Context, arg1 string, arg2 io.ReaderFrom) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetAttachmentInto", arg0, arg1, arg2)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// GetAttachmentInto indicates an expected call of GetAttachmentInto.
+func (mr *MockMessageDownloaderMockRecorder) GetAttachmentInto(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachmentInto", reflect.TypeOf((*MockMessageDownloader)(nil).GetAttachmentInto), arg0, arg1, arg2)
+}
+
+// GetMessage mocks base method.
+func (m *MockMessageDownloader) GetMessage(arg0 context.Context, arg1 string) (proton.Message, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetMessage", arg0, arg1)
+ ret0, _ := ret[0].(proton.Message)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetMessage indicates an expected call of GetMessage.
+func (mr *MockMessageDownloaderMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockMessageDownloader)(nil).GetMessage), arg0, arg1)
+}
diff --git a/internal/user/sync.go b/internal/user/sync.go
index 9fb59342..84cce7c6 100644
--- a/internal/user/sync.go
+++ b/internal/user/sync.go
@@ -378,32 +378,18 @@ func (user *User) syncMessages(
batchLen int
}
- type downloadRequest struct {
- ids []string
- expectedSize uint64
- err error
- }
-
- type downloadedMessageBatch struct {
- batch []proton.FullMessage
- }
-
type builtMessageBatch struct {
batch []*buildRes
}
downloadCh := make(chan downloadRequest)
- buildCh := make(chan downloadedMessageBatch)
-
// The higher this value, the longer we can continue our download iteration before being blocked on channel writes
// to the update flushing goroutine.
flushCh := make(chan builtMessageBatch)
flushUpdateCh := make(chan flushUpdate)
- errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
-
// Go routine in charge of downloading message metadata
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(downloadCh)
@@ -469,65 +455,7 @@ func (user *User) syncMessages(
}, logging.Labels{"sync-stage": "meta-data"})
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
- async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
- defer close(buildCh)
- defer close(errorCh)
- defer func() {
- logrus.Debugf("sync downloader exit")
- }()
-
- attachmentDownloader := user.newAttachmentDownloader(ctx, client, syncLimits.MaxParallelDownloads)
- defer attachmentDownloader.close()
-
- for request := range downloadCh {
- logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize))
- if request.err != nil {
- errorCh <- request.err
- return
- }
-
- if ctx.Err() != nil {
- errorCh <- ctx.Err()
- return
- }
-
- result, err := parallel.MapContext(ctx, syncLimits.MaxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
- defer async.HandlePanic(user.panicHandler)
-
- var result proton.FullMessage
-
- msg, err := client.GetMessage(ctx, id)
- if err != nil {
- logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
- return proton.FullMessage{}, err
- }
-
- attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
- if err != nil {
- logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments")
- return proton.FullMessage{}, err
- }
-
- result.Message = msg
- result.AttData = attachments
-
- return result, nil
- })
- if err != nil {
- errorCh <- err
- return
- }
-
- select {
- case buildCh <- downloadedMessageBatch{
- batch: result,
- }:
-
- case <-ctx.Done():
- return
- }
- }
- }, logging.Labels{"sync-stage": "download"})
+ buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, downloadCh, syncLimits)
// Goroutine which builds messages after they have been downloaded
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
@@ -793,93 +721,6 @@ func wantLabels(apiLabels map[string]proton.Label, labelIDs []string) []string {
})
}
-type attachmentResult struct {
- attachment []byte
- err error
-}
-
-type attachmentJob struct {
- id string
- size int64
- result chan attachmentResult
-}
-
-type attachmentDownloader struct {
- workerCh chan attachmentJob
- cancel context.CancelFunc
-}
-
-func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan attachmentJob) {
- for {
- select {
- case <-ctx.Done():
- return
- case job, ok := <-work:
- if !ok {
- return
- }
- var b bytes.Buffer
- b.Grow(int(job.size))
- err := client.GetAttachmentInto(ctx, job.id, &b)
- select {
- case <-ctx.Done():
- close(job.result)
- return
- case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
- close(job.result)
- }
- }
- }
-}
-
-func (user *User) newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
- workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
- ctx, cancel := context.WithCancel(ctx)
- for i := 0; i < workerCount; i++ {
- workerCh = make(chan attachmentJob)
- async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
- "sync": fmt.Sprintf("att-downloader %v", i),
- })
- }
-
- return &attachmentDownloader{
- workerCh: workerCh,
- cancel: cancel,
- }
-}
-
-func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) {
- resultChs := make([]chan attachmentResult, len(attachments))
- for i, id := range attachments {
- resultChs[i] = make(chan attachmentResult, 1)
- select {
- case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}:
- case <-ctx.Done():
- return nil, ctx.Err()
- }
- }
-
- result := make([][]byte, len(attachments))
- var err error
- for i := 0; i < len(attachments); i++ {
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case r := <-resultChs[i]:
- if r.err != nil {
- err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err)
- }
- result[i] = r.attachment
- }
- }
-
- return result, err
-}
-
-func (a *attachmentDownloader) close() {
- a.cancel()
-}
-
func chunkSyncBuilderBatch(batch []proton.FullMessage, maxMemory uint64) [][]proton.FullMessage {
var expectedMemUsage uint64
var chunks [][]proton.FullMessage
diff --git a/internal/user/sync_downloader.go b/internal/user/sync_downloader.go
new file mode 100644
index 00000000..4d61b2f0
--- /dev/null
+++ b/internal/user/sync_downloader.go
@@ -0,0 +1,339 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package user
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "time"
+
+ "github.com/ProtonMail/gluon/async"
+ "github.com/ProtonMail/gluon/logging"
+ "github.com/ProtonMail/go-proton-api"
+ "github.com/bradenaw/juniper/parallel"
+ "github.com/bradenaw/juniper/xslices"
+ "github.com/sirupsen/logrus"
+)
+
+type downloadRequest struct {
+ ids []string
+ expectedSize uint64
+ err error
+}
+
+type downloadedMessageBatch struct {
+ batch []proton.FullMessage
+}
+
+type MessageDownloader interface {
+ GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error
+ GetMessage(ctx context.Context, messageID string) (proton.Message, error)
+}
+
+type downloadState int
+
+const (
+ downloadStateZero downloadState = iota
+ downloadStateHasMessage
+ downloadStateFinished
+)
+
+type downloadResult struct {
+ ID string
+ State downloadState
+ Message proton.FullMessage
+ err error
+}
+
+func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, downloadCh <-chan downloadRequest, syncLimits syncLimits) (<-chan downloadedMessageBatch, <-chan error) {
+ buildCh := make(chan downloadedMessageBatch)
+ errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
+
+ // Goroutine in charge of downloading and building messages in maxBatchSize batches.
+ async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) {
+ defer close(buildCh)
+ defer close(errorCh)
+ defer func() {
+ logrus.Debugf("sync downloader exit")
+ }()
+
+ attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, syncLimits.MaxParallelDownloads)
+ defer attachmentDownloader.close()
+
+ for request := range downloadCh {
+ logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize))
+ if request.err != nil {
+ errorCh <- request.err
+ return
+ }
+
+ result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, syncLimits.MaxParallelDownloads)
+ if err != nil {
+ errorCh <- err
+ return
+ }
+
+ if ctx.Err() != nil {
+ errorCh <- ctx.Err()
+ return
+ }
+
+ batch, err := downloadMessagesStage2(ctx, result, downloader, SyncRetryCooldown)
+ if err != nil {
+ errorCh <- err
+ return
+ }
+
+ select {
+ case buildCh <- downloadedMessageBatch{
+ batch: batch,
+ }:
+
+ case <-ctx.Done():
+ return
+ }
+ }
+ }, logging.Labels{"sync-stage": "download"})
+
+ return buildCh, errorCh
+}
+
+type attachmentResult struct {
+ attachment []byte
+ err error
+}
+
+type attachmentJob struct {
+ id string
+ size int64
+ result chan attachmentResult
+}
+
+type attachmentDownloader struct {
+ workerCh chan attachmentJob
+ cancel context.CancelFunc
+}
+
+func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <-chan attachmentJob) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case job, ok := <-work:
+ if !ok {
+ return
+ }
+ var b bytes.Buffer
+ b.Grow(int(job.size))
+ err := downloader.GetAttachmentInto(ctx, job.id, &b)
+ select {
+ case <-ctx.Done():
+ close(job.result)
+ return
+ case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
+ close(job.result)
+ }
+ }
+ }
+}
+
+func newAttachmentDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, workerCount int) *attachmentDownloader {
+ workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
+ ctx, cancel := context.WithCancel(ctx)
+ for i := 0; i < workerCount; i++ {
+ workerCh = make(chan attachmentJob)
+ async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, workerCh) }, logging.Labels{
+ "sync": fmt.Sprintf("att-downloader %v", i),
+ })
+ }
+
+ return &attachmentDownloader{
+ workerCh: workerCh,
+ cancel: cancel,
+ }
+}
+
+func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) {
+ resultChs := make([]chan attachmentResult, len(attachments))
+ for i, id := range attachments {
+ resultChs[i] = make(chan attachmentResult, 1)
+ select {
+ case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ }
+
+ result := make([][]byte, len(attachments))
+ var err error
+ for i := 0; i < len(attachments); i++ {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case r := <-resultChs[i]:
+ if r.err != nil {
+ err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err)
+ }
+ result[i] = r.attachment
+ }
+ }
+
+ return result, err
+}
+
+func (a *attachmentDownloader) close() {
+ a.cancel()
+}
+
+func downloadMessageStage1(
+ ctx context.Context,
+ panicHandler async.PanicHandler,
+ request downloadRequest,
+ downloader MessageDownloader,
+ attachmentDownloader *attachmentDownloader,
+ parallelDownloads int,
+) ([]downloadResult, error) {
+ // 1st attempt download everything in parallel
+ return parallel.MapContext(ctx, parallelDownloads, request.ids, func(ctx context.Context, id string) (downloadResult, error) {
+ defer async.HandlePanic(panicHandler)
+
+ result := downloadResult{ID: id}
+
+ msg, err := downloader.GetMessage(ctx, id)
+ if err != nil {
+ logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
+ result.err = err
+ return result, nil
+ }
+
+ result.Message.Message = msg
+ result.State = downloadStateHasMessage
+
+ attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
+ result.Message.AttData = attachments
+
+ if err != nil {
+ logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments")
+ result.err = err
+ return result, nil
+ }
+
+ result.State = downloadStateFinished
+
+ return result, nil
+ })
+}
+
+func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloader MessageDownloader, coolDown time.Duration) ([]proton.FullMessage, error) {
+ logrus.Debug("Entering download stage 2")
+ var retryList []int
+ var shouldWaitBeforeRetry bool
+
+ for {
+ if shouldWaitBeforeRetry {
+ time.Sleep(coolDown)
+ }
+
+ retryList = nil
+ shouldWaitBeforeRetry = false
+
+ for index, s := range state {
+ if s.State == downloadStateFinished {
+ continue
+ }
+
+ if s.err != nil {
+ if is429Error(s.err) {
+ logrus.WithField("msg-id", s.ID).Debug("Message download failed due to 429, retrying")
+ retryList = append(retryList, index)
+ continue
+ }
+ return nil, s.err
+ }
+ }
+
+ if len(retryList) == 0 {
+ break
+ }
+
+ for _, i := range retryList {
+ st := &state[i]
+ if st.State == downloadStateZero {
+ message, err := downloader.GetMessage(ctx, st.ID)
+ if err != nil {
+ logrus.WithField("msg-id", st.ID).WithError(err).Error("failed to download message (429)")
+ if is429Error(err) {
+ st.err = err
+ shouldWaitBeforeRetry = true
+ continue
+ }
+
+ return nil, err
+ }
+
+ st.Message.Message = message
+ st.State = downloadStateHasMessage
+ }
+
+ if st.Message.AttData == nil && st.Message.NumAttachments != 0 {
+ st.Message.AttData = make([][]byte, st.Message.NumAttachments)
+ }
+
+ hasAllAttachments := true
+ for i := 0; i < st.Message.NumAttachments; i++ {
+ if st.Message.AttData[i] == nil {
+ buffer := bytes.Buffer{}
+ if err := downloader.GetAttachmentInto(ctx, st.Message.Attachments[i].ID, &buffer); err != nil {
+ logrus.WithField("msg-id", st.ID).WithError(err).Errorf("failed to download attachment %v/%v (429)", i+1, len(st.Message.Attachments))
+ if is429Error(err) {
+ st.err = err
+ shouldWaitBeforeRetry = true
+ hasAllAttachments = false
+ continue
+ }
+
+ return nil, err
+ }
+
+ st.Message.AttData[i] = buffer.Bytes()
+ }
+ }
+
+ if hasAllAttachments {
+ st.State = downloadStateFinished
+ }
+ }
+ }
+
+ logrus.Debug("All message downloaded successfully")
+ return xslices.Map(state, func(s downloadResult) proton.FullMessage {
+ return s.Message
+ }), nil
+}
+
+func is429Error(err error) bool {
+ var apiErr *proton.APIError
+ if errors.As(err, &apiErr) {
+ return apiErr.Status == 429
+ }
+
+ return false
+}
diff --git a/internal/user/sync_downloader_test.go b/internal/user/sync_downloader_test.go
new file mode 100644
index 00000000..b8edfdd9
--- /dev/null
+++ b/internal/user/sync_downloader_test.go
@@ -0,0 +1,400 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package user
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/ProtonMail/gluon/async"
+ "github.com/ProtonMail/go-proton-api"
+ "github.com/ProtonMail/proton-bridge/v3/internal/user/mocks"
+ "github.com/golang/mock/gomock"
+ "github.com/stretchr/testify/require"
+)
+
+func TestSyncDownloader_Stage1_429(t *testing.T) {
+ // Check 429 is correctly caught and download state recorded correctly
+ // Message 1: All ok
+ // Message 2: Message failed
+ // Message 3: One attachment failed.
+ mockCtrl := gomock.NewController(t)
+ messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
+ panicHandler := &async.NoopPanicHandler{}
+ ctx := context.Background()
+
+ requests := downloadRequest{
+ ids: []string{"Msg1", "Msg2", "Msg3"},
+ expectedSize: 0,
+ err: nil,
+ }
+
+ messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg1")).Times(1).Return(proton.Message{
+ MessageMetadata: proton.MessageMetadata{
+ ID: "MsgID1",
+ NumAttachments: 1,
+ },
+ Attachments: []proton.Attachment{
+ {
+ ID: "Attachment1_1",
+ },
+ },
+ }, nil)
+
+ messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg2")).Times(1).Return(proton.Message{}, &proton.APIError{Status: 429})
+ messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{
+ MessageMetadata: proton.MessageMetadata{
+ ID: "MsgID3",
+ NumAttachments: 2,
+ },
+ Attachments: []proton.Attachment{
+ {
+ ID: "Attachment3_1",
+ },
+ {
+ ID: "Attachment3_2",
+ },
+ },
+ }, nil)
+
+ const attachmentData = "attachment data"
+
+ messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment1_1"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
+ _, err := r.ReadFrom(strings.NewReader(attachmentData))
+ return err
+ })
+
+ messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment3_1"), gomock.Any()).Times(1).Return(&proton.APIError{Status: 429})
+ messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment3_2"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
+ _, err := r.ReadFrom(strings.NewReader(attachmentData))
+ return err
+ })
+
+ attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, 1)
+ defer attachmentDownloader.close()
+
+ result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, 1)
+ require.NoError(t, err)
+ require.Equal(t, 3, len(result))
+ // Check message 1
+ require.Equal(t, result[0].State, downloadStateFinished)
+ require.Equal(t, result[0].Message.ID, "MsgID1")
+ require.NotEmpty(t, result[0].Message.AttData)
+ require.NotEqual(t, attachmentData, result[0].Message.AttData[0])
+ require.NotNil(t, result[0].Message.AttData[0])
+ require.Nil(t, result[0].err)
+
+ // Check message 2
+ require.Equal(t, result[1].State, downloadStateZero)
+ require.Empty(t, result[1].Message.ID)
+ require.NotNil(t, result[1].err)
+
+ require.Equal(t, result[2].State, downloadStateHasMessage)
+ require.Equal(t, result[2].Message.ID, "MsgID3")
+ require.Equal(t, 2, len(result[2].Message.AttData))
+ require.NotNil(t, result[2].err)
+ require.Nil(t, result[2].Message.AttData[0])
+ require.NotEqual(t, attachmentData, result[2].Message.AttData[1])
+ require.NotNil(t, result[2].err)
+}
+
+func TestSyncDownloader_Stage2_Everything200(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
+ ctx := context.Background()
+
+ downloadResult := []downloadResult{
+ {
+ ID: "Msg1",
+ State: downloadStateFinished,
+ },
+ {
+ ID: "Msg2",
+ State: downloadStateFinished,
+ },
+ }
+
+ result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
+ require.NoError(t, err)
+ require.Equal(t, 2, len(result))
+}
+
+func TestSyncDownloader_Stage2_Not429(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
+ ctx := context.Background()
+
+ msgErr := fmt.Errorf("something not 429")
+ downloadResult := []downloadResult{
+ {
+ ID: "Msg1",
+ State: downloadStateFinished,
+ },
+ {
+ ID: "Msg2",
+ State: downloadStateHasMessage,
+ err: msgErr,
+ },
+ {
+ ID: "Msg3",
+ State: downloadStateFinished,
+ },
+ }
+
+ _, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
+ require.Error(t, err)
+ require.Equal(t, msgErr, err)
+}
+
+func TestSyncDownloader_Stage2_API500(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
+ ctx := context.Background()
+
+ msgErr := &proton.APIError{Status: 500}
+ downloadResult := []downloadResult{
+ {
+ ID: "Msg2",
+ State: downloadStateHasMessage,
+ err: msgErr,
+ },
+ {
+ ID: "Msg3",
+ State: downloadStateFinished,
+ },
+ }
+
+ _, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
+ require.Error(t, err)
+ require.Equal(t, msgErr, err)
+}
+
+func TestSyncDownloader_Stage2_Some429(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
+ ctx := context.Background()
+
+ const attachmentData1 = "attachment data 1"
+ const attachmentData2 = "attachment data 2"
+ const attachmentData3 = "attachment data 3"
+ const attachmentData4 = "attachment data 4"
+
+ err429 := &proton.APIError{Status: 429}
+ downloadResult := []downloadResult{
+ {
+ // Full message , but missing 1 of 2 attachments
+ ID: "Msg1",
+ Message: proton.FullMessage{
+ Message: proton.Message{
+ MessageMetadata: proton.MessageMetadata{
+ ID: "Msg1",
+ NumAttachments: 2,
+ },
+ Attachments: []proton.Attachment{
+ {
+ ID: "A3",
+ },
+ {
+ ID: "A4",
+ },
+ },
+ },
+ AttData: [][]byte{
+ nil,
+ []byte(attachmentData4),
+ },
+ },
+ State: downloadStateHasMessage,
+ err: err429,
+ },
+ {
+ // Full message, but missing all attachments
+ ID: "Msg2",
+ Message: proton.FullMessage{
+ Message: proton.Message{
+ MessageMetadata: proton.MessageMetadata{
+ ID: "Msg2",
+ NumAttachments: 2,
+ },
+ Attachments: []proton.Attachment{
+ {
+ ID: "A1",
+ },
+ {
+ ID: "A2",
+ },
+ },
+ },
+ AttData: nil,
+ },
+ State: downloadStateHasMessage,
+ err: err429,
+ },
+ {
+ // Missing everything
+ ID: "Msg3",
+ State: downloadStateZero,
+ Message: proton.FullMessage{
+ Message: proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3"}},
+ },
+ err: err429,
+ },
+ }
+
+ {
+ // Simulate 2 failures for message 3 body.
+ firstCall := messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(2).Return(proton.Message{}, err429)
+ messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).After(firstCall).Times(1).Return(proton.Message{
+ MessageMetadata: proton.MessageMetadata{
+ ID: "Msg3",
+ },
+ }, nil)
+ }
+
+ {
+ // Simulate failures for message 2 attachments.
+ firstCall := messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A1"), gomock.Any()).Times(2).Return(err429)
+ messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A1"), gomock.Any()).After(firstCall).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
+ _, err := r.ReadFrom(strings.NewReader(attachmentData1))
+ return err
+ })
+ messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A2"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
+ _, err := r.ReadFrom(strings.NewReader(attachmentData2))
+ return err
+ })
+ }
+
+ {
+ messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A3"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
+ _, err := r.ReadFrom(strings.NewReader(attachmentData3))
+ return err
+ })
+ }
+
+ messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
+ require.NoError(t, err)
+ require.Equal(t, 3, len(messages))
+
+ require.Equal(t, messages[0].Message.ID, "Msg1")
+ require.Equal(t, messages[1].Message.ID, "Msg2")
+ require.Equal(t, messages[2].Message.ID, "Msg3")
+
+ // check attachments
+ require.Equal(t, attachmentData3, string(messages[0].AttData[0]))
+ require.Equal(t, attachmentData4, string(messages[0].AttData[1]))
+ require.Equal(t, attachmentData1, string(messages[1].AttData[0]))
+ require.Equal(t, attachmentData2, string(messages[1].AttData[1]))
+ require.Empty(t, messages[2].AttData)
+}
+
+func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
+ ctx := context.Background()
+
+ err429 := &proton.APIError{Status: 429}
+ err500 := &proton.APIError{Status: 500}
+ downloadResult := []downloadResult{
+ {
+ // Missing everything
+ ID: "Msg3",
+ State: downloadStateZero,
+ Message: proton.FullMessage{
+ Message: proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3"}},
+ },
+ err: err429,
+ },
+ {
+ // Full message , but missing 1 of 2 attachments
+ ID: "Msg1",
+ Message: proton.FullMessage{
+ Message: proton.Message{
+ MessageMetadata: proton.MessageMetadata{
+ ID: "Msg1",
+ NumAttachments: 2,
+ },
+ Attachments: []proton.Attachment{
+ {
+ ID: "A3",
+ },
+ {
+ ID: "A4",
+ },
+ },
+ },
+ },
+ State: downloadStateHasMessage,
+ err: err429,
+ },
+ }
+
+ {
+ // Simulate 2 failures for message 3 body,
+ messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{}, err500)
+ }
+
+ messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
+ require.Error(t, err)
+ require.Empty(t, 0, messages)
+}
+
+func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
+ ctx := context.Background()
+
+ err429 := &proton.APIError{Status: 429}
+ err500 := &proton.APIError{Status: 500}
+ downloadResult := []downloadResult{
+ {
+ // Full message , but missing 1 of 2 attachments
+ ID: "Msg1",
+ Message: proton.FullMessage{
+ Message: proton.Message{
+ MessageMetadata: proton.MessageMetadata{
+ ID: "Msg1",
+ NumAttachments: 2,
+ },
+ Attachments: []proton.Attachment{
+ {
+ ID: "A3",
+ },
+ {
+ ID: "A4",
+ },
+ },
+ },
+ },
+ State: downloadStateHasMessage,
+ err: err429,
+ },
+ }
+
+ // 429 for first attachment
+ messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A3"), gomock.Any()).Times(1).Return(err429)
+ // 500 for second attachment
+ messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A4"), gomock.Any()).Times(1).Return(err500)
+
+ messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
+ require.Error(t, err)
+ require.Empty(t, 0, messages)
+}