diff --git a/Makefile b/Makefile index 5e17da8a..ca6d34b4 100644 --- a/Makefile +++ b/Makefile @@ -149,7 +149,7 @@ coverage: test go tool cover -html=/tmp/coverage.out -o=coverage.html mocks: - mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/bridge Configer,PreferenceProvider,PanicHandler,PMAPIProvider,CredentialsStorer > internal/bridge/mocks/mocks.go + mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/bridge Configer,PreferenceProvider,PanicHandler,ClientManager,PMAPIProvider,CredentialsStorer > internal/bridge/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser > internal/store/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 05091336..4f9b96ed 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -19,7 +19,6 @@ package bridge import ( - "errors" "strconv" "strings" "sync" @@ -33,6 +32,7 @@ import ( "github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/hashicorp/go-multierror" + "github.com/pkg/errors" logrus "github.com/sirupsen/logrus" ) @@ -48,7 +48,7 @@ type Bridge struct { panicHandler PanicHandler events listener.Listener version string - clientManager *pmapi.ClientManager + clientManager ClientManager credStorer CredentialsStorer storeCache *store.Cache @@ -76,7 +76,7 @@ func New( panicHandler PanicHandler, eventListener listener.Listener, version string, - clientManager *pmapi.ClientManager, + clientManager ClientManager, credStorer CredentialsStorer, ) *Bridge { log.Trace("Creating new bridge") @@ -107,7 +107,7 @@ func New( go func() { defer panicHandler.HandlePanic() - b.watchUserAuths() + b.watchAPIAuths() }() if b.credStorer == nil { @@ -178,16 +178,21 @@ func (b *Bridge) watchBridgeOutdated() { } } -// watchUserAuths receives auths from the client manager and sends them to the appropriate user. -func (b *Bridge) watchUserAuths() { +// watchAPIAuths receives auths from the client manager and sends them to the appropriate user. +func (b *Bridge) watchAPIAuths() { for auth := range b.clientManager.GetBridgeAuthChannel() { logrus.Debug("Bridge received auth from ClientManager") - if user, ok := b.hasUser(auth.UserID); ok { - logrus.Debug("Bridge is forwarding auth to user") - user.AuthorizeWithAPIAuth(auth.Auth) - } else { + user, ok := b.hasUser(auth.UserID) + if !ok { logrus.Info("User is not added to bridge yet") + continue + } + + if auth.Auth != nil { + user.updateAuthToken(auth.Auth) + } else { + user.logout() } } } @@ -209,113 +214,140 @@ func (b *Bridge) closeAllConnections() { // * In case user `auth.HasMailboxPassword()`, ask for it, otherwise use `password` // and then finish the login procedure. // user, err := bridge.FinishLogin(client, auth, mailboxPassword) -func (b *Bridge) Login(username, password string) (loginClient PMAPIProvider, auth *pmapi.Auth, err error) { - log.WithField("username", username).Trace("Logging in to bridge") - +func (b *Bridge) Login(username, password string) (authClient PMAPIProvider, auth *pmapi.Auth, err error) { b.crashBandicoot(username) - // We need to use "login" client because we need userID to properly assign access tokens into token manager. - loginClient = b.clientManager.GetClient("login") + // We need to use anonymous client because we don't yet have userID and so can't save auth tokens yet. + authClient = b.clientManager.GetAnonymousClient() - authInfo, err := loginClient.AuthInfo(username) + authInfo, err := authClient.AuthInfo(username) if err != nil { log.WithField("username", username).WithError(err).Error("Could not get auth info for user") - return nil, nil, err + return } - if auth, err = loginClient.Auth(username, password, authInfo); err != nil { + if auth, err = authClient.Auth(username, password, authInfo); err != nil { log.WithField("username", username).WithError(err).Error("Could not get auth for user") - return loginClient, auth, err + return } - return loginClient, auth, nil + return } // FinishLogin finishes the login procedure and adds the user into the credentials store. // See `Login` for more details of the login flow. -func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPassword string) (user *User, err error) { //nolint[funlen] - log.Trace("Finishing bridge login") - +func (b *Bridge) FinishLogin(authClient PMAPIProvider, auth *pmapi.Auth, mbPassword string) (user *User, err error) { //nolint[funlen] defer func() { if err == pmapi.ErrUpgradeApplication { b.events.Emit(events.UpgradeApplicationEvent, "") } }() - b.lock.Lock() - defer b.lock.Unlock() + apiUser, hashedPassword, err := getAPIUser(authClient, auth, mbPassword) + if err != nil { + log.WithError(err).Error("Failed to get API user") + return + } - defer loginClient.Logout() + if user, err = b.GetUser(apiUser.ID); err == nil { + if err = b.connectExistingUser(user, auth, hashedPassword); err != nil { + log.WithError(err).Error("Failed to connect existing user") + return + } + } else { + if err = b.addNewUser(apiUser, auth, hashedPassword); err != nil { + log.WithError(err).Error("Failed to add new user") + return + } + } - mbPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt) + b.events.Emit(events.UserRefreshEvent, apiUser.ID) + + return b.GetUser(apiUser.ID) +} + +// connectExistingUser connects an existing bridge user to the bridge. +func (b *Bridge) connectExistingUser(user *User, auth *pmapi.Auth, hashedPassword string) (err error) { + if user.IsConnected() { + return errors.New("user is already connected") + } + + // Update the user's password in the cred store in case they changed it. + if err = b.credStorer.UpdatePassword(user.ID(), hashedPassword); err != nil { + return errors.Wrap(err, "failed to update password of user in credentials store") + } + + client := b.clientManager.GetClient(user.ID()) + + if auth, err = client.AuthRefresh(auth.GenToken()); err != nil { + return errors.Wrap(err, "failed to refresh auth token of new client") + } + + if err = b.credStorer.UpdateToken(user.ID(), auth.GenToken()); err != nil { + return errors.Wrap(err, "failed to update token of user in credentials store") + } + + if err = user.init(b.idleUpdates); err != nil { + return errors.Wrap(err, "failed to initialise user") + } + + return +} + +// addNewUser adds a new bridge user to the bridge. +func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword string) (err error) { + client := b.clientManager.GetClient(user.ID) + + if auth, err = client.AuthRefresh(auth.GenToken()); err != nil { + return errors.Wrap(err, "failed to refresh token in new client") + } + + if user, err = client.UpdateUser(); err != nil { + return errors.Wrap(err, "failed to update API user") + } + + activeEmails := client.Addresses().ActiveEmails() + + if _, err = b.credStorer.Add(user.ID, user.Name, auth.GenToken(), hashedPassword, activeEmails); err != nil { + return errors.Wrap(err, "failed to add user to credentials store") + } + + bridgeUser, err := newUser(b.panicHandler, user.ID, b.events, b.credStorer, b.clientManager, b.storeCache, b.config.GetDBDir()) + if err != nil { + return errors.Wrap(err, "failed to create user") + } + + // The user needs to be part of the users list in order for it to receive an auth during initialisation. + // TODO: If adding the user fails, we don't want to leave it there. + b.users = append(b.users, bridgeUser) + + if err = bridgeUser.init(b.idleUpdates); err != nil { + return errors.Wrap(err, "failed to initialise user") + } + + b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel)) + + return +} + +func getAPIUser(client PMAPIProvider, auth *pmapi.Auth, mbPassword string) (user *pmapi.User, hashedPassword string, err error) { + hashedPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt) if err != nil { log.WithError(err).Error("Could not hash mailbox password") return } - if _, err = loginClient.Unlock(mbPassword); err != nil { + if _, err = client.Unlock(hashedPassword); err != nil { log.WithError(err).Error("Could not decrypt keyring") return } - apiUser, err := loginClient.CurrentUser() - if err != nil { - log.WithError(err).Error("Could not get login API user") + if user, err = client.UpdateUser(); err != nil { + log.WithError(err).Error("Could not update API user") return } - user, hasUser := b.hasUser(apiUser.ID) - - // If the user exists and is logged in, we don't want to do anything. - if hasUser && user.IsConnected() { - err = errors.New("user is already logged in") - log.WithError(err).Warn("User is already logged in") - return - } - - apiClient := b.clientManager.GetClient(apiUser.ID) - auth, err = apiClient.AuthRefresh(auth.GenToken()) - if err != nil { - log.WithError(err).Error("Could not refresh token in new client") - return - } - - // We load the current user again because it should now have addresses loaded. - apiUser, err = apiClient.CurrentUser() - if err != nil { - log.WithError(err).Error("Could not get current API user") - return - } - - activeEmails := apiClient.Addresses().ActiveEmails() - if _, err = b.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), mbPassword, activeEmails); err != nil { - log.WithError(err).Error("Could not add user to credentials store") - return - } - - // If it's a new user, generate the user object. - if !hasUser { - user, err = newUser(b.panicHandler, apiUser.ID, b.events, b.credStorer, b.clientManager, b.storeCache, b.config.GetDBDir()) - if err != nil { - log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user") - return - } - b.users = append(b.users, user) - } - - // Set up the user auth and store (which we do for both new and existing users). - if err = user.init(b.idleUpdates); err != nil { - log.WithField("user", user.userID).WithError(err).Error("Could not initialise user") - return - } - - if !hasUser { - b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel)) - } - - b.events.Emit(events.UserRefreshEvent, apiUser.ID) - - return user, err + return } // GetUsers returns all added users into keychain (even logged out users). @@ -326,8 +358,7 @@ func (b *Bridge) GetUsers() []*User { return b.users } -// GetUser returns a user by `query` which is compared to users' ID, username -// or any attached e-mail address. +// GetUser returns a user by `query` which is compared to users' ID, username or any attached e-mail address. func (b *Bridge) GetUser(query string) (*User, error) { b.crashBandicoot(query) diff --git a/internal/bridge/bridge_new_test.go b/internal/bridge/bridge_new_test.go index b4c02f89..ec6214d1 100644 --- a/internal/bridge/bridge_new_test.go +++ b/internal/bridge/bridge_new_test.go @@ -73,7 +73,6 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) { m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.pmapiClient.EXPECT().Logout().Return(nil) - m.pmapiClient.EXPECT().SetAuths(nil) m.credentialsStore.EXPECT().Logout("user").Return(nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index d636094c..786024df 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -129,6 +129,7 @@ type mocks struct { PanicHandler *bridgemocks.MockPanicHandler prefProvider *bridgemocks.MockPreferenceProvider pmapiClient *bridgemocks.MockPMAPIProvider + clientManager *bridgemocks.MockClientManager credentialsStore *bridgemocks.MockCredentialsStorer eventListener *MockListener @@ -208,14 +209,10 @@ func testNewBridge(t *testing.T, m mocks) *Bridge { m.prefProvider.EXPECT().GetBool(preferences.AllowProxyKey).Return(false).AnyTimes() m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes() m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes() - m.pmapiClient.EXPECT().SetAuths(gomock.Any()).AnyTimes() m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any()) - pmapiClientFactory := func(userID string) PMAPIProvider { - log.WithField("userID", userID).Info("Creating new pmclient") - return m.pmapiClient - } + m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient) - bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", pmapiClientFactory, m.credentialsStore) + bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore) waitForEvents() diff --git a/internal/bridge/credentials/store.go b/internal/bridge/credentials/store.go index f42d3ce3..3bfc795f 100644 --- a/internal/bridge/credentials/store.go +++ b/internal/bridge/credentials/store.go @@ -125,6 +125,20 @@ func (s *Store) UpdateEmails(userID string, emails []string) error { return s.saveCredentials(credentials) } +func (s *Store) UpdatePassword(userID, password string) error { + storeLocker.Lock() + defer storeLocker.Unlock() + + credentials, err := s.get(userID) + if err != nil { + return err + } + + credentials.MailboxPassword = password + + return s.saveCredentials(credentials) +} + func (s *Store) UpdateToken(userID, apiToken string) error { storeLocker.Lock() defer storeLocker.Unlock() diff --git a/internal/bridge/mocks/mocks.go b/internal/bridge/mocks/mocks.go index 62e7998e..c88099d9 100644 --- a/internal/bridge/mocks/mocks.go +++ b/internal/bridge/mocks/mocks.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,PMAPIProvider,CredentialsStorer) +// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,ClientManager,PMAPIProvider,CredentialsStorer) // Package mocks is a generated GoMock package. package mocks @@ -205,6 +205,95 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic)) } +// MockClientManager is a mock of ClientManager interface +type MockClientManager struct { + ctrl *gomock.Controller + recorder *MockClientManagerMockRecorder +} + +// MockClientManagerMockRecorder is the mock recorder for MockClientManager +type MockClientManagerMockRecorder struct { + mock *MockClientManager +} + +// NewMockClientManager creates a new mock instance +func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager { + mock := &MockClientManager{ctrl: ctrl} + mock.recorder = &MockClientManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder { + return m.recorder +} + +// AllowProxy mocks base method +func (m *MockClientManager) AllowProxy() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AllowProxy") +} + +// AllowProxy indicates an expected call of AllowProxy +func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy)) +} + +// DisallowProxy mocks base method +func (m *MockClientManager) DisallowProxy() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisallowProxy") +} + +// DisallowProxy indicates an expected call of DisallowProxy +func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy)) +} + +// GetAnonymousClient mocks base method +func (m *MockClientManager) GetAnonymousClient() *pmapi.Client { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAnonymousClient") + ret0, _ := ret[0].(*pmapi.Client) + return ret0 +} + +// GetAnonymousClient indicates an expected call of GetAnonymousClient +func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient)) +} + +// GetBridgeAuthChannel mocks base method +func (m *MockClientManager) GetBridgeAuthChannel() chan pmapi.ClientAuth { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBridgeAuthChannel") + ret0, _ := ret[0].(chan pmapi.ClientAuth) + return ret0 +} + +// GetBridgeAuthChannel indicates an expected call of GetBridgeAuthChannel +func (mr *MockClientManagerMockRecorder) GetBridgeAuthChannel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBridgeAuthChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel)) +} + +// GetClient mocks base method +func (m *MockClientManager) GetClient(arg0 string) *pmapi.Client { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClient", arg0) + ret0, _ := ret[0].(*pmapi.Client) + return ret0 +} + +// GetClient indicates an expected call of GetClient +func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0) +} + // MockPMAPIProvider is a mock of PMAPIProvider interface type MockPMAPIProvider struct { ctrl *gomock.Controller @@ -615,11 +704,9 @@ func (mr *MockPMAPIProviderMockRecorder) ListMessages(arg0 interface{}) *gomock. } // Logout mocks base method -func (m *MockPMAPIProvider) Logout() error { +func (m *MockPMAPIProvider) Logout() { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Logout") - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.Call(m, "Logout") } // Logout indicates an expected call of Logout @@ -700,18 +787,6 @@ func (mr *MockPMAPIProviderMockRecorder) SendSimpleMetric(arg0, arg1, arg2 inter return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSimpleMetric", reflect.TypeOf((*MockPMAPIProvider)(nil).SendSimpleMetric), arg0, arg1, arg2) } -// SetAuths mocks base method -func (m *MockPMAPIProvider) SetAuths(arg0 chan<- *pmapi.Auth) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetAuths", arg0) -} - -// SetAuths indicates an expected call of SetAuths -func (mr *MockPMAPIProviderMockRecorder) SetAuths(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAuths", reflect.TypeOf((*MockPMAPIProvider)(nil).SetAuths), arg0) -} - // UnlabelMessages mocks base method func (m *MockPMAPIProvider) UnlabelMessages(arg0 []string, arg1 string) error { m.ctrl.T.Helper() @@ -909,6 +984,20 @@ func (mr *MockCredentialsStorerMockRecorder) UpdateEmails(arg0, arg1 interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEmails", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateEmails), arg0, arg1) } +// UpdatePassword mocks base method +func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdatePassword indicates an expected call of UpdatePassword +func (mr *MockCredentialsStorerMockRecorder) UpdatePassword(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdatePassword), arg0, arg1) +} + // UpdateToken mocks base method func (m *MockCredentialsStorer) UpdateToken(arg0, arg1 string) error { m.ctrl.T.Helper() diff --git a/internal/bridge/types.go b/internal/bridge/types.go index e74dd8ec..5e002b1d 100644 --- a/internal/bridge/types.go +++ b/internal/bridge/types.go @@ -43,7 +43,12 @@ type PanicHandler interface { HandlePanic() } -type Clientman interface { +type ClientManager interface { + GetClient(userID string) *pmapi.Client + GetAnonymousClient() *pmapi.Client + GetBridgeAuthChannel() chan pmapi.ClientAuth + AllowProxy() + DisallowProxy() } type PMAPIProvider interface { @@ -100,6 +105,7 @@ type CredentialsStorer interface { Get(userID string) (*credentials.Credentials, error) SwitchAddressMode(userID string) error UpdateEmails(userID string, emails []string) error + UpdatePassword(userID, password string) error UpdateToken(userID, apiToken string) error Logout(userID string) error Delete(userID string) error diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 60e3b290..426324be 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -41,7 +41,7 @@ type User struct { log *logrus.Entry panicHandler PanicHandler listener listener.Listener - clientMan *pmapi.ClientManager + clientMan ClientManager credStorer CredentialsStorer imapUpdatesChannel chan interface{} @@ -66,7 +66,7 @@ func newUser( userID string, eventListener listener.Listener, credStorer CredentialsStorer, - clientMan *pmapi.ClientManager, + clientMan ClientManager, storeCache *store.Cache, storeDir string, ) (u *User, err error) { @@ -243,20 +243,9 @@ func (u *User) authorizeAndUnlock() (err error) { return nil } -func (u *User) AuthorizeWithAPIAuth(auth *pmapi.Auth) { - u.lock.Lock() - defer u.lock.Unlock() - +func (u *User) updateAuthToken(auth *pmapi.Auth) { u.log.Debug("User received auth from bridge") - if auth == nil { - if err := u.logout(); err != nil { - u.log.WithError(err).Error("Failed to logout user after receiving empty auth from API") - } - u.isAuthorized = false - return - } - if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil { u.log.WithError(err).Error("Failed to update refresh token in credentials store") return @@ -510,11 +499,16 @@ func (u *User) logout() error { u.lock.Lock() wasConnected := u.creds.IsConnected() u.lock.Unlock() + err := u.Logout() + if wasConnected { u.listener.Emit(events.LogoutEvent, u.userID) u.listener.Emit(events.UserRefreshEvent, u.userID) } + + u.isAuthorized = false + return err } @@ -534,6 +528,7 @@ func (u *User) Logout() (err error) { u.wasKeyringUnlocked = false u.unlockingKeyringLock.Unlock() + // TODO: Is this necessary or could it be done by ClientManager when a nil auth is received? u.client().Logout() if err = u.credStorer.Logout(u.userID); err != nil { @@ -550,6 +545,7 @@ func (u *User) Logout() (err error) { u.closeEventLoop() u.closeAllConnections() + runtime.GC() return err @@ -557,7 +553,7 @@ func (u *User) Logout() (err error) { func (u *User) refreshFromCredentials() { if credentials, err := u.credStorer.Get(u.userID); err != nil { - log.Error("Cannot update credentials: ", err) + log.WithError(err).Error("Cannot refresh user credentials") } else { u.creds = credentials } diff --git a/internal/bridge/user_credentials_test.go b/internal/bridge/user_credentials_test.go index 04cfe322..688b43e1 100644 --- a/internal/bridge/user_credentials_test.go +++ b/internal/bridge/user_credentials_test.go @@ -174,13 +174,13 @@ func TestCheckBridgeLoginLoggedOut(t *testing.T) { defer m.ctrl.Finish() m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) - user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") + m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient) + user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")) m.pmapiClient.EXPECT().Addresses().Return(nil) - m.pmapiClient.EXPECT().SetAuths(gomock.Any()) m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) - _ = user.init(nil, m.pmapiClient) + _ = user.init(nil) defer cleanUpUserData(user) diff --git a/internal/bridge/user_new_test.go b/internal/bridge/user_new_test.go index 0c038a2e..1067ab30 100644 --- a/internal/bridge/user_new_test.go +++ b/internal/bridge/user_new_test.go @@ -44,7 +44,6 @@ func TestNewUserBridgeOutdated(t *testing.T) { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.credentialsStore.EXPECT().Logout("user").Return(nil).AnyTimes() m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrUpgradeApplication).AnyTimes() - m.pmapiClient.EXPECT().SetAuths(gomock.Any()) m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "").AnyTimes() m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUpgradeApplication) m.pmapiClient.EXPECT().Addresses().Return(nil) @@ -58,7 +57,6 @@ func TestNewUserNoInternetConnection(t *testing.T) { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrAPINotReachable).AnyTimes() - m.pmapiClient.EXPECT().SetAuths(gomock.Any()) m.eventListener.EXPECT().Emit(events.InternetOffEvent, "").AnyTimes() m.pmapiClient.EXPECT().Addresses().Return(nil) @@ -75,12 +73,10 @@ func TestNewUserAuthRefreshFails(t *testing.T) { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.credentialsStore.EXPECT().Logout("user").Return(nil) m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")).AnyTimes() - m.pmapiClient.EXPECT().SetAuths(gomock.Any()) m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.pmapiClient.EXPECT().Logout().Return(nil) - m.pmapiClient.EXPECT().SetAuths(nil) m.credentialsStore.EXPECT().Logout("user").Return(nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") @@ -96,7 +92,6 @@ func TestNewUserUnlockFails(t *testing.T) { m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) m.credentialsStore.EXPECT().Logout("user").Return(nil) - m.pmapiClient.EXPECT().SetAuths(gomock.Any()) m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password")) @@ -104,7 +99,6 @@ func TestNewUserUnlockFails(t *testing.T) { m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.pmapiClient.EXPECT().Logout().Return(nil) - m.pmapiClient.EXPECT().SetAuths(nil) m.credentialsStore.EXPECT().Logout("user").Return(nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") @@ -120,7 +114,6 @@ func TestNewUserUnlockAddressesFails(t *testing.T) { m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) m.credentialsStore.EXPECT().Logout("user").Return(nil) - m.pmapiClient.EXPECT().SetAuths(gomock.Any()) m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) @@ -129,7 +122,6 @@ func TestNewUserUnlockAddressesFails(t *testing.T) { m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.pmapiClient.EXPECT().Logout().Return(nil) - m.pmapiClient.EXPECT().SetAuths(nil) m.credentialsStore.EXPECT().Logout("user").Return(nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") @@ -144,7 +136,6 @@ func TestNewUser(t *testing.T) { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - m.pmapiClient.EXPECT().SetAuths(gomock.Any()) m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index d13dcfeb..092f3642 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -30,7 +30,6 @@ func testNewUser(m mocks) *User { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - m.pmapiClient.EXPECT().SetAuths(gomock.Any()) m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) @@ -57,14 +56,12 @@ func testNewUserForLogout(m mocks) *User { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - m.pmapiClient.EXPECT().SetAuths(gomock.Any()) m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil) m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil) // These may or may not be hit depending on how fast the log out happens. - m.pmapiClient.EXPECT().SetAuths(nil).AnyTimes() m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil).AnyTimes() m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}).AnyTimes() m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil) diff --git a/internal/frontend/qt/frontend.go b/internal/frontend/qt/frontend.go index 00b76c3c..4b4ca5ab 100644 --- a/internal/frontend/qt/frontend.go +++ b/internal/frontend/qt/frontend.go @@ -527,11 +527,11 @@ func (s *FrontendQt) toggleAllowProxy() { if s.preferences.GetBool(preferences.AllowProxyKey) { s.preferences.SetBool(preferences.AllowProxyKey, false) - bridge.DisallowDoH() + s.bridge.DisallowProxy() s.Qml.SetIsProxyAllowed(false) } else { s.preferences.SetBool(preferences.AllowProxyKey, true) - bridge.AllowDoH() + s.bridge.AllowProxy() s.Qml.SetIsProxyAllowed(true) } } diff --git a/internal/pmapifactory/pmapi_prod.go b/internal/pmapifactory/pmapi_prod.go index f4db4c11..3b1c5dd2 100644 --- a/internal/pmapifactory/pmapi_prod.go +++ b/internal/pmapifactory/pmapi_prod.go @@ -39,7 +39,7 @@ func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig { } func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, listener listener.Listener) { - logrus.Info("Setting dialer with pinning") + logrus.Info("Setting ClientManager to create clients with key pinning") pin := pmapi.NewDialerWithPinning(cm, cfg.AppVersion) @@ -47,5 +47,5 @@ func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, lis listener.Emit(events.TLSCertIssue, "") } - cm.SetClientRoundTripper(pin.TransportWithPinning()) + cm.SetRoundTripper(pin.TransportWithPinning()) } diff --git a/internal/store/mailbox_counts.go b/internal/store/mailbox_counts.go index 21bef9a7..956e0e69 100644 --- a/internal/store/mailbox_counts.go +++ b/internal/store/mailbox_counts.go @@ -217,7 +217,7 @@ func (store *Store) txGetOnAPICounts(tx *bolt.Tx) ([]*mailboxCounts, error) { // createOrUpdateOnAPICounts will change only on-API-counts. func (store *Store) createOrUpdateOnAPICounts(mailboxCountsOnAPI []*pmapi.MessagesCount) error { - store.log.WithField("apiCounts", mailboxCountsOnAPI).Debug("Updating API counts") + store.log.Debug("Updating API counts") tx := func(tx *bolt.Tx) error { countsBkt := tx.Bucket(countsBucket) diff --git a/pkg/config/config.go b/pkg/config/config.go index 1e88721a..7968e88d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -19,12 +19,9 @@ package config import ( "io/ioutil" - "net" - "net/http" "os" "path/filepath" "strings" - "time" "github.com/ProtonMail/go-appdir" "github.com/ProtonMail/proton-bridge/pkg/pmapi" @@ -73,11 +70,7 @@ func newConfig(appName, version, revision, cacheVersion string, appDirs, appDirs apiConfig: &pmapi.ClientConfig{ AppVersion: strings.Title(appName) + "_" + version, ClientID: appName, - Transport: &http.Transport{ - DialContext: (&net.Dialer{Timeout: 3 * time.Second}).DialContext, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, - }, + SentryDSN: "https://bacfb56338a7471a9fede610046afdda:ab437b0d13f54602a0f5feb684e6d319@api.protonmail.ch/reports/sentry/8", }, } } diff --git a/pkg/pmapi/auth.go b/pkg/pmapi/auth.go index 38f62df1..867a052c 100644 --- a/pkg/pmapi/auth.go +++ b/pkg/pmapi/auth.go @@ -118,11 +118,18 @@ type Auth struct { TwoFA *TwoFactorInfo `json:"2FA,omitempty"` } +// UID returns the session UID from the Auth. +// Only Auths generated from the /auth route will have the UID. +// Auths generated from /auth/refresh are not required to. func (s *Auth) UID() string { return s.uid } func (s *Auth) GenToken() string { + if s == nil { + return "" + } + return fmt.Sprintf("%v:%v", s.UID(), s.RefreshToken) } @@ -147,7 +154,9 @@ type AuthRes struct { AccessToken string TokenType string - UID string + + // UID is the session UID. This is only present in an initial Auth (/auth), not in a refreshed Auth (/auth/refresh). + UID string ServerProof string } @@ -196,18 +205,21 @@ type AuthRefreshReq struct { } func (c *Client) sendAuth(auth *Auth) { - go func() { - c.log.Debug("Client is sending auth to ClientManager") + c.log.Debug("Client is sending auth to ClientManager") + if auth != nil { + // UID is only provided in the initial /auth, not during /auth/refresh + if auth.UID() != "" { + c.uid = auth.UID() + } + c.accessToken = auth.accessToken + } + + go func() { c.cm.getClientAuthChannel() <- ClientAuth{ UserID: c.userID, Auth: auth, } - - if auth != nil { - c.uid = auth.UID() - c.accessToken = auth.accessToken - } }() } @@ -446,6 +458,7 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) } auth = res.getAuth() + c.sendAuth(auth) return auth, err @@ -456,6 +469,8 @@ func (c *Client) Logout() { c.cm.LogoutClient(c.userID) } +// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not. + // logout logs the current user out. func (c *Client) logout() (err error) { req, err := c.NewRequest("DELETE", "/auth", nil) diff --git a/pkg/pmapi/auth_test.go b/pkg/pmapi/auth_test.go index 05ad840c..995b2c39 100644 --- a/pkg/pmapi/auth_test.go +++ b/pkg/pmapi/auth_test.go @@ -279,6 +279,7 @@ func TestClient_AuthRefresh(t *testing.T) { exp := &Auth{} *exp = *testAuth + exp.uid = "" // AuthRefresh will not return UID (only Auth returns the UID). exp.accessToken = testAccessToken exp.KeySalt = "" exp.EventID = "" @@ -329,12 +330,20 @@ func TestClient_Logout(t *testing.T) { }, ) defer finish() + c.uid = testUID c.accessToken = testAccessToken c.Logout() - // TODO: Check that the client is logged out and sensitive data is cleared eventually. + r.Eventually(t, func() bool { + // TODO: Use a method like IsConnected() which returns whether the client was logged out or not. + return c.accessToken == "" && + c.uid == "" && + c.kr == nil && + c.addresses == nil && + c.user == nil + }, 10*time.Second, 10*time.Millisecond) } func TestClient_DoUnauthorized(t *testing.T) { @@ -354,7 +363,6 @@ func TestClient_DoUnauthorized(t *testing.T) { c.uid = testUID c.accessToken = testAccessTokenOld - c.expiresAt = aLongTimeAgo c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken req, err := c.NewRequest("GET", "/", nil) diff --git a/pkg/pmapi/client.go b/pkg/pmapi/client.go index 3e86e1ce..8a46b50c 100644 --- a/pkg/pmapi/client.go +++ b/pkg/pmapi/client.go @@ -79,11 +79,6 @@ type ClientConfig struct { // The sentry DSN. SentryDSN string - // Transport specifies the mechanism by which individual HTTP requests are made. - // If nil, http.DefaultTransport is used. - // TODO: This could be removed entirely and set in the client manager via SetClientRoundTripper. - Transport http.RoundTripper - // Timeout specifies the timeout from request to getting response headers to our API. // Passed to http.Client, empty means no timeout. Timeout time.Duration @@ -120,7 +115,7 @@ type Client struct { func newClient(cm *ClientManager, userID string) *Client { return &Client{ cm: cm, - hc: getHTTPClient(cm.GetConfig()), + hc: getHTTPClient(cm.GetConfig(), cm.GetRoundTripper()), userID: userID, requestLocker: &sync.Mutex{}, keyLocker: &sync.Mutex{}, @@ -128,32 +123,12 @@ func newClient(cm *ClientManager, userID string) *Client { } } -// getHTTPClient returns a http client configured by the given client config. -func getHTTPClient(cfg *ClientConfig) (hc *http.Client) { - hc = &http.Client{Timeout: cfg.Timeout} - - if cfg.Transport == nil { - if defaultTransport != nil { - hc.Transport = defaultTransport - } - return +// getHTTPClient returns a http client configured by the given client config and using the given transport. +func getHTTPClient(cfg *ClientConfig, rt http.RoundTripper) (hc *http.Client) { + return &http.Client{ + Timeout: cfg.Timeout, + Transport: rt, } - - // In future use Clone here. - // https://go-review.googlesource.com/c/go/+/174597/ - if cfgTransport, ok := cfg.Transport.(*http.Transport); ok { - transport := &http.Transport{} - *transport = *cfgTransport //nolint - if transport.Proxy == nil { - transport.Proxy = http.ProxyFromEnvironment - } - hc.Transport = transport - return - } - - hc.Transport = cfg.Transport - - return hc } // Do makes an API request. It does not check for HTTP status code errors. diff --git a/pkg/pmapi/client_test.go b/pkg/pmapi/client_test.go index b110ba39..6e3c5050 100644 --- a/pkg/pmapi/client_test.go +++ b/pkg/pmapi/client_test.go @@ -36,9 +36,8 @@ var testClientConfig = &ClientConfig{ MinSpeed: 256, } -func newTestClient() *Client { - c := newClient(NewClientManager(testClientConfig), "tester") - return c +func newTestClient(cm *ClientManager) *Client { + return cm.GetClient("tester") } func TestClient_Do(t *testing.T) { diff --git a/pkg/pmapi/clientmanager.go b/pkg/pmapi/clientmanager.go index dbc0d91c..be2dc46c 100644 --- a/pkg/pmapi/clientmanager.go +++ b/pkg/pmapi/clientmanager.go @@ -1,6 +1,7 @@ package pmapi import ( + "fmt" "net/http" "sync" "time" @@ -10,27 +11,31 @@ import ( "github.com/sirupsen/logrus" ) -var proxyUseDuration = 24 * time.Hour +var defaultProxyUseDuration = 24 * time.Hour // ClientManager is a manager of clients. type ClientManager struct { - config *ClientConfig + config *ClientConfig + roundTripper http.RoundTripper clients map[string]*Client clientsLocker sync.Locker - tokens map[string]string - tokenExpirations map[string]*tokenExpiration - tokensLocker sync.Locker + tokens map[string]string + tokensLocker sync.Locker - url string - urlLocker sync.Locker + expirations map[string]*tokenExpiration + expirationsLocker sync.Locker + + host, scheme string + hostLocker sync.Locker bridgeAuths chan ClientAuth clientAuths chan ClientAuth - allowProxy bool - proxyProvider *proxyProvider + allowProxy bool + proxyProvider *proxyProvider + proxyUseDuration time.Duration } type ClientAuth struct { @@ -50,22 +55,27 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { } cm = &ClientManager{ - config: config, + config: config, + roundTripper: http.DefaultTransport, clients: make(map[string]*Client), clientsLocker: &sync.Mutex{}, - tokens: make(map[string]string), - tokenExpirations: make(map[string]*tokenExpiration), - tokensLocker: &sync.Mutex{}, + tokens: make(map[string]string), + tokensLocker: &sync.Mutex{}, - url: RootURL, - urlLocker: &sync.Mutex{}, + expirations: make(map[string]*tokenExpiration), + expirationsLocker: &sync.Mutex{}, + + host: RootURL, + scheme: RootScheme, + hostLocker: &sync.Mutex{}, bridgeAuths: make(chan ClientAuth), clientAuths: make(chan ClientAuth), - proxyProvider: newProxyProvider(dohProviders, proxyQuery), + proxyProvider: newProxyProvider(dohProviders, proxyQuery), + proxyUseDuration: defaultProxyUseDuration, } go cm.forwardClientAuths() @@ -73,10 +83,14 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { return } -// SetClientRoundTripper sets the roundtripper used by clients created by this client manager. -func (cm *ClientManager) SetClientRoundTripper(rt http.RoundTripper) { - logrus.Info("Setting client roundtripper") - cm.config.Transport = rt +// SetRoundTripper sets the roundtripper used by clients created by this client manager. +func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) { + cm.roundTripper = rt +} + +// GetRoundTripper sets the roundtripper used by clients created by this client manager. +func (cm *ClientManager) GetRoundTripper() (rt http.RoundTripper) { + return cm.roundTripper } // GetClient returns a client for the given userID. @@ -91,6 +105,17 @@ func (cm *ClientManager) GetClient(userID string) *Client { return cm.clients[userID] } +// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created. +func (cm *ClientManager) GetAnonymousClient() *Client { + if client, ok := cm.clients[""]; ok { + client.Logout() + } + + cm.clients[""] = newClient(cm, "") + + return cm.clients[""] +} + // LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared. func (cm *ClientManager) LogoutClient(userID string) { client, ok := cm.clients[userID] @@ -104,7 +129,6 @@ func (cm *ClientManager) LogoutClient(userID string) { go func() { if err := client.logout(); err != nil { // TODO: Try again! This should loop until it succeeds (might fail the first time due to internet). - logrus.WithError(err).Error("Client logout failed, not trying again") } client.clearSensitiveData() cm.clearToken(userID) @@ -113,52 +137,69 @@ func (cm *ClientManager) LogoutClient(userID string) { return } -// GetRootURL returns the root URL to make requests to. -// It does not include the protocol i.e. no "https://". -func (cm *ClientManager) GetRootURL() string { - cm.urlLocker.Lock() - defer cm.urlLocker.Unlock() +// GetHost returns the host to make requests to. +// It does not include the protocol i.e. no "https://" (use GetScheme for that). +func (cm *ClientManager) GetHost() string { + cm.hostLocker.Lock() + defer cm.hostLocker.Unlock() - return cm.url + return cm.host +} + +// GetScheme returns the scheme with which to make requests to the host. +func (cm *ClientManager) GetScheme() string { + cm.hostLocker.Lock() + defer cm.hostLocker.Unlock() + + return cm.scheme +} + +// GetRootURL returns the full root URL (scheme+host). +func (cm *ClientManager) GetRootURL() string { + cm.hostLocker.Lock() + defer cm.hostLocker.Unlock() + + return fmt.Sprintf("%v://%v", cm.scheme, cm.host) } // IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be. func (cm *ClientManager) IsProxyAllowed() bool { - cm.urlLocker.Lock() - defer cm.urlLocker.Unlock() + cm.hostLocker.Lock() + defer cm.hostLocker.Unlock() return cm.allowProxy } // AllowProxy allows the client manager to switch clients over to a proxy if need be. func (cm *ClientManager) AllowProxy() { - cm.urlLocker.Lock() - defer cm.urlLocker.Unlock() + cm.hostLocker.Lock() + defer cm.hostLocker.Unlock() cm.allowProxy = true } // DisallowProxy prevents the client manager from switching clients over to a proxy if need be. func (cm *ClientManager) DisallowProxy() { - cm.urlLocker.Lock() - defer cm.urlLocker.Unlock() + cm.hostLocker.Lock() + defer cm.hostLocker.Unlock() cm.allowProxy = false - cm.url = RootURL + cm.host = RootURL } // IsProxyEnabled returns whether we are currently proxying requests. func (cm *ClientManager) IsProxyEnabled() bool { - cm.urlLocker.Lock() - defer cm.urlLocker.Unlock() + cm.hostLocker.Lock() + defer cm.hostLocker.Unlock() - return cm.url != RootURL + return cm.host != RootURL } -// FindProxy returns a usable proxy server. +// SwitchToProxy returns a usable proxy server. +// TODO: Perhaps the name could be better -- we aren't only switching to a proxy but also to the standard API. func (cm *ClientManager) SwitchToProxy() (proxy string, err error) { - cm.urlLocker.Lock() - defer cm.urlLocker.Unlock() + cm.hostLocker.Lock() + defer cm.hostLocker.Unlock() logrus.Info("Attempting to switch to a proxy") @@ -169,9 +210,16 @@ func (cm *ClientManager) SwitchToProxy() (proxy string, err error) { logrus.WithField("proxy", proxy).Info("Switching to a proxy") - cm.url = proxy + // If the host is currently the RootURL, it's the first time we are enabling a proxy. + // This means we want to disable it again in 24 hours. + if cm.host == RootURL { + go func() { + <-time.After(cm.proxyUseDuration) + cm.host = RootURL + }() + } - // TODO: Disable again after 24 hours. + cm.host = proxy return } @@ -183,6 +231,9 @@ func (cm *ClientManager) GetConfig() *ClientConfig { // GetToken returns the token for the given userID. func (cm *ClientManager) GetToken(userID string) string { + cm.tokensLocker.Lock() + defer cm.tokensLocker.Unlock() + return cm.tokens[userID] } @@ -208,6 +259,11 @@ func (cm *ClientManager) forwardClientAuths() { // setToken sets the token for the given userID with the given expiration time. func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) { + // We don't want to set tokens of anonymous clients. + if userID == "" { + return + } + cm.tokensLocker.Lock() defer cm.tokensLocker.Unlock() @@ -221,12 +277,15 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration // setTokenExpiration will ensure the token is refreshed if it expires. // If the token already has an expiration time set, it is replaced. func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) { - if exp, ok := cm.tokenExpirations[userID]; ok { + cm.expirationsLocker.Lock() + defer cm.expirationsLocker.Unlock() + + if exp, ok := cm.expirations[userID]; ok { exp.timer.Stop() close(exp.cancel) } - cm.tokenExpirations[userID] = &tokenExpiration{ + cm.expirations[userID] = &tokenExpiration{ timer: time.NewTimer(expiration), cancel: make(chan struct{}), } @@ -262,7 +321,7 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) { } func (cm *ClientManager) watchTokenExpiration(userID string) { - expiration := cm.tokenExpirations[userID] + expiration := cm.expirations[userID] select { case <-expiration.timer.C: @@ -270,6 +329,6 @@ func (cm *ClientManager) watchTokenExpiration(userID string) { cm.clients[userID].AuthRefresh(cm.tokens[userID]) case <-expiration.cancel: - logrus.WithField("userID", userID).Info("Auth was refreshed before it expired, cancelling this watcher") + logrus.WithField("userID", userID).Info("Auth was refreshed before it expired") } } diff --git a/pkg/pmapi/config.go b/pkg/pmapi/config.go index 7e1bb421..7cb1c7a1 100644 --- a/pkg/pmapi/config.go +++ b/pkg/pmapi/config.go @@ -26,8 +26,12 @@ import ( // // This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod. // Default is pmapi_prod. +// +// It should not contain the protocol! The protocol should be in RootScheme. var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals] +var RootScheme = "https" + // CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program // version and email client. // e.g. Bridge/1.0.4 (Windows) MicrosoftOutlook/16.0.9330.2087 diff --git a/pkg/pmapi/config_dev.go b/pkg/pmapi/config_dev.go index 77a2e6c6..42c627ea 100644 --- a/pkg/pmapi/config_dev.go +++ b/pkg/pmapi/config_dev.go @@ -21,4 +21,5 @@ package pmapi func init() { RootURL = "dev.protonmail.com/api" + RootScheme = "https" } diff --git a/pkg/pmapi/config_local.go b/pkg/pmapi/config_local.go index 9544e4f2..c952ba76 100644 --- a/pkg/pmapi/config_local.go +++ b/pkg/pmapi/config_local.go @@ -28,6 +28,7 @@ func init() { // Use port above 1000 which doesn't need root access to start anything on it. // Now the port is rounded pi. :-) RootURL = "127.0.0.1:3142/api" + RootScheme = "http" // TLS certificate is self-signed defaultTransport = &http.Transport{ diff --git a/pkg/pmapi/contacts_test.go b/pkg/pmapi/contacts_test.go index 1aae9367..524a400b 100644 --- a/pkg/pmapi/contacts_test.go +++ b/pkg/pmapi/contacts_test.go @@ -654,7 +654,7 @@ var testCardsCleartext = []Card{ } func TestClient_Encrypt(t *testing.T) { - c := newTestClient() + c := newTestClient(NewClientManager(testClientConfig)) c.kr = testPrivateKeyRing cardEncrypted, err := c.EncryptAndSignCards(testCardsCleartext) @@ -668,7 +668,7 @@ func TestClient_Encrypt(t *testing.T) { } func TestClient_Decrypt(t *testing.T) { - c := newTestClient() + c := newTestClient(NewClientManager(testClientConfig)) c.kr = testPrivateKeyRing cardCleartext, err := c.DecryptAndVerifyCards(testCardsEncrypted) diff --git a/pkg/pmapi/dialer_with_proxy.go b/pkg/pmapi/dialer_with_proxy.go index 503a2c26..9a3c7eee 100644 --- a/pkg/pmapi/dialer_with_proxy.go +++ b/pkg/pmapi/dialer_with_proxy.go @@ -297,7 +297,7 @@ func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn // If DoH is not allowed, give up. Or, if we are dialing something other than the API // (e.g. we dial protonmail.com/... to check for updates), there's also no point in // continuing since a proxy won't help us reach that. - if !p.cm.IsProxyAllowed() || host != p.cm.GetRootURL() { + if !p.cm.IsProxyAllowed() || host != p.cm.GetHost() { p.log.WithField("address", address).Debug("Aborting dial, cannot switch to a proxy") return } diff --git a/pkg/pmapi/dialer_with_proxy_test.go b/pkg/pmapi/dialer_with_proxy_test.go index bd26d0e4..78a90a21 100644 --- a/pkg/pmapi/dialer_with_proxy_test.go +++ b/pkg/pmapi/dialer_with_proxy_test.go @@ -23,26 +23,27 @@ import ( "testing" ) -const liveAPI = "https://api.protonmail.ch" +const liveAPI = "api.protonmail.ch" var testLiveConfig = &ClientConfig{ AppVersion: "Bridge_1.2.4-test", ClientID: "Bridge", } -func newTestDialerWithPinning() (*int, *DialerWithPinning) { +func setTestDialerWithPinning(cm *ClientManager) (*int, *DialerWithPinning) { called := 0 - p := NewPMAPIPinning(testLiveConfig.AppVersion) + p := NewDialerWithPinning(cm, testLiveConfig.AppVersion) p.ReportCertIssueLocal = func() { called++ } - testLiveConfig.Transport = p.TransportWithPinning() + cm.SetRoundTripper(p.TransportWithPinning()) return &called, p } func TestTLSPinValid(t *testing.T) { - called, _ := newTestDialerWithPinning() - - RootURL = liveAPI - client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name()) + cm := NewClientManager(testLiveConfig) + cm.host = liveAPI + RootScheme = "https" + called, _ := setTestDialerWithPinning(cm) + client := cm.GetClient("pmapi" + t.Name()) _, err := client.AuthInfo("this.address.is.disabled") Ok(t, err) @@ -51,12 +52,13 @@ func TestTLSPinValid(t *testing.T) { } func TestTLSPinBackup(t *testing.T) { - called, p := newTestDialerWithPinning() + cm := NewClientManager(testLiveConfig) + cm.host = liveAPI + called, p := setTestDialerWithPinning(cm) p.report.KnownPins[1] = p.report.KnownPins[0] p.report.KnownPins[0] = "" - RootURL = liveAPI - client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name()) + client := cm.GetClient("pmapi" + t.Name()) _, err := client.AuthInfo("this.address.is.disabled") Ok(t, err) @@ -65,19 +67,21 @@ func TestTLSPinBackup(t *testing.T) { } func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused] - called, p := newTestDialerWithPinning() + cm := NewClientManager(testLiveConfig) + cm.host = liveAPI + + called, p := setTestDialerWithPinning(cm) for i := 0; i < len(p.report.KnownPins); i++ { p.report.KnownPins[i] = "testing" } - RootURL = liveAPI - client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name()) + client := cm.GetClient("pmapi" + t.Name()) _, err := client.AuthInfo("this.address.is.disabled") Ok(t, err) // check that it will be called only once per session - client = newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name()) + client = cm.GetClient("pmapi" + t.Name()) _, err = client.AuthInfo("this.address.is.disabled") Ok(t, err) @@ -85,20 +89,22 @@ func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused] } func _TestTLSPinInvalid(t *testing.T) { // nolint[unused] + cm := NewClientManager(testLiveConfig) + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writeJSONResponsefromFile(t, w, "/auth/info/post_response.json", 0) })) defer ts.Close() - called, _ := newTestDialerWithPinning() + called, _ := setTestDialerWithPinning(cm) - client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name()) + client := cm.GetClient("pmapi" + t.Name()) - RootURL = liveAPI + cm.host = liveAPI _, err := client.AuthInfo("this.address.is.disabled") Ok(t, err) - RootURL = ts.URL + cm.host = ts.URL _, err = client.AuthInfo("this.address.is.disabled") Assert(t, err != nil, "error is expected but have %v", err) @@ -106,20 +112,23 @@ func _TestTLSPinInvalid(t *testing.T) { // nolint[unused] } func _TestTLSSignedCertWrongPublicKey(t *testing.T) { // nolint[unused] - _, dialer := newTestDialerWithPinning() + cm := NewClientManager(testLiveConfig) + _, dialer := setTestDialerWithPinning(cm) _, err := dialer.dialAndCheckFingerprints("tcp", "rsa4096.badssl.com:443") Assert(t, err != nil, "expected dial to fail because of wrong public key: ", err.Error()) } func _TestTLSSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused] - _, dialer := newTestDialerWithPinning() + cm := NewClientManager(testLiveConfig) + _, dialer := setTestDialerWithPinning(cm) dialer.report.KnownPins = append(dialer.report.KnownPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`) _, err := dialer.dialAndCheckFingerprints("tcp", "rsa4096.badssl.com:443") Assert(t, err == nil, "expected dial to succeed because public key is known and cert is signed by CA: ", err.Error()) } func _TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused] - _, dialer := newTestDialerWithPinning() + cm := NewClientManager(testLiveConfig) + _, dialer := setTestDialerWithPinning(cm) dialer.report.KnownPins = append(dialer.report.KnownPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`) _, err := dialer.dialAndCheckFingerprints("tcp", "self-signed.badssl.com:443") Assert(t, err == nil, "expected dial to succeed because public key is known despite cert being self-signed: ", err.Error()) diff --git a/pkg/pmapi/proxy.go b/pkg/pmapi/proxy.go index 41d902eb..2af15940 100644 --- a/pkg/pmapi/proxy.go +++ b/pkg/pmapi/proxy.go @@ -72,8 +72,9 @@ func newProxyProvider(providers []string, query string) (p *proxyProvider) { // return } -// findProxy returns a new proxy domain which is not equal to the current RootURL. +// findProxy returns a new working proxy domain. This includes the standard API. // It returns an error if the process takes longer than ProxySearchTime. +// TODO: Perhaps the name can be better -- we might also return the standard API. func (p *proxyProvider) findProxy() (proxy string, err error) { if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) { return "", errors.New("not looking for a proxy, too soon") @@ -88,6 +89,12 @@ func (p *proxyProvider) findProxy() (proxy string, err error) { logrus.WithError(err).Warn("Failed to refresh proxy cache, cache may be out of date") } + // We want to switch back to the RootURL if possible. + if p.canReach(RootURL) { + proxyResult <- RootURL + return + } + for _, proxy := range p.proxyCache { if p.canReach(proxy) { proxyResult <- proxy @@ -114,6 +121,7 @@ func (p *proxyProvider) findProxy() (proxy string, err error) { } // refreshProxyCache loads the latest proxies from the known providers. +// It includes the standard API. func (p *proxyProvider) refreshProxyCache() error { logrus.Info("Refreshing proxy cache") @@ -121,9 +129,6 @@ func (p *proxyProvider) refreshProxyCache() error { if proxies, err := p.dohLookup(p.query, provider); err == nil { p.proxyCache = proxies - // We also want to allow bridge to switch back to the standard API at any time. - p.proxyCache = append(p.proxyCache, RootURL) - logrus.WithField("proxies", proxies).Info("Available proxies") return nil diff --git a/pkg/pmapi/proxy_test.go b/pkg/pmapi/proxy_test.go index d14718c4..19a89c4d 100644 --- a/pkg/pmapi/proxy_test.go +++ b/pkg/pmapi/proxy_test.go @@ -122,23 +122,27 @@ func TestProxyProvider_UseProxy(t *testing.T) { blockAPI() defer unblockAPI() + cm := NewClientManager(testClientConfig) + proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy.Close() p := newProxyProvider([]string{"not used"}, "not used") + cm.proxyProvider = p + p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } - - url, err := p.findProxy() + url, err := cm.SwitchToProxy() require.NoError(t, err) - - p.useProxy(url) - require.Equal(t, proxy.URL, GlobalGetRootURL()) + require.Equal(t, proxy.URL, url) + require.Equal(t, proxy.URL, cm.GetHost()) } func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { blockAPI() defer unblockAPI() + cm := NewClientManager(testClientConfig) + proxy1 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy1.Close() proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) @@ -147,101 +151,106 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { defer proxy3.Close() p := newProxyProvider([]string{"not used"}, "not used") + cm.proxyProvider = p p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil } - url, err := p.findProxy() + url, err := cm.SwitchToProxy() require.NoError(t, err) - p.useProxy(url) - require.Equal(t, proxy1.URL, GlobalGetRootURL()) + require.Equal(t, proxy1.URL, url) + require.Equal(t, proxy1.URL, cm.GetHost()) // Have to wait so as to not get rejected. time.Sleep(proxyLookupWait) p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy2.URL}, nil } - url, err = p.findProxy() + url, err = cm.SwitchToProxy() require.NoError(t, err) - p.useProxy(url) - require.Equal(t, proxy2.URL, GlobalGetRootURL()) + require.Equal(t, proxy2.URL, url) + require.Equal(t, proxy2.URL, cm.GetHost()) // Have to wait so as to not get rejected. time.Sleep(proxyLookupWait) p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy3.URL}, nil } - url, err = p.findProxy() + url, err = cm.SwitchToProxy() require.NoError(t, err) - p.useProxy(url) - require.Equal(t, proxy3.URL, GlobalGetRootURL()) + require.Equal(t, proxy3.URL, url) + require.Equal(t, proxy3.URL, cm.GetHost()) } func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) { blockAPI() defer unblockAPI() + cm := NewClientManager(testClientConfig) + proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy.Close() p := newProxyProvider([]string{"not used"}, "not used") - p.useDuration = time.Second - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } + cm.proxyProvider = p + cm.proxyUseDuration = time.Second - url, err := p.findProxy() + p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } + url, err := cm.SwitchToProxy() require.NoError(t, err) require.Equal(t, proxy.URL, url) - - p.useProxy(url) - require.Equal(t, proxy.URL, GlobalGetRootURL()) + require.Equal(t, proxy.URL, cm.GetHost()) time.Sleep(2 * time.Second) - require.Equal(t, globalOriginalURL, GlobalGetRootURL()) + require.Equal(t, RootURL, cm.GetHost()) } func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) { - // Don't block the API here because we want it to be working so the test can find it. + blockAPI() defer unblockAPI() + cm := NewClientManager(testClientConfig) + proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy.Close() p := newProxyProvider([]string{"not used"}, "not used") - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } + cm.proxyProvider = p - url, err := p.findProxy() + p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } + url, err := cm.SwitchToProxy() require.NoError(t, err) require.Equal(t, proxy.URL, url) + require.Equal(t, proxy.URL, cm.GetHost()) - p.useProxy(url) - require.Equal(t, proxy.URL, GlobalGetRootURL()) - - // Simulate that the proxy stops working. + // Simulate that the proxy stops working and that the standard api is reachable again. proxy.Close() + unblockAPI() time.Sleep(proxyLookupWait) // We should now find the original API URL if it is working again. - url, err = p.findProxy() + url, err = cm.SwitchToProxy() require.NoError(t, err) - require.Equal(t, globalOriginalURL, url) - - p.useProxy(url) - require.Equal(t, globalOriginalURL, GlobalGetRootURL()) + require.Equal(t, RootURL, url) + require.Equal(t, RootURL, cm.GetHost()) } func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) { blockAPI() defer unblockAPI() + cm := NewClientManager(testClientConfig) + proxy1 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy1.Close() proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy2.Close() p := newProxyProvider([]string{"not used"}, "not used") - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } + cm.proxyProvider = p // Find a proxy. - url, err := p.findProxy() + p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } + url, err := cm.SwitchToProxy() require.NoError(t, err) - p.useProxy(url) - require.Equal(t, proxy1.URL, GlobalGetRootURL()) + require.Equal(t, proxy1.URL, url) + require.Equal(t, proxy1.URL, cm.GetHost()) // Have to wait so as to not get rejected. time.Sleep(proxyLookupWait) @@ -250,10 +259,10 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl proxy1.Close() // Should switch to the second proxy because both the first proxy and the protonmail API are blocked. - url, err = p.findProxy() + url, err = cm.SwitchToProxy() require.NoError(t, err) - p.useProxy(url) - require.Equal(t, proxy2.URL, GlobalGetRootURL()) + require.Equal(t, proxy2.URL, url) + require.Equal(t, proxy2.URL, cm.GetHost()) } func TestProxyProvider_DoHLookup_Quad9(t *testing.T) { @@ -289,16 +298,14 @@ func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) } // testAPIURLBackup is used to hold the globalOriginalURL because we clear it for test purposes and need to restore it. -var testAPIURLBackup = globalOriginalURL +var testAPIURLBackup = RootURL // blockAPI prevents tests from reaching the standard API, forcing them to find a proxy. func blockAPI() { - globalSetRootURL("") - globalOriginalURL = "" + RootURL = "" } // unblockAPI allow tests to reach the standard API again. func unblockAPI() { - globalOriginalURL = testAPIURLBackup - globalSetRootURL(globalOriginalURL) + RootURL = testAPIURLBackup } diff --git a/pkg/pmapi/req.go b/pkg/pmapi/req.go index b8d8a4f3..ce73bc8b 100644 --- a/pkg/pmapi/req.go +++ b/pkg/pmapi/req.go @@ -28,7 +28,8 @@ import ( // NewRequest creates a new request. func (c *Client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) { // TODO: Support other protocols (localhost needs http not https). - req, err = http.NewRequest(method, "https://"+c.cm.GetRootURL()+path, body) + req, err = http.NewRequest(method, c.cm.GetRootURL()+path, body) + if req != nil { req.Header.Set("User-Agent", CurrentUserAgent) } diff --git a/pkg/pmapi/sentry_test.go b/pkg/pmapi/sentry_test.go index 8591df3c..efe946bd 100644 --- a/pkg/pmapi/sentry_test.go +++ b/pkg/pmapi/sentry_test.go @@ -25,7 +25,8 @@ import ( ) func TestSentryCrashReport(t *testing.T) { - c := newClient(NewClientManager(testClientConfig), "bridgetest") + cm := NewClientManager(testClientConfig) + c := cm.GetClient("bridgetest") if err := c.ReportSentryCrash(errors.New("Testing crash report - api proxy; goroutines with threads, find origin")); err != nil { t.Fatal("Expected no error while report, but have", err) } diff --git a/pkg/pmapi/server_test.go b/pkg/pmapi/server_test.go index e6da92e8..a616b5f5 100644 --- a/pkg/pmapi/server_test.go +++ b/pkg/pmapi/server_test.go @@ -22,6 +22,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "reflect" @@ -72,15 +73,24 @@ func Equals(tb testing.TB, exp, act interface{}) { // newTestServer is old function and should be replaced everywhere by newTestServerCallbacks. func newTestServer(h http.Handler) (*httptest.Server, *Client) { s := httptest.NewServer(h) - RootURL = s.URL - return s, newTestClient() + serverURL, err := url.Parse(s.URL) + if err != nil { + panic(err) + } + + cm := NewClientManager(testClientConfig) + cm.host = serverURL.Host + cm.scheme = serverURL.Scheme + + return s, newTestClient(cm) } func newTestServerCallbacks(tb testing.TB, callbacks ...func(testing.TB, http.ResponseWriter, *http.Request) string) (func(), *Client) { reqNum := 0 _, file, line, _ := runtime.Caller(1) file = filepath.Base(file) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reqNum++ if reqNum > len(callbacks) { @@ -95,7 +105,12 @@ func newTestServerCallbacks(tb testing.TB, callbacks ...func(testing.TB, http.Re writeJSONResponsefromFile(tb, w, response, reqNum-1) } })) - RootURL = server.URL + + serverURL, err := url.Parse(server.URL) + if err != nil { + panic(err) + } + finish := func() { server.CloseClientConnections() // Closing without waiting for finishing requests. if reqNum != len(callbacks) { @@ -106,7 +121,12 @@ func newTestServerCallbacks(tb testing.TB, callbacks ...func(testing.TB, http.Re tb.Error("server failed") } } - return finish, newTestClient() + + cm := NewClientManager(testClientConfig) + cm.host = serverURL.Host + cm.scheme = serverURL.Scheme + + return finish, newTestClient(cm) } func checkMethodAndPath(r *http.Request, method, path string) error { diff --git a/pkg/pmapi/testdata/routes/auth/refresh/post_resp_has_uid.json b/pkg/pmapi/testdata/routes/auth/refresh/post_resp_has_uid.json index cdf3fb50..dfe67223 100644 --- a/pkg/pmapi/testdata/routes/auth/refresh/post_resp_has_uid.json +++ b/pkg/pmapi/testdata/routes/auth/refresh/post_resp_has_uid.json @@ -3,8 +3,8 @@ "AccessToken": "de0423049b44243afeec7d9c1d99be7b46da1e8a", "ExpiresIn": 360000, "TokenType": "Bearer", - "Uid": "differentUID", - "UID": "differentUID", + "Uid": "729ad6012421d67ad26950dc898bebe3a6e3caa2", + "UID": "729ad6012421d67ad26950dc898bebe3a6e3caa2", "Scope": "full mail payments reset keys", "RefreshToken": "b894b4c4f20003f12d486900d8b88c7d68e67235" -} \ No newline at end of file +} diff --git a/pkg/pmapi/users.go b/pkg/pmapi/users.go index eae443fe..15e1e15b 100644 --- a/pkg/pmapi/users.go +++ b/pkg/pmapi/users.go @@ -121,7 +121,7 @@ func (c *Client) UpdateUser() (user *User, err error) { return user, err } -// CurrentUser return currently active user or user will be updated. +// CurrentUser returns currently active user or user will be updated. func (c *Client) CurrentUser() (user *User, err error) { if c.user != nil && len(c.addresses) != 0 { user = c.user diff --git a/test/context/bridge.go b/test/context/bridge.go index 699460ef..15b0be3d 100644 --- a/test/context/bridge.go +++ b/test/context/bridge.go @@ -21,10 +21,8 @@ import ( "os" "runtime" - "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/pkg/listener" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" ) // GetBridge returns bridge instance. @@ -35,7 +33,10 @@ func (ctx *TestContext) GetBridge() *bridge.Bridge { // withBridgeInstance creates a bridge instance for use in the test. // Every TestContext has this by default and thus this doesn't need to be exported. func (ctx *TestContext) withBridgeInstance() { - ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, ctx.clientManager) + pmapiFactory := func(userID string) bridge.PMAPIProvider { + return ctx.pmapiController.GetClient(userID) + } + ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, pmapiFactory) ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data") } @@ -60,7 +61,7 @@ func newBridgeInstance( cfg *fakeConfig, credStore bridge.CredentialsStorer, eventListener listener.Listener, - clientManager *pmapi.ClientManager, + pmapiFactory bridge.PMAPIProviderFactory, ) *bridge.Bridge { version := os.Getenv("VERSION") bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "") @@ -68,7 +69,7 @@ func newBridgeInstance( panicHandler := &panicHandler{t: t} pref := preferences.New(cfg) - return bridge.New(cfg, pref, panicHandler, eventListener, version, clientManager, credStore) + return bridge.New(cfg, pref, panicHandler, eventListener, version, pmapiFactory, credStore) } // SetLastBridgeError sets the last error that occurred while executing a bridge action. diff --git a/test/context/config.go b/test/context/config.go index c932b47d..8cb3838a 100644 --- a/test/context/config.go +++ b/test/context/config.go @@ -28,6 +28,7 @@ import ( type fakeConfig struct { dir string + tm *pmapi.TokenManager } // newFakeConfig creates a temporary folder for files. @@ -40,6 +41,7 @@ func newFakeConfig() *fakeConfig { return &fakeConfig{ dir: dir, + tm: pmapi.NewTokenManager(), } } @@ -51,6 +53,8 @@ func (c *fakeConfig) GetAPIConfig() *pmapi.ClientConfig { AppVersion: "Bridge_" + os.Getenv("VERSION"), ClientID: "bridge", SentryDSN: "", + // TokenManager should not be required, but PMAPI still doesn't handle not-set cases everywhere. + TokenManager: c.tm, } } func (c *fakeConfig) GetDBDir() string { diff --git a/test/context/context.go b/test/context/context.go index d8f0cda3..1d788d69 100644 --- a/test/context/context.go +++ b/test/context/context.go @@ -21,7 +21,6 @@ package context import ( "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/pkg/listener" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/test/accounts" "github.com/ProtonMail/proton-bridge/test/mocks" "github.com/sirupsen/logrus" @@ -58,9 +57,6 @@ type TestContext struct { smtpClients map[string]*mocks.SMTPClient smtpLastResponses map[string]*mocks.SMTPResponse - // PMAPI related variables. - clientManager *pmapi.ClientManager - // These are the cleanup steps executed when Cleanup() is called. cleanupSteps []*Cleaner @@ -74,20 +70,17 @@ func New() *TestContext { cfg := newFakeConfig() - cm := pmapi.NewClientManager(cfg.GetAPIConfig()) - ctx := &TestContext{ t: &bddT{}, cfg: cfg, listener: listener.New(), - pmapiController: newPMAPIController(cm), + pmapiController: newPMAPIController(), testAccounts: newTestAccounts(), credStore: newFakeCredStore(), imapClients: make(map[string]*mocks.IMAPClient), imapLastResponses: make(map[string]*mocks.IMAPResponse), smtpClients: make(map[string]*mocks.SMTPClient), smtpLastResponses: make(map[string]*mocks.SMTPResponse), - clientManager: cm, logger: logrus.StandardLogger(), } diff --git a/test/context/credentials.go b/test/context/credentials.go index 0926b2c9..b9527d92 100644 --- a/test/context/credentials.go +++ b/test/context/credentials.go @@ -81,6 +81,15 @@ func (c *fakeCredStore) UpdateEmails(userID string, emails []string) error { return nil } +func (c *fakeCredStore) UpdatePassword(userID, password string) error { + creds, err := c.Get(userID) + if err != nil { + return err + } + creds.MailboxPassword = password + return nil +} + func (c *fakeCredStore) UpdateToken(userID, apiToken string) error { creds, err := c.Get(userID) if err != nil { diff --git a/test/context/pmapi_controller.go b/test/context/pmapi_controller.go index 1ad6ff6a..3c2b7469 100644 --- a/test/context/pmapi_controller.go +++ b/test/context/pmapi_controller.go @@ -40,12 +40,12 @@ type PMAPIController interface { GetCalls(method, path string) [][]byte } -func newPMAPIController(cm *pmapi.ClientManager) PMAPIController { +func newPMAPIController() PMAPIController { switch os.Getenv(EnvName) { case EnvFake: return newFakePMAPIController() case EnvLive: - return newLivePMAPIController(cm) + return newLivePMAPIController() default: panic("unknown env") } @@ -67,8 +67,8 @@ func (s *fakePMAPIControllerWrap) GetClient(userID string) bridge.PMAPIProvider return s.Controller.GetClient(userID) } -func newLivePMAPIController(cm *pmapi.ClientManager) PMAPIController { - return newLiveAPIControllerWrap(liveapi.NewController(cm)) +func newLivePMAPIController() PMAPIController { + return newLiveAPIControllerWrap(liveapi.NewController()) } type liveAPIControllerWrap struct { diff --git a/test/fakeapi/auth.go b/test/fakeapi/auth.go index 36e07dea..ed56b40f 100644 --- a/test/fakeapi/auth.go +++ b/test/fakeapi/auth.go @@ -141,12 +141,13 @@ func (api *FakePMAPI) AuthRefresh(token string) (*pmapi.Auth, error) { return auth, nil } -func (api *FakePMAPI) Logout() { +func (api *FakePMAPI) Logout() error { if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil { - return + return err } // Logout will also emit change to auth channel api.sendAuth(nil) api.controller.deleteSession(api.uid) api.unsetUser() + return nil } diff --git a/test/liveapi/controller.go b/test/liveapi/controller.go index 97140a95..f78b4afe 100644 --- a/test/liveapi/controller.go +++ b/test/liveapi/controller.go @@ -18,7 +18,9 @@ package liveapi import ( + "fmt" "net/http" + "os" "sync" "github.com/ProtonMail/proton-bridge/pkg/pmapi" @@ -30,31 +32,31 @@ type Controller struct { calls []*fakeCall pmapiByUsername map[string]*pmapi.Client messageIDsByUsername map[string][]string - clientManager *pmapi.ClientManager // State controlled by test. noInternetConnection bool } -func NewController(cm *pmapi.ClientManager) *Controller { - cntrl := &Controller{ +func NewController() *Controller { + return &Controller{ lock: &sync.RWMutex{}, calls: []*fakeCall{}, pmapiByUsername: map[string]*pmapi.Client{}, messageIDsByUsername: map[string][]string{}, - clientManager: cm, noInternetConnection: false, } - - cntrl.clientManager.SetClientRoundTripper(&fakeTransport{ - cntrl: cntrl, - transport: http.DefaultTransport, - }) - - return cntrl } func (cntrl *Controller) GetClient(userID string) *pmapi.Client { - return cntrl.clientManager.GetClient(userID) + cfg := &pmapi.ClientConfig{ + AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")), + ClientID: "bridge-test", + Transport: &fakeTransport{ + cntrl: cntrl, + transport: http.DefaultTransport, + }, + TokenManager: pmapi.NewTokenManager(), + } + return pmapi.NewClient(cfg, userID) } diff --git a/test/liveapi/users.go b/test/liveapi/users.go index 7b3958f1..602f4adf 100644 --- a/test/liveapi/users.go +++ b/test/liveapi/users.go @@ -18,7 +18,9 @@ package liveapi import ( - "github.com/ProtonMail/proton-bridge/pkg/pmapi" + "fmt" + "os" + "github.com/cucumber/godog" "github.com/pkg/errors" ) @@ -28,7 +30,11 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, return godog.ErrPending } - client := cntrl.GetClient(user.ID) + client := pmapi.NewClient(&pmapi.ClientConfig{ + AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")), + ClientID: "bridge-cntrl", + TokenManager: pmapi.NewTokenManager(), + }, user.ID) authInfo, err := client.AuthInfo(user.Name) if err != nil { @@ -55,6 +61,5 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, } cntrl.pmapiByUsername[user.Name] = client - return nil }