feat: make store use ClientManager

This commit is contained in:
James Houlahan
2020-04-07 09:55:28 +02:00
parent f269be4291
commit 042c340881
43 changed files with 414 additions and 264 deletions

View File

@ -464,15 +464,13 @@ func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
return auth, err
}
// Logout instructs the client manager to log out this client.
// TODO: Should this even be a client method? Or just a method on the client manager?
func (c *client) Logout() {
c.cm.LogoutClient(c.userID)
}
// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not.
// logout logs the current user out.
func (c *client) logout() (err error) {
// DeleteAuth deletes the API session.
func (c *client) DeleteAuth() (err error) {
req, err := c.NewRequest("DELETE", "/auth", nil)
if err != nil {
return
@ -490,7 +488,10 @@ func (c *client) logout() (err error) {
return
}
func (c *client) clearSensitiveData() {
// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not.
// ClearData clears sensitive data from the client.
func (c *client) ClearData() {
c.uid = ""
c.accessToken = ""
c.kr = nil

View File

@ -98,20 +98,28 @@ type Client interface {
Auth(username, password string, info *AuthInfo) (*Auth, error)
AuthInfo(username string) (*AuthInfo, error)
AuthRefresh(token string) (*Auth, error)
Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error)
UnlockAddresses(passphrase []byte) error
Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error)
Logout()
DeleteAuth() error
ClearData()
CurrentUser() (*User, error)
UpdateUser() (*User, error)
Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error)
UnlockAddresses(passphrase []byte) error
GetAddresses() (addresses AddressList, err error)
Addresses() AddressList
Logout()
GetEvent(eventID string) (*Event, error)
SendMessage(string, *SendMessageReq) (sent, parent *Message, err error)
CreateDraft(m *Message, parent string, action int) (created *Message, err error)
Import([]*ImportMsgReq) ([]*ImportMsgRes, error)
CountMessages(addressID string) ([]*MessagesCount, error)
ListMessages(filter *MessagesFilter) ([]*Message, int, error)
GetMessage(apiID string) (*Message, error)
Import([]*ImportMsgReq) ([]*ImportMsgRes, error)
DeleteMessages(apiIDs []string) error
LabelMessages(apiIDs []string, labelID string) error
UnlabelMessages(apiIDs []string, labelID string) error
@ -128,20 +136,17 @@ type Client interface {
SendSimpleMetric(category, action, label string) error
ReportSentryCrash(reportErr error) (err error)
Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error)
GetMailSettings() (MailSettings, error)
GetContactEmailByEmail(string, int, int) ([]ContactEmail, error)
GetContactByID(string) (Contact, error)
DecryptAndVerifyCards([]Card) ([]Card, error)
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
SendMessage(string, *SendMessageReq) (sent, parent *Message, err error)
CreateDraft(m *Message, parent string, action int) (created *Message, err error)
CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
DeleteAttachment(attID string) (err error)
KeyRingForAddressID(string) (kr *pmcrypto.KeyRing)
GetAttachment(id string) (att io.ReadCloser, err error)
CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
DeleteAttachment(attID string) (err error)
KeyRingForAddressID(string) (kr *pmcrypto.KeyRing)
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
}
// client is a client of the protonmail API. It implements the Client interface.

View File

@ -15,10 +15,17 @@ var defaultProxyUseDuration = 24 * time.Hour
// ClientManager is a manager of clients.
type ClientManager struct {
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
// create other types of clients (e.g. for integration tests).
newClient func(userID string) Client
config *ClientConfig
roundTripper http.RoundTripper
clients map[string]*client
// TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests.
// But that screws up other things like not being able to clear sensitive info during logout
// unless the client interface contains a method for that.
clients map[string]Client
clientsLocker sync.Locker
tokens map[string]string
@ -38,11 +45,13 @@ type ClientManager struct {
proxyUseDuration time.Duration
}
// ClientAuth holds an API auth produced by a Client for a specific user.
type ClientAuth struct {
UserID string
Auth *Auth
}
// tokenExpiration manages the expiration of an access token.
type tokenExpiration struct {
timer *time.Timer
cancel chan (struct{})
@ -58,7 +67,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
config: config,
roundTripper: http.DefaultTransport,
clients: make(map[string]*client),
clients: make(map[string]Client),
clientsLocker: &sync.Mutex{},
tokens: make(map[string]string),
@ -78,11 +87,19 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
proxyUseDuration: defaultProxyUseDuration,
}
cm.newClient = func(userID string) Client {
return newClient(cm, userID)
}
go cm.forwardClientAuths()
return
}
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
cm.newClient = f
}
// SetRoundTripper sets the roundtripper used by clients created by this client manager.
func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
cm.roundTripper = rt
@ -100,7 +117,7 @@ func (cm *ClientManager) GetClient(userID string) Client {
return client
}
cm.clients[userID] = newClient(cm, userID)
cm.clients[userID] = cm.newClient(userID)
return cm.clients[userID]
}
@ -108,10 +125,10 @@ func (cm *ClientManager) GetClient(userID string) Client {
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created.
func (cm *ClientManager) GetAnonymousClient() Client {
if client, ok := cm.clients[""]; ok {
client.Logout()
client.DeleteAuth()
}
cm.clients[""] = newClient(cm, "")
cm.clients[""] = cm.newClient("")
return cm.clients[""]
}
@ -127,10 +144,10 @@ func (cm *ClientManager) LogoutClient(userID string) {
delete(cm.clients, userID)
go func() {
if err := client.logout(); err != nil {
// TODO: Try again! This should loop until it succeeds (might fail the first time due to internet).
if err := client.DeleteAuth(); err != nil {
// TODO: Retry if the request failed.
}
client.clearSensitiveData()
client.ClearData()
cm.clearToken(userID)
}()

View File

@ -256,6 +256,21 @@ func (mr *MockClientMockRecorder) EmptyFolder(arg0, arg1 interface{}) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmptyFolder", reflect.TypeOf((*MockClient)(nil).EmptyFolder), arg0, arg1)
}
// GetAddresses mocks base method
func (m *MockClient) GetAddresses() (pmapi.AddressList, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAddresses")
ret0, _ := ret[0].(pmapi.AddressList)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAddresses indicates an expected call of GetAddresses
func (mr *MockClientMockRecorder) GetAddresses() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddresses", reflect.TypeOf((*MockClient)(nil).GetAddresses))
}
// GetAttachment mocks base method
func (m *MockClient) GetAttachment(arg0 string) (io.ReadCloser, error) {
m.ctrl.T.Helper()