feat: make store use ClientManager

This commit is contained in:
James Houlahan
2020-04-07 09:55:28 +02:00
parent f269be4291
commit 042c340881
43 changed files with 414 additions and 264 deletions

View File

@ -48,7 +48,7 @@ type Bridge struct {
panicHandler PanicHandler
events listener.Listener
version string
clientManager *pmapi.ClientManager
clientManager ClientManager
credStorer CredentialsStorer
storeCache *store.Cache
@ -76,7 +76,7 @@ func New(
panicHandler PanicHandler,
eventListener listener.Listener,
version string,
clientManager *pmapi.ClientManager,
clientManager ClientManager,
credStorer CredentialsStorer,
) *Bridge {
log.Trace("Creating new bridge")
@ -185,7 +185,7 @@ func (b *Bridge) watchAPIAuths() {
user, ok := b.hasUser(auth.UserID)
if !ok {
logrus.Info("User is not added to bridge yet")
logrus.WithField("userID", auth.UserID).Info("User not available for auth update")
continue
}

View File

@ -39,7 +39,7 @@ func TestBridgeFinishLoginBadPassword(t *testing.T) {
// Set up mocks for FinishLogin.
err := errors.New("bad password")
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err)
m.pmapiClient.EXPECT().Logout().Return(nil)
m.pmapiClient.EXPECT().Logout()
checkBridgeFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", err)
}

View File

@ -56,6 +56,7 @@ func TestNewBridgeWithDisconnectedUser(t *testing.T) {
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil).Times(2)
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized"))
m.pmapiClient.EXPECT().Addresses().Return(nil)
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient)
checkBridgeNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
}
@ -66,13 +67,14 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) {
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token"))
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.pmapiClient.EXPECT().Logout().Return(nil)
m.pmapiClient.EXPECT().Logout()
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
@ -84,10 +86,14 @@ func TestNewBridgeWithConnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil)
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil)
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
// Set up mocks for store initialisation for the authorized user.
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil)
@ -97,10 +103,6 @@ func TestNewBridgeWithConnectedUser(t *testing.T) {
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
checkBridgeNew(t, m, []*credentials.Credentials{testCredentials})
}

View File

