feat(GODT-2829): New Sync Service

Implementation of the new sync service that interleaves syncing jobs for
all active users.

It also includes improvements to the message downloader. The download
will now auto rate limit the parallel workers based on the server
responses.

Additionally each of the stages is now tested in isolation to ensure the
behavior matches the expectations.

Finally, this patch does not replace the existing IMAP sync. A follow up
patch is necessary to integrate the IMAP bits into the interfaces
required by these changes.
This commit is contained in:
Leander Beernaert
2023-08-17 09:50:03 +02:00
parent a731237701
commit 78f7cbdc79
23 changed files with 4770 additions and 0 deletions

View File

@ -299,6 +299,11 @@ EventSubscriber,MessageEventHandler,LabelEventHandler,AddressEventHandler,Refres
> internal/events/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \
> internal/services/useridentity/mocks/mocks.go
mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/sync" -package sync github.com/ProtonMail/proton-bridge/v3/internal/services/sync \
ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,\
StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier \
> tmp
mv tmp internal/services/sync/mocks_test.go
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint-bug-report

1
go.sum
View File

@ -266,6 +266,7 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=

123
internal/network/proton.go Normal file
View File

@ -0,0 +1,123 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package network
import (
"context"
"errors"
"math/rand"
"time"
"github.com/ProtonMail/go-proton-api"
)
type CoolDownProvider interface {
GetNextWaitTime() time.Duration
Reset()
}
func jitter(max int) time.Duration {
return time.Duration(rand.Intn(max)) * time.Second //nolint:gosec
}
type ExpCoolDown struct {
count int
}
func (c *ExpCoolDown) GetNextWaitTime() time.Duration {
waitTimes := []time.Duration{
20 * time.Second,
40 * time.Second,
80 * time.Second,
160 * time.Second,
300 * time.Second,
600 * time.Second,
}
last := len(waitTimes) - 1
if c.count >= last {
return waitTimes[last] + jitter(10)
}
c.count++
return waitTimes[c.count-1] + jitter(10)
}
func (c *ExpCoolDown) Reset() {
c.count = 0
}
type NoCoolDown struct{}
func (c *NoCoolDown) GetNextWaitTime() time.Duration { return time.Millisecond }
func (c *NoCoolDown) Reset() {}
func Is429Or5XXError(err error) bool {
var apiErr *proton.APIError
if errors.As(err, &apiErr) {
return apiErr.Status == 429 || apiErr.Status >= 500
}
return false
}
type ProtonClientRetryWrapper[T any] struct {
client T
coolDown CoolDownProvider
encountered429or5xx bool
}
func NewClientRetryWrapper[T any](client T, coolDown CoolDownProvider) *ProtonClientRetryWrapper[T] {
return &ProtonClientRetryWrapper[T]{client: client, coolDown: coolDown}
}
func (p *ProtonClientRetryWrapper[T]) DidEncounter429or5xx() bool {
return p.encountered429or5xx
}
func (p *ProtonClientRetryWrapper[T]) Retry(ctx context.Context, f func(context.Context, T) error) error {
p.coolDown.Reset()
p.encountered429or5xx = false
for {
err := f(ctx, p.client)
if Is429Or5XXError(err) {
p.encountered429or5xx = true
coolDown := p.coolDown.GetNextWaitTime()
select {
case <-ctx.Done():
case <-time.After(coolDown):
}
continue
}
return err
}
}
func RetryWithClient[T any, R any](ctx context.Context, p *ProtonClientRetryWrapper[T], f func(context.Context, T) (R, error)) (R, error) {
var result R
err := p.Retry(ctx, func(ctx context.Context, t T) error {
r, err := f(ctx, t)
result = r
return err
})
return result, err
}

View File

@ -0,0 +1,36 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"io"
"github.com/ProtonMail/go-proton-api"
)
type APIClient interface {
GetGroupedMessageCount(ctx context.Context) ([]proton.MessageGroupCount, error)
GetLabels(ctx context.Context, labelTypes ...proton.LabelType) ([]proton.Label, error)
GetMessage(ctx context.Context, messageID string) (proton.Message, error)
GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error)
GetMessageIDs(ctx context.Context, afterID string) ([]string, error)
GetFullMessage(ctx context.Context, messageID string, scheduler proton.Scheduler, storageProvider proton.AttachmentAllocator) (proton.FullMessage, error)
GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error
GetAttachment(ctx context.Context, attachmentID string) ([]byte, error)
}

View File

@ -0,0 +1,115 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"sync"
"github.com/ProtonMail/go-proton-api"
)
type DownloadCache struct {
messageLock sync.RWMutex
messages map[string]proton.Message
attachmentLock sync.RWMutex
attachments map[string][]byte
}
func newDownloadCache() *DownloadCache {
return &DownloadCache{
messages: make(map[string]proton.Message, 64),
attachments: make(map[string][]byte, 64),
}
}
func (s *DownloadCache) StoreMessage(message proton.Message) {
s.messageLock.Lock()
defer s.messageLock.Unlock()
s.messages[message.ID] = message
}
func (s *DownloadCache) StoreAttachment(id string, data []byte) {
s.attachmentLock.Lock()
defer s.attachmentLock.Unlock()
s.attachments[id] = data
}
func (s *DownloadCache) DeleteMessages(id ...string) {
s.messageLock.Lock()
defer s.messageLock.Unlock()
for _, id := range id {
delete(s.messages, id)
}
}
func (s *DownloadCache) DeleteAttachments(id ...string) {
s.attachmentLock.Lock()
defer s.attachmentLock.Unlock()
for _, id := range id {
delete(s.attachments, id)
}
}
func (s *DownloadCache) GetMessage(id string) (proton.Message, bool) {
s.messageLock.RLock()
defer s.messageLock.RUnlock()
v, ok := s.messages[id]
return v, ok
}
func (s *DownloadCache) GetAttachment(id string) ([]byte, bool) {
s.attachmentLock.RLock()
defer s.attachmentLock.RUnlock()
v, ok := s.attachments[id]
return v, ok
}
func (s *DownloadCache) Clear() {
s.messageLock.Lock()
s.messages = make(map[string]proton.Message, 64)
s.messageLock.Unlock()
s.attachmentLock.Lock()
s.attachments = make(map[string][]byte, 64)
s.attachmentLock.Unlock()
}
func (s *DownloadCache) Count() (int, int) {
var (
messageCount int
attachmentCount int
)
s.messageLock.Lock()
messageCount = len(s.messages)
s.messageLock.Unlock()
s.attachmentLock.Lock()
attachmentCount = len(s.attachments)
s.attachmentLock.Unlock()
return messageCount, attachmentCount
}

View File

@ -0,0 +1,218 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"fmt"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/network"
"github.com/sirupsen/logrus"
)
const DefaultRetryCoolDown = 20 * time.Second
type LabelMap = map[string]proton.Label
// Handler is the interface from which we control the syncing of the IMAP data. One instance should be created for each
// user and used for every subsequent sync request.
type Handler struct {
regulator Regulator
client APIClient
userID string
syncState StateProvider
log *logrus.Entry
group *async.Group
syncFinishedCh chan error
panicHandler async.PanicHandler
downloadCache *DownloadCache
}
func NewHandler(
regulator Regulator,
client APIClient,
userID string,
state StateProvider,
log *logrus.Entry,
panicHandler async.PanicHandler,
) *Handler {
return &Handler{
client: client,
userID: userID,
syncState: state,
log: log,
syncFinishedCh: make(chan error),
group: async.NewGroup(context.Background(), panicHandler),
regulator: regulator,
panicHandler: panicHandler,
downloadCache: newDownloadCache(),
}
}
func (t *Handler) Close() {
t.group.CancelAndWait()
close(t.syncFinishedCh)
}
func (t *Handler) CancelAndWait() {
t.group.CancelAndWait()
}
func (t *Handler) Cancel() {
t.group.Cancel()
}
func (t *Handler) OnSyncFinishedCH() <-chan error {
return t.syncFinishedCh
}
func (t *Handler) Execute(
syncReporter Reporter,
labels LabelMap,
updateApplier UpdateApplier,
messageBuilder MessageBuilder,
coolDown time.Duration,
) {
t.log.Info("Sync triggered")
t.group.Once(func(ctx context.Context) {
start := time.Now()
t.log.WithField("start", start).Info("Beginning user sync")
syncReporter.OnStart(ctx)
var err error
for {
if err = ctx.Err(); err != nil {
t.log.WithError(err).Error("Sync aborted")
break
} else if err = t.run(ctx, syncReporter, labels, updateApplier, messageBuilder); err != nil {
t.log.WithError(err).Error("Failed to sync, will retry later")
sleepCtx(ctx, coolDown)
} else {
break
}
}
if err != nil {
syncReporter.OnError(ctx, err)
} else {
syncReporter.OnFinished(ctx)
}
t.log.WithField("duration", time.Since(start)).Info("Finished user sync")
t.syncFinishedCh <- err
})
}
func (t *Handler) run(ctx context.Context,
syncReporter Reporter,
labels LabelMap,
updateApplier UpdateApplier,
messageBuilder MessageBuilder,
) error {
syncStatus, err := t.syncState.GetSyncStatus(ctx)
if err != nil {
return fmt.Errorf("failed to get sync status: %w", err)
}
if syncStatus.IsComplete() {
t.log.Info("Sync already complete, only system labels will be updated")
if err := updateApplier.SyncSystemLabelsOnly(ctx, labels); err != nil {
t.log.WithError(err).Error("Failed to sync system labels")
return err
}
return nil
}
if !syncStatus.HasLabels {
t.log.Info("Syncing labels")
if err := updateApplier.SyncLabels(ctx, labels); err != nil {
return fmt.Errorf("failed to sync labels: %w", err)
}
if err := t.syncState.SetHasLabels(ctx, true); err != nil {
return fmt.Errorf("failed to set has labels: %w", err)
}
t.log.Info("Synced labels")
}
if !syncStatus.HasMessageCount {
wrapper := network.NewClientRetryWrapper(t.client, &network.ExpCoolDown{})
messageCounts, err := network.RetryWithClient(ctx, wrapper, func(ctx context.Context, c APIClient) ([]proton.MessageGroupCount, error) {
return c.GetGroupedMessageCount(ctx)
})
if err != nil {
return fmt.Errorf("failed to retrieve message ids: %w", err)
}
var totalMessageCount int64
for _, gc := range messageCounts {
if gc.LabelID == proton.AllMailLabel {
totalMessageCount = int64(gc.Total)
break
}
}
if err := t.syncState.SetMessageCount(ctx, totalMessageCount); err != nil {
return fmt.Errorf("failed to store message count: %w", err)
}
}
if !syncStatus.HasMessages {
t.log.Info("Syncing messages")
stageContext := NewJob(
ctx,
t.client,
t.userID,
labels,
messageBuilder,
updateApplier,
syncReporter,
t.syncState,
t.panicHandler,
t.downloadCache,
t.log,
)
t.regulator.Sync(ctx, stageContext)
// Wait on reply
if err := stageContext.wait(ctx); err != nil {
return fmt.Errorf("failed sync messages: %w", err)
}
if err := t.syncState.SetHasMessages(ctx, true); err != nil {
return fmt.Errorf("failed to set sync as completed: %w", err)
}
t.log.Info("Synced messages")
} else {
t.log.Info("Messages are already synced, skipping")
}
return nil
}

View File

