From f7434109be9820947deb53d9aed8da736696f1c3 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Mon, 27 Nov 2023 16:30:27 +0100 Subject: [PATCH] fix(GODT-3124): Race conditions reported by race check --- .../services/imapservice/sync_reporter.go | 80 ++++++++++++------- tests/ctx_heartbeat_test.go | 11 +++ tests/heartbeat_test.go | 2 +- 3 files changed, 62 insertions(+), 31 deletions(-) diff --git a/internal/services/imapservice/sync_reporter.go b/internal/services/imapservice/sync_reporter.go index 21f3ac4e..bc7684f1 100644 --- a/internal/services/imapservice/sync_reporter.go +++ b/internal/services/imapservice/sync_reporter.go @@ -19,15 +19,13 @@ package imapservice import ( "context" + "sync" "time" "github.com/ProtonMail/proton-bridge/v3/internal/events" ) -type syncReporter struct { - userID string - eventPublisher events.EventPublisher - +type syncData struct { start time.Time total int64 count int64 @@ -36,8 +34,25 @@ type syncReporter struct { freq time.Duration } +type syncReporter struct { + userID string + eventPublisher events.EventPublisher + + dataLock sync.Mutex + data syncData +} + +func (rep *syncReporter) withData(f func(s *syncData)) { + rep.dataLock.Lock() + defer rep.dataLock.Unlock() + + f(&rep.data) +} + func (rep *syncReporter) OnStart(ctx context.Context) { - rep.start = time.Now() + rep.withData(func(s *syncData) { + s.start = time.Now() + }) rep.eventPublisher.PublishEvent(ctx, events.SyncStarted{UserID: rep.userID}) } @@ -55,35 +70,38 @@ func (rep *syncReporter) OnError(ctx context.Context, err error) { } func (rep *syncReporter) OnProgress(ctx context.Context, delta int64) { - rep.count += delta + rep.withData(func(s *syncData) { + s.count += delta + var progress float64 + var remaining time.Duration - var progress float64 - var remaining time.Duration + // It's possible for count to be bigger or smaller than total depending on when the sync begins and whether new + // messages are added/removed during this period. When this happens just limited the progress to 100%. + if s.count > s.total { + progress = 1 + } else { + progress = float64(s.count) / float64(s.total) + remaining = time.Since(s.start) * time.Duration(s.total-(s.count+1)) / time.Duration(s.count+1) + } - // It's possible for count to be bigger or smaller than total depending on when the sync begins and whether new - // messages are added/removed during this period. When this happens just limited the progress to 100%. - if rep.count > rep.total { - progress = 1 - } else { - progress = float64(rep.count) / float64(rep.total) - remaining = time.Since(rep.start) * time.Duration(rep.total-(rep.count+1)) / time.Duration(rep.count+1) - } + if time.Since(s.last) > s.freq { + rep.eventPublisher.PublishEvent(ctx, events.SyncProgress{ + UserID: rep.userID, + Progress: progress, + Elapsed: time.Since(s.start), + Remaining: remaining, + }) - if time.Since(rep.last) > rep.freq { - rep.eventPublisher.PublishEvent(ctx, events.SyncProgress{ - UserID: rep.userID, - Progress: progress, - Elapsed: time.Since(rep.start), - Remaining: remaining, - }) - - rep.last = time.Now() - } + s.last = time.Now() + } + }) } func (rep *syncReporter) InitializeProgressCounter(_ context.Context, current int64, total int64) { - rep.count = current - rep.total = total + rep.withData(func(s *syncData) { + s.count = current + s.total = total + }) } func newSyncReporter(userID string, eventsPublisher events.EventPublisher, freq time.Duration) *syncReporter { @@ -91,7 +109,9 @@ func newSyncReporter(userID string, eventsPublisher events.EventPublisher, freq userID: userID, eventPublisher: eventsPublisher, - start: time.Now(), - freq: freq, + data: syncData{ + start: time.Now(), + freq: freq, + }, } } diff --git a/tests/ctx_heartbeat_test.go b/tests/ctx_heartbeat_test.go index d375616f..99ac9351 100644 --- a/tests/ctx_heartbeat_test.go +++ b/tests/ctx_heartbeat_test.go @@ -20,6 +20,7 @@ package tests import ( "context" "errors" + "sync" "testing" "time" @@ -29,6 +30,7 @@ import ( ) type heartbeatRecorder struct { + lock sync.Mutex heartbeat telemetry.HeartbeatData bridge *bridge.Bridge reject bool @@ -74,10 +76,19 @@ func (hb *heartbeatRecorder) SendHeartbeat(_ context.Context, metrics *telemetry if hb.reject { return false } + hb.lock.Lock() + defer hb.lock.Unlock() hb.heartbeat = *metrics return true } +func (hb *heartbeatRecorder) GetRecordedHeartbeat() telemetry.HeartbeatData { + hb.lock.Lock() + defer hb.lock.Unlock() + + return hb.heartbeat +} + func (hb *heartbeatRecorder) SetLastHeartbeatSent(timestamp time.Time) error { if hb.bridge == nil { return errors.New("no bridge initialized") diff --git a/tests/heartbeat_test.go b/tests/heartbeat_test.go index 65faa254..5b58044d 100644 --- a/tests/heartbeat_test.go +++ b/tests/heartbeat_test.go @@ -43,7 +43,7 @@ func (s *scenario) bridgeSendsTheFollowingHeartbeat(text *godog.DocString) error return err } - return matchHeartbeat(s.t.heartbeat.heartbeat, wantHeartbeat) + return matchHeartbeat(s.t.heartbeat.GetRecordedHeartbeat(), wantHeartbeat) } func (s *scenario) bridgeNeedsToSendHeartbeat() error {