diff --git a/Changelog.md b/Changelog.md index 5b58614c..7eea40de 100644 --- a/Changelog.md +++ b/Changelog.md @@ -16,6 +16,11 @@ Changelog [format](http://keepachangelog.com/en/1.0.0/) * Adding DSN Sentry as build time parameter * GODT-124 bump go-appdir from v1.0.0 to v1.1.0 * CSB-72 Skip processing message update event if http statuscode is 422 +* Skip processing message update event if http statuscode is 422 +* GODT-204 `pmapi.TokenManager` replaced by `pmapi.ClientManager` + * `expiresAt` is no longer part of client + * TODO Please fill here all logic changes + ### Fixed * Use correct binary name when finding location of addcert.scpt diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index a81a5313..edf8eb34 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -181,18 +181,20 @@ func (b *Bridge) watchBridgeOutdated() { // watchAPIAuths receives auths from the client manager and sends them to the appropriate user. func (b *Bridge) watchAPIAuths() { for auth := range b.clientManager.GetAuthUpdateChannel() { - logrus.Debug("Bridge received auth from ClientManager") + log.Debug("Bridge received auth from ClientManager") user, ok := b.hasUser(auth.UserID) if !ok { - logrus.WithField("userID", auth.UserID).Info("User not available for auth update") + log.WithField("userID", auth.UserID).Info("User not available for auth update") continue } if auth.Auth != nil { user.updateAuthToken(auth.Auth) - } else { - user.logout() + } else if err := user.logout(); err != nil { + log.WithError(err). + WithField("userID", auth.UserID). + Error("User logout failed while watching API auths") } } } @@ -241,6 +243,14 @@ func (b *Bridge) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPasswo if err == pmapi.ErrUpgradeApplication { b.events.Emit(events.UpgradeApplicationEvent, "") } + if err != nil { + log.WithError(err).Debug("Login not finished; removing auth session") + if delAuthErr := authClient.DeleteAuth(); delAuthErr != nil { + log.WithError(delAuthErr).Error("Failed to clear login session after unlock") + } + } + // The anonymous client will be removed from list and authentication will not be deleted. + authClient.Logout() }() apiUser, hashedPassword, err := getAPIUser(authClient, auth, mbPassword) @@ -249,7 +259,8 @@ func (b *Bridge) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPasswo return } - if user, err = b.GetUser(apiUser.ID); err == nil { + var ok bool + if user, ok = b.hasUser(apiUser.ID); ok { if err = b.connectExistingUser(user, auth, hashedPassword); err != nil { log.WithError(err).Error("Failed to connect existing user") return @@ -305,7 +316,7 @@ func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword s return errors.Wrap(err, "failed to refresh token in new client") } - if user, err = client.UpdateUser(); err != nil { + if user, err = client.CurrentUser(); err != nil { return errors.Wrap(err, "failed to update API user") } @@ -330,7 +341,7 @@ func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword s b.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel)) - return + return err } func getAPIUser(client pmapi.Client, auth *pmapi.Auth, mbPassword string) (user *pmapi.User, hashedPassword string, err error) { @@ -340,7 +351,13 @@ func getAPIUser(client pmapi.Client, auth *pmapi.Auth, mbPassword string) (user return } - if user, err = client.UpdateUser(); err != nil { + // We unlock the user's PGP key here to detect if the user's mailbox password is wrong. + if _, err = client.Unlock(hashedPassword); err != nil { + log.WithError(err).Error("Wrong mailbox password") + return + } + + if user, err = client.CurrentUser(); err != nil { log.WithError(err).Error("Could not load API user") return } diff --git a/internal/bridge/bridge_login_test.go b/internal/bridge/bridge_login_test.go index fc9d376b..bd2f0663 100644 --- a/internal/bridge/bridge_login_test.go +++ b/internal/bridge/bridge_login_test.go @@ -29,17 +29,20 @@ import ( "github.com/stretchr/testify/assert" ) -func TestBridgeFinishLoginBadPassword(t *testing.T) { +func TestBridgeFinishLoginBadMailboxPassword(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - // Init bridge with no user from keychain. - m.credentialsStore.EXPECT().List().Return([]string{}, nil) - - // Set up mocks for FinishLogin. err := errors.New("bad password") - m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err) - m.pmapiClient.EXPECT().Logout() + gomock.InOrder( + // Init bridge with no user from keychain. + m.credentialsStore.EXPECT().List().Return([]string{}, nil), + + // Set up mocks for FinishLogin. + m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err), + m.pmapiClient.EXPECT().DeleteAuth(), + m.pmapiClient.EXPECT().Logout(), + ) checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", err) } @@ -48,15 +51,18 @@ func TestBridgeFinishLoginUpgradeApplication(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - // Init bridge with no user from keychain. - m.credentialsStore.EXPECT().List().Return([]string{}, nil) - - // Set up mocks for FinishLogin. - m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, pmapi.ErrUpgradeApplication) - - m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "") err := errors.New("Cannot logout when upgrade needed") - m.pmapiClient.EXPECT().Logout().Return(err) + gomock.InOrder( + // Init bridge with no user from keychain. + m.credentialsStore.EXPECT().List().Return([]string{}, nil), + + // Set up mocks for FinishLogin. + m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, pmapi.ErrUpgradeApplication), + + m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, ""), + m.pmapiClient.EXPECT().DeleteAuth().Return(err), + m.pmapiClient.EXPECT().Logout(), + ) checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", pmapi.ErrUpgradeApplication) } @@ -79,49 +85,57 @@ func TestBridgeFinishLoginNewUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - // Bridge finds no users in the keychain. - m.credentialsStore.EXPECT().List().Return([]string{}, nil) + // Basically every call client has get client manager + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - // Get user to be able to setup new client with proper userID. - m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) - m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil) + gomock.InOrder( + // Bridge finds no users in the keychain. + m.credentialsStore.EXPECT().List().Return([]string{}, nil), - // Setup of new client. - m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil) - m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil) - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) + // Get user to be able to setup new client with proper userID. + m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), + m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), - // Set up mocks for authorising the new user (in user.init). - m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}) - m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil).Times(2) - m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil) - m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken("afterCredentials"), nil) - m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil) + // bridge.Bridge.addNewUser(() + m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil), + m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - m.credentialsStore.EXPECT().UpdateToken("user", ":afterCredentials").Return(nil) + // bridge.newUser() + m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}), + m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil).Times(2), - // Set up mocks for creating the user's store (in store.New). - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) + // bridge.User.init() + m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil), + //TODO m.credentialsStore.EXPECT().UpdateToken("user", ":afterCredentials").Return(nil), + //TODO m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken("afterCredentials"), nil), - // Emit event for new user and send metrics. - m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") - m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)) + // authorize if necessary + m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), + m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil), - // Set up mocks for starting the store's event loop (in store.New). - // The event loop runs in another goroutine so this might happen at any time. - m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) + // Set up mocks for creating the user's store (in store.New). + m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), + m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - // Set up mocks for performing the initial store sync. - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil) + // Emit event for new user and send metrics. + m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient), + m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)), + m.pmapiClient.EXPECT().Logout(), + + // Reload account list in GUI. + m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"), + // defer logout anonymous + m.pmapiClient.EXPECT().Logout(), + ) + + mockEventLoopNoAction(m) checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil) } -func TestBridgeFinishLoginExistingUser(t *testing.T) { +func TestBridgeFinishLoginExistingDisconnectedUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() @@ -129,87 +143,82 @@ func TestBridgeFinishLoginExistingUser(t *testing.T) { loggedOutCreds.APIToken = "" loggedOutCreds.MailboxPassword = "" - // Bridge finds one logged out user in the keychain. - m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) - // New user - m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil) - // Init user - m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil) - m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrInvalidToken) - m.pmapiClient.EXPECT().Addresses().Return(nil) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - // Get user to be able to setup new client with proper userID. - m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) - m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil) + gomock.InOrder( + // Bridge finds one logged out user in the keychain. + m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil), + // New user + m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil), + // Init user + m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil), + m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrInvalidToken), + m.pmapiClient.EXPECT().Addresses().Return(nil), - // Setup of new client. - m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil) - m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil) - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) + // Get user to be able to setup new client with proper userID. + m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), + m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), - // Set up mocks for authorising the new user (in user.init). - m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}) - m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil) - m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil) - m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken("afterCredentials"), nil) - m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil) + // bridge.Bridge.connectExistingUser + m.credentialsStore.EXPECT().UpdatePassword("user", testCredentials.MailboxPassword).Return(nil), + m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil), + m.credentialsStore.EXPECT().UpdateToken("user", ":afterLogin").Return(nil), - m.credentialsStore.EXPECT().UpdateToken("user", ":afterCredentials").Return(nil) + // bridge.User.init() + m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil), + m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil), - // Set up mocks for creating the user's store (in store.New). - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) + // authorize if necessary + m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), + m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil), - // Reload account list in GUI. - m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") + /* TODO + // Set up mocks for authorising the new user (in user.init). + m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}), + m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken("afterCredentials"), nil), - // Set up mocks for starting the store's event loop (in store.New) - // The event loop runs in another goroutine so this might happen at any time. - m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) + m.credentialsStore.EXPECT().UpdateToken("user", ":afterCredentials").Return(nil), + */ - // Set up mocks for performing the initial store sync. - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil) + // Set up mocks for creating the user's store (in store.New). + m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), + m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), + + // Reload account list in GUI. + m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"), + // defer logout anonymous + m.pmapiClient.EXPECT().Logout(), + ) + + mockEventLoopNoAction(m) checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil) } -func TestBridgeDoubleLogin(t *testing.T) { +func TestBridgeFinishLoginConnectedUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - // Firstly, start bridge with existing user... - + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil) - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) - - m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil) - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) + mockConnectedUser(m) + mockEventLoopNoAction(m) bridge := testNewBridge(t, m) defer cleanUpBridgeUserData(bridge) // Then, try to log in again... - - m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) - m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil) - m.pmapiClient.EXPECT().Logout() + gomock.InOrder( + m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), + m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), + m.pmapiClient.EXPECT().DeleteAuth(), + m.pmapiClient.EXPECT().Logout(), + ) _, err := bridge.FinishLogin(m.pmapiClient, testAuth, testCredentials.MailboxPassword) - assert.Equal(t, "user is already logged in", err.Error()) + assert.Equal(t, "user is already connected", err.Error()) } func checkBridgeFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) { diff --git a/internal/bridge/bridge_new_test.go b/internal/bridge/bridge_new_test.go index f7d5bf20..fc0024f6 100644 --- a/internal/bridge/bridge_new_test.go +++ b/internal/bridge/bridge_new_test.go @@ -52,11 +52,16 @@ func TestNewBridgeWithDisconnectedUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil).Times(2) - m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")) - m.pmapiClient.EXPECT().Addresses().Return(nil) - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient) + // Basically every call client has get client manager. + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + + gomock.InOrder( + m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), + m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")), + m.pmapiClient.EXPECT().Addresses().Return(nil), + ) checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) } @@ -65,9 +70,10 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) m.credentialsStore.EXPECT().Logout("user").Return(nil) m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")) @@ -82,26 +88,33 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) { checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) } +func mockConnectedUser(m mocks) { + gomock.InOrder( + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), + //TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil), + + m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), + m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil), + + // Set up mocks for store initialisation for the authorized user. + m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), + m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), + ) +} + func TestNewBridgeWithConnectedUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) - // Set up mocks for store initialisation for the authorized user. - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) - m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) + mockConnectedUser(m) + mockEventLoopNoAction(m) checkBridgeNew(t, m, []*credentials.Credentials{testCredentials}) } @@ -112,27 +125,22 @@ func TestNewBridgeWithUsers(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + m.credentialsStore.EXPECT().List().Return([]string{"userDisconnected", "user"}, nil) - m.credentialsStore.EXPECT().List().Return([]string{"user", "user"}, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil).Times(2) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) + gomock.InOrder( + m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil), + m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil), + // Set up mocks for store initialisation for the unauth user. + m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient), + m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")), + m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient), + m.pmapiClient.EXPECT().Addresses().Return(nil), + ) - // Set up mocks for store initialisation for the unauth user. - m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")) - m.pmapiClient.EXPECT().Addresses().Return(nil) + mockConnectedUser(m) - // Set up mocks for store initialisation for the authorized user. - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) - m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) + mockEventLoopNoAction(m) checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials}) } @@ -141,9 +149,13 @@ func TestNewBridgeFirstStart(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.prefProvider.EXPECT().GetBool(preferences.FirstStartKey).Return(true) - m.credentialsStore.EXPECT().List().Return([]string{}, nil) - m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.FirstStart), gomock.Any()) + gomock.InOrder( + m.credentialsStore.EXPECT().List().Return([]string{}, nil), + m.prefProvider.EXPECT().GetBool(preferences.FirstStartKey).Return(true), + m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient), + m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.FirstStart), gomock.Any()), + m.pmapiClient.EXPECT().Logout(), + ) testNewBridge(t, m) } diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index f8dd2820..95f866e2 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -18,8 +18,10 @@ package bridge import ( + "fmt" "io/ioutil" "os" + "runtime/debug" "testing" "time" @@ -37,6 +39,9 @@ import ( ) func TestMain(m *testing.M) { + if os.Getenv("VERBOSITY") == "fatal" { + logrus.SetLevel(logrus.FatalLevel) + } if os.Getenv("VERBOSITY") == "trace" { logrus.SetLevel(logrus.TraceLevel) } @@ -138,8 +143,27 @@ type mocks struct { storeCache *store.Cache } +type fullStackReporter struct { + T testing.TB +} + +func (fr *fullStackReporter) Errorf(format string, args ...interface{}) { + fmt.Printf("err: "+format+"\n", args...) + fr.T.Fail() +} +func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) { + debug.PrintStack() + fmt.Printf("fail: "+format+"\n", args...) + fr.T.FailNow() +} + func initMocks(t *testing.T) mocks { - mockCtrl := gomock.NewController(t) + var mockCtrl *gomock.Controller + if os.Getenv("VERBOSITY") == "trace" { + mockCtrl = gomock.NewController(&fullStackReporter{t}) + } else { + mockCtrl = gomock.NewController(t) + } cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db") require.NoError(t, err, "could not get temporary file for store cache") @@ -172,35 +196,38 @@ func initMocks(t *testing.T) mocks { } func testNewBridgeWithUsers(t *testing.T, m mocks) *Bridge { - // Init for user. - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil) - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) + // Events are asynchronous + m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).Times(2) + m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2) + m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2) - // Init for users. - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) - m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses) - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) - m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil).Times(2) - m.credentialsStore.EXPECT().UpdateToken("users", ":reftok").Return(nil) - m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil) - m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil) - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) + gomock.InOrder( + m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil), - m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil) + // Init for user. + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), + // TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) + // TODO m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) + m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), + m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil), + m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), + m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), + + // Init for users. + m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil), + m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil), + m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), + // TODO m.credentialsStore.EXPECT().UpdateToken("users", ":reftok").Return(nil), + // TODO m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil), + m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), + m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil), + m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), + m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), + m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses), + ) return testNewBridge(t, m) } @@ -214,7 +241,7 @@ func testNewBridge(t *testing.T, m mocks) *Bridge { m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes() m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes() m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any()) - m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan *pmapi.ClientAuth)) + m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan pmapi.ClientAuth)) bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore) @@ -233,6 +260,9 @@ func TestClearData(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) + bridge := testNewBridgeWithUsers(t, m) defer cleanUpBridgeUserData(bridge) @@ -255,3 +285,14 @@ func TestClearData(t *testing.T) { waitForEvents() } + +func mockEventLoopNoAction(m mocks) { + // Set up mocks for starting the store's event loop (in store.New). + // The event loop runs in another goroutine so this might happen at any time. + gomock.InOrder( + m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil), + m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil), + // Set up mocks for performing the initial store sync. + m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil), + ) +} diff --git a/internal/bridge/bridge_users_test.go b/internal/bridge/bridge_users_test.go index a3cf48a9..650be22c 100644 --- a/internal/bridge/bridge_users_test.go +++ b/internal/bridge/bridge_users_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/ProtonMail/proton-bridge/internal/events" + gomock "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) @@ -29,6 +30,9 @@ func TestGetNoUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) + checkBridgeGetUser(t, m, "nouser", -1, "user nouser not found") } @@ -36,6 +40,9 @@ func TestGetUserByID(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) + checkBridgeGetUser(t, m, "user", 0, "") checkBridgeGetUser(t, m, "users", 1, "") } @@ -44,6 +51,9 @@ func TestGetUserByName(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) + checkBridgeGetUser(t, m, "username", 0, "") checkBridgeGetUser(t, m, "usersname", 1, "") } @@ -52,6 +62,9 @@ func TestGetUserByEmail(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) + checkBridgeGetUser(t, m, "user@pm.me", 0, "") checkBridgeGetUser(t, m, "users@pm.me", 1, "") checkBridgeGetUser(t, m, "anotheruser@pm.me", 1, "") @@ -62,14 +75,18 @@ func TestDeleteUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) + bridge := testNewBridgeWithUsers(t, m) defer cleanUpBridgeUserData(bridge) - m.pmapiClient.EXPECT().Logout().Return(nil) - - m.credentialsStore.EXPECT().Logout("user").Return(nil) - m.credentialsStore.EXPECT().Delete("user").Return(nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) + gomock.InOrder( + m.pmapiClient.EXPECT().Logout().Return(), + m.credentialsStore.EXPECT().Logout("user").Return(nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), + m.credentialsStore.EXPECT().Delete("user").Return(nil), + ) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") @@ -83,14 +100,20 @@ func TestDeleteUserWithFailingLogout(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) + bridge := testNewBridgeWithUsers(t, m) defer cleanUpBridgeUserData(bridge) - m.pmapiClient.EXPECT().Logout().Return(nil) - - m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")) - m.credentialsStore.EXPECT().Delete("user").Return(nil).Times(2) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) + gomock.InOrder( + m.pmapiClient.EXPECT().Logout().Return(), + m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")), + m.credentialsStore.EXPECT().Delete("user").Return(nil), + m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("no such user")), + m.credentialsStore.EXPECT().Delete("user").Return(nil), + //TODO m.credentialsStore.EXPECT().Delete("user").Return(errors.New("no such user")), + ) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") diff --git a/internal/bridge/mocks/mocks.go b/internal/bridge/mocks/mocks.go index 0463001e..d1256dfb 100644 --- a/internal/bridge/mocks/mocks.go +++ b/internal/bridge/mocks/mocks.go @@ -5,11 +5,10 @@ package mocks import ( - reflect "reflect" - credentials "github.com/ProtonMail/proton-bridge/internal/bridge/credentials" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" gomock "github.com/golang/mock/gomock" + reflect "reflect" ) // MockConfiger is a mock of Configer interface @@ -272,10 +271,10 @@ func (m *MockClientManager) GetAuthUpdateChannel() chan pmapi.ClientAuth { return ret0 } -// GetAuthUpdateChannel indicates an expected call of GetBridgeAuthChannel +// GetAuthUpdateChannel indicates an expected call of GetAuthUpdateChannel func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetAuthUpdateChannel)) } // GetClient mocks base method diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 41a7d160..32080616 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -409,7 +409,7 @@ func (u *User) GetBridgePassword() string { } // CheckBridgeLogin checks whether the user is logged in and the bridge -// password is correct. +// IMAP/SMTP password is correct. func (u *User) CheckBridgeLogin(password string) error { if isApplicationOutdated { u.listener.Emit(events.UpgradeApplicationEvent, "") diff --git a/internal/bridge/user_credentials_test.go b/internal/bridge/user_credentials_test.go index 688b43e1..3f1c1e70 100644 --- a/internal/bridge/user_credentials_test.go +++ b/internal/bridge/user_credentials_test.go @@ -34,16 +34,24 @@ func TestUpdateUser(t *testing.T) { user := testNewUser(m) defer cleanUpUserData(user) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) + gomock.InOrder( + m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), + m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), + m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil), - m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil) - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) + m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil), + m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), + m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) + m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}), + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + ) + + gomock.InOrder( + m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1), + m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1), + ) assert.NoError(t, user.UpdateUser()) @@ -105,9 +113,12 @@ func TestLogoutUser(t *testing.T) { user := testNewUserForLogout(m) defer cleanUpUserData(user) - m.pmapiClient.EXPECT().Logout().Return(nil) - m.credentialsStore.EXPECT().Logout("user").Return(nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) + gomock.InOrder( + m.pmapiClient.EXPECT().Logout().Return(), + m.credentialsStore.EXPECT().Logout("user").Return(nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), + ) + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") err := user.Logout() @@ -124,10 +135,12 @@ func TestLogoutUserFailsLogout(t *testing.T) { user := testNewUserForLogout(m) defer cleanUpUserData(user) - m.pmapiClient.EXPECT().Logout().Return(nil) - m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")) - m.credentialsStore.EXPECT().Delete("user").Return(nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) + gomock.InOrder( + m.pmapiClient.EXPECT().Logout().Return(), + m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")), + m.credentialsStore.EXPECT().Delete("user").Return(nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), + ) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") err := user.Logout() @@ -135,15 +148,20 @@ func TestLogoutUserFailsLogout(t *testing.T) { assert.NoError(t, err) } -func TestCheckBridgeLogin(t *testing.T) { +func TestCheckBridgeLoginOK(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() user := testNewUser(m) defer cleanUpUserData(user) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) + gomock.InOrder( + // TODO why u.HasAPIAuth() = false + // TODO why not :reftoken + m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), + m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), + m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil), + ) err := user.CheckBridgeLogin(testCredentials.BridgePassword) @@ -162,11 +180,12 @@ func TestCheckBridgeLoginUpgradeApplication(t *testing.T) { m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "") isApplicationOutdated = true + err := user.CheckBridgeLogin("any-pass") waitForEvents() - isApplicationOutdated = false - assert.Equal(t, pmapi.ErrUpgradeApplication, err) + + isApplicationOutdated = false } func TestCheckBridgeLoginLoggedOut(t *testing.T) { @@ -174,19 +193,29 @@ func TestCheckBridgeLoginLoggedOut(t *testing.T) { defer m.ctrl.Finish() m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) - m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient) - user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") - m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")) - m.pmapiClient.EXPECT().Addresses().Return(nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) - _ = user.init(nil) + user, err := newUser( + m.PanicHandler, "user", + m.eventListener, m.credentialsStore, + m.clientManager, m.storeCache, "/tmp", + ) + assert.NoError(t, err) + + m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient).MinTimes(1) + gomock.InOrder( + m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), + m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")), + m.pmapiClient.EXPECT().Addresses().Return(nil), + ) + + err = user.init(nil) + assert.Error(t, err) defer cleanUpUserData(user) m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") - err := user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword) + err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword) waitForEvents() assert.Equal(t, "bridge account is logged out, use bridge to login again", err.Error()) @@ -199,8 +228,13 @@ func TestCheckBridgeLoginBadPassword(t *testing.T) { user := testNewUser(m) defer cleanUpUserData(user) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) + gomock.InOrder( + // TODO why u.HasAPIAuth() = false + // TODO why not :reftoken + m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), + m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), + m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil), + ) err := user.CheckBridgeLogin("wrong!") waitForEvents() diff --git a/internal/bridge/user_new_test.go b/internal/bridge/user_new_test.go index 8344a124..444fae70 100644 --- a/internal/bridge/user_new_test.go +++ b/internal/bridge/user_new_test.go @@ -41,12 +41,16 @@ func TestNewUserBridgeOutdated(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().Logout("user").Return(nil).AnyTimes() - m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrUpgradeApplication).AnyTimes() - m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "").AnyTimes() - m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUpgradeApplication) - m.pmapiClient.EXPECT().Addresses().Return(nil) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + + gomock.InOrder( + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrUpgradeApplication), + m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, ""), + m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUpgradeApplication), + m.pmapiClient.EXPECT().Addresses().Return(nil), + ) checkNewUser(m) } @@ -55,13 +59,18 @@ func TestNewUserNoInternetConnection(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrAPINotReachable).AnyTimes() - m.eventListener.EXPECT().Emit(events.InternetOffEvent, "").AnyTimes() + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.pmapiClient.EXPECT().Addresses().Return(nil) - m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrAPINotReachable) - m.pmapiClient.EXPECT().GetEvent("").Return(nil, pmapi.ErrAPINotReachable).AnyTimes() + gomock.InOrder( + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrAPINotReachable), + m.eventListener.EXPECT().Emit(events.InternetOffEvent, ""), + + m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrAPINotReachable), + m.pmapiClient.EXPECT().Addresses().Return(nil), + m.pmapiClient.EXPECT().GetEvent("").Return(nil, pmapi.ErrAPINotReachable).AnyTimes(), + ) checkNewUser(m) } @@ -70,17 +79,22 @@ func TestNewUserAuthRefreshFails(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().Logout("user").Return(nil) - m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")).AnyTimes() - + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") - m.pmapiClient.EXPECT().Logout().Return(nil) - m.credentialsStore.EXPECT().Logout("user").Return(nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") + gomock.InOrder( + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")), + m.credentialsStore.EXPECT().Logout("user").Return(nil), + + m.pmapiClient.EXPECT().Logout(), + m.credentialsStore.EXPECT().Logout("user").Return(nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), + ) + checkNewUserDisconnected(m) } @@ -88,21 +102,25 @@ func TestNewUserUnlockFails(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - m.credentialsStore.EXPECT().Logout("user").Return(nil) - - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password")) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") - m.pmapiClient.EXPECT().Logout().Return(nil) - m.credentialsStore.EXPECT().Logout("user").Return(nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") + gomock.InOrder( + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + // TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil), + m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), + + m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password")), + m.credentialsStore.EXPECT().Logout("user").Return(nil), + m.pmapiClient.EXPECT().Logout(), + m.credentialsStore.EXPECT().Logout("user").Return(nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), + ) + checkNewUserDisconnected(m) } @@ -110,22 +128,26 @@ func TestNewUserUnlockAddressesFails(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - m.credentialsStore.EXPECT().Logout("user").Return(nil) - - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(errors.New("bad password")) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") - m.pmapiClient.EXPECT().Logout().Return(nil) - m.credentialsStore.EXPECT().Logout("user").Return(nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") + gomock.InOrder( + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + // TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil), + m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), + + m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), + m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(errors.New("bad password")), + m.credentialsStore.EXPECT().Logout("user").Return(nil), + m.pmapiClient.EXPECT().Logout(), + m.credentialsStore.EXPECT().Logout("user").Return(nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), + ) + checkNewUserDisconnected(m) } @@ -133,21 +155,9 @@ func TestNewUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) - - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) - - m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil) - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) + mockConnectedUser(m) + mockEventLoopNoAction(m) checkNewUser(m) } diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index a39dc4e3..17ddaa71 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -27,21 +27,15 @@ import ( ) func testNewUser(m mocks) *User { - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) + mockConnectedUser(m) - // Expectations for initial sync (when loading existing user from credentials store). - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) - m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).AnyTimes() - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil) - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes() + gomock.InOrder( + m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).MaxTimes(1), + m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1), + m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1), + ) user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") assert.NoError(m.t, err) @@ -53,21 +47,15 @@ func testNewUser(m mocks) *User { } func testNewUserForLogout(m mocks) *User { - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) - m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) + mockConnectedUser(m) - // These may or may not be hit depending on how fast the log out happens. - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil).AnyTimes() - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}).AnyTimes() - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) - m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).AnyTimes() - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes() + gomock.InOrder( + m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).MaxTimes(1), + m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1), + m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1), + ) user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") assert.NoError(m.t, err) diff --git a/pkg/pmapi/auth.go b/pkg/pmapi/auth.go index 18d6a49e..1ef2a820 100644 --- a/pkg/pmapi/auth.go +++ b/pkg/pmapi/auth.go @@ -424,8 +424,7 @@ func (c *client) Unlock(password string) (kr *pmcrypto.KeyRing, err error) { func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) { // If we don't yet have a saved access token, save this one in case the refresh fails! // That way we can try again later (see handleUnauthorizedStatus). - // TODO: - // c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken) + c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken) split := strings.Split(uidAndRefreshToken, ":") if len(split) != 2 { diff --git a/pkg/pmapi/auth_test.go b/pkg/pmapi/auth_test.go index 995b2c39..e3a8bece 100644 --- a/pkg/pmapi/auth_test.go +++ b/pkg/pmapi/auth_test.go @@ -33,8 +33,6 @@ import ( r "github.com/stretchr/testify/require" ) -var aLongTimeAgo = time.Unix(233431200, 0) - var testIdentity = &pmcrypto.Identity{ Name: "UserID", Email: "", @@ -276,10 +274,11 @@ func TestClient_AuthRefresh(t *testing.T) { auth, err := c.AuthRefresh(testUID + ":" + testRefreshToken) Ok(t, err) + Equals(t, testUID, c.uid) exp := &Auth{} *exp = *testAuth - exp.uid = "" // AuthRefresh will not return UID (only Auth returns the UID). + exp.uid = testUID // AuthRefresh will not return UID (only Auth returns the UID) we should set testUID to be able to generate token, see `GetToken` exp.accessToken = testAccessToken exp.KeySalt = "" exp.EventID = "" diff --git a/pkg/pmapi/clientmanager.go b/pkg/pmapi/clientmanager.go index 694579c0..3fad3bc2 100644 --- a/pkg/pmapi/clientmanager.go +++ b/pkg/pmapi/clientmanager.go @@ -3,6 +3,7 @@ package pmapi import ( "fmt" "net/http" + "strings" "sync" "time" @@ -10,8 +11,6 @@ import ( "github.com/sirupsen/logrus" ) -var defaultProxyUseDuration = 24 * time.Hour - // ClientManager is a manager of clients. type ClientManager struct { // newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to @@ -21,9 +20,6 @@ type ClientManager struct { config *ClientConfig roundTripper http.RoundTripper - // TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests. - // But that screws up other things like not being able to clear sensitive info during logout - // unless the client interface contains a method for that. clients map[string]Client clientsLocker sync.Locker @@ -33,17 +29,19 @@ type ClientManager struct { expirations map[string]*tokenExpiration expirationsLocker sync.Locker - host, scheme string - hostLocker sync.Locker - bridgeAuths chan ClientAuth clientAuths chan ClientAuth + host, scheme string + hostLocker sync.RWMutex + allowProxy bool proxyProvider *proxyProvider proxyUseDuration time.Duration idGen idGen + + log *logrus.Entry } type idGen int @@ -81,14 +79,16 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { expirationsLocker: &sync.Mutex{}, host: RootURL, - scheme: RootScheme, - hostLocker: &sync.Mutex{}, + scheme: rootScheme, + hostLocker: sync.RWMutex{}, bridgeAuths: make(chan ClientAuth), clientAuths: make(chan ClientAuth), proxyProvider: newProxyProvider(dohProviders, proxyQuery), - proxyUseDuration: defaultProxyUseDuration, + proxyUseDuration: proxyUseDuration, + + log: logrus.WithField("pkg", "pmapi-manager"), } cm.newClient = func(userID string) Client { @@ -97,7 +97,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { go cm.forwardClientAuths() - return + return cm } func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) { @@ -140,20 +140,20 @@ func (cm *ClientManager) LogoutClient(userID string) { delete(cm.clients, userID) go func() { - if err := client.DeleteAuth(); err != nil { - // TODO: Retry if the request failed. + if !strings.HasPrefix(userID, "anonymous-") { + if err := client.DeleteAuth(); err != nil { + // TODO: Retry if the request failed. + } } client.ClearData() cm.clearToken(userID) }() - - return } // GetRootURL returns the full root URL (scheme+host). func (cm *ClientManager) GetRootURL() string { - cm.hostLocker.Lock() - defer cm.hostLocker.Unlock() + cm.hostLocker.RLock() + defer cm.hostLocker.RUnlock() return fmt.Sprintf("%v://%v", cm.scheme, cm.host) } @@ -161,24 +161,16 @@ func (cm *ClientManager) GetRootURL() string { // getHost returns the host to make requests to. // It does not include the protocol i.e. no "https://" (use getScheme for that). func (cm *ClientManager) getHost() string { - cm.hostLocker.Lock() - defer cm.hostLocker.Unlock() + cm.hostLocker.RLock() + defer cm.hostLocker.RUnlock() return cm.host } -// getScheme returns the scheme with which to make requests to the host. -func (cm *ClientManager) getScheme() string { - cm.hostLocker.Lock() - defer cm.hostLocker.Unlock() - - return cm.scheme -} - // IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be. func (cm *ClientManager) IsProxyAllowed() bool { - cm.hostLocker.Lock() - defer cm.hostLocker.Unlock() + cm.hostLocker.RLock() + defer cm.hostLocker.RUnlock() return cm.allowProxy } @@ -202,8 +194,8 @@ func (cm *ClientManager) DisallowProxy() { // IsProxyEnabled returns whether we are currently proxying requests. func (cm *ClientManager) IsProxyEnabled() bool { - cm.hostLocker.Lock() - defer cm.hostLocker.Unlock() + cm.hostLocker.RLock() + defer cm.hostLocker.RUnlock() return cm.host != RootURL } @@ -264,6 +256,21 @@ func (cm *ClientManager) forwardClientAuths() { } } +// SetTokenIfUnset sets the token for the given userID if it wasn't already set. +// The token does not expire. +func (cm *ClientManager) SetTokenIfUnset(userID, token string) { + cm.tokensLocker.Lock() + defer cm.tokensLocker.Unlock() + + if _, ok := cm.tokens[userID]; ok { + return + } + + logrus.WithField("userID", userID).Info("Setting token because it is currently unset") + + cm.tokens[userID] = token +} + // setToken sets the token for the given userID with the given expiration time. func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) { cm.tokensLocker.Lock() @@ -275,6 +282,7 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration cm.setTokenExpiration(userID, expiration) + // TODO: This should be one go routine per all tokens. go cm.watchTokenExpiration(userID) } @@ -311,7 +319,7 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) { // If we aren't managing this client, there's nothing to do. if _, ok := cm.clients[ca.UserID]; !ok { - logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client") + logrus.WithField("userID", ca.UserID).Info("Not handling auth for unmanaged client") return } @@ -332,8 +340,12 @@ func (cm *ClientManager) watchTokenExpiration(userID string) { select { case <-expiration.timer.C: - logrus.WithField("userID", userID).Info("Auth token expired! Refreshing") - cm.clients[userID].AuthRefresh(cm.tokens[userID]) + cm.log.WithField("userID", userID).Info("Auth token expired! Refreshing") + if _, err := cm.clients[userID].AuthRefresh(cm.tokens[userID]); err != nil { + cm.log.WithField("userID", userID). + WithError(err). + Error("Token refresh failed before expiration") + } case <-expiration.cancel: logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired") diff --git a/pkg/pmapi/config.go b/pkg/pmapi/config.go index 7cb1c7a1..0e46bf62 100644 --- a/pkg/pmapi/config.go +++ b/pkg/pmapi/config.go @@ -27,10 +27,9 @@ import ( // This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod. // Default is pmapi_prod. // -// It should not contain the protocol! The protocol should be in RootScheme. +// It must not contain the protocol! The protocol should be in rootScheme. var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals] - -var RootScheme = "https" +var rootScheme = "https" //nolint[gochecknoglobals] // CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program // version and email client. diff --git a/pkg/pmapi/config_dev.go b/pkg/pmapi/config_dev.go index 42c627ea..2c6266e5 100644 --- a/pkg/pmapi/config_dev.go +++ b/pkg/pmapi/config_dev.go @@ -21,5 +21,5 @@ package pmapi func init() { RootURL = "dev.protonmail.com/api" - RootScheme = "https" + rootScheme = "https" } diff --git a/pkg/pmapi/config_local.go b/pkg/pmapi/config_local.go index c952ba76..193bd357 100644 --- a/pkg/pmapi/config_local.go +++ b/pkg/pmapi/config_local.go @@ -28,7 +28,7 @@ func init() { // Use port above 1000 which doesn't need root access to start anything on it. // Now the port is rounded pi. :-) RootURL = "127.0.0.1:3142/api" - RootScheme = "http" + rootScheme = "http" // TLS certificate is self-signed defaultTransport = &http.Transport{ diff --git a/pkg/pmapi/dialer_with_proxy_test.go b/pkg/pmapi/dialer_with_proxy_test.go index 78a90a21..693c8c03 100644 --- a/pkg/pmapi/dialer_with_proxy_test.go +++ b/pkg/pmapi/dialer_with_proxy_test.go @@ -41,7 +41,7 @@ func setTestDialerWithPinning(cm *ClientManager) (*int, *DialerWithPinning) { func TestTLSPinValid(t *testing.T) { cm := NewClientManager(testLiveConfig) cm.host = liveAPI - RootScheme = "https" + rootScheme = "https" called, _ := setTestDialerWithPinning(cm) client := cm.GetClient("pmapi" + t.Name()) diff --git a/pkg/pmapi/mocks/mocks.go b/pkg/pmapi/mocks/mocks.go index 21f66594..f9090e31 100644 --- a/pkg/pmapi/mocks/mocks.go +++ b/pkg/pmapi/mocks/mocks.go @@ -5,12 +5,11 @@ package mocks import ( - io "io" - reflect "reflect" - crypto "github.com/ProtonMail/gopenpgp/crypto" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" gomock "github.com/golang/mock/gomock" + io "io" + reflect "reflect" ) // MockClient is a mock of Client interface @@ -110,6 +109,18 @@ func (mr *MockClientMockRecorder) AuthRefresh(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRefresh", reflect.TypeOf((*MockClient)(nil).AuthRefresh), arg0) } +// ClearData mocks base method +func (m *MockClient) ClearData() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClearData") +} + +// ClearData indicates an expected call of ClearData +func (mr *MockClientMockRecorder) ClearData() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearData", reflect.TypeOf((*MockClient)(nil).ClearData)) +} + // CountMessages mocks base method func (m *MockClient) CountMessages(arg0 string) ([]*pmapi.MessagesCount, error) { m.ctrl.T.Helper() @@ -214,6 +225,20 @@ func (mr *MockClientMockRecorder) DeleteAttachment(arg0 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAttachment", reflect.TypeOf((*MockClient)(nil).DeleteAttachment), arg0) } +// DeleteAuth mocks base method +func (m *MockClient) DeleteAuth() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAuth") + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAuth indicates an expected call of DeleteAuth +func (mr *MockClientMockRecorder) DeleteAuth() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuth", reflect.TypeOf((*MockClient)(nil).DeleteAuth)) +} + // DeleteLabel mocks base method func (m *MockClient) DeleteLabel(arg0 string) error { m.ctrl.T.Helper() diff --git a/pkg/pmapi/proxy.go b/pkg/pmapi/proxy.go index e2628711..3ec2d0c0 100644 --- a/pkg/pmapi/proxy.go +++ b/pkg/pmapi/proxy.go @@ -30,7 +30,7 @@ import ( ) const ( - proxyRevertTime = 24 * time.Hour + proxyUseDuration = 24 * time.Hour proxySearchTimeout = 30 * time.Second proxyQueryTimeout = 10 * time.Second proxyLookupWait = 5 * time.Second diff --git a/pkg/pmapi/req.go b/pkg/pmapi/req.go index 899306f1..b4ff9fde 100644 --- a/pkg/pmapi/req.go +++ b/pkg/pmapi/req.go @@ -27,7 +27,6 @@ import ( // NewRequest creates a new request. func (c *client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) { - // TODO: Support other protocols (localhost needs http not https). req, err = http.NewRequest(method, c.cm.GetRootURL()+path, body) if req != nil { diff --git a/test/bridge_checks_test.go b/test/bridge_checks_test.go index 3a25b57c..e9e8ece0 100644 --- a/test/bridge_checks_test.go +++ b/test/bridge_checks_test.go @@ -180,9 +180,10 @@ func hasAPIAuth(accountName string) error { if err != nil { return internalError(err, "getting user %s", account.Username()) } - a.Eventually(ctx.GetTestingT(), func() bool { - return bridgeUser.HasAPIAuth() - }, 5*time.Second, 10*time.Millisecond) + a.Eventually(ctx.GetTestingT(), + bridgeUser.HasAPIAuth, + 5*time.Second, 10*time.Millisecond, + ) return ctx.GetTestingError() } diff --git a/test/context/bridge.go b/test/context/bridge.go index 699460ef..4b088519 100644 --- a/test/context/bridge.go +++ b/test/context/bridge.go @@ -24,7 +24,6 @@ import ( "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/pkg/listener" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" ) // GetBridge returns bridge instance. @@ -60,7 +59,7 @@ func newBridgeInstance( cfg *fakeConfig, credStore bridge.CredentialsStorer, eventListener listener.Listener, - clientManager *pmapi.ClientManager, + clientManager bridge.ClientManager, ) *bridge.Bridge { version := os.Getenv("VERSION") bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "") diff --git a/test/fakeapi/auth.go b/test/fakeapi/auth.go index e9b22d11..428f51b7 100644 --- a/test/fakeapi/auth.go +++ b/test/fakeapi/auth.go @@ -148,7 +148,9 @@ func (api *FakePMAPI) AuthRefresh(token string) (*pmapi.Auth, error) { } func (api *FakePMAPI) Logout() { - api.DeleteAuth() + if err := api.DeleteAuth(); err != nil { + api.log.WithError(err).Error("delete auth failed during logout") + } api.ClearData() } diff --git a/test/liveapi/transport.go b/test/liveapi/transport.go index dd001e21..4a96da44 100644 --- a/test/liveapi/transport.go +++ b/test/liveapi/transport.go @@ -33,7 +33,7 @@ func (ctl *Controller) TurnInternetConnectionOn() { } type fakeTransport struct { - ctl *Controller + ctl *Controller transport http.RoundTripper }