diff --git a/internal/services/syncservice/handler.go b/internal/services/syncservice/handler.go index 160112e0..0ad776ce 100644 --- a/internal/services/syncservice/handler.go +++ b/internal/services/syncservice/handler.go @@ -210,12 +210,10 @@ func (t *Handler) run(ctx context.Context, stageContext.metadataFetched = syncStatus.NumSyncedMessages stageContext.totalMessageCount = syncStatus.TotalMessageCount - defer stageContext.Close() - t.regulator.Sync(ctx, stageContext) // Wait on reply - if err := stageContext.wait(ctx); err != nil { + if err := stageContext.waitAndClose(ctx); err != nil { return fmt.Errorf("failed sync messages: %w", err) } diff --git a/internal/services/syncservice/job.go b/internal/services/syncservice/job.go index cad55c90..29daafe3 100644 --- a/internal/services/syncservice/job.go +++ b/internal/services/syncservice/job.go @@ -19,9 +19,6 @@ package syncservice import ( "context" - "errors" - "fmt" - "sync" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/go-proton-api" @@ -48,10 +45,8 @@ type Job struct { updateApplier UpdateApplier syncReporter Reporter - log *logrus.Entry - errorCh *async.QueuedChannel[error] - wg sync.WaitGroup - once sync.Once + log *logrus.Entry + jw *jobWaiter panicHandler async.PanicHandler downloadCache *DownloadCache @@ -74,7 +69,7 @@ func NewJob(ctx context.Context, ) *Job { ctx, cancel := context.WithCancel(ctx) - return &Job{ + j := &Job{ ctx: ctx, client: client, userID: userID, @@ -85,26 +80,23 @@ func NewJob(ctx context.Context, messageBuilder: messageBuilder, updateApplier: updateApplier, syncReporter: syncReporter, - errorCh: async.NewQueuedChannel[error](4, 8, panicHandler, fmt.Sprintf("sync-job-error-%v", userID)), panicHandler: panicHandler, downloadCache: cache, + jw: newJobWaiter(log.WithField("sync-job", "waiter"), panicHandler), } + + j.jw.begin() + + return j } -func (j *Job) Close() { - j.errorCh.CloseAndDiscardQueued() - j.wg.Wait() +func (j *Job) close() { + j.jw.close() } func (j *Job) onError(err error) { - defer j.wg.Done() + defer j.jw.onTaskFinished(err) - // context cancelled is caught & handled in a different location. - if errors.Is(err, context.Canceled) { - return - } - - j.errorCh.Enqueue(err) j.cancel() } @@ -119,55 +111,42 @@ func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int return } - // j.onError() also calls j.wg.Done(). - j.wg.Done() + // j.onError() also calls j.jw.onTaskFinished(). + defer j.jw.onTaskFinished(nil) j.syncReporter.OnProgress(ctx, count) } // begin is expected to be called once the job enters the pipeline. func (j *Job) begin() { j.log.Info("Job started") - j.wg.Add(1) - j.startChildWaiter() + j.jw.onTaskCreated() } // end is expected to be called once the job has no further work left. func (j *Job) end() { j.log.Info("Job finished") - j.wg.Done() + j.jw.onTaskFinished(nil) } -// wait waits until the job has finished, the context got cancelled or an error occurred. -func (j *Job) wait(ctx context.Context) error { - defer j.wg.Wait() - +// waitAndClose waits until the job has finished, the context got cancelled or an error occurred. +func (j *Job) waitAndClose(ctx context.Context) error { + defer j.close() select { case <-ctx.Done(): - j.cancel() + j.jw.onContextCancelled() + <-j.jw.doneCh return ctx.Err() - case err := <-j.errorCh.GetChannel(): - return err + case e := <-j.jw.doneCh: + return e } } func (j *Job) newChildJob(messageID string, messageCount int64) childJob { j.log.Infof("Creating new child job") - j.wg.Add(1) + j.jw.onTaskCreated() return childJob{job: j, lastMessageID: messageID, messageCount: messageCount} } -func (j *Job) startChildWaiter() { - j.once.Do(func() { - go func() { - defer async.HandlePanic(j.panicHandler) - - j.wg.Wait() - j.log.Info("All child jobs succeeded") - j.errorCh.Enqueue(j.ctx.Err()) - }() - }) -} - // childJob represents a batch of work that goes down the pipeline. It keeps track of the message ID that is in the // batch and the number of messages in the batch. type childJob struct { @@ -232,7 +211,7 @@ func (s *childJob) checkCancelled() bool { err := s.job.ctx.Err() if err != nil { s.job.log.Infof("Child job exit due to context cancelled") - s.job.wg.Done() + s.job.jw.onTaskFinished(err) return true } @@ -242,3 +221,102 @@ func (s *childJob) checkCancelled() bool { func (s *childJob) getContext() context.Context { return s.job.ctx } + +type JobWaiterMessage int + +const ( + JobWaiterMessageCreated JobWaiterMessage = iota + JobWaiterMessageFinished + JobWaiterMessageCtxErr +) + +type jobWaiterMessagePair struct { + m JobWaiterMessage + err error +} + +// jobWaiter is meant to be used to track ongoing sync batches. Once all the child jobs +// have completed, the first recorded error (if any) will be written to doneCh and then this +// channel will be closed. +type jobWaiter struct { + ch chan jobWaiterMessagePair + doneCh chan error + log *logrus.Entry + panicHandler async.PanicHandler +} + +func newJobWaiter(log *logrus.Entry, panicHandler async.PanicHandler) *jobWaiter { + return &jobWaiter{ + ch: make(chan jobWaiterMessagePair), + doneCh: make(chan error), + log: log, + panicHandler: panicHandler, + } +} + +func (j *jobWaiter) close() { + close(j.ch) +} + +func (j *jobWaiter) sendMessage(m JobWaiterMessage, err error) { + j.ch <- jobWaiterMessagePair{ + m: m, + err: err, + } +} + +func (j *jobWaiter) onTaskFinished(err error) { + j.sendMessage(JobWaiterMessageFinished, err) +} + +func (j *jobWaiter) onTaskCreated() { + j.sendMessage(JobWaiterMessageCreated, nil) +} + +func (j *jobWaiter) onContextCancelled() { + j.sendMessage(JobWaiterMessageCtxErr, nil) +} + +func (j *jobWaiter) begin() { + go func() { + defer async.HandlePanic(j.panicHandler) + + total := 0 + var err error + + defer func() { + j.doneCh <- err + close(j.doneCh) + }() + + for { + m, ok := <-j.ch + if !ok { + return + } + + switch m.m { + case JobWaiterMessageCtxErr: + // DO nothing + case JobWaiterMessageCreated: + total++ + case JobWaiterMessageFinished: + total-- + if m.err != nil && err == nil { + err = m.err + } + default: + j.log.Errorf("Unknown message type: %v", m.m) + continue + } + + if total <= 0 { + if total < 0 { + logrus.Errorf("Child count less than 0, shouldn't happen...") + } + j.log.Info("All child jobs completed") + return + } + } + }() +} diff --git a/internal/services/syncservice/job_test.go b/internal/services/syncservice/job_test.go index d195297f..243930c6 100644 --- a/internal/services/syncservice/job_test.go +++ b/internal/services/syncservice/job_test.go @@ -20,6 +20,7 @@ package syncservice import ( "context" "errors" + "sync" "testing" "github.com/ProtonMail/gluon/async" @@ -56,8 +57,7 @@ func TestJob_WaitsOnChildren(t *testing.T) { tj.job.end() }() - require.NoError(t, tj.job.wait(context.Background())) - tj.job.Close() + require.NoError(t, tj.job.waitAndClose(context.Background())) } func TestJob_WaitsOnAllChildrenOnError(t *testing.T) { @@ -73,18 +73,22 @@ func TestJob_WaitsOnAllChildrenOnError(t *testing.T) { jobErr := errors.New("failed") + startCh := make(chan struct{}) + go func() { job1 := tj.job.newChildJob("1", 0) job2 := tj.job.newChildJob("2", 1) + <-startCh + job1.onFinished(context.Background()) job2.onError(jobErr) }() - err := tj.job.wait(context.Background()) + close(startCh) + err := tj.job.waitAndClose(context.Background()) require.Error(t, err) require.ErrorIs(t, err, jobErr) - tj.job.Close() } func TestJob_MultipleChildrenReportError(t *testing.T) { @@ -99,20 +103,22 @@ func TestJob_MultipleChildrenReportError(t *testing.T) { startCh := make(chan struct{}) + wg := sync.WaitGroup{} for i := 0; i < 10; i++ { + wg.Add(1) go func() { job := tj.job.newChildJob("1", 0) + wg.Done() <-startCh - job.onError(jobErr) }() } + wg.Wait() close(startCh) - err := tj.job.wait(context.Background()) + err := tj.job.waitAndClose(context.Background()) require.Error(t, err) require.ErrorIs(t, err, jobErr) - tj.job.Close() } func TestJob_ChildFailureCancelsAllOtherChildJobs(t *testing.T) { @@ -127,8 +133,12 @@ func TestJob_ChildFailureCancelsAllOtherChildJobs(t *testing.T) { failJob := tj.job.newChildJob("0", 1) + tj.job.begin() + wg := sync.WaitGroup{} for i := 0; i < 10; i++ { + wg.Add(1) go func() { + defer wg.Done() job := tj.job.newChildJob("1", 0) <-job.getContext().Done() require.ErrorIs(t, job.getContext().Err(), context.Canceled) @@ -137,12 +147,13 @@ func TestJob_ChildFailureCancelsAllOtherChildJobs(t *testing.T) { } go func() { failJob.onError(jobErr) + wg.Wait() + tj.job.end() }() - err := tj.job.wait(context.Background()) + err := tj.job.waitAndClose(context.Background()) require.Error(t, err) require.ErrorIs(t, err, jobErr) - tj.job.Close() } func TestJob_CtxCancelCancelsAllChildren(t *testing.T) { @@ -154,9 +165,12 @@ func TestJob_CtxCancelCancelsAllChildren(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) tj := newTestJob(ctx, mockCtrl, "u", getTestLabels()) + wg := sync.WaitGroup{} for i := 0; i < 10; i++ { + wg.Add(1) go func() { job := tj.job.newChildJob("1", 0) + wg.Done() <-job.getContext().Done() require.ErrorIs(t, job.getContext().Err(), context.Canceled) require.True(t, job.checkCancelled()) @@ -164,13 +178,35 @@ func TestJob_CtxCancelCancelsAllChildren(t *testing.T) { } go func() { + wg.Wait() cancel() }() - err := tj.job.wait(ctx) + err := tj.job.waitAndClose(ctx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +} + +func TestJob_CtxCancelBeforeBegin(t *testing.T) { + options := setupGoLeak() + defer goleak.VerifyNone(t, options) + + mockCtrl := gomock.NewController(t) + + ctx, cancel := context.WithCancel(context.Background()) + tj := newTestJob(ctx, mockCtrl, "u", getTestLabels()) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + wg.Wait() + cancel() + }() + + wg.Done() + err := tj.job.waitAndClose(ctx) require.Error(t, err) require.ErrorIs(t, err, context.Canceled) - tj.job.Close() } func TestJob_WithoutChildJobsCanBeTerminated(t *testing.T) { @@ -186,9 +222,8 @@ func TestJob_WithoutChildJobsCanBeTerminated(t *testing.T) { tj.job.begin() tj.job.end() }() - err := tj.job.wait(ctx) + err := tj.job.waitAndClose(context.Background()) require.NoError(t, err) - tj.job.Close() } type tjob struct { diff --git a/internal/services/syncservice/stage_apply_test.go b/internal/services/syncservice/stage_apply_test.go index 16464164..0e36e4ca 100644 --- a/internal/services/syncservice/stage_apply_test.go +++ b/internal/services/syncservice/stage_apply_test.go @@ -55,7 +55,7 @@ func TestApplyStage_CancelledJobIsDiscarded(t *testing.T) { messages: nil, }) - err := tj.job.wait(ctx) + err := tj.job.waitAndClose(ctx) require.ErrorIs(t, err, context.Canceled) cancel() } @@ -89,7 +89,7 @@ func TestApplyStage_JobWithNoMessagesIsFinalized(t *testing.T) { messages: nil, }) - err := tj.job.wait(ctx) + err := tj.job.waitAndClose(ctx) cancel() require.NoError(t, err) } @@ -132,7 +132,7 @@ func TestApplyStage_ErrorOnApplyIsReportedAndJobFails(t *testing.T) { messages: buildResults, }) - err := tj.job.wait(ctx) + err := tj.job.waitAndClose(ctx) cancel() require.ErrorIs(t, err, applyErr) } diff --git a/internal/services/syncservice/stage_build_test.go b/internal/services/syncservice/stage_build_test.go index 6b94c297..e8db0263 100644 --- a/internal/services/syncservice/stage_build_test.go +++ b/internal/services/syncservice/stage_build_test.go @@ -269,7 +269,7 @@ func TestBuildStage_OtherErrorsFailJob(t *testing.T) { input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}) - err := tj.job.wait(ctx) + err := tj.job.waitAndClose(ctx) require.Equal(t, expectedErr, err) cancel() diff --git a/internal/services/syncservice/stage_download_test.go b/internal/services/syncservice/stage_download_test.go index 29f91638..c44f4900 100644 --- a/internal/services/syncservice/stage_download_test.go +++ b/internal/services/syncservice/stage_download_test.go @@ -313,7 +313,7 @@ func TestDownloadStage_JobAbortsOnMessageDownloadError(t *testing.T) { ids: []string{"foo"}, }) - err := tj.job.wait(ctx) + err := tj.job.waitAndClose(ctx) require.Equal(t, expectedErr, err) cancel() @@ -364,7 +364,7 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) { ids: []string{"foo"}, }) - err := tj.job.wait(ctx) + err := tj.job.waitAndClose(ctx) require.Equal(t, expectedErr, err) cancel() diff --git a/internal/services/syncservice/stage_metadata.go b/internal/services/syncservice/stage_metadata.go index 30a5ce3e..46ac28f5 100644 --- a/internal/services/syncservice/stage_metadata.go +++ b/internal/services/syncservice/stage_metadata.go @@ -87,6 +87,10 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag return } + if job.ctx.Err() != nil { + continue + } + job.begin() state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown) if err != nil { diff --git a/internal/services/syncservice/stage_metadata_test.go b/internal/services/syncservice/stage_metadata_test.go index 65a87035..9bc303e5 100644 --- a/internal/services/syncservice/stage_metadata_test.go +++ b/internal/services/syncservice/stage_metadata_test.go @@ -114,7 +114,7 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) { } }() - err = tj.job.wait(context.Background()) + err = tj.job.waitAndClose(ctx) require.ErrorIs(t, err, context.Canceled) cancel() } @@ -165,8 +165,8 @@ func TestMetadataStage_RunInterleaved(t *testing.T) { } }() - require.NoError(t, tj1.job.wait(ctx)) - require.NoError(t, tj2.job.wait(ctx)) + require.NoError(t, tj1.job.waitAndClose(ctx)) + require.NoError(t, tj2.job.waitAndClose(ctx)) cancel() }