From 3f32fd95e078f385e52e45563b3a2e5fb4c66d5c Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Thu, 16 Apr 2020 14:44:59 +0200 Subject: [PATCH] feat: refresh expired access tokens in one goroutine --- internal/bridge/mocks/mocks.go | 3 +- internal/bridge/user.go | 1 - pkg/listener/listener.go | 7 +--- pkg/pmapi/auth.go | 2 +- pkg/pmapi/clientmanager.go | 72 ++++++++++++++++++++++------------ pkg/pmapi/mocks/mocks.go | 5 ++- 6 files changed, 56 insertions(+), 34 deletions(-) diff --git a/internal/bridge/mocks/mocks.go b/internal/bridge/mocks/mocks.go index d1256dfb..54de9505 100644 --- a/internal/bridge/mocks/mocks.go +++ b/internal/bridge/mocks/mocks.go @@ -5,10 +5,11 @@ package mocks import ( + reflect "reflect" + credentials "github.com/ProtonMail/proton-bridge/internal/bridge/credentials" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" gomock "github.com/golang/mock/gomock" - reflect "reflect" ) // MockConfiger is a mock of Configer interface diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 32080616..1fbcad62 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -528,7 +528,6 @@ func (u *User) Logout() (err error) { u.wasKeyringUnlocked = false u.unlockingKeyringLock.Unlock() - // TODO: Is this necessary or could it be done by ClientManager when a nil auth is received? u.client().Logout() if err = u.credStorer.Logout(u.userID); err != nil { diff --git a/pkg/listener/listener.go b/pkg/listener/listener.go index 66eeef37..0ea4524c 100644 --- a/pkg/listener/listener.go +++ b/pkg/listener/listener.go @@ -74,11 +74,8 @@ func (l *listener) Add(eventName string, channel chan<- string) { if l.channels == nil { l.channels = make(map[string][]chan<- string) } - if _, ok := l.channels[eventName]; ok { - l.channels[eventName] = append(l.channels[eventName], channel) - } else { - l.channels[eventName] = []chan<- string{channel} - } + + l.channels[eventName] = append(l.channels[eventName], channel) } // Remove removes an event listener. diff --git a/pkg/pmapi/auth.go b/pkg/pmapi/auth.go index 1ef2a820..6b3976d5 100644 --- a/pkg/pmapi/auth.go +++ b/pkg/pmapi/auth.go @@ -470,7 +470,7 @@ func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) return auth, err } -// TODO: Should this even be a client method? Or just a method on the client manager? +// Logout instructs the client manager to log this client out. func (c *client) Logout() { c.cm.LogoutClient(c.userID) } diff --git a/pkg/pmapi/clientmanager.go b/pkg/pmapi/clientmanager.go index da53509d..f2e59192 100644 --- a/pkg/pmapi/clientmanager.go +++ b/pkg/pmapi/clientmanager.go @@ -27,6 +27,7 @@ type ClientManager struct { tokensLocker sync.Locker expirations map[string]*tokenExpiration + expiredTokens chan string expirationsLocker sync.Locker bridgeAuths chan ClientAuth @@ -76,6 +77,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { tokensLocker: &sync.Mutex{}, expirations: make(map[string]*tokenExpiration), + expiredTokens: make(chan string), expirationsLocker: &sync.Mutex{}, host: RootURL, @@ -97,6 +99,8 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { go cm.forwardClientAuths() + go cm.watchTokenExpirations() + return cm } @@ -131,8 +135,10 @@ func (cm *ClientManager) GetAnonymousClient() Client { // LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared. func (cm *ClientManager) LogoutClient(userID string) { - client, ok := cm.clients[userID] + cm.clientsLocker.Lock() + defer cm.clientsLocker.Unlock() + client, ok := cm.clients[userID] if !ok { return } @@ -140,13 +146,16 @@ func (cm *ClientManager) LogoutClient(userID string) { delete(cm.clients, userID) go func() { - if !strings.HasPrefix(userID, "anonymous-") { - for client.DeleteAuth() == ErrAPINotReachable { - cm.log.Warn("Logging out client failed because API was not reachable, retrying...") - } + defer client.ClearData() + defer cm.clearToken(userID) + + if strings.HasPrefix(userID, "anonymous-") { + return + } + + for client.DeleteAuth() == ErrAPINotReachable { + cm.log.Warn("Logging out client failed because API was not reachable, retrying...") } - client.ClearData() - cm.clearToken(userID) }() } @@ -281,9 +290,6 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration cm.tokens[userID] = token cm.setTokenExpiration(userID, expiration) - - // TODO: This should be one go routine per all tokens. - go cm.watchTokenExpiration(userID) } // setTokenExpiration will ensure the token is refreshed if it expires. @@ -292,6 +298,9 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat cm.expirationsLocker.Lock() defer cm.expirationsLocker.Unlock() + // Reduce the expiration by one minute so we can do the refresh with enough time to spare. + expiration -= time.Minute + if exp, ok := cm.expirations[userID]; ok { exp.timer.Stop() close(exp.cancel) @@ -301,6 +310,16 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat timer: time.NewTimer(expiration), cancel: make(chan struct{}), } + + go func(expiration *tokenExpiration) { + select { + case <-expiration.timer.C: + cm.expiredTokens <- userID + + case <-expiration.cancel: + logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired") + } + }(cm.expirations[userID]) } func (cm *ClientManager) clearToken(userID string) { @@ -324,30 +343,35 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) { } // If the auth is nil, we should clear the token. - // TODO: Maybe we should trigger a client logout here? Then we don't have to remember to log it out ourself. if ca.Auth == nil { cm.clearToken(ca.UserID) + go cm.LogoutClient(ca.UserID) return } cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second) } -func (cm *ClientManager) watchTokenExpiration(userID string) { - cm.expirationsLocker.Lock() - expiration := cm.expirations[userID] - cm.expirationsLocker.Unlock() +func (cm *ClientManager) watchTokenExpirations() { + for userID := range cm.expiredTokens { + log := cm.log.WithField("userID", userID) - select { - case <-expiration.timer.C: - cm.log.WithField("userID", userID).Info("Auth token expired! Refreshing") - if _, err := cm.clients[userID].AuthRefresh(cm.tokens[userID]); err != nil { - cm.log.WithField("userID", userID). - WithError(err). - Error("Token refresh failed before expiration") + log.Info("Auth token expired! Refreshing") + + client, ok := cm.clients[userID] + if !ok { + log.Warn("Can't refresh expired token because there is no such client") + continue } - case <-expiration.cancel: - logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired") + token, ok := cm.tokens[userID] + if !ok { + log.Warn("Can't refresh expired token because there is no such token") + continue + } + + if _, err := client.AuthRefresh(token); err != nil { + log.WithError(err).Error("Failed to refresh expired token") + } } } diff --git a/pkg/pmapi/mocks/mocks.go b/pkg/pmapi/mocks/mocks.go index f9090e31..b06cc0e3 100644 --- a/pkg/pmapi/mocks/mocks.go +++ b/pkg/pmapi/mocks/mocks.go @@ -5,11 +5,12 @@ package mocks import ( + io "io" + reflect "reflect" + crypto "github.com/ProtonMail/gopenpgp/crypto" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" gomock "github.com/golang/mock/gomock" - io "io" - reflect "reflect" ) // MockClient is a mock of Client interface