mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 12:46:46 +00:00
fix: race condition when updating user auth
This commit is contained in:
committed by
Michal Horejsek
parent
51846efed5
commit
8f15041d8f
@ -217,12 +217,7 @@ func (c *client) sendAuth(auth *Auth) {
|
||||
c.accessToken = auth.accessToken
|
||||
}
|
||||
|
||||
go func(auth ClientAuth) {
|
||||
c.cm.clientAuths <- auth
|
||||
}(ClientAuth{
|
||||
UserID: c.userID,
|
||||
Auth: auth,
|
||||
})
|
||||
c.cm.HandleAuth(ClientAuth{UserID: c.userID, Auth: auth})
|
||||
}
|
||||
|
||||
// AuthInfo gets authentication info for a user.
|
||||
|
||||
@ -32,8 +32,7 @@ type ClientManager struct {
|
||||
expiredTokens chan string
|
||||
expirationsLocker sync.Locker
|
||||
|
||||
clientAuths chan ClientAuth // auths received by clients from the API are received here and handled.
|
||||
forwardedAuths chan ClientAuth // once auths are handled, they are forwarded on this channel.
|
||||
authUpdates chan ClientAuth
|
||||
|
||||
host, scheme string
|
||||
hostLocker sync.RWMutex
|
||||
@ -86,8 +85,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
scheme: rootScheme,
|
||||
hostLocker: sync.RWMutex{},
|
||||
|
||||
forwardedAuths: make(chan ClientAuth),
|
||||
clientAuths: make(chan ClientAuth),
|
||||
authUpdates: make(chan ClientAuth),
|
||||
|
||||
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
|
||||
proxyUseDuration: proxyUseDuration,
|
||||
@ -99,8 +97,6 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
return newClient(cm, userID)
|
||||
}
|
||||
|
||||
go cm.forwardClientAuths()
|
||||
|
||||
go cm.watchTokenExpirations()
|
||||
|
||||
return cm
|
||||
@ -260,7 +256,7 @@ func (cm *ClientManager) GetToken(userID string) string {
|
||||
|
||||
// GetAuthUpdateChannel returns a channel on which client auths can be received.
|
||||
func (cm *ClientManager) GetAuthUpdateChannel() chan ClientAuth {
|
||||
return cm.forwardedAuths
|
||||
return cm.authUpdates
|
||||
}
|
||||
|
||||
// Errors for possible connection issues
|
||||
@ -325,16 +321,6 @@ func checkConnection(client *http.Client, url string, errorChannel chan error) {
|
||||
errorChannel <- nil
|
||||
}
|
||||
|
||||
// forwardClientAuths handles all incoming auths from clients before forwarding them on the forwarded auths channel.
|
||||
func (cm *ClientManager) forwardClientAuths() {
|
||||
for auth := range cm.clientAuths {
|
||||
logrus.Debug("ClientManager received auth from client")
|
||||
cm.handleClientAuth(auth)
|
||||
logrus.Debug("ClientManager is forwarding auth")
|
||||
cm.forwardedAuths <- auth
|
||||
}
|
||||
}
|
||||
|
||||
// setTokenIfUnset sets the token for the given userID if it wasn't already set.
|
||||
// The set token does not expire.
|
||||
func (cm *ClientManager) setTokenIfUnset(userID, token string) {
|
||||
@ -401,8 +387,8 @@ func (cm *ClientManager) clearToken(userID string) {
|
||||
delete(cm.tokens, userID)
|
||||
}
|
||||
|
||||
// handleClientAuth updates or clears client authorisation based on auths received.
|
||||
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||
// HandleAuth updates or clears client authorisation based on auths received and then forwards the auth onwards.
|
||||
func (cm *ClientManager) HandleAuth(ca ClientAuth) {
|
||||
cm.clientsLocker.Lock()
|
||||
defer cm.clientsLocker.Unlock()
|
||||
|
||||
@ -420,6 +406,10 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||
}
|
||||
|
||||
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
|
||||
|
||||
logrus.Debug("ClientManager is forwarding auth update...")
|
||||
cm.authUpdates <- ca
|
||||
logrus.Debug("Auth update was forwarded")
|
||||
}
|
||||
|
||||
// watchTokenExpirations refreshes any tokens which are about to expire.
|
||||
|
||||
@ -21,8 +21,3 @@ package pmapi
|
||||
func (s *Auth) DANGEROUSLYSetUID(uid string) {
|
||||
s.uid = uid
|
||||
}
|
||||
|
||||
// GetClientAuthChannel returns a channel on which clients should send auths.
|
||||
func (cm *ClientManager) GetClientAuthChannel() chan ClientAuth {
|
||||
return cm.clientAuths
|
||||
}
|
||||
|
||||
@ -108,9 +108,7 @@ func (api *FakePMAPI) checkInternetAndRecordCall(method method, path string, req
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) sendAuth(auth *pmapi.Auth) {
|
||||
go func(clientAuth pmapi.ClientAuth) {
|
||||
api.controller.clientManager.GetClientAuthChannel() <- clientAuth
|
||||
}(pmapi.ClientAuth{UserID: api.user.ID, Auth: auth})
|
||||
api.controller.clientManager.HandleAuth(pmapi.ClientAuth{UserID: api.user.ID, Auth: auth})
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) setUser(username string) error {
|
||||
|
||||
Reference in New Issue
Block a user