@ -129,11 +129,11 @@ type mocks struct {
config *bridgemocks.MockConfiger
PanicHandler *bridgemocks.MockPanicHandler
prefProvider *bridgemocks.MockPreferenceProvider
clientManager *bridgemocks.MockClientManager
credentialsStore *bridgemocks.MockCredentialsStorer
eventListener *MockListener
pmapiClient *pmapimocks.MockClient
clientManager *pmapimocks.MockClientManager
pmapiClient *pmapimocks.MockClient
storeCache *store.Cache
}
@ -151,11 +151,11 @@ func initMocks(t *testing.T) mocks {
config: bridgemocks.NewMockConfiger(mockCtrl),
PanicHandler: bridgemocks.NewMockPanicHandler(mockCtrl),
prefProvider: bridgemocks.NewMockPreferenceProvider(mockCtrl),
clientManager: bridgemocks.NewMockClientManager(mockCtrl),
credentialsStore: bridgemocks.NewMockCredentialsStorer(mockCtrl),
eventListener: NewMockListener(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
clientManager: pmapimocks.NewMockClientManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()),
}
@ -214,7 +214,7 @@ func testNewBridge(t *testing.T, m mocks) *Bridge {
m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes()
m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes()
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient)
m.clientManager.EXPECT().GetBridgeAuthChannel().Return(make(chan *pmapi.ClientAuth))
bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore)

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,CredentialsStorer)
// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,ClientManager,CredentialsStorer)
// Package mocks is a generated GoMock package.
package mocks
@ -203,6 +203,95 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
}
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// AllowProxy mocks base method
func (m *MockClientManager) AllowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AllowProxy")
}
// AllowProxy indicates an expected call of AllowProxy
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
}
// DisallowProxy mocks base method
func (m *MockClientManager) DisallowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisallowProxy")
}
// DisallowProxy indicates an expected call of DisallowProxy
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
}
// GetAnonymousClient mocks base method
func (m *MockClientManager) GetAnonymousClient() pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAnonymousClient")
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetAnonymousClient indicates an expected call of GetAnonymousClient
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
}
// GetBridgeAuthChannel mocks base method
func (m *MockClientManager) GetBridgeAuthChannel() chan pmapi.ClientAuth {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBridgeAuthChannel")
ret0, _ := ret[0].(chan pmapi.ClientAuth)
return ret0
}
// GetBridgeAuthChannel indicates an expected call of GetBridgeAuthChannel
func (mr *MockClientManagerMockRecorder) GetBridgeAuthChannel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBridgeAuthChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel))
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// MockCredentialsStorer is a mock of CredentialsStorer interface
type MockCredentialsStorer struct {
ctrl *gomock.Controller

View File

@ -51,3 +51,11 @@ type CredentialsStorer interface {
Logout(userID string) error
Delete(userID string) error
}
type ClientManager interface {
GetClient(userID string) pmapi.Client
GetAnonymousClient() pmapi.Client
AllowProxy()
DisallowProxy()
GetBridgeAuthChannel() chan pmapi.ClientAuth
}

View File

@ -41,7 +41,7 @@ type User struct {
log *logrus.Entry
panicHandler PanicHandler
listener listener.Listener
clientManager *pmapi.ClientManager
clientManager ClientManager
credStorer CredentialsStorer
imapUpdatesChannel chan interface{}
@ -66,7 +66,7 @@ func newUser(
userID string,
eventListener listener.Listener,
credStorer CredentialsStorer,
clientManager *pmapi.ClientManager,
clientManager ClientManager,
storeCache *store.Cache,
storeDir string,
) (u *User, err error) {
@ -139,7 +139,7 @@ func (u *User) init(idleUpdates chan interface{}) (err error) {
}
u.store = nil
}
store, err := store.New(u.panicHandler, u, u.client(), u.listener, u.storePath, u.storeCache)
store, err := store.New(u.panicHandler, u, u.clientManager, u.listener, u.storePath, u.storeCache)
if err != nil {
return errors.Wrap(err, "failed to create store")
}

View File

@ -33,7 +33,7 @@ func TestNewUserNoCredentialsStore(t *testing.T) {
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
a.Error(t, err)
}
@ -153,10 +153,10 @@ func TestNewUser(t *testing.T) {
}
func checkNewUser(m mocks) {
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
defer cleanUpUserData(user)
_ = user.init(nil, m.pmapiClient)
_ = user.init(nil)
waitForEvents()
@ -164,10 +164,10 @@ func checkNewUser(m mocks) {
}
func checkNewUserDisconnected(m mocks) {
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
defer cleanUpUserData(user)
_ = user.init(nil, m.pmapiClient)
_ = user.init(nil)
waitForEvents()

View File

@ -43,10 +43,10 @@ func testNewUser(m mocks) *User {
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil)
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
assert.NoError(m.t, err)
err = user.init(nil, m.pmapiClient)
err = user.init(nil)
assert.NoError(m.t, err)
return user
@ -69,10 +69,10 @@ func testNewUserForLogout(m mocks) *User {
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
assert.NoError(m.t, err)
err = user.init(nil, m.pmapiClient)
err = user.init(nil)
assert.NoError(m.t, err)
return user

View File

@ -109,5 +109,9 @@ func (storeAddress *Address) AddressID() string {
// APIAddress returns the `pmapi.Address` struct.
func (storeAddress *Address) APIAddress() *pmapi.Address {
return storeAddress.store.api.Addresses().ByEmail(storeAddress.address)
return storeAddress.client().Addresses().ByEmail(storeAddress.address)
}
func (storeAddress *Address) client() pmapi.Client {
return storeAddress.store.client()
}

View File

@ -42,13 +42,12 @@ type eventLoop struct {
log *logrus.Entry
store *Store
apiClient PMAPIProvider
user BridgeUser
events listener.Listener
store *Store
user BridgeUser
events listener.Listener
}
func newEventLoop(cache *Cache, store *Store, api PMAPIProvider, user BridgeUser, events listener.Listener) *eventLoop {
func newEventLoop(cache *Cache, store *Store, user BridgeUser, events listener.Listener) *eventLoop {
eventLog := log.WithField("userID", user.ID())
eventLog.Trace("Creating new event loop")
@ -60,10 +59,9 @@ func newEventLoop(cache *Cache, store *Store, api PMAPIProvider, user BridgeUser
log: eventLog,
store: store,
apiClient: api,
user: user,
events: events,
store: store,
user: user,
events: events,
}
}
@ -71,10 +69,14 @@ func (loop *eventLoop) IsRunning() bool {
return loop.isRunning
}
func (loop *eventLoop) client() pmapi.Client {
return loop.store.client()
}
func (loop *eventLoop) setFirstEventID() (err error) {
loop.log.Info("Setting first event ID")
event, err := loop.apiClient.GetEvent("")
event, err := loop.client().GetEvent("")
if err != nil {
loop.log.WithError(err).Error("Could not get latest event ID")
return
@ -240,7 +242,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
loop.pollCounter++
var event *pmapi.Event
if event, err = loop.apiClient.GetEvent(loop.currentEventID); err != nil {
if event, err = loop.client().GetEvent(loop.currentEventID); err != nil {
return false, errors.Wrap(err, "failed to get event")
}
@ -321,7 +323,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
log.Debug("Processing address change event")
// Get old addresses for comparisons before updating user.
oldList := loop.apiClient.Addresses()
oldList := loop.client().Addresses()
if err = loop.user.UpdateUser(); err != nil {
if logoutErr := loop.user.Logout(); logoutErr != nil {
@ -363,7 +365,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
}
}
if err = loop.store.createOrUpdateAddressInfo(loop.apiClient.Addresses()); err != nil {
if err = loop.store.createOrUpdateAddressInfo(loop.client().Addresses()); err != nil {
return errors.Wrap(err, "failed to update address IDs in store")
}
@ -430,7 +432,7 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
if msg, err = loop.apiClient.GetMessage(message.ID); err != nil {
if msg, err = loop.client().GetMessage(message.ID); err != nil {
if err == pmapi.ErrNoSuchAPIID {
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
err = nil

View File

@ -39,21 +39,21 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
// Doesn't matter which IDs are used.
// This test is trying to see whether event loop will immediately process
// next event if there is `More` of them.
m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
EventID: "event50",
More: 1,
}, nil),
m.api.EXPECT().GetEvent("event50").Return(&pmapi.Event{
m.client.EXPECT().GetEvent("event50").Return(&pmapi.Event{
EventID: "event70",
More: 0,
}, nil),
m.api.EXPECT().GetEvent("event70").Return(&pmapi.Event{
m.client.EXPECT().GetEvent("event70").Return(&pmapi.Event{
EventID: "event71",
More: 0,
}, nil),
)
m.newStoreNoEvents(true)
m.api.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
m.client.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
// Event loop runs in goroutine and will be stopped by deferred mock clearing.
go m.store.eventLoop.start()
@ -78,12 +78,12 @@ func TestEventLoopUpdateMessageFromLoop(t *testing.T) {
newSubject := "new subject"
// First sync will add message with old subject to database.
m.api.EXPECT().GetMessage("msg1").Return(&pmapi.Message{
m.client.EXPECT().GetMessage("msg1").Return(&pmapi.Message{
ID: "msg1",
Subject: subject,
}, nil)
// Event will update the subject.
m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
EventID: "event1",
Messages: []*pmapi.EventMessage{{
EventItem: pmapi.EventItem{

View File

@ -258,8 +258,8 @@ func (storeMailbox *Mailbox) pollNow() {
}
// api is a proxy for the store's `PMAPIProvider`.
func (storeMailbox *Mailbox) api() PMAPIProvider {
return storeMailbox.store.api
func (storeMailbox *Mailbox) client() pmapi.Client {
return storeMailbox.store.client()
}
// update is a proxy for the store's db's `Update`.

View File

@ -37,7 +37,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
// wrapping it.
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
msg, err := storeMailbox.api().GetMessage(apiID)
msg, err := storeMailbox.client().GetMessage(apiID)
if err != nil {
return nil, err
}
@ -62,7 +62,7 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
LabelIDs: labelIDs,
}
res, err := storeMailbox.api().Import([]*pmapi.ImportMsgReq{importReqs})
res, err := storeMailbox.client().Import([]*pmapi.ImportMsgReq{importReqs})
if err == nil && len(res) > 0 {
msg.ID = res[0].MessageID
}
@ -79,7 +79,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Labeling messages")
defer storeMailbox.pollNow()
return storeMailbox.api().LabelMessages(apiIDs, storeMailbox.labelID)
return storeMailbox.client().LabelMessages(apiIDs, storeMailbox.labelID)
}
// UnlabelMessages removes the label by calling an API.
@ -92,7 +92,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Unlabeling messages")
defer storeMailbox.pollNow()
return storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID)
return storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID)
}
// MarkMessagesRead marks the message read by calling an API.
@ -116,7 +116,7 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
ids = append(ids, apiID)
}
}
return storeMailbox.api().MarkMessagesRead(ids)
return storeMailbox.client().MarkMessagesRead(ids)
}
// MarkMessagesUnread marks the message unread by calling an API.
@ -128,7 +128,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as unread")
defer storeMailbox.pollNow()
return storeMailbox.api().MarkMessagesUnread(apiIDs)
return storeMailbox.client().MarkMessagesUnread(apiIDs)
}
// MarkMessagesStarred adds the Starred label by calling an API.
@ -141,7 +141,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as starred")
defer storeMailbox.pollNow()
return storeMailbox.api().LabelMessages(apiIDs, pmapi.StarredLabel)
return storeMailbox.client().LabelMessages(apiIDs, pmapi.StarredLabel)
}
// MarkMessagesUnstarred removes the Starred label by calling an API.
@ -154,7 +154,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as unstarred")
defer storeMailbox.pollNow()
return storeMailbox.api().UnlabelMessages(apiIDs, pmapi.StarredLabel)
return storeMailbox.client().UnlabelMessages(apiIDs, pmapi.StarredLabel)
}
// DeleteMessages deletes messages.
@ -197,21 +197,21 @@ func (storeMailbox *Mailbox) DeleteMessages(apiIDs []string) error {
}
}
if len(messageIDsToUnlabel) > 0 {
if err := storeMailbox.api().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil {
if err := storeMailbox.client().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil {
log.WithError(err).Warning("Cannot unlabel before deleting")
}
}
if len(messageIDsToDelete) > 0 {
if err := storeMailbox.api().DeleteMessages(messageIDsToDelete); err != nil {
if err := storeMailbox.client().DeleteMessages(messageIDsToDelete); err != nil {
return err
}
}
case pmapi.DraftLabel:
if err := storeMailbox.api().DeleteMessages(apiIDs); err != nil {
if err := storeMailbox.client().DeleteMessages(apiIDs); err != nil {
return err
}
default:
if err := storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil {
if err := storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil {
return err
}
}

View File

@ -28,7 +28,6 @@ import (
// a specific mailbox with helper functions to get IMAP UID, sequence
// numbers and similar.
type Message struct {
api PMAPIProvider
msg *pmapi.Message
store *Store
@ -37,7 +36,6 @@ type Message struct {
func newStoreMessage(storeMailbox *Mailbox, msg *pmapi.Message) *Message {
return &Message{
api: storeMailbox.store.api,
msg: msg,
store: storeMailbox.store,
storeMailbox: storeMailbox,

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser)
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,ClientManager,BridgeUser)
// Package mocks is a generated GoMock package.
package mocks
@ -7,6 +7,7 @@ package mocks
import (
reflect "reflect"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
)
@ -45,6 +46,43 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
}
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// MockBridgeUser is a mock of BridgeUser interface
type MockBridgeUser struct {
ctrl *gomock.Controller

View File

@ -89,10 +89,10 @@ var (
// Store is local user storage, which handles the synchronization between IMAP and PM API.
type Store struct {
panicHandler PanicHandler
eventLoop *eventLoop
user BridgeUser
api PMAPIProvider
panicHandler PanicHandler
eventLoop *eventLoop
user BridgeUser
clientManager ClientManager
log *logrus.Entry
@ -111,13 +111,13 @@ type Store struct {
func New(
panicHandler PanicHandler,
user BridgeUser,
api PMAPIProvider,
clientManager ClientManager,
events listener.Listener,
path string,
cache *Cache,
) (store *Store, err error) {
if user == nil || api == nil || events == nil || cache == nil {
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, api, events, cache)
if user == nil || clientManager == nil || events == nil || cache == nil {
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, clientManager, events, cache)
}
l := log.WithField("user", user.ID())
@ -138,14 +138,14 @@ func New(
}
store = &Store{
panicHandler: panicHandler,
api: api,
user: user,
cache: cache,
filePath: path,
db: bdb,
lock: &sync.RWMutex{},
log: l,
panicHandler: panicHandler,
clientManager: clientManager,
user: user,
cache: cache,
filePath: path,
db: bdb,
lock: &sync.RWMutex{},
log: l,
}
if err = store.init(firstInit); err != nil {
@ -158,7 +158,7 @@ func New(
}
if user.IsConnected() {
store.eventLoop = newEventLoop(cache, store, api, user, events)
store.eventLoop = newEventLoop(cache, store, user, events)
go func() {
defer store.panicHandler.HandlePanic()
store.eventLoop.start()
@ -261,10 +261,14 @@ func (store *Store) init(firstInit bool) (err error) {
return err
}
func (store *Store) client() pmapi.Client {
return store.clientManager.GetClient(store.UserID())
}
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
// the API is unavailable for whatever reason it tries to fetch the labels locally.
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
if labels, err = store.api.ListLabels(); err != nil {
if labels, err = store.client().ListLabels(); err != nil {
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
store.log.WithError(err).Error("Cannot list local labels")

View File

@ -24,9 +24,9 @@ import (
"sync"
"testing"
bridgemocks "github.com/ProtonMail/proton-bridge/internal/bridge/mocks"
storeMocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
@ -43,12 +43,13 @@ const (
type mocksForStore struct {
tb testing.TB
ctrl *gomock.Controller
events *storeMocks.MockListener
api *bridgemocks.MockPMAPIProvider
user *storeMocks.MockBridgeUser
panicHandler *storeMocks.MockPanicHandler
store *Store
ctrl *gomock.Controller
events *storemocks.MockListener
user *storemocks.MockBridgeUser
client *pmapimocks.MockClient
clientManager *storemocks.MockClientManager
panicHandler *storemocks.MockPanicHandler
store *Store
tmpDir string
cache *Cache
@ -57,12 +58,13 @@ type mocksForStore struct {
func initMocks(tb testing.TB) (*mocksForStore, func()) {
ctrl := gomock.NewController(tb)
mocks := &mocksForStore{
tb: tb,
ctrl: ctrl,
events: storeMocks.NewMockListener(ctrl),
api: bridgemocks.NewMockPMAPIProvider(ctrl),
user: storeMocks.NewMockBridgeUser(ctrl),
panicHandler: storeMocks.NewMockPanicHandler(ctrl),
tb: tb,
ctrl: ctrl,
events: storemocks.NewMockListener(ctrl),
user: storemocks.NewMockBridgeUser(ctrl),
client: pmapimocks.NewMockClient(ctrl),
clientManager: storemocks.NewMockClientManager(ctrl),
panicHandler: storemocks.NewMockPanicHandler(ctrl),
}
// Called during clean-up.
@ -92,13 +94,15 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar
mocks.user.EXPECT().IsConnected().Return(true)
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
mocks.api.EXPECT().Addresses().Return(pmapi.AddressList{
mocks.clientManager.EXPECT().GetClient("userID").AnyTimes().Return(mocks.client)
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
})
mocks.api.EXPECT().ListLabels()
mocks.api.EXPECT().CountMessages("")
mocks.api.EXPECT().GetEvent(gomock.Any()).
mocks.client.EXPECT().ListLabels()
mocks.client.EXPECT().CountMessages("")
mocks.client.EXPECT().GetEvent(gomock.Any()).
Return(&pmapi.Event{
EventID: "latestEventID",
}, nil).AnyTimes()
@ -106,7 +110,7 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar
// We want to wait until first sync has finished.
firstSyncWaiter := sync.WaitGroup{}
firstSyncWaiter.Add(1)
mocks.api.EXPECT().
mocks.client.EXPECT().
ListMessages(gomock.Any()).
DoAndReturn(func(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error) {
firstSyncWaiter.Done()
@ -117,7 +121,7 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unpar
mocks.store, err = New(
mocks.panicHandler,
mocks.user,
mocks.api,
mocks.clientManager,
mocks.events,
filepath.Join(mocks.tmpDir, "mailbox-test.db"),
mocks.cache,

View File

@ -17,42 +17,14 @@
package store
import (
"io"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
import "github.com/ProtonMail/proton-bridge/pkg/pmapi"
type PanicHandler interface {
HandlePanic()
}
// PMAPIProvider is subset of pmapi.Client for use by the Store.
type PMAPIProvider interface {
CurrentUser() (*pmapi.User, error)
Addresses() pmapi.AddressList
GetEvent(eventID string) (*pmapi.Event, error)
CountMessages(addressID string) ([]*pmapi.MessagesCount, error)
ListMessages(filter *pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
GetMessage(apiID string) (*pmapi.Message, error)
Import([]*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error)
DeleteMessages(apiIDs []string) error
LabelMessages(apiIDs []string, labelID string) error
UnlabelMessages(apiIDs []string, labelID string) error
MarkMessagesRead(apiIDs []string) error
MarkMessagesUnread(apiIDs []string) error
CreateDraft(m *pmapi.Message, parent string, action int) (created *pmapi.Message, err error)
CreateAttachment(att *pmapi.Attachment, r io.Reader, sig io.Reader) (created *pmapi.Attachment, err error)
SendMessage(messageID string, req *pmapi.SendMessageReq) (sent, parent *pmapi.Message, err error)
ListLabels() ([]*pmapi.Label, error)
CreateLabel(label *pmapi.Label) (*pmapi.Label, error)
UpdateLabel(label *pmapi.Label) (*pmapi.Label, error)
DeleteLabel(labelID string) error
EmptyFolder(labelID string, addressID string) error
type ClientManager interface {
GetClient(userID string) pmapi.Client
}
// BridgeUser is subset of bridge.User for use by the Store.

View File

@ -24,7 +24,7 @@ func (store *Store) UserID() string {
// GetSpace returns used and total space in bytes.
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
apiUser, err := store.api.CurrentUser()
apiUser, err := store.client().CurrentUser()
if err != nil {
return 0, 0, err
}
@ -33,7 +33,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
// GetMaxUpload returns max size of attachment in bytes.
func (store *Store) GetMaxUpload() (uint, error) {
apiUser, err := store.api.CurrentUser()
apiUser, err := store.client().CurrentUser()
if err != nil {
return 0, err
}

View File

@ -61,7 +61,7 @@ func (store *Store) GetAddressInfo() (addrs []AddressInfo, err error) {
}
// Store does not have address info yet, need to build it first from API.
addressList := store.api.Addresses()
addressList := store.client().Addresses()
if addressList == nil {
err = errors.New("addresses unavailable")
store.log.WithError(err).Error("Could not get user addresses from API")

View File

@ -55,7 +55,7 @@ func (store *Store) createMailbox(name string) error {
return nil
}
_, err := store.api.CreateLabel(&pmapi.Label{
_, err := store.client().CreateLabel(&pmapi.Label{
Name: name,
Color: color,
Exclusive: exclusive,
@ -133,7 +133,7 @@ func (store *Store) leastUsedColor() string {
func (store *Store) updateMailbox(labelID, newName, color string) error {
defer store.eventLoop.pollNow()
_, err := store.api.UpdateLabel(&pmapi.Label{
_, err := store.client().UpdateLabel(&pmapi.Label{
ID: labelID,
Name: newName,
Color: color,
@ -150,15 +150,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
var err error
switch labelID {
case pmapi.SpamLabel:
err = store.api.EmptyFolder(pmapi.SpamLabel, addressID)
err = store.client().EmptyFolder(pmapi.SpamLabel, addressID)
case pmapi.TrashLabel:
err = store.api.EmptyFolder(pmapi.TrashLabel, addressID)
err = store.client().EmptyFolder(pmapi.TrashLabel, addressID)
default:
err = fmt.Errorf("cannot empty mailbox %v", labelID)
}
return err
}
return store.api.DeleteLabel(labelID)
return store.client().DeleteLabel(labelID)
}
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
@ -173,7 +173,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
return nil
}
labels, err := store.api.ListLabels()
labels, err := store.client().ListLabels()
if err != nil {
return err
}

View File

@ -54,7 +54,7 @@ func (store *Store) CreateDraft(
message.Attachments = nil
draftAction := store.getDraftAction(message)
draft, err := store.api.CreateDraft(message, parentID, draftAction)
draft, err := store.client().CreateDraft(message, parentID, draftAction)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create draft")
}
@ -105,7 +105,7 @@ func (store *Store) createAttachment(kr *pmcrypto.KeyRing, attachment *pmapi.Att
return nil, errors.Wrap(err, "failed to encrypt attachment")
}
createdAttachment, err := store.api.CreateAttachment(attachment, encReader, sigReader)
createdAttachment, err := store.client().CreateAttachment(attachment, encReader, sigReader)
if err != nil {
return nil, errors.Wrap(err, "failed to create attachment")
}
@ -116,7 +116,7 @@ func (store *Store) createAttachment(kr *pmcrypto.KeyRing, attachment *pmapi.Att
// SendMessage sends the message.
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
defer store.eventLoop.pollNow()
_, _, err := store.api.SendMessage(messageID, req)
_, _, err := store.client().SendMessage(messageID, req)
return err
}

View File

@ -34,7 +34,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
// updateCountsFromServer will download and set the counts.
func (store *Store) updateCountsFromServer() error {
counts, err := store.api.CountMessages("")
counts, err := store.client().CountMessages("")
if err != nil {
return errors.Wrap(err, "cannot update counts from server")
}
@ -144,7 +144,8 @@ func (store *Store) triggerSync() {
store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started")
err := syncAllMail(store.panicHandler, store, store.api, syncState)
// TODO: Is it okay to pass in a client directly? What if it is logged out in the meantime?
err := syncAllMail(store.panicHandler, store, store.client(), syncState)
if err != nil {
log.WithError(err).Error("Store sync failed")
return