feat: central auth channel for clients

This commit is contained in:
James Houlahan
2020-04-01 15:16:36 +02:00
parent 0a55fac29a
commit f239e8f3bf
7 changed files with 227 additions and 193 deletions

View File

@ -43,14 +43,14 @@ var (
// Bridge is a struct handling users. // Bridge is a struct handling users.
type Bridge struct { type Bridge struct {
config Configer config Configer
pref PreferenceProvider pref PreferenceProvider
panicHandler PanicHandler panicHandler PanicHandler
events listener.Listener events listener.Listener
version string version string
clientMan *pmapi.ClientManager clientManager *pmapi.ClientManager
credStorer CredentialsStorer credStorer CredentialsStorer
storeCache *store.Cache storeCache *store.Cache
// users is a list of accounts that have been added to bridge. // users is a list of accounts that have been added to bridge.
// They are stored sorted in the credentials store in the order // They are stored sorted in the credentials store in the order
@ -76,22 +76,22 @@ func New(
panicHandler PanicHandler, panicHandler PanicHandler,
eventListener listener.Listener, eventListener listener.Listener,
version string, version string,
clientMan *pmapi.ClientManager, clientManager *pmapi.ClientManager,
credStorer CredentialsStorer, credStorer CredentialsStorer,
) *Bridge { ) *Bridge {
log.Trace("Creating new bridge") log.Trace("Creating new bridge")
b := &Bridge{ b := &Bridge{
config: config, config: config,
pref: pref, pref: pref,
panicHandler: panicHandler, panicHandler: panicHandler,
events: eventListener, events: eventListener,
version: version, version: version,
clientMan: clientMan, clientManager: clientManager,
credStorer: credStorer, credStorer: credStorer,
storeCache: store.NewCache(config.GetIMAPCachePath()), storeCache: store.NewCache(config.GetIMAPCachePath()),
idleUpdates: make(chan interface{}), idleUpdates: make(chan interface{}),
lock: sync.RWMutex{}, lock: sync.RWMutex{},
} }
// Allow DoH before starting bridge if the user has previously set this setting. // Allow DoH before starting bridge if the user has previously set this setting.
@ -105,6 +105,11 @@ func New(
b.watchBridgeOutdated() b.watchBridgeOutdated()
}() }()
go func() {
defer panicHandler.HandlePanic()
b.watchUserAuths()
}()
if b.credStorer == nil { if b.credStorer == nil {
log.Error("Bridge has no credentials store") log.Error("Bridge has no credentials store")
} else if err := b.loadUsersFromCredentialsStore(); err != nil { } else if err := b.loadUsersFromCredentialsStore(); err != nil {
@ -148,7 +153,7 @@ func (b *Bridge) loadUsersFromCredentialsStore() (err error) {
for _, userID := range userIDs { for _, userID := range userIDs {
l := log.WithField("user", userID) l := log.WithField("user", userID)
user, newUserErr := newUser(b.panicHandler, userID, b.events, b.credStorer, b.clientMan, b.storeCache, b.config.GetDBDir()) user, newUserErr := newUser(b.panicHandler, userID, b.events, b.credStorer, b.clientManager, b.storeCache, b.config.GetDBDir())
if newUserErr != nil { if newUserErr != nil {
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping") l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
continue continue
@ -173,6 +178,18 @@ func (b *Bridge) watchBridgeOutdated() {
} }
} }
func (b *Bridge) watchUserAuths() {
for auth := range b.clientManager.GetBridgeAuthChannel() {
user, ok := b.hasUser(auth.UserID)
if !ok {
continue
}
user.ReceiveAPIAuth(auth.Auth)
}
}
func (b *Bridge) closeAllConnections() { func (b *Bridge) closeAllConnections() {
for _, user := range b.users { for _, user := range b.users {
user.closeAllConnections() user.closeAllConnections()
@ -195,9 +212,8 @@ func (b *Bridge) Login(username, password string) (loginClient PMAPIProvider, au
b.crashBandicoot(username) b.crashBandicoot(username)
// We need to use "login" client because we need userID to properly // We need to use "login" client because we need userID to properly assign access tokens into token manager.
// assign access tokens into token manager. loginClient = b.clientManager.GetClient("login")
loginClient = b.clientMan.GetClient("login")
authInfo, err := loginClient.AuthInfo(username) authInfo, err := loginClient.AuthInfo(username)
if err != nil { if err != nil {
@ -227,29 +243,22 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
b.lock.Lock() b.lock.Lock()
defer b.lock.Unlock() defer b.lock.Unlock()
defer loginClient.Logout()
mbPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt) mbPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt)
if err != nil { if err != nil {
log.WithError(err).Error("Could not hash mailbox password") log.WithError(err).Error("Could not hash mailbox password")
if logoutErr := loginClient.Logout(); logoutErr != nil {
log.WithError(logoutErr).Error("Clean login session after hash password failed.")
}
return return
} }
if _, err = loginClient.Unlock(mbPassword); err != nil { if _, err = loginClient.Unlock(mbPassword); err != nil {
log.WithError(err).Error("Could not decrypt keyring") log.WithError(err).Error("Could not decrypt keyring")
if logoutErr := loginClient.Logout(); logoutErr != nil {
log.WithError(logoutErr).Error("Clean login session after unlock failed.")
}
return return
} }
apiUser, err := loginClient.CurrentUser() apiUser, err := loginClient.CurrentUser()
if err != nil { if err != nil {
log.WithError(err).Error("Could not get login API user") log.WithError(err).Error("Could not get login API user")
if logoutErr := loginClient.Logout(); logoutErr != nil {
log.WithError(logoutErr).Error("Clean login session after get current user failed.")
}
return return
} }
@ -259,20 +268,13 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
if hasUser && user.IsConnected() { if hasUser && user.IsConnected() {
err = errors.New("user is already logged in") err = errors.New("user is already logged in")
log.WithError(err).Warn("User is already logged in") log.WithError(err).Warn("User is already logged in")
if logoutErr := loginClient.Logout(); logoutErr != nil {
log.WithError(logoutErr).Warn("Could not discard auth generated during second login")
}
return return
} }
apiToken := auth.UID() + ":" + auth.RefreshToken apiClient := b.clientManager.GetClient(apiUser.ID)
apiClient := b.clientMan.GetClient(apiUser.ID) auth, err = apiClient.AuthRefresh(auth.GenToken())
auth, err = apiClient.AuthRefresh(apiToken)
if err != nil { if err != nil {
log.WithError(err).Error("Could refresh token in new client") log.WithError(err).Error("Could refresh token in new client")
if logoutErr := loginClient.Logout(); logoutErr != nil {
log.WithError(logoutErr).Warn("Could not discard auth generated after auth refresh")
}
return return
} }
@ -280,22 +282,18 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
apiUser, err = apiClient.CurrentUser() apiUser, err = apiClient.CurrentUser()
if err != nil { if err != nil {
log.WithError(err).Error("Could not get current API user") log.WithError(err).Error("Could not get current API user")
if logoutErr := loginClient.Logout(); logoutErr != nil {
log.WithError(logoutErr).Error("Clean login session after get current user failed.")
}
return return
} }
apiToken = auth.UID() + ":" + auth.RefreshToken
activeEmails := apiClient.Addresses().ActiveEmails() activeEmails := apiClient.Addresses().ActiveEmails()
if _, err = b.credStorer.Add(apiUser.ID, apiUser.Name, apiToken, mbPassword, activeEmails); err != nil { if _, err = b.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), mbPassword, activeEmails); err != nil {
log.WithError(err).Error("Could not add user to credentials store") log.WithError(err).Error("Could not add user to credentials store")
return return
} }
// If it's a new user, generate the user object. // If it's a new user, generate the user object.
if !hasUser { if !hasUser {
user, err = newUser(b.panicHandler, apiUser.ID, b.events, b.credStorer, b.clientMan, b.storeCache, b.config.GetDBDir()) user, err = newUser(b.panicHandler, apiUser.ID, b.events, b.credStorer, b.clientManager, b.storeCache, b.config.GetDBDir())
if err != nil { if err != nil {
log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user") log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user")
return return
@ -405,8 +403,8 @@ func (b *Bridge) DeleteUser(userID string, clearStore bool) error {
// ReportBug reports a new bug from the user. // ReportBug reports a new bug from the user.
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error { func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
c := b.clientMan.GetClient("bug_reporter") c := b.clientManager.GetClient("bug_reporter")
defer func() { _ = c.Logout() }() defer c.Logout()
title := "[Bridge] Bug" title := "[Bridge] Bug"
if err := c.ReportBugWithEmailClient( if err := c.ReportBugWithEmailClient(
@ -429,8 +427,8 @@ func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address,
// SendMetric sends a metric. We don't want to return any errors, only log them. // SendMetric sends a metric. We don't want to return any errors, only log them.
func (b *Bridge) SendMetric(m m.Metric) { func (b *Bridge) SendMetric(m m.Metric) {
c := b.clientMan.GetClient("metric_reporter") c := b.clientManager.GetClient("metric_reporter")
defer func() { _ = c.Logout() }() defer c.Logout()
cat, act, lab := m.Get() cat, act, lab := m.Get()
if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil { if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil {

View File

@ -47,7 +47,6 @@ type Clientman interface {
} }
type PMAPIProvider interface { type PMAPIProvider interface {
SetAuths(auths chan<- *pmapi.Auth)
Auth(username, password string, info *pmapi.AuthInfo) (*pmapi.Auth, error) Auth(username, password string, info *pmapi.AuthInfo) (*pmapi.Auth, error)
AuthInfo(username string) (*pmapi.AuthInfo, error) AuthInfo(username string) (*pmapi.AuthInfo, error)
AuthRefresh(token string) (*pmapi.Auth, error) AuthRefresh(token string) (*pmapi.Auth, error)
@ -56,7 +55,8 @@ type PMAPIProvider interface {
CurrentUser() (*pmapi.User, error) CurrentUser() (*pmapi.User, error)
UpdateUser() (*pmapi.User, error) UpdateUser() (*pmapi.User, error)
Addresses() pmapi.AddressList Addresses() pmapi.AddressList
Logout() error
Logout()
GetEvent(eventID string) (*pmapi.Event, error) GetEvent(eventID string) (*pmapi.Event, error)

View File

@ -53,9 +53,8 @@ type User struct {
userID string userID string
creds *credentials.Credentials creds *credentials.Credentials
lock sync.RWMutex lock sync.RWMutex
authChannel chan *pmapi.Auth isAuthorized bool
hasAPIAuth bool
unlockingKeyringLock sync.Mutex unlockingKeyringLock sync.Mutex
wasKeyringUnlocked bool wasKeyringUnlocked bool
@ -116,15 +115,6 @@ func (u *User) init(idleUpdates chan interface{}) (err error) {
} }
u.creds = creds u.creds = creds
// Set up the auth channel on which auths from the api client are sent.
u.authChannel = make(chan *pmapi.Auth)
u.client().SetAuths(u.authChannel)
u.hasAPIAuth = false
go func() {
defer u.panicHandler.HandlePanic()
u.watchAPIClientAuths()
}()
// Try to authorise the user if they aren't already authorised. // Try to authorise the user if they aren't already authorised.
// Note: we still allow users to set up bridge if the internet is off. // Note: we still allow users to set up bridge if the internet is off.
if authErr := u.authorizeIfNecessary(false); authErr != nil { if authErr := u.authorizeIfNecessary(false); authErr != nil {
@ -169,7 +159,7 @@ func (u *User) SetIMAPIdleUpdateChannel() {
// authorizeIfNecessary checks whether user is logged in and is connected to api auth channel. // authorizeIfNecessary checks whether user is logged in and is connected to api auth channel.
// If user is not already connected to the api auth channel (for example there was no internet during start), // If user is not already connected to the api auth channel (for example there was no internet during start),
// it tries to connect it. See `connectToAuthChannel` for more info. // it tries to connect it.
func (u *User) authorizeIfNecessary(emitEvent bool) (err error) { func (u *User) authorizeIfNecessary(emitEvent bool) (err error) {
// If user is connected and has an auth channel, then perfect, nothing to do here. // If user is connected and has an auth channel, then perfect, nothing to do here.
if u.creds.IsConnected() && u.HasAPIAuth() { if u.creds.IsConnected() && u.HasAPIAuth() {
@ -236,11 +226,9 @@ func (u *User) authorizeAndUnlock() (err error) {
return nil return nil
} }
auth, err := u.client().AuthRefresh(u.creds.APIToken) if _, err := u.client().AuthRefresh(u.creds.APIToken); err != nil {
if err != nil {
return errors.Wrap(err, "failed to refresh API auth") return errors.Wrap(err, "failed to refresh API auth")
} }
u.authChannel <- auth
if _, err = u.client().Unlock(u.creds.MailboxPassword); err != nil { if _, err = u.client().Unlock(u.creds.MailboxPassword); err != nil {
return errors.Wrap(err, "failed to unlock user") return errors.Wrap(err, "failed to unlock user")
@ -253,32 +241,34 @@ func (u *User) authorizeAndUnlock() (err error) {
return nil return nil
} }
// See `connectToAPIClientAuthChannel` for more info. func (u *User) ReceiveAPIAuth(auth *pmapi.Auth) {
func (u *User) watchAPIClientAuths() { if auth == nil {
for auth := range u.authChannel { if err := u.logout(); err != nil {
if auth != nil {
newRefreshToken := auth.UID() + ":" + auth.RefreshToken
u.updateAPIToken(newRefreshToken)
u.hasAPIAuth = true
} else if err := u.logout(); err != nil {
u.log.WithError(err).Error("Cannot logout user after receiving empty auth from API") u.log.WithError(err).Error("Cannot logout user after receiving empty auth from API")
} }
u.isAuthorized = false
return
} }
u.updateAPIToken(auth.GenToken())
} }
// updateAPIToken is helper for updating the token in keychain. It's not supposed to be // updateAPIToken is helper for updating the token in keychain. It's not supposed to be
// called directly from other parts of the code--only from `watchAPIClientAuths`. // called directly from other parts of the code, only from `ReceiveAPIAuth`.
func (u *User) updateAPIToken(newRefreshToken string) { func (u *User) updateAPIToken(newRefreshToken string) {
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
u.log.Info("Saving refresh token") u.log.WithField("token", newRefreshToken).Info("Saving token to credentials store")
if err := u.credStorer.UpdateToken(u.userID, newRefreshToken); err != nil { if err := u.credStorer.UpdateToken(u.userID, newRefreshToken); err != nil {
u.log.WithError(err).Error("Cannot update refresh token in credentials store") u.log.WithError(err).Error("Cannot update refresh token in credentials store")
} else { return
u.refreshFromCredentials()
} }
u.refreshFromCredentials()
u.isAuthorized = true
} }
// clearStore removes the database. // clearStore removes the database.
@ -548,18 +538,7 @@ func (u *User) Logout() (err error) {
u.wasKeyringUnlocked = false u.wasKeyringUnlocked = false
u.unlockingKeyringLock.Unlock() u.unlockingKeyringLock.Unlock()
if err = u.client().Logout(); err != nil { u.client().Logout()
u.log.WithError(err).Warn("Could not log user out from API client")
}
u.client().SetAuths(nil)
// Logout needs to stop auth channel so when user logs back in
// it can register again with new client.
// Note: be careful to not close channel twice.
if u.authChannel != nil {
close(u.authChannel)
u.authChannel = nil
}
if err = u.credStorer.Logout(u.userID); err != nil { if err = u.credStorer.Logout(u.userID); err != nil {
u.log.WithError(err).Warn("Could not log user out from credentials store") u.log.WithError(err).Warn("Could not log user out from credentials store")
@ -617,5 +596,5 @@ func (u *User) GetStore() *store.Store {
} }
func (u *User) HasAPIAuth() bool { func (u *User) HasAPIAuth() bool {
return u.hasAPIAuth return u.isAuthorized
} }

View File

@ -66,7 +66,7 @@ func HandlePanic(cfg *Config, output string) {
if !cfg.IsDevMode() { if !cfg.IsDevMode() {
// TODO: Is it okay to just create a throwaway client like this? // TODO: Is it okay to just create a throwaway client like this?
c := pmapi.NewClientManager(cfg.GetAPIConfig()).GetClient("no-user-id") c := pmapi.NewClientManager(cfg.GetAPIConfig()).GetClient("no-user-id")
defer func() { _ = c.Logout() }() defer c.Logout()
if err := c.ReportSentryCrash(fmt.Errorf(output)); err != nil { if err := c.ReportSentryCrash(fmt.Errorf(output)); err != nil {
log.Error("Sentry crash report failed: ", err) log.Error("Sentry crash report failed: ", err)

View File

@ -21,6 +21,7 @@ import (
"crypto/subtle" "crypto/subtle"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -122,6 +123,10 @@ func (s *Auth) UID() string {
return s.uid return s.uid
} }
func (s *Auth) GenToken() string {
return fmt.Sprintf("%v:%v", s.UID(), s.RefreshToken)
}
func (s *Auth) HasTwoFactor() bool { func (s *Auth) HasTwoFactor() bool {
if s.TwoFA == nil { if s.TwoFA == nil {
return false return false
@ -191,9 +196,16 @@ type AuthRefreshReq struct {
State string State string
} }
// SetAuths sets auths channel. func (c *Client) sendAuth(auth *Auth) {
func (c *Client) SetAuths(auths chan<- *Auth) { c.cm.getClientAuthChannel() <- ClientAuth{
c.auths = auths UserID: c.userID,
Auth: auth,
}
if auth != nil {
c.uid = auth.UID()
c.accessToken = auth.accessToken
}
} }
// AuthInfo gets authentication info for a user. // AuthInfo gets authentication info for a user.
@ -301,13 +313,7 @@ func (c *Client) Auth(username, password string, info *AuthInfo) (auth *Auth, er
} }
auth = authRes.getAuth() auth = authRes.getAuth()
c.uid = auth.UID() c.sendAuth(auth)
c.accessToken = auth.accessToken
if c.auths != nil {
c.auths <- auth
}
c.cm.SetToken(c.userID, c.uid+":"+auth.RefreshToken)
// Auth has to be fully unlocked to get key salt. During `Auth` it can happen // Auth has to be fully unlocked to get key salt. During `Auth` it can happen
// only to accounts without 2FA. For 2FA accounts, it's done in `Auth2FA`. // only to accounts without 2FA. For 2FA accounts, it's done in `Auth2FA`.
@ -403,7 +409,8 @@ func (c *Client) Unlock(password string) (kr *pmcrypto.KeyRing, err error) {
func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) { func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
// If we don't yet have a saved access token, save this one in case the refresh fails! // If we don't yet have a saved access token, save this one in case the refresh fails!
// That way we can try again later (see handleUnauthorizedStatus). // That way we can try again later (see handleUnauthorizedStatus).
c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken) // TODO:
// c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
split := strings.Split(uidAndRefreshToken, ":") split := strings.Split(uidAndRefreshToken, ":")
if len(split) != 2 { if len(split) != 2 {
@ -437,22 +444,18 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
} }
auth = res.getAuth() auth = res.getAuth()
// UID should never change after auth, see backend-communication#11 c.sendAuth(auth)
auth.uid = c.uid
if c.auths != nil {
c.auths <- auth
}
c.uid = auth.UID()
c.accessToken = auth.accessToken
c.cm.SetToken(c.userID, c.uid+":"+res.RefreshToken)
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second) c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
return auth, err return auth, err
} }
// Logout logs the current user out. func (c *Client) Logout() {
func (c *Client) Logout() (err error) { c.cm.LogoutClient(c.userID)
}
// logout logs the current user out.
func (c *Client) logout() (err error) {
req, err := NewRequest("DELETE", "/auth", nil) req, err := NewRequest("DELETE", "/auth", nil)
if err != nil { if err != nil {
return return
@ -467,23 +470,13 @@ func (c *Client) Logout() (err error) {
return return
} }
// This can trigger a deadlock! We don't want to do it if the above requests failed (GODT-154). return
// That's why it's not in the deferred statement above. }
if c.auths != nil {
c.auths <- nil
}
// This should ideally be deferred at the top of this method so that it is executed func (c *Client) clearSensitiveData() {
// regardless of what happens, but we currently don't have a way to prevent ourselves
// from using a logged out client. So for now, it's down here, as it was in Charles release.
// defer func() {
c.uid = "" c.uid = ""
c.accessToken = "" c.accessToken = ""
c.kr = nil c.kr = nil
// c.addresses = nil c.addresses = nil
c.user = nil c.user = nil
c.cm.ClearToken(c.userID)
// }()
return err
} }

View File

@ -99,10 +99,8 @@ type ClientConfig struct {
// Client to communicate with API. // Client to communicate with API.
type Client struct { type Client struct {
auths chan<- *Auth // Channel that sends Auth responses back to the bridge.
client *http.Client
cm *ClientManager cm *ClientManager
client *http.Client
uid string uid string
accessToken string accessToken string
@ -121,39 +119,42 @@ type Client struct {
// newClient creates a new API client. // newClient creates a new API client.
func newClient(cm *ClientManager, userID string) *Client { func newClient(cm *ClientManager, userID string) *Client {
return &Client{ return &Client{
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
client: getHTTPClient(cm.GetConfig()),
cm: cm, cm: cm,
client: getHTTPClient(cm.GetConfig()),
userID: userID, userID: userID,
requestLocker: &sync.Mutex{}, requestLocker: &sync.Mutex{},
keyLocker: &sync.Mutex{}, keyLocker: &sync.Mutex{},
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
} }
} }
// getHTTPClient returns a http client configured by the given client config. // getHTTPClient returns a http client configured by the given client config.
func getHTTPClient(cfg *ClientConfig) *http.Client { func getHTTPClient(cfg *ClientConfig) (hc *http.Client) {
hc := &http.Client{ hc = &http.Client{Timeout: cfg.Timeout}
Timeout: cfg.Timeout,
if cfg.Transport == nil && defaultTransport == nil {
return
} }
if cfg.Transport != nil { if defaultTransport != nil {
cfgTransport, ok := cfg.Transport.(*http.Transport)
if ok {
// In future use Clone here.
// https://go-review.googlesource.com/c/go/+/174597/
transport := &http.Transport{}
*transport = *cfgTransport //nolint
if transport.Proxy == nil {
transport.Proxy = http.ProxyFromEnvironment
}
hc.Transport = transport
} else {
hc.Transport = cfg.Transport
}
} else if defaultTransport != nil {
hc.Transport = defaultTransport hc.Transport = defaultTransport
return
} }
// In future use Clone here.
// https://go-review.googlesource.com/c/go/+/174597/
if cfgTransport, ok := cfg.Transport.(*http.Transport); ok {
transport := &http.Transport{}
*transport = *cfgTransport //nolint
if transport.Proxy == nil {
transport.Proxy = http.ProxyFromEnvironment
}
hc.Transport = transport
return
}
hc.Transport = cfg.Transport
return hc return hc
} }
@ -400,30 +401,20 @@ func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFun
func (c *Client) refreshAccessToken() (err error) { func (c *Client) refreshAccessToken() (err error) {
c.log.Debug("Refreshing token") c.log.Debug("Refreshing token")
refreshToken := c.cm.GetToken(c.userID) refreshToken := c.cm.GetToken(c.userID)
c.log.WithField("token", refreshToken).Info("Current refresh token")
if refreshToken == "" { if refreshToken == "" {
if c.auths != nil { c.sendAuth(nil)
c.auths <- nil
}
c.cm.ClearToken(c.userID)
return ErrInvalidToken return ErrInvalidToken
} }
auth, err := c.AuthRefresh(refreshToken) if _, err := c.AuthRefresh(refreshToken); err != nil {
if err != nil { c.sendAuth(nil)
c.log.WithError(err).WithField("auths", c.auths).Debug("Token refreshing failed") return err
// The refresh failed, so we should log the user out.
// A nil value in the Auths channel will trigger this.
if c.auths != nil {
c.auths <- nil
}
c.cm.ClearToken(c.userID)
return
} }
c.uid = auth.UID()
c.accessToken = auth.accessToken return
return err
} }
func (c *Client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) { func (c *Client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) {

View File

@ -1,30 +1,50 @@
package pmapi package pmapi
import ( import (
"sync"
"github.com/getsentry/raven-go" "github.com/getsentry/raven-go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// ClientManager is a manager of clients. // ClientManager is a manager of clients.
type ClientManager struct { type ClientManager struct {
// TODO: Lockers. clients map[string]*Client
clientsLocker sync.Locker
clients map[string]*Client tokens map[string]string
tokens map[string]string tokensLocker sync.Locker
config *ClientConfig
config *ClientConfig
bridgeAuths chan ClientAuth
clientAuths chan ClientAuth
}
type ClientAuth struct {
UserID string
Auth *Auth
} }
// NewClientManager creates a new ClientMan which manages clients configured with the given client config. // NewClientManager creates a new ClientMan which manages clients configured with the given client config.
func NewClientManager(config *ClientConfig) *ClientManager { func NewClientManager(config *ClientConfig) (cm *ClientManager) {
if err := raven.SetDSN(config.SentryDSN); err != nil { if err := raven.SetDSN(config.SentryDSN); err != nil {
logrus.WithError(err).Error("Could not set up sentry DSN") logrus.WithError(err).Error("Could not set up sentry DSN")
} }
return &ClientManager{ cm = &ClientManager{
clients: make(map[string]*Client), clients: make(map[string]*Client),
tokens: make(map[string]string), clientsLocker: &sync.Mutex{},
config: config, tokens: make(map[string]string),
tokensLocker: &sync.Mutex{},
config: config,
bridgeAuths: make(chan ClientAuth),
clientAuths: make(chan ClientAuth),
} }
go cm.forwardClientAuths()
return
} }
// GetClient returns a client for the given userID. // GetClient returns a client for the given userID.
@ -39,6 +59,28 @@ func (cm *ClientManager) GetClient(userID string) *Client {
return cm.clients[userID] return cm.clients[userID]
} }
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
func (cm *ClientManager) LogoutClient(userID string) {
client, ok := cm.clients[userID]
if !ok {
return
}
delete(cm.clients, userID)
go func() {
if err := client.logout(); err != nil {
// TODO: Try again!
logrus.WithError(err).Error("Client logout failed, not trying again")
}
client.clearSensitiveData()
cm.clearToken(userID)
}()
return
}
// GetConfig returns the config used to configure clients. // GetConfig returns the config used to configure clients.
func (cm *ClientManager) GetConfig() *ClientConfig { func (cm *ClientManager) GetConfig() *ClientConfig {
return cm.config return cm.config
@ -49,21 +91,52 @@ func (cm *ClientManager) GetToken(userID string) string {
return cm.tokens[userID] return cm.tokens[userID]
} }
// SetToken sets the token for the given userID. // GetBridgeAuthChannel returns a channel on which client auths can be received.
func (cm *ClientManager) SetToken(userID, token string) { func (cm *ClientManager) GetBridgeAuthChannel() chan ClientAuth {
return cm.clientAuths
}
// getClientAuthChannel returns a channel on which clients should send auths.
func (cm *ClientManager) getClientAuthChannel() chan ClientAuth {
return cm.clientAuths
}
// forwardClientAuths handles all incoming auths from clients before forwarding them on the bridge auth channel.
func (cm *ClientManager) forwardClientAuths() {
for auth := range cm.clientAuths {
cm.handleClientAuth(auth)
cm.bridgeAuths <- auth
}
}
func (cm *ClientManager) setToken(userID, token string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).WithField("token", token).Info("Updating refresh token")
cm.tokens[userID] = token cm.tokens[userID] = token
} }
// SetTokenIfUnset sets the token for the given userID if it does not yet have a token. func (cm *ClientManager) clearToken(userID string) {
func (cm *ClientManager) SetTokenIfUnset(userID, token string) { cm.tokensLocker.Lock()
if _, ok := cm.tokens[userID]; ok { defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).Info("Clearing refresh token")
delete(cm.tokens, userID)
}
// handleClientAuth
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
// TODO: Maybe want to logout the client in case of nil auth.
if _, ok := cm.clients[ca.UserID]; !ok {
return return
} }
cm.tokens[userID] = token if ca.Auth == nil {
} cm.clearToken(ca.UserID)
} else {
// ClearToken clears the token of the given userID. cm.setToken(ca.UserID, ca.Auth.GenToken())
func (cm *ClientManager) ClearToken(userID string) { }
delete(cm.tokens, userID)
} }