Fixing unit tests for client manager.

* [x] pmapi: refresh auth uid won't change
* [x] bridge tests:
    * update mocks
    * delete auth when FinishLogin fails
    * check for mailbox password
    * add `gomock.InOrder` for better test control
* [x] fix linter issues except TODOs
* [x] make rootScheme unexported
* [x] store tests: update mocks
This commit is contained in:
Jakub
2020-04-14 07:54:11 +02:00
committed by James Houlahan
parent debd374d75
commit 80f4e1e346
25 changed files with 537 additions and 364 deletions

View File

@ -16,6 +16,11 @@ Changelog [format](http://keepachangelog.com/en/1.0.0/)
* Adding DSN Sentry as build time parameter * Adding DSN Sentry as build time parameter
* GODT-124 bump go-appdir from v1.0.0 to v1.1.0 * GODT-124 bump go-appdir from v1.0.0 to v1.1.0
* CSB-72 Skip processing message update event if http statuscode is 422 * CSB-72 Skip processing message update event if http statuscode is 422
* Skip processing message update event if http statuscode is 422
* GODT-204 `pmapi.TokenManager` replaced by `pmapi.ClientManager`
* `expiresAt` is no longer part of client
* TODO Please fill here all logic changes
### Fixed ### Fixed
* Use correct binary name when finding location of addcert.scpt * Use correct binary name when finding location of addcert.scpt

View File

@ -181,18 +181,20 @@ func (b *Bridge) watchBridgeOutdated() {
// watchAPIAuths receives auths from the client manager and sends them to the appropriate user. // watchAPIAuths receives auths from the client manager and sends them to the appropriate user.
func (b *Bridge) watchAPIAuths() { func (b *Bridge) watchAPIAuths() {
for auth := range b.clientManager.GetAuthUpdateChannel() { for auth := range b.clientManager.GetAuthUpdateChannel() {
logrus.Debug("Bridge received auth from ClientManager") log.Debug("Bridge received auth from ClientManager")
user, ok := b.hasUser(auth.UserID) user, ok := b.hasUser(auth.UserID)
if !ok { if !ok {
logrus.WithField("userID", auth.UserID).Info("User not available for auth update") log.WithField("userID", auth.UserID).Info("User not available for auth update")
continue continue
} }
if auth.Auth != nil { if auth.Auth != nil {
user.updateAuthToken(auth.Auth) user.updateAuthToken(auth.Auth)
} else { } else if err := user.logout(); err != nil {
user.logout() log.WithError(err).
WithField("userID", auth.UserID).
Error("User logout failed while watching API auths")
} }
} }
} }
@ -241,6 +243,14 @@ func (b *Bridge) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPasswo
if err == pmapi.ErrUpgradeApplication { if err == pmapi.ErrUpgradeApplication {
b.events.Emit(events.UpgradeApplicationEvent, "") b.events.Emit(events.UpgradeApplicationEvent, "")
} }
if err != nil {
log.WithError(err).Debug("Login not finished; removing auth session")
if delAuthErr := authClient.DeleteAuth(); delAuthErr != nil {
log.WithError(delAuthErr).Error("Failed to clear login session after unlock")
}
}
// The anonymous client will be removed from list and authentication will not be deleted.
authClient.Logout()
}() }()
apiUser, hashedPassword, err := getAPIUser(authClient, auth, mbPassword) apiUser, hashedPassword, err := getAPIUser(authClient, auth, mbPassword)
@ -249,7 +259,8 @@ func (b *Bridge) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPasswo
return return
} }
if user, err = b.GetUser(apiUser.ID); err == nil { var ok bool
if user, ok = b.hasUser(apiUser.ID); ok {
if err = b.connectExistingUser(user, auth, hashedPassword); err != nil { if err = b.connectExistingUser(user, auth, hashedPassword); err != nil {
log.WithError(err).Error("Failed to connect existing user") log.WithError(err).Error("Failed to connect existing user")
return return
@ -305,7 +316,7 @@ func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword s
return errors.Wrap(err, "failed to refresh token in new client") return errors.Wrap(err, "failed to refresh token in new client")
} }
if user, err = client.UpdateUser(); err != nil { if user, err = client.CurrentUser(); err != nil {
return errors.Wrap(err, "failed to update API user") return errors.Wrap(err, "failed to update API user")
} }
@ -330,7 +341,7 @@ func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword s
b.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel)) b.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel))
return return err
} }
func getAPIUser(client pmapi.Client, auth *pmapi.Auth, mbPassword string) (user *pmapi.User, hashedPassword string, err error) { func getAPIUser(client pmapi.Client, auth *pmapi.Auth, mbPassword string) (user *pmapi.User, hashedPassword string, err error) {
@ -340,7 +351,13 @@ func getAPIUser(client pmapi.Client, auth *pmapi.Auth, mbPassword string) (user
return return
} }
if user, err = client.UpdateUser(); err != nil { // We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
if _, err = client.Unlock(hashedPassword); err != nil {
log.WithError(err).Error("Wrong mailbox password")
return
}
if user, err = client.CurrentUser(); err != nil {
log.WithError(err).Error("Could not load API user") log.WithError(err).Error("Could not load API user")
return return
} }

View File

@ -29,17 +29,20 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestBridgeFinishLoginBadPassword(t *testing.T) { func TestBridgeFinishLoginBadMailboxPassword(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
err := errors.New("bad password")
gomock.InOrder(
// Init bridge with no user from keychain. // Init bridge with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil) m.credentialsStore.EXPECT().List().Return([]string{}, nil),
// Set up mocks for FinishLogin. // Set up mocks for FinishLogin.
err := errors.New("bad password") m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err),
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err) m.pmapiClient.EXPECT().DeleteAuth(),
m.pmapiClient.EXPECT().Logout() m.pmapiClient.EXPECT().Logout(),
)
checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", err) checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", err)
} }
@ -48,15 +51,18 @@ func TestBridgeFinishLoginUpgradeApplication(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
err := errors.New("Cannot logout when upgrade needed")
gomock.InOrder(
// Init bridge with no user from keychain. // Init bridge with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil) m.credentialsStore.EXPECT().List().Return([]string{}, nil),
// Set up mocks for FinishLogin. // Set up mocks for FinishLogin.
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, pmapi.ErrUpgradeApplication) m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, pmapi.ErrUpgradeApplication),
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "") m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, ""),
err := errors.New("Cannot logout when upgrade needed") m.pmapiClient.EXPECT().DeleteAuth().Return(err),
m.pmapiClient.EXPECT().Logout().Return(err) m.pmapiClient.EXPECT().Logout(),
)
checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", pmapi.ErrUpgradeApplication) checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", pmapi.ErrUpgradeApplication)
} }
@ -79,49 +85,57 @@ func TestBridgeFinishLoginNewUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
// Basically every call client has get client manager
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
// Bridge finds no users in the keychain. // Bridge finds no users in the keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil) m.credentialsStore.EXPECT().List().Return([]string{}, nil),
// Get user to be able to setup new client with proper userID. // Get user to be able to setup new client with proper userID.
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil) m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
// Setup of new client. // bridge.Bridge.addNewUser(()
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil) m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil) m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Set up mocks for authorising the new user (in user.init). // bridge.newUser()
m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}) m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}),
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil).Times(2) m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil).Times(2),
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil)
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken("afterCredentials"), nil)
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil)
m.credentialsStore.EXPECT().UpdateToken("user", ":afterCredentials").Return(nil) // bridge.User.init()
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
//TODO m.credentialsStore.EXPECT().UpdateToken("user", ":afterCredentials").Return(nil),
//TODO m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken("afterCredentials"), nil),
// authorize if necessary
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
// Set up mocks for creating the user's store (in store.New). // Set up mocks for creating the user's store (in store.New).
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Emit event for new user and send metrics. // Emit event for new user and send metrics.
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient),
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)) m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)),
m.pmapiClient.EXPECT().Logout(),
// Set up mocks for starting the store's event loop (in store.New). // Reload account list in GUI.
// The event loop runs in another goroutine so this might happen at any time. m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) // defer logout anonymous
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) m.pmapiClient.EXPECT().Logout(),
)
// Set up mocks for performing the initial store sync. mockEventLoopNoAction(m)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil)
checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil) checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
} }
func TestBridgeFinishLoginExistingUser(t *testing.T) { func TestBridgeFinishLoginExistingDisconnectedUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
@ -129,87 +143,82 @@ func TestBridgeFinishLoginExistingUser(t *testing.T) {
loggedOutCreds.APIToken = "" loggedOutCreds.APIToken = ""
loggedOutCreds.MailboxPassword = "" loggedOutCreds.MailboxPassword = ""
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
// Bridge finds one logged out user in the keychain. // Bridge finds one logged out user in the keychain.
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
// New user // New user
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil) m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
// Init user // Init user
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil) m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrInvalidToken) m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrInvalidToken),
m.pmapiClient.EXPECT().Addresses().Return(nil) m.pmapiClient.EXPECT().Addresses().Return(nil),
// Get user to be able to setup new client with proper userID. // Get user to be able to setup new client with proper userID.
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil) m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
// Setup of new client. // bridge.Bridge.connectExistingUser
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil) m.credentialsStore.EXPECT().UpdatePassword("user", testCredentials.MailboxPassword).Return(nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil) m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) m.credentialsStore.EXPECT().UpdateToken("user", ":afterLogin").Return(nil),
// bridge.User.init()
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
// authorize if necessary
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
/* TODO
// Set up mocks for authorising the new user (in user.init). // Set up mocks for authorising the new user (in user.init).
m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}) m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}),
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil) m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken("afterCredentials"), nil),
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil)
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken("afterCredentials"), nil)
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil)
m.credentialsStore.EXPECT().UpdateToken("user", ":afterCredentials").Return(nil) m.credentialsStore.EXPECT().UpdateToken("user", ":afterCredentials").Return(nil),
*/
// Set up mocks for creating the user's store (in store.New). // Set up mocks for creating the user's store (in store.New).
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Reload account list in GUI. // Reload account list in GUI.
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
// defer logout anonymous
m.pmapiClient.EXPECT().Logout(),
)
// Set up mocks for starting the store's event loop (in store.New) mockEventLoopNoAction(m)
// The event loop runs in another goroutine so this might happen at any time.
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil)
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil)
// Set up mocks for performing the initial store sync.
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil)
checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil) checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
} }
func TestBridgeDoubleLogin(t *testing.T) { func TestBridgeFinishLoginConnectedUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
// Firstly, start bridge with existing user... m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) mockConnectedUser(m)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) mockEventLoopNoAction(m)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil)
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil)
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil)
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress})
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil)
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil)
bridge := testNewBridge(t, m) bridge := testNewBridge(t, m)
defer cleanUpBridgeUserData(bridge) defer cleanUpBridgeUserData(bridge)
// Then, try to log in again... // Then, try to log in again...
gomock.InOrder(
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil) m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().Logout() m.pmapiClient.EXPECT().DeleteAuth(),
m.pmapiClient.EXPECT().Logout(),
)
_, err := bridge.FinishLogin(m.pmapiClient, testAuth, testCredentials.MailboxPassword) _, err := bridge.FinishLogin(m.pmapiClient, testAuth, testCredentials.MailboxPassword)
assert.Equal(t, "user is already logged in", err.Error()) assert.Equal(t, "user is already connected", err.Error())
} }
func checkBridgeFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) { func checkBridgeFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) {

View File

@ -52,11 +52,16 @@ func TestNewBridgeWithDisconnectedUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) // Basically every call client has get client manager.
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil).Times(2) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized"))
m.pmapiClient.EXPECT().Addresses().Return(nil) gomock.InOrder(
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient) m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
} }
@ -65,9 +70,10 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().Logout("user").Return(nil) m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")) m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token"))
@ -82,26 +88,33 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) {
checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
} }
func mockConnectedUser(m mocks) {
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
//TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil),
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
// Set up mocks for store initialisation for the authorized user.
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
)
}
func TestNewBridgeWithConnectedUser(t *testing.T) { func TestNewBridgeWithConnectedUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil)
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
// Set up mocks for store initialisation for the authorized user. mockConnectedUser(m)
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) mockEventLoopNoAction(m)
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress})
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil)
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil)
checkBridgeNew(t, m, []*credentials.Credentials{testCredentials}) checkBridgeNew(t, m, []*credentials.Credentials{testCredentials})
} }
@ -112,27 +125,22 @@ func TestNewBridgeWithUsers(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) m.credentialsStore.EXPECT().List().Return([]string{"userDisconnected", "user"}, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil)
m.credentialsStore.EXPECT().List().Return([]string{"user", "user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil).Times(2)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
gomock.InOrder(
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
// Set up mocks for store initialisation for the unauth user. // Set up mocks for store initialisation for the unauth user.
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")) m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
m.pmapiClient.EXPECT().Addresses().Return(nil) m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
// Set up mocks for store initialisation for the authorized user. mockConnectedUser(m)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) mockEventLoopNoAction(m)
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress})
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil)
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil)
checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials}) checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
} }
@ -141,9 +149,13 @@ func TestNewBridgeFirstStart(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.prefProvider.EXPECT().GetBool(preferences.FirstStartKey).Return(true) gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{}, nil) m.credentialsStore.EXPECT().List().Return([]string{}, nil),
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.FirstStart), gomock.Any()) m.prefProvider.EXPECT().GetBool(preferences.FirstStartKey).Return(true),
m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient),
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.FirstStart), gomock.Any()),
m.pmapiClient.EXPECT().Logout(),
)
testNewBridge(t, m) testNewBridge(t, m)
} }

