fix(GODT-3125): Heartbeat crash on exit

Ensure that the heartbeat background task is stopped before we close
the users as it accesses data within these instances.

Additionally, we also make sure that when telemetry is disabled, we stop
the background task.

Finally, `HeartbeatManager` now specifies what the desired interval is
so we can better configure the test cases.
This commit is contained in:
Leander Beernaert
2023-11-16 11:05:40 +01:00
parent ea1c2534df
commit 5a434fafbc
13 changed files with 146 additions and 60 deletions

View File

@ -282,9 +282,6 @@ func run(c *cli.Context) error {
// Remove old updates files // Remove old updates files
b.RemoveOldUpdates() b.RemoveOldUpdates()
// Start telemetry heartbeat process
b.StartHeartbeat(b)
// Run the frontend. // Run the frontend.
return runFrontend(c, crashHandler, restarter, locations, b, eventCh, quitCh, c.Int(flagParentPID)) return runFrontend(c, crashHandler, restarter, locations, b, eventCh, quitCh, c.Int(flagParentPID))
}) })

View File

@ -113,6 +113,7 @@ func withBridge(
crashHandler, crashHandler,
reporter, reporter,
imap.DefaultEpochUIDValidityGenerator(), imap.DefaultEpochUIDValidityGenerator(),
nil,
// The logging stuff. // The logging stuff.
c.String(flagLogIMAP) == "client" || c.String(flagLogIMAP) == "all", c.String(flagLogIMAP) == "client" || c.String(flagLogIMAP) == "all",

View File

@ -75,7 +75,7 @@ type Bridge struct {
installCh chan installJob installCh chan installJob
// heartbeat is the telemetry heartbeat for metrics. // heartbeat is the telemetry heartbeat for metrics.
heartbeat telemetry.Heartbeat heartbeat *heartBeatState
// curVersion is the current version of the bridge, // curVersion is the current version of the bridge,
// newVersion is the version that was installed by the updater. // newVersion is the version that was installed by the updater.
@ -128,9 +128,6 @@ type Bridge struct {
// goUpdate triggers a check/install of updates. // goUpdate triggers a check/install of updates.
goUpdate func() goUpdate func()
// goHeartbeat triggers a check/sending if heartbeat is needed.
goHeartbeat func()
serverManager *imapsmtpserver.Service serverManager *imapsmtpserver.Service
syncService *syncservice.Service syncService *syncservice.Service
} }
@ -153,6 +150,7 @@ func New(
panicHandler async.PanicHandler, panicHandler async.PanicHandler,
reporter reporter.Reporter, reporter reporter.Reporter,
uidValidityGenerator imap.UIDValidityGenerator, uidValidityGenerator imap.UIDValidityGenerator,
heartBeatManager telemetry.HeartbeatManager,
logIMAPClient, logIMAPServer bool, // whether to log IMAP client/server activity logIMAPClient, logIMAPServer bool, // whether to log IMAP client/server activity
logSMTP bool, // whether to log SMTP activity logSMTP bool, // whether to log SMTP activity
@ -168,6 +166,7 @@ func New(
// bridge is the bridge. // bridge is the bridge.
bridge, err := newBridge( bridge, err := newBridge(
context.Background(),
tasks, tasks,
imapEventCh, imapEventCh,
@ -184,6 +183,7 @@ func New(
identifier, identifier,
proxyCtl, proxyCtl,
uidValidityGenerator, uidValidityGenerator,
heartBeatManager,
logIMAPClient, logIMAPServer, logSMTP, logIMAPClient, logIMAPServer, logSMTP,
) )
if err != nil { if err != nil {
@ -202,6 +202,7 @@ func New(
} }
func newBridge( func newBridge(
ctx context.Context,
tasks *async.Group, tasks *async.Group,
imapEventCh chan imapEvents.Event, imapEventCh chan imapEvents.Event,
@ -218,6 +219,7 @@ func newBridge(
identifier identifier.Identifier, identifier identifier.Identifier,
proxyCtl ProxyController, proxyCtl ProxyController,
uidValidityGenerator imap.UIDValidityGenerator, uidValidityGenerator imap.UIDValidityGenerator,
heartbeatManager telemetry.HeartbeatManager,
logIMAPClient, logIMAPServer, logSMTP bool, logIMAPClient, logIMAPServer, logSMTP bool,
) (*Bridge, error) { ) (*Bridge, error) {
@ -268,6 +270,8 @@ func newBridge(
panicHandler: panicHandler, panicHandler: panicHandler,
reporter: reporter, reporter: reporter,
heartbeat: newHeartBeatState(ctx, panicHandler),
focusService: focusService, focusService: focusService,
autostarter: autostarter, autostarter: autostarter,
locator: locator, locator: locator,
@ -297,6 +301,12 @@ func newBridge(
return nil, err return nil, err
} }
if heartbeatManager == nil {
bridge.heartbeat.init(bridge, bridge)
} else {
bridge.heartbeat.init(bridge, heartbeatManager)
}
bridge.syncService.Run(bridge.tasks) bridge.syncService.Run(bridge.tasks)
return bridge, nil return bridge, nil
@ -426,6 +436,9 @@ func (bridge *Bridge) GetErrors() []error {
func (bridge *Bridge) Close(ctx context.Context) { func (bridge *Bridge) Close(ctx context.Context) {
logrus.Info("Closing bridge") logrus.Info("Closing bridge")
// Stop heart beat before closing users.
bridge.heartbeat.stop()
// Close all users. // Close all users.
safe.Lock(func() { safe.Lock(func() {
for _, user := range bridge.users { for _, user := range bridge.users {

View File

@ -963,6 +963,7 @@ func withBridgeNoMocks(
mocks.CrashHandler, mocks.CrashHandler,
mocks.Reporter, mocks.Reporter,
testUIDValidityGenerator, testUIDValidityGenerator,
mocks.Heartbeat,
// The logging stuff. // The logging stuff.
os.Getenv("BRIDGE_LOG_IMAP_CLIENT") == "1", os.Getenv("BRIDGE_LOG_IMAP_CLIENT") == "1",
@ -972,9 +973,6 @@ func withBridgeNoMocks(
require.NoError(t, err) require.NoError(t, err)
require.Empty(t, bridge.GetErrors()) require.Empty(t, bridge.GetErrors())
// Start the Heartbeat process.
bridge.StartHeartbeat(mocks.Heartbeat)
// Wait for bridge to finish loading users. // Wait for bridge to finish loading users.
waitForEvent(t, eventCh, events.AllUsersLoaded{}) waitForEvent(t, eventCh, events.AllUsersLoaded{})

View File

@ -20,8 +20,10 @@ package bridge
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"sync"
"time" "time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry" "github.com/ProtonMail/proton-bridge/v3/internal/telemetry"
@ -31,6 +33,87 @@ import (
const HeartbeatCheckInterval = time.Hour const HeartbeatCheckInterval = time.Hour
type heartBeatState struct {
task *async.Group
telemetry.Heartbeat
taskLock sync.Mutex
taskStarted bool
taskInterval time.Duration
}
func newHeartBeatState(ctx context.Context, panicHandler async.PanicHandler) *heartBeatState {
return &heartBeatState{
task: async.NewGroup(ctx, panicHandler),
}
}
func (h *heartBeatState) init(bridge *Bridge, manager telemetry.HeartbeatManager) {
h.Heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), bridge.keychains.GetDefaultHelper())
h.taskInterval = manager.GetHeartbeatPeriodicInterval()
h.SetRollout(bridge.GetUpdateRollout())
h.SetAutoStart(bridge.GetAutostart())
h.SetAutoUpdate(bridge.GetAutoUpdate())
h.SetBeta(bridge.GetUpdateChannel())
h.SetDoh(bridge.GetProxyAllowed())
h.SetShowAllMail(bridge.GetShowAllMail())
h.SetIMAPConnectionMode(bridge.GetIMAPSSL())
h.SetSMTPConnectionMode(bridge.GetSMTPSSL())
h.SetIMAPPort(bridge.GetIMAPPort())
h.SetSMTPPort(bridge.GetSMTPPort())
h.SetCacheLocation(bridge.GetGluonCacheDir())
if val, err := bridge.GetKeychainApp(); err != nil {
h.SetKeyChainPref(val)
} else {
h.SetKeyChainPref(bridge.keychains.GetDefaultHelper())
}
h.SetPrevVersion(bridge.GetLastVersion().String())
safe.RLock(func() {
var splitMode = false
for _, user := range bridge.users {
if user.GetAddressMode() == vault.SplitMode {
splitMode = true
break
}
}
var nbAccount = len(bridge.users)
h.SetNbAccount(nbAccount)
h.SetSplitMode(splitMode)
// Do not try to send if there is no user yet.
if nbAccount > 0 {
defer h.start()
}
}, bridge.usersLock)
}
func (h *heartBeatState) start() {
h.taskLock.Lock()
defer h.taskLock.Unlock()
if h.taskStarted {
return
}
h.taskStarted = true
h.task.PeriodicOrTrigger(h.taskInterval, 0, func(ctx context.Context) {
logrus.Debug("Checking for heartbeat")
h.TrySending(ctx)
})
}
func (h *heartBeatState) stop() {
h.taskLock.Lock()
defer h.taskLock.Unlock()
if !h.taskStarted {
return
}
h.task.CancelAndWait()
h.taskStarted = false
}
func (bridge *Bridge) IsTelemetryAvailable(ctx context.Context) bool { func (bridge *Bridge) IsTelemetryAvailable(ctx context.Context) bool {
var flag = true var flag = true
if bridge.GetTelemetryDisabled() { if bridge.GetTelemetryDisabled() {
@ -79,49 +162,6 @@ func (bridge *Bridge) SetLastHeartbeatSent(timestamp time.Time) error {
return bridge.vault.SetLastHeartbeatSent(timestamp) return bridge.vault.SetLastHeartbeatSent(timestamp)
} }
func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) { func (bridge *Bridge) GetHeartbeatPeriodicInterval() time.Duration {
bridge.heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), bridge.keychains.GetDefaultHelper()) return HeartbeatCheckInterval
// Check for heartbeat when triggered.
bridge.goHeartbeat = bridge.tasks.PeriodicOrTrigger(HeartbeatCheckInterval, 0, func(ctx context.Context) {
logrus.Debug("Checking for heartbeat")
bridge.heartbeat.TrySending(ctx)
})
bridge.heartbeat.SetRollout(bridge.GetUpdateRollout())
bridge.heartbeat.SetAutoStart(bridge.GetAutostart())
bridge.heartbeat.SetAutoUpdate(bridge.GetAutoUpdate())
bridge.heartbeat.SetBeta(bridge.GetUpdateChannel())
bridge.heartbeat.SetDoh(bridge.GetProxyAllowed())
bridge.heartbeat.SetShowAllMail(bridge.GetShowAllMail())
bridge.heartbeat.SetIMAPConnectionMode(bridge.GetIMAPSSL())
bridge.heartbeat.SetSMTPConnectionMode(bridge.GetSMTPSSL())
bridge.heartbeat.SetIMAPPort(bridge.GetIMAPPort())
bridge.heartbeat.SetSMTPPort(bridge.GetSMTPPort())
bridge.heartbeat.SetCacheLocation(bridge.GetGluonCacheDir())
if val, err := bridge.GetKeychainApp(); err != nil {
bridge.heartbeat.SetKeyChainPref(val)
} else {
bridge.heartbeat.SetKeyChainPref(bridge.keychains.GetDefaultHelper())
}
bridge.heartbeat.SetPrevVersion(bridge.GetLastVersion().String())
safe.RLock(func() {
var splitMode = false
for _, user := range bridge.users {
if user.GetAddressMode() == vault.SplitMode {
splitMode = true
break
}
}
var nbAccount = len(bridge.users)
bridge.heartbeat.SetNbAccount(nbAccount)
bridge.heartbeat.SetSplitMode(splitMode)
// Do not try to send if there is no user yet.
if nbAccount > 0 {
defer bridge.goHeartbeat()
}
}, bridge.usersLock)
} }

View File

@ -7,6 +7,7 @@ import (
"os" "os"
"sync" "sync"
"testing" "testing"
"time"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v3/internal/bridge/mocks" "github.com/ProtonMail/proton-bridge/v3/internal/bridge/mocks"
@ -51,6 +52,7 @@ func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks {
// this is called at start of heartbeat process. // this is called at start of heartbeat process.
mocks.Heartbeat.EXPECT().IsTelemetryAvailable(gomock.Any()).AnyTimes() mocks.Heartbeat.EXPECT().IsTelemetryAvailable(gomock.Any()).AnyTimes()
mocks.Heartbeat.EXPECT().GetHeartbeatPeriodicInterval().AnyTimes().Return(500 * time.Millisecond)
return mocks return mocks
} }

View File

@ -36,6 +36,20 @@ func (m *MockHeartbeatManager) EXPECT() *MockHeartbeatManagerMockRecorder {
return m.recorder return m.recorder
} }
// GetHeartbeatPeriodicInterval mocks base method.
func (m *MockHeartbeatManager) GetHeartbeatPeriodicInterval() time.Duration {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetHeartbeatPeriodicInterval")
ret0, _ := ret[0].(time.Duration)
return ret0
}
// GetHeartbeatPeriodicInterval indicates an expected call of GetHeartbeatPeriodicInterval.
func (mr *MockHeartbeatManagerMockRecorder) GetHeartbeatPeriodicInterval() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeartbeatPeriodicInterval", reflect.TypeOf((*MockHeartbeatManager)(nil).GetHeartbeatPeriodicInterval))
}
// GetLastHeartbeatSent mocks base method. // GetLastHeartbeatSent mocks base method.
func (m *MockHeartbeatManager) GetLastHeartbeatSent() time.Time { func (m *MockHeartbeatManager) GetLastHeartbeatSent() time.Time {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -261,9 +261,12 @@ func (bridge *Bridge) SetTelemetryDisabled(isDisabled bool) error {
return err return err
} }
// If telemetry is re-enabled locally, try to send the heartbeat. // If telemetry is re-enabled locally, try to send the heartbeat.
if !isDisabled { if isDisabled {
defer bridge.goHeartbeat() bridge.heartbeat.stop()
} else {
bridge.heartbeat.start()
} }
return nil return nil
} }

View File

@ -594,7 +594,7 @@ func (bridge *Bridge) addUserWithVault(
}, bridge.usersLock) }, bridge.usersLock)
// As we need at least one user to send heartbeat, try to send it. // As we need at least one user to send heartbeat, try to send it.
defer bridge.goHeartbeat() bridge.heartbeat.start()
return nil return nil
} }

View File

@ -36,6 +36,20 @@ func (m *MockHeartbeatManager) EXPECT() *MockHeartbeatManagerMockRecorder {
return m.recorder return m.recorder
} }
// GetHeartbeatPeriodicInterval mocks base method.
func (m *MockHeartbeatManager) GetHeartbeatPeriodicInterval() time.Duration {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetHeartbeatPeriodicInterval")
ret0, _ := ret[0].(time.Duration)
return ret0
}
// GetHeartbeatPeriodicInterval indicates an expected call of GetHeartbeatPeriodicInterval.
func (mr *MockHeartbeatManagerMockRecorder) GetHeartbeatPeriodicInterval() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeartbeatPeriodicInterval", reflect.TypeOf((*MockHeartbeatManager)(nil).GetHeartbeatPeriodicInterval))
}
// GetLastHeartbeatSent mocks base method. // GetLastHeartbeatSent mocks base method.
func (m *MockHeartbeatManager) GetLastHeartbeatSent() time.Time { func (m *MockHeartbeatManager) GetLastHeartbeatSent() time.Time {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -42,6 +42,7 @@ type HeartbeatManager interface {
SendHeartbeat(ctx context.Context, heartbeat *HeartbeatData) bool SendHeartbeat(ctx context.Context, heartbeat *HeartbeatData) bool
GetLastHeartbeatSent() time.Time GetLastHeartbeatSent() time.Time
SetLastHeartbeatSent(time.Time) error SetLastHeartbeatSent(time.Time) error
GetHeartbeatPeriodicInterval() time.Duration
} }
type HeartbeatValues struct { type HeartbeatValues struct {

View File

@ -166,6 +166,7 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) {
t.mocks.CrashHandler, t.mocks.CrashHandler,
t.reporter, t.reporter,
imap.DefaultEpochUIDValidityGenerator(), imap.DefaultEpochUIDValidityGenerator(),
t.heartbeat,
// Logging stuff // Logging stuff
logIMAP, logIMAP,
@ -179,8 +180,6 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) {
t.bridge = bridge t.bridge = bridge
t.heartbeat.setBridge(bridge) t.heartbeat.setBridge(bridge)
bridge.StartHeartbeat(t.heartbeat)
return t.events.collectFrom(eventCh), nil return t.events.collectFrom(eventCh), nil
} }

View File

@ -85,6 +85,10 @@ func (hb *heartbeatRecorder) SetLastHeartbeatSent(timestamp time.Time) error {
return hb.bridge.SetLastHeartbeatSent(timestamp) return hb.bridge.SetLastHeartbeatSent(timestamp)
} }
func (hb *heartbeatRecorder) GetHeartbeatPeriodicInterval() time.Duration {
return 200 * time.Millisecond
}
func (hb *heartbeatRecorder) rejectSend() { func (hb *heartbeatRecorder) rejectSend() {
hb.reject = true hb.reject = true
} }