mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-17 15:46:44 +00:00
Fixing unit tests for client manager.
* [x] pmapi: refresh auth uid won't change
* [x] bridge tests:
* update mocks
* delete auth when FinishLogin fails
* check for mailbox password
* add `gomock.InOrder` for better test control
* [x] fix linter issues except TODOs
* [x] make rootScheme unexported
* [x] store tests: update mocks
This commit is contained in:
@ -3,6 +3,7 @@ package pmapi
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -10,8 +11,6 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var defaultProxyUseDuration = 24 * time.Hour
|
||||
|
||||
// ClientManager is a manager of clients.
|
||||
type ClientManager struct {
|
||||
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
|
||||
@ -21,9 +20,6 @@ type ClientManager struct {
|
||||
config *ClientConfig
|
||||
roundTripper http.RoundTripper
|
||||
|
||||
// TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests.
|
||||
// But that screws up other things like not being able to clear sensitive info during logout
|
||||
// unless the client interface contains a method for that.
|
||||
clients map[string]Client
|
||||
clientsLocker sync.Locker
|
||||
|
||||
@ -33,17 +29,19 @@ type ClientManager struct {
|
||||
expirations map[string]*tokenExpiration
|
||||
expirationsLocker sync.Locker
|
||||
|
||||
host, scheme string
|
||||
hostLocker sync.Locker
|
||||
|
||||
bridgeAuths chan ClientAuth
|
||||
clientAuths chan ClientAuth
|
||||
|
||||
host, scheme string
|
||||
hostLocker sync.RWMutex
|
||||
|
||||
allowProxy bool
|
||||
proxyProvider *proxyProvider
|
||||
proxyUseDuration time.Duration
|
||||
|
||||
idGen idGen
|
||||
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
type idGen int
|
||||
@ -81,14 +79,16 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
expirationsLocker: &sync.Mutex{},
|
||||
|
||||
host: RootURL,
|
||||
scheme: RootScheme,
|
||||
hostLocker: &sync.Mutex{},
|
||||
scheme: rootScheme,
|
||||
hostLocker: sync.RWMutex{},
|
||||
|
||||
bridgeAuths: make(chan ClientAuth),
|
||||
clientAuths: make(chan ClientAuth),
|
||||
|
||||
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
|
||||
proxyUseDuration: defaultProxyUseDuration,
|
||||
proxyUseDuration: proxyUseDuration,
|
||||
|
||||
log: logrus.WithField("pkg", "pmapi-manager"),
|
||||
}
|
||||
|
||||
cm.newClient = func(userID string) Client {
|
||||
@ -97,7 +97,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
|
||||
go cm.forwardClientAuths()
|
||||
|
||||
return
|
||||
return cm
|
||||
}
|
||||
|
||||
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
|
||||
@ -140,20 +140,20 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
||||
delete(cm.clients, userID)
|
||||
|
||||
go func() {
|
||||
if err := client.DeleteAuth(); err != nil {
|
||||
// TODO: Retry if the request failed.
|
||||
if !strings.HasPrefix(userID, "anonymous-") {
|
||||
if err := client.DeleteAuth(); err != nil {
|
||||
// TODO: Retry if the request failed.
|
||||
}
|
||||
}
|
||||
client.ClearData()
|
||||
cm.clearToken(userID)
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetRootURL returns the full root URL (scheme+host).
|
||||
func (cm *ClientManager) GetRootURL() string {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return fmt.Sprintf("%v://%v", cm.scheme, cm.host)
|
||||
}
|
||||
@ -161,24 +161,16 @@ func (cm *ClientManager) GetRootURL() string {
|
||||
// getHost returns the host to make requests to.
|
||||
// It does not include the protocol i.e. no "https://" (use getScheme for that).
|
||||
func (cm *ClientManager) getHost() string {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return cm.host
|
||||
}
|
||||
|
||||
// getScheme returns the scheme with which to make requests to the host.
|
||||
func (cm *ClientManager) getScheme() string {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
return cm.scheme
|
||||
}
|
||||
|
||||
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
||||
func (cm *ClientManager) IsProxyAllowed() bool {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return cm.allowProxy
|
||||
}
|
||||
@ -202,8 +194,8 @@ func (cm *ClientManager) DisallowProxy() {
|
||||
|
||||
// IsProxyEnabled returns whether we are currently proxying requests.
|
||||
func (cm *ClientManager) IsProxyEnabled() bool {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return cm.host != RootURL
|
||||
}
|
||||
@ -264,6 +256,21 @@ func (cm *ClientManager) forwardClientAuths() {
|
||||
}
|
||||
}
|
||||
|
||||
// SetTokenIfUnset sets the token for the given userID if it wasn't already set.
|
||||
// The token does not expire.
|
||||
func (cm *ClientManager) SetTokenIfUnset(userID, token string) {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
if _, ok := cm.tokens[userID]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
logrus.WithField("userID", userID).Info("Setting token because it is currently unset")
|
||||
|
||||
cm.tokens[userID] = token
|
||||
}
|
||||
|
||||
// setToken sets the token for the given userID with the given expiration time.
|
||||
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
|
||||
cm.tokensLocker.Lock()
|
||||
@ -275,6 +282,7 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
|
||||
|
||||
cm.setTokenExpiration(userID, expiration)
|
||||
|
||||
// TODO: This should be one go routine per all tokens.
|
||||
go cm.watchTokenExpiration(userID)
|
||||
}
|
||||
|
||||
@ -311,7 +319,7 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||
|
||||
// If we aren't managing this client, there's nothing to do.
|
||||
if _, ok := cm.clients[ca.UserID]; !ok {
|
||||
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client")
|
||||
logrus.WithField("userID", ca.UserID).Info("Not handling auth for unmanaged client")
|
||||
return
|
||||
}
|
||||
|
||||
@ -332,8 +340,12 @@ func (cm *ClientManager) watchTokenExpiration(userID string) {
|
||||
|
||||
select {
|
||||
case <-expiration.timer.C:
|
||||
logrus.WithField("userID", userID).Info("Auth token expired! Refreshing")
|
||||
cm.clients[userID].AuthRefresh(cm.tokens[userID])
|
||||
cm.log.WithField("userID", userID).Info("Auth token expired! Refreshing")
|
||||
if _, err := cm.clients[userID].AuthRefresh(cm.tokens[userID]); err != nil {
|
||||
cm.log.WithField("userID", userID).
|
||||
WithError(err).
|
||||
Error("Token refresh failed before expiration")
|
||||
}
|
||||
|
||||
case <-expiration.cancel:
|
||||
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
|
||||
|
||||
Reference in New Issue
Block a user