View File

@ -18,8 +18,10 @@
package bridge package bridge
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"runtime/debug"
"testing" "testing"
"time" "time"
@ -37,6 +39,9 @@ import (
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
if os.Getenv("VERBOSITY") == "fatal" {
logrus.SetLevel(logrus.FatalLevel)
}
if os.Getenv("VERBOSITY") == "trace" { if os.Getenv("VERBOSITY") == "trace" {
logrus.SetLevel(logrus.TraceLevel) logrus.SetLevel(logrus.TraceLevel)
} }
@ -138,8 +143,27 @@ type mocks struct {
storeCache *store.Cache storeCache *store.Cache
} }
type fullStackReporter struct {
T testing.TB
}
func (fr *fullStackReporter) Errorf(format string, args ...interface{}) {
fmt.Printf("err: "+format+"\n", args...)
fr.T.Fail()
}
func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
debug.PrintStack()
fmt.Printf("fail: "+format+"\n", args...)
fr.T.FailNow()
}
func initMocks(t *testing.T) mocks { func initMocks(t *testing.T) mocks {
mockCtrl := gomock.NewController(t) var mockCtrl *gomock.Controller
if os.Getenv("VERBOSITY") == "trace" {
mockCtrl = gomock.NewController(&fullStackReporter{t})
} else {
mockCtrl = gomock.NewController(t)
}
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db") cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
require.NoError(t, err, "could not get temporary file for store cache") require.NoError(t, err, "could not get temporary file for store cache")
@ -172,35 +196,38 @@ func initMocks(t *testing.T) mocks {
} }
func testNewBridgeWithUsers(t *testing.T, m mocks) *Bridge { func testNewBridgeWithUsers(t *testing.T, m mocks) *Bridge {
// Events are asynchronous
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).Times(2)
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil),
// Init for user. // Init for user.
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) // TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) // TODO m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil)
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil)
// Init for users. // Init for users.
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) // TODO m.credentialsStore.EXPECT().UpdateToken("users", ":reftok").Return(nil),
m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses) // TODO m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil).Times(2) m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
m.credentialsStore.EXPECT().UpdateToken("users", ":reftok").Return(nil) m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil) m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil) m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses),
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil) )
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil)
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil)
return testNewBridge(t, m) return testNewBridge(t, m)
} }
@ -214,7 +241,7 @@ func testNewBridge(t *testing.T, m mocks) *Bridge {
m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes() m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes()
m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes() m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes()
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any()) m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan *pmapi.ClientAuth)) m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan pmapi.ClientAuth))
bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore) bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore)
@ -233,6 +260,9 @@ func TestClearData(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
bridge := testNewBridgeWithUsers(t, m) bridge := testNewBridgeWithUsers(t, m)
defer cleanUpBridgeUserData(bridge) defer cleanUpBridgeUserData(bridge)
@ -255,3 +285,14 @@ func TestClearData(t *testing.T) {
waitForEvents() waitForEvents()
} }
func mockEventLoopNoAction(m mocks) {
// Set up mocks for starting the store's event loop (in store.New).
// The event loop runs in another goroutine so this might happen at any time.
gomock.InOrder(
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil),
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil),
// Set up mocks for performing the initial store sync.
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil),
)
}

