mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
feat: implement token expiration watcher
This commit is contained in:
@ -24,7 +24,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
pmcrypto "github.com/ProtonMail/gopenpgp/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/srp"
|
||||
@ -197,15 +196,19 @@ type AuthRefreshReq struct {
|
||||
}
|
||||
|
||||
func (c *Client) sendAuth(auth *Auth) {
|
||||
c.cm.getClientAuthChannel() <- ClientAuth{
|
||||
UserID: c.userID,
|
||||
Auth: auth,
|
||||
}
|
||||
go func() {
|
||||
c.log.Debug("Client is sending auth to ClientManager")
|
||||
|
||||
if auth != nil {
|
||||
c.uid = auth.UID()
|
||||
c.accessToken = auth.accessToken
|
||||
}
|
||||
c.cm.getClientAuthChannel() <- ClientAuth{
|
||||
UserID: c.userID,
|
||||
Auth: auth,
|
||||
}
|
||||
|
||||
if auth != nil {
|
||||
c.uid = auth.UID()
|
||||
c.accessToken = auth.accessToken
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// AuthInfo gets authentication info for a user.
|
||||
@ -324,7 +327,6 @@ func (c *Client) Auth(username, password string, info *AuthInfo) (auth *Auth, er
|
||||
}
|
||||
}
|
||||
|
||||
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
|
||||
return auth, err
|
||||
}
|
||||
|
||||
@ -445,7 +447,6 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
||||
|
||||
auth = res.getAuth()
|
||||
c.sendAuth(auth)
|
||||
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
|
||||
|
||||
return auth, err
|
||||
}
|
||||
|
||||
@ -81,6 +81,7 @@ type ClientConfig struct {
|
||||
|
||||
// Transport specifies the mechanism by which individual HTTP requests are made.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
// TODO: This could be removed entirely and set in the client manager via SetClientRoundTripper.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// Timeout specifies the timeout from request to getting response headers to our API.
|
||||
@ -108,7 +109,6 @@ type Client struct {
|
||||
requestLocker sync.Locker
|
||||
keyLocker sync.Locker
|
||||
|
||||
expiresAt time.Time
|
||||
user *User
|
||||
addresses AddressList
|
||||
kr *pmcrypto.KeyRing
|
||||
|
||||
@ -3,12 +3,15 @@ package pmapi
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/raven-go"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var proxyUseDuration = 24 * time.Hour
|
||||
|
||||
// ClientManager is a manager of clients.
|
||||
type ClientManager struct {
|
||||
config *ClientConfig
|
||||
@ -16,8 +19,9 @@ type ClientManager struct {
|
||||
clients map[string]*Client
|
||||
clientsLocker sync.Locker
|
||||
|
||||
tokens map[string]string
|
||||
tokensLocker sync.Locker
|
||||
tokens map[string]string
|
||||
tokenExpirations map[string]*tokenExpiration
|
||||
tokensLocker sync.Locker
|
||||
|
||||
url string
|
||||
urlLocker sync.Locker
|
||||
@ -34,6 +38,11 @@ type ClientAuth struct {
|
||||
Auth *Auth
|
||||
}
|
||||
|
||||
type tokenExpiration struct {
|
||||
timer *time.Timer
|
||||
cancel chan (struct{})
|
||||
}
|
||||
|
||||
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
|
||||
func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
if err := raven.SetDSN(config.SentryDSN); err != nil {
|
||||
@ -46,8 +55,9 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
clients: make(map[string]*Client),
|
||||
clientsLocker: &sync.Mutex{},
|
||||
|
||||
tokens: make(map[string]string),
|
||||
tokensLocker: &sync.Mutex{},
|
||||
tokens: make(map[string]string),
|
||||
tokenExpirations: make(map[string]*tokenExpiration),
|
||||
tokensLocker: &sync.Mutex{},
|
||||
|
||||
url: RootURL,
|
||||
urlLocker: &sync.Mutex{},
|
||||
@ -112,43 +122,56 @@ func (cm *ClientManager) GetRootURL() string {
|
||||
return cm.url
|
||||
}
|
||||
|
||||
// SetRootURL sets the root URL to make requests to.
|
||||
func (cm *ClientManager) SetRootURL(url string) {
|
||||
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
||||
func (cm *ClientManager) IsProxyAllowed() bool {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
|
||||
logrus.WithField("url", url).Info("Changing to a new root URL")
|
||||
|
||||
cm.url = url
|
||||
}
|
||||
|
||||
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
||||
func (cm *ClientManager) IsProxyAllowed() bool {
|
||||
return cm.allowProxy
|
||||
}
|
||||
|
||||
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
|
||||
func (cm *ClientManager) AllowProxy() {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
|
||||
cm.allowProxy = true
|
||||
}
|
||||
|
||||
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
|
||||
func (cm *ClientManager) DisallowProxy() {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
|
||||
cm.allowProxy = false
|
||||
cm.url = RootURL
|
||||
}
|
||||
|
||||
// IsProxyEnabled returns whether we are currently proxying requests.
|
||||
func (cm *ClientManager) IsProxyEnabled() bool {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
|
||||
return cm.url != RootURL
|
||||
}
|
||||
|
||||
// FindProxy returns a usable proxy server.
|
||||
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
|
||||
logrus.Info("Attempting gto switch to a proxy")
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
|
||||
logrus.Info("Attempting to switch to a proxy")
|
||||
|
||||
if proxy, err = cm.proxyProvider.findProxy(); err != nil {
|
||||
err = errors.Wrap(err, "failed to find usable proxy")
|
||||
err = errors.Wrap(err, "failed to find a usable proxy")
|
||||
return
|
||||
}
|
||||
|
||||
cm.SetRootURL(proxy)
|
||||
logrus.WithField("proxy", proxy).Info("Switching to a proxy")
|
||||
|
||||
// TODO: Disable after 24 hours.
|
||||
cm.url = proxy
|
||||
|
||||
// TODO: Disable again after 24 hours.
|
||||
|
||||
return
|
||||
}
|
||||
@ -165,7 +188,7 @@ func (cm *ClientManager) GetToken(userID string) string {
|
||||
|
||||
// GetBridgeAuthChannel returns a channel on which client auths can be received.
|
||||
func (cm *ClientManager) GetBridgeAuthChannel() chan ClientAuth {
|
||||
return cm.clientAuths
|
||||
return cm.bridgeAuths
|
||||
}
|
||||
|
||||
// getClientAuthChannel returns a channel on which clients should send auths.
|
||||
@ -176,25 +199,46 @@ func (cm *ClientManager) getClientAuthChannel() chan ClientAuth {
|
||||
// forwardClientAuths handles all incoming auths from clients before forwarding them on the bridge auth channel.
|
||||
func (cm *ClientManager) forwardClientAuths() {
|
||||
for auth := range cm.clientAuths {
|
||||
logrus.Debug("ClientManager received auth from client")
|
||||
cm.handleClientAuth(auth)
|
||||
logrus.Debug("ClientManager is forwarding auth to bridge")
|
||||
cm.bridgeAuths <- auth
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ClientManager) setToken(userID, token string) {
|
||||
// setToken sets the token for the given userID with the given expiration time.
|
||||
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).Info("Updating refresh token")
|
||||
logrus.WithField("userID", userID).Info("Updating token")
|
||||
|
||||
cm.tokens[userID] = token
|
||||
|
||||
cm.setTokenExpiration(userID, expiration)
|
||||
}
|
||||
|
||||
// setTokenExpiration will ensure the token is refreshed if it expires.
|
||||
// If the token already has an expiration time set, it is replaced.
|
||||
func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) {
|
||||
if exp, ok := cm.tokenExpirations[userID]; ok {
|
||||
exp.timer.Stop()
|
||||
close(exp.cancel)
|
||||
}
|
||||
|
||||
cm.tokenExpirations[userID] = &tokenExpiration{
|
||||
timer: time.NewTimer(expiration),
|
||||
cancel: make(chan struct{}),
|
||||
}
|
||||
|
||||
go cm.watchTokenExpiration(userID)
|
||||
}
|
||||
|
||||
func (cm *ClientManager) clearToken(userID string) {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).Info("Clearing refresh token")
|
||||
logrus.WithField("userID", userID).Info("Clearing token")
|
||||
|
||||
delete(cm.tokens, userID)
|
||||
}
|
||||
@ -203,6 +247,7 @@ func (cm *ClientManager) clearToken(userID string) {
|
||||
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||
// If we aren't managing this client, there's nothing to do.
|
||||
if _, ok := cm.clients[ca.UserID]; !ok {
|
||||
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client")
|
||||
return
|
||||
}
|
||||
|
||||
@ -213,5 +258,18 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||
return
|
||||
}
|
||||
|
||||
cm.setToken(ca.UserID, ca.Auth.GenToken())
|
||||
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
|
||||
}
|
||||
|
||||
func (cm *ClientManager) watchTokenExpiration(userID string) {
|
||||
expiration := cm.tokenExpirations[userID]
|
||||
|
||||
select {
|
||||
case <-expiration.timer.C:
|
||||
logrus.WithField("userID", userID).Info("Auth token expired! Refreshing")
|
||||
cm.clients[userID].AuthRefresh(cm.tokens[userID])
|
||||
|
||||
case <-expiration.cancel:
|
||||
logrus.WithField("userID", userID).Info("Auth was refreshed before it expired, cancelling this watcher")
|
||||
}
|
||||
}
|
||||
|
||||
@ -282,13 +282,15 @@ func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (c
|
||||
func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn net.Conn, err error) {
|
||||
p.log.Info("Dialing with proxy fallback")
|
||||
|
||||
var host, port string
|
||||
if host, port, err = net.SplitHostPort(address); err != nil {
|
||||
// Try to dial, and if it succeeds, then just return.
|
||||
if conn, err = p.dial(network, address); err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to dial, and if it succeeds, then just return.
|
||||
if conn, err = p.dial(network, address); err == nil {
|
||||
p.log.WithField("address", address).WithError(err).Error("Dialing failed")
|
||||
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -296,18 +298,21 @@ func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn
|
||||
// (e.g. we dial protonmail.com/... to check for updates), there's also no point in
|
||||
// continuing since a proxy won't help us reach that.
|
||||
if !p.cm.IsProxyAllowed() || host != p.cm.GetRootURL() {
|
||||
p.log.WithField("useProxy", p.cm.IsProxyAllowed()).Info("Dial failed but not switching to proxy")
|
||||
p.log.WithField("address", address).Debug("Aborting dial, cannot switch to a proxy")
|
||||
return
|
||||
}
|
||||
|
||||
// Switch to a proxy and retry the dial.
|
||||
var proxy string
|
||||
|
||||
if proxy, err = p.cm.SwitchToProxy(); err != nil {
|
||||
proxy, err := p.cm.SwitchToProxy()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return p.dial(network, net.JoinHostPort(proxy, port))
|
||||
proxyAddress := net.JoinHostPort(proxy, port)
|
||||
|
||||
p.log.WithField("address", proxyAddress).Debug("Trying dial again using a proxy")
|
||||
|
||||
return p.dial(network, proxyAddress)
|
||||
}
|
||||
|
||||
// dial returns a connection to the given address using the given network.
|
||||
|
||||
@ -51,7 +51,6 @@ type proxyProvider struct {
|
||||
query string // The query string used to find proxies.
|
||||
proxyCache []string // All known proxies, cached in case DoH providers are unreachable.
|
||||
|
||||
useDuration time.Duration // How much time to use the proxy before returning to the original API.
|
||||
findTimeout, lookupTimeout time.Duration // Timeouts for DNS query and proxy search.
|
||||
|
||||
lastLookup time.Time // The time at which we last attempted to find a proxy.
|
||||
@ -63,7 +62,6 @@ func newProxyProvider(providers []string, query string) (p *proxyProvider) { //
|
||||
p = &proxyProvider{
|
||||
providers: providers,
|
||||
query: query,
|
||||
useDuration: proxyRevertTime,
|
||||
findTimeout: proxySearchTimeout,
|
||||
lookupTimeout: proxyQueryTimeout,
|
||||
}
|
||||
@ -148,6 +146,7 @@ func (p *proxyProvider) canReach(url string) bool {
|
||||
SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) // nolint[gosec]
|
||||
|
||||
if _, err := pinger.R().Get("/tests/ping"); err != nil {
|
||||
logrus.WithField("proxy", url).WithError(err).Warn("Failed to ping proxy")
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@ -109,7 +109,6 @@ func (c *Client) UpdateUser() (user *User, err error) {
|
||||
}
|
||||
|
||||
c.user = user
|
||||
c.log.Infoln("update user:", user.ID)
|
||||
raven.SetUserContext(&raven.User{ID: user.ID})
|
||||
|
||||
var tmpList AddressList
|
||||
@ -117,6 +116,8 @@ func (c *Client) UpdateUser() (user *User, err error) {
|
||||
c.addresses = tmpList
|
||||
}
|
||||
|
||||
c.log.WithField("userID", user.ID).Info("Updated user")
|
||||
|
||||
return user, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user