feat: make store use ClientManager

This commit is contained in:
James Houlahan
2020-04-07 09:55:28 +02:00
parent f269be4291
commit 042c340881
43 changed files with 414 additions and 264 deletions

View File

@ -149,8 +149,8 @@ coverage: test
go tool cover -html=/tmp/coverage.out -o=coverage.html go tool cover -html=/tmp/coverage.out -o=coverage.html
mocks: mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/bridge Configer,PreferenceProvider,PanicHandler,CredentialsStorer > internal/bridge/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/bridge Configer,PreferenceProvider,PanicHandler,ClientManager,CredentialsStorer > internal/bridge/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser > internal/store/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,ClientManager,BridgeUser > internal/store/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client > pkg/pmapi/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client > pkg/pmapi/mocks/mocks.go

View File

@ -48,7 +48,7 @@ type Bridge struct {
panicHandler PanicHandler panicHandler PanicHandler
events listener.Listener events listener.Listener
version string version string
clientManager *pmapi.ClientManager clientManager ClientManager
credStorer CredentialsStorer credStorer CredentialsStorer
storeCache *store.Cache storeCache *store.Cache
@ -76,7 +76,7 @@ func New(
panicHandler PanicHandler, panicHandler PanicHandler,
eventListener listener.Listener, eventListener listener.Listener,
version string, version string,
clientManager *pmapi.ClientManager, clientManager ClientManager,
credStorer CredentialsStorer, credStorer CredentialsStorer,
) *Bridge { ) *Bridge {
log.Trace("Creating new bridge") log.Trace("Creating new bridge")
@ -185,7 +185,7 @@ func (b *Bridge) watchAPIAuths() {
user, ok := b.hasUser(auth.UserID) user, ok := b.hasUser(auth.UserID)
if !ok { if !ok {
logrus.Info("User is not added to bridge yet") logrus.WithField("userID", auth.UserID).Info("User not available for auth update")
continue continue
} }

View File

@ -39,7 +39,7 @@ func TestBridgeFinishLoginBadPassword(t *testing.T) {
// Set up mocks for FinishLogin. // Set up mocks for FinishLogin.
err := errors.New("bad password") err := errors.New("bad password")
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err) m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err)
m.pmapiClient.EXPECT().Logout().Return(nil) m.pmapiClient.EXPECT().Logout()
checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", err) checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", err)
} }

View File

@ -56,6 +56,7 @@ func TestNewBridgeWithDisconnectedUser(t *testing.T) {
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil).Times(2) m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil).Times(2)
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")) m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized"))
m.pmapiClient.EXPECT().Addresses().Return(nil) m.pmapiClient.EXPECT().Addresses().Return(nil)
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient)
checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
} }
@ -66,13 +67,14 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) {
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().Logout("user").Return(nil) m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")) m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token"))
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.pmapiClient.EXPECT().Logout().Return(nil) m.pmapiClient.EXPECT().Logout()
m.credentialsStore.EXPECT().Logout("user").Return(nil) m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
@ -84,10 +86,14 @@ func TestNewBridgeWithConnectedUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil) m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil) m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil) m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil)
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
// Set up mocks for store initialisation for the authorized user. // Set up mocks for store initialisation for the authorized user.
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil) m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil)
@ -97,10 +103,6 @@ func TestNewBridgeWithConnectedUser(t *testing.T) {
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil) m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
checkBridgeNew(t, m, []*credentials.Credentials{testCredentials}) checkBridgeNew(t, m, []*credentials.Credentials{testCredentials})
} }

View File

@ -129,11 +129,11 @@ type mocks struct {
config *bridgemocks.MockConfiger config *bridgemocks.MockConfiger
PanicHandler *bridgemocks.MockPanicHandler PanicHandler *bridgemocks.MockPanicHandler
prefProvider *bridgemocks.MockPreferenceProvider prefProvider *bridgemocks.MockPreferenceProvider
clientManager *bridgemocks.MockClientManager
credentialsStore *bridgemocks.MockCredentialsStorer credentialsStore *bridgemocks.MockCredentialsStorer
eventListener *MockListener eventListener *MockListener
pmapiClient *pmapimocks.MockClient pmapiClient *pmapimocks.MockClient
clientManager *pmapimocks.MockClientManager
storeCache *store.Cache storeCache *store.Cache
} }
@ -151,11 +151,11 @@ func initMocks(t *testing.T) mocks {
config: bridgemocks.NewMockConfiger(mockCtrl), config: bridgemocks.NewMockConfiger(mockCtrl),
PanicHandler: bridgemocks.NewMockPanicHandler(mockCtrl), PanicHandler: bridgemocks.NewMockPanicHandler(mockCtrl),
prefProvider: bridgemocks.NewMockPreferenceProvider(mockCtrl), prefProvider: bridgemocks.NewMockPreferenceProvider(mockCtrl),
clientManager: bridgemocks.NewMockClientManager(mockCtrl),
credentialsStore: bridgemocks.NewMockCredentialsStorer(mockCtrl), credentialsStore: bridgemocks.NewMockCredentialsStorer(mockCtrl),
eventListener: NewMockListener(mockCtrl), eventListener: NewMockListener(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl), pmapiClient: pmapimocks.NewMockClient(mockCtrl),
clientManager: pmapimocks.NewMockClientManager(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()), storeCache: store.NewCache(cacheFile.Name()),
} }
@ -214,7 +214,7 @@ func testNewBridge(t *testing.T, m mocks) *Bridge {
m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes() m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes()
m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes() m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes()
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any()) m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient) m.clientManager.EXPECT().GetBridgeAuthChannel().Return(make(chan *pmapi.ClientAuth))
bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore) bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore)

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,CredentialsStorer) // Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,ClientManager,CredentialsStorer)
// Package mocks is a generated GoMock package. // Package mocks is a generated GoMock package.
package mocks package mocks
@ -203,6 +203,95 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
} }
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// AllowProxy mocks base method
func (m *MockClientManager) AllowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AllowProxy")
}
// AllowProxy indicates an expected call of AllowProxy
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
}
// DisallowProxy mocks base method
func (m *MockClientManager) DisallowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisallowProxy")
}
// DisallowProxy indicates an expected call of DisallowProxy
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
}
// GetAnonymousClient mocks base method
func (m *MockClientManager) GetAnonymousClient() pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAnonymousClient")
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetAnonymousClient indicates an expected call of GetAnonymousClient
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
}
// GetBridgeAuthChannel mocks base method
func (m *MockClientManager) GetBridgeAuthChannel() chan pmapi.ClientAuth {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBridgeAuthChannel")
ret0, _ := ret[0].(chan pmapi.ClientAuth)
return ret0
}
// GetBridgeAuthChannel indicates an expected call of GetBridgeAuthChannel
func (mr *MockClientManagerMockRecorder) GetBridgeAuthChannel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBridgeAuthChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel))
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// MockCredentialsStorer is a mock of CredentialsStorer interface // MockCredentialsStorer is a mock of CredentialsStorer interface
type MockCredentialsStorer struct { type MockCredentialsStorer struct {
ctrl *gomock.Controller ctrl *gomock.Controller

View File

@ -51,3 +51,11 @@ type CredentialsStorer interface {
Logout(userID string) error Logout(userID string) error
Delete(userID string) error Delete(userID string) error
} }
type ClientManager interface {
GetClient(userID string) pmapi.Client
GetAnonymousClient() pmapi.Client
AllowProxy()
DisallowProxy()
GetBridgeAuthChannel() chan pmapi.ClientAuth
}

View File

@ -41,7 +41,7 @@ type User struct {
log *logrus.Entry log *logrus.Entry
panicHandler PanicHandler panicHandler PanicHandler
listener listener.Listener listener listener.Listener
clientManager *pmapi.ClientManager clientManager ClientManager
credStorer CredentialsStorer credStorer CredentialsStorer
imapUpdatesChannel chan interface{} imapUpdatesChannel chan interface{}
@ -66,7 +66,7 @@ func newUser(
userID string, userID string,
eventListener listener.Listener, eventListener listener.Listener,
credStorer CredentialsStorer, credStorer CredentialsStorer,
clientManager *pmapi.ClientManager, clientManager ClientManager,
storeCache *store.Cache, storeCache *store.Cache,
storeDir string, storeDir string,
) (u *User, err error) { ) (u *User, err error) {
@ -139,7 +139,7 @@ func (u *User) init(idleUpdates chan interface{}) (err error) {
} }
u.store = nil u.store = nil
} }
store, err := store.New(u.panicHandler, u, u.client(), u.listener, u.storePath, u.storeCache) store, err := store.New(u.panicHandler, u, u.clientManager, u.listener, u.storePath, u.storeCache)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to create store") return errors.Wrap(err, "failed to create store")
} }

