From 042c3408811ebc93146e71e8ad3adc92dddc0faf Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Tue, 7 Apr 2020 09:55:28 +0200 Subject: [PATCH] feat: make store use ClientManager --- Makefile | 4 +- internal/bridge/bridge.go | 6 +- internal/bridge/bridge_login_test.go | 2 +- internal/bridge/bridge_new_test.go | 14 +++-- internal/bridge/bridge_test.go | 10 +-- internal/bridge/mocks/mocks.go | 91 +++++++++++++++++++++++++++- internal/bridge/types.go | 8 +++ internal/bridge/user.go | 6 +- internal/bridge/user_new_test.go | 10 +-- internal/bridge/user_test.go | 8 +-- internal/store/address.go | 6 +- internal/store/event_loop.go | 30 ++++----- internal/store/event_loop_test.go | 12 ++-- internal/store/mailbox.go | 4 +- internal/store/mailbox_message.go | 24 ++++---- internal/store/message.go | 2 - internal/store/mocks/mocks.go | 40 +++++++++++- internal/store/store.go | 38 ++++++------ internal/store/store_test.go | 44 ++++++++------ internal/store/types.go | 34 +---------- internal/store/user.go | 4 +- internal/store/user_address_info.go | 2 +- internal/store/user_mailbox.go | 12 ++-- internal/store/user_message.go | 6 +- internal/store/user_sync.go | 5 +- pkg/pmapi/auth.go | 13 ++-- pkg/pmapi/client.go | 31 ++++++---- pkg/pmapi/clientmanager.go | 33 +++++++--- pkg/pmapi/mocks/mocks.go | 15 +++++ test/context/bridge.go | 11 ++-- test/context/config.go | 4 -- test/context/context.go | 7 ++- test/context/pmapi_controller.go | 24 +++----- test/fakeapi/attachments.go | 8 +++ test/fakeapi/auth.go | 16 ++++- test/fakeapi/controller.go | 16 ++--- test/fakeapi/reports.go | 4 ++ test/fakeapi/user.go | 4 ++ test/liveapi/cleanup.go | 10 +-- test/liveapi/controller.go | 28 +++------ test/liveapi/labels.go | 12 +--- test/liveapi/messages.go | 9 +-- test/liveapi/users.go | 11 +--- 43 files changed, 414 insertions(+), 264 deletions(-) diff --git a/Makefile b/Makefile index e4505c9f..7bf4d29c 100644 --- a/Makefile +++ b/Makefile @@ -149,8 +149,8 @@ coverage: test go tool cover -html=/tmp/coverage.out -o=coverage.html mocks: - mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/bridge Configer,PreferenceProvider,PanicHandler,CredentialsStorer > internal/bridge/mocks/mocks.go - mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser > internal/store/mocks/mocks.go + mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/bridge Configer,PreferenceProvider,PanicHandler,ClientManager,CredentialsStorer > internal/bridge/mocks/mocks.go + mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,ClientManager,BridgeUser > internal/store/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client > pkg/pmapi/mocks/mocks.go diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index f0b89994..ac95a385 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -48,7 +48,7 @@ type Bridge struct { panicHandler PanicHandler events listener.Listener version string - clientManager *pmapi.ClientManager + clientManager ClientManager credStorer CredentialsStorer storeCache *store.Cache @@ -76,7 +76,7 @@ func New( panicHandler PanicHandler, eventListener listener.Listener, version string, - clientManager *pmapi.ClientManager, + clientManager ClientManager, credStorer CredentialsStorer, ) *Bridge { log.Trace("Creating new bridge") @@ -185,7 +185,7 @@ func (b *Bridge) watchAPIAuths() { user, ok := b.hasUser(auth.UserID) if !ok { - logrus.Info("User is not added to bridge yet") + logrus.WithField("userID", auth.UserID).Info("User not available for auth update") continue } diff --git a/internal/bridge/bridge_login_test.go b/internal/bridge/bridge_login_test.go index 3b07c657..fc9d376b 100644 --- a/internal/bridge/bridge_login_test.go +++ b/internal/bridge/bridge_login_test.go @@ -39,7 +39,7 @@ func TestBridgeFinishLoginBadPassword(t *testing.T) { // Set up mocks for FinishLogin. err := errors.New("bad password") m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err) - m.pmapiClient.EXPECT().Logout().Return(nil) + m.pmapiClient.EXPECT().Logout() checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", err) } diff --git a/internal/bridge/bridge_new_test.go b/internal/bridge/bridge_new_test.go index ec6214d1..f7d5bf20 100644 --- a/internal/bridge/bridge_new_test.go +++ b/internal/bridge/bridge_new_test.go @@ -56,6 +56,7 @@ func TestNewBridgeWithDisconnectedUser(t *testing.T) { m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil).Times(2) m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")) m.pmapiClient.EXPECT().Addresses().Return(nil) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient) checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) } @@ -66,13 +67,14 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) { m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) m.credentialsStore.EXPECT().Logout("user").Return(nil) m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")) m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") - m.pmapiClient.EXPECT().Logout().Return(nil) + m.pmapiClient.EXPECT().Logout() m.credentialsStore.EXPECT().Logout("user").Return(nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") @@ -84,10 +86,14 @@ func TestNewBridgeWithConnectedUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() + m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) + m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) + m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil) + m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) // Set up mocks for store initialisation for the authorized user. m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) @@ -97,10 +103,6 @@ func TestNewBridgeWithConnectedUser(t *testing.T) { m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) - m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil) - checkBridgeNew(t, m, []*credentials.Credentials{testCredentials}) } diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 4999668d..40c9185d 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -129,11 +129,11 @@ type mocks struct { config *bridgemocks.MockConfiger PanicHandler *bridgemocks.MockPanicHandler prefProvider *bridgemocks.MockPreferenceProvider + clientManager *bridgemocks.MockClientManager credentialsStore *bridgemocks.MockCredentialsStorer eventListener *MockListener - pmapiClient *pmapimocks.MockClient - clientManager *pmapimocks.MockClientManager + pmapiClient *pmapimocks.MockClient storeCache *store.Cache } @@ -151,11 +151,11 @@ func initMocks(t *testing.T) mocks { config: bridgemocks.NewMockConfiger(mockCtrl), PanicHandler: bridgemocks.NewMockPanicHandler(mockCtrl), prefProvider: bridgemocks.NewMockPreferenceProvider(mockCtrl), + clientManager: bridgemocks.NewMockClientManager(mockCtrl), credentialsStore: bridgemocks.NewMockCredentialsStorer(mockCtrl), eventListener: NewMockListener(mockCtrl), - pmapiClient: pmapimocks.NewMockClient(mockCtrl), - clientManager: pmapimocks.NewMockClientManager(mockCtrl), + pmapiClient: pmapimocks.NewMockClient(mockCtrl), storeCache: store.NewCache(cacheFile.Name()), } @@ -214,7 +214,7 @@ func testNewBridge(t *testing.T, m mocks) *Bridge { m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes() m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes() m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any()) - m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient) + m.clientManager.EXPECT().GetBridgeAuthChannel().Return(make(chan *pmapi.ClientAuth)) bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore) diff --git a/internal/bridge/mocks/mocks.go b/internal/bridge/mocks/mocks.go index b08f73c7..bd4e9cda 100644 --- a/internal/bridge/mocks/mocks.go +++ b/internal/bridge/mocks/mocks.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,CredentialsStorer) +// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,ClientManager,CredentialsStorer) // Package mocks is a generated GoMock package. package mocks @@ -203,6 +203,95 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic)) } +// MockClientManager is a mock of ClientManager interface +type MockClientManager struct { + ctrl *gomock.Controller + recorder *MockClientManagerMockRecorder +} + +// MockClientManagerMockRecorder is the mock recorder for MockClientManager +type MockClientManagerMockRecorder struct { + mock *MockClientManager +} + +// NewMockClientManager creates a new mock instance +func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager { + mock := &MockClientManager{ctrl: ctrl} + mock.recorder = &MockClientManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder { + return m.recorder +} + +// AllowProxy mocks base method +func (m *MockClientManager) AllowProxy() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AllowProxy") +} + +// AllowProxy indicates an expected call of AllowProxy +func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy)) +} + +// DisallowProxy mocks base method +func (m *MockClientManager) DisallowProxy() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisallowProxy") +} + +// DisallowProxy indicates an expected call of DisallowProxy +func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy)) +} + +// GetAnonymousClient mocks base method +func (m *MockClientManager) GetAnonymousClient() pmapi.Client { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAnonymousClient") + ret0, _ := ret[0].(pmapi.Client) + return ret0 +} + +// GetAnonymousClient indicates an expected call of GetAnonymousClient +func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient)) +} + +// GetBridgeAuthChannel mocks base method +func (m *MockClientManager) GetBridgeAuthChannel() chan pmapi.ClientAuth { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBridgeAuthChannel") + ret0, _ := ret[0].(chan pmapi.ClientAuth) + return ret0 +} + +// GetBridgeAuthChannel indicates an expected call of GetBridgeAuthChannel +func (mr *MockClientManagerMockRecorder) GetBridgeAuthChannel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBridgeAuthChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel)) +} + +// GetClient mocks base method +func (m *MockClientManager) GetClient(arg0 string) pmapi.Client { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClient", arg0) + ret0, _ := ret[0].(pmapi.Client) + return ret0 +} + +// GetClient indicates an expected call of GetClient +func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0) +} + // MockCredentialsStorer is a mock of CredentialsStorer interface type MockCredentialsStorer struct { ctrl *gomock.Controller diff --git a/internal/bridge/types.go b/internal/bridge/types.go index c728df49..ea630ee4 100644 --- a/internal/bridge/types.go +++ b/internal/bridge/types.go @@ -51,3 +51,11 @@ type CredentialsStorer interface { Logout(userID string) error Delete(userID string) error } + +type ClientManager interface { + GetClient(userID string) pmapi.Client + GetAnonymousClient() pmapi.Client + AllowProxy() + DisallowProxy() + GetBridgeAuthChannel() chan pmapi.ClientAuth +} diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 9279f435..41a7d160 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -41,7 +41,7 @@ type User struct { log *logrus.Entry panicHandler PanicHandler listener listener.Listener - clientManager *pmapi.ClientManager + clientManager ClientManager credStorer CredentialsStorer imapUpdatesChannel chan interface{} @@ -66,7 +66,7 @@ func newUser( userID string, eventListener listener.Listener, credStorer CredentialsStorer, - clientManager *pmapi.ClientManager, + clientManager ClientManager, storeCache *store.Cache, storeDir string, ) (u *User, err error) { @@ -139,7 +139,7 @@ func (u *User) init(idleUpdates chan interface{}) (err error) { } u.store = nil } - store, err := store.New(u.panicHandler, u, u.client(), u.listener, u.storePath, u.storeCache) + store, err := store.New(u.panicHandler, u, u.clientManager, u.listener, u.storePath, u.storeCache) if err != nil { return errors.Wrap(err, "failed to create store") } diff --git a/internal/bridge/user_new_test.go b/internal/bridge/user_new_test.go index 1067ab30..8344a124 100644 --- a/internal/bridge/user_new_test.go +++ b/internal/bridge/user_new_test.go @@ -33,7 +33,7 @@ func TestNewUserNoCredentialsStore(t *testing.T) { m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail")) - _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") + _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") a.Error(t, err) } @@ -153,10 +153,10 @@ func TestNewUser(t *testing.T) { } func checkNewUser(m mocks) { - user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") + user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") defer cleanUpUserData(user) - _ = user.init(nil, m.pmapiClient) + _ = user.init(nil) waitForEvents() @@ -164,10 +164,10 @@ func checkNewUser(m mocks) { } func checkNewUserDisconnected(m mocks) { - user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") + user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") defer cleanUpUserData(user) - _ = user.init(nil, m.pmapiClient) + _ = user.init(nil) waitForEvents() diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index 092f3642..a39dc4e3 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -43,10 +43,10 @@ func testNewUser(m mocks) *User { m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil) m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes() - user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") + user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") assert.NoError(m.t, err) - err = user.init(nil, m.pmapiClient) + err = user.init(nil) assert.NoError(m.t, err) return user @@ -69,10 +69,10 @@ func testNewUserForLogout(m mocks) *User { m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes() - user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") + user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp") assert.NoError(m.t, err) - err = user.init(nil, m.pmapiClient) + err = user.init(nil) assert.NoError(m.t, err) return user diff --git a/internal/store/address.go b/internal/store/address.go index cb32c3f3..d890742b 100644 --- a/internal/store/address.go +++ b/internal/store/address.go @@ -109,5 +109,9 @@ func (storeAddress *Address) AddressID() string { // APIAddress returns the `pmapi.Address` struct. func (storeAddress *Address) APIAddress() *pmapi.Address { - return storeAddress.store.api.Addresses().ByEmail(storeAddress.address) + return storeAddress.client().Addresses().ByEmail(storeAddress.address) +} + +func (storeAddress *Address) client() pmapi.Client { + return storeAddress.store.client() } diff --git a/internal/store/event_loop.go b/internal/store/event_loop.go index 0238c983..7077e393 100644 --- a/internal/store/event_loop.go +++ b/internal/store/event_loop.go @@ -42,13 +42,12 @@ type eventLoop struct { log *logrus.Entry - store *Store - apiClient PMAPIProvider - user BridgeUser - events listener.Listener + store *Store + user BridgeUser + events listener.Listener } -func newEventLoop(cache *Cache, store *Store, api PMAPIProvider, user BridgeUser, events listener.Listener) *eventLoop { +func newEventLoop(cache *Cache, store *Store, user BridgeUser, events listener.Listener) *eventLoop { eventLog := log.WithField("userID", user.ID()) eventLog.Trace("Creating new event loop") @@ -60,10 +59,9 @@ func newEventLoop(cache *Cache, store *Store, api PMAPIProvider, user BridgeUser log: eventLog, - store: store, - apiClient: api, - user: user, - events: events, + store: store, + user: user, + events: events, } } @@ -71,10 +69,14 @@ func (loop *eventLoop) IsRunning() bool { return loop.isRunning } +func (loop *eventLoop) client() pmapi.Client { + return loop.store.client() +} + func (loop *eventLoop) setFirstEventID() (err error) { loop.log.Info("Setting first event ID") - event, err := loop.apiClient.GetEvent("") + event, err := loop.client().GetEvent("") if err != nil { loop.log.WithError(err).Error("Could not get latest event ID") return @@ -240,7 +242,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun loop.pollCounter++ var event *pmapi.Event - if event, err = loop.apiClient.GetEvent(loop.currentEventID); err != nil { + if event, err = loop.client().GetEvent(loop.currentEventID); err != nil { return false, errors.Wrap(err, "failed to get event") } @@ -321,7 +323,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap log.Debug("Processing address change event") // Get old addresses for comparisons before updating user. - oldList := loop.apiClient.Addresses() + oldList := loop.client().Addresses() if err = loop.user.UpdateUser(); err != nil { if logoutErr := loop.user.Logout(); logoutErr != nil { @@ -363,7 +365,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap } } - if err = loop.store.createOrUpdateAddressInfo(loop.apiClient.Addresses()); err != nil { + if err = loop.store.createOrUpdateAddressInfo(loop.client().Addresses()); err != nil { return errors.Wrap(err, "failed to update address IDs in store") } @@ -430,7 +432,7 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...") - if msg, err = loop.apiClient.GetMessage(message.ID); err != nil { + if msg, err = loop.client().GetMessage(message.ID); err != nil { if err == pmapi.ErrNoSuchAPIID { msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API") err = nil diff --git a/internal/store/event_loop_test.go b/internal/store/event_loop_test.go index 7bb67053..037b8ae5 100644 --- a/internal/store/event_loop_test.go +++ b/internal/store/event_loop_test.go @@ -39,21 +39,21 @@ func TestEventLoopProcessMoreEvents(t *testing.T) { // Doesn't matter which IDs are used. // This test is trying to see whether event loop will immediately process // next event if there is `More` of them. - m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{ + m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{ EventID: "event50", More: 1, }, nil), - m.api.EXPECT().GetEvent("event50").Return(&pmapi.Event{ + m.client.EXPECT().GetEvent("event50").Return(&pmapi.Event{ EventID: "event70", More: 0, }, nil), - m.api.EXPECT().GetEvent("event70").Return(&pmapi.Event{ + m.client.EXPECT().GetEvent("event70").Return(&pmapi.Event{ EventID: "event71", More: 0, }, nil), ) m.newStoreNoEvents(true) - m.api.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() + m.client.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() // Event loop runs in goroutine and will be stopped by deferred mock clearing. go m.store.eventLoop.start() @@ -78,12 +78,12 @@ func TestEventLoopUpdateMessageFromLoop(t *testing.T) { newSubject := "new subject" // First sync will add message with old subject to database. - m.api.EXPECT().GetMessage("msg1").Return(&pmapi.Message{ + m.client.EXPECT().GetMessage("msg1").Return(&pmapi.Message{ ID: "msg1", Subject: subject, }, nil) // Event will update the subject. - m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{ + m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{ EventID: "event1", Messages: []*pmapi.EventMessage{{ EventItem: pmapi.EventItem{ diff --git a/internal/store/mailbox.go b/internal/store/mailbox.go index b8589b49..967775e0 100644 --- a/internal/store/mailbox.go +++ b/internal/store/mailbox.go @@ -258,8 +258,8 @@ func (storeMailbox *Mailbox) pollNow() { } // api is a proxy for the store's `PMAPIProvider`. -func (storeMailbox *Mailbox) api() PMAPIProvider { - return storeMailbox.store.api +func (storeMailbox *Mailbox) client() pmapi.Client { + return storeMailbox.store.client() } // update is a proxy for the store's db's `Update`. diff --git a/internal/store/mailbox_message.go b/internal/store/mailbox_message.go index 3e2fcf7b..55adfee0 100644 --- a/internal/store/mailbox_message.go +++ b/internal/store/mailbox_message.go @@ -37,7 +37,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) { // FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message // wrapping it. func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) { - msg, err := storeMailbox.api().GetMessage(apiID) + msg, err := storeMailbox.client().GetMessage(apiID) if err != nil { return nil, err } @@ -62,7 +62,7 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe LabelIDs: labelIDs, } - res, err := storeMailbox.api().Import([]*pmapi.ImportMsgReq{importReqs}) + res, err := storeMailbox.client().Import([]*pmapi.ImportMsgReq{importReqs}) if err == nil && len(res) > 0 { msg.ID = res[0].MessageID } @@ -79,7 +79,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error { "mailbox": storeMailbox.Name, }).Trace("Labeling messages") defer storeMailbox.pollNow() - return storeMailbox.api().LabelMessages(apiIDs, storeMailbox.labelID) + return storeMailbox.client().LabelMessages(apiIDs, storeMailbox.labelID) } // UnlabelMessages removes the label by calling an API. @@ -92,7 +92,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error { "mailbox": storeMailbox.Name, }).Trace("Unlabeling messages") defer storeMailbox.pollNow() - return storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID) + return storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID) } // MarkMessagesRead marks the message read by calling an API. @@ -116,7 +116,7 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error { ids = append(ids, apiID) } } - return storeMailbox.api().MarkMessagesRead(ids) + return storeMailbox.client().MarkMessagesRead(ids) } // MarkMessagesUnread marks the message unread by calling an API. @@ -128,7 +128,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error { "mailbox": storeMailbox.Name, }).Trace("Marking messages as unread") defer storeMailbox.pollNow() - return storeMailbox.api().MarkMessagesUnread(apiIDs) + return storeMailbox.client().MarkMessagesUnread(apiIDs) } // MarkMessagesStarred adds the Starred label by calling an API. @@ -141,7 +141,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error { "mailbox": storeMailbox.Name, }).Trace("Marking messages as starred") defer storeMailbox.pollNow() - return storeMailbox.api().LabelMessages(apiIDs, pmapi.StarredLabel) + return storeMailbox.client().LabelMessages(apiIDs, pmapi.StarredLabel) } // MarkMessagesUnstarred removes the Starred label by calling an API. @@ -154,7 +154,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error { "mailbox": storeMailbox.Name, }).Trace("Marking messages as unstarred") defer storeMailbox.pollNow() - return storeMailbox.api().UnlabelMessages(apiIDs, pmapi.StarredLabel) + return storeMailbox.client().UnlabelMessages(apiIDs, pmapi.StarredLabel) } // DeleteMessages deletes messages. @@ -197,21 +197,21 @@ func (storeMailbox *Mailbox) DeleteMessages(apiIDs []string) error { } } if len(messageIDsToUnlabel) > 0 { - if err := storeMailbox.api().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil { + if err := storeMailbox.client().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil { log.WithError(err).Warning("Cannot unlabel before deleting") } } if len(messageIDsToDelete) > 0 { - if err := storeMailbox.api().DeleteMessages(messageIDsToDelete); err != nil { + if err := storeMailbox.client().DeleteMessages(messageIDsToDelete); err != nil { return err } } case pmapi.DraftLabel: - if err := storeMailbox.api().DeleteMessages(apiIDs); err != nil { + if err := storeMailbox.client().DeleteMessages(apiIDs); err != nil { return err } default: - if err := storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil { + if err := storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil { return err } } diff --git a/internal/store/message.go b/internal/store/message.go index 8be26e86..ac1d89c5 100644 --- a/internal/store/message.go +++ b/internal/store/message.go @@ -28,7 +28,6 @@ import ( // a specific mailbox with helper functions to get IMAP UID, sequence // numbers and similar. type Message struct { - api PMAPIProvider msg *pmapi.Message store *Store @@ -37,7 +36,6 @@ type Message struct { func newStoreMessage(storeMailbox *Mailbox, msg *pmapi.Message) *Message { return &Message{ - api: storeMailbox.store.api, msg: msg, store: storeMailbox.store, storeMailbox: storeMailbox, diff --git a/internal/store/mocks/mocks.go b/internal/store/mocks/mocks.go index 486cfaef..8fd5c2d6 100644 --- a/internal/store/mocks/mocks.go +++ b/internal/store/mocks/mocks.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser) +// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,ClientManager,BridgeUser) // Package mocks is a generated GoMock package. package mocks @@ -7,6 +7,7 @@ package mocks import ( reflect "reflect" + pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" gomock "github.com/golang/mock/gomock" ) @@ -45,6 +46,43 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic)) } +// MockClientManager is a mock of ClientManager interface +type MockClientManager struct { + ctrl *gomock.Controller + recorder *MockClientManagerMockRecorder +} + +// MockClientManagerMockRecorder is the mock recorder for MockClientManager +type MockClientManagerMockRecorder struct { + mock *MockClientManager +} + +// NewMockClientManager creates a new mock instance +func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager { + mock := &MockClientManager{ctrl: ctrl} + mock.recorder = &MockClientManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder { + return m.recorder +} + +// GetClient mocks base method +func (m *MockClientManager) GetClient(arg0 string) pmapi.Client { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClient", arg0) + ret0, _ := ret[0].(pmapi.Client) + return ret0 +} + +// GetClient indicates an expected call of GetClient +func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0) +} + // MockBridgeUser is a mock of BridgeUser interface type MockBridgeUser struct { ctrl *gomock.Controller diff --git a/internal/store/store.go b/internal/store/store.go index 5952890c..354b0d96 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -89,10 +89,10 @@ var ( // Store is local user storage, which handles the synchronization between IMAP and PM API. type Store struct { - panicHandler PanicHandler - eventLoop *eventLoop - user BridgeUser - api PMAPIProvider + panicHandler PanicHandler + eventLoop *eventLoop + user BridgeUser + clientManager ClientManager log *logrus.Entry @@ -111,13 +111,13 @@ type Store struct { func New( panicHandler PanicHandler, user BridgeUser, - api PMAPIProvider, + clientManager ClientManager, events listener.Listener, path string, cache *Cache, ) (store *Store, err error) { - if user == nil || api == nil || events == nil || cache == nil { - return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, api, events, cache) + if user == nil || clientManager == nil || events == nil || cache == nil { + return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, clientManager, events, cache) } l := log.WithField("user", user.ID()) @@ -138,14 +138,14 @@ func New( } store = &Store{ - panicHandler: panicHandler, - api: api, - user: user, - cache: cache, - filePath: path, - db: bdb, - lock: &sync.RWMutex{}, - log: l, + panicHandler: panicHandler, + clientManager: clientManager, + user: user, + cache: cache, + filePath: path, + db: bdb, + lock: &sync.RWMutex{}, + log: l, } if err = store.init(firstInit); err != nil { @@ -158,7 +158,7 @@ func New( } if user.IsConnected() { - store.eventLoop = newEventLoop(cache, store, api, user, events) + store.eventLoop = newEventLoop(cache, store, user, events) go func() { defer store.panicHandler.HandlePanic() store.eventLoop.start() @@ -261,10 +261,14 @@ func (store *Store) init(firstInit bool) (err error) { return err } +func (store *Store) client() pmapi.Client { + return store.clientManager.GetClient(store.UserID()) +} + // initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if // the API is unavailable for whatever reason it tries to fetch the labels locally. func (store *Store) initCounts() (labels []*pmapi.Label, err error) { - if labels, err = store.api.ListLabels(); err != nil { + if labels, err = store.client().ListLabels(); err != nil { store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.") if labels, err = store.getLabelsFromLocalStorage(); err != nil { store.log.WithError(err).Error("Cannot list local labels") diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 2859fbde..2957b936 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -24,9 +24,9 @@ import ( "sync" "testing" - bridgemocks "github.com/ProtonMail/proton-bridge/internal/bridge/mocks" - storeMocks "github.com/ProtonMail/proton-bridge/internal/store/mocks" + storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks" "github.com/ProtonMail/proton-bridge/pkg/pmapi" + pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -43,12 +43,13 @@ const ( type mocksForStore struct { tb testing.TB - ctrl *gomock.Controller - events *storeMocks.MockListener - api *bridgemocks.MockPMAPIProvider - user *storeMocks.MockBridgeUser - panicHandler *storeMocks.MockPanicHandler - store *Store + ctrl *gomock.Controller + events *storemocks.MockListener + user *storemocks.MockBridgeUser + client *pmapimocks.MockClient + clientManager *storemocks.MockClientManager + panicHandler *storemocks.MockPanicHandler + store *Store tmpDir string cache *Cache @@ -57,12 +58,13 @@ type mocksForStore struct { func initMocks(tb testing.TB) (*mocksForStore, func()) { ctrl := gomock.NewController(tb) mocks := &mocksForStore{ - tb: tb, - ctrl: ctrl, - events: storeMocks.NewMockListener(ctrl), - api: bridgemocks.NewMockPMAPIProvider(ctrl), - user: storeMocks.NewMockBridgeUser(ctrl), - panicHandler: storeMocks.NewMockPanicHandler(ctrl), + tb: tb, + ctrl: ctrl, + events: storemocks.NewMockListener(ctrl), + user: storemocks.NewMockBridgeUser(ctrl), + client: pmapimocks.NewMockClient(ctrl), + clientManager: storemocks.NewMockClientManager(ctrl), + panicHandler: storemocks.NewMockPanicHandler(ctrl), } // Called during clean-up. @@ -92,13 +94,15 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar mocks.user.EXPECT().IsConnected().Return(true) mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode) - mocks.api.EXPECT().Addresses().Return(pmapi.AddressList{ + mocks.clientManager.EXPECT().GetClient("userID").AnyTimes().Return(mocks.client) + + mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{ {ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive}, {ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive}, }) - mocks.api.EXPECT().ListLabels() - mocks.api.EXPECT().CountMessages("") - mocks.api.EXPECT().GetEvent(gomock.Any()). + mocks.client.EXPECT().ListLabels() + mocks.client.EXPECT().CountMessages("") + mocks.client.EXPECT().GetEvent(gomock.Any()). Return(&pmapi.Event{ EventID: "latestEventID", }, nil).AnyTimes() @@ -106,7 +110,7 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar // We want to wait until first sync has finished. firstSyncWaiter := sync.WaitGroup{} firstSyncWaiter.Add(1) - mocks.api.EXPECT(). + mocks.client.EXPECT(). ListMessages(gomock.Any()). DoAndReturn(func(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error) { firstSyncWaiter.Done() @@ -117,7 +121,7 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar mocks.store, err = New( mocks.panicHandler, mocks.user, - mocks.api, + mocks.clientManager, mocks.events, filepath.Join(mocks.tmpDir, "mailbox-test.db"), mocks.cache, diff --git a/internal/store/types.go b/internal/store/types.go index 9319a651..6470c3dd 100644 --- a/internal/store/types.go +++ b/internal/store/types.go @@ -17,42 +17,14 @@ package store -import ( - "io" - - "github.com/ProtonMail/proton-bridge/pkg/pmapi" -) +import "github.com/ProtonMail/proton-bridge/pkg/pmapi" type PanicHandler interface { HandlePanic() } -// PMAPIProvider is subset of pmapi.Client for use by the Store. -type PMAPIProvider interface { - CurrentUser() (*pmapi.User, error) - Addresses() pmapi.AddressList - - GetEvent(eventID string) (*pmapi.Event, error) - - CountMessages(addressID string) ([]*pmapi.MessagesCount, error) - ListMessages(filter *pmapi.MessagesFilter) ([]*pmapi.Message, int, error) - GetMessage(apiID string) (*pmapi.Message, error) - Import([]*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error) - DeleteMessages(apiIDs []string) error - LabelMessages(apiIDs []string, labelID string) error - UnlabelMessages(apiIDs []string, labelID string) error - MarkMessagesRead(apiIDs []string) error - MarkMessagesUnread(apiIDs []string) error - - CreateDraft(m *pmapi.Message, parent string, action int) (created *pmapi.Message, err error) - CreateAttachment(att *pmapi.Attachment, r io.Reader, sig io.Reader) (created *pmapi.Attachment, err error) - SendMessage(messageID string, req *pmapi.SendMessageReq) (sent, parent *pmapi.Message, err error) - - ListLabels() ([]*pmapi.Label, error) - CreateLabel(label *pmapi.Label) (*pmapi.Label, error) - UpdateLabel(label *pmapi.Label) (*pmapi.Label, error) - DeleteLabel(labelID string) error - EmptyFolder(labelID string, addressID string) error +type ClientManager interface { + GetClient(userID string) pmapi.Client } // BridgeUser is subset of bridge.User for use by the Store. diff --git a/internal/store/user.go b/internal/store/user.go index 6bbf90c9..b48f2a34 100644 --- a/internal/store/user.go +++ b/internal/store/user.go @@ -24,7 +24,7 @@ func (store *Store) UserID() string { // GetSpace returns used and total space in bytes. func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) { - apiUser, err := store.api.CurrentUser() + apiUser, err := store.client().CurrentUser() if err != nil { return 0, 0, err } @@ -33,7 +33,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) { // GetMaxUpload returns max size of attachment in bytes. func (store *Store) GetMaxUpload() (uint, error) { - apiUser, err := store.api.CurrentUser() + apiUser, err := store.client().CurrentUser() if err != nil { return 0, err } diff --git a/internal/store/user_address_info.go b/internal/store/user_address_info.go index 6abde2be..f777a9f7 100644 --- a/internal/store/user_address_info.go +++ b/internal/store/user_address_info.go @@ -61,7 +61,7 @@ func (store *Store) GetAddressInfo() (addrs []AddressInfo, err error) { } // Store does not have address info yet, need to build it first from API. - addressList := store.api.Addresses() + addressList := store.client().Addresses() if addressList == nil { err = errors.New("addresses unavailable") store.log.WithError(err).Error("Could not get user addresses from API") diff --git a/internal/store/user_mailbox.go b/internal/store/user_mailbox.go index 0e341f40..db8e9b9b 100644 --- a/internal/store/user_mailbox.go +++ b/internal/store/user_mailbox.go @@ -55,7 +55,7 @@ func (store *Store) createMailbox(name string) error { return nil } - _, err := store.api.CreateLabel(&pmapi.Label{ + _, err := store.client().CreateLabel(&pmapi.Label{ Name: name, Color: color, Exclusive: exclusive, @@ -133,7 +133,7 @@ func (store *Store) leastUsedColor() string { func (store *Store) updateMailbox(labelID, newName, color string) error { defer store.eventLoop.pollNow() - _, err := store.api.UpdateLabel(&pmapi.Label{ + _, err := store.client().UpdateLabel(&pmapi.Label{ ID: labelID, Name: newName, Color: color, @@ -150,15 +150,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error { var err error switch labelID { case pmapi.SpamLabel: - err = store.api.EmptyFolder(pmapi.SpamLabel, addressID) + err = store.client().EmptyFolder(pmapi.SpamLabel, addressID) case pmapi.TrashLabel: - err = store.api.EmptyFolder(pmapi.TrashLabel, addressID) + err = store.client().EmptyFolder(pmapi.TrashLabel, addressID) default: err = fmt.Errorf("cannot empty mailbox %v", labelID) } return err } - return store.api.DeleteLabel(labelID) + return store.client().DeleteLabel(labelID) } func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error { @@ -173,7 +173,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro return nil } - labels, err := store.api.ListLabels() + labels, err := store.client().ListLabels() if err != nil { return err } diff --git a/internal/store/user_message.go b/internal/store/user_message.go index 9fc91c0d..4b58910b 100644 --- a/internal/store/user_message.go +++ b/internal/store/user_message.go @@ -54,7 +54,7 @@ func (store *Store) CreateDraft( message.Attachments = nil draftAction := store.getDraftAction(message) - draft, err := store.api.CreateDraft(message, parentID, draftAction) + draft, err := store.client().CreateDraft(message, parentID, draftAction) if err != nil { return nil, nil, errors.Wrap(err, "failed to create draft") } @@ -105,7 +105,7 @@ func (store *Store) createAttachment(kr *pmcrypto.KeyRing, attachment *pmapi.Att return nil, errors.Wrap(err, "failed to encrypt attachment") } - createdAttachment, err := store.api.CreateAttachment(attachment, encReader, sigReader) + createdAttachment, err := store.client().CreateAttachment(attachment, encReader, sigReader) if err != nil { return nil, errors.Wrap(err, "failed to create attachment") } @@ -116,7 +116,7 @@ func (store *Store) createAttachment(kr *pmcrypto.KeyRing, attachment *pmapi.Att // SendMessage sends the message. func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error { defer store.eventLoop.pollNow() - _, _, err := store.api.SendMessage(messageID, req) + _, _, err := store.client().SendMessage(messageID, req) return err } diff --git a/internal/store/user_sync.go b/internal/store/user_sync.go index 221eb629..c4962d1b 100644 --- a/internal/store/user_sync.go +++ b/internal/store/user_sync.go @@ -34,7 +34,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted" // updateCountsFromServer will download and set the counts. func (store *Store) updateCountsFromServer() error { - counts, err := store.api.CountMessages("") + counts, err := store.client().CountMessages("") if err != nil { return errors.Wrap(err, "cannot update counts from server") } @@ -144,7 +144,8 @@ func (store *Store) triggerSync() { store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started") - err := syncAllMail(store.panicHandler, store, store.api, syncState) + // TODO: Is it okay to pass in a client directly? What if it is logged out in the meantime? + err := syncAllMail(store.panicHandler, store, store.client(), syncState) if err != nil { log.WithError(err).Error("Store sync failed") return diff --git a/pkg/pmapi/auth.go b/pkg/pmapi/auth.go index 20d57847..9eba98d5 100644 --- a/pkg/pmapi/auth.go +++ b/pkg/pmapi/auth.go @@ -464,15 +464,13 @@ func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) return auth, err } -// Logout instructs the client manager to log out this client. +// TODO: Should this even be a client method? Or just a method on the client manager? func (c *client) Logout() { c.cm.LogoutClient(c.userID) } -// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not. - -// logout logs the current user out. -func (c *client) logout() (err error) { +// DeleteAuth deletes the API session. +func (c *client) DeleteAuth() (err error) { req, err := c.NewRequest("DELETE", "/auth", nil) if err != nil { return @@ -490,7 +488,10 @@ func (c *client) logout() (err error) { return } -func (c *client) clearSensitiveData() { +// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not. + +// ClearData clears sensitive data from the client. +func (c *client) ClearData() { c.uid = "" c.accessToken = "" c.kr = nil diff --git a/pkg/pmapi/client.go b/pkg/pmapi/client.go index a0b09f32..a5e5c57e 100644 --- a/pkg/pmapi/client.go +++ b/pkg/pmapi/client.go @@ -98,20 +98,28 @@ type Client interface { Auth(username, password string, info *AuthInfo) (*Auth, error) AuthInfo(username string) (*AuthInfo, error) AuthRefresh(token string) (*Auth, error) - Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error) - UnlockAddresses(passphrase []byte) error + Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error) + Logout() + DeleteAuth() error + ClearData() + CurrentUser() (*User, error) UpdateUser() (*User, error) + Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error) + UnlockAddresses(passphrase []byte) error + + GetAddresses() (addresses AddressList, err error) Addresses() AddressList - Logout() - GetEvent(eventID string) (*Event, error) + SendMessage(string, *SendMessageReq) (sent, parent *Message, err error) + CreateDraft(m *Message, parent string, action int) (created *Message, err error) + Import([]*ImportMsgReq) ([]*ImportMsgRes, error) + CountMessages(addressID string) ([]*MessagesCount, error) ListMessages(filter *MessagesFilter) ([]*Message, int, error) GetMessage(apiID string) (*Message, error) - Import([]*ImportMsgReq) ([]*ImportMsgRes, error) DeleteMessages(apiIDs []string) error LabelMessages(apiIDs []string, labelID string) error UnlabelMessages(apiIDs []string, labelID string) error @@ -128,20 +136,17 @@ type Client interface { SendSimpleMetric(category, action, label string) error ReportSentryCrash(reportErr error) (err error) - Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error) - GetMailSettings() (MailSettings, error) GetContactEmailByEmail(string, int, int) ([]ContactEmail, error) GetContactByID(string) (Contact, error) DecryptAndVerifyCards([]Card) ([]Card, error) - GetPublicKeysForEmail(string) ([]PublicKey, bool, error) - SendMessage(string, *SendMessageReq) (sent, parent *Message, err error) - CreateDraft(m *Message, parent string, action int) (created *Message, err error) - CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) - DeleteAttachment(attID string) (err error) - KeyRingForAddressID(string) (kr *pmcrypto.KeyRing) GetAttachment(id string) (att io.ReadCloser, err error) + CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) + DeleteAttachment(attID string) (err error) + + KeyRingForAddressID(string) (kr *pmcrypto.KeyRing) + GetPublicKeysForEmail(string) ([]PublicKey, bool, error) } // client is a client of the protonmail API. It implements the Client interface. diff --git a/pkg/pmapi/clientmanager.go b/pkg/pmapi/clientmanager.go index ae4989b5..821dde51 100644 --- a/pkg/pmapi/clientmanager.go +++ b/pkg/pmapi/clientmanager.go @@ -15,10 +15,17 @@ var defaultProxyUseDuration = 24 * time.Hour // ClientManager is a manager of clients. type ClientManager struct { + // newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to + // create other types of clients (e.g. for integration tests). + newClient func(userID string) Client + config *ClientConfig roundTripper http.RoundTripper - clients map[string]*client + // TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests. + // But that screws up other things like not being able to clear sensitive info during logout + // unless the client interface contains a method for that. + clients map[string]Client clientsLocker sync.Locker tokens map[string]string @@ -38,11 +45,13 @@ type ClientManager struct { proxyUseDuration time.Duration } +// ClientAuth holds an API auth produced by a Client for a specific user. type ClientAuth struct { UserID string Auth *Auth } +// tokenExpiration manages the expiration of an access token. type tokenExpiration struct { timer *time.Timer cancel chan (struct{}) @@ -58,7 +67,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { config: config, roundTripper: http.DefaultTransport, - clients: make(map[string]*client), + clients: make(map[string]Client), clientsLocker: &sync.Mutex{}, tokens: make(map[string]string), @@ -78,11 +87,19 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { proxyUseDuration: defaultProxyUseDuration, } + cm.newClient = func(userID string) Client { + return newClient(cm, userID) + } + go cm.forwardClientAuths() return } +func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) { + cm.newClient = f +} + // SetRoundTripper sets the roundtripper used by clients created by this client manager. func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) { cm.roundTripper = rt @@ -100,7 +117,7 @@ func (cm *ClientManager) GetClient(userID string) Client { return client } - cm.clients[userID] = newClient(cm, userID) + cm.clients[userID] = cm.newClient(userID) return cm.clients[userID] } @@ -108,10 +125,10 @@ func (cm *ClientManager) GetClient(userID string) Client { // GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created. func (cm *ClientManager) GetAnonymousClient() Client { if client, ok := cm.clients[""]; ok { - client.Logout() + client.DeleteAuth() } - cm.clients[""] = newClient(cm, "") + cm.clients[""] = cm.newClient("") return cm.clients[""] } @@ -127,10 +144,10 @@ func (cm *ClientManager) LogoutClient(userID string) { delete(cm.clients, userID) go func() { - if err := client.logout(); err != nil { - // TODO: Try again! This should loop until it succeeds (might fail the first time due to internet). + if err := client.DeleteAuth(); err != nil { + // TODO: Retry if the request failed. } - client.clearSensitiveData() + client.ClearData() cm.clearToken(userID) }() diff --git a/pkg/pmapi/mocks/mocks.go b/pkg/pmapi/mocks/mocks.go index 1e2a6b7f..21f66594 100644 --- a/pkg/pmapi/mocks/mocks.go +++ b/pkg/pmapi/mocks/mocks.go @@ -256,6 +256,21 @@ func (mr *MockClientMockRecorder) EmptyFolder(arg0, arg1 interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmptyFolder", reflect.TypeOf((*MockClient)(nil).EmptyFolder), arg0, arg1) } +// GetAddresses mocks base method +func (m *MockClient) GetAddresses() (pmapi.AddressList, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAddresses") + ret0, _ := ret[0].(pmapi.AddressList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAddresses indicates an expected call of GetAddresses +func (mr *MockClientMockRecorder) GetAddresses() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddresses", reflect.TypeOf((*MockClient)(nil).GetAddresses)) +} + // GetAttachment mocks base method func (m *MockClient) GetAttachment(arg0 string) (io.ReadCloser, error) { m.ctrl.T.Helper() diff --git a/test/context/bridge.go b/test/context/bridge.go index b420dbfd..4b088519 100644 --- a/test/context/bridge.go +++ b/test/context/bridge.go @@ -21,9 +21,9 @@ import ( "os" "runtime" + "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/pkg/listener" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" ) // GetBridge returns bridge instance. @@ -34,10 +34,7 @@ func (ctx *TestContext) GetBridge() *bridge.Bridge { // withBridgeInstance creates a bridge instance for use in the test. // Every TestContext has this by default and thus this doesn't need to be exported. func (ctx *TestContext) withBridgeInstance() { - pmapiFactory := func(userID string) pmapi.Client { - return ctx.pmapiController.GetClient(userID) - } - ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, pmapiFactory) + ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, ctx.clientManager) ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data") } @@ -62,7 +59,7 @@ func newBridgeInstance( cfg *fakeConfig, credStore bridge.CredentialsStorer, eventListener listener.Listener, - pmapiFactory bridge.PMAPIProviderFactory, + clientManager bridge.ClientManager, ) *bridge.Bridge { version := os.Getenv("VERSION") bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "") @@ -70,7 +67,7 @@ func newBridgeInstance( panicHandler := &panicHandler{t: t} pref := preferences.New(cfg) - return bridge.New(cfg, pref, panicHandler, eventListener, version, pmapiFactory, credStore) + return bridge.New(cfg, pref, panicHandler, eventListener, version, clientManager, credStore) } // SetLastBridgeError sets the last error that occurred while executing a bridge action. diff --git a/test/context/config.go b/test/context/config.go index 8cb3838a..c932b47d 100644 --- a/test/context/config.go +++ b/test/context/config.go @@ -28,7 +28,6 @@ import ( type fakeConfig struct { dir string - tm *pmapi.TokenManager } // newFakeConfig creates a temporary folder for files. @@ -41,7 +40,6 @@ func newFakeConfig() *fakeConfig { return &fakeConfig{ dir: dir, - tm: pmapi.NewTokenManager(), } } @@ -53,8 +51,6 @@ func (c *fakeConfig) GetAPIConfig() *pmapi.ClientConfig { AppVersion: "Bridge_" + os.Getenv("VERSION"), ClientID: "bridge", SentryDSN: "", - // TokenManager should not be required, but PMAPI still doesn't handle not-set cases everywhere. - TokenManager: c.tm, } } func (c *fakeConfig) GetDBDir() string { diff --git a/test/context/context.go b/test/context/context.go index 1d788d69..d9222bf7 100644 --- a/test/context/context.go +++ b/test/context/context.go @@ -21,6 +21,7 @@ package context import ( "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/pkg/listener" + "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/test/accounts" "github.com/ProtonMail/proton-bridge/test/mocks" "github.com/sirupsen/logrus" @@ -44,6 +45,7 @@ type TestContext struct { bridge *bridge.Bridge bridgeLastError error credStore bridge.CredentialsStorer + clientManager *pmapi.ClientManager // IMAP related variables. imapAddr string @@ -70,11 +72,14 @@ func New() *TestContext { cfg := newFakeConfig() + cm := pmapi.NewClientManager(cfg.GetAPIConfig()) + ctx := &TestContext{ t: &bddT{}, cfg: cfg, listener: listener.New(), - pmapiController: newPMAPIController(), + pmapiController: newPMAPIController(cm), + clientManager: cm, testAccounts: newTestAccounts(), credStore: newFakeCredStore(), imapClients: make(map[string]*mocks.IMAPClient), diff --git a/test/context/pmapi_controller.go b/test/context/pmapi_controller.go index 3c2b7469..f3474116 100644 --- a/test/context/pmapi_controller.go +++ b/test/context/pmapi_controller.go @@ -20,14 +20,12 @@ package context import ( "os" - "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/test/fakeapi" "github.com/ProtonMail/proton-bridge/test/liveapi" ) type PMAPIController interface { - GetClient(userID string) bridge.PMAPIProvider TurnInternetConnectionOff() TurnInternetConnectionOn() AddUser(user *pmapi.User, addresses *pmapi.AddressList, password string, twoFAEnabled bool) error @@ -40,19 +38,19 @@ type PMAPIController interface { GetCalls(method, path string) [][]byte } -func newPMAPIController() PMAPIController { +func newPMAPIController(cm *pmapi.ClientManager) PMAPIController { switch os.Getenv(EnvName) { case EnvFake: - return newFakePMAPIController() + return newFakePMAPIController(cm) case EnvLive: - return newLivePMAPIController() + return newLivePMAPIController(cm) default: panic("unknown env") } } -func newFakePMAPIController() PMAPIController { - return newFakePMAPIControllerWrap(fakeapi.NewController()) +func newFakePMAPIController(cm *pmapi.ClientManager) PMAPIController { + return newFakePMAPIControllerWrap(fakeapi.NewController(cm)) } type fakePMAPIControllerWrap struct { @@ -63,12 +61,8 @@ func newFakePMAPIControllerWrap(controller *fakeapi.Controller) PMAPIController return &fakePMAPIControllerWrap{Controller: controller} } -func (s *fakePMAPIControllerWrap) GetClient(userID string) bridge.PMAPIProvider { - return s.Controller.GetClient(userID) -} - -func newLivePMAPIController() PMAPIController { - return newLiveAPIControllerWrap(liveapi.NewController()) +func newLivePMAPIController(cm *pmapi.ClientManager) PMAPIController { + return newLiveAPIControllerWrap(liveapi.NewController(cm)) } type liveAPIControllerWrap struct { @@ -78,7 +72,3 @@ type liveAPIControllerWrap struct { func newLiveAPIControllerWrap(controller *liveapi.Controller) PMAPIController { return &liveAPIControllerWrap{Controller: controller} } - -func (s *liveAPIControllerWrap) GetClient(userID string) bridge.PMAPIProvider { - return s.Controller.GetClient(userID) -} diff --git a/test/fakeapi/attachments.go b/test/fakeapi/attachments.go index f8ea6092..e4a34b96 100644 --- a/test/fakeapi/attachments.go +++ b/test/fakeapi/attachments.go @@ -45,3 +45,11 @@ func (api *FakePMAPI) CreateAttachment(attachment *pmapi.Attachment, data io.Rea attachment.KeyPackets = base64.StdEncoding.EncodeToString(bytes) return attachment, nil } + +func (api *FakePMAPI) DeleteAttachment(attachmentID string) error { + if err := api.checkAndRecordCall(GET, "/attachments/"+attachmentID, nil); err != nil { + return err + } + + return nil +} diff --git a/test/fakeapi/auth.go b/test/fakeapi/auth.go index ed56b40f..95bf911d 100644 --- a/test/fakeapi/auth.go +++ b/test/fakeapi/auth.go @@ -141,13 +141,23 @@ func (api *FakePMAPI) AuthRefresh(token string) (*pmapi.Auth, error) { return auth, nil } -func (api *FakePMAPI) Logout() error { +func (api *FakePMAPI) Logout() { + _ = api.DeleteAuth() + api.ClearData() +} + +func (api *FakePMAPI) DeleteAuth() error { if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil { return err } + // Logout will also emit change to auth channel api.sendAuth(nil) - api.controller.deleteSession(api.uid) - api.unsetUser() + return nil } + +func (api *FakePMAPI) ClearData() { + api.controller.deleteSession(api.uid) + api.unsetUser() +} diff --git a/test/fakeapi/controller.go b/test/fakeapi/controller.go index b83e313b..023bbb86 100644 --- a/test/fakeapi/controller.go +++ b/test/fakeapi/controller.go @@ -44,8 +44,8 @@ type Controller struct { log *logrus.Entry } -func NewController() *Controller { - return &Controller{ +func NewController(cm *pmapi.ClientManager) (cntrl *Controller) { + cntrl = &Controller{ lock: &sync.RWMutex{}, fakeAPIs: []*FakePMAPI{}, calls: []*fakeCall{}, @@ -62,10 +62,12 @@ func NewController() *Controller { log: logrus.WithField("pkg", "fakeapi-controller"), } -} -func (cntrl *Controller) GetClient(userID string) *FakePMAPI { - fakeAPI := New(cntrl) - cntrl.fakeAPIs = append(cntrl.fakeAPIs, fakeAPI) - return fakeAPI + cm.SetClientConstructor(func(userID string) pmapi.Client { + fakeAPI := New(cntrl) + cntrl.fakeAPIs = append(cntrl.fakeAPIs, fakeAPI) + return fakeAPI + }) + + return } diff --git a/test/fakeapi/reports.go b/test/fakeapi/reports.go index 3652c407..ce39335b 100644 --- a/test/fakeapi/reports.go +++ b/test/fakeapi/reports.go @@ -42,3 +42,7 @@ func (api *FakePMAPI) SendSimpleMetric(category, action, label string) error { v.Set("Label", label) return api.checkInternetAndRecordCall(GET, "/metrics?"+v.Encode(), nil) } + +func (api *FakePMAPI) ReportSentryCrash(reportErr error) (err error) { + return nil +} diff --git a/test/fakeapi/user.go b/test/fakeapi/user.go index 6ce24290..2cb1a9b4 100644 --- a/test/fakeapi/user.go +++ b/test/fakeapi/user.go @@ -50,6 +50,10 @@ func (api *FakePMAPI) UpdateUser() (*pmapi.User, error) { return api.user, nil } +func (api *FakePMAPI) GetAddresses() (pmapi.AddressList, error) { + return *api.addresses, nil +} + func (api *FakePMAPI) Addresses() pmapi.AddressList { return *api.addresses } diff --git a/test/liveapi/cleanup.go b/test/liveapi/cleanup.go index 2f4db367..17836078 100644 --- a/test/liveapi/cleanup.go +++ b/test/liveapi/cleanup.go @@ -24,7 +24,7 @@ import ( "github.com/pkg/errors" ) -func cleanup(client *pmapi.Client) error { +func cleanup(client pmapi.Client) error { if err := cleanSystemFolders(client); err != nil { return errors.Wrap(err, "failed to clean system folders") } @@ -37,7 +37,7 @@ func cleanup(client *pmapi.Client) error { return nil } -func cleanSystemFolders(client *pmapi.Client) error { +func cleanSystemFolders(client pmapi.Client) error { for _, labelID := range []string{pmapi.InboxLabel, pmapi.SentLabel, pmapi.ArchiveLabel, pmapi.AllMailLabel, pmapi.DraftLabel} { for { messages, total, err := client.ListMessages(&pmapi.MessagesFilter{ @@ -69,7 +69,7 @@ func cleanSystemFolders(client *pmapi.Client) error { return nil } -func cleanCustomLables(client *pmapi.Client) error { +func cleanCustomLables(client pmapi.Client) error { labels, err := client.ListLabels() if err != nil { return errors.Wrap(err, "failed to list labels") @@ -87,7 +87,7 @@ func cleanCustomLables(client *pmapi.Client) error { return nil } -func cleanTrash(client *pmapi.Client) error { +func cleanTrash(client pmapi.Client) error { for { _, total, err := client.ListMessages(&pmapi.MessagesFilter{ PageSize: 1, @@ -110,7 +110,7 @@ func cleanTrash(client *pmapi.Client) error { return nil } -func emptyFolder(client *pmapi.Client, labelID string) error { +func emptyFolder(client pmapi.Client, labelID string) error { err := client.EmptyFolder(labelID, "") if err != nil { return err diff --git a/test/liveapi/controller.go b/test/liveapi/controller.go index f78b4afe..1b3ba52c 100644 --- a/test/liveapi/controller.go +++ b/test/liveapi/controller.go @@ -18,9 +18,7 @@ package liveapi import ( - "fmt" "net/http" - "os" "sync" "github.com/ProtonMail/proton-bridge/pkg/pmapi" @@ -30,33 +28,27 @@ type Controller struct { // Internal states. lock *sync.RWMutex calls []*fakeCall - pmapiByUsername map[string]*pmapi.Client messageIDsByUsername map[string][]string + clientManager *pmapi.ClientManager // State controlled by test. noInternetConnection bool } -func NewController() *Controller { - return &Controller{ +func NewController(cm *pmapi.ClientManager) (cntrl *Controller) { + cntrl = &Controller{ lock: &sync.RWMutex{}, calls: []*fakeCall{}, - pmapiByUsername: map[string]*pmapi.Client{}, messageIDsByUsername: map[string][]string{}, + clientManager: cm, noInternetConnection: false, } -} -func (cntrl *Controller) GetClient(userID string) *pmapi.Client { - cfg := &pmapi.ClientConfig{ - AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")), - ClientID: "bridge-test", - Transport: &fakeTransport{ - cntrl: cntrl, - transport: http.DefaultTransport, - }, - TokenManager: pmapi.NewTokenManager(), - } - return pmapi.NewClient(cfg, userID) + cm.SetRoundTripper(&fakeTransport{ + cntrl: cntrl, + transport: http.DefaultTransport, + }) + + return } diff --git a/test/liveapi/labels.go b/test/liveapi/labels.go index 1c7be0be..b715516c 100644 --- a/test/liveapi/labels.go +++ b/test/liveapi/labels.go @@ -36,11 +36,7 @@ var systemLabelNameToID = map[string]string{ //nolint[gochecknoglobals] } func (cntrl *Controller) AddUserLabel(username string, label *pmapi.Label) error { - client, ok := cntrl.pmapiByUsername[username] - if !ok { - return fmt.Errorf("user %s does not exist", username) - } - + client := cntrl.clientManager.GetClient(username) label.Exclusive = getLabelExclusive(label.Name) label.Name = getLabelNameWithoutPrefix(label.Name) label.Color = pmapi.LabelColors[0] @@ -67,11 +63,7 @@ func (cntrl *Controller) getLabelID(username, labelName string) (string, error) return labelID, nil } - client, ok := cntrl.pmapiByUsername[username] - if !ok { - return "", fmt.Errorf("user %s does not exist", username) - } - + client := cntrl.clientManager.GetClient(username) labels, err := client.ListLabels() if err != nil { return "", errors.Wrap(err, "failed to list labels") diff --git a/test/liveapi/messages.go b/test/liveapi/messages.go index 08ccdf82..3dcc8db7 100644 --- a/test/liveapi/messages.go +++ b/test/liveapi/messages.go @@ -31,10 +31,7 @@ import ( ) func (cntrl *Controller) AddUserMessage(username string, message *pmapi.Message) error { - client, ok := cntrl.pmapiByUsername[username] - if !ok { - return fmt.Errorf("user %s does not exist", username) - } + client := cntrl.clientManager.GetClient(username) body, err := buildMessage(client, message) if err != nil { @@ -64,7 +61,7 @@ func (cntrl *Controller) AddUserMessage(username string, message *pmapi.Message) return nil } -func buildMessage(client *pmapi.Client, message *pmapi.Message) (*bytes.Buffer, error) { +func buildMessage(client pmapi.Client, message *pmapi.Message) (*bytes.Buffer, error) { if err := encryptMessage(client, message); err != nil { return nil, errors.Wrap(err, "failed to encrypt message") } @@ -79,7 +76,7 @@ func buildMessage(client *pmapi.Client, message *pmapi.Message) (*bytes.Buffer, return body, nil } -func encryptMessage(client *pmapi.Client, message *pmapi.Message) error { +func encryptMessage(client pmapi.Client, message *pmapi.Message) error { addresses, err := client.GetAddresses() if err != nil { return errors.Wrap(err, "failed to get address") diff --git a/test/liveapi/users.go b/test/liveapi/users.go index 602f4adf..4238a091 100644 --- a/test/liveapi/users.go +++ b/test/liveapi/users.go @@ -18,9 +18,7 @@ package liveapi import ( - "fmt" - "os" - + "github.com/ProtonMail/bridge/pkg/pmapi" "github.com/cucumber/godog" "github.com/pkg/errors" ) @@ -30,11 +28,7 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, return godog.ErrPending } - client := pmapi.NewClient(&pmapi.ClientConfig{ - AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")), - ClientID: "bridge-cntrl", - TokenManager: pmapi.NewTokenManager(), - }, user.ID) + client := cntrl.clientManager.GetClient(user.ID) authInfo, err := client.AuthInfo(user.Name) if err != nil { @@ -60,6 +54,5 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, return errors.Wrap(err, "failed to clean user") } - cntrl.pmapiByUsername[user.Name] = client return nil }