forked from Silverfish/proton-bridge
fix(GODT-2822): Handle 429 during message download
When we run into 429 during a message download, do not cancel the whole batch and switch to a sequential downloader to avoid API overload.
This commit is contained in:
1
Makefile
1
Makefile
@ -274,6 +274,7 @@ mocks:
|
|||||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go
|
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go
|
||||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/telemetry HeartbeatManager > internal/telemetry/mocks/mocks.go
|
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/telemetry HeartbeatManager > internal/telemetry/mocks/mocks.go
|
||||||
cp internal/telemetry/mocks/mocks.go internal/bridge/mocks/telemetry_mocks.go
|
cp internal/telemetry/mocks/mocks.go internal/bridge/mocks/telemetry_mocks.go
|
||||||
|
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/user MessageDownloader > internal/user/mocks/mocks.go
|
||||||
|
|
||||||
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog
|
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog
|
||||||
|
|
||||||
|
|||||||
66
internal/user/mocks/mocks.go
Normal file
66
internal/user/mocks/mocks.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/ProtonMail/proton-bridge/v3/internal/user (interfaces: MessageDownloader)
|
||||||
|
|
||||||
|
// Package mocks is a generated GoMock package.
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
io "io"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
proton "github.com/ProtonMail/go-proton-api"
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockMessageDownloader is a mock of MessageDownloader interface.
|
||||||
|
type MockMessageDownloader struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockMessageDownloaderMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockMessageDownloaderMockRecorder is the mock recorder for MockMessageDownloader.
|
||||||
|
type MockMessageDownloaderMockRecorder struct {
|
||||||
|
mock *MockMessageDownloader
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockMessageDownloader creates a new mock instance.
|
||||||
|
func NewMockMessageDownloader(ctrl *gomock.Controller) *MockMessageDownloader {
|
||||||
|
mock := &MockMessageDownloader{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockMessageDownloaderMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockMessageDownloader) EXPECT() *MockMessageDownloaderMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAttachmentInto mocks base method.
|
||||||
|
func (m *MockMessageDownloader) GetAttachmentInto(arg0 context.Context, arg1 string, arg2 io.ReaderFrom) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetAttachmentInto", arg0, arg1, arg2)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAttachmentInto indicates an expected call of GetAttachmentInto.
|
||||||
|
func (mr *MockMessageDownloaderMockRecorder) GetAttachmentInto(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachmentInto", reflect.TypeOf((*MockMessageDownloader)(nil).GetAttachmentInto), arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessage mocks base method.
|
||||||
|
func (m *MockMessageDownloader) GetMessage(arg0 context.Context, arg1 string) (proton.Message, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetMessage", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(proton.Message)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessage indicates an expected call of GetMessage.
|
||||||
|
func (mr *MockMessageDownloaderMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockMessageDownloader)(nil).GetMessage), arg0, arg1)
|
||||||
|
}
|
||||||
@ -378,32 +378,18 @@ func (user *User) syncMessages(
|
|||||||
batchLen int
|
batchLen int
|
||||||
}
|
}
|
||||||
|
|
||||||
type downloadRequest struct {
|
|
||||||
ids []string
|
|
||||||
expectedSize uint64
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
type downloadedMessageBatch struct {
|
|
||||||
batch []proton.FullMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
type builtMessageBatch struct {
|
type builtMessageBatch struct {
|
||||||
batch []*buildRes
|
batch []*buildRes
|
||||||
}
|
}
|
||||||
|
|
||||||
downloadCh := make(chan downloadRequest)
|
downloadCh := make(chan downloadRequest)
|
||||||
|
|
||||||
buildCh := make(chan downloadedMessageBatch)
|
|
||||||
|
|
||||||
// The higher this value, the longer we can continue our download iteration before being blocked on channel writes
|
// The higher this value, the longer we can continue our download iteration before being blocked on channel writes
|
||||||
// to the update flushing goroutine.
|
// to the update flushing goroutine.
|
||||||
flushCh := make(chan builtMessageBatch)
|
flushCh := make(chan builtMessageBatch)
|
||||||
|
|
||||||
flushUpdateCh := make(chan flushUpdate)
|
flushUpdateCh := make(chan flushUpdate)
|
||||||
|
|
||||||
errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
|
|
||||||
|
|
||||||
// Go routine in charge of downloading message metadata
|
// Go routine in charge of downloading message metadata
|
||||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||||
defer close(downloadCh)
|
defer close(downloadCh)
|
||||||
@ -469,65 +455,7 @@ func (user *User) syncMessages(
|
|||||||
}, logging.Labels{"sync-stage": "meta-data"})
|
}, logging.Labels{"sync-stage": "meta-data"})
|
||||||
|
|
||||||
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
||||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
buildCh, errorCh := startSyncDownloader(ctx, user.panicHandler, user.client, downloadCh, syncLimits)
|
||||||
defer close(buildCh)
|
|
||||||
defer close(errorCh)
|
|
||||||
defer func() {
|
|
||||||
logrus.Debugf("sync downloader exit")
|
|
||||||
}()
|
|
||||||
|
|
||||||
attachmentDownloader := user.newAttachmentDownloader(ctx, client, syncLimits.MaxParallelDownloads)
|
|
||||||
defer attachmentDownloader.close()
|
|
||||||
|
|
||||||
for request := range downloadCh {
|
|
||||||
logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize))
|
|
||||||
if request.err != nil {
|
|
||||||
errorCh <- request.err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
errorCh <- ctx.Err()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := parallel.MapContext(ctx, syncLimits.MaxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
|
|
||||||
defer async.HandlePanic(user.panicHandler)
|
|
||||||
|
|
||||||
var result proton.FullMessage
|
|
||||||
|
|
||||||
msg, err := client.GetMessage(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
|
|
||||||
return proton.FullMessage{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments")
|
|
||||||
return proton.FullMessage{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
result.Message = msg
|
|
||||||
result.AttData = attachments
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
errorCh <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case buildCh <- downloadedMessageBatch{
|
|
||||||
batch: result,
|
|
||||||
}:
|
|
||||||
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, logging.Labels{"sync-stage": "download"})
|
|
||||||
|
|
||||||
// Goroutine which builds messages after they have been downloaded
|
// Goroutine which builds messages after they have been downloaded
|
||||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||||
@ -793,93 +721,6 @@ func wantLabels(apiLabels map[string]proton.Label, labelIDs []string) []string {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type attachmentResult struct {
|
|
||||||
attachment []byte
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
type attachmentJob struct {
|
|
||||||
id string
|
|
||||||
size int64
|
|
||||||
result chan attachmentResult
|
|
||||||
}
|
|
||||||
|
|
||||||
type attachmentDownloader struct {
|
|
||||||
workerCh chan attachmentJob
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan attachmentJob) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case job, ok := <-work:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var b bytes.Buffer
|
|
||||||
b.Grow(int(job.size))
|
|
||||||
err := client.GetAttachmentInto(ctx, job.id, &b)
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
close(job.result)
|
|
||||||
return
|
|
||||||
case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
|
|
||||||
close(job.result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (user *User) newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
|
|
||||||
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
for i := 0; i < workerCount; i++ {
|
|
||||||
workerCh = make(chan attachmentJob)
|
|
||||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
|
|
||||||
"sync": fmt.Sprintf("att-downloader %v", i),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return &attachmentDownloader{
|
|
||||||
workerCh: workerCh,
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) {
|
|
||||||
resultChs := make([]chan attachmentResult, len(attachments))
|
|
||||||
for i, id := range attachments {
|
|
||||||
resultChs[i] = make(chan attachmentResult, 1)
|
|
||||||
select {
|
|
||||||
case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([][]byte, len(attachments))
|
|
||||||
var err error
|
|
||||||
for i := 0; i < len(attachments); i++ {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case r := <-resultChs[i]:
|
|
||||||
if r.err != nil {
|
|
||||||
err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err)
|
|
||||||
}
|
|
||||||
result[i] = r.attachment
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *attachmentDownloader) close() {
|
|
||||||
a.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
func chunkSyncBuilderBatch(batch []proton.FullMessage, maxMemory uint64) [][]proton.FullMessage {
|
func chunkSyncBuilderBatch(batch []proton.FullMessage, maxMemory uint64) [][]proton.FullMessage {
|
||||||
var expectedMemUsage uint64
|
var expectedMemUsage uint64
|
||||||
var chunks [][]proton.FullMessage
|
var chunks [][]proton.FullMessage
|
||||||
|
|||||||
339
internal/user/sync_downloader.go
Normal file
339
internal/user/sync_downloader.go
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
// Copyright (c) 2023 Proton AG
|
||||||
|
//
|
||||||
|
// This file is part of Proton Mail Bridge.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/gluon/async"
|
||||||
|
"github.com/ProtonMail/gluon/logging"
|
||||||
|
"github.com/ProtonMail/go-proton-api"
|
||||||
|
"github.com/bradenaw/juniper/parallel"
|
||||||
|
"github.com/bradenaw/juniper/xslices"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type downloadRequest struct {
|
||||||
|
ids []string
|
||||||
|
expectedSize uint64
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type downloadedMessageBatch struct {
|
||||||
|
batch []proton.FullMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageDownloader interface {
|
||||||
|
GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error
|
||||||
|
GetMessage(ctx context.Context, messageID string) (proton.Message, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type downloadState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
downloadStateZero downloadState = iota
|
||||||
|
downloadStateHasMessage
|
||||||
|
downloadStateFinished
|
||||||
|
)
|
||||||
|
|
||||||
|
type downloadResult struct {
|
||||||
|
ID string
|
||||||
|
State downloadState
|
||||||
|
Message proton.FullMessage
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func startSyncDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, downloadCh <-chan downloadRequest, syncLimits syncLimits) (<-chan downloadedMessageBatch, <-chan error) {
|
||||||
|
buildCh := make(chan downloadedMessageBatch)
|
||||||
|
errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
|
||||||
|
|
||||||
|
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
||||||
|
async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) {
|
||||||
|
defer close(buildCh)
|
||||||
|
defer close(errorCh)
|
||||||
|
defer func() {
|
||||||
|
logrus.Debugf("sync downloader exit")
|
||||||
|
}()
|
||||||
|
|
||||||
|
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, syncLimits.MaxParallelDownloads)
|
||||||
|
defer attachmentDownloader.close()
|
||||||
|
|
||||||
|
for request := range downloadCh {
|
||||||
|
logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize))
|
||||||
|
if request.err != nil {
|
||||||
|
errorCh <- request.err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := downloadMessageStage1(ctx, panicHandler, request, downloader, attachmentDownloader, syncLimits.MaxParallelDownloads)
|
||||||
|
if err != nil {
|
||||||
|
errorCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
errorCh <- ctx.Err()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
batch, err := downloadMessagesStage2(ctx, result, downloader, SyncRetryCooldown)
|
||||||
|
if err != nil {
|
||||||
|
errorCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case buildCh <- downloadedMessageBatch{
|
||||||
|
batch: batch,
|
||||||
|
}:
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, logging.Labels{"sync-stage": "download"})
|
||||||
|
|
||||||
|
return buildCh, errorCh
|
||||||
|
}
|
||||||
|
|
||||||
|
type attachmentResult struct {
|
||||||
|
attachment []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type attachmentJob struct {
|
||||||
|
id string
|
||||||
|
size int64
|
||||||
|
result chan attachmentResult
|
||||||
|
}
|
||||||
|
|
||||||
|
type attachmentDownloader struct {
|
||||||
|
workerCh chan attachmentJob
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func attachmentWorker(ctx context.Context, downloader MessageDownloader, work <-chan attachmentJob) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case job, ok := <-work:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var b bytes.Buffer
|
||||||
|
b.Grow(int(job.size))
|
||||||
|
err := downloader.GetAttachmentInto(ctx, job.id, &b)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
close(job.result)
|
||||||
|
return
|
||||||
|
case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
|
||||||
|
close(job.result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAttachmentDownloader(ctx context.Context, panicHandler async.PanicHandler, downloader MessageDownloader, workerCount int) *attachmentDownloader {
|
||||||
|
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
for i := 0; i < workerCount; i++ {
|
||||||
|
workerCh = make(chan attachmentJob)
|
||||||
|
async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, workerCh) }, logging.Labels{
|
||||||
|
"sync": fmt.Sprintf("att-downloader %v", i),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &attachmentDownloader{
|
||||||
|
workerCh: workerCh,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) {
|
||||||
|
resultChs := make([]chan attachmentResult, len(attachments))
|
||||||
|
for i, id := range attachments {
|
||||||
|
resultChs[i] = make(chan attachmentResult, 1)
|
||||||
|
select {
|
||||||
|
case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([][]byte, len(attachments))
|
||||||
|
var err error
|
||||||
|
for i := 0; i < len(attachments); i++ {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case r := <-resultChs[i]:
|
||||||
|
if r.err != nil {
|
||||||
|
err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err)
|
||||||
|
}
|
||||||
|
result[i] = r.attachment
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *attachmentDownloader) close() {
|
||||||
|
a.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func downloadMessageStage1(
|
||||||
|
ctx context.Context,
|
||||||
|
panicHandler async.PanicHandler,
|
||||||
|
request downloadRequest,
|
||||||
|
downloader MessageDownloader,
|
||||||
|
attachmentDownloader *attachmentDownloader,
|
||||||
|
parallelDownloads int,
|
||||||
|
) ([]downloadResult, error) {
|
||||||
|
// 1st attempt download everything in parallel
|
||||||
|
return parallel.MapContext(ctx, parallelDownloads, request.ids, func(ctx context.Context, id string) (downloadResult, error) {
|
||||||
|
defer async.HandlePanic(panicHandler)
|
||||||
|
|
||||||
|
result := downloadResult{ID: id}
|
||||||
|
|
||||||
|
msg, err := downloader.GetMessage(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
|
||||||
|
result.err = err
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Message.Message = msg
|
||||||
|
result.State = downloadStateHasMessage
|
||||||
|
|
||||||
|
attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
|
||||||
|
result.Message.AttData = attachments
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments")
|
||||||
|
result.err = err
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result.State = downloadStateFinished
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func downloadMessagesStage2(ctx context.Context, state []downloadResult, downloader MessageDownloader, coolDown time.Duration) ([]proton.FullMessage, error) {
|
||||||
|
logrus.Debug("Entering download stage 2")
|
||||||
|
var retryList []int
|
||||||
|
var shouldWaitBeforeRetry bool
|
||||||
|
|
||||||
|
for {
|
||||||
|
if shouldWaitBeforeRetry {
|
||||||
|
time.Sleep(coolDown)
|
||||||
|
}
|
||||||
|
|
||||||
|
retryList = nil
|
||||||
|
shouldWaitBeforeRetry = false
|
||||||
|
|
||||||
|
for index, s := range state {
|
||||||
|
if s.State == downloadStateFinished {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.err != nil {
|
||||||
|
if is429Error(s.err) {
|
||||||
|
logrus.WithField("msg-id", s.ID).Debug("Message download failed due to 429, retrying")
|
||||||
|
retryList = append(retryList, index)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(retryList) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, i := range retryList {
|
||||||
|
st := &state[i]
|
||||||
|
if st.State == downloadStateZero {
|
||||||
|
message, err := downloader.GetMessage(ctx, st.ID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithField("msg-id", st.ID).WithError(err).Error("failed to download message (429)")
|
||||||
|
if is429Error(err) {
|
||||||
|
st.err = err
|
||||||
|
shouldWaitBeforeRetry = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
st.Message.Message = message
|
||||||
|
st.State = downloadStateHasMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
if st.Message.AttData == nil && st.Message.NumAttachments != 0 {
|
||||||
|
st.Message.AttData = make([][]byte, st.Message.NumAttachments)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasAllAttachments := true
|
||||||
|
for i := 0; i < st.Message.NumAttachments; i++ {
|
||||||
|
if st.Message.AttData[i] == nil {
|
||||||
|
buffer := bytes.Buffer{}
|
||||||
|
if err := downloader.GetAttachmentInto(ctx, st.Message.Attachments[i].ID, &buffer); err != nil {
|
||||||
|
logrus.WithField("msg-id", st.ID).WithError(err).Errorf("failed to download attachment %v/%v (429)", i+1, len(st.Message.Attachments))
|
||||||
|
if is429Error(err) {
|
||||||
|
st.err = err
|
||||||
|
shouldWaitBeforeRetry = true
|
||||||
|
hasAllAttachments = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
st.Message.AttData[i] = buffer.Bytes()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasAllAttachments {
|
||||||
|
st.State = downloadStateFinished
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Debug("All message downloaded successfully")
|
||||||
|
return xslices.Map(state, func(s downloadResult) proton.FullMessage {
|
||||||
|
return s.Message
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func is429Error(err error) bool {
|
||||||
|
var apiErr *proton.APIError
|
||||||
|
if errors.As(err, &apiErr) {
|
||||||
|
return apiErr.Status == 429
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
400
internal/user/sync_downloader_test.go
Normal file
400
internal/user/sync_downloader_test.go
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
// Copyright (c) 2023 Proton AG
|
||||||
|
//
|
||||||
|
// This file is part of Proton Mail Bridge.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/gluon/async"
|
||||||
|
"github.com/ProtonMail/go-proton-api"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/user/mocks"
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSyncDownloader_Stage1_429(t *testing.T) {
|
||||||
|
// Check 429 is correctly caught and download state recorded correctly
|
||||||
|
// Message 1: All ok
|
||||||
|
// Message 2: Message failed
|
||||||
|
// Message 3: One attachment failed.
|
||||||
|
mockCtrl := gomock.NewController(t)
|
||||||
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
|
panicHandler := &async.NoopPanicHandler{}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
requests := downloadRequest{
|
||||||
|
ids: []string{"Msg1", "Msg2", "Msg3"},
|
||||||
|
expectedSize: 0,
|
||||||
|
err: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg1")).Times(1).Return(proton.Message{
|
||||||
|
MessageMetadata: proton.MessageMetadata{
|
||||||
|
ID: "MsgID1",
|
||||||
|
NumAttachments: 1,
|
||||||
|
},
|
||||||
|
Attachments: []proton.Attachment{
|
||||||
|
{
|
||||||
|
ID: "Attachment1_1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg2")).Times(1).Return(proton.Message{}, &proton.APIError{Status: 429})
|
||||||
|
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{
|
||||||
|
MessageMetadata: proton.MessageMetadata{
|
||||||
|
ID: "MsgID3",
|
||||||
|
NumAttachments: 2,
|
||||||
|
},
|
||||||
|
Attachments: []proton.Attachment{
|
||||||
|
{
|
||||||
|
ID: "Attachment3_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "Attachment3_2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
const attachmentData = "attachment data"
|
||||||
|
|
||||||
|
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment1_1"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
|
||||||
|
_, err := r.ReadFrom(strings.NewReader(attachmentData))
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment3_1"), gomock.Any()).Times(1).Return(&proton.APIError{Status: 429})
|
||||||
|
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment3_2"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
|
||||||
|
_, err := r.ReadFrom(strings.NewReader(attachmentData))
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, 1)
|
||||||
|
defer attachmentDownloader.close()
|
||||||
|
|
||||||
|
result, err := downloadMessageStage1(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 3, len(result))
|
||||||
|
// Check message 1
|
||||||
|
require.Equal(t, result[0].State, downloadStateFinished)
|
||||||
|
require.Equal(t, result[0].Message.ID, "MsgID1")
|
||||||
|
require.NotEmpty(t, result[0].Message.AttData)
|
||||||
|
require.NotEqual(t, attachmentData, result[0].Message.AttData[0])
|
||||||
|
require.NotNil(t, result[0].Message.AttData[0])
|
||||||
|
require.Nil(t, result[0].err)
|
||||||
|
|
||||||
|
// Check message 2
|
||||||
|
require.Equal(t, result[1].State, downloadStateZero)
|
||||||
|
require.Empty(t, result[1].Message.ID)
|
||||||
|
require.NotNil(t, result[1].err)
|
||||||
|
|
||||||
|
require.Equal(t, result[2].State, downloadStateHasMessage)
|
||||||
|
require.Equal(t, result[2].Message.ID, "MsgID3")
|
||||||
|
require.Equal(t, 2, len(result[2].Message.AttData))
|
||||||
|
require.NotNil(t, result[2].err)
|
||||||
|
require.Nil(t, result[2].Message.AttData[0])
|
||||||
|
require.NotEqual(t, attachmentData, result[2].Message.AttData[1])
|
||||||
|
require.NotNil(t, result[2].err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncDownloader_Stage2_Everything200(t *testing.T) {
|
||||||
|
mockCtrl := gomock.NewController(t)
|
||||||
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
downloadResult := []downloadResult{
|
||||||
|
{
|
||||||
|
ID: "Msg1",
|
||||||
|
State: downloadStateFinished,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "Msg2",
|
||||||
|
State: downloadStateFinished,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 2, len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncDownloader_Stage2_Not429(t *testing.T) {
|
||||||
|
mockCtrl := gomock.NewController(t)
|
||||||
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
msgErr := fmt.Errorf("something not 429")
|
||||||
|
downloadResult := []downloadResult{
|
||||||
|
{
|
||||||
|
ID: "Msg1",
|
||||||
|
State: downloadStateFinished,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "Msg2",
|
||||||
|
State: downloadStateHasMessage,
|
||||||
|
err: msgErr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "Msg3",
|
||||||
|
State: downloadStateFinished,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, msgErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncDownloader_Stage2_API500(t *testing.T) {
|
||||||
|
mockCtrl := gomock.NewController(t)
|
||||||
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
msgErr := &proton.APIError{Status: 500}
|
||||||
|
downloadResult := []downloadResult{
|
||||||
|
{
|
||||||
|
ID: "Msg2",
|
||||||
|
State: downloadStateHasMessage,
|
||||||
|
err: msgErr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "Msg3",
|
||||||
|
State: downloadStateFinished,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, msgErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncDownloader_Stage2_Some429(t *testing.T) {
|
||||||
|
mockCtrl := gomock.NewController(t)
|
||||||
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
const attachmentData1 = "attachment data 1"
|
||||||
|
const attachmentData2 = "attachment data 2"
|
||||||
|
const attachmentData3 = "attachment data 3"
|
||||||
|
const attachmentData4 = "attachment data 4"
|
||||||
|
|
||||||
|
err429 := &proton.APIError{Status: 429}
|
||||||
|
downloadResult := []downloadResult{
|
||||||
|
{
|
||||||
|
// Full message , but missing 1 of 2 attachments
|
||||||
|
ID: "Msg1",
|
||||||
|
Message: proton.FullMessage{
|
||||||
|
Message: proton.Message{
|
||||||
|
MessageMetadata: proton.MessageMetadata{
|
||||||
|
ID: "Msg1",
|
||||||
|
NumAttachments: 2,
|
||||||
|
},
|
||||||
|
Attachments: []proton.Attachment{
|
||||||
|
{
|
||||||
|
ID: "A3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "A4",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
AttData: [][]byte{
|
||||||
|
nil,
|
||||||
|
[]byte(attachmentData4),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
State: downloadStateHasMessage,
|
||||||
|
err: err429,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Full message, but missing all attachments
|
||||||
|
ID: "Msg2",
|
||||||
|
Message: proton.FullMessage{
|
||||||
|
Message: proton.Message{
|
||||||
|
MessageMetadata: proton.MessageMetadata{
|
||||||
|
ID: "Msg2",
|
||||||
|
NumAttachments: 2,
|
||||||
|
},
|
||||||
|
Attachments: []proton.Attachment{
|
||||||
|
{
|
||||||
|
ID: "A1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "A2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
AttData: nil,
|
||||||
|
},
|
||||||
|
State: downloadStateHasMessage,
|
||||||
|
err: err429,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Missing everything
|
||||||
|
ID: "Msg3",
|
||||||
|
State: downloadStateZero,
|
||||||
|
Message: proton.FullMessage{
|
||||||
|
Message: proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3"}},
|
||||||
|
},
|
||||||
|
err: err429,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Simulate 2 failures for message 3 body.
|
||||||
|
firstCall := messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(2).Return(proton.Message{}, err429)
|
||||||
|
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).After(firstCall).Times(1).Return(proton.Message{
|
||||||
|
MessageMetadata: proton.MessageMetadata{
|
||||||
|
ID: "Msg3",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Simulate failures for message 2 attachments.
|
||||||
|
firstCall := messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A1"), gomock.Any()).Times(2).Return(err429)
|
||||||
|
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A1"), gomock.Any()).After(firstCall).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
|
||||||
|
_, err := r.ReadFrom(strings.NewReader(attachmentData1))
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A2"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
|
||||||
|
_, err := r.ReadFrom(strings.NewReader(attachmentData2))
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A3"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
|
||||||
|
_, err := r.ReadFrom(strings.NewReader(attachmentData3))
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 3, len(messages))
|
||||||
|
|
||||||
|
require.Equal(t, messages[0].Message.ID, "Msg1")
|
||||||
|
require.Equal(t, messages[1].Message.ID, "Msg2")
|
||||||
|
require.Equal(t, messages[2].Message.ID, "Msg3")
|
||||||
|
|
||||||
|
// check attachments
|
||||||
|
require.Equal(t, attachmentData3, string(messages[0].AttData[0]))
|
||||||
|
require.Equal(t, attachmentData4, string(messages[0].AttData[1]))
|
||||||
|
require.Equal(t, attachmentData1, string(messages[1].AttData[0]))
|
||||||
|
require.Equal(t, attachmentData2, string(messages[1].AttData[1]))
|
||||||
|
require.Empty(t, messages[2].AttData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) {
|
||||||
|
mockCtrl := gomock.NewController(t)
|
||||||
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err429 := &proton.APIError{Status: 429}
|
||||||
|
err500 := &proton.APIError{Status: 500}
|
||||||
|
downloadResult := []downloadResult{
|
||||||
|
{
|
||||||
|
// Missing everything
|
||||||
|
ID: "Msg3",
|
||||||
|
State: downloadStateZero,
|
||||||
|
Message: proton.FullMessage{
|
||||||
|
Message: proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3"}},
|
||||||
|
},
|
||||||
|
err: err429,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Full message , but missing 1 of 2 attachments
|
||||||
|
ID: "Msg1",
|
||||||
|
Message: proton.FullMessage{
|
||||||
|
Message: proton.Message{
|
||||||
|
MessageMetadata: proton.MessageMetadata{
|
||||||
|
ID: "Msg1",
|
||||||
|
NumAttachments: 2,
|
||||||
|
},
|
||||||
|
Attachments: []proton.Attachment{
|
||||||
|
{
|
||||||
|
ID: "A3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "A4",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
State: downloadStateHasMessage,
|
||||||
|
err: err429,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Simulate 2 failures for message 3 body,
|
||||||
|
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{}, err500)
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Empty(t, 0, messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) {
|
||||||
|
mockCtrl := gomock.NewController(t)
|
||||||
|
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err429 := &proton.APIError{Status: 429}
|
||||||
|
err500 := &proton.APIError{Status: 500}
|
||||||
|
downloadResult := []downloadResult{
|
||||||
|
{
|
||||||
|
// Full message , but missing 1 of 2 attachments
|
||||||
|
ID: "Msg1",
|
||||||
|
Message: proton.FullMessage{
|
||||||
|
Message: proton.Message{
|
||||||
|
MessageMetadata: proton.MessageMetadata{
|
||||||
|
ID: "Msg1",
|
||||||
|
NumAttachments: 2,
|
||||||
|
},
|
||||||
|
Attachments: []proton.Attachment{
|
||||||
|
{
|
||||||
|
ID: "A3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "A4",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
State: downloadStateHasMessage,
|
||||||
|
err: err429,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 429 for first attachment
|
||||||
|
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A3"), gomock.Any()).Times(1).Return(err429)
|
||||||
|
// 500 for second attachment
|
||||||
|
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A4"), gomock.Any()).Times(1).Return(err500)
|
||||||
|
|
||||||
|
messages, err := downloadMessagesStage2(ctx, downloadResult, messageDownloader, time.Millisecond)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Empty(t, 0, messages)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user