diff --git a/internal/services/syncservice/handler.go b/internal/services/syncservice/handler.go index dbccf55b..89637eb6 100644 --- a/internal/services/syncservice/handler.go +++ b/internal/services/syncservice/handler.go @@ -29,6 +29,7 @@ import ( ) const DefaultRetryCoolDown = 20 * time.Second +const NumSyncStages = 4 type LabelMap = map[string]proton.Label @@ -187,7 +188,7 @@ func (t *Handler) run(ctx context.Context, syncStatus.TotalMessageCount = totalMessageCount } - syncReporter.InitializeProgressCounter(ctx, syncStatus.NumSyncedMessages, syncStatus.TotalMessageCount) + syncReporter.InitializeProgressCounter(ctx, syncStatus.NumSyncedMessages, syncStatus.TotalMessageCount*NumSyncStages) if !syncStatus.HasMessages { t.log.Info("Syncing messages") diff --git a/internal/services/syncservice/handler_test.go b/internal/services/syncservice/handler_test.go index ff71e377..2cf71c9a 100644 --- a/internal/services/syncservice/handler_test.go +++ b/internal/services/syncservice/handler_test.go @@ -57,7 +57,7 @@ func TestTask_NoStateAndSucceeds(t *testing.T) { }) call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil) call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil) - tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal*NumSyncStages)) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) { @@ -126,7 +126,7 @@ func TestTask_StateHasLabels(t *testing.T) { call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) - tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal*NumSyncStages)) } { @@ -172,7 +172,7 @@ func TestTask_StateHasLabelsAndMessageCount(t *testing.T) { }) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) - tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal*NumSyncStages)) } tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta)) @@ -222,7 +222,7 @@ func TestTask_RepeatsOnSyncFailure(t *testing.T) { tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta) - tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal)) + tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal*NumSyncStages)) { call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) { diff --git a/internal/services/syncservice/job.go b/internal/services/syncservice/job.go index 01b2dc1a..8978ae6b 100644 --- a/internal/services/syncservice/job.go +++ b/internal/services/syncservice/job.go @@ -108,6 +108,10 @@ func (j *Job) onError(err error) { j.cancel() } +func (j *Job) onStageCompleted(ctx context.Context, count int64) { + j.syncReporter.OnProgress(ctx, count) +} + func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int64) { defer j.wg.Done() @@ -219,6 +223,10 @@ func (s *childJob) onFinished(ctx context.Context) { s.job.downloadCache.DeleteAttachments(s.cachedAttachmentIDs...) } +func (s *childJob) onStageCompleted(ctx context.Context) { + s.job.onStageCompleted(ctx, s.messageCount) +} + func (s *childJob) checkCancelled() bool { err := s.job.ctx.Err() if err != nil { diff --git a/internal/services/syncservice/stage_build.go b/internal/services/syncservice/stage_build.go index a9957eff..1010c73e 100644 --- a/internal/services/syncservice/stage_build.go +++ b/internal/services/syncservice/stage_build.go @@ -178,8 +178,12 @@ func (b *BuildStage) run(ctx context.Context) { } } + outJob := chunkedJobs[idx] + + outJob.onStageCompleted(ctx) + b.output.Produce(ctx, ApplyRequest{ - childJob: chunkedJobs[idx], + childJob: outJob, messages: success, }) } diff --git a/internal/services/syncservice/stage_build_test.go b/internal/services/syncservice/stage_build_test.go index 4c803173..6b94c297 100644 --- a/internal/services/syncservice/stage_build_test.go +++ b/internal/services/syncservice/stage_build_test.go @@ -90,6 +90,8 @@ func TestBuildStage_SuccessRemovesFailedMessage(t *testing.T) { return nil }) + tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(10))) + tj.job.begin() childJob := tj.job.newChildJob("f", 10) tj.job.end() @@ -160,6 +162,8 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) { "error": buildError, })).Return(nil) + tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(10))) + stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, mockReporter) go func() { @@ -210,6 +214,8 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin "messageID": "MSG", })).Return(nil) + tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(10))) + stage := NewBuildStage(input, output, 1024, &async.NoopPanicHandler{}, mockReporter) go func() { diff --git a/internal/services/syncservice/stage_download.go b/internal/services/syncservice/stage_download.go index 2e1b1413..5e0a0375 100644 --- a/internal/services/syncservice/stage_download.go +++ b/internal/services/syncservice/stage_download.go @@ -181,6 +181,8 @@ func (d *DownloadStage) run(ctx context.Context) { request.cachedMessageIDs = request.ids // Step 5: Publish result. + request.onStageCompleted(ctx) + d.output.Produce(ctx, BuildRequest{ batch: result, childJob: request.childJob, diff --git a/internal/services/syncservice/stage_download_test.go b/internal/services/syncservice/stage_download_test.go index 237886c9..ed640444 100644 --- a/internal/services/syncservice/stage_download_test.go +++ b/internal/services/syncservice/stage_download_test.go @@ -175,6 +175,8 @@ func TestDownloadStage_Run(t *testing.T) { tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any()) tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("f"), gomock.Eq(int64(10))).Return(nil) + tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(10))) + tj.job.begin() defer tj.job.end() childJob := tj.job.newChildJob("f", 10) @@ -216,6 +218,8 @@ func TestDownloadStage_RunWith422(t *testing.T) { tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Any()) tj.state.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq("f"), gomock.Eq(int64(10))).Return(nil) + tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(10))) + tj.job.begin() defer tj.job.end() childJob := tj.job.newChildJob("f", 10) diff --git a/internal/services/syncservice/stage_metadata.go b/internal/services/syncservice/stage_metadata.go index b7ce6935..30a5ce3e 100644 --- a/internal/services/syncservice/stage_metadata.go +++ b/internal/services/syncservice/stage_metadata.go @@ -113,6 +113,8 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag state.stage.metadataFetched += int64(len(output.ids)) job.log.Debugf("Metada collected: %v/%v", state.stage.metadataFetched, state.stage.totalMessageCount) + output.onStageCompleted(ctx) + m.output.Produce(ctx, output) } diff --git a/internal/services/syncservice/stage_metadata_test.go b/internal/services/syncservice/stage_metadata_test.go index f77adc71..65a87035 100644 --- a/internal/services/syncservice/stage_metadata_test.go +++ b/internal/services/syncservice/stage_metadata_test.go @@ -62,6 +62,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) { input.Produce(ctx, tj.job) for _, chunk := range xslices.Chunk(msgs, TestMaxMessages) { + tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(len(chunk)))) req, err := output.Consume(ctx) require.NoError(t, err) require.Equal(t, req.ids, xslices.Map(chunk, func(m proton.MessageMetadata) string {