diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index ef21ae1e..590ca02e 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -307,7 +307,7 @@ func newBridge( bridge.heartbeat.init(bridge, heartbeatManager) } - bridge.syncService.Run(bridge.tasks) + bridge.syncService.Run() return bridge, nil } @@ -451,6 +451,8 @@ func (bridge *Bridge) Close(ctx context.Context) { logrus.WithError(err).Error("Failed to close servers") } + bridge.syncService.Close() + // Stop all ongoing tasks. bridge.tasks.CancelAndWait() diff --git a/internal/services/syncservice/handler.go b/internal/services/syncservice/handler.go index 0ad776ce..f5e7d22f 100644 --- a/internal/services/syncservice/handler.go +++ b/internal/services/syncservice/handler.go @@ -210,7 +210,11 @@ func (t *Handler) run(ctx context.Context, stageContext.metadataFetched = syncStatus.NumSyncedMessages stageContext.totalMessageCount = syncStatus.TotalMessageCount - t.regulator.Sync(ctx, stageContext) + if err := t.regulator.Sync(ctx, stageContext); err != nil { + stageContext.onError(err) + _ = stageContext.waitAndClose(ctx) + return fmt.Errorf("failed to start sync job: %w", err) + } // Wait on reply if err := stageContext.waitAndClose(ctx); err != nil { diff --git a/internal/services/syncservice/interfaces.go b/internal/services/syncservice/interfaces.go index b8b28aa1..9a8c236c 100644 --- a/internal/services/syncservice/interfaces.go +++ b/internal/services/syncservice/interfaces.go @@ -64,7 +64,7 @@ func (s Status) InProgress() bool { // Regulator is an abstraction for the sync service, since it regulates the number of concurrent sync activities. type Regulator interface { - Sync(ctx context.Context, stage *Job) + Sync(ctx context.Context, stage *Job) error } type BuildResult struct { diff --git a/internal/services/syncservice/job.go b/internal/services/syncservice/job.go index 29daafe3..6615dd21 100644 --- a/internal/services/syncservice/job.go +++ b/internal/services/syncservice/job.go @@ -119,7 +119,6 @@ func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int // begin is expected to be called once the job enters the pipeline. func (j *Job) begin() { j.log.Info("Job started") - j.jw.onTaskCreated() } // end is expected to be called once the job has no further work left. @@ -133,7 +132,6 @@ func (j *Job) waitAndClose(ctx context.Context) error { defer j.close() select { case <-ctx.Done(): - j.jw.onContextCancelled() <-j.jw.doneCh return ctx.Err() case e := <-j.jw.doneCh: @@ -227,7 +225,6 @@ type JobWaiterMessage int const ( JobWaiterMessageCreated JobWaiterMessage = iota JobWaiterMessageFinished - JobWaiterMessageCtxErr ) type jobWaiterMessagePair struct { @@ -248,7 +245,7 @@ type jobWaiter struct { func newJobWaiter(log *logrus.Entry, panicHandler async.PanicHandler) *jobWaiter { return &jobWaiter{ ch: make(chan jobWaiterMessagePair), - doneCh: make(chan error), + doneCh: make(chan error, 2), log: log, panicHandler: panicHandler, } @@ -273,15 +270,11 @@ func (j *jobWaiter) onTaskCreated() { j.sendMessage(JobWaiterMessageCreated, nil) } -func (j *jobWaiter) onContextCancelled() { - j.sendMessage(JobWaiterMessageCtxErr, nil) -} - func (j *jobWaiter) begin() { go func() { defer async.HandlePanic(j.panicHandler) - total := 0 + total := 1 var err error defer func() { @@ -296,8 +289,6 @@ func (j *jobWaiter) begin() { } switch m.m { - case JobWaiterMessageCtxErr: - // DO nothing case JobWaiterMessageCreated: total++ case JobWaiterMessageFinished: diff --git a/internal/services/syncservice/job_test.go b/internal/services/syncservice/job_test.go index 243930c6..d39b27e3 100644 --- a/internal/services/syncservice/job_test.go +++ b/internal/services/syncservice/job_test.go @@ -83,6 +83,7 @@ func TestJob_WaitsOnAllChildrenOnError(t *testing.T) { job1.onFinished(context.Background()) job2.onError(jobErr) + tj.job.end() }() close(startCh) @@ -115,6 +116,7 @@ func TestJob_MultipleChildrenReportError(t *testing.T) { } wg.Wait() + tj.job.end() close(startCh) err := tj.job.waitAndClose(context.Background()) require.Error(t, err) @@ -179,6 +181,7 @@ func TestJob_CtxCancelCancelsAllChildren(t *testing.T) { go func() { wg.Wait() + tj.job.end() cancel() }() @@ -201,6 +204,7 @@ func TestJob_CtxCancelBeforeBegin(t *testing.T) { go func() { wg.Wait() cancel() + tj.job.end() }() wg.Done() diff --git a/internal/services/syncservice/mocks_test.go b/internal/services/syncservice/mocks_test.go index aaac4474..53e12612 100644 --- a/internal/services/syncservice/mocks_test.go +++ b/internal/services/syncservice/mocks_test.go @@ -127,9 +127,11 @@ func (mr *MockBuildStageOutputMockRecorder) Close() *gomock.Call { } // Produce mocks base method. -func (m *MockBuildStageOutput) Produce(arg0 context.Context, arg1 ApplyRequest) { +func (m *MockBuildStageOutput) Produce(arg0 context.Context, arg1 ApplyRequest) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Produce", arg0, arg1) + ret := m.ctrl.Call(m, "Produce", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 } // Produce indicates an expected call of Produce. @@ -212,9 +214,11 @@ func (mr *MockDownloadStageOutputMockRecorder) Close() *gomock.Call { } // Produce mocks base method. -func (m *MockDownloadStageOutput) Produce(arg0 context.Context, arg1 BuildRequest) { +func (m *MockDownloadStageOutput) Produce(arg0 context.Context, arg1 BuildRequest) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Produce", arg0, arg1) + ret := m.ctrl.Call(m, "Produce", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 } // Produce indicates an expected call of Produce. @@ -297,9 +301,11 @@ func (mr *MockMetadataStageOutputMockRecorder) Close() *gomock.Call { } // Produce mocks base method. -func (m *MockMetadataStageOutput) Produce(arg0 context.Context, arg1 DownloadRequest) { +func (m *MockMetadataStageOutput) Produce(arg0 context.Context, arg1 DownloadRequest) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Produce", arg0, arg1) + ret := m.ctrl.Call(m, "Produce", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 } // Produce indicates an expected call of Produce. @@ -478,9 +484,11 @@ func (m *MockRegulator) EXPECT() *MockRegulatorMockRecorder { } // Sync mocks base method. -func (m *MockRegulator) Sync(arg0 context.Context, arg1 *Job) { +func (m *MockRegulator) Sync(arg0 context.Context, arg1 *Job) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Sync", arg0, arg1) + ret := m.ctrl.Call(m, "Sync", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 } // Sync indicates an expected call of Sync. diff --git a/internal/services/syncservice/service.go b/internal/services/syncservice/service.go index 654ff3e3..158c8d7a 100644 --- a/internal/services/syncservice/service.go +++ b/internal/services/syncservice/service.go @@ -33,7 +33,7 @@ type Service struct { applyStage *ApplyStage limits syncLimits metaCh *ChannelConsumerProducer[*Job] - panicHandler async.PanicHandler + group *async.Group } func NewService(reporter reporter.Reporter, @@ -53,26 +53,22 @@ func NewService(reporter reporter.Reporter, buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter), applyStage: NewApplyStage(applyCh), metaCh: metaCh, - panicHandler: panicHandler, + group: async.NewGroup(context.Background(), panicHandler), } } -func (s *Service) Run(group *async.Group) { - group.Once(func(ctx context.Context) { - syncGroup := async.NewGroup(ctx, s.panicHandler) - - s.metadataStage.Run(syncGroup) - s.downloadStage.Run(syncGroup) - s.buildStage.Run(syncGroup) - s.applyStage.Run(syncGroup) - - defer s.metaCh.Close() - defer syncGroup.CancelAndWait() - - <-ctx.Done() - }) +func (s *Service) Run() { + s.metadataStage.Run(s.group) + s.downloadStage.Run(s.group) + s.buildStage.Run(s.group) + s.applyStage.Run(s.group) } -func (s *Service) Sync(ctx context.Context, stage *Job) { - s.metaCh.Produce(ctx, stage) +func (s *Service) Sync(ctx context.Context, stage *Job) error { + return s.metaCh.Produce(ctx, stage) +} + +func (s *Service) Close() { + s.group.CancelAndWait() + s.metaCh.Close() } diff --git a/internal/services/syncservice/stage_apply_test.go b/internal/services/syncservice/stage_apply_test.go index 0e36e4ca..10afdf22 100644 --- a/internal/services/syncservice/stage_apply_test.go +++ b/internal/services/syncservice/stage_apply_test.go @@ -50,10 +50,10 @@ func TestApplyStage_CancelledJobIsDiscarded(t *testing.T) { }() jobCancel() - input.Produce(ctx, ApplyRequest{ + require.NoError(t, input.Produce(ctx, ApplyRequest{ childJob: childJob, messages: nil, - }) + })) err := tj.job.waitAndClose(ctx) require.ErrorIs(t, err, context.Canceled) @@ -84,10 +84,10 @@ func TestApplyStage_JobWithNoMessagesIsFinalized(t *testing.T) { stage.run(ctx) }() - input.Produce(ctx, ApplyRequest{ + require.NoError(t, input.Produce(ctx, ApplyRequest{ childJob: childJob, messages: nil, - }) + })) err := tj.job.waitAndClose(ctx) cancel() @@ -127,10 +127,10 @@ func TestApplyStage_ErrorOnApplyIsReportedAndJobFails(t *testing.T) { stage.run(ctx) }() - input.Produce(ctx, ApplyRequest{ + require.NoError(t, input.Produce(ctx, ApplyRequest{ childJob: childJob, messages: buildResults, - }) + })) err := tj.job.waitAndClose(ctx) cancel() diff --git a/internal/services/syncservice/stage_build.go b/internal/services/syncservice/stage_build.go index 1010c73e..b2a6a6d1 100644 --- a/internal/services/syncservice/stage_build.go +++ b/internal/services/syncservice/stage_build.go @@ -21,6 +21,7 @@ import ( "bytes" "context" "errors" + "fmt" "runtime" "github.com/ProtonMail/gluon/async" @@ -182,10 +183,12 @@ func (b *BuildStage) run(ctx context.Context) { outJob.onStageCompleted(ctx) - b.output.Produce(ctx, ApplyRequest{ + if err := b.output.Produce(ctx, ApplyRequest{ childJob: outJob, messages: success, - }) + }); err != nil { + return fmt.Errorf("failed to produce output for next stage: %w", err) + } } return nil diff --git a/internal/services/syncservice/stage_build_test.go b/internal/services/syncservice/stage_build_test.go index e8db0263..ce06724c 100644 --- a/internal/services/syncservice/stage_build_test.go +++ b/internal/services/syncservice/stage_build_test.go @@ -111,7 +111,7 @@ func TestBuildStage_SuccessRemovesFailedMessage(t *testing.T) { stage.run(ctx) }() - input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}) + require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})) req, err := output.Consume(ctx) cancel() @@ -170,7 +170,7 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) { stage.run(ctx) }() - input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}) + require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})) req, err := output.Consume(ctx) cancel() @@ -222,7 +222,7 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin stage.run(ctx) }() - input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}) + require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})) req, err := output.Consume(ctx) cancel() @@ -267,7 +267,7 @@ func TestBuildStage_OtherErrorsFailJob(t *testing.T) { stage.run(ctx) }() - input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}) + require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})) err := tj.job.waitAndClose(ctx) require.Equal(t, expectedErr, err) @@ -311,10 +311,10 @@ func TestBuildStage_CancelledJobIsDiscarded(t *testing.T) { }() jobCancel() - input.Produce(ctx, BuildRequest{ + require.NoError(t, input.Produce(ctx, BuildRequest{ childJob: childJob, batch: []proton.FullMessage{msg}, - }) + })) go func() { cancel() }() diff --git a/internal/services/syncservice/stage_download.go b/internal/services/syncservice/stage_download.go index 5e0a0375..d26e5266 100644 --- a/internal/services/syncservice/stage_download.go +++ b/internal/services/syncservice/stage_download.go @@ -21,6 +21,7 @@ import ( "bytes" "context" "errors" + "fmt" "sync/atomic" "github.com/ProtonMail/gluon/async" @@ -183,10 +184,12 @@ func (d *DownloadStage) run(ctx context.Context) { // Step 5: Publish result. request.onStageCompleted(ctx) - d.output.Produce(ctx, BuildRequest{ + if err := d.output.Produce(ctx, BuildRequest{ batch: result, childJob: request.childJob, - }) + }); err != nil { + request.job.onError(fmt.Errorf("failed to produce output for next stage: %w", err)) + } } } diff --git a/internal/services/syncservice/stage_download_test.go b/internal/services/syncservice/stage_download_test.go index c44f4900..2cd08bc8 100644 --- a/internal/services/syncservice/stage_download_test.go +++ b/internal/services/syncservice/stage_download_test.go @@ -189,10 +189,10 @@ func TestDownloadStage_Run(t *testing.T) { stage.run(ctx) }() - input.Produce(ctx, DownloadRequest{ + require.NoError(t, input.Produce(ctx, DownloadRequest{ childJob: childJob, ids: msgIDs, - }) + })) out, err := output.Consume(ctx) require.NoError(t, err) @@ -232,10 +232,10 @@ func TestDownloadStage_RunWith422(t *testing.T) { stage.run(ctx) }() - input.Produce(ctx, DownloadRequest{ + require.NoError(t, input.Produce(ctx, DownloadRequest{ childJob: childJob, ids: msgIDs, - }) + })) out, err := output.Consume(ctx) require.NoError(t, err) @@ -271,10 +271,11 @@ func TestDownloadStage_CancelledJobIsDiscarded(t *testing.T) { }() jobCancel() - input.Produce(ctx, DownloadRequest{ + + require.NoError(t, input.Produce(ctx, DownloadRequest{ childJob: childJob, ids: nil, - }) + })) go func() { cancel() }() @@ -308,10 +309,10 @@ func TestDownloadStage_JobAbortsOnMessageDownloadError(t *testing.T) { stage.run(ctx) }() - input.Produce(ctx, DownloadRequest{ + require.NoError(t, input.Produce(ctx, DownloadRequest{ childJob: childJob, ids: []string{"foo"}, - }) + })) err := tj.job.waitAndClose(ctx) require.Equal(t, expectedErr, err) @@ -359,10 +360,10 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) { stage.run(ctx) }() - input.Produce(ctx, DownloadRequest{ + require.NoError(t, input.Produce(ctx, DownloadRequest{ childJob: childJob, ids: []string{"foo"}, - }) + })) err := tj.job.waitAndClose(ctx) require.Equal(t, expectedErr, err) diff --git a/internal/services/syncservice/stage_metadata.go b/internal/services/syncservice/stage_metadata.go index 46ac28f5..0525af19 100644 --- a/internal/services/syncservice/stage_metadata.go +++ b/internal/services/syncservice/stage_metadata.go @@ -20,6 +20,7 @@ package syncservice import ( "context" "errors" + "fmt" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/logging" @@ -87,10 +88,6 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag return } - if job.ctx.Err() != nil { - continue - } - job.begin() state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown) if err != nil { @@ -119,7 +116,10 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag output.onStageCompleted(ctx) - m.output.Produce(ctx, output) + if err := m.output.Produce(ctx, output); err != nil { + job.onError(fmt.Errorf("failed to produce output for next stage: %w", err)) + return + } } // If this job has no more work left, signal completion. diff --git a/internal/services/syncservice/stage_metadata_test.go b/internal/services/syncservice/stage_metadata_test.go index 9bc303e5..d168a675 100644 --- a/internal/services/syncservice/stage_metadata_test.go +++ b/internal/services/syncservice/stage_metadata_test.go @@ -21,6 +21,7 @@ import ( "context" "fmt" "io" + "sync" "testing" "github.com/ProtonMail/gluon/async" @@ -59,7 +60,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) { metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{}) }() - input.Produce(ctx, tj.job) + require.NoError(t, input.Produce(ctx, tj.job)) for _, chunk := range xslices.Chunk(msgs, TestMaxMessages) { tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(len(chunk)))) @@ -93,7 +94,10 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) { metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{}) }() - input.Produce(ctx, tj.job) + { + err := input.Produce(ctx, tj.job) + require.NoError(t, err) + } // read one output then cancel request, err := output.Consume(ctx) @@ -102,8 +106,11 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) { // cancel job context jobCancel() + wg := sync.WaitGroup{} + wg.Add(1) // The next stages should check whether the job has been cancelled or not. Here we need to do it manually. go func() { + wg.Done() for { req, err := output.Consume(ctx) if err != nil { @@ -113,8 +120,9 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) { req.checkCancelled() } }() - + wg.Wait() err = tj.job.waitAndClose(ctx) + require.Error(t, err) require.ErrorIs(t, err, context.Canceled) cancel() } @@ -149,8 +157,8 @@ func TestMetadataStage_RunInterleaved(t *testing.T) { }() go func() { - input.Produce(ctx, tj1.job) - input.Produce(ctx, tj2.job) + require.NoError(t, input.Produce(ctx, tj1.job)) + require.NoError(t, input.Produce(ctx, tj2.job)) }() go func() { diff --git a/internal/services/syncservice/stage_output.go b/internal/services/syncservice/stage_output.go index d595b0d5..4e01bd11 100644 --- a/internal/services/syncservice/stage_output.go +++ b/internal/services/syncservice/stage_output.go @@ -23,7 +23,7 @@ import ( ) type StageOutputProducer[T any] interface { - Produce(ctx context.Context, value T) + Produce(ctx context.Context, value T) error Close() } @@ -41,10 +41,12 @@ func NewChannelConsumerProducer[T any]() *ChannelConsumerProducer[T] { return &ChannelConsumerProducer[T]{ch: make(chan T)} } -func (c ChannelConsumerProducer[T]) Produce(ctx context.Context, value T) { +func (c ChannelConsumerProducer[T]) Produce(ctx context.Context, value T) error { select { case <-ctx.Done(): + return ctx.Err() case c.ch <- value: + return nil } }