@ -0,0 +1,352 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"fmt"
"testing"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/bradenaw/juniper/xmaps"
"github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
func TestTask_NoStateAndSucceeds(t *testing.T) {
const MessageTotal int64 = 50
const MessageID string = "foo"
const MessageDelta int64 = 10
labels := getTestLabels()
mockCtrl := gomock.NewController(t)
tt := newTestHandler(mockCtrl, "u")
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
{
call1 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
return Status{
HasLabels: false,
HasMessages: false,
HasMessageCount: false,
FailedMessages: xmaps.SetFromSlice([]string{}),
LastSyncedMessageID: "",
NumSyncedMessages: 0,
TotalMessageCount: 0,
}, nil
})
call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil)
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) {
return Status{
HasLabels: true,
HasMessages: true,
HasMessageCount: true,
FailedMessages: xmaps.SetFromSlice([]string{}),
LastSyncedMessageID: MessageID,
NumSyncedMessages: MessageDelta,
TotalMessageCount: MessageTotal,
}, nil
})
}
{
call1 := tt.updateApplier.EXPECT().SyncLabels(gomock.Any(), gomock.Eq(labels)).Times(1).Return(nil)
tt.updateApplier.EXPECT().SyncSystemLabelsOnly(gomock.Any(), gomock.Eq(labels)).After(call1).Times(1).Return(nil)
}
{
tt.client.EXPECT().GetGroupedMessageCount(gomock.Any()).Return([]proton.MessageGroupCount{
{
LabelID: proton.AllMailLabel,
Total: int(MessageTotal),
Unread: 0,
},
}, nil)
}
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
// First run.
err := tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
require.NoError(t, err)
// Second Run, it's completed sync labels only.
err = tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
require.NoError(t, err)
}
func TestTask_StateHasLabels(t *testing.T) {
const MessageTotal int64 = 50
const MessageID string = "foo"
const MessageDelta int64 = 10
labels := getTestLabels()
mockCtrl := gomock.NewController(t)
tt := newTestHandler(mockCtrl, "u")
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
{
call2 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
return Status{
HasLabels: true,
HasMessages: false,
HasMessageCount: false,
FailedMessages: xmaps.SetFromSlice([]string{}),
LastSyncedMessageID: "",
NumSyncedMessages: 0,
TotalMessageCount: 0,
}, nil
})
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
}
{
tt.client.EXPECT().GetGroupedMessageCount(gomock.Any()).Return([]proton.MessageGroupCount{
{
LabelID: proton.AllMailLabel,
Total: int(MessageTotal),
Unread: 0,
},
}, nil)
}
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
err := tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
require.NoError(t, err)
}
func TestTask_StateHasLabelsAndMessageCount(t *testing.T) {
const MessageTotal int64 = 50
const MessageID string = "foo"
const MessageDelta int64 = 10
labels := getTestLabels()
mockCtrl := gomock.NewController(t)
tt := newTestHandler(mockCtrl, "u")
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
{
call3 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
return Status{
HasLabels: true,
HasMessages: false,
HasMessageCount: true,
FailedMessages: xmaps.SetFromSlice([]string{}),
LastSyncedMessageID: "",
NumSyncedMessages: 0,
TotalMessageCount: MessageTotal,
}, nil
})
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
}
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
err := tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
require.NoError(t, err)
}
func TestTask_StateHasSyncedState(t *testing.T) {
const MessageTotal int64 = 50
const MessageID string = "foo"
labels := getTestLabels()
mockCtrl := gomock.NewController(t)
tt := newTestHandler(mockCtrl, "u")
tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
return Status{
HasLabels: true,
HasMessages: true,
HasMessageCount: true,
FailedMessages: xmaps.SetFromSlice([]string{}),
LastSyncedMessageID: MessageID,
NumSyncedMessages: MessageTotal,
TotalMessageCount: MessageTotal,
}, nil
})
tt.updateApplier.EXPECT().SyncSystemLabelsOnly(gomock.Any(), gomock.Eq(labels)).Return(nil)
err := tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
require.NoError(t, err)
}
func TestTask_RepeatsOnSyncFailure(t *testing.T) {
const MessageTotal int64 = 50
const MessageID string = "foo"
const MessageDelta int64 = 10
labels := getTestLabels()
mockCtrl := gomock.NewController(t)
tt := newTestHandler(mockCtrl, "u")
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
{
call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
return Status{
HasLabels: false,
HasMessages: false,
HasMessageCount: false,
FailedMessages: xmaps.SetFromSlice([]string{}),
LastSyncedMessageID: "",
NumSyncedMessages: 0,
TotalMessageCount: 0,
}, nil
})
call1 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
return Status{
HasLabels: false,
HasMessages: false,
HasMessageCount: false,
FailedMessages: xmaps.SetFromSlice([]string{}),
LastSyncedMessageID: "",
NumSyncedMessages: 0,
TotalMessageCount: 0,
}, nil
}).After(call0)
call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil)
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
}
{
call0 := tt.updateApplier.EXPECT().SyncLabels(gomock.Any(), gomock.Eq(labels)).Times(1).Return(fmt.Errorf("failed"))
tt.updateApplier.EXPECT().SyncLabels(gomock.Any(), gomock.Eq(labels)).Times(1).Return(nil).After(call0)
}
{
tt.client.EXPECT().GetGroupedMessageCount(gomock.Any()).Return([]proton.MessageGroupCount{
{
LabelID: proton.AllMailLabel,
Total: int(MessageTotal),
Unread: 0,
},
}, nil)
}
tt.syncReporter.EXPECT().OnStart(gomock.Any())
tt.syncReporter.EXPECT().OnFinished(gomock.Any())
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
tt.task.Execute(tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder, time.Microsecond)
require.NoError(t, <-tt.task.OnSyncFinishedCH())
}
func getTestLabels() map[string]proton.Label {
return map[string]proton.Label{
proton.AllMailLabel: {
ID: proton.AllMailLabel,
Name: "All Mail",
Type: proton.LabelTypeSystem,
},
proton.InboxLabel: {
ID: proton.InboxLabel,
Name: "Inbox",
Type: proton.LabelTypeSystem,
},
proton.DraftsLabel: {
ID: proton.DraftsLabel,
Name: "Drafts",
Type: proton.LabelTypeSystem,
},
proton.TrashLabel: {
ID: proton.DraftsLabel,
Name: "Drafts",
Type: proton.LabelTypeSystem,
},
"label1": {
ID: "label1",
Name: "label1",
Type: proton.LabelTypeLabel,
},
"folder1": {
ID: "folder1",
Name: "folder1",
Type: proton.LabelTypeFolder,
},
"folder2": {
ID: "folder2",
Name: "folder2",
ParentID: "folder1",
Type: proton.LabelTypeFolder,
},
}
}
type thandler struct {
task *Handler
regulator *MockRegulator
syncState *MockStateProvider
updateApplier *MockUpdateApplier
messageBuilder *MockMessageBuilder
client *MockAPIClient
syncReporter *MockReporter
}
func (t thandler) addMessageSyncCompletedExpectation(messageID string, delta int64) { //nolint:unparam
t.regulator.EXPECT().Sync(gomock.Any(), gomock.Any()).Do(func(_ context.Context, job *Job) {
job.begin()
j := job.newChildJob(messageID, delta)
j.onFinished(context.Background())
job.end()
})
}
func newTestHandler(mockCtrl *gomock.Controller, userID string) thandler { // nolint:unparam
regulator := NewMockRegulator(mockCtrl)
syncState := NewMockStateProvider(mockCtrl)
updateApplier := NewMockUpdateApplier(mockCtrl)
client := NewMockAPIClient(mockCtrl)
messageBuilder := NewMockMessageBuilder(mockCtrl)
syncReporter := NewMockReporter(mockCtrl)
task := NewHandler(regulator, client, userID, syncState, logrus.WithField("test", "test"), &async.NoopPanicHandler{})
return thandler{
task: task,
regulator: regulator,
syncState: syncState,
updateApplier: updateApplier,
messageBuilder: messageBuilder,
syncReporter: syncReporter,
client: client,
}
}

View File

@ -0,0 +1,88 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"bytes"
"context"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xmaps"
)
type StateProvider interface {
AddFailedMessageID(context.Context, string) error
RemFailedMessageID(context.Context, string) error
GetSyncStatus(context.Context) (Status, error)
ClearSyncStatus(context.Context) error
SetHasLabels(context.Context, bool) error
SetHasMessages(context.Context, bool) error
SetLastMessageID(context.Context, string, int64) error
SetMessageCount(context.Context, int64) error
}
type Status struct {
HasLabels bool
HasMessages bool
HasMessageCount bool
FailedMessages xmaps.Set[string]
LastSyncedMessageID string
NumSyncedMessages int64
TotalMessageCount int64
}
func DefaultStatus() Status {
return Status{
FailedMessages: make(map[string]struct{}),
}
}
func (s Status) IsComplete() bool {
return s.HasLabels && s.HasMessages
}
// Regulator is an abstraction for the sync service, since it regulates the number of concurrent sync activities.
type Regulator interface {
Sync(ctx context.Context, stage *Job)
}
type BuildResult struct {
AddressID string
MessageID string
Update *imap.MessageCreated
}
type MessageBuilder interface {
WithKeys(f func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error
BuildMessage(apiLabels map[string]proton.Label, full proton.FullMessage, addrKR *crypto.KeyRing, buffer *bytes.Buffer) (BuildResult, error)
}
type UpdateApplier interface {
ApplySyncUpdates(ctx context.Context, updates []BuildResult) error
SyncSystemLabelsOnly(ctx context.Context, labels map[string]proton.Label) error
SyncLabels(ctx context.Context, labels map[string]proton.Label) error
}
type Reporter interface {
OnStart(ctx context.Context)
OnFinished(ctx context.Context)
OnError(ctx context.Context, err error)
OnProgress(ctx context.Context, delta int64)
}

View File

@ -0,0 +1,201 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"errors"
"fmt"
"sync"
"github.com/ProtonMail/gluon/async"
"github.com/sirupsen/logrus"
)
// Job represents a unit of work that will travel down the sync pipeline. The job will be split up into child jobs
// for each batch. The parent job (this) will then wait until all the children have finished executing. Execution can
// terminate by either:
// * Completing the pipeline successfully
// * Context Cancellation
// * Errors
// On error, or context cancellation all child jobs are cancelled.
type Job struct {
ctx context.Context
cancel func()
client APIClient
state StateProvider
userID string
labels LabelMap
messageBuilder MessageBuilder
updateApplier UpdateApplier
syncReporter Reporter
log *logrus.Entry
errorCh *async.QueuedChannel[error]
wg sync.WaitGroup
once sync.Once
panicHandler async.PanicHandler
downloadCache *DownloadCache
}
func NewJob(ctx context.Context,
client APIClient,
userID string,
labels LabelMap,
messageBuilder MessageBuilder,
updateApplier UpdateApplier,
syncReporter Reporter,
state StateProvider,
panicHandler async.PanicHandler,
cache *DownloadCache,
log *logrus.Entry,
) *Job {
ctx, cancel := context.WithCancel(ctx)
return &Job{
ctx: ctx,
client: client,
userID: userID,
cancel: cancel,
state: state,
log: log,
labels: labels,
messageBuilder: messageBuilder,
updateApplier: updateApplier,
syncReporter: syncReporter,
errorCh: async.NewQueuedChannel[error](4, 8, panicHandler, fmt.Sprintf("sync-job-error-%v", userID)),
panicHandler: panicHandler,
downloadCache: cache,
}
}
func (j *Job) Close() {
j.errorCh.CloseAndDiscardQueued()
j.wg.Wait()
}
func (j *Job) onError(err error) {
defer j.wg.Done()
// context cancelled is caught & handled in a different location.
if errors.Is(err, context.Canceled) {
return
}
j.errorCh.Enqueue(err)
j.cancel()
}
func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int64) {
defer j.wg.Done()
if err := j.state.SetLastMessageID(ctx, lastMessageID, count); err != nil {
j.log.WithError(err).Error("Failed to store last synced message id")
j.onError(err)
return
}
j.syncReporter.OnProgress(ctx, count)
}
// begin is expected to be called once the job enters the pipeline.
func (j *Job) begin() {
j.log.Info("Job started")
j.wg.Add(1)
j.startChildWaiter()
}
// end is expected to be called once the job has no further work left.
func (j *Job) end() {
j.log.Info("Job finished")
j.wg.Done()
}
// wait waits until the job has finished, the context got cancelled or an error occurred.
func (j *Job) wait(ctx context.Context) error {
defer j.wg.Wait()
select {
case <-ctx.Done():
j.cancel()
return ctx.Err()
case err := <-j.errorCh.GetChannel():
return err
}
}
func (j *Job) newChildJob(messageID string, messageCount int64) childJob {
j.log.Infof("Creating new child job")
j.wg.Add(1)
return childJob{job: j, lastMessageID: messageID, messageCount: messageCount}
}
func (j *Job) startChildWaiter() {
j.once.Do(func() {
go func() {
defer async.HandlePanic(j.panicHandler)
j.wg.Wait()
j.log.Info("All child jobs succeeded")
j.errorCh.Enqueue(j.ctx.Err())
}()
})
}
// childJob represents a batch of work that goes down the pipeline. It keeps track of the message ID that is in the
// batch and the number of messages in the batch.
type childJob struct {
job *Job
lastMessageID string
messageCount int64
cachedMessageIDs []string
cachedAttachmentIDs []string
}
func (s *childJob) onError(err error) {
s.job.log.WithError(err).Info("Child job ran into error")
s.job.onError(err)
}
func (s *childJob) userID() string {
return s.job.userID
}
func (s *childJob) onFinished(ctx context.Context) {
s.job.log.Infof("Child job finished")
s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount)
s.job.downloadCache.DeleteMessages(s.cachedMessageIDs...)
s.job.downloadCache.DeleteAttachments(s.cachedAttachmentIDs...)
}
func (s *childJob) checkCancelled() bool {
err := s.job.ctx.Err()
if err != nil {
s.job.log.Infof("Child job exit due to context cancelled")
s.job.wg.Done()
return true
}
return false
}
func (s *childJob) getContext() context.Context {
return s.job.ctx
}

View File

@ -0,0 +1,237 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"errors"
"testing"
"github.com/ProtonMail/gluon/async"
"github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func setupGoLeak() goleak.Option {
logrus.Trace("prepare for go leak")
return goleak.IgnoreCurrent()
}
func TestJob_WaitsOnChildren(t *testing.T) {
options := setupGoLeak()
defer goleak.VerifyNone(t, options)
mockCtrl := gomock.NewController(t)
tj := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("1"), gomock.Eq(int64(0))).Return(nil)
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("2"), gomock.Eq(int64(1))).Return(nil)
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any()).Times(2)
go func() {
tj.job.begin()
job1 := tj.job.newChildJob("1", 0)
job2 := tj.job.newChildJob("2", 1)
job1.onFinished(context.Background())
job2.onFinished(context.Background())
tj.job.end()
}()
require.NoError(t, tj.job.wait(context.Background()))
tj.job.Close()
}
func TestJob_WaitsOnAllChildrenOnError(t *testing.T) {
options := setupGoLeak()
defer goleak.VerifyNone(t, options)
mockCtrl := gomock.NewController(t)
tj := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("1"), gomock.Eq(int64(0))).Return(nil)
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any())
jobErr := errors.New("failed")
go func() {
job1 := tj.job.newChildJob("1", 0)
job2 := tj.job.newChildJob("2", 1)
job1.onFinished(context.Background())
job2.onError(jobErr)
}()
err := tj.job.wait(context.Background())
require.Error(t, err)
require.ErrorIs(t, err, jobErr)
tj.job.Close()
}
func TestJob_MultipleChildrenReportError(t *testing.T) {
options := setupGoLeak()
defer goleak.VerifyNone(t, options)
mockCtrl := gomock.NewController(t)
tj := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
jobErr := errors.New("failed")
startCh := make(chan struct{})
for i := 0; i < 10; i++ {
go func() {
job := tj.job.newChildJob("1", 0)
<-startCh
job.onError(jobErr)
}()
}
close(startCh)
err := tj.job.wait(context.Background())
require.Error(t, err)
require.ErrorIs(t, err, jobErr)
tj.job.Close()
}
func TestJob_ChildFailureCancelsAllOtherChildJobs(t *testing.T) {
options := setupGoLeak()
defer goleak.VerifyNone(t, options)
mockCtrl := gomock.NewController(t)
tj := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
jobErr := errors.New("failed")
failJob := tj.job.newChildJob("0", 1)
for i := 0; i < 10; i++ {
go func() {
job := tj.job.newChildJob("1", 0)
<-job.getContext().Done()
require.ErrorIs(t, job.getContext().Err(), context.Canceled)
require.True(t, job.checkCancelled())
}()
}
go func() {
failJob.onError(jobErr)
}()
err := tj.job.wait(context.Background())
require.Error(t, err)
require.ErrorIs(t, err, jobErr)
tj.job.Close()
}
func TestJob_CtxCancelCancelsAllChildren(t *testing.T) {
options := setupGoLeak()
defer goleak.VerifyNone(t, options)
mockCtrl := gomock.NewController(t)
ctx, cancel := context.WithCancel(context.Background())
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
for i := 0; i < 10; i++ {
go func() {
job := tj.job.newChildJob("1", 0)
<-job.getContext().Done()
require.ErrorIs(t, job.getContext().Err(), context.Canceled)
require.True(t, job.checkCancelled())
}()
}
go func() {
cancel()
}()
err := tj.job.wait(ctx)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
tj.job.Close()
}
func TestJob_WithoutChildJobsCanBeTerminated(t *testing.T) {
options := setupGoLeak()
defer goleak.VerifyNone(t, options)
mockCtrl := gomock.NewController(t)
ctx := context.Background()
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
go func() {
tj.job.begin()
tj.job.end()
}()
err := tj.job.wait(ctx)
require.NoError(t, err)
tj.job.Close()
}
type tjob struct {
job *Job
client *MockAPIClient
messageBuilder *MockMessageBuilder
updateApplier *MockUpdateApplier
syncReporter *MockReporter
state *MockStateProvider
}
func newTestJob(
ctx context.Context,
mockCtrl *gomock.Controller,
userID string,
labels LabelMap,
) tjob {
client := NewMockAPIClient(mockCtrl)
messageBuilder := NewMockMessageBuilder(mockCtrl)
updateApplier := NewMockUpdateApplier(mockCtrl)
syncReporter := NewMockReporter(mockCtrl)
state := NewMockStateProvider(mockCtrl)
job := NewJob(
ctx,
client,
userID,
labels,
messageBuilder,
updateApplier,
syncReporter,
state,
&async.NoopPanicHandler{},
newDownloadCache(),
logrus.WithField("s", "test"),
)
return tjob{
job: job,
client: client,
messageBuilder: messageBuilder,
updateApplier: updateApplier,
syncReporter: syncReporter,
state: state,
}
}

