mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
feat: central auth channel for clients
This commit is contained in:
@ -43,14 +43,14 @@ var (
|
|||||||
|
|
||||||
// Bridge is a struct handling users.
|
// Bridge is a struct handling users.
|
||||||
type Bridge struct {
|
type Bridge struct {
|
||||||
config Configer
|
config Configer
|
||||||
pref PreferenceProvider
|
pref PreferenceProvider
|
||||||
panicHandler PanicHandler
|
panicHandler PanicHandler
|
||||||
events listener.Listener
|
events listener.Listener
|
||||||
version string
|
version string
|
||||||
clientMan *pmapi.ClientManager
|
clientManager *pmapi.ClientManager
|
||||||
credStorer CredentialsStorer
|
credStorer CredentialsStorer
|
||||||
storeCache *store.Cache
|
storeCache *store.Cache
|
||||||
|
|
||||||
// users is a list of accounts that have been added to bridge.
|
// users is a list of accounts that have been added to bridge.
|
||||||
// They are stored sorted in the credentials store in the order
|
// They are stored sorted in the credentials store in the order
|
||||||
@ -76,22 +76,22 @@ func New(
|
|||||||
panicHandler PanicHandler,
|
panicHandler PanicHandler,
|
||||||
eventListener listener.Listener,
|
eventListener listener.Listener,
|
||||||
version string,
|
version string,
|
||||||
clientMan *pmapi.ClientManager,
|
clientManager *pmapi.ClientManager,
|
||||||
credStorer CredentialsStorer,
|
credStorer CredentialsStorer,
|
||||||
) *Bridge {
|
) *Bridge {
|
||||||
log.Trace("Creating new bridge")
|
log.Trace("Creating new bridge")
|
||||||
|
|
||||||
b := &Bridge{
|
b := &Bridge{
|
||||||
config: config,
|
config: config,
|
||||||
pref: pref,
|
pref: pref,
|
||||||
panicHandler: panicHandler,
|
panicHandler: panicHandler,
|
||||||
events: eventListener,
|
events: eventListener,
|
||||||
version: version,
|
version: version,
|
||||||
clientMan: clientMan,
|
clientManager: clientManager,
|
||||||
credStorer: credStorer,
|
credStorer: credStorer,
|
||||||
storeCache: store.NewCache(config.GetIMAPCachePath()),
|
storeCache: store.NewCache(config.GetIMAPCachePath()),
|
||||||
idleUpdates: make(chan interface{}),
|
idleUpdates: make(chan interface{}),
|
||||||
lock: sync.RWMutex{},
|
lock: sync.RWMutex{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allow DoH before starting bridge if the user has previously set this setting.
|
// Allow DoH before starting bridge if the user has previously set this setting.
|
||||||
@ -105,6 +105,11 @@ func New(
|
|||||||
b.watchBridgeOutdated()
|
b.watchBridgeOutdated()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer panicHandler.HandlePanic()
|
||||||
|
b.watchUserAuths()
|
||||||
|
}()
|
||||||
|
|
||||||
if b.credStorer == nil {
|
if b.credStorer == nil {
|
||||||
log.Error("Bridge has no credentials store")
|
log.Error("Bridge has no credentials store")
|
||||||
} else if err := b.loadUsersFromCredentialsStore(); err != nil {
|
} else if err := b.loadUsersFromCredentialsStore(); err != nil {
|
||||||
@ -148,7 +153,7 @@ func (b *Bridge) loadUsersFromCredentialsStore() (err error) {
|
|||||||
for _, userID := range userIDs {
|
for _, userID := range userIDs {
|
||||||
l := log.WithField("user", userID)
|
l := log.WithField("user", userID)
|
||||||
|
|
||||||
user, newUserErr := newUser(b.panicHandler, userID, b.events, b.credStorer, b.clientMan, b.storeCache, b.config.GetDBDir())
|
user, newUserErr := newUser(b.panicHandler, userID, b.events, b.credStorer, b.clientManager, b.storeCache, b.config.GetDBDir())
|
||||||
if newUserErr != nil {
|
if newUserErr != nil {
|
||||||
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
|
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
|
||||||
continue
|
continue
|
||||||
@ -173,6 +178,18 @@ func (b *Bridge) watchBridgeOutdated() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) watchUserAuths() {
|
||||||
|
for auth := range b.clientManager.GetBridgeAuthChannel() {
|
||||||
|
user, ok := b.hasUser(auth.UserID)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
user.ReceiveAPIAuth(auth.Auth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Bridge) closeAllConnections() {
|
func (b *Bridge) closeAllConnections() {
|
||||||
for _, user := range b.users {
|
for _, user := range b.users {
|
||||||
user.closeAllConnections()
|
user.closeAllConnections()
|
||||||
@ -195,9 +212,8 @@ func (b *Bridge) Login(username, password string) (loginClient PMAPIProvider, au
|
|||||||
|
|
||||||
b.crashBandicoot(username)
|
b.crashBandicoot(username)
|
||||||
|
|
||||||
// We need to use "login" client because we need userID to properly
|
// We need to use "login" client because we need userID to properly assign access tokens into token manager.
|
||||||
// assign access tokens into token manager.
|
loginClient = b.clientManager.GetClient("login")
|
||||||
loginClient = b.clientMan.GetClient("login")
|
|
||||||
|
|
||||||
authInfo, err := loginClient.AuthInfo(username)
|
authInfo, err := loginClient.AuthInfo(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -227,29 +243,22 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
|
|||||||
b.lock.Lock()
|
b.lock.Lock()
|
||||||
defer b.lock.Unlock()
|
defer b.lock.Unlock()
|
||||||
|
|
||||||
|
defer loginClient.Logout()
|
||||||
|
|
||||||
mbPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt)
|
mbPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Could not hash mailbox password")
|
log.WithError(err).Error("Could not hash mailbox password")
|
||||||
if logoutErr := loginClient.Logout(); logoutErr != nil {
|
|
||||||
log.WithError(logoutErr).Error("Clean login session after hash password failed.")
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = loginClient.Unlock(mbPassword); err != nil {
|
if _, err = loginClient.Unlock(mbPassword); err != nil {
|
||||||
log.WithError(err).Error("Could not decrypt keyring")
|
log.WithError(err).Error("Could not decrypt keyring")
|
||||||
if logoutErr := loginClient.Logout(); logoutErr != nil {
|
|
||||||
log.WithError(logoutErr).Error("Clean login session after unlock failed.")
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
apiUser, err := loginClient.CurrentUser()
|
apiUser, err := loginClient.CurrentUser()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Could not get login API user")
|
log.WithError(err).Error("Could not get login API user")
|
||||||
if logoutErr := loginClient.Logout(); logoutErr != nil {
|
|
||||||
log.WithError(logoutErr).Error("Clean login session after get current user failed.")
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -259,20 +268,13 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
|
|||||||
if hasUser && user.IsConnected() {
|
if hasUser && user.IsConnected() {
|
||||||
err = errors.New("user is already logged in")
|
err = errors.New("user is already logged in")
|
||||||
log.WithError(err).Warn("User is already logged in")
|
log.WithError(err).Warn("User is already logged in")
|
||||||
if logoutErr := loginClient.Logout(); logoutErr != nil {
|
|
||||||
log.WithError(logoutErr).Warn("Could not discard auth generated during second login")
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
apiToken := auth.UID() + ":" + auth.RefreshToken
|
apiClient := b.clientManager.GetClient(apiUser.ID)
|
||||||
apiClient := b.clientMan.GetClient(apiUser.ID)
|
auth, err = apiClient.AuthRefresh(auth.GenToken())
|
||||||
auth, err = apiClient.AuthRefresh(apiToken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Could refresh token in new client")
|
log.WithError(err).Error("Could refresh token in new client")
|
||||||
if logoutErr := loginClient.Logout(); logoutErr != nil {
|
|
||||||
log.WithError(logoutErr).Warn("Could not discard auth generated after auth refresh")
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -280,22 +282,18 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
|
|||||||
apiUser, err = apiClient.CurrentUser()
|
apiUser, err = apiClient.CurrentUser()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Could not get current API user")
|
log.WithError(err).Error("Could not get current API user")
|
||||||
if logoutErr := loginClient.Logout(); logoutErr != nil {
|
|
||||||
log.WithError(logoutErr).Error("Clean login session after get current user failed.")
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
apiToken = auth.UID() + ":" + auth.RefreshToken
|
|
||||||
activeEmails := apiClient.Addresses().ActiveEmails()
|
activeEmails := apiClient.Addresses().ActiveEmails()
|
||||||
if _, err = b.credStorer.Add(apiUser.ID, apiUser.Name, apiToken, mbPassword, activeEmails); err != nil {
|
if _, err = b.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), mbPassword, activeEmails); err != nil {
|
||||||
log.WithError(err).Error("Could not add user to credentials store")
|
log.WithError(err).Error("Could not add user to credentials store")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If it's a new user, generate the user object.
|
// If it's a new user, generate the user object.
|
||||||
if !hasUser {
|
if !hasUser {
|
||||||
user, err = newUser(b.panicHandler, apiUser.ID, b.events, b.credStorer, b.clientMan, b.storeCache, b.config.GetDBDir())
|
user, err = newUser(b.panicHandler, apiUser.ID, b.events, b.credStorer, b.clientManager, b.storeCache, b.config.GetDBDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user")
|
log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user")
|
||||||
return
|
return
|
||||||
@ -405,8 +403,8 @@ func (b *Bridge) DeleteUser(userID string, clearStore bool) error {
|
|||||||
|
|
||||||
// ReportBug reports a new bug from the user.
|
// ReportBug reports a new bug from the user.
|
||||||
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
||||||
c := b.clientMan.GetClient("bug_reporter")
|
c := b.clientManager.GetClient("bug_reporter")
|
||||||
defer func() { _ = c.Logout() }()
|
defer c.Logout()
|
||||||
|
|
||||||
title := "[Bridge] Bug"
|
title := "[Bridge] Bug"
|
||||||
if err := c.ReportBugWithEmailClient(
|
if err := c.ReportBugWithEmailClient(
|
||||||
@ -429,8 +427,8 @@ func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address,
|
|||||||
|
|
||||||
// SendMetric sends a metric. We don't want to return any errors, only log them.
|
// SendMetric sends a metric. We don't want to return any errors, only log them.
|
||||||
func (b *Bridge) SendMetric(m m.Metric) {
|
func (b *Bridge) SendMetric(m m.Metric) {
|
||||||
c := b.clientMan.GetClient("metric_reporter")
|
c := b.clientManager.GetClient("metric_reporter")
|
||||||
defer func() { _ = c.Logout() }()
|
defer c.Logout()
|
||||||
|
|
||||||
cat, act, lab := m.Get()
|
cat, act, lab := m.Get()
|
||||||
if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil {
|
if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil {
|
||||||
|
|||||||
@ -47,7 +47,6 @@ type Clientman interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PMAPIProvider interface {
|
type PMAPIProvider interface {
|
||||||
SetAuths(auths chan<- *pmapi.Auth)
|
|
||||||
Auth(username, password string, info *pmapi.AuthInfo) (*pmapi.Auth, error)
|
Auth(username, password string, info *pmapi.AuthInfo) (*pmapi.Auth, error)
|
||||||
AuthInfo(username string) (*pmapi.AuthInfo, error)
|
AuthInfo(username string) (*pmapi.AuthInfo, error)
|
||||||
AuthRefresh(token string) (*pmapi.Auth, error)
|
AuthRefresh(token string) (*pmapi.Auth, error)
|
||||||
@ -56,7 +55,8 @@ type PMAPIProvider interface {
|
|||||||
CurrentUser() (*pmapi.User, error)
|
CurrentUser() (*pmapi.User, error)
|
||||||
UpdateUser() (*pmapi.User, error)
|
UpdateUser() (*pmapi.User, error)
|
||||||
Addresses() pmapi.AddressList
|
Addresses() pmapi.AddressList
|
||||||
Logout() error
|
|
||||||
|
Logout()
|
||||||
|
|
||||||
GetEvent(eventID string) (*pmapi.Event, error)
|
GetEvent(eventID string) (*pmapi.Event, error)
|
||||||
|
|
||||||
|
|||||||
@ -53,9 +53,8 @@ type User struct {
|
|||||||
userID string
|
userID string
|
||||||
creds *credentials.Credentials
|
creds *credentials.Credentials
|
||||||
|
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
authChannel chan *pmapi.Auth
|
isAuthorized bool
|
||||||
hasAPIAuth bool
|
|
||||||
|
|
||||||
unlockingKeyringLock sync.Mutex
|
unlockingKeyringLock sync.Mutex
|
||||||
wasKeyringUnlocked bool
|
wasKeyringUnlocked bool
|
||||||
@ -116,15 +115,6 @@ func (u *User) init(idleUpdates chan interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
u.creds = creds
|
u.creds = creds
|
||||||
|
|
||||||
// Set up the auth channel on which auths from the api client are sent.
|
|
||||||
u.authChannel = make(chan *pmapi.Auth)
|
|
||||||
u.client().SetAuths(u.authChannel)
|
|
||||||
u.hasAPIAuth = false
|
|
||||||
go func() {
|
|
||||||
defer u.panicHandler.HandlePanic()
|
|
||||||
u.watchAPIClientAuths()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Try to authorise the user if they aren't already authorised.
|
// Try to authorise the user if they aren't already authorised.
|
||||||
// Note: we still allow users to set up bridge if the internet is off.
|
// Note: we still allow users to set up bridge if the internet is off.
|
||||||
if authErr := u.authorizeIfNecessary(false); authErr != nil {
|
if authErr := u.authorizeIfNecessary(false); authErr != nil {
|
||||||
@ -169,7 +159,7 @@ func (u *User) SetIMAPIdleUpdateChannel() {
|
|||||||
|
|
||||||
// authorizeIfNecessary checks whether user is logged in and is connected to api auth channel.
|
// authorizeIfNecessary checks whether user is logged in and is connected to api auth channel.
|
||||||
// If user is not already connected to the api auth channel (for example there was no internet during start),
|
// If user is not already connected to the api auth channel (for example there was no internet during start),
|
||||||
// it tries to connect it. See `connectToAuthChannel` for more info.
|
// it tries to connect it.
|
||||||
func (u *User) authorizeIfNecessary(emitEvent bool) (err error) {
|
func (u *User) authorizeIfNecessary(emitEvent bool) (err error) {
|
||||||
// If user is connected and has an auth channel, then perfect, nothing to do here.
|
// If user is connected and has an auth channel, then perfect, nothing to do here.
|
||||||
if u.creds.IsConnected() && u.HasAPIAuth() {
|
if u.creds.IsConnected() && u.HasAPIAuth() {
|
||||||
@ -236,11 +226,9 @@ func (u *User) authorizeAndUnlock() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
auth, err := u.client().AuthRefresh(u.creds.APIToken)
|
if _, err := u.client().AuthRefresh(u.creds.APIToken); err != nil {
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "failed to refresh API auth")
|
return errors.Wrap(err, "failed to refresh API auth")
|
||||||
}
|
}
|
||||||
u.authChannel <- auth
|
|
||||||
|
|
||||||
if _, err = u.client().Unlock(u.creds.MailboxPassword); err != nil {
|
if _, err = u.client().Unlock(u.creds.MailboxPassword); err != nil {
|
||||||
return errors.Wrap(err, "failed to unlock user")
|
return errors.Wrap(err, "failed to unlock user")
|
||||||
@ -253,32 +241,34 @@ func (u *User) authorizeAndUnlock() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// See `connectToAPIClientAuthChannel` for more info.
|
func (u *User) ReceiveAPIAuth(auth *pmapi.Auth) {
|
||||||
func (u *User) watchAPIClientAuths() {
|
if auth == nil {
|
||||||
for auth := range u.authChannel {
|
if err := u.logout(); err != nil {
|
||||||
if auth != nil {
|
|
||||||
newRefreshToken := auth.UID() + ":" + auth.RefreshToken
|
|
||||||
u.updateAPIToken(newRefreshToken)
|
|
||||||
u.hasAPIAuth = true
|
|
||||||
} else if err := u.logout(); err != nil {
|
|
||||||
u.log.WithError(err).Error("Cannot logout user after receiving empty auth from API")
|
u.log.WithError(err).Error("Cannot logout user after receiving empty auth from API")
|
||||||
}
|
}
|
||||||
|
u.isAuthorized = false
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
u.updateAPIToken(auth.GenToken())
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateAPIToken is helper for updating the token in keychain. It's not supposed to be
|
// updateAPIToken is helper for updating the token in keychain. It's not supposed to be
|
||||||
// called directly from other parts of the code--only from `watchAPIClientAuths`.
|
// called directly from other parts of the code, only from `ReceiveAPIAuth`.
|
||||||
func (u *User) updateAPIToken(newRefreshToken string) {
|
func (u *User) updateAPIToken(newRefreshToken string) {
|
||||||
u.lock.Lock()
|
u.lock.Lock()
|
||||||
defer u.lock.Unlock()
|
defer u.lock.Unlock()
|
||||||
|
|
||||||
u.log.Info("Saving refresh token")
|
u.log.WithField("token", newRefreshToken).Info("Saving token to credentials store")
|
||||||
|
|
||||||
if err := u.credStorer.UpdateToken(u.userID, newRefreshToken); err != nil {
|
if err := u.credStorer.UpdateToken(u.userID, newRefreshToken); err != nil {
|
||||||
u.log.WithError(err).Error("Cannot update refresh token in credentials store")
|
u.log.WithError(err).Error("Cannot update refresh token in credentials store")
|
||||||
} else {
|
return
|
||||||
u.refreshFromCredentials()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
u.refreshFromCredentials()
|
||||||
|
|
||||||
|
u.isAuthorized = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// clearStore removes the database.
|
// clearStore removes the database.
|
||||||
@ -548,18 +538,7 @@ func (u *User) Logout() (err error) {
|
|||||||
u.wasKeyringUnlocked = false
|
u.wasKeyringUnlocked = false
|
||||||
u.unlockingKeyringLock.Unlock()
|
u.unlockingKeyringLock.Unlock()
|
||||||
|
|
||||||
if err = u.client().Logout(); err != nil {
|
u.client().Logout()
|
||||||
u.log.WithError(err).Warn("Could not log user out from API client")
|
|
||||||
}
|
|
||||||
u.client().SetAuths(nil)
|
|
||||||
|
|
||||||
// Logout needs to stop auth channel so when user logs back in
|
|
||||||
// it can register again with new client.
|
|
||||||
// Note: be careful to not close channel twice.
|
|
||||||
if u.authChannel != nil {
|
|
||||||
close(u.authChannel)
|
|
||||||
u.authChannel = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = u.credStorer.Logout(u.userID); err != nil {
|
if err = u.credStorer.Logout(u.userID); err != nil {
|
||||||
u.log.WithError(err).Warn("Could not log user out from credentials store")
|
u.log.WithError(err).Warn("Could not log user out from credentials store")
|
||||||
@ -617,5 +596,5 @@ func (u *User) GetStore() *store.Store {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) HasAPIAuth() bool {
|
func (u *User) HasAPIAuth() bool {
|
||||||
return u.hasAPIAuth
|
return u.isAuthorized
|
||||||
}
|
}
|
||||||
|
|||||||
@ -66,7 +66,7 @@ func HandlePanic(cfg *Config, output string) {
|
|||||||
if !cfg.IsDevMode() {
|
if !cfg.IsDevMode() {
|
||||||
// TODO: Is it okay to just create a throwaway client like this?
|
// TODO: Is it okay to just create a throwaway client like this?
|
||||||
c := pmapi.NewClientManager(cfg.GetAPIConfig()).GetClient("no-user-id")
|
c := pmapi.NewClientManager(cfg.GetAPIConfig()).GetClient("no-user-id")
|
||||||
defer func() { _ = c.Logout() }()
|
defer c.Logout()
|
||||||
|
|
||||||
if err := c.ReportSentryCrash(fmt.Errorf(output)); err != nil {
|
if err := c.ReportSentryCrash(fmt.Errorf(output)); err != nil {
|
||||||
log.Error("Sentry crash report failed: ", err)
|
log.Error("Sentry crash report failed: ", err)
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import (
|
|||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -122,6 +123,10 @@ func (s *Auth) UID() string {
|
|||||||
return s.uid
|
return s.uid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Auth) GenToken() string {
|
||||||
|
return fmt.Sprintf("%v:%v", s.UID(), s.RefreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Auth) HasTwoFactor() bool {
|
func (s *Auth) HasTwoFactor() bool {
|
||||||
if s.TwoFA == nil {
|
if s.TwoFA == nil {
|
||||||
return false
|
return false
|
||||||
@ -191,9 +196,16 @@ type AuthRefreshReq struct {
|
|||||||
State string
|
State string
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAuths sets auths channel.
|
func (c *Client) sendAuth(auth *Auth) {
|
||||||
func (c *Client) SetAuths(auths chan<- *Auth) {
|
c.cm.getClientAuthChannel() <- ClientAuth{
|
||||||
c.auths = auths
|
UserID: c.userID,
|
||||||
|
Auth: auth,
|
||||||
|
}
|
||||||
|
|
||||||
|
if auth != nil {
|
||||||
|
c.uid = auth.UID()
|
||||||
|
c.accessToken = auth.accessToken
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthInfo gets authentication info for a user.
|
// AuthInfo gets authentication info for a user.
|
||||||
@ -301,13 +313,7 @@ func (c *Client) Auth(username, password string, info *AuthInfo) (auth *Auth, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
auth = authRes.getAuth()
|
auth = authRes.getAuth()
|
||||||
c.uid = auth.UID()
|
c.sendAuth(auth)
|
||||||
c.accessToken = auth.accessToken
|
|
||||||
|
|
||||||
if c.auths != nil {
|
|
||||||
c.auths <- auth
|
|
||||||
}
|
|
||||||
c.cm.SetToken(c.userID, c.uid+":"+auth.RefreshToken)
|
|
||||||
|
|
||||||
// Auth has to be fully unlocked to get key salt. During `Auth` it can happen
|
// Auth has to be fully unlocked to get key salt. During `Auth` it can happen
|
||||||
// only to accounts without 2FA. For 2FA accounts, it's done in `Auth2FA`.
|
// only to accounts without 2FA. For 2FA accounts, it's done in `Auth2FA`.
|
||||||
@ -403,7 +409,8 @@ func (c *Client) Unlock(password string) (kr *pmcrypto.KeyRing, err error) {
|
|||||||
func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
|
func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
|
||||||
// If we don't yet have a saved access token, save this one in case the refresh fails!
|
// If we don't yet have a saved access token, save this one in case the refresh fails!
|
||||||
// That way we can try again later (see handleUnauthorizedStatus).
|
// That way we can try again later (see handleUnauthorizedStatus).
|
||||||
c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
|
// TODO:
|
||||||
|
// c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
|
||||||
|
|
||||||
split := strings.Split(uidAndRefreshToken, ":")
|
split := strings.Split(uidAndRefreshToken, ":")
|
||||||
if len(split) != 2 {
|
if len(split) != 2 {
|
||||||
@ -437,22 +444,18 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
auth = res.getAuth()
|
auth = res.getAuth()
|
||||||
// UID should never change after auth, see backend-communication#11
|
c.sendAuth(auth)
|
||||||
auth.uid = c.uid
|
|
||||||
if c.auths != nil {
|
|
||||||
c.auths <- auth
|
|
||||||
}
|
|
||||||
|
|
||||||
c.uid = auth.UID()
|
|
||||||
c.accessToken = auth.accessToken
|
|
||||||
c.cm.SetToken(c.userID, c.uid+":"+res.RefreshToken)
|
|
||||||
|
|
||||||
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
|
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
return auth, err
|
return auth, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logout logs the current user out.
|
func (c *Client) Logout() {
|
||||||
func (c *Client) Logout() (err error) {
|
c.cm.LogoutClient(c.userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// logout logs the current user out.
|
||||||
|
func (c *Client) logout() (err error) {
|
||||||
req, err := NewRequest("DELETE", "/auth", nil)
|
req, err := NewRequest("DELETE", "/auth", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -467,23 +470,13 @@ func (c *Client) Logout() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// This can trigger a deadlock! We don't want to do it if the above requests failed (GODT-154).
|
return
|
||||||
// That's why it's not in the deferred statement above.
|
}
|
||||||
if c.auths != nil {
|
|
||||||
c.auths <- nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// This should ideally be deferred at the top of this method so that it is executed
|
func (c *Client) clearSensitiveData() {
|
||||||
// regardless of what happens, but we currently don't have a way to prevent ourselves
|
|
||||||
// from using a logged out client. So for now, it's down here, as it was in Charles release.
|
|
||||||
// defer func() {
|
|
||||||
c.uid = ""
|
c.uid = ""
|
||||||
c.accessToken = ""
|
c.accessToken = ""
|
||||||
c.kr = nil
|
c.kr = nil
|
||||||
// c.addresses = nil
|
c.addresses = nil
|
||||||
c.user = nil
|
c.user = nil
|
||||||
c.cm.ClearToken(c.userID)
|
|
||||||
// }()
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -99,10 +99,8 @@ type ClientConfig struct {
|
|||||||
|
|
||||||
// Client to communicate with API.
|
// Client to communicate with API.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
auths chan<- *Auth // Channel that sends Auth responses back to the bridge.
|
|
||||||
|
|
||||||
client *http.Client
|
|
||||||
cm *ClientManager
|
cm *ClientManager
|
||||||
|
client *http.Client
|
||||||
|
|
||||||
uid string
|
uid string
|
||||||
accessToken string
|
accessToken string
|
||||||
@ -121,39 +119,42 @@ type Client struct {
|
|||||||
// newClient creates a new API client.
|
// newClient creates a new API client.
|
||||||
func newClient(cm *ClientManager, userID string) *Client {
|
func newClient(cm *ClientManager, userID string) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
|
|
||||||
client: getHTTPClient(cm.GetConfig()),
|
|
||||||
cm: cm,
|
cm: cm,
|
||||||
|
client: getHTTPClient(cm.GetConfig()),
|
||||||
userID: userID,
|
userID: userID,
|
||||||
requestLocker: &sync.Mutex{},
|
requestLocker: &sync.Mutex{},
|
||||||
keyLocker: &sync.Mutex{},
|
keyLocker: &sync.Mutex{},
|
||||||
|
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getHTTPClient returns a http client configured by the given client config.
|
// getHTTPClient returns a http client configured by the given client config.
|
||||||
func getHTTPClient(cfg *ClientConfig) *http.Client {
|
func getHTTPClient(cfg *ClientConfig) (hc *http.Client) {
|
||||||
hc := &http.Client{
|
hc = &http.Client{Timeout: cfg.Timeout}
|
||||||
Timeout: cfg.Timeout,
|
|
||||||
|
if cfg.Transport == nil && defaultTransport == nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Transport != nil {
|
if defaultTransport != nil {
|
||||||
cfgTransport, ok := cfg.Transport.(*http.Transport)
|
|
||||||
if ok {
|
|
||||||
// In future use Clone here.
|
|
||||||
// https://go-review.googlesource.com/c/go/+/174597/
|
|
||||||
transport := &http.Transport{}
|
|
||||||
*transport = *cfgTransport //nolint
|
|
||||||
if transport.Proxy == nil {
|
|
||||||
transport.Proxy = http.ProxyFromEnvironment
|
|
||||||
}
|
|
||||||
hc.Transport = transport
|
|
||||||
} else {
|
|
||||||
hc.Transport = cfg.Transport
|
|
||||||
}
|
|
||||||
} else if defaultTransport != nil {
|
|
||||||
hc.Transport = defaultTransport
|
hc.Transport = defaultTransport
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In future use Clone here.
|
||||||
|
// https://go-review.googlesource.com/c/go/+/174597/
|
||||||
|
if cfgTransport, ok := cfg.Transport.(*http.Transport); ok {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
*transport = *cfgTransport //nolint
|
||||||
|
if transport.Proxy == nil {
|
||||||
|
transport.Proxy = http.ProxyFromEnvironment
|
||||||
|
}
|
||||||
|
hc.Transport = transport
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hc.Transport = cfg.Transport
|
||||||
|
|
||||||
return hc
|
return hc
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -400,30 +401,20 @@ func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFun
|
|||||||
|
|
||||||
func (c *Client) refreshAccessToken() (err error) {
|
func (c *Client) refreshAccessToken() (err error) {
|
||||||
c.log.Debug("Refreshing token")
|
c.log.Debug("Refreshing token")
|
||||||
|
|
||||||
refreshToken := c.cm.GetToken(c.userID)
|
refreshToken := c.cm.GetToken(c.userID)
|
||||||
c.log.WithField("token", refreshToken).Info("Current refresh token")
|
|
||||||
if refreshToken == "" {
|
if refreshToken == "" {
|
||||||
if c.auths != nil {
|
c.sendAuth(nil)
|
||||||
c.auths <- nil
|
|
||||||
}
|
|
||||||
c.cm.ClearToken(c.userID)
|
|
||||||
return ErrInvalidToken
|
return ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
auth, err := c.AuthRefresh(refreshToken)
|
if _, err := c.AuthRefresh(refreshToken); err != nil {
|
||||||
if err != nil {
|
c.sendAuth(nil)
|
||||||
c.log.WithError(err).WithField("auths", c.auths).Debug("Token refreshing failed")
|
return err
|
||||||
// The refresh failed, so we should log the user out.
|
|
||||||
// A nil value in the Auths channel will trigger this.
|
|
||||||
if c.auths != nil {
|
|
||||||
c.auths <- nil
|
|
||||||
}
|
|
||||||
c.cm.ClearToken(c.userID)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
c.uid = auth.UID()
|
|
||||||
c.accessToken = auth.accessToken
|
return
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) {
|
func (c *Client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) {
|
||||||
|
|||||||
@ -1,30 +1,50 @@
|
|||||||
package pmapi
|
package pmapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/getsentry/raven-go"
|
"github.com/getsentry/raven-go"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientManager is a manager of clients.
|
// ClientManager is a manager of clients.
|
||||||
type ClientManager struct {
|
type ClientManager struct {
|
||||||
// TODO: Lockers.
|
clients map[string]*Client
|
||||||
|
clientsLocker sync.Locker
|
||||||
|
|
||||||
clients map[string]*Client
|
tokens map[string]string
|
||||||
tokens map[string]string
|
tokensLocker sync.Locker
|
||||||
config *ClientConfig
|
|
||||||
|
config *ClientConfig
|
||||||
|
|
||||||
|
bridgeAuths chan ClientAuth
|
||||||
|
clientAuths chan ClientAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientAuth struct {
|
||||||
|
UserID string
|
||||||
|
Auth *Auth
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
|
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
|
||||||
func NewClientManager(config *ClientConfig) *ClientManager {
|
func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||||
if err := raven.SetDSN(config.SentryDSN); err != nil {
|
if err := raven.SetDSN(config.SentryDSN); err != nil {
|
||||||
logrus.WithError(err).Error("Could not set up sentry DSN")
|
logrus.WithError(err).Error("Could not set up sentry DSN")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ClientManager{
|
cm = &ClientManager{
|
||||||
clients: make(map[string]*Client),
|
clients: make(map[string]*Client),
|
||||||
tokens: make(map[string]string),
|
clientsLocker: &sync.Mutex{},
|
||||||
config: config,
|
tokens: make(map[string]string),
|
||||||
|
tokensLocker: &sync.Mutex{},
|
||||||
|
config: config,
|
||||||
|
bridgeAuths: make(chan ClientAuth),
|
||||||
|
clientAuths: make(chan ClientAuth),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go cm.forwardClientAuths()
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClient returns a client for the given userID.
|
// GetClient returns a client for the given userID.
|
||||||
@ -39,6 +59,28 @@ func (cm *ClientManager) GetClient(userID string) *Client {
|
|||||||
return cm.clients[userID]
|
return cm.clients[userID]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
||||||
|
func (cm *ClientManager) LogoutClient(userID string) {
|
||||||
|
client, ok := cm.clients[userID]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(cm.clients, userID)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := client.logout(); err != nil {
|
||||||
|
// TODO: Try again!
|
||||||
|
logrus.WithError(err).Error("Client logout failed, not trying again")
|
||||||
|
}
|
||||||
|
client.clearSensitiveData()
|
||||||
|
cm.clearToken(userID)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// GetConfig returns the config used to configure clients.
|
// GetConfig returns the config used to configure clients.
|
||||||
func (cm *ClientManager) GetConfig() *ClientConfig {
|
func (cm *ClientManager) GetConfig() *ClientConfig {
|
||||||
return cm.config
|
return cm.config
|
||||||
@ -49,21 +91,52 @@ func (cm *ClientManager) GetToken(userID string) string {
|
|||||||
return cm.tokens[userID]
|
return cm.tokens[userID]
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetToken sets the token for the given userID.
|
// GetBridgeAuthChannel returns a channel on which client auths can be received.
|
||||||
func (cm *ClientManager) SetToken(userID, token string) {
|
func (cm *ClientManager) GetBridgeAuthChannel() chan ClientAuth {
|
||||||
|
return cm.clientAuths
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClientAuthChannel returns a channel on which clients should send auths.
|
||||||
|
func (cm *ClientManager) getClientAuthChannel() chan ClientAuth {
|
||||||
|
return cm.clientAuths
|
||||||
|
}
|
||||||
|
|
||||||
|
// forwardClientAuths handles all incoming auths from clients before forwarding them on the bridge auth channel.
|
||||||
|
func (cm *ClientManager) forwardClientAuths() {
|
||||||
|
for auth := range cm.clientAuths {
|
||||||
|
cm.handleClientAuth(auth)
|
||||||
|
cm.bridgeAuths <- auth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ClientManager) setToken(userID, token string) {
|
||||||
|
cm.tokensLocker.Lock()
|
||||||
|
defer cm.tokensLocker.Unlock()
|
||||||
|
|
||||||
|
logrus.WithField("userID", userID).WithField("token", token).Info("Updating refresh token")
|
||||||
|
|
||||||
cm.tokens[userID] = token
|
cm.tokens[userID] = token
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTokenIfUnset sets the token for the given userID if it does not yet have a token.
|
func (cm *ClientManager) clearToken(userID string) {
|
||||||
func (cm *ClientManager) SetTokenIfUnset(userID, token string) {
|
cm.tokensLocker.Lock()
|
||||||
if _, ok := cm.tokens[userID]; ok {
|
defer cm.tokensLocker.Unlock()
|
||||||
|
|
||||||
|
logrus.WithField("userID", userID).Info("Clearing refresh token")
|
||||||
|
|
||||||
|
delete(cm.tokens, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleClientAuth
|
||||||
|
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||||
|
// TODO: Maybe want to logout the client in case of nil auth.
|
||||||
|
if _, ok := cm.clients[ca.UserID]; !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.tokens[userID] = token
|
if ca.Auth == nil {
|
||||||
}
|
cm.clearToken(ca.UserID)
|
||||||
|
} else {
|
||||||
// ClearToken clears the token of the given userID.
|
cm.setToken(ca.UserID, ca.Auth.GenToken())
|
||||||
func (cm *ClientManager) ClearToken(userID string) {
|
}
|
||||||
delete(cm.tokens, userID)
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user