From 8f15041d8f24a3be031510a02b5e57582c3aa2ee Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Wed, 22 Apr 2020 15:07:35 +0200 Subject: [PATCH] fix: race condition when updating user auth --- pkg/pmapi/auth.go | 7 +------ pkg/pmapi/clientmanager.go | 28 +++++++++------------------- pkg/pmapi/pmapi_test_exports.go | 5 ----- test/fakeapi/fakeapi.go | 4 +--- 4 files changed, 11 insertions(+), 33 deletions(-) diff --git a/pkg/pmapi/auth.go b/pkg/pmapi/auth.go index 8647edeb..fb550c25 100644 --- a/pkg/pmapi/auth.go +++ b/pkg/pmapi/auth.go @@ -217,12 +217,7 @@ func (c *client) sendAuth(auth *Auth) { c.accessToken = auth.accessToken } - go func(auth ClientAuth) { - c.cm.clientAuths <- auth - }(ClientAuth{ - UserID: c.userID, - Auth: auth, - }) + c.cm.HandleAuth(ClientAuth{UserID: c.userID, Auth: auth}) } // AuthInfo gets authentication info for a user. diff --git a/pkg/pmapi/clientmanager.go b/pkg/pmapi/clientmanager.go index 9057b847..03c17eeb 100644 --- a/pkg/pmapi/clientmanager.go +++ b/pkg/pmapi/clientmanager.go @@ -32,8 +32,7 @@ type ClientManager struct { expiredTokens chan string expirationsLocker sync.Locker - clientAuths chan ClientAuth // auths received by clients from the API are received here and handled. - forwardedAuths chan ClientAuth // once auths are handled, they are forwarded on this channel. + authUpdates chan ClientAuth host, scheme string hostLocker sync.RWMutex @@ -86,8 +85,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { scheme: rootScheme, hostLocker: sync.RWMutex{}, - forwardedAuths: make(chan ClientAuth), - clientAuths: make(chan ClientAuth), + authUpdates: make(chan ClientAuth), proxyProvider: newProxyProvider(dohProviders, proxyQuery), proxyUseDuration: proxyUseDuration, @@ -99,8 +97,6 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { return newClient(cm, userID) } - go cm.forwardClientAuths() - go cm.watchTokenExpirations() return cm @@ -260,7 +256,7 @@ func (cm *ClientManager) GetToken(userID string) string { // GetAuthUpdateChannel returns a channel on which client auths can be received. func (cm *ClientManager) GetAuthUpdateChannel() chan ClientAuth { - return cm.forwardedAuths + return cm.authUpdates } // Errors for possible connection issues @@ -325,16 +321,6 @@ func checkConnection(client *http.Client, url string, errorChannel chan error) { errorChannel <- nil } -// forwardClientAuths handles all incoming auths from clients before forwarding them on the forwarded auths channel. -func (cm *ClientManager) forwardClientAuths() { - for auth := range cm.clientAuths { - logrus.Debug("ClientManager received auth from client") - cm.handleClientAuth(auth) - logrus.Debug("ClientManager is forwarding auth") - cm.forwardedAuths <- auth - } -} - // setTokenIfUnset sets the token for the given userID if it wasn't already set. // The set token does not expire. func (cm *ClientManager) setTokenIfUnset(userID, token string) { @@ -401,8 +387,8 @@ func (cm *ClientManager) clearToken(userID string) { delete(cm.tokens, userID) } -// handleClientAuth updates or clears client authorisation based on auths received. -func (cm *ClientManager) handleClientAuth(ca ClientAuth) { +// HandleAuth updates or clears client authorisation based on auths received and then forwards the auth onwards. +func (cm *ClientManager) HandleAuth(ca ClientAuth) { cm.clientsLocker.Lock() defer cm.clientsLocker.Unlock() @@ -420,6 +406,10 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) { } cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second) + + logrus.Debug("ClientManager is forwarding auth update...") + cm.authUpdates <- ca + logrus.Debug("Auth update was forwarded") } // watchTokenExpirations refreshes any tokens which are about to expire. diff --git a/pkg/pmapi/pmapi_test_exports.go b/pkg/pmapi/pmapi_test_exports.go index 5eb50032..de1efaad 100644 --- a/pkg/pmapi/pmapi_test_exports.go +++ b/pkg/pmapi/pmapi_test_exports.go @@ -21,8 +21,3 @@ package pmapi func (s *Auth) DANGEROUSLYSetUID(uid string) { s.uid = uid } - -// GetClientAuthChannel returns a channel on which clients should send auths. -func (cm *ClientManager) GetClientAuthChannel() chan ClientAuth { - return cm.clientAuths -} diff --git a/test/fakeapi/fakeapi.go b/test/fakeapi/fakeapi.go index 5234cf19..7005d14b 100644 --- a/test/fakeapi/fakeapi.go +++ b/test/fakeapi/fakeapi.go @@ -108,9 +108,7 @@ func (api *FakePMAPI) checkInternetAndRecordCall(method method, path string, req } func (api *FakePMAPI) sendAuth(auth *pmapi.Auth) { - go func(clientAuth pmapi.ClientAuth) { - api.controller.clientManager.GetClientAuthChannel() <- clientAuth - }(pmapi.ClientAuth{UserID: api.user.ID, Auth: auth}) + api.controller.clientManager.HandleAuth(pmapi.ClientAuth{UserID: api.user.ID, Auth: auth}) } func (api *FakePMAPI) setUser(username string) error {