Files
proton-bridge/pkg/pmapi/clientmanager.go
2020-04-21 08:36:39 +00:00

441 lines
11 KiB
Go

package pmapi
import (
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// ClientManager is a manager of clients.
type ClientManager struct {
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
// create other types of clients (e.g. for integration tests).
newClient func(userID string) Client
config *ClientConfig
roundTripper http.RoundTripper
clients map[string]Client
clientsLocker sync.Locker
tokens map[string]string
tokensLocker sync.Locker
expirations map[string]*tokenExpiration
expiredTokens chan string
expirationsLocker sync.Locker
bridgeAuths chan ClientAuth
clientAuths chan ClientAuth
host, scheme string
hostLocker sync.RWMutex
allowProxy bool
proxyProvider *proxyProvider
proxyUseDuration time.Duration
idGen idGen
log *logrus.Entry
}
type idGen int
func (i *idGen) next() int {
(*i)++
return int(*i)
}
// ClientAuth holds an API auth produced by a Client for a specific user.
type ClientAuth struct {
UserID string
Auth *Auth
}
// tokenExpiration manages the expiration of an access token.
type tokenExpiration struct {
timer *time.Timer
cancel chan (struct{})
}
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
func NewClientManager(config *ClientConfig) (cm *ClientManager) {
cm = &ClientManager{
config: config,
roundTripper: http.DefaultTransport,
clients: make(map[string]Client),
clientsLocker: &sync.Mutex{},
tokens: make(map[string]string),
tokensLocker: &sync.Mutex{},
expirations: make(map[string]*tokenExpiration),
expiredTokens: make(chan string),
expirationsLocker: &sync.Mutex{},
host: RootURL,
scheme: rootScheme,
hostLocker: sync.RWMutex{},
bridgeAuths: make(chan ClientAuth),
clientAuths: make(chan ClientAuth),
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
proxyUseDuration: proxyUseDuration,
log: logrus.WithField("pkg", "pmapi-manager"),
}
cm.newClient = func(userID string) Client {
return newClient(cm, userID)
}
go cm.forwardClientAuths()
go cm.watchTokenExpirations()
return cm
}
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
cm.newClient = f
}
// SetRoundTripper sets the roundtripper used by clients created by this client manager.
func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
cm.roundTripper = rt
}
// GetClient returns a client for the given userID.
// If the client does not exist already, it is created.
func (cm *ClientManager) GetClient(userID string) Client {
cm.clientsLocker.Lock()
defer cm.clientsLocker.Unlock()
if client, ok := cm.clients[userID]; ok {
return client
}
cm.clients[userID] = cm.newClient(userID)
return cm.clients[userID]
}
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created.
func (cm *ClientManager) GetAnonymousClient() Client {
return cm.GetClient(fmt.Sprintf("anonymous-%v", cm.idGen.next()))
}
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
func (cm *ClientManager) LogoutClient(userID string) {
cm.clientsLocker.Lock()
defer cm.clientsLocker.Unlock()
client, ok := cm.clients[userID]
if !ok {
return
}
delete(cm.clients, userID)
go func() {
defer client.ClearData()
defer cm.clearToken(userID)
if strings.HasPrefix(userID, "anonymous-") {
return
}
for client.DeleteAuth() == ErrAPINotReachable {
cm.log.Warn("Logging out client failed because API was not reachable, retrying...")
}
}()
}
// GetRootURL returns the full root URL (scheme+host).
func (cm *ClientManager) GetRootURL() string {
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return fmt.Sprintf("%v://%v", cm.scheme, cm.host)
}
// getHost returns the host to make requests to.
// It does not include the protocol i.e. no "https://" (use getScheme for that).
func (cm *ClientManager) getHost() string {
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return cm.host
}
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
func (cm *ClientManager) IsProxyAllowed() bool {
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return cm.allowProxy
}
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
func (cm *ClientManager) AllowProxy() {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
cm.allowProxy = true
}
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
func (cm *ClientManager) DisallowProxy() {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
cm.allowProxy = false
cm.host = RootURL
}
// IsProxyEnabled returns whether we are currently proxying requests.
func (cm *ClientManager) IsProxyEnabled() bool {
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return cm.host != RootURL
}
// switchToReachableServer switches to using a reachable server (either proxy or standard API).
func (cm *ClientManager) switchToReachableServer() (proxy string, err error) {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
logrus.Info("Attempting to switch to a proxy")
if proxy, err = cm.proxyProvider.findReachableServer(); err != nil {
err = errors.Wrap(err, "failed to find a usable proxy")
return
}
logrus.WithField("proxy", proxy).Info("Switching to a proxy")
// If the host is currently the RootURL, it's the first time we are enabling a proxy.
// This means we want to disable it again in 24 hours.
if cm.host == RootURL {
go func() {
<-time.After(cm.proxyUseDuration)
cm.host = RootURL
}()
}
cm.host = proxy
return
}
// GetToken returns the token for the given userID.
func (cm *ClientManager) GetToken(userID string) string {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
return cm.tokens[userID]
}
// GetAuthUpdateChannel returns a channel on which client auths can be received.
func (cm *ClientManager) GetAuthUpdateChannel() chan ClientAuth {
return cm.bridgeAuths
}
// GetClientAuthChannel returns a channel on which clients should send auths.
func (cm *ClientManager) GetClientAuthChannel() chan ClientAuth {
return cm.clientAuths
}
// Errors for possible connection issues
var (
ErrNoInternetConnection = errors.New("no internet connection")
ErrCanNotReachAPI = errors.New("can not reach PM API")
)
// CheckConnection returns an error if there is no internet connection.
// This should be moved to the ConnectionManager when it is implemented.
func (cm *ClientManager) CheckConnection() error {
client := getHTTPClient(cm.config, cm.roundTripper)
// Do not cumulate timeouts, use goroutines.
retStatus := make(chan error)
retAPI := make(chan error)
// Check protonstatus.com without SSL for performance reasons. vpn_status endpoint is fast and
// returns only OK; this endpoint is not known by the public. We check the connection only.
go checkConnection(client, "http://protonstatus.com/vpn_status", retStatus)
// Check of API reachability also uses a fast endpoint.
go checkConnection(client, cm.GetRootURL()+"/tests/ping", retAPI)
errStatus := <-retStatus
errAPI := <-retAPI
switch {
case errStatus == nil && errAPI == nil:
return nil
case errStatus == nil && errAPI != nil:
cm.log.Error("ProtonStatus is reachable but API is not")
return ErrCanNotReachAPI
case errStatus != nil && errAPI == nil:
cm.log.Warn("API is reachable but protonstatus is not")
return nil
case errStatus != nil && errAPI != nil:
cm.log.Error("Both ProtonStatus and API are unreachable")
return ErrNoInternetConnection
}
return nil
}
func checkConnection(client *http.Client, url string, errorChannel chan error) {
resp, err := client.Get(url)
if err != nil {
errorChannel <- err
return
}
_ = resp.Body.Close()
if resp.StatusCode != 200 {
errorChannel <- fmt.Errorf("HTTP status code %d", resp.StatusCode)
return
}
errorChannel <- nil
}
// forwardClientAuths handles all incoming auths from clients before forwarding them on the bridge auth channel.
func (cm *ClientManager) forwardClientAuths() {
for auth := range cm.clientAuths {
logrus.Debug("ClientManager received auth from client")
cm.handleClientAuth(auth)
logrus.Debug("ClientManager is forwarding auth to bridge")
cm.bridgeAuths <- auth
}
}
// SetTokenIfUnset sets the token for the given userID if it wasn't already set.
// The set token does not expire.
func (cm *ClientManager) SetTokenIfUnset(userID, token string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
if _, ok := cm.tokens[userID]; ok {
return
}
logrus.WithField("userID", userID).Info("Setting token because it is currently unset")
cm.tokens[userID] = token
}
// setToken sets the token for the given userID with the given expiration time.
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).Info("Updating token")
cm.tokens[userID] = token
cm.setTokenExpiration(userID, expiration)
}
// setTokenExpiration will ensure the token is refreshed if it expires.
// If the token already has an expiration time set, it is replaced.
func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) {
cm.expirationsLocker.Lock()
defer cm.expirationsLocker.Unlock()
// Reduce the expiration by one minute so we can do the refresh with enough time to spare.
expiration -= time.Minute
if exp, ok := cm.expirations[userID]; ok {
exp.timer.Stop()
close(exp.cancel)
}
cm.expirations[userID] = &tokenExpiration{
timer: time.NewTimer(expiration),
cancel: make(chan struct{}),
}
go func(expiration *tokenExpiration) {
select {
case <-expiration.timer.C:
cm.expiredTokens <- userID
case <-expiration.cancel:
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
}
}(cm.expirations[userID])
}
func (cm *ClientManager) clearToken(userID string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).Info("Clearing token")
delete(cm.tokens, userID)
}
// handleClientAuth updates or clears client authorisation based on auths received.
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
cm.clientsLocker.Lock()
defer cm.clientsLocker.Unlock()
// If we aren't managing this client, there's nothing to do.
if _, ok := cm.clients[ca.UserID]; !ok {
logrus.WithField("userID", ca.UserID).Info("Not handling auth for unmanaged client")
return
}
// If the auth is nil, we should clear the token.
if ca.Auth == nil {
cm.clearToken(ca.UserID)
go cm.LogoutClient(ca.UserID)
return
}
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
}
// watchTokenExpirations refreshes any tokens which are about to expire.
func (cm *ClientManager) watchTokenExpirations() {
for userID := range cm.expiredTokens {
log := cm.log.WithField("userID", userID)
log.Info("Auth token expired! Refreshing")
client, ok := cm.clients[userID]
if !ok {
log.Warn("Can't refresh expired token because there is no such client")
continue
}
token, ok := cm.tokens[userID]
if !ok {
log.Warn("Can't refresh expired token because there is no such token")
continue
}
if _, err := client.AuthRefresh(token); err != nil {
log.WithError(err).Error("Failed to refresh expired token")
}
}
}