forked from Silverfish/proton-bridge
fix(GODT-3124): Handling of sync child jobs
Improve the handling of sync child jobs to ensure it behaves correctly in all scenarios. The sync service now uses a isolated context to avoid all the pipeline stages shutting down before all the sync tasks have had the opportunity to run their course. The job waiter now immediately starts with a counter of 1 and waits until all the child and the parent job finish before considering the work to be finished. Finally, we also handle the case where a sync job can't be queued because the calling context has been cancelled.
This commit is contained in:
@ -307,7 +307,7 @@ func newBridge(
|
|||||||
bridge.heartbeat.init(bridge, heartbeatManager)
|
bridge.heartbeat.init(bridge, heartbeatManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
bridge.syncService.Run(bridge.tasks)
|
bridge.syncService.Run()
|
||||||
|
|
||||||
return bridge, nil
|
return bridge, nil
|
||||||
}
|
}
|
||||||
@ -451,6 +451,8 @@ func (bridge *Bridge) Close(ctx context.Context) {
|
|||||||
logrus.WithError(err).Error("Failed to close servers")
|
logrus.WithError(err).Error("Failed to close servers")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bridge.syncService.Close()
|
||||||
|
|
||||||
// Stop all ongoing tasks.
|
// Stop all ongoing tasks.
|
||||||
bridge.tasks.CancelAndWait()
|
bridge.tasks.CancelAndWait()
|
||||||
|
|
||||||
|
|||||||
@ -210,7 +210,11 @@ func (t *Handler) run(ctx context.Context,
|
|||||||
stageContext.metadataFetched = syncStatus.NumSyncedMessages
|
stageContext.metadataFetched = syncStatus.NumSyncedMessages
|
||||||
stageContext.totalMessageCount = syncStatus.TotalMessageCount
|
stageContext.totalMessageCount = syncStatus.TotalMessageCount
|
||||||
|
|
||||||
t.regulator.Sync(ctx, stageContext)
|
if err := t.regulator.Sync(ctx, stageContext); err != nil {
|
||||||
|
stageContext.onError(err)
|
||||||
|
_ = stageContext.waitAndClose(ctx)
|
||||||
|
return fmt.Errorf("failed to start sync job: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Wait on reply
|
// Wait on reply
|
||||||
if err := stageContext.waitAndClose(ctx); err != nil {
|
if err := stageContext.waitAndClose(ctx); err != nil {
|
||||||
|
|||||||
@ -64,7 +64,7 @@ func (s Status) InProgress() bool {
|
|||||||
|
|
||||||
// Regulator is an abstraction for the sync service, since it regulates the number of concurrent sync activities.
|
// Regulator is an abstraction for the sync service, since it regulates the number of concurrent sync activities.
|
||||||
type Regulator interface {
|
type Regulator interface {
|
||||||
Sync(ctx context.Context, stage *Job)
|
Sync(ctx context.Context, stage *Job) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type BuildResult struct {
|
type BuildResult struct {
|
||||||
|
|||||||
@ -119,7 +119,6 @@ func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int
|
|||||||
// begin is expected to be called once the job enters the pipeline.
|
// begin is expected to be called once the job enters the pipeline.
|
||||||
func (j *Job) begin() {
|
func (j *Job) begin() {
|
||||||
j.log.Info("Job started")
|
j.log.Info("Job started")
|
||||||
j.jw.onTaskCreated()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// end is expected to be called once the job has no further work left.
|
// end is expected to be called once the job has no further work left.
|
||||||
@ -133,7 +132,6 @@ func (j *Job) waitAndClose(ctx context.Context) error {
|
|||||||
defer j.close()
|
defer j.close()
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
j.jw.onContextCancelled()
|
|
||||||
<-j.jw.doneCh
|
<-j.jw.doneCh
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case e := <-j.jw.doneCh:
|
case e := <-j.jw.doneCh:
|
||||||
@ -227,7 +225,6 @@ type JobWaiterMessage int
|
|||||||
const (
|
const (
|
||||||
JobWaiterMessageCreated JobWaiterMessage = iota
|
JobWaiterMessageCreated JobWaiterMessage = iota
|
||||||
JobWaiterMessageFinished
|
JobWaiterMessageFinished
|
||||||
JobWaiterMessageCtxErr
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type jobWaiterMessagePair struct {
|
type jobWaiterMessagePair struct {
|
||||||
@ -248,7 +245,7 @@ type jobWaiter struct {
|
|||||||
func newJobWaiter(log *logrus.Entry, panicHandler async.PanicHandler) *jobWaiter {
|
func newJobWaiter(log *logrus.Entry, panicHandler async.PanicHandler) *jobWaiter {
|
||||||
return &jobWaiter{
|
return &jobWaiter{
|
||||||
ch: make(chan jobWaiterMessagePair),
|
ch: make(chan jobWaiterMessagePair),
|
||||||
doneCh: make(chan error),
|
doneCh: make(chan error, 2),
|
||||||
log: log,
|
log: log,
|
||||||
panicHandler: panicHandler,
|
panicHandler: panicHandler,
|
||||||
}
|
}
|
||||||
@ -273,15 +270,11 @@ func (j *jobWaiter) onTaskCreated() {
|
|||||||
j.sendMessage(JobWaiterMessageCreated, nil)
|
j.sendMessage(JobWaiterMessageCreated, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *jobWaiter) onContextCancelled() {
|
|
||||||
j.sendMessage(JobWaiterMessageCtxErr, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *jobWaiter) begin() {
|
func (j *jobWaiter) begin() {
|
||||||
go func() {
|
go func() {
|
||||||
defer async.HandlePanic(j.panicHandler)
|
defer async.HandlePanic(j.panicHandler)
|
||||||
|
|
||||||
total := 0
|
total := 1
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -296,8 +289,6 @@ func (j *jobWaiter) begin() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch m.m {
|
switch m.m {
|
||||||
case JobWaiterMessageCtxErr:
|
|
||||||
// DO nothing
|
|
||||||
case JobWaiterMessageCreated:
|
case JobWaiterMessageCreated:
|
||||||
total++
|
total++
|
||||||
case JobWaiterMessageFinished:
|
case JobWaiterMessageFinished:
|
||||||
|
|||||||
@ -83,6 +83,7 @@ func TestJob_WaitsOnAllChildrenOnError(t *testing.T) {
|
|||||||
|
|
||||||
job1.onFinished(context.Background())
|
job1.onFinished(context.Background())
|
||||||
job2.onError(jobErr)
|
job2.onError(jobErr)
|
||||||
|
tj.job.end()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
close(startCh)
|
close(startCh)
|
||||||
@ -115,6 +116,7 @@ func TestJob_MultipleChildrenReportError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
tj.job.end()
|
||||||
close(startCh)
|
close(startCh)
|
||||||
err := tj.job.waitAndClose(context.Background())
|
err := tj.job.waitAndClose(context.Background())
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@ -179,6 +181,7 @@ func TestJob_CtxCancelCancelsAllChildren(t *testing.T) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
tj.job.end()
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -201,6 +204,7 @@ func TestJob_CtxCancelBeforeBegin(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
cancel()
|
cancel()
|
||||||
|
tj.job.end()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
wg.Done()
|
wg.Done()
|
||||||
|
|||||||
@ -127,9 +127,11 @@ func (mr *MockBuildStageOutputMockRecorder) Close() *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Produce mocks base method.
|
// Produce mocks base method.
|
||||||
func (m *MockBuildStageOutput) Produce(arg0 context.Context, arg1 ApplyRequest) {
|
func (m *MockBuildStageOutput) Produce(arg0 context.Context, arg1 ApplyRequest) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "Produce", arg0, arg1)
|
ret := m.ctrl.Call(m, "Produce", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Produce indicates an expected call of Produce.
|
// Produce indicates an expected call of Produce.
|
||||||
@ -212,9 +214,11 @@ func (mr *MockDownloadStageOutputMockRecorder) Close() *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Produce mocks base method.
|
// Produce mocks base method.
|
||||||
func (m *MockDownloadStageOutput) Produce(arg0 context.Context, arg1 BuildRequest) {
|
func (m *MockDownloadStageOutput) Produce(arg0 context.Context, arg1 BuildRequest) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "Produce", arg0, arg1)
|
ret := m.ctrl.Call(m, "Produce", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Produce indicates an expected call of Produce.
|
// Produce indicates an expected call of Produce.
|
||||||
@ -297,9 +301,11 @@ func (mr *MockMetadataStageOutputMockRecorder) Close() *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Produce mocks base method.
|
// Produce mocks base method.
|
||||||
func (m *MockMetadataStageOutput) Produce(arg0 context.Context, arg1 DownloadRequest) {
|
func (m *MockMetadataStageOutput) Produce(arg0 context.Context, arg1 DownloadRequest) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "Produce", arg0, arg1)
|
ret := m.ctrl.Call(m, "Produce", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Produce indicates an expected call of Produce.
|
// Produce indicates an expected call of Produce.
|
||||||
@ -478,9 +484,11 @@ func (m *MockRegulator) EXPECT() *MockRegulatorMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sync mocks base method.
|
// Sync mocks base method.
|
||||||
func (m *MockRegulator) Sync(arg0 context.Context, arg1 *Job) {
|
func (m *MockRegulator) Sync(arg0 context.Context, arg1 *Job) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "Sync", arg0, arg1)
|
ret := m.ctrl.Call(m, "Sync", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync indicates an expected call of Sync.
|
// Sync indicates an expected call of Sync.
|
||||||
|
|||||||
@ -33,7 +33,7 @@ type Service struct {
|
|||||||
applyStage *ApplyStage
|
applyStage *ApplyStage
|
||||||
limits syncLimits
|
limits syncLimits
|
||||||
metaCh *ChannelConsumerProducer[*Job]
|
metaCh *ChannelConsumerProducer[*Job]
|
||||||
panicHandler async.PanicHandler
|
group *async.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(reporter reporter.Reporter,
|
func NewService(reporter reporter.Reporter,
|
||||||
@ -53,26 +53,22 @@ func NewService(reporter reporter.Reporter,
|
|||||||
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
|
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
|
||||||
applyStage: NewApplyStage(applyCh),
|
applyStage: NewApplyStage(applyCh),
|
||||||
metaCh: metaCh,
|
metaCh: metaCh,
|
||||||
panicHandler: panicHandler,
|
group: async.NewGroup(context.Background(), panicHandler),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run(group *async.Group) {
|
func (s *Service) Run() {
|
||||||
group.Once(func(ctx context.Context) {
|
s.metadataStage.Run(s.group)
|
||||||
syncGroup := async.NewGroup(ctx, s.panicHandler)
|
s.downloadStage.Run(s.group)
|
||||||
|
s.buildStage.Run(s.group)
|
||||||
s.metadataStage.Run(syncGroup)
|
s.applyStage.Run(s.group)
|
||||||
s.downloadStage.Run(syncGroup)
|
|
||||||
s.buildStage.Run(syncGroup)
|
|
||||||
s.applyStage.Run(syncGroup)
|
|
||||||
|
|
||||||
defer s.metaCh.Close()
|
|
||||||
defer syncGroup.CancelAndWait()
|
|
||||||
|
|
||||||
<-ctx.Done()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Sync(ctx context.Context, stage *Job) {
|
func (s *Service) Sync(ctx context.Context, stage *Job) error {
|
||||||
s.metaCh.Produce(ctx, stage)
|
return s.metaCh.Produce(ctx, stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Close() {
|
||||||
|
s.group.CancelAndWait()
|
||||||
|
s.metaCh.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -50,10 +50,10 @@ func TestApplyStage_CancelledJobIsDiscarded(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
jobCancel()
|
jobCancel()
|
||||||
input.Produce(ctx, ApplyRequest{
|
require.NoError(t, input.Produce(ctx, ApplyRequest{
|
||||||
childJob: childJob,
|
childJob: childJob,
|
||||||
messages: nil,
|
messages: nil,
|
||||||
})
|
}))
|
||||||
|
|
||||||
err := tj.job.waitAndClose(ctx)
|
err := tj.job.waitAndClose(ctx)
|
||||||
require.ErrorIs(t, err, context.Canceled)
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
@ -84,10 +84,10 @@ func TestApplyStage_JobWithNoMessagesIsFinalized(t *testing.T) {
|
|||||||
stage.run(ctx)
|
stage.run(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, ApplyRequest{
|
require.NoError(t, input.Produce(ctx, ApplyRequest{
|
||||||
childJob: childJob,
|
childJob: childJob,
|
||||||
messages: nil,
|
messages: nil,
|
||||||
})
|
}))
|
||||||
|
|
||||||
err := tj.job.waitAndClose(ctx)
|
err := tj.job.waitAndClose(ctx)
|
||||||
cancel()
|
cancel()
|
||||||
@ -127,10 +127,10 @@ func TestApplyStage_ErrorOnApplyIsReportedAndJobFails(t *testing.T) {
|
|||||||
stage.run(ctx)
|
stage.run(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, ApplyRequest{
|
require.NoError(t, input.Produce(ctx, ApplyRequest{
|
||||||
childJob: childJob,
|
childJob: childJob,
|
||||||
messages: buildResults,
|
messages: buildResults,
|
||||||
})
|
}))
|
||||||
|
|
||||||
err := tj.job.waitAndClose(ctx)
|
err := tj.job.waitAndClose(ctx)
|
||||||
cancel()
|
cancel()
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/async"
|
"github.com/ProtonMail/gluon/async"
|
||||||
@ -182,10 +183,12 @@ func (b *BuildStage) run(ctx context.Context) {
|
|||||||
|
|
||||||
outJob.onStageCompleted(ctx)
|
outJob.onStageCompleted(ctx)
|
||||||
|
|
||||||
b.output.Produce(ctx, ApplyRequest{
|
if err := b.output.Produce(ctx, ApplyRequest{
|
||||||
childJob: outJob,
|
childJob: outJob,
|
||||||
messages: success,
|
messages: success,
|
||||||
})
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to produce output for next stage: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -111,7 +111,7 @@ func TestBuildStage_SuccessRemovesFailedMessage(t *testing.T) {
|
|||||||
stage.run(ctx)
|
stage.run(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
|
require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
|
||||||
|
|
||||||
req, err := output.Consume(ctx)
|
req, err := output.Consume(ctx)
|
||||||
cancel()
|
cancel()
|
||||||
@ -170,7 +170,7 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) {
|
|||||||
stage.run(ctx)
|
stage.run(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
|
require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
|
||||||
|
|
||||||
req, err := output.Consume(ctx)
|
req, err := output.Consume(ctx)
|
||||||
cancel()
|
cancel()
|
||||||
@ -222,7 +222,7 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin
|
|||||||
stage.run(ctx)
|
stage.run(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
|
require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
|
||||||
|
|
||||||
req, err := output.Consume(ctx)
|
req, err := output.Consume(ctx)
|
||||||
cancel()
|
cancel()
|
||||||
@ -267,7 +267,7 @@ func TestBuildStage_OtherErrorsFailJob(t *testing.T) {
|
|||||||
stage.run(ctx)
|
stage.run(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
|
require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
|
||||||
|
|
||||||
err := tj.job.waitAndClose(ctx)
|
err := tj.job.waitAndClose(ctx)
|
||||||
require.Equal(t, expectedErr, err)
|
require.Equal(t, expectedErr, err)
|
||||||
@ -311,10 +311,10 @@ func TestBuildStage_CancelledJobIsDiscarded(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
jobCancel()
|
jobCancel()
|
||||||
input.Produce(ctx, BuildRequest{
|
require.NoError(t, input.Produce(ctx, BuildRequest{
|
||||||
childJob: childJob,
|
childJob: childJob,
|
||||||
batch: []proton.FullMessage{msg},
|
batch: []proton.FullMessage{msg},
|
||||||
})
|
}))
|
||||||
|
|
||||||
go func() { cancel() }()
|
go func() { cancel() }()
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/async"
|
"github.com/ProtonMail/gluon/async"
|
||||||
@ -183,10 +184,12 @@ func (d *DownloadStage) run(ctx context.Context) {
|
|||||||
// Step 5: Publish result.
|
// Step 5: Publish result.
|
||||||
request.onStageCompleted(ctx)
|
request.onStageCompleted(ctx)
|
||||||
|
|
||||||
d.output.Produce(ctx, BuildRequest{
|
if err := d.output.Produce(ctx, BuildRequest{
|
||||||
batch: result,
|
batch: result,
|
||||||
childJob: request.childJob,
|
childJob: request.childJob,
|
||||||
})
|
}); err != nil {
|
||||||
|
request.job.onError(fmt.Errorf("failed to produce output for next stage: %w", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -189,10 +189,10 @@ func TestDownloadStage_Run(t *testing.T) {
|
|||||||
stage.run(ctx)
|
stage.run(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, DownloadRequest{
|
require.NoError(t, input.Produce(ctx, DownloadRequest{
|
||||||
childJob: childJob,
|
childJob: childJob,
|
||||||
ids: msgIDs,
|
ids: msgIDs,
|
||||||
})
|
}))
|
||||||
|
|
||||||
out, err := output.Consume(ctx)
|
out, err := output.Consume(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -232,10 +232,10 @@ func TestDownloadStage_RunWith422(t *testing.T) {
|
|||||||
stage.run(ctx)
|
stage.run(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, DownloadRequest{
|
require.NoError(t, input.Produce(ctx, DownloadRequest{
|
||||||
childJob: childJob,
|
childJob: childJob,
|
||||||
ids: msgIDs,
|
ids: msgIDs,
|
||||||
})
|
}))
|
||||||
|
|
||||||
out, err := output.Consume(ctx)
|
out, err := output.Consume(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -271,10 +271,11 @@ func TestDownloadStage_CancelledJobIsDiscarded(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
jobCancel()
|
jobCancel()
|
||||||
input.Produce(ctx, DownloadRequest{
|
|
||||||
|
require.NoError(t, input.Produce(ctx, DownloadRequest{
|
||||||
childJob: childJob,
|
childJob: childJob,
|
||||||
ids: nil,
|
ids: nil,
|
||||||
})
|
}))
|
||||||
|
|
||||||
go func() { cancel() }()
|
go func() { cancel() }()
|
||||||
|
|
||||||
@ -308,10 +309,10 @@ func TestDownloadStage_JobAbortsOnMessageDownloadError(t *testing.T) {
|
|||||||
stage.run(ctx)
|
stage.run(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, DownloadRequest{
|
require.NoError(t, input.Produce(ctx, DownloadRequest{
|
||||||
childJob: childJob,
|
childJob: childJob,
|
||||||
ids: []string{"foo"},
|
ids: []string{"foo"},
|
||||||
})
|
}))
|
||||||
|
|
||||||
err := tj.job.waitAndClose(ctx)
|
err := tj.job.waitAndClose(ctx)
|
||||||
require.Equal(t, expectedErr, err)
|
require.Equal(t, expectedErr, err)
|
||||||
@ -359,10 +360,10 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) {
|
|||||||
stage.run(ctx)
|
stage.run(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, DownloadRequest{
|
require.NoError(t, input.Produce(ctx, DownloadRequest{
|
||||||
childJob: childJob,
|
childJob: childJob,
|
||||||
ids: []string{"foo"},
|
ids: []string{"foo"},
|
||||||
})
|
}))
|
||||||
|
|
||||||
err := tj.job.waitAndClose(ctx)
|
err := tj.job.waitAndClose(ctx)
|
||||||
require.Equal(t, expectedErr, err)
|
require.Equal(t, expectedErr, err)
|
||||||
|
|||||||
@ -20,6 +20,7 @@ package syncservice
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/async"
|
"github.com/ProtonMail/gluon/async"
|
||||||
"github.com/ProtonMail/gluon/logging"
|
"github.com/ProtonMail/gluon/logging"
|
||||||
@ -87,10 +88,6 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if job.ctx.Err() != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
job.begin()
|
job.begin()
|
||||||
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
|
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -119,7 +116,10 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag
|
|||||||
|
|
||||||
output.onStageCompleted(ctx)
|
output.onStageCompleted(ctx)
|
||||||
|
|
||||||
m.output.Produce(ctx, output)
|
if err := m.output.Produce(ctx, output); err != nil {
|
||||||
|
job.onError(fmt.Errorf("failed to produce output for next stage: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this job has no more work left, signal completion.
|
// If this job has no more work left, signal completion.
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/async"
|
"github.com/ProtonMail/gluon/async"
|
||||||
@ -59,7 +60,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) {
|
|||||||
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
|
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, tj.job)
|
require.NoError(t, input.Produce(ctx, tj.job))
|
||||||
|
|
||||||
for _, chunk := range xslices.Chunk(msgs, TestMaxMessages) {
|
for _, chunk := range xslices.Chunk(msgs, TestMaxMessages) {
|
||||||
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(len(chunk))))
|
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(len(chunk))))
|
||||||
@ -93,7 +94,10 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
|
|||||||
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
|
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
|
||||||
}()
|
}()
|
||||||
|
|
||||||
input.Produce(ctx, tj.job)
|
{
|
||||||
|
err := input.Produce(ctx, tj.job)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
// read one output then cancel
|
// read one output then cancel
|
||||||
request, err := output.Consume(ctx)
|
request, err := output.Consume(ctx)
|
||||||
@ -102,8 +106,11 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
|
|||||||
// cancel job context
|
// cancel job context
|
||||||
jobCancel()
|
jobCancel()
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
// The next stages should check whether the job has been cancelled or not. Here we need to do it manually.
|
// The next stages should check whether the job has been cancelled or not. Here we need to do it manually.
|
||||||
go func() {
|
go func() {
|
||||||
|
wg.Done()
|
||||||
for {
|
for {
|
||||||
req, err := output.Consume(ctx)
|
req, err := output.Consume(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -113,8 +120,9 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
|
|||||||
req.checkCancelled()
|
req.checkCancelled()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
wg.Wait()
|
||||||
err = tj.job.waitAndClose(ctx)
|
err = tj.job.waitAndClose(ctx)
|
||||||
|
require.Error(t, err)
|
||||||
require.ErrorIs(t, err, context.Canceled)
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
@ -149,8 +157,8 @@ func TestMetadataStage_RunInterleaved(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
input.Produce(ctx, tj1.job)
|
require.NoError(t, input.Produce(ctx, tj1.job))
|
||||||
input.Produce(ctx, tj2.job)
|
require.NoError(t, input.Produce(ctx, tj2.job))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@ -23,7 +23,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type StageOutputProducer[T any] interface {
|
type StageOutputProducer[T any] interface {
|
||||||
Produce(ctx context.Context, value T)
|
Produce(ctx context.Context, value T) error
|
||||||
Close()
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,10 +41,12 @@ func NewChannelConsumerProducer[T any]() *ChannelConsumerProducer[T] {
|
|||||||
return &ChannelConsumerProducer[T]{ch: make(chan T)}
|
return &ChannelConsumerProducer[T]{ch: make(chan T)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c ChannelConsumerProducer[T]) Produce(ctx context.Context, value T) {
|
func (c ChannelConsumerProducer[T]) Produce(ctx context.Context, value T) error {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
case c.ch <- value:
|
case c.ch <- value:
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user