From 5a434fafbca6d0e659e3430c80f938aebcb90b3c Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Thu, 16 Nov 2023 11:05:40 +0100 Subject: [PATCH] fix(GODT-3125): Heartbeat crash on exit Ensure that the heartbeat background task is stopped before we close the users as it accesses data within these instances. Additionally, we also make sure that when telemetry is disabled, we stop the background task. Finally, `HeartbeatManager` now specifies what the desired interval is so we can better configure the test cases. --- internal/app/app.go | 3 - internal/app/bridge.go | 1 + internal/bridge/bridge.go | 21 +++- internal/bridge/bridge_test.go | 4 +- internal/bridge/heartbeat.go | 130 +++++++++++++++-------- internal/bridge/mocks.go | 2 + internal/bridge/mocks/telemetry_mocks.go | 14 +++ internal/bridge/settings.go | 7 +- internal/bridge/user.go | 2 +- internal/telemetry/mocks/mocks.go | 14 +++ internal/telemetry/types_heartbeat.go | 1 + tests/ctx_bridge_test.go | 3 +- tests/ctx_heartbeat_test.go | 4 + 13 files changed, 146 insertions(+), 60 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index f4e36035..b082d5b2 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -282,9 +282,6 @@ func run(c *cli.Context) error { // Remove old updates files b.RemoveOldUpdates() - // Start telemetry heartbeat process - b.StartHeartbeat(b) - // Run the frontend. return runFrontend(c, crashHandler, restarter, locations, b, eventCh, quitCh, c.Int(flagParentPID)) }) diff --git a/internal/app/bridge.go b/internal/app/bridge.go index 97af2937..421feba6 100644 --- a/internal/app/bridge.go +++ b/internal/app/bridge.go @@ -113,6 +113,7 @@ func withBridge( crashHandler, reporter, imap.DefaultEpochUIDValidityGenerator(), + nil, // The logging stuff. c.String(flagLogIMAP) == "client" || c.String(flagLogIMAP) == "all", diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 6a91697f..ef21ae1e 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -75,7 +75,7 @@ type Bridge struct { installCh chan installJob // heartbeat is the telemetry heartbeat for metrics. - heartbeat telemetry.Heartbeat + heartbeat *heartBeatState // curVersion is the current version of the bridge, // newVersion is the version that was installed by the updater. @@ -128,9 +128,6 @@ type Bridge struct { // goUpdate triggers a check/install of updates. goUpdate func() - // goHeartbeat triggers a check/sending if heartbeat is needed. - goHeartbeat func() - serverManager *imapsmtpserver.Service syncService *syncservice.Service } @@ -153,6 +150,7 @@ func New( panicHandler async.PanicHandler, reporter reporter.Reporter, uidValidityGenerator imap.UIDValidityGenerator, + heartBeatManager telemetry.HeartbeatManager, logIMAPClient, logIMAPServer bool, // whether to log IMAP client/server activity logSMTP bool, // whether to log SMTP activity @@ -168,6 +166,7 @@ func New( // bridge is the bridge. bridge, err := newBridge( + context.Background(), tasks, imapEventCh, @@ -184,6 +183,7 @@ func New( identifier, proxyCtl, uidValidityGenerator, + heartBeatManager, logIMAPClient, logIMAPServer, logSMTP, ) if err != nil { @@ -202,6 +202,7 @@ func New( } func newBridge( + ctx context.Context, tasks *async.Group, imapEventCh chan imapEvents.Event, @@ -218,6 +219,7 @@ func newBridge( identifier identifier.Identifier, proxyCtl ProxyController, uidValidityGenerator imap.UIDValidityGenerator, + heartbeatManager telemetry.HeartbeatManager, logIMAPClient, logIMAPServer, logSMTP bool, ) (*Bridge, error) { @@ -268,6 +270,8 @@ func newBridge( panicHandler: panicHandler, reporter: reporter, + heartbeat: newHeartBeatState(ctx, panicHandler), + focusService: focusService, autostarter: autostarter, locator: locator, @@ -297,6 +301,12 @@ func newBridge( return nil, err } + if heartbeatManager == nil { + bridge.heartbeat.init(bridge, bridge) + } else { + bridge.heartbeat.init(bridge, heartbeatManager) + } + bridge.syncService.Run(bridge.tasks) return bridge, nil @@ -426,6 +436,9 @@ func (bridge *Bridge) GetErrors() []error { func (bridge *Bridge) Close(ctx context.Context) { logrus.Info("Closing bridge") + // Stop heart beat before closing users. + bridge.heartbeat.stop() + // Close all users. safe.Lock(func() { for _, user := range bridge.users { diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 9540e02c..bb20e464 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -963,6 +963,7 @@ func withBridgeNoMocks( mocks.CrashHandler, mocks.Reporter, testUIDValidityGenerator, + mocks.Heartbeat, // The logging stuff. os.Getenv("BRIDGE_LOG_IMAP_CLIENT") == "1", @@ -972,9 +973,6 @@ func withBridgeNoMocks( require.NoError(t, err) require.Empty(t, bridge.GetErrors()) - // Start the Heartbeat process. - bridge.StartHeartbeat(mocks.Heartbeat) - // Wait for bridge to finish loading users. waitForEvent(t, eventCh, events.AllUsersLoaded{}) diff --git a/internal/bridge/heartbeat.go b/internal/bridge/heartbeat.go index a851ecb5..241bfd88 100644 --- a/internal/bridge/heartbeat.go +++ b/internal/bridge/heartbeat.go @@ -20,8 +20,10 @@ package bridge import ( "context" "encoding/json" + "sync" "time" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/telemetry" @@ -31,6 +33,87 @@ import ( const HeartbeatCheckInterval = time.Hour +type heartBeatState struct { + task *async.Group + telemetry.Heartbeat + taskLock sync.Mutex + taskStarted bool + taskInterval time.Duration +} + +func newHeartBeatState(ctx context.Context, panicHandler async.PanicHandler) *heartBeatState { + return &heartBeatState{ + task: async.NewGroup(ctx, panicHandler), + } +} + +func (h *heartBeatState) init(bridge *Bridge, manager telemetry.HeartbeatManager) { + h.Heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), bridge.keychains.GetDefaultHelper()) + h.taskInterval = manager.GetHeartbeatPeriodicInterval() + h.SetRollout(bridge.GetUpdateRollout()) + h.SetAutoStart(bridge.GetAutostart()) + h.SetAutoUpdate(bridge.GetAutoUpdate()) + h.SetBeta(bridge.GetUpdateChannel()) + h.SetDoh(bridge.GetProxyAllowed()) + h.SetShowAllMail(bridge.GetShowAllMail()) + h.SetIMAPConnectionMode(bridge.GetIMAPSSL()) + h.SetSMTPConnectionMode(bridge.GetSMTPSSL()) + h.SetIMAPPort(bridge.GetIMAPPort()) + h.SetSMTPPort(bridge.GetSMTPPort()) + h.SetCacheLocation(bridge.GetGluonCacheDir()) + if val, err := bridge.GetKeychainApp(); err != nil { + h.SetKeyChainPref(val) + } else { + h.SetKeyChainPref(bridge.keychains.GetDefaultHelper()) + } + h.SetPrevVersion(bridge.GetLastVersion().String()) + + safe.RLock(func() { + var splitMode = false + for _, user := range bridge.users { + if user.GetAddressMode() == vault.SplitMode { + splitMode = true + break + } + } + var nbAccount = len(bridge.users) + h.SetNbAccount(nbAccount) + h.SetSplitMode(splitMode) + + // Do not try to send if there is no user yet. + if nbAccount > 0 { + defer h.start() + } + }, bridge.usersLock) +} + +func (h *heartBeatState) start() { + h.taskLock.Lock() + defer h.taskLock.Unlock() + if h.taskStarted { + return + } + + h.taskStarted = true + + h.task.PeriodicOrTrigger(h.taskInterval, 0, func(ctx context.Context) { + logrus.Debug("Checking for heartbeat") + + h.TrySending(ctx) + }) +} + +func (h *heartBeatState) stop() { + h.taskLock.Lock() + defer h.taskLock.Unlock() + if !h.taskStarted { + return + } + + h.task.CancelAndWait() + h.taskStarted = false +} + func (bridge *Bridge) IsTelemetryAvailable(ctx context.Context) bool { var flag = true if bridge.GetTelemetryDisabled() { @@ -79,49 +162,6 @@ func (bridge *Bridge) SetLastHeartbeatSent(timestamp time.Time) error { return bridge.vault.SetLastHeartbeatSent(timestamp) } -func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) { - bridge.heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), bridge.keychains.GetDefaultHelper()) - - // Check for heartbeat when triggered. - bridge.goHeartbeat = bridge.tasks.PeriodicOrTrigger(HeartbeatCheckInterval, 0, func(ctx context.Context) { - logrus.Debug("Checking for heartbeat") - - bridge.heartbeat.TrySending(ctx) - }) - - bridge.heartbeat.SetRollout(bridge.GetUpdateRollout()) - bridge.heartbeat.SetAutoStart(bridge.GetAutostart()) - bridge.heartbeat.SetAutoUpdate(bridge.GetAutoUpdate()) - bridge.heartbeat.SetBeta(bridge.GetUpdateChannel()) - bridge.heartbeat.SetDoh(bridge.GetProxyAllowed()) - bridge.heartbeat.SetShowAllMail(bridge.GetShowAllMail()) - bridge.heartbeat.SetIMAPConnectionMode(bridge.GetIMAPSSL()) - bridge.heartbeat.SetSMTPConnectionMode(bridge.GetSMTPSSL()) - bridge.heartbeat.SetIMAPPort(bridge.GetIMAPPort()) - bridge.heartbeat.SetSMTPPort(bridge.GetSMTPPort()) - bridge.heartbeat.SetCacheLocation(bridge.GetGluonCacheDir()) - if val, err := bridge.GetKeychainApp(); err != nil { - bridge.heartbeat.SetKeyChainPref(val) - } else { - bridge.heartbeat.SetKeyChainPref(bridge.keychains.GetDefaultHelper()) - } - bridge.heartbeat.SetPrevVersion(bridge.GetLastVersion().String()) - - safe.RLock(func() { - var splitMode = false - for _, user := range bridge.users { - if user.GetAddressMode() == vault.SplitMode { - splitMode = true - break - } - } - var nbAccount = len(bridge.users) - bridge.heartbeat.SetNbAccount(nbAccount) - bridge.heartbeat.SetSplitMode(splitMode) - - // Do not try to send if there is no user yet. - if nbAccount > 0 { - defer bridge.goHeartbeat() - } - }, bridge.usersLock) +func (bridge *Bridge) GetHeartbeatPeriodicInterval() time.Duration { + return HeartbeatCheckInterval } diff --git a/internal/bridge/mocks.go b/internal/bridge/mocks.go index f0818752..d8c74098 100644 --- a/internal/bridge/mocks.go +++ b/internal/bridge/mocks.go @@ -7,6 +7,7 @@ import ( "os" "sync" "testing" + "time" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v3/internal/bridge/mocks" @@ -51,6 +52,7 @@ func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks { // this is called at start of heartbeat process. mocks.Heartbeat.EXPECT().IsTelemetryAvailable(gomock.Any()).AnyTimes() + mocks.Heartbeat.EXPECT().GetHeartbeatPeriodicInterval().AnyTimes().Return(500 * time.Millisecond) return mocks } diff --git a/internal/bridge/mocks/telemetry_mocks.go b/internal/bridge/mocks/telemetry_mocks.go index a1be4d5d..ff7aed05 100644 --- a/internal/bridge/mocks/telemetry_mocks.go +++ b/internal/bridge/mocks/telemetry_mocks.go @@ -36,6 +36,20 @@ func (m *MockHeartbeatManager) EXPECT() *MockHeartbeatManagerMockRecorder { return m.recorder } +// GetHeartbeatPeriodicInterval mocks base method. +func (m *MockHeartbeatManager) GetHeartbeatPeriodicInterval() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHeartbeatPeriodicInterval") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// GetHeartbeatPeriodicInterval indicates an expected call of GetHeartbeatPeriodicInterval. +func (mr *MockHeartbeatManagerMockRecorder) GetHeartbeatPeriodicInterval() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeartbeatPeriodicInterval", reflect.TypeOf((*MockHeartbeatManager)(nil).GetHeartbeatPeriodicInterval)) +} + // GetLastHeartbeatSent mocks base method. func (m *MockHeartbeatManager) GetLastHeartbeatSent() time.Time { m.ctrl.T.Helper() diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index 5471068f..b12fee1a 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -261,9 +261,12 @@ func (bridge *Bridge) SetTelemetryDisabled(isDisabled bool) error { return err } // If telemetry is re-enabled locally, try to send the heartbeat. - if !isDisabled { - defer bridge.goHeartbeat() + if isDisabled { + bridge.heartbeat.stop() + } else { + bridge.heartbeat.start() } + return nil } diff --git a/internal/bridge/user.go b/internal/bridge/user.go index c9c1d719..9fa60c60 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -594,7 +594,7 @@ func (bridge *Bridge) addUserWithVault( }, bridge.usersLock) // As we need at least one user to send heartbeat, try to send it. - defer bridge.goHeartbeat() + bridge.heartbeat.start() return nil } diff --git a/internal/telemetry/mocks/mocks.go b/internal/telemetry/mocks/mocks.go index a1be4d5d..ff7aed05 100644 --- a/internal/telemetry/mocks/mocks.go +++ b/internal/telemetry/mocks/mocks.go @@ -36,6 +36,20 @@ func (m *MockHeartbeatManager) EXPECT() *MockHeartbeatManagerMockRecorder { return m.recorder } +// GetHeartbeatPeriodicInterval mocks base method. +func (m *MockHeartbeatManager) GetHeartbeatPeriodicInterval() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHeartbeatPeriodicInterval") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// GetHeartbeatPeriodicInterval indicates an expected call of GetHeartbeatPeriodicInterval. +func (mr *MockHeartbeatManagerMockRecorder) GetHeartbeatPeriodicInterval() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeartbeatPeriodicInterval", reflect.TypeOf((*MockHeartbeatManager)(nil).GetHeartbeatPeriodicInterval)) +} + // GetLastHeartbeatSent mocks base method. func (m *MockHeartbeatManager) GetLastHeartbeatSent() time.Time { m.ctrl.T.Helper() diff --git a/internal/telemetry/types_heartbeat.go b/internal/telemetry/types_heartbeat.go index 0460e4b8..cad1e78c 100644 --- a/internal/telemetry/types_heartbeat.go +++ b/internal/telemetry/types_heartbeat.go @@ -42,6 +42,7 @@ type HeartbeatManager interface { SendHeartbeat(ctx context.Context, heartbeat *HeartbeatData) bool GetLastHeartbeatSent() time.Time SetLastHeartbeatSent(time.Time) error + GetHeartbeatPeriodicInterval() time.Duration } type HeartbeatValues struct { diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index e551510c..b344f446 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -166,6 +166,7 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) { t.mocks.CrashHandler, t.reporter, imap.DefaultEpochUIDValidityGenerator(), + t.heartbeat, // Logging stuff logIMAP, @@ -179,8 +180,6 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) { t.bridge = bridge t.heartbeat.setBridge(bridge) - bridge.StartHeartbeat(t.heartbeat) - return t.events.collectFrom(eventCh), nil } diff --git a/tests/ctx_heartbeat_test.go b/tests/ctx_heartbeat_test.go index 07179e00..d375616f 100644 --- a/tests/ctx_heartbeat_test.go +++ b/tests/ctx_heartbeat_test.go @@ -85,6 +85,10 @@ func (hb *heartbeatRecorder) SetLastHeartbeatSent(timestamp time.Time) error { return hb.bridge.SetLastHeartbeatSent(timestamp) } +func (hb *heartbeatRecorder) GetHeartbeatPeriodicInterval() time.Duration { + return 200 * time.Millisecond +} + func (hb *heartbeatRecorder) rejectSend() { hb.reject = true }