feat: implement token expiration watcher

This commit is contained in:
James Houlahan
2020-04-02 14:10:15 +02:00
parent ce29d4d74e
commit 941e09079c
15 changed files with 149 additions and 93 deletions

View File

@ -3,12 +3,15 @@ package pmapi
import (
"net/http"
"sync"
"time"
"github.com/getsentry/raven-go"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
var proxyUseDuration = 24 * time.Hour
// ClientManager is a manager of clients.
type ClientManager struct {
config *ClientConfig
@ -16,8 +19,9 @@ type ClientManager struct {
clients map[string]*Client
clientsLocker sync.Locker
tokens map[string]string
tokensLocker sync.Locker
tokens map[string]string
tokenExpirations map[string]*tokenExpiration
tokensLocker sync.Locker
url string
urlLocker sync.Locker
@ -34,6 +38,11 @@ type ClientAuth struct {
Auth *Auth
}
type tokenExpiration struct {
timer *time.Timer
cancel chan (struct{})
}
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
func NewClientManager(config *ClientConfig) (cm *ClientManager) {
if err := raven.SetDSN(config.SentryDSN); err != nil {
@ -46,8 +55,9 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
clients: make(map[string]*Client),
clientsLocker: &sync.Mutex{},
tokens: make(map[string]string),
tokensLocker: &sync.Mutex{},
tokens: make(map[string]string),
tokenExpirations: make(map[string]*tokenExpiration),
tokensLocker: &sync.Mutex{},
url: RootURL,
urlLocker: &sync.Mutex{},
@ -112,43 +122,56 @@ func (cm *ClientManager) GetRootURL() string {
return cm.url
}
// SetRootURL sets the root URL to make requests to.
func (cm *ClientManager) SetRootURL(url string) {
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
func (cm *ClientManager) IsProxyAllowed() bool {
cm.urlLocker.Lock()
defer cm.urlLocker.Unlock()
logrus.WithField("url", url).Info("Changing to a new root URL")
cm.url = url
}
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
func (cm *ClientManager) IsProxyAllowed() bool {
return cm.allowProxy
}
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
func (cm *ClientManager) AllowProxy() {
cm.urlLocker.Lock()
defer cm.urlLocker.Unlock()
cm.allowProxy = true
}
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
func (cm *ClientManager) DisallowProxy() {
cm.urlLocker.Lock()
defer cm.urlLocker.Unlock()
cm.allowProxy = false
cm.url = RootURL
}
// IsProxyEnabled returns whether we are currently proxying requests.
func (cm *ClientManager) IsProxyEnabled() bool {
cm.urlLocker.Lock()
defer cm.urlLocker.Unlock()
return cm.url != RootURL
}
// FindProxy returns a usable proxy server.
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
logrus.Info("Attempting gto switch to a proxy")
cm.urlLocker.Lock()
defer cm.urlLocker.Unlock()
logrus.Info("Attempting to switch to a proxy")
if proxy, err = cm.proxyProvider.findProxy(); err != nil {
err = errors.Wrap(err, "failed to find usable proxy")
err = errors.Wrap(err, "failed to find a usable proxy")
return
}
cm.SetRootURL(proxy)
logrus.WithField("proxy", proxy).Info("Switching to a proxy")
// TODO: Disable after 24 hours.
cm.url = proxy
// TODO: Disable again after 24 hours.
return
}
@ -165,7 +188,7 @@ func (cm *ClientManager) GetToken(userID string) string {
// GetBridgeAuthChannel returns a channel on which client auths can be received.
func (cm *ClientManager) GetBridgeAuthChannel() chan ClientAuth {
return cm.clientAuths
return cm.bridgeAuths
}
// getClientAuthChannel returns a channel on which clients should send auths.
@ -176,25 +199,46 @@ func (cm *ClientManager) getClientAuthChannel() chan ClientAuth {
// forwardClientAuths handles all incoming auths from clients before forwarding them on the bridge auth channel.
func (cm *ClientManager) forwardClientAuths() {
for auth := range cm.clientAuths {
logrus.Debug("ClientManager received auth from client")
cm.handleClientAuth(auth)
logrus.Debug("ClientManager is forwarding auth to bridge")
cm.bridgeAuths <- auth
}
}
func (cm *ClientManager) setToken(userID, token string) {
// setToken sets the token for the given userID with the given expiration time.
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).Info("Updating refresh token")
logrus.WithField("userID", userID).Info("Updating token")
cm.tokens[userID] = token
cm.setTokenExpiration(userID, expiration)
}
// setTokenExpiration will ensure the token is refreshed if it expires.
// If the token already has an expiration time set, it is replaced.
func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) {
if exp, ok := cm.tokenExpirations[userID]; ok {
exp.timer.Stop()
close(exp.cancel)
}
cm.tokenExpirations[userID] = &tokenExpiration{
timer: time.NewTimer(expiration),
cancel: make(chan struct{}),
}
go cm.watchTokenExpiration(userID)
}
func (cm *ClientManager) clearToken(userID string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).Info("Clearing refresh token")
logrus.WithField("userID", userID).Info("Clearing token")
delete(cm.tokens, userID)
}
@ -203,6 +247,7 @@ func (cm *ClientManager) clearToken(userID string) {
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
// If we aren't managing this client, there's nothing to do.
if _, ok := cm.clients[ca.UserID]; !ok {
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client")
return
}
@ -213,5 +258,18 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
return
}
cm.setToken(ca.UserID, ca.Auth.GenToken())
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
}
func (cm *ClientManager) watchTokenExpiration(userID string) {
expiration := cm.tokenExpirations[userID]
select {
case <-expiration.timer.C:
logrus.WithField("userID", userID).Info("Auth token expired! Refreshing")
cm.clients[userID].AuthRefresh(cm.tokens[userID])
case <-expiration.cancel:
logrus.WithField("userID", userID).Info("Auth was refreshed before it expired, cancelling this watcher")
}
}