mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 20:56:51 +00:00
feat: improve login flow
This commit is contained in:
2
Makefile
2
Makefile
@ -149,7 +149,7 @@ coverage: test
|
|||||||
go tool cover -html=/tmp/coverage.out -o=coverage.html
|
go tool cover -html=/tmp/coverage.out -o=coverage.html
|
||||||
|
|
||||||
mocks:
|
mocks:
|
||||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/bridge Configer,PreferenceProvider,PanicHandler,PMAPIProvider,CredentialsStorer > internal/bridge/mocks/mocks.go
|
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/bridge Configer,PreferenceProvider,PanicHandler,ClientManager,PMAPIProvider,CredentialsStorer > internal/bridge/mocks/mocks.go
|
||||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser > internal/store/mocks/mocks.go
|
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser > internal/store/mocks/mocks.go
|
||||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go
|
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,6 @@
|
|||||||
package bridge
|
package bridge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -33,6 +32,7 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/pkg/errors"
|
||||||
logrus "github.com/sirupsen/logrus"
|
logrus "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ type Bridge struct {
|
|||||||
panicHandler PanicHandler
|
panicHandler PanicHandler
|
||||||
events listener.Listener
|
events listener.Listener
|
||||||
version string
|
version string
|
||||||
clientManager *pmapi.ClientManager
|
clientManager ClientManager
|
||||||
credStorer CredentialsStorer
|
credStorer CredentialsStorer
|
||||||
storeCache *store.Cache
|
storeCache *store.Cache
|
||||||
|
|
||||||
@ -76,7 +76,7 @@ func New(
|
|||||||
panicHandler PanicHandler,
|
panicHandler PanicHandler,
|
||||||
eventListener listener.Listener,
|
eventListener listener.Listener,
|
||||||
version string,
|
version string,
|
||||||
clientManager *pmapi.ClientManager,
|
clientManager ClientManager,
|
||||||
credStorer CredentialsStorer,
|
credStorer CredentialsStorer,
|
||||||
) *Bridge {
|
) *Bridge {
|
||||||
log.Trace("Creating new bridge")
|
log.Trace("Creating new bridge")
|
||||||
@ -107,7 +107,7 @@ func New(
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer panicHandler.HandlePanic()
|
defer panicHandler.HandlePanic()
|
||||||
b.watchUserAuths()
|
b.watchAPIAuths()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if b.credStorer == nil {
|
if b.credStorer == nil {
|
||||||
@ -178,16 +178,21 @@ func (b *Bridge) watchBridgeOutdated() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// watchUserAuths receives auths from the client manager and sends them to the appropriate user.
|
// watchAPIAuths receives auths from the client manager and sends them to the appropriate user.
|
||||||
func (b *Bridge) watchUserAuths() {
|
func (b *Bridge) watchAPIAuths() {
|
||||||
for auth := range b.clientManager.GetBridgeAuthChannel() {
|
for auth := range b.clientManager.GetBridgeAuthChannel() {
|
||||||
logrus.Debug("Bridge received auth from ClientManager")
|
logrus.Debug("Bridge received auth from ClientManager")
|
||||||
|
|
||||||
if user, ok := b.hasUser(auth.UserID); ok {
|
user, ok := b.hasUser(auth.UserID)
|
||||||
logrus.Debug("Bridge is forwarding auth to user")
|
if !ok {
|
||||||
user.AuthorizeWithAPIAuth(auth.Auth)
|
|
||||||
} else {
|
|
||||||
logrus.Info("User is not added to bridge yet")
|
logrus.Info("User is not added to bridge yet")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if auth.Auth != nil {
|
||||||
|
user.updateAuthToken(auth.Auth)
|
||||||
|
} else {
|
||||||
|
user.logout()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -209,113 +214,140 @@ func (b *Bridge) closeAllConnections() {
|
|||||||
// * In case user `auth.HasMailboxPassword()`, ask for it, otherwise use `password`
|
// * In case user `auth.HasMailboxPassword()`, ask for it, otherwise use `password`
|
||||||
// and then finish the login procedure.
|
// and then finish the login procedure.
|
||||||
// user, err := bridge.FinishLogin(client, auth, mailboxPassword)
|
// user, err := bridge.FinishLogin(client, auth, mailboxPassword)
|
||||||
func (b *Bridge) Login(username, password string) (loginClient PMAPIProvider, auth *pmapi.Auth, err error) {
|
func (b *Bridge) Login(username, password string) (authClient PMAPIProvider, auth *pmapi.Auth, err error) {
|
||||||
log.WithField("username", username).Trace("Logging in to bridge")
|
|
||||||
|
|
||||||
b.crashBandicoot(username)
|
b.crashBandicoot(username)
|
||||||
|
|
||||||
// We need to use "login" client because we need userID to properly assign access tokens into token manager.
|
// We need to use anonymous client because we don't yet have userID and so can't save auth tokens yet.
|
||||||
loginClient = b.clientManager.GetClient("login")
|
authClient = b.clientManager.GetAnonymousClient()
|
||||||
|
|
||||||
authInfo, err := loginClient.AuthInfo(username)
|
authInfo, err := authClient.AuthInfo(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithField("username", username).WithError(err).Error("Could not get auth info for user")
|
log.WithField("username", username).WithError(err).Error("Could not get auth info for user")
|
||||||
return nil, nil, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth, err = loginClient.Auth(username, password, authInfo); err != nil {
|
if auth, err = authClient.Auth(username, password, authInfo); err != nil {
|
||||||
log.WithField("username", username).WithError(err).Error("Could not get auth for user")
|
log.WithField("username", username).WithError(err).Error("Could not get auth for user")
|
||||||
return loginClient, auth, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return loginClient, auth, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// FinishLogin finishes the login procedure and adds the user into the credentials store.
|
// FinishLogin finishes the login procedure and adds the user into the credentials store.
|
||||||
// See `Login` for more details of the login flow.
|
// See `Login` for more details of the login flow.
|
||||||
func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPassword string) (user *User, err error) { //nolint[funlen]
|
func (b *Bridge) FinishLogin(authClient PMAPIProvider, auth *pmapi.Auth, mbPassword string) (user *User, err error) { //nolint[funlen]
|
||||||
log.Trace("Finishing bridge login")
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err == pmapi.ErrUpgradeApplication {
|
if err == pmapi.ErrUpgradeApplication {
|
||||||
b.events.Emit(events.UpgradeApplicationEvent, "")
|
b.events.Emit(events.UpgradeApplicationEvent, "")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
b.lock.Lock()
|
apiUser, hashedPassword, err := getAPIUser(authClient, auth, mbPassword)
|
||||||
defer b.lock.Unlock()
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to get API user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
defer loginClient.Logout()
|
if user, err = b.GetUser(apiUser.ID); err == nil {
|
||||||
|
if err = b.connectExistingUser(user, auth, hashedPassword); err != nil {
|
||||||
|
log.WithError(err).Error("Failed to connect existing user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err = b.addNewUser(apiUser, auth, hashedPassword); err != nil {
|
||||||
|
log.WithError(err).Error("Failed to add new user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mbPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt)
|
b.events.Emit(events.UserRefreshEvent, apiUser.ID)
|
||||||
|
|
||||||
|
return b.GetUser(apiUser.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectExistingUser connects an existing bridge user to the bridge.
|
||||||
|
func (b *Bridge) connectExistingUser(user *User, auth *pmapi.Auth, hashedPassword string) (err error) {
|
||||||
|
if user.IsConnected() {
|
||||||
|
return errors.New("user is already connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the user's password in the cred store in case they changed it.
|
||||||
|
if err = b.credStorer.UpdatePassword(user.ID(), hashedPassword); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to update password of user in credentials store")
|
||||||
|
}
|
||||||
|
|
||||||
|
client := b.clientManager.GetClient(user.ID())
|
||||||
|
|
||||||
|
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to refresh auth token of new client")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = b.credStorer.UpdateToken(user.ID(), auth.GenToken()); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to update token of user in credentials store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = user.init(b.idleUpdates); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to initialise user")
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNewUser adds a new bridge user to the bridge.
|
||||||
|
func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword string) (err error) {
|
||||||
|
client := b.clientManager.GetClient(user.ID)
|
||||||
|
|
||||||
|
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to refresh token in new client")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user, err = client.UpdateUser(); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to update API user")
|
||||||
|
}
|
||||||
|
|
||||||
|
activeEmails := client.Addresses().ActiveEmails()
|
||||||
|
|
||||||
|
if _, err = b.credStorer.Add(user.ID, user.Name, auth.GenToken(), hashedPassword, activeEmails); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to add user to credentials store")
|
||||||
|
}
|
||||||
|
|
||||||
|
bridgeUser, err := newUser(b.panicHandler, user.ID, b.events, b.credStorer, b.clientManager, b.storeCache, b.config.GetDBDir())
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to create user")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The user needs to be part of the users list in order for it to receive an auth during initialisation.
|
||||||
|
// TODO: If adding the user fails, we don't want to leave it there.
|
||||||
|
b.users = append(b.users, bridgeUser)
|
||||||
|
|
||||||
|
if err = bridgeUser.init(b.idleUpdates); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to initialise user")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAPIUser(client PMAPIProvider, auth *pmapi.Auth, mbPassword string) (user *pmapi.User, hashedPassword string, err error) {
|
||||||
|
hashedPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Could not hash mailbox password")
|
log.WithError(err).Error("Could not hash mailbox password")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = loginClient.Unlock(mbPassword); err != nil {
|
if _, err = client.Unlock(hashedPassword); err != nil {
|
||||||
log.WithError(err).Error("Could not decrypt keyring")
|
log.WithError(err).Error("Could not decrypt keyring")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
apiUser, err := loginClient.CurrentUser()
|
if user, err = client.UpdateUser(); err != nil {
|
||||||
if err != nil {
|
log.WithError(err).Error("Could not update API user")
|
||||||
log.WithError(err).Error("Could not get login API user")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, hasUser := b.hasUser(apiUser.ID)
|
return
|
||||||
|
|
||||||
// If the user exists and is logged in, we don't want to do anything.
|
|
||||||
if hasUser && user.IsConnected() {
|
|
||||||
err = errors.New("user is already logged in")
|
|
||||||
log.WithError(err).Warn("User is already logged in")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
apiClient := b.clientManager.GetClient(apiUser.ID)
|
|
||||||
auth, err = apiClient.AuthRefresh(auth.GenToken())
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Error("Could not refresh token in new client")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// We load the current user again because it should now have addresses loaded.
|
|
||||||
apiUser, err = apiClient.CurrentUser()
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Error("Could not get current API user")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
activeEmails := apiClient.Addresses().ActiveEmails()
|
|
||||||
if _, err = b.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), mbPassword, activeEmails); err != nil {
|
|
||||||
log.WithError(err).Error("Could not add user to credentials store")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it's a new user, generate the user object.
|
|
||||||
if !hasUser {
|
|
||||||
user, err = newUser(b.panicHandler, apiUser.ID, b.events, b.credStorer, b.clientManager, b.storeCache, b.config.GetDBDir())
|
|
||||||
if err != nil {
|
|
||||||
log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.users = append(b.users, user)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up the user auth and store (which we do for both new and existing users).
|
|
||||||
if err = user.init(b.idleUpdates); err != nil {
|
|
||||||
log.WithField("user", user.userID).WithError(err).Error("Could not initialise user")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasUser {
|
|
||||||
b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel))
|
|
||||||
}
|
|
||||||
|
|
||||||
b.events.Emit(events.UserRefreshEvent, apiUser.ID)
|
|
||||||
|
|
||||||
return user, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsers returns all added users into keychain (even logged out users).
|
// GetUsers returns all added users into keychain (even logged out users).
|
||||||
@ -326,8 +358,7 @@ func (b *Bridge) GetUsers() []*User {
|
|||||||
return b.users
|
return b.users
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUser returns a user by `query` which is compared to users' ID, username
|
// GetUser returns a user by `query` which is compared to users' ID, username or any attached e-mail address.
|
||||||
// or any attached e-mail address.
|
|
||||||
func (b *Bridge) GetUser(query string) (*User, error) {
|
func (b *Bridge) GetUser(query string) (*User, error) {
|
||||||
b.crashBandicoot(query)
|
b.crashBandicoot(query)
|
||||||
|
|
||||||
|
|||||||
@ -73,7 +73,6 @@ func TestNewBridgeWithConnectedUserWithBadToken(t *testing.T) {
|
|||||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||||
m.pmapiClient.EXPECT().Logout().Return(nil)
|
m.pmapiClient.EXPECT().Logout().Return(nil)
|
||||||
m.pmapiClient.EXPECT().SetAuths(nil)
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||||
|
|||||||
@ -129,6 +129,7 @@ type mocks struct {
|
|||||||
PanicHandler *bridgemocks.MockPanicHandler
|
PanicHandler *bridgemocks.MockPanicHandler
|
||||||
prefProvider *bridgemocks.MockPreferenceProvider
|
prefProvider *bridgemocks.MockPreferenceProvider
|
||||||
pmapiClient *bridgemocks.MockPMAPIProvider
|
pmapiClient *bridgemocks.MockPMAPIProvider
|
||||||
|
clientManager *bridgemocks.MockClientManager
|
||||||
credentialsStore *bridgemocks.MockCredentialsStorer
|
credentialsStore *bridgemocks.MockCredentialsStorer
|
||||||
eventListener *MockListener
|
eventListener *MockListener
|
||||||
|
|
||||||
@ -208,14 +209,10 @@ func testNewBridge(t *testing.T, m mocks) *Bridge {
|
|||||||
m.prefProvider.EXPECT().GetBool(preferences.AllowProxyKey).Return(false).AnyTimes()
|
m.prefProvider.EXPECT().GetBool(preferences.AllowProxyKey).Return(false).AnyTimes()
|
||||||
m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes()
|
m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes()
|
||||||
m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes()
|
m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes()
|
||||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any()).AnyTimes()
|
|
||||||
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
||||||
pmapiClientFactory := func(userID string) PMAPIProvider {
|
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient)
|
||||||
log.WithField("userID", userID).Info("Creating new pmclient")
|
|
||||||
return m.pmapiClient
|
|
||||||
}
|
|
||||||
|
|
||||||
bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", pmapiClientFactory, m.credentialsStore)
|
bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore)
|
||||||
|
|
||||||
waitForEvents()
|
waitForEvents()
|
||||||
|
|
||||||
|
|||||||
@ -125,6 +125,20 @@ func (s *Store) UpdateEmails(userID string, emails []string) error {
|
|||||||
return s.saveCredentials(credentials)
|
return s.saveCredentials(credentials)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdatePassword(userID, password string) error {
|
||||||
|
storeLocker.Lock()
|
||||||
|
defer storeLocker.Unlock()
|
||||||
|
|
||||||
|
credentials, err := s.get(userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials.MailboxPassword = password
|
||||||
|
|
||||||
|
return s.saveCredentials(credentials)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) UpdateToken(userID, apiToken string) error {
|
func (s *Store) UpdateToken(userID, apiToken string) error {
|
||||||
storeLocker.Lock()
|
storeLocker.Lock()
|
||||||
defer storeLocker.Unlock()
|
defer storeLocker.Unlock()
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,PMAPIProvider,CredentialsStorer)
|
// Source: github.com/ProtonMail/proton-bridge/internal/bridge (interfaces: Configer,PreferenceProvider,PanicHandler,ClientManager,PMAPIProvider,CredentialsStorer)
|
||||||
|
|
||||||
// Package mocks is a generated GoMock package.
|
// Package mocks is a generated GoMock package.
|
||||||
package mocks
|
package mocks
|
||||||
@ -205,6 +205,95 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MockClientManager is a mock of ClientManager interface
|
||||||
|
type MockClientManager struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockClientManagerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||||
|
type MockClientManagerMockRecorder struct {
|
||||||
|
mock *MockClientManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockClientManager creates a new mock instance
|
||||||
|
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||||
|
mock := &MockClientManager{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowProxy mocks base method
|
||||||
|
func (m *MockClientManager) AllowProxy() {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "AllowProxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowProxy indicates an expected call of AllowProxy
|
||||||
|
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisallowProxy mocks base method
|
||||||
|
func (m *MockClientManager) DisallowProxy() {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "DisallowProxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisallowProxy indicates an expected call of DisallowProxy
|
||||||
|
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAnonymousClient mocks base method
|
||||||
|
func (m *MockClientManager) GetAnonymousClient() *pmapi.Client {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetAnonymousClient")
|
||||||
|
ret0, _ := ret[0].(*pmapi.Client)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAnonymousClient indicates an expected call of GetAnonymousClient
|
||||||
|
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBridgeAuthChannel mocks base method
|
||||||
|
func (m *MockClientManager) GetBridgeAuthChannel() chan pmapi.ClientAuth {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetBridgeAuthChannel")
|
||||||
|
ret0, _ := ret[0].(chan pmapi.ClientAuth)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBridgeAuthChannel indicates an expected call of GetBridgeAuthChannel
|
||||||
|
func (mr *MockClientManagerMockRecorder) GetBridgeAuthChannel() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBridgeAuthChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient mocks base method
|
||||||
|
func (m *MockClientManager) GetClient(arg0 string) *pmapi.Client {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||||
|
ret0, _ := ret[0].(*pmapi.Client)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient indicates an expected call of GetClient
|
||||||
|
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
// MockPMAPIProvider is a mock of PMAPIProvider interface
|
// MockPMAPIProvider is a mock of PMAPIProvider interface
|
||||||
type MockPMAPIProvider struct {
|
type MockPMAPIProvider struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
@ -615,11 +704,9 @@ func (mr *MockPMAPIProviderMockRecorder) ListMessages(arg0 interface{}) *gomock.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Logout mocks base method
|
// Logout mocks base method
|
||||||
func (m *MockPMAPIProvider) Logout() error {
|
func (m *MockPMAPIProvider) Logout() {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Logout")
|
m.ctrl.Call(m, "Logout")
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logout indicates an expected call of Logout
|
// Logout indicates an expected call of Logout
|
||||||
@ -700,18 +787,6 @@ func (mr *MockPMAPIProviderMockRecorder) SendSimpleMetric(arg0, arg1, arg2 inter
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSimpleMetric", reflect.TypeOf((*MockPMAPIProvider)(nil).SendSimpleMetric), arg0, arg1, arg2)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSimpleMetric", reflect.TypeOf((*MockPMAPIProvider)(nil).SendSimpleMetric), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAuths mocks base method
|
|
||||||
func (m *MockPMAPIProvider) SetAuths(arg0 chan<- *pmapi.Auth) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "SetAuths", arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAuths indicates an expected call of SetAuths
|
|
||||||
func (mr *MockPMAPIProviderMockRecorder) SetAuths(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAuths", reflect.TypeOf((*MockPMAPIProvider)(nil).SetAuths), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnlabelMessages mocks base method
|
// UnlabelMessages mocks base method
|
||||||
func (m *MockPMAPIProvider) UnlabelMessages(arg0 []string, arg1 string) error {
|
func (m *MockPMAPIProvider) UnlabelMessages(arg0 []string, arg1 string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@ -909,6 +984,20 @@ func (mr *MockCredentialsStorerMockRecorder) UpdateEmails(arg0, arg1 interface{}
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEmails", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateEmails), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEmails", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateEmails), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdatePassword mocks base method
|
||||||
|
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePassword indicates an expected call of UpdatePassword
|
||||||
|
func (mr *MockCredentialsStorerMockRecorder) UpdatePassword(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdatePassword), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateToken mocks base method
|
// UpdateToken mocks base method
|
||||||
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1 string) error {
|
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1 string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@ -43,7 +43,12 @@ type PanicHandler interface {
|
|||||||
HandlePanic()
|
HandlePanic()
|
||||||
}
|
}
|
||||||
|
|
||||||
type Clientman interface {
|
type ClientManager interface {
|
||||||
|
GetClient(userID string) *pmapi.Client
|
||||||
|
GetAnonymousClient() *pmapi.Client
|
||||||
|
GetBridgeAuthChannel() chan pmapi.ClientAuth
|
||||||
|
AllowProxy()
|
||||||
|
DisallowProxy()
|
||||||
}
|
}
|
||||||
|
|
||||||
type PMAPIProvider interface {
|
type PMAPIProvider interface {
|
||||||
@ -100,6 +105,7 @@ type CredentialsStorer interface {
|
|||||||
Get(userID string) (*credentials.Credentials, error)
|
Get(userID string) (*credentials.Credentials, error)
|
||||||
SwitchAddressMode(userID string) error
|
SwitchAddressMode(userID string) error
|
||||||
UpdateEmails(userID string, emails []string) error
|
UpdateEmails(userID string, emails []string) error
|
||||||
|
UpdatePassword(userID, password string) error
|
||||||
UpdateToken(userID, apiToken string) error
|
UpdateToken(userID, apiToken string) error
|
||||||
Logout(userID string) error
|
Logout(userID string) error
|
||||||
Delete(userID string) error
|
Delete(userID string) error
|
||||||
|
|||||||
@ -41,7 +41,7 @@ type User struct {
|
|||||||
log *logrus.Entry
|
log *logrus.Entry
|
||||||
panicHandler PanicHandler
|
panicHandler PanicHandler
|
||||||
listener listener.Listener
|
listener listener.Listener
|
||||||
clientMan *pmapi.ClientManager
|
clientMan ClientManager
|
||||||
credStorer CredentialsStorer
|
credStorer CredentialsStorer
|
||||||
|
|
||||||
imapUpdatesChannel chan interface{}
|
imapUpdatesChannel chan interface{}
|
||||||
@ -66,7 +66,7 @@ func newUser(
|
|||||||
userID string,
|
userID string,
|
||||||
eventListener listener.Listener,
|
eventListener listener.Listener,
|
||||||
credStorer CredentialsStorer,
|
credStorer CredentialsStorer,
|
||||||
clientMan *pmapi.ClientManager,
|
clientMan ClientManager,
|
||||||
storeCache *store.Cache,
|
storeCache *store.Cache,
|
||||||
storeDir string,
|
storeDir string,
|
||||||
) (u *User, err error) {
|
) (u *User, err error) {
|
||||||
@ -243,20 +243,9 @@ func (u *User) authorizeAndUnlock() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) AuthorizeWithAPIAuth(auth *pmapi.Auth) {
|
func (u *User) updateAuthToken(auth *pmapi.Auth) {
|
||||||
u.lock.Lock()
|
|
||||||
defer u.lock.Unlock()
|
|
||||||
|
|
||||||
u.log.Debug("User received auth from bridge")
|
u.log.Debug("User received auth from bridge")
|
||||||
|
|
||||||
if auth == nil {
|
|
||||||
if err := u.logout(); err != nil {
|
|
||||||
u.log.WithError(err).Error("Failed to logout user after receiving empty auth from API")
|
|
||||||
}
|
|
||||||
u.isAuthorized = false
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
|
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
|
||||||
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
|
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
|
||||||
return
|
return
|
||||||
@ -510,11 +499,16 @@ func (u *User) logout() error {
|
|||||||
u.lock.Lock()
|
u.lock.Lock()
|
||||||
wasConnected := u.creds.IsConnected()
|
wasConnected := u.creds.IsConnected()
|
||||||
u.lock.Unlock()
|
u.lock.Unlock()
|
||||||
|
|
||||||
err := u.Logout()
|
err := u.Logout()
|
||||||
|
|
||||||
if wasConnected {
|
if wasConnected {
|
||||||
u.listener.Emit(events.LogoutEvent, u.userID)
|
u.listener.Emit(events.LogoutEvent, u.userID)
|
||||||
u.listener.Emit(events.UserRefreshEvent, u.userID)
|
u.listener.Emit(events.UserRefreshEvent, u.userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
u.isAuthorized = false
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -534,6 +528,7 @@ func (u *User) Logout() (err error) {
|
|||||||
u.wasKeyringUnlocked = false
|
u.wasKeyringUnlocked = false
|
||||||
u.unlockingKeyringLock.Unlock()
|
u.unlockingKeyringLock.Unlock()
|
||||||
|
|
||||||
|
// TODO: Is this necessary or could it be done by ClientManager when a nil auth is received?
|
||||||
u.client().Logout()
|
u.client().Logout()
|
||||||
|
|
||||||
if err = u.credStorer.Logout(u.userID); err != nil {
|
if err = u.credStorer.Logout(u.userID); err != nil {
|
||||||
@ -550,6 +545,7 @@ func (u *User) Logout() (err error) {
|
|||||||
u.closeEventLoop()
|
u.closeEventLoop()
|
||||||
|
|
||||||
u.closeAllConnections()
|
u.closeAllConnections()
|
||||||
|
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@ -557,7 +553,7 @@ func (u *User) Logout() (err error) {
|
|||||||
|
|
||||||
func (u *User) refreshFromCredentials() {
|
func (u *User) refreshFromCredentials() {
|
||||||
if credentials, err := u.credStorer.Get(u.userID); err != nil {
|
if credentials, err := u.credStorer.Get(u.userID); err != nil {
|
||||||
log.Error("Cannot update credentials: ", err)
|
log.WithError(err).Error("Cannot refresh user credentials")
|
||||||
} else {
|
} else {
|
||||||
u.creds = credentials
|
u.creds = credentials
|
||||||
}
|
}
|
||||||
|
|||||||
@ -174,13 +174,13 @@ func TestCheckBridgeLoginLoggedOut(t *testing.T) {
|
|||||||
defer m.ctrl.Finish()
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||||
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.pmapiClient, m.storeCache, "/tmp")
|
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient)
|
||||||
|
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized"))
|
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized"))
|
||||||
m.pmapiClient.EXPECT().Addresses().Return(nil)
|
m.pmapiClient.EXPECT().Addresses().Return(nil)
|
||||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||||
_ = user.init(nil, m.pmapiClient)
|
_ = user.init(nil)
|
||||||
|
|
||||||
defer cleanUpUserData(user)
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
|||||||
@ -44,7 +44,6 @@ func TestNewUserBridgeOutdated(t *testing.T) {
|
|||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil).AnyTimes()
|
m.credentialsStore.EXPECT().Logout("user").Return(nil).AnyTimes()
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrUpgradeApplication).AnyTimes()
|
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrUpgradeApplication).AnyTimes()
|
||||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
|
||||||
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "").AnyTimes()
|
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "").AnyTimes()
|
||||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUpgradeApplication)
|
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUpgradeApplication)
|
||||||
m.pmapiClient.EXPECT().Addresses().Return(nil)
|
m.pmapiClient.EXPECT().Addresses().Return(nil)
|
||||||
@ -58,7 +57,6 @@ func TestNewUserNoInternetConnection(t *testing.T) {
|
|||||||
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrAPINotReachable).AnyTimes()
|
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrAPINotReachable).AnyTimes()
|
||||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
|
||||||
m.eventListener.EXPECT().Emit(events.InternetOffEvent, "").AnyTimes()
|
m.eventListener.EXPECT().Emit(events.InternetOffEvent, "").AnyTimes()
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return(nil)
|
m.pmapiClient.EXPECT().Addresses().Return(nil)
|
||||||
@ -75,12 +73,10 @@ func TestNewUserAuthRefreshFails(t *testing.T) {
|
|||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")).AnyTimes()
|
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")).AnyTimes()
|
||||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
|
||||||
|
|
||||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||||
m.pmapiClient.EXPECT().Logout().Return(nil)
|
m.pmapiClient.EXPECT().Logout().Return(nil)
|
||||||
m.pmapiClient.EXPECT().SetAuths(nil)
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||||
@ -96,7 +92,6 @@ func TestNewUserUnlockFails(t *testing.T) {
|
|||||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password"))
|
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password"))
|
||||||
@ -104,7 +99,6 @@ func TestNewUserUnlockFails(t *testing.T) {
|
|||||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||||
m.pmapiClient.EXPECT().Logout().Return(nil)
|
m.pmapiClient.EXPECT().Logout().Return(nil)
|
||||||
m.pmapiClient.EXPECT().SetAuths(nil)
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||||
@ -120,7 +114,6 @@ func TestNewUserUnlockAddressesFails(t *testing.T) {
|
|||||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
||||||
@ -129,7 +122,6 @@ func TestNewUserUnlockAddressesFails(t *testing.T) {
|
|||||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||||
m.pmapiClient.EXPECT().Logout().Return(nil)
|
m.pmapiClient.EXPECT().Logout().Return(nil)
|
||||||
m.pmapiClient.EXPECT().SetAuths(nil)
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||||
@ -144,7 +136,6 @@ func TestNewUser(t *testing.T) {
|
|||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
||||||
|
|||||||
@ -30,7 +30,6 @@ func testNewUser(m mocks) *User {
|
|||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
||||||
@ -57,14 +56,12 @@ func testNewUserForLogout(m mocks) *User {
|
|||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||||
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
m.credentialsStore.EXPECT().UpdateToken("user", ":reftok").Return(nil)
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().SetAuths(gomock.Any())
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil)
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil)
|
||||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil)
|
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil)
|
||||||
|
|
||||||
// These may or may not be hit depending on how fast the log out happens.
|
// These may or may not be hit depending on how fast the log out happens.
|
||||||
m.pmapiClient.EXPECT().SetAuths(nil).AnyTimes()
|
|
||||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil).AnyTimes()
|
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil).AnyTimes()
|
||||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}).AnyTimes()
|
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}).AnyTimes()
|
||||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil)
|
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil)
|
||||||
|
|||||||
@ -527,11 +527,11 @@ func (s *FrontendQt) toggleAllowProxy() {
|
|||||||
|
|
||||||
if s.preferences.GetBool(preferences.AllowProxyKey) {
|
if s.preferences.GetBool(preferences.AllowProxyKey) {
|
||||||
s.preferences.SetBool(preferences.AllowProxyKey, false)
|
s.preferences.SetBool(preferences.AllowProxyKey, false)
|
||||||
bridge.DisallowDoH()
|
s.bridge.DisallowProxy()
|
||||||
s.Qml.SetIsProxyAllowed(false)
|
s.Qml.SetIsProxyAllowed(false)
|
||||||
} else {
|
} else {
|
||||||
s.preferences.SetBool(preferences.AllowProxyKey, true)
|
s.preferences.SetBool(preferences.AllowProxyKey, true)
|
||||||
bridge.AllowDoH()
|
s.bridge.AllowProxy()
|
||||||
s.Qml.SetIsProxyAllowed(true)
|
s.Qml.SetIsProxyAllowed(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -39,7 +39,7 @@ func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, listener listener.Listener) {
|
func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, listener listener.Listener) {
|
||||||
logrus.Info("Setting dialer with pinning")
|
logrus.Info("Setting ClientManager to create clients with key pinning")
|
||||||
|
|
||||||
pin := pmapi.NewDialerWithPinning(cm, cfg.AppVersion)
|
pin := pmapi.NewDialerWithPinning(cm, cfg.AppVersion)
|
||||||
|
|
||||||
@ -47,5 +47,5 @@ func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, lis
|
|||||||
listener.Emit(events.TLSCertIssue, "")
|
listener.Emit(events.TLSCertIssue, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.SetClientRoundTripper(pin.TransportWithPinning())
|
cm.SetRoundTripper(pin.TransportWithPinning())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -217,7 +217,7 @@ func (store *Store) txGetOnAPICounts(tx *bolt.Tx) ([]*mailboxCounts, error) {
|
|||||||
|
|
||||||
// createOrUpdateOnAPICounts will change only on-API-counts.
|
// createOrUpdateOnAPICounts will change only on-API-counts.
|
||||||
func (store *Store) createOrUpdateOnAPICounts(mailboxCountsOnAPI []*pmapi.MessagesCount) error {
|
func (store *Store) createOrUpdateOnAPICounts(mailboxCountsOnAPI []*pmapi.MessagesCount) error {
|
||||||
store.log.WithField("apiCounts", mailboxCountsOnAPI).Debug("Updating API counts")
|
store.log.Debug("Updating API counts")
|
||||||
|
|
||||||
tx := func(tx *bolt.Tx) error {
|
tx := func(tx *bolt.Tx) error {
|
||||||
countsBkt := tx.Bucket(countsBucket)
|
countsBkt := tx.Bucket(countsBucket)
|
||||||
|
|||||||
@ -19,12 +19,9 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ProtonMail/go-appdir"
|
"github.com/ProtonMail/go-appdir"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
@ -73,11 +70,7 @@ func newConfig(appName, version, revision, cacheVersion string, appDirs, appDirs
|
|||||||
apiConfig: &pmapi.ClientConfig{
|
apiConfig: &pmapi.ClientConfig{
|
||||||
AppVersion: strings.Title(appName) + "_" + version,
|
AppVersion: strings.Title(appName) + "_" + version,
|
||||||
ClientID: appName,
|
ClientID: appName,
|
||||||
Transport: &http.Transport{
|
SentryDSN: "https://bacfb56338a7471a9fede610046afdda:ab437b0d13f54602a0f5feb684e6d319@api.protonmail.ch/reports/sentry/8",
|
||||||
DialContext: (&net.Dialer{Timeout: 3 * time.Second}).DialContext,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
ResponseHeaderTimeout: 10 * time.Second,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -118,11 +118,18 @@ type Auth struct {
|
|||||||
TwoFA *TwoFactorInfo `json:"2FA,omitempty"`
|
TwoFA *TwoFactorInfo `json:"2FA,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UID returns the session UID from the Auth.
|
||||||
|
// Only Auths generated from the /auth route will have the UID.
|
||||||
|
// Auths generated from /auth/refresh are not required to.
|
||||||
func (s *Auth) UID() string {
|
func (s *Auth) UID() string {
|
||||||
return s.uid
|
return s.uid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Auth) GenToken() string {
|
func (s *Auth) GenToken() string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%v:%v", s.UID(), s.RefreshToken)
|
return fmt.Sprintf("%v:%v", s.UID(), s.RefreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,7 +154,9 @@ type AuthRes struct {
|
|||||||
|
|
||||||
AccessToken string
|
AccessToken string
|
||||||
TokenType string
|
TokenType string
|
||||||
UID string
|
|
||||||
|
// UID is the session UID. This is only present in an initial Auth (/auth), not in a refreshed Auth (/auth/refresh).
|
||||||
|
UID string
|
||||||
|
|
||||||
ServerProof string
|
ServerProof string
|
||||||
}
|
}
|
||||||
@ -196,18 +205,21 @@ type AuthRefreshReq struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) sendAuth(auth *Auth) {
|
func (c *Client) sendAuth(auth *Auth) {
|
||||||
go func() {
|
c.log.Debug("Client is sending auth to ClientManager")
|
||||||
c.log.Debug("Client is sending auth to ClientManager")
|
|
||||||
|
|
||||||
|
if auth != nil {
|
||||||
|
// UID is only provided in the initial /auth, not during /auth/refresh
|
||||||
|
if auth.UID() != "" {
|
||||||
|
c.uid = auth.UID()
|
||||||
|
}
|
||||||
|
c.accessToken = auth.accessToken
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
c.cm.getClientAuthChannel() <- ClientAuth{
|
c.cm.getClientAuthChannel() <- ClientAuth{
|
||||||
UserID: c.userID,
|
UserID: c.userID,
|
||||||
Auth: auth,
|
Auth: auth,
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth != nil {
|
|
||||||
c.uid = auth.UID()
|
|
||||||
c.accessToken = auth.accessToken
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -446,6 +458,7 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
auth = res.getAuth()
|
auth = res.getAuth()
|
||||||
|
|
||||||
c.sendAuth(auth)
|
c.sendAuth(auth)
|
||||||
|
|
||||||
return auth, err
|
return auth, err
|
||||||
@ -456,6 +469,8 @@ func (c *Client) Logout() {
|
|||||||
c.cm.LogoutClient(c.userID)
|
c.cm.LogoutClient(c.userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Need a method like IsConnected() to be able to detect whether a client is logged in or not.
|
||||||
|
|
||||||
// logout logs the current user out.
|
// logout logs the current user out.
|
||||||
func (c *Client) logout() (err error) {
|
func (c *Client) logout() (err error) {
|
||||||
req, err := c.NewRequest("DELETE", "/auth", nil)
|
req, err := c.NewRequest("DELETE", "/auth", nil)
|
||||||
|
|||||||
@ -279,6 +279,7 @@ func TestClient_AuthRefresh(t *testing.T) {
|
|||||||
|
|
||||||
exp := &Auth{}
|
exp := &Auth{}
|
||||||
*exp = *testAuth
|
*exp = *testAuth
|
||||||
|
exp.uid = "" // AuthRefresh will not return UID (only Auth returns the UID).
|
||||||
exp.accessToken = testAccessToken
|
exp.accessToken = testAccessToken
|
||||||
exp.KeySalt = ""
|
exp.KeySalt = ""
|
||||||
exp.EventID = ""
|
exp.EventID = ""
|
||||||
@ -329,12 +330,20 @@ func TestClient_Logout(t *testing.T) {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
defer finish()
|
defer finish()
|
||||||
|
|
||||||
c.uid = testUID
|
c.uid = testUID
|
||||||
c.accessToken = testAccessToken
|
c.accessToken = testAccessToken
|
||||||
|
|
||||||
c.Logout()
|
c.Logout()
|
||||||
|
|
||||||
// TODO: Check that the client is logged out and sensitive data is cleared eventually.
|
r.Eventually(t, func() bool {
|
||||||
|
// TODO: Use a method like IsConnected() which returns whether the client was logged out or not.
|
||||||
|
return c.accessToken == "" &&
|
||||||
|
c.uid == "" &&
|
||||||
|
c.kr == nil &&
|
||||||
|
c.addresses == nil &&
|
||||||
|
c.user == nil
|
||||||
|
}, 10*time.Second, 10*time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_DoUnauthorized(t *testing.T) {
|
func TestClient_DoUnauthorized(t *testing.T) {
|
||||||
@ -354,7 +363,6 @@ func TestClient_DoUnauthorized(t *testing.T) {
|
|||||||
|
|
||||||
c.uid = testUID
|
c.uid = testUID
|
||||||
c.accessToken = testAccessTokenOld
|
c.accessToken = testAccessTokenOld
|
||||||
c.expiresAt = aLongTimeAgo
|
|
||||||
c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken
|
c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken
|
||||||
|
|
||||||
req, err := c.NewRequest("GET", "/", nil)
|
req, err := c.NewRequest("GET", "/", nil)
|
||||||
|
|||||||
@ -79,11 +79,6 @@ type ClientConfig struct {
|
|||||||
// The sentry DSN.
|
// The sentry DSN.
|
||||||
SentryDSN string
|
SentryDSN string
|
||||||
|
|
||||||
// Transport specifies the mechanism by which individual HTTP requests are made.
|
|
||||||
// If nil, http.DefaultTransport is used.
|
|
||||||
// TODO: This could be removed entirely and set in the client manager via SetClientRoundTripper.
|
|
||||||
Transport http.RoundTripper
|
|
||||||
|
|
||||||
// Timeout specifies the timeout from request to getting response headers to our API.
|
// Timeout specifies the timeout from request to getting response headers to our API.
|
||||||
// Passed to http.Client, empty means no timeout.
|
// Passed to http.Client, empty means no timeout.
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
@ -120,7 +115,7 @@ type Client struct {
|
|||||||
func newClient(cm *ClientManager, userID string) *Client {
|
func newClient(cm *ClientManager, userID string) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
cm: cm,
|
cm: cm,
|
||||||
hc: getHTTPClient(cm.GetConfig()),
|
hc: getHTTPClient(cm.GetConfig(), cm.GetRoundTripper()),
|
||||||
userID: userID,
|
userID: userID,
|
||||||
requestLocker: &sync.Mutex{},
|
requestLocker: &sync.Mutex{},
|
||||||
keyLocker: &sync.Mutex{},
|
keyLocker: &sync.Mutex{},
|
||||||
@ -128,32 +123,12 @@ func newClient(cm *ClientManager, userID string) *Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getHTTPClient returns a http client configured by the given client config.
|
// getHTTPClient returns a http client configured by the given client config and using the given transport.
|
||||||
func getHTTPClient(cfg *ClientConfig) (hc *http.Client) {
|
func getHTTPClient(cfg *ClientConfig, rt http.RoundTripper) (hc *http.Client) {
|
||||||
hc = &http.Client{Timeout: cfg.Timeout}
|
return &http.Client{
|
||||||
|
Timeout: cfg.Timeout,
|
||||||
if cfg.Transport == nil {
|
Transport: rt,
|
||||||
if defaultTransport != nil {
|
|
||||||
hc.Transport = defaultTransport
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// In future use Clone here.
|
|
||||||
// https://go-review.googlesource.com/c/go/+/174597/
|
|
||||||
if cfgTransport, ok := cfg.Transport.(*http.Transport); ok {
|
|
||||||
transport := &http.Transport{}
|
|
||||||
*transport = *cfgTransport //nolint
|
|
||||||
if transport.Proxy == nil {
|
|
||||||
transport.Proxy = http.ProxyFromEnvironment
|
|
||||||
}
|
|
||||||
hc.Transport = transport
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hc.Transport = cfg.Transport
|
|
||||||
|
|
||||||
return hc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do makes an API request. It does not check for HTTP status code errors.
|
// Do makes an API request. It does not check for HTTP status code errors.
|
||||||
|
|||||||
@ -36,9 +36,8 @@ var testClientConfig = &ClientConfig{
|
|||||||
MinSpeed: 256,
|
MinSpeed: 256,
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestClient() *Client {
|
func newTestClient(cm *ClientManager) *Client {
|
||||||
c := newClient(NewClientManager(testClientConfig), "tester")
|
return cm.GetClient("tester")
|
||||||
return c
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_Do(t *testing.T) {
|
func TestClient_Do(t *testing.T) {
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package pmapi
|
package pmapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -10,27 +11,31 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var proxyUseDuration = 24 * time.Hour
|
var defaultProxyUseDuration = 24 * time.Hour
|
||||||
|
|
||||||
// ClientManager is a manager of clients.
|
// ClientManager is a manager of clients.
|
||||||
type ClientManager struct {
|
type ClientManager struct {
|
||||||
config *ClientConfig
|
config *ClientConfig
|
||||||
|
roundTripper http.RoundTripper
|
||||||
|
|
||||||
clients map[string]*Client
|
clients map[string]*Client
|
||||||
clientsLocker sync.Locker
|
clientsLocker sync.Locker
|
||||||
|
|
||||||
tokens map[string]string
|
tokens map[string]string
|
||||||
tokenExpirations map[string]*tokenExpiration
|
tokensLocker sync.Locker
|
||||||
tokensLocker sync.Locker
|
|
||||||
|
|
||||||
url string
|
expirations map[string]*tokenExpiration
|
||||||
urlLocker sync.Locker
|
expirationsLocker sync.Locker
|
||||||
|
|
||||||
|
host, scheme string
|
||||||
|
hostLocker sync.Locker
|
||||||
|
|
||||||
bridgeAuths chan ClientAuth
|
bridgeAuths chan ClientAuth
|
||||||
clientAuths chan ClientAuth
|
clientAuths chan ClientAuth
|
||||||
|
|
||||||
allowProxy bool
|
allowProxy bool
|
||||||
proxyProvider *proxyProvider
|
proxyProvider *proxyProvider
|
||||||
|
proxyUseDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientAuth struct {
|
type ClientAuth struct {
|
||||||
@ -50,22 +55,27 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cm = &ClientManager{
|
cm = &ClientManager{
|
||||||
config: config,
|
config: config,
|
||||||
|
roundTripper: http.DefaultTransport,
|
||||||
|
|
||||||
clients: make(map[string]*Client),
|
clients: make(map[string]*Client),
|
||||||
clientsLocker: &sync.Mutex{},
|
clientsLocker: &sync.Mutex{},
|
||||||
|
|
||||||
tokens: make(map[string]string),
|
tokens: make(map[string]string),
|
||||||
tokenExpirations: make(map[string]*tokenExpiration),
|
tokensLocker: &sync.Mutex{},
|
||||||
tokensLocker: &sync.Mutex{},
|
|
||||||
|
|
||||||
url: RootURL,
|
expirations: make(map[string]*tokenExpiration),
|
||||||
urlLocker: &sync.Mutex{},
|
expirationsLocker: &sync.Mutex{},
|
||||||
|
|
||||||
|
host: RootURL,
|
||||||
|
scheme: RootScheme,
|
||||||
|
hostLocker: &sync.Mutex{},
|
||||||
|
|
||||||
bridgeAuths: make(chan ClientAuth),
|
bridgeAuths: make(chan ClientAuth),
|
||||||
clientAuths: make(chan ClientAuth),
|
clientAuths: make(chan ClientAuth),
|
||||||
|
|
||||||
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
|
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
|
||||||
|
proxyUseDuration: defaultProxyUseDuration,
|
||||||
}
|
}
|
||||||
|
|
||||||
go cm.forwardClientAuths()
|
go cm.forwardClientAuths()
|
||||||
@ -73,10 +83,14 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetClientRoundTripper sets the roundtripper used by clients created by this client manager.
|
// SetRoundTripper sets the roundtripper used by clients created by this client manager.
|
||||||
func (cm *ClientManager) SetClientRoundTripper(rt http.RoundTripper) {
|
func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
|
||||||
logrus.Info("Setting client roundtripper")
|
cm.roundTripper = rt
|
||||||
cm.config.Transport = rt
|
}
|
||||||
|
|
||||||
|
// GetRoundTripper sets the roundtripper used by clients created by this client manager.
|
||||||
|
func (cm *ClientManager) GetRoundTripper() (rt http.RoundTripper) {
|
||||||
|
return cm.roundTripper
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClient returns a client for the given userID.
|
// GetClient returns a client for the given userID.
|
||||||
@ -91,6 +105,17 @@ func (cm *ClientManager) GetClient(userID string) *Client {
|
|||||||
return cm.clients[userID]
|
return cm.clients[userID]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created.
|
||||||
|
func (cm *ClientManager) GetAnonymousClient() *Client {
|
||||||
|
if client, ok := cm.clients[""]; ok {
|
||||||
|
client.Logout()
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.clients[""] = newClient(cm, "")
|
||||||
|
|
||||||
|
return cm.clients[""]
|
||||||
|
}
|
||||||
|
|
||||||
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
||||||
func (cm *ClientManager) LogoutClient(userID string) {
|
func (cm *ClientManager) LogoutClient(userID string) {
|
||||||
client, ok := cm.clients[userID]
|
client, ok := cm.clients[userID]
|
||||||
@ -104,7 +129,6 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
|||||||
go func() {
|
go func() {
|
||||||
if err := client.logout(); err != nil {
|
if err := client.logout(); err != nil {
|
||||||
// TODO: Try again! This should loop until it succeeds (might fail the first time due to internet).
|
// TODO: Try again! This should loop until it succeeds (might fail the first time due to internet).
|
||||||
logrus.WithError(err).Error("Client logout failed, not trying again")
|
|
||||||
}
|
}
|
||||||
client.clearSensitiveData()
|
client.clearSensitiveData()
|
||||||
cm.clearToken(userID)
|
cm.clearToken(userID)
|
||||||
@ -113,52 +137,69 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRootURL returns the root URL to make requests to.
|
// GetHost returns the host to make requests to.
|
||||||
// It does not include the protocol i.e. no "https://".
|
// It does not include the protocol i.e. no "https://" (use GetScheme for that).
|
||||||
func (cm *ClientManager) GetRootURL() string {
|
func (cm *ClientManager) GetHost() string {
|
||||||
cm.urlLocker.Lock()
|
cm.hostLocker.Lock()
|
||||||
defer cm.urlLocker.Unlock()
|
defer cm.hostLocker.Unlock()
|
||||||
|
|
||||||
return cm.url
|
return cm.host
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetScheme returns the scheme with which to make requests to the host.
|
||||||
|
func (cm *ClientManager) GetScheme() string {
|
||||||
|
cm.hostLocker.Lock()
|
||||||
|
defer cm.hostLocker.Unlock()
|
||||||
|
|
||||||
|
return cm.scheme
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRootURL returns the full root URL (scheme+host).
|
||||||
|
func (cm *ClientManager) GetRootURL() string {
|
||||||
|
cm.hostLocker.Lock()
|
||||||
|
defer cm.hostLocker.Unlock()
|
||||||
|
|
||||||
|
return fmt.Sprintf("%v://%v", cm.scheme, cm.host)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
||||||
func (cm *ClientManager) IsProxyAllowed() bool {
|
func (cm *ClientManager) IsProxyAllowed() bool {
|
||||||
cm.urlLocker.Lock()
|
cm.hostLocker.Lock()
|
||||||
defer cm.urlLocker.Unlock()
|
defer cm.hostLocker.Unlock()
|
||||||
|
|
||||||
return cm.allowProxy
|
return cm.allowProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
|
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
|
||||||
func (cm *ClientManager) AllowProxy() {
|
func (cm *ClientManager) AllowProxy() {
|
||||||
cm.urlLocker.Lock()
|
cm.hostLocker.Lock()
|
||||||
defer cm.urlLocker.Unlock()
|
defer cm.hostLocker.Unlock()
|
||||||
|
|
||||||
cm.allowProxy = true
|
cm.allowProxy = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
|
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
|
||||||
func (cm *ClientManager) DisallowProxy() {
|
func (cm *ClientManager) DisallowProxy() {
|
||||||
cm.urlLocker.Lock()
|
cm.hostLocker.Lock()
|
||||||
defer cm.urlLocker.Unlock()
|
defer cm.hostLocker.Unlock()
|
||||||
|
|
||||||
cm.allowProxy = false
|
cm.allowProxy = false
|
||||||
cm.url = RootURL
|
cm.host = RootURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsProxyEnabled returns whether we are currently proxying requests.
|
// IsProxyEnabled returns whether we are currently proxying requests.
|
||||||
func (cm *ClientManager) IsProxyEnabled() bool {
|
func (cm *ClientManager) IsProxyEnabled() bool {
|
||||||
cm.urlLocker.Lock()
|
cm.hostLocker.Lock()
|
||||||
defer cm.urlLocker.Unlock()
|
defer cm.hostLocker.Unlock()
|
||||||
|
|
||||||
return cm.url != RootURL
|
return cm.host != RootURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindProxy returns a usable proxy server.
|
// SwitchToProxy returns a usable proxy server.
|
||||||
|
// TODO: Perhaps the name could be better -- we aren't only switching to a proxy but also to the standard API.
|
||||||
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
|
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
|
||||||
cm.urlLocker.Lock()
|
cm.hostLocker.Lock()
|
||||||
defer cm.urlLocker.Unlock()
|
defer cm.hostLocker.Unlock()
|
||||||
|
|
||||||
logrus.Info("Attempting to switch to a proxy")
|
logrus.Info("Attempting to switch to a proxy")
|
||||||
|
|
||||||
@ -169,9 +210,16 @@ func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
|
|||||||
|
|
||||||
logrus.WithField("proxy", proxy).Info("Switching to a proxy")
|
logrus.WithField("proxy", proxy).Info("Switching to a proxy")
|
||||||
|
|
||||||
cm.url = proxy
|
// If the host is currently the RootURL, it's the first time we are enabling a proxy.
|
||||||
|
// This means we want to disable it again in 24 hours.
|
||||||
|
if cm.host == RootURL {
|
||||||
|
go func() {
|
||||||
|
<-time.After(cm.proxyUseDuration)
|
||||||
|
cm.host = RootURL
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Disable again after 24 hours.
|
cm.host = proxy
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -183,6 +231,9 @@ func (cm *ClientManager) GetConfig() *ClientConfig {
|
|||||||
|
|
||||||
// GetToken returns the token for the given userID.
|
// GetToken returns the token for the given userID.
|
||||||
func (cm *ClientManager) GetToken(userID string) string {
|
func (cm *ClientManager) GetToken(userID string) string {
|
||||||
|
cm.tokensLocker.Lock()
|
||||||
|
defer cm.tokensLocker.Unlock()
|
||||||
|
|
||||||
return cm.tokens[userID]
|
return cm.tokens[userID]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -208,6 +259,11 @@ func (cm *ClientManager) forwardClientAuths() {
|
|||||||
|
|
||||||
// setToken sets the token for the given userID with the given expiration time.
|
// setToken sets the token for the given userID with the given expiration time.
|
||||||
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
|
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
|
||||||
|
// We don't want to set tokens of anonymous clients.
|
||||||
|
if userID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
cm.tokensLocker.Lock()
|
cm.tokensLocker.Lock()
|
||||||
defer cm.tokensLocker.Unlock()
|
defer cm.tokensLocker.Unlock()
|
||||||
|
|
||||||
@ -221,12 +277,15 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
|
|||||||
// setTokenExpiration will ensure the token is refreshed if it expires.
|
// setTokenExpiration will ensure the token is refreshed if it expires.
|
||||||
// If the token already has an expiration time set, it is replaced.
|
// If the token already has an expiration time set, it is replaced.
|
||||||
func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) {
|
func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) {
|
||||||
if exp, ok := cm.tokenExpirations[userID]; ok {
|
cm.expirationsLocker.Lock()
|
||||||
|
defer cm.expirationsLocker.Unlock()
|
||||||
|
|
||||||
|
if exp, ok := cm.expirations[userID]; ok {
|
||||||
exp.timer.Stop()
|
exp.timer.Stop()
|
||||||
close(exp.cancel)
|
close(exp.cancel)
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.tokenExpirations[userID] = &tokenExpiration{
|
cm.expirations[userID] = &tokenExpiration{
|
||||||
timer: time.NewTimer(expiration),
|
timer: time.NewTimer(expiration),
|
||||||
cancel: make(chan struct{}),
|
cancel: make(chan struct{}),
|
||||||
}
|
}
|
||||||
@ -262,7 +321,7 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ClientManager) watchTokenExpiration(userID string) {
|
func (cm *ClientManager) watchTokenExpiration(userID string) {
|
||||||
expiration := cm.tokenExpirations[userID]
|
expiration := cm.expirations[userID]
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-expiration.timer.C:
|
case <-expiration.timer.C:
|
||||||
@ -270,6 +329,6 @@ func (cm *ClientManager) watchTokenExpiration(userID string) {
|
|||||||
cm.clients[userID].AuthRefresh(cm.tokens[userID])
|
cm.clients[userID].AuthRefresh(cm.tokens[userID])
|
||||||
|
|
||||||
case <-expiration.cancel:
|
case <-expiration.cancel:
|
||||||
logrus.WithField("userID", userID).Info("Auth was refreshed before it expired, cancelling this watcher")
|
logrus.WithField("userID", userID).Info("Auth was refreshed before it expired")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,8 +26,12 @@ import (
|
|||||||
//
|
//
|
||||||
// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
|
// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
|
||||||
// Default is pmapi_prod.
|
// Default is pmapi_prod.
|
||||||
|
//
|
||||||
|
// It should not contain the protocol! The protocol should be in RootScheme.
|
||||||
var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
|
var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
|
||||||
|
|
||||||
|
var RootScheme = "https"
|
||||||
|
|
||||||
// CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program
|
// CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program
|
||||||
// version and email client.
|
// version and email client.
|
||||||
// e.g. Bridge/1.0.4 (Windows) MicrosoftOutlook/16.0.9330.2087
|
// e.g. Bridge/1.0.4 (Windows) MicrosoftOutlook/16.0.9330.2087
|
||||||
|
|||||||
@ -21,4 +21,5 @@ package pmapi
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
RootURL = "dev.protonmail.com/api"
|
RootURL = "dev.protonmail.com/api"
|
||||||
|
RootScheme = "https"
|
||||||
}
|
}
|
||||||
|
|||||||
@ -28,6 +28,7 @@ func init() {
|
|||||||
// Use port above 1000 which doesn't need root access to start anything on it.
|
// Use port above 1000 which doesn't need root access to start anything on it.
|
||||||
// Now the port is rounded pi. :-)
|
// Now the port is rounded pi. :-)
|
||||||
RootURL = "127.0.0.1:3142/api"
|
RootURL = "127.0.0.1:3142/api"
|
||||||
|
RootScheme = "http"
|
||||||
|
|
||||||
// TLS certificate is self-signed
|
// TLS certificate is self-signed
|
||||||
defaultTransport = &http.Transport{
|
defaultTransport = &http.Transport{
|
||||||
|
|||||||
@ -654,7 +654,7 @@ var testCardsCleartext = []Card{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_Encrypt(t *testing.T) {
|
func TestClient_Encrypt(t *testing.T) {
|
||||||
c := newTestClient()
|
c := newTestClient(NewClientManager(testClientConfig))
|
||||||
c.kr = testPrivateKeyRing
|
c.kr = testPrivateKeyRing
|
||||||
|
|
||||||
cardEncrypted, err := c.EncryptAndSignCards(testCardsCleartext)
|
cardEncrypted, err := c.EncryptAndSignCards(testCardsCleartext)
|
||||||
@ -668,7 +668,7 @@ func TestClient_Encrypt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_Decrypt(t *testing.T) {
|
func TestClient_Decrypt(t *testing.T) {
|
||||||
c := newTestClient()
|
c := newTestClient(NewClientManager(testClientConfig))
|
||||||
c.kr = testPrivateKeyRing
|
c.kr = testPrivateKeyRing
|
||||||
|
|
||||||
cardCleartext, err := c.DecryptAndVerifyCards(testCardsEncrypted)
|
cardCleartext, err := c.DecryptAndVerifyCards(testCardsEncrypted)
|
||||||
|
|||||||
@ -297,7 +297,7 @@ func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn
|
|||||||
// If DoH is not allowed, give up. Or, if we are dialing something other than the API
|
// If DoH is not allowed, give up. Or, if we are dialing something other than the API
|
||||||
// (e.g. we dial protonmail.com/... to check for updates), there's also no point in
|
// (e.g. we dial protonmail.com/... to check for updates), there's also no point in
|
||||||
// continuing since a proxy won't help us reach that.
|
// continuing since a proxy won't help us reach that.
|
||||||
if !p.cm.IsProxyAllowed() || host != p.cm.GetRootURL() {
|
if !p.cm.IsProxyAllowed() || host != p.cm.GetHost() {
|
||||||
p.log.WithField("address", address).Debug("Aborting dial, cannot switch to a proxy")
|
p.log.WithField("address", address).Debug("Aborting dial, cannot switch to a proxy")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,26 +23,27 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
const liveAPI = "https://api.protonmail.ch"
|
const liveAPI = "api.protonmail.ch"
|
||||||
|
|
||||||
var testLiveConfig = &ClientConfig{
|
var testLiveConfig = &ClientConfig{
|
||||||
AppVersion: "Bridge_1.2.4-test",
|
AppVersion: "Bridge_1.2.4-test",
|
||||||
ClientID: "Bridge",
|
ClientID: "Bridge",
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestDialerWithPinning() (*int, *DialerWithPinning) {
|
func setTestDialerWithPinning(cm *ClientManager) (*int, *DialerWithPinning) {
|
||||||
called := 0
|
called := 0
|
||||||
p := NewPMAPIPinning(testLiveConfig.AppVersion)
|
p := NewDialerWithPinning(cm, testLiveConfig.AppVersion)
|
||||||
p.ReportCertIssueLocal = func() { called++ }
|
p.ReportCertIssueLocal = func() { called++ }
|
||||||
testLiveConfig.Transport = p.TransportWithPinning()
|
cm.SetRoundTripper(p.TransportWithPinning())
|
||||||
return &called, p
|
return &called, p
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSPinValid(t *testing.T) {
|
func TestTLSPinValid(t *testing.T) {
|
||||||
called, _ := newTestDialerWithPinning()
|
cm := NewClientManager(testLiveConfig)
|
||||||
|
cm.host = liveAPI
|
||||||
RootURL = liveAPI
|
RootScheme = "https"
|
||||||
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
called, _ := setTestDialerWithPinning(cm)
|
||||||
|
client := cm.GetClient("pmapi" + t.Name())
|
||||||
|
|
||||||
_, err := client.AuthInfo("this.address.is.disabled")
|
_, err := client.AuthInfo("this.address.is.disabled")
|
||||||
Ok(t, err)
|
Ok(t, err)
|
||||||
@ -51,12 +52,13 @@ func TestTLSPinValid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSPinBackup(t *testing.T) {
|
func TestTLSPinBackup(t *testing.T) {
|
||||||
called, p := newTestDialerWithPinning()
|
cm := NewClientManager(testLiveConfig)
|
||||||
|
cm.host = liveAPI
|
||||||
|
called, p := setTestDialerWithPinning(cm)
|
||||||
p.report.KnownPins[1] = p.report.KnownPins[0]
|
p.report.KnownPins[1] = p.report.KnownPins[0]
|
||||||
p.report.KnownPins[0] = ""
|
p.report.KnownPins[0] = ""
|
||||||
|
|
||||||
RootURL = liveAPI
|
client := cm.GetClient("pmapi" + t.Name())
|
||||||
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
|
||||||
|
|
||||||
_, err := client.AuthInfo("this.address.is.disabled")
|
_, err := client.AuthInfo("this.address.is.disabled")
|
||||||
Ok(t, err)
|
Ok(t, err)
|
||||||
@ -65,19 +67,21 @@ func TestTLSPinBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused]
|
func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused]
|
||||||
called, p := newTestDialerWithPinning()
|
cm := NewClientManager(testLiveConfig)
|
||||||
|
cm.host = liveAPI
|
||||||
|
|
||||||
|
called, p := setTestDialerWithPinning(cm)
|
||||||
for i := 0; i < len(p.report.KnownPins); i++ {
|
for i := 0; i < len(p.report.KnownPins); i++ {
|
||||||
p.report.KnownPins[i] = "testing"
|
p.report.KnownPins[i] = "testing"
|
||||||
}
|
}
|
||||||
|
|
||||||
RootURL = liveAPI
|
client := cm.GetClient("pmapi" + t.Name())
|
||||||
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
|
||||||
|
|
||||||
_, err := client.AuthInfo("this.address.is.disabled")
|
_, err := client.AuthInfo("this.address.is.disabled")
|
||||||
Ok(t, err)
|
Ok(t, err)
|
||||||
|
|
||||||
// check that it will be called only once per session
|
// check that it will be called only once per session
|
||||||
client = newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
client = cm.GetClient("pmapi" + t.Name())
|
||||||
_, err = client.AuthInfo("this.address.is.disabled")
|
_, err = client.AuthInfo("this.address.is.disabled")
|
||||||
Ok(t, err)
|
Ok(t, err)
|
||||||
|
|
||||||
@ -85,20 +89,22 @@ func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused]
|
|||||||
}
|
}
|
||||||
|
|
||||||
func _TestTLSPinInvalid(t *testing.T) { // nolint[unused]
|
func _TestTLSPinInvalid(t *testing.T) { // nolint[unused]
|
||||||
|
cm := NewClientManager(testLiveConfig)
|
||||||
|
|
||||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
writeJSONResponsefromFile(t, w, "/auth/info/post_response.json", 0)
|
writeJSONResponsefromFile(t, w, "/auth/info/post_response.json", 0)
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
called, _ := newTestDialerWithPinning()
|
called, _ := setTestDialerWithPinning(cm)
|
||||||
|
|
||||||
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
client := cm.GetClient("pmapi" + t.Name())
|
||||||
|
|
||||||
RootURL = liveAPI
|
cm.host = liveAPI
|
||||||
_, err := client.AuthInfo("this.address.is.disabled")
|
_, err := client.AuthInfo("this.address.is.disabled")
|
||||||
Ok(t, err)
|
Ok(t, err)
|
||||||
|
|
||||||
RootURL = ts.URL
|
cm.host = ts.URL
|
||||||
_, err = client.AuthInfo("this.address.is.disabled")
|
_, err = client.AuthInfo("this.address.is.disabled")
|
||||||
Assert(t, err != nil, "error is expected but have %v", err)
|
Assert(t, err != nil, "error is expected but have %v", err)
|
||||||
|
|
||||||
@ -106,20 +112,23 @@ func _TestTLSPinInvalid(t *testing.T) { // nolint[unused]
|
|||||||
}
|
}
|
||||||
|
|
||||||
func _TestTLSSignedCertWrongPublicKey(t *testing.T) { // nolint[unused]
|
func _TestTLSSignedCertWrongPublicKey(t *testing.T) { // nolint[unused]
|
||||||
_, dialer := newTestDialerWithPinning()
|
cm := NewClientManager(testLiveConfig)
|
||||||
|
_, dialer := setTestDialerWithPinning(cm)
|
||||||
_, err := dialer.dialAndCheckFingerprints("tcp", "rsa4096.badssl.com:443")
|
_, err := dialer.dialAndCheckFingerprints("tcp", "rsa4096.badssl.com:443")
|
||||||
Assert(t, err != nil, "expected dial to fail because of wrong public key: ", err.Error())
|
Assert(t, err != nil, "expected dial to fail because of wrong public key: ", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
func _TestTLSSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused]
|
func _TestTLSSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused]
|
||||||
_, dialer := newTestDialerWithPinning()
|
cm := NewClientManager(testLiveConfig)
|
||||||
|
_, dialer := setTestDialerWithPinning(cm)
|
||||||
dialer.report.KnownPins = append(dialer.report.KnownPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`)
|
dialer.report.KnownPins = append(dialer.report.KnownPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`)
|
||||||
_, err := dialer.dialAndCheckFingerprints("tcp", "rsa4096.badssl.com:443")
|
_, err := dialer.dialAndCheckFingerprints("tcp", "rsa4096.badssl.com:443")
|
||||||
Assert(t, err == nil, "expected dial to succeed because public key is known and cert is signed by CA: ", err.Error())
|
Assert(t, err == nil, "expected dial to succeed because public key is known and cert is signed by CA: ", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
func _TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused]
|
func _TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused]
|
||||||
_, dialer := newTestDialerWithPinning()
|
cm := NewClientManager(testLiveConfig)
|
||||||
|
_, dialer := setTestDialerWithPinning(cm)
|
||||||
dialer.report.KnownPins = append(dialer.report.KnownPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
|
dialer.report.KnownPins = append(dialer.report.KnownPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
|
||||||
_, err := dialer.dialAndCheckFingerprints("tcp", "self-signed.badssl.com:443")
|
_, err := dialer.dialAndCheckFingerprints("tcp", "self-signed.badssl.com:443")
|
||||||
Assert(t, err == nil, "expected dial to succeed because public key is known despite cert being self-signed: ", err.Error())
|
Assert(t, err == nil, "expected dial to succeed because public key is known despite cert being self-signed: ", err.Error())
|
||||||
|
|||||||
@ -72,8 +72,9 @@ func newProxyProvider(providers []string, query string) (p *proxyProvider) { //
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// findProxy returns a new proxy domain which is not equal to the current RootURL.
|
// findProxy returns a new working proxy domain. This includes the standard API.
|
||||||
// It returns an error if the process takes longer than ProxySearchTime.
|
// It returns an error if the process takes longer than ProxySearchTime.
|
||||||
|
// TODO: Perhaps the name can be better -- we might also return the standard API.
|
||||||
func (p *proxyProvider) findProxy() (proxy string, err error) {
|
func (p *proxyProvider) findProxy() (proxy string, err error) {
|
||||||
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
|
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
|
||||||
return "", errors.New("not looking for a proxy, too soon")
|
return "", errors.New("not looking for a proxy, too soon")
|
||||||
@ -88,6 +89,12 @@ func (p *proxyProvider) findProxy() (proxy string, err error) {
|
|||||||
logrus.WithError(err).Warn("Failed to refresh proxy cache, cache may be out of date")
|
logrus.WithError(err).Warn("Failed to refresh proxy cache, cache may be out of date")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We want to switch back to the RootURL if possible.
|
||||||
|
if p.canReach(RootURL) {
|
||||||
|
proxyResult <- RootURL
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
for _, proxy := range p.proxyCache {
|
for _, proxy := range p.proxyCache {
|
||||||
if p.canReach(proxy) {
|
if p.canReach(proxy) {
|
||||||
proxyResult <- proxy
|
proxyResult <- proxy
|
||||||
@ -114,6 +121,7 @@ func (p *proxyProvider) findProxy() (proxy string, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// refreshProxyCache loads the latest proxies from the known providers.
|
// refreshProxyCache loads the latest proxies from the known providers.
|
||||||
|
// It includes the standard API.
|
||||||
func (p *proxyProvider) refreshProxyCache() error {
|
func (p *proxyProvider) refreshProxyCache() error {
|
||||||
logrus.Info("Refreshing proxy cache")
|
logrus.Info("Refreshing proxy cache")
|
||||||
|
|
||||||
@ -121,9 +129,6 @@ func (p *proxyProvider) refreshProxyCache() error {
|
|||||||
if proxies, err := p.dohLookup(p.query, provider); err == nil {
|
if proxies, err := p.dohLookup(p.query, provider); err == nil {
|
||||||
p.proxyCache = proxies
|
p.proxyCache = proxies
|
||||||
|
|
||||||
// We also want to allow bridge to switch back to the standard API at any time.
|
|
||||||
p.proxyCache = append(p.proxyCache, RootURL)
|
|
||||||
|
|
||||||
logrus.WithField("proxies", proxies).Info("Available proxies")
|
logrus.WithField("proxies", proxies).Info("Available proxies")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -122,23 +122,27 @@ func TestProxyProvider_UseProxy(t *testing.T) {
|
|||||||
blockAPI()
|
blockAPI()
|
||||||
defer unblockAPI()
|
defer unblockAPI()
|
||||||
|
|
||||||
|
cm := NewClientManager(testClientConfig)
|
||||||
|
|
||||||
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
defer proxy.Close()
|
defer proxy.Close()
|
||||||
|
|
||||||
p := newProxyProvider([]string{"not used"}, "not used")
|
p := newProxyProvider([]string{"not used"}, "not used")
|
||||||
|
cm.proxyProvider = p
|
||||||
|
|
||||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||||
|
url, err := cm.SwitchToProxy()
|
||||||
url, err := p.findProxy()
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, proxy.URL, url)
|
||||||
p.useProxy(url)
|
require.Equal(t, proxy.URL, cm.GetHost())
|
||||||
require.Equal(t, proxy.URL, GlobalGetRootURL())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
|
func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
|
||||||
blockAPI()
|
blockAPI()
|
||||||
defer unblockAPI()
|
defer unblockAPI()
|
||||||
|
|
||||||
|
cm := NewClientManager(testClientConfig)
|
||||||
|
|
||||||
proxy1 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
proxy1 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
defer proxy1.Close()
|
defer proxy1.Close()
|
||||||
proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
@ -147,101 +151,106 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
|
|||||||
defer proxy3.Close()
|
defer proxy3.Close()
|
||||||
|
|
||||||
p := newProxyProvider([]string{"not used"}, "not used")
|
p := newProxyProvider([]string{"not used"}, "not used")
|
||||||
|
cm.proxyProvider = p
|
||||||
|
|
||||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
|
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
|
||||||
url, err := p.findProxy()
|
url, err := cm.SwitchToProxy()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
p.useProxy(url)
|
require.Equal(t, proxy1.URL, url)
|
||||||
require.Equal(t, proxy1.URL, GlobalGetRootURL())
|
require.Equal(t, proxy1.URL, cm.GetHost())
|
||||||
|
|
||||||
// Have to wait so as to not get rejected.
|
// Have to wait so as to not get rejected.
|
||||||
time.Sleep(proxyLookupWait)
|
time.Sleep(proxyLookupWait)
|
||||||
|
|
||||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy2.URL}, nil }
|
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy2.URL}, nil }
|
||||||
url, err = p.findProxy()
|
url, err = cm.SwitchToProxy()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
p.useProxy(url)
|
require.Equal(t, proxy2.URL, url)
|
||||||
require.Equal(t, proxy2.URL, GlobalGetRootURL())
|
require.Equal(t, proxy2.URL, cm.GetHost())
|
||||||
|
|
||||||
// Have to wait so as to not get rejected.
|
// Have to wait so as to not get rejected.
|
||||||
time.Sleep(proxyLookupWait)
|
time.Sleep(proxyLookupWait)
|
||||||
|
|
||||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy3.URL}, nil }
|
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy3.URL}, nil }
|
||||||
url, err = p.findProxy()
|
url, err = cm.SwitchToProxy()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
p.useProxy(url)
|
require.Equal(t, proxy3.URL, url)
|
||||||
require.Equal(t, proxy3.URL, GlobalGetRootURL())
|
require.Equal(t, proxy3.URL, cm.GetHost())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) {
|
func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) {
|
||||||
blockAPI()
|
blockAPI()
|
||||||
defer unblockAPI()
|
defer unblockAPI()
|
||||||
|
|
||||||
|
cm := NewClientManager(testClientConfig)
|
||||||
|
|
||||||
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
defer proxy.Close()
|
defer proxy.Close()
|
||||||
|
|
||||||
p := newProxyProvider([]string{"not used"}, "not used")
|
p := newProxyProvider([]string{"not used"}, "not used")
|
||||||
p.useDuration = time.Second
|
cm.proxyProvider = p
|
||||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
cm.proxyUseDuration = time.Second
|
||||||
|
|
||||||
url, err := p.findProxy()
|
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||||
|
url, err := cm.SwitchToProxy()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, proxy.URL, url)
|
require.Equal(t, proxy.URL, url)
|
||||||
|
require.Equal(t, proxy.URL, cm.GetHost())
|
||||||
p.useProxy(url)
|
|
||||||
require.Equal(t, proxy.URL, GlobalGetRootURL())
|
|
||||||
|
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
require.Equal(t, globalOriginalURL, GlobalGetRootURL())
|
require.Equal(t, RootURL, cm.GetHost())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
|
func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
|
||||||
// Don't block the API here because we want it to be working so the test can find it.
|
blockAPI()
|
||||||
defer unblockAPI()
|
defer unblockAPI()
|
||||||
|
|
||||||
|
cm := NewClientManager(testClientConfig)
|
||||||
|
|
||||||
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
defer proxy.Close()
|
defer proxy.Close()
|
||||||
|
|
||||||
p := newProxyProvider([]string{"not used"}, "not used")
|
p := newProxyProvider([]string{"not used"}, "not used")
|
||||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
cm.proxyProvider = p
|
||||||
|
|
||||||
url, err := p.findProxy()
|
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||||
|
url, err := cm.SwitchToProxy()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, proxy.URL, url)
|
require.Equal(t, proxy.URL, url)
|
||||||
|
require.Equal(t, proxy.URL, cm.GetHost())
|
||||||
|
|
||||||
p.useProxy(url)
|
// Simulate that the proxy stops working and that the standard api is reachable again.
|
||||||
require.Equal(t, proxy.URL, GlobalGetRootURL())
|
|
||||||
|
|
||||||
// Simulate that the proxy stops working.
|
|
||||||
proxy.Close()
|
proxy.Close()
|
||||||
|
unblockAPI()
|
||||||
time.Sleep(proxyLookupWait)
|
time.Sleep(proxyLookupWait)
|
||||||
|
|
||||||
// We should now find the original API URL if it is working again.
|
// We should now find the original API URL if it is working again.
|
||||||
url, err = p.findProxy()
|
url, err = cm.SwitchToProxy()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, globalOriginalURL, url)
|
require.Equal(t, RootURL, url)
|
||||||
|
require.Equal(t, RootURL, cm.GetHost())
|
||||||
p.useProxy(url)
|
|
||||||
require.Equal(t, globalOriginalURL, GlobalGetRootURL())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
|
func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
|
||||||
blockAPI()
|
blockAPI()
|
||||||
defer unblockAPI()
|
defer unblockAPI()
|
||||||
|
|
||||||
|
cm := NewClientManager(testClientConfig)
|
||||||
|
|
||||||
proxy1 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
proxy1 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
defer proxy1.Close()
|
defer proxy1.Close()
|
||||||
proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
defer proxy2.Close()
|
defer proxy2.Close()
|
||||||
|
|
||||||
p := newProxyProvider([]string{"not used"}, "not used")
|
p := newProxyProvider([]string{"not used"}, "not used")
|
||||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
|
cm.proxyProvider = p
|
||||||
|
|
||||||
// Find a proxy.
|
// Find a proxy.
|
||||||
url, err := p.findProxy()
|
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
|
||||||
|
url, err := cm.SwitchToProxy()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
p.useProxy(url)
|
require.Equal(t, proxy1.URL, url)
|
||||||
require.Equal(t, proxy1.URL, GlobalGetRootURL())
|
require.Equal(t, proxy1.URL, cm.GetHost())
|
||||||
|
|
||||||
// Have to wait so as to not get rejected.
|
// Have to wait so as to not get rejected.
|
||||||
time.Sleep(proxyLookupWait)
|
time.Sleep(proxyLookupWait)
|
||||||
@ -250,10 +259,10 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl
|
|||||||
proxy1.Close()
|
proxy1.Close()
|
||||||
|
|
||||||
// Should switch to the second proxy because both the first proxy and the protonmail API are blocked.
|
// Should switch to the second proxy because both the first proxy and the protonmail API are blocked.
|
||||||
url, err = p.findProxy()
|
url, err = cm.SwitchToProxy()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
p.useProxy(url)
|
require.Equal(t, proxy2.URL, url)
|
||||||
require.Equal(t, proxy2.URL, GlobalGetRootURL())
|
require.Equal(t, proxy2.URL, cm.GetHost())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
|
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
|
||||||
@ -289,16 +298,14 @@ func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// testAPIURLBackup is used to hold the globalOriginalURL because we clear it for test purposes and need to restore it.
|
// testAPIURLBackup is used to hold the globalOriginalURL because we clear it for test purposes and need to restore it.
|
||||||
var testAPIURLBackup = globalOriginalURL
|
var testAPIURLBackup = RootURL
|
||||||
|
|
||||||
// blockAPI prevents tests from reaching the standard API, forcing them to find a proxy.
|
// blockAPI prevents tests from reaching the standard API, forcing them to find a proxy.
|
||||||
func blockAPI() {
|
func blockAPI() {
|
||||||
globalSetRootURL("")
|
RootURL = ""
|
||||||
globalOriginalURL = ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// unblockAPI allow tests to reach the standard API again.
|
// unblockAPI allow tests to reach the standard API again.
|
||||||
func unblockAPI() {
|
func unblockAPI() {
|
||||||
globalOriginalURL = testAPIURLBackup
|
RootURL = testAPIURLBackup
|
||||||
globalSetRootURL(globalOriginalURL)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -28,7 +28,8 @@ import (
|
|||||||
// NewRequest creates a new request.
|
// NewRequest creates a new request.
|
||||||
func (c *Client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) {
|
func (c *Client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) {
|
||||||
// TODO: Support other protocols (localhost needs http not https).
|
// TODO: Support other protocols (localhost needs http not https).
|
||||||
req, err = http.NewRequest(method, "https://"+c.cm.GetRootURL()+path, body)
|
req, err = http.NewRequest(method, c.cm.GetRootURL()+path, body)
|
||||||
|
|
||||||
if req != nil {
|
if req != nil {
|
||||||
req.Header.Set("User-Agent", CurrentUserAgent)
|
req.Header.Set("User-Agent", CurrentUserAgent)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -25,7 +25,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestSentryCrashReport(t *testing.T) {
|
func TestSentryCrashReport(t *testing.T) {
|
||||||
c := newClient(NewClientManager(testClientConfig), "bridgetest")
|
cm := NewClientManager(testClientConfig)
|
||||||
|
c := cm.GetClient("bridgetest")
|
||||||
if err := c.ReportSentryCrash(errors.New("Testing crash report - api proxy; goroutines with threads, find origin")); err != nil {
|
if err := c.ReportSentryCrash(errors.New("Testing crash report - api proxy; goroutines with threads, find origin")); err != nil {
|
||||||
t.Fatal("Expected no error while report, but have", err)
|
t.Fatal("Expected no error while report, but have", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,6 +22,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -72,15 +73,24 @@ func Equals(tb testing.TB, exp, act interface{}) {
|
|||||||
// newTestServer is old function and should be replaced everywhere by newTestServerCallbacks.
|
// newTestServer is old function and should be replaced everywhere by newTestServerCallbacks.
|
||||||
func newTestServer(h http.Handler) (*httptest.Server, *Client) {
|
func newTestServer(h http.Handler) (*httptest.Server, *Client) {
|
||||||
s := httptest.NewServer(h)
|
s := httptest.NewServer(h)
|
||||||
RootURL = s.URL
|
|
||||||
|
|
||||||
return s, newTestClient()
|
serverURL, err := url.Parse(s.URL)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cm := NewClientManager(testClientConfig)
|
||||||
|
cm.host = serverURL.Host
|
||||||
|
cm.scheme = serverURL.Scheme
|
||||||
|
|
||||||
|
return s, newTestClient(cm)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestServerCallbacks(tb testing.TB, callbacks ...func(testing.TB, http.ResponseWriter, *http.Request) string) (func(), *Client) {
|
func newTestServerCallbacks(tb testing.TB, callbacks ...func(testing.TB, http.ResponseWriter, *http.Request) string) (func(), *Client) {
|
||||||
reqNum := 0
|
reqNum := 0
|
||||||
_, file, line, _ := runtime.Caller(1)
|
_, file, line, _ := runtime.Caller(1)
|
||||||
file = filepath.Base(file)
|
file = filepath.Base(file)
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
reqNum++
|
reqNum++
|
||||||
if reqNum > len(callbacks) {
|
if reqNum > len(callbacks) {
|
||||||
@ -95,7 +105,12 @@ func newTestServerCallbacks(tb testing.TB, callbacks ...func(testing.TB, http.Re
|
|||||||
writeJSONResponsefromFile(tb, w, response, reqNum-1)
|
writeJSONResponsefromFile(tb, w, response, reqNum-1)
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
RootURL = server.URL
|
|
||||||
|
serverURL, err := url.Parse(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
finish := func() {
|
finish := func() {
|
||||||
server.CloseClientConnections() // Closing without waiting for finishing requests.
|
server.CloseClientConnections() // Closing without waiting for finishing requests.
|
||||||
if reqNum != len(callbacks) {
|
if reqNum != len(callbacks) {
|
||||||
@ -106,7 +121,12 @@ func newTestServerCallbacks(tb testing.TB, callbacks ...func(testing.TB, http.Re
|
|||||||
tb.Error("server failed")
|
tb.Error("server failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return finish, newTestClient()
|
|
||||||
|
cm := NewClientManager(testClientConfig)
|
||||||
|
cm.host = serverURL.Host
|
||||||
|
cm.scheme = serverURL.Scheme
|
||||||
|
|
||||||
|
return finish, newTestClient(cm)
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkMethodAndPath(r *http.Request, method, path string) error {
|
func checkMethodAndPath(r *http.Request, method, path string) error {
|
||||||
|
|||||||
@ -3,8 +3,8 @@
|
|||||||
"AccessToken": "de0423049b44243afeec7d9c1d99be7b46da1e8a",
|
"AccessToken": "de0423049b44243afeec7d9c1d99be7b46da1e8a",
|
||||||
"ExpiresIn": 360000,
|
"ExpiresIn": 360000,
|
||||||
"TokenType": "Bearer",
|
"TokenType": "Bearer",
|
||||||
"Uid": "differentUID",
|
"Uid": "729ad6012421d67ad26950dc898bebe3a6e3caa2",
|
||||||
"UID": "differentUID",
|
"UID": "729ad6012421d67ad26950dc898bebe3a6e3caa2",
|
||||||
"Scope": "full mail payments reset keys",
|
"Scope": "full mail payments reset keys",
|
||||||
"RefreshToken": "b894b4c4f20003f12d486900d8b88c7d68e67235"
|
"RefreshToken": "b894b4c4f20003f12d486900d8b88c7d68e67235"
|
||||||
}
|
}
|
||||||
|
|||||||
@ -121,7 +121,7 @@ func (c *Client) UpdateUser() (user *User, err error) {
|
|||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// CurrentUser return currently active user or user will be updated.
|
// CurrentUser returns currently active user or user will be updated.
|
||||||
func (c *Client) CurrentUser() (user *User, err error) {
|
func (c *Client) CurrentUser() (user *User, err error) {
|
||||||
if c.user != nil && len(c.addresses) != 0 {
|
if c.user != nil && len(c.addresses) != 0 {
|
||||||
user = c.user
|
user = c.user
|
||||||
|
|||||||
@ -21,10 +21,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
|
||||||
"github.com/ProtonMail/proton-bridge/internal/preferences"
|
"github.com/ProtonMail/proton-bridge/internal/preferences"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetBridge returns bridge instance.
|
// GetBridge returns bridge instance.
|
||||||
@ -35,7 +33,10 @@ func (ctx *TestContext) GetBridge() *bridge.Bridge {
|
|||||||
// withBridgeInstance creates a bridge instance for use in the test.
|
// withBridgeInstance creates a bridge instance for use in the test.
|
||||||
// Every TestContext has this by default and thus this doesn't need to be exported.
|
// Every TestContext has this by default and thus this doesn't need to be exported.
|
||||||
func (ctx *TestContext) withBridgeInstance() {
|
func (ctx *TestContext) withBridgeInstance() {
|
||||||
ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, ctx.clientManager)
|
pmapiFactory := func(userID string) bridge.PMAPIProvider {
|
||||||
|
return ctx.pmapiController.GetClient(userID)
|
||||||
|
}
|
||||||
|
ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, pmapiFactory)
|
||||||
ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data")
|
ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,7 +61,7 @@ func newBridgeInstance(
|
|||||||
cfg *fakeConfig,
|
cfg *fakeConfig,
|
||||||
credStore bridge.CredentialsStorer,
|
credStore bridge.CredentialsStorer,
|
||||||
eventListener listener.Listener,
|
eventListener listener.Listener,
|
||||||
clientManager *pmapi.ClientManager,
|
pmapiFactory bridge.PMAPIProviderFactory,
|
||||||
) *bridge.Bridge {
|
) *bridge.Bridge {
|
||||||
version := os.Getenv("VERSION")
|
version := os.Getenv("VERSION")
|
||||||
bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "")
|
bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "")
|
||||||
@ -68,7 +69,7 @@ func newBridgeInstance(
|
|||||||
panicHandler := &panicHandler{t: t}
|
panicHandler := &panicHandler{t: t}
|
||||||
pref := preferences.New(cfg)
|
pref := preferences.New(cfg)
|
||||||
|
|
||||||
return bridge.New(cfg, pref, panicHandler, eventListener, version, clientManager, credStore)
|
return bridge.New(cfg, pref, panicHandler, eventListener, version, pmapiFactory, credStore)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLastBridgeError sets the last error that occurred while executing a bridge action.
|
// SetLastBridgeError sets the last error that occurred while executing a bridge action.
|
||||||
|
|||||||
@ -28,6 +28,7 @@ import (
|
|||||||
|
|
||||||
type fakeConfig struct {
|
type fakeConfig struct {
|
||||||
dir string
|
dir string
|
||||||
|
tm *pmapi.TokenManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// newFakeConfig creates a temporary folder for files.
|
// newFakeConfig creates a temporary folder for files.
|
||||||
@ -40,6 +41,7 @@ func newFakeConfig() *fakeConfig {
|
|||||||
|
|
||||||
return &fakeConfig{
|
return &fakeConfig{
|
||||||
dir: dir,
|
dir: dir,
|
||||||
|
tm: pmapi.NewTokenManager(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,6 +53,8 @@ func (c *fakeConfig) GetAPIConfig() *pmapi.ClientConfig {
|
|||||||
AppVersion: "Bridge_" + os.Getenv("VERSION"),
|
AppVersion: "Bridge_" + os.Getenv("VERSION"),
|
||||||
ClientID: "bridge",
|
ClientID: "bridge",
|
||||||
SentryDSN: "",
|
SentryDSN: "",
|
||||||
|
// TokenManager should not be required, but PMAPI still doesn't handle not-set cases everywhere.
|
||||||
|
TokenManager: c.tm,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (c *fakeConfig) GetDBDir() string {
|
func (c *fakeConfig) GetDBDir() string {
|
||||||
|
|||||||
@ -21,7 +21,6 @@ package context
|
|||||||
import (
|
import (
|
||||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
|
||||||
"github.com/ProtonMail/proton-bridge/test/accounts"
|
"github.com/ProtonMail/proton-bridge/test/accounts"
|
||||||
"github.com/ProtonMail/proton-bridge/test/mocks"
|
"github.com/ProtonMail/proton-bridge/test/mocks"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@ -58,9 +57,6 @@ type TestContext struct {
|
|||||||
smtpClients map[string]*mocks.SMTPClient
|
smtpClients map[string]*mocks.SMTPClient
|
||||||
smtpLastResponses map[string]*mocks.SMTPResponse
|
smtpLastResponses map[string]*mocks.SMTPResponse
|
||||||
|
|
||||||
// PMAPI related variables.
|
|
||||||
clientManager *pmapi.ClientManager
|
|
||||||
|
|
||||||
// These are the cleanup steps executed when Cleanup() is called.
|
// These are the cleanup steps executed when Cleanup() is called.
|
||||||
cleanupSteps []*Cleaner
|
cleanupSteps []*Cleaner
|
||||||
|
|
||||||
@ -74,20 +70,17 @@ func New() *TestContext {
|
|||||||
|
|
||||||
cfg := newFakeConfig()
|
cfg := newFakeConfig()
|
||||||
|
|
||||||
cm := pmapi.NewClientManager(cfg.GetAPIConfig())
|
|
||||||
|
|
||||||
ctx := &TestContext{
|
ctx := &TestContext{
|
||||||
t: &bddT{},
|
t: &bddT{},
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
listener: listener.New(),
|
listener: listener.New(),
|
||||||
pmapiController: newPMAPIController(cm),
|
pmapiController: newPMAPIController(),
|
||||||
testAccounts: newTestAccounts(),
|
testAccounts: newTestAccounts(),
|
||||||
credStore: newFakeCredStore(),
|
credStore: newFakeCredStore(),
|
||||||
imapClients: make(map[string]*mocks.IMAPClient),
|
imapClients: make(map[string]*mocks.IMAPClient),
|
||||||
imapLastResponses: make(map[string]*mocks.IMAPResponse),
|
imapLastResponses: make(map[string]*mocks.IMAPResponse),
|
||||||
smtpClients: make(map[string]*mocks.SMTPClient),
|
smtpClients: make(map[string]*mocks.SMTPClient),
|
||||||
smtpLastResponses: make(map[string]*mocks.SMTPResponse),
|
smtpLastResponses: make(map[string]*mocks.SMTPResponse),
|
||||||
clientManager: cm,
|
|
||||||
logger: logrus.StandardLogger(),
|
logger: logrus.StandardLogger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -81,6 +81,15 @@ func (c *fakeCredStore) UpdateEmails(userID string, emails []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *fakeCredStore) UpdatePassword(userID, password string) error {
|
||||||
|
creds, err := c.Get(userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
creds.MailboxPassword = password
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *fakeCredStore) UpdateToken(userID, apiToken string) error {
|
func (c *fakeCredStore) UpdateToken(userID, apiToken string) error {
|
||||||
creds, err := c.Get(userID)
|
creds, err := c.Get(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -40,12 +40,12 @@ type PMAPIController interface {
|
|||||||
GetCalls(method, path string) [][]byte
|
GetCalls(method, path string) [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPMAPIController(cm *pmapi.ClientManager) PMAPIController {
|
func newPMAPIController() PMAPIController {
|
||||||
switch os.Getenv(EnvName) {
|
switch os.Getenv(EnvName) {
|
||||||
case EnvFake:
|
case EnvFake:
|
||||||
return newFakePMAPIController()
|
return newFakePMAPIController()
|
||||||
case EnvLive:
|
case EnvLive:
|
||||||
return newLivePMAPIController(cm)
|
return newLivePMAPIController()
|
||||||
default:
|
default:
|
||||||
panic("unknown env")
|
panic("unknown env")
|
||||||
}
|
}
|
||||||
@ -67,8 +67,8 @@ func (s *fakePMAPIControllerWrap) GetClient(userID string) bridge.PMAPIProvider
|
|||||||
return s.Controller.GetClient(userID)
|
return s.Controller.GetClient(userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLivePMAPIController(cm *pmapi.ClientManager) PMAPIController {
|
func newLivePMAPIController() PMAPIController {
|
||||||
return newLiveAPIControllerWrap(liveapi.NewController(cm))
|
return newLiveAPIControllerWrap(liveapi.NewController())
|
||||||
}
|
}
|
||||||
|
|
||||||
type liveAPIControllerWrap struct {
|
type liveAPIControllerWrap struct {
|
||||||
|
|||||||
@ -141,12 +141,13 @@ func (api *FakePMAPI) AuthRefresh(token string) (*pmapi.Auth, error) {
|
|||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *FakePMAPI) Logout() {
|
func (api *FakePMAPI) Logout() error {
|
||||||
if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil {
|
if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
// Logout will also emit change to auth channel
|
// Logout will also emit change to auth channel
|
||||||
api.sendAuth(nil)
|
api.sendAuth(nil)
|
||||||
api.controller.deleteSession(api.uid)
|
api.controller.deleteSession(api.uid)
|
||||||
api.unsetUser()
|
api.unsetUser()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,7 +18,9 @@
|
|||||||
package liveapi
|
package liveapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
@ -30,31 +32,31 @@ type Controller struct {
|
|||||||
calls []*fakeCall
|
calls []*fakeCall
|
||||||
pmapiByUsername map[string]*pmapi.Client
|
pmapiByUsername map[string]*pmapi.Client
|
||||||
messageIDsByUsername map[string][]string
|
messageIDsByUsername map[string][]string
|
||||||
clientManager *pmapi.ClientManager
|
|
||||||
|
|
||||||
// State controlled by test.
|
// State controlled by test.
|
||||||
noInternetConnection bool
|
noInternetConnection bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewController(cm *pmapi.ClientManager) *Controller {
|
func NewController() *Controller {
|
||||||
cntrl := &Controller{
|
return &Controller{
|
||||||
lock: &sync.RWMutex{},
|
lock: &sync.RWMutex{},
|
||||||
calls: []*fakeCall{},
|
calls: []*fakeCall{},
|
||||||
pmapiByUsername: map[string]*pmapi.Client{},
|
pmapiByUsername: map[string]*pmapi.Client{},
|
||||||
messageIDsByUsername: map[string][]string{},
|
messageIDsByUsername: map[string][]string{},
|
||||||
clientManager: cm,
|
|
||||||
|
|
||||||
noInternetConnection: false,
|
noInternetConnection: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
cntrl.clientManager.SetClientRoundTripper(&fakeTransport{
|
|
||||||
cntrl: cntrl,
|
|
||||||
transport: http.DefaultTransport,
|
|
||||||
})
|
|
||||||
|
|
||||||
return cntrl
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cntrl *Controller) GetClient(userID string) *pmapi.Client {
|
func (cntrl *Controller) GetClient(userID string) *pmapi.Client {
|
||||||
return cntrl.clientManager.GetClient(userID)
|
cfg := &pmapi.ClientConfig{
|
||||||
|
AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")),
|
||||||
|
ClientID: "bridge-test",
|
||||||
|
Transport: &fakeTransport{
|
||||||
|
cntrl: cntrl,
|
||||||
|
transport: http.DefaultTransport,
|
||||||
|
},
|
||||||
|
TokenManager: pmapi.NewTokenManager(),
|
||||||
|
}
|
||||||
|
return pmapi.NewClient(cfg, userID)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,7 +18,9 @@
|
|||||||
package liveapi
|
package liveapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/cucumber/godog"
|
"github.com/cucumber/godog"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
@ -28,7 +30,11 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList,
|
|||||||
return godog.ErrPending
|
return godog.ErrPending
|
||||||
}
|
}
|
||||||
|
|
||||||
client := cntrl.GetClient(user.ID)
|
client := pmapi.NewClient(&pmapi.ClientConfig{
|
||||||
|
AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")),
|
||||||
|
ClientID: "bridge-cntrl",
|
||||||
|
TokenManager: pmapi.NewTokenManager(),
|
||||||
|
}, user.ID)
|
||||||
|
|
||||||
authInfo, err := client.AuthInfo(user.Name)
|
authInfo, err := client.AuthInfo(user.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -55,6 +61,5 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList,
|
|||||||
}
|
}
|
||||||
|
|
||||||
cntrl.pmapiByUsername[user.Name] = client
|
cntrl.pmapiByUsername[user.Name] = client
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user