View File

@ -0,0 +1,121 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"os"
"github.com/pbnjay/memory"
"github.com/sirupsen/logrus"
)
const Kilobyte = uint64(1024)
const Megabyte = 1024 * Kilobyte
const Gigabyte = 1024 * Megabyte
func toMB(v uint64) float64 {
return float64(v) / float64(Megabyte)
}
type syncLimits struct {
MaxDownloadRequestMem uint64
MinDownloadRequestMem uint64
MaxMessageBuildingMem uint64
MinMessageBuildingMem uint64
MaxSyncMemory uint64
MaxParallelDownloads int
DownloadRequestMem uint64
MessageBuildMem uint64
}
func newSyncLimits(maxSyncMemory uint64) syncLimits {
limits := syncLimits{
// There's no point in using more than 128MB of download data per stage, after that we reach a point of diminishing
// returns as we can't keep the pipeline fed fast enough.
MaxDownloadRequestMem: 128 * Megabyte,
// Any lower than this and we may fail to download messages.
MinDownloadRequestMem: 40 * Megabyte,
// This value can be increased to your hearts content. The more system memory the user has, the more messages
// we can build in parallel.
MaxMessageBuildingMem: 128 * Megabyte,
MinMessageBuildingMem: 64 * Megabyte,
// Maximum recommend value for parallel downloads by the API team.
MaxParallelDownloads: 20,
MaxSyncMemory: maxSyncMemory,
}
if _, ok := os.LookupEnv("BRIDGE_SYNC_FORCE_MINIMUM_SPEC"); ok {
logrus.Warn("Sync specs forced to minimum")
limits.MaxDownloadRequestMem = 50 * Megabyte
limits.MaxMessageBuildingMem = 80 * Megabyte
limits.MaxParallelDownloads = 2
limits.MaxSyncMemory = 800 * Megabyte
}
// Expected mem usage for this whole process should be the sum of MaxMessageBuildingMem and MaxDownloadRequestMem
// times x due to pipeline and all additional memory used by network requests and compression+io.
totalMemory := memory.TotalMemory()
if limits.MaxSyncMemory >= totalMemory/2 {
logrus.Warnf("Requested max sync memory of %v MB is greater than half of system memory (%v MB), forcing to half of system memory",
toMB(limits.MaxSyncMemory), toMB(totalMemory/2))
limits.MaxSyncMemory = totalMemory / 2
}
if limits.MaxSyncMemory < 800*Megabyte {
logrus.Warnf("Requested max sync memory of %v MB, but minimum recommended is 800 MB, forcing max syncMemory to 800MB", toMB(limits.MaxSyncMemory))
limits.MaxSyncMemory = 800 * Megabyte
}
logrus.Debugf("Total System Memory: %v", toMB(totalMemory))
// If less than 2GB available try and limit max memory to 512 MB
switch {
case limits.MaxSyncMemory < 2*Gigabyte:
if limits.MaxSyncMemory < 800*Megabyte {
logrus.Warnf("System has less than 800MB of memory, you may experience issues sycing large mailboxes")
}
limits.DownloadRequestMem = limits.MinDownloadRequestMem
limits.MessageBuildMem = limits.MinMessageBuildingMem
case limits.MaxSyncMemory == 2*Gigabyte:
// Increasing the max download capacity has very little effect on sync speed. We could increase the download
// memory but the user would see less sync notifications. A smaller value here leads to more frequent
// updates. Additionally, most of sync time is spent in the message building.
limits.DownloadRequestMem = limits.MaxDownloadRequestMem
// Currently limited so that if a user has multiple accounts active it also doesn't cause excessive memory usage.
limits.MessageBuildMem = limits.MaxMessageBuildingMem
default:
// Divide by 8 as download stage and build stage will use aprox. 4x the specified memory.
remainingMemory := (limits.MaxSyncMemory - 2*Gigabyte) / 8
limits.DownloadRequestMem = limits.MaxDownloadRequestMem + remainingMemory
limits.MessageBuildMem = limits.MaxMessageBuildingMem + remainingMemory
}
logrus.Debugf("Max memory usage for sync Download=%vMB Building=%vMB Predicted Max Total=%vMB",
toMB(limits.DownloadRequestMem),
toMB(limits.MessageBuildMem),
toMB((limits.MessageBuildMem*4)+(limits.DownloadRequestMem*4)),
)
return limits
}

View File