View File

@ -22,6 +22,7 @@ import (
"testing" "testing"
"github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/events"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -29,6 +30,9 @@ func TestGetNoUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkBridgeGetUser(t, m, "nouser", -1, "user nouser not found") checkBridgeGetUser(t, m, "nouser", -1, "user nouser not found")
} }
@ -36,6 +40,9 @@ func TestGetUserByID(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkBridgeGetUser(t, m, "user", 0, "") checkBridgeGetUser(t, m, "user", 0, "")
checkBridgeGetUser(t, m, "users", 1, "") checkBridgeGetUser(t, m, "users", 1, "")
} }
@ -44,6 +51,9 @@ func TestGetUserByName(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkBridgeGetUser(t, m, "username", 0, "") checkBridgeGetUser(t, m, "username", 0, "")
checkBridgeGetUser(t, m, "usersname", 1, "") checkBridgeGetUser(t, m, "usersname", 1, "")
} }
@ -52,6 +62,9 @@ func TestGetUserByEmail(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkBridgeGetUser(t, m, "user@pm.me", 0, "") checkBridgeGetUser(t, m, "user@pm.me", 0, "")
checkBridgeGetUser(t, m, "users@pm.me", 1, "") checkBridgeGetUser(t, m, "users@pm.me", 1, "")
checkBridgeGetUser(t, m, "anotheruser@pm.me", 1, "") checkBridgeGetUser(t, m, "anotheruser@pm.me", 1, "")
@ -62,14 +75,18 @@ func TestDeleteUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
bridge := testNewBridgeWithUsers(t, m) bridge := testNewBridgeWithUsers(t, m)
defer cleanUpBridgeUserData(bridge) defer cleanUpBridgeUserData(bridge)
m.pmapiClient.EXPECT().Logout().Return(nil) gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(nil) m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Delete("user").Return(nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
@ -83,14 +100,20 @@ func TestDeleteUserWithFailingLogout(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
bridge := testNewBridgeWithUsers(t, m) bridge := testNewBridgeWithUsers(t, m)
defer cleanUpBridgeUserData(bridge) defer cleanUpBridgeUserData(bridge)
m.pmapiClient.EXPECT().Logout().Return(nil) gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")) m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
m.credentialsStore.EXPECT().Delete("user").Return(nil).Times(2) m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("no such user")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
//TODO m.credentialsStore.EXPECT().Delete("user").Return(errors.New("no such user")),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")

View File

@ -5,11 +5,10 @@
package mocks package mocks
import ( import (
reflect "reflect"
credentials "github.com/ProtonMail/proton-bridge/internal/bridge/credentials" credentials "github.com/ProtonMail/proton-bridge/internal/bridge/credentials"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
reflect "reflect"
) )
// MockConfiger is a mock of Configer interface // MockConfiger is a mock of Configer interface
@ -272,10 +271,10 @@ func (m *MockClientManager) GetAuthUpdateChannel() chan pmapi.ClientAuth {
return ret0 return ret0
} }
// GetAuthUpdateChannel indicates an expected call of GetBridgeAuthChannel // GetAuthUpdateChannel indicates an expected call of GetAuthUpdateChannel
func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call { func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetAuthUpdateChannel))
} }
// GetClient mocks base method // GetClient mocks base method

