fix(GODT-3124): Handling of sync child jobs

Improve the handling of sync child jobs to ensure it behaves correctly
in all scenarios.

The sync service now uses a isolated context to avoid all the pipeline
stages shutting down before all the sync tasks have had the opportunity
to run their course.

The job waiter now immediately starts with a counter of 1 and waits
until all the child and the parent job finish before considering the
work to be finished.

Finally, we also handle the case where a sync job can't be queued
because the calling context has been cancelled.
This commit is contained in:
Leander Beernaert
2023-11-29 13:52:12 +01:00
parent 9449177553
commit 7a1c7e8743
15 changed files with 100 additions and 78 deletions

View File

@ -307,7 +307,7 @@ func newBridge(
bridge.heartbeat.init(bridge, heartbeatManager) bridge.heartbeat.init(bridge, heartbeatManager)
} }
bridge.syncService.Run(bridge.tasks) bridge.syncService.Run()
return bridge, nil return bridge, nil
} }
@ -451,6 +451,8 @@ func (bridge *Bridge) Close(ctx context.Context) {
logrus.WithError(err).Error("Failed to close servers") logrus.WithError(err).Error("Failed to close servers")
} }
bridge.syncService.Close()
// Stop all ongoing tasks. // Stop all ongoing tasks.
bridge.tasks.CancelAndWait() bridge.tasks.CancelAndWait()

View File

