mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
feat: make store use ClientManager
This commit is contained in:
@ -464,15 +464,13 @@ func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
||||
return auth, err
|
||||
}
|
||||
|
||||
// Logout instructs the client manager to log out this client.
|
||||
// TODO: Should this even be a client method? Or just a method on the client manager?
|
||||
func (c *client) Logout() {
|
||||
c.cm.LogoutClient(c.userID)
|
||||
}
|
||||
|
||||
// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not.
|
||||
|
||||
// logout logs the current user out.
|
||||
func (c *client) logout() (err error) {
|
||||
// DeleteAuth deletes the API session.
|
||||
func (c *client) DeleteAuth() (err error) {
|
||||
req, err := c.NewRequest("DELETE", "/auth", nil)
|
||||
if err != nil {
|
||||
return
|
||||
@ -490,7 +488,10 @@ func (c *client) logout() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (c *client) clearSensitiveData() {
|
||||
// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not.
|
||||
|
||||
// ClearData clears sensitive data from the client.
|
||||
func (c *client) ClearData() {
|
||||
c.uid = ""
|
||||
c.accessToken = ""
|
||||
c.kr = nil
|
||||
|
||||
@ -98,20 +98,28 @@ type Client interface {
|
||||
Auth(username, password string, info *AuthInfo) (*Auth, error)
|
||||
AuthInfo(username string) (*AuthInfo, error)
|
||||
AuthRefresh(token string) (*Auth, error)
|
||||
Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error)
|
||||
UnlockAddresses(passphrase []byte) error
|
||||
Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error)
|
||||
Logout()
|
||||
DeleteAuth() error
|
||||
ClearData()
|
||||
|
||||
CurrentUser() (*User, error)
|
||||
UpdateUser() (*User, error)
|
||||
Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error)
|
||||
UnlockAddresses(passphrase []byte) error
|
||||
|
||||
GetAddresses() (addresses AddressList, err error)
|
||||
Addresses() AddressList
|
||||
|
||||
Logout()
|
||||
|
||||
GetEvent(eventID string) (*Event, error)
|
||||
|
||||
SendMessage(string, *SendMessageReq) (sent, parent *Message, err error)
|
||||
CreateDraft(m *Message, parent string, action int) (created *Message, err error)
|
||||
Import([]*ImportMsgReq) ([]*ImportMsgRes, error)
|
||||
|
||||
CountMessages(addressID string) ([]*MessagesCount, error)
|
||||
ListMessages(filter *MessagesFilter) ([]*Message, int, error)
|
||||
GetMessage(apiID string) (*Message, error)
|
||||
Import([]*ImportMsgReq) ([]*ImportMsgRes, error)
|
||||
DeleteMessages(apiIDs []string) error
|
||||
LabelMessages(apiIDs []string, labelID string) error
|
||||
UnlabelMessages(apiIDs []string, labelID string) error
|
||||
@ -128,20 +136,17 @@ type Client interface {
|
||||
SendSimpleMetric(category, action, label string) error
|
||||
ReportSentryCrash(reportErr error) (err error)
|
||||
|
||||
Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error)
|
||||
|
||||
GetMailSettings() (MailSettings, error)
|
||||
GetContactEmailByEmail(string, int, int) ([]ContactEmail, error)
|
||||
GetContactByID(string) (Contact, error)
|
||||
DecryptAndVerifyCards([]Card) ([]Card, error)
|
||||
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
|
||||
SendMessage(string, *SendMessageReq) (sent, parent *Message, err error)
|
||||
CreateDraft(m *Message, parent string, action int) (created *Message, err error)
|
||||
CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
|
||||
DeleteAttachment(attID string) (err error)
|
||||
KeyRingForAddressID(string) (kr *pmcrypto.KeyRing)
|
||||
|
||||
GetAttachment(id string) (att io.ReadCloser, err error)
|
||||
CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
|
||||
DeleteAttachment(attID string) (err error)
|
||||
|
||||
KeyRingForAddressID(string) (kr *pmcrypto.KeyRing)
|
||||
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
|
||||
}
|
||||
|
||||
// client is a client of the protonmail API. It implements the Client interface.
|
||||
|
||||
@ -15,10 +15,17 @@ var defaultProxyUseDuration = 24 * time.Hour
|
||||
|
||||
// ClientManager is a manager of clients.
|
||||
type ClientManager struct {
|
||||
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
|
||||
// create other types of clients (e.g. for integration tests).
|
||||
newClient func(userID string) Client
|
||||
|
||||
config *ClientConfig
|
||||
roundTripper http.RoundTripper
|
||||
|
||||
clients map[string]*client
|
||||
// TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests.
|
||||
// But that screws up other things like not being able to clear sensitive info during logout
|
||||
// unless the client interface contains a method for that.
|
||||
clients map[string]Client
|
||||
clientsLocker sync.Locker
|
||||
|
||||
tokens map[string]string
|
||||
@ -38,11 +45,13 @@ type ClientManager struct {
|
||||
proxyUseDuration time.Duration
|
||||
}
|
||||
|
||||
// ClientAuth holds an API auth produced by a Client for a specific user.
|
||||
type ClientAuth struct {
|
||||
UserID string
|
||||
Auth *Auth
|
||||
}
|
||||
|
||||
// tokenExpiration manages the expiration of an access token.
|
||||
type tokenExpiration struct {
|
||||
timer *time.Timer
|
||||
cancel chan (struct{})
|
||||
@ -58,7 +67,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
config: config,
|
||||
roundTripper: http.DefaultTransport,
|
||||
|
||||
clients: make(map[string]*client),
|
||||
clients: make(map[string]Client),
|
||||
clientsLocker: &sync.Mutex{},
|
||||
|
||||
tokens: make(map[string]string),
|
||||
@ -78,11 +87,19 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
proxyUseDuration: defaultProxyUseDuration,
|
||||
}
|
||||
|
||||
cm.newClient = func(userID string) Client {
|
||||
return newClient(cm, userID)
|
||||
}
|
||||
|
||||
go cm.forwardClientAuths()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
|
||||
cm.newClient = f
|
||||
}
|
||||
|
||||
// SetRoundTripper sets the roundtripper used by clients created by this client manager.
|
||||
func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
|
||||
cm.roundTripper = rt
|
||||
@ -100,7 +117,7 @@ func (cm *ClientManager) GetClient(userID string) Client {
|
||||
return client
|
||||
}
|
||||
|
||||
cm.clients[userID] = newClient(cm, userID)
|
||||
cm.clients[userID] = cm.newClient(userID)
|
||||
|
||||
return cm.clients[userID]
|
||||
}
|
||||
@ -108,10 +125,10 @@ func (cm *ClientManager) GetClient(userID string) Client {
|
||||
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created.
|
||||
func (cm *ClientManager) GetAnonymousClient() Client {
|
||||
if client, ok := cm.clients[""]; ok {
|
||||
client.Logout()
|
||||
client.DeleteAuth()
|
||||
}
|
||||
|
||||
cm.clients[""] = newClient(cm, "")
|
||||
cm.clients[""] = cm.newClient("")
|
||||
|
||||
return cm.clients[""]
|
||||
}
|
||||
@ -127,10 +144,10 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
||||
delete(cm.clients, userID)
|
||||
|
||||
go func() {
|
||||
if err := client.logout(); err != nil {
|
||||
// TODO: Try again! This should loop until it succeeds (might fail the first time due to internet).
|
||||
if err := client.DeleteAuth(); err != nil {
|
||||
// TODO: Retry if the request failed.
|
||||
}
|
||||
client.clearSensitiveData()
|
||||
client.ClearData()
|
||||
cm.clearToken(userID)
|
||||
}()
|
||||
|
||||
|
||||
@ -256,6 +256,21 @@ func (mr *MockClientMockRecorder) EmptyFolder(arg0, arg1 interface{}) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmptyFolder", reflect.TypeOf((*MockClient)(nil).EmptyFolder), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetAddresses mocks base method
|
||||
func (m *MockClient) GetAddresses() (pmapi.AddressList, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAddresses")
|
||||
ret0, _ := ret[0].(pmapi.AddressList)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAddresses indicates an expected call of GetAddresses
|
||||
func (mr *MockClientMockRecorder) GetAddresses() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddresses", reflect.TypeOf((*MockClient)(nil).GetAddresses))
|
||||
}
|
||||
|
||||
// GetAttachment mocks base method
|
||||
func (m *MockClient) GetAttachment(arg0 string) (io.ReadCloser, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Reference in New Issue
Block a user