View File

@ -409,7 +409,7 @@ func (u *User) GetBridgePassword() string {
} }
// CheckBridgeLogin checks whether the user is logged in and the bridge // CheckBridgeLogin checks whether the user is logged in and the bridge
// password is correct. // IMAP/SMTP password is correct.
func (u *User) CheckBridgeLogin(password string) error { func (u *User) CheckBridgeLogin(password string) error {
if isApplicationOutdated { if isApplicationOutdated {
u.listener.Emit(events.UpgradeApplicationEvent, "") u.listener.Emit(events.UpgradeApplicationEvent, "")

View File

@ -34,16 +34,24 @@ func TestUpdateUser(t *testing.T) {
user := testNewUser(m) user := testNewUser(m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) gomock.InOrder(
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil) m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil) m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}) m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
gomock.InOrder(
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
)
assert.NoError(t, user.UpdateUser()) assert.NoError(t, user.UpdateUser())
@ -105,9 +113,12 @@ func TestLogoutUser(t *testing.T) {
user := testNewUserForLogout(m) user := testNewUserForLogout(m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
m.pmapiClient.EXPECT().Logout().Return(nil) gomock.InOrder(
m.credentialsStore.EXPECT().Logout("user").Return(nil) m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := user.Logout() err := user.Logout()
@ -124,10 +135,12 @@ func TestLogoutUserFailsLogout(t *testing.T) {
user := testNewUserForLogout(m) user := testNewUserForLogout(m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
m.pmapiClient.EXPECT().Logout().Return(nil) gomock.InOrder(
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")) m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Delete("user").Return(nil) m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := user.Logout() err := user.Logout()
@ -135,15 +148,20 @@ func TestLogoutUserFailsLogout(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestCheckBridgeLogin(t *testing.T) { func TestCheckBridgeLoginOK(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
user := testNewUser(m) user := testNewUser(m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) gomock.InOrder(
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) // TODO why u.HasAPIAuth() = false
// TODO why not :reftoken
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
)
err := user.CheckBridgeLogin(testCredentials.BridgePassword) err := user.CheckBridgeLogin(testCredentials.BridgePassword)
@ -162,11 +180,12 @@ func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "") m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
isApplicationOutdated = true isApplicationOutdated = true
err := user.CheckBridgeLogin("any-pass") err := user.CheckBridgeLogin("any-pass")
waitForEvents() waitForEvents()
isApplicationOutdated = false
assert.Equal(t, pmapi.ErrUpgradeApplication, err) assert.Equal(t, pmapi.ErrUpgradeApplication, err)
isApplicationOutdated = false
} }
func TestCheckBridgeLoginLoggedOut(t *testing.T) { func TestCheckBridgeLoginLoggedOut(t *testing.T) {
@ -174,19 +193,29 @@ func TestCheckBridgeLoginLoggedOut(t *testing.T) {
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient)
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized"))
m.pmapiClient.EXPECT().Addresses().Return(nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) user, err := newUser(
_ = user.init(nil) m.PanicHandler, "user",
m.eventListener, m.credentialsStore,
m.clientManager, m.storeCache, "/tmp",
)
assert.NoError(t, err)
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
err = user.init(nil)
assert.Error(t, err)
defer cleanUpUserData(user) defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
err := user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword) err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
waitForEvents() waitForEvents()
assert.Equal(t, "bridge account is logged out, use bridge to login again", err.Error()) assert.Equal(t, "bridge account is logged out, use bridge to login again", err.Error())
@ -199,8 +228,13 @@ func TestCheckBridgeLoginBadPassword(t *testing.T) {
user := testNewUser(m) user := testNewUser(m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) gomock.InOrder(
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) // TODO why u.HasAPIAuth() = false
// TODO why not :reftoken
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
)
err := user.CheckBridgeLogin("wrong!") err := user.CheckBridgeLogin("wrong!")
waitForEvents() waitForEvents()

View File

@ -41,12 +41,16 @@ func TestNewUserBridgeOutdated(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().Logout("user").Return(nil).AnyTimes()
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrUpgradeApplication).AnyTimes() gomock.InOrder(
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "").AnyTimes() m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUpgradeApplication) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().Addresses().Return(nil) m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrUpgradeApplication),
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, ""),
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUpgradeApplication),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
checkNewUser(m) checkNewUser(m)
} }
@ -55,13 +59,18 @@ func TestNewUserNoInternetConnection(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrAPINotReachable).AnyTimes()
m.eventListener.EXPECT().Emit(events.InternetOffEvent, "").AnyTimes()
m.pmapiClient.EXPECT().Addresses().Return(nil) gomock.InOrder(
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrAPINotReachable) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().GetEvent("").Return(nil, pmapi.ErrAPINotReachable).AnyTimes() m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrAPINotReachable),
m.eventListener.EXPECT().Emit(events.InternetOffEvent, ""),
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrAPINotReachable),
m.pmapiClient.EXPECT().Addresses().Return(nil),
m.pmapiClient.EXPECT().GetEvent("").Return(nil, pmapi.ErrAPINotReachable).AnyTimes(),
)
checkNewUser(m) checkNewUser(m)
} }
@ -70,17 +79,22 @@ func TestNewUserAuthRefreshFails(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")).AnyTimes()
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.pmapiClient.EXPECT().Logout().Return(nil)
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserDisconnected(m) checkNewUserDisconnected(m)
} }
@ -88,21 +102,25 @@ func TestNewUserUnlockFails(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password"))
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.pmapiClient.EXPECT().Logout().Return(nil)
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
// TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserDisconnected(m) checkNewUserDisconnected(m)
} }
@ -110,22 +128,26 @@ func TestNewUserUnlockAddressesFails(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(errors.New("bad password"))
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.pmapiClient.EXPECT().Logout().Return(nil)
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
// TODO m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(errors.New("bad password")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserDisconnected(m) checkNewUserDisconnected(m)
} }
@ -133,21 +155,9 @@ func TestNewUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) mockConnectedUser(m)
mockEventLoopNoAction(m)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil)
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress})
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil)
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil)
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil)
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil)
checkNewUser(m) checkNewUser(m)
} }

