From 0bc99dbd4fc919cfa4b8e59d44d33dac91ee7f4c Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Wed, 26 Oct 2022 23:48:18 +0200 Subject: [PATCH] Other(refactor): Use normal value + mutex for user.apiAddrs --- internal/user/events.go | 142 ++++++++++++++++++++----------------- internal/user/imap.go | 42 +++++------ internal/user/keys.go | 108 +++++++++------------------- internal/user/keys_test.go | 20 +----- internal/user/smtp.go | 82 +++++++++++---------- internal/user/sync.go | 71 ++++++++++--------- internal/user/types.go | 28 +++++++- internal/user/user.go | 124 ++++++++++++++++++-------------- 8 files changed, 311 insertions(+), 306 deletions(-) diff --git a/internal/user/events.go b/internal/user/events.go index 0f224ea3..4774ef5c 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -99,77 +99,87 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea } func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error { - if had := user.apiAddrs.Set(event.Address.ID, event.Address); had { - return fmt.Errorf("address %q already exists", event.Address.ID) - } + return safe.LockRet(func() error { + if _, ok := user.apiAddrs[event.Address.ID]; ok { + return fmt.Errorf("address %q already exists", event.ID) + } - switch user.vault.AddressMode() { - case vault.CombinedMode: - user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) { - user.updateCh.SetFrom(event.Address.ID, addrID) + user.apiAddrs[event.Address.ID] = event.Address + + switch user.vault.AddressMode() { + case vault.CombinedMode: + primAddr, err := getAddrIdx(user.apiAddrs, 0) + if err != nil { + return fmt.Errorf("failed to get primary address: %w", err) + } + + user.updateCh.SetFrom(event.Address.ID, primAddr.ID) + + case vault.SplitMode: + user.updateCh.Set(event.Address.ID, queue.NewQueuedChannel[imap.Update](0, 0)) + } + + user.eventCh.Enqueue(events.UserAddressCreated{ + UserID: user.ID(), + AddressID: event.Address.ID, + Email: event.Address.Email, }) - case vault.SplitMode: - user.updateCh.Set(event.Address.ID, queue.NewQueuedChannel[imap.Update](0, 0)) - } - - user.eventCh.Enqueue(events.UserAddressCreated{ - UserID: user.ID(), - AddressID: event.Address.ID, - Email: event.Address.Email, - }) - - if user.vault.AddressMode() == vault.SplitMode { - if ok, err := user.updateCh.GetErr(event.Address.ID, func(updateCh *queue.QueuedChannel[imap.Update]) error { - return syncLabels(ctx, user.client, updateCh) - }); !ok { - return fmt.Errorf("no such address %q", event.Address.ID) - } else if err != nil { - return fmt.Errorf("failed to sync labels to new address: %w", err) + if user.vault.AddressMode() == vault.SplitMode { + if ok, err := user.updateCh.GetErr(event.Address.ID, func(updateCh *queue.QueuedChannel[imap.Update]) error { + return syncLabels(ctx, user.client, updateCh) + }); !ok { + return fmt.Errorf("no such address %q", event.Address.ID) + } else if err != nil { + return fmt.Errorf("failed to sync labels to new address: %w", err) + } } - } - return nil + return nil + }, &user.apiAddrsLock) } func (user *User) handleUpdateAddressEvent(_ context.Context, event liteapi.AddressEvent) error { //nolint:unparam - if had := user.apiAddrs.Set(event.Address.ID, event.Address); !had { - return fmt.Errorf("address %q does not exist", event.Address.ID) - } + return safe.LockRet(func() error { + if _, ok := user.apiAddrs[event.Address.ID]; ok { + return fmt.Errorf("address %q already exists", event.ID) + } - user.eventCh.Enqueue(events.UserAddressUpdated{ - UserID: user.ID(), - AddressID: event.Address.ID, - Email: event.Address.Email, + user.apiAddrs[event.Address.ID] = event.Address + + user.eventCh.Enqueue(events.UserAddressUpdated{ + UserID: user.ID(), + AddressID: event.Address.ID, + Email: event.Address.Email, + }) + + return nil }) - - return nil } func (user *User) handleDeleteAddressEvent(_ context.Context, event liteapi.AddressEvent) error { - var email string - - if ok := user.apiAddrs.GetDelete(event.ID, func(apiAddr liteapi.Address) { - email = apiAddr.Email - }); !ok { - return fmt.Errorf("no such address %q", event.ID) - } - - if ok := user.updateCh.GetDelete(event.ID, func(updateCh *queue.QueuedChannel[imap.Update]) { - if user.vault.AddressMode() == vault.SplitMode { - updateCh.CloseAndDiscardQueued() + return safe.LockRet(func() error { + addr, ok := user.apiAddrs[event.ID] + if !ok { + return fmt.Errorf("address %q does not exist", event.ID) } - }); !ok { - return fmt.Errorf("no such address %q", event.ID) - } - user.eventCh.Enqueue(events.UserAddressDeleted{ - UserID: user.ID(), - AddressID: event.ID, - Email: email, + if ok := user.updateCh.GetDelete(event.ID, func(updateCh *queue.QueuedChannel[imap.Update]) { + if user.vault.AddressMode() == vault.SplitMode { + updateCh.CloseAndDiscardQueued() + } + }); !ok { + return fmt.Errorf("no such address %q", event.ID) + } + + user.eventCh.Enqueue(events.UserAddressDeleted{ + UserID: user.ID(), + AddressID: event.ID, + Email: addr.Email, + }) + + return nil }) - - return nil } // handleLabelEvents handles the given label events. @@ -254,18 +264,20 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me return fmt.Errorf("failed to get full message: %w", err) } - return user.withAddrKR(event.Message.AddressID, func(_, addrKR *crypto.KeyRing) error { - buildRes, err := buildRFC822(full, addrKR) - if err != nil { - return fmt.Errorf("failed to build RFC822 message: %w", err) - } + return safe.RLockRet(func() error { + return withAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { + buildRes, err := buildRFC822(full, addrKR) + if err != nil { + return fmt.Errorf("failed to build RFC822 message: %w", err) + } - user.updateCh.Get(full.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) { - updateCh.Enqueue(imap.NewMessagesCreated(buildRes.update)) + user.updateCh.Get(full.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) { + updateCh.Enqueue(imap.NewMessagesCreated(buildRes.update)) + }) + + return nil }) - - return nil - }) + }, &user.apiUserLock, &user.apiAddrsLock) } func (user *User) handleUpdateMessageEvent(_ context.Context, event liteapi.MessageEvent) error { //nolint:unparam diff --git a/internal/user/imap.go b/internal/user/imap.go index b1d403e1..9a0e391b 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -373,29 +373,31 @@ func (conn *imapConnector) importMessage( ) (imap.Message, []byte, error) { var full liteapi.FullMessage - if err := conn.withAddrKR(conn.addrID, func(_, addrKR *crypto.KeyRing) error { - res, err := stream.Collect(ctx, conn.client.ImportMessages(ctx, addrKR, 1, 1, []liteapi.ImportReq{{ - Metadata: liteapi.ImportMetadata{ - AddressID: conn.addrID, - LabelIDs: labelIDs, - Unread: liteapi.Bool(unread), - Flags: flags, - }, - Message: literal, - }}...)) - if err != nil { - return fmt.Errorf("failed to import message: %w", err) - } + if err := safe.RLockRet(func() error { + return withAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { + res, err := stream.Collect(ctx, conn.client.ImportMessages(ctx, addrKR, 1, 1, []liteapi.ImportReq{{ + Metadata: liteapi.ImportMetadata{ + AddressID: conn.addrID, + LabelIDs: labelIDs, + Unread: liteapi.Bool(unread), + Flags: flags, + }, + Message: literal, + }}...)) + if err != nil { + return fmt.Errorf("failed to import message: %w", err) + } - if full, err = conn.client.GetFullMessage(ctx, res[0].MessageID); err != nil { - return fmt.Errorf("failed to fetch message: %w", err) - } + if full, err = conn.client.GetFullMessage(ctx, res[0].MessageID); err != nil { + return fmt.Errorf("failed to fetch message: %w", err) + } - if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil { - return fmt.Errorf("failed to build message: %w", err) - } + if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil { + return fmt.Errorf("failed to build message: %w", err) + } - return nil + return nil + }) }); err != nil { return imap.Message{}, nil, err } diff --git a/internal/user/keys.go b/internal/user/keys.go index e3809d22..5874cd6d 100644 --- a/internal/user/keys.go +++ b/internal/user/keys.go @@ -21,85 +21,43 @@ import ( "fmt" "github.com/ProtonMail/gopenpgp/v2/crypto" - "github.com/ProtonMail/proton-bridge/v2/internal/safe" "gitlab.protontech.ch/go/liteapi" ) -func (user *User) withUserKR(fn func(*crypto.KeyRing) error) error { - return safe.RLockRet(func() error { - userKR, err := user.apiUser.Keys.Unlock(user.vault.KeyPass(), nil) +func withAddrKR(apiUser liteapi.User, apiAddr liteapi.Address, keyPass []byte, fn func(userKR, addrKR *crypto.KeyRing) error) error { + userKR, err := apiUser.Keys.Unlock(keyPass, nil) + if err != nil { + return fmt.Errorf("failed to unlock user keys: %w", err) + } + defer userKR.ClearPrivateParams() + + addrKR, err := apiAddr.Keys.Unlock(keyPass, userKR) + if err != nil { + return fmt.Errorf("failed to unlock address keys: %w", err) + } + defer addrKR.ClearPrivateParams() + + return fn(userKR, addrKR) +} + +func withAddrKRs(apiUser liteapi.User, apiAddr map[string]liteapi.Address, keyPass []byte, fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error { + userKR, err := apiUser.Keys.Unlock(keyPass, nil) + if err != nil { + return fmt.Errorf("failed to unlock user keys: %w", err) + } + defer userKR.ClearPrivateParams() + + addrKRs := make(map[string]*crypto.KeyRing, len(apiAddr)) + + for addrID, apiAddr := range apiAddr { + addrKR, err := apiAddr.Keys.Unlock(keyPass, userKR) if err != nil { - return fmt.Errorf("failed to unlock user keys: %w", err) + return fmt.Errorf("failed to unlock address keys: %w", err) } - defer userKR.ClearPrivateParams() + defer addrKR.ClearPrivateParams() - return fn(userKR) - }, &user.apiUserLock) -} - -func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing, *crypto.KeyRing) error) error { - return user.withUserKR(func(userKR *crypto.KeyRing) error { - if ok, err := user.apiAddrs.GetErr(addrID, func(apiAddr liteapi.Address) error { - addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR) - if err != nil { - return fmt.Errorf("failed to unlock address keys: %w", err) - } - defer userKR.ClearPrivateParams() - - return fn(userKR, addrKR) - }); !ok { - return fmt.Errorf("no such address %q", addrID) - } else if err != nil { - return err - } - - return nil - }) -} - -func (user *User) withAddrKRByEmail(email string, fn func(*crypto.KeyRing, *crypto.KeyRing) error) error { - return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error { - addrID, err := getAddrID(apiAddrs, email) - if err != nil { - return fmt.Errorf("failed to get address ID: %w", err) - } - - return user.withUserKR(func(userKR *crypto.KeyRing) error { - if ok, err := user.apiAddrs.GetErr(addrID, func(apiAddr liteapi.Address) error { - addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR) - if err != nil { - return fmt.Errorf("failed to unlock address keys: %w", err) - } - defer userKR.ClearPrivateParams() - - return fn(userKR, addrKR) - }); !ok { - return fmt.Errorf("no such address %q", addrID) - } else if err != nil { - return err - } - - return nil - }) - }) -} - -func (user *User) withAddrKRs(fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error { - return user.withUserKR(func(userKR *crypto.KeyRing) error { - return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error { - addrKRs := make(map[string]*crypto.KeyRing) - - for _, apiAddr := range apiAddrs { - addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR) - if err != nil { - return fmt.Errorf("failed to unlock address keys: %w", err) - } - defer userKR.ClearPrivateParams() - - addrKRs[apiAddr.ID] = addrKR - } - - return fn(userKR, addrKRs) - }) - }) + addrKRs[addrID] = addrKR + } + + return fn(userKR, addrKRs) } diff --git a/internal/user/keys_test.go b/internal/user/keys_test.go index da678b97..ab322977 100644 --- a/internal/user/keys_test.go +++ b/internal/user/keys_test.go @@ -27,24 +27,6 @@ import ( "gitlab.protontech.ch/go/liteapi/server" ) -func BenchmarkUserKeyRing(b *testing.B) { - b.StopTimer() - - withAPI(b, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { - withAccount(b, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) { - withUser(b, ctx, s, m, "username", "password", func(user *User) { - b.StartTimer() - - for i := 0; i < b.N; i++ { - require.NoError(b, user.withUserKR(func(userKR *crypto.KeyRing) error { - return nil - })) - } - }) - }) - }) -} - func BenchmarkAddrKeyRing(b *testing.B) { b.StopTimer() @@ -54,7 +36,7 @@ func BenchmarkAddrKeyRing(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - require.NoError(b, user.withAddrKR(addrIDs[0], func(userKR, addrKR *crypto.KeyRing) error { + require.NoError(b, withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { return nil })) } diff --git a/internal/user/smtp.go b/internal/user/smtp.go index 78d0d334..d11ec63a 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -32,6 +32,7 @@ import ( "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-rfc5322" "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/pkg/message" "github.com/ProtonMail/proton-bridge/v2/pkg/message/parser" @@ -85,54 +86,61 @@ func (user *User) sendMail(authID string, emails []string, from string, to []str return fmt.Errorf("failed to get mail settings: %w", err) } - return user.withAddrKRByEmail(from, func(userKR, addrKR *crypto.KeyRing) error { - // Use the first key for encrypting the message. - addrKR, err := addrKR.FirstKey() + return safe.LockRet(func() error { + addrID, err := getAddrID(user.apiAddrs, from) if err != nil { - return fmt.Errorf("failed to get first key: %w", err) + return err } - // If we have to attach the public key, do it now. - if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { - key, err := addrKR.GetKey(0) + return withAddrKR(user.apiUser, user.apiAddrs[addrID], user.vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error { + // Use the first key for encrypting the message. + addrKR, err := addrKR.FirstKey() if err != nil { - return fmt.Errorf("failed to get sending key: %w", err) + return fmt.Errorf("failed to get first key: %w", err) } - pubKey, err := key.GetArmoredPublicKey() - if err != nil { - return fmt.Errorf("failed to get public key: %w", err) + // If we have to attach the public key, do it now. + if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { + key, err := addrKR.GetKey(0) + if err != nil { + return fmt.Errorf("failed to get sending key: %w", err) + } + + pubKey, err := key.GetArmoredPublicKey() + if err != nil { + return fmt.Errorf("failed to get public key: %w", err) + } + + parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) } - parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) - } + // Parse the message we want to send (after we have attached the public key). + message, err := message.ParseWithParser(parser) + if err != nil { + return fmt.Errorf("failed to parse message: %w", err) + } - // Parse the message we want to send (after we have attached the public key). - message, err := message.ParseWithParser(parser) - if err != nil { - return fmt.Errorf("failed to parse message: %w", err) - } + // Send the message using the correct key. + sent, err := sendWithKey( + ctx, + user.client, + authID, + user.vault.AddressMode(), + settings, + userKR, addrKR, + emails, from, to, + message, + ) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } - // Send the message using the correct key. - sent, err := sendWithKey( - ctx, - user.client, - authID, - user.vault.AddressMode(), - settings, - userKR, addrKR, - emails, from, to, - message, - ) - if err != nil { - return fmt.Errorf("failed to send message: %w", err) - } + // If the message was successfully sent, we can update the message ID in the record. + user.sendHash.addMessageID(hash, sent.ID) - // If the message was successfully sent, we can update the message ID in the record. - user.sendHash.addMessageID(hash, sent.ID) - - return nil - }) + return nil + }) + }, &user.apiUserLock, &user.apiAddrsLock) } // sendWithKey sends the message with the given address key. diff --git a/internal/user/sync.go b/internal/user/sync.go index 9a8584c4..d056aa7b 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -29,6 +29,7 @@ import ( "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/xslices" @@ -72,45 +73,47 @@ func (user *User) doSync(ctx context.Context) error { } func (user *User) sync(ctx context.Context) error { - return user.withAddrKRs(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { - if !user.vault.SyncStatus().HasLabels { - user.log.Debug("Syncing labels") + return safe.RLockRet(func() error { + return withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { + if !user.vault.SyncStatus().HasLabels { + user.log.Debug("Syncing labels") - if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error { - return syncLabels(ctx, user.client, xslices.Unique(updateCh)...) - }); err != nil { - return fmt.Errorf("failed to sync labels: %w", err) + if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error { + return syncLabels(ctx, user.client, xslices.Unique(updateCh)...) + }); err != nil { + return fmt.Errorf("failed to sync labels: %w", err) + } + + if err := user.vault.SetHasLabels(true); err != nil { + return fmt.Errorf("failed to set has labels: %w", err) + } + + user.log.Debug("Synced labels") + } else { + user.log.Debug("Labels are already synced, skipping") } - if err := user.vault.SetHasLabels(true); err != nil { - return fmt.Errorf("failed to set has labels: %w", err) + if !user.vault.SyncStatus().HasMessages { + user.log.Debug("Syncing messages") + + if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error { + return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh) + }); err != nil { + return fmt.Errorf("failed to sync messages: %w", err) + } + + if err := user.vault.SetHasMessages(true); err != nil { + return fmt.Errorf("failed to set has messages: %w", err) + } + + user.log.Debug("Synced messages") + } else { + user.log.Debug("Messages are already synced, skipping") } - user.log.Debug("Synced labels") - } else { - user.log.Debug("Labels are already synced, skipping") - } - - if !user.vault.SyncStatus().HasMessages { - user.log.Debug("Syncing messages") - - if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error { - return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh) - }); err != nil { - return fmt.Errorf("failed to sync messages: %w", err) - } - - if err := user.vault.SetHasMessages(true); err != nil { - return fmt.Errorf("failed to set has messages: %w", err) - } - - user.log.Debug("Synced messages") - } else { - user.log.Debug("Messages are already synced, skipping") - } - - return nil - }) + return nil + }) + }, &user.apiUserLock, &user.apiAddrsLock) } func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error { diff --git a/internal/user/types.go b/internal/user/types.go index ebe4833b..17a79a22 100644 --- a/internal/user/types.go +++ b/internal/user/types.go @@ -24,6 +24,8 @@ import ( "strings" "gitlab.protontech.ch/go/liteapi" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" ) // mapTo converts the slice to the given type. @@ -82,7 +84,7 @@ func hexDecode(b []byte) ([]byte, error) { } // getAddrID returns the address ID for the given email address. -func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) { +func getAddrID(apiAddrs map[string]liteapi.Address, email string) (string, error) { for _, addr := range apiAddrs { if strings.EqualFold(addr.Email, sanitizeEmail(email)) { return addr.ID, nil @@ -91,3 +93,27 @@ func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) { return "", fmt.Errorf("address %s not found", email) } + +// getAddrIdx returns the address with the given index. +func getAddrIdx(apiAddrs map[string]liteapi.Address, idx int) (liteapi.Address, error) { + sorted := sortSlice(maps.Values(apiAddrs), func(a, b liteapi.Address) bool { + return a.Order < b.Order + }) + + if idx < 0 || idx >= len(sorted) { + return liteapi.Address{}, fmt.Errorf("address index %d out of range", idx) + } + + return sorted[idx], nil +} + +// sortSlice returns the given slice sorted by the given comparator. +func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item { + sorted := make([]Item, len(items)) + + copy(sorted, items) + + slices.SortFunc(sorted, less) + + return sorted +} diff --git a/internal/user/user.go b/internal/user/user.go index 1fa9819a..f9890f81 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -38,6 +38,7 @@ import ( "github.com/bradenaw/juniper/xsync" "github.com/sirupsen/logrus" "gitlab.protontech.ch/go/liteapi" + "golang.org/x/exp/maps" ) var ( @@ -55,7 +56,9 @@ type User struct { apiUser liteapi.User apiUserLock sync.RWMutex - apiAddrs *safe.Map[string, liteapi.Address] + apiAddrs map[string]liteapi.Address + apiAddrsLock sync.RWMutex + apiLabels *safe.Map[string, liteapi.Label] updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]] sendHash *sendRecorder @@ -134,7 +137,7 @@ func New( eventCh: queue.NewQueuedChannel[events.Event](0, 0), apiUser: apiUser, - apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr), + apiAddrs: groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), apiLabels: safe.NewMapFrom(groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }), nil), updateCh: safe.NewMapFrom(updateCh, nil), sendHash: newSendRecorder(sendEntryExpiry), @@ -217,19 +220,23 @@ func (user *User) Match(query string) bool { return true } - return user.apiAddrs.HasFunc(func(_ string, addr liteapi.Address) bool { - return addr.Email == query - }) - }, &user.apiUserLock) + for _, addr := range user.apiAddrs { + if query == addr.Email { + return true + } + } + + return false + }, &user.apiUserLock, &user.apiAddrsLock) } // Emails returns all the user's email addresses via the callback. func (user *User) Emails() []string { - return safe.MapValuesRet(user.apiAddrs, func(apiAddrs []liteapi.Address) []string { - return xslices.Map(apiAddrs, func(addr liteapi.Address) string { + return safe.RLockRet(func() []string { + return xslices.Map(maps.Values(user.apiAddrs), func(addr liteapi.Address) string { return addr.Email }) - }) + }, &user.apiAddrsLock) } // GetAddressMode returns the user's current address mode. @@ -242,37 +249,39 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er user.abortable.Abort() defer user.goSync() - user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) { - for _, updateCh := range xslices.Unique(updateCh) { - updateCh.CloseAndDiscardQueued() + return safe.RLockRet(func() error { + user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) { + for _, updateCh := range xslices.Unique(updateCh) { + updateCh.CloseAndDiscardQueued() + } + }) + + user.updateCh.Clear() + + switch mode { + case vault.CombinedMode: + primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0) + + for addrID := range user.apiAddrs { + user.updateCh.Set(addrID, primaryUpdateCh) + } + + case vault.SplitMode: + for addrID := range user.apiAddrs { + user.updateCh.Set(addrID, queue.NewQueuedChannel[imap.Update](0, 0)) + } } - }) - user.updateCh.Clear() + if err := user.vault.SetAddressMode(mode); err != nil { + return fmt.Errorf("failed to set address mode: %w", err) + } - switch mode { - case vault.CombinedMode: - primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0) + if err := user.vault.ClearSyncStatus(); err != nil { + return fmt.Errorf("failed to clear sync status: %w", err) + } - user.apiAddrs.IterKeys(func(addrID string) { - user.updateCh.Set(addrID, primaryUpdateCh) - }) - - case vault.SplitMode: - user.apiAddrs.IterKeys(func(addrID string) { - user.updateCh.Set(addrID, queue.NewQueuedChannel[imap.Update](0, 0)) - }) - } - - if err := user.vault.SetAddressMode(mode); err != nil { - return fmt.Errorf("failed to set address mode: %w", err) - } - - if err := user.vault.ClearSyncStatus(); err != nil { - return fmt.Errorf("failed to clear sync status: %w", err) - } - - return nil + return nil + }, &user.apiAddrsLock) } // GetGluonIDs returns the users gluon IDs. @@ -334,21 +343,26 @@ func (user *User) NewIMAPConnector(addrID string) connector.Connector { // In combined mode, this is just the user's primary address. // In split mode, this is all the user's addresses. func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) { - imapConn := make(map[string]connector.Connector) + return safe.RLockRetErr(func() (map[string]connector.Connector, error) { + imapConn := make(map[string]connector.Connector) - switch user.vault.AddressMode() { - case vault.CombinedMode: - user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) { - imapConn[addrID] = newIMAPConnector(user, addrID) - }) + switch user.vault.AddressMode() { + case vault.CombinedMode: + primAddr, err := getAddrIdx(user.apiAddrs, 0) + if err != nil { + return nil, fmt.Errorf("failed to get primary address: %w", err) + } - case vault.SplitMode: - user.apiAddrs.IterKeys(func(addrID string) { - imapConn[addrID] = newIMAPConnector(user, addrID) - }) - } + imapConn[primAddr.ID] = newIMAPConnector(user, primAddr.ID) - return imapConn, nil + case vault.SplitMode: + for addrID := range user.apiAddrs { + imapConn[addrID] = newIMAPConnector(user, addrID) + } + } + + return imapConn, nil + }, &user.apiAddrsLock) } // SendMail sends an email from the given address to the given recipients. @@ -357,17 +371,17 @@ func (user *User) SendMail(authID string, from string, to []string, r io.Reader) return ErrInvalidRecipient } - return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error { - if _, err := getAddrID(apiAddrs, from); err != nil { + return safe.RLockRet(func() error { + if _, err := getAddrID(user.apiAddrs, from); err != nil { return ErrInvalidReturnPath } - emails := xslices.Map(apiAddrs, func(addr liteapi.Address) string { + emails := xslices.Map(maps.Values(user.apiAddrs), func(addr liteapi.Address) string { return addr.Email }) return user.sendMail(authID, emails, from, to, r) - }) + }, &user.apiAddrsLock) } // CheckAuth returns whether the given email and password can be used to authenticate over IMAP or SMTP with this user. @@ -382,15 +396,15 @@ func (user *User) CheckAuth(email string, password []byte) (string, error) { return "", fmt.Errorf("invalid password") } - return safe.MapValuesRetErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) { - for _, addr := range apiAddrs { + return safe.RLockRetErr(func() (string, error) { + for _, addr := range user.apiAddrs { if strings.EqualFold(addr.Email, email) { return addr.ID, nil } } return "", fmt.Errorf("invalid email") - }) + }, &user.apiAddrsLock) } // OnStatusUp is called when the connection goes up.