mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-11 05:06:51 +00:00
Fixing unit tests for client manager.
* [x] pmapi: refresh auth uid won't change
* [x] bridge tests:
* update mocks
* delete auth when FinishLogin fails
* check for mailbox password
* add `gomock.InOrder` for better test control
* [x] fix linter issues except TODOs
* [x] make rootScheme unexported
* [x] store tests: update mocks
This commit is contained in:
@ -424,8 +424,7 @@ func (c *client) Unlock(password string) (kr *pmcrypto.KeyRing, err error) {
|
||||
func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
|
||||
// If we don't yet have a saved access token, save this one in case the refresh fails!
|
||||
// That way we can try again later (see handleUnauthorizedStatus).
|
||||
// TODO:
|
||||
// c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
|
||||
c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
|
||||
|
||||
split := strings.Split(uidAndRefreshToken, ":")
|
||||
if len(split) != 2 {
|
||||
|
||||
@ -33,8 +33,6 @@ import (
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var aLongTimeAgo = time.Unix(233431200, 0)
|
||||
|
||||
var testIdentity = &pmcrypto.Identity{
|
||||
Name: "UserID",
|
||||
Email: "",
|
||||
@ -276,10 +274,11 @@ func TestClient_AuthRefresh(t *testing.T) {
|
||||
|
||||
auth, err := c.AuthRefresh(testUID + ":" + testRefreshToken)
|
||||
Ok(t, err)
|
||||
Equals(t, testUID, c.uid)
|
||||
|
||||
exp := &Auth{}
|
||||
*exp = *testAuth
|
||||
exp.uid = "" // AuthRefresh will not return UID (only Auth returns the UID).
|
||||
exp.uid = testUID // AuthRefresh will not return UID (only Auth returns the UID) we should set testUID to be able to generate token, see `GetToken`
|
||||
exp.accessToken = testAccessToken
|
||||
exp.KeySalt = ""
|
||||
exp.EventID = ""
|
||||
|
||||
@ -3,6 +3,7 @@ package pmapi
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -10,8 +11,6 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var defaultProxyUseDuration = 24 * time.Hour
|
||||
|
||||
// ClientManager is a manager of clients.
|
||||
type ClientManager struct {
|
||||
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
|
||||
@ -21,9 +20,6 @@ type ClientManager struct {
|
||||
config *ClientConfig
|
||||
roundTripper http.RoundTripper
|
||||
|
||||
// TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests.
|
||||
// But that screws up other things like not being able to clear sensitive info during logout
|
||||
// unless the client interface contains a method for that.
|
||||
clients map[string]Client
|
||||
clientsLocker sync.Locker
|
||||
|
||||
@ -33,17 +29,19 @@ type ClientManager struct {
|
||||
expirations map[string]*tokenExpiration
|
||||
expirationsLocker sync.Locker
|
||||
|
||||
host, scheme string
|
||||
hostLocker sync.Locker
|
||||
|
||||
bridgeAuths chan ClientAuth
|
||||
clientAuths chan ClientAuth
|
||||
|
||||
host, scheme string
|
||||
hostLocker sync.RWMutex
|
||||
|
||||
allowProxy bool
|
||||
proxyProvider *proxyProvider
|
||||
proxyUseDuration time.Duration
|
||||
|
||||
idGen idGen
|
||||
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
type idGen int
|
||||
@ -81,14 +79,16 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
expirationsLocker: &sync.Mutex{},
|
||||
|
||||
host: RootURL,
|
||||
scheme: RootScheme,
|
||||
hostLocker: &sync.Mutex{},
|
||||
scheme: rootScheme,
|
||||
hostLocker: sync.RWMutex{},
|
||||
|
||||
bridgeAuths: make(chan ClientAuth),
|
||||
clientAuths: make(chan ClientAuth),
|
||||
|
||||
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
|
||||
proxyUseDuration: defaultProxyUseDuration,
|
||||
proxyUseDuration: proxyUseDuration,
|
||||
|
||||
log: logrus.WithField("pkg", "pmapi-manager"),
|
||||
}
|
||||
|
||||
cm.newClient = func(userID string) Client {
|
||||
@ -97,7 +97,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
|
||||
go cm.forwardClientAuths()
|
||||
|
||||
return
|
||||
return cm
|
||||
}
|
||||
|
||||
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
|
||||
@ -140,20 +140,20 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
||||
delete(cm.clients, userID)
|
||||
|
||||
go func() {
|
||||
if err := client.DeleteAuth(); err != nil {
|
||||
// TODO: Retry if the request failed.
|
||||
if !strings.HasPrefix(userID, "anonymous-") {
|
||||
if err := client.DeleteAuth(); err != nil {
|
||||
// TODO: Retry if the request failed.
|
||||
}
|
||||
}
|
||||
client.ClearData()
|
||||
cm.clearToken(userID)
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetRootURL returns the full root URL (scheme+host).
|
||||
func (cm *ClientManager) GetRootURL() string {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return fmt.Sprintf("%v://%v", cm.scheme, cm.host)
|
||||
}
|
||||
@ -161,24 +161,16 @@ func (cm *ClientManager) GetRootURL() string {
|
||||
// getHost returns the host to make requests to.
|
||||
// It does not include the protocol i.e. no "https://" (use getScheme for that).
|
||||
func (cm *ClientManager) getHost() string {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return cm.host
|
||||
}
|
||||
|
||||
// getScheme returns the scheme with which to make requests to the host.
|
||||
func (cm *ClientManager) getScheme() string {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
return cm.scheme
|
||||
}
|
||||
|
||||
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
||||
func (cm *ClientManager) IsProxyAllowed() bool {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return cm.allowProxy
|
||||
}
|
||||
@ -202,8 +194,8 @@ func (cm *ClientManager) DisallowProxy() {
|
||||
|
||||
// IsProxyEnabled returns whether we are currently proxying requests.
|
||||
func (cm *ClientManager) IsProxyEnabled() bool {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return cm.host != RootURL
|
||||
}
|
||||
@ -264,6 +256,21 @@ func (cm *ClientManager) forwardClientAuths() {
|
||||
}
|
||||
}
|
||||
|
||||
// SetTokenIfUnset sets the token for the given userID if it wasn't already set.
|
||||
// The token does not expire.
|
||||
func (cm *ClientManager) SetTokenIfUnset(userID, token string) {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
if _, ok := cm.tokens[userID]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
logrus.WithField("userID", userID).Info("Setting token because it is currently unset")
|
||||
|
||||
cm.tokens[userID] = token
|
||||
}
|
||||
|
||||
// setToken sets the token for the given userID with the given expiration time.
|
||||
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
|
||||
cm.tokensLocker.Lock()
|
||||
@ -275,6 +282,7 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
|
||||
|
||||
cm.setTokenExpiration(userID, expiration)
|
||||
|
||||
// TODO: This should be one go routine per all tokens.
|
||||
go cm.watchTokenExpiration(userID)
|
||||
}
|
||||
|
||||
@ -311,7 +319,7 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||
|
||||
// If we aren't managing this client, there's nothing to do.
|
||||
if _, ok := cm.clients[ca.UserID]; !ok {
|
||||
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client")
|
||||
logrus.WithField("userID", ca.UserID).Info("Not handling auth for unmanaged client")
|
||||
return
|
||||
}
|
||||
|
||||
@ -332,8 +340,12 @@ func (cm *ClientManager) watchTokenExpiration(userID string) {
|
||||
|
||||
select {
|
||||
case <-expiration.timer.C:
|
||||
logrus.WithField("userID", userID).Info("Auth token expired! Refreshing")
|
||||
cm.clients[userID].AuthRefresh(cm.tokens[userID])
|
||||
cm.log.WithField("userID", userID).Info("Auth token expired! Refreshing")
|
||||
if _, err := cm.clients[userID].AuthRefresh(cm.tokens[userID]); err != nil {
|
||||
cm.log.WithField("userID", userID).
|
||||
WithError(err).
|
||||
Error("Token refresh failed before expiration")
|
||||
}
|
||||
|
||||
case <-expiration.cancel:
|
||||
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
|
||||
|
||||
@ -27,10 +27,9 @@ import (
|
||||
// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
|
||||
// Default is pmapi_prod.
|
||||
//
|
||||
// It should not contain the protocol! The protocol should be in RootScheme.
|
||||
// It must not contain the protocol! The protocol should be in rootScheme.
|
||||
var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
|
||||
|
||||
var RootScheme = "https"
|
||||
var rootScheme = "https" //nolint[gochecknoglobals]
|
||||
|
||||
// CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program
|
||||
// version and email client.
|
||||
|
||||
@ -21,5 +21,5 @@ package pmapi
|
||||
|
||||
func init() {
|
||||
RootURL = "dev.protonmail.com/api"
|
||||
RootScheme = "https"
|
||||
rootScheme = "https"
|
||||
}
|
||||
|
||||
@ -28,7 +28,7 @@ func init() {
|
||||
// Use port above 1000 which doesn't need root access to start anything on it.
|
||||
// Now the port is rounded pi. :-)
|
||||
RootURL = "127.0.0.1:3142/api"
|
||||
RootScheme = "http"
|
||||
rootScheme = "http"
|
||||
|
||||
// TLS certificate is self-signed
|
||||
defaultTransport = &http.Transport{
|
||||
|
||||
@ -41,7 +41,7 @@ func setTestDialerWithPinning(cm *ClientManager) (*int, *DialerWithPinning) {
|
||||
func TestTLSPinValid(t *testing.T) {
|
||||
cm := NewClientManager(testLiveConfig)
|
||||
cm.host = liveAPI
|
||||
RootScheme = "https"
|
||||
rootScheme = "https"
|
||||
called, _ := setTestDialerWithPinning(cm)
|
||||
client := cm.GetClient("pmapi" + t.Name())
|
||||
|
||||
|
||||
@ -5,12 +5,11 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
io "io"
|
||||
reflect "reflect"
|
||||
|
||||
crypto "github.com/ProtonMail/gopenpgp/crypto"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
io "io"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockClient is a mock of Client interface
|
||||
@ -110,6 +109,18 @@ func (mr *MockClientMockRecorder) AuthRefresh(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRefresh", reflect.TypeOf((*MockClient)(nil).AuthRefresh), arg0)
|
||||
}
|
||||
|
||||
// ClearData mocks base method
|
||||
func (m *MockClient) ClearData() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ClearData")
|
||||
}
|
||||
|
||||
// ClearData indicates an expected call of ClearData
|
||||
func (mr *MockClientMockRecorder) ClearData() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearData", reflect.TypeOf((*MockClient)(nil).ClearData))
|
||||
}
|
||||
|
||||
// CountMessages mocks base method
|
||||
func (m *MockClient) CountMessages(arg0 string) ([]*pmapi.MessagesCount, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@ -214,6 +225,20 @@ func (mr *MockClientMockRecorder) DeleteAttachment(arg0 interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAttachment", reflect.TypeOf((*MockClient)(nil).DeleteAttachment), arg0)
|
||||
}
|
||||
|
||||
// DeleteAuth mocks base method
|
||||
func (m *MockClient) DeleteAuth() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAuth")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAuth indicates an expected call of DeleteAuth
|
||||
func (mr *MockClientMockRecorder) DeleteAuth() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuth", reflect.TypeOf((*MockClient)(nil).DeleteAuth))
|
||||
}
|
||||
|
||||
// DeleteLabel mocks base method
|
||||
func (m *MockClient) DeleteLabel(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -30,7 +30,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
proxyRevertTime = 24 * time.Hour
|
||||
proxyUseDuration = 24 * time.Hour
|
||||
proxySearchTimeout = 30 * time.Second
|
||||
proxyQueryTimeout = 10 * time.Second
|
||||
proxyLookupWait = 5 * time.Second
|
||||
|
||||
@ -27,7 +27,6 @@ import (
|
||||
|
||||
// NewRequest creates a new request.
|
||||
func (c *client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) {
|
||||
// TODO: Support other protocols (localhost needs http not https).
|
||||
req, err = http.NewRequest(method, c.cm.GetRootURL()+path, body)
|
||||
|
||||
if req != nil {
|
||||
|
||||
Reference in New Issue
Block a user