mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2026-02-04 08:18:34 +00:00
feat: refresh expired access tokens in one goroutine
This commit is contained in:
@ -5,10 +5,11 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
credentials "github.com/ProtonMail/proton-bridge/internal/bridge/credentials"
|
credentials "github.com/ProtonMail/proton-bridge/internal/bridge/credentials"
|
||||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
reflect "reflect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockConfiger is a mock of Configer interface
|
// MockConfiger is a mock of Configer interface
|
||||||
|
|||||||
@ -528,7 +528,6 @@ func (u *User) Logout() (err error) {
|
|||||||
u.wasKeyringUnlocked = false
|
u.wasKeyringUnlocked = false
|
||||||
u.unlockingKeyringLock.Unlock()
|
u.unlockingKeyringLock.Unlock()
|
||||||
|
|
||||||
// TODO: Is this necessary or could it be done by ClientManager when a nil auth is received?
|
|
||||||
u.client().Logout()
|
u.client().Logout()
|
||||||
|
|
||||||
if err = u.credStorer.Logout(u.userID); err != nil {
|
if err = u.credStorer.Logout(u.userID); err != nil {
|
||||||
|
|||||||
@ -74,11 +74,8 @@ func (l *listener) Add(eventName string, channel chan<- string) {
|
|||||||
if l.channels == nil {
|
if l.channels == nil {
|
||||||
l.channels = make(map[string][]chan<- string)
|
l.channels = make(map[string][]chan<- string)
|
||||||
}
|
}
|
||||||
if _, ok := l.channels[eventName]; ok {
|
|
||||||
l.channels[eventName] = append(l.channels[eventName], channel)
|
l.channels[eventName] = append(l.channels[eventName], channel)
|
||||||
} else {
|
|
||||||
l.channels[eventName] = []chan<- string{channel}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove removes an event listener.
|
// Remove removes an event listener.
|
||||||
|
|||||||
@ -470,7 +470,7 @@ func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
|||||||
return auth, err
|
return auth, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Should this even be a client method? Or just a method on the client manager?
|
// Logout instructs the client manager to log this client out.
|
||||||
func (c *client) Logout() {
|
func (c *client) Logout() {
|
||||||
c.cm.LogoutClient(c.userID)
|
c.cm.LogoutClient(c.userID)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,6 +27,7 @@ type ClientManager struct {
|
|||||||
tokensLocker sync.Locker
|
tokensLocker sync.Locker
|
||||||
|
|
||||||
expirations map[string]*tokenExpiration
|
expirations map[string]*tokenExpiration
|
||||||
|
expiredTokens chan string
|
||||||
expirationsLocker sync.Locker
|
expirationsLocker sync.Locker
|
||||||
|
|
||||||
bridgeAuths chan ClientAuth
|
bridgeAuths chan ClientAuth
|
||||||
@ -76,6 +77,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
|||||||
tokensLocker: &sync.Mutex{},
|
tokensLocker: &sync.Mutex{},
|
||||||
|
|
||||||
expirations: make(map[string]*tokenExpiration),
|
expirations: make(map[string]*tokenExpiration),
|
||||||
|
expiredTokens: make(chan string),
|
||||||
expirationsLocker: &sync.Mutex{},
|
expirationsLocker: &sync.Mutex{},
|
||||||
|
|
||||||
host: RootURL,
|
host: RootURL,
|
||||||
@ -97,6 +99,8 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
|||||||
|
|
||||||
go cm.forwardClientAuths()
|
go cm.forwardClientAuths()
|
||||||
|
|
||||||
|
go cm.watchTokenExpirations()
|
||||||
|
|
||||||
return cm
|
return cm
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,8 +135,10 @@ func (cm *ClientManager) GetAnonymousClient() Client {
|
|||||||
|
|
||||||
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
||||||
func (cm *ClientManager) LogoutClient(userID string) {
|
func (cm *ClientManager) LogoutClient(userID string) {
|
||||||
client, ok := cm.clients[userID]
|
cm.clientsLocker.Lock()
|
||||||
|
defer cm.clientsLocker.Unlock()
|
||||||
|
|
||||||
|
client, ok := cm.clients[userID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -140,13 +146,16 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
|||||||
delete(cm.clients, userID)
|
delete(cm.clients, userID)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if !strings.HasPrefix(userID, "anonymous-") {
|
defer client.ClearData()
|
||||||
for client.DeleteAuth() == ErrAPINotReachable {
|
defer cm.clearToken(userID)
|
||||||
cm.log.Warn("Logging out client failed because API was not reachable, retrying...")
|
|
||||||
}
|
if strings.HasPrefix(userID, "anonymous-") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for client.DeleteAuth() == ErrAPINotReachable {
|
||||||
|
cm.log.Warn("Logging out client failed because API was not reachable, retrying...")
|
||||||
}
|
}
|
||||||
client.ClearData()
|
|
||||||
cm.clearToken(userID)
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -281,9 +290,6 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
|
|||||||
cm.tokens[userID] = token
|
cm.tokens[userID] = token
|
||||||
|
|
||||||
cm.setTokenExpiration(userID, expiration)
|
cm.setTokenExpiration(userID, expiration)
|
||||||
|
|
||||||
// TODO: This should be one go routine per all tokens.
|
|
||||||
go cm.watchTokenExpiration(userID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// setTokenExpiration will ensure the token is refreshed if it expires.
|
// setTokenExpiration will ensure the token is refreshed if it expires.
|
||||||
@ -292,6 +298,9 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat
|
|||||||
cm.expirationsLocker.Lock()
|
cm.expirationsLocker.Lock()
|
||||||
defer cm.expirationsLocker.Unlock()
|
defer cm.expirationsLocker.Unlock()
|
||||||
|
|
||||||
|
// Reduce the expiration by one minute so we can do the refresh with enough time to spare.
|
||||||
|
expiration -= time.Minute
|
||||||
|
|
||||||
if exp, ok := cm.expirations[userID]; ok {
|
if exp, ok := cm.expirations[userID]; ok {
|
||||||
exp.timer.Stop()
|
exp.timer.Stop()
|
||||||
close(exp.cancel)
|
close(exp.cancel)
|
||||||
@ -301,6 +310,16 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat
|
|||||||
timer: time.NewTimer(expiration),
|
timer: time.NewTimer(expiration),
|
||||||
cancel: make(chan struct{}),
|
cancel: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func(expiration *tokenExpiration) {
|
||||||
|
select {
|
||||||
|
case <-expiration.timer.C:
|
||||||
|
cm.expiredTokens <- userID
|
||||||
|
|
||||||
|
case <-expiration.cancel:
|
||||||
|
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
|
||||||
|
}
|
||||||
|
}(cm.expirations[userID])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ClientManager) clearToken(userID string) {
|
func (cm *ClientManager) clearToken(userID string) {
|
||||||
@ -324,30 +343,35 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the auth is nil, we should clear the token.
|
// If the auth is nil, we should clear the token.
|
||||||
// TODO: Maybe we should trigger a client logout here? Then we don't have to remember to log it out ourself.
|
|
||||||
if ca.Auth == nil {
|
if ca.Auth == nil {
|
||||||
cm.clearToken(ca.UserID)
|
cm.clearToken(ca.UserID)
|
||||||
|
go cm.LogoutClient(ca.UserID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
|
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ClientManager) watchTokenExpiration(userID string) {
|
func (cm *ClientManager) watchTokenExpirations() {
|
||||||
cm.expirationsLocker.Lock()
|
for userID := range cm.expiredTokens {
|
||||||
expiration := cm.expirations[userID]
|
log := cm.log.WithField("userID", userID)
|
||||||
cm.expirationsLocker.Unlock()
|
|
||||||
|
|
||||||
select {
|
log.Info("Auth token expired! Refreshing")
|
||||||
case <-expiration.timer.C:
|
|
||||||
cm.log.WithField("userID", userID).Info("Auth token expired! Refreshing")
|
client, ok := cm.clients[userID]
|
||||||
if _, err := cm.clients[userID].AuthRefresh(cm.tokens[userID]); err != nil {
|
if !ok {
|
||||||
cm.log.WithField("userID", userID).
|
log.Warn("Can't refresh expired token because there is no such client")
|
||||||
WithError(err).
|
continue
|
||||||
Error("Token refresh failed before expiration")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-expiration.cancel:
|
token, ok := cm.tokens[userID]
|
||||||
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
|
if !ok {
|
||||||
|
log.Warn("Can't refresh expired token because there is no such token")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.AuthRefresh(token); err != nil {
|
||||||
|
log.WithError(err).Error("Failed to refresh expired token")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,11 +5,12 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
io "io"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
crypto "github.com/ProtonMail/gopenpgp/crypto"
|
crypto "github.com/ProtonMail/gopenpgp/crypto"
|
||||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
io "io"
|
|
||||||
reflect "reflect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockClient is a mock of Client interface
|
// MockClient is a mock of Client interface
|
||||||
|
|||||||
Reference in New Issue
Block a user