@ -210,7 +210,11 @@ func (t *Handler) run(ctx context.Context,
stageContext.metadataFetched = syncStatus.NumSyncedMessages stageContext.metadataFetched = syncStatus.NumSyncedMessages
stageContext.totalMessageCount = syncStatus.TotalMessageCount stageContext.totalMessageCount = syncStatus.TotalMessageCount
t.regulator.Sync(ctx, stageContext) if err := t.regulator.Sync(ctx, stageContext); err != nil {
stageContext.onError(err)
_ = stageContext.waitAndClose(ctx)
return fmt.Errorf("failed to start sync job: %w", err)
}
// Wait on reply // Wait on reply
if err := stageContext.waitAndClose(ctx); err != nil { if err := stageContext.waitAndClose(ctx); err != nil {

View File

@ -64,7 +64,7 @@ func (s Status) InProgress() bool {
// Regulator is an abstraction for the sync service, since it regulates the number of concurrent sync activities. // Regulator is an abstraction for the sync service, since it regulates the number of concurrent sync activities.
type Regulator interface { type Regulator interface {
Sync(ctx context.Context, stage *Job) Sync(ctx context.Context, stage *Job) error
} }
type BuildResult struct { type BuildResult struct {

View File

@ -119,7 +119,6 @@ func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int
// begin is expected to be called once the job enters the pipeline. // begin is expected to be called once the job enters the pipeline.
func (j *Job) begin() { func (j *Job) begin() {
j.log.Info("Job started") j.log.Info("Job started")
j.jw.onTaskCreated()
} }
// end is expected to be called once the job has no further work left. // end is expected to be called once the job has no further work left.
@ -133,7 +132,6 @@ func (j *Job) waitAndClose(ctx context.Context) error {
defer j.close() defer j.close()
select { select {
case <-ctx.Done(): case <-ctx.Done():
j.jw.onContextCancelled()
<-j.jw.doneCh <-j.jw.doneCh
return ctx.Err() return ctx.Err()
case e := <-j.jw.doneCh: case e := <-j.jw.doneCh:
@ -227,7 +225,6 @@ type JobWaiterMessage int
const ( const (
JobWaiterMessageCreated JobWaiterMessage = iota JobWaiterMessageCreated JobWaiterMessage = iota
JobWaiterMessageFinished JobWaiterMessageFinished
JobWaiterMessageCtxErr
) )
type jobWaiterMessagePair struct { type jobWaiterMessagePair struct {
@ -248,7 +245,7 @@ type jobWaiter struct {
func newJobWaiter(log *logrus.Entry, panicHandler async.PanicHandler) *jobWaiter { func newJobWaiter(log *logrus.Entry, panicHandler async.PanicHandler) *jobWaiter {
return &jobWaiter{ return &jobWaiter{
ch: make(chan jobWaiterMessagePair), ch: make(chan jobWaiterMessagePair),
doneCh: make(chan error), doneCh: make(chan error, 2),
log: log, log: log,
panicHandler: panicHandler, panicHandler: panicHandler,
} }
@ -273,15 +270,11 @@ func (j *jobWaiter) onTaskCreated() {
j.sendMessage(JobWaiterMessageCreated, nil) j.sendMessage(JobWaiterMessageCreated, nil)
} }
func (j *jobWaiter) onContextCancelled() {
j.sendMessage(JobWaiterMessageCtxErr, nil)
}
func (j *jobWaiter) begin() { func (j *jobWaiter) begin() {
go func() { go func() {
defer async.HandlePanic(j.panicHandler) defer async.HandlePanic(j.panicHandler)
total := 0 total := 1
var err error var err error
defer func() { defer func() {
@ -296,8 +289,6 @@ func (j *jobWaiter) begin() {
} }
switch m.m { switch m.m {
case JobWaiterMessageCtxErr:
// DO nothing
case JobWaiterMessageCreated: case JobWaiterMessageCreated:
total++ total++
case JobWaiterMessageFinished: case JobWaiterMessageFinished:

View File

@ -83,6 +83,7 @@ func TestJob_WaitsOnAllChildrenOnError(t *testing.T) {
job1.onFinished(context.Background()) job1.onFinished(context.Background())
job2.onError(jobErr) job2.onError(jobErr)
tj.job.end()
}() }()
close(startCh) close(startCh)
@ -115,6 +116,7 @@ func TestJob_MultipleChildrenReportError(t *testing.T) {
} }
wg.Wait() wg.Wait()
tj.job.end()
close(startCh) close(startCh)
err := tj.job.waitAndClose(context.Background()) err := tj.job.waitAndClose(context.Background())
require.Error(t, err) require.Error(t, err)
@ -179,6 +181,7 @@ func TestJob_CtxCancelCancelsAllChildren(t *testing.T) {
go func() { go func() {
wg.Wait() wg.Wait()
tj.job.end()
cancel() cancel()
}() }()
@ -201,6 +204,7 @@ func TestJob_CtxCancelBeforeBegin(t *testing.T) {
go func() { go func() {
wg.Wait() wg.Wait()
cancel() cancel()
tj.job.end()
}() }()
wg.Done() wg.Done()

View File

@ -127,9 +127,11 @@ func (mr *MockBuildStageOutputMockRecorder) Close() *gomock.Call {
} }
// Produce mocks base method. // Produce mocks base method.
func (m *MockBuildStageOutput) Produce(arg0 context.Context, arg1 ApplyRequest) { func (m *MockBuildStageOutput) Produce(arg0 context.Context, arg1 ApplyRequest) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Produce", arg0, arg1) ret := m.ctrl.Call(m, "Produce", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
} }
// Produce indicates an expected call of Produce. // Produce indicates an expected call of Produce.
@ -212,9 +214,11 @@ func (mr *MockDownloadStageOutputMockRecorder) Close() *gomock.Call {
} }
// Produce mocks base method. // Produce mocks base method.
func (m *MockDownloadStageOutput) Produce(arg0 context.Context, arg1 BuildRequest) { func (m *MockDownloadStageOutput) Produce(arg0 context.Context, arg1 BuildRequest) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Produce", arg0, arg1) ret := m.ctrl.Call(m, "Produce", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
} }
// Produce indicates an expected call of Produce. // Produce indicates an expected call of Produce.
@ -297,9 +301,11 @@ func (mr *MockMetadataStageOutputMockRecorder) Close() *gomock.Call {
} }
// Produce mocks base method. // Produce mocks base method.
func (m *MockMetadataStageOutput) Produce(arg0 context.Context, arg1 DownloadRequest) { func (m *MockMetadataStageOutput) Produce(arg0 context.Context, arg1 DownloadRequest) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Produce", arg0, arg1) ret := m.ctrl.Call(m, "Produce", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
} }
// Produce indicates an expected call of Produce. // Produce indicates an expected call of Produce.
@ -478,9 +484,11 @@ func (m *MockRegulator) EXPECT() *MockRegulatorMockRecorder {
} }
// Sync mocks base method. // Sync mocks base method.
func (m *MockRegulator) Sync(arg0 context.Context, arg1 *Job) { func (m *MockRegulator) Sync(arg0 context.Context, arg1 *Job) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Sync", arg0, arg1) ret := m.ctrl.Call(m, "Sync", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
} }
// Sync indicates an expected call of Sync. // Sync indicates an expected call of Sync.

View File

@ -33,7 +33,7 @@ type Service struct {
applyStage *ApplyStage applyStage *ApplyStage
limits syncLimits limits syncLimits
metaCh *ChannelConsumerProducer[*Job] metaCh *ChannelConsumerProducer[*Job]
panicHandler async.PanicHandler group *async.Group
} }
func NewService(reporter reporter.Reporter, func NewService(reporter reporter.Reporter,
@ -53,26 +53,22 @@ func NewService(reporter reporter.Reporter,
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter), buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
applyStage: NewApplyStage(applyCh), applyStage: NewApplyStage(applyCh),
metaCh: metaCh, metaCh: metaCh,
panicHandler: panicHandler, group: async.NewGroup(context.Background(), panicHandler),
} }
} }
func (s *Service) Run(group *async.Group) { func (s *Service) Run() {
group.Once(func(ctx context.Context) { s.metadataStage.Run(s.group)
syncGroup := async.NewGroup(ctx, s.panicHandler) s.downloadStage.Run(s.group)
s.buildStage.Run(s.group)
s.metadataStage.Run(syncGroup) s.applyStage.Run(s.group)
s.downloadStage.Run(syncGroup)
s.buildStage.Run(syncGroup)
s.applyStage.Run(syncGroup)
defer s.metaCh.Close()
defer syncGroup.CancelAndWait()
<-ctx.Done()
})
} }
func (s *Service) Sync(ctx context.Context, stage *Job) { func (s *Service) Sync(ctx context.Context, stage *Job) error {
s.metaCh.Produce(ctx, stage) return s.metaCh.Produce(ctx, stage)
}
func (s *Service) Close() {
s.group.CancelAndWait()
s.metaCh.Close()
} }

View File

@ -50,10 +50,10 @@ func TestApplyStage_CancelledJobIsDiscarded(t *testing.T) {
}() }()
jobCancel() jobCancel()
input.Produce(ctx, ApplyRequest{ require.NoError(t, input.Produce(ctx, ApplyRequest{
childJob: childJob, childJob: childJob,
messages: nil, messages: nil,
}) }))
err := tj.job.waitAndClose(ctx) err := tj.job.waitAndClose(ctx)
require.ErrorIs(t, err, context.Canceled) require.ErrorIs(t, err, context.Canceled)
@ -84,10 +84,10 @@ func TestApplyStage_JobWithNoMessagesIsFinalized(t *testing.T) {
stage.run(ctx) stage.run(ctx)
}() }()
input.Produce(ctx, ApplyRequest{ require.NoError(t, input.Produce(ctx, ApplyRequest{
childJob: childJob, childJob: childJob,
messages: nil, messages: nil,
}) }))
err := tj.job.waitAndClose(ctx) err := tj.job.waitAndClose(ctx)
cancel() cancel()
@ -127,10 +127,10 @@ func TestApplyStage_ErrorOnApplyIsReportedAndJobFails(t *testing.T) {
stage.run(ctx) stage.run(ctx)
}() }()
input.Produce(ctx, ApplyRequest{ require.NoError(t, input.Produce(ctx, ApplyRequest{
childJob: childJob, childJob: childJob,
messages: buildResults, messages: buildResults,
}) }))
err := tj.job.waitAndClose(ctx) err := tj.job.waitAndClose(ctx)
cancel() cancel()

View File

@ -21,6 +21,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"runtime" "runtime"
"github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/async"
@ -182,10 +183,12 @@ func (b *BuildStage) run(ctx context.Context) {
outJob.onStageCompleted(ctx) outJob.onStageCompleted(ctx)
b.output.Produce(ctx, ApplyRequest{ if err := b.output.Produce(ctx, ApplyRequest{
childJob: outJob, childJob: outJob,
messages: success, messages: success,
}) }); err != nil {
return fmt.Errorf("failed to produce output for next stage: %w", err)
}
} }
return nil return nil

View File

@ -111,7 +111,7 @@ func TestBuildStage_SuccessRemovesFailedMessage(t *testing.T) {
stage.run(ctx) stage.run(ctx)
}() }()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}) require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
req, err := output.Consume(ctx) req, err := output.Consume(ctx)
cancel() cancel()
@ -170,7 +170,7 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) {
stage.run(ctx) stage.run(ctx)
}() }()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}) require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
req, err := output.Consume(ctx) req, err := output.Consume(ctx)
cancel() cancel()
@ -222,7 +222,7 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin
stage.run(ctx) stage.run(ctx)
}() }()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}) require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
req, err := output.Consume(ctx) req, err := output.Consume(ctx)
cancel() cancel()
@ -267,7 +267,7 @@ func TestBuildStage_OtherErrorsFailJob(t *testing.T) {
stage.run(ctx) stage.run(ctx)
}() }()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}) require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
err := tj.job.waitAndClose(ctx) err := tj.job.waitAndClose(ctx)
require.Equal(t, expectedErr, err) require.Equal(t, expectedErr, err)
@ -311,10 +311,10 @@ func TestBuildStage_CancelledJobIsDiscarded(t *testing.T) {
}() }()
jobCancel() jobCancel()
input.Produce(ctx, BuildRequest{ require.NoError(t, input.Produce(ctx, BuildRequest{
childJob: childJob, childJob: childJob,
batch: []proton.FullMessage{msg}, batch: []proton.FullMessage{msg},
}) }))
go func() { cancel() }() go func() { cancel() }()

