mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
feat: make store use ClientManager
This commit is contained in:
4
Makefile
4
Makefile
@ -149,8 +149,8 @@ coverage: test
|
||||
go tool cover -html=/tmp/coverage.out -o=coverage.html
|
||||
|
||||
mocks:
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/bridge Configer,PreferenceProvider,PanicHandler,CredentialsStorer > internal/bridge/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser > internal/store/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/bridge Configer,PreferenceProvider,PanicHandler,ClientManager,CredentialsStorer > internal/bridge/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,ClientManager,BridgeUser > internal/store/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client > pkg/pmapi/mocks/mocks.go
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ type Bridge struct {
|
||||
panicHandler PanicHandler
|
||||
events listener.Listener
|
||||
version string
|
||||
clientManager *pmapi.ClientManager
|
||||
clientManager ClientManager
|
||||
credStorer CredentialsStorer
|
||||
storeCache *store.Cache
|
||||
|
||||
@ -76,7 +76,7 @@ func New(
|
||||
panicHandler PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
version string,
|
||||
clientManager *pmapi.ClientManager,
|
||||
clientManager ClientManager,
|
||||
credStorer CredentialsStorer,
|
||||
) *Bridge {
|
||||
log.Trace("Creating new bridge")
|
||||
@ -185,7 +185,7 @@ func (b *Bridge) watchAPIAuths() {
|
||||
|
||||
user, ok := b.hasUser(auth.UserID)
|
||||
if !ok {
|
||||
logrus.Info("User is not added to bridge yet")
|
||||
logrus.WithField("userID", auth.UserID).Info("User not available for auth update")
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ func TestBridgeFinishLoginBadPassword(t *testing.T) {
|
||||
// Set up mocks for FinishLogin.
|
||||
err := errors.New("bad password")
|
||||
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err)
|
||||
m.pmapiClient.EXPECT().Logout().Return(nil)
|
||||
m.pmapiClient.EXPECT().Logout()
|
||||
|
||||
checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", err)
|
||||
}
|
||||
|
||||
@ -56,6 +56,7 @@ func TestNewBridgeWithDisconnectedUser(t *testing.T) {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized"))
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil)
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient)
|
||||
|
||||
checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||
}
|
||||
@ -66,13 +67,14 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) {
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token"))
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.pmapiClient.EXPECT().Logout().Return(nil)
|
||||
m.pmapiClient.EXPECT().Logout()
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
@ -84,10 +86,14 @@ func TestNewBridgeWithConnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil)
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil)
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
// Set up mocks for store initialisation for the authorized user.
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil)
|
||||
@ -97,10 +103,6 @@ func TestNewBridgeWithConnectedUser(t *testing.T) {
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil)
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||
|
||||
checkBridgeNew(t, m, []*credentials.Credentials{testCredentials})
|
||||
}
|
||||
|
||||
|
||||
@ -129,11 +129,11 @@ type mocks struct {
|
||||
config *bridgemocks.MockConfiger
|
||||
PanicHandler *bridgemocks.MockPanicHandler
|
||||
prefProvider *bridgemocks.MockPreferenceProvider
|
||||
clientManager *bridgemocks.MockClientManager
|
||||
credentialsStore *bridgemocks.MockCredentialsStorer
|
||||
eventListener *MockListener
|
||||
|
||||
pmapiClient *pmapimocks.MockClient
|
||||
clientManager *pmapimocks.MockClientManager
|
||||
pmapiClient *pmapimocks.MockClient
|
||||
|
||||
storeCache *store.Cache
|
||||
}
|
||||
@ -151,11 +151,11 @@ func initMocks(t *testing.T) mocks {
|
||||
config: bridgemocks.NewMockConfiger(mockCtrl),
|
||||
PanicHandler: bridgemocks.NewMockPanicHandler(mockCtrl),
|
||||
prefProvider: bridgemocks.NewMockPreferenceProvider(mockCtrl),
|
||||
clientManager: bridgemocks.NewMockClientManager(mockCtrl),
|
||||
credentialsStore: bridgemocks.NewMockCredentialsStorer(mockCtrl),
|
||||
eventListener: NewMockListener(mockCtrl),
|
||||
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
clientManager: pmapimocks.NewMockClientManager(mockCtrl),
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
|
||||
storeCache: store.NewCache(cacheFile.Name()),
|
||||
}
|
||||
@ -214,7 +214,7 @@ func testNewBridge(t *testing.T, m mocks) *Bridge {
|
||||
m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes()
|
||||
m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes()
|
||||
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
||||
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient)
|
||||
m.clientManager.EXPECT().GetBridgeAuthChannel().Return(make(chan *pmapi.ClientAuth))
|
||||
|
||||
bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore)
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,CredentialsStorer)
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,ClientManager,CredentialsStorer)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
@ -203,6 +203,95 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockClientManager is a mock of ClientManager interface
|
||||
type MockClientManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||
type MockClientManagerMockRecorder struct {
|
||||
mock *MockClientManager
|
||||
}
|
||||
|
||||
// NewMockClientManager creates a new mock instance
|
||||
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||
mock := &MockClientManager{ctrl: ctrl}
|
||||
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AllowProxy mocks base method
|
||||
func (m *MockClientManager) AllowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AllowProxy")
|
||||
}
|
||||
|
||||
// AllowProxy indicates an expected call of AllowProxy
|
||||
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// DisallowProxy mocks base method
|
||||
func (m *MockClientManager) DisallowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisallowProxy")
|
||||
}
|
||||
|
||||
// DisallowProxy indicates an expected call of DisallowProxy
|
||||
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
|
||||
}
|
||||
|
||||
// GetAnonymousClient mocks base method
|
||||
func (m *MockClientManager) GetAnonymousClient() pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAnonymousClient")
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAnonymousClient indicates an expected call of GetAnonymousClient
|
||||
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
|
||||
}
|
||||
|
||||
// GetBridgeAuthChannel mocks base method
|
||||
func (m *MockClientManager) GetBridgeAuthChannel() chan pmapi.ClientAuth {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetBridgeAuthChannel")
|
||||
ret0, _ := ret[0].(chan pmapi.ClientAuth)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetBridgeAuthChannel indicates an expected call of GetBridgeAuthChannel
|
||||
func (mr *MockClientManagerMockRecorder) GetBridgeAuthChannel() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBridgeAuthChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel))
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||
}
|
||||
|
||||
// MockCredentialsStorer is a mock of CredentialsStorer interface
|
||||
type MockCredentialsStorer struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@ -51,3 +51,11 @@ type CredentialsStorer interface {
|
||||
Logout(userID string) error
|
||||
Delete(userID string) error
|
||||
}
|
||||
|
||||
type ClientManager interface {
|
||||
GetClient(userID string) pmapi.Client
|
||||
GetAnonymousClient() pmapi.Client
|
||||
AllowProxy()
|
||||
DisallowProxy()
|
||||
GetBridgeAuthChannel() chan pmapi.ClientAuth
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ type User struct {
|
||||
log *logrus.Entry
|
||||
panicHandler PanicHandler
|
||||
listener listener.Listener
|
||||
clientManager *pmapi.ClientManager
|
||||
clientManager ClientManager
|
||||
credStorer CredentialsStorer
|
||||
|
||||
imapUpdatesChannel chan interface{}
|
||||
@ -66,7 +66,7 @@ func newUser(
|
||||
userID string,
|
||||
eventListener listener.Listener,
|
||||
credStorer CredentialsStorer,
|
||||
clientManager *pmapi.ClientManager,
|
||||
clientManager ClientManager,
|
||||
storeCache *store.Cache,
|
||||
storeDir string,
|
||||
) (u *User, err error) {
|
||||
@ -139,7 +139,7 @@ func (u *User) init(idleUpdates chan interface{}) (err error) {
|
||||
}
|
||||
u.store = nil
|
||||
}
|
||||
store, err := store.New(u.panicHandler, u, u.client(), u.listener, u.storePath, u.storeCache)
|
||||
store, err := store.New(u.panicHandler, u, u.clientManager, u.listener, u.storePath, u.storeCache)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create store")
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ func TestNewUserNoCredentialsStore(t *testing.T) {
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
|
||||
|
||||
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
|
||||
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||
a.Error(t, err)
|
||||
}
|
||||
|
||||
@ -153,10 +153,10 @@ func TestNewUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func checkNewUser(m mocks) {
|
||||
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
|
||||
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
_ = user.init(nil, m.pmapiClient)
|
||||
_ = user.init(nil)
|
||||
|
||||
waitForEvents()
|
||||
|
||||
@ -164,10 +164,10 @@ func checkNewUser(m mocks) {
|
||||
}
|
||||
|
||||
func checkNewUserDisconnected(m mocks) {
|
||||
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
|
||||
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
_ = user.init(nil, m.pmapiClient)
|
||||
_ = user.init(nil)
|
||||
|
||||
waitForEvents()
|
||||
|
||||
|
||||
@ -43,10 +43,10 @@ func testNewUser(m mocks) *User {
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil)
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
|
||||
|
||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
|
||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
err = user.init(nil, m.pmapiClient)
|
||||
err = user.init(nil)
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
return user
|
||||
@ -69,10 +69,10 @@ func testNewUserForLogout(m mocks) *User {
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
|
||||
|
||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
|
||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
err = user.init(nil, m.pmapiClient)
|
||||
err = user.init(nil)
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
return user
|
||||
|
||||
@ -109,5 +109,9 @@ func (storeAddress *Address) AddressID() string {
|
||||
|
||||
// APIAddress returns the `pmapi.Address` struct.
|
||||
func (storeAddress *Address) APIAddress() *pmapi.Address {
|
||||
return storeAddress.store.api.Addresses().ByEmail(storeAddress.address)
|
||||
return storeAddress.client().Addresses().ByEmail(storeAddress.address)
|
||||
}
|
||||
|
||||
func (storeAddress *Address) client() pmapi.Client {
|
||||
return storeAddress.store.client()
|
||||
}
|
||||
|
||||
@ -42,13 +42,12 @@ type eventLoop struct {
|
||||
|
||||
log *logrus.Entry
|
||||
|
||||
store *Store
|
||||
apiClient PMAPIProvider
|
||||
user BridgeUser
|
||||
events listener.Listener
|
||||
store *Store
|
||||
user BridgeUser
|
||||
events listener.Listener
|
||||
}
|
||||
|
||||
func newEventLoop(cache *Cache, store *Store, api PMAPIProvider, user BridgeUser, events listener.Listener) *eventLoop {
|
||||
func newEventLoop(cache *Cache, store *Store, user BridgeUser, events listener.Listener) *eventLoop {
|
||||
eventLog := log.WithField("userID", user.ID())
|
||||
eventLog.Trace("Creating new event loop")
|
||||
|
||||
@ -60,10 +59,9 @@ func newEventLoop(cache *Cache, store *Store, api PMAPIProvider, user BridgeUser
|
||||
|
||||
log: eventLog,
|
||||
|
||||
store: store,
|
||||
apiClient: api,
|
||||
user: user,
|
||||
events: events,
|
||||
store: store,
|
||||
user: user,
|
||||
events: events,
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,10 +69,14 @@ func (loop *eventLoop) IsRunning() bool {
|
||||
return loop.isRunning
|
||||
}
|
||||
|
||||
func (loop *eventLoop) client() pmapi.Client {
|
||||
return loop.store.client()
|
||||
}
|
||||
|
||||
func (loop *eventLoop) setFirstEventID() (err error) {
|
||||
loop.log.Info("Setting first event ID")
|
||||
|
||||
event, err := loop.apiClient.GetEvent("")
|
||||
event, err := loop.client().GetEvent("")
|
||||
if err != nil {
|
||||
loop.log.WithError(err).Error("Could not get latest event ID")
|
||||
return
|
||||
@ -240,7 +242,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
loop.pollCounter++
|
||||
|
||||
var event *pmapi.Event
|
||||
if event, err = loop.apiClient.GetEvent(loop.currentEventID); err != nil {
|
||||
if event, err = loop.client().GetEvent(loop.currentEventID); err != nil {
|
||||
return false, errors.Wrap(err, "failed to get event")
|
||||
}
|
||||
|
||||
@ -321,7 +323,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
|
||||
log.Debug("Processing address change event")
|
||||
|
||||
// Get old addresses for comparisons before updating user.
|
||||
oldList := loop.apiClient.Addresses()
|
||||
oldList := loop.client().Addresses()
|
||||
|
||||
if err = loop.user.UpdateUser(); err != nil {
|
||||
if logoutErr := loop.user.Logout(); logoutErr != nil {
|
||||
@ -363,7 +365,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
|
||||
}
|
||||
}
|
||||
|
||||
if err = loop.store.createOrUpdateAddressInfo(loop.apiClient.Addresses()); err != nil {
|
||||
if err = loop.store.createOrUpdateAddressInfo(loop.client().Addresses()); err != nil {
|
||||
return errors.Wrap(err, "failed to update address IDs in store")
|
||||
}
|
||||
|
||||
@ -430,7 +432,7 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
|
||||
|
||||
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
|
||||
|
||||
if msg, err = loop.apiClient.GetMessage(message.ID); err != nil {
|
||||
if msg, err = loop.client().GetMessage(message.ID); err != nil {
|
||||
if err == pmapi.ErrNoSuchAPIID {
|
||||
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
|
||||
err = nil
|
||||
|
||||
@ -39,21 +39,21 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
|
||||
// Doesn't matter which IDs are used.
|
||||
// This test is trying to see whether event loop will immediately process
|
||||
// next event if there is `More` of them.
|
||||
m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
|
||||
m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
|
||||
EventID: "event50",
|
||||
More: 1,
|
||||
}, nil),
|
||||
m.api.EXPECT().GetEvent("event50").Return(&pmapi.Event{
|
||||
m.client.EXPECT().GetEvent("event50").Return(&pmapi.Event{
|
||||
EventID: "event70",
|
||||
More: 0,
|
||||
}, nil),
|
||||
m.api.EXPECT().GetEvent("event70").Return(&pmapi.Event{
|
||||
m.client.EXPECT().GetEvent("event70").Return(&pmapi.Event{
|
||||
EventID: "event71",
|
||||
More: 0,
|
||||
}, nil),
|
||||
)
|
||||
m.newStoreNoEvents(true)
|
||||
m.api.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
m.client.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
|
||||
// Event loop runs in goroutine and will be stopped by deferred mock clearing.
|
||||
go m.store.eventLoop.start()
|
||||
@ -78,12 +78,12 @@ func TestEventLoopUpdateMessageFromLoop(t *testing.T) {
|
||||
newSubject := "new subject"
|
||||
|
||||
// First sync will add message with old subject to database.
|
||||
m.api.EXPECT().GetMessage("msg1").Return(&pmapi.Message{
|
||||
m.client.EXPECT().GetMessage("msg1").Return(&pmapi.Message{
|
||||
ID: "msg1",
|
||||
Subject: subject,
|
||||
}, nil)
|
||||
// Event will update the subject.
|
||||
m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
|
||||
m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
|
||||
EventID: "event1",
|
||||
Messages: []*pmapi.EventMessage{{
|
||||
EventItem: pmapi.EventItem{
|
||||
|
||||
@ -258,8 +258,8 @@ func (storeMailbox *Mailbox) pollNow() {
|
||||
}
|
||||
|
||||
// api is a proxy for the store's `PMAPIProvider`.
|
||||
func (storeMailbox *Mailbox) api() PMAPIProvider {
|
||||
return storeMailbox.store.api
|
||||
func (storeMailbox *Mailbox) client() pmapi.Client {
|
||||
return storeMailbox.store.client()
|
||||
}
|
||||
|
||||
// update is a proxy for the store's db's `Update`.
|
||||
|
||||
@ -37,7 +37,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
|
||||
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
|
||||
// wrapping it.
|
||||
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
|
||||
msg, err := storeMailbox.api().GetMessage(apiID)
|
||||
msg, err := storeMailbox.client().GetMessage(apiID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -62,7 +62,7 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
|
||||
LabelIDs: labelIDs,
|
||||
}
|
||||
|
||||
res, err := storeMailbox.api().Import([]*pmapi.ImportMsgReq{importReqs})
|
||||
res, err := storeMailbox.client().Import([]*pmapi.ImportMsgReq{importReqs})
|
||||
if err == nil && len(res) > 0 {
|
||||
msg.ID = res[0].MessageID
|
||||
}
|
||||
@ -79,7 +79,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Labeling messages")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.api().LabelMessages(apiIDs, storeMailbox.labelID)
|
||||
return storeMailbox.client().LabelMessages(apiIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// UnlabelMessages removes the label by calling an API.
|
||||
@ -92,7 +92,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Unlabeling messages")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID)
|
||||
return storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// MarkMessagesRead marks the message read by calling an API.
|
||||
@ -116,7 +116,7 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
|
||||
ids = append(ids, apiID)
|
||||
}
|
||||
}
|
||||
return storeMailbox.api().MarkMessagesRead(ids)
|
||||
return storeMailbox.client().MarkMessagesRead(ids)
|
||||
}
|
||||
|
||||
// MarkMessagesUnread marks the message unread by calling an API.
|
||||
@ -128,7 +128,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as unread")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.api().MarkMessagesUnread(apiIDs)
|
||||
return storeMailbox.client().MarkMessagesUnread(apiIDs)
|
||||
}
|
||||
|
||||
// MarkMessagesStarred adds the Starred label by calling an API.
|
||||
@ -141,7 +141,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as starred")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.api().LabelMessages(apiIDs, pmapi.StarredLabel)
|
||||
return storeMailbox.client().LabelMessages(apiIDs, pmapi.StarredLabel)
|
||||
}
|
||||
|
||||
// MarkMessagesUnstarred removes the Starred label by calling an API.
|
||||
@ -154,7 +154,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as unstarred")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.api().UnlabelMessages(apiIDs, pmapi.StarredLabel)
|
||||
return storeMailbox.client().UnlabelMessages(apiIDs, pmapi.StarredLabel)
|
||||
}
|
||||
|
||||
// DeleteMessages deletes messages.
|
||||
@ -197,21 +197,21 @@ func (storeMailbox *Mailbox) DeleteMessages(apiIDs []string) error {
|
||||
}
|
||||
}
|
||||
if len(messageIDsToUnlabel) > 0 {
|
||||
if err := storeMailbox.api().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
||||
if err := storeMailbox.client().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
||||
log.WithError(err).Warning("Cannot unlabel before deleting")
|
||||
}
|
||||
}
|
||||
if len(messageIDsToDelete) > 0 {
|
||||
if err := storeMailbox.api().DeleteMessages(messageIDsToDelete); err != nil {
|
||||
if err := storeMailbox.client().DeleteMessages(messageIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case pmapi.DraftLabel:
|
||||
if err := storeMailbox.api().DeleteMessages(apiIDs); err != nil {
|
||||
if err := storeMailbox.client().DeleteMessages(apiIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if err := storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil {
|
||||
if err := storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,7 +28,6 @@ import (
|
||||
// a specific mailbox with helper functions to get IMAP UID, sequence
|
||||
// numbers and similar.
|
||||
type Message struct {
|
||||
api PMAPIProvider
|
||||
msg *pmapi.Message
|
||||
|
||||
store *Store
|
||||
@ -37,7 +36,6 @@ type Message struct {
|
||||
|
||||
func newStoreMessage(storeMailbox *Mailbox, msg *pmapi.Message) *Message {
|
||||
return &Message{
|
||||
api: storeMailbox.store.api,
|
||||
msg: msg,
|
||||
store: storeMailbox.store,
|
||||
storeMailbox: storeMailbox,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser)
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,ClientManager,BridgeUser)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
@ -7,6 +7,7 @@ package mocks
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
@ -45,6 +46,43 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockClientManager is a mock of ClientManager interface
|
||||
type MockClientManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||
type MockClientManagerMockRecorder struct {
|
||||
mock *MockClientManager
|
||||
}
|
||||
|
||||
// NewMockClientManager creates a new mock instance
|
||||
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||
mock := &MockClientManager{ctrl: ctrl}
|
||||
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||
}
|
||||
|
||||
// MockBridgeUser is a mock of BridgeUser interface
|
||||
type MockBridgeUser struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@ -89,10 +89,10 @@ var (
|
||||
|
||||
// Store is local user storage, which handles the synchronization between IMAP and PM API.
|
||||
type Store struct {
|
||||
panicHandler PanicHandler
|
||||
eventLoop *eventLoop
|
||||
user BridgeUser
|
||||
api PMAPIProvider
|
||||
panicHandler PanicHandler
|
||||
eventLoop *eventLoop
|
||||
user BridgeUser
|
||||
clientManager ClientManager
|
||||
|
||||
log *logrus.Entry
|
||||
|
||||
@ -111,13 +111,13 @@ type Store struct {
|
||||
func New(
|
||||
panicHandler PanicHandler,
|
||||
user BridgeUser,
|
||||
api PMAPIProvider,
|
||||
clientManager ClientManager,
|
||||
events listener.Listener,
|
||||
path string,
|
||||
cache *Cache,
|
||||
) (store *Store, err error) {
|
||||
if user == nil || api == nil || events == nil || cache == nil {
|
||||
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, api, events, cache)
|
||||
if user == nil || clientManager == nil || events == nil || cache == nil {
|
||||
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, clientManager, events, cache)
|
||||
}
|
||||
|
||||
l := log.WithField("user", user.ID())
|
||||
@ -138,14 +138,14 @@ func New(
|
||||
}
|
||||
|
||||
store = &Store{
|
||||
panicHandler: panicHandler,
|
||||
api: api,
|
||||
user: user,
|
||||
cache: cache,
|
||||
filePath: path,
|
||||
db: bdb,
|
||||
lock: &sync.RWMutex{},
|
||||
log: l,
|
||||
panicHandler: panicHandler,
|
||||
clientManager: clientManager,
|
||||
user: user,
|
||||
cache: cache,
|
||||
filePath: path,
|
||||
db: bdb,
|
||||
lock: &sync.RWMutex{},
|
||||
log: l,
|
||||
}
|
||||
|
||||
if err = store.init(firstInit); err != nil {
|
||||
@ -158,7 +158,7 @@ func New(
|
||||
}
|
||||
|
||||
if user.IsConnected() {
|
||||
store.eventLoop = newEventLoop(cache, store, api, user, events)
|
||||
store.eventLoop = newEventLoop(cache, store, user, events)
|
||||
go func() {
|
||||
defer store.panicHandler.HandlePanic()
|
||||
store.eventLoop.start()
|
||||
@ -261,10 +261,14 @@ func (store *Store) init(firstInit bool) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *Store) client() pmapi.Client {
|
||||
return store.clientManager.GetClient(store.UserID())
|
||||
}
|
||||
|
||||
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
|
||||
// the API is unavailable for whatever reason it tries to fetch the labels locally.
|
||||
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
|
||||
if labels, err = store.api.ListLabels(); err != nil {
|
||||
if labels, err = store.client().ListLabels(); err != nil {
|
||||
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
|
||||
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
|
||||
store.log.WithError(err).Error("Cannot list local labels")
|
||||
|
||||
@ -24,9 +24,9 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
bridgemocks "github.com/ProtonMail/proton-bridge/internal/bridge/mocks"
|
||||
storeMocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
|
||||
storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -43,12 +43,13 @@ const (
|
||||
type mocksForStore struct {
|
||||
tb testing.TB
|
||||
|
||||
ctrl *gomock.Controller
|
||||
events *storeMocks.MockListener
|
||||
api *bridgemocks.MockPMAPIProvider
|
||||
user *storeMocks.MockBridgeUser
|
||||
panicHandler *storeMocks.MockPanicHandler
|
||||
store *Store
|
||||
ctrl *gomock.Controller
|
||||
events *storemocks.MockListener
|
||||
user *storemocks.MockBridgeUser
|
||||
client *pmapimocks.MockClient
|
||||
clientManager *storemocks.MockClientManager
|
||||
panicHandler *storemocks.MockPanicHandler
|
||||
store *Store
|
||||
|
||||
tmpDir string
|
||||
cache *Cache
|
||||
@ -57,12 +58,13 @@ type mocksForStore struct {
|
||||
func initMocks(tb testing.TB) (*mocksForStore, func()) {
|
||||
ctrl := gomock.NewController(tb)
|
||||
mocks := &mocksForStore{
|
||||
tb: tb,
|
||||
ctrl: ctrl,
|
||||
events: storeMocks.NewMockListener(ctrl),
|
||||
api: bridgemocks.NewMockPMAPIProvider(ctrl),
|
||||
user: storeMocks.NewMockBridgeUser(ctrl),
|
||||
panicHandler: storeMocks.NewMockPanicHandler(ctrl),
|
||||
tb: tb,
|
||||
ctrl: ctrl,
|
||||
events: storemocks.NewMockListener(ctrl),
|
||||
user: storemocks.NewMockBridgeUser(ctrl),
|
||||
client: pmapimocks.NewMockClient(ctrl),
|
||||
clientManager: storemocks.NewMockClientManager(ctrl),
|
||||
panicHandler: storemocks.NewMockPanicHandler(ctrl),
|
||||
}
|
||||
|
||||
// Called during clean-up.
|
||||
@ -92,13 +94,15 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar
|
||||
mocks.user.EXPECT().IsConnected().Return(true)
|
||||
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
|
||||
|
||||
mocks.api.EXPECT().Addresses().Return(pmapi.AddressList{
|
||||
mocks.clientManager.EXPECT().GetClient("userID").AnyTimes().Return(mocks.client)
|
||||
|
||||
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
|
||||
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
|
||||
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
|
||||
})
|
||||
mocks.api.EXPECT().ListLabels()
|
||||
mocks.api.EXPECT().CountMessages("")
|
||||
mocks.api.EXPECT().GetEvent(gomock.Any()).
|
||||
mocks.client.EXPECT().ListLabels()
|
||||
mocks.client.EXPECT().CountMessages("")
|
||||
mocks.client.EXPECT().GetEvent(gomock.Any()).
|
||||
Return(&pmapi.Event{
|
||||
EventID: "latestEventID",
|
||||
}, nil).AnyTimes()
|
||||
@ -106,7 +110,7 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar
|
||||
// We want to wait until first sync has finished.
|
||||
firstSyncWaiter := sync.WaitGroup{}
|
||||
firstSyncWaiter.Add(1)
|
||||
mocks.api.EXPECT().
|
||||
mocks.client.EXPECT().
|
||||
ListMessages(gomock.Any()).
|
||||
DoAndReturn(func(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error) {
|
||||
firstSyncWaiter.Done()
|
||||
@ -117,7 +121,7 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar
|
||||
mocks.store, err = New(
|
||||
mocks.panicHandler,
|
||||
mocks.user,
|
||||
mocks.api,
|
||||
mocks.clientManager,
|
||||
mocks.events,
|
||||
filepath.Join(mocks.tmpDir, "mailbox-test.db"),
|
||||
mocks.cache,
|
||||
|
||||
@ -17,42 +17,14 @@
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
import "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
|
||||
type PanicHandler interface {
|
||||
HandlePanic()
|
||||
}
|
||||
|
||||
// PMAPIProvider is subset of pmapi.Client for use by the Store.
|
||||
type PMAPIProvider interface {
|
||||
CurrentUser() (*pmapi.User, error)
|
||||
Addresses() pmapi.AddressList
|
||||
|
||||
GetEvent(eventID string) (*pmapi.Event, error)
|
||||
|
||||
CountMessages(addressID string) ([]*pmapi.MessagesCount, error)
|
||||
ListMessages(filter *pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
|
||||
GetMessage(apiID string) (*pmapi.Message, error)
|
||||
Import([]*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error)
|
||||
DeleteMessages(apiIDs []string) error
|
||||
LabelMessages(apiIDs []string, labelID string) error
|
||||
UnlabelMessages(apiIDs []string, labelID string) error
|
||||
MarkMessagesRead(apiIDs []string) error
|
||||
MarkMessagesUnread(apiIDs []string) error
|
||||
|
||||
CreateDraft(m *pmapi.Message, parent string, action int) (created *pmapi.Message, err error)
|
||||
CreateAttachment(att *pmapi.Attachment, r io.Reader, sig io.Reader) (created *pmapi.Attachment, err error)
|
||||
SendMessage(messageID string, req *pmapi.SendMessageReq) (sent, parent *pmapi.Message, err error)
|
||||
|
||||
ListLabels() ([]*pmapi.Label, error)
|
||||
CreateLabel(label *pmapi.Label) (*pmapi.Label, error)
|
||||
UpdateLabel(label *pmapi.Label) (*pmapi.Label, error)
|
||||
DeleteLabel(labelID string) error
|
||||
EmptyFolder(labelID string, addressID string) error
|
||||
type ClientManager interface {
|
||||
GetClient(userID string) pmapi.Client
|
||||
}
|
||||
|
||||
// BridgeUser is subset of bridge.User for use by the Store.
|
||||
|
||||
@ -24,7 +24,7 @@ func (store *Store) UserID() string {
|
||||
|
||||
// GetSpace returns used and total space in bytes.
|
||||
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
||||
apiUser, err := store.api.CurrentUser()
|
||||
apiUser, err := store.client().CurrentUser()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
@ -33,7 +33,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
||||
|
||||
// GetMaxUpload returns max size of attachment in bytes.
|
||||
func (store *Store) GetMaxUpload() (uint, error) {
|
||||
apiUser, err := store.api.CurrentUser()
|
||||
apiUser, err := store.client().CurrentUser()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@ -61,7 +61,7 @@ func (store *Store) GetAddressInfo() (addrs []AddressInfo, err error) {
|
||||
}
|
||||
|
||||
// Store does not have address info yet, need to build it first from API.
|
||||
addressList := store.api.Addresses()
|
||||
addressList := store.client().Addresses()
|
||||
if addressList == nil {
|
||||
err = errors.New("addresses unavailable")
|
||||
store.log.WithError(err).Error("Could not get user addresses from API")
|
||||
|
||||
@ -55,7 +55,7 @@ func (store *Store) createMailbox(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := store.api.CreateLabel(&pmapi.Label{
|
||||
_, err := store.client().CreateLabel(&pmapi.Label{
|
||||
Name: name,
|
||||
Color: color,
|
||||
Exclusive: exclusive,
|
||||
@ -133,7 +133,7 @@ func (store *Store) leastUsedColor() string {
|
||||
func (store *Store) updateMailbox(labelID, newName, color string) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
|
||||
_, err := store.api.UpdateLabel(&pmapi.Label{
|
||||
_, err := store.client().UpdateLabel(&pmapi.Label{
|
||||
ID: labelID,
|
||||
Name: newName,
|
||||
Color: color,
|
||||
@ -150,15 +150,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
|
||||
var err error
|
||||
switch labelID {
|
||||
case pmapi.SpamLabel:
|
||||
err = store.api.EmptyFolder(pmapi.SpamLabel, addressID)
|
||||
err = store.client().EmptyFolder(pmapi.SpamLabel, addressID)
|
||||
case pmapi.TrashLabel:
|
||||
err = store.api.EmptyFolder(pmapi.TrashLabel, addressID)
|
||||
err = store.client().EmptyFolder(pmapi.TrashLabel, addressID)
|
||||
default:
|
||||
err = fmt.Errorf("cannot empty mailbox %v", labelID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return store.api.DeleteLabel(labelID)
|
||||
return store.client().DeleteLabel(labelID)
|
||||
}
|
||||
|
||||
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
|
||||
@ -173,7 +173,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
labels, err := store.api.ListLabels()
|
||||
labels, err := store.client().ListLabels()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -54,7 +54,7 @@ func (store *Store) CreateDraft(
|
||||
message.Attachments = nil
|
||||
|
||||
draftAction := store.getDraftAction(message)
|
||||
draft, err := store.api.CreateDraft(message, parentID, draftAction)
|
||||
draft, err := store.client().CreateDraft(message, parentID, draftAction)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create draft")
|
||||
}
|
||||
@ -105,7 +105,7 @@ func (store *Store) createAttachment(kr *pmcrypto.KeyRing, attachment *pmapi.Att
|
||||
return nil, errors.Wrap(err, "failed to encrypt attachment")
|
||||
}
|
||||
|
||||
createdAttachment, err := store.api.CreateAttachment(attachment, encReader, sigReader)
|
||||
createdAttachment, err := store.client().CreateAttachment(attachment, encReader, sigReader)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create attachment")
|
||||
}
|
||||
@ -116,7 +116,7 @@ func (store *Store) createAttachment(kr *pmcrypto.KeyRing, attachment *pmapi.Att
|
||||
// SendMessage sends the message.
|
||||
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
_, _, err := store.api.SendMessage(messageID, req)
|
||||
_, _, err := store.client().SendMessage(messageID, req)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
|
||||
|
||||
// updateCountsFromServer will download and set the counts.
|
||||
func (store *Store) updateCountsFromServer() error {
|
||||
counts, err := store.api.CountMessages("")
|
||||
counts, err := store.client().CountMessages("")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot update counts from server")
|
||||
}
|
||||
@ -144,7 +144,8 @@ func (store *Store) triggerSync() {
|
||||
|
||||
store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started")
|
||||
|
||||
err := syncAllMail(store.panicHandler, store, store.api, syncState)
|
||||
// TODO: Is it okay to pass in a client directly? What if it is logged out in the meantime?
|
||||
err := syncAllMail(store.panicHandler, store, store.client(), syncState)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Store sync failed")
|
||||
return
|
||||
|
||||
@ -464,15 +464,13 @@ func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
||||
return auth, err
|
||||
}
|
||||
|
||||
// Logout instructs the client manager to log out this client.
|
||||
// TODO: Should this even be a client method? Or just a method on the client manager?
|
||||
func (c *client) Logout() {
|
||||
c.cm.LogoutClient(c.userID)
|
||||
}
|
||||
|
||||
// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not.
|
||||
|
||||
// logout logs the current user out.
|
||||
func (c *client) logout() (err error) {
|
||||
// DeleteAuth deletes the API session.
|
||||
func (c *client) DeleteAuth() (err error) {
|
||||
req, err := c.NewRequest("DELETE", "/auth", nil)
|
||||
if err != nil {
|
||||
return
|
||||
@ -490,7 +488,10 @@ func (c *client) logout() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (c *client) clearSensitiveData() {
|
||||
// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not.
|
||||
|
||||
// ClearData clears sensitive data from the client.
|
||||
func (c *client) ClearData() {
|
||||
c.uid = ""
|
||||
c.accessToken = ""
|
||||
c.kr = nil
|
||||
|
||||
@ -98,20 +98,28 @@ type Client interface {
|
||||
Auth(username, password string, info *AuthInfo) (*Auth, error)
|
||||
AuthInfo(username string) (*AuthInfo, error)
|
||||
AuthRefresh(token string) (*Auth, error)
|
||||
Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error)
|
||||
UnlockAddresses(passphrase []byte) error
|
||||
Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error)
|
||||
Logout()
|
||||
DeleteAuth() error
|
||||
ClearData()
|
||||
|
||||
CurrentUser() (*User, error)
|
||||
UpdateUser() (*User, error)
|
||||
Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error)
|
||||
UnlockAddresses(passphrase []byte) error
|
||||
|
||||
GetAddresses() (addresses AddressList, err error)
|
||||
Addresses() AddressList
|
||||
|
||||
Logout()
|
||||
|
||||
GetEvent(eventID string) (*Event, error)
|
||||
|
||||
SendMessage(string, *SendMessageReq) (sent, parent *Message, err error)
|
||||
CreateDraft(m *Message, parent string, action int) (created *Message, err error)
|
||||
Import([]*ImportMsgReq) ([]*ImportMsgRes, error)
|
||||
|
||||
CountMessages(addressID string) ([]*MessagesCount, error)
|
||||
ListMessages(filter *MessagesFilter) ([]*Message, int, error)
|
||||
GetMessage(apiID string) (*Message, error)
|
||||
Import([]*ImportMsgReq) ([]*ImportMsgRes, error)
|
||||
DeleteMessages(apiIDs []string) error
|
||||
LabelMessages(apiIDs []string, labelID string) error
|
||||
UnlabelMessages(apiIDs []string, labelID string) error
|
||||
@ -128,20 +136,17 @@ type Client interface {
|
||||
SendSimpleMetric(category, action, label string) error
|
||||
ReportSentryCrash(reportErr error) (err error)
|
||||
|
||||
Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error)
|
||||
|
||||
GetMailSettings() (MailSettings, error)
|
||||
GetContactEmailByEmail(string, int, int) ([]ContactEmail, error)
|
||||
GetContactByID(string) (Contact, error)
|
||||
DecryptAndVerifyCards([]Card) ([]Card, error)
|
||||
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
|
||||
SendMessage(string, *SendMessageReq) (sent, parent *Message, err error)
|
||||
CreateDraft(m *Message, parent string, action int) (created *Message, err error)
|
||||
CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
|
||||
DeleteAttachment(attID string) (err error)
|
||||
KeyRingForAddressID(string) (kr *pmcrypto.KeyRing)
|
||||
|
||||
GetAttachment(id string) (att io.ReadCloser, err error)
|
||||
CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
|
||||
DeleteAttachment(attID string) (err error)
|
||||
|
||||
KeyRingForAddressID(string) (kr *pmcrypto.KeyRing)
|
||||
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
|
||||
}
|
||||
|
||||
// client is a client of the protonmail API. It implements the Client interface.
|
||||
|
||||
@ -15,10 +15,17 @@ var defaultProxyUseDuration = 24 * time.Hour
|
||||
|
||||
// ClientManager is a manager of clients.
|
||||
type ClientManager struct {
|
||||
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
|
||||
// create other types of clients (e.g. for integration tests).
|
||||
newClient func(userID string) Client
|
||||
|
||||
config *ClientConfig
|
||||
roundTripper http.RoundTripper
|
||||
|
||||
clients map[string]*client
|
||||
// TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests.
|
||||
// But that screws up other things like not being able to clear sensitive info during logout
|
||||
// unless the client interface contains a method for that.
|
||||
clients map[string]Client
|
||||
clientsLocker sync.Locker
|
||||
|
||||
tokens map[string]string
|
||||
@ -38,11 +45,13 @@ type ClientManager struct {
|
||||
proxyUseDuration time.Duration
|
||||
}
|
||||
|
||||
// ClientAuth holds an API auth produced by a Client for a specific user.
|
||||
type ClientAuth struct {
|
||||
UserID string
|
||||
Auth *Auth
|
||||
}
|
||||
|
||||
// tokenExpiration manages the expiration of an access token.
|
||||
type tokenExpiration struct {
|
||||
timer *time.Timer
|
||||
cancel chan (struct{})
|
||||
@ -58,7 +67,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
config: config,
|
||||
roundTripper: http.DefaultTransport,
|
||||
|
||||
clients: make(map[string]*client),
|
||||
clients: make(map[string]Client),
|
||||
clientsLocker: &sync.Mutex{},
|
||||
|
||||
tokens: make(map[string]string),
|
||||
@ -78,11 +87,19 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
proxyUseDuration: defaultProxyUseDuration,
|
||||
}
|
||||
|
||||
cm.newClient = func(userID string) Client {
|
||||
return newClient(cm, userID)
|
||||
}
|
||||
|
||||
go cm.forwardClientAuths()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
|
||||
cm.newClient = f
|
||||
}
|
||||
|
||||
// SetRoundTripper sets the roundtripper used by clients created by this client manager.
|
||||
func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
|
||||
cm.roundTripper = rt
|
||||
@ -100,7 +117,7 @@ func (cm *ClientManager) GetClient(userID string) Client {
|
||||
return client
|
||||
}
|
||||
|
||||
cm.clients[userID] = newClient(cm, userID)
|
||||
cm.clients[userID] = cm.newClient(userID)
|
||||
|
||||
return cm.clients[userID]
|
||||
}
|
||||
@ -108,10 +125,10 @@ func (cm *ClientManager) GetClient(userID string) Client {
|
||||
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created.
|
||||
func (cm *ClientManager) GetAnonymousClient() Client {
|
||||
if client, ok := cm.clients[""]; ok {
|
||||
client.Logout()
|
||||
client.DeleteAuth()
|
||||
}
|
||||
|
||||
cm.clients[""] = newClient(cm, "")
|
||||
cm.clients[""] = cm.newClient("")
|
||||
|
||||
return cm.clients[""]
|
||||
}
|
||||
@ -127,10 +144,10 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
||||
delete(cm.clients, userID)
|
||||
|
||||
go func() {
|
||||
if err := client.logout(); err != nil {
|
||||
// TODO: Try again! This should loop until it succeeds (might fail the first time due to internet).
|
||||
if err := client.DeleteAuth(); err != nil {
|
||||
// TODO: Retry if the request failed.
|
||||
}
|
||||
client.clearSensitiveData()
|
||||
client.ClearData()
|
||||
cm.clearToken(userID)
|
||||
}()
|
||||
|
||||
|
||||
@ -256,6 +256,21 @@ func (mr *MockClientMockRecorder) EmptyFolder(arg0, arg1 interface{}) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmptyFolder", reflect.TypeOf((*MockClient)(nil).EmptyFolder), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetAddresses mocks base method
|
||||
func (m *MockClient) GetAddresses() (pmapi.AddressList, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAddresses")
|
||||
ret0, _ := ret[0].(pmapi.AddressList)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAddresses indicates an expected call of GetAddresses
|
||||
func (mr *MockClientMockRecorder) GetAddresses() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddresses", reflect.TypeOf((*MockClient)(nil).GetAddresses))
|
||||
}
|
||||
|
||||
// GetAttachment mocks base method
|
||||
func (m *MockClient) GetAttachment(arg0 string) (io.ReadCloser, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -21,9 +21,9 @@ import (
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/internal/preferences"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
// GetBridge returns bridge instance.
|
||||
@ -34,10 +34,7 @@ func (ctx *TestContext) GetBridge() *bridge.Bridge {
|
||||
// withBridgeInstance creates a bridge instance for use in the test.
|
||||
// Every TestContext has this by default and thus this doesn't need to be exported.
|
||||
func (ctx *TestContext) withBridgeInstance() {
|
||||
pmapiFactory := func(userID string) pmapi.Client {
|
||||
return ctx.pmapiController.GetClient(userID)
|
||||
}
|
||||
ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, pmapiFactory)
|
||||
ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, ctx.clientManager)
|
||||
ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data")
|
||||
}
|
||||
|
||||
@ -62,7 +59,7 @@ func newBridgeInstance(
|
||||
cfg *fakeConfig,
|
||||
credStore bridge.CredentialsStorer,
|
||||
eventListener listener.Listener,
|
||||
pmapiFactory bridge.PMAPIProviderFactory,
|
||||
clientManager bridge.ClientManager,
|
||||
) *bridge.Bridge {
|
||||
version := os.Getenv("VERSION")
|
||||
bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "")
|
||||
@ -70,7 +67,7 @@ func newBridgeInstance(
|
||||
panicHandler := &panicHandler{t: t}
|
||||
pref := preferences.New(cfg)
|
||||
|
||||
return bridge.New(cfg, pref, panicHandler, eventListener, version, pmapiFactory, credStore)
|
||||
return bridge.New(cfg, pref, panicHandler, eventListener, version, clientManager, credStore)
|
||||
}
|
||||
|
||||
// SetLastBridgeError sets the last error that occurred while executing a bridge action.
|
||||
|
||||
@ -28,7 +28,6 @@ import (
|
||||
|
||||
type fakeConfig struct {
|
||||
dir string
|
||||
tm *pmapi.TokenManager
|
||||
}
|
||||
|
||||
// newFakeConfig creates a temporary folder for files.
|
||||
@ -41,7 +40,6 @@ func newFakeConfig() *fakeConfig {
|
||||
|
||||
return &fakeConfig{
|
||||
dir: dir,
|
||||
tm: pmapi.NewTokenManager(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,8 +51,6 @@ func (c *fakeConfig) GetAPIConfig() *pmapi.ClientConfig {
|
||||
AppVersion: "Bridge_" + os.Getenv("VERSION"),
|
||||
ClientID: "bridge",
|
||||
SentryDSN: "",
|
||||
// TokenManager should not be required, but PMAPI still doesn't handle not-set cases everywhere.
|
||||
TokenManager: c.tm,
|
||||
}
|
||||
}
|
||||
func (c *fakeConfig) GetDBDir() string {
|
||||
|
||||
@ -21,6 +21,7 @@ package context
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/ProtonMail/proton-bridge/test/accounts"
|
||||
"github.com/ProtonMail/proton-bridge/test/mocks"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -44,6 +45,7 @@ type TestContext struct {
|
||||
bridge *bridge.Bridge
|
||||
bridgeLastError error
|
||||
credStore bridge.CredentialsStorer
|
||||
clientManager *pmapi.ClientManager
|
||||
|
||||
// IMAP related variables.
|
||||
imapAddr string
|
||||
@ -70,11 +72,14 @@ func New() *TestContext {
|
||||
|
||||
cfg := newFakeConfig()
|
||||
|
||||
cm := pmapi.NewClientManager(cfg.GetAPIConfig())
|
||||
|
||||
ctx := &TestContext{
|
||||
t: &bddT{},
|
||||
cfg: cfg,
|
||||
listener: listener.New(),
|
||||
pmapiController: newPMAPIController(),
|
||||
pmapiController: newPMAPIController(cm),
|
||||
clientManager: cm,
|
||||
testAccounts: newTestAccounts(),
|
||||
credStore: newFakeCredStore(),
|
||||
imapClients: make(map[string]*mocks.IMAPClient),
|
||||
|
||||
@ -20,14 +20,12 @@ package context
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/ProtonMail/proton-bridge/test/fakeapi"
|
||||
"github.com/ProtonMail/proton-bridge/test/liveapi"
|
||||
)
|
||||
|
||||
type PMAPIController interface {
|
||||
GetClient(userID string) bridge.PMAPIProvider
|
||||
TurnInternetConnectionOff()
|
||||
TurnInternetConnectionOn()
|
||||
AddUser(user *pmapi.User, addresses *pmapi.AddressList, password string, twoFAEnabled bool) error
|
||||
@ -40,19 +38,19 @@ type PMAPIController interface {
|
||||
GetCalls(method, path string) [][]byte
|
||||
}
|
||||
|
||||
func newPMAPIController() PMAPIController {
|
||||
func newPMAPIController(cm *pmapi.ClientManager) PMAPIController {
|
||||
switch os.Getenv(EnvName) {
|
||||
case EnvFake:
|
||||
return newFakePMAPIController()
|
||||
return newFakePMAPIController(cm)
|
||||
case EnvLive:
|
||||
return newLivePMAPIController()
|
||||
return newLivePMAPIController(cm)
|
||||
default:
|
||||
panic("unknown env")
|
||||
}
|
||||
}
|
||||
|
||||
func newFakePMAPIController() PMAPIController {
|
||||
return newFakePMAPIControllerWrap(fakeapi.NewController())
|
||||
func newFakePMAPIController(cm *pmapi.ClientManager) PMAPIController {
|
||||
return newFakePMAPIControllerWrap(fakeapi.NewController(cm))
|
||||
}
|
||||
|
||||
type fakePMAPIControllerWrap struct {
|
||||
@ -63,12 +61,8 @@ func newFakePMAPIControllerWrap(controller *fakeapi.Controller) PMAPIController
|
||||
return &fakePMAPIControllerWrap{Controller: controller}
|
||||
}
|
||||
|
||||
func (s *fakePMAPIControllerWrap) GetClient(userID string) bridge.PMAPIProvider {
|
||||
return s.Controller.GetClient(userID)
|
||||
}
|
||||
|
||||
func newLivePMAPIController() PMAPIController {
|
||||
return newLiveAPIControllerWrap(liveapi.NewController())
|
||||
func newLivePMAPIController(cm *pmapi.ClientManager) PMAPIController {
|
||||
return newLiveAPIControllerWrap(liveapi.NewController(cm))
|
||||
}
|
||||
|
||||
type liveAPIControllerWrap struct {
|
||||
@ -78,7 +72,3 @@ type liveAPIControllerWrap struct {
|
||||
func newLiveAPIControllerWrap(controller *liveapi.Controller) PMAPIController {
|
||||
return &liveAPIControllerWrap{Controller: controller}
|
||||
}
|
||||
|
||||
func (s *liveAPIControllerWrap) GetClient(userID string) bridge.PMAPIProvider {
|
||||
return s.Controller.GetClient(userID)
|
||||
}
|
||||
|
||||
@ -45,3 +45,11 @@ func (api *FakePMAPI) CreateAttachment(attachment *pmapi.Attachment, data io.Rea
|
||||
attachment.KeyPackets = base64.StdEncoding.EncodeToString(bytes)
|
||||
return attachment, nil
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) DeleteAttachment(attachmentID string) error {
|
||||
if err := api.checkAndRecordCall(GET, "/attachments/"+attachmentID, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -141,13 +141,23 @@ func (api *FakePMAPI) AuthRefresh(token string) (*pmapi.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) Logout() error {
|
||||
func (api *FakePMAPI) Logout() {
|
||||
_ = api.DeleteAuth()
|
||||
api.ClearData()
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) DeleteAuth() error {
|
||||
if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Logout will also emit change to auth channel
|
||||
api.sendAuth(nil)
|
||||
api.controller.deleteSession(api.uid)
|
||||
api.unsetUser()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) ClearData() {
|
||||
api.controller.deleteSession(api.uid)
|
||||
api.unsetUser()
|
||||
}
|
||||
|
||||
@ -44,8 +44,8 @@ type Controller struct {
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
func NewController() *Controller {
|
||||
return &Controller{
|
||||
func NewController(cm *pmapi.ClientManager) (cntrl *Controller) {
|
||||
cntrl = &Controller{
|
||||
lock: &sync.RWMutex{},
|
||||
fakeAPIs: []*FakePMAPI{},
|
||||
calls: []*fakeCall{},
|
||||
@ -62,10 +62,12 @@ func NewController() *Controller {
|
||||
|
||||
log: logrus.WithField("pkg", "fakeapi-controller"),
|
||||
}
|
||||
}
|
||||
|
||||
func (cntrl *Controller) GetClient(userID string) *FakePMAPI {
|
||||
fakeAPI := New(cntrl)
|
||||
cntrl.fakeAPIs = append(cntrl.fakeAPIs, fakeAPI)
|
||||
return fakeAPI
|
||||
cm.SetClientConstructor(func(userID string) pmapi.Client {
|
||||
fakeAPI := New(cntrl)
|
||||
cntrl.fakeAPIs = append(cntrl.fakeAPIs, fakeAPI)
|
||||
return fakeAPI
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -42,3 +42,7 @@ func (api *FakePMAPI) SendSimpleMetric(category, action, label string) error {
|
||||
v.Set("Label", label)
|
||||
return api.checkInternetAndRecordCall(GET, "/metrics?"+v.Encode(), nil)
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) ReportSentryCrash(reportErr error) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -50,6 +50,10 @@ func (api *FakePMAPI) UpdateUser() (*pmapi.User, error) {
|
||||
return api.user, nil
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) GetAddresses() (pmapi.AddressList, error) {
|
||||
return *api.addresses, nil
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) Addresses() pmapi.AddressList {
|
||||
return *api.addresses
|
||||
}
|
||||
|
||||
@ -24,7 +24,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func cleanup(client *pmapi.Client) error {
|
||||
func cleanup(client pmapi.Client) error {
|
||||
if err := cleanSystemFolders(client); err != nil {
|
||||
return errors.Wrap(err, "failed to clean system folders")
|
||||
}
|
||||
@ -37,7 +37,7 @@ func cleanup(client *pmapi.Client) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanSystemFolders(client *pmapi.Client) error {
|
||||
func cleanSystemFolders(client pmapi.Client) error {
|
||||
for _, labelID := range []string{pmapi.InboxLabel, pmapi.SentLabel, pmapi.ArchiveLabel, pmapi.AllMailLabel, pmapi.DraftLabel} {
|
||||
for {
|
||||
messages, total, err := client.ListMessages(&pmapi.MessagesFilter{
|
||||
@ -69,7 +69,7 @@ func cleanSystemFolders(client *pmapi.Client) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanCustomLables(client *pmapi.Client) error {
|
||||
func cleanCustomLables(client pmapi.Client) error {
|
||||
labels, err := client.ListLabels()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to list labels")
|
||||
@ -87,7 +87,7 @@ func cleanCustomLables(client *pmapi.Client) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanTrash(client *pmapi.Client) error {
|
||||
func cleanTrash(client pmapi.Client) error {
|
||||
for {
|
||||
_, total, err := client.ListMessages(&pmapi.MessagesFilter{
|
||||
PageSize: 1,
|
||||
@ -110,7 +110,7 @@ func cleanTrash(client *pmapi.Client) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func emptyFolder(client *pmapi.Client, labelID string) error {
|
||||
func emptyFolder(client pmapi.Client, labelID string) error {
|
||||
err := client.EmptyFolder(labelID, "")
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@ -18,9 +18,7 @@
|
||||
package liveapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
@ -30,33 +28,27 @@ type Controller struct {
|
||||
// Internal states.
|
||||
lock *sync.RWMutex
|
||||
calls []*fakeCall
|
||||
pmapiByUsername map[string]*pmapi.Client
|
||||
messageIDsByUsername map[string][]string
|
||||
clientManager *pmapi.ClientManager
|
||||
|
||||
// State controlled by test.
|
||||
noInternetConnection bool
|
||||
}
|
||||
|
||||
func NewController() *Controller {
|
||||
return &Controller{
|
||||
func NewController(cm *pmapi.ClientManager) (cntrl *Controller) {
|
||||
cntrl = &Controller{
|
||||
lock: &sync.RWMutex{},
|
||||
calls: []*fakeCall{},
|
||||
pmapiByUsername: map[string]*pmapi.Client{},
|
||||
messageIDsByUsername: map[string][]string{},
|
||||
clientManager: cm,
|
||||
|
||||
noInternetConnection: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (cntrl *Controller) GetClient(userID string) *pmapi.Client {
|
||||
cfg := &pmapi.ClientConfig{
|
||||
AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")),
|
||||
ClientID: "bridge-test",
|
||||
Transport: &fakeTransport{
|
||||
cntrl: cntrl,
|
||||
transport: http.DefaultTransport,
|
||||
},
|
||||
TokenManager: pmapi.NewTokenManager(),
|
||||
}
|
||||
return pmapi.NewClient(cfg, userID)
|
||||
cm.SetRoundTripper(&fakeTransport{
|
||||
cntrl: cntrl,
|
||||
transport: http.DefaultTransport,
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -36,11 +36,7 @@ var systemLabelNameToID = map[string]string{ //nolint[gochecknoglobals]
|
||||
}
|
||||
|
||||
func (cntrl *Controller) AddUserLabel(username string, label *pmapi.Label) error {
|
||||
client, ok := cntrl.pmapiByUsername[username]
|
||||
if !ok {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
}
|
||||
|
||||
client := cntrl.clientManager.GetClient(username)
|
||||
label.Exclusive = getLabelExclusive(label.Name)
|
||||
label.Name = getLabelNameWithoutPrefix(label.Name)
|
||||
label.Color = pmapi.LabelColors[0]
|
||||
@ -67,11 +63,7 @@ func (cntrl *Controller) getLabelID(username, labelName string) (string, error)
|
||||
return labelID, nil
|
||||
}
|
||||
|
||||
client, ok := cntrl.pmapiByUsername[username]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("user %s does not exist", username)
|
||||
}
|
||||
|
||||
client := cntrl.clientManager.GetClient(username)
|
||||
labels, err := client.ListLabels()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to list labels")
|
||||
|
||||
@ -31,10 +31,7 @@ import (
|
||||
)
|
||||
|
||||
func (cntrl *Controller) AddUserMessage(username string, message *pmapi.Message) error {
|
||||
client, ok := cntrl.pmapiByUsername[username]
|
||||
if !ok {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
}
|
||||
client := cntrl.clientManager.GetClient(username)
|
||||
|
||||
body, err := buildMessage(client, message)
|
||||
if err != nil {
|
||||
@ -64,7 +61,7 @@ func (cntrl *Controller) AddUserMessage(username string, message *pmapi.Message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildMessage(client *pmapi.Client, message *pmapi.Message) (*bytes.Buffer, error) {
|
||||
func buildMessage(client pmapi.Client, message *pmapi.Message) (*bytes.Buffer, error) {
|
||||
if err := encryptMessage(client, message); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to encrypt message")
|
||||
}
|
||||
@ -79,7 +76,7 @@ func buildMessage(client *pmapi.Client, message *pmapi.Message) (*bytes.Buffer,
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func encryptMessage(client *pmapi.Client, message *pmapi.Message) error {
|
||||
func encryptMessage(client pmapi.Client, message *pmapi.Message) error {
|
||||
addresses, err := client.GetAddresses()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get address")
|
||||
|
||||
@ -18,9 +18,7 @@
|
||||
package liveapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ProtonMail/bridge/pkg/pmapi"
|
||||
"github.com/cucumber/godog"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
@ -30,11 +28,7 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList,
|
||||
return godog.ErrPending
|
||||
}
|
||||
|
||||
client := pmapi.NewClient(&pmapi.ClientConfig{
|
||||
AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")),
|
||||
ClientID: "bridge-cntrl",
|
||||
TokenManager: pmapi.NewTokenManager(),
|
||||
}, user.ID)
|
||||
client := cntrl.clientManager.GetClient(user.ID)
|
||||
|
||||
authInfo, err := client.AuthInfo(user.Name)
|
||||
if err != nil {
|
||||
@ -60,6 +54,5 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList,
|
||||
return errors.Wrap(err, "failed to clean user")
|
||||
}
|
||||
|
||||
cntrl.pmapiByUsername[user.Name] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user