feat: refresh expired access tokens in one goroutine

This commit is contained in:
James Houlahan
2020-04-16 14:44:59 +02:00
parent 40e96b9d1e
commit 3f32fd95e0
6 changed files with 56 additions and 34 deletions

View File

@ -5,10 +5,11 @@
package mocks package mocks
import ( import (
reflect "reflect"
credentials "github.com/ProtonMail/proton-bridge/internal/bridge/credentials" credentials "github.com/ProtonMail/proton-bridge/internal/bridge/credentials"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
reflect "reflect"
) )
// MockConfiger is a mock of Configer interface // MockConfiger is a mock of Configer interface

View File

@ -528,7 +528,6 @@ func (u *User) Logout() (err error) {
u.wasKeyringUnlocked = false u.wasKeyringUnlocked = false
u.unlockingKeyringLock.Unlock() u.unlockingKeyringLock.Unlock()
// TODO: Is this necessary or could it be done by ClientManager when a nil auth is received?
u.client().Logout() u.client().Logout()
if err = u.credStorer.Logout(u.userID); err != nil { if err = u.credStorer.Logout(u.userID); err != nil {

View File

@ -74,11 +74,8 @@ func (l *listener) Add(eventName string, channel chan<- string) {
if l.channels == nil { if l.channels == nil {
l.channels = make(map[string][]chan<- string) l.channels = make(map[string][]chan<- string)
} }
if _, ok := l.channels[eventName]; ok {
l.channels[eventName] = append(l.channels[eventName], channel) l.channels[eventName] = append(l.channels[eventName], channel)
} else {
l.channels[eventName] = []chan<- string{channel}
}
} }
// Remove removes an event listener. // Remove removes an event listener.

View File

@ -470,7 +470,7 @@ func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
return auth, err return auth, err
} }
// TODO: Should this even be a client method? Or just a method on the client manager? // Logout instructs the client manager to log this client out.
func (c *client) Logout() { func (c *client) Logout() {
c.cm.LogoutClient(c.userID) c.cm.LogoutClient(c.userID)
} }

View File

@ -27,6 +27,7 @@ type ClientManager struct {
tokensLocker sync.Locker tokensLocker sync.Locker
expirations map[string]*tokenExpiration expirations map[string]*tokenExpiration
expiredTokens chan string
expirationsLocker sync.Locker expirationsLocker sync.Locker
bridgeAuths chan ClientAuth bridgeAuths chan ClientAuth
@ -76,6 +77,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
tokensLocker: &sync.Mutex{}, tokensLocker: &sync.Mutex{},
expirations: make(map[string]*tokenExpiration), expirations: make(map[string]*tokenExpiration),
expiredTokens: make(chan string),
expirationsLocker: &sync.Mutex{}, expirationsLocker: &sync.Mutex{},
host: RootURL, host: RootURL,
@ -97,6 +99,8 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
go cm.forwardClientAuths() go cm.forwardClientAuths()
go cm.watchTokenExpirations()
return cm return cm
} }
@ -131,8 +135,10 @@ func (cm *ClientManager) GetAnonymousClient() Client {
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared. // LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
func (cm *ClientManager) LogoutClient(userID string) { func (cm *ClientManager) LogoutClient(userID string) {
client, ok := cm.clients[userID] cm.clientsLocker.Lock()
defer cm.clientsLocker.Unlock()
client, ok := cm.clients[userID]
if !ok { if !ok {
return return
} }
@ -140,13 +146,16 @@ func (cm *ClientManager) LogoutClient(userID string) {
delete(cm.clients, userID) delete(cm.clients, userID)
go func() { go func() {
if !strings.HasPrefix(userID, "anonymous-") { defer client.ClearData()
for client.DeleteAuth() == ErrAPINotReachable { defer cm.clearToken(userID)
cm.log.Warn("Logging out client failed because API was not reachable, retrying...")
} if strings.HasPrefix(userID, "anonymous-") {
return
}
for client.DeleteAuth() == ErrAPINotReachable {
cm.log.Warn("Logging out client failed because API was not reachable, retrying...")
} }
client.ClearData()
cm.clearToken(userID)
}() }()
} }
@ -281,9 +290,6 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
cm.tokens[userID] = token cm.tokens[userID] = token
cm.setTokenExpiration(userID, expiration) cm.setTokenExpiration(userID, expiration)
// TODO: This should be one go routine per all tokens.
go cm.watchTokenExpiration(userID)
} }
// setTokenExpiration will ensure the token is refreshed if it expires. // setTokenExpiration will ensure the token is refreshed if it expires.
@ -292,6 +298,9 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat
cm.expirationsLocker.Lock() cm.expirationsLocker.Lock()
defer cm.expirationsLocker.Unlock() defer cm.expirationsLocker.Unlock()
// Reduce the expiration by one minute so we can do the refresh with enough time to spare.
expiration -= time.Minute
if exp, ok := cm.expirations[userID]; ok { if exp, ok := cm.expirations[userID]; ok {
exp.timer.Stop() exp.timer.Stop()
close(exp.cancel) close(exp.cancel)
@ -301,6 +310,16 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat
timer: time.NewTimer(expiration), timer: time.NewTimer(expiration),
cancel: make(chan struct{}), cancel: make(chan struct{}),
} }
go func(expiration *tokenExpiration) {
select {
case <-expiration.timer.C:
cm.expiredTokens <- userID
case <-expiration.cancel:
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
}
}(cm.expirations[userID])
} }
func (cm *ClientManager) clearToken(userID string) { func (cm *ClientManager) clearToken(userID string) {
@ -324,30 +343,35 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
} }
// If the auth is nil, we should clear the token. // If the auth is nil, we should clear the token.
// TODO: Maybe we should trigger a client logout here? Then we don't have to remember to log it out ourself.
if ca.Auth == nil { if ca.Auth == nil {
cm.clearToken(ca.UserID) cm.clearToken(ca.UserID)
go cm.LogoutClient(ca.UserID)
return return
} }
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second) cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
} }
func (cm *ClientManager) watchTokenExpiration(userID string) { func (cm *ClientManager) watchTokenExpirations() {
cm.expirationsLocker.Lock() for userID := range cm.expiredTokens {
expiration := cm.expirations[userID] log := cm.log.WithField("userID", userID)
cm.expirationsLocker.Unlock()
select { log.Info("Auth token expired! Refreshing")
case <-expiration.timer.C:
cm.log.WithField("userID", userID).Info("Auth token expired! Refreshing") client, ok := cm.clients[userID]
if _, err := cm.clients[userID].AuthRefresh(cm.tokens[userID]); err != nil { if !ok {
cm.log.WithField("userID", userID). log.Warn("Can't refresh expired token because there is no such client")
WithError(err). continue
Error("Token refresh failed before expiration")
} }
case <-expiration.cancel: token, ok := cm.tokens[userID]
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired") if !ok {
log.Warn("Can't refresh expired token because there is no such token")
continue
}
if _, err := client.AuthRefresh(token); err != nil {
log.WithError(err).Error("Failed to refresh expired token")
}
} }
} }

View File

@ -5,11 +5,12 @@
package mocks package mocks
import ( import (
io "io"
reflect "reflect"
crypto "github.com/ProtonMail/gopenpgp/crypto" crypto "github.com/ProtonMail/gopenpgp/crypto"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
io "io"
reflect "reflect"
) )
// MockClient is a mock of Client interface // MockClient is a mock of Client interface