From c4f80103b6aa15b97bb058ed44c0a0ac6269b716 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Thu, 6 Jul 2023 12:00:19 +0200 Subject: [PATCH] fix(GODT-2774): Add external context to telemetry tasks This ensures they get cancelled if the parent context becomes invalid --- internal/bridge/heartbeat.go | 10 ++++----- internal/bridge/mocks.go | 2 +- internal/bridge/mocks/telemetry_mocks.go | 17 ++++++++-------- internal/bridge/settings.go | 2 +- internal/bridge/smtp_backend.go | 3 ++- internal/bridge/user.go | 4 ++-- internal/telemetry/heartbeat.go | 7 ++++--- internal/telemetry/heartbeat_test.go | 11 +++++----- internal/telemetry/mocks/mocks.go | 17 ++++++++-------- internal/telemetry/types_heartbeat.go | 5 +++-- internal/user/config_status.go | 26 ++++++++++++------------ internal/user/imap.go | 22 +++++++++++--------- internal/user/user.go | 2 +- internal/user/user_test.go | 2 +- tests/bridge_test.go | 2 +- tests/ctx_heartbeat_test.go | 7 ++++--- 16 files changed, 74 insertions(+), 65 deletions(-) diff --git a/internal/bridge/heartbeat.go b/internal/bridge/heartbeat.go index a24fad96..9b382c4c 100644 --- a/internal/bridge/heartbeat.go +++ b/internal/bridge/heartbeat.go @@ -32,7 +32,7 @@ import ( const HeartbeatCheckInterval = time.Hour -func (bridge *Bridge) IsTelemetryAvailable() bool { +func (bridge *Bridge) IsTelemetryAvailable(ctx context.Context) bool { var flag = true if bridge.GetTelemetryDisabled() { return false @@ -40,14 +40,14 @@ func (bridge *Bridge) IsTelemetryAvailable() bool { safe.RLock(func() { for _, user := range bridge.users { - flag = flag && user.IsTelemetryEnabled(context.Background()) + flag = flag && user.IsTelemetryEnabled(ctx) } }, bridge.usersLock) return flag } -func (bridge *Bridge) SendHeartbeat(heartbeat *telemetry.HeartbeatData) bool { +func (bridge *Bridge) SendHeartbeat(ctx context.Context, heartbeat *telemetry.HeartbeatData) bool { data, err := json.Marshal(heartbeat) if err != nil { if err := bridge.reporter.ReportMessageWithContext("Cannot parse heartbeat data.", reporter.Context{ @@ -62,7 +62,7 @@ func (bridge *Bridge) SendHeartbeat(heartbeat *telemetry.HeartbeatData) bool { safe.RLock(func() { for _, user := range bridge.users { - if err := user.SendTelemetry(context.Background(), data); err == nil { + if err := user.SendTelemetry(ctx, data); err == nil { sent = true break } @@ -87,7 +87,7 @@ func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) { bridge.goHeartbeat = bridge.tasks.PeriodicOrTrigger(HeartbeatCheckInterval, 0, func(ctx context.Context) { logrus.Debug("Checking for heartbeat") - bridge.heartbeat.TrySending() + bridge.heartbeat.TrySending(ctx) }) bridge.heartbeat.SetRollout(bridge.GetUpdateRollout()) diff --git a/internal/bridge/mocks.go b/internal/bridge/mocks.go index 63b16b8f..4412e1f2 100644 --- a/internal/bridge/mocks.go +++ b/internal/bridge/mocks.go @@ -50,7 +50,7 @@ func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks { mocks.CrashHandler.EXPECT().HandlePanic(gomock.Any()).AnyTimes() // this is called at start of heartbeat process. - mocks.Heartbeat.EXPECT().IsTelemetryAvailable().AnyTimes() + mocks.Heartbeat.EXPECT().IsTelemetryAvailable(gomock.Any()).AnyTimes() return mocks } diff --git a/internal/bridge/mocks/telemetry_mocks.go b/internal/bridge/mocks/telemetry_mocks.go index be5251e2..a1be4d5d 100644 --- a/internal/bridge/mocks/telemetry_mocks.go +++ b/internal/bridge/mocks/telemetry_mocks.go @@ -5,6 +5,7 @@ package mocks import ( + context "context" reflect "reflect" time "time" @@ -50,31 +51,31 @@ func (mr *MockHeartbeatManagerMockRecorder) GetLastHeartbeatSent() *gomock.Call } // IsTelemetryAvailable mocks base method. -func (m *MockHeartbeatManager) IsTelemetryAvailable() bool { +func (m *MockHeartbeatManager) IsTelemetryAvailable(arg0 context.Context) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsTelemetryAvailable") + ret := m.ctrl.Call(m, "IsTelemetryAvailable", arg0) ret0, _ := ret[0].(bool) return ret0 } // IsTelemetryAvailable indicates an expected call of IsTelemetryAvailable. -func (mr *MockHeartbeatManagerMockRecorder) IsTelemetryAvailable() *gomock.Call { +func (mr *MockHeartbeatManagerMockRecorder) IsTelemetryAvailable(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsTelemetryAvailable", reflect.TypeOf((*MockHeartbeatManager)(nil).IsTelemetryAvailable)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsTelemetryAvailable", reflect.TypeOf((*MockHeartbeatManager)(nil).IsTelemetryAvailable), arg0) } // SendHeartbeat mocks base method. -func (m *MockHeartbeatManager) SendHeartbeat(arg0 *telemetry.HeartbeatData) bool { +func (m *MockHeartbeatManager) SendHeartbeat(arg0 context.Context, arg1 *telemetry.HeartbeatData) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendHeartbeat", arg0) + ret := m.ctrl.Call(m, "SendHeartbeat", arg0, arg1) ret0, _ := ret[0].(bool) return ret0 } // SendHeartbeat indicates an expected call of SendHeartbeat. -func (mr *MockHeartbeatManagerMockRecorder) SendHeartbeat(arg0 interface{}) *gomock.Call { +func (mr *MockHeartbeatManagerMockRecorder) SendHeartbeat(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeartbeat", reflect.TypeOf((*MockHeartbeatManager)(nil).SendHeartbeat), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeartbeat", reflect.TypeOf((*MockHeartbeatManager)(nil).SendHeartbeat), arg0, arg1) } // SetLastHeartbeatSent mocks base method. diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index 752779a5..04aca77e 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -297,7 +297,7 @@ func (bridge *Bridge) SetColorScheme(colorScheme string) error { // Note: it does not clear the keychain. The only entry in the keychain is the vault password, // which we need at next startup to decrypt the vault. func (bridge *Bridge) FactoryReset(ctx context.Context) { - useTelemetry := bridge.IsTelemetryAvailable() + useTelemetry := bridge.IsTelemetryAvailable(ctx) // Delete all the users. safe.Lock(func() { diff --git a/internal/bridge/smtp_backend.go b/internal/bridge/smtp_backend.go index d4bacf18..ec93f524 100644 --- a/internal/bridge/smtp_backend.go +++ b/internal/bridge/smtp_backend.go @@ -18,6 +18,7 @@ package bridge import ( + "context" "fmt" "io" "strings" @@ -61,7 +62,7 @@ func (s *smtpSession) AuthPlain(username, password string) error { s.Bridge.setUserAgent(useragent.UnknownClient, useragent.DefaultVersion) } - user.SendConfigStatusSuccess() + user.SendConfigStatusSuccess(context.Background()) return nil } diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 3d95c9a8..221bab28 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -243,7 +243,7 @@ func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error { func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error { logrus.WithField("userID", userID).Info("Deleting user") - useTelemetry := bridge.IsTelemetryAvailable() + useTelemetry := bridge.IsTelemetryAvailable(ctx) return safe.LockRet(func() error { if !bridge.vault.HasUser(userID) { @@ -602,7 +602,7 @@ func (bridge *Bridge) logoutUser(ctx context.Context, user *user.User, withAPI, // if this is actually a remove account if withData && withAPI { - user.SendConfigStatusAbort(withTelemetry) + user.SendConfigStatusAbort(ctx, withTelemetry) } logrus.WithFields(logrus.Fields{ diff --git a/internal/telemetry/heartbeat.go b/internal/telemetry/heartbeat.go index 00036e05..e7d52acb 100644 --- a/internal/telemetry/heartbeat.go +++ b/internal/telemetry/heartbeat.go @@ -18,6 +18,7 @@ package telemetry import ( + "context" "strconv" "time" @@ -149,12 +150,12 @@ func (heartbeat *Heartbeat) SetPrevVersion(val string) { heartbeat.metrics.Dimensions.PrevVersion = val } -func (heartbeat *Heartbeat) TrySending() { - if heartbeat.manager.IsTelemetryAvailable() { +func (heartbeat *Heartbeat) TrySending(ctx context.Context) { + if heartbeat.manager.IsTelemetryAvailable(ctx) { lastSent := heartbeat.manager.GetLastHeartbeatSent() now := time.Now() if now.Year() > lastSent.Year() || (now.Year() == lastSent.Year() && now.YearDay() > lastSent.YearDay()) { - if !heartbeat.manager.SendHeartbeat(&heartbeat.metrics) { + if !heartbeat.manager.SendHeartbeat(ctx, &heartbeat.metrics) { heartbeat.log.WithFields(logrus.Fields{ "metrics": heartbeat.metrics, }).Error("Failed to send heartbeat") diff --git a/internal/telemetry/heartbeat_test.go b/internal/telemetry/heartbeat_test.go index 046b0d18..43f2cb33 100644 --- a/internal/telemetry/heartbeat_test.go +++ b/internal/telemetry/heartbeat_test.go @@ -18,6 +18,7 @@ package telemetry_test import ( + "context" "testing" "time" @@ -52,21 +53,21 @@ func TestHeartbeat_default_heartbeat(t *testing.T) { }, } - mock.EXPECT().IsTelemetryAvailable().Return(true) + mock.EXPECT().IsTelemetryAvailable(context.Background()).Return(true) mock.EXPECT().GetLastHeartbeatSent().Return(time.Date(2022, 6, 4, 0, 0, 0, 0, time.UTC)) - mock.EXPECT().SendHeartbeat(&data).Return(true) + mock.EXPECT().SendHeartbeat(context.Background(), &data).Return(true) mock.EXPECT().SetLastHeartbeatSent(gomock.Any()).Return(nil) - hb.TrySending() + hb.TrySending(context.Background()) }) } func TestHeartbeat_already_sent_heartbeat(t *testing.T) { withHeartbeat(t, 1143, 1025, "/tmp", "defaultKeychain", func(hb *telemetry.Heartbeat, mock *mocks.MockHeartbeatManager) { - mock.EXPECT().IsTelemetryAvailable().Return(true) + mock.EXPECT().IsTelemetryAvailable(context.Background()).Return(true) mock.EXPECT().GetLastHeartbeatSent().Return(time.Now().Truncate(24 * time.Hour)) - hb.TrySending() + hb.TrySending(context.Background()) }) } diff --git a/internal/telemetry/mocks/mocks.go b/internal/telemetry/mocks/mocks.go index be5251e2..a1be4d5d 100644 --- a/internal/telemetry/mocks/mocks.go +++ b/internal/telemetry/mocks/mocks.go @@ -5,6 +5,7 @@ package mocks import ( + context "context" reflect "reflect" time "time" @@ -50,31 +51,31 @@ func (mr *MockHeartbeatManagerMockRecorder) GetLastHeartbeatSent() *gomock.Call } // IsTelemetryAvailable mocks base method. -func (m *MockHeartbeatManager) IsTelemetryAvailable() bool { +func (m *MockHeartbeatManager) IsTelemetryAvailable(arg0 context.Context) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsTelemetryAvailable") + ret := m.ctrl.Call(m, "IsTelemetryAvailable", arg0) ret0, _ := ret[0].(bool) return ret0 } // IsTelemetryAvailable indicates an expected call of IsTelemetryAvailable. -func (mr *MockHeartbeatManagerMockRecorder) IsTelemetryAvailable() *gomock.Call { +func (mr *MockHeartbeatManagerMockRecorder) IsTelemetryAvailable(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsTelemetryAvailable", reflect.TypeOf((*MockHeartbeatManager)(nil).IsTelemetryAvailable)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsTelemetryAvailable", reflect.TypeOf((*MockHeartbeatManager)(nil).IsTelemetryAvailable), arg0) } // SendHeartbeat mocks base method. -func (m *MockHeartbeatManager) SendHeartbeat(arg0 *telemetry.HeartbeatData) bool { +func (m *MockHeartbeatManager) SendHeartbeat(arg0 context.Context, arg1 *telemetry.HeartbeatData) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendHeartbeat", arg0) + ret := m.ctrl.Call(m, "SendHeartbeat", arg0, arg1) ret0, _ := ret[0].(bool) return ret0 } // SendHeartbeat indicates an expected call of SendHeartbeat. -func (mr *MockHeartbeatManagerMockRecorder) SendHeartbeat(arg0 interface{}) *gomock.Call { +func (mr *MockHeartbeatManagerMockRecorder) SendHeartbeat(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeartbeat", reflect.TypeOf((*MockHeartbeatManager)(nil).SendHeartbeat), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeartbeat", reflect.TypeOf((*MockHeartbeatManager)(nil).SendHeartbeat), arg0, arg1) } // SetLastHeartbeatSent mocks base method. diff --git a/internal/telemetry/types_heartbeat.go b/internal/telemetry/types_heartbeat.go index 35abf3a7..0460e4b8 100644 --- a/internal/telemetry/types_heartbeat.go +++ b/internal/telemetry/types_heartbeat.go @@ -18,6 +18,7 @@ package telemetry import ( + "context" "time" "github.com/sirupsen/logrus" @@ -33,12 +34,12 @@ const ( ) type Availability interface { - IsTelemetryAvailable() bool + IsTelemetryAvailable(ctx context.Context) bool } type HeartbeatManager interface { Availability - SendHeartbeat(heartbeat *HeartbeatData) bool + SendHeartbeat(ctx context.Context, heartbeat *HeartbeatData) bool GetLastHeartbeatSent() time.Time SetLastHeartbeatSent(time.Time) error } diff --git a/internal/user/config_status.go b/internal/user/config_status.go index 8a93c8f3..cae0dee2 100644 --- a/internal/user/config_status.go +++ b/internal/user/config_status.go @@ -25,12 +25,12 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/configstatus" ) -func (user *User) SendConfigStatusSuccess() { +func (user *User) SendConfigStatusSuccess(ctx context.Context) { if user.configStatus.IsFromFailure() { - user.SendConfigStatusRecovery() + user.SendConfigStatusRecovery(ctx) return } - if !user.telemetryManager.IsTelemetryAvailable() { + if !user.telemetryManager.IsTelemetryAvailable(ctx) { return } if !user.configStatus.IsPending() { @@ -49,7 +49,7 @@ func (user *User) SendConfigStatusSuccess() { return } - if err := user.SendTelemetry(context.Background(), data); err == nil { + if err := user.SendTelemetry(ctx, data); err == nil { user.log.Info("Configuration Status Success event sent.") if err := user.configStatus.ApplySuccess(); err != nil { user.log.WithError(err).Error("Failed to ApplySuccess on config_status.") @@ -57,7 +57,7 @@ func (user *User) SendConfigStatusSuccess() { } } -func (user *User) SendConfigStatusAbort(withTelemetry bool) { +func (user *User) SendConfigStatusAbort(ctx context.Context, withTelemetry bool) { if err := user.configStatus.Remove(); err != nil { user.log.WithError(err).Error("Failed to remove config_status file.") } @@ -80,17 +80,17 @@ func (user *User) SendConfigStatusAbort(withTelemetry bool) { return } - if err := user.SendTelemetry(context.Background(), data); err == nil { + if err := user.SendTelemetry(ctx, data); err == nil { user.log.Info("Configuration Status Abort event sent.") } } -func (user *User) SendConfigStatusRecovery() { +func (user *User) SendConfigStatusRecovery(ctx context.Context) { if !user.configStatus.IsFromFailure() { - user.SendConfigStatusSuccess() + user.SendConfigStatusSuccess(ctx) return } - if !user.telemetryManager.IsTelemetryAvailable() { + if !user.telemetryManager.IsTelemetryAvailable(ctx) { return } if !user.configStatus.IsPending() { @@ -109,7 +109,7 @@ func (user *User) SendConfigStatusRecovery() { return } - if err := user.SendTelemetry(context.Background(), data); err == nil { + if err := user.SendTelemetry(ctx, data); err == nil { user.log.Info("Configuration Status Recovery event sent.") if err := user.configStatus.ApplySuccess(); err != nil { user.log.WithError(err).Error("Failed to ApplySuccess on config_status.") @@ -117,8 +117,8 @@ func (user *User) SendConfigStatusRecovery() { } } -func (user *User) SendConfigStatusProgress() { - if !user.telemetryManager.IsTelemetryAvailable() { +func (user *User) SendConfigStatusProgress(ctx context.Context) { + if !user.telemetryManager.IsTelemetryAvailable(ctx) { return } if !user.configStatus.IsPending() { @@ -143,7 +143,7 @@ func (user *User) SendConfigStatusProgress() { return } - if err := user.SendTelemetry(context.Background(), data); err == nil { + if err := user.SendTelemetry(ctx, data); err == nil { user.log.Info("Configuration Status Progress event sent.") if err := user.configStatus.ApplyProgress(); err != nil { user.log.WithError(err).Error("Failed to ApplyProgress on config_status.") diff --git a/internal/user/imap.go b/internal/user/imap.go index d2bd53e0..87afa3f5 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -76,7 +76,7 @@ func newIMAPConnector(user *User, addrID string) *imapConnector { } // Authorize returns whether the given username/password combination are valid for this connector. -func (conn *imapConnector) Authorize(_ context.Context, username string, password []byte) bool { +func (conn *imapConnector) Authorize(ctx context.Context, username string, password []byte) bool { addrID, err := conn.CheckAuth(username, password) if err != nil { return false @@ -86,7 +86,7 @@ func (conn *imapConnector) Authorize(_ context.Context, username string, passwor return false } - conn.User.SendConfigStatusSuccess() + conn.User.SendConfigStatusSuccess(ctx) return true } @@ -355,15 +355,17 @@ func (conn *imapConnector) CreateMessage( } msg, literal, err := conn.importMessage(ctx, literal, wantLabelIDs, wantFlags, unread) - if err != nil && errors.Is(err, proton.ErrImportSizeExceeded) { - // Remap error so that Gluon does not put this message in the recovery mailbox. - err = fmt.Errorf("%v: %w", err, connector.ErrMessageSizeExceedsLimits) - } + if err != nil { + if errors.Is(err, proton.ErrImportSizeExceeded) { + // Remap error so that Gluon does not put this message in the recovery mailbox. + err = fmt.Errorf("%v: %w", err, connector.ErrMessageSizeExceedsLimits) + } - if apiErr := new(proton.APIError); errors.As(err, &apiErr) { - logrus.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("Failed to import message") - } else { - logrus.WithError(err).Error("Failed to import message") + if apiErr := new(proton.APIError); errors.As(err, &apiErr) { + logrus.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("Failed to import message") + } else { + logrus.WithError(err).Error("Failed to import message") + } } return msg, literal, err diff --git a/internal/user/user.go b/internal/user/user.go index cc27b519..49d29ab7 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -180,7 +180,7 @@ func New( // Check for status_progress when triggered. user.goStatusProgress = user.tasks.PeriodicOrTrigger(configstatus.ProgressCheckInterval, 0, func(ctx context.Context) { - user.SendConfigStatusProgress() + user.SendConfigStatusProgress(ctx) }) defer user.goStatusProgress() diff --git a/internal/user/user_test.go b/internal/user/user_test.go index 5b0ada78..a2f34d18 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -148,7 +148,7 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma ctl := gomock.NewController(tb) defer ctl.Finish() manager := mocks.NewMockHeartbeatManager(ctl) - manager.EXPECT().IsTelemetryAvailable().AnyTimes() + manager.EXPECT().IsTelemetryAvailable(context.Background()).AnyTimes() user, err := New(ctx, vaultUser, client, nil, apiUser, nil, true, vault.DefaultMaxSyncMemory, tb.TempDir(), manager) require.NoError(tb, err) defer user.Close() diff --git a/tests/bridge_test.go b/tests/bridge_test.go index fd5f2324..0989466c 100644 --- a/tests/bridge_test.go +++ b/tests/bridge_test.go @@ -314,7 +314,7 @@ func (s *scenario) bridgeTelemetryFeatureDisabled() error { } func (s *scenario) checkTelemetry(expect bool) error { - res := s.t.bridge.IsTelemetryAvailable() + res := s.t.bridge.IsTelemetryAvailable(context.Background()) if res != expect { return fmt.Errorf("expected telemetry feature %v but got %v ", expect, res) } diff --git a/tests/ctx_heartbeat_test.go b/tests/ctx_heartbeat_test.go index babf4f9e..07179e00 100644 --- a/tests/ctx_heartbeat_test.go +++ b/tests/ctx_heartbeat_test.go @@ -18,6 +18,7 @@ package tests import ( + "context" "errors" "testing" "time" @@ -54,14 +55,14 @@ func (hb *heartbeatRecorder) GetLastHeartbeatSent() time.Time { return hb.bridge.GetLastHeartbeatSent() } -func (hb *heartbeatRecorder) IsTelemetryAvailable() bool { +func (hb *heartbeatRecorder) IsTelemetryAvailable(ctx context.Context) bool { if hb.bridge == nil { return false } - return hb.bridge.IsTelemetryAvailable() + return hb.bridge.IsTelemetryAvailable(ctx) } -func (hb *heartbeatRecorder) SendHeartbeat(metrics *telemetry.HeartbeatData) bool { +func (hb *heartbeatRecorder) SendHeartbeat(_ context.Context, metrics *telemetry.HeartbeatData) bool { if hb.bridge == nil { return false }