From ca1996a670ee38bb9b67412a29b38d46ccf6df9e Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Mon, 6 Feb 2023 15:27:29 +0100 Subject: [PATCH] fix(GODT-2327): Properly cancel event stream when handling refresh --- internal/user/events.go | 5 +++- internal/user/user.go | 58 +++++++++++++++++++++++++---------------- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/internal/user/events.go b/internal/user/events.go index 36f8b362..390f30b5 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -85,8 +85,11 @@ func (user *User) handleRefreshEvent(ctx context.Context, refresh proton.Refresh l.WithError(err).Error("Failed to report refresh to sentry") } + // Cancel the event stream once this refresh is done. + defer user.pollAbort.Abort() + // Cancel and restart ongoing syncs. - user.abortable.Abort() + user.syncAbort.Abort() defer user.goSync() return safe.LockRet(func() error { diff --git a/internal/user/user.go b/internal/user/user.go index c3444359..0a763769 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -78,7 +78,8 @@ type User struct { updateChLock safe.RWMutex tasks *async.Group - abortable async.Abortable + syncAbort async.Abortable + pollAbort async.Abortable goSync func() pollAPIEventsCh chan chan struct{} @@ -182,29 +183,38 @@ func New( } } - // When triggered, attempt to sync the user. + // When triggered, sync the user and then begin streaming API events. user.goSync = user.tasks.Trigger(func(ctx context.Context) { user.log.Debug("Sync triggered") - user.abortable.Do(ctx, func(ctx context.Context) { - if !user.vault.SyncStatus().IsComplete() { - if err := user.doSync(ctx); err != nil { - user.log.WithError(err).Error("Failed to sync, will retry later") - - go func() { - select { - case <-ctx.Done(): - user.log.WithError(err).Warn("Aborting sync retry") - case <-time.After(SyncRetryCooldown): - user.goSync() - } - }() - } + // Sync the user. + user.syncAbort.Do(ctx, func(ctx context.Context) { + if user.vault.SyncStatus().IsComplete() { + user.log.Debug("Sync already complete, skipping") + return + } + + if err := user.doSync(ctx); err != nil { + user.log.WithError(err).Error("Failed to sync, will retry later") + + go func() { + select { + case <-ctx.Done(): + user.log.WithError(err).Warn("Aborting sync retry") + case <-time.After(SyncRetryCooldown): + user.goSync() + } + }() } - // Once we know the sync has completed, we can start polling for API events. - user.startEvents(ctx) }) + + // Once we know the sync has completed, we can start polling for API events. + if user.vault.SyncStatus().IsComplete() { + user.pollAbort.Do(ctx, func(ctx context.Context) { + user.startEvents(ctx) + }) + } }) return user, nil @@ -270,7 +280,8 @@ func (user *User) GetAddressMode() vault.AddressMode { func (user *User) SetAddressMode(_ context.Context, mode vault.AddressMode) error { user.log.WithField("mode", mode).Info("Setting address mode") - user.abortable.Abort() + user.syncAbort.Abort() + user.pollAbort.Abort() defer user.goSync() return safe.LockRet(func() error { @@ -461,7 +472,8 @@ func (user *User) OnStatusUp(context.Context) { func (user *User) OnStatusDown(context.Context) { user.log.Info("Connection is down") - user.abortable.Abort() + user.syncAbort.Abort() + user.pollAbort.Abort() } // GetSyncStatus returns the sync status of the user. @@ -471,6 +483,7 @@ func (user *User) GetSyncStatus() vault.SyncStatus { // ClearSyncStatus clears the sync status of the user. // This also drops any updates in the update channel(s). +// Warning: the gluon user must be removed and re-added if this happens! func (user *User) ClearSyncStatus() error { user.log.Info("Clearing sync status") @@ -481,6 +494,7 @@ func (user *User) ClearSyncStatus() error { // clearSyncStatus clears the sync status of the user. // This also drops any updates in the update channel(s). +// Warning: the gluon user must be removed and re-added if this happens! // It is assumed that the eventLock, apiAddrsLock and updateChLock are already locked. func (user *User) clearSyncStatus() error { user.log.Info("Clearing sync status") @@ -593,9 +607,7 @@ func (user *User) startEvents(ctx context.Context) { user.log.Debug("Event poll triggered") - if !user.vault.SyncStatus().IsComplete() { - user.log.Debug("Sync is incomplete, skipping event poll") - } else if err := user.doEventPoll(ctx); err != nil { + if err := user.doEventPoll(ctx); err != nil { user.log.WithError(err).Error("Failed to poll events") }