@ -0,0 +1,916 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/sync (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier)
// Package sync is a generated GoMock package.
package syncservice
import (
bytes "bytes"
context "context"
io "io"
reflect "reflect"
proton "github.com/ProtonMail/go-proton-api"
crypto "github.com/ProtonMail/gopenpgp/v2/crypto"
gomock "github.com/golang/mock/gomock"
)
// MockApplyStageInput is a mock of ApplyStageInput interface.
type MockApplyStageInput struct {
ctrl *gomock.Controller
recorder *MockApplyStageInputMockRecorder
}
// MockApplyStageInputMockRecorder is the mock recorder for MockApplyStageInput.
type MockApplyStageInputMockRecorder struct {
mock *MockApplyStageInput
}
// NewMockApplyStageInput creates a new mock instance.
func NewMockApplyStageInput(ctrl *gomock.Controller) *MockApplyStageInput {
mock := &MockApplyStageInput{ctrl: ctrl}
mock.recorder = &MockApplyStageInputMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockApplyStageInput) EXPECT() *MockApplyStageInputMockRecorder {
return m.recorder
}
// Consume mocks base method.
func (m *MockApplyStageInput) Consume(arg0 context.Context) (ApplyRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Consume", arg0)
ret0, _ := ret[0].(ApplyRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Consume indicates an expected call of Consume.
func (mr *MockApplyStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockApplyStageInput)(nil).Consume), arg0)
}
// TryConsume mocks base method.
func (m *MockApplyStageInput) TryConsume(arg0 context.Context) (ApplyRequest, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(ApplyRequest)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockApplyStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockApplyStageInput)(nil).TryConsume), arg0)
}
// MockBuildStageInput is a mock of BuildStageInput interface.
type MockBuildStageInput struct {
ctrl *gomock.Controller
recorder *MockBuildStageInputMockRecorder
}
// MockBuildStageInputMockRecorder is the mock recorder for MockBuildStageInput.
type MockBuildStageInputMockRecorder struct {
mock *MockBuildStageInput
}
// NewMockBuildStageInput creates a new mock instance.
func NewMockBuildStageInput(ctrl *gomock.Controller) *MockBuildStageInput {
mock := &MockBuildStageInput{ctrl: ctrl}
mock.recorder = &MockBuildStageInputMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockBuildStageInput) EXPECT() *MockBuildStageInputMockRecorder {
return m.recorder
}
// Consume mocks base method.
func (m *MockBuildStageInput) Consume(arg0 context.Context) (BuildRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Consume", arg0)
ret0, _ := ret[0].(BuildRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Consume indicates an expected call of Consume.
func (mr *MockBuildStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockBuildStageInput)(nil).Consume), arg0)
}
// TryConsume mocks base method.
func (m *MockBuildStageInput) TryConsume(arg0 context.Context) (BuildRequest, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(BuildRequest)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockBuildStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockBuildStageInput)(nil).TryConsume), arg0)
}
// MockBuildStageOutput is a mock of BuildStageOutput interface.
type MockBuildStageOutput struct {
ctrl *gomock.Controller
recorder *MockBuildStageOutputMockRecorder
}
// MockBuildStageOutputMockRecorder is the mock recorder for MockBuildStageOutput.
type MockBuildStageOutputMockRecorder struct {
mock *MockBuildStageOutput
}
// NewMockBuildStageOutput creates a new mock instance.
func NewMockBuildStageOutput(ctrl *gomock.Controller) *MockBuildStageOutput {
mock := &MockBuildStageOutput{ctrl: ctrl}
mock.recorder = &MockBuildStageOutputMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockBuildStageOutput) EXPECT() *MockBuildStageOutputMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockBuildStageOutput) Close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Close")
}
// Close indicates an expected call of Close.
func (mr *MockBuildStageOutputMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockBuildStageOutput)(nil).Close))
}
// Produce mocks base method.
func (m *MockBuildStageOutput) Produce(arg0 context.Context, arg1 ApplyRequest) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Produce", arg0, arg1)
}
// Produce indicates an expected call of Produce.
func (mr *MockBuildStageOutputMockRecorder) Produce(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Produce", reflect.TypeOf((*MockBuildStageOutput)(nil).Produce), arg0, arg1)
}
// MockDownloadStageInput is a mock of DownloadStageInput interface.
type MockDownloadStageInput struct {
ctrl *gomock.Controller
recorder *MockDownloadStageInputMockRecorder
}
// MockDownloadStageInputMockRecorder is the mock recorder for MockDownloadStageInput.
type MockDownloadStageInputMockRecorder struct {
mock *MockDownloadStageInput
}
// NewMockDownloadStageInput creates a new mock instance.
func NewMockDownloadStageInput(ctrl *gomock.Controller) *MockDownloadStageInput {
mock := &MockDownloadStageInput{ctrl: ctrl}
mock.recorder = &MockDownloadStageInputMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDownloadStageInput) EXPECT() *MockDownloadStageInputMockRecorder {
return m.recorder
}
// Consume mocks base method.
func (m *MockDownloadStageInput) Consume(arg0 context.Context) (DownloadRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Consume", arg0)
ret0, _ := ret[0].(DownloadRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Consume indicates an expected call of Consume.
func (mr *MockDownloadStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockDownloadStageInput)(nil).Consume), arg0)
}
// TryConsume mocks base method.
func (m *MockDownloadStageInput) TryConsume(arg0 context.Context) (DownloadRequest, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(DownloadRequest)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockDownloadStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockDownloadStageInput)(nil).TryConsume), arg0)
}
// MockDownloadStageOutput is a mock of DownloadStageOutput interface.
type MockDownloadStageOutput struct {
ctrl *gomock.Controller
recorder *MockDownloadStageOutputMockRecorder
}
// MockDownloadStageOutputMockRecorder is the mock recorder for MockDownloadStageOutput.
type MockDownloadStageOutputMockRecorder struct {
mock *MockDownloadStageOutput
}
// NewMockDownloadStageOutput creates a new mock instance.
func NewMockDownloadStageOutput(ctrl *gomock.Controller) *MockDownloadStageOutput {
mock := &MockDownloadStageOutput{ctrl: ctrl}
mock.recorder = &MockDownloadStageOutputMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDownloadStageOutput) EXPECT() *MockDownloadStageOutputMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockDownloadStageOutput) Close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Close")
}
// Close indicates an expected call of Close.
func (mr *MockDownloadStageOutputMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDownloadStageOutput)(nil).Close))
}
// Produce mocks base method.
func (m *MockDownloadStageOutput) Produce(arg0 context.Context, arg1 BuildRequest) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Produce", arg0, arg1)
}
// Produce indicates an expected call of Produce.
func (mr *MockDownloadStageOutputMockRecorder) Produce(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Produce", reflect.TypeOf((*MockDownloadStageOutput)(nil).Produce), arg0, arg1)
}
// MockMetadataStageInput is a mock of MetadataStageInput interface.
type MockMetadataStageInput struct {
ctrl *gomock.Controller
recorder *MockMetadataStageInputMockRecorder
}
// MockMetadataStageInputMockRecorder is the mock recorder for MockMetadataStageInput.
type MockMetadataStageInputMockRecorder struct {
mock *MockMetadataStageInput
}
// NewMockMetadataStageInput creates a new mock instance.
func NewMockMetadataStageInput(ctrl *gomock.Controller) *MockMetadataStageInput {
mock := &MockMetadataStageInput{ctrl: ctrl}
mock.recorder = &MockMetadataStageInputMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMetadataStageInput) EXPECT() *MockMetadataStageInputMockRecorder {
return m.recorder
}
// Consume mocks base method.
func (m *MockMetadataStageInput) Consume(arg0 context.Context) (*Job, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Consume", arg0)
ret0, _ := ret[0].(*Job)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Consume indicates an expected call of Consume.
func (mr *MockMetadataStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockMetadataStageInput)(nil).Consume), arg0)
}
// TryConsume mocks base method.
func (m *MockMetadataStageInput) TryConsume(arg0 context.Context) (*Job, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(*Job)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockMetadataStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockMetadataStageInput)(nil).TryConsume), arg0)
}
// MockMetadataStageOutput is a mock of MetadataStageOutput interface.
type MockMetadataStageOutput struct {
ctrl *gomock.Controller
recorder *MockMetadataStageOutputMockRecorder
}
// MockMetadataStageOutputMockRecorder is the mock recorder for MockMetadataStageOutput.
type MockMetadataStageOutputMockRecorder struct {
mock *MockMetadataStageOutput
}
// NewMockMetadataStageOutput creates a new mock instance.
func NewMockMetadataStageOutput(ctrl *gomock.Controller) *MockMetadataStageOutput {
mock := &MockMetadataStageOutput{ctrl: ctrl}
mock.recorder = &MockMetadataStageOutputMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMetadataStageOutput) EXPECT() *MockMetadataStageOutputMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockMetadataStageOutput) Close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Close")
}
// Close indicates an expected call of Close.
func (mr *MockMetadataStageOutputMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMetadataStageOutput)(nil).Close))
}
// Produce mocks base method.
func (m *MockMetadataStageOutput) Produce(arg0 context.Context, arg1 DownloadRequest) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Produce", arg0, arg1)
}
// Produce indicates an expected call of Produce.
func (mr *MockMetadataStageOutputMockRecorder) Produce(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Produce", reflect.TypeOf((*MockMetadataStageOutput)(nil).Produce), arg0, arg1)
}
// MockStateProvider is a mock of StateProvider interface.
type MockStateProvider struct {
ctrl *gomock.Controller
recorder *MockStateProviderMockRecorder
}
// MockStateProviderMockRecorder is the mock recorder for MockStateProvider.
type MockStateProviderMockRecorder struct {
mock *MockStateProvider
}
// NewMockStateProvider creates a new mock instance.
func NewMockStateProvider(ctrl *gomock.Controller) *MockStateProvider {
mock := &MockStateProvider{ctrl: ctrl}
mock.recorder = &MockStateProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStateProvider) EXPECT() *MockStateProviderMockRecorder {
return m.recorder
}
// AddFailedMessageID mocks base method.
func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddFailedMessageID", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// AddFailedMessageID indicates an expected call of AddFailedMessageID.
func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), arg0, arg1)
}
// ClearSyncStatus mocks base method.
func (m *MockStateProvider) ClearSyncStatus(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClearSyncStatus", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// ClearSyncStatus indicates an expected call of ClearSyncStatus.
func (mr *MockStateProviderMockRecorder) ClearSyncStatus(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearSyncStatus", reflect.TypeOf((*MockStateProvider)(nil).ClearSyncStatus), arg0)
}
// GetSyncStatus mocks base method.
func (m *MockStateProvider) GetSyncStatus(arg0 context.Context) (Status, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSyncStatus", arg0)
ret0, _ := ret[0].(Status)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSyncStatus indicates an expected call of GetSyncStatus.
func (mr *MockStateProviderMockRecorder) GetSyncStatus(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSyncStatus", reflect.TypeOf((*MockStateProvider)(nil).GetSyncStatus), arg0)
}
// RemFailedMessageID mocks base method.
func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemFailedMessageID", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// RemFailedMessageID indicates an expected call of RemFailedMessageID.
func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), arg0, arg1)
}
// SetHasLabels mocks base method.
func (m *MockStateProvider) SetHasLabels(arg0 context.Context, arg1 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetHasLabels", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SetHasLabels indicates an expected call of SetHasLabels.
func (mr *MockStateProviderMockRecorder) SetHasLabels(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHasLabels", reflect.TypeOf((*MockStateProvider)(nil).SetHasLabels), arg0, arg1)
}
// SetHasMessages mocks base method.
func (m *MockStateProvider) SetHasMessages(arg0 context.Context, arg1 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetHasMessages", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SetHasMessages indicates an expected call of SetHasMessages.
func (mr *MockStateProviderMockRecorder) SetHasMessages(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHasMessages", reflect.TypeOf((*MockStateProvider)(nil).SetHasMessages), arg0, arg1)
}
// SetLastMessageID mocks base method.
func (m *MockStateProvider) SetLastMessageID(arg0 context.Context, arg1 string, arg2 int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetLastMessageID", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SetLastMessageID indicates an expected call of SetLastMessageID.
func (mr *MockStateProviderMockRecorder) SetLastMessageID(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLastMessageID", reflect.TypeOf((*MockStateProvider)(nil).SetLastMessageID), arg0, arg1, arg2)
}
// SetMessageCount mocks base method.
func (m *MockStateProvider) SetMessageCount(arg0 context.Context, arg1 int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetMessageCount", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SetMessageCount indicates an expected call of SetMessageCount.
func (mr *MockStateProviderMockRecorder) SetMessageCount(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMessageCount", reflect.TypeOf((*MockStateProvider)(nil).SetMessageCount), arg0, arg1)
}
// MockRegulator is a mock of Regulator interface.
type MockRegulator struct {
ctrl *gomock.Controller
recorder *MockRegulatorMockRecorder
}
// MockRegulatorMockRecorder is the mock recorder for MockRegulator.
type MockRegulatorMockRecorder struct {
mock *MockRegulator
}
// NewMockRegulator creates a new mock instance.
func NewMockRegulator(ctrl *gomock.Controller) *MockRegulator {
mock := &MockRegulator{ctrl: ctrl}
mock.recorder = &MockRegulatorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRegulator) EXPECT() *MockRegulatorMockRecorder {
return m.recorder
}
// Sync mocks base method.
func (m *MockRegulator) Sync(arg0 context.Context, arg1 *Job) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Sync", arg0, arg1)
}
// Sync indicates an expected call of Sync.
func (mr *MockRegulatorMockRecorder) Sync(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sync", reflect.TypeOf((*MockRegulator)(nil).Sync), arg0, arg1)
}
// MockUpdateApplier is a mock of UpdateApplier interface.
type MockUpdateApplier struct {
ctrl *gomock.Controller
recorder *MockUpdateApplierMockRecorder
}
// MockUpdateApplierMockRecorder is the mock recorder for MockUpdateApplier.
type MockUpdateApplierMockRecorder struct {
mock *MockUpdateApplier
}
// NewMockUpdateApplier creates a new mock instance.
func NewMockUpdateApplier(ctrl *gomock.Controller) *MockUpdateApplier {
mock := &MockUpdateApplier{ctrl: ctrl}
mock.recorder = &MockUpdateApplierMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockUpdateApplier) EXPECT() *MockUpdateApplierMockRecorder {
return m.recorder
}
// ApplySyncUpdates mocks base method.
func (m *MockUpdateApplier) ApplySyncUpdates(arg0 context.Context, arg1 []BuildResult) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ApplySyncUpdates", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// ApplySyncUpdates indicates an expected call of ApplySyncUpdates.
func (mr *MockUpdateApplierMockRecorder) ApplySyncUpdates(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplySyncUpdates", reflect.TypeOf((*MockUpdateApplier)(nil).ApplySyncUpdates), arg0, arg1)
}
// SyncLabels mocks base method.
func (m *MockUpdateApplier) SyncLabels(arg0 context.Context, arg1 map[string]proton.Label) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SyncLabels", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SyncLabels indicates an expected call of SyncLabels.
func (mr *MockUpdateApplierMockRecorder) SyncLabels(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncLabels", reflect.TypeOf((*MockUpdateApplier)(nil).SyncLabels), arg0, arg1)
}
// SyncSystemLabelsOnly mocks base method.
func (m *MockUpdateApplier) SyncSystemLabelsOnly(arg0 context.Context, arg1 map[string]proton.Label) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SyncSystemLabelsOnly", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SyncSystemLabelsOnly indicates an expected call of SyncSystemLabelsOnly.
func (mr *MockUpdateApplierMockRecorder) SyncSystemLabelsOnly(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncSystemLabelsOnly", reflect.TypeOf((*MockUpdateApplier)(nil).SyncSystemLabelsOnly), arg0, arg1)
}
// MockMessageBuilder is a mock of MessageBuilder interface.
type MockMessageBuilder struct {
ctrl *gomock.Controller
recorder *MockMessageBuilderMockRecorder
}
// MockMessageBuilderMockRecorder is the mock recorder for MockMessageBuilder.
type MockMessageBuilderMockRecorder struct {
mock *MockMessageBuilder
}
// NewMockMessageBuilder creates a new mock instance.
func NewMockMessageBuilder(ctrl *gomock.Controller) *MockMessageBuilder {
mock := &MockMessageBuilder{ctrl: ctrl}
mock.recorder = &MockMessageBuilderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMessageBuilder) EXPECT() *MockMessageBuilderMockRecorder {
return m.recorder
}
// BuildMessage mocks base method.
func (m *MockMessageBuilder) BuildMessage(arg0 map[string]proton.Label, arg1 proton.FullMessage, arg2 *crypto.KeyRing, arg3 *bytes.Buffer) (BuildResult, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BuildMessage", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(BuildResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BuildMessage indicates an expected call of BuildMessage.
func (mr *MockMessageBuilderMockRecorder) BuildMessage(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildMessage", reflect.TypeOf((*MockMessageBuilder)(nil).BuildMessage), arg0, arg1, arg2, arg3)
}
// WithKeys mocks base method.
func (m *MockMessageBuilder) WithKeys(arg0 func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WithKeys", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// WithKeys indicates an expected call of WithKeys.
func (mr *MockMessageBuilderMockRecorder) WithKeys(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithKeys", reflect.TypeOf((*MockMessageBuilder)(nil).WithKeys), arg0)
}
// MockAPIClient is a mock of APIClient interface.
type MockAPIClient struct {
ctrl *gomock.Controller
recorder *MockAPIClientMockRecorder
}
// MockAPIClientMockRecorder is the mock recorder for MockAPIClient.
type MockAPIClientMockRecorder struct {
mock *MockAPIClient
}
// NewMockAPIClient creates a new mock instance.
func NewMockAPIClient(ctrl *gomock.Controller) *MockAPIClient {
mock := &MockAPIClient{ctrl: ctrl}
mock.recorder = &MockAPIClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockAPIClient) EXPECT() *MockAPIClientMockRecorder {
return m.recorder
}
// GetAttachment mocks base method.
func (m *MockAPIClient) GetAttachment(arg0 context.Context, arg1 string) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAttachment", arg0, arg1)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAttachment indicates an expected call of GetAttachment.
func (mr *MockAPIClientMockRecorder) GetAttachment(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockAPIClient)(nil).GetAttachment), arg0, arg1)
}
// GetAttachmentInto mocks base method.
func (m *MockAPIClient) GetAttachmentInto(arg0 context.Context, arg1 string, arg2 io.ReaderFrom) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAttachmentInto", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// GetAttachmentInto indicates an expected call of GetAttachmentInto.
func (mr *MockAPIClientMockRecorder) GetAttachmentInto(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachmentInto", reflect.TypeOf((*MockAPIClient)(nil).GetAttachmentInto), arg0, arg1, arg2)
}
// GetFullMessage mocks base method.
func (m *MockAPIClient) GetFullMessage(arg0 context.Context, arg1 string, arg2 proton.Scheduler, arg3 proton.AttachmentAllocator) (proton.FullMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetFullMessage", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(proton.FullMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetFullMessage indicates an expected call of GetFullMessage.
func (mr *MockAPIClientMockRecorder) GetFullMessage(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFullMessage", reflect.TypeOf((*MockAPIClient)(nil).GetFullMessage), arg0, arg1, arg2, arg3)
}
// GetGroupedMessageCount mocks base method.
func (m *MockAPIClient) GetGroupedMessageCount(arg0 context.Context) ([]proton.MessageGroupCount, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupedMessageCount", arg0)
ret0, _ := ret[0].([]proton.MessageGroupCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupedMessageCount indicates an expected call of GetGroupedMessageCount.
func (mr *MockAPIClientMockRecorder) GetGroupedMessageCount(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupedMessageCount", reflect.TypeOf((*MockAPIClient)(nil).GetGroupedMessageCount), arg0)
}
// GetLabels mocks base method.
func (m *MockAPIClient) GetLabels(arg0 context.Context, arg1 ...proton.LabelType) ([]proton.Label, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "GetLabels", varargs...)
ret0, _ := ret[0].([]proton.Label)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetLabels indicates an expected call of GetLabels.
func (mr *MockAPIClientMockRecorder) GetLabels(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLabels", reflect.TypeOf((*MockAPIClient)(nil).GetLabels), varargs...)
}
// GetMessage mocks base method.
func (m *MockAPIClient) GetMessage(arg0 context.Context, arg1 string) (proton.Message, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMessage", arg0, arg1)
ret0, _ := ret[0].(proton.Message)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMessage indicates an expected call of GetMessage.
func (mr *MockAPIClientMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockAPIClient)(nil).GetMessage), arg0, arg1)
}
// GetMessageIDs mocks base method.
func (m *MockAPIClient) GetMessageIDs(arg0 context.Context, arg1 string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMessageIDs", arg0, arg1)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMessageIDs indicates an expected call of GetMessageIDs.
func (mr *MockAPIClientMockRecorder) GetMessageIDs(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageIDs", reflect.TypeOf((*MockAPIClient)(nil).GetMessageIDs), arg0, arg1)
}
// GetMessageMetadataPage mocks base method.
func (m *MockAPIClient) GetMessageMetadataPage(arg0 context.Context, arg1, arg2 int, arg3 proton.MessageFilter) ([]proton.MessageMetadata, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMessageMetadataPage", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]proton.MessageMetadata)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMessageMetadataPage indicates an expected call of GetMessageMetadataPage.
func (mr *MockAPIClientMockRecorder) GetMessageMetadataPage(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageMetadataPage", reflect.TypeOf((*MockAPIClient)(nil).GetMessageMetadataPage), arg0, arg1, arg2, arg3)
}
// MockReporter is a mock of Reporter interface.
type MockReporter struct {
ctrl *gomock.Controller
recorder *MockReporterMockRecorder
}
// MockReporterMockRecorder is the mock recorder for MockReporter.
type MockReporterMockRecorder struct {
mock *MockReporter
}
// NewMockReporter creates a new mock instance.
func NewMockReporter(ctrl *gomock.Controller) *MockReporter {
mock := &MockReporter{ctrl: ctrl}
mock.recorder = &MockReporterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockReporter) EXPECT() *MockReporterMockRecorder {
return m.recorder
}
// OnError mocks base method.
func (m *MockReporter) OnError(arg0 context.Context, arg1 error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnError", arg0, arg1)
}
// OnError indicates an expected call of OnError.
func (mr *MockReporterMockRecorder) OnError(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockReporter)(nil).OnError), arg0, arg1)
}
// OnFinished mocks base method.
func (m *MockReporter) OnFinished(arg0 context.Context) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnFinished", arg0)
}
// OnFinished indicates an expected call of OnFinished.
func (mr *MockReporterMockRecorder) OnFinished(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnFinished", reflect.TypeOf((*MockReporter)(nil).OnFinished), arg0)
}
// OnProgress mocks base method.
func (m *MockReporter) OnProgress(arg0 context.Context, arg1 int64) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnProgress", arg0, arg1)
}
// OnProgress indicates an expected call of OnProgress.
func (mr *MockReporterMockRecorder) OnProgress(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnProgress", reflect.TypeOf((*MockReporter)(nil).OnProgress), arg0, arg1)
}
// OnStart mocks base method.
func (m *MockReporter) OnStart(arg0 context.Context) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnStart", arg0)
}
// OnStart indicates an expected call of OnStart.
func (mr *MockReporterMockRecorder) OnStart(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnStart", reflect.TypeOf((*MockReporter)(nil).OnStart), arg0)
}
// MockDownloadRateModifier is a mock of DownloadRateModifier interface.
type MockDownloadRateModifier struct {
ctrl *gomock.Controller
recorder *MockDownloadRateModifierMockRecorder
}
// MockDownloadRateModifierMockRecorder is the mock recorder for MockDownloadRateModifier.
type MockDownloadRateModifierMockRecorder struct {
mock *MockDownloadRateModifier
}
// NewMockDownloadRateModifier creates a new mock instance.
func NewMockDownloadRateModifier(ctrl *gomock.Controller) *MockDownloadRateModifier {
mock := &MockDownloadRateModifier{ctrl: ctrl}
mock.recorder = &MockDownloadRateModifierMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDownloadRateModifier) EXPECT() *MockDownloadRateModifierMockRecorder {
return m.recorder
}
// Apply mocks base method.
func (m *MockDownloadRateModifier) Apply(arg0 bool, arg1, arg2 int) int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Apply", arg0, arg1, arg2)
ret0, _ := ret[0].(int)
return ret0
}
// Apply indicates an expected call of Apply.
func (mr *MockDownloadRateModifierMockRecorder) Apply(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockDownloadRateModifier)(nil).Apply), arg0, arg1, arg2)
}

