From e95aece6d3fab51ffbdde9aece552886b1d6a718 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Thu, 16 Apr 2020 17:17:44 +0200 Subject: [PATCH] refactor: don't pass client directly to store syncer --- internal/bridge/bridge_login_test.go | 61 +++++++++++------------- internal/bridge/bridge_new_test.go | 15 +++++- internal/bridge/bridge_test.go | 14 ++++-- internal/bridge/bridge_users_test.go | 1 - internal/bridge/user_credentials_test.go | 8 ---- internal/bridge/user_new_test.go | 30 ++++-------- internal/bridge/user_test.go | 3 ++ internal/store/sync.go | 6 +-- internal/store/sync_test.go | 6 +-- internal/store/user_sync.go | 3 +- 10 files changed, 71 insertions(+), 76 deletions(-) diff --git a/internal/bridge/bridge_login_test.go b/internal/bridge/bridge_login_test.go index bd2f0663..65f73c3e 100644 --- a/internal/bridge/bridge_login_test.go +++ b/internal/bridge/bridge_login_test.go @@ -89,32 +89,27 @@ func TestBridgeFinishLoginNewUser(t *testing.T) { m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) gomock.InOrder( - // Bridge finds no users in the keychain. + // bridge.New() finds no users in keychain. m.credentialsStore.EXPECT().List().Return([]string{}, nil), - // Get user to be able to setup new client with proper userID. + // getAPIUser() loads user info from API (e.g. userID). m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), - // bridge.Bridge.addNewUser(() + // addNewUser() m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil), m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - - // bridge.newUser() m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}), - m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil).Times(2), + m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil), - // bridge.User.init() + // user.init() in addNewUser + m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil), m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil), - //TODO m.credentialsStore.EXPECT().UpdateToken("user", ":afterCredentials").Return(nil), - //TODO m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken("afterCredentials"), nil), - - // authorize if necessary m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil), - // Set up mocks for creating the user's store (in store.New). + // store.New() in user.init m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), @@ -126,13 +121,16 @@ func TestBridgeFinishLoginNewUser(t *testing.T) { // Reload account list in GUI. m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"), + // defer logout anonymous m.pmapiClient.EXPECT().Logout(), ) mockEventLoopNoAction(m) - checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil) + user := checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil) + + mockAuthUpdate(user, "afterCredentials", m) } func TestBridgeFinishLoginExistingDisconnectedUser(t *testing.T) { @@ -146,54 +144,51 @@ func TestBridgeFinishLoginExistingDisconnectedUser(t *testing.T) { m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) gomock.InOrder( - // Bridge finds one logged out user in the keychain. + // bridge.New() finds one existing user in keychain. m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil), - // New user + + // newUser() m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil), - // Init user + + // user.init() m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil), + + // store.New() in user.init m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrInvalidToken), m.pmapiClient.EXPECT().Addresses().Return(nil), - // Get user to be able to setup new client with proper userID. + // getAPIUser() loads user info from API (e.g. userID). m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), - // bridge.Bridge.connectExistingUser + // connectExistingUser() m.credentialsStore.EXPECT().UpdatePassword("user", testCredentials.MailboxPassword).Return(nil), m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil), m.credentialsStore.EXPECT().UpdateToken("user", ":afterLogin").Return(nil), - // bridge.User.init() + // user.init() in connectExistingUser m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil), m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil), - - // authorize if necessary m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil), - /* TODO - // Set up mocks for authorising the new user (in user.init). - m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}), - m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken("afterCredentials"), nil), - - m.credentialsStore.EXPECT().UpdateToken("user", ":afterCredentials").Return(nil), - */ - - // Set up mocks for creating the user's store (in store.New). + // store.New() in user.init m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), // Reload account list in GUI. m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"), + // defer logout anonymous m.pmapiClient.EXPECT().Logout(), ) mockEventLoopNoAction(m) - checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil) + user := checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil) + + mockAuthUpdate(user, "afterCredentials", m) } func TestBridgeFinishLoginConnectedUser(t *testing.T) { @@ -221,7 +216,7 @@ func TestBridgeFinishLoginConnectedUser(t *testing.T) { assert.Equal(t, "user is already connected", err.Error()) } -func checkBridgeFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) { +func checkBridgeFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) *User { bridge := testNewBridge(t, m) defer cleanUpBridgeUserData(bridge) @@ -239,4 +234,6 @@ func checkBridgeFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPass assert.Equal(t, (*User)(nil), user) assert.Equal(t, 0, len(bridge.users)) } + + return user } diff --git a/internal/bridge/bridge_new_test.go b/internal/bridge/bridge_new_test.go index fc0024f6..67aae3f9 100644 --- a/internal/bridge/bridge_new_test.go +++ b/internal/bridge/bridge_new_test.go @@ -94,7 +94,6 @@ func mockConnectedUser(m mocks) { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), - //TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil), m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil), m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil), @@ -106,6 +105,20 @@ func mockConnectedUser(m mocks) { ) } +// mockAuthUpdate simulates bridge calling UpdateAuthToken on the given user. +// This would normally be done by Bridge when it receives an auth from the ClientManager, +// but as we don't have a full bridge instance here, we do this manually. +func mockAuthUpdate(user *User, token string, m mocks) { + gomock.InOrder( + m.credentialsStore.EXPECT().UpdateToken("user", ":"+token).Return(nil), + m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(token), nil), + ) + + user.updateAuthToken(refreshWithToken(token)) + + waitForEvents() +} + func TestNewBridgeWithConnectedUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 95f866e2..0006f522 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -208,8 +208,6 @@ func testNewBridgeWithUsers(t *testing.T, m mocks) *Bridge { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), - // TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - // TODO m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil), m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), @@ -220,8 +218,6 @@ func testNewBridgeWithUsers(t *testing.T, m mocks) *Bridge { m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil), m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil), m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), - // TODO m.credentialsStore.EXPECT().UpdateToken("users", ":reftok").Return(nil), - // TODO m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil), m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil), m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), @@ -229,7 +225,15 @@ func testNewBridgeWithUsers(t *testing.T, m mocks) *Bridge { m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses), ) - return testNewBridge(t, m) + bridge := testNewBridge(t, m) + + user, _ := bridge.GetUser("user") + mockAuthUpdate(user, "reftok", m) + + users, _ := bridge.GetUser("user") + mockAuthUpdate(users, "reftok", m) + + return bridge } func testNewBridge(t *testing.T, m mocks) *Bridge { diff --git a/internal/bridge/bridge_users_test.go b/internal/bridge/bridge_users_test.go index 650be22c..783dad26 100644 --- a/internal/bridge/bridge_users_test.go +++ b/internal/bridge/bridge_users_test.go @@ -112,7 +112,6 @@ func TestDeleteUserWithFailingLogout(t *testing.T) { m.credentialsStore.EXPECT().Delete("user").Return(nil), m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("no such user")), m.credentialsStore.EXPECT().Delete("user").Return(nil), - //TODO m.credentialsStore.EXPECT().Delete("user").Return(errors.New("no such user")), ) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") diff --git a/internal/bridge/user_credentials_test.go b/internal/bridge/user_credentials_test.go index 3f1c1e70..2da9904e 100644 --- a/internal/bridge/user_credentials_test.go +++ b/internal/bridge/user_credentials_test.go @@ -35,7 +35,6 @@ func TestUpdateUser(t *testing.T) { defer cleanUpUserData(user) gomock.InOrder( - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil), @@ -156,9 +155,6 @@ func TestCheckBridgeLoginOK(t *testing.T) { defer cleanUpUserData(user) gomock.InOrder( - // TODO why u.HasAPIAuth() = false - // TODO why not :reftoken - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil), ) @@ -217,7 +213,6 @@ func TestCheckBridgeLoginLoggedOut(t *testing.T) { err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword) waitForEvents() - assert.Equal(t, "bridge account is logged out, use bridge to login again", err.Error()) } @@ -229,9 +224,6 @@ func TestCheckBridgeLoginBadPassword(t *testing.T) { defer cleanUpUserData(user) gomock.InOrder( - // TODO why u.HasAPIAuth() = false - // TODO why not :reftoken - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil), ) diff --git a/internal/bridge/user_new_test.go b/internal/bridge/user_new_test.go index 444fae70..5606dacc 100644 --- a/internal/bridge/user_new_test.go +++ b/internal/bridge/user_new_test.go @@ -21,6 +21,7 @@ import ( "errors" "testing" + "github.com/ProtonMail/proton-bridge/internal/bridge/credentials" "github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/pkg/pmapi" gomock "github.com/golang/mock/gomock" @@ -52,7 +53,7 @@ func TestNewUserBridgeOutdated(t *testing.T) { m.pmapiClient.EXPECT().Addresses().Return(nil), ) - checkNewUser(m) + checkNewUserHasCredentials(testCredentials, m) } func TestNewUserNoInternetConnection(t *testing.T) { @@ -72,7 +73,7 @@ func TestNewUserNoInternetConnection(t *testing.T) { m.pmapiClient.EXPECT().GetEvent("").Return(nil, pmapi.ErrAPINotReachable).AnyTimes(), ) - checkNewUser(m) + checkNewUserHasCredentials(testCredentials, m) } func TestNewUserAuthRefreshFails(t *testing.T) { @@ -95,7 +96,7 @@ func TestNewUserAuthRefreshFails(t *testing.T) { m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), ) - checkNewUserDisconnected(m) + checkNewUserHasCredentials(testCredentialsDisconnected, m) } func TestNewUserUnlockFails(t *testing.T) { @@ -111,7 +112,6 @@ func TestNewUserUnlockFails(t *testing.T) { gomock.InOrder( m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), - // TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil), m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password")), @@ -121,7 +121,7 @@ func TestNewUserUnlockFails(t *testing.T) { m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), ) - checkNewUserDisconnected(m) + checkNewUserHasCredentials(testCredentialsDisconnected, m) } func TestNewUserUnlockAddressesFails(t *testing.T) { @@ -137,7 +137,6 @@ func TestNewUserUnlockAddressesFails(t *testing.T) { gomock.InOrder( m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), - // TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil), m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil), @@ -148,7 +147,7 @@ func TestNewUserUnlockAddressesFails(t *testing.T) { m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), ) - checkNewUserDisconnected(m) + checkNewUserHasCredentials(testCredentialsDisconnected, m) } func TestNewUser(t *testing.T) { @@ -159,10 +158,10 @@ func TestNewUser(t *testing.T) { mockConnectedUser(m) mockEventLoopNoAction(m) - checkNewUser(m) + checkNewUserHasCredentials(testCredentials, m) } -func checkNewUser(m mocks) { +func checkNewUserHasCredentials(creds *credentials.Credentials, m mocks) { user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") defer cleanUpUserData(user) @@ -170,18 +169,7 @@ func checkNewUser(m mocks) { waitForEvents() - a.Equal(m.t, testCredentials, user.creds) -} - -func checkNewUserDisconnected(m mocks) { - user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") - defer cleanUpUserData(user) - - _ = user.init(nil) - - waitForEvents() - - a.Equal(m.t, testCredentialsDisconnected, user.creds) + a.Equal(m.t, creds, user.creds) } func _TestUserEventRefreshUpdatesAddresses(t *testing.T) { // nolint[funlen] diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index 17ddaa71..1c17f529 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -26,6 +26,7 @@ import ( "github.com/stretchr/testify/require" ) +// testNewUser sets up a new, authorised user. func testNewUser(m mocks) *User { m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) @@ -43,6 +44,8 @@ func testNewUser(m mocks) *User { err = user.init(nil) assert.NoError(m.t, err) + mockAuthUpdate(user, "reftok", m) + return user } diff --git a/internal/store/sync.go b/internal/store/sync.go index 9a83cdbe..d5380f1d 100644 --- a/internal/store/sync.go +++ b/internal/store/sync.go @@ -42,7 +42,7 @@ type messageLister interface { ListMessages(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error) } -func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api messageLister, syncState *syncState) error { +func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func() messageLister, syncState *syncState) error { labelID := pmapi.AllMailLabel // When the full sync starts (i.e. is not already in progress), we need to load @@ -53,7 +53,7 @@ func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api message return errors.Wrap(err, "failed to load message IDs") } - if err := findIDRanges(labelID, api, syncState); err != nil { + if err := findIDRanges(labelID, api(), syncState); err != nil { return errors.Wrap(err, "failed to load IDs ranges") } syncState.save() @@ -71,7 +71,7 @@ func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api message defer panicHandler.HandlePanic() defer wg.Done() - err := syncBatch(labelID, store, api, syncState, idRange, &shouldStop) + err := syncBatch(labelID, store, api(), syncState, idRange, &shouldStop) if err != nil { shouldStop = 1 resultError = errors.Wrap(err, "failed to sync group") diff --git a/internal/store/sync_test.go b/internal/store/sync_test.go index 2f7fae8f..6488e772 100644 --- a/internal/store/sync_test.go +++ b/internal/store/sync_test.go @@ -178,7 +178,7 @@ func TestSyncAllMail(t *testing.T) { //nolint[funlen] } syncState := newSyncState(store, 0, tc.idRanges, tc.idsToBeDeleted) - err := syncAllMail(m.panicHandler, store, api, syncState) + err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState) require.Nil(t, err) // Check all messages were created or updated. @@ -226,7 +226,7 @@ func TestSyncAllMail_FailedListing(t *testing.T) { } syncState := newTestSyncState(store) - err := syncAllMail(m.panicHandler, store, api, syncState) + err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState) require.EqualError(t, err, "failed to sync group: failed to list messages: error") } @@ -245,7 +245,7 @@ func TestSyncAllMail_FailedCreateOrUpdateMessage(t *testing.T) { } syncState := newTestSyncState(store) - err := syncAllMail(m.panicHandler, store, api, syncState) + err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState) require.EqualError(t, err, "failed to sync group: failed to create or update messages: error") } diff --git a/internal/store/user_sync.go b/internal/store/user_sync.go index c4962d1b..fb58403b 100644 --- a/internal/store/user_sync.go +++ b/internal/store/user_sync.go @@ -144,8 +144,7 @@ func (store *Store) triggerSync() { store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started") - // TODO: Is it okay to pass in a client directly? What if it is logged out in the meantime? - err := syncAllMail(store.panicHandler, store, store.client(), syncState) + err := syncAllMail(store.panicHandler, store, func() messageLister { return store.client() }, syncState) if err != nil { log.WithError(err).Error("Store sync failed") return