View File

@ -21,6 +21,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"sync/atomic" "sync/atomic"
"github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/async"
@ -183,10 +184,12 @@ func (d *DownloadStage) run(ctx context.Context) {
// Step 5: Publish result. // Step 5: Publish result.
request.onStageCompleted(ctx) request.onStageCompleted(ctx)
d.output.Produce(ctx, BuildRequest{ if err := d.output.Produce(ctx, BuildRequest{
batch: result, batch: result,
childJob: request.childJob, childJob: request.childJob,
}) }); err != nil {
request.job.onError(fmt.Errorf("failed to produce output for next stage: %w", err))
}
} }
} }

View File

@ -189,10 +189,10 @@ func TestDownloadStage_Run(t *testing.T) {
stage.run(ctx) stage.run(ctx)
}() }()
input.Produce(ctx, DownloadRequest{ require.NoError(t, input.Produce(ctx, DownloadRequest{
childJob: childJob, childJob: childJob,
ids: msgIDs, ids: msgIDs,
}) }))
out, err := output.Consume(ctx) out, err := output.Consume(ctx)
require.NoError(t, err) require.NoError(t, err)
@ -232,10 +232,10 @@ func TestDownloadStage_RunWith422(t *testing.T) {
stage.run(ctx) stage.run(ctx)
}() }()
input.Produce(ctx, DownloadRequest{ require.NoError(t, input.Produce(ctx, DownloadRequest{
childJob: childJob, childJob: childJob,
ids: msgIDs, ids: msgIDs,
}) }))
out, err := output.Consume(ctx) out, err := output.Consume(ctx)
require.NoError(t, err) require.NoError(t, err)
@ -271,10 +271,11 @@ func TestDownloadStage_CancelledJobIsDiscarded(t *testing.T) {
}() }()
jobCancel() jobCancel()
input.Produce(ctx, DownloadRequest{
require.NoError(t, input.Produce(ctx, DownloadRequest{
childJob: childJob, childJob: childJob,
ids: nil, ids: nil,
}) }))
go func() { cancel() }() go func() { cancel() }()
@ -308,10 +309,10 @@ func TestDownloadStage_JobAbortsOnMessageDownloadError(t *testing.T) {
stage.run(ctx) stage.run(ctx)
}() }()
input.Produce(ctx, DownloadRequest{ require.NoError(t, input.Produce(ctx, DownloadRequest{
childJob: childJob, childJob: childJob,
ids: []string{"foo"}, ids: []string{"foo"},
}) }))
err := tj.job.waitAndClose(ctx) err := tj.job.waitAndClose(ctx)
require.Equal(t, expectedErr, err) require.Equal(t, expectedErr, err)
@ -359,10 +360,10 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) {
stage.run(ctx) stage.run(ctx)
}() }()
input.Produce(ctx, DownloadRequest{ require.NoError(t, input.Produce(ctx, DownloadRequest{
childJob: childJob, childJob: childJob,
ids: []string{"foo"}, ids: []string{"foo"},
}) }))
err := tj.job.waitAndClose(ctx) err := tj.job.waitAndClose(ctx)
require.Equal(t, expectedErr, err) require.Equal(t, expectedErr, err)

