mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-17 23:56:56 +00:00
feat: improve login flow
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@ -10,27 +11,31 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var proxyUseDuration = 24 * time.Hour
|
||||
var defaultProxyUseDuration = 24 * time.Hour
|
||||
|
||||
// ClientManager is a manager of clients.
|
||||
type ClientManager struct {
|
||||
config *ClientConfig
|
||||
config *ClientConfig
|
||||
roundTripper http.RoundTripper
|
||||
|
||||
clients map[string]*Client
|
||||
clientsLocker sync.Locker
|
||||
|
||||
tokens map[string]string
|
||||
tokenExpirations map[string]*tokenExpiration
|
||||
tokensLocker sync.Locker
|
||||
tokens map[string]string
|
||||
tokensLocker sync.Locker
|
||||
|
||||
url string
|
||||
urlLocker sync.Locker
|
||||
expirations map[string]*tokenExpiration
|
||||
expirationsLocker sync.Locker
|
||||
|
||||
host, scheme string
|
||||
hostLocker sync.Locker
|
||||
|
||||
bridgeAuths chan ClientAuth
|
||||
clientAuths chan ClientAuth
|
||||
|
||||
allowProxy bool
|
||||
proxyProvider *proxyProvider
|
||||
allowProxy bool
|
||||
proxyProvider *proxyProvider
|
||||
proxyUseDuration time.Duration
|
||||
}
|
||||
|
||||
type ClientAuth struct {
|
||||
@ -50,22 +55,27 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
}
|
||||
|
||||
cm = &ClientManager{
|
||||
config: config,
|
||||
config: config,
|
||||
roundTripper: http.DefaultTransport,
|
||||
|
||||
clients: make(map[string]*Client),
|
||||
clientsLocker: &sync.Mutex{},
|
||||
|
||||
tokens: make(map[string]string),
|
||||
tokenExpirations: make(map[string]*tokenExpiration),
|
||||
tokensLocker: &sync.Mutex{},
|
||||
tokens: make(map[string]string),
|
||||
tokensLocker: &sync.Mutex{},
|
||||
|
||||
url: RootURL,
|
||||
urlLocker: &sync.Mutex{},
|
||||
expirations: make(map[string]*tokenExpiration),
|
||||
expirationsLocker: &sync.Mutex{},
|
||||
|
||||
host: RootURL,
|
||||
scheme: RootScheme,
|
||||
hostLocker: &sync.Mutex{},
|
||||
|
||||
bridgeAuths: make(chan ClientAuth),
|
||||
clientAuths: make(chan ClientAuth),
|
||||
|
||||
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
|
||||
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
|
||||
proxyUseDuration: defaultProxyUseDuration,
|
||||
}
|
||||
|
||||
go cm.forwardClientAuths()
|
||||
@ -73,10 +83,14 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
return
|
||||
}
|
||||
|
||||
// SetClientRoundTripper sets the roundtripper used by clients created by this client manager.
|
||||
func (cm *ClientManager) SetClientRoundTripper(rt http.RoundTripper) {
|
||||
logrus.Info("Setting client roundtripper")
|
||||
cm.config.Transport = rt
|
||||
// SetRoundTripper sets the roundtripper used by clients created by this client manager.
|
||||
func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
|
||||
cm.roundTripper = rt
|
||||
}
|
||||
|
||||
// GetRoundTripper sets the roundtripper used by clients created by this client manager.
|
||||
func (cm *ClientManager) GetRoundTripper() (rt http.RoundTripper) {
|
||||
return cm.roundTripper
|
||||
}
|
||||
|
||||
// GetClient returns a client for the given userID.
|
||||
@ -91,6 +105,17 @@ func (cm *ClientManager) GetClient(userID string) *Client {
|
||||
return cm.clients[userID]
|
||||
}
|
||||
|
||||
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created.
|
||||
func (cm *ClientManager) GetAnonymousClient() *Client {
|
||||
if client, ok := cm.clients[""]; ok {
|
||||
client.Logout()
|
||||
}
|
||||
|
||||
cm.clients[""] = newClient(cm, "")
|
||||
|
||||
return cm.clients[""]
|
||||
}
|
||||
|
||||
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
||||
func (cm *ClientManager) LogoutClient(userID string) {
|
||||
client, ok := cm.clients[userID]
|
||||
@ -104,7 +129,6 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
||||
go func() {
|
||||
if err := client.logout(); err != nil {
|
||||
// TODO: Try again! This should loop until it succeeds (might fail the first time due to internet).
|
||||
logrus.WithError(err).Error("Client logout failed, not trying again")
|
||||
}
|
||||
client.clearSensitiveData()
|
||||
cm.clearToken(userID)
|
||||
@ -113,52 +137,69 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
||||
return
|
||||
}
|
||||
|
||||
// GetRootURL returns the root URL to make requests to.
|
||||
// It does not include the protocol i.e. no "https://".
|
||||
func (cm *ClientManager) GetRootURL() string {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
// GetHost returns the host to make requests to.
|
||||
// It does not include the protocol i.e. no "https://" (use GetScheme for that).
|
||||
func (cm *ClientManager) GetHost() string {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
return cm.url
|
||||
return cm.host
|
||||
}
|
||||
|
||||
// GetScheme returns the scheme with which to make requests to the host.
|
||||
func (cm *ClientManager) GetScheme() string {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
return cm.scheme
|
||||
}
|
||||
|
||||
// GetRootURL returns the full root URL (scheme+host).
|
||||
func (cm *ClientManager) GetRootURL() string {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
return fmt.Sprintf("%v://%v", cm.scheme, cm.host)
|
||||
}
|
||||
|
||||
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
||||
func (cm *ClientManager) IsProxyAllowed() bool {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
return cm.allowProxy
|
||||
}
|
||||
|
||||
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
|
||||
func (cm *ClientManager) AllowProxy() {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
cm.allowProxy = true
|
||||
}
|
||||
|
||||
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
|
||||
func (cm *ClientManager) DisallowProxy() {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
cm.allowProxy = false
|
||||
cm.url = RootURL
|
||||
cm.host = RootURL
|
||||
}
|
||||
|
||||
// IsProxyEnabled returns whether we are currently proxying requests.
|
||||
func (cm *ClientManager) IsProxyEnabled() bool {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
return cm.url != RootURL
|
||||
return cm.host != RootURL
|
||||
}
|
||||
|
||||
// FindProxy returns a usable proxy server.
|
||||
// SwitchToProxy returns a usable proxy server.
|
||||
// TODO: Perhaps the name could be better -- we aren't only switching to a proxy but also to the standard API.
|
||||
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
logrus.Info("Attempting to switch to a proxy")
|
||||
|
||||
@ -169,9 +210,16 @@ func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
|
||||
|
||||
logrus.WithField("proxy", proxy).Info("Switching to a proxy")
|
||||
|
||||
cm.url = proxy
|
||||
// If the host is currently the RootURL, it's the first time we are enabling a proxy.
|
||||
// This means we want to disable it again in 24 hours.
|
||||
if cm.host == RootURL {
|
||||
go func() {
|
||||
<-time.After(cm.proxyUseDuration)
|
||||
cm.host = RootURL
|
||||
}()
|
||||
}
|
||||
|
||||
// TODO: Disable again after 24 hours.
|
||||
cm.host = proxy
|
||||
|
||||
return
|
||||
}
|
||||
@ -183,6 +231,9 @@ func (cm *ClientManager) GetConfig() *ClientConfig {
|
||||
|
||||
// GetToken returns the token for the given userID.
|
||||
func (cm *ClientManager) GetToken(userID string) string {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
return cm.tokens[userID]
|
||||
}
|
||||
|
||||
@ -208,6 +259,11 @@ func (cm *ClientManager) forwardClientAuths() {
|
||||
|
||||
// setToken sets the token for the given userID with the given expiration time.
|
||||
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
|
||||
// We don't want to set tokens of anonymous clients.
|
||||
if userID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
@ -221,12 +277,15 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
|
||||
// setTokenExpiration will ensure the token is refreshed if it expires.
|
||||
// If the token already has an expiration time set, it is replaced.
|
||||
func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) {
|
||||
if exp, ok := cm.tokenExpirations[userID]; ok {
|
||||
cm.expirationsLocker.Lock()
|
||||
defer cm.expirationsLocker.Unlock()
|
||||
|
||||
if exp, ok := cm.expirations[userID]; ok {
|
||||
exp.timer.Stop()
|
||||
close(exp.cancel)
|
||||
}
|
||||
|
||||
cm.tokenExpirations[userID] = &tokenExpiration{
|
||||
cm.expirations[userID] = &tokenExpiration{
|
||||
timer: time.NewTimer(expiration),
|
||||
cancel: make(chan struct{}),
|
||||
}
|
||||
@ -262,7 +321,7 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||
}
|
||||
|
||||
func (cm *ClientManager) watchTokenExpiration(userID string) {
|
||||
expiration := cm.tokenExpirations[userID]
|
||||
expiration := cm.expirations[userID]
|
||||
|
||||
select {
|
||||
case <-expiration.timer.C:
|
||||
@ -270,6 +329,6 @@ func (cm *ClientManager) watchTokenExpiration(userID string) {
|
||||
cm.clients[userID].AuthRefresh(cm.tokens[userID])
|
||||
|
||||
case <-expiration.cancel:
|
||||
logrus.WithField("userID", userID).Info("Auth was refreshed before it expired, cancelling this watcher")
|
||||
logrus.WithField("userID", userID).Info("Auth was refreshed before it expired")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user