View File

@ -0,0 +1,78 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/reporter"
)
// Service which mediates IMAP syncing in Bridge.
// IMPORTANT: Be sure to cancel all ongoing sync Handlers before cancelling this service's Group.
type Service struct {
metadataStage *MetadataStage
downloadStage *DownloadStage
buildStage *BuildStage
applyStage *ApplyStage
limits syncLimits
metaCh *ChannelConsumerProducer[*Job]
panicHandler async.PanicHandler
}
func NewService(reporter reporter.Reporter,
panicHandler async.PanicHandler,
) *Service {
limits := newSyncLimits(2 * Gigabyte)
metaCh := NewChannelConsumerProducer[*Job]()
downloadCh := NewChannelConsumerProducer[DownloadRequest]()
buildCh := NewChannelConsumerProducer[BuildRequest]()
applyCh := NewChannelConsumerProducer[ApplyRequest]()
return &Service{
limits: limits,
metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem),
downloadStage: NewDownloadStage(downloadCh, buildCh, 20, panicHandler),
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
applyStage: NewApplyStage(applyCh),
metaCh: metaCh,
panicHandler: panicHandler,
}
}
func (s *Service) Run(group *async.Group) {
group.Once(func(ctx context.Context) {
syncGroup := async.NewGroup(ctx, s.panicHandler)
s.metadataStage.Run(syncGroup)
s.downloadStage.Run(syncGroup)
s.buildStage.Run(syncGroup)
s.applyStage.Run(syncGroup)
defer s.metaCh.Close()
defer syncGroup.CancelAndWait()
<-ctx.Done()
})
}
func (s *Service) Sync(ctx context.Context, stage *Job) {
s.metaCh.Produce(ctx, stage)
}

View File

@ -0,0 +1,78 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"errors"
"github.com/ProtonMail/gluon/async"
"github.com/sirupsen/logrus"
)
type ApplyRequest struct {
childJob
messages []BuildResult
}
type ApplyStageInput = StageInputConsumer[ApplyRequest]
// ApplyStage applies the sync updates and waits for their completion before proceeding with the next batch. This is
// the final stage in the sync pipeline.
type ApplyStage struct {
input ApplyStageInput
log *logrus.Entry
}
func NewApplyStage(input ApplyStageInput) *ApplyStage {
return &ApplyStage{input: input, log: logrus.WithField("sync-stage", "apply")}
}
func (a *ApplyStage) Run(group *async.Group) {
group.Once(a.run)
}
func (a *ApplyStage) run(ctx context.Context) {
for {
req, err := a.input.Consume(ctx)
if err != nil {
if !(errors.Is(err, ErrNoMoreInput) || errors.Is(err, context.Canceled)) {
a.log.WithError(err).Error("Exiting state with error")
}
return
}
if req.checkCancelled() {
continue
}
if len(req.messages) == 0 {
req.onFinished(req.getContext())
continue
}
if err := req.job.updateApplier.ApplySyncUpdates(ctx, req.messages); err != nil {
a.log.WithError(err).Error("Failed to apply sync updates")
req.job.onError(err)
continue
}
req.onFinished(req.getContext())
}
}

View File

@ -0,0 +1,138 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"errors"
"testing"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/go-proton-api"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestApplyStage_CancelledJobIsDiscarded(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[ApplyRequest]()
stage := NewApplyStage(input)
ctx, cancel := context.WithCancel(context.Background())
jobCtx, jobCancel := context.WithCancel(context.Background())
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
tj.job.begin()
childJob := tj.job.newChildJob("f", 10)
tj.job.end()
go func() {
stage.run(ctx)
}()
jobCancel()
input.Produce(ctx, ApplyRequest{
childJob: childJob,
messages: nil,
})
err := tj.job.wait(ctx)
require.ErrorIs(t, err, context.Canceled)
cancel()
}
func TestApplyStage_JobWithNoMessagesIsFinalized(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[ApplyRequest]()
stage := NewApplyStage(input)
ctx, cancel := context.WithCancel(context.Background())
jobCtx, jobCancel := context.WithCancel(context.Background())
defer jobCancel()
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any())
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("f"), gomock.Eq(int64(10)))
tj.job.begin()
childJob := tj.job.newChildJob("f", 10)
tj.job.end()
go func() {
stage.run(ctx)
}()
input.Produce(ctx, ApplyRequest{
childJob: childJob,
messages: nil,
})
err := tj.job.wait(ctx)
cancel()
require.NoError(t, err)
}
func TestApplyStage_ErrorOnApplyIsReportedAndJobFails(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[ApplyRequest]()
stage := NewApplyStage(input)
ctx, cancel := context.WithCancel(context.Background())
jobCtx, jobCancel := context.WithCancel(context.Background())
defer jobCancel()
buildResults := []BuildResult{
{
AddressID: "Foo",
MessageID: "Bar",
Update: &imap.MessageCreated{},
},
}
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
applyErr := errors.New("apply failed")
tj.updateApplier.EXPECT().ApplySyncUpdates(gomock.Any(), gomock.Eq(buildResults)).Return(applyErr)
tj.job.begin()
childJob := tj.job.newChildJob("f", 10)
tj.job.end()
go func() {
stage.run(ctx)
}()
input.Produce(ctx, ApplyRequest{
childJob: childJob,
messages: buildResults,
})
err := tj.job.wait(ctx)
cancel()
require.ErrorIs(t, err, applyErr)
}

View File

@ -0,0 +1,198 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"bytes"
"context"
"errors"
"runtime"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
)
type BuildRequest struct {
childJob
batch []proton.FullMessage
}
type BuildStageInput = StageInputConsumer[BuildRequest]
type BuildStageOutput = StageOutputProducer[ApplyRequest]
// BuildStage is in charge of decrypting and converting the downloaded messages from the previous stage into
// RFC822 compliant messages which can then be sent to the IMAP server.
type BuildStage struct {
input BuildStageInput
output BuildStageOutput
maxBuildMem uint64
panicHandler async.PanicHandler
reporter reporter.Reporter
log *logrus.Entry
}
func NewBuildStage(
input BuildStageInput,
output BuildStageOutput,
maxBuildMem uint64,
panicHandler async.PanicHandler,
reporter reporter.Reporter,
) *BuildStage {
return &BuildStage{
input: input,
output: output,
maxBuildMem: maxBuildMem,
log: logrus.WithField("sync-stage", "build"),
panicHandler: panicHandler,
reporter: reporter,
}
}
func (b *BuildStage) Run(group *async.Group) {
group.Once(b.run)
}
func (b *BuildStage) run(ctx context.Context) {
maxMessagesInParallel := runtime.NumCPU()
defer b.output.Close()
for {
req, err := b.input.Consume(ctx)
if err != nil {
if !(errors.Is(err, ErrNoMoreInput) || errors.Is(err, context.Canceled)) {
b.log.WithError(err).Error("Exiting state with error")
}
return
}
if req.checkCancelled() {
continue
}
err = req.job.messageBuilder.WithKeys(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
chunks := chunkSyncBuilderBatch(req.batch, b.maxBuildMem)
for _, chunk := range chunks {
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (BuildResult, error) {
defer async.HandlePanic(b.panicHandler)
kr, ok := addrKRs[msg.AddressID]
if !ok {
req.job.log.Errorf("Address '%v' on message '%v' does not have an unlocked kerying", msg.AddressID, msg.ID)
if err := req.job.state.AddFailedMessageID(req.getContext(), msg.ID); err != nil {
req.job.log.WithError(err).Error("Failed to add failed message ID")
}
if err := b.reporter.ReportMessageWithContext("Failed to build message - no unlocked keyring (sync)", reporter.Context{
"messageID": msg.ID,
"userID": req.userID(),
}); err != nil {
req.job.log.WithError(err).Error("Failed to report message build error")
}
return BuildResult{}, nil
}
res, err := req.job.messageBuilder.BuildMessage(req.job.labels, msg, kr, new(bytes.Buffer))
if err != nil {
req.job.log.WithError(err).WithField("msgID", msg.ID).Error("Failed to build message (syn)")
if err := req.job.state.AddFailedMessageID(req.getContext(), msg.ID); err != nil {
req.job.log.WithError(err).Error("Failed to add failed message ID")
}
if err := b.reporter.ReportMessageWithContext("Failed to build message (sync)", reporter.Context{
"messageID": msg.ID,
"error": err,
"userID": req.userID(),
}); err != nil {
req.job.log.WithError(err).Error("Failed to report message build error")
}
// We could sync a placeholder message here, but for now we skip it entirely.
return BuildResult{}, nil
}
if err := req.job.state.RemFailedMessageID(req.getContext(), res.MessageID); err != nil {
req.job.log.WithError(err).Error("Failed to remove failed message ID")
}
return res, nil
})
if err != nil {
return err
}
b.output.Produce(ctx, ApplyRequest{
childJob: req.childJob,
messages: xslices.Filter(result, func(t BuildResult) bool {
return t.Update != nil
}),
})
}
return nil
})
if err != nil {
req.job.onError(err)
}
}
}
func chunkSyncBuilderBatch(batch []proton.FullMessage, maxMemory uint64) [][]proton.FullMessage {
var expectedMemUsage uint64
var chunks [][]proton.FullMessage
var lastIndex int
var index int
for _, v := range batch {
var dataSize uint64
for _, a := range v.Attachments {
dataSize += uint64(a.Size)
}
// 2x increase for attachment due to extra memory needed for decrypting and writing
// in memory buffer.
dataSize *= 2
dataSize += uint64(len(v.Body))
nextMemSize := expectedMemUsage + dataSize
if nextMemSize >= maxMemory {
chunks = append(chunks, batch[lastIndex:index])
lastIndex = index
expectedMemUsage = dataSize
} else {
expectedMemUsage = nextMemSize
}
index++
}
if lastIndex < len(batch) {
chunks = append(chunks, batch[lastIndex:])
}
return chunks
}

View File

