fix: don't delete uid of anonymous clients

This commit is contained in:
James Houlahan
2020-04-09 15:33:23 +02:00
parent ed8595fa5b
commit debd374d75
8 changed files with 48 additions and 33 deletions

View File

@ -246,6 +246,10 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
return false, errors.Wrap(err, "failed to get event") return false, errors.Wrap(err, "failed to get event")
} }
if event == nil {
return
}
l = l.WithField("newEventID", event.EventID) l = l.WithField("newEventID", event.EventID)
if !loop.hasInternet { if !loop.hasInternet {

View File

@ -19,6 +19,7 @@ package store
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
@ -53,6 +54,12 @@ func (store *Store) TestGetStoreFilePath() string {
// TestDumpDB will dump store database content. // TestDumpDB will dump store database content.
func (store *Store) TestDumpDB(tb assert.TestingT) { func (store *Store) TestDumpDB(tb assert.TestingT) {
if store == nil || store.db == nil {
fmt.Printf(">>>>>>>> NIL STORE / DB <<<<<\n\n")
assert.NoError(tb, errors.New("store or database is nil"))
return
}
dumpCounts := true dumpCounts := true
fmt.Printf(">>>>>>>> DUMP %s <<<<<\n\n", store.db.Path()) fmt.Printf(">>>>>>>> DUMP %s <<<<<\n\n", store.db.Path())

View File

@ -205,19 +205,23 @@ type AuthRefreshReq struct {
} }
func (c *client) sendAuth(auth *Auth) { func (c *client) sendAuth(auth *Auth) {
c.log.Debug("Client is sending auth to ClientManager") if auth != nil {
c.log.WithField("auth", *auth).Debug("Client is sending auth to ClientManager")
} else {
c.log.Debug("Client is sending nil auth to ClientManager")
}
if auth != nil { if auth != nil {
c.uid = auth.UID() c.uid = auth.UID()
c.accessToken = auth.accessToken c.accessToken = auth.accessToken
} }
go func() { go func(auth ClientAuth) {
c.cm.GetClientAuthChannel() <- ClientAuth{ c.cm.GetClientAuthChannel() <- auth
UserID: c.userID, }(ClientAuth{
Auth: auth, UserID: c.userID,
} Auth: auth,
}() })
} }
// AuthInfo gets authentication info for a user. // AuthInfo gets authentication info for a user.

View File

@ -42,6 +42,15 @@ type ClientManager struct {
allowProxy bool allowProxy bool
proxyProvider *proxyProvider proxyProvider *proxyProvider
proxyUseDuration time.Duration proxyUseDuration time.Duration
idGen idGen
}
type idGen int
func (i *idGen) next() int {
(*i)++
return int(*i)
} }
// ClientAuth holds an API auth produced by a Client for a specific user. // ClientAuth holds an API auth produced by a Client for a specific user.
@ -117,16 +126,7 @@ func (cm *ClientManager) GetClient(userID string) Client {
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created. // GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created.
func (cm *ClientManager) GetAnonymousClient() Client { func (cm *ClientManager) GetAnonymousClient() Client {
cm.clientsLocker.Lock() return cm.GetClient(fmt.Sprintf("anonymous-%v", cm.idGen.next()))
defer cm.clientsLocker.Unlock()
if client, ok := cm.clients[""]; ok {
client.DeleteAuth()
}
cm.clients[""] = cm.newClient("")
return cm.clients[""]
} }
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared. // LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
@ -266,11 +266,6 @@ func (cm *ClientManager) forwardClientAuths() {
// setToken sets the token for the given userID with the given expiration time. // setToken sets the token for the given userID with the given expiration time.
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) { func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
// We don't want to set tokens of anonymous clients.
if userID == "" {
return
}
cm.tokensLocker.Lock() cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock() defer cm.tokensLocker.Unlock()
@ -279,6 +274,8 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
cm.tokens[userID] = token cm.tokens[userID] = token
cm.setTokenExpiration(userID, expiration) cm.setTokenExpiration(userID, expiration)
go cm.watchTokenExpiration(userID)
} }
// setTokenExpiration will ensure the token is refreshed if it expires. // setTokenExpiration will ensure the token is refreshed if it expires.
@ -296,8 +293,6 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat
timer: time.NewTimer(expiration), timer: time.NewTimer(expiration),
cancel: make(chan struct{}), cancel: make(chan struct{}),
} }
go cm.watchTokenExpiration(userID)
} }
func (cm *ClientManager) clearToken(userID string) { func (cm *ClientManager) clearToken(userID string) {
@ -311,6 +306,9 @@ func (cm *ClientManager) clearToken(userID string) {
// handleClientAuth updates or clears client authorisation based on auths received. // handleClientAuth updates or clears client authorisation based on auths received.
func (cm *ClientManager) handleClientAuth(ca ClientAuth) { func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
cm.clientsLocker.Lock()
defer cm.clientsLocker.Unlock()
// If we aren't managing this client, there's nothing to do. // If we aren't managing this client, there's nothing to do.
if _, ok := cm.clients[ca.UserID]; !ok { if _, ok := cm.clients[ca.UserID]; !ok {
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client") logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client")
@ -328,7 +326,9 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
} }
func (cm *ClientManager) watchTokenExpiration(userID string) { func (cm *ClientManager) watchTokenExpiration(userID string) {
cm.expirationsLocker.Lock()
expiration := cm.expirations[userID] expiration := cm.expirations[userID]
cm.expirationsLocker.Unlock()
select { select {
case <-expiration.timer.C: case <-expiration.timer.C:

View File

@ -116,8 +116,6 @@ func (c *client) UpdateUser() (user *User, err error) {
c.addresses = tmpList c.addresses = tmpList
} }
c.log.WithField("userID", user.ID).Info("Updated user")
return user, err return user, err
} }

View File

@ -180,7 +180,9 @@ func hasAPIAuth(accountName string) error {
if err != nil { if err != nil {
return internalError(err, "getting user %s", account.Username()) return internalError(err, "getting user %s", account.Username())
} }
a.True(ctx.GetTestingT(), bridgeUser.HasAPIAuth()) a.Eventually(ctx.GetTestingT(), func() bool {
return bridgeUser.HasAPIAuth()
}, 5*time.Second, 10*time.Millisecond)
return ctx.GetTestingError() return ctx.GetTestingError()
} }

View File

@ -27,6 +27,9 @@ func (api *FakePMAPI) GetEvent(eventID string) (*pmapi.Event, error) {
} }
// Request for empty ID returns the latest event. // Request for empty ID returns the latest event.
if eventID == "" { if eventID == "" {
if len(api.events) == 0 {
return &pmapi.Event{EventID: ""}, nil
}
return api.events[len(api.events)-1], nil return api.events[len(api.events)-1], nil
} }
// Otherwise it tries to find specific ID and return all next events merged into one. // Otherwise it tries to find specific ID and return all next events merged into one.

View File

@ -104,12 +104,9 @@ func (api *FakePMAPI) checkInternetAndRecordCall(method method, path string, req
} }
func (api *FakePMAPI) sendAuth(auth *pmapi.Auth) { func (api *FakePMAPI) sendAuth(auth *pmapi.Auth) {
go func() { go func(clientAuth pmapi.ClientAuth) {
api.controller.clientManager.GetClientAuthChannel() <- pmapi.ClientAuth{ api.controller.clientManager.GetClientAuthChannel() <- clientAuth
UserID: api.user.ID, }(pmapi.ClientAuth{UserID: api.user.ID, Auth: auth})
Auth: auth,
}
}()
} }
func (api *FakePMAPI) setUser(username string) error { func (api *FakePMAPI) setUser(username string) error {