View File

@ -33,7 +33,7 @@ func TestNewUserNoCredentialsStore(t *testing.T) {
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail")) m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
a.Error(t, err) a.Error(t, err)
} }
@ -153,10 +153,10 @@ func TestNewUser(t *testing.T) {
} }
func checkNewUser(m mocks) { func checkNewUser(m mocks) {
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
defer cleanUpUserData(user) defer cleanUpUserData(user)
_ = user.init(nil, m.pmapiClient) _ = user.init(nil)
waitForEvents() waitForEvents()
@ -164,10 +164,10 @@ func checkNewUser(m mocks) {
} }
func checkNewUserDisconnected(m mocks) { func checkNewUserDisconnected(m mocks) {
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
defer cleanUpUserData(user) defer cleanUpUserData(user)
_ = user.init(nil, m.pmapiClient) _ = user.init(nil)
waitForEvents() waitForEvents()

View File

@ -43,10 +43,10 @@ func testNewUser(m mocks) *User {
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil) m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil)
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes() m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
assert.NoError(m.t, err) assert.NoError(m.t, err)
err = user.init(nil, m.pmapiClient) err = user.init(nil)
assert.NoError(m.t, err) assert.NoError(m.t, err)
return user return user
@ -69,10 +69,10 @@ func testNewUserForLogout(m mocks) *User {
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes() m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp") user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
assert.NoError(m.t, err) assert.NoError(m.t, err)
err = user.init(nil, m.pmapiClient) err = user.init(nil)
assert.NoError(m.t, err) assert.NoError(m.t, err)
return user return user

View File

@ -109,5 +109,9 @@ func (storeAddress *Address) AddressID() string {
// APIAddress returns the `pmapi.Address` struct. // APIAddress returns the `pmapi.Address` struct.
func (storeAddress *Address) APIAddress() *pmapi.Address { func (storeAddress *Address) APIAddress() *pmapi.Address {
return storeAddress.store.api.Addresses().ByEmail(storeAddress.address) return storeAddress.client().Addresses().ByEmail(storeAddress.address)
}
func (storeAddress *Address) client() pmapi.Client {
return storeAddress.store.client()
} }

View File

@ -43,12 +43,11 @@ type eventLoop struct {
log *logrus.Entry log *logrus.Entry
store *Store store *Store
apiClient PMAPIProvider
user BridgeUser user BridgeUser
events listener.Listener events listener.Listener
} }
func newEventLoop(cache *Cache, store *Store, api PMAPIProvider, user BridgeUser, events listener.Listener) *eventLoop { func newEventLoop(cache *Cache, store *Store, user BridgeUser, events listener.Listener) *eventLoop {
eventLog := log.WithField("userID", user.ID()) eventLog := log.WithField("userID", user.ID())
eventLog.Trace("Creating new event loop") eventLog.Trace("Creating new event loop")
@ -61,7 +60,6 @@ func newEventLoop(cache *Cache, store *Store, api PMAPIProvider, user BridgeUser
log: eventLog, log: eventLog,
store: store, store: store,
apiClient: api,
user: user, user: user,
events: events, events: events,
} }
@ -71,10 +69,14 @@ func (loop *eventLoop) IsRunning() bool {
return loop.isRunning return loop.isRunning
} }
func (loop *eventLoop) client() pmapi.Client {
return loop.store.client()
}
func (loop *eventLoop) setFirstEventID() (err error) { func (loop *eventLoop) setFirstEventID() (err error) {
loop.log.Info("Setting first event ID") loop.log.Info("Setting first event ID")
event, err := loop.apiClient.GetEvent("") event, err := loop.client().GetEvent("")
if err != nil { if err != nil {
loop.log.WithError(err).Error("Could not get latest event ID") loop.log.WithError(err).Error("Could not get latest event ID")
return return
@ -240,7 +242,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
loop.pollCounter++ loop.pollCounter++
var event *pmapi.Event var event *pmapi.Event
if event, err = loop.apiClient.GetEvent(loop.currentEventID); err != nil { if event, err = loop.client().GetEvent(loop.currentEventID); err != nil {
return false, errors.Wrap(err, "failed to get event") return false, errors.Wrap(err, "failed to get event")
} }
@ -321,7 +323,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
log.Debug("Processing address change event") log.Debug("Processing address change event")
// Get old addresses for comparisons before updating user. // Get old addresses for comparisons before updating user.
oldList := loop.apiClient.Addresses() oldList := loop.client().Addresses()
if err = loop.user.UpdateUser(); err != nil { if err = loop.user.UpdateUser(); err != nil {
if logoutErr := loop.user.Logout(); logoutErr != nil { if logoutErr := loop.user.Logout(); logoutErr != nil {
@ -363,7 +365,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
} }
} }
if err = loop.store.createOrUpdateAddressInfo(loop.apiClient.Addresses()); err != nil { if err = loop.store.createOrUpdateAddressInfo(loop.client().Addresses()); err != nil {
return errors.Wrap(err, "failed to update address IDs in store") return errors.Wrap(err, "failed to update address IDs in store")
} }
@ -430,7 +432,7 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...") msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
if msg, err = loop.apiClient.GetMessage(message.ID); err != nil { if msg, err = loop.client().GetMessage(message.ID); err != nil {
if err == pmapi.ErrNoSuchAPIID { if err == pmapi.ErrNoSuchAPIID {
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API") msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
err = nil err = nil

View File

@ -39,21 +39,21 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
// Doesn't matter which IDs are used. // Doesn't matter which IDs are used.
// This test is trying to see whether event loop will immediately process // This test is trying to see whether event loop will immediately process
// next event if there is `More` of them. // next event if there is `More` of them.
m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{ m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
EventID: "event50", EventID: "event50",
More: 1, More: 1,
}, nil), }, nil),
m.api.EXPECT().GetEvent("event50").Return(&pmapi.Event{ m.client.EXPECT().GetEvent("event50").Return(&pmapi.Event{
EventID: "event70", EventID: "event70",
More: 0, More: 0,
}, nil), }, nil),
m.api.EXPECT().GetEvent("event70").Return(&pmapi.Event{ m.client.EXPECT().GetEvent("event70").Return(&pmapi.Event{
EventID: "event71", EventID: "event71",
More: 0, More: 0,
}, nil), }, nil),
) )
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
m.api.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() m.client.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
// Event loop runs in goroutine and will be stopped by deferred mock clearing. // Event loop runs in goroutine and will be stopped by deferred mock clearing.
go m.store.eventLoop.start() go m.store.eventLoop.start()
@ -78,12 +78,12 @@ func TestEventLoopUpdateMessageFromLoop(t *testing.T) {
newSubject := "new subject" newSubject := "new subject"
// First sync will add message with old subject to database. // First sync will add message with old subject to database.
m.api.EXPECT().GetMessage("msg1").Return(&pmapi.Message{ m.client.EXPECT().GetMessage("msg1").Return(&pmapi.Message{
ID: "msg1", ID: "msg1",
Subject: subject, Subject: subject,
}, nil) }, nil)
// Event will update the subject. // Event will update the subject.
m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{ m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
EventID: "event1", EventID: "event1",
Messages: []*pmapi.EventMessage{{ Messages: []*pmapi.EventMessage{{
EventItem: pmapi.EventItem{ EventItem: pmapi.EventItem{

View File

@ -258,8 +258,8 @@ func (storeMailbox *Mailbox) pollNow() {
} }
// api is a proxy for the store's `PMAPIProvider`. // api is a proxy for the store's `PMAPIProvider`.
func (storeMailbox *Mailbox) api() PMAPIProvider { func (storeMailbox *Mailbox) client() pmapi.Client {
return storeMailbox.store.api return storeMailbox.store.client()
} }
// update is a proxy for the store's db's `Update`. // update is a proxy for the store's db's `Update`.

View File

@ -37,7 +37,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message // FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
// wrapping it. // wrapping it.
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) { func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
msg, err := storeMailbox.api().GetMessage(apiID) msg, err := storeMailbox.client().GetMessage(apiID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -62,7 +62,7 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
LabelIDs: labelIDs, LabelIDs: labelIDs,
} }
res, err := storeMailbox.api().Import([]*pmapi.ImportMsgReq{importReqs}) res, err := storeMailbox.client().Import([]*pmapi.ImportMsgReq{importReqs})
if err == nil && len(res) > 0 { if err == nil && len(res) > 0 {
msg.ID = res[0].MessageID msg.ID = res[0].MessageID
} }
@ -79,7 +79,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Labeling messages") }).Trace("Labeling messages")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.api().LabelMessages(apiIDs, storeMailbox.labelID) return storeMailbox.client().LabelMessages(apiIDs, storeMailbox.labelID)
} }
// UnlabelMessages removes the label by calling an API. // UnlabelMessages removes the label by calling an API.
@ -92,7 +92,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Unlabeling messages") }).Trace("Unlabeling messages")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID) return storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID)
} }
// MarkMessagesRead marks the message read by calling an API. // MarkMessagesRead marks the message read by calling an API.
@ -116,7 +116,7 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
ids = append(ids, apiID) ids = append(ids, apiID)
} }
} }
return storeMailbox.api().MarkMessagesRead(ids) return storeMailbox.client().MarkMessagesRead(ids)
} }
// MarkMessagesUnread marks the message unread by calling an API. // MarkMessagesUnread marks the message unread by calling an API.
@ -128,7 +128,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Marking messages as unread") }).Trace("Marking messages as unread")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.api().MarkMessagesUnread(apiIDs) return storeMailbox.client().MarkMessagesUnread(apiIDs)
} }
// MarkMessagesStarred adds the Starred label by calling an API. // MarkMessagesStarred adds the Starred label by calling an API.
@ -141,7 +141,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Marking messages as starred") }).Trace("Marking messages as starred")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.api().LabelMessages(apiIDs, pmapi.StarredLabel) return storeMailbox.client().LabelMessages(apiIDs, pmapi.StarredLabel)
} }
// MarkMessagesUnstarred removes the Starred label by calling an API. // MarkMessagesUnstarred removes the Starred label by calling an API.
@ -154,7 +154,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Marking messages as unstarred") }).Trace("Marking messages as unstarred")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.api().UnlabelMessages(apiIDs, pmapi.StarredLabel) return storeMailbox.client().UnlabelMessages(apiIDs, pmapi.StarredLabel)
} }
// DeleteMessages deletes messages. // DeleteMessages deletes messages.
@ -197,21 +197,21 @@ func (storeMailbox *Mailbox) DeleteMessages(apiIDs []string) error {
} }
} }
if len(messageIDsToUnlabel) > 0 { if len(messageIDsToUnlabel) > 0 {
if err := storeMailbox.api().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil { if err := storeMailbox.client().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil {
log.WithError(err).Warning("Cannot unlabel before deleting") log.WithError(err).Warning("Cannot unlabel before deleting")
} }
} }
if len(messageIDsToDelete) > 0 { if len(messageIDsToDelete) > 0 {
if err := storeMailbox.api().DeleteMessages(messageIDsToDelete); err != nil { if err := storeMailbox.client().DeleteMessages(messageIDsToDelete); err != nil {
return err return err
} }
} }
case pmapi.DraftLabel: case pmapi.DraftLabel:
if err := storeMailbox.api().DeleteMessages(apiIDs); err != nil { if err := storeMailbox.client().DeleteMessages(apiIDs); err != nil {
return err return err
} }
default: default:
if err := storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil { if err := storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil {
return err return err
} }
} }

