Other(refactor): Use normal value + mutex for user.apiUser

This commit is contained in:
James Houlahan
2022-10-26 23:13:38 +02:00
parent 8749d5dc7d
commit 83339da26c
3 changed files with 31 additions and 25 deletions

View File

@ -25,6 +25,7 @@ import (
"github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
@ -61,13 +62,15 @@ func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error
// handleUserEvent handles the given user event. // handleUserEvent handles the given user event.
func (user *User) handleUserEvent(_ context.Context, userEvent liteapi.User) error { func (user *User) handleUserEvent(_ context.Context, userEvent liteapi.User) error {
user.apiUser.Save(userEvent) return safe.LockRet(func() error {
user.apiUser = userEvent
user.eventCh.Enqueue(events.UserChanged{ user.eventCh.Enqueue(events.UserChanged{
UserID: user.ID(), UserID: user.ID(),
}) })
return nil return nil
}, &user.apiUserLock)
} }
// handleAddressEvents handles the given address events. // handleAddressEvents handles the given address events.

View File

@ -21,19 +21,20 @@ import (
"fmt" "fmt"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
) )
func (user *User) withUserKR(fn func(*crypto.KeyRing) error) error { func (user *User) withUserKR(fn func(*crypto.KeyRing) error) error {
return user.apiUser.LoadErr(func(apiUser liteapi.User) error { return safe.RLockRet(func() error {
userKR, err := apiUser.Keys.Unlock(user.vault.KeyPass(), nil) userKR, err := user.apiUser.Keys.Unlock(user.vault.KeyPass(), nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to unlock user keys: %w", err) return fmt.Errorf("failed to unlock user keys: %w", err)
} }
defer userKR.ClearPrivateParams() defer userKR.ClearPrivateParams()
return fn(userKR) return fn(userKR)
}) }, &user.apiUserLock)
} }
func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing, *crypto.KeyRing) error) error { func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing, *crypto.KeyRing) error) error {

View File

@ -52,7 +52,9 @@ type User struct {
client *liteapi.Client client *liteapi.Client
eventCh *queue.QueuedChannel[events.Event] eventCh *queue.QueuedChannel[events.Event]
apiUser *safe.Value[liteapi.User] apiUser liteapi.User
apiUserLock sync.RWMutex
apiAddrs *safe.Map[string, liteapi.Address] apiAddrs *safe.Map[string, liteapi.Address]
apiLabels *safe.Map[string, liteapi.Label] apiLabels *safe.Map[string, liteapi.Label]
updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]] updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]]
@ -131,7 +133,7 @@ func New(
client: client, client: client,
eventCh: queue.NewQueuedChannel[events.Event](0, 0), eventCh: queue.NewQueuedChannel[events.Event](0, 0),
apiUser: safe.NewValue(apiUser), apiUser: apiUser,
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr), apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
apiLabels: safe.NewMapFrom(groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }), nil), apiLabels: safe.NewMapFrom(groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }), nil),
updateCh: safe.NewMapFrom(updateCh, nil), updateCh: safe.NewMapFrom(updateCh, nil),
@ -196,29 +198,29 @@ func New(
// ID returns the user's ID. // ID returns the user's ID.
func (user *User) ID() string { func (user *User) ID() string {
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string { return safe.RLockRet(func() string {
return apiUser.ID return user.apiUser.ID
}) }, &user.apiUserLock)
} }
// Name returns the user's username. // Name returns the user's username.
func (user *User) Name() string { func (user *User) Name() string {
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string { return safe.RLockRet(func() string {
return apiUser.Name return user.apiUser.Name
}) }, &user.apiUserLock)
} }
// Match matches the given query against the user's username and email addresses. // Match matches the given query against the user's username and email addresses.
func (user *User) Match(query string) bool { func (user *User) Match(query string) bool {
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) bool { return safe.RLockRet(func() bool {
if query == apiUser.Name { if query == user.apiUser.Name {
return true return true
} }
return user.apiAddrs.HasFunc(func(_ string, addr liteapi.Address) bool { return user.apiAddrs.HasFunc(func(_ string, addr liteapi.Address) bool {
return addr.Email == query return addr.Email == query
}) })
}) }, &user.apiUserLock)
} }
// Emails returns all the user's email addresses via the callback. // Emails returns all the user's email addresses via the callback.
@ -305,16 +307,16 @@ func (user *User) BridgePass() []byte {
// UsedSpace returns the total space used by the user on the API. // UsedSpace returns the total space used by the user on the API.
func (user *User) UsedSpace() int { func (user *User) UsedSpace() int {
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int { return safe.RLockRet(func() int {
return apiUser.UsedSpace return user.apiUser.UsedSpace
}) }, &user.apiUserLock)
} }
// MaxSpace returns the amount of space the user can use on the API. // MaxSpace returns the amount of space the user can use on the API.
func (user *User) MaxSpace() int { func (user *User) MaxSpace() int {
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int { return safe.RLockRet(func() int {
return apiUser.MaxSpace return user.apiUser.MaxSpace
}) }, &user.apiUserLock)
} }
// GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change). // GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change).