View File

@ -27,21 +27,15 @@ import (
) )
func testNewUser(m mocks) *User { func testNewUser(m mocks) *User {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) mockConnectedUser(m)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil)
// Expectations for initial sync (when loading existing user from credentials store). gomock.InOrder(
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}) m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).AnyTimes() )
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil)
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
assert.NoError(m.t, err) assert.NoError(m.t, err)
@ -53,21 +47,15 @@ func testNewUser(m mocks) *User {
} }
func testNewUserForLogout(m mocks) *User { func testNewUserForLogout(m mocks) *User {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) mockConnectedUser(m)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil)
// These may or may not be hit depending on how fast the log out happens. gomock.InOrder(
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil).AnyTimes() m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}).AnyTimes() m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).AnyTimes() )
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
assert.NoError(m.t, err) assert.NoError(m.t, err)

View File

@ -424,8 +424,7 @@ func (c *client) Unlock(password string) (kr *pmcrypto.KeyRing, err error) {
func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) { func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
// If we don't yet have a saved access token, save this one in case the refresh fails! // If we don't yet have a saved access token, save this one in case the refresh fails!
// That way we can try again later (see handleUnauthorizedStatus). // That way we can try again later (see handleUnauthorizedStatus).
// TODO: c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
// c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
split := strings.Split(uidAndRefreshToken, ":") split := strings.Split(uidAndRefreshToken, ":")
if len(split) != 2 { if len(split) != 2 {

View File

@ -33,8 +33,6 @@ import (
r "github.com/stretchr/testify/require" r "github.com/stretchr/testify/require"
) )
var aLongTimeAgo = time.Unix(233431200, 0)
var testIdentity = &pmcrypto.Identity{ var testIdentity = &pmcrypto.Identity{
Name: "UserID", Name: "UserID",
Email: "", Email: "",
@ -276,10 +274,11 @@ func TestClient_AuthRefresh(t *testing.T) {
auth, err := c.AuthRefresh(testUID + ":" + testRefreshToken) auth, err := c.AuthRefresh(testUID + ":" + testRefreshToken)
Ok(t, err) Ok(t, err)
Equals(t, testUID, c.uid)
exp := &Auth{} exp := &Auth{}
*exp = *testAuth *exp = *testAuth
exp.uid = "" // AuthRefresh will not return UID (only Auth returns the UID). exp.uid = testUID // AuthRefresh will not return UID (only Auth returns the UID) we should set testUID to be able to generate token, see `GetToken`
exp.accessToken = testAccessToken exp.accessToken = testAccessToken
exp.KeySalt = "" exp.KeySalt = ""
exp.EventID = "" exp.EventID = ""

View File

@ -3,6 +3,7 @@ package pmapi
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
@ -10,8 +11,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
var defaultProxyUseDuration = 24 * time.Hour
// ClientManager is a manager of clients. // ClientManager is a manager of clients.
type ClientManager struct { type ClientManager struct {
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to // newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
@ -21,9 +20,6 @@ type ClientManager struct {
config *ClientConfig config *ClientConfig
roundTripper http.RoundTripper roundTripper http.RoundTripper
// TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests.
// But that screws up other things like not being able to clear sensitive info during logout
// unless the client interface contains a method for that.
clients map[string]Client clients map[string]Client
clientsLocker sync.Locker clientsLocker sync.Locker
@ -33,17 +29,19 @@ type ClientManager struct {
expirations map[string]*tokenExpiration expirations map[string]*tokenExpiration
expirationsLocker sync.Locker expirationsLocker sync.Locker
host, scheme string
hostLocker sync.Locker
bridgeAuths chan ClientAuth bridgeAuths chan ClientAuth
clientAuths chan ClientAuth clientAuths chan ClientAuth
host, scheme string
hostLocker sync.RWMutex
allowProxy bool allowProxy bool
proxyProvider *proxyProvider proxyProvider *proxyProvider
proxyUseDuration time.Duration proxyUseDuration time.Duration
idGen idGen idGen idGen
log *logrus.Entry
} }
type idGen int type idGen int
@ -81,14 +79,16 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
expirationsLocker: &sync.Mutex{}, expirationsLocker: &sync.Mutex{},
host: RootURL, host: RootURL,
scheme: RootScheme, scheme: rootScheme,
hostLocker: &sync.Mutex{}, hostLocker: sync.RWMutex{},
bridgeAuths: make(chan ClientAuth), bridgeAuths: make(chan ClientAuth),
clientAuths: make(chan ClientAuth), clientAuths: make(chan ClientAuth),
proxyProvider: newProxyProvider(dohProviders, proxyQuery), proxyProvider: newProxyProvider(dohProviders, proxyQuery),
proxyUseDuration: defaultProxyUseDuration, proxyUseDuration: proxyUseDuration,
log: logrus.WithField("pkg", "pmapi-manager"),
} }
cm.newClient = func(userID string) Client { cm.newClient = func(userID string) Client {
@ -97,7 +97,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
go cm.forwardClientAuths() go cm.forwardClientAuths()
return return cm
} }
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) { func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
@ -140,20 +140,20 @@ func (cm *ClientManager) LogoutClient(userID string) {
delete(cm.clients, userID) delete(cm.clients, userID)
go func() { go func() {
if !strings.HasPrefix(userID, "anonymous-") {
if err := client.DeleteAuth(); err != nil { if err := client.DeleteAuth(); err != nil {
// TODO: Retry if the request failed. // TODO: Retry if the request failed.
} }
}
client.ClearData() client.ClearData()
cm.clearToken(userID) cm.clearToken(userID)
}() }()
return
} }
// GetRootURL returns the full root URL (scheme+host). // GetRootURL returns the full root URL (scheme+host).
func (cm *ClientManager) GetRootURL() string { func (cm *ClientManager) GetRootURL() string {
cm.hostLocker.Lock() cm.hostLocker.RLock()
defer cm.hostLocker.Unlock() defer cm.hostLocker.RUnlock()
return fmt.Sprintf("%v://%v", cm.scheme, cm.host) return fmt.Sprintf("%v://%v", cm.scheme, cm.host)
} }
@ -161,24 +161,16 @@ func (cm *ClientManager) GetRootURL() string {
// getHost returns the host to make requests to. // getHost returns the host to make requests to.
// It does not include the protocol i.e. no "https://" (use getScheme for that). // It does not include the protocol i.e. no "https://" (use getScheme for that).
func (cm *ClientManager) getHost() string { func (cm *ClientManager) getHost() string {
cm.hostLocker.Lock() cm.hostLocker.RLock()
defer cm.hostLocker.Unlock() defer cm.hostLocker.RUnlock()
return cm.host return cm.host
} }
// getScheme returns the scheme with which to make requests to the host.
func (cm *ClientManager) getScheme() string {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
return cm.scheme
}
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be. // IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
func (cm *ClientManager) IsProxyAllowed() bool { func (cm *ClientManager) IsProxyAllowed() bool {
cm.hostLocker.Lock() cm.hostLocker.RLock()
defer cm.hostLocker.Unlock() defer cm.hostLocker.RUnlock()
return cm.allowProxy return cm.allowProxy
} }
@ -202,8 +194,8 @@ func (cm *ClientManager) DisallowProxy() {
// IsProxyEnabled returns whether we are currently proxying requests. // IsProxyEnabled returns whether we are currently proxying requests.
func (cm *ClientManager) IsProxyEnabled() bool { func (cm *ClientManager) IsProxyEnabled() bool {
cm.hostLocker.Lock() cm.hostLocker.RLock()
defer cm.hostLocker.Unlock() defer cm.hostLocker.RUnlock()
return cm.host != RootURL return cm.host != RootURL
} }
@ -264,6 +256,21 @@ func (cm *ClientManager) forwardClientAuths() {
} }
} }
// SetTokenIfUnset sets the token for the given userID if it wasn't already set.
// The token does not expire.
func (cm *ClientManager) SetTokenIfUnset(userID, token string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
if _, ok := cm.tokens[userID]; ok {
return
}
logrus.WithField("userID", userID).Info("Setting token because it is currently unset")
cm.tokens[userID] = token
}
// setToken sets the token for the given userID with the given expiration time. // setToken sets the token for the given userID with the given expiration time.
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) { func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
cm.tokensLocker.Lock() cm.tokensLocker.Lock()
@ -275,6 +282,7 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
cm.setTokenExpiration(userID, expiration) cm.setTokenExpiration(userID, expiration)
// TODO: This should be one go routine per all tokens.
go cm.watchTokenExpiration(userID) go cm.watchTokenExpiration(userID)
} }
@ -311,7 +319,7 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
// If we aren't managing this client, there's nothing to do. // If we aren't managing this client, there's nothing to do.
if _, ok := cm.clients[ca.UserID]; !ok { if _, ok := cm.clients[ca.UserID]; !ok {
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client") logrus.WithField("userID", ca.UserID).Info("Not handling auth for unmanaged client")
return return
} }
@ -332,8 +340,12 @@ func (cm *ClientManager) watchTokenExpiration(userID string) {
select { select {
case <-expiration.timer.C: case <-expiration.timer.C:
logrus.WithField("userID", userID).Info("Auth token expired! Refreshing") cm.log.WithField("userID", userID).Info("Auth token expired! Refreshing")
cm.clients[userID].AuthRefresh(cm.tokens[userID]) if _, err := cm.clients[userID].AuthRefresh(cm.tokens[userID]); err != nil {
cm.log.WithField("userID", userID).
WithError(err).
Error("Token refresh failed before expiration")
}
case <-expiration.cancel: case <-expiration.cancel:
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired") logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")

View File

@ -27,10 +27,9 @@ import (
// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod. // This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
// Default is pmapi_prod. // Default is pmapi_prod.
// //
// It should not contain the protocol! The protocol should be in RootScheme. // It must not contain the protocol! The protocol should be in rootScheme.
var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals] var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
var rootScheme = "https" //nolint[gochecknoglobals]
var RootScheme = "https"
// CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program // CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program
// version and email client. // version and email client.

View File

@ -21,5 +21,5 @@ package pmapi
func init() { func init() {
RootURL = "dev.protonmail.com/api" RootURL = "dev.protonmail.com/api"
RootScheme = "https" rootScheme = "https"
} }

View File

@ -28,7 +28,7 @@ func init() {
// Use port above 1000 which doesn't need root access to start anything on it. // Use port above 1000 which doesn't need root access to start anything on it.
// Now the port is rounded pi. :-) // Now the port is rounded pi. :-)
RootURL = "127.0.0.1:3142/api" RootURL = "127.0.0.1:3142/api"
RootScheme = "http" rootScheme = "http"
// TLS certificate is self-signed // TLS certificate is self-signed
defaultTransport = &http.Transport{ defaultTransport = &http.Transport{

View File

@ -41,7 +41,7 @@ func setTestDialerWithPinning(cm *ClientManager) (*int, *DialerWithPinning) {
func TestTLSPinValid(t *testing.T) { func TestTLSPinValid(t *testing.T) {
cm := NewClientManager(testLiveConfig) cm := NewClientManager(testLiveConfig)
cm.host = liveAPI cm.host = liveAPI
RootScheme = "https" rootScheme = "https"
called, _ := setTestDialerWithPinning(cm) called, _ := setTestDialerWithPinning(cm)
client := cm.GetClient("pmapi" + t.Name()) client := cm.GetClient("pmapi" + t.Name())

View File

@ -5,12 +5,11 @@
package mocks package mocks
import ( import (
io "io"
reflect "reflect"
crypto "github.com/ProtonMail/gopenpgp/crypto" crypto "github.com/ProtonMail/gopenpgp/crypto"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
io "io"
reflect "reflect"
) )
// MockClient is a mock of Client interface // MockClient is a mock of Client interface
@ -110,6 +109,18 @@ func (mr *MockClientMockRecorder) AuthRefresh(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRefresh", reflect.TypeOf((*MockClient)(nil).AuthRefresh), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRefresh", reflect.TypeOf((*MockClient)(nil).AuthRefresh), arg0)
} }
// ClearData mocks base method
func (m *MockClient) ClearData() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ClearData")
}
// ClearData indicates an expected call of ClearData
func (mr *MockClientMockRecorder) ClearData() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearData", reflect.TypeOf((*MockClient)(nil).ClearData))
}
// CountMessages mocks base method // CountMessages mocks base method
func (m *MockClient) CountMessages(arg0 string) ([]*pmapi.MessagesCount, error) { func (m *MockClient) CountMessages(arg0 string) ([]*pmapi.MessagesCount, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -214,6 +225,20 @@ func (mr *MockClientMockRecorder) DeleteAttachment(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAttachment", reflect.TypeOf((*MockClient)(nil).DeleteAttachment), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAttachment", reflect.TypeOf((*MockClient)(nil).DeleteAttachment), arg0)
} }
// DeleteAuth mocks base method
func (m *MockClient) DeleteAuth() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAuth")
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAuth indicates an expected call of DeleteAuth
func (mr *MockClientMockRecorder) DeleteAuth() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuth", reflect.TypeOf((*MockClient)(nil).DeleteAuth))
}
// DeleteLabel mocks base method // DeleteLabel mocks base method
func (m *MockClient) DeleteLabel(arg0 string) error { func (m *MockClient) DeleteLabel(arg0 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -30,7 +30,7 @@ import (
) )
const ( const (
proxyRevertTime = 24 * time.Hour proxyUseDuration = 24 * time.Hour
proxySearchTimeout = 30 * time.Second proxySearchTimeout = 30 * time.Second
proxyQueryTimeout = 10 * time.Second proxyQueryTimeout = 10 * time.Second
proxyLookupWait = 5 * time.Second proxyLookupWait = 5 * time.Second

View File

@ -27,7 +27,6 @@ import (
// NewRequest creates a new request. // NewRequest creates a new request.
func (c *client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) { func (c *client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) {
// TODO: Support other protocols (localhost needs http not https).
req, err = http.NewRequest(method, c.cm.GetRootURL()+path, body) req, err = http.NewRequest(method, c.cm.GetRootURL()+path, body)
if req != nil { if req != nil {

View File

@ -180,9 +180,10 @@ func hasAPIAuth(accountName string) error {
if err != nil { if err != nil {
return internalError(err, "getting user %s", account.Username()) return internalError(err, "getting user %s", account.Username())
} }
a.Eventually(ctx.GetTestingT(), func() bool { a.Eventually(ctx.GetTestingT(),
return bridgeUser.HasAPIAuth() bridgeUser.HasAPIAuth,
}, 5*time.Second, 10*time.Millisecond) 5*time.Second, 10*time.Millisecond,
)
return ctx.GetTestingError() return ctx.GetTestingError()
} }

View File

@ -24,7 +24,6 @@ import (
"github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/internal/preferences"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
) )
// GetBridge returns bridge instance. // GetBridge returns bridge instance.
@ -60,7 +59,7 @@ func newBridgeInstance(
cfg *fakeConfig, cfg *fakeConfig,
credStore bridge.CredentialsStorer, credStore bridge.CredentialsStorer,
eventListener listener.Listener, eventListener listener.Listener,
clientManager *pmapi.ClientManager, clientManager bridge.ClientManager,
) *bridge.Bridge { ) *bridge.Bridge {
version := os.Getenv("VERSION") version := os.Getenv("VERSION")
bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "") bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "")

View File

@ -148,7 +148,9 @@ func (api *FakePMAPI) AuthRefresh(token string) (*pmapi.Auth, error) {
} }
func (api *FakePMAPI) Logout() { func (api *FakePMAPI) Logout() {
api.DeleteAuth() if err := api.DeleteAuth(); err != nil {
api.log.WithError(err).Error("delete auth failed during logout")
}
api.ClearData() api.ClearData()
} }