diff --git a/Makefile b/Makefile index a6471daf..6f010be6 100644 --- a/Makefile +++ b/Makefile @@ -299,11 +299,11 @@ EventSubscriber,MessageEventHandler,LabelEventHandler,AddressEventHandler,Refres > internal/events/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \ > internal/services/useridentity/mocks/mocks.go - mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/sync" -package sync github.com/ProtonMail/proton-bridge/v3/internal/services/sync \ + mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice" -package syncservice github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice \ ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,\ StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier \ > tmp - mv tmp internal/services/sync/mocks_test.go + mv tmp internal/services/syncservice/mocks_test.go lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint-bug-report diff --git a/internal/services/syncservice/api_client.go b/internal/services/syncservice/api_client.go index 6e4f6ca0..b6c0e867 100644 --- a/internal/services/syncservice/api_client.go +++ b/internal/services/syncservice/api_client.go @@ -29,7 +29,6 @@ type APIClient interface { GetLabels(ctx context.Context, labelTypes ...proton.LabelType) ([]proton.Label, error) GetMessage(ctx context.Context, messageID string) (proton.Message, error) GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error) - GetMessageIDs(ctx context.Context, afterID string) ([]string, error) GetFullMessage(ctx context.Context, messageID string, scheduler proton.Scheduler, storageProvider proton.AttachmentAllocator) (proton.FullMessage, error) GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error GetAttachment(ctx context.Context, attachmentID string) ([]byte, error) diff --git a/internal/services/syncservice/handler.go b/internal/services/syncservice/handler.go index c3536387..dbccf55b 100644 --- a/internal/services/syncservice/handler.go +++ b/internal/services/syncservice/handler.go @@ -117,7 +117,11 @@ func (t *Handler) Execute( } t.log.WithField("duration", time.Since(start)).Info("Finished user sync") - t.syncFinishedCh <- err + select { + case <-ctx.Done(): + return + case t.syncFinishedCh <- err: + } }) } @@ -179,8 +183,12 @@ func (t *Handler) run(ctx context.Context, if err := t.syncState.SetMessageCount(ctx, totalMessageCount); err != nil { return fmt.Errorf("failed to store message count: %w", err) } + + syncStatus.TotalMessageCount = totalMessageCount } + syncReporter.InitializeProgressCounter(ctx, syncStatus.NumSyncedMessages, syncStatus.TotalMessageCount) + if !syncStatus.HasMessages { t.log.Info("Syncing messages") @@ -198,6 +206,11 @@ func (t *Handler) run(ctx context.Context, t.log, ) + stageContext.metadataFetched = syncStatus.NumSyncedMessages + stageContext.totalMessageCount = syncStatus.TotalMessageCount + + defer stageContext.Close() + t.regulator.Sync(ctx, stageContext) // Wait on reply diff --git a/internal/services/syncservice/handler_test.go b/internal/services/syncservice/handler_test.go index 9ddf72d1..ff71e377 100644 --- a/internal/services/syncservice/handler_test.go +++ b/internal/services/syncservice/handler_test.go @@ -57,6 +57,7 @@ func TestTask_NoStateAndSucceeds(t *testing.T) { }) call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil) call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) { @@ -125,6 +126,7 @@ func TestTask_StateHasLabels(t *testing.T) { call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) } { @@ -170,6 +172,7 @@ func TestTask_StateHasLabelsAndMessageCount(t *testing.T) { }) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) } tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta)) @@ -219,6 +222,8 @@ func TestTask_RepeatsOnSyncFailure(t *testing.T) { tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) + { call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) { return Status{ diff --git a/internal/services/syncservice/interfaces.go b/internal/services/syncservice/interfaces.go index 2af9da33..39dd9d28 100644 --- a/internal/services/syncservice/interfaces.go +++ b/internal/services/syncservice/interfaces.go @@ -28,8 +28,8 @@ import ( ) type StateProvider interface { - AddFailedMessageID(context.Context, string) error - RemFailedMessageID(context.Context, string) error + AddFailedMessageID(context.Context, ...string) error + RemFailedMessageID(context.Context, ...string) error GetSyncStatus(context.Context) (Status, error) ClearSyncStatus(context.Context) error SetHasLabels(context.Context, bool) error @@ -85,4 +85,5 @@ type Reporter interface { OnFinished(ctx context.Context) OnError(ctx context.Context, err error) OnProgress(ctx context.Context, delta int64) + InitializeProgressCounter(ctx context.Context, current int64, total int64) } diff --git a/internal/services/syncservice/job.go b/internal/services/syncservice/job.go index 1985cf25..01b2dc1a 100644 --- a/internal/services/syncservice/job.go +++ b/internal/services/syncservice/job.go @@ -24,6 +24,7 @@ import ( "sync" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/go-proton-api" "github.com/sirupsen/logrus" ) @@ -54,6 +55,9 @@ type Job struct { panicHandler async.PanicHandler downloadCache *DownloadCache + + metadataFetched int64 + totalMessageCount int64 } func NewJob(ctx context.Context, @@ -178,6 +182,36 @@ func (s *childJob) userID() string { return s.job.userID } +func (s *childJob) chunkDivide(chunks [][]proton.FullMessage) []childJob { + numChunks := len(chunks) + + if numChunks == 1 { + return []childJob{*s} + } + + result := make([]childJob, numChunks) + for i := 0; i < numChunks-1; i++ { + result[i] = s.job.newChildJob(chunks[i][len(chunks[i])-1].ID, int64(len(chunks[i]))) + collectIDs(&result[i], chunks[i]) + } + + result[numChunks-1] = *s + collectIDs(&result[numChunks-1], chunks[numChunks-1]) + + return result +} + +func collectIDs(j *childJob, msgs []proton.FullMessage) { + j.cachedAttachmentIDs = make([]string, 0, len(msgs)) + j.cachedMessageIDs = make([]string, 0, len(msgs)) + for _, msg := range msgs { + j.cachedMessageIDs = append(j.cachedMessageIDs, msg.ID) + for _, attach := range msg.Attachments { + j.cachedAttachmentIDs = append(j.cachedAttachmentIDs, attach.ID) + } + } +} + func (s *childJob) onFinished(ctx context.Context) { s.job.log.Infof("Child job finished") s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount) diff --git a/internal/services/syncservice/mocks_test.go b/internal/services/syncservice/mocks_test.go index cdb09c78..aaac4474 100644 --- a/internal/services/syncservice/mocks_test.go +++ b/internal/services/syncservice/mocks_test.go @@ -1,7 +1,7 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/sync (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier) +// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier) -// Package sync is a generated GoMock package. +// Package syncservice is a generated GoMock package. package syncservice import ( @@ -53,22 +53,6 @@ func (mr *MockApplyStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockApplyStageInput)(nil).Consume), arg0) } -// TryConsume mocks base method. -func (m *MockApplyStageInput) TryConsume(arg0 context.Context) (ApplyRequest, bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TryConsume", arg0) - ret0, _ := ret[0].(ApplyRequest) - ret1, _ := ret[1].(bool) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// TryConsume indicates an expected call of TryConsume. -func (mr *MockApplyStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockApplyStageInput)(nil).TryConsume), arg0) -} - // MockBuildStageInput is a mock of BuildStageInput interface. type MockBuildStageInput struct { ctrl *gomock.Controller @@ -107,22 +91,6 @@ func (mr *MockBuildStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockBuildStageInput)(nil).Consume), arg0) } -// TryConsume mocks base method. -func (m *MockBuildStageInput) TryConsume(arg0 context.Context) (BuildRequest, bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TryConsume", arg0) - ret0, _ := ret[0].(BuildRequest) - ret1, _ := ret[1].(bool) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// TryConsume indicates an expected call of TryConsume. -func (mr *MockBuildStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockBuildStageInput)(nil).TryConsume), arg0) -} - // MockBuildStageOutput is a mock of BuildStageOutput interface. type MockBuildStageOutput struct { ctrl *gomock.Controller @@ -208,22 +176,6 @@ func (mr *MockDownloadStageInputMockRecorder) Consume(arg0 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockDownloadStageInput)(nil).Consume), arg0) } -// TryConsume mocks base method. -func (m *MockDownloadStageInput) TryConsume(arg0 context.Context) (DownloadRequest, bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TryConsume", arg0) - ret0, _ := ret[0].(DownloadRequest) - ret1, _ := ret[1].(bool) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// TryConsume indicates an expected call of TryConsume. -func (mr *MockDownloadStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockDownloadStageInput)(nil).TryConsume), arg0) -} - // MockDownloadStageOutput is a mock of DownloadStageOutput interface. type MockDownloadStageOutput struct { ctrl *gomock.Controller @@ -309,22 +261,6 @@ func (mr *MockMetadataStageInputMockRecorder) Consume(arg0 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockMetadataStageInput)(nil).Consume), arg0) } -// TryConsume mocks base method. -func (m *MockMetadataStageInput) TryConsume(arg0 context.Context) (*Job, bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TryConsume", arg0) - ret0, _ := ret[0].(*Job) - ret1, _ := ret[1].(bool) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// TryConsume indicates an expected call of TryConsume. -func (mr *MockMetadataStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockMetadataStageInput)(nil).TryConsume), arg0) -} - // MockMetadataStageOutput is a mock of MetadataStageOutput interface. type MockMetadataStageOutput struct { ctrl *gomock.Controller @@ -396,17 +332,22 @@ func (m *MockStateProvider) EXPECT() *MockStateProviderMockRecorder { } // AddFailedMessageID mocks base method. -func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 string) error { +func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 ...string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddFailedMessageID", arg0, arg1) + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "AddFailedMessageID", varargs...) ret0, _ := ret[0].(error) return ret0 } // AddFailedMessageID indicates an expected call of AddFailedMessageID. -func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), arg0, arg1) + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), varargs...) } // ClearSyncStatus mocks base method. @@ -439,17 +380,22 @@ func (mr *MockStateProviderMockRecorder) GetSyncStatus(arg0 interface{}) *gomock } // RemFailedMessageID mocks base method. -func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 string) error { +func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 ...string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemFailedMessageID", arg0, arg1) + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "RemFailedMessageID", varargs...) ret0, _ := ret[0].(error) return ret0 } // RemFailedMessageID indicates an expected call of RemFailedMessageID. -func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), arg0, arg1) + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), varargs...) } // SetHasLabels mocks base method. @@ -777,21 +723,6 @@ func (mr *MockAPIClientMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockAPIClient)(nil).GetMessage), arg0, arg1) } -// GetMessageIDs mocks base method. -func (m *MockAPIClient) GetMessageIDs(arg0 context.Context, arg1 string) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMessageIDs", arg0, arg1) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetMessageIDs indicates an expected call of GetMessageIDs. -func (mr *MockAPIClientMockRecorder) GetMessageIDs(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageIDs", reflect.TypeOf((*MockAPIClient)(nil).GetMessageIDs), arg0, arg1) -} - // GetMessageMetadataPage mocks base method. func (m *MockAPIClient) GetMessageMetadataPage(arg0 context.Context, arg1, arg2 int, arg3 proton.MessageFilter) ([]proton.MessageMetadata, error) { m.ctrl.T.Helper() @@ -830,6 +761,18 @@ func (m *MockReporter) EXPECT() *MockReporterMockRecorder { return m.recorder } +// InitializeProgressCounter mocks base method. +func (m *MockReporter) InitializeProgressCounter(arg0 context.Context, arg1, arg2 int64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "InitializeProgressCounter", arg0, arg1, arg2) +} + +// InitializeProgressCounter indicates an expected call of InitializeProgressCounter. +func (mr *MockReporterMockRecorder) InitializeProgressCounter(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitializeProgressCounter", reflect.TypeOf((*MockReporter)(nil).InitializeProgressCounter), arg0, arg1, arg2) +} + // OnError mocks base method. func (m *MockReporter) OnError(arg0 context.Context, arg1 error) { m.ctrl.T.Helper() diff --git a/internal/services/syncservice/service.go b/internal/services/syncservice/service.go index cdc77a7b..10f7087e 100644 --- a/internal/services/syncservice/service.go +++ b/internal/services/syncservice/service.go @@ -48,7 +48,7 @@ func NewService(reporter reporter.Reporter, return &Service{ limits: limits, - metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem), + metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem, panicHandler), downloadStage: NewDownloadStage(downloadCh, buildCh, 20, panicHandler), buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter), applyStage: NewApplyStage(applyCh), diff --git a/internal/services/syncservice/stage_apply.go b/internal/services/syncservice/stage_apply.go index 117c8142..fa47238f 100644 --- a/internal/services/syncservice/stage_apply.go +++ b/internal/services/syncservice/stage_apply.go @@ -22,6 +22,7 @@ import ( "errors" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/logging" "github.com/sirupsen/logrus" ) @@ -44,7 +45,15 @@ func NewApplyStage(input ApplyStageInput) *ApplyStage { } func (a *ApplyStage) Run(group *async.Group) { - group.Once(a.run) + group.Once(func(ctx context.Context) { + logging.DoAnnotated( + ctx, + func(ctx context.Context) { + a.run(ctx) + }, + logging.Labels{"sync-stage": "apply"}, + ) + }) } func (a *ApplyStage) run(ctx context.Context) { diff --git a/internal/services/syncservice/stage_build.go b/internal/services/syncservice/stage_build.go index d4b0f3b4..a9957eff 100644 --- a/internal/services/syncservice/stage_build.go +++ b/internal/services/syncservice/stage_build.go @@ -24,6 +24,7 @@ import ( "runtime" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/logging" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/gopenpgp/v2/crypto" @@ -70,7 +71,15 @@ func NewBuildStage( } func (b *BuildStage) Run(group *async.Group) { - group.Once(b.run) + group.Once(func(ctx context.Context) { + logging.DoAnnotated( + ctx, + func(ctx context.Context) { + b.run(ctx) + }, + logging.Labels{"sync-stage": "build"}, + ) + }) } func (b *BuildStage) run(ctx context.Context) { @@ -94,7 +103,21 @@ func (b *BuildStage) run(ctx context.Context) { err = req.job.messageBuilder.WithKeys(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { chunks := chunkSyncBuilderBatch(req.batch, b.maxBuildMem) - for _, chunk := range chunks { + // This stage will split our existing job into many smaller bits. We need to update the Parent Job so + // that it correctly tracks the lifetime of extra jobs. Additionally, we also need to make sure + // that only the last chunk contains the metadata to clear the cache. + chunkedJobs := req.chunkDivide(chunks) + + for idx, chunk := range chunks { + if chunkedJobs[idx].checkCancelled() { + // Cancel all other chunks. + for i := idx + 1; i < len(chunkedJobs); i++ { + chunkedJobs[i].checkCancelled() + } + + return nil + } + result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (BuildResult, error) { defer async.HandlePanic(b.panicHandler) @@ -135,21 +158,29 @@ func (b *BuildStage) run(ctx context.Context) { return BuildResult{}, nil } - if err := req.job.state.RemFailedMessageID(req.getContext(), res.MessageID); err != nil { - req.job.log.WithError(err).Error("Failed to remove failed message ID") - } - return res, nil }) if err != nil { return err } + success := xslices.Filter(result, func(t BuildResult) bool { + return t.Update != nil + }) + + if len(success) > 0 { + successIDs := xslices.Map(success, func(t BuildResult) string { + return t.MessageID + }) + + if err := req.job.state.RemFailedMessageID(req.getContext(), successIDs...); err != nil { + req.job.log.WithError(err).Error("Failed to remove failed message ID") + } + } + b.output.Produce(ctx, ApplyRequest{ - childJob: req.childJob, - messages: xslices.Filter(result, func(t BuildResult) bool { - return t.Update != nil - }), + childJob: chunkedJobs[idx], + messages: success, }) } diff --git a/internal/services/syncservice/stage_build_test.go b/internal/services/syncservice/stage_build_test.go index ae473faa..4c803173 100644 --- a/internal/services/syncservice/stage_build_test.go +++ b/internal/services/syncservice/stage_build_test.go @@ -153,7 +153,7 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) { buildError := errors.New("it failed") tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(BuildResult{}, buildError) - tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG")) + tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"})) mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{ "userID": "u", "messageID": "MSG", @@ -204,7 +204,7 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin childJob := tj.job.newChildJob("f", 10) tj.job.end() - tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG")) + tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"})) mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{ "userID": "u", "messageID": "MSG", diff --git a/internal/services/syncservice/stage_download.go b/internal/services/syncservice/stage_download.go index ba9d030e..a0d6d981 100644 --- a/internal/services/syncservice/stage_download.go +++ b/internal/services/syncservice/stage_download.go @@ -112,8 +112,11 @@ func (d *DownloadStage) run(ctx context.Context) { } var attData [][]byte - if msg.NumAttachments > 0 { - attData = make([][]byte, msg.NumAttachments) + + numAttachments := len(msg.Attachments) + + if numAttachments > 0 { + attData = make([][]byte, numAttachments) } return proton.FullMessage{Message: msg, AttData: attData}, nil @@ -139,7 +142,8 @@ func (d *DownloadStage) run(ctx context.Context) { attachmentIDs := make([]string, 0, len(result)) for msgIdx, v := range result { - for attIdx := 0; attIdx < v.NumAttachments; attIdx++ { + numAttachments := len(v.Attachments) + for attIdx := 0; attIdx < numAttachments; attIdx++ { attachmentIndices = append(attachmentIndices, attachmentMeta{ msgIdx: msgIdx, attIdx: attIdx, diff --git a/internal/services/syncservice/stage_download_test.go b/internal/services/syncservice/stage_download_test.go index 775244dd..8225ddcb 100644 --- a/internal/services/syncservice/stage_download_test.go +++ b/internal/services/syncservice/stage_download_test.go @@ -333,8 +333,7 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) { tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{}) tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{ MessageMetadata: proton.MessageMetadata{ - ID: "msg", - NumAttachments: 1, + ID: "msg", }, Header: "", ParsedHeaders: nil, @@ -436,7 +435,7 @@ func buildDownloadStageAttachments(msg *proton.FullMessage, index int) { func genDownloadStageAttachmentInfo(msg *proton.FullMessage, msgIdx int, count int) { msg.Attachments = make([]proton.Attachment, count) msg.AttData = make([][]byte, count) - msg.NumAttachments = count + for i := 0; i < count; i++ { data := fmt.Sprintf("msg-%v-att-%v", msgIdx, i) msg.Attachments[i] = proton.Attachment{ diff --git a/internal/services/syncservice/stage_metadata.go b/internal/services/syncservice/stage_metadata.go index a186e73e..b39fcb65 100644 --- a/internal/services/syncservice/stage_metadata.go +++ b/internal/services/syncservice/stage_metadata.go @@ -19,11 +19,12 @@ package syncservice import ( "context" + "errors" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/logging" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/proton-bridge/v3/internal/network" - "github.com/bradenaw/juniper/xslices" "github.com/sirupsen/logrus" ) @@ -37,80 +38,91 @@ type MetadataStage struct { output MetadataStageOutput input MetadataStageInput maxDownloadMem uint64 - jobs []*metadataIterator log *logrus.Entry + panicHandler async.PanicHandler } -func NewMetadataStage(input MetadataStageInput, output MetadataStageOutput, maxDownloadMem uint64) *MetadataStage { - return &MetadataStage{input: input, output: output, maxDownloadMem: maxDownloadMem, log: logrus.WithField("sync-stage", "metadata")} +func NewMetadataStage( + input MetadataStageInput, + output MetadataStageOutput, + maxDownloadMem uint64, + panicHandler async.PanicHandler, +) *MetadataStage { + return &MetadataStage{ + input: input, + output: output, + maxDownloadMem: maxDownloadMem, + log: logrus.WithField("sync-stage", "metadata"), + panicHandler: panicHandler, + } } const MetadataPageSize = 150 const MetadataMaxMessages = 250 -func (m MetadataStage) Run(group *async.Group) { +func (m *MetadataStage) Run(group *async.Group) { group.Once(func(ctx context.Context) { - m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{}) + logging.DoAnnotated( + ctx, + func(ctx context.Context) { + m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{}) + }, + logging.Labels{"sync-stage": "metadata"}, + ) }) } -func (m MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) { +func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) { defer m.output.Close() + group := async.NewGroup(ctx, m.panicHandler) + defer group.CancelAndWait() + for { - if ctx.Err() != nil { - return - } - - // Check if new job has been submitted - job, ok, err := m.input.TryConsume(ctx) + job, err := m.input.Consume(ctx) if err != nil { - m.log.WithError(err).Error("Error trying to retrieve more work") + if !(errors.Is(err, context.Canceled) || errors.Is(err, ErrNoMoreInput)) { + m.log.WithError(err).Error("Error trying to retrieve more work") + } return } - if ok { - job.begin() - state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown) - if err != nil { - job.onError(err) - continue - } - m.jobs = append(m.jobs, state) + + job.begin() + state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown) + if err != nil { + job.onError(err) + continue } - // Iterate over all jobs and produce work. - for i := 0; i < len(m.jobs); { - job := m.jobs[i] + group.Once(func(ctx context.Context) { + for { + if state.stage.ctx.Err() != nil { + state.stage.end() + return + } - // If the job's context has been cancelled, remove from the list. - if job.stage.ctx.Err() != nil { - m.jobs = xslices.RemoveUnordered(m.jobs, i, 1) - job.stage.end() - continue + // Check for more work. + output, hasMore, err := state.Next(m.maxDownloadMem, metadataPageSize, maxMessages) + if err != nil { + state.stage.onError(err) + return + } + + // If there is actually more work, push it down the pipeline. + if len(output.ids) != 0 { + state.stage.metadataFetched += int64(len(output.ids)) + job.log.Debugf("Metada collected: %v/%v", state.stage.metadataFetched, state.stage.totalMessageCount) + + m.output.Produce(ctx, output) + } + + // If this job has no more work left, signal completion. + if !hasMore { + state.stage.end() + return + } } - - // Check for more work. - output, hasMore, err := job.Next(m.maxDownloadMem, metadataPageSize, maxMessages) - if err != nil { - job.stage.onError(err) - m.jobs = xslices.RemoveUnordered(m.jobs, i, 1) - continue - } - - // If there is actually more work, push it down the pipeline. - if len(output.ids) != 0 { - m.output.Produce(ctx, output) - } - - // If this job has no more work left, signal completion. - if !hasMore { - m.jobs = xslices.RemoveUnordered(m.jobs, i, 1) - job.stage.end() - continue - } - - i++ - } + }) } } diff --git a/internal/services/syncservice/stage_metadata_test.go b/internal/services/syncservice/stage_metadata_test.go index efa3ebe9..f77adc71 100644 --- a/internal/services/syncservice/stage_metadata_test.go +++ b/internal/services/syncservice/stage_metadata_test.go @@ -48,7 +48,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) { output := NewChannelConsumerProducer[DownloadRequest]() ctx, cancel := context.WithCancel(context.Background()) - metadata := NewMetadataStage(input, output, TestMaxDownloadMem) + metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{}) numMessages := 50 messageSize := 100 @@ -86,7 +86,7 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) { output := NewChannelConsumerProducer[DownloadRequest]() ctx, cancel := context.WithCancel(context.Background()) - metadata := NewMetadataStage(input, output, TestMaxDownloadMem) + metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{}) go func() { metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{}) @@ -135,7 +135,7 @@ func TestMetadataStage_RunInterleaved(t *testing.T) { output := NewChannelConsumerProducer[DownloadRequest]() ctx, cancel := context.WithCancel(context.Background()) - metadata := NewMetadataStage(input, output, TestMaxDownloadMem) + metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{}) numMessages := 50 messageSize := 100 diff --git a/internal/services/syncservice/stage_output.go b/internal/services/syncservice/stage_output.go index 4278862c..d595b0d5 100644 --- a/internal/services/syncservice/stage_output.go +++ b/internal/services/syncservice/stage_output.go @@ -31,7 +31,6 @@ var ErrNoMoreInput = errors.New("no more input") type StageInputConsumer[T any] interface { Consume(ctx context.Context) (T, error) - TryConsume(ctx context.Context) (T, bool, error) } type ChannelConsumerProducer[T any] struct { @@ -66,20 +65,3 @@ func (c ChannelConsumerProducer[T]) Consume(ctx context.Context) (T, error) { return t, nil } } - -func (c ChannelConsumerProducer[T]) TryConsume(ctx context.Context) (T, bool, error) { - select { - case <-ctx.Done(): - var t T - return t, false, ctx.Err() - case t, ok := <-c.ch: - if !ok { - return t, false, ErrNoMoreInput - } - - return t, true, nil - default: - var t T - return t, false, nil - } -}