mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-15 14:56:42 +00:00
fix: don't delete uid of anonymous clients
This commit is contained in:
@ -246,6 +246,10 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
|||||||
return false, errors.Wrap(err, "failed to get event")
|
return false, errors.Wrap(err, "failed to get event")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if event == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
l = l.WithField("newEventID", event.EventID)
|
l = l.WithField("newEventID", event.EventID)
|
||||||
|
|
||||||
if !loop.hasInternet {
|
if !loop.hasInternet {
|
||||||
|
|||||||
@ -19,6 +19,7 @@ package store
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
@ -53,6 +54,12 @@ func (store *Store) TestGetStoreFilePath() string {
|
|||||||
|
|
||||||
// TestDumpDB will dump store database content.
|
// TestDumpDB will dump store database content.
|
||||||
func (store *Store) TestDumpDB(tb assert.TestingT) {
|
func (store *Store) TestDumpDB(tb assert.TestingT) {
|
||||||
|
if store == nil || store.db == nil {
|
||||||
|
fmt.Printf(">>>>>>>> NIL STORE / DB <<<<<\n\n")
|
||||||
|
assert.NoError(tb, errors.New("store or database is nil"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
dumpCounts := true
|
dumpCounts := true
|
||||||
fmt.Printf(">>>>>>>> DUMP %s <<<<<\n\n", store.db.Path())
|
fmt.Printf(">>>>>>>> DUMP %s <<<<<\n\n", store.db.Path())
|
||||||
|
|
||||||
|
|||||||
@ -205,19 +205,23 @@ type AuthRefreshReq struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) sendAuth(auth *Auth) {
|
func (c *client) sendAuth(auth *Auth) {
|
||||||
c.log.Debug("Client is sending auth to ClientManager")
|
if auth != nil {
|
||||||
|
c.log.WithField("auth", *auth).Debug("Client is sending auth to ClientManager")
|
||||||
|
} else {
|
||||||
|
c.log.Debug("Client is sending nil auth to ClientManager")
|
||||||
|
}
|
||||||
|
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
c.uid = auth.UID()
|
c.uid = auth.UID()
|
||||||
c.accessToken = auth.accessToken
|
c.accessToken = auth.accessToken
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func(auth ClientAuth) {
|
||||||
c.cm.GetClientAuthChannel() <- ClientAuth{
|
c.cm.GetClientAuthChannel() <- auth
|
||||||
UserID: c.userID,
|
}(ClientAuth{
|
||||||
Auth: auth,
|
UserID: c.userID,
|
||||||
}
|
Auth: auth,
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthInfo gets authentication info for a user.
|
// AuthInfo gets authentication info for a user.
|
||||||
|
|||||||
@ -42,6 +42,15 @@ type ClientManager struct {
|
|||||||
allowProxy bool
|
allowProxy bool
|
||||||
proxyProvider *proxyProvider
|
proxyProvider *proxyProvider
|
||||||
proxyUseDuration time.Duration
|
proxyUseDuration time.Duration
|
||||||
|
|
||||||
|
idGen idGen
|
||||||
|
}
|
||||||
|
|
||||||
|
type idGen int
|
||||||
|
|
||||||
|
func (i *idGen) next() int {
|
||||||
|
(*i)++
|
||||||
|
return int(*i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientAuth holds an API auth produced by a Client for a specific user.
|
// ClientAuth holds an API auth produced by a Client for a specific user.
|
||||||
@ -117,16 +126,7 @@ func (cm *ClientManager) GetClient(userID string) Client {
|
|||||||
|
|
||||||
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created.
|
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created.
|
||||||
func (cm *ClientManager) GetAnonymousClient() Client {
|
func (cm *ClientManager) GetAnonymousClient() Client {
|
||||||
cm.clientsLocker.Lock()
|
return cm.GetClient(fmt.Sprintf("anonymous-%v", cm.idGen.next()))
|
||||||
defer cm.clientsLocker.Unlock()
|
|
||||||
|
|
||||||
if client, ok := cm.clients[""]; ok {
|
|
||||||
client.DeleteAuth()
|
|
||||||
}
|
|
||||||
|
|
||||||
cm.clients[""] = cm.newClient("")
|
|
||||||
|
|
||||||
return cm.clients[""]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
||||||
@ -266,11 +266,6 @@ func (cm *ClientManager) forwardClientAuths() {
|
|||||||
|
|
||||||
// setToken sets the token for the given userID with the given expiration time.
|
// setToken sets the token for the given userID with the given expiration time.
|
||||||
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
|
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
|
||||||
// We don't want to set tokens of anonymous clients.
|
|
||||||
if userID == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cm.tokensLocker.Lock()
|
cm.tokensLocker.Lock()
|
||||||
defer cm.tokensLocker.Unlock()
|
defer cm.tokensLocker.Unlock()
|
||||||
|
|
||||||
@ -279,6 +274,8 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
|
|||||||
cm.tokens[userID] = token
|
cm.tokens[userID] = token
|
||||||
|
|
||||||
cm.setTokenExpiration(userID, expiration)
|
cm.setTokenExpiration(userID, expiration)
|
||||||
|
|
||||||
|
go cm.watchTokenExpiration(userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setTokenExpiration will ensure the token is refreshed if it expires.
|
// setTokenExpiration will ensure the token is refreshed if it expires.
|
||||||
@ -296,8 +293,6 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat
|
|||||||
timer: time.NewTimer(expiration),
|
timer: time.NewTimer(expiration),
|
||||||
cancel: make(chan struct{}),
|
cancel: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
go cm.watchTokenExpiration(userID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ClientManager) clearToken(userID string) {
|
func (cm *ClientManager) clearToken(userID string) {
|
||||||
@ -311,6 +306,9 @@ func (cm *ClientManager) clearToken(userID string) {
|
|||||||
|
|
||||||
// handleClientAuth updates or clears client authorisation based on auths received.
|
// handleClientAuth updates or clears client authorisation based on auths received.
|
||||||
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||||
|
cm.clientsLocker.Lock()
|
||||||
|
defer cm.clientsLocker.Unlock()
|
||||||
|
|
||||||
// If we aren't managing this client, there's nothing to do.
|
// If we aren't managing this client, there's nothing to do.
|
||||||
if _, ok := cm.clients[ca.UserID]; !ok {
|
if _, ok := cm.clients[ca.UserID]; !ok {
|
||||||
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client")
|
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client")
|
||||||
@ -328,7 +326,9 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ClientManager) watchTokenExpiration(userID string) {
|
func (cm *ClientManager) watchTokenExpiration(userID string) {
|
||||||
|
cm.expirationsLocker.Lock()
|
||||||
expiration := cm.expirations[userID]
|
expiration := cm.expirations[userID]
|
||||||
|
cm.expirationsLocker.Unlock()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-expiration.timer.C:
|
case <-expiration.timer.C:
|
||||||
|
|||||||
@ -116,8 +116,6 @@ func (c *client) UpdateUser() (user *User, err error) {
|
|||||||
c.addresses = tmpList
|
c.addresses = tmpList
|
||||||
}
|
}
|
||||||
|
|
||||||
c.log.WithField("userID", user.ID).Info("Updated user")
|
|
||||||
|
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -180,7 +180,9 @@ func hasAPIAuth(accountName string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return internalError(err, "getting user %s", account.Username())
|
return internalError(err, "getting user %s", account.Username())
|
||||||
}
|
}
|
||||||
a.True(ctx.GetTestingT(), bridgeUser.HasAPIAuth())
|
a.Eventually(ctx.GetTestingT(), func() bool {
|
||||||
|
return bridgeUser.HasAPIAuth()
|
||||||
|
}, 5*time.Second, 10*time.Millisecond)
|
||||||
return ctx.GetTestingError()
|
return ctx.GetTestingError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -27,6 +27,9 @@ func (api *FakePMAPI) GetEvent(eventID string) (*pmapi.Event, error) {
|
|||||||
}
|
}
|
||||||
// Request for empty ID returns the latest event.
|
// Request for empty ID returns the latest event.
|
||||||
if eventID == "" {
|
if eventID == "" {
|
||||||
|
if len(api.events) == 0 {
|
||||||
|
return &pmapi.Event{EventID: ""}, nil
|
||||||
|
}
|
||||||
return api.events[len(api.events)-1], nil
|
return api.events[len(api.events)-1], nil
|
||||||
}
|
}
|
||||||
// Otherwise it tries to find specific ID and return all next events merged into one.
|
// Otherwise it tries to find specific ID and return all next events merged into one.
|
||||||
|
|||||||
@ -104,12 +104,9 @@ func (api *FakePMAPI) checkInternetAndRecordCall(method method, path string, req
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (api *FakePMAPI) sendAuth(auth *pmapi.Auth) {
|
func (api *FakePMAPI) sendAuth(auth *pmapi.Auth) {
|
||||||
go func() {
|
go func(clientAuth pmapi.ClientAuth) {
|
||||||
api.controller.clientManager.GetClientAuthChannel() <- pmapi.ClientAuth{
|
api.controller.clientManager.GetClientAuthChannel() <- clientAuth
|
||||||
UserID: api.user.ID,
|
}(pmapi.ClientAuth{UserID: api.user.ID, Auth: auth})
|
||||||
Auth: auth,
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *FakePMAPI) setUser(username string) error {
|
func (api *FakePMAPI) setUser(username string) error {
|
||||||
|
|||||||
Reference in New Issue
Block a user