@ -0,0 +1,317 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"errors"
"testing"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/bridge/mocks"
"github.com/bradenaw/juniper/xslices"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestSyncChunkSyncBuilderBatch(t *testing.T) {
// GODT-2424 - Some messages were not fully built due to a bug in the chunking if the total memory used by the
// message would be higher than the maximum we allowed.
const totalMessageCount = 100
msg := proton.FullMessage{
Message: proton.Message{
Attachments: []proton.Attachment{
{
Size: int64(8 * Megabyte),
},
},
},
AttData: nil,
}
messages := xslices.Repeat(msg, totalMessageCount)
chunks := chunkSyncBuilderBatch(messages, 16*Megabyte)
var totalMessagesInChunks int
for _, v := range chunks {
totalMessagesInChunks += len(v)
}
require.Equal(t, totalMessagesInChunks, totalMessageCount)
}
func TestBuildStage_SuccessRemovesFailedMessage(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[BuildRequest]()
output := NewChannelConsumerProducer[ApplyRequest]()
reporter := mocks.NewMockReporter(mockCtrl)
labels := getTestLabels()
ctx, cancel := context.WithCancel(context.Background())
tj := newTestJob(ctx, mockCtrl, "u", labels)
msg := proton.FullMessage{
Message: proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "MSG",
AddressID: "addrID",
},
},
}
tj.messageBuilder.EXPECT().WithKeys(gomock.Any()).DoAndReturn(func(f func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
require.NoError(t, f(nil, map[string]*crypto.KeyRing{
"addrID": {},
}))
return nil
})
tj.job.begin()
childJob := tj.job.newChildJob("f", 10)
tj.job.end()
buildResult := BuildResult{
AddressID: "addrID",
MessageID: "MSG",
Update: &imap.MessageCreated{},
}
tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(buildResult, nil)
tj.state.EXPECT().RemFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, reporter)
go func() {
stage.run(ctx)
}()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
req, err := output.Consume(ctx)
cancel()
require.NoError(t, err)
require.Len(t, req.messages, 1)
require.Equal(t, buildResult, req.messages[0])
}
func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[BuildRequest]()
output := NewChannelConsumerProducer[ApplyRequest]()
mockReporter := mocks.NewMockReporter(mockCtrl)
labels := getTestLabels()
ctx, cancel := context.WithCancel(context.Background())
tj := newTestJob(ctx, mockCtrl, "u", labels)
msg := proton.FullMessage{
Message: proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "MSG",
AddressID: "addrID",
},
},
}
tj.messageBuilder.EXPECT().WithKeys(gomock.Any()).DoAndReturn(func(f func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
require.NoError(t, f(nil, map[string]*crypto.KeyRing{
"addrID": {},
}))
return nil
})
tj.job.begin()
childJob := tj.job.newChildJob("f", 10)
tj.job.end()
buildError := errors.New("it failed")
tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(BuildResult{}, buildError)
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
"userID": "u",
"messageID": "MSG",
"error": buildError,
})).Return(nil)
stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, mockReporter)
go func() {
stage.run(ctx)
}()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
req, err := output.Consume(ctx)
cancel()
require.NoError(t, err)
require.Empty(t, req.messages)
}
func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[BuildRequest]()
output := NewChannelConsumerProducer[ApplyRequest]()
mockReporter := mocks.NewMockReporter(mockCtrl)
labels := getTestLabels()
ctx, cancel := context.WithCancel(context.Background())
tj := newTestJob(ctx, mockCtrl, "u", labels)
msg := proton.FullMessage{
Message: proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "MSG",
AddressID: "addrID",
},
},
}
tj.messageBuilder.EXPECT().WithKeys(gomock.Any()).DoAndReturn(func(f func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
require.NoError(t, f(nil, map[string]*crypto.KeyRing{}))
return nil
})
tj.job.begin()
childJob := tj.job.newChildJob("f", 10)
tj.job.end()
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
"userID": "u",
"messageID": "MSG",
})).Return(nil)
stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, mockReporter)
go func() {
stage.run(ctx)
}()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
req, err := output.Consume(ctx)
cancel()
require.NoError(t, err)
require.Empty(t, req.messages)
}
func TestBuildStage_OtherErrorsFailJob(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[BuildRequest]()
output := NewChannelConsumerProducer[ApplyRequest]()
mockReporter := mocks.NewMockReporter(mockCtrl)
labels := getTestLabels()
ctx, cancel := context.WithCancel(context.Background())
tj := newTestJob(ctx, mockCtrl, "u", labels)
msg := proton.FullMessage{
Message: proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "MSG",
AddressID: "addrID",
},
},
}
expectedErr := errors.New("something went wrong")
tj.messageBuilder.EXPECT().WithKeys(gomock.Any()).DoAndReturn(func(f func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
return expectedErr
})
tj.job.begin()
childJob := tj.job.newChildJob("f", 10)
tj.job.end()
stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, mockReporter)
go func() {
stage.run(ctx)
}()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
err := tj.job.wait(ctx)
require.Equal(t, expectedErr, err)
cancel()
_, err = output.Consume(context.Background())
require.ErrorIs(t, err, ErrNoMoreInput)
}
func TestBuildStage_CancelledJobIsDiscarded(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[BuildRequest]()
output := NewChannelConsumerProducer[ApplyRequest]()
mockReporter := mocks.NewMockReporter(mockCtrl)
msg := proton.FullMessage{
Message: proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "MSG",
AddressID: "addrID",
},
},
}
stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, mockReporter)
ctx, cancel := context.WithCancel(context.Background())
jobCtx, jobCancel := context.WithCancel(context.Background())
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
tj.job.begin()
defer tj.job.end()
childJob := tj.job.newChildJob("f", 10)
go func() {
stage.run(ctx)
}()
jobCancel()
input.Produce(ctx, BuildRequest{
childJob: childJob,
batch: []proton.FullMessage{msg},
})
go func() { cancel() }()
_, err := output.Consume(context.Background())
require.ErrorIs(t, err, ErrNoMoreInput)
}

View File

@ -0,0 +1,279 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"errors"
"sync/atomic"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/logging"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/network"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
)
type DownloadRequest struct {
childJob
ids []string
}
type DownloadStageInput = StageInputConsumer[DownloadRequest]
type DownloadStageOutput = StageOutputProducer[BuildRequest]
// DownloadStage downloads the messages and attachments. It auto-throttles the download of the messages based on
// whether we run into 429|5xx codes.
type DownloadStage struct {
input DownloadStageInput
output DownloadStageOutput
maxParallelDownloads int
panicHandler async.PanicHandler
log *logrus.Entry
}
func NewDownloadStage(
input DownloadStageInput,
output DownloadStageOutput,
maxParallelDownloads int,
panicHandler async.PanicHandler,
) *DownloadStage {
return &DownloadStage{
input: input,
output: output,
maxParallelDownloads: maxParallelDownloads,
panicHandler: panicHandler,
log: logrus.WithField("sync-stage", "download"),
}
}
func (d *DownloadStage) Run(group *async.Group) {
group.Once(func(ctx context.Context) {
logging.DoAnnotated(ctx, func(ctx context.Context) {
d.run(ctx)
}, logging.Labels{"sync-stage": "Download"})
})
}
func (d *DownloadStage) run(ctx context.Context) {
defer d.output.Close()
newCoolDown := func() network.CoolDownProvider {
return &network.ExpCoolDown{}
}
for {
request, err := d.input.Consume(ctx)
if err != nil {
if !(errors.Is(err, ErrNoMoreInput) || errors.Is(err, context.Canceled)) {
d.log.WithError(err).Error("Exiting state with error")
}
return
}
if request.checkCancelled() {
continue
}
// Step 1: Download Messages.
result, err := autoDownloadRate(
ctx,
&DefaultDownloadRateModifier{},
request.job.client,
d.maxParallelDownloads,
request.ids,
newCoolDown,
func(ctx context.Context, client APIClient, input string) (proton.FullMessage, error) {
msg, err := downloadMessage(ctx, request.job.downloadCache, client, input)
if err != nil {
var apiErr *proton.APIError
if errors.As(err, &apiErr) && apiErr.Status == 422 {
return proton.FullMessage{}, nil
}
return proton.FullMessage{}, err
}
var attData [][]byte
if msg.NumAttachments > 0 {
attData = make([][]byte, msg.NumAttachments)
}
return proton.FullMessage{Message: msg, AttData: attData}, nil
},
)
if err != nil {
request.job.onError(err)
continue
}
// Step 2: Prepare attachment ids for download.
type attachmentMeta struct {
msgIdx int
attIdx int
}
// Filter out any messages that don't exist.
result = xslices.Filter(result, func(t proton.FullMessage) bool {
return t.ID != ""
})
attachmentIndices := make([]attachmentMeta, 0, len(result))
attachmentIDs := make([]string, 0, len(result))
for msgIdx, v := range result {
for attIdx := 0; attIdx < v.NumAttachments; attIdx++ {
attachmentIndices = append(attachmentIndices, attachmentMeta{
msgIdx: msgIdx,
attIdx: attIdx,
})
attachmentIDs = append(attachmentIDs, result[msgIdx].Attachments[attIdx].ID)
}
}
// Step 3: Download attachments data to the message.
attachments, err := autoDownloadRate(
ctx,
&DefaultDownloadRateModifier{},
request.job.client,
d.maxParallelDownloads,
attachmentIndices,
newCoolDown,
func(ctx context.Context, client APIClient, input attachmentMeta) ([]byte, error) {
return downloadAttachment(ctx, request.job.downloadCache, client, result[input.msgIdx].Attachments[input.attIdx].ID)
},
)
if err != nil {
request.job.onError(err)
continue
}
// Step 4: attach attachment data to the message.
for i, meta := range attachmentIndices {
result[meta.msgIdx].AttData[meta.attIdx] = attachments[i]
}
request.cachedAttachmentIDs = attachmentIDs
request.cachedMessageIDs = request.ids
// Step 5: Publish result.
d.output.Produce(ctx, BuildRequest{
batch: result,
childJob: request.childJob,
})
}
}
func downloadMessage(ctx context.Context, cache *DownloadCache, client APIClient, id string) (proton.Message, error) {
msg, ok := cache.GetMessage(id)
if ok {
return msg, nil
}
msg, err := client.GetMessage(ctx, id)
if err != nil {
return proton.Message{}, err
}
cache.StoreMessage(msg)
return msg, nil
}
func downloadAttachment(ctx context.Context, cache *DownloadCache, client APIClient, id string) ([]byte, error) {
data, ok := cache.GetAttachment(id)
if ok {
return data, nil
}
data, err := client.GetAttachment(ctx, id)
if err != nil {
return nil, err
}
cache.StoreAttachment(id, data)
return data, nil
}
type DownloadRateModifier interface {
Apply(wasSuccess bool, current int, max int) int
}
func autoDownloadRate[T any, R any](
ctx context.Context,
modifier DownloadRateModifier,
client APIClient,
maxParallelDownloads int,
data []T,
newCoolDown func() network.CoolDownProvider,
f func(ctx context.Context, client APIClient, input T) (R, error),
) ([]R, error) {
result := make([]R, 0, len(data))
proton429or5xxCounter := int32(0)
parallelTasks := maxParallelDownloads
for _, chunk := range xslices.Chunk(data, maxParallelDownloads) {
parallelTasks = modifier.Apply(atomic.LoadInt32(&proton429or5xxCounter) != 0, parallelTasks, maxParallelDownloads)
atomic.StoreInt32(&proton429or5xxCounter, 0)
chunkResult, err := parallel.MapContext(
ctx,
parallelTasks,
chunk,
func(ctx context.Context, in T) (R, error) {
wrapper := network.NewClientRetryWrapper(client, newCoolDown())
msg, err := network.RetryWithClient(ctx, wrapper, func(ctx context.Context, c APIClient) (R, error) {
return f(ctx, c, in)
})
if wrapper.DidEncounter429or5xx() {
atomic.AddInt32(&proton429or5xxCounter, 1)
}
return msg, err
})
if err != nil {
return nil, err
}
result = append(result, chunkResult...)
}
return result, nil
}
type DefaultDownloadRateModifier struct{}
func (d DefaultDownloadRateModifier) Apply(wasSuccess bool, current int, max int) int {
if !wasSuccess {
return 2
}
parallelTasks := current * 2
if parallelTasks > max {
parallelTasks = max
}
return parallelTasks
}

View File

