From 83339da26cdd7b16fe0751b5d4c2ba65e4476b75 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Wed, 26 Oct 2022 23:13:38 +0200 Subject: [PATCH] Other(refactor): Use normal value + mutex for user.apiUser --- internal/user/events.go | 13 ++++++++----- internal/user/keys.go | 7 ++++--- internal/user/user.go | 36 +++++++++++++++++++----------------- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/internal/user/events.go b/internal/user/events.go index fab08722..0f224ea3 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -25,6 +25,7 @@ import ( "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/xslices" "gitlab.protontech.ch/go/liteapi" @@ -61,13 +62,15 @@ func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error // handleUserEvent handles the given user event. func (user *User) handleUserEvent(_ context.Context, userEvent liteapi.User) error { - user.apiUser.Save(userEvent) + return safe.LockRet(func() error { + user.apiUser = userEvent - user.eventCh.Enqueue(events.UserChanged{ - UserID: user.ID(), - }) + user.eventCh.Enqueue(events.UserChanged{ + UserID: user.ID(), + }) - return nil + return nil + }, &user.apiUserLock) } // handleAddressEvents handles the given address events. diff --git a/internal/user/keys.go b/internal/user/keys.go index 961458c2..e3809d22 100644 --- a/internal/user/keys.go +++ b/internal/user/keys.go @@ -21,19 +21,20 @@ import ( "fmt" "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "gitlab.protontech.ch/go/liteapi" ) func (user *User) withUserKR(fn func(*crypto.KeyRing) error) error { - return user.apiUser.LoadErr(func(apiUser liteapi.User) error { - userKR, err := apiUser.Keys.Unlock(user.vault.KeyPass(), nil) + return safe.RLockRet(func() error { + userKR, err := user.apiUser.Keys.Unlock(user.vault.KeyPass(), nil) if err != nil { return fmt.Errorf("failed to unlock user keys: %w", err) } defer userKR.ClearPrivateParams() return fn(userKR) - }) + }, &user.apiUserLock) } func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing, *crypto.KeyRing) error) error { diff --git a/internal/user/user.go b/internal/user/user.go index 5a63fd80..1fa9819a 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -52,7 +52,9 @@ type User struct { client *liteapi.Client eventCh *queue.QueuedChannel[events.Event] - apiUser *safe.Value[liteapi.User] + apiUser liteapi.User + apiUserLock sync.RWMutex + apiAddrs *safe.Map[string, liteapi.Address] apiLabels *safe.Map[string, liteapi.Label] updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]] @@ -131,7 +133,7 @@ func New( client: client, eventCh: queue.NewQueuedChannel[events.Event](0, 0), - apiUser: safe.NewValue(apiUser), + apiUser: apiUser, apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr), apiLabels: safe.NewMapFrom(groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }), nil), updateCh: safe.NewMapFrom(updateCh, nil), @@ -196,29 +198,29 @@ func New( // ID returns the user's ID. func (user *User) ID() string { - return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string { - return apiUser.ID - }) + return safe.RLockRet(func() string { + return user.apiUser.ID + }, &user.apiUserLock) } // Name returns the user's username. func (user *User) Name() string { - return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string { - return apiUser.Name - }) + return safe.RLockRet(func() string { + return user.apiUser.Name + }, &user.apiUserLock) } // Match matches the given query against the user's username and email addresses. func (user *User) Match(query string) bool { - return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) bool { - if query == apiUser.Name { + return safe.RLockRet(func() bool { + if query == user.apiUser.Name { return true } return user.apiAddrs.HasFunc(func(_ string, addr liteapi.Address) bool { return addr.Email == query }) - }) + }, &user.apiUserLock) } // Emails returns all the user's email addresses via the callback. @@ -305,16 +307,16 @@ func (user *User) BridgePass() []byte { // UsedSpace returns the total space used by the user on the API. func (user *User) UsedSpace() int { - return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int { - return apiUser.UsedSpace - }) + return safe.RLockRet(func() int { + return user.apiUser.UsedSpace + }, &user.apiUserLock) } // MaxSpace returns the amount of space the user can use on the API. func (user *User) MaxSpace() int { - return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int { - return apiUser.MaxSpace - }) + return safe.RLockRet(func() int { + return user.apiUser.MaxSpace + }, &user.apiUserLock) } // GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change).