View File

@ -28,7 +28,6 @@ import (
// a specific mailbox with helper functions to get IMAP UID, sequence // a specific mailbox with helper functions to get IMAP UID, sequence
// numbers and similar. // numbers and similar.
type Message struct { type Message struct {
api PMAPIProvider
msg *pmapi.Message msg *pmapi.Message
store *Store store *Store
@ -37,7 +36,6 @@ type Message struct {
func newStoreMessage(storeMailbox *Mailbox, msg *pmapi.Message) *Message { func newStoreMessage(storeMailbox *Mailbox, msg *pmapi.Message) *Message {
return &Message{ return &Message{
api: storeMailbox.store.api,
msg: msg, msg: msg,
store: storeMailbox.store, store: storeMailbox.store,
storeMailbox: storeMailbox, storeMailbox: storeMailbox,

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser) // Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,ClientManager,BridgeUser)
// Package mocks is a generated GoMock package. // Package mocks is a generated GoMock package.
package mocks package mocks
@ -7,6 +7,7 @@ package mocks
import ( import (
reflect "reflect" reflect "reflect"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
@ -45,6 +46,43 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
} }
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// MockBridgeUser is a mock of BridgeUser interface // MockBridgeUser is a mock of BridgeUser interface
type MockBridgeUser struct { type MockBridgeUser struct {
ctrl *gomock.Controller ctrl *gomock.Controller

View File

@ -92,7 +92,7 @@ type Store struct {
panicHandler PanicHandler panicHandler PanicHandler
eventLoop *eventLoop eventLoop *eventLoop
user BridgeUser user BridgeUser
api PMAPIProvider clientManager ClientManager
log *logrus.Entry log *logrus.Entry
@ -111,13 +111,13 @@ type Store struct {
func New( func New(
panicHandler PanicHandler, panicHandler PanicHandler,
user BridgeUser, user BridgeUser,
api PMAPIProvider, clientManager ClientManager,
events listener.Listener, events listener.Listener,
path string, path string,
cache *Cache, cache *Cache,
) (store *Store, err error) { ) (store *Store, err error) {
if user == nil || api == nil || events == nil || cache == nil { if user == nil || clientManager == nil || events == nil || cache == nil {
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, api, events, cache) return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, clientManager, events, cache)
} }
l := log.WithField("user", user.ID()) l := log.WithField("user", user.ID())
@ -139,7 +139,7 @@ func New(
store = &Store{ store = &Store{
panicHandler: panicHandler, panicHandler: panicHandler,
api: api, clientManager: clientManager,
user: user, user: user,
cache: cache, cache: cache,
filePath: path, filePath: path,
@ -158,7 +158,7 @@ func New(
} }
if user.IsConnected() { if user.IsConnected() {
store.eventLoop = newEventLoop(cache, store, api, user, events) store.eventLoop = newEventLoop(cache, store, user, events)
go func() { go func() {
defer store.panicHandler.HandlePanic() defer store.panicHandler.HandlePanic()
store.eventLoop.start() store.eventLoop.start()
@ -261,10 +261,14 @@ func (store *Store) init(firstInit bool) (err error) {
return err return err
} }
func (store *Store) client() pmapi.Client {
return store.clientManager.GetClient(store.UserID())
}
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if // initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
// the API is unavailable for whatever reason it tries to fetch the labels locally. // the API is unavailable for whatever reason it tries to fetch the labels locally.
func (store *Store) initCounts() (labels []*pmapi.Label, err error) { func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
if labels, err = store.api.ListLabels(); err != nil { if labels, err = store.client().ListLabels(); err != nil {
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.") store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
if labels, err = store.getLabelsFromLocalStorage(); err != nil { if labels, err = store.getLabelsFromLocalStorage(); err != nil {
store.log.WithError(err).Error("Cannot list local labels") store.log.WithError(err).Error("Cannot list local labels")

View File

@ -24,9 +24,9 @@ import (
"sync" "sync"
"testing" "testing"
bridgemocks "github.com/ProtonMail/proton-bridge/internal/bridge/mocks" storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
storeMocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -44,10 +44,11 @@ type mocksForStore struct {
tb testing.TB tb testing.TB
ctrl *gomock.Controller ctrl *gomock.Controller
events *storeMocks.MockListener events *storemocks.MockListener
api *bridgemocks.MockPMAPIProvider user *storemocks.MockBridgeUser
user *storeMocks.MockBridgeUser client *pmapimocks.MockClient
panicHandler *storeMocks.MockPanicHandler clientManager *storemocks.MockClientManager
panicHandler *storemocks.MockPanicHandler
store *Store store *Store
tmpDir string tmpDir string
@ -59,10 +60,11 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
mocks := &mocksForStore{ mocks := &mocksForStore{
tb: tb, tb: tb,
ctrl: ctrl, ctrl: ctrl,
events: storeMocks.NewMockListener(ctrl), events: storemocks.NewMockListener(ctrl),
api: bridgemocks.NewMockPMAPIProvider(ctrl), user: storemocks.NewMockBridgeUser(ctrl),
user: storeMocks.NewMockBridgeUser(ctrl), client: pmapimocks.NewMockClient(ctrl),
panicHandler: storeMocks.NewMockPanicHandler(ctrl), clientManager: storemocks.NewMockClientManager(ctrl),
panicHandler: storemocks.NewMockPanicHandler(ctrl),
} }
// Called during clean-up. // Called during clean-up.
@ -92,13 +94,15 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar
mocks.user.EXPECT().IsConnected().Return(true) mocks.user.EXPECT().IsConnected().Return(true)
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode) mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
mocks.api.EXPECT().Addresses().Return(pmapi.AddressList{ mocks.clientManager.EXPECT().GetClient("userID").AnyTimes().Return(mocks.client)
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive}, {ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive}, {ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
}) })
mocks.api.EXPECT().ListLabels() mocks.client.EXPECT().ListLabels()
mocks.api.EXPECT().CountMessages("") mocks.client.EXPECT().CountMessages("")
mocks.api.EXPECT().GetEvent(gomock.Any()). mocks.client.EXPECT().GetEvent(gomock.Any()).
Return(&pmapi.Event{ Return(&pmapi.Event{
EventID: "latestEventID", EventID: "latestEventID",
}, nil).AnyTimes() }, nil).AnyTimes()
@ -106,7 +110,7 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar
// We want to wait until first sync has finished. // We want to wait until first sync has finished.
firstSyncWaiter := sync.WaitGroup{} firstSyncWaiter := sync.WaitGroup{}
firstSyncWaiter.Add(1) firstSyncWaiter.Add(1)
mocks.api.EXPECT(). mocks.client.EXPECT().
ListMessages(gomock.Any()). ListMessages(gomock.Any()).
DoAndReturn(func(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error) { DoAndReturn(func(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error) {
firstSyncWaiter.Done() firstSyncWaiter.Done()
@ -117,7 +121,7 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar
mocks.store, err = New( mocks.store, err = New(
mocks.panicHandler, mocks.panicHandler,
mocks.user, mocks.user,
mocks.api, mocks.clientManager,
mocks.events, mocks.events,
filepath.Join(mocks.tmpDir, "mailbox-test.db"), filepath.Join(mocks.tmpDir, "mailbox-test.db"),
mocks.cache, mocks.cache,

View File

@ -17,42 +17,14 @@
package store package store
import ( import "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"io"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type PanicHandler interface { type PanicHandler interface {
HandlePanic() HandlePanic()
} }
// PMAPIProvider is subset of pmapi.Client for use by the Store. type ClientManager interface {
type PMAPIProvider interface { GetClient(userID string) pmapi.Client
CurrentUser() (*pmapi.User, error)
Addresses() pmapi.AddressList
GetEvent(eventID string) (*pmapi.Event, error)
CountMessages(addressID string) ([]*pmapi.MessagesCount, error)
ListMessages(filter *pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
GetMessage(apiID string) (*pmapi.Message, error)
Import([]*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error)
DeleteMessages(apiIDs []string) error
LabelMessages(apiIDs []string, labelID string) error
UnlabelMessages(apiIDs []string, labelID string) error
MarkMessagesRead(apiIDs []string) error
MarkMessagesUnread(apiIDs []string) error
CreateDraft(m *pmapi.Message, parent string, action int) (created *pmapi.Message, err error)
CreateAttachment(att *pmapi.Attachment, r io.Reader, sig io.Reader) (created *pmapi.Attachment, err error)
SendMessage(messageID string, req *pmapi.SendMessageReq) (sent, parent *pmapi.Message, err error)
ListLabels() ([]*pmapi.Label, error)
CreateLabel(label *pmapi.Label) (*pmapi.Label, error)
UpdateLabel(label *pmapi.Label) (*pmapi.Label, error)
DeleteLabel(labelID string) error
EmptyFolder(labelID string, addressID string) error
} }
// BridgeUser is subset of bridge.User for use by the Store. // BridgeUser is subset of bridge.User for use by the Store.

View File

@ -24,7 +24,7 @@ func (store *Store) UserID() string {
// GetSpace returns used and total space in bytes. // GetSpace returns used and total space in bytes.
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) { func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
apiUser, err := store.api.CurrentUser() apiUser, err := store.client().CurrentUser()
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
@ -33,7 +33,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
// GetMaxUpload returns max size of attachment in bytes. // GetMaxUpload returns max size of attachment in bytes.
func (store *Store) GetMaxUpload() (uint, error) { func (store *Store) GetMaxUpload() (uint, error) {
apiUser, err := store.api.CurrentUser() apiUser, err := store.client().CurrentUser()
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -61,7 +61,7 @@ func (store *Store) GetAddressInfo() (addrs []AddressInfo, err error) {
} }
// Store does not have address info yet, need to build it first from API. // Store does not have address info yet, need to build it first from API.
addressList := store.api.Addresses() addressList := store.client().Addresses()
if addressList == nil { if addressList == nil {
err = errors.New("addresses unavailable") err = errors.New("addresses unavailable")
store.log.WithError(err).Error("Could not get user addresses from API") store.log.WithError(err).Error("Could not get user addresses from API")

View File

@ -55,7 +55,7 @@ func (store *Store) createMailbox(name string) error {
return nil return nil
} }
_, err := store.api.CreateLabel(&pmapi.Label{ _, err := store.client().CreateLabel(&pmapi.Label{
Name: name, Name: name,
Color: color, Color: color,
Exclusive: exclusive, Exclusive: exclusive,
@ -133,7 +133,7 @@ func (store *Store) leastUsedColor() string {
func (store *Store) updateMailbox(labelID, newName, color string) error { func (store *Store) updateMailbox(labelID, newName, color string) error {
defer store.eventLoop.pollNow() defer store.eventLoop.pollNow()
_, err := store.api.UpdateLabel(&pmapi.Label{ _, err := store.client().UpdateLabel(&pmapi.Label{
ID: labelID, ID: labelID,
Name: newName, Name: newName,
Color: color, Color: color,
@ -150,15 +150,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
var err error var err error
switch labelID { switch labelID {
case pmapi.SpamLabel: case pmapi.SpamLabel:
err = store.api.EmptyFolder(pmapi.SpamLabel, addressID) err = store.client().EmptyFolder(pmapi.SpamLabel, addressID)
case pmapi.TrashLabel: case pmapi.TrashLabel:
err = store.api.EmptyFolder(pmapi.TrashLabel, addressID) err = store.client().EmptyFolder(pmapi.TrashLabel, addressID)
default: default:
err = fmt.Errorf("cannot empty mailbox %v", labelID) err = fmt.Errorf("cannot empty mailbox %v", labelID)
} }
return err return err
} }
return store.api.DeleteLabel(labelID) return store.client().DeleteLabel(labelID)
} }
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error { func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
@ -173,7 +173,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
return nil return nil
} }
labels, err := store.api.ListLabels() labels, err := store.client().ListLabels()
if err != nil { if err != nil {
return err return err
} }

View File

@ -54,7 +54,7 @@ func (store *Store) CreateDraft(
message.Attachments = nil message.Attachments = nil
draftAction := store.getDraftAction(message) draftAction := store.getDraftAction(message)
draft, err := store.api.CreateDraft(message, parentID, draftAction) draft, err := store.client().CreateDraft(message, parentID, draftAction)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "failed to create draft") return nil, nil, errors.Wrap(err, "failed to create draft")
} }
@ -105,7 +105,7 @@ func (store *Store) createAttachment(kr *pmcrypto.KeyRing, attachment *pmapi.Att
return nil, errors.Wrap(err, "failed to encrypt attachment") return nil, errors.Wrap(err, "failed to encrypt attachment")
} }
createdAttachment, err := store.api.CreateAttachment(attachment, encReader, sigReader) createdAttachment, err := store.client().CreateAttachment(attachment, encReader, sigReader)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to create attachment") return nil, errors.Wrap(err, "failed to create attachment")
} }
@ -116,7 +116,7 @@ func (store *Store) createAttachment(kr *pmcrypto.KeyRing, attachment *pmapi.Att
// SendMessage sends the message. // SendMessage sends the message.
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error { func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
defer store.eventLoop.pollNow() defer store.eventLoop.pollNow()
_, _, err := store.api.SendMessage(messageID, req) _, _, err := store.client().SendMessage(messageID, req)
return err return err
} }

View File

@ -34,7 +34,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
// updateCountsFromServer will download and set the counts. // updateCountsFromServer will download and set the counts.
func (store *Store) updateCountsFromServer() error { func (store *Store) updateCountsFromServer() error {
counts, err := store.api.CountMessages("") counts, err := store.client().CountMessages("")
if err != nil { if err != nil {
return errors.Wrap(err, "cannot update counts from server") return errors.Wrap(err, "cannot update counts from server")
} }
@ -144,7 +144,8 @@ func (store *Store) triggerSync() {
store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started") store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started")
err := syncAllMail(store.panicHandler, store, store.api, syncState) // TODO: Is it okay to pass in a client directly? What if it is logged out in the meantime?
err := syncAllMail(store.panicHandler, store, store.client(), syncState)
if err != nil { if err != nil {
log.WithError(err).Error("Store sync failed") log.WithError(err).Error("Store sync failed")
return return

View File

@ -464,15 +464,13 @@ func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
return auth, err return auth, err
} }
// Logout instructs the client manager to log out this client. // TODO: Should this even be a client method? Or just a method on the client manager?
func (c *client) Logout() { func (c *client) Logout() {
c.cm.LogoutClient(c.userID) c.cm.LogoutClient(c.userID)
} }
// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not. // DeleteAuth deletes the API session.
func (c *client) DeleteAuth() (err error) {
// logout logs the current user out.
func (c *client) logout() (err error) {
req, err := c.NewRequest("DELETE", "/auth", nil) req, err := c.NewRequest("DELETE", "/auth", nil)
if err != nil { if err != nil {
return return
@ -490,7 +488,10 @@ func (c *client) logout() (err error) {
return return
} }
func (c *client) clearSensitiveData() { // TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not.
// ClearData clears sensitive data from the client.
func (c *client) ClearData() {
c.uid = "" c.uid = ""
c.accessToken = "" c.accessToken = ""
c.kr = nil c.kr = nil

View File

@ -98,20 +98,28 @@ type Client interface {
Auth(username, password string, info *AuthInfo) (*Auth, error) Auth(username, password string, info *AuthInfo) (*Auth, error)
AuthInfo(username string) (*AuthInfo, error) AuthInfo(username string) (*AuthInfo, error)
AuthRefresh(token string) (*Auth, error) AuthRefresh(token string) (*Auth, error)
Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error) Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error)
UnlockAddresses(passphrase []byte) error Logout()
DeleteAuth() error
ClearData()
CurrentUser() (*User, error) CurrentUser() (*User, error)
UpdateUser() (*User, error) UpdateUser() (*User, error)
Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error)
UnlockAddresses(passphrase []byte) error
GetAddresses() (addresses AddressList, err error)
Addresses() AddressList Addresses() AddressList
Logout()
GetEvent(eventID string) (*Event, error) GetEvent(eventID string) (*Event, error)
SendMessage(string, *SendMessageReq) (sent, parent *Message, err error)
CreateDraft(m *Message, parent string, action int) (created *Message, err error)
Import([]*ImportMsgReq) ([]*ImportMsgRes, error)
CountMessages(addressID string) ([]*MessagesCount, error) CountMessages(addressID string) ([]*MessagesCount, error)
ListMessages(filter *MessagesFilter) ([]*Message, int, error) ListMessages(filter *MessagesFilter) ([]*Message, int, error)
GetMessage(apiID string) (*Message, error) GetMessage(apiID string) (*Message, error)
Import([]*ImportMsgReq) ([]*ImportMsgRes, error)
DeleteMessages(apiIDs []string) error DeleteMessages(apiIDs []string) error
LabelMessages(apiIDs []string, labelID string) error LabelMessages(apiIDs []string, labelID string) error
UnlabelMessages(apiIDs []string, labelID string) error UnlabelMessages(apiIDs []string, labelID string) error
@ -128,20 +136,17 @@ type Client interface {
SendSimpleMetric(category, action, label string) error SendSimpleMetric(category, action, label string) error
ReportSentryCrash(reportErr error) (err error) ReportSentryCrash(reportErr error) (err error)
Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error)
GetMailSettings() (MailSettings, error) GetMailSettings() (MailSettings, error)
GetContactEmailByEmail(string, int, int) ([]ContactEmail, error) GetContactEmailByEmail(string, int, int) ([]ContactEmail, error)
GetContactByID(string) (Contact, error) GetContactByID(string) (Contact, error)
DecryptAndVerifyCards([]Card) ([]Card, error) DecryptAndVerifyCards([]Card) ([]Card, error)
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
SendMessage(string, *SendMessageReq) (sent, parent *Message, err error)
CreateDraft(m *Message, parent string, action int) (created *Message, err error)
CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
DeleteAttachment(attID string) (err error)
KeyRingForAddressID(string) (kr *pmcrypto.KeyRing)
GetAttachment(id string) (att io.ReadCloser, err error) GetAttachment(id string) (att io.ReadCloser, err error)
CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
DeleteAttachment(attID string) (err error)
KeyRingForAddressID(string) (kr *pmcrypto.KeyRing)
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
} }
// client is a client of the protonmail API. It implements the Client interface. // client is a client of the protonmail API. It implements the Client interface.

View File

@ -15,10 +15,17 @@ var defaultProxyUseDuration = 24 * time.Hour
// ClientManager is a manager of clients. // ClientManager is a manager of clients.
type ClientManager struct { type ClientManager struct {
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
// create other types of clients (e.g. for integration tests).
newClient func(userID string) Client
config *ClientConfig config *ClientConfig
roundTripper http.RoundTripper roundTripper http.RoundTripper
clients map[string]*client // TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests.
// But that screws up other things like not being able to clear sensitive info during logout
// unless the client interface contains a method for that.
clients map[string]Client
clientsLocker sync.Locker clientsLocker sync.Locker
tokens map[string]string tokens map[string]string
@ -38,11 +45,13 @@ type ClientManager struct {
proxyUseDuration time.Duration proxyUseDuration time.Duration
} }
// ClientAuth holds an API auth produced by a Client for a specific user.
type ClientAuth struct { type ClientAuth struct {
UserID string UserID string
Auth *Auth Auth *Auth
} }
// tokenExpiration manages the expiration of an access token.
type tokenExpiration struct { type tokenExpiration struct {
timer *time.Timer timer *time.Timer
cancel chan (struct{}) cancel chan (struct{})
@ -58,7 +67,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
config: config, config: config,
roundTripper: http.DefaultTransport, roundTripper: http.DefaultTransport,
clients: make(map[string]*client), clients: make(map[string]Client),
clientsLocker: &sync.Mutex{}, clientsLocker: &sync.Mutex{},
tokens: make(map[string]string), tokens: make(map[string]string),
@ -78,11 +87,19 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
proxyUseDuration: defaultProxyUseDuration, proxyUseDuration: defaultProxyUseDuration,
} }
cm.newClient = func(userID string) Client {
return newClient(cm, userID)
}
go cm.forwardClientAuths() go cm.forwardClientAuths()
return return
} }
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
cm.newClient = f
}
// SetRoundTripper sets the roundtripper used by clients created by this client manager. // SetRoundTripper sets the roundtripper used by clients created by this client manager.
func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) { func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
cm.roundTripper = rt cm.roundTripper = rt
@ -100,7 +117,7 @@ func (cm *ClientManager) GetClient(userID string) Client {
return client return client
} }
cm.clients[userID] = newClient(cm, userID) cm.clients[userID] = cm.newClient(userID)
return cm.clients[userID] return cm.clients[userID]
} }
@ -108,10 +125,10 @@ func (cm *ClientManager) GetClient(userID string) Client {
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created. // GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created.
func (cm *ClientManager) GetAnonymousClient() Client { func (cm *ClientManager) GetAnonymousClient() Client {
if client, ok := cm.clients[""]; ok { if client, ok := cm.clients[""]; ok {
client.Logout() client.DeleteAuth()
} }
cm.clients[""] = newClient(cm, "") cm.clients[""] = cm.newClient("")
return cm.clients[""] return cm.clients[""]
} }
@ -127,10 +144,10 @@ func (cm *ClientManager) LogoutClient(userID string) {
delete(cm.clients, userID) delete(cm.clients, userID)
go func() { go func() {
if err := client.logout(); err != nil { if err := client.DeleteAuth(); err != nil {
// TODO: Try again! This should loop until it succeeds (might fail the first time due to internet). // TODO: Retry if the request failed.
} }
client.clearSensitiveData() client.ClearData()
cm.clearToken(userID) cm.clearToken(userID)
}() }()

View File

@ -256,6 +256,21 @@ func (mr *MockClientMockRecorder) EmptyFolder(arg0, arg1 interface{}) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmptyFolder", reflect.TypeOf((*MockClient)(nil).EmptyFolder), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmptyFolder", reflect.TypeOf((*MockClient)(nil).EmptyFolder), arg0, arg1)
} }
// GetAddresses mocks base method
func (m *MockClient) GetAddresses() (pmapi.AddressList, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAddresses")
ret0, _ := ret[0].(pmapi.AddressList)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAddresses indicates an expected call of GetAddresses
func (mr *MockClientMockRecorder) GetAddresses() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddresses", reflect.TypeOf((*MockClient)(nil).GetAddresses))
}
// GetAttachment mocks base method // GetAttachment mocks base method
func (m *MockClient) GetAttachment(arg0 string) (io.ReadCloser, error) { func (m *MockClient) GetAttachment(arg0 string) (io.ReadCloser, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -21,9 +21,9 @@ import (
"os" "os"
"runtime" "runtime"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/internal/preferences"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
) )
// GetBridge returns bridge instance. // GetBridge returns bridge instance.
@ -34,10 +34,7 @@ func (ctx *TestContext) GetBridge() *bridge.Bridge {
// withBridgeInstance creates a bridge instance for use in the test. // withBridgeInstance creates a bridge instance for use in the test.
// Every TestContext has this by default and thus this doesn't need to be exported. // Every TestContext has this by default and thus this doesn't need to be exported.
func (ctx *TestContext) withBridgeInstance() { func (ctx *TestContext) withBridgeInstance() {
pmapiFactory := func(userID string) pmapi.Client { ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, ctx.clientManager)
return ctx.pmapiController.GetClient(userID)
}
ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, pmapiFactory)
ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data") ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data")
} }
@ -62,7 +59,7 @@ func newBridgeInstance(
cfg *fakeConfig, cfg *fakeConfig,
credStore bridge.CredentialsStorer, credStore bridge.CredentialsStorer,
eventListener listener.Listener, eventListener listener.Listener,
pmapiFactory bridge.PMAPIProviderFactory, clientManager bridge.ClientManager,
) *bridge.Bridge { ) *bridge.Bridge {
version := os.Getenv("VERSION") version := os.Getenv("VERSION")
bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "") bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "")
@ -70,7 +67,7 @@ func newBridgeInstance(
panicHandler := &panicHandler{t: t} panicHandler := &panicHandler{t: t}
pref := preferences.New(cfg) pref := preferences.New(cfg)
return bridge.New(cfg, pref, panicHandler, eventListener, version, pmapiFactory, credStore) return bridge.New(cfg, pref, panicHandler, eventListener, version, clientManager, credStore)
} }
// SetLastBridgeError sets the last error that occurred while executing a bridge action. // SetLastBridgeError sets the last error that occurred while executing a bridge action.

View File

@ -28,7 +28,6 @@ import (
type fakeConfig struct { type fakeConfig struct {
dir string dir string
tm *pmapi.TokenManager
} }
// newFakeConfig creates a temporary folder for files. // newFakeConfig creates a temporary folder for files.
@ -41,7 +40,6 @@ func newFakeConfig() *fakeConfig {
return &fakeConfig{ return &fakeConfig{
dir: dir, dir: dir,
tm: pmapi.NewTokenManager(),
} }
} }
@ -53,8 +51,6 @@ func (c *fakeConfig) GetAPIConfig() *pmapi.ClientConfig {
AppVersion: "Bridge_" + os.Getenv("VERSION"), AppVersion: "Bridge_" + os.Getenv("VERSION"),
ClientID: "bridge", ClientID: "bridge",
SentryDSN: "", SentryDSN: "",
// TokenManager should not be required, but PMAPI still doesn't handle not-set cases everywhere.
TokenManager: c.tm,
} }
} }
func (c *fakeConfig) GetDBDir() string { func (c *fakeConfig) GetDBDir() string {

View File

@ -21,6 +21,7 @@ package context
import ( import (
"github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/test/accounts" "github.com/ProtonMail/proton-bridge/test/accounts"
"github.com/ProtonMail/proton-bridge/test/mocks" "github.com/ProtonMail/proton-bridge/test/mocks"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -44,6 +45,7 @@ type TestContext struct {
bridge *bridge.Bridge bridge *bridge.Bridge
bridgeLastError error bridgeLastError error
credStore bridge.CredentialsStorer credStore bridge.CredentialsStorer
clientManager *pmapi.ClientManager
// IMAP related variables. // IMAP related variables.
imapAddr string imapAddr string
@ -70,11 +72,14 @@ func New() *TestContext {
cfg := newFakeConfig() cfg := newFakeConfig()
cm := pmapi.NewClientManager(cfg.GetAPIConfig())
ctx := &TestContext{ ctx := &TestContext{
t: &bddT{}, t: &bddT{},
cfg: cfg, cfg: cfg,
listener: listener.New(), listener: listener.New(),
pmapiController: newPMAPIController(), pmapiController: newPMAPIController(cm),
clientManager: cm,
testAccounts: newTestAccounts(), testAccounts: newTestAccounts(),
credStore: newFakeCredStore(), credStore: newFakeCredStore(),
imapClients: make(map[string]*mocks.IMAPClient), imapClients: make(map[string]*mocks.IMAPClient),

View File

@ -20,14 +20,12 @@ package context
import ( import (
"os" "os"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/test/fakeapi" "github.com/ProtonMail/proton-bridge/test/fakeapi"
"github.com/ProtonMail/proton-bridge/test/liveapi" "github.com/ProtonMail/proton-bridge/test/liveapi"
) )
type PMAPIController interface { type PMAPIController interface {
GetClient(userID string) bridge.PMAPIProvider
TurnInternetConnectionOff() TurnInternetConnectionOff()
TurnInternetConnectionOn() TurnInternetConnectionOn()
AddUser(user *pmapi.User, addresses *pmapi.AddressList, password string, twoFAEnabled bool) error AddUser(user *pmapi.User, addresses *pmapi.AddressList, password string, twoFAEnabled bool) error
@ -40,19 +38,19 @@ type PMAPIController interface {
GetCalls(method, path string) [][]byte GetCalls(method, path string) [][]byte
} }
func newPMAPIController() PMAPIController { func newPMAPIController(cm *pmapi.ClientManager) PMAPIController {
switch os.Getenv(EnvName) { switch os.Getenv(EnvName) {
case EnvFake: case EnvFake:
return newFakePMAPIController() return newFakePMAPIController(cm)
case EnvLive: case EnvLive:
return newLivePMAPIController() return newLivePMAPIController(cm)
default: default:
panic("unknown env") panic("unknown env")
} }
} }
func newFakePMAPIController() PMAPIController { func newFakePMAPIController(cm *pmapi.ClientManager) PMAPIController {
return newFakePMAPIControllerWrap(fakeapi.NewController()) return newFakePMAPIControllerWrap(fakeapi.NewController(cm))
} }
type fakePMAPIControllerWrap struct { type fakePMAPIControllerWrap struct {
@ -63,12 +61,8 @@ func newFakePMAPIControllerWrap(controller *fakeapi.Controller) PMAPIController
return &fakePMAPIControllerWrap{Controller: controller} return &fakePMAPIControllerWrap{Controller: controller}
} }
func (s *fakePMAPIControllerWrap) GetClient(userID string) bridge.PMAPIProvider { func newLivePMAPIController(cm *pmapi.ClientManager) PMAPIController {
return s.Controller.GetClient(userID) return newLiveAPIControllerWrap(liveapi.NewController(cm))
}
func newLivePMAPIController() PMAPIController {
return newLiveAPIControllerWrap(liveapi.NewController())
} }
type liveAPIControllerWrap struct { type liveAPIControllerWrap struct {
@ -78,7 +72,3 @@ type liveAPIControllerWrap struct {
func newLiveAPIControllerWrap(controller *liveapi.Controller) PMAPIController { func newLiveAPIControllerWrap(controller *liveapi.Controller) PMAPIController {
return &liveAPIControllerWrap{Controller: controller} return &liveAPIControllerWrap{Controller: controller}
} }
func (s *liveAPIControllerWrap) GetClient(userID string) bridge.PMAPIProvider {
return s.Controller.GetClient(userID)
}

View File

@ -45,3 +45,11 @@ func (api *FakePMAPI) CreateAttachment(attachment *pmapi.Attachment, data io.Rea
attachment.KeyPackets = base64.StdEncoding.EncodeToString(bytes) attachment.KeyPackets = base64.StdEncoding.EncodeToString(bytes)
return attachment, nil return attachment, nil
} }
func (api *FakePMAPI) DeleteAttachment(attachmentID string) error {
if err := api.checkAndRecordCall(GET, "/attachments/"+attachmentID, nil); err != nil {
return err
}
return nil
}

View File

@ -141,13 +141,23 @@ func (api *FakePMAPI) AuthRefresh(token string) (*pmapi.Auth, error) {
return auth, nil return auth, nil
} }
func (api *FakePMAPI) Logout() error { func (api *FakePMAPI) Logout() {
_ = api.DeleteAuth()
api.ClearData()
}
func (api *FakePMAPI) DeleteAuth() error {
if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil { if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil {
return err return err
} }
// Logout will also emit change to auth channel // Logout will also emit change to auth channel
api.sendAuth(nil) api.sendAuth(nil)
api.controller.deleteSession(api.uid)
api.unsetUser()
return nil return nil
} }
func (api *FakePMAPI) ClearData() {
api.controller.deleteSession(api.uid)
api.unsetUser()
}

View File

@ -44,8 +44,8 @@ type Controller struct {
log *logrus.Entry log *logrus.Entry
} }
func NewController() *Controller { func NewController(cm *pmapi.ClientManager) (cntrl *Controller) {
return &Controller{ cntrl = &Controller{
lock: &sync.RWMutex{}, lock: &sync.RWMutex{},
fakeAPIs: []*FakePMAPI{}, fakeAPIs: []*FakePMAPI{},
calls: []*fakeCall{}, calls: []*fakeCall{},
@ -62,10 +62,12 @@ func NewController() *Controller {
log: logrus.WithField("pkg", "fakeapi-controller"), log: logrus.WithField("pkg", "fakeapi-controller"),
} }
}
func (cntrl *Controller) GetClient(userID string) *FakePMAPI { cm.SetClientConstructor(func(userID string) pmapi.Client {
fakeAPI := New(cntrl) fakeAPI := New(cntrl)
cntrl.fakeAPIs = append(cntrl.fakeAPIs, fakeAPI) cntrl.fakeAPIs = append(cntrl.fakeAPIs, fakeAPI)
return fakeAPI return fakeAPI
})
return
} }

View File

@ -42,3 +42,7 @@ func (api *FakePMAPI) SendSimpleMetric(category, action, label string) error {
v.Set("Label", label) v.Set("Label", label)
return api.checkInternetAndRecordCall(GET, "/metrics?"+v.Encode(), nil) return api.checkInternetAndRecordCall(GET, "/metrics?"+v.Encode(), nil)
} }
func (api *FakePMAPI) ReportSentryCrash(reportErr error) (err error) {
return nil
}

View File

@ -50,6 +50,10 @@ func (api *FakePMAPI) UpdateUser() (*pmapi.User, error) {
return api.user, nil return api.user, nil
} }
func (api *FakePMAPI) GetAddresses() (pmapi.AddressList, error) {
return *api.addresses, nil
}
func (api *FakePMAPI) Addresses() pmapi.AddressList { func (api *FakePMAPI) Addresses() pmapi.AddressList {
return *api.addresses return *api.addresses
} }

View File

@ -24,7 +24,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func cleanup(client *pmapi.Client) error { func cleanup(client pmapi.Client) error {
if err := cleanSystemFolders(client); err != nil { if err := cleanSystemFolders(client); err != nil {
return errors.Wrap(err, "failed to clean system folders") return errors.Wrap(err, "failed to clean system folders")
} }
@ -37,7 +37,7 @@ func cleanup(client *pmapi.Client) error {
return nil return nil
} }
func cleanSystemFolders(client *pmapi.Client) error { func cleanSystemFolders(client pmapi.Client) error {
for _, labelID := range []string{pmapi.InboxLabel, pmapi.SentLabel, pmapi.ArchiveLabel, pmapi.AllMailLabel, pmapi.DraftLabel} { for _, labelID := range []string{pmapi.InboxLabel, pmapi.SentLabel, pmapi.ArchiveLabel, pmapi.AllMailLabel, pmapi.DraftLabel} {
for { for {
messages, total, err := client.ListMessages(&pmapi.MessagesFilter{ messages, total, err := client.ListMessages(&pmapi.MessagesFilter{
@ -69,7 +69,7 @@ func cleanSystemFolders(client *pmapi.Client) error {
return nil return nil
} }
func cleanCustomLables(client *pmapi.Client) error { func cleanCustomLables(client pmapi.Client) error {
labels, err := client.ListLabels() labels, err := client.ListLabels()
if err != nil { if err != nil {
return errors.Wrap(err, "failed to list labels") return errors.Wrap(err, "failed to list labels")
@ -87,7 +87,7 @@ func cleanCustomLables(client *pmapi.Client) error {
return nil return nil
} }
func cleanTrash(client *pmapi.Client) error { func cleanTrash(client pmapi.Client) error {
for { for {
_, total, err := client.ListMessages(&pmapi.MessagesFilter{ _, total, err := client.ListMessages(&pmapi.MessagesFilter{
PageSize: 1, PageSize: 1,
@ -110,7 +110,7 @@ func cleanTrash(client *pmapi.Client) error {
return nil return nil
} }
func emptyFolder(client *pmapi.Client, labelID string) error { func emptyFolder(client pmapi.Client, labelID string) error {
err := client.EmptyFolder(labelID, "") err := client.EmptyFolder(labelID, "")
if err != nil { if err != nil {
return err return err

View File

@ -18,9 +18,7 @@
package liveapi package liveapi
import ( import (
"fmt"
"net/http" "net/http"
"os"
"sync" "sync"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
@ -30,33 +28,27 @@ type Controller struct {
// Internal states. // Internal states.
lock *sync.RWMutex lock *sync.RWMutex
calls []*fakeCall calls []*fakeCall
pmapiByUsername map[string]*pmapi.Client
messageIDsByUsername map[string][]string messageIDsByUsername map[string][]string
clientManager *pmapi.ClientManager
// State controlled by test. // State controlled by test.
noInternetConnection bool noInternetConnection bool
} }
func NewController() *Controller { func NewController(cm *pmapi.ClientManager) (cntrl *Controller) {
return &Controller{ cntrl = &Controller{
lock: &sync.RWMutex{}, lock: &sync.RWMutex{},
calls: []*fakeCall{}, calls: []*fakeCall{},
pmapiByUsername: map[string]*pmapi.Client{},
messageIDsByUsername: map[string][]string{}, messageIDsByUsername: map[string][]string{},
clientManager: cm,
noInternetConnection: false, noInternetConnection: false,
} }
}
func (cntrl *Controller) GetClient(userID string) *pmapi.Client { cm.SetRoundTripper(&fakeTransport{
cfg := &pmapi.ClientConfig{
AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")),
ClientID: "bridge-test",
Transport: &fakeTransport{
cntrl: cntrl, cntrl: cntrl,
transport: http.DefaultTransport, transport: http.DefaultTransport,
}, })
TokenManager: pmapi.NewTokenManager(),
} return
return pmapi.NewClient(cfg, userID)
} }

View File

@ -36,11 +36,7 @@ var systemLabelNameToID = map[string]string{ //nolint[gochecknoglobals]
} }
func (cntrl *Controller) AddUserLabel(username string, label *pmapi.Label) error { func (cntrl *Controller) AddUserLabel(username string, label *pmapi.Label) error {
client, ok := cntrl.pmapiByUsername[username] client := cntrl.clientManager.GetClient(username)
if !ok {
return fmt.Errorf("user %s does not exist", username)
}
label.Exclusive = getLabelExclusive(label.Name) label.Exclusive = getLabelExclusive(label.Name)
label.Name = getLabelNameWithoutPrefix(label.Name) label.Name = getLabelNameWithoutPrefix(label.Name)
label.Color = pmapi.LabelColors[0] label.Color = pmapi.LabelColors[0]
@ -67,11 +63,7 @@ func (cntrl *Controller) getLabelID(username, labelName string) (string, error)
return labelID, nil return labelID, nil
} }
client, ok := cntrl.pmapiByUsername[username] client := cntrl.clientManager.GetClient(username)
if !ok {
return "", fmt.Errorf("user %s does not exist", username)
}
labels, err := client.ListLabels() labels, err := client.ListLabels()
if err != nil { if err != nil {
return "", errors.Wrap(err, "failed to list labels") return "", errors.Wrap(err, "failed to list labels")

View File

@ -31,10 +31,7 @@ import (
) )
func (cntrl *Controller) AddUserMessage(username string, message *pmapi.Message) error { func (cntrl *Controller) AddUserMessage(username string, message *pmapi.Message) error {
client, ok := cntrl.pmapiByUsername[username] client := cntrl.clientManager.GetClient(username)
if !ok {
return fmt.Errorf("user %s does not exist", username)
}
body, err := buildMessage(client, message) body, err := buildMessage(client, message)
if err != nil { if err != nil {
@ -64,7 +61,7 @@ func (cntrl *Controller) AddUserMessage(username string, message *pmapi.Message)
return nil return nil
} }
func buildMessage(client *pmapi.Client, message *pmapi.Message) (*bytes.Buffer, error) { func buildMessage(client pmapi.Client, message *pmapi.Message) (*bytes.Buffer, error) {
if err := encryptMessage(client, message); err != nil { if err := encryptMessage(client, message); err != nil {
return nil, errors.Wrap(err, "failed to encrypt message") return nil, errors.Wrap(err, "failed to encrypt message")
} }
@ -79,7 +76,7 @@ func buildMessage(client *pmapi.Client, message *pmapi.Message) (*bytes.Buffer,
return body, nil return body, nil
} }
func encryptMessage(client *pmapi.Client, message *pmapi.Message) error { func encryptMessage(client pmapi.Client, message *pmapi.Message) error {
addresses, err := client.GetAddresses() addresses, err := client.GetAddresses()
if err != nil { if err != nil {
return errors.Wrap(err, "failed to get address") return errors.Wrap(err, "failed to get address")

View File

@ -18,9 +18,7 @@
package liveapi package liveapi
import ( import (
"fmt" "github.com/ProtonMail/bridge/pkg/pmapi"
"os"
"github.com/cucumber/godog" "github.com/cucumber/godog"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -30,11 +28,7 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList,
return godog.ErrPending return godog.ErrPending
} }
client := pmapi.NewClient(&pmapi.ClientConfig{ client := cntrl.clientManager.GetClient(user.ID)
AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")),
ClientID: "bridge-cntrl",
TokenManager: pmapi.NewTokenManager(),
}, user.ID)
authInfo, err := client.AuthInfo(user.Name) authInfo, err := client.AuthInfo(user.Name)
if err != nil { if err != nil {
@ -60,6 +54,5 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList,
return errors.Wrap(err, "failed to clean user") return errors.Wrap(err, "failed to clean user")
} }
cntrl.pmapiByUsername[user.Name] = client
return nil return nil
} }