View File

@ -20,6 +20,7 @@ package syncservice
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/logging" "github.com/ProtonMail/gluon/logging"
@ -87,10 +88,6 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag
return return
} }
if job.ctx.Err() != nil {
continue
}
job.begin() job.begin()
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown) state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
if err != nil { if err != nil {
@ -119,7 +116,10 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag
output.onStageCompleted(ctx) output.onStageCompleted(ctx)
m.output.Produce(ctx, output) if err := m.output.Produce(ctx, output); err != nil {
job.onError(fmt.Errorf("failed to produce output for next stage: %w", err))
return
}
} }
// If this job has no more work left, signal completion. // If this job has no more work left, signal completion.

View File

@ -21,6 +21,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"sync"
"testing" "testing"
"github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/async"
@ -59,7 +60,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) {
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{}) metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
}() }()
input.Produce(ctx, tj.job) require.NoError(t, input.Produce(ctx, tj.job))
for _, chunk := range xslices.Chunk(msgs, TestMaxMessages) { for _, chunk := range xslices.Chunk(msgs, TestMaxMessages) {
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(len(chunk)))) tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(len(chunk))))
@ -93,7 +94,10 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{}) metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
}() }()
input.Produce(ctx, tj.job) {
err := input.Produce(ctx, tj.job)
require.NoError(t, err)
}
// read one output then cancel // read one output then cancel
request, err := output.Consume(ctx) request, err := output.Consume(ctx)
@ -102,8 +106,11 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
// cancel job context // cancel job context
jobCancel() jobCancel()
wg := sync.WaitGroup{}
wg.Add(1)
// The next stages should check whether the job has been cancelled or not. Here we need to do it manually. // The next stages should check whether the job has been cancelled or not. Here we need to do it manually.
go func() { go func() {
wg.Done()
for { for {
req, err := output.Consume(ctx) req, err := output.Consume(ctx)
if err != nil { if err != nil {
@ -113,8 +120,9 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
req.checkCancelled() req.checkCancelled()
} }
}() }()
wg.Wait()
err = tj.job.waitAndClose(ctx) err = tj.job.waitAndClose(ctx)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled) require.ErrorIs(t, err, context.Canceled)
cancel() cancel()
} }
@ -149,8 +157,8 @@ func TestMetadataStage_RunInterleaved(t *testing.T) {
}() }()
go func() { go func() {
input.Produce(ctx, tj1.job) require.NoError(t, input.Produce(ctx, tj1.job))
input.Produce(ctx, tj2.job) require.NoError(t, input.Produce(ctx, tj2.job))
}() }()
go func() { go func() {

View File

@ -23,7 +23,7 @@ import (
) )
type StageOutputProducer[T any] interface { type StageOutputProducer[T any] interface {
Produce(ctx context.Context, value T) Produce(ctx context.Context, value T) error
Close() Close()
} }
@ -41,10 +41,12 @@ func NewChannelConsumerProducer[T any]() *ChannelConsumerProducer[T] {
return &ChannelConsumerProducer[T]{ch: make(chan T)} return &ChannelConsumerProducer[T]{ch: make(chan T)}
} }
func (c ChannelConsumerProducer[T]) Produce(ctx context.Context, value T) { func (c ChannelConsumerProducer[T]) Produce(ctx context.Context, value T) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err()
case c.ch <- value: case c.ch <- value:
return nil
} }
} }