From aa77a67a1c9de4d143dd5edce85908b54d6f9bfd Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Fri, 25 Aug 2023 15:01:03 +0200 Subject: [PATCH] fix(GODT-2829): Sync Service fixes Fix tracking of child jobs. The build stage splits the incoming work even further, but this was not reflected in the wait group counter. This also fixes an issue where the cache was cleared to late. Add more debug info for analysis. Refactor sync state interface in order to have persistent sync rate. --- Makefile | 4 +- internal/services/syncservice/api_client.go | 1 - internal/services/syncservice/handler.go | 15 ++- internal/services/syncservice/handler_test.go | 5 + internal/services/syncservice/interfaces.go | 5 +- internal/services/syncservice/job.go | 34 +++++ internal/services/syncservice/mocks_test.go | 121 +++++------------- internal/services/syncservice/service.go | 2 +- internal/services/syncservice/stage_apply.go | 11 +- internal/services/syncservice/stage_build.go | 51 ++++++-- .../services/syncservice/stage_build_test.go | 4 +- .../services/syncservice/stage_download.go | 10 +- .../syncservice/stage_download_test.go | 5 +- .../services/syncservice/stage_metadata.go | 118 +++++++++-------- .../syncservice/stage_metadata_test.go | 6 +- internal/services/syncservice/stage_output.go | 18 --- 16 files changed, 221 insertions(+), 189 deletions(-) diff --git a/Makefile b/Makefile index a6471daf..6f010be6 100644 --- a/Makefile +++ b/Makefile @@ -299,11 +299,11 @@ EventSubscriber,MessageEventHandler,LabelEventHandler,AddressEventHandler,Refres > internal/events/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \ > internal/services/useridentity/mocks/mocks.go - mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/sync" -package sync github.com/ProtonMail/proton-bridge/v3/internal/services/sync \ + mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice" -package syncservice github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice \ ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,\ StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier \ > tmp - mv tmp internal/services/sync/mocks_test.go + mv tmp internal/services/syncservice/mocks_test.go lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint-bug-report diff --git a/internal/services/syncservice/api_client.go b/internal/services/syncservice/api_client.go index 6e4f6ca0..b6c0e867 100644 --- a/internal/services/syncservice/api_client.go +++ b/internal/services/syncservice/api_client.go @@ -29,7 +29,6 @@ type APIClient interface { GetLabels(ctx context.Context, labelTypes ...proton.LabelType) ([]proton.Label, error) GetMessage(ctx context.Context, messageID string) (proton.Message, error) GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error) - GetMessageIDs(ctx context.Context, afterID string) ([]string, error) GetFullMessage(ctx context.Context, messageID string, scheduler proton.Scheduler, storageProvider proton.AttachmentAllocator) (proton.FullMessage, error) GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error GetAttachment(ctx context.Context, attachmentID string) ([]byte, error) diff --git a/internal/services/syncservice/handler.go b/internal/services/syncservice/handler.go index c3536387..dbccf55b 100644 --- a/internal/services/syncservice/handler.go +++ b/internal/services/syncservice/handler.go @@ -117,7 +117,11 @@ func (t *Handler) Execute( } t.log.WithField("duration", time.Since(start)).Info("Finished user sync") - t.syncFinishedCh <- err + select { + case <-ctx.Done(): + return + case t.syncFinishedCh <- err: + } }) } @@ -179,8 +183,12 @@ func (t *Handler) run(ctx context.Context, if err := t.syncState.SetMessageCount(ctx, totalMessageCount); err != nil { return fmt.Errorf("failed to store message count: %w", err) } + + syncStatus.TotalMessageCount = totalMessageCount } + syncReporter.InitializeProgressCounter(ctx, syncStatus.NumSyncedMessages, syncStatus.TotalMessageCount) + if !syncStatus.HasMessages { t.log.Info("Syncing messages") @@ -198,6 +206,11 @@ func (t *Handler) run(ctx context.Context, t.log, ) + stageContext.metadataFetched = syncStatus.NumSyncedMessages + stageContext.totalMessageCount = syncStatus.TotalMessageCount + + defer stageContext.Close() + t.regulator.Sync(ctx, stageContext) // Wait on reply diff --git a/internal/services/syncservice/handler_test.go b/internal/services/syncservice/handler_test.go index 9ddf72d1..ff71e377 100644 --- a/internal/services/syncservice/handler_test.go +++ b/internal/services/syncservice/handler_test.go @@ -57,6 +57,7 @@ func TestTask_NoStateAndSucceeds(t *testing.T) { }) call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil) call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) { @@ -125,6 +126,7 @@ func TestTask_StateHasLabels(t *testing.T) { call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) } { @@ -170,6 +172,7 @@ func TestTask_StateHasLabelsAndMessageCount(t *testing.T) { }) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) } tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta)) @@ -219,6 +222,8 @@ func TestTask_RepeatsOnSyncFailure(t *testing.T) { tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) + { call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) { return Status{ diff --git a/internal/services/syncservice/interfaces.go b/internal/services/syncservice/interfaces.go index 2af9da33..39dd9d28 100644 --- a/internal/services/syncservice/interfaces.go +++ b/internal/services/syncservice/interfaces.go @@ -28,8 +28,8 @@ import ( ) type StateProvider interface { - AddFailedMessageID(context.Context, string) error - RemFailedMessageID(context.Context, string) error + AddFailedMessageID(context.Context, ...string) error + RemFailedMessageID(context.Context, ...string) error GetSyncStatus(context.Context) (Status, error) ClearSyncStatus(context.Context) error SetHasLabels(context.Context, bool) error @@ -85,4 +85,5 @@ type Reporter interface { OnFinished(ctx context.Context) OnError(ctx context.Context, err error) OnProgress(ctx context.Context, delta int64) + InitializeProgressCounter(ctx context.Context, current int64, total int64) } diff --git a/internal/services/syncservice/job.go b/internal/services/syncservice/job.go index 1985cf25..01b2dc1a 100644 --- a/internal/services/syncservice/job.go +++ b/internal/services/syncservice/job.go @@ -24,6 +24,7 @@ import ( "sync" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/go-proton-api" "github.com/sirupsen/logrus" ) @@ -54,6 +55,9 @@ type Job struct { panicHandler async.PanicHandler downloadCache *DownloadCache + + metadataFetched int64 + totalMessageCount int64 } func NewJob(ctx context.Context, @@ -178,6 +182,36 @@ func (s *childJob) userID() string { return s.job.userID } +func (s *childJob) chunkDivide(chunks [][]proton.FullMessage) []childJob { + numChunks := len(chunks) + + if numChunks == 1 { + return []childJob{*s} + } + + result := make([]childJob, numChunks) + for i := 0; i < numChunks-1; i++ { + result[i] = s.job.newChildJob(chunks[i][len(chunks[i])-1].ID, int64(len(chunks[i]))) + collectIDs(&result[i], chunks[i]) + } + + result[numChunks-1] = *s + collectIDs(&result[numChunks-1], chunks[numChunks-1]) + + return result +} + +func collectIDs(j *childJob, msgs []proton.FullMessage) { + j.cachedAttachmentIDs = make([]string, 0, len(msgs)) + j.cachedMessageIDs = make([]string, 0, len(msgs)) + for _, msg := range msgs { + j.cachedMessageIDs = append(j.cachedMessageIDs, msg.ID) + for _, attach := range msg.Attachments { + j.cachedAttachmentIDs = append(j.cachedAttachmentIDs, attach.ID) + } + } +} + func (s *childJob) onFinished(ctx context.Context) { s.job.log.Infof("Child job finished") s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount) diff --git a/internal/services/syncservice/mocks_test.go b/internal/services/syncservice/mocks_test.go index cdb09c78..aaac4474 100644 --- a/internal/services/syncservice/mocks_test.go +++ b/internal/services/syncservice/mocks_test.go @@ -1,7 +1,7 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/sync (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier) +// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier) -// Package sync is a generated GoMock package. +// Package syncservice is a generated GoMock package. package syncservice import ( @@ -53,22 +53,6 @@ func (mr *MockApplyStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockApplyStageInput)(nil).Consume), arg0) } -// TryConsume mocks base method. -func (m *MockApplyStageInput) TryConsume(arg0 context.Context) (ApplyRequest, bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TryConsume", arg0) - ret0, _ := ret[0].(ApplyRequest) - ret1, _ := ret[1].(bool) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// TryConsume indicates an expected call of TryConsume. -func (mr *MockApplyStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockApplyStageInput)(nil).TryConsume), arg0) -} - // MockBuildStageInput is a mock of BuildStageInput interface. type MockBuildStageInput struct { ctrl *gomock.Controller @@ -107,22 +91,6 @@ func (mr *MockBuildStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockBuildStageInput)(nil).Consume), arg0) } -// TryConsume mocks base method. -func (m *MockBuildStageInput) TryConsume(arg0 context.Context) (BuildRequest, bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TryConsume", arg0) - ret0, _ := ret[0].(BuildRequest) - ret1, _ := ret[1].(bool) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// TryConsume indicates an expected call of TryConsume. -func (mr *MockBuildStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockBuildStageInput)(nil).TryConsume), arg0) -} - // MockBuildStageOutput is a mock of BuildStageOutput interface. type MockBuildStageOutput struct { ctrl *gomock.Controller @@ -208,22 +176,6 @@ func (mr *MockDownloadStageInputMockRecorder) Consume(arg0 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockDownloadStageInput)(nil).Consume), arg0) } -// TryConsume mocks base method. -func (m *MockDownloadStageInput) TryConsume(arg0 context.Context) (DownloadRequest, bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TryConsume", arg0) - ret0, _ := ret[0].(DownloadRequest) - ret1, _ := ret[1].(bool) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// TryConsume indicates an expected call of TryConsume. -func (mr *MockDownloadStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockDownloadStageInput)(nil).TryConsume), arg0) -} - // MockDownloadStageOutput is a mock of DownloadStageOutput interface. type MockDownloadStageOutput struct { ctrl *gomock.Controller @@ -309,22 +261,6 @@ func (mr *MockMetadataStageInputMockRecorder) Consume(arg0 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockMetadataStageInput)(nil).Consume), arg0) } -// TryConsume mocks base method. -func (m *MockMetadataStageInput) TryConsume(arg0 context.Context) (*Job, bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TryConsume", arg0) - ret0, _ := ret[0].(*Job) - ret1, _ := ret[1].(bool) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// TryConsume indicates an expected call of TryConsume. -func (mr *MockMetadataStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockMetadataStageInput)(nil).TryConsume), arg0) -} - // MockMetadataStageOutput is a mock of MetadataStageOutput interface. type MockMetadataStageOutput struct { ctrl *gomock.Controller @@ -396,17 +332,22 @@ func (m *MockStateProvider) EXPECT() *MockStateProviderMockRecorder { } // AddFailedMessageID mocks base method. -func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 string) error { +func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 ...string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddFailedMessageID", arg0, arg1) + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "AddFailedMessageID", varargs...) ret0, _ := ret[0].(error) return ret0 } // AddFailedMessageID indicates an expected call of AddFailedMessageID. -func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), arg0, arg1) + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), varargs...) } // ClearSyncStatus mocks base method. @@ -439,17 +380,22 @@ func (mr *MockStateProviderMockRecorder) GetSyncStatus(arg0 interface{}) *gomock } // RemFailedMessageID mocks base method. -func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 string) error { +func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 ...string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemFailedMessageID", arg0, arg1) + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "RemFailedMessageID", varargs...) ret0, _ := ret[0].(error) return ret0 } // RemFailedMessageID indicates an expected call of RemFailedMessageID. -func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), arg0, arg1) + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), varargs...) } // SetHasLabels mocks base method. @@ -777,21 +723,6 @@ func (mr *MockAPIClientMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockAPIClient)(nil).GetMessage), arg0, arg1) } -// GetMessageIDs mocks base method. -func (m *MockAPIClient) GetMessageIDs(arg0 context.Context, arg1 string) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMessageIDs", arg0, arg1) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetMessageIDs indicates an expected call of GetMessageIDs. -func (mr *MockAPIClientMockRecorder) GetMessageIDs(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageIDs", reflect.TypeOf((*MockAPIClient)(nil).GetMessageIDs), arg0, arg1) -} - // GetMessageMetadataPage mocks base method. func (m *MockAPIClient) GetMessageMetadataPage(arg0 context.Context, arg1, arg2 int, arg3 proton.MessageFilter) ([]proton.MessageMetadata, error) { m.ctrl.T.Helper() @@ -830,6 +761,18 @@ func (m *MockReporter) EXPECT() *MockReporterMockRecorder { return m.recorder } +// InitializeProgressCounter mocks base method. +func (m *MockReporter) InitializeProgressCounter(arg0 context.Context, arg1, arg2 int64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "InitializeProgressCounter", arg0, arg1, arg2) +} + +// InitializeProgressCounter indicates an expected call of InitializeProgressCounter. +func (mr *MockReporterMockRecorder) InitializeProgressCounter(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitializeProgressCounter", reflect.TypeOf((*MockReporter)(nil).InitializeProgressCounter), arg0, arg1, arg2) +} + // OnError mocks base method. func (m *MockReporter) OnError(arg0 context.Context, arg1 error) { m.ctrl.T.Helper() diff --git a/internal/services/syncservice/service.go b/internal/services/syncservice/service.go index cdc77a7b..10f7087e 100644 --- a/internal/services/syncservice/service.go +++ b/internal/services/syncservice/service.go @@ -48,7 +48,7 @@ func NewService(reporter reporter.Reporter, return &Service{ limits: limits, - metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem), + metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem, panicHandler), downloadStage: NewDownloadStage(downloadCh, buildCh, 20, panicHandler), buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter), applyStage: NewApplyStage(applyCh), diff --git a/internal/services/syncservice/stage_apply.go b/internal/services/syncservice/stage_apply.go index 117c8142..fa47238f 100644 --- a/internal/services/syncservice/stage_apply.go +++ b/internal/services/syncservice/stage_apply.go @@ -22,6 +22,7 @@ import ( "errors" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/logging" "github.com/sirupsen/logrus" ) @@ -44,7 +45,15 @@ func NewApplyStage(input ApplyStageInput) *ApplyStage { } func (a *ApplyStage) Run(group *async.Group) { - group.Once(a.run) + group.Once(func(ctx context.Context) { + logging.DoAnnotated( + ctx, + func(ctx context.Context) { + a.run(ctx) + }, + logging.Labels{"sync-stage": "apply"}, + ) + }) } func (a *ApplyStage) run(ctx context.Context) { diff --git a/internal/services/syncservice/stage_build.go b/internal/services/syncservice/stage_build.go index d4b0f3b4..a9957eff 100644 --- a/internal/services/syncservice/stage_build.go +++ b/internal/services/syncservice/stage_build.go @@ -24,6 +24,7 @@ import ( "runtime" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/logging" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/gopenpgp/v2/crypto" @@ -70,7 +71,15 @@ func NewBuildStage( } func (b *BuildStage) Run(group *async.Group) { - group.Once(b.run) + group.Once(func(ctx context.Context) { + logging.DoAnnotated( + ctx, + func(ctx context.Context) { + b.run(ctx) + }, + logging.Labels{"sync-stage": "build"}, + ) + }) } func (b *BuildStage) run(ctx context.Context) { @@ -94,7 +103,21 @@ func (b *BuildStage) run(ctx context.Context) { err = req.job.messageBuilder.WithKeys(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { chunks := chunkSyncBuilderBatch(req.batch, b.maxBuildMem) - for _, chunk := range chunks { + // This stage will split our existing job into many smaller bits. We need to update the Parent Job so + // that it correctly tracks the lifetime of extra jobs. Additionally, we also need to make sure + // that only the last chunk contains the metadata to clear the cache. + chunkedJobs := req.chunkDivide(chunks) + + for idx, chunk := range chunks { + if chunkedJobs[idx].checkCancelled() { + // Cancel all other chunks. + for i := idx + 1; i < len(chunkedJobs); i++ { + chunkedJobs[i].checkCancelled() + } + + return nil + } + result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (BuildResult, error) { defer async.HandlePanic(b.panicHandler) @@ -135,21 +158,29 @@ func (b *BuildStage) run(ctx context.Context) { return BuildResult{}, nil } - if err := req.job.state.RemFailedMessageID(req.getContext(), res.MessageID); err != nil { - req.job.log.WithError(err).Error("Failed to remove failed message ID") - } - return res, nil }) if err != nil { return err } + success := xslices.Filter(result, func(t BuildResult) bool { + return t.Update != nil + }) + + if len(success) > 0 { + successIDs := xslices.Map(success, func(t BuildResult) string { + return t.MessageID + }) + + if err := req.job.state.RemFailedMessageID(req.getContext(), successIDs...); err != nil { + req.job.log.WithError(err).Error("Failed to remove failed message ID") + } + } + b.output.Produce(ctx, ApplyRequest{ - childJob: req.childJob, - messages: xslices.Filter(result, func(t BuildResult) bool { - return t.Update != nil - }), + childJob: chunkedJobs[idx], + messages: success, }) } diff --git a/internal/services/syncservice/stage_build_test.go b/internal/services/syncservice/stage_build_test.go index ae473faa..4c803173 100644 --- a/internal/services/syncservice/stage_build_test.go +++ b/internal/services/syncservice/stage_build_test.go @@ -153,7 +153,7 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) { buildError := errors.New("it failed") tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(BuildResult{}, buildError) - tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG")) + tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"})) mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{ "userID": "u", "messageID": "MSG", @@ -204,7 +204,7 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin childJob := tj.job.newChildJob("f", 10) tj.job.end() - tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG")) + tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"})) mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{ "userID": "u", "messageID": "MSG", diff --git a/internal/services/syncservice/stage_download.go b/internal/services/syncservice/stage_download.go index ba9d030e..a0d6d981 100644 --- a/internal/services/syncservice/stage_download.go +++ b/internal/services/syncservice/stage_download.go @@ -112,8 +112,11 @@ func (d *DownloadStage) run(ctx context.Context) { } var attData [][]byte - if msg.NumAttachments > 0 { - attData = make([][]byte, msg.NumAttachments) + + numAttachments := len(msg.Attachments) + + if numAttachments > 0 { + attData = make([][]byte, numAttachments) } return proton.FullMessage{Message: msg, AttData: attData}, nil @@ -139,7 +142,8 @@ func (d *DownloadStage) run(ctx context.Context) { attachmentIDs := make([]string, 0, len(result)) for msgIdx, v := range result { - for attIdx := 0; attIdx < v.NumAttachments; attIdx++ { + numAttachments := len(v.Attachments) + for attIdx := 0; attIdx < numAttachments; attIdx++ { attachmentIndices = append(attachmentIndices, attachmentMeta{ msgIdx: msgIdx, attIdx: attIdx, diff --git a/internal/services/syncservice/stage_download_test.go b/internal/services/syncservice/stage_download_test.go index 775244dd..8225ddcb 100644 --- a/internal/services/syncservice/stage_download_test.go +++ b/internal/services/syncservice/stage_download_test.go @@ -333,8 +333,7 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) { tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{}) tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{ MessageMetadata: proton.MessageMetadata{ - ID: "msg", - NumAttachments: 1, + ID: "msg", }, Header: "", ParsedHeaders: nil, @@ -436,7 +435,7 @@ func buildDownloadStageAttachments(msg *proton.FullMessage, index int) { func genDownloadStageAttachmentInfo(msg *proton.FullMessage, msgIdx int, count int) { msg.Attachments = make([]proton.Attachment, count) msg.AttData = make([][]byte, count) - msg.NumAttachments = count + for i := 0; i < count; i++ { data := fmt.Sprintf("msg-%v-att-%v", msgIdx, i) msg.Attachments[i] = proton.Attachment{ diff --git a/internal/services/syncservice/stage_metadata.go b/internal/services/syncservice/stage_metadata.go index a186e73e..b39fcb65 100644 --- a/internal/services/syncservice/stage_metadata.go +++ b/internal/services/syncservice/stage_metadata.go @@ -19,11 +19,12 @@ package syncservice import ( "context" + "errors" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/logging" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/proton-bridge/v3/internal/network" - "github.com/bradenaw/juniper/xslices" "github.com/sirupsen/logrus" ) @@ -37,80 +38,91 @@ type MetadataStage struct { output MetadataStageOutput input MetadataStageInput maxDownloadMem uint64 - jobs []*metadataIterator log *logrus.Entry + panicHandler async.PanicHandler } -func NewMetadataStage(input MetadataStageInput, output MetadataStageOutput, maxDownloadMem uint64) *MetadataStage { - return &MetadataStage{input: input, output: output, maxDownloadMem: maxDownloadMem, log: logrus.WithField("sync-stage", "metadata")} +func NewMetadataStage( + input MetadataStageInput, + output MetadataStageOutput, + maxDownloadMem uint64, + panicHandler async.PanicHandler, +) *MetadataStage { + return &MetadataStage{ + input: input, + output: output, + maxDownloadMem: maxDownloadMem, + log: logrus.WithField("sync-stage", "metadata"), + panicHandler: panicHandler, + } } const MetadataPageSize = 150 const MetadataMaxMessages = 250 -func (m MetadataStage) Run(group *async.Group) { +func (m *MetadataStage) Run(group *async.Group) { group.Once(func(ctx context.Context) { - m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{}) + logging.DoAnnotated( + ctx, + func(ctx context.Context) { + m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{}) + }, + logging.Labels{"sync-stage": "metadata"}, + ) }) } -func (m MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) { +func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) { defer m.output.Close() + group := async.NewGroup(ctx, m.panicHandler) + defer group.CancelAndWait() + for { - if ctx.Err() != nil { - return - } - - // Check if new job has been submitted - job, ok, err := m.input.TryConsume(ctx) + job, err := m.input.Consume(ctx) if err != nil { - m.log.WithError(err).Error("Error trying to retrieve more work") + if !(errors.Is(err, context.Canceled) || errors.Is(err, ErrNoMoreInput)) { + m.log.WithError(err).Error("Error trying to retrieve more work") + } return } - if ok { - job.begin() - state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown) - if err != nil { - job.onError(err) - continue - } - m.jobs = append(m.jobs, state) + + job.begin() + state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown) + if err != nil { + job.onError(err) + continue } - // Iterate over all jobs and produce work. - for i := 0; i < len(m.jobs); { - job := m.jobs[i] + group.Once(func(ctx context.Context) { + for { + if state.stage.ctx.Err() != nil { + state.stage.end() + return + } - // If the job's context has been cancelled, remove from the list. - if job.stage.ctx.Err() != nil { - m.jobs = xslices.RemoveUnordered(m.jobs, i, 1) - job.stage.end() - continue + // Check for more work. + output, hasMore, err := state.Next(m.maxDownloadMem, metadataPageSize, maxMessages) + if err != nil { + state.stage.onError(err) + return + } + + // If there is actually more work, push it down the pipeline. + if len(output.ids) != 0 { + state.stage.metadataFetched += int64(len(output.ids)) + job.log.Debugf("Metada collected: %v/%v", state.stage.metadataFetched, state.stage.totalMessageCount) + + m.output.Produce(ctx, output) + } + + // If this job has no more work left, signal completion. + if !hasMore { + state.stage.end() + return + } } - - // Check for more work. - output, hasMore, err := job.Next(m.maxDownloadMem, metadataPageSize, maxMessages) - if err != nil { - job.stage.onError(err) - m.jobs = xslices.RemoveUnordered(m.jobs, i, 1) - continue - } - - // If there is actually more work, push it down the pipeline. - if len(output.ids) != 0 { - m.output.Produce(ctx, output) - } - - // If this job has no more work left, signal completion. - if !hasMore { - m.jobs = xslices.RemoveUnordered(m.jobs, i, 1) - job.stage.end() - continue - } - - i++ - } + }) } } diff --git a/internal/services/syncservice/stage_metadata_test.go b/internal/services/syncservice/stage_metadata_test.go index efa3ebe9..f77adc71 100644 --- a/internal/services/syncservice/stage_metadata_test.go +++ b/internal/services/syncservice/stage_metadata_test.go @@ -48,7 +48,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) { output := NewChannelConsumerProducer[DownloadRequest]() ctx, cancel := context.WithCancel(context.Background()) - metadata := NewMetadataStage(input, output, TestMaxDownloadMem) + metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{}) numMessages := 50 messageSize := 100 @@ -86,7 +86,7 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) { output := NewChannelConsumerProducer[DownloadRequest]() ctx, cancel := context.WithCancel(context.Background()) - metadata := NewMetadataStage(input, output, TestMaxDownloadMem) + metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{}) go func() { metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{}) @@ -135,7 +135,7 @@ func TestMetadataStage_RunInterleaved(t *testing.T) { output := NewChannelConsumerProducer[DownloadRequest]() ctx, cancel := context.WithCancel(context.Background()) - metadata := NewMetadataStage(input, output, TestMaxDownloadMem) + metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{}) numMessages := 50 messageSize := 100 diff --git a/internal/services/syncservice/stage_output.go b/internal/services/syncservice/stage_output.go index 4278862c..d595b0d5 100644 --- a/internal/services/syncservice/stage_output.go +++ b/internal/services/syncservice/stage_output.go @@ -31,7 +31,6 @@ var ErrNoMoreInput = errors.New("no more input") type StageInputConsumer[T any] interface { Consume(ctx context.Context) (T, error) - TryConsume(ctx context.Context) (T, bool, error) } type ChannelConsumerProducer[T any] struct { @@ -66,20 +65,3 @@ func (c ChannelConsumerProducer[T]) Consume(ctx context.Context) (T, error) { return t, nil } } - -func (c ChannelConsumerProducer[T]) TryConsume(ctx context.Context) (T, bool, error) { - select { - case <-ctx.Done(): - var t T - return t, false, ctx.Err() - case t, ok := <-c.ch: - if !ok { - return t, false, ErrNoMoreInput - } - - return t, true, nil - default: - var t T - return t, false, nil - } -}