forked from Silverfish/proton-bridge
fix(GODT-2829): Sync Service fixes
Fix tracking of child jobs. The build stage splits the incoming work even further, but this was not reflected in the wait group counter. This also fixes an issue where the cache was cleared to late. Add more debug info for analysis. Refactor sync state interface in order to have persistent sync rate.
This commit is contained in:
4
Makefile
4
Makefile
@ -299,11 +299,11 @@ EventSubscriber,MessageEventHandler,LabelEventHandler,AddressEventHandler,Refres
|
||||
> internal/events/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \
|
||||
> internal/services/useridentity/mocks/mocks.go
|
||||
mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/sync" -package sync github.com/ProtonMail/proton-bridge/v3/internal/services/sync \
|
||||
mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice" -package syncservice github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice \
|
||||
ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,\
|
||||
StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier \
|
||||
> tmp
|
||||
mv tmp internal/services/sync/mocks_test.go
|
||||
mv tmp internal/services/syncservice/mocks_test.go
|
||||
|
||||
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint-bug-report
|
||||
|
||||
|
||||
@ -29,7 +29,6 @@ type APIClient interface {
|
||||
GetLabels(ctx context.Context, labelTypes ...proton.LabelType) ([]proton.Label, error)
|
||||
GetMessage(ctx context.Context, messageID string) (proton.Message, error)
|
||||
GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error)
|
||||
GetMessageIDs(ctx context.Context, afterID string) ([]string, error)
|
||||
GetFullMessage(ctx context.Context, messageID string, scheduler proton.Scheduler, storageProvider proton.AttachmentAllocator) (proton.FullMessage, error)
|
||||
GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error
|
||||
GetAttachment(ctx context.Context, attachmentID string) ([]byte, error)
|
||||
|
||||
@ -117,7 +117,11 @@ func (t *Handler) Execute(
|
||||
}
|
||||
|
||||
t.log.WithField("duration", time.Since(start)).Info("Finished user sync")
|
||||
t.syncFinishedCh <- err
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case t.syncFinishedCh <- err:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -179,8 +183,12 @@ func (t *Handler) run(ctx context.Context,
|
||||
if err := t.syncState.SetMessageCount(ctx, totalMessageCount); err != nil {
|
||||
return fmt.Errorf("failed to store message count: %w", err)
|
||||
}
|
||||
|
||||
syncStatus.TotalMessageCount = totalMessageCount
|
||||
}
|
||||
|
||||
syncReporter.InitializeProgressCounter(ctx, syncStatus.NumSyncedMessages, syncStatus.TotalMessageCount)
|
||||
|
||||
if !syncStatus.HasMessages {
|
||||
t.log.Info("Syncing messages")
|
||||
|
||||
@ -198,6 +206,11 @@ func (t *Handler) run(ctx context.Context,
|
||||
t.log,
|
||||
)
|
||||
|
||||
stageContext.metadataFetched = syncStatus.NumSyncedMessages
|
||||
stageContext.totalMessageCount = syncStatus.TotalMessageCount
|
||||
|
||||
defer stageContext.Close()
|
||||
|
||||
t.regulator.Sync(ctx, stageContext)
|
||||
|
||||
// Wait on reply
|
||||
|
||||
@ -57,6 +57,7 @@ func TestTask_NoStateAndSucceeds(t *testing.T) {
|
||||
})
|
||||
call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil)
|
||||
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
|
||||
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
|
||||
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
||||
call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
||||
tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||
@ -125,6 +126,7 @@ func TestTask_StateHasLabels(t *testing.T) {
|
||||
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
|
||||
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
||||
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
||||
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
|
||||
}
|
||||
|
||||
{
|
||||
@ -170,6 +172,7 @@ func TestTask_StateHasLabelsAndMessageCount(t *testing.T) {
|
||||
})
|
||||
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
||||
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
||||
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
|
||||
}
|
||||
|
||||
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
|
||||
@ -219,6 +222,8 @@ func TestTask_RepeatsOnSyncFailure(t *testing.T) {
|
||||
|
||||
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
|
||||
|
||||
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
|
||||
|
||||
{
|
||||
call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||
return Status{
|
||||
|
||||
@ -28,8 +28,8 @@ import (
|
||||
)
|
||||
|
||||
type StateProvider interface {
|
||||
AddFailedMessageID(context.Context, string) error
|
||||
RemFailedMessageID(context.Context, string) error
|
||||
AddFailedMessageID(context.Context, ...string) error
|
||||
RemFailedMessageID(context.Context, ...string) error
|
||||
GetSyncStatus(context.Context) (Status, error)
|
||||
ClearSyncStatus(context.Context) error
|
||||
SetHasLabels(context.Context, bool) error
|
||||
@ -85,4 +85,5 @@ type Reporter interface {
|
||||
OnFinished(ctx context.Context)
|
||||
OnError(ctx context.Context, err error)
|
||||
OnProgress(ctx context.Context, delta int64)
|
||||
InitializeProgressCounter(ctx context.Context, current int64, total int64)
|
||||
}
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -54,6 +55,9 @@ type Job struct {
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
downloadCache *DownloadCache
|
||||
|
||||
metadataFetched int64
|
||||
totalMessageCount int64
|
||||
}
|
||||
|
||||
func NewJob(ctx context.Context,
|
||||
@ -178,6 +182,36 @@ func (s *childJob) userID() string {
|
||||
return s.job.userID
|
||||
}
|
||||
|
||||
func (s *childJob) chunkDivide(chunks [][]proton.FullMessage) []childJob {
|
||||
numChunks := len(chunks)
|
||||
|
||||
if numChunks == 1 {
|
||||
return []childJob{*s}
|
||||
}
|
||||
|
||||
result := make([]childJob, numChunks)
|
||||
for i := 0; i < numChunks-1; i++ {
|
||||
result[i] = s.job.newChildJob(chunks[i][len(chunks[i])-1].ID, int64(len(chunks[i])))
|
||||
collectIDs(&result[i], chunks[i])
|
||||
}
|
||||
|
||||
result[numChunks-1] = *s
|
||||
collectIDs(&result[numChunks-1], chunks[numChunks-1])
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func collectIDs(j *childJob, msgs []proton.FullMessage) {
|
||||
j.cachedAttachmentIDs = make([]string, 0, len(msgs))
|
||||
j.cachedMessageIDs = make([]string, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
j.cachedMessageIDs = append(j.cachedMessageIDs, msg.ID)
|
||||
for _, attach := range msg.Attachments {
|
||||
j.cachedAttachmentIDs = append(j.cachedAttachmentIDs, attach.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *childJob) onFinished(ctx context.Context) {
|
||||
s.job.log.Infof("Child job finished")
|
||||
s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/sync (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier)
|
||||
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier)
|
||||
|
||||
// Package sync is a generated GoMock package.
|
||||
// Package syncservice is a generated GoMock package.
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
@ -53,22 +53,6 @@ func (mr *MockApplyStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockApplyStageInput)(nil).Consume), arg0)
|
||||
}
|
||||
|
||||
// TryConsume mocks base method.
|
||||
func (m *MockApplyStageInput) TryConsume(arg0 context.Context) (ApplyRequest, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
||||
ret0, _ := ret[0].(ApplyRequest)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// TryConsume indicates an expected call of TryConsume.
|
||||
func (mr *MockApplyStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockApplyStageInput)(nil).TryConsume), arg0)
|
||||
}
|
||||
|
||||
// MockBuildStageInput is a mock of BuildStageInput interface.
|
||||
type MockBuildStageInput struct {
|
||||
ctrl *gomock.Controller
|
||||
@ -107,22 +91,6 @@ func (mr *MockBuildStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockBuildStageInput)(nil).Consume), arg0)
|
||||
}
|
||||
|
||||
// TryConsume mocks base method.
|
||||
func (m *MockBuildStageInput) TryConsume(arg0 context.Context) (BuildRequest, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
||||
ret0, _ := ret[0].(BuildRequest)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// TryConsume indicates an expected call of TryConsume.
|
||||
func (mr *MockBuildStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockBuildStageInput)(nil).TryConsume), arg0)
|
||||
}
|
||||
|
||||
// MockBuildStageOutput is a mock of BuildStageOutput interface.
|
||||
type MockBuildStageOutput struct {
|
||||
ctrl *gomock.Controller
|
||||
@ -208,22 +176,6 @@ func (mr *MockDownloadStageInputMockRecorder) Consume(arg0 interface{}) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockDownloadStageInput)(nil).Consume), arg0)
|
||||
}
|
||||
|
||||
// TryConsume mocks base method.
|
||||
func (m *MockDownloadStageInput) TryConsume(arg0 context.Context) (DownloadRequest, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
||||
ret0, _ := ret[0].(DownloadRequest)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// TryConsume indicates an expected call of TryConsume.
|
||||
func (mr *MockDownloadStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockDownloadStageInput)(nil).TryConsume), arg0)
|
||||
}
|
||||
|
||||
// MockDownloadStageOutput is a mock of DownloadStageOutput interface.
|
||||
type MockDownloadStageOutput struct {
|
||||
ctrl *gomock.Controller
|
||||
@ -309,22 +261,6 @@ func (mr *MockMetadataStageInputMockRecorder) Consume(arg0 interface{}) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockMetadataStageInput)(nil).Consume), arg0)
|
||||
}
|
||||
|
||||
// TryConsume mocks base method.
|
||||
func (m *MockMetadataStageInput) TryConsume(arg0 context.Context) (*Job, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
||||
ret0, _ := ret[0].(*Job)
|
||||
ret1, _ := ret[1].(bool)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// TryConsume indicates an expected call of TryConsume.
|
||||
func (mr *MockMetadataStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockMetadataStageInput)(nil).TryConsume), arg0)
|
||||
}
|
||||
|
||||
// MockMetadataStageOutput is a mock of MetadataStageOutput interface.
|
||||
type MockMetadataStageOutput struct {
|
||||
ctrl *gomock.Controller
|
||||
@ -396,17 +332,22 @@ func (m *MockStateProvider) EXPECT() *MockStateProviderMockRecorder {
|
||||
}
|
||||
|
||||
// AddFailedMessageID mocks base method.
|
||||
func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 string) error {
|
||||
func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 ...string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddFailedMessageID", arg0, arg1)
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "AddFailedMessageID", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddFailedMessageID indicates an expected call of AddFailedMessageID.
|
||||
func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0, arg1 interface{}) *gomock.Call {
|
||||
func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), arg0, arg1)
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), varargs...)
|
||||
}
|
||||
|
||||
// ClearSyncStatus mocks base method.
|
||||
@ -439,17 +380,22 @@ func (mr *MockStateProviderMockRecorder) GetSyncStatus(arg0 interface{}) *gomock
|
||||
}
|
||||
|
||||
// RemFailedMessageID mocks base method.
|
||||
func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 string) error {
|
||||
func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 ...string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemFailedMessageID", arg0, arg1)
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "RemFailedMessageID", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemFailedMessageID indicates an expected call of RemFailedMessageID.
|
||||
func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0, arg1 interface{}) *gomock.Call {
|
||||
func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), arg0, arg1)
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), varargs...)
|
||||
}
|
||||
|
||||
// SetHasLabels mocks base method.
|
||||
@ -777,21 +723,6 @@ func (mr *MockAPIClientMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockAPIClient)(nil).GetMessage), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetMessageIDs mocks base method.
|
||||
func (m *MockAPIClient) GetMessageIDs(arg0 context.Context, arg1 string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMessageIDs", arg0, arg1)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMessageIDs indicates an expected call of GetMessageIDs.
|
||||
func (mr *MockAPIClientMockRecorder) GetMessageIDs(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageIDs", reflect.TypeOf((*MockAPIClient)(nil).GetMessageIDs), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetMessageMetadataPage mocks base method.
|
||||
func (m *MockAPIClient) GetMessageMetadataPage(arg0 context.Context, arg1, arg2 int, arg3 proton.MessageFilter) ([]proton.MessageMetadata, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@ -830,6 +761,18 @@ func (m *MockReporter) EXPECT() *MockReporterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// InitializeProgressCounter mocks base method.
|
||||
func (m *MockReporter) InitializeProgressCounter(arg0 context.Context, arg1, arg2 int64) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "InitializeProgressCounter", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// InitializeProgressCounter indicates an expected call of InitializeProgressCounter.
|
||||
func (mr *MockReporterMockRecorder) InitializeProgressCounter(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitializeProgressCounter", reflect.TypeOf((*MockReporter)(nil).InitializeProgressCounter), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// OnError mocks base method.
|
||||
func (m *MockReporter) OnError(arg0 context.Context, arg1 error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -48,7 +48,7 @@ func NewService(reporter reporter.Reporter,
|
||||
|
||||
return &Service{
|
||||
limits: limits,
|
||||
metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem),
|
||||
metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem, panicHandler),
|
||||
downloadStage: NewDownloadStage(downloadCh, buildCh, 20, panicHandler),
|
||||
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
|
||||
applyStage: NewApplyStage(applyCh),
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"errors"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/logging"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -44,7 +45,15 @@ func NewApplyStage(input ApplyStageInput) *ApplyStage {
|
||||
}
|
||||
|
||||
func (a *ApplyStage) Run(group *async.Group) {
|
||||
group.Once(a.run)
|
||||
group.Once(func(ctx context.Context) {
|
||||
logging.DoAnnotated(
|
||||
ctx,
|
||||
func(ctx context.Context) {
|
||||
a.run(ctx)
|
||||
},
|
||||
logging.Labels{"sync-stage": "apply"},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *ApplyStage) run(ctx context.Context) {
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"runtime"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/logging"
|
||||
"github.com/ProtonMail/gluon/reporter"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
@ -70,7 +71,15 @@ func NewBuildStage(
|
||||
}
|
||||
|
||||
func (b *BuildStage) Run(group *async.Group) {
|
||||
group.Once(b.run)
|
||||
group.Once(func(ctx context.Context) {
|
||||
logging.DoAnnotated(
|
||||
ctx,
|
||||
func(ctx context.Context) {
|
||||
b.run(ctx)
|
||||
},
|
||||
logging.Labels{"sync-stage": "build"},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func (b *BuildStage) run(ctx context.Context) {
|
||||
@ -94,7 +103,21 @@ func (b *BuildStage) run(ctx context.Context) {
|
||||
err = req.job.messageBuilder.WithKeys(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
chunks := chunkSyncBuilderBatch(req.batch, b.maxBuildMem)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
// This stage will split our existing job into many smaller bits. We need to update the Parent Job so
|
||||
// that it correctly tracks the lifetime of extra jobs. Additionally, we also need to make sure
|
||||
// that only the last chunk contains the metadata to clear the cache.
|
||||
chunkedJobs := req.chunkDivide(chunks)
|
||||
|
||||
for idx, chunk := range chunks {
|
||||
if chunkedJobs[idx].checkCancelled() {
|
||||
// Cancel all other chunks.
|
||||
for i := idx + 1; i < len(chunkedJobs); i++ {
|
||||
chunkedJobs[i].checkCancelled()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (BuildResult, error) {
|
||||
defer async.HandlePanic(b.panicHandler)
|
||||
|
||||
@ -135,21 +158,29 @@ func (b *BuildStage) run(ctx context.Context) {
|
||||
return BuildResult{}, nil
|
||||
}
|
||||
|
||||
if err := req.job.state.RemFailedMessageID(req.getContext(), res.MessageID); err != nil {
|
||||
req.job.log.WithError(err).Error("Failed to remove failed message ID")
|
||||
}
|
||||
|
||||
return res, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
success := xslices.Filter(result, func(t BuildResult) bool {
|
||||
return t.Update != nil
|
||||
})
|
||||
|
||||
if len(success) > 0 {
|
||||
successIDs := xslices.Map(success, func(t BuildResult) string {
|
||||
return t.MessageID
|
||||
})
|
||||
|
||||
if err := req.job.state.RemFailedMessageID(req.getContext(), successIDs...); err != nil {
|
||||
req.job.log.WithError(err).Error("Failed to remove failed message ID")
|
||||
}
|
||||
}
|
||||
|
||||
b.output.Produce(ctx, ApplyRequest{
|
||||
childJob: req.childJob,
|
||||
messages: xslices.Filter(result, func(t BuildResult) bool {
|
||||
return t.Update != nil
|
||||
}),
|
||||
childJob: chunkedJobs[idx],
|
||||
messages: success,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -153,7 +153,7 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) {
|
||||
buildError := errors.New("it failed")
|
||||
|
||||
tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(BuildResult{}, buildError)
|
||||
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
|
||||
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"}))
|
||||
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
|
||||
"userID": "u",
|
||||
"messageID": "MSG",
|
||||
@ -204,7 +204,7 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin
|
||||
childJob := tj.job.newChildJob("f", 10)
|
||||
tj.job.end()
|
||||
|
||||
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
|
||||
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"}))
|
||||
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
|
||||
"userID": "u",
|
||||
"messageID": "MSG",
|
||||
|
||||
@ -112,8 +112,11 @@ func (d *DownloadStage) run(ctx context.Context) {
|
||||
}
|
||||
|
||||
var attData [][]byte
|
||||
if msg.NumAttachments > 0 {
|
||||
attData = make([][]byte, msg.NumAttachments)
|
||||
|
||||
numAttachments := len(msg.Attachments)
|
||||
|
||||
if numAttachments > 0 {
|
||||
attData = make([][]byte, numAttachments)
|
||||
}
|
||||
|
||||
return proton.FullMessage{Message: msg, AttData: attData}, nil
|
||||
@ -139,7 +142,8 @@ func (d *DownloadStage) run(ctx context.Context) {
|
||||
attachmentIDs := make([]string, 0, len(result))
|
||||
|
||||
for msgIdx, v := range result {
|
||||
for attIdx := 0; attIdx < v.NumAttachments; attIdx++ {
|
||||
numAttachments := len(v.Attachments)
|
||||
for attIdx := 0; attIdx < numAttachments; attIdx++ {
|
||||
attachmentIndices = append(attachmentIndices, attachmentMeta{
|
||||
msgIdx: msgIdx,
|
||||
attIdx: attIdx,
|
||||
|
||||
@ -333,8 +333,7 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) {
|
||||
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
|
||||
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{
|
||||
MessageMetadata: proton.MessageMetadata{
|
||||
ID: "msg",
|
||||
NumAttachments: 1,
|
||||
ID: "msg",
|
||||
},
|
||||
Header: "",
|
||||
ParsedHeaders: nil,
|
||||
@ -436,7 +435,7 @@ func buildDownloadStageAttachments(msg *proton.FullMessage, index int) {
|
||||
func genDownloadStageAttachmentInfo(msg *proton.FullMessage, msgIdx int, count int) {
|
||||
msg.Attachments = make([]proton.Attachment, count)
|
||||
msg.AttData = make([][]byte, count)
|
||||
msg.NumAttachments = count
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
data := fmt.Sprintf("msg-%v-att-%v", msgIdx, i)
|
||||
msg.Attachments[i] = proton.Attachment{
|
||||
|
||||
@ -19,11 +19,12 @@ package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/logging"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/network"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -37,80 +38,91 @@ type MetadataStage struct {
|
||||
output MetadataStageOutput
|
||||
input MetadataStageInput
|
||||
maxDownloadMem uint64
|
||||
jobs []*metadataIterator
|
||||
log *logrus.Entry
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
func NewMetadataStage(input MetadataStageInput, output MetadataStageOutput, maxDownloadMem uint64) *MetadataStage {
|
||||
return &MetadataStage{input: input, output: output, maxDownloadMem: maxDownloadMem, log: logrus.WithField("sync-stage", "metadata")}
|
||||
func NewMetadataStage(
|
||||
input MetadataStageInput,
|
||||
output MetadataStageOutput,
|
||||
maxDownloadMem uint64,
|
||||
panicHandler async.PanicHandler,
|
||||
) *MetadataStage {
|
||||
return &MetadataStage{
|
||||
input: input,
|
||||
output: output,
|
||||
maxDownloadMem: maxDownloadMem,
|
||||
log: logrus.WithField("sync-stage", "metadata"),
|
||||
panicHandler: panicHandler,
|
||||
}
|
||||
}
|
||||
|
||||
const MetadataPageSize = 150
|
||||
const MetadataMaxMessages = 250
|
||||
|
||||
func (m MetadataStage) Run(group *async.Group) {
|
||||
func (m *MetadataStage) Run(group *async.Group) {
|
||||
group.Once(func(ctx context.Context) {
|
||||
m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{})
|
||||
logging.DoAnnotated(
|
||||
ctx,
|
||||
func(ctx context.Context) {
|
||||
m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{})
|
||||
},
|
||||
logging.Labels{"sync-stage": "metadata"},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func (m MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) {
|
||||
func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) {
|
||||
defer m.output.Close()
|
||||
|
||||
group := async.NewGroup(ctx, m.panicHandler)
|
||||
defer group.CancelAndWait()
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if new job has been submitted
|
||||
job, ok, err := m.input.TryConsume(ctx)
|
||||
job, err := m.input.Consume(ctx)
|
||||
if err != nil {
|
||||
m.log.WithError(err).Error("Error trying to retrieve more work")
|
||||
if !(errors.Is(err, context.Canceled) || errors.Is(err, ErrNoMoreInput)) {
|
||||
m.log.WithError(err).Error("Error trying to retrieve more work")
|
||||
}
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
job.begin()
|
||||
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
|
||||
if err != nil {
|
||||
job.onError(err)
|
||||
continue
|
||||
}
|
||||
m.jobs = append(m.jobs, state)
|
||||
|
||||
job.begin()
|
||||
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
|
||||
if err != nil {
|
||||
job.onError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Iterate over all jobs and produce work.
|
||||
for i := 0; i < len(m.jobs); {
|
||||
job := m.jobs[i]
|
||||
group.Once(func(ctx context.Context) {
|
||||
for {
|
||||
if state.stage.ctx.Err() != nil {
|
||||
state.stage.end()
|
||||
return
|
||||
}
|
||||
|
||||
// If the job's context has been cancelled, remove from the list.
|
||||
if job.stage.ctx.Err() != nil {
|
||||
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
|
||||
job.stage.end()
|
||||
continue
|
||||
// Check for more work.
|
||||
output, hasMore, err := state.Next(m.maxDownloadMem, metadataPageSize, maxMessages)
|
||||
if err != nil {
|
||||
state.stage.onError(err)
|
||||
return
|
||||
}
|
||||
|
||||
// If there is actually more work, push it down the pipeline.
|
||||
if len(output.ids) != 0 {
|
||||
state.stage.metadataFetched += int64(len(output.ids))
|
||||
job.log.Debugf("Metada collected: %v/%v", state.stage.metadataFetched, state.stage.totalMessageCount)
|
||||
|
||||
m.output.Produce(ctx, output)
|
||||
}
|
||||
|
||||
// If this job has no more work left, signal completion.
|
||||
if !hasMore {
|
||||
state.stage.end()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check for more work.
|
||||
output, hasMore, err := job.Next(m.maxDownloadMem, metadataPageSize, maxMessages)
|
||||
if err != nil {
|
||||
job.stage.onError(err)
|
||||
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
|
||||
continue
|
||||
}
|
||||
|
||||
// If there is actually more work, push it down the pipeline.
|
||||
if len(output.ids) != 0 {
|
||||
m.output.Produce(ctx, output)
|
||||
}
|
||||
|
||||
// If this job has no more work left, signal completion.
|
||||
if !hasMore {
|
||||
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
|
||||
job.stage.end()
|
||||
continue
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) {
|
||||
output := NewChannelConsumerProducer[DownloadRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
|
||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
|
||||
|
||||
numMessages := 50
|
||||
messageSize := 100
|
||||
@ -86,7 +86,7 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
|
||||
output := NewChannelConsumerProducer[DownloadRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
|
||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
|
||||
|
||||
go func() {
|
||||
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
|
||||
@ -135,7 +135,7 @@ func TestMetadataStage_RunInterleaved(t *testing.T) {
|
||||
output := NewChannelConsumerProducer[DownloadRequest]()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
|
||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
|
||||
|
||||
numMessages := 50
|
||||
messageSize := 100
|
||||
|
||||
@ -31,7 +31,6 @@ var ErrNoMoreInput = errors.New("no more input")
|
||||
|
||||
type StageInputConsumer[T any] interface {
|
||||
Consume(ctx context.Context) (T, error)
|
||||
TryConsume(ctx context.Context) (T, bool, error)
|
||||
}
|
||||
|
||||
type ChannelConsumerProducer[T any] struct {
|
||||
@ -66,20 +65,3 @@ func (c ChannelConsumerProducer[T]) Consume(ctx context.Context) (T, error) {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c ChannelConsumerProducer[T]) TryConsume(ctx context.Context) (T, bool, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
var t T
|
||||
return t, false, ctx.Err()
|
||||
case t, ok := <-c.ch:
|
||||
if !ok {
|
||||
return t, false, ErrNoMoreInput
|
||||
}
|
||||
|
||||
return t, true, nil
|
||||
default:
|
||||
var t T
|
||||
return t, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user