From debd374d75eec3fa1745fe03e774162ffd6734c2 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Thu, 9 Apr 2020 15:33:23 +0200 Subject: [PATCH] fix: don't delete uid of anonymous clients --- internal/store/event_loop.go | 4 ++++ internal/store/store_test_exports.go | 7 ++++++ pkg/pmapi/auth.go | 18 +++++++++------ pkg/pmapi/clientmanager.go | 34 ++++++++++++++-------------- pkg/pmapi/users.go | 2 -- test/bridge_checks_test.go | 4 +++- test/fakeapi/events.go | 3 +++ test/fakeapi/fakeapi.go | 9 +++----- 8 files changed, 48 insertions(+), 33 deletions(-) diff --git a/internal/store/event_loop.go b/internal/store/event_loop.go index 7077e393..bdda76da 100644 --- a/internal/store/event_loop.go +++ b/internal/store/event_loop.go @@ -246,6 +246,10 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun return false, errors.Wrap(err, "failed to get event") } + if event == nil { + return + } + l = l.WithField("newEventID", event.EventID) if !loop.hasInternet { diff --git a/internal/store/store_test_exports.go b/internal/store/store_test_exports.go index 7248c249..e10005a9 100644 --- a/internal/store/store_test_exports.go +++ b/internal/store/store_test_exports.go @@ -19,6 +19,7 @@ package store import ( "encoding/json" + "errors" "fmt" "github.com/ProtonMail/proton-bridge/pkg/pmapi" @@ -53,6 +54,12 @@ func (store *Store) TestGetStoreFilePath() string { // TestDumpDB will dump store database content. func (store *Store) TestDumpDB(tb assert.TestingT) { + if store == nil || store.db == nil { + fmt.Printf(">>>>>>>> NIL STORE / DB <<<<<\n\n") + assert.NoError(tb, errors.New("store or database is nil")) + return + } + dumpCounts := true fmt.Printf(">>>>>>>> DUMP %s <<<<<\n\n", store.db.Path()) diff --git a/pkg/pmapi/auth.go b/pkg/pmapi/auth.go index 3126241a..18d6a49e 100644 --- a/pkg/pmapi/auth.go +++ b/pkg/pmapi/auth.go @@ -205,19 +205,23 @@ type AuthRefreshReq struct { } func (c *client) sendAuth(auth *Auth) { - c.log.Debug("Client is sending auth to ClientManager") + if auth != nil { + c.log.WithField("auth", *auth).Debug("Client is sending auth to ClientManager") + } else { + c.log.Debug("Client is sending nil auth to ClientManager") + } if auth != nil { c.uid = auth.UID() c.accessToken = auth.accessToken } - go func() { - c.cm.GetClientAuthChannel() <- ClientAuth{ - UserID: c.userID, - Auth: auth, - } - }() + go func(auth ClientAuth) { + c.cm.GetClientAuthChannel() <- auth + }(ClientAuth{ + UserID: c.userID, + Auth: auth, + }) } // AuthInfo gets authentication info for a user. diff --git a/pkg/pmapi/clientmanager.go b/pkg/pmapi/clientmanager.go index fa8bd0ed..694579c0 100644 --- a/pkg/pmapi/clientmanager.go +++ b/pkg/pmapi/clientmanager.go @@ -42,6 +42,15 @@ type ClientManager struct { allowProxy bool proxyProvider *proxyProvider proxyUseDuration time.Duration + + idGen idGen +} + +type idGen int + +func (i *idGen) next() int { + (*i)++ + return int(*i) } // ClientAuth holds an API auth produced by a Client for a specific user. @@ -117,16 +126,7 @@ func (cm *ClientManager) GetClient(userID string) Client { // GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created. func (cm *ClientManager) GetAnonymousClient() Client { - cm.clientsLocker.Lock() - defer cm.clientsLocker.Unlock() - - if client, ok := cm.clients[""]; ok { - client.DeleteAuth() - } - - cm.clients[""] = cm.newClient("") - - return cm.clients[""] + return cm.GetClient(fmt.Sprintf("anonymous-%v", cm.idGen.next())) } // LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared. @@ -266,11 +266,6 @@ func (cm *ClientManager) forwardClientAuths() { // setToken sets the token for the given userID with the given expiration time. func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) { - // We don't want to set tokens of anonymous clients. - if userID == "" { - return - } - cm.tokensLocker.Lock() defer cm.tokensLocker.Unlock() @@ -279,6 +274,8 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration cm.tokens[userID] = token cm.setTokenExpiration(userID, expiration) + + go cm.watchTokenExpiration(userID) } // setTokenExpiration will ensure the token is refreshed if it expires. @@ -296,8 +293,6 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat timer: time.NewTimer(expiration), cancel: make(chan struct{}), } - - go cm.watchTokenExpiration(userID) } func (cm *ClientManager) clearToken(userID string) { @@ -311,6 +306,9 @@ func (cm *ClientManager) clearToken(userID string) { // handleClientAuth updates or clears client authorisation based on auths received. func (cm *ClientManager) handleClientAuth(ca ClientAuth) { + cm.clientsLocker.Lock() + defer cm.clientsLocker.Unlock() + // If we aren't managing this client, there's nothing to do. if _, ok := cm.clients[ca.UserID]; !ok { logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client") @@ -328,7 +326,9 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) { } func (cm *ClientManager) watchTokenExpiration(userID string) { + cm.expirationsLocker.Lock() expiration := cm.expirations[userID] + cm.expirationsLocker.Unlock() select { case <-expiration.timer.C: diff --git a/pkg/pmapi/users.go b/pkg/pmapi/users.go index b999d229..3e20c63d 100644 --- a/pkg/pmapi/users.go +++ b/pkg/pmapi/users.go @@ -116,8 +116,6 @@ func (c *client) UpdateUser() (user *User, err error) { c.addresses = tmpList } - c.log.WithField("userID", user.ID).Info("Updated user") - return user, err } diff --git a/test/bridge_checks_test.go b/test/bridge_checks_test.go index d38bd16e..3a25b57c 100644 --- a/test/bridge_checks_test.go +++ b/test/bridge_checks_test.go @@ -180,7 +180,9 @@ func hasAPIAuth(accountName string) error { if err != nil { return internalError(err, "getting user %s", account.Username()) } - a.True(ctx.GetTestingT(), bridgeUser.HasAPIAuth()) + a.Eventually(ctx.GetTestingT(), func() bool { + return bridgeUser.HasAPIAuth() + }, 5*time.Second, 10*time.Millisecond) return ctx.GetTestingError() } diff --git a/test/fakeapi/events.go b/test/fakeapi/events.go index dd0904d4..6ec03863 100644 --- a/test/fakeapi/events.go +++ b/test/fakeapi/events.go @@ -27,6 +27,9 @@ func (api *FakePMAPI) GetEvent(eventID string) (*pmapi.Event, error) { } // Request for empty ID returns the latest event. if eventID == "" { + if len(api.events) == 0 { + return &pmapi.Event{EventID: ""}, nil + } return api.events[len(api.events)-1], nil } // Otherwise it tries to find specific ID and return all next events merged into one. diff --git a/test/fakeapi/fakeapi.go b/test/fakeapi/fakeapi.go index aa8d54bd..74e79908 100644 --- a/test/fakeapi/fakeapi.go +++ b/test/fakeapi/fakeapi.go @@ -104,12 +104,9 @@ func (api *FakePMAPI) checkInternetAndRecordCall(method method, path string, req } func (api *FakePMAPI) sendAuth(auth *pmapi.Auth) { - go func() { - api.controller.clientManager.GetClientAuthChannel() <- pmapi.ClientAuth{ - UserID: api.user.ID, - Auth: auth, - } - }() + go func(clientAuth pmapi.ClientAuth) { + api.controller.clientManager.GetClientAuthChannel() <- clientAuth + }(pmapi.ClientAuth{UserID: api.user.ID, Auth: auth}) } func (api *FakePMAPI) setUser(username string) error {