fix(GODT-3124): Race condition in sync task waiter

Fix incorrect use of `sync.WaitGroup` use to wait on sync jobs that
fail. After calling `WaitGroup.Wait()` it is advised to call
`WaitGroup.Add` until the existing counter has reached 0.

The code has been updated with a different mechanism that achieves the
same behavior which was previously available.
This commit is contained in:
Leander Beernaert
2023-11-28 08:04:57 +01:00
parent 6d7c21b2c9
commit 7d13c99710
8 changed files with 185 additions and 70 deletions

View File

@ -19,9 +19,6 @@ package syncservice
import (
"context"
"errors"
"fmt"
"sync"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
@ -48,10 +45,8 @@ type Job struct {
updateApplier UpdateApplier
syncReporter Reporter
log *logrus.Entry
errorCh *async.QueuedChannel[error]
wg sync.WaitGroup
once sync.Once
log *logrus.Entry
jw *jobWaiter
panicHandler async.PanicHandler
downloadCache *DownloadCache
@ -74,7 +69,7 @@ func NewJob(ctx context.Context,
) *Job {
ctx, cancel := context.WithCancel(ctx)
return &Job{
j := &Job{
ctx: ctx,
client: client,
userID: userID,
@ -85,26 +80,23 @@ func NewJob(ctx context.Context,
messageBuilder: messageBuilder,
updateApplier: updateApplier,
syncReporter: syncReporter,
errorCh: async.NewQueuedChannel[error](4, 8, panicHandler, fmt.Sprintf("sync-job-error-%v", userID)),
panicHandler: panicHandler,
downloadCache: cache,
jw: newJobWaiter(log.WithField("sync-job", "waiter"), panicHandler),
}
j.jw.begin()
return j
}
func (j *Job) Close() {
j.errorCh.CloseAndDiscardQueued()
j.wg.Wait()
func (j *Job) close() {
j.jw.close()
}
func (j *Job) onError(err error) {
defer j.wg.Done()
defer j.jw.onTaskFinished(err)
// context cancelled is caught & handled in a different location.
if errors.Is(err, context.Canceled) {
return
}
j.errorCh.Enqueue(err)
j.cancel()
}
@ -119,55 +111,42 @@ func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int
return
}
// j.onError() also calls j.wg.Done().
j.wg.Done()
// j.onError() also calls j.jw.onTaskFinished().
defer j.jw.onTaskFinished(nil)
j.syncReporter.OnProgress(ctx, count)
}
// begin is expected to be called once the job enters the pipeline.
func (j *Job) begin() {
j.log.Info("Job started")
j.wg.Add(1)
j.startChildWaiter()
j.jw.onTaskCreated()
}
// end is expected to be called once the job has no further work left.
func (j *Job) end() {
j.log.Info("Job finished")
j.wg.Done()
j.jw.onTaskFinished(nil)
}
// wait waits until the job has finished, the context got cancelled or an error occurred.
func (j *Job) wait(ctx context.Context) error {
defer j.wg.Wait()
// waitAndClose waits until the job has finished, the context got cancelled or an error occurred.
func (j *Job) waitAndClose(ctx context.Context) error {
defer j.close()
select {
case <-ctx.Done():
j.cancel()
j.jw.onContextCancelled()
<-j.jw.doneCh
return ctx.Err()
case err := <-j.errorCh.GetChannel():
return err
case e := <-j.jw.doneCh:
return e
}
}
func (j *Job) newChildJob(messageID string, messageCount int64) childJob {
j.log.Infof("Creating new child job")
j.wg.Add(1)
j.jw.onTaskCreated()
return childJob{job: j, lastMessageID: messageID, messageCount: messageCount}
}
func (j *Job) startChildWaiter() {
j.once.Do(func() {
go func() {
defer async.HandlePanic(j.panicHandler)
j.wg.Wait()
j.log.Info("All child jobs succeeded")
j.errorCh.Enqueue(j.ctx.Err())
}()
})
}
// childJob represents a batch of work that goes down the pipeline. It keeps track of the message ID that is in the
// batch and the number of messages in the batch.
type childJob struct {
@ -232,7 +211,7 @@ func (s *childJob) checkCancelled() bool {
err := s.job.ctx.Err()
if err != nil {
s.job.log.Infof("Child job exit due to context cancelled")
s.job.wg.Done()
s.job.jw.onTaskFinished(err)
return true
}
@ -242,3 +221,102 @@ func (s *childJob) checkCancelled() bool {
func (s *childJob) getContext() context.Context {
return s.job.ctx
}
type JobWaiterMessage int
const (
JobWaiterMessageCreated JobWaiterMessage = iota
JobWaiterMessageFinished
JobWaiterMessageCtxErr
)
type jobWaiterMessagePair struct {
m JobWaiterMessage
err error
}
// jobWaiter is meant to be used to track ongoing sync batches. Once all the child jobs
// have completed, the first recorded error (if any) will be written to doneCh and then this
// channel will be closed.
type jobWaiter struct {
ch chan jobWaiterMessagePair
doneCh chan error
log *logrus.Entry
panicHandler async.PanicHandler
}
func newJobWaiter(log *logrus.Entry, panicHandler async.PanicHandler) *jobWaiter {
return &jobWaiter{
ch: make(chan jobWaiterMessagePair),
doneCh: make(chan error),
log: log,
panicHandler: panicHandler,
}
}
func (j *jobWaiter) close() {
close(j.ch)
}
func (j *jobWaiter) sendMessage(m JobWaiterMessage, err error) {
j.ch <- jobWaiterMessagePair{
m: m,
err: err,
}
}
func (j *jobWaiter) onTaskFinished(err error) {
j.sendMessage(JobWaiterMessageFinished, err)
}
func (j *jobWaiter) onTaskCreated() {
j.sendMessage(JobWaiterMessageCreated, nil)
}
func (j *jobWaiter) onContextCancelled() {
j.sendMessage(JobWaiterMessageCtxErr, nil)
}
func (j *jobWaiter) begin() {
go func() {
defer async.HandlePanic(j.panicHandler)
total := 0
var err error
defer func() {
j.doneCh <- err
close(j.doneCh)
}()
for {
m, ok := <-j.ch
if !ok {
return
}
switch m.m {
case JobWaiterMessageCtxErr:
// DO nothing
case JobWaiterMessageCreated:
total++
case JobWaiterMessageFinished:
total--
if m.err != nil && err == nil {
err = m.err
}
default:
j.log.Errorf("Unknown message type: %v", m.m)
continue
}
if total <= 0 {
if total < 0 {
logrus.Errorf("Child count less than 0, shouldn't happen...")
}
j.log.Info("All child jobs completed")
return
}
}
}()
}