forked from Silverfish/proton-bridge
feat: improve login flow
This commit is contained in:
@ -19,7 +19,6 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -33,6 +32,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pkg/errors"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -48,7 +48,7 @@ type Bridge struct {
|
||||
panicHandler PanicHandler
|
||||
events listener.Listener
|
||||
version string
|
||||
clientManager *pmapi.ClientManager
|
||||
clientManager ClientManager
|
||||
credStorer CredentialsStorer
|
||||
storeCache *store.Cache
|
||||
|
||||
@ -76,7 +76,7 @@ func New(
|
||||
panicHandler PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
version string,
|
||||
clientManager *pmapi.ClientManager,
|
||||
clientManager ClientManager,
|
||||
credStorer CredentialsStorer,
|
||||
) *Bridge {
|
||||
log.Trace("Creating new bridge")
|
||||
@ -107,7 +107,7 @@ func New(
|
||||
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
b.watchUserAuths()
|
||||
b.watchAPIAuths()
|
||||
}()
|
||||
|
||||
if b.credStorer == nil {
|
||||
@ -178,16 +178,21 @@ func (b *Bridge) watchBridgeOutdated() {
|
||||
}
|
||||
}
|
||||
|
||||
// watchUserAuths receives auths from the client manager and sends them to the appropriate user.
|
||||
func (b *Bridge) watchUserAuths() {
|
||||
// watchAPIAuths receives auths from the client manager and sends them to the appropriate user.
|
||||
func (b *Bridge) watchAPIAuths() {
|
||||
for auth := range b.clientManager.GetBridgeAuthChannel() {
|
||||
logrus.Debug("Bridge received auth from ClientManager")
|
||||
|
||||
if user, ok := b.hasUser(auth.UserID); ok {
|
||||
logrus.Debug("Bridge is forwarding auth to user")
|
||||
user.AuthorizeWithAPIAuth(auth.Auth)
|
||||
} else {
|
||||
user, ok := b.hasUser(auth.UserID)
|
||||
if !ok {
|
||||
logrus.Info("User is not added to bridge yet")
|
||||
continue
|
||||
}
|
||||
|
||||
if auth.Auth != nil {
|
||||
user.updateAuthToken(auth.Auth)
|
||||
} else {
|
||||
user.logout()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -209,113 +214,140 @@ func (b *Bridge) closeAllConnections() {
|
||||
// * In case user `auth.HasMailboxPassword()`, ask for it, otherwise use `password`
|
||||
// and then finish the login procedure.
|
||||
// user, err := bridge.FinishLogin(client, auth, mailboxPassword)
|
||||
func (b *Bridge) Login(username, password string) (loginClient PMAPIProvider, auth *pmapi.Auth, err error) {
|
||||
log.WithField("username", username).Trace("Logging in to bridge")
|
||||
|
||||
func (b *Bridge) Login(username, password string) (authClient PMAPIProvider, auth *pmapi.Auth, err error) {
|
||||
b.crashBandicoot(username)
|
||||
|
||||
// We need to use "login" client because we need userID to properly assign access tokens into token manager.
|
||||
loginClient = b.clientManager.GetClient("login")
|
||||
// We need to use anonymous client because we don't yet have userID and so can't save auth tokens yet.
|
||||
authClient = b.clientManager.GetAnonymousClient()
|
||||
|
||||
authInfo, err := loginClient.AuthInfo(username)
|
||||
authInfo, err := authClient.AuthInfo(username)
|
||||
if err != nil {
|
||||
log.WithField("username", username).WithError(err).Error("Could not get auth info for user")
|
||||
return nil, nil, err
|
||||
return
|
||||
}
|
||||
|
||||
if auth, err = loginClient.Auth(username, password, authInfo); err != nil {
|
||||
if auth, err = authClient.Auth(username, password, authInfo); err != nil {
|
||||
log.WithField("username", username).WithError(err).Error("Could not get auth for user")
|
||||
return loginClient, auth, err
|
||||
return
|
||||
}
|
||||
|
||||
return loginClient, auth, nil
|
||||
return
|
||||
}
|
||||
|
||||
// FinishLogin finishes the login procedure and adds the user into the credentials store.
|
||||
// See `Login` for more details of the login flow.
|
||||
func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPassword string) (user *User, err error) { //nolint[funlen]
|
||||
log.Trace("Finishing bridge login")
|
||||
|
||||
func (b *Bridge) FinishLogin(authClient PMAPIProvider, auth *pmapi.Auth, mbPassword string) (user *User, err error) { //nolint[funlen]
|
||||
defer func() {
|
||||
if err == pmapi.ErrUpgradeApplication {
|
||||
b.events.Emit(events.UpgradeApplicationEvent, "")
|
||||
}
|
||||
}()
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
apiUser, hashedPassword, err := getAPIUser(authClient, auth, mbPassword)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to get API user")
|
||||
return
|
||||
}
|
||||
|
||||
defer loginClient.Logout()
|
||||
if user, err = b.GetUser(apiUser.ID); err == nil {
|
||||
if err = b.connectExistingUser(user, auth, hashedPassword); err != nil {
|
||||
log.WithError(err).Error("Failed to connect existing user")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err = b.addNewUser(apiUser, auth, hashedPassword); err != nil {
|
||||
log.WithError(err).Error("Failed to add new user")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mbPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt)
|
||||
b.events.Emit(events.UserRefreshEvent, apiUser.ID)
|
||||
|
||||
return b.GetUser(apiUser.ID)
|
||||
}
|
||||
|
||||
// connectExistingUser connects an existing bridge user to the bridge.
|
||||
func (b *Bridge) connectExistingUser(user *User, auth *pmapi.Auth, hashedPassword string) (err error) {
|
||||
if user.IsConnected() {
|
||||
return errors.New("user is already connected")
|
||||
}
|
||||
|
||||
// Update the user's password in the cred store in case they changed it.
|
||||
if err = b.credStorer.UpdatePassword(user.ID(), hashedPassword); err != nil {
|
||||
return errors.Wrap(err, "failed to update password of user in credentials store")
|
||||
}
|
||||
|
||||
client := b.clientManager.GetClient(user.ID())
|
||||
|
||||
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh auth token of new client")
|
||||
}
|
||||
|
||||
if err = b.credStorer.UpdateToken(user.ID(), auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to update token of user in credentials store")
|
||||
}
|
||||
|
||||
if err = user.init(b.idleUpdates); err != nil {
|
||||
return errors.Wrap(err, "failed to initialise user")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// addNewUser adds a new bridge user to the bridge.
|
||||
func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword string) (err error) {
|
||||
client := b.clientManager.GetClient(user.ID)
|
||||
|
||||
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh token in new client")
|
||||
}
|
||||
|
||||
if user, err = client.UpdateUser(); err != nil {
|
||||
return errors.Wrap(err, "failed to update API user")
|
||||
}
|
||||
|
||||
activeEmails := client.Addresses().ActiveEmails()
|
||||
|
||||
if _, err = b.credStorer.Add(user.ID, user.Name, auth.GenToken(), hashedPassword, activeEmails); err != nil {
|
||||
return errors.Wrap(err, "failed to add user to credentials store")
|
||||
}
|
||||
|
||||
bridgeUser, err := newUser(b.panicHandler, user.ID, b.events, b.credStorer, b.clientManager, b.storeCache, b.config.GetDBDir())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create user")
|
||||
}
|
||||
|
||||
// The user needs to be part of the users list in order for it to receive an auth during initialisation.
|
||||
// TODO: If adding the user fails, we don't want to leave it there.
|
||||
b.users = append(b.users, bridgeUser)
|
||||
|
||||
if err = bridgeUser.init(b.idleUpdates); err != nil {
|
||||
return errors.Wrap(err, "failed to initialise user")
|
||||
}
|
||||
|
||||
b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getAPIUser(client PMAPIProvider, auth *pmapi.Auth, mbPassword string) (user *pmapi.User, hashedPassword string, err error) {
|
||||
hashedPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not hash mailbox password")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = loginClient.Unlock(mbPassword); err != nil {
|
||||
if _, err = client.Unlock(hashedPassword); err != nil {
|
||||
log.WithError(err).Error("Could not decrypt keyring")
|
||||
return
|
||||
}
|
||||
|
||||
apiUser, err := loginClient.CurrentUser()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not get login API user")
|
||||
if user, err = client.UpdateUser(); err != nil {
|
||||
log.WithError(err).Error("Could not update API user")
|
||||
return
|
||||
}
|
||||
|
||||
user, hasUser := b.hasUser(apiUser.ID)
|
||||
|
||||
// If the user exists and is logged in, we don't want to do anything.
|
||||
if hasUser && user.IsConnected() {
|
||||
err = errors.New("user is already logged in")
|
||||
log.WithError(err).Warn("User is already logged in")
|
||||
return
|
||||
}
|
||||
|
||||
apiClient := b.clientManager.GetClient(apiUser.ID)
|
||||
auth, err = apiClient.AuthRefresh(auth.GenToken())
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not refresh token in new client")
|
||||
return
|
||||
}
|
||||
|
||||
// We load the current user again because it should now have addresses loaded.
|
||||
apiUser, err = apiClient.CurrentUser()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not get current API user")
|
||||
return
|
||||
}
|
||||
|
||||
activeEmails := apiClient.Addresses().ActiveEmails()
|
||||
if _, err = b.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), mbPassword, activeEmails); err != nil {
|
||||
log.WithError(err).Error("Could not add user to credentials store")
|
||||
return
|
||||
}
|
||||
|
||||
// If it's a new user, generate the user object.
|
||||
if !hasUser {
|
||||
user, err = newUser(b.panicHandler, apiUser.ID, b.events, b.credStorer, b.clientManager, b.storeCache, b.config.GetDBDir())
|
||||
if err != nil {
|
||||
log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user")
|
||||
return
|
||||
}
|
||||
b.users = append(b.users, user)
|
||||
}
|
||||
|
||||
// Set up the user auth and store (which we do for both new and existing users).
|
||||
if err = user.init(b.idleUpdates); err != nil {
|
||||
log.WithField("user", user.userID).WithError(err).Error("Could not initialise user")
|
||||
return
|
||||
}
|
||||
|
||||
if !hasUser {
|
||||
b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel))
|
||||
}
|
||||
|
||||
b.events.Emit(events.UserRefreshEvent, apiUser.ID)
|
||||
|
||||
return user, err
|
||||
return
|
||||
}
|
||||
|
||||
// GetUsers returns all added users into keychain (even logged out users).
|
||||
@ -326,8 +358,7 @@ func (b *Bridge) GetUsers() []*User {
|
||||
return b.users
|
||||
}
|
||||
|
||||
// GetUser returns a user by `query` which is compared to users' ID, username
|
||||
// or any attached e-mail address.
|
||||
// GetUser returns a user by `query` which is compared to users' ID, username or any attached e-mail address.
|
||||
func (b *Bridge) GetUser(query string) (*User, error) {
|
||||
b.crashBandicoot(query)
|
||||
|
||||
|
||||
@ -73,7 +73,6 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) {
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.pmapiClient.EXPECT().Logout().Return(nil)
|
||||
m.pmapiClient.EXPECT().SetAuths(nil)
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
@ -129,6 +129,7 @@ type mocks struct {
|
||||
PanicHandler *bridgemocks.MockPanicHandler
|
||||
prefProvider *bridgemocks.MockPreferenceProvider
|
||||
pmapiClient *bridgemocks.MockPMAPIProvider
|
||||
clientManager *bridgemocks.MockClientManager
|
||||
credentialsStore *bridgemocks.MockCredentialsStorer
|
||||
eventListener *MockListener
|
||||
|
||||
@ -208,14 +209,10 @@ func testNewBridge(t *testing.T, m mocks) *Bridge {
|
||||
m.prefProvider.EXPECT().GetBool(preferences.AllowProxyKey).Return(false).AnyTimes()
|
||||
m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes()
|
||||
m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes()
|
||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any()).AnyTimes()
|
||||
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
||||
pmapiClientFactory := func(userID string) PMAPIProvider {
|
||||
log.WithField("userID", userID).Info("Creating new pmclient")
|
||||
return m.pmapiClient
|
||||
}
|
||||
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient)
|
||||
|
||||
bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", pmapiClientFactory, m.credentialsStore)
|
||||
bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore)
|
||||
|
||||
waitForEvents()
|
||||
|
||||
|
||||
@ -125,6 +125,20 @@ func (s *Store) UpdateEmails(userID string, emails []string) error {
|
||||
return s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdatePassword(userID, password string) error {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
credentials.MailboxPassword = password
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateToken(userID, apiToken string) error {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,PMAPIProvider,CredentialsStorer)
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,ClientManager,PMAPIProvider,CredentialsStorer)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
@ -205,6 +205,95 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockClientManager is a mock of ClientManager interface
|
||||
type MockClientManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||
type MockClientManagerMockRecorder struct {
|
||||
mock *MockClientManager
|
||||
}
|
||||
|
||||
// NewMockClientManager creates a new mock instance
|
||||
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||
mock := &MockClientManager{ctrl: ctrl}
|
||||
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AllowProxy mocks base method
|
||||
func (m *MockClientManager) AllowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AllowProxy")
|
||||
}
|
||||
|
||||
// AllowProxy indicates an expected call of AllowProxy
|
||||
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// DisallowProxy mocks base method
|
||||
func (m *MockClientManager) DisallowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisallowProxy")
|
||||
}
|
||||
|
||||
// DisallowProxy indicates an expected call of DisallowProxy
|
||||
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
|
||||
}
|
||||
|
||||
// GetAnonymousClient mocks base method
|
||||
func (m *MockClientManager) GetAnonymousClient() *pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAnonymousClient")
|
||||
ret0, _ := ret[0].(*pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAnonymousClient indicates an expected call of GetAnonymousClient
|
||||
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
|
||||
}
|
||||
|
||||
// GetBridgeAuthChannel mocks base method
|
||||
func (m *MockClientManager) GetBridgeAuthChannel() chan pmapi.ClientAuth {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetBridgeAuthChannel")
|
||||
ret0, _ := ret[0].(chan pmapi.ClientAuth)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetBridgeAuthChannel indicates an expected call of GetBridgeAuthChannel
|
||||
func (mr *MockClientManagerMockRecorder) GetBridgeAuthChannel() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBridgeAuthChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel))
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockClientManager) GetClient(arg0 string) *pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||
ret0, _ := ret[0].(*pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||
}
|
||||
|
||||
// MockPMAPIProvider is a mock of PMAPIProvider interface
|
||||
type MockPMAPIProvider struct {
|
||||
ctrl *gomock.Controller
|
||||
@ -615,11 +704,9 @@ func (mr *MockPMAPIProviderMockRecorder) ListMessages(arg0 interface{}) *gomock.
|
||||
}
|
||||
|
||||
// Logout mocks base method
|
||||
func (m *MockPMAPIProvider) Logout() error {
|
||||
func (m *MockPMAPIProvider) Logout() {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Logout")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
m.ctrl.Call(m, "Logout")
|
||||
}
|
||||
|
||||
// Logout indicates an expected call of Logout
|
||||
@ -700,18 +787,6 @@ func (mr *MockPMAPIProviderMockRecorder) SendSimpleMetric(arg0, arg1, arg2 inter
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSimpleMetric", reflect.TypeOf((*MockPMAPIProvider)(nil).SendSimpleMetric), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetAuths mocks base method
|
||||
func (m *MockPMAPIProvider) SetAuths(arg0 chan<- *pmapi.Auth) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetAuths", arg0)
|
||||
}
|
||||
|
||||
// SetAuths indicates an expected call of SetAuths
|
||||
func (mr *MockPMAPIProviderMockRecorder) SetAuths(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAuths", reflect.TypeOf((*MockPMAPIProvider)(nil).SetAuths), arg0)
|
||||
}
|
||||
|
||||
// UnlabelMessages mocks base method
|
||||
func (m *MockPMAPIProvider) UnlabelMessages(arg0 []string, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@ -909,6 +984,20 @@ func (mr *MockCredentialsStorerMockRecorder) UpdateEmails(arg0, arg1 interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEmails", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateEmails), arg0, arg1)
|
||||
}
|
||||
|
||||
// UpdatePassword mocks base method
|
||||
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdatePassword indicates an expected call of UpdatePassword
|
||||
func (mr *MockCredentialsStorerMockRecorder) UpdatePassword(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdatePassword), arg0, arg1)
|
||||
}
|
||||
|
||||
// UpdateToken mocks base method
|
||||
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -43,7 +43,12 @@ type PanicHandler interface {
|
||||
HandlePanic()
|
||||
}
|
||||
|
||||
type Clientman interface {
|
||||
type ClientManager interface {
|
||||
GetClient(userID string) *pmapi.Client
|
||||
GetAnonymousClient() *pmapi.Client
|
||||
GetBridgeAuthChannel() chan pmapi.ClientAuth
|
||||
AllowProxy()
|
||||
DisallowProxy()
|
||||
}
|
||||
|
||||
type PMAPIProvider interface {
|
||||
@ -100,6 +105,7 @@ type CredentialsStorer interface {
|
||||
Get(userID string) (*credentials.Credentials, error)
|
||||
SwitchAddressMode(userID string) error
|
||||
UpdateEmails(userID string, emails []string) error
|
||||
UpdatePassword(userID, password string) error
|
||||
UpdateToken(userID, apiToken string) error
|
||||
Logout(userID string) error
|
||||
Delete(userID string) error
|
||||
|
||||
@ -41,7 +41,7 @@ type User struct {
|
||||
log *logrus.Entry
|
||||
panicHandler PanicHandler
|
||||
listener listener.Listener
|
||||
clientMan *pmapi.ClientManager
|
||||
clientMan ClientManager
|
||||
credStorer CredentialsStorer
|
||||
|
||||
imapUpdatesChannel chan interface{}
|
||||
@ -66,7 +66,7 @@ func newUser(
|
||||
userID string,
|
||||
eventListener listener.Listener,
|
||||
credStorer CredentialsStorer,
|
||||
clientMan *pmapi.ClientManager,
|
||||
clientMan ClientManager,
|
||||
storeCache *store.Cache,
|
||||
storeDir string,
|
||||
) (u *User, err error) {
|
||||
@ -243,20 +243,9 @@ func (u *User) authorizeAndUnlock() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) AuthorizeWithAPIAuth(auth *pmapi.Auth) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
func (u *User) updateAuthToken(auth *pmapi.Auth) {
|
||||
u.log.Debug("User received auth from bridge")
|
||||
|
||||
if auth == nil {
|
||||
if err := u.logout(); err != nil {
|
||||
u.log.WithError(err).Error("Failed to logout user after receiving empty auth from API")
|
||||
}
|
||||
u.isAuthorized = false
|
||||
return
|
||||
}
|
||||
|
||||
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
|
||||
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
|
||||
return
|
||||
@ -510,11 +499,16 @@ func (u *User) logout() error {
|
||||
u.lock.Lock()
|
||||
wasConnected := u.creds.IsConnected()
|
||||
u.lock.Unlock()
|
||||
|
||||
err := u.Logout()
|
||||
|
||||
if wasConnected {
|
||||
u.listener.Emit(events.LogoutEvent, u.userID)
|
||||
u.listener.Emit(events.UserRefreshEvent, u.userID)
|
||||
}
|
||||
|
||||
u.isAuthorized = false
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -534,6 +528,7 @@ func (u *User) Logout() (err error) {
|
||||
u.wasKeyringUnlocked = false
|
||||
u.unlockingKeyringLock.Unlock()
|
||||
|
||||
// TODO: Is this necessary or could it be done by ClientManager when a nil auth is received?
|
||||
u.client().Logout()
|
||||
|
||||
if err = u.credStorer.Logout(u.userID); err != nil {
|
||||
@ -550,6 +545,7 @@ func (u *User) Logout() (err error) {
|
||||
u.closeEventLoop()
|
||||
|
||||
u.closeAllConnections()
|
||||
|
||||
runtime.GC()
|
||||
|
||||
return err
|
||||
@ -557,7 +553,7 @@ func (u *User) Logout() (err error) {
|
||||
|
||||
func (u *User) refreshFromCredentials() {
|
||||
if credentials, err := u.credStorer.Get(u.userID); err != nil {
|
||||
log.Error("Cannot update credentials: ", err)
|
||||
log.WithError(err).Error("Cannot refresh user credentials")
|
||||
} else {
|
||||
u.creds = credentials
|
||||
}
|
||||
|
||||
@ -174,13 +174,13 @@ func TestCheckBridgeLoginLoggedOut(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
|
||||
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient)
|
||||
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized"))
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil)
|
||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
_ = user.init(nil, m.pmapiClient)
|
||||
_ = user.init(nil)
|
||||
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
|
||||
@ -44,7 +44,6 @@ func TestNewUserBridgeOutdated(t *testing.T) {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrUpgradeApplication).AnyTimes()
|
||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
||||
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "").AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUpgradeApplication)
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil)
|
||||
@ -58,7 +57,6 @@ func TestNewUserNoInternetConnection(t *testing.T) {
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrAPINotReachable).AnyTimes()
|
||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
||||
m.eventListener.EXPECT().Emit(events.InternetOffEvent, "").AnyTimes()
|
||||
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil)
|
||||
@ -75,12 +73,10 @@ func TestNewUserAuthRefreshFails(t *testing.T) {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")).AnyTimes()
|
||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.pmapiClient.EXPECT().Logout().Return(nil)
|
||||
m.pmapiClient.EXPECT().SetAuths(nil)
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
@ -96,7 +92,6 @@ func TestNewUserUnlockFails(t *testing.T) {
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
|
||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password"))
|
||||
@ -104,7 +99,6 @@ func TestNewUserUnlockFails(t *testing.T) {
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.pmapiClient.EXPECT().Logout().Return(nil)
|
||||
m.pmapiClient.EXPECT().SetAuths(nil)
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
@ -120,7 +114,6 @@ func TestNewUserUnlockAddressesFails(t *testing.T) {
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
|
||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
||||
@ -129,7 +122,6 @@ func TestNewUserUnlockAddressesFails(t *testing.T) {
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.pmapiClient.EXPECT().Logout().Return(nil)
|
||||
m.pmapiClient.EXPECT().SetAuths(nil)
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
@ -144,7 +136,6 @@ func TestNewUser(t *testing.T) {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||
|
||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
||||
|
||||
@ -30,7 +30,6 @@ func testNewUser(m mocks) *User {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||
|
||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
||||
@ -57,14 +56,12 @@ func testNewUserForLogout(m mocks) *User {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||
|
||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil)
|
||||
|
||||
// These may or may not be hit depending on how fast the log out happens.
|
||||
m.pmapiClient.EXPECT().SetAuths(nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}).AnyTimes()
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil)
|
||||
|
||||
@ -527,11 +527,11 @@ func (s *FrontendQt) toggleAllowProxy() {
|
||||
|
||||
if s.preferences.GetBool(preferences.AllowProxyKey) {
|
||||
s.preferences.SetBool(preferences.AllowProxyKey, false)
|
||||
bridge.DisallowDoH()
|
||||
s.bridge.DisallowProxy()
|
||||
s.Qml.SetIsProxyAllowed(false)
|
||||
} else {
|
||||
s.preferences.SetBool(preferences.AllowProxyKey, true)
|
||||
bridge.AllowDoH()
|
||||
s.bridge.AllowProxy()
|
||||
s.Qml.SetIsProxyAllowed(true)
|
||||
}
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig {
|
||||
}
|
||||
|
||||
func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, listener listener.Listener) {
|
||||
logrus.Info("Setting dialer with pinning")
|
||||
logrus.Info("Setting ClientManager to create clients with key pinning")
|
||||
|
||||
pin := pmapi.NewDialerWithPinning(cm, cfg.AppVersion)
|
||||
|
||||
@ -47,5 +47,5 @@ func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, lis
|
||||
listener.Emit(events.TLSCertIssue, "")
|
||||
}
|
||||
|
||||
cm.SetClientRoundTripper(pin.TransportWithPinning())
|
||||
cm.SetRoundTripper(pin.TransportWithPinning())
|
||||
}
|
||||
|
||||
@ -217,7 +217,7 @@ func (store *Store) txGetOnAPICounts(tx *bolt.Tx) ([]*mailboxCounts, error) {
|
||||
|
||||
// createOrUpdateOnAPICounts will change only on-API-counts.
|
||||
func (store *Store) createOrUpdateOnAPICounts(mailboxCountsOnAPI []*pmapi.MessagesCount) error {
|
||||
store.log.WithField("apiCounts", mailboxCountsOnAPI).Debug("Updating API counts")
|
||||
store.log.Debug("Updating API counts")
|
||||
|
||||
tx := func(tx *bolt.Tx) error {
|
||||
countsBkt := tx.Bucket(countsBucket)
|
||||
|
||||
Reference in New Issue
Block a user