forked from Silverfish/proton-bridge
feat(GODT-2829): New Sync Service
Implementation of the new sync service that interleaves syncing jobs for all active users. It also includes improvements to the message downloader. The download will now auto rate limit the parallel workers based on the server responses. Additionally each of the stages is now tested in isolation to ensure the behavior matches the expectations. Finally, this patch does not replace the existing IMAP sync. A follow up patch is necessary to integrate the IMAP bits into the interfaces required by these changes.
This commit is contained in:
5
Makefile
5
Makefile
@ -299,6 +299,11 @@ EventSubscriber,MessageEventHandler,LabelEventHandler,AddressEventHandler,Refres
|
||||
> internal/events/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \
|
||||
> internal/services/useridentity/mocks/mocks.go
|
||||
mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/sync" -package sync github.com/ProtonMail/proton-bridge/v3/internal/services/sync \
|
||||
ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,\
|
||||
StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier \
|
||||
> tmp
|
||||
mv tmp internal/services/sync/mocks_test.go
|
||||
|
||||
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint-bug-report
|
||||
|
||||
|
||||
1
go.sum
1
go.sum
@ -266,6 +266,7 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||
|
||||
123
internal/network/proton.go
Normal file
123
internal/network/proton.go
Normal file
@ -0,0 +1,123 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
)
|
||||
|
||||
type CoolDownProvider interface {
|
||||
GetNextWaitTime() time.Duration
|
||||
Reset()
|
||||
}
|
||||
|
||||
func jitter(max int) time.Duration {
|
||||
return time.Duration(rand.Intn(max)) * time.Second //nolint:gosec
|
||||
}
|
||||
|
||||
type ExpCoolDown struct {
|
||||
count int
|
||||
}
|
||||
|
||||
func (c *ExpCoolDown) GetNextWaitTime() time.Duration {
|
||||
waitTimes := []time.Duration{
|
||||
20 * time.Second,
|
||||
40 * time.Second,
|
||||
80 * time.Second,
|
||||
160 * time.Second,
|
||||
300 * time.Second,
|
||||
600 * time.Second,
|
||||
}
|
||||
|
||||
last := len(waitTimes) - 1
|
||||
|
||||
if c.count >= last {
|
||||
return waitTimes[last] + jitter(10)
|
||||
}
|
||||
|
||||
c.count++
|
||||
|
||||
return waitTimes[c.count-1] + jitter(10)
|
||||
}
|
||||
|
||||
func (c *ExpCoolDown) Reset() {
|
||||
c.count = 0
|
||||
}
|
||||
|
||||
type NoCoolDown struct{}
|
||||
|
||||
func (c *NoCoolDown) GetNextWaitTime() time.Duration { return time.Millisecond }
|
||||
func (c *NoCoolDown) Reset() {}
|
||||
|
||||
func Is429Or5XXError(err error) bool {
|
||||
var apiErr *proton.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
return apiErr.Status == 429 || apiErr.Status >= 500
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type ProtonClientRetryWrapper[T any] struct {
|
||||
client T
|
||||
coolDown CoolDownProvider
|
||||
encountered429or5xx bool
|
||||
}
|
||||
|
||||
func NewClientRetryWrapper[T any](client T, coolDown CoolDownProvider) *ProtonClientRetryWrapper[T] {
|
||||
return &ProtonClientRetryWrapper[T]{client: client, coolDown: coolDown}
|
||||
}
|
||||
|
||||
func (p *ProtonClientRetryWrapper[T]) DidEncounter429or5xx() bool {
|
||||
return p.encountered429or5xx
|
||||
}
|
||||
|
||||
func (p *ProtonClientRetryWrapper[T]) Retry(ctx context.Context, f func(context.Context, T) error) error {
|
||||
p.coolDown.Reset()
|
||||
p.encountered429or5xx = false
|
||||
for {
|
||||
err := f(ctx, p.client)
|
||||
if Is429Or5XXError(err) {
|
||||
p.encountered429or5xx = true
|
||||
coolDown := p.coolDown.GetNextWaitTime()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(coolDown):
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func RetryWithClient[T any, R any](ctx context.Context, p *ProtonClientRetryWrapper[T], f func(context.Context, T) (R, error)) (R, error) {
|
||||
var result R
|
||||
err := p.Retry(ctx, func(ctx context.Context, t T) error {
|
||||
r, err := f(ctx, t)
|
||||
result = r
|
||||
return err
|
||||
})
|
||||
|
||||
return result, err
|
||||
}
|
||||
36
internal/services/syncservice/api_client.go
Normal file
36
internal/services/syncservice/api_client.go
Normal file
@ -0,0 +1,36 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
)
|
||||
|
||||
type APIClient interface {
|
||||
GetGroupedMessageCount(ctx context.Context) ([]proton.MessageGroupCount, error)
|
||||
GetLabels(ctx context.Context, labelTypes ...proton.LabelType) ([]proton.Label, error)
|
||||
GetMessage(ctx context.Context, messageID string) (proton.Message, error)
|
||||
GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error)
|
||||
GetMessageIDs(ctx context.Context, afterID string) ([]string, error)
|
||||
GetFullMessage(ctx context.Context, messageID string, scheduler proton.Scheduler, storageProvider proton.AttachmentAllocator) (proton.FullMessage, error)
|
||||
GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error
|
||||
GetAttachment(ctx context.Context, attachmentID string) ([]byte, error)
|
||||
}
|
||||
115
internal/services/syncservice/download_cache.go
Normal file
115
internal/services/syncservice/download_cache.go
Normal file
@ -0,0 +1,115 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
)
|
||||
|
||||
type DownloadCache struct {
|
||||
messageLock sync.RWMutex
|
||||
messages map[string]proton.Message
|
||||
attachmentLock sync.RWMutex
|
||||
attachments map[string][]byte
|
||||
}
|
||||
|
||||
func newDownloadCache() *DownloadCache {
|
||||
return &DownloadCache{
|
||||
messages: make(map[string]proton.Message, 64),
|
||||
attachments: make(map[string][]byte, 64),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DownloadCache) StoreMessage(message proton.Message) {
|
||||
s.messageLock.Lock()
|
||||
defer s.messageLock.Unlock()
|
||||
|
||||
s.messages[message.ID] = message
|
||||
}
|
||||
|
||||
func (s *DownloadCache) StoreAttachment(id string, data []byte) {
|
||||
s.attachmentLock.Lock()
|
||||
defer s.attachmentLock.Unlock()
|
||||
|
||||
s.attachments[id] = data
|
||||
}
|
||||
|
||||
func (s *DownloadCache) DeleteMessages(id ...string) {
|
||||
s.messageLock.Lock()
|
||||
defer s.messageLock.Unlock()
|
||||
|
||||
for _, id := range id {
|
||||
delete(s.messages, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DownloadCache) DeleteAttachments(id ...string) {
|
||||
s.attachmentLock.Lock()
|
||||
defer s.attachmentLock.Unlock()
|
||||
|
||||
for _, id := range id {
|
||||
delete(s.attachments, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DownloadCache) GetMessage(id string) (proton.Message, bool) {
|
||||
s.messageLock.RLock()
|
||||
defer s.messageLock.RUnlock()
|
||||
|
||||
v, ok := s.messages[id]
|
||||
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (s *DownloadCache) GetAttachment(id string) ([]byte, bool) {
|
||||
s.attachmentLock.RLock()
|
||||
defer s.attachmentLock.RUnlock()
|
||||
|
||||
v, ok := s.attachments[id]
|
||||
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (s *DownloadCache) Clear() {
|
||||
s.messageLock.Lock()
|
||||
s.messages = make(map[string]proton.Message, 64)
|
||||
s.messageLock.Unlock()
|
||||
|
||||
s.attachmentLock.Lock()
|
||||
s.attachments = make(map[string][]byte, 64)
|
||||
s.attachmentLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *DownloadCache) Count() (int, int) {
|
||||
var (
|
||||
messageCount int
|
||||
attachmentCount int
|
||||
)
|
||||
|
||||
s.messageLock.Lock()
|
||||
messageCount = len(s.messages)
|
||||
s.messageLock.Unlock()
|
||||
|
||||
s.attachmentLock.Lock()
|
||||
attachmentCount = len(s.attachments)
|
||||
s.attachmentLock.Unlock()
|
||||
|
||||
return messageCount, attachmentCount
|
||||
}
|
||||
218
internal/services/syncservice/handler.go
Normal file
218
internal/services/syncservice/handler.go
Normal file
@ -0,0 +1,218 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/network"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const DefaultRetryCoolDown = 20 * time.Second
|
||||
|
||||
type LabelMap = map[string]proton.Label
|
||||
|
||||
// Handler is the interface from which we control the syncing of the IMAP data. One instance should be created for each
|
||||
// user and used for every subsequent sync request.
|
||||
type Handler struct {
|
||||
regulator Regulator
|
||||
client APIClient
|
||||
userID string
|
||||
syncState StateProvider
|
||||
log *logrus.Entry
|
||||
group *async.Group
|
||||
syncFinishedCh chan error
|
||||
panicHandler async.PanicHandler
|
||||
downloadCache *DownloadCache
|
||||
}
|
||||
|
||||
func NewHandler(
|
||||
regulator Regulator,
|
||||
client APIClient,
|
||||
userID string,
|
||||
state StateProvider,
|
||||
log *logrus.Entry,
|
||||
panicHandler async.PanicHandler,
|
||||
) *Handler {
|
||||
return &Handler{
|
||||
client: client,
|
||||
userID: userID,
|
||||
syncState: state,
|
||||
log: log,
|
||||
syncFinishedCh: make(chan error),
|
||||
group: async.NewGroup(context.Background(), panicHandler),
|
||||
regulator: regulator,
|
||||
panicHandler: panicHandler,
|
||||
downloadCache: newDownloadCache(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Handler) Close() {
|
||||
t.group.CancelAndWait()
|
||||
close(t.syncFinishedCh)
|
||||
}
|
||||
|
||||
func (t *Handler) CancelAndWait() {
|
||||
t.group.CancelAndWait()
|
||||
}
|
||||
|
||||
func (t *Handler) Cancel() {
|
||||
t.group.Cancel()
|
||||
}
|
||||
|
||||
func (t *Handler) OnSyncFinishedCH() <-chan error {
|
||||
return t.syncFinishedCh
|
||||
}
|
||||
|
||||
func (t *Handler) Execute(
|
||||
syncReporter Reporter,
|
||||
labels LabelMap,
|
||||
updateApplier UpdateApplier,
|
||||
messageBuilder MessageBuilder,
|
||||
coolDown time.Duration,
|
||||
) {
|
||||
t.log.Info("Sync triggered")
|
||||
t.group.Once(func(ctx context.Context) {
|
||||
start := time.Now()
|
||||
t.log.WithField("start", start).Info("Beginning user sync")
|
||||
|
||||
syncReporter.OnStart(ctx)
|
||||
var err error
|
||||
for {
|
||||
if err = ctx.Err(); err != nil {
|
||||
t.log.WithError(err).Error("Sync aborted")
|
||||
break
|
||||
} else if err = t.run(ctx, syncReporter, labels, updateApplier, messageBuilder); err != nil {
|
||||
t.log.WithError(err).Error("Failed to sync, will retry later")
|
||||
sleepCtx(ctx, coolDown)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
syncReporter.OnError(ctx, err)
|
||||
} else {
|
||||
syncReporter.OnFinished(ctx)
|
||||
}
|
||||
|
||||
t.log.WithField("duration", time.Since(start)).Info("Finished user sync")
|
||||
t.syncFinishedCh <- err
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Handler) run(ctx context.Context,
|
||||
syncReporter Reporter,
|
||||
labels LabelMap,
|
||||
updateApplier UpdateApplier,
|
||||
messageBuilder MessageBuilder,
|
||||
) error {
|
||||
syncStatus, err := t.syncState.GetSyncStatus(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get sync status: %w", err)
|
||||
}
|
||||
|
||||
if syncStatus.IsComplete() {
|
||||
t.log.Info("Sync already complete, only system labels will be updated")
|
||||
|
||||
if err := updateApplier.SyncSystemLabelsOnly(ctx, labels); err != nil {
|
||||
t.log.WithError(err).Error("Failed to sync system labels")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if !syncStatus.HasLabels {
|
||||
t.log.Info("Syncing labels")
|
||||
if err := updateApplier.SyncLabels(ctx, labels); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
}
|
||||
|
||||
if err := t.syncState.SetHasLabels(ctx, true); err != nil {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
}
|
||||
|
||||
t.log.Info("Synced labels")
|
||||
}
|
||||
|
||||
if !syncStatus.HasMessageCount {
|
||||
wrapper := network.NewClientRetryWrapper(t.client, &network.ExpCoolDown{})
|
||||
|
||||
messageCounts, err := network.RetryWithClient(ctx, wrapper, func(ctx context.Context, c APIClient) ([]proton.MessageGroupCount, error) {
|
||||
return c.GetGroupedMessageCount(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve message ids: %w", err)
|
||||
}
|
||||
|
||||
var totalMessageCount int64
|
||||
|
||||
for _, gc := range messageCounts {
|
||||
if gc.LabelID == proton.AllMailLabel {
|
||||
totalMessageCount = int64(gc.Total)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := t.syncState.SetMessageCount(ctx, totalMessageCount); err != nil {
|
||||
return fmt.Errorf("failed to store message count: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !syncStatus.HasMessages {
|
||||
t.log.Info("Syncing messages")
|
||||
|
||||
stageContext := NewJob(
|
||||
ctx,
|
||||
t.client,
|
||||
t.userID,
|
||||
labels,
|
||||
messageBuilder,
|
||||
updateApplier,
|
||||
syncReporter,
|
||||
t.syncState,
|
||||
t.panicHandler,
|
||||
t.downloadCache,
|
||||
t.log,
|
||||
)
|
||||
|
||||
t.regulator.Sync(ctx, stageContext)
|
||||
|
||||
// Wait on reply
|
||||
if err := stageContext.wait(ctx); err != nil {
|
||||
return fmt.Errorf("failed sync messages: %w", err)
|
||||
}
|
||||
|
||||
if err := t.syncState.SetHasMessages(ctx, true); err != nil {
|
||||
return fmt.Errorf("failed to set sync as completed: %w", err)
|
||||
}
|
||||
|
||||
t.log.Info("Synced messages")
|
||||
} else {
|
||||
t.log.Info("Messages are already synced, skipping")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
352
internal/services/syncservice/handler_test.go
Normal file
352
internal/services/syncservice/handler_test.go
Normal file
@ -0,0 +1,352 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/bradenaw/juniper/xmaps"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTask_NoStateAndSucceeds(t *testing.T) {
|
||||
const MessageTotal int64 = 50
|
||||
const MessageID string = "foo"
|
||||
const MessageDelta int64 = 10
|
||||
|
||||
labels := getTestLabels()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
tt := newTestHandler(mockCtrl, "u")
|
||||
|
||||
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
|
||||
|
||||
{
|
||||
call1 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||
return Status{
|
||||
HasLabels: false,
|
||||
HasMessages: false,
|
||||
HasMessageCount: false,
|
||||
FailedMessages: xmaps.SetFromSlice([]string{}),
|
||||
LastSyncedMessageID: "",
|
||||
NumSyncedMessages: 0,
|
||||
TotalMessageCount: 0,
|
||||
}, nil
|
||||
})
|
||||
call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil)
|
||||
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
|
||||
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
||||
call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
||||
tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||
return Status{
|
||||
HasLabels: true,
|
||||
HasMessages: true,
|
||||
HasMessageCount: true,
|
||||
FailedMessages: xmaps.SetFromSlice([]string{}),
|
||||
LastSyncedMessageID: MessageID,
|
||||
NumSyncedMessages: MessageDelta,
|
||||
TotalMessageCount: MessageTotal,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
{
|
||||
call1 := tt.updateApplier.EXPECT().SyncLabels(gomock.Any(), gomock.Eq(labels)).Times(1).Return(nil)
|
||||
tt.updateApplier.EXPECT().SyncSystemLabelsOnly(gomock.Any(), gomock.Eq(labels)).After(call1).Times(1).Return(nil)
|
||||
}
|
||||
|
||||
{
|
||||
tt.client.EXPECT().GetGroupedMessageCount(gomock.Any()).Return([]proton.MessageGroupCount{
|
||||
{
|
||||
LabelID: proton.AllMailLabel,
|
||||
Total: int(MessageTotal),
|
||||
Unread: 0,
|
||||
},
|
||||
}, nil)
|
||||
}
|
||||
|
||||
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
|
||||
|
||||
// First run.
|
||||
err := tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second Run, it's completed sync labels only.
|
||||
err = tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTask_StateHasLabels(t *testing.T) {
|
||||
const MessageTotal int64 = 50
|
||||
const MessageID string = "foo"
|
||||
const MessageDelta int64 = 10
|
||||
|
||||
labels := getTestLabels()
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
tt := newTestHandler(mockCtrl, "u")
|
||||
|
||||
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
|
||||
|
||||
{
|
||||
call2 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||
return Status{
|
||||
HasLabels: true,
|
||||
HasMessages: false,
|
||||
HasMessageCount: false,
|
||||
FailedMessages: xmaps.SetFromSlice([]string{}),
|
||||
LastSyncedMessageID: "",
|
||||
NumSyncedMessages: 0,
|
||||
TotalMessageCount: 0,
|
||||
}, nil
|
||||
})
|
||||
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
|
||||
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
||||
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
||||
}
|
||||
|
||||
{
|
||||
tt.client.EXPECT().GetGroupedMessageCount(gomock.Any()).Return([]proton.MessageGroupCount{
|
||||
{
|
||||
LabelID: proton.AllMailLabel,
|
||||
Total: int(MessageTotal),
|
||||
Unread: 0,
|
||||
},
|
||||
}, nil)
|
||||
}
|
||||
|
||||
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
|
||||
|
||||
err := tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTask_StateHasLabelsAndMessageCount(t *testing.T) {
|
||||
const MessageTotal int64 = 50
|
||||
const MessageID string = "foo"
|
||||
const MessageDelta int64 = 10
|
||||
|
||||
labels := getTestLabels()
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
tt := newTestHandler(mockCtrl, "u")
|
||||
|
||||
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
|
||||
|
||||
{
|
||||
call3 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||
return Status{
|
||||
HasLabels: true,
|
||||
HasMessages: false,
|
||||
HasMessageCount: true,
|
||||
FailedMessages: xmaps.SetFromSlice([]string{}),
|
||||
LastSyncedMessageID: "",
|
||||
NumSyncedMessages: 0,
|
||||
TotalMessageCount: MessageTotal,
|
||||
}, nil
|
||||
})
|
||||
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
||||
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
||||
}
|
||||
|
||||
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
|
||||
|
||||
err := tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTask_StateHasSyncedState(t *testing.T) {
|
||||
const MessageTotal int64 = 50
|
||||
const MessageID string = "foo"
|
||||
|
||||
labels := getTestLabels()
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
tt := newTestHandler(mockCtrl, "u")
|
||||
|
||||
tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||
return Status{
|
||||
HasLabels: true,
|
||||
HasMessages: true,
|
||||
HasMessageCount: true,
|
||||
FailedMessages: xmaps.SetFromSlice([]string{}),
|
||||
LastSyncedMessageID: MessageID,
|
||||
NumSyncedMessages: MessageTotal,
|
||||
TotalMessageCount: MessageTotal,
|
||||
}, nil
|
||||
})
|
||||
|
||||
tt.updateApplier.EXPECT().SyncSystemLabelsOnly(gomock.Any(), gomock.Eq(labels)).Return(nil)
|
||||
|
||||
err := tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTask_RepeatsOnSyncFailure(t *testing.T) {
|
||||
const MessageTotal int64 = 50
|
||||
const MessageID string = "foo"
|
||||
const MessageDelta int64 = 10
|
||||
|
||||
labels := getTestLabels()
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
tt := newTestHandler(mockCtrl, "u")
|
||||
|
||||
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
|
||||
|
||||
{
|
||||
call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||
return Status{
|
||||
HasLabels: false,
|
||||
HasMessages: false,
|
||||
HasMessageCount: false,
|
||||
FailedMessages: xmaps.SetFromSlice([]string{}),
|
||||
LastSyncedMessageID: "",
|
||||
NumSyncedMessages: 0,
|
||||
TotalMessageCount: 0,
|
||||
}, nil
|
||||
})
|
||||
call1 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||
return Status{
|
||||
HasLabels: false,
|
||||
HasMessages: false,
|
||||
HasMessageCount: false,
|
||||
FailedMessages: xmaps.SetFromSlice([]string{}),
|
||||
LastSyncedMessageID: "",
|
||||
NumSyncedMessages: 0,
|
||||
TotalMessageCount: 0,
|
||||
}, nil
|
||||
}).After(call0)
|
||||
call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil)
|
||||
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
|
||||
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
||||
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
||||
}
|
||||
|
||||
{
|
||||
call0 := tt.updateApplier.EXPECT().SyncLabels(gomock.Any(), gomock.Eq(labels)).Times(1).Return(fmt.Errorf("failed"))
|
||||
tt.updateApplier.EXPECT().SyncLabels(gomock.Any(), gomock.Eq(labels)).Times(1).Return(nil).After(call0)
|
||||
}
|
||||
|
||||
{
|
||||
tt.client.EXPECT().GetGroupedMessageCount(gomock.Any()).Return([]proton.MessageGroupCount{
|
||||
{
|
||||
LabelID: proton.AllMailLabel,
|
||||
Total: int(MessageTotal),
|
||||
Unread: 0,
|
||||
},
|
||||
}, nil)
|
||||
}
|
||||
|
||||
tt.syncReporter.EXPECT().OnStart(gomock.Any())
|
||||
tt.syncReporter.EXPECT().OnFinished(gomock.Any())
|
||||
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
|
||||
|
||||
tt.task.Execute(tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder, time.Microsecond)
|
||||
require.NoError(t, <-tt.task.OnSyncFinishedCH())
|
||||
}
|
||||
|
||||
func getTestLabels() map[string]proton.Label {
|
||||
return map[string]proton.Label{
|
||||
proton.AllMailLabel: {
|
||||
ID: proton.AllMailLabel,
|
||||
Name: "All Mail",
|
||||
Type: proton.LabelTypeSystem,
|
||||
},
|
||||
proton.InboxLabel: {
|
||||
ID: proton.InboxLabel,
|
||||
Name: "Inbox",
|
||||
Type: proton.LabelTypeSystem,
|
||||
},
|
||||
proton.DraftsLabel: {
|
||||
ID: proton.DraftsLabel,
|
||||
Name: "Drafts",
|
||||
Type: proton.LabelTypeSystem,
|
||||
},
|
||||
proton.TrashLabel: {
|
||||
ID: proton.DraftsLabel,
|
||||
Name: "Drafts",
|
||||
Type: proton.LabelTypeSystem,
|
||||
},
|
||||
"label1": {
|
||||
ID: "label1",
|
||||
Name: "label1",
|
||||
Type: proton.LabelTypeLabel,
|
||||
},
|
||||
"folder1": {
|
||||
ID: "folder1",
|
||||
Name: "folder1",
|
||||
Type: proton.LabelTypeFolder,
|
||||
},
|
||||
"folder2": {
|
||||
ID: "folder2",
|
||||
Name: "folder2",
|
||||
ParentID: "folder1",
|
||||
Type: proton.LabelTypeFolder,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type thandler struct {
|
||||
task *Handler
|
||||
regulator *MockRegulator
|
||||
syncState *MockStateProvider
|
||||
updateApplier *MockUpdateApplier
|
||||
messageBuilder *MockMessageBuilder
|
||||
client *MockAPIClient
|
||||
syncReporter *MockReporter
|
||||
}
|
||||
|
||||
func (t thandler) addMessageSyncCompletedExpectation(messageID string, delta int64) { //nolint:unparam
|
||||
t.regulator.EXPECT().Sync(gomock.Any(), gomock.Any()).Do(func(_ context.Context, job *Job) {
|
||||
job.begin()
|
||||
j := job.newChildJob(messageID, delta)
|
||||
j.onFinished(context.Background())
|
||||
job.end()
|
||||
})
|
||||
}
|
||||
|
||||
func newTestHandler(mockCtrl *gomock.Controller, userID string) thandler { // nolint:unparam
|
||||
regulator := NewMockRegulator(mockCtrl)
|
||||
syncState := NewMockStateProvider(mockCtrl)
|
||||
updateApplier := NewMockUpdateApplier(mockCtrl)
|
||||
client := NewMockAPIClient(mockCtrl)
|
||||
messageBuilder := NewMockMessageBuilder(mockCtrl)
|
||||
syncReporter := NewMockReporter(mockCtrl)
|
||||
task := NewHandler(regulator, client, userID, syncState, logrus.WithField("test", "test"), &async.NoopPanicHandler{})
|
||||
|
||||
return thandler{
|
||||
task: task,
|
||||
regulator: regulator,
|
||||
syncState: syncState,
|
||||
updateApplier: updateApplier,
|
||||
messageBuilder: messageBuilder,
|
||||
syncReporter: syncReporter,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
88
internal/services/syncservice/interfaces.go
Normal file
88
internal/services/syncservice/interfaces.go
Normal file
@ -0,0 +1,88 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/bradenaw/juniper/xmaps"
|
||||
)
|
||||
|
||||
type StateProvider interface {
|
||||
AddFailedMessageID(context.Context, string) error
|
||||
RemFailedMessageID(context.Context, string) error
|
||||
GetSyncStatus(context.Context) (Status, error)
|
||||
ClearSyncStatus(context.Context) error
|
||||
SetHasLabels(context.Context, bool) error
|
||||
SetHasMessages(context.Context, bool) error
|
||||
SetLastMessageID(context.Context, string, int64) error
|
||||
SetMessageCount(context.Context, int64) error
|
||||
}
|
||||
|
||||
type Status struct {
|
||||
HasLabels bool
|
||||
HasMessages bool
|
||||
HasMessageCount bool
|
||||
FailedMessages xmaps.Set[string]
|
||||
LastSyncedMessageID string
|
||||
NumSyncedMessages int64
|
||||
TotalMessageCount int64
|
||||
}
|
||||
|
||||
func DefaultStatus() Status {
|
||||
return Status{
|
||||
FailedMessages: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s Status) IsComplete() bool {
|
||||
return s.HasLabels && s.HasMessages
|
||||
}
|
||||
|
||||
// Regulator is an abstraction for the sync service, since it regulates the number of concurrent sync activities.
|
||||
type Regulator interface {
|
||||
Sync(ctx context.Context, stage *Job)
|
||||
}
|
||||
|
||||
type BuildResult struct {
|
||||
AddressID string
|
||||
MessageID string
|
||||
Update *imap.MessageCreated
|
||||
}
|
||||
|
||||
type MessageBuilder interface {
|
||||
WithKeys(f func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error
|
||||
BuildMessage(apiLabels map[string]proton.Label, full proton.FullMessage, addrKR *crypto.KeyRing, buffer *bytes.Buffer) (BuildResult, error)
|
||||
}
|
||||
|
||||
type UpdateApplier interface {
|
||||
ApplySyncUpdates(ctx context.Context, updates []BuildResult) error
|
||||
SyncSystemLabelsOnly(ctx context.Context, labels map[string]proton.Label) error
|
||||
SyncLabels(ctx context.Context, labels map[string]proton.Label) error
|
||||
}
|
||||
|
||||
type Reporter interface {
|
||||
OnStart(ctx context.Context)
|
||||
OnFinished(ctx context.Context)
|
||||
OnError(ctx context.Context, err error)
|
||||
OnProgress(ctx context.Context, delta int64)
|
||||
}
|
||||
201
internal/services/syncservice/job.go
Normal file
201
internal/services/syncservice/job.go
Normal file
@ -0,0 +1,201 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Job represents a unit of work that will travel down the sync pipeline. The job will be split up into child jobs
|
||||
// for each batch. The parent job (this) will then wait until all the children have finished executing. Execution can
|
||||
// terminate by either:
|
||||
// * Completing the pipeline successfully
|
||||
// * Context Cancellation
|
||||
// * Errors
|
||||
// On error, or context cancellation all child jobs are cancelled.
|
||||
type Job struct {
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
|
||||
client APIClient
|
||||
state StateProvider
|
||||
|
||||
userID string
|
||||
labels LabelMap
|
||||
messageBuilder MessageBuilder
|
||||
updateApplier UpdateApplier
|
||||
syncReporter Reporter
|
||||
|
||||
log *logrus.Entry
|
||||
errorCh *async.QueuedChannel[error]
|
||||
wg sync.WaitGroup
|
||||
once sync.Once
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
downloadCache *DownloadCache
|
||||
}
|
||||
|
||||
func NewJob(ctx context.Context,
|
||||
client APIClient,
|
||||
userID string,
|
||||
labels LabelMap,
|
||||
messageBuilder MessageBuilder,
|
||||
updateApplier UpdateApplier,
|
||||
syncReporter Reporter,
|
||||
state StateProvider,
|
||||
panicHandler async.PanicHandler,
|
||||
cache *DownloadCache,
|
||||
log *logrus.Entry,
|
||||
) *Job {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &Job{
|
||||
ctx: ctx,
|
||||
client: client,
|
||||
userID: userID,
|
||||
cancel: cancel,
|
||||
state: state,
|
||||
log: log,
|
||||
labels: labels,
|
||||
messageBuilder: messageBuilder,
|
||||
updateApplier: updateApplier,
|
||||
syncReporter: syncReporter,
|
||||
errorCh: async.NewQueuedChannel[error](4, 8, panicHandler, fmt.Sprintf("sync-job-error-%v", userID)),
|
||||
panicHandler: panicHandler,
|
||||
downloadCache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (j *Job) Close() {
|
||||
j.errorCh.CloseAndDiscardQueued()
|
||||
j.wg.Wait()
|
||||
}
|
||||
|
||||
func (j *Job) onError(err error) {
|
||||
defer j.wg.Done()
|
||||
|
||||
// context cancelled is caught & handled in a different location.
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
|
||||
j.errorCh.Enqueue(err)
|
||||
j.cancel()
|
||||
}
|
||||
|
||||
func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int64) {
|
||||
defer j.wg.Done()
|
||||
|
||||
if err := j.state.SetLastMessageID(ctx, lastMessageID, count); err != nil {
|
||||
j.log.WithError(err).Error("Failed to store last synced message id")
|
||||
j.onError(err)
|
||||
return
|
||||
}
|
||||
j.syncReporter.OnProgress(ctx, count)
|
||||
}
|
||||
|
||||
// begin is expected to be called once the job enters the pipeline.
|
||||
func (j *Job) begin() {
|
||||
j.log.Info("Job started")
|
||||
j.wg.Add(1)
|
||||
j.startChildWaiter()
|
||||
}
|
||||
|
||||
// end is expected to be called once the job has no further work left.
|
||||
func (j *Job) end() {
|
||||
j.log.Info("Job finished")
|
||||
j.wg.Done()
|
||||
}
|
||||
|
||||
// wait waits until the job has finished, the context got cancelled or an error occurred.
|
||||
func (j *Job) wait(ctx context.Context) error {
|
||||
defer j.wg.Wait()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
j.cancel()
|
||||
return ctx.Err()
|
||||
case err := <-j.errorCh.GetChannel():
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (j *Job) newChildJob(messageID string, messageCount int64) childJob {
|
||||
j.log.Infof("Creating new child job")
|
||||
j.wg.Add(1)
|
||||
return childJob{job: j, lastMessageID: messageID, messageCount: messageCount}
|
||||
}
|
||||
|
||||
func (j *Job) startChildWaiter() {
|
||||
j.once.Do(func() {
|
||||
go func() {
|
||||
defer async.HandlePanic(j.panicHandler)
|
||||
|
||||
j.wg.Wait()
|
||||
j.log.Info("All child jobs succeeded")
|
||||
j.errorCh.Enqueue(j.ctx.Err())
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// childJob represents a batch of work that goes down the pipeline. It keeps track of the message ID that is in the
|
||||
// batch and the number of messages in the batch.
|
||||
type childJob struct {
|
||||
job *Job
|
||||
lastMessageID string
|
||||
messageCount int64
|
||||
cachedMessageIDs []string
|
||||
cachedAttachmentIDs []string
|
||||
}
|
||||
|
||||
func (s *childJob) onError(err error) {
|
||||
s.job.log.WithError(err).Info("Child job ran into error")
|
||||
s.job.onError(err)
|
||||
}
|
||||
|
||||
func (s *childJob) userID() string {
|
||||
return s.job.userID
|
||||
}
|
||||
|
||||
func (s *childJob) onFinished(ctx context.Context) {
|
||||
s.job.log.Infof("Child job finished")
|
||||
s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount)
|
||||
s.job.downloadCache.DeleteMessages(s.cachedMessageIDs...)
|
||||
s.job.downloadCache.DeleteAttachments(s.cachedAttachmentIDs...)
|
||||
}
|
||||
|
||||
func (s *childJob) checkCancelled() bool {
|
||||
err := s.job.ctx.Err()
|
||||
if err != nil {
|
||||
s.job.log.Infof("Child job exit due to context cancelled")
|
||||
s.job.wg.Done()
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *childJob) getContext() context.Context {
|
||||
return s.job.ctx
|
||||
}
|
||||
237
internal/services/syncservice/job_test.go
Normal file
237
internal/services/syncservice/job_test.go
Normal file
@ -0,0 +1,237 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func setupGoLeak() goleak.Option {
|
||||
logrus.Trace("prepare for go leak")
|
||||
return goleak.IgnoreCurrent()
|
||||
}
|
||||
|
||||
func TestJob_WaitsOnChildren(t *testing.T) {
|
||||
options := setupGoLeak()
|
||||
defer goleak.VerifyNone(t, options)
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
tj := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
|
||||
|
||||
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("1"), gomock.Eq(int64(0))).Return(nil)
|
||||
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("2"), gomock.Eq(int64(1))).Return(nil)
|
||||
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any()).Times(2)
|
||||
|
||||
go func() {
|
||||
tj.job.begin()
|
||||
job1 := tj.job.newChildJob("1", 0)
|
||||
job2 := tj.job.newChildJob("2", 1)
|
||||
|
||||
job1.onFinished(context.Background())
|
||||
job2.onFinished(context.Background())
|
||||
tj.job.end()
|
||||
}()
|
||||
|
||||
require.NoError(t, tj.job.wait(context.Background()))
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
func TestJob_WaitsOnAllChildrenOnError(t *testing.T) {
|
||||
options := setupGoLeak()
|
||||
defer goleak.VerifyNone(t, options)
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
tj := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
|
||||
|
||||
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("1"), gomock.Eq(int64(0))).Return(nil)
|
||||
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any())
|
||||
|
||||
jobErr := errors.New("failed")
|
||||
|
||||
go func() {
|
||||
job1 := tj.job.newChildJob("1", 0)
|
||||
job2 := tj.job.newChildJob("2", 1)
|
||||
|
||||
job1.onFinished(context.Background())
|
||||
job2.onError(jobErr)
|
||||
}()
|
||||
|
||||
err := tj.job.wait(context.Background())
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, jobErr)
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
func TestJob_MultipleChildrenReportError(t *testing.T) {
|
||||
options := setupGoLeak()
|
||||
defer goleak.VerifyNone(t, options)
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
tj := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
|
||||
|
||||
jobErr := errors.New("failed")
|
||||
|
||||
startCh := make(chan struct{})
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
job := tj.job.newChildJob("1", 0)
|
||||
<-startCh
|
||||
|
||||
job.onError(jobErr)
|
||||
}()
|
||||
}
|
||||
|
||||
close(startCh)
|
||||
err := tj.job.wait(context.Background())
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, jobErr)
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
func TestJob_ChildFailureCancelsAllOtherChildJobs(t *testing.T) {
|
||||
options := setupGoLeak()
|
||||
defer goleak.VerifyNone(t, options)
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
tj := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
|
||||
|
||||
jobErr := errors.New("failed")
|
||||
|
||||
failJob := tj.job.newChildJob("0", 1)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
job := tj.job.newChildJob("1", 0)
|
||||
<-job.getContext().Done()
|
||||
require.ErrorIs(t, job.getContext().Err(), context.Canceled)
|
||||
require.True(t, job.checkCancelled())
|
||||
}()
|
||||
}
|
||||
go func() {
|
||||
failJob.onError(jobErr)
|
||||
}()
|
||||
|
||||
err := tj.job.wait(context.Background())
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, jobErr)
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
func TestJob_CtxCancelCancelsAllChildren(t *testing.T) {
|
||||
options := setupGoLeak()
|
||||
defer goleak.VerifyNone(t, options)
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
job := tj.job.newChildJob("1", 0)
|
||||
<-job.getContext().Done()
|
||||
require.ErrorIs(t, job.getContext().Err(), context.Canceled)
|
||||
require.True(t, job.checkCancelled())
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
func TestJob_WithoutChildJobsCanBeTerminated(t *testing.T) {
|
||||
options := setupGoLeak()
|
||||
defer goleak.VerifyNone(t, options)
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
|
||||
go func() {
|
||||
tj.job.begin()
|
||||
tj.job.end()
|
||||
}()
|
||||
err := tj.job.wait(ctx)
|
||||
require.NoError(t, err)
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
type tjob struct {
|
||||
job *Job
|
||||
client *MockAPIClient
|
||||
messageBuilder *MockMessageBuilder
|
||||
updateApplier *MockUpdateApplier
|
||||
syncReporter *MockReporter
|
||||
state *MockStateProvider
|
||||
}
|
||||
|
||||
func newTestJob(
|
||||
ctx context.Context,
|
||||
mockCtrl *gomock.Controller,
|
||||
userID string,
|
||||
labels LabelMap,
|
||||
) tjob {
|
||||
client := NewMockAPIClient(mockCtrl)
|
||||
messageBuilder := NewMockMessageBuilder(mockCtrl)
|
||||
updateApplier := NewMockUpdateApplier(mockCtrl)
|
||||
syncReporter := NewMockReporter(mockCtrl)
|
||||
state := NewMockStateProvider(mockCtrl)
|
||||
|
||||
job := NewJob(
|
||||
ctx,
|
||||
client,
|
||||
userID,
|
||||
labels,
|
||||
messageBuilder,
|
||||
updateApplier,
|
||||
syncReporter,
|
||||
state,
|
||||
&async.NoopPanicHandler{},
|
||||
newDownloadCache(),
|
||||
logrus.WithField("s", "test"),
|
||||
)
|
||||
|
||||
return tjob{
|
||||
job: job,
|
||||
client: client,
|
||||
messageBuilder: messageBuilder,
|
||||
updateApplier: updateApplier,
|
||||
syncReporter: syncReporter,
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
121
internal/services/syncservice/limits.go
Normal file
121
internal/services/syncservice/limits.go
Normal file
@ -0,0 +1,121 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/pbnjay/memory"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const Kilobyte = uint64(1024)
|
||||
const Megabyte = 1024 * Kilobyte
|
||||
const Gigabyte = 1024 * Megabyte
|
||||
|
||||
func toMB(v uint64) float64 {
|
||||
return float64(v) / float64(Megabyte)
|
||||
}
|
||||
|
||||
type syncLimits struct {
|
||||
MaxDownloadRequestMem uint64
|
||||
MinDownloadRequestMem uint64
|
||||
MaxMessageBuildingMem uint64
|
||||
MinMessageBuildingMem uint64
|
||||
MaxSyncMemory uint64
|
||||
MaxParallelDownloads int
|
||||
DownloadRequestMem uint64
|
||||
MessageBuildMem uint64
|
||||
}
|
||||
|
||||
func newSyncLimits(maxSyncMemory uint64) syncLimits {
|
||||
limits := syncLimits{
|
||||
// There's no point in using more than 128MB of download data per stage, after that we reach a point of diminishing
|
||||
// returns as we can't keep the pipeline fed fast enough.
|
||||
MaxDownloadRequestMem: 128 * Megabyte,
|
||||
|
||||
// Any lower than this and we may fail to download messages.
|
||||
MinDownloadRequestMem: 40 * Megabyte,
|
||||
|
||||
// This value can be increased to your hearts content. The more system memory the user has, the more messages
|
||||
// we can build in parallel.
|
||||
MaxMessageBuildingMem: 128 * Megabyte,
|
||||
MinMessageBuildingMem: 64 * Megabyte,
|
||||
|
||||
// Maximum recommend value for parallel downloads by the API team.
|
||||
MaxParallelDownloads: 20,
|
||||
|
||||
MaxSyncMemory: maxSyncMemory,
|
||||
}
|
||||
|
||||
if _, ok := os.LookupEnv("BRIDGE_SYNC_FORCE_MINIMUM_SPEC"); ok {
|
||||
logrus.Warn("Sync specs forced to minimum")
|
||||
limits.MaxDownloadRequestMem = 50 * Megabyte
|
||||
limits.MaxMessageBuildingMem = 80 * Megabyte
|
||||
limits.MaxParallelDownloads = 2
|
||||
limits.MaxSyncMemory = 800 * Megabyte
|
||||
}
|
||||
|
||||
// Expected mem usage for this whole process should be the sum of MaxMessageBuildingMem and MaxDownloadRequestMem
|
||||
// times x due to pipeline and all additional memory used by network requests and compression+io.
|
||||
|
||||
totalMemory := memory.TotalMemory()
|
||||
|
||||
if limits.MaxSyncMemory >= totalMemory/2 {
|
||||
logrus.Warnf("Requested max sync memory of %v MB is greater than half of system memory (%v MB), forcing to half of system memory",
|
||||
toMB(limits.MaxSyncMemory), toMB(totalMemory/2))
|
||||
limits.MaxSyncMemory = totalMemory / 2
|
||||
}
|
||||
|
||||
if limits.MaxSyncMemory < 800*Megabyte {
|
||||
logrus.Warnf("Requested max sync memory of %v MB, but minimum recommended is 800 MB, forcing max syncMemory to 800MB", toMB(limits.MaxSyncMemory))
|
||||
limits.MaxSyncMemory = 800 * Megabyte
|
||||
}
|
||||
|
||||
logrus.Debugf("Total System Memory: %v", toMB(totalMemory))
|
||||
|
||||
// If less than 2GB available try and limit max memory to 512 MB
|
||||
switch {
|
||||
case limits.MaxSyncMemory < 2*Gigabyte:
|
||||
if limits.MaxSyncMemory < 800*Megabyte {
|
||||
logrus.Warnf("System has less than 800MB of memory, you may experience issues sycing large mailboxes")
|
||||
}
|
||||
limits.DownloadRequestMem = limits.MinDownloadRequestMem
|
||||
limits.MessageBuildMem = limits.MinMessageBuildingMem
|
||||
case limits.MaxSyncMemory == 2*Gigabyte:
|
||||
// Increasing the max download capacity has very little effect on sync speed. We could increase the download
|
||||
// memory but the user would see less sync notifications. A smaller value here leads to more frequent
|
||||
// updates. Additionally, most of sync time is spent in the message building.
|
||||
limits.DownloadRequestMem = limits.MaxDownloadRequestMem
|
||||
// Currently limited so that if a user has multiple accounts active it also doesn't cause excessive memory usage.
|
||||
limits.MessageBuildMem = limits.MaxMessageBuildingMem
|
||||
default:
|
||||
// Divide by 8 as download stage and build stage will use aprox. 4x the specified memory.
|
||||
remainingMemory := (limits.MaxSyncMemory - 2*Gigabyte) / 8
|
||||
limits.DownloadRequestMem = limits.MaxDownloadRequestMem + remainingMemory
|
||||
limits.MessageBuildMem = limits.MaxMessageBuildingMem + remainingMemory
|
||||
}
|
||||
|
||||
logrus.Debugf("Max memory usage for sync Download=%vMB Building=%vMB Predicted Max Total=%vMB",
|
||||
toMB(limits.DownloadRequestMem),
|
||||
toMB(limits.MessageBuildMem),
|
||||
toMB((limits.MessageBuildMem*4)+(limits.DownloadRequestMem*4)),
|
||||
)
|
||||
|
||||
return limits
|
||||
}
|
||||
916
internal/services/syncservice/mocks_test.go
Normal file
916
internal/services/syncservice/mocks_test.go
Normal file
@ -0,0 +1,916 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/sync (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier)
|
||||
|
||||
// Package sync is a generated GoMock package.
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
bytes "bytes"
|
||||
context "context"
|
||||
io "io"
|
||||
reflect "reflect"
|
||||
|
||||
proton "github.com/ProtonMail/go-proton-api"
|
||||
crypto "github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockApplyStageInput is a mock of ApplyStageInput interface.
|
||||
type MockApplyStageInput struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockApplyStageInputMockRecorder
|
||||
}
|
||||
|
||||
// MockApplyStageInputMockRecorder is the mock recorder for MockApplyStageInput.
|
||||
type MockApplyStageInputMockRecorder struct {
|
||||
mock *MockApplyStageInput
|
||||
}
|
||||
|
||||
// NewMockApplyStageInput creates a new mock instance.
|
||||
func NewMockApplyStageInput(ctrl *gomock.Controller) *MockApplyStageInput {
|
||||
mock := &MockApplyStageInput{ctrl: ctrl}
|
||||
mock.recorder = &MockApplyStageInputMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockApplyStageInput) EXPECT() *MockApplyStageInputMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Consume mocks base method.
|
||||
func (m *MockApplyStageInput) Consume(arg0 context.Context) (ApplyRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Consume", arg0)
|
||||
ret0, _ := ret[0].(ApplyRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Consume indicates an expected call of Consume.
|
||||
func (mr *MockApplyStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockApplyStageInput)(nil).Consume), arg0)
|
||||
}
|
||||
|
||||
// TryConsume mocks base method.
|
||||
func (m *MockApplyStageInput) TryConsume(arg0 context.Context) (ApplyRequest, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
||||
ret0, _ := ret[0].(ApplyRequest)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// TryConsume indicates an expected call of TryConsume.
|
||||
func (mr *MockApplyStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockApplyStageInput)(nil).TryConsume), arg0)
|
||||
}
|
||||
|
||||
// MockBuildStageInput is a mock of BuildStageInput interface.
|
||||
type MockBuildStageInput struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockBuildStageInputMockRecorder
|
||||
}
|
||||
|
||||
// MockBuildStageInputMockRecorder is the mock recorder for MockBuildStageInput.
|
||||
type MockBuildStageInputMockRecorder struct {
|
||||
mock *MockBuildStageInput
|
||||
}
|
||||
|
||||
// NewMockBuildStageInput creates a new mock instance.
|
||||
func NewMockBuildStageInput(ctrl *gomock.Controller) *MockBuildStageInput {
|
||||
mock := &MockBuildStageInput{ctrl: ctrl}
|
||||
mock.recorder = &MockBuildStageInputMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockBuildStageInput) EXPECT() *MockBuildStageInputMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Consume mocks base method.
|
||||
func (m *MockBuildStageInput) Consume(arg0 context.Context) (BuildRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Consume", arg0)
|
||||
ret0, _ := ret[0].(BuildRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Consume indicates an expected call of Consume.
|
||||
func (mr *MockBuildStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockBuildStageInput)(nil).Consume), arg0)
|
||||
}
|
||||
|
||||
// TryConsume mocks base method.
|
||||
func (m *MockBuildStageInput) TryConsume(arg0 context.Context) (BuildRequest, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
||||
ret0, _ := ret[0].(BuildRequest)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// TryConsume indicates an expected call of TryConsume.
|
||||
func (mr *MockBuildStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockBuildStageInput)(nil).TryConsume), arg0)
|
||||
}
|
||||
|
||||
// MockBuildStageOutput is a mock of BuildStageOutput interface.
|
||||
type MockBuildStageOutput struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockBuildStageOutputMockRecorder
|
||||
}
|
||||
|
||||
// MockBuildStageOutputMockRecorder is the mock recorder for MockBuildStageOutput.
|
||||
type MockBuildStageOutputMockRecorder struct {
|
||||
mock *MockBuildStageOutput
|
||||
}
|
||||
|
||||
// NewMockBuildStageOutput creates a new mock instance.
|
||||
func NewMockBuildStageOutput(ctrl *gomock.Controller) *MockBuildStageOutput {
|
||||
mock := &MockBuildStageOutput{ctrl: ctrl}
|
||||
mock.recorder = &MockBuildStageOutputMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockBuildStageOutput) EXPECT() *MockBuildStageOutputMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockBuildStageOutput) Close() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Close")
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockBuildStageOutputMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockBuildStageOutput)(nil).Close))
|
||||
}
|
||||
|
||||
// Produce mocks base method.
|
||||
func (m *MockBuildStageOutput) Produce(arg0 context.Context, arg1 ApplyRequest) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Produce", arg0, arg1)
|
||||
}
|
||||
|
||||
// Produce indicates an expected call of Produce.
|
||||
func (mr *MockBuildStageOutputMockRecorder) Produce(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Produce", reflect.TypeOf((*MockBuildStageOutput)(nil).Produce), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockDownloadStageInput is a mock of DownloadStageInput interface.
|
||||
type MockDownloadStageInput struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDownloadStageInputMockRecorder
|
||||
}
|
||||
|
||||
// MockDownloadStageInputMockRecorder is the mock recorder for MockDownloadStageInput.
|
||||
type MockDownloadStageInputMockRecorder struct {
|
||||
mock *MockDownloadStageInput
|
||||
}
|
||||
|
||||
// NewMockDownloadStageInput creates a new mock instance.
|
||||
func NewMockDownloadStageInput(ctrl *gomock.Controller) *MockDownloadStageInput {
|
||||
mock := &MockDownloadStageInput{ctrl: ctrl}
|
||||
mock.recorder = &MockDownloadStageInputMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDownloadStageInput) EXPECT() *MockDownloadStageInputMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Consume mocks base method.
|
||||
func (m *MockDownloadStageInput) Consume(arg0 context.Context) (DownloadRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Consume", arg0)
|
||||
ret0, _ := ret[0].(DownloadRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Consume indicates an expected call of Consume.
|
||||
func (mr *MockDownloadStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockDownloadStageInput)(nil).Consume), arg0)
|
||||
}
|
||||
|
||||
// TryConsume mocks base method.
|
||||
func (m *MockDownloadStageInput) TryConsume(arg0 context.Context) (DownloadRequest, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
||||
ret0, _ := ret[0].(DownloadRequest)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// TryConsume indicates an expected call of TryConsume.
|
||||
func (mr *MockDownloadStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockDownloadStageInput)(nil).TryConsume), arg0)
|
||||
}
|
||||
|
||||
// MockDownloadStageOutput is a mock of DownloadStageOutput interface.
|
||||
type MockDownloadStageOutput struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDownloadStageOutputMockRecorder
|
||||
}
|
||||
|
||||
// MockDownloadStageOutputMockRecorder is the mock recorder for MockDownloadStageOutput.
|
||||
type MockDownloadStageOutputMockRecorder struct {
|
||||
mock *MockDownloadStageOutput
|
||||
}
|
||||
|
||||
// NewMockDownloadStageOutput creates a new mock instance.
|
||||
func NewMockDownloadStageOutput(ctrl *gomock.Controller) *MockDownloadStageOutput {
|
||||
mock := &MockDownloadStageOutput{ctrl: ctrl}
|
||||
mock.recorder = &MockDownloadStageOutputMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDownloadStageOutput) EXPECT() *MockDownloadStageOutputMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockDownloadStageOutput) Close() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Close")
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockDownloadStageOutputMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDownloadStageOutput)(nil).Close))
|
||||
}
|
||||
|
||||
// Produce mocks base method.
|
||||
func (m *MockDownloadStageOutput) Produce(arg0 context.Context, arg1 BuildRequest) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Produce", arg0, arg1)
|
||||
}
|
||||
|
||||
// Produce indicates an expected call of Produce.
|
||||
func (mr *MockDownloadStageOutputMockRecorder) Produce(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Produce", reflect.TypeOf((*MockDownloadStageOutput)(nil).Produce), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockMetadataStageInput is a mock of MetadataStageInput interface.
|
||||
type MockMetadataStageInput struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMetadataStageInputMockRecorder
|
||||
}
|
||||
|
||||
// MockMetadataStageInputMockRecorder is the mock recorder for MockMetadataStageInput.
|
||||
type MockMetadataStageInputMockRecorder struct {
|
||||
mock *MockMetadataStageInput
|
||||
}
|
||||
|
||||
// NewMockMetadataStageInput creates a new mock instance.
|
||||
func NewMockMetadataStageInput(ctrl *gomock.Controller) *MockMetadataStageInput {
|
||||
mock := &MockMetadataStageInput{ctrl: ctrl}
|
||||
mock.recorder = &MockMetadataStageInputMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockMetadataStageInput) EXPECT() *MockMetadataStageInputMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Consume mocks base method.
|
||||
func (m *MockMetadataStageInput) Consume(arg0 context.Context) (*Job, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Consume", arg0)
|
||||
ret0, _ := ret[0].(*Job)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Consume indicates an expected call of Consume.
|
||||
func (mr *MockMetadataStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockMetadataStageInput)(nil).Consume), arg0)
|
||||
}
|
||||
|
||||
// TryConsume mocks base method.
|
||||
func (m *MockMetadataStageInput) TryConsume(arg0 context.Context) (*Job, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
||||
ret0, _ := ret[0].(*Job)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// TryConsume indicates an expected call of TryConsume.
|
||||
func (mr *MockMetadataStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockMetadataStageInput)(nil).TryConsume), arg0)
|
||||
}
|
||||
|
||||
// MockMetadataStageOutput is a mock of MetadataStageOutput interface.
|
||||
type MockMetadataStageOutput struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMetadataStageOutputMockRecorder
|
||||
}
|
||||
|
||||
// MockMetadataStageOutputMockRecorder is the mock recorder for MockMetadataStageOutput.
|
||||
type MockMetadataStageOutputMockRecorder struct {
|
||||
mock *MockMetadataStageOutput
|
||||
}
|
||||
|
||||
// NewMockMetadataStageOutput creates a new mock instance.
|
||||
func NewMockMetadataStageOutput(ctrl *gomock.Controller) *MockMetadataStageOutput {
|
||||
mock := &MockMetadataStageOutput{ctrl: ctrl}
|
||||
mock.recorder = &MockMetadataStageOutputMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockMetadataStageOutput) EXPECT() *MockMetadataStageOutputMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockMetadataStageOutput) Close() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Close")
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockMetadataStageOutputMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMetadataStageOutput)(nil).Close))
|
||||
}
|
||||
|
||||
// Produce mocks base method.
|
||||
func (m *MockMetadataStageOutput) Produce(arg0 context.Context, arg1 DownloadRequest) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Produce", arg0, arg1)
|
||||
}
|
||||
|
||||
// Produce indicates an expected call of Produce.
|
||||
func (mr *MockMetadataStageOutputMockRecorder) Produce(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Produce", reflect.TypeOf((*MockMetadataStageOutput)(nil).Produce), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockStateProvider is a mock of StateProvider interface.
|
||||
type MockStateProvider struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStateProviderMockRecorder
|
||||
}
|
||||
|
||||
// MockStateProviderMockRecorder is the mock recorder for MockStateProvider.
|
||||
type MockStateProviderMockRecorder struct {
|
||||
mock *MockStateProvider
|
||||
}
|
||||
|
||||
// NewMockStateProvider creates a new mock instance.
|
||||
func NewMockStateProvider(ctrl *gomock.Controller) *MockStateProvider {
|
||||
mock := &MockStateProvider{ctrl: ctrl}
|
||||
mock.recorder = &MockStateProviderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStateProvider) EXPECT() *MockStateProviderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddFailedMessageID mocks base method.
|
||||
func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddFailedMessageID", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddFailedMessageID indicates an expected call of AddFailedMessageID.
|
||||
func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), arg0, arg1)
|
||||
}
|
||||
|
||||
// ClearSyncStatus mocks base method.
|
||||
func (m *MockStateProvider) ClearSyncStatus(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClearSyncStatus", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClearSyncStatus indicates an expected call of ClearSyncStatus.
|
||||
func (mr *MockStateProviderMockRecorder) ClearSyncStatus(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearSyncStatus", reflect.TypeOf((*MockStateProvider)(nil).ClearSyncStatus), arg0)
|
||||
}
|
||||
|
||||
// GetSyncStatus mocks base method.
|
||||
func (m *MockStateProvider) GetSyncStatus(arg0 context.Context) (Status, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetSyncStatus", arg0)
|
||||
ret0, _ := ret[0].(Status)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSyncStatus indicates an expected call of GetSyncStatus.
|
||||
func (mr *MockStateProviderMockRecorder) GetSyncStatus(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSyncStatus", reflect.TypeOf((*MockStateProvider)(nil).GetSyncStatus), arg0)
|
||||
}
|
||||
|
||||
// RemFailedMessageID mocks base method.
|
||||
func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemFailedMessageID", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemFailedMessageID indicates an expected call of RemFailedMessageID.
|
||||
func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), arg0, arg1)
|
||||
}
|
||||
|
||||
// SetHasLabels mocks base method.
|
||||
func (m *MockStateProvider) SetHasLabels(arg0 context.Context, arg1 bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetHasLabels", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetHasLabels indicates an expected call of SetHasLabels.
|
||||
func (mr *MockStateProviderMockRecorder) SetHasLabels(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHasLabels", reflect.TypeOf((*MockStateProvider)(nil).SetHasLabels), arg0, arg1)
|
||||
}
|
||||
|
||||
// SetHasMessages mocks base method.
|
||||
func (m *MockStateProvider) SetHasMessages(arg0 context.Context, arg1 bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetHasMessages", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetHasMessages indicates an expected call of SetHasMessages.
|
||||
func (mr *MockStateProviderMockRecorder) SetHasMessages(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHasMessages", reflect.TypeOf((*MockStateProvider)(nil).SetHasMessages), arg0, arg1)
|
||||
}
|
||||
|
||||
// SetLastMessageID mocks base method.
|
||||
func (m *MockStateProvider) SetLastMessageID(arg0 context.Context, arg1 string, arg2 int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetLastMessageID", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetLastMessageID indicates an expected call of SetLastMessageID.
|
||||
func (mr *MockStateProviderMockRecorder) SetLastMessageID(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLastMessageID", reflect.TypeOf((*MockStateProvider)(nil).SetLastMessageID), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetMessageCount mocks base method.
|
||||
func (m *MockStateProvider) SetMessageCount(arg0 context.Context, arg1 int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetMessageCount", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetMessageCount indicates an expected call of SetMessageCount.
|
||||
func (mr *MockStateProviderMockRecorder) SetMessageCount(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMessageCount", reflect.TypeOf((*MockStateProvider)(nil).SetMessageCount), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockRegulator is a mock of Regulator interface.
|
||||
type MockRegulator struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRegulatorMockRecorder
|
||||
}
|
||||
|
||||
// MockRegulatorMockRecorder is the mock recorder for MockRegulator.
|
||||
type MockRegulatorMockRecorder struct {
|
||||
mock *MockRegulator
|
||||
}
|
||||
|
||||
// NewMockRegulator creates a new mock instance.
|
||||
func NewMockRegulator(ctrl *gomock.Controller) *MockRegulator {
|
||||
mock := &MockRegulator{ctrl: ctrl}
|
||||
mock.recorder = &MockRegulatorMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRegulator) EXPECT() *MockRegulatorMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Sync mocks base method.
|
||||
func (m *MockRegulator) Sync(arg0 context.Context, arg1 *Job) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Sync", arg0, arg1)
|
||||
}
|
||||
|
||||
// Sync indicates an expected call of Sync.
|
||||
func (mr *MockRegulatorMockRecorder) Sync(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sync", reflect.TypeOf((*MockRegulator)(nil).Sync), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockUpdateApplier is a mock of UpdateApplier interface.
|
||||
type MockUpdateApplier struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUpdateApplierMockRecorder
|
||||
}
|
||||
|
||||
// MockUpdateApplierMockRecorder is the mock recorder for MockUpdateApplier.
|
||||
type MockUpdateApplierMockRecorder struct {
|
||||
mock *MockUpdateApplier
|
||||
}
|
||||
|
||||
// NewMockUpdateApplier creates a new mock instance.
|
||||
func NewMockUpdateApplier(ctrl *gomock.Controller) *MockUpdateApplier {
|
||||
mock := &MockUpdateApplier{ctrl: ctrl}
|
||||
mock.recorder = &MockUpdateApplierMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockUpdateApplier) EXPECT() *MockUpdateApplierMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ApplySyncUpdates mocks base method.
|
||||
func (m *MockUpdateApplier) ApplySyncUpdates(arg0 context.Context, arg1 []BuildResult) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ApplySyncUpdates", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ApplySyncUpdates indicates an expected call of ApplySyncUpdates.
|
||||
func (mr *MockUpdateApplierMockRecorder) ApplySyncUpdates(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplySyncUpdates", reflect.TypeOf((*MockUpdateApplier)(nil).ApplySyncUpdates), arg0, arg1)
|
||||
}
|
||||
|
||||
// SyncLabels mocks base method.
|
||||
func (m *MockUpdateApplier) SyncLabels(arg0 context.Context, arg1 map[string]proton.Label) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SyncLabels", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SyncLabels indicates an expected call of SyncLabels.
|
||||
func (mr *MockUpdateApplierMockRecorder) SyncLabels(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncLabels", reflect.TypeOf((*MockUpdateApplier)(nil).SyncLabels), arg0, arg1)
|
||||
}
|
||||
|
||||
// SyncSystemLabelsOnly mocks base method.
|
||||
func (m *MockUpdateApplier) SyncSystemLabelsOnly(arg0 context.Context, arg1 map[string]proton.Label) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SyncSystemLabelsOnly", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SyncSystemLabelsOnly indicates an expected call of SyncSystemLabelsOnly.
|
||||
func (mr *MockUpdateApplierMockRecorder) SyncSystemLabelsOnly(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncSystemLabelsOnly", reflect.TypeOf((*MockUpdateApplier)(nil).SyncSystemLabelsOnly), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockMessageBuilder is a mock of MessageBuilder interface.
|
||||
type MockMessageBuilder struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMessageBuilderMockRecorder
|
||||
}
|
||||
|
||||
// MockMessageBuilderMockRecorder is the mock recorder for MockMessageBuilder.
|
||||
type MockMessageBuilderMockRecorder struct {
|
||||
mock *MockMessageBuilder
|
||||
}
|
||||
|
||||
// NewMockMessageBuilder creates a new mock instance.
|
||||
func NewMockMessageBuilder(ctrl *gomock.Controller) *MockMessageBuilder {
|
||||
mock := &MockMessageBuilder{ctrl: ctrl}
|
||||
mock.recorder = &MockMessageBuilderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockMessageBuilder) EXPECT() *MockMessageBuilderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BuildMessage mocks base method.
|
||||
func (m *MockMessageBuilder) BuildMessage(arg0 map[string]proton.Label, arg1 proton.FullMessage, arg2 *crypto.KeyRing, arg3 *bytes.Buffer) (BuildResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BuildMessage", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(BuildResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// BuildMessage indicates an expected call of BuildMessage.
|
||||
func (mr *MockMessageBuilderMockRecorder) BuildMessage(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildMessage", reflect.TypeOf((*MockMessageBuilder)(nil).BuildMessage), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// WithKeys mocks base method.
|
||||
func (m *MockMessageBuilder) WithKeys(arg0 func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "WithKeys", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// WithKeys indicates an expected call of WithKeys.
|
||||
func (mr *MockMessageBuilderMockRecorder) WithKeys(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithKeys", reflect.TypeOf((*MockMessageBuilder)(nil).WithKeys), arg0)
|
||||
}
|
||||
|
||||
// MockAPIClient is a mock of APIClient interface.
|
||||
type MockAPIClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockAPIClientMockRecorder
|
||||
}
|
||||
|
||||
// MockAPIClientMockRecorder is the mock recorder for MockAPIClient.
|
||||
type MockAPIClientMockRecorder struct {
|
||||
mock *MockAPIClient
|
||||
}
|
||||
|
||||
// NewMockAPIClient creates a new mock instance.
|
||||
func NewMockAPIClient(ctrl *gomock.Controller) *MockAPIClient {
|
||||
mock := &MockAPIClient{ctrl: ctrl}
|
||||
mock.recorder = &MockAPIClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockAPIClient) EXPECT() *MockAPIClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetAttachment mocks base method.
|
||||
func (m *MockAPIClient) GetAttachment(arg0 context.Context, arg1 string) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAttachment", arg0, arg1)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAttachment indicates an expected call of GetAttachment.
|
||||
func (mr *MockAPIClientMockRecorder) GetAttachment(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockAPIClient)(nil).GetAttachment), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetAttachmentInto mocks base method.
|
||||
func (m *MockAPIClient) GetAttachmentInto(arg0 context.Context, arg1 string, arg2 io.ReaderFrom) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAttachmentInto", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAttachmentInto indicates an expected call of GetAttachmentInto.
|
||||
func (mr *MockAPIClientMockRecorder) GetAttachmentInto(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachmentInto", reflect.TypeOf((*MockAPIClient)(nil).GetAttachmentInto), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetFullMessage mocks base method.
|
||||
func (m *MockAPIClient) GetFullMessage(arg0 context.Context, arg1 string, arg2 proton.Scheduler, arg3 proton.AttachmentAllocator) (proton.FullMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetFullMessage", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(proton.FullMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetFullMessage indicates an expected call of GetFullMessage.
|
||||
func (mr *MockAPIClientMockRecorder) GetFullMessage(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFullMessage", reflect.TypeOf((*MockAPIClient)(nil).GetFullMessage), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// GetGroupedMessageCount mocks base method.
|
||||
func (m *MockAPIClient) GetGroupedMessageCount(arg0 context.Context) ([]proton.MessageGroupCount, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupedMessageCount", arg0)
|
||||
ret0, _ := ret[0].([]proton.MessageGroupCount)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupedMessageCount indicates an expected call of GetGroupedMessageCount.
|
||||
func (mr *MockAPIClientMockRecorder) GetGroupedMessageCount(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupedMessageCount", reflect.TypeOf((*MockAPIClient)(nil).GetGroupedMessageCount), arg0)
|
||||
}
|
||||
|
||||
// GetLabels mocks base method.
|
||||
func (m *MockAPIClient) GetLabels(arg0 context.Context, arg1 ...proton.LabelType) ([]proton.Label, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "GetLabels", varargs...)
|
||||
ret0, _ := ret[0].([]proton.Label)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetLabels indicates an expected call of GetLabels.
|
||||
func (mr *MockAPIClientMockRecorder) GetLabels(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLabels", reflect.TypeOf((*MockAPIClient)(nil).GetLabels), varargs...)
|
||||
}
|
||||
|
||||
// GetMessage mocks base method.
|
||||
func (m *MockAPIClient) GetMessage(arg0 context.Context, arg1 string) (proton.Message, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMessage", arg0, arg1)
|
||||
ret0, _ := ret[0].(proton.Message)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMessage indicates an expected call of GetMessage.
|
||||
func (mr *MockAPIClientMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockAPIClient)(nil).GetMessage), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetMessageIDs mocks base method.
|
||||
func (m *MockAPIClient) GetMessageIDs(arg0 context.Context, arg1 string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMessageIDs", arg0, arg1)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMessageIDs indicates an expected call of GetMessageIDs.
|
||||
func (mr *MockAPIClientMockRecorder) GetMessageIDs(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageIDs", reflect.TypeOf((*MockAPIClient)(nil).GetMessageIDs), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetMessageMetadataPage mocks base method.
|
||||
func (m *MockAPIClient) GetMessageMetadataPage(arg0 context.Context, arg1, arg2 int, arg3 proton.MessageFilter) ([]proton.MessageMetadata, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMessageMetadataPage", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]proton.MessageMetadata)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMessageMetadataPage indicates an expected call of GetMessageMetadataPage.
|
||||
func (mr *MockAPIClientMockRecorder) GetMessageMetadataPage(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageMetadataPage", reflect.TypeOf((*MockAPIClient)(nil).GetMessageMetadataPage), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// MockReporter is a mock of Reporter interface.
|
||||
type MockReporter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockReporterMockRecorder
|
||||
}
|
||||
|
||||
// MockReporterMockRecorder is the mock recorder for MockReporter.
|
||||
type MockReporterMockRecorder struct {
|
||||
mock *MockReporter
|
||||
}
|
||||
|
||||
// NewMockReporter creates a new mock instance.
|
||||
func NewMockReporter(ctrl *gomock.Controller) *MockReporter {
|
||||
mock := &MockReporter{ctrl: ctrl}
|
||||
mock.recorder = &MockReporterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockReporter) EXPECT() *MockReporterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// OnError mocks base method.
|
||||
func (m *MockReporter) OnError(arg0 context.Context, arg1 error) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnError", arg0, arg1)
|
||||
}
|
||||
|
||||
// OnError indicates an expected call of OnError.
|
||||
func (mr *MockReporterMockRecorder) OnError(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockReporter)(nil).OnError), arg0, arg1)
|
||||
}
|
||||
|
||||
// OnFinished mocks base method.
|
||||
func (m *MockReporter) OnFinished(arg0 context.Context) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnFinished", arg0)
|
||||
}
|
||||
|
||||
// OnFinished indicates an expected call of OnFinished.
|
||||
func (mr *MockReporterMockRecorder) OnFinished(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnFinished", reflect.TypeOf((*MockReporter)(nil).OnFinished), arg0)
|
||||
}
|
||||
|
||||
// OnProgress mocks base method.
|
||||
func (m *MockReporter) OnProgress(arg0 context.Context, arg1 int64) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnProgress", arg0, arg1)
|
||||
}
|
||||
|
||||
// OnProgress indicates an expected call of OnProgress.
|
||||
func (mr *MockReporterMockRecorder) OnProgress(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnProgress", reflect.TypeOf((*MockReporter)(nil).OnProgress), arg0, arg1)
|
||||
}
|
||||
|
||||
// OnStart mocks base method.
|
||||
func (m *MockReporter) OnStart(arg0 context.Context) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnStart", arg0)
|
||||
}
|
||||
|
||||
// OnStart indicates an expected call of OnStart.
|
||||
func (mr *MockReporterMockRecorder) OnStart(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnStart", reflect.TypeOf((*MockReporter)(nil).OnStart), arg0)
|
||||
}
|
||||
|
||||
// MockDownloadRateModifier is a mock of DownloadRateModifier interface.
|
||||
type MockDownloadRateModifier struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDownloadRateModifierMockRecorder
|
||||
}
|
||||
|
||||
// MockDownloadRateModifierMockRecorder is the mock recorder for MockDownloadRateModifier.
|
||||
type MockDownloadRateModifierMockRecorder struct {
|
||||
mock *MockDownloadRateModifier
|
||||
}
|
||||
|
||||
// NewMockDownloadRateModifier creates a new mock instance.
|
||||
func NewMockDownloadRateModifier(ctrl *gomock.Controller) *MockDownloadRateModifier {
|
||||
mock := &MockDownloadRateModifier{ctrl: ctrl}
|
||||
mock.recorder = &MockDownloadRateModifierMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDownloadRateModifier) EXPECT() *MockDownloadRateModifierMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Apply mocks base method.
|
||||
func (m *MockDownloadRateModifier) Apply(arg0 bool, arg1, arg2 int) int {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Apply", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(int)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Apply indicates an expected call of Apply.
|
||||
func (mr *MockDownloadRateModifierMockRecorder) Apply(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockDownloadRateModifier)(nil).Apply), arg0, arg1, arg2)
|
||||
}
|
||||
78
internal/services/syncservice/service.go
Normal file
78
internal/services/syncservice/service.go
Normal file
@ -0,0 +1,78 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/reporter"
|
||||
)
|
||||
|
||||
// Service which mediates IMAP syncing in Bridge.
|
||||
// IMPORTANT: Be sure to cancel all ongoing sync Handlers before cancelling this service's Group.
|
||||
type Service struct {
|
||||
metadataStage *MetadataStage
|
||||
downloadStage *DownloadStage
|
||||
buildStage *BuildStage
|
||||
applyStage *ApplyStage
|
||||
limits syncLimits
|
||||
metaCh *ChannelConsumerProducer[*Job]
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
func NewService(reporter reporter.Reporter,
|
||||
panicHandler async.PanicHandler,
|
||||
) *Service {
|
||||
limits := newSyncLimits(2 * Gigabyte)
|
||||
|
||||
metaCh := NewChannelConsumerProducer[*Job]()
|
||||
downloadCh := NewChannelConsumerProducer[DownloadRequest]()
|
||||
buildCh := NewChannelConsumerProducer[BuildRequest]()
|
||||
applyCh := NewChannelConsumerProducer[ApplyRequest]()
|
||||
|
||||
return &Service{
|
||||
limits: limits,
|
||||
metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem),
|
||||
downloadStage: NewDownloadStage(downloadCh, buildCh, 20, panicHandler),
|
||||
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
|
||||
applyStage: NewApplyStage(applyCh),
|
||||
metaCh: metaCh,
|
||||
panicHandler: panicHandler,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run(group *async.Group) {
|
||||
group.Once(func(ctx context.Context) {
|
||||
syncGroup := async.NewGroup(ctx, s.panicHandler)
|
||||
|
||||
s.metadataStage.Run(syncGroup)
|
||||
s.downloadStage.Run(syncGroup)
|
||||
s.buildStage.Run(syncGroup)
|
||||
s.applyStage.Run(syncGroup)
|
||||
|
||||
defer s.metaCh.Close()
|
||||
defer syncGroup.CancelAndWait()
|
||||
|
||||
<-ctx.Done()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) Sync(ctx context.Context, stage *Job) {
|
||||
s.metaCh.Produce(ctx, stage)
|
||||
}
|
||||
78
internal/services/syncservice/stage_apply.go
Normal file
78
internal/services/syncservice/stage_apply.go
Normal file
@ -0,0 +1,78 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type ApplyRequest struct {
|
||||
childJob
|
||||
messages []BuildResult
|
||||
}
|
||||
|
||||
type ApplyStageInput = StageInputConsumer[ApplyRequest]
|
||||
|
||||
// ApplyStage applies the sync updates and waits for their completion before proceeding with the next batch. This is
|
||||
// the final stage in the sync pipeline.
|
||||
type ApplyStage struct {
|
||||
input ApplyStageInput
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
func NewApplyStage(input ApplyStageInput) *ApplyStage {
|
||||
return &ApplyStage{input: input, log: logrus.WithField("sync-stage", "apply")}
|
||||
}
|
||||
|
||||
func (a *ApplyStage) Run(group *async.Group) {
|
||||
group.Once(a.run)
|
||||
}
|
||||
|
||||
func (a *ApplyStage) run(ctx context.Context) {
|
||||
for {
|
||||
req, err := a.input.Consume(ctx)
|
||||
if err != nil {
|
||||
if !(errors.Is(err, ErrNoMoreInput) || errors.Is(err, context.Canceled)) {
|
||||
a.log.WithError(err).Error("Exiting state with error")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if req.checkCancelled() {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(req.messages) == 0 {
|
||||
req.onFinished(req.getContext())
|
||||
continue
|
||||
}
|
||||
|
||||
if err := req.job.updateApplier.ApplySyncUpdates(ctx, req.messages); err != nil {
|
||||
a.log.WithError(err).Error("Failed to apply sync updates")
|
||||
req.job.onError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
req.onFinished(req.getContext())
|
||||
}
|
||||
}
|
||||
138
internal/services/syncservice/stage_apply_test.go
Normal file
138
internal/services/syncservice/stage_apply_test.go
Normal file
@ -0,0 +1,138 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyStage_CancelledJobIsDiscarded(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[ApplyRequest]()
|
||||
|
||||
stage := NewApplyStage(input)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
jobCtx, jobCancel := context.WithCancel(context.Background())
|
||||
|
||||
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
|
||||
|
||||
tj.job.begin()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
tj.job.end()
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
jobCancel()
|
||||
input.Produce(ctx, ApplyRequest{
|
||||
childJob: childJob,
|
||||
messages: nil,
|
||||
})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestApplyStage_JobWithNoMessagesIsFinalized(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[ApplyRequest]()
|
||||
|
||||
stage := NewApplyStage(input)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
jobCtx, jobCancel := context.WithCancel(context.Background())
|
||||
defer jobCancel()
|
||||
|
||||
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
|
||||
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any())
|
||||
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("f"), gomock.Eq(int64(10)))
|
||||
|
||||
tj.job.begin()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
tj.job.end()
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
input.Produce(ctx, ApplyRequest{
|
||||
childJob: childJob,
|
||||
messages: nil,
|
||||
})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
cancel()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestApplyStage_ErrorOnApplyIsReportedAndJobFails(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[ApplyRequest]()
|
||||
|
||||
stage := NewApplyStage(input)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
jobCtx, jobCancel := context.WithCancel(context.Background())
|
||||
defer jobCancel()
|
||||
|
||||
buildResults := []BuildResult{
|
||||
{
|
||||
AddressID: "Foo",
|
||||
MessageID: "Bar",
|
||||
Update: &imap.MessageCreated{},
|
||||
},
|
||||
}
|
||||
|
||||
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
|
||||
|
||||
applyErr := errors.New("apply failed")
|
||||
tj.updateApplier.EXPECT().ApplySyncUpdates(gomock.Any(), gomock.Eq(buildResults)).Return(applyErr)
|
||||
|
||||
tj.job.begin()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
tj.job.end()
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
input.Produce(ctx, ApplyRequest{
|
||||
childJob: childJob,
|
||||
messages: buildResults,
|
||||
})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
cancel()
|
||||
require.ErrorIs(t, err, applyErr)
|
||||
}
|
||||
198
internal/services/syncservice/stage_build.go
Normal file
198
internal/services/syncservice/stage_build.go
Normal file
@ -0,0 +1,198 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/reporter"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type BuildRequest struct {
|
||||
childJob
|
||||
batch []proton.FullMessage
|
||||
}
|
||||
|
||||
type BuildStageInput = StageInputConsumer[BuildRequest]
|
||||
type BuildStageOutput = StageOutputProducer[ApplyRequest]
|
||||
|
||||
// BuildStage is in charge of decrypting and converting the downloaded messages from the previous stage into
|
||||
// RFC822 compliant messages which can then be sent to the IMAP server.
|
||||
type BuildStage struct {
|
||||
input BuildStageInput
|
||||
output BuildStageOutput
|
||||
maxBuildMem uint64
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
reporter reporter.Reporter
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
func NewBuildStage(
|
||||
input BuildStageInput,
|
||||
output BuildStageOutput,
|
||||
maxBuildMem uint64,
|
||||
panicHandler async.PanicHandler,
|
||||
reporter reporter.Reporter,
|
||||
) *BuildStage {
|
||||
return &BuildStage{
|
||||
input: input,
|
||||
output: output,
|
||||
maxBuildMem: maxBuildMem,
|
||||
log: logrus.WithField("sync-stage", "build"),
|
||||
panicHandler: panicHandler,
|
||||
reporter: reporter,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BuildStage) Run(group *async.Group) {
|
||||
group.Once(b.run)
|
||||
}
|
||||
|
||||
func (b *BuildStage) run(ctx context.Context) {
|
||||
maxMessagesInParallel := runtime.NumCPU()
|
||||
|
||||
defer b.output.Close()
|
||||
for {
|
||||
req, err := b.input.Consume(ctx)
|
||||
if err != nil {
|
||||
if !(errors.Is(err, ErrNoMoreInput) || errors.Is(err, context.Canceled)) {
|
||||
b.log.WithError(err).Error("Exiting state with error")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if req.checkCancelled() {
|
||||
continue
|
||||
}
|
||||
|
||||
err = req.job.messageBuilder.WithKeys(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
chunks := chunkSyncBuilderBatch(req.batch, b.maxBuildMem)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (BuildResult, error) {
|
||||
defer async.HandlePanic(b.panicHandler)
|
||||
|
||||
kr, ok := addrKRs[msg.AddressID]
|
||||
if !ok {
|
||||
req.job.log.Errorf("Address '%v' on message '%v' does not have an unlocked kerying", msg.AddressID, msg.ID)
|
||||
|
||||
if err := req.job.state.AddFailedMessageID(req.getContext(), msg.ID); err != nil {
|
||||
req.job.log.WithError(err).Error("Failed to add failed message ID")
|
||||
}
|
||||
|
||||
if err := b.reporter.ReportMessageWithContext("Failed to build message - no unlocked keyring (sync)", reporter.Context{
|
||||
"messageID": msg.ID,
|
||||
"userID": req.userID(),
|
||||
}); err != nil {
|
||||
req.job.log.WithError(err).Error("Failed to report message build error")
|
||||
}
|
||||
return BuildResult{}, nil
|
||||
}
|
||||
|
||||
res, err := req.job.messageBuilder.BuildMessage(req.job.labels, msg, kr, new(bytes.Buffer))
|
||||
if err != nil {
|
||||
req.job.log.WithError(err).WithField("msgID", msg.ID).Error("Failed to build message (syn)")
|
||||
|
||||
if err := req.job.state.AddFailedMessageID(req.getContext(), msg.ID); err != nil {
|
||||
req.job.log.WithError(err).Error("Failed to add failed message ID")
|
||||
}
|
||||
|
||||
if err := b.reporter.ReportMessageWithContext("Failed to build message (sync)", reporter.Context{
|
||||
"messageID": msg.ID,
|
||||
"error": err,
|
||||
"userID": req.userID(),
|
||||
}); err != nil {
|
||||
req.job.log.WithError(err).Error("Failed to report message build error")
|
||||
}
|
||||
|
||||
// We could sync a placeholder message here, but for now we skip it entirely.
|
||||
return BuildResult{}, nil
|
||||
}
|
||||
|
||||
if err := req.job.state.RemFailedMessageID(req.getContext(), res.MessageID); err != nil {
|
||||
req.job.log.WithError(err).Error("Failed to remove failed message ID")
|
||||
}
|
||||
|
||||
return res, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.output.Produce(ctx, ApplyRequest{
|
||||
childJob: req.childJob,
|
||||
messages: xslices.Filter(result, func(t BuildResult) bool {
|
||||
return t.Update != nil
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
req.job.onError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chunkSyncBuilderBatch(batch []proton.FullMessage, maxMemory uint64) [][]proton.FullMessage {
|
||||
var expectedMemUsage uint64
|
||||
var chunks [][]proton.FullMessage
|
||||
var lastIndex int
|
||||
var index int
|
||||
|
||||
for _, v := range batch {
|
||||
var dataSize uint64
|
||||
for _, a := range v.Attachments {
|
||||
dataSize += uint64(a.Size)
|
||||
}
|
||||
|
||||
// 2x increase for attachment due to extra memory needed for decrypting and writing
|
||||
// in memory buffer.
|
||||
dataSize *= 2
|
||||
dataSize += uint64(len(v.Body))
|
||||
|
||||
nextMemSize := expectedMemUsage + dataSize
|
||||
if nextMemSize >= maxMemory {
|
||||
chunks = append(chunks, batch[lastIndex:index])
|
||||
lastIndex = index
|
||||
expectedMemUsage = dataSize
|
||||
} else {
|
||||
expectedMemUsage = nextMemSize
|
||||
}
|
||||
|
||||
index++
|
||||
}
|
||||
|
||||
if lastIndex < len(batch) {
|
||||
chunks = append(chunks, batch[lastIndex:])
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
317
internal/services/syncservice/stage_build_test.go
Normal file
317
internal/services/syncservice/stage_build_test.go
Normal file
@ -0,0 +1,317 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/reporter"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/bridge/mocks"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSyncChunkSyncBuilderBatch(t *testing.T) {
|
||||
// GODT-2424 - Some messages were not fully built due to a bug in the chunking if the total memory used by the
|
||||
// message would be higher than the maximum we allowed.
|
||||
const totalMessageCount = 100
|
||||
|
||||
msg := proton.FullMessage{
|
||||
Message: proton.Message{
|
||||
Attachments: []proton.Attachment{
|
||||
{
|
||||
Size: int64(8 * Megabyte),
|
||||
},
|
||||
},
|
||||
},
|
||||
AttData: nil,
|
||||
}
|
||||
|
||||
messages := xslices.Repeat(msg, totalMessageCount)
|
||||
|
||||
chunks := chunkSyncBuilderBatch(messages, 16*Megabyte)
|
||||
|
||||
var totalMessagesInChunks int
|
||||
|
||||
for _, v := range chunks {
|
||||
totalMessagesInChunks += len(v)
|
||||
}
|
||||
|
||||
require.Equal(t, totalMessagesInChunks, totalMessageCount)
|
||||
}
|
||||
|
||||
func TestBuildStage_SuccessRemovesFailedMessage(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[BuildRequest]()
|
||||
output := NewChannelConsumerProducer[ApplyRequest]()
|
||||
reporter := mocks.NewMockReporter(mockCtrl)
|
||||
|
||||
labels := getTestLabels()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tj := newTestJob(ctx, mockCtrl, "u", labels)
|
||||
|
||||
msg := proton.FullMessage{
|
||||
Message: proton.Message{
|
||||
MessageMetadata: proton.MessageMetadata{
|
||||
ID: "MSG",
|
||||
AddressID: "addrID",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tj.messageBuilder.EXPECT().WithKeys(gomock.Any()).DoAndReturn(func(f func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
|
||||
require.NoError(t, f(nil, map[string]*crypto.KeyRing{
|
||||
"addrID": {},
|
||||
}))
|
||||
return nil
|
||||
})
|
||||
|
||||
tj.job.begin()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
tj.job.end()
|
||||
|
||||
buildResult := BuildResult{
|
||||
AddressID: "addrID",
|
||||
MessageID: "MSG",
|
||||
Update: &imap.MessageCreated{},
|
||||
}
|
||||
|
||||
tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(buildResult, nil)
|
||||
tj.state.EXPECT().RemFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
|
||||
|
||||
stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, reporter)
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
|
||||
|
||||
req, err := output.Consume(ctx)
|
||||
cancel()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, req.messages, 1)
|
||||
require.Equal(t, buildResult, req.messages[0])
|
||||
}
|
||||
|
||||
func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[BuildRequest]()
|
||||
output := NewChannelConsumerProducer[ApplyRequest]()
|
||||
mockReporter := mocks.NewMockReporter(mockCtrl)
|
||||
|
||||
labels := getTestLabels()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tj := newTestJob(ctx, mockCtrl, "u", labels)
|
||||
|
||||
msg := proton.FullMessage{
|
||||
Message: proton.Message{
|
||||
MessageMetadata: proton.MessageMetadata{
|
||||
ID: "MSG",
|
||||
AddressID: "addrID",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tj.messageBuilder.EXPECT().WithKeys(gomock.Any()).DoAndReturn(func(f func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
|
||||
require.NoError(t, f(nil, map[string]*crypto.KeyRing{
|
||||
"addrID": {},
|
||||
}))
|
||||
return nil
|
||||
})
|
||||
|
||||
tj.job.begin()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
tj.job.end()
|
||||
|
||||
buildError := errors.New("it failed")
|
||||
|
||||
tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(BuildResult{}, buildError)
|
||||
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
|
||||
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
|
||||
"userID": "u",
|
||||
"messageID": "MSG",
|
||||
"error": buildError,
|
||||
})).Return(nil)
|
||||
|
||||
stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, mockReporter)
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
|
||||
|
||||
req, err := output.Consume(ctx)
|
||||
cancel()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, req.messages)
|
||||
}
|
||||
|
||||
func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[BuildRequest]()
|
||||
output := NewChannelConsumerProducer[ApplyRequest]()
|
||||
mockReporter := mocks.NewMockReporter(mockCtrl)
|
||||
|
||||
labels := getTestLabels()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tj := newTestJob(ctx, mockCtrl, "u", labels)
|
||||
|
||||
msg := proton.FullMessage{
|
||||
Message: proton.Message{
|
||||
MessageMetadata: proton.MessageMetadata{
|
||||
ID: "MSG",
|
||||
AddressID: "addrID",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tj.messageBuilder.EXPECT().WithKeys(gomock.Any()).DoAndReturn(func(f func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
|
||||
require.NoError(t, f(nil, map[string]*crypto.KeyRing{}))
|
||||
return nil
|
||||
})
|
||||
|
||||
tj.job.begin()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
tj.job.end()
|
||||
|
||||
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
|
||||
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
|
||||
"userID": "u",
|
||||
"messageID": "MSG",
|
||||
})).Return(nil)
|
||||
|
||||
stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, mockReporter)
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
|
||||
|
||||
req, err := output.Consume(ctx)
|
||||
cancel()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, req.messages)
|
||||
}
|
||||
|
||||
func TestBuildStage_OtherErrorsFailJob(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[BuildRequest]()
|
||||
output := NewChannelConsumerProducer[ApplyRequest]()
|
||||
mockReporter := mocks.NewMockReporter(mockCtrl)
|
||||
|
||||
labels := getTestLabels()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tj := newTestJob(ctx, mockCtrl, "u", labels)
|
||||
|
||||
msg := proton.FullMessage{
|
||||
Message: proton.Message{
|
||||
MessageMetadata: proton.MessageMetadata{
|
||||
ID: "MSG",
|
||||
AddressID: "addrID",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expectedErr := errors.New("something went wrong")
|
||||
|
||||
tj.messageBuilder.EXPECT().WithKeys(gomock.Any()).DoAndReturn(func(f func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
|
||||
return expectedErr
|
||||
})
|
||||
|
||||
tj.job.begin()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
tj.job.end()
|
||||
|
||||
stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, mockReporter)
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
require.Equal(t, expectedErr, err)
|
||||
|
||||
cancel()
|
||||
|
||||
_, err = output.Consume(context.Background())
|
||||
require.ErrorIs(t, err, ErrNoMoreInput)
|
||||
}
|
||||
|
||||
func TestBuildStage_CancelledJobIsDiscarded(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[BuildRequest]()
|
||||
output := NewChannelConsumerProducer[ApplyRequest]()
|
||||
mockReporter := mocks.NewMockReporter(mockCtrl)
|
||||
|
||||
msg := proton.FullMessage{
|
||||
Message: proton.Message{
|
||||
MessageMetadata: proton.MessageMetadata{
|
||||
ID: "MSG",
|
||||
AddressID: "addrID",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, mockReporter)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
jobCtx, jobCancel := context.WithCancel(context.Background())
|
||||
|
||||
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
|
||||
|
||||
tj.job.begin()
|
||||
defer tj.job.end()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
jobCancel()
|
||||
input.Produce(ctx, BuildRequest{
|
||||
childJob: childJob,
|
||||
batch: []proton.FullMessage{msg},
|
||||
})
|
||||
|
||||
go func() { cancel() }()
|
||||
|
||||
_, err := output.Consume(context.Background())
|
||||
require.ErrorIs(t, err, ErrNoMoreInput)
|
||||
}
|
||||
279
internal/services/syncservice/stage_download.go
Normal file
279
internal/services/syncservice/stage_download.go
Normal file
@ -0,0 +1,279 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/logging"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/network"
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type DownloadRequest struct {
|
||||
childJob
|
||||
ids []string
|
||||
}
|
||||
|
||||
type DownloadStageInput = StageInputConsumer[DownloadRequest]
|
||||
type DownloadStageOutput = StageOutputProducer[BuildRequest]
|
||||
|
||||
// DownloadStage downloads the messages and attachments. It auto-throttles the download of the messages based on
|
||||
// whether we run into 429|5xx codes.
|
||||
type DownloadStage struct {
|
||||
input DownloadStageInput
|
||||
output DownloadStageOutput
|
||||
maxParallelDownloads int
|
||||
panicHandler async.PanicHandler
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
func NewDownloadStage(
|
||||
input DownloadStageInput,
|
||||
output DownloadStageOutput,
|
||||
maxParallelDownloads int,
|
||||
panicHandler async.PanicHandler,
|
||||
) *DownloadStage {
|
||||
return &DownloadStage{
|
||||
input: input,
|
||||
output: output,
|
||||
maxParallelDownloads: maxParallelDownloads,
|
||||
panicHandler: panicHandler,
|
||||
log: logrus.WithField("sync-stage", "download"),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DownloadStage) Run(group *async.Group) {
|
||||
group.Once(func(ctx context.Context) {
|
||||
logging.DoAnnotated(ctx, func(ctx context.Context) {
|
||||
d.run(ctx)
|
||||
}, logging.Labels{"sync-stage": "Download"})
|
||||
})
|
||||
}
|
||||
|
||||
func (d *DownloadStage) run(ctx context.Context) {
|
||||
defer d.output.Close()
|
||||
|
||||
newCoolDown := func() network.CoolDownProvider {
|
||||
return &network.ExpCoolDown{}
|
||||
}
|
||||
|
||||
for {
|
||||
request, err := d.input.Consume(ctx)
|
||||
if err != nil {
|
||||
if !(errors.Is(err, ErrNoMoreInput) || errors.Is(err, context.Canceled)) {
|
||||
d.log.WithError(err).Error("Exiting state with error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if request.checkCancelled() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Step 1: Download Messages.
|
||||
result, err := autoDownloadRate(
|
||||
ctx,
|
||||
&DefaultDownloadRateModifier{},
|
||||
request.job.client,
|
||||
d.maxParallelDownloads,
|
||||
request.ids,
|
||||
newCoolDown,
|
||||
func(ctx context.Context, client APIClient, input string) (proton.FullMessage, error) {
|
||||
msg, err := downloadMessage(ctx, request.job.downloadCache, client, input)
|
||||
if err != nil {
|
||||
var apiErr *proton.APIError
|
||||
if errors.As(err, &apiErr) && apiErr.Status == 422 {
|
||||
return proton.FullMessage{}, nil
|
||||
}
|
||||
|
||||
return proton.FullMessage{}, err
|
||||
}
|
||||
|
||||
var attData [][]byte
|
||||
if msg.NumAttachments > 0 {
|
||||
attData = make([][]byte, msg.NumAttachments)
|
||||
}
|
||||
|
||||
return proton.FullMessage{Message: msg, AttData: attData}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
request.job.onError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Step 2: Prepare attachment ids for download.
|
||||
type attachmentMeta struct {
|
||||
msgIdx int
|
||||
attIdx int
|
||||
}
|
||||
|
||||
// Filter out any messages that don't exist.
|
||||
result = xslices.Filter(result, func(t proton.FullMessage) bool {
|
||||
return t.ID != ""
|
||||
})
|
||||
|
||||
attachmentIndices := make([]attachmentMeta, 0, len(result))
|
||||
attachmentIDs := make([]string, 0, len(result))
|
||||
|
||||
for msgIdx, v := range result {
|
||||
for attIdx := 0; attIdx < v.NumAttachments; attIdx++ {
|
||||
attachmentIndices = append(attachmentIndices, attachmentMeta{
|
||||
msgIdx: msgIdx,
|
||||
attIdx: attIdx,
|
||||
})
|
||||
|
||||
attachmentIDs = append(attachmentIDs, result[msgIdx].Attachments[attIdx].ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Download attachments data to the message.
|
||||
attachments, err := autoDownloadRate(
|
||||
ctx,
|
||||
&DefaultDownloadRateModifier{},
|
||||
request.job.client,
|
||||
d.maxParallelDownloads,
|
||||
attachmentIndices,
|
||||
newCoolDown,
|
||||
func(ctx context.Context, client APIClient, input attachmentMeta) ([]byte, error) {
|
||||
return downloadAttachment(ctx, request.job.downloadCache, client, result[input.msgIdx].Attachments[input.attIdx].ID)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
request.job.onError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Step 4: attach attachment data to the message.
|
||||
for i, meta := range attachmentIndices {
|
||||
result[meta.msgIdx].AttData[meta.attIdx] = attachments[i]
|
||||
}
|
||||
|
||||
request.cachedAttachmentIDs = attachmentIDs
|
||||
request.cachedMessageIDs = request.ids
|
||||
|
||||
// Step 5: Publish result.
|
||||
d.output.Produce(ctx, BuildRequest{
|
||||
batch: result,
|
||||
childJob: request.childJob,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func downloadMessage(ctx context.Context, cache *DownloadCache, client APIClient, id string) (proton.Message, error) {
|
||||
msg, ok := cache.GetMessage(id)
|
||||
if ok {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
msg, err := client.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
return proton.Message{}, err
|
||||
}
|
||||
|
||||
cache.StoreMessage(msg)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func downloadAttachment(ctx context.Context, cache *DownloadCache, client APIClient, id string) ([]byte, error) {
|
||||
data, ok := cache.GetAttachment(id)
|
||||
if ok {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
data, err := client.GetAttachment(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache.StoreAttachment(id, data)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
type DownloadRateModifier interface {
|
||||
Apply(wasSuccess bool, current int, max int) int
|
||||
}
|
||||
|
||||
func autoDownloadRate[T any, R any](
|
||||
ctx context.Context,
|
||||
modifier DownloadRateModifier,
|
||||
client APIClient,
|
||||
maxParallelDownloads int,
|
||||
data []T,
|
||||
newCoolDown func() network.CoolDownProvider,
|
||||
f func(ctx context.Context, client APIClient, input T) (R, error),
|
||||
) ([]R, error) {
|
||||
result := make([]R, 0, len(data))
|
||||
|
||||
proton429or5xxCounter := int32(0)
|
||||
parallelTasks := maxParallelDownloads
|
||||
for _, chunk := range xslices.Chunk(data, maxParallelDownloads) {
|
||||
parallelTasks = modifier.Apply(atomic.LoadInt32(&proton429or5xxCounter) != 0, parallelTasks, maxParallelDownloads)
|
||||
|
||||
atomic.StoreInt32(&proton429or5xxCounter, 0)
|
||||
|
||||
chunkResult, err := parallel.MapContext(
|
||||
ctx,
|
||||
parallelTasks,
|
||||
chunk,
|
||||
func(ctx context.Context, in T) (R, error) {
|
||||
wrapper := network.NewClientRetryWrapper(client, newCoolDown())
|
||||
msg, err := network.RetryWithClient(ctx, wrapper, func(ctx context.Context, c APIClient) (R, error) {
|
||||
return f(ctx, c, in)
|
||||
})
|
||||
|
||||
if wrapper.DidEncounter429or5xx() {
|
||||
atomic.AddInt32(&proton429or5xxCounter, 1)
|
||||
}
|
||||
|
||||
return msg, err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, chunkResult...)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type DefaultDownloadRateModifier struct{}
|
||||
|
||||
func (d DefaultDownloadRateModifier) Apply(wasSuccess bool, current int, max int) int {
|
||||
if !wasSuccess {
|
||||
return 2
|
||||
}
|
||||
|
||||
parallelTasks := current * 2
|
||||
if parallelTasks > max {
|
||||
parallelTasks = max
|
||||
}
|
||||
|
||||
return parallelTasks
|
||||
}
|
||||
472
internal/services/syncservice/stage_download_test.go
Normal file
472
internal/services/syncservice/stage_download_test.go
Normal file
@ -0,0 +1,472 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/network"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDownloadMessage_NotInCache(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
client := NewMockAPIClient(mockCtrl)
|
||||
cache := newDownloadCache()
|
||||
client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{}, nil)
|
||||
|
||||
_, err := downloadMessage(context.Background(), cache, client, "msg")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDownloadMessage_InCache(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
client := NewMockAPIClient(mockCtrl)
|
||||
cache := newDownloadCache()
|
||||
|
||||
msg := proton.Message{
|
||||
MessageMetadata: proton.MessageMetadata{ID: "msg", Size: 1024},
|
||||
}
|
||||
|
||||
cache.StoreMessage(msg)
|
||||
|
||||
downloaded, err := downloadMessage(context.Background(), cache, client, "msg")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, msg, downloaded)
|
||||
}
|
||||
|
||||
func TestDownloadAttachment_NotInCache(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
client := NewMockAPIClient(mockCtrl)
|
||||
cache := newDownloadCache()
|
||||
client.EXPECT().GetAttachment(gomock.Any(), gomock.Any()).Return(nil, nil)
|
||||
|
||||
_, err := downloadAttachment(context.Background(), cache, client, "id")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDownloadAttachment_InCache(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
client := NewMockAPIClient(mockCtrl)
|
||||
cache := newDownloadCache()
|
||||
attachment := []byte("hello world")
|
||||
cache.StoreAttachment("id", attachment)
|
||||
|
||||
downloaded, err := downloadAttachment(context.Background(), cache, client, "id")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, attachment, downloaded)
|
||||
}
|
||||
|
||||
func TestAutoDownloadScale_AllOkay(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
client := NewMockAPIClient(mockCtrl)
|
||||
data := buildDownloadScaleData(15)
|
||||
|
||||
const MaxParallel = 5
|
||||
|
||||
call1 := client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Times(5).DoAndReturn(autoDownloadScaleClientDoAndReturn)
|
||||
call2 := client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Times(5).After(call1).DoAndReturn(autoDownloadScaleClientDoAndReturn)
|
||||
client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Times(5).After(call2).DoAndReturn(autoDownloadScaleClientDoAndReturn)
|
||||
|
||||
msgs, err := autoDownloadRate(
|
||||
context.Background(),
|
||||
&DefaultDownloadRateModifier{},
|
||||
client,
|
||||
MaxParallel,
|
||||
data,
|
||||
autoScaleCoolDown,
|
||||
func(ctx context.Context, client APIClient, input string) (proton.Message, error) {
|
||||
return client.GetMessage(ctx, input)
|
||||
},
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, xslices.Map(data, newDownloadScaleMessage), msgs)
|
||||
}
|
||||
|
||||
func TestAutoDownloadScale_429or500x(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
client := NewMockAPIClient(mockCtrl)
|
||||
data := buildDownloadScaleData(32)
|
||||
|
||||
rateModifier := NewMockDownloadRateModifier(mockCtrl)
|
||||
|
||||
const MaxParallel = 8
|
||||
|
||||
for _, d := range data {
|
||||
switch d {
|
||||
case "m7":
|
||||
call429 := client.EXPECT().GetMessage(gomock.Any(), gomock.Eq("m7")).DoAndReturn(func(_ context.Context, id string) (proton.Message, error) {
|
||||
return proton.Message{}, &proton.APIError{Status: 429}
|
||||
})
|
||||
client.EXPECT().GetMessage(gomock.Any(), gomock.Eq("m7")).After(call429).DoAndReturn(autoDownloadScaleClientDoAndReturn)
|
||||
case "m23":
|
||||
call503 := client.EXPECT().GetMessage(gomock.Any(), gomock.Eq("m23")).DoAndReturn(func(_ context.Context, id string) (proton.Message, error) {
|
||||
return proton.Message{}, &proton.APIError{Status: 503}
|
||||
})
|
||||
client.EXPECT().GetMessage(gomock.Any(), gomock.Eq("m23")).After(call503).DoAndReturn(autoDownloadScaleClientDoAndReturn)
|
||||
default:
|
||||
client.EXPECT().GetMessage(gomock.Any(), gomock.Eq(d)).DoAndReturn(autoDownloadScaleClientDoAndReturn)
|
||||
}
|
||||
}
|
||||
|
||||
defaultRateModifier := DefaultDownloadRateModifier{}
|
||||
|
||||
// First call catches failure in message m7, we throttle.
|
||||
call1 := rateModifier.EXPECT().Apply(gomock.Eq(false), gomock.Eq(8), gomock.Eq(8)).Return(defaultRateModifier.Apply(false, 8, 8))
|
||||
// Next batch succeeds. So we bump the parallel downloads by 2x.
|
||||
call2 := rateModifier.EXPECT().Apply(gomock.Eq(true), gomock.Eq(2), gomock.Eq(8)).Return(defaultRateModifier.Apply(true, 2, 8))
|
||||
// We now encounter a 503 with m23. Reset to 2 parallel downloads.
|
||||
call3 := rateModifier.EXPECT().Apply(gomock.Eq(false), gomock.Eq(4), gomock.Eq(8)).Return(defaultRateModifier.Apply(false, 4, 8))
|
||||
// The next batch succeeds once again.
|
||||
call4 := rateModifier.EXPECT().Apply(gomock.Eq(true), gomock.Eq(2), gomock.Eq(8)).Return(defaultRateModifier.Apply(true, 2, 8))
|
||||
|
||||
gomock.InOrder(call1, call2, call3, call4)
|
||||
|
||||
msgs, err := autoDownloadRate(
|
||||
context.Background(),
|
||||
rateModifier,
|
||||
client,
|
||||
MaxParallel,
|
||||
data,
|
||||
autoScaleCoolDown,
|
||||
func(ctx context.Context, client APIClient, input string) (proton.Message, error) {
|
||||
return client.GetMessage(ctx, input)
|
||||
},
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, xslices.Map(data, newDownloadScaleMessage), msgs)
|
||||
}
|
||||
|
||||
func TestDownloadStage_Run(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[DownloadRequest]()
|
||||
output := NewChannelConsumerProducer[BuildRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tj := newTestJob(ctx, mockCtrl, "", map[string]proton.Label{})
|
||||
|
||||
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any())
|
||||
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("f"), gomock.Eq(int64(10))).Return(nil)
|
||||
|
||||
tj.job.begin()
|
||||
defer tj.job.end()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
|
||||
stage := NewDownloadStage(input, output, 4, &async.NoopPanicHandler{})
|
||||
|
||||
msgIDs, expected := buildDownloadStageData(&tj, 56, false)
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
input.Produce(ctx, DownloadRequest{
|
||||
childJob: childJob,
|
||||
ids: msgIDs,
|
||||
})
|
||||
|
||||
out, err := output.Consume(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, out.batch)
|
||||
out.onFinished(ctx)
|
||||
cancel()
|
||||
|
||||
cachedMessages, cachedAttachments := tj.job.downloadCache.Count()
|
||||
require.Zero(t, cachedMessages)
|
||||
require.Zero(t, cachedAttachments)
|
||||
}
|
||||
|
||||
func TestDownloadStage_RunWith422(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[DownloadRequest]()
|
||||
output := NewChannelConsumerProducer[BuildRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tj := newTestJob(ctx, mockCtrl, "", map[string]proton.Label{})
|
||||
|
||||
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any())
|
||||
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("f"), gomock.Eq(int64(10))).Return(nil)
|
||||
|
||||
tj.job.begin()
|
||||
defer tj.job.end()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
|
||||
stage := NewDownloadStage(input, output, 4, &async.NoopPanicHandler{})
|
||||
|
||||
msgIDs, expected := buildDownloadStageData(&tj, 56, true)
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
input.Produce(ctx, DownloadRequest{
|
||||
childJob: childJob,
|
||||
ids: msgIDs,
|
||||
})
|
||||
|
||||
out, err := output.Consume(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, out.batch)
|
||||
out.onFinished(ctx)
|
||||
cancel()
|
||||
|
||||
cachedMessages, cachedAttachments := tj.job.downloadCache.Count()
|
||||
require.Zero(t, cachedMessages)
|
||||
require.Zero(t, cachedAttachments)
|
||||
}
|
||||
|
||||
func TestDownloadStage_CancelledJobIsDiscarded(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[DownloadRequest]()
|
||||
output := NewChannelConsumerProducer[BuildRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
jobCtx, jobCancel := context.WithCancel(context.Background())
|
||||
|
||||
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
|
||||
|
||||
tj.job.begin()
|
||||
defer tj.job.end()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
|
||||
stage := NewDownloadStage(input, output, 4, &async.NoopPanicHandler{})
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
jobCancel()
|
||||
input.Produce(ctx, DownloadRequest{
|
||||
childJob: childJob,
|
||||
ids: nil,
|
||||
})
|
||||
|
||||
go func() { cancel() }()
|
||||
|
||||
_, err := output.Consume(context.Background())
|
||||
require.ErrorIs(t, err, ErrNoMoreInput)
|
||||
}
|
||||
|
||||
func TestDownloadStage_JobAbortsOnMessageDownloadError(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[DownloadRequest]()
|
||||
output := NewChannelConsumerProducer[BuildRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
jobCtx, jobCancel := context.WithCancel(context.Background())
|
||||
defer jobCancel()
|
||||
|
||||
expectedErr := errors.New("fail")
|
||||
|
||||
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
|
||||
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{}, expectedErr)
|
||||
|
||||
tj.job.begin()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
tj.job.end()
|
||||
|
||||
stage := NewDownloadStage(input, output, 4, &async.NoopPanicHandler{})
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
input.Produce(ctx, DownloadRequest{
|
||||
childJob: childJob,
|
||||
ids: []string{"foo"},
|
||||
})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
require.Equal(t, expectedErr, err)
|
||||
|
||||
cancel()
|
||||
|
||||
_, err = output.Consume(context.Background())
|
||||
require.ErrorIs(t, err, ErrNoMoreInput)
|
||||
}
|
||||
|
||||
func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
input := NewChannelConsumerProducer[DownloadRequest]()
|
||||
output := NewChannelConsumerProducer[BuildRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
jobCtx, jobCancel := context.WithCancel(context.Background())
|
||||
defer jobCancel()
|
||||
|
||||
expectedErr := errors.New("fail")
|
||||
|
||||
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
|
||||
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{
|
||||
MessageMetadata: proton.MessageMetadata{
|
||||
ID: "msg",
|
||||
NumAttachments: 1,
|
||||
},
|
||||
Header: "",
|
||||
ParsedHeaders: nil,
|
||||
Body: "",
|
||||
MIMEType: "",
|
||||
Attachments: []proton.Attachment{{
|
||||
ID: "attach",
|
||||
}},
|
||||
}, nil)
|
||||
tj.client.EXPECT().GetAttachment(gomock.Any(), gomock.Eq("attach")).Return(nil, expectedErr)
|
||||
|
||||
tj.job.begin()
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
tj.job.end()
|
||||
|
||||
stage := NewDownloadStage(input, output, 4, &async.NoopPanicHandler{})
|
||||
|
||||
go func() {
|
||||
stage.run(ctx)
|
||||
}()
|
||||
|
||||
input.Produce(ctx, DownloadRequest{
|
||||
childJob: childJob,
|
||||
ids: []string{"foo"},
|
||||
})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
require.Equal(t, expectedErr, err)
|
||||
|
||||
cancel()
|
||||
|
||||
_, err = output.Consume(context.Background())
|
||||
require.ErrorIs(t, err, ErrNoMoreInput)
|
||||
}
|
||||
|
||||
func buildDownloadStageData(tj *tjob, numMessages int, with422 bool) ([]string, []proton.FullMessage) {
|
||||
result := make([]proton.FullMessage, numMessages)
|
||||
msgIDs := make([]string, numMessages)
|
||||
|
||||
for i := 0; i < numMessages; i++ {
|
||||
msgID := fmt.Sprintf("msg-%v", i)
|
||||
msgIDs[i] = msgID
|
||||
result[i] = proton.FullMessage{
|
||||
Message: proton.Message{
|
||||
MessageMetadata: proton.MessageMetadata{
|
||||
ID: msgID,
|
||||
Size: len([]byte(msgID)),
|
||||
},
|
||||
Header: "",
|
||||
ParsedHeaders: nil,
|
||||
Body: msgID,
|
||||
MIMEType: "",
|
||||
Attachments: nil,
|
||||
},
|
||||
AttData: nil,
|
||||
}
|
||||
|
||||
buildDownloadStageAttachments(&result[i], i)
|
||||
}
|
||||
|
||||
for i, m := range result {
|
||||
if with422 && i%2 == 0 {
|
||||
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Eq(m.ID)).Return(proton.Message{}, &proton.APIError{Status: 422})
|
||||
continue
|
||||
}
|
||||
|
||||
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Eq(m.ID)).Return(m.Message, nil)
|
||||
|
||||
for idx, a := range m.Attachments {
|
||||
tj.client.EXPECT().GetAttachment(gomock.Any(), gomock.Eq(a.ID)).Return(m.AttData[idx], nil)
|
||||
}
|
||||
}
|
||||
|
||||
if with422 {
|
||||
result422 := make([]proton.FullMessage, 0, numMessages/2)
|
||||
|
||||
for i := 0; i < numMessages; i++ {
|
||||
if i%2 == 0 {
|
||||
continue
|
||||
}
|
||||
result422 = append(result422, result[i])
|
||||
}
|
||||
|
||||
return msgIDs, result422
|
||||
}
|
||||
|
||||
return msgIDs, result
|
||||
}
|
||||
|
||||
func buildDownloadStageAttachments(msg *proton.FullMessage, index int) {
|
||||
mod := index % 4
|
||||
if mod == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
genDownloadStageAttachmentInfo(msg, index, mod)
|
||||
}
|
||||
|
||||
func genDownloadStageAttachmentInfo(msg *proton.FullMessage, msgIdx int, count int) {
|
||||
msg.Attachments = make([]proton.Attachment, count)
|
||||
msg.AttData = make([][]byte, count)
|
||||
msg.NumAttachments = count
|
||||
for i := 0; i < count; i++ {
|
||||
data := fmt.Sprintf("msg-%v-att-%v", msgIdx, i)
|
||||
msg.Attachments[i] = proton.Attachment{
|
||||
ID: data,
|
||||
Size: int64(len([]byte(data))),
|
||||
}
|
||||
msg.AttData[i] = []byte(data)
|
||||
msg.Size += len([]byte(data))
|
||||
}
|
||||
}
|
||||
|
||||
func autoScaleCoolDown() network.CoolDownProvider {
|
||||
return &network.NoCoolDown{}
|
||||
}
|
||||
|
||||
func buildDownloadScaleData(count int) []string {
|
||||
r := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
r[i] = fmt.Sprintf("m%v", i)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func newDownloadScaleMessage(id string) proton.Message {
|
||||
return proton.Message{
|
||||
MessageMetadata: proton.MessageMetadata{ID: id},
|
||||
}
|
||||
}
|
||||
|
||||
func autoDownloadScaleClientDoAndReturn(_ context.Context, id string) (proton.Message, error) {
|
||||
return newDownloadScaleMessage(id), nil
|
||||
}
|
||||
218
internal/services/syncservice/stage_metadata.go
Normal file
218
internal/services/syncservice/stage_metadata.go
Normal file
@ -0,0 +1,218 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/network"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type MetadataStageOutput = StageOutputProducer[DownloadRequest]
|
||||
type MetadataStageInput = StageInputConsumer[*Job]
|
||||
|
||||
// MetadataStage is responsible for the throttling the sync pipeline by only allowing `MetadataMaxMessages` or up to
|
||||
// maximum allowed memory usage messages to go through the pipeline. It is also responsible for interleaving
|
||||
// different sync jobs so all jobs can progress and finish.
|
||||
type MetadataStage struct {
|
||||
output MetadataStageOutput
|
||||
input MetadataStageInput
|
||||
maxDownloadMem uint64
|
||||
jobs []*metadataIterator
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
func NewMetadataStage(input MetadataStageInput, output MetadataStageOutput, maxDownloadMem uint64) *MetadataStage {
|
||||
return &MetadataStage{input: input, output: output, maxDownloadMem: maxDownloadMem, log: logrus.WithField("sync-stage", "metadata")}
|
||||
}
|
||||
|
||||
const MetadataPageSize = 150
|
||||
const MetadataMaxMessages = 250
|
||||
|
||||
func (m MetadataStage) Run(group *async.Group) {
|
||||
group.Once(func(ctx context.Context) {
|
||||
m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{})
|
||||
})
|
||||
}
|
||||
|
||||
func (m MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) {
|
||||
defer m.output.Close()
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if new job has been submitted
|
||||
job, ok, err := m.input.TryConsume(ctx)
|
||||
if err != nil {
|
||||
m.log.WithError(err).Error("Error trying to retrieve more work")
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
job.begin()
|
||||
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
|
||||
if err != nil {
|
||||
job.onError(err)
|
||||
continue
|
||||
}
|
||||
m.jobs = append(m.jobs, state)
|
||||
}
|
||||
|
||||
// Iterate over all jobs and produce work.
|
||||
for i := 0; i < len(m.jobs); {
|
||||
job := m.jobs[i]
|
||||
|
||||
// If the job's context has been cancelled, remove from the list.
|
||||
if job.stage.ctx.Err() != nil {
|
||||
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
|
||||
job.stage.end()
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for more work.
|
||||
output, hasMore, err := job.Next(m.maxDownloadMem, metadataPageSize, maxMessages)
|
||||
if err != nil {
|
||||
job.stage.onError(err)
|
||||
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
|
||||
continue
|
||||
}
|
||||
|
||||
// If there is actually more work, push it down the pipeline.
|
||||
if len(output.ids) != 0 {
|
||||
m.output.Produce(ctx, output)
|
||||
}
|
||||
|
||||
// If this job has no more work left, signal completion.
|
||||
if !hasMore {
|
||||
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
|
||||
job.stage.end()
|
||||
continue
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type metadataIterator struct {
|
||||
stage *Job
|
||||
client *network.ProtonClientRetryWrapper[APIClient]
|
||||
lastMessageID string
|
||||
remaining []proton.MessageMetadata
|
||||
downloadReqIDs []string
|
||||
expectedSize uint64
|
||||
}
|
||||
|
||||
func newMetadataIterator(ctx context.Context, stage *Job, metadataPageSize int, coolDown network.CoolDownProvider) (*metadataIterator, error) {
|
||||
syncStatus, err := stage.state.GetSyncStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &metadataIterator{
|
||||
stage: stage,
|
||||
client: network.NewClientRetryWrapper(stage.client, coolDown),
|
||||
lastMessageID: syncStatus.LastSyncedMessageID,
|
||||
remaining: nil,
|
||||
downloadReqIDs: make([]string, 0, metadataPageSize),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *metadataIterator) Next(maxDownloadMem uint64, metadataPageSize int, maxMessages int) (DownloadRequest, bool, error) {
|
||||
for {
|
||||
if m.stage.ctx.Err() != nil {
|
||||
return DownloadRequest{}, false, m.stage.ctx.Err()
|
||||
}
|
||||
|
||||
if len(m.remaining) == 0 {
|
||||
metadata, err := network.RetryWithClient(m.stage.ctx, m.client, func(ctx context.Context, c APIClient) ([]proton.MessageMetadata, error) {
|
||||
// To get the metadata of the messages in batches we need to initialize the state with a call to
|
||||
// GetMessageMetadata withe filter{Desc:true}.
|
||||
if m.lastMessageID == "" {
|
||||
return c.GetMessageMetadataPage(ctx, 0, metadataPageSize, proton.MessageFilter{
|
||||
Desc: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Afterward we perform the same query but set the EndID to the last message of the previous batch.
|
||||
// Care must be taken here as the EndID will appear again as the first metadata result if it has not
|
||||
// been eliminated.
|
||||
meta, err := c.GetMessageMetadataPage(ctx, 0, metadataPageSize, proton.MessageFilter{
|
||||
EndID: m.lastMessageID,
|
||||
Desc: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// To break the loop we need to check that either:
|
||||
// * There are no messages returned
|
||||
if len(meta) == 0 {
|
||||
return meta, err
|
||||
}
|
||||
|
||||
// * There is only one message returned and it matches the EndID query
|
||||
if meta[0].ID == m.lastMessageID {
|
||||
return meta[1:], nil
|
||||
}
|
||||
|
||||
return meta, nil
|
||||
})
|
||||
if err != nil {
|
||||
m.stage.log.WithError(err).Errorf("Failed to download message metadata with lastMessageID=%v", m.lastMessageID)
|
||||
return DownloadRequest{}, false, err
|
||||
}
|
||||
|
||||
m.remaining = append(m.remaining, metadata...)
|
||||
|
||||
// Update the last message ID
|
||||
if len(m.remaining) != 0 {
|
||||
m.lastMessageID = m.remaining[len(m.remaining)-1].ID
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.remaining) == 0 {
|
||||
if len(m.downloadReqIDs) != 0 {
|
||||
return DownloadRequest{childJob: m.stage.newChildJob(m.downloadReqIDs[len(m.downloadReqIDs)-1], int64(len(m.downloadReqIDs))), ids: m.downloadReqIDs}, false, nil
|
||||
}
|
||||
|
||||
return DownloadRequest{}, false, nil
|
||||
}
|
||||
|
||||
for idx, meta := range m.remaining {
|
||||
nextSize := m.expectedSize + uint64(meta.Size)
|
||||
if nextSize >= maxDownloadMem || len(m.downloadReqIDs) >= maxMessages {
|
||||
m.expectedSize = 0
|
||||
m.remaining = m.remaining[idx:]
|
||||
downloadReqIDs := m.downloadReqIDs
|
||||
m.downloadReqIDs = make([]string, 0, metadataPageSize)
|
||||
|
||||
return DownloadRequest{childJob: m.stage.newChildJob(downloadReqIDs[len(downloadReqIDs)-1], int64(len(downloadReqIDs))), ids: downloadReqIDs}, true, nil
|
||||
}
|
||||
|
||||
m.downloadReqIDs = append(m.downloadReqIDs, meta.ID)
|
||||
m.expectedSize = nextSize
|
||||
}
|
||||
|
||||
m.remaining = nil
|
||||
}
|
||||
}
|
||||
463
internal/services/syncservice/stage_metadata_test.go
Normal file
463
internal/services/syncservice/stage_metadata_test.go
Normal file
@ -0,0 +1,463 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/network"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const TestMetadataPageSize = 5
|
||||
const TestMaxDownloadMem = Kilobyte
|
||||
const TestMaxMessages = 10
|
||||
|
||||
func TestMetadataStage_RunFinishesWith429(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
tj := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
|
||||
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{
|
||||
LastSyncedMessageID: "",
|
||||
}, nil)
|
||||
|
||||
input := NewChannelConsumerProducer[*Job]()
|
||||
output := NewChannelConsumerProducer[DownloadRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
|
||||
|
||||
numMessages := 50
|
||||
messageSize := 100
|
||||
|
||||
msgs := setupMetadataSuccessRunWith429(&tj, numMessages, messageSize)
|
||||
|
||||
go func() {
|
||||
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
|
||||
}()
|
||||
|
||||
input.Produce(ctx, tj.job)
|
||||
|
||||
for _, chunk := range xslices.Chunk(msgs, TestMaxMessages) {
|
||||
req, err := output.Consume(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, req.ids, xslices.Map(chunk, func(m proton.MessageMetadata) string {
|
||||
return m.ID
|
||||
}))
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
jobCtx, jobCancel := context.WithCancel(context.Background())
|
||||
tj := newTestFixedMetadataJob(jobCtx, mockCtrl, "u", getTestLabels())
|
||||
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{
|
||||
LastSyncedMessageID: "",
|
||||
}, nil)
|
||||
tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
input := NewChannelConsumerProducer[*Job]()
|
||||
output := NewChannelConsumerProducer[DownloadRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
|
||||
|
||||
go func() {
|
||||
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
|
||||
}()
|
||||
|
||||
input.Produce(ctx, tj.job)
|
||||
|
||||
// read one output then cancel
|
||||
request, err := output.Consume(ctx)
|
||||
require.NoError(t, err)
|
||||
request.onFinished(ctx)
|
||||
// cancel job context
|
||||
jobCancel()
|
||||
|
||||
// The next stages should check whether the job has been cancelled or not. Here we need to do it manually.
|
||||
go func() {
|
||||
for {
|
||||
req, err := output.Consume(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req.checkCancelled()
|
||||
}
|
||||
}()
|
||||
|
||||
err = tj.job.wait(context.Background())
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestMetadataStage_RunInterleaved(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
tj1 := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
|
||||
tj1.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{}, nil)
|
||||
tj1.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
tj1.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
tj2 := newTestJob(context.Background(), mockCtrl, "u", getTestLabels())
|
||||
tj2.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{}, nil)
|
||||
tj2.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
tj2.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
input := NewChannelConsumerProducer[*Job]()
|
||||
output := NewChannelConsumerProducer[DownloadRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
|
||||
|
||||
numMessages := 50
|
||||
messageSize := 100
|
||||
|
||||
setupMetadataSuccessRunWith429(&tj1, numMessages, messageSize)
|
||||
setupMetadataSuccessRunWith429(&tj2, numMessages, messageSize)
|
||||
|
||||
go func() {
|
||||
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
|
||||
}()
|
||||
|
||||
go func() {
|
||||
input.Produce(ctx, tj1.job)
|
||||
input.Produce(ctx, tj2.job)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
req, err := output.Consume(ctx)
|
||||
if err != nil {
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
return
|
||||
}
|
||||
|
||||
req.onFinished(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
require.NoError(t, tj1.job.wait(ctx))
|
||||
require.NoError(t, tj2.job.wait(ctx))
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestMetadataIterator_ExitNoMoreMetadata(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
ctx := context.Background()
|
||||
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
|
||||
|
||||
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{
|
||||
LastSyncedMessageID: "foo",
|
||||
}, nil)
|
||||
|
||||
tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(proton.MessageFilter{Desc: true, EndID: "foo"})).Return(nil, nil)
|
||||
|
||||
iter, err := newMetadataIterator(ctx, tj.job, TestMetadataPageSize, &network.NoCoolDown{})
|
||||
require.NoError(t, err)
|
||||
|
||||
j, hasMore, err := iter.Next(TestMaxDownloadMem, TestMetadataPageSize, TestMaxMessages)
|
||||
require.NoError(t, err)
|
||||
require.False(t, hasMore)
|
||||
require.Empty(t, j.ids)
|
||||
}
|
||||
|
||||
func TestMetadataIterator_ExitLastCallAlwaysReturnLastMessageID(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
ctx := context.Background()
|
||||
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
|
||||
|
||||
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{
|
||||
LastSyncedMessageID: "foo",
|
||||
}, nil)
|
||||
|
||||
tj.client.EXPECT().GetMessageMetadataPage(
|
||||
gomock.Any(),
|
||||
gomock.Eq(0),
|
||||
gomock.Eq(TestMetadataPageSize),
|
||||
gomock.Eq(proton.MessageFilter{Desc: true, EndID: "foo"}),
|
||||
).Return([]proton.MessageMetadata{{
|
||||
ID: "foo",
|
||||
Size: 100,
|
||||
}}, nil)
|
||||
|
||||
iter, err := newMetadataIterator(ctx, tj.job, TestMetadataPageSize, &network.NoCoolDown{})
|
||||
require.NoError(t, err)
|
||||
|
||||
j, hasMore, err := iter.Next(TestMaxDownloadMem, TestMetadataPageSize, TestMaxMessages)
|
||||
require.NoError(t, err)
|
||||
require.False(t, hasMore)
|
||||
require.Empty(t, j.ids)
|
||||
}
|
||||
|
||||
func TestMetadataIterator_ExitWithRemainingReturnsNoMore(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
ctx := context.Background()
|
||||
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
|
||||
|
||||
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{}, nil)
|
||||
|
||||
const MetadataPageSize = 2
|
||||
|
||||
tj.client.EXPECT().GetMessageMetadataPage(
|
||||
gomock.Any(),
|
||||
gomock.Eq(0),
|
||||
gomock.Eq(MetadataPageSize),
|
||||
gomock.Eq(proton.MessageFilter{
|
||||
Desc: true,
|
||||
}),
|
||||
).Return([]proton.MessageMetadata{
|
||||
{
|
||||
ID: "foo",
|
||||
Size: 100,
|
||||
},
|
||||
{
|
||||
ID: "bar",
|
||||
Size: 100,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
tj.client.EXPECT().GetMessageMetadataPage(
|
||||
gomock.Any(),
|
||||
gomock.Eq(0),
|
||||
gomock.Eq(MetadataPageSize),
|
||||
gomock.Eq(proton.MessageFilter{
|
||||
Desc: true,
|
||||
EndID: "bar",
|
||||
}),
|
||||
).Return([]proton.MessageMetadata{
|
||||
{
|
||||
ID: "bar",
|
||||
Size: 100,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
iter, err := newMetadataIterator(ctx, tj.job, MetadataPageSize, &network.NoCoolDown{})
|
||||
require.NoError(t, err)
|
||||
|
||||
j, hasMore, err := iter.Next(TestMaxDownloadMem, MetadataPageSize, 3)
|
||||
require.NoError(t, err)
|
||||
require.False(t, hasMore)
|
||||
require.Equal(t, []string{"foo", "bar"}, j.ids)
|
||||
}
|
||||
|
||||
func TestMetadataIterator_RespectsSizeLimit(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
ctx := context.Background()
|
||||
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
|
||||
|
||||
tj.state.EXPECT().GetSyncStatus(gomock.Any()).Return(Status{}, nil)
|
||||
|
||||
// First call.
|
||||
tj.client.EXPECT().GetMessageMetadataPage(
|
||||
gomock.Any(),
|
||||
gomock.Eq(0),
|
||||
gomock.Eq(TestMetadataPageSize),
|
||||
gomock.Eq(proton.MessageFilter{Desc: true}),
|
||||
).Return([]proton.MessageMetadata{
|
||||
{
|
||||
ID: testMsgID(0),
|
||||
Size: 256,
|
||||
},
|
||||
{
|
||||
ID: testMsgID(1),
|
||||
Size: 512,
|
||||
},
|
||||
{
|
||||
ID: testMsgID(2),
|
||||
Size: 128,
|
||||
},
|
||||
{
|
||||
ID: testMsgID(3),
|
||||
Size: 256,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
// Second Call
|
||||
tj.client.EXPECT().GetMessageMetadataPage(
|
||||
gomock.Any(),
|
||||
gomock.Eq(0),
|
||||
gomock.Eq(TestMetadataPageSize),
|
||||
gomock.Eq(proton.MessageFilter{Desc: true, EndID: testMsgID(3)}),
|
||||
).Return([]proton.MessageMetadata{
|
||||
{
|
||||
ID: testMsgID(3),
|
||||
Size: 256,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
iter, err := newMetadataIterator(ctx, tj.job, TestMetadataPageSize, &network.NoCoolDown{})
|
||||
require.NoError(t, err)
|
||||
|
||||
j, hasMore, err := iter.Next(TestMaxDownloadMem, TestMetadataPageSize, TestMaxMessages)
|
||||
require.NoError(t, err)
|
||||
require.True(t, hasMore)
|
||||
require.Equal(t, []string{testMsgID(0), testMsgID(1), testMsgID(2)}, j.ids)
|
||||
|
||||
j, hasMore, err = iter.Next(TestMaxDownloadMem, TestMetadataPageSize, TestMaxMessages)
|
||||
require.NoError(t, err)
|
||||
require.False(t, hasMore)
|
||||
require.Equal(t, []string{testMsgID(3)}, j.ids)
|
||||
}
|
||||
|
||||
func testMsgID(i int) string {
|
||||
return fmt.Sprintf("msg-id-%v", i)
|
||||
}
|
||||
|
||||
func setupMetadataSuccessRunWith429(tj *tjob, msgCount int, msgSize int) []proton.MessageMetadata {
|
||||
msgs := make([]proton.MessageMetadata, msgCount)
|
||||
|
||||
for i := 0; i < msgCount; i++ {
|
||||
msgs[i].ID = testMsgID(i)
|
||||
msgs[i].Size = msgSize
|
||||
}
|
||||
|
||||
// setup api call
|
||||
for i := 0; i < msgCount; i += TestMetadataPageSize - 1 {
|
||||
filter := proton.MessageFilter{
|
||||
Desc: true,
|
||||
}
|
||||
|
||||
if i != 0 {
|
||||
filter.EndID = msgs[i].ID
|
||||
}
|
||||
|
||||
if i+TestMetadataPageSize > msgCount {
|
||||
call := tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(filter)).Return(
|
||||
nil, &proton.APIError{Status: 503},
|
||||
)
|
||||
tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(filter)).Return(
|
||||
msgs[i:], nil,
|
||||
).After(call)
|
||||
} else {
|
||||
call := tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(filter)).Return(
|
||||
nil, &proton.APIError{Status: 429},
|
||||
)
|
||||
|
||||
tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(filter)).Return(
|
||||
msgs[i:i+TestMetadataPageSize], nil,
|
||||
).After(call)
|
||||
}
|
||||
}
|
||||
|
||||
// Last call with last metadata id
|
||||
tj.client.EXPECT().GetMessageMetadataPage(gomock.Any(), gomock.Eq(0), gomock.Eq(TestMetadataPageSize), gomock.Eq(proton.MessageFilter{Desc: true, EndID: msgs[msgCount-1].ID})).Return(
|
||||
msgs[msgCount-1:], nil,
|
||||
)
|
||||
|
||||
return msgs
|
||||
}
|
||||
|
||||
func newTestFixedMetadataJob(
|
||||
ctx context.Context,
|
||||
mockCtrl *gomock.Controller,
|
||||
userID string,
|
||||
labels LabelMap,
|
||||
) tjob {
|
||||
messageBuilder := NewMockMessageBuilder(mockCtrl)
|
||||
updateApplier := NewMockUpdateApplier(mockCtrl)
|
||||
syncReporter := NewMockReporter(mockCtrl)
|
||||
state := NewMockStateProvider(mockCtrl)
|
||||
client := newFixedMetadataClient(50)
|
||||
|
||||
job := NewJob(
|
||||
ctx,
|
||||
client,
|
||||
userID,
|
||||
labels,
|
||||
messageBuilder,
|
||||
updateApplier,
|
||||
syncReporter,
|
||||
state,
|
||||
&async.NoopPanicHandler{},
|
||||
newDownloadCache(),
|
||||
logrus.WithField("s", "test"),
|
||||
)
|
||||
|
||||
return tjob{
|
||||
job: job,
|
||||
client: nil,
|
||||
messageBuilder: messageBuilder,
|
||||
updateApplier: updateApplier,
|
||||
syncReporter: syncReporter,
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
|
||||
type fixedMetadataClient struct {
|
||||
msg []proton.MessageMetadata
|
||||
offset int
|
||||
}
|
||||
|
||||
func newFixedMetadataClient(msgCount int) APIClient {
|
||||
msgs := make([]proton.MessageMetadata, msgCount)
|
||||
|
||||
for i := 0; i < msgCount; i++ {
|
||||
msgs[i].ID = testMsgID(i)
|
||||
msgs[i].Size = 100
|
||||
}
|
||||
|
||||
return &fixedMetadataClient{msg: msgs}
|
||||
}
|
||||
|
||||
func (c *fixedMetadataClient) GetGroupedMessageCount(_ context.Context) ([]proton.MessageGroupCount, error) {
|
||||
panic("should not be called")
|
||||
}
|
||||
|
||||
func (c *fixedMetadataClient) GetLabels(_ context.Context, _ ...proton.LabelType) ([]proton.Label, error) {
|
||||
panic("should not be called")
|
||||
}
|
||||
|
||||
func (c *fixedMetadataClient) GetMessage(_ context.Context, _ string) (proton.Message, error) {
|
||||
panic("should not be called")
|
||||
}
|
||||
|
||||
func (c *fixedMetadataClient) GetMessageMetadataPage(_ context.Context, _, pageSize int, _ proton.MessageFilter) ([]proton.MessageMetadata, error) {
|
||||
result := c.msg[c.offset : c.offset+pageSize]
|
||||
c.offset += pageSize
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *fixedMetadataClient) GetMessageIDs(_ context.Context, _ string) ([]string, error) {
|
||||
panic("should not be called")
|
||||
}
|
||||
|
||||
func (c *fixedMetadataClient) GetFullMessage(_ context.Context, _ string, _ proton.Scheduler, _ proton.AttachmentAllocator) (proton.FullMessage, error) {
|
||||
panic("should not be called")
|
||||
}
|
||||
|
||||
func (c *fixedMetadataClient) GetAttachmentInto(_ context.Context, _ string, _ io.ReaderFrom) error {
|
||||
panic("should not be called")
|
||||
}
|
||||
|
||||
func (c *fixedMetadataClient) GetAttachment(_ context.Context, _ string) ([]byte, error) {
|
||||
panic("should not be called")
|
||||
}
|
||||
85
internal/services/syncservice/stage_output.go
Normal file
85
internal/services/syncservice/stage_output.go
Normal file
@ -0,0 +1,85 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type StageOutputProducer[T any] interface {
|
||||
Produce(ctx context.Context, value T)
|
||||
Close()
|
||||
}
|
||||
|
||||
var ErrNoMoreInput = errors.New("no more input")
|
||||
|
||||
type StageInputConsumer[T any] interface {
|
||||
Consume(ctx context.Context) (T, error)
|
||||
TryConsume(ctx context.Context) (T, bool, error)
|
||||
}
|
||||
|
||||
type ChannelConsumerProducer[T any] struct {
|
||||
ch chan T
|
||||
}
|
||||
|
||||
func NewChannelConsumerProducer[T any]() *ChannelConsumerProducer[T] {
|
||||
return &ChannelConsumerProducer[T]{ch: make(chan T)}
|
||||
}
|
||||
|
||||
func (c ChannelConsumerProducer[T]) Produce(ctx context.Context, value T) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case c.ch <- value:
|
||||
}
|
||||
}
|
||||
|
||||
func (c ChannelConsumerProducer[T]) Close() {
|
||||
close(c.ch)
|
||||
}
|
||||
|
||||
func (c ChannelConsumerProducer[T]) Consume(ctx context.Context) (T, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
var t T
|
||||
return t, ctx.Err()
|
||||
case t, ok := <-c.ch:
|
||||
if !ok {
|
||||
return t, ErrNoMoreInput
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c ChannelConsumerProducer[T]) TryConsume(ctx context.Context) (T, bool, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
var t T
|
||||
return t, false, ctx.Err()
|
||||
case t, ok := <-c.ch:
|
||||
if !ok {
|
||||
return t, false, ErrNoMoreInput
|
||||
}
|
||||
|
||||
return t, true, nil
|
||||
default:
|
||||
var t T
|
||||
return t, false, nil
|
||||
}
|
||||
}
|
||||
31
internal/services/syncservice/utils.go
Normal file
31
internal/services/syncservice/utils.go
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// sleepCtx sleeps for the given duration, or until the context is canceled.
|
||||
func sleepCtx(ctx context.Context, d time.Duration) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(d):
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user