Fixing unit tests for client manager.

* [x] pmapi: refresh auth uid won't change
* [x] bridge tests:
    * update mocks
    * delete auth when FinishLogin fails
    * check for mailbox password
    * add `gomock.InOrder` for better test control
* [x] fix linter issues except TODOs
* [x] make rootScheme unexported
* [x] store tests: update mocks
This commit is contained in:
Jakub
2020-04-14 07:54:11 +02:00
committed by James Houlahan
parent debd374d75
commit 80f4e1e346
25 changed files with 537 additions and 364 deletions

View File

@ -424,8 +424,7 @@ func (c *client) Unlock(password string) (kr *pmcrypto.KeyRing, err error) {
func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
// If we don't yet have a saved access token, save this one in case the refresh fails!
// That way we can try again later (see handleUnauthorizedStatus).
// TODO:
// c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
split := strings.Split(uidAndRefreshToken, ":")
if len(split) != 2 {

View File

@ -33,8 +33,6 @@ import (
r "github.com/stretchr/testify/require"
)
var aLongTimeAgo = time.Unix(233431200, 0)
var testIdentity = &pmcrypto.Identity{
Name: "UserID",
Email: "",
@ -276,10 +274,11 @@ func TestClient_AuthRefresh(t *testing.T) {
auth, err := c.AuthRefresh(testUID + ":" + testRefreshToken)
Ok(t, err)
Equals(t, testUID, c.uid)
exp := &Auth{}
*exp = *testAuth
exp.uid = "" // AuthRefresh will not return UID (only Auth returns the UID).
exp.uid = testUID // AuthRefresh will not return UID (only Auth returns the UID) we should set testUID to be able to generate token, see `GetToken`
exp.accessToken = testAccessToken
exp.KeySalt = ""
exp.EventID = ""

View File

@ -3,6 +3,7 @@ package pmapi
import (
"fmt"
"net/http"
"strings"
"sync"
"time"
@ -10,8 +11,6 @@ import (
"github.com/sirupsen/logrus"
)
var defaultProxyUseDuration = 24 * time.Hour
// ClientManager is a manager of clients.
type ClientManager struct {
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
@ -21,9 +20,6 @@ type ClientManager struct {
config *ClientConfig
roundTripper http.RoundTripper
// TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests.
// But that screws up other things like not being able to clear sensitive info during logout
// unless the client interface contains a method for that.
clients map[string]Client
clientsLocker sync.Locker
@ -33,17 +29,19 @@ type ClientManager struct {
expirations map[string]*tokenExpiration
expirationsLocker sync.Locker
host, scheme string
hostLocker sync.Locker
bridgeAuths chan ClientAuth
clientAuths chan ClientAuth
host, scheme string
hostLocker sync.RWMutex
allowProxy bool
proxyProvider *proxyProvider
proxyUseDuration time.Duration
idGen idGen
log *logrus.Entry
}
type idGen int
@ -81,14 +79,16 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
expirationsLocker: &sync.Mutex{},
host: RootURL,
scheme: RootScheme,
hostLocker: &sync.Mutex{},
scheme: rootScheme,
hostLocker: sync.RWMutex{},
bridgeAuths: make(chan ClientAuth),
clientAuths: make(chan ClientAuth),
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
proxyUseDuration: defaultProxyUseDuration,
proxyUseDuration: proxyUseDuration,
log: logrus.WithField("pkg", "pmapi-manager"),
}
cm.newClient = func(userID string) Client {
@ -97,7 +97,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
go cm.forwardClientAuths()
return
return cm
}
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
@ -140,20 +140,20 @@ func (cm *ClientManager) LogoutClient(userID string) {
delete(cm.clients, userID)
go func() {
if err := client.DeleteAuth(); err != nil {
// TODO: Retry if the request failed.
if !strings.HasPrefix(userID, "anonymous-") {
if err := client.DeleteAuth(); err != nil {
// TODO: Retry if the request failed.
}
}
client.ClearData()
cm.clearToken(userID)
}()
return
}
// GetRootURL returns the full root URL (scheme+host).
func (cm *ClientManager) GetRootURL() string {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return fmt.Sprintf("%v://%v", cm.scheme, cm.host)
}
@ -161,24 +161,16 @@ func (cm *ClientManager) GetRootURL() string {
// getHost returns the host to make requests to.
// It does not include the protocol i.e. no "https://" (use getScheme for that).
func (cm *ClientManager) getHost() string {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return cm.host
}
// getScheme returns the scheme with which to make requests to the host.
func (cm *ClientManager) getScheme() string {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
return cm.scheme
}
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
func (cm *ClientManager) IsProxyAllowed() bool {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return cm.allowProxy
}
@ -202,8 +194,8 @@ func (cm *ClientManager) DisallowProxy() {
// IsProxyEnabled returns whether we are currently proxying requests.
func (cm *ClientManager) IsProxyEnabled() bool {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return cm.host != RootURL
}
@ -264,6 +256,21 @@ func (cm *ClientManager) forwardClientAuths() {
}
}
// SetTokenIfUnset sets the token for the given userID if it wasn't already set.
// The token does not expire.
func (cm *ClientManager) SetTokenIfUnset(userID, token string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
if _, ok := cm.tokens[userID]; ok {
return
}
logrus.WithField("userID", userID).Info("Setting token because it is currently unset")
cm.tokens[userID] = token
}
// setToken sets the token for the given userID with the given expiration time.
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
cm.tokensLocker.Lock()
@ -275,6 +282,7 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
cm.setTokenExpiration(userID, expiration)
// TODO: This should be one go routine per all tokens.
go cm.watchTokenExpiration(userID)
}
@ -311,7 +319,7 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
// If we aren't managing this client, there's nothing to do.
if _, ok := cm.clients[ca.UserID]; !ok {
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client")
logrus.WithField("userID", ca.UserID).Info("Not handling auth for unmanaged client")
return
}
@ -332,8 +340,12 @@ func (cm *ClientManager) watchTokenExpiration(userID string) {
select {
case <-expiration.timer.C:
logrus.WithField("userID", userID).Info("Auth token expired! Refreshing")
cm.clients[userID].AuthRefresh(cm.tokens[userID])
cm.log.WithField("userID", userID).Info("Auth token expired! Refreshing")
if _, err := cm.clients[userID].AuthRefresh(cm.tokens[userID]); err != nil {
cm.log.WithField("userID", userID).
WithError(err).
Error("Token refresh failed before expiration")
}
case <-expiration.cancel:
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")

View File

@ -27,10 +27,9 @@ import (
// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
// Default is pmapi_prod.
//
// It should not contain the protocol! The protocol should be in RootScheme.
// It must not contain the protocol! The protocol should be in rootScheme.
var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
var RootScheme = "https"
var rootScheme = "https" //nolint[gochecknoglobals]
// CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program
// version and email client.

View File

@ -21,5 +21,5 @@ package pmapi
func init() {
RootURL = "dev.protonmail.com/api"
RootScheme = "https"
rootScheme = "https"
}

View File

@ -28,7 +28,7 @@ func init() {
// Use port above 1000 which doesn't need root access to start anything on it.
// Now the port is rounded pi. :-)
RootURL = "127.0.0.1:3142/api"
RootScheme = "http"
rootScheme = "http"
// TLS certificate is self-signed
defaultTransport = &http.Transport{

View File

@ -41,7 +41,7 @@ func setTestDialerWithPinning(cm *ClientManager) (*int, *DialerWithPinning) {
func TestTLSPinValid(t *testing.T) {
cm := NewClientManager(testLiveConfig)
cm.host = liveAPI
RootScheme = "https"
rootScheme = "https"
called, _ := setTestDialerWithPinning(cm)
client := cm.GetClient("pmapi" + t.Name())

View File

@ -5,12 +5,11 @@
package mocks
import (
io "io"
reflect "reflect"
crypto "github.com/ProtonMail/gopenpgp/crypto"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
io "io"
reflect "reflect"
)
// MockClient is a mock of Client interface
@ -110,6 +109,18 @@ func (mr *MockClientMockRecorder) AuthRefresh(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRefresh", reflect.TypeOf((*MockClient)(nil).AuthRefresh), arg0)
}
// ClearData mocks base method
func (m *MockClient) ClearData() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ClearData")
}
// ClearData indicates an expected call of ClearData
func (mr *MockClientMockRecorder) ClearData() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearData", reflect.TypeOf((*MockClient)(nil).ClearData))
}
// CountMessages mocks base method
func (m *MockClient) CountMessages(arg0 string) ([]*pmapi.MessagesCount, error) {
m.ctrl.T.Helper()
@ -214,6 +225,20 @@ func (mr *MockClientMockRecorder) DeleteAttachment(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAttachment", reflect.TypeOf((*MockClient)(nil).DeleteAttachment), arg0)
}
// DeleteAuth mocks base method
func (m *MockClient) DeleteAuth() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAuth")
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAuth indicates an expected call of DeleteAuth
func (mr *MockClientMockRecorder) DeleteAuth() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuth", reflect.TypeOf((*MockClient)(nil).DeleteAuth))
}
// DeleteLabel mocks base method
func (m *MockClient) DeleteLabel(arg0 string) error {
m.ctrl.T.Helper()

View File

@ -30,7 +30,7 @@ import (
)
const (
proxyRevertTime = 24 * time.Hour
proxyUseDuration = 24 * time.Hour
proxySearchTimeout = 30 * time.Second
proxyQueryTimeout = 10 * time.Second
proxyLookupWait = 5 * time.Second

View File

@ -27,7 +27,6 @@ import (
// NewRequest creates a new request.
func (c *client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) {
// TODO: Support other protocols (localhost needs http not https).
req, err = http.NewRequest(method, c.cm.GetRootURL()+path, body)
if req != nil {