diff --git a/internal/app/app.go b/internal/app/app.go index f4e36035..b082d5b2 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -282,9 +282,6 @@ func run(c *cli.Context) error { // Remove old updates files b.RemoveOldUpdates() - // Start telemetry heartbeat process - b.StartHeartbeat(b) - // Run the frontend. return runFrontend(c, crashHandler, restarter, locations, b, eventCh, quitCh, c.Int(flagParentPID)) }) diff --git a/internal/app/bridge.go b/internal/app/bridge.go index 97af2937..421feba6 100644 --- a/internal/app/bridge.go +++ b/internal/app/bridge.go @@ -113,6 +113,7 @@ func withBridge( crashHandler, reporter, imap.DefaultEpochUIDValidityGenerator(), + nil, // The logging stuff. c.String(flagLogIMAP) == "client" || c.String(flagLogIMAP) == "all", diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 6a91697f..ef21ae1e 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -75,7 +75,7 @@ type Bridge struct { installCh chan installJob // heartbeat is the telemetry heartbeat for metrics. - heartbeat telemetry.Heartbeat + heartbeat *heartBeatState // curVersion is the current version of the bridge, // newVersion is the version that was installed by the updater. @@ -128,9 +128,6 @@ type Bridge struct { // goUpdate triggers a check/install of updates. goUpdate func() - // goHeartbeat triggers a check/sending if heartbeat is needed. - goHeartbeat func() - serverManager *imapsmtpserver.Service syncService *syncservice.Service } @@ -153,6 +150,7 @@ func New( panicHandler async.PanicHandler, reporter reporter.Reporter, uidValidityGenerator imap.UIDValidityGenerator, + heartBeatManager telemetry.HeartbeatManager, logIMAPClient, logIMAPServer bool, // whether to log IMAP client/server activity logSMTP bool, // whether to log SMTP activity @@ -168,6 +166,7 @@ func New( // bridge is the bridge. bridge, err := newBridge( + context.Background(), tasks, imapEventCh, @@ -184,6 +183,7 @@ func New( identifier, proxyCtl, uidValidityGenerator, + heartBeatManager, logIMAPClient, logIMAPServer, logSMTP, ) if err != nil { @@ -202,6 +202,7 @@ func New( } func newBridge( + ctx context.Context, tasks *async.Group, imapEventCh chan imapEvents.Event, @@ -218,6 +219,7 @@ func newBridge( identifier identifier.Identifier, proxyCtl ProxyController, uidValidityGenerator imap.UIDValidityGenerator, + heartbeatManager telemetry.HeartbeatManager, logIMAPClient, logIMAPServer, logSMTP bool, ) (*Bridge, error) { @@ -268,6 +270,8 @@ func newBridge( panicHandler: panicHandler, reporter: reporter, + heartbeat: newHeartBeatState(ctx, panicHandler), + focusService: focusService, autostarter: autostarter, locator: locator, @@ -297,6 +301,12 @@ func newBridge( return nil, err } + if heartbeatManager == nil { + bridge.heartbeat.init(bridge, bridge) + } else { + bridge.heartbeat.init(bridge, heartbeatManager) + } + bridge.syncService.Run(bridge.tasks) return bridge, nil @@ -426,6 +436,9 @@ func (bridge *Bridge) GetErrors() []error { func (bridge *Bridge) Close(ctx context.Context) { logrus.Info("Closing bridge") + // Stop heart beat before closing users. + bridge.heartbeat.stop() + // Close all users. safe.Lock(func() { for _, user := range bridge.users { diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 9540e02c..bb20e464 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -963,6 +963,7 @@ func withBridgeNoMocks( mocks.CrashHandler, mocks.Reporter, testUIDValidityGenerator, + mocks.Heartbeat, // The logging stuff. os.Getenv("BRIDGE_LOG_IMAP_CLIENT") == "1", @@ -972,9 +973,6 @@ func withBridgeNoMocks( require.NoError(t, err) require.Empty(t, bridge.GetErrors()) - // Start the Heartbeat process. - bridge.StartHeartbeat(mocks.Heartbeat) - // Wait for bridge to finish loading users. waitForEvent(t, eventCh, events.AllUsersLoaded{}) diff --git a/internal/bridge/heartbeat.go b/internal/bridge/heartbeat.go index a851ecb5..241bfd88 100644 --- a/internal/bridge/heartbeat.go +++ b/internal/bridge/heartbeat.go @@ -20,8 +20,10 @@ package bridge import ( "context" "encoding/json" + "sync" "time" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/telemetry" @@ -31,6 +33,87 @@ import ( const HeartbeatCheckInterval = time.Hour +type heartBeatState struct { + task *async.Group + telemetry.Heartbeat + taskLock sync.Mutex + taskStarted bool + taskInterval time.Duration +} + +func newHeartBeatState(ctx context.Context, panicHandler async.PanicHandler) *heartBeatState { + return &heartBeatState{ + task: async.NewGroup(ctx, panicHandler), + } +} + +func (h *heartBeatState) init(bridge *Bridge, manager telemetry.HeartbeatManager) { + h.Heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), bridge.keychains.GetDefaultHelper()) + h.taskInterval = manager.GetHeartbeatPeriodicInterval() + h.SetRollout(bridge.GetUpdateRollout()) + h.SetAutoStart(bridge.GetAutostart()) + h.SetAutoUpdate(bridge.GetAutoUpdate()) + h.SetBeta(bridge.GetUpdateChannel()) + h.SetDoh(bridge.GetProxyAllowed()) + h.SetShowAllMail(bridge.GetShowAllMail()) + h.SetIMAPConnectionMode(bridge.GetIMAPSSL()) + h.SetSMTPConnectionMode(bridge.GetSMTPSSL()) + h.SetIMAPPort(bridge.GetIMAPPort()) + h.SetSMTPPort(bridge.GetSMTPPort()) + h.SetCacheLocation(bridge.GetGluonCacheDir()) + if val, err := bridge.GetKeychainApp(); err != nil { + h.SetKeyChainPref(val) + } else { + h.SetKeyChainPref(bridge.keychains.GetDefaultHelper()) + } + h.SetPrevVersion(bridge.GetLastVersion().String()) + + safe.RLock(func() { + var splitMode = false + for _, user := range bridge.users { + if user.GetAddressMode() == vault.SplitMode { + splitMode = true + break + } + } + var nbAccount = len(bridge.users) + h.SetNbAccount(nbAccount) + h.SetSplitMode(splitMode) + + // Do not try to send if there is no user yet. + if nbAccount > 0 { + defer h.start() + } + }, bridge.usersLock) +} + +func (h *heartBeatState) start() { + h.taskLock.Lock() + defer h.taskLock.Unlock() + if h.taskStarted { + return + } + + h.taskStarted = true + + h.task.PeriodicOrTrigger(h.taskInterval, 0, func(ctx context.Context) { + logrus.Debug("Checking for heartbeat") + + h.TrySending(ctx) + }) +} + +func (h *heartBeatState) stop() { + h.taskLock.Lock() + defer h.taskLock.Unlock() + if !h.taskStarted { + return + } + + h.task.CancelAndWait() + h.taskStarted = false +} + func (bridge *Bridge) IsTelemetryAvailable(ctx context.Context) bool { var flag = true if bridge.GetTelemetryDisabled() { @@ -79,49 +162,6 @@ func (bridge *Bridge) SetLastHeartbeatSent(timestamp time.Time) error { return bridge.vault.SetLastHeartbeatSent(timestamp) } -func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) { - bridge.heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), bridge.keychains.GetDefaultHelper()) - - // Check for heartbeat when triggered. - bridge.goHeartbeat = bridge.tasks.PeriodicOrTrigger(HeartbeatCheckInterval, 0, func(ctx context.Context) { - logrus.Debug("Checking for heartbeat") - - bridge.heartbeat.TrySending(ctx) - }) - - bridge.heartbeat.SetRollout(bridge.GetUpdateRollout()) - bridge.heartbeat.SetAutoStart(bridge.GetAutostart()) - bridge.heartbeat.SetAutoUpdate(bridge.GetAutoUpdate()) - bridge.heartbeat.SetBeta(bridge.GetUpdateChannel()) - bridge.heartbeat.SetDoh(bridge.GetProxyAllowed()) - bridge.heartbeat.SetShowAllMail(bridge.GetShowAllMail()) - bridge.heartbeat.SetIMAPConnectionMode(bridge.GetIMAPSSL()) - bridge.heartbeat.SetSMTPConnectionMode(bridge.GetSMTPSSL()) - bridge.heartbeat.SetIMAPPort(bridge.GetIMAPPort()) - bridge.heartbeat.SetSMTPPort(bridge.GetSMTPPort()) - bridge.heartbeat.SetCacheLocation(bridge.GetGluonCacheDir()) - if val, err := bridge.GetKeychainApp(); err != nil { - bridge.heartbeat.SetKeyChainPref(val) - } else { - bridge.heartbeat.SetKeyChainPref(bridge.keychains.GetDefaultHelper()) - } - bridge.heartbeat.SetPrevVersion(bridge.GetLastVersion().String()) - - safe.RLock(func() { - var splitMode = false - for _, user := range bridge.users { - if user.GetAddressMode() == vault.SplitMode { - splitMode = true - break - } - } - var nbAccount = len(bridge.users) - bridge.heartbeat.SetNbAccount(nbAccount) - bridge.heartbeat.SetSplitMode(splitMode) - - // Do not try to send if there is no user yet. - if nbAccount > 0 { - defer bridge.goHeartbeat() - } - }, bridge.usersLock) +func (bridge *Bridge) GetHeartbeatPeriodicInterval() time.Duration { + return HeartbeatCheckInterval } diff --git a/internal/bridge/mocks.go b/internal/bridge/mocks.go index f0818752..d8c74098 100644 --- a/internal/bridge/mocks.go +++ b/internal/bridge/mocks.go @@ -7,6 +7,7 @@ import ( "os" "sync" "testing" + "time" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v3/internal/bridge/mocks" @@ -51,6 +52,7 @@ func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks { // this is called at start of heartbeat process. mocks.Heartbeat.EXPECT().IsTelemetryAvailable(gomock.Any()).AnyTimes() + mocks.Heartbeat.EXPECT().GetHeartbeatPeriodicInterval().AnyTimes().Return(500 * time.Millisecond) return mocks } diff --git a/internal/bridge/mocks/telemetry_mocks.go b/internal/bridge/mocks/telemetry_mocks.go index a1be4d5d..ff7aed05 100644 --- a/internal/bridge/mocks/telemetry_mocks.go +++ b/internal/bridge/mocks/telemetry_mocks.go @@ -36,6 +36,20 @@ func (m *MockHeartbeatManager) EXPECT() *MockHeartbeatManagerMockRecorder { return m.recorder } +// GetHeartbeatPeriodicInterval mocks base method. +func (m *MockHeartbeatManager) GetHeartbeatPeriodicInterval() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHeartbeatPeriodicInterval") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// GetHeartbeatPeriodicInterval indicates an expected call of GetHeartbeatPeriodicInterval. +func (mr *MockHeartbeatManagerMockRecorder) GetHeartbeatPeriodicInterval() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeartbeatPeriodicInterval", reflect.TypeOf((*MockHeartbeatManager)(nil).GetHeartbeatPeriodicInterval)) +} + // GetLastHeartbeatSent mocks base method. func (m *MockHeartbeatManager) GetLastHeartbeatSent() time.Time { m.ctrl.T.Helper() diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index 5471068f..b12fee1a 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -261,9 +261,12 @@ func (bridge *Bridge) SetTelemetryDisabled(isDisabled bool) error { return err } // If telemetry is re-enabled locally, try to send the heartbeat. - if !isDisabled { - defer bridge.goHeartbeat() + if isDisabled { + bridge.heartbeat.stop() + } else { + bridge.heartbeat.start() } + return nil } diff --git a/internal/bridge/user.go b/internal/bridge/user.go index c9c1d719..9fa60c60 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -594,7 +594,7 @@ func (bridge *Bridge) addUserWithVault( }, bridge.usersLock) // As we need at least one user to send heartbeat, try to send it. - defer bridge.goHeartbeat() + bridge.heartbeat.start() return nil } diff --git a/internal/telemetry/mocks/mocks.go b/internal/telemetry/mocks/mocks.go index a1be4d5d..ff7aed05 100644 --- a/internal/telemetry/mocks/mocks.go +++ b/internal/telemetry/mocks/mocks.go @@ -36,6 +36,20 @@ func (m *MockHeartbeatManager) EXPECT() *MockHeartbeatManagerMockRecorder { return m.recorder } +// GetHeartbeatPeriodicInterval mocks base method. +func (m *MockHeartbeatManager) GetHeartbeatPeriodicInterval() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHeartbeatPeriodicInterval") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// GetHeartbeatPeriodicInterval indicates an expected call of GetHeartbeatPeriodicInterval. +func (mr *MockHeartbeatManagerMockRecorder) GetHeartbeatPeriodicInterval() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeartbeatPeriodicInterval", reflect.TypeOf((*MockHeartbeatManager)(nil).GetHeartbeatPeriodicInterval)) +} + // GetLastHeartbeatSent mocks base method. func (m *MockHeartbeatManager) GetLastHeartbeatSent() time.Time { m.ctrl.T.Helper() diff --git a/internal/telemetry/types_heartbeat.go b/internal/telemetry/types_heartbeat.go index 0460e4b8..cad1e78c 100644 --- a/internal/telemetry/types_heartbeat.go +++ b/internal/telemetry/types_heartbeat.go @@ -42,6 +42,7 @@ type HeartbeatManager interface { SendHeartbeat(ctx context.Context, heartbeat *HeartbeatData) bool GetLastHeartbeatSent() time.Time SetLastHeartbeatSent(time.Time) error + GetHeartbeatPeriodicInterval() time.Duration } type HeartbeatValues struct { diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index e551510c..b344f446 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -166,6 +166,7 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) { t.mocks.CrashHandler, t.reporter, imap.DefaultEpochUIDValidityGenerator(), + t.heartbeat, // Logging stuff logIMAP, @@ -179,8 +180,6 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) { t.bridge = bridge t.heartbeat.setBridge(bridge) - bridge.StartHeartbeat(t.heartbeat) - return t.events.collectFrom(eventCh), nil } diff --git a/tests/ctx_heartbeat_test.go b/tests/ctx_heartbeat_test.go index 07179e00..d375616f 100644 --- a/tests/ctx_heartbeat_test.go +++ b/tests/ctx_heartbeat_test.go @@ -85,6 +85,10 @@ func (hb *heartbeatRecorder) SetLastHeartbeatSent(timestamp time.Time) error { return hb.bridge.SetLastHeartbeatSent(timestamp) } +func (hb *heartbeatRecorder) GetHeartbeatPeriodicInterval() time.Duration { + return 200 * time.Millisecond +} + func (hb *heartbeatRecorder) rejectSend() { hb.reject = true }