@ -0,0 +1,472 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"errors"
"fmt"
"testing"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/network"
"github.com/bradenaw/juniper/xslices"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestDownloadMessage_NotInCache(t *testing.T) {
mockCtrl := gomock.NewController(t)
client := NewMockAPIClient(mockCtrl)
cache := newDownloadCache()
client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{}, nil)
_, err := downloadMessage(context.Background(), cache, client, "msg")
require.NoError(t, err)
}
func TestDownloadMessage_InCache(t *testing.T) {
mockCtrl := gomock.NewController(t)
client := NewMockAPIClient(mockCtrl)
cache := newDownloadCache()
msg := proton.Message{
MessageMetadata: proton.MessageMetadata{ID: "msg", Size: 1024},
}
cache.StoreMessage(msg)
downloaded, err := downloadMessage(context.Background(), cache, client, "msg")
require.NoError(t, err)
require.Equal(t, msg, downloaded)
}
func TestDownloadAttachment_NotInCache(t *testing.T) {
mockCtrl := gomock.NewController(t)
client := NewMockAPIClient(mockCtrl)
cache := newDownloadCache()
client.EXPECT().GetAttachment(gomock.Any(), gomock.Any()).Return(nil, nil)
_, err := downloadAttachment(context.Background(), cache, client, "id")
require.NoError(t, err)
}
func TestDownloadAttachment_InCache(t *testing.T) {
mockCtrl := gomock.NewController(t)
client := NewMockAPIClient(mockCtrl)
cache := newDownloadCache()
attachment := []byte("hello world")
cache.StoreAttachment("id", attachment)
downloaded, err := downloadAttachment(context.Background(), cache, client, "id")
require.NoError(t, err)
require.Equal(t, attachment, downloaded)
}
func TestAutoDownloadScale_AllOkay(t *testing.T) {
mockCtrl := gomock.NewController(t)
client := NewMockAPIClient(mockCtrl)
data := buildDownloadScaleData(15)
const MaxParallel = 5
call1 := client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Times(5).DoAndReturn(autoDownloadScaleClientDoAndReturn)
call2 := client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Times(5).After(call1).DoAndReturn(autoDownloadScaleClientDoAndReturn)
client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Times(5).After(call2).DoAndReturn(autoDownloadScaleClientDoAndReturn)
msgs, err := autoDownloadRate(
context.Background(),
&DefaultDownloadRateModifier{},
client,
MaxParallel,
data,
autoScaleCoolDown,
func(ctx context.Context, client APIClient, input string) (proton.Message, error) {
return client.GetMessage(ctx, input)
},
)
require.NoError(t, err)
require.Equal(t, xslices.Map(data, newDownloadScaleMessage), msgs)
}
func TestAutoDownloadScale_429or500x(t *testing.T) {
mockCtrl := gomock.NewController(t)
client := NewMockAPIClient(mockCtrl)
data := buildDownloadScaleData(32)
rateModifier := NewMockDownloadRateModifier(mockCtrl)
const MaxParallel = 8
for _, d := range data {
switch d {
case "m7":
call429 := client.EXPECT().GetMessage(gomock.Any(), gomock.Eq("m7")).DoAndReturn(func(_ context.Context, id string) (proton.Message, error) {
return proton.Message{}, &proton.APIError{Status: 429}
})
client.EXPECT().GetMessage(gomock.Any(), gomock.Eq("m7")).After(call429).DoAndReturn(autoDownloadScaleClientDoAndReturn)
case "m23":
call503 := client.EXPECT().GetMessage(gomock.Any(), gomock.Eq("m23")).DoAndReturn(func(_ context.Context, id string) (proton.Message, error) {
return proton.Message{}, &proton.APIError{Status: 503}
})
client.EXPECT().GetMessage(gomock.Any(), gomock.Eq("m23")).After(call503).DoAndReturn(autoDownloadScaleClientDoAndReturn)
default:
client.EXPECT().GetMessage(gomock.Any(), gomock.Eq(d)).DoAndReturn(autoDownloadScaleClientDoAndReturn)
}
}
defaultRateModifier := DefaultDownloadRateModifier{}
// First call catches failure in message m7, we throttle.
call1 := rateModifier.EXPECT().Apply(gomock.Eq(false), gomock.Eq(8), gomock.Eq(8)).Return(defaultRateModifier.Apply(false, 8, 8))
// Next batch succeeds. So we bump the parallel downloads by 2x.
call2 := rateModifier.EXPECT().Apply(gomock.Eq(true), gomock.Eq(2), gomock.Eq(8)).Return(defaultRateModifier.Apply(true, 2, 8))
// We now encounter a 503 with m23. Reset to 2 parallel downloads.
call3 := rateModifier.EXPECT().Apply(gomock.Eq(false), gomock.Eq(4), gomock.Eq(8)).Return(defaultRateModifier.Apply(false, 4, 8))
// The next batch succeeds once again.
call4 := rateModifier.EXPECT().Apply(gomock.Eq(true), gomock.Eq(2), gomock.Eq(8)).Return(defaultRateModifier.Apply(true, 2, 8))
gomock.InOrder(call1, call2, call3, call4)
msgs, err := autoDownloadRate(
context.Background(),
rateModifier,
client,
MaxParallel,
data,
autoScaleCoolDown,
func(ctx context.Context, client APIClient, input string) (proton.Message, error) {
return client.GetMessage(ctx, input)
},
)
require.NoError(t, err)
require.Equal(t, xslices.Map(data, newDownloadScaleMessage), msgs)
}
func TestDownloadStage_Run(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[DownloadRequest]()
output := NewChannelConsumerProducer[BuildRequest]()
ctx, cancel := context.WithCancel(context.Background())
tj := newTestJob(ctx, mockCtrl, "", map[string]proton.Label{})
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any())
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("f"), gomock.Eq(int64(10))).Return(nil)
tj.job.begin()
defer tj.job.end()
childJob := tj.job.newChildJob("f", 10)
stage := NewDownloadStage(input, output, 4, &async.NoopPanicHandler{})
msgIDs, expected := buildDownloadStageData(&tj, 56, false)
go func() {
stage.run(ctx)
}()
input.Produce(ctx, DownloadRequest{
childJob: childJob,
ids: msgIDs,
})
out, err := output.Consume(ctx)
require.NoError(t, err)
require.Equal(t, expected, out.batch)
out.onFinished(ctx)
cancel()
cachedMessages, cachedAttachments := tj.job.downloadCache.Count()
require.Zero(t, cachedMessages)
require.Zero(t, cachedAttachments)
}
func TestDownloadStage_RunWith422(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[DownloadRequest]()
output := NewChannelConsumerProducer[BuildRequest]()
ctx, cancel := context.WithCancel(context.Background())
tj := newTestJob(ctx, mockCtrl, "", map[string]proton.Label{})
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any())
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("f"), gomock.Eq(int64(10))).Return(nil)
tj.job.begin()
defer tj.job.end()
childJob := tj.job.newChildJob("f", 10)
stage := NewDownloadStage(input, output, 4, &async.NoopPanicHandler{})
msgIDs, expected := buildDownloadStageData(&tj, 56, true)
go func() {
stage.run(ctx)
}()
input.Produce(ctx, DownloadRequest{
childJob: childJob,
ids: msgIDs,
})
out, err := output.Consume(ctx)
require.NoError(t, err)
require.Equal(t, expected, out.batch)
out.onFinished(ctx)
cancel()
cachedMessages, cachedAttachments := tj.job.downloadCache.Count()
require.Zero(t, cachedMessages)
require.Zero(t, cachedAttachments)
}
func TestDownloadStage_CancelledJobIsDiscarded(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[DownloadRequest]()
output := NewChannelConsumerProducer[BuildRequest]()
ctx, cancel := context.WithCancel(context.Background())
jobCtx, jobCancel := context.WithCancel(context.Background())
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
tj.job.begin()
defer tj.job.end()
childJob := tj.job.newChildJob("f", 10)
stage := NewDownloadStage(input, output, 4, &async.NoopPanicHandler{})
go func() {
stage.run(ctx)
}()
jobCancel()
input.Produce(ctx, DownloadRequest{
childJob: childJob,
ids: nil,
})
go func() { cancel() }()
_, err := output.Consume(context.Background())
require.ErrorIs(t, err, ErrNoMoreInput)
}
func TestDownloadStage_JobAbortsOnMessageDownloadError(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[DownloadRequest]()
output := NewChannelConsumerProducer[BuildRequest]()
ctx, cancel := context.WithCancel(context.Background())
jobCtx, jobCancel := context.WithCancel(context.Background())
defer jobCancel()
expectedErr := errors.New("fail")
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{}, expectedErr)
tj.job.begin()
childJob := tj.job.newChildJob("f", 10)
tj.job.end()
stage := NewDownloadStage(input, output, 4, &async.NoopPanicHandler{})
go func() {
stage.run(ctx)
}()
input.Produce(ctx, DownloadRequest{
childJob: childJob,
ids: []string{"foo"},
})
err := tj.job.wait(ctx)
require.Equal(t, expectedErr, err)
cancel()
_, err = output.Consume(context.Background())
require.ErrorIs(t, err, ErrNoMoreInput)
}
func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) {
mockCtrl := gomock.NewController(t)
input := NewChannelConsumerProducer[DownloadRequest]()
output := NewChannelConsumerProducer[BuildRequest]()
ctx, cancel := context.WithCancel(context.Background())
jobCtx, jobCancel := context.WithCancel(context.Background())
defer jobCancel()
expectedErr := errors.New("fail")
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "msg",
NumAttachments: 1,
},
Header: "",
ParsedHeaders: nil,
Body: "",
MIMEType: "",
Attachments: []proton.Attachment{{
ID: "attach",
}},
}, nil)
tj.client.EXPECT().GetAttachment(gomock.Any(), gomock.Eq("attach")).Return(nil, expectedErr)
tj.job.begin()
childJob := tj.job.newChildJob("f", 10)
tj.job.end()
stage := NewDownloadStage(input, output, 4, &async.NoopPanicHandler{})
go func() {
stage.run(ctx)
}()
input.Produce(ctx, DownloadRequest{
childJob: childJob,
ids: []string{"foo"},
})
err := tj.job.wait(ctx)
require.Equal(t, expectedErr, err)
cancel()
_, err = output.Consume(context.Background())
require.ErrorIs(t, err, ErrNoMoreInput)
}
func buildDownloadStageData(tj *tjob, numMessages int, with422 bool) ([]string, []proton.FullMessage) {
result := make([]proton.FullMessage, numMessages)
msgIDs := make([]string, numMessages)
for i := 0; i < numMessages; i++ {
msgID := fmt.Sprintf("msg-%v", i)
msgIDs[i] = msgID
result[i] = proton.FullMessage{
Message: proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: msgID,
Size: len([]byte(msgID)),
},
Header: "",
ParsedHeaders: nil,
Body: msgID,
MIMEType: "",
Attachments: nil,
},
AttData: nil,
}
buildDownloadStageAttachments(&result[i], i)
}
for i, m := range result {
if with422 && i%2 == 0 {
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Eq(m.ID)).Return(proton.Message{}, &proton.APIError{Status: 422})
continue
}
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Eq(m.ID)).Return(m.Message, nil)
for idx, a := range m.Attachments {
tj.client.EXPECT().GetAttachment(gomock.Any(), gomock.Eq(a.ID)).Return(m.AttData[idx], nil)
}
}
if with422 {
result422 := make([]proton.FullMessage, 0, numMessages/2)
for i := 0; i < numMessages; i++ {
if i%2 == 0 {
continue
}
result422 = append(result422, result[i])
}
return msgIDs, result422
}
return msgIDs, result
}
func buildDownloadStageAttachments(msg *proton.FullMessage, index int) {
mod := index % 4
if mod == 0 {
return
}
genDownloadStageAttachmentInfo(msg, index, mod)
}
func genDownloadStageAttachmentInfo(msg *proton.FullMessage, msgIdx int, count int) {
msg.Attachments = make([]proton.Attachment, count)
msg.AttData = make([][]byte, count)
msg.NumAttachments = count
for i := 0; i < count; i++ {
data := fmt.Sprintf("msg-%v-att-%v", msgIdx, i)
msg.Attachments[i] = proton.Attachment{
ID: data,
Size: int64(len([]byte(data))),
}
msg.AttData[i] = []byte(data)
msg.Size += len([]byte(data))
}
}
func autoScaleCoolDown() network.CoolDownProvider {
return &network.NoCoolDown{}
}
func buildDownloadScaleData(count int) []string {
r := make([]string, count)
for i := 0; i < count; i++ {
r[i] = fmt.Sprintf("m%v", i)
}
return r
}
func newDownloadScaleMessage(id string) proton.Message {
return proton.Message{
MessageMetadata: proton.MessageMetadata{ID: id},
}
}
func autoDownloadScaleClientDoAndReturn(_ context.Context, id string) (proton.Message, error) {
return newDownloadScaleMessage(id), nil
}

View File

@ -0,0 +1,218 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/network"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
)
type MetadataStageOutput = StageOutputProducer[DownloadRequest]
type MetadataStageInput = StageInputConsumer[*Job]
// MetadataStage is responsible for the throttling the sync pipeline by only allowing `MetadataMaxMessages` or up to
// maximum allowed memory usage messages to go through the pipeline. It is also responsible for interleaving
// different sync jobs so all jobs can progress and finish.
type MetadataStage struct {
output MetadataStageOutput
input MetadataStageInput
maxDownloadMem uint64
jobs []*metadataIterator
log *logrus.Entry
}
func NewMetadataStage(input MetadataStageInput, output MetadataStageOutput, maxDownloadMem uint64) *MetadataStage {
return &MetadataStage{input: input, output: output, maxDownloadMem: maxDownloadMem, log: logrus.WithField("sync-stage", "metadata")}
}
const MetadataPageSize = 150
const MetadataMaxMessages = 250
func (m MetadataStage) Run(group *async.Group) {
group.Once(func(ctx context.Context) {
m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{})
})
}
func (m MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) {
defer m.output.Close()
for {
if ctx.Err() != nil {
return
}
// Check if new job has been submitted
job, ok, err := m.input.TryConsume(ctx)
if err != nil {
m.log.WithError(err).Error("Error trying to retrieve more work")
return
}
if ok {
job.begin()
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
if err != nil {
job.onError(err)
continue
}
m.jobs = append(m.jobs, state)
}
// Iterate over all jobs and produce work.
for i := 0; i < len(m.jobs); {
job := m.jobs[i]
// If the job's context has been cancelled, remove from the list.
if job.stage.ctx.Err() != nil {
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
job.stage.end()
continue
}
// Check for more work.
output, hasMore, err := job.Next(m.maxDownloadMem, metadataPageSize, maxMessages)
if err != nil {
job.stage.onError(err)
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
continue
}
// If there is actually more work, push it down the pipeline.
if len(output.ids) != 0 {
m.output.Produce(ctx, output)
}
// If this job has no more work left, signal completion.
if !hasMore {
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
job.stage.end()
continue
}
i++
}
}
}
type metadataIterator struct {
stage *Job
client *network.ProtonClientRetryWrapper[APIClient]
lastMessageID string
remaining []proton.MessageMetadata
downloadReqIDs []string
expectedSize uint64
}
func newMetadataIterator(ctx context.Context, stage *Job, metadataPageSize int, coolDown network.CoolDownProvider) (*metadataIterator, error) {
syncStatus, err := stage.state.GetSyncStatus(ctx)
if err != nil {
return nil, err
}
return &metadataIterator{
stage: stage,
client: network.NewClientRetryWrapper(stage.client, coolDown),
lastMessageID: syncStatus.LastSyncedMessageID,
remaining: nil,
downloadReqIDs: make([]string, 0, metadataPageSize),
}, nil
}
func (m *metadataIterator) Next(maxDownloadMem uint64, metadataPageSize int, maxMessages int) (DownloadRequest, bool, error) {
for {
if m.stage.ctx.Err() != nil {
return DownloadRequest{}, false, m.stage.ctx.Err()
}
if len(m.remaining) == 0 {
metadata, err := network.RetryWithClient(m.stage.ctx, m.client, func(ctx context.Context, c APIClient) ([]proton.MessageMetadata, error) {
// To get the metadata of the messages in batches we need to initialize the state with a call to
// GetMessageMetadata withe filter{Desc:true}.
if m.lastMessageID == "" {
return c.GetMessageMetadataPage(ctx, 0, metadataPageSize, proton.MessageFilter{
Desc: true,
})
}
// Afterward we perform the same query but set the EndID to the last message of the previous batch.
// Care must be taken here as the EndID will appear again as the first metadata result if it has not
// been eliminated.
meta, err := c.GetMessageMetadataPage(ctx, 0, metadataPageSize, proton.MessageFilter{
EndID: m.lastMessageID,
Desc: true,
})
if err != nil {
return nil, err
}
// To break the loop we need to check that either:
// * There are no messages returned
if len(meta) == 0 {
return meta, err
}
// * There is only one message returned and it matches the EndID query
if meta[0].ID == m.lastMessageID {
return meta[1:], nil
}
return meta, nil
})
if err != nil {
m.stage.log.WithError(err).Errorf("Failed to download message metadata with lastMessageID=%v", m.lastMessageID)
return DownloadRequest{}, false, err
}
m.remaining = append(m.remaining, metadata...)
// Update the last message ID
if len(m.remaining) != 0 {
m.lastMessageID = m.remaining[len(m.remaining)-1].ID
}
}
if len(m.remaining) == 0 {
if len(m.downloadReqIDs) != 0 {
return DownloadRequest{childJob: m.stage.newChildJob(m.downloadReqIDs[len(m.downloadReqIDs)-1], int64(len(m.downloadReqIDs))), ids: m.downloadReqIDs}, false, nil
}
return DownloadRequest{}, false, nil
}
for idx, meta := range m.remaining {
nextSize := m.expectedSize + uint64(meta.Size)
if nextSize >= maxDownloadMem || len(m.downloadReqIDs) >= maxMessages {
m.expectedSize = 0
m.remaining = m.remaining[idx:]
downloadReqIDs := m.downloadReqIDs
m.downloadReqIDs = make([]string, 0, metadataPageSize)
return DownloadRequest{childJob: m.stage.newChildJob(downloadReqIDs[len(downloadReqIDs)-1], int64(len(downloadReqIDs))), ids: downloadReqIDs}, true, nil
}
m.downloadReqIDs = append(m.downloadReqIDs, meta.ID)
m.expectedSize = nextSize
}
m.remaining = nil
}
}

View File

@ -0,0 +1,463 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"fmt"
"io"
"testing"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/network"
"github.com/bradenaw/juniper/xslices"
"github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
const TestMetadataPageSize = 5
const TestMaxDownloadMem = Kilobyte
const TestMaxMessages = 10
func TestMetadataStage_RunFinishesWith429(t *testing.T) {
mockCtrl := gomock.NewController(t)
tj := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{
LastSyncedMessageID: "",
}, nil)
input := NewChannelConsumerProducer[*Job]()
output := NewChannelConsumerProducer[DownloadRequest]()
ctx, cancel := context.WithCancel(context.Background())
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
numMessages := 50
messageSize := 100
msgs := setupMetadataSuccessRunWith429(&tj, numMessages, messageSize)
go func() {
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
}()
input.Produce(ctx, tj.job)
for _, chunk := range xslices.Chunk(msgs, TestMaxMessages) {
req, err := output.Consume(ctx)
require.NoError(t, err)
require.Equal(t, req.ids, xslices.Map(chunk, func(m proton.MessageMetadata) string {
return m.ID
}))
}
cancel()
}
func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
mockCtrl := gomock.NewController(t)
jobCtx, jobCancel := context.WithCancel(context.Background())
tj := newTestFixedMetadataJob(jobCtx, mockCtrl, "u", getTestLabels())
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{
LastSyncedMessageID: "",
}, nil)
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any()).AnyTimes()
input := NewChannelConsumerProducer[*Job]()
output := NewChannelConsumerProducer[DownloadRequest]()
ctx, cancel := context.WithCancel(context.Background())
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
go func() {
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
}()
input.Produce(ctx, tj.job)
// read one output then cancel
request, err := output.Consume(ctx)
require.NoError(t, err)
request.onFinished(ctx)
// cancel job context
jobCancel()
// The next stages should check whether the job has been cancelled or not. Here we need to do it manually.
go func() {
for {
req, err := output.Consume(ctx)
if err != nil {
return
}
req.checkCancelled()
}
}()
err = tj.job.wait(context.Background())
require.ErrorIs(t, err, context.Canceled)
cancel()
}
func TestMetadataStage_RunInterleaved(t *testing.T) {
mockCtrl := gomock.NewController(t)
tj1 := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
tj1.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{}, nil)
tj1.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tj1.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any()).AnyTimes()
tj2 := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
tj2.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{}, nil)
tj2.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tj2.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any()).AnyTimes()
input := NewChannelConsumerProducer[*Job]()
output := NewChannelConsumerProducer[DownloadRequest]()
ctx, cancel := context.WithCancel(context.Background())
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
numMessages := 50
messageSize := 100
setupMetadataSuccessRunWith429(&tj1, numMessages, messageSize)
setupMetadataSuccessRunWith429(&tj2, numMessages, messageSize)
go func() {
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
}()
go func() {
input.Produce(ctx, tj1.job)
input.Produce(ctx, tj2.job)
}()
go func() {
for {
req, err := output.Consume(ctx)
if err != nil {
require.ErrorIs(t, err, context.Canceled)
return
}
req.onFinished(ctx)
}
}()
require.NoError(t, tj1.job.wait(ctx))
require.NoError(t, tj2.job.wait(ctx))
cancel()
}
func TestMetadataIterator_ExitNoMoreMetadata(t *testing.T) {
mockCtrl := gomock.NewController(t)
ctx := context.Background()
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{
LastSyncedMessageID: "foo",
}, nil)
tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(proton.MessageFilter{Desc: true, EndID: "foo"})).Return(nil, nil)
iter, err := newMetadataIterator(ctx, tj.job, TestMetadataPageSize, &network.NoCoolDown{})
require.NoError(t, err)
j, hasMore, err := iter.Next(TestMaxDownloadMem, TestMetadataPageSize, TestMaxMessages)
require.NoError(t, err)
require.False(t, hasMore)
require.Empty(t, j.ids)
}
func TestMetadataIterator_ExitLastCallAlwaysReturnLastMessageID(t *testing.T) {
mockCtrl := gomock.NewController(t)
ctx := context.Background()
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{
LastSyncedMessageID: "foo",
}, nil)
tj.client.EXPECT().GetMessageMetadataPage(
gomock.Any(),
gomock.Eq(0),
gomock.Eq(TestMetadataPageSize),
gomock.Eq(proton.MessageFilter{Desc: true, EndID: "foo"}),
).Return([]proton.MessageMetadata{{
ID: "foo",
Size: 100,
}}, nil)
iter, err := newMetadataIterator(ctx, tj.job, TestMetadataPageSize, &network.NoCoolDown{})
require.NoError(t, err)
j, hasMore, err := iter.Next(TestMaxDownloadMem, TestMetadataPageSize, TestMaxMessages)
require.NoError(t, err)
require.False(t, hasMore)
require.Empty(t, j.ids)
}
func TestMetadataIterator_ExitWithRemainingReturnsNoMore(t *testing.T) {
mockCtrl := gomock.NewController(t)
ctx := context.Background()
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{}, nil)
const MetadataPageSize = 2
tj.client.EXPECT().GetMessageMetadataPage(
gomock.Any(),
gomock.Eq(0),
gomock.Eq(MetadataPageSize),
gomock.Eq(proton.MessageFilter{
Desc: true,
}),
).Return([]proton.MessageMetadata{
{
ID: "foo",
Size: 100,
},
{
ID: "bar",
Size: 100,
},
}, nil)
tj.client.EXPECT().GetMessageMetadataPage(
gomock.Any(),
gomock.Eq(0),
gomock.Eq(MetadataPageSize),
gomock.Eq(proton.MessageFilter{
Desc: true,
EndID: "bar",
}),
).Return([]proton.MessageMetadata{
{
ID: "bar",
Size: 100,
},
}, nil)
iter, err := newMetadataIterator(ctx, tj.job, MetadataPageSize, &network.NoCoolDown{})
require.NoError(t, err)
j, hasMore, err := iter.Next(TestMaxDownloadMem, MetadataPageSize, 3)
require.NoError(t, err)
require.False(t, hasMore)
require.Equal(t, []string{"foo", "bar"}, j.ids)
}
func TestMetadataIterator_RespectsSizeLimit(t *testing.T) {
mockCtrl := gomock.NewController(t)
ctx := context.Background()
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{}, nil)
// First call.
tj.client.EXPECT().GetMessageMetadataPage(
gomock.Any(),
gomock.Eq(0),
gomock.Eq(TestMetadataPageSize),
gomock.Eq(proton.MessageFilter{Desc: true}),
).Return([]proton.MessageMetadata{
{
ID: testMsgID(0),
Size: 256,
},
{
ID: testMsgID(1),
Size: 512,
},
{
ID: testMsgID(2),
Size: 128,
},
{
ID: testMsgID(3),
Size: 256,
},
}, nil)
// Second Call
tj.client.EXPECT().GetMessageMetadataPage(
gomock.Any(),
gomock.Eq(0),
gomock.Eq(TestMetadataPageSize),
gomock.Eq(proton.MessageFilter{Desc: true, EndID: testMsgID(3)}),
).Return([]proton.MessageMetadata{
{
ID: testMsgID(3),
Size: 256,
},
}, nil)
iter, err := newMetadataIterator(ctx, tj.job, TestMetadataPageSize, &network.NoCoolDown{})
require.NoError(t, err)
j, hasMore, err := iter.Next(TestMaxDownloadMem, TestMetadataPageSize, TestMaxMessages)
require.NoError(t, err)
require.True(t, hasMore)
require.Equal(t, []string{testMsgID(0), testMsgID(1), testMsgID(2)}, j.ids)
j, hasMore, err = iter.Next(TestMaxDownloadMem, TestMetadataPageSize, TestMaxMessages)
require.NoError(t, err)
require.False(t, hasMore)
require.Equal(t, []string{testMsgID(3)}, j.ids)
}
func testMsgID(i int) string {
return fmt.Sprintf("msg-id-%v", i)
}
func setupMetadataSuccessRunWith429(tj *tjob, msgCount int, msgSize int) []proton.MessageMetadata {
msgs := make([]proton.MessageMetadata, msgCount)
for i := 0; i < msgCount; i++ {
msgs[i].ID = testMsgID(i)
msgs[i].Size = msgSize
}
// setup api call
for i := 0; i < msgCount; i += TestMetadataPageSize - 1 {
filter := proton.MessageFilter{
Desc: true,
}
if i != 0 {
filter.EndID = msgs[i].ID
}
if i+TestMetadataPageSize > msgCount {
call := tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(filter)).Return(
nil, &proton.APIError{Status: 503},
)
tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(filter)).Return(
msgs[i:], nil,
).After(call)
} else {
call := tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(filter)).Return(
nil, &proton.APIError{Status: 429},
)
tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(filter)).Return(
msgs[i:i+TestMetadataPageSize], nil,
).After(call)
}
}
// Last call with last metadata id
tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(proton.MessageFilter{Desc: true, EndID: msgs[msgCount-1].ID})).Return(
msgs[msgCount-1:], nil,
)
return msgs
}
func newTestFixedMetadataJob(
ctx context.Context,
mockCtrl *gomock.Controller,
userID string,
labels LabelMap,
) tjob {
messageBuilder := NewMockMessageBuilder(mockCtrl)
updateApplier := NewMockUpdateApplier(mockCtrl)
syncReporter := NewMockReporter(mockCtrl)
state := NewMockStateProvider(mockCtrl)
client := newFixedMetadataClient(50)
job := NewJob(
ctx,
client,
userID,
labels,
messageBuilder,
updateApplier,
syncReporter,
state,
&async.NoopPanicHandler{},
newDownloadCache(),
logrus.WithField("s", "test"),
)
return tjob{
job: job,
client: nil,
messageBuilder: messageBuilder,
updateApplier: updateApplier,
syncReporter: syncReporter,
state: state,
}
}
type fixedMetadataClient struct {
msg []proton.MessageMetadata
offset int
}
func newFixedMetadataClient(msgCount int) APIClient {
msgs := make([]proton.MessageMetadata, msgCount)
for i := 0; i < msgCount; i++ {
msgs[i].ID = testMsgID(i)
msgs[i].Size = 100
}
return &fixedMetadataClient{msg: msgs}
}
func (c *fixedMetadataClient) GetGroupedMessageCount(_ context.Context) ([]proton.MessageGroupCount, error) {
panic("should not be called")
}
func (c *fixedMetadataClient) GetLabels(_ context.Context, _ ...proton.LabelType) ([]proton.Label, error) {
panic("should not be called")
}
func (c *fixedMetadataClient) GetMessage(_ context.Context, _ string) (proton.Message, error) {
panic("should not be called")
}
func (c *fixedMetadataClient) GetMessageMetadataPage(_ context.Context, _, pageSize int, _ proton.MessageFilter) ([]proton.MessageMetadata, error) {
result := c.msg[c.offset : c.offset+pageSize]
c.offset += pageSize
return result, nil
}
func (c *fixedMetadataClient) GetMessageIDs(_ context.Context, _ string) ([]string, error) {
panic("should not be called")
}
func (c *fixedMetadataClient) GetFullMessage(_ context.Context, _ string, _ proton.Scheduler, _ proton.AttachmentAllocator) (proton.FullMessage, error) {
panic("should not be called")
}
func (c *fixedMetadataClient) GetAttachmentInto(_ context.Context, _ string, _ io.ReaderFrom) error {
panic("should not be called")
}
func (c *fixedMetadataClient) GetAttachment(_ context.Context, _ string) ([]byte, error) {
panic("should not be called")
}

View File

@ -0,0 +1,85 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"errors"
)
type StageOutputProducer[T any] interface {
Produce(ctx context.Context, value T)
Close()
}
var ErrNoMoreInput = errors.New("no more input")
type StageInputConsumer[T any] interface {
Consume(ctx context.Context) (T, error)
TryConsume(ctx context.Context) (T, bool, error)
}
type ChannelConsumerProducer[T any] struct {
ch chan T
}
func NewChannelConsumerProducer[T any]() *ChannelConsumerProducer[T] {
return &ChannelConsumerProducer[T]{ch: make(chan T)}
}
func (c ChannelConsumerProducer[T]) Produce(ctx context.Context, value T) {
select {
case <-ctx.Done():
case c.ch <- value:
}
}
func (c ChannelConsumerProducer[T]) Close() {
close(c.ch)
}
func (c ChannelConsumerProducer[T]) Consume(ctx context.Context) (T, error) {
select {
case <-ctx.Done():
var t T
return t, ctx.Err()
case t, ok := <-c.ch:
if !ok {
return t, ErrNoMoreInput
}
return t, nil
}
}
func (c ChannelConsumerProducer[T]) TryConsume(ctx context.Context) (T, bool, error) {
select {
case <-ctx.Done():
var t T
return t, false, ctx.Err()
case t, ok := <-c.ch:
if !ok {
return t, false, ErrNoMoreInput
}
return t, true, nil
default:
var t T
return t, false, nil
}
}

View File

@ -0,0 +1,31 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"time"
)
// sleepCtx sleeps for the given duration, or until the context is canceled.
func sleepCtx(ctx context.Context, d time.Duration) {
select {
case <-ctx.Done():
case <-time.After(d):
}
}