From fd80848fcdcc452b863d5eba7b46f5f64273552c Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Thu, 27 Oct 2022 00:50:26 +0200 Subject: [PATCH] Other(refactor): Use normal value + mutex for user.updateCh --- internal/user/events.go | 71 +++++++++---------- internal/user/imap.go | 14 ++-- internal/user/smtp.go | 105 ---------------------------- internal/user/sync.go | 11 ++- internal/user/user.go | 150 +++++++++++++++++++++++++++++++++------- 5 files changed, 165 insertions(+), 186 deletions(-) diff --git a/internal/user/events.go b/internal/user/events.go index 522d0963..d7628c8e 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -113,18 +113,14 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad return fmt.Errorf("failed to get primary address: %w", err) } - user.updateCh.SetFrom(event.Address.ID, primAddr.ID) + user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID] case vault.SplitMode: - user.updateCh.Set(event.Address.ID, queue.NewQueuedChannel[imap.Update](0, 0)) + user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0) } if user.vault.AddressMode() == vault.SplitMode { - if ok, err := user.updateCh.GetErr(event.Address.ID, func(updateCh *queue.QueuedChannel[imap.Update]) error { - return syncLabels(ctx, user.client, updateCh) - }); !ok { - return fmt.Errorf("no such address %q", event.Address.ID) - } else if err != nil { + if err := syncLabels(ctx, user.client, user.updateCh[event.Address.ID]); err != nil { return fmt.Errorf("failed to sync labels to new address: %w", err) } } @@ -136,7 +132,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad }) return nil - }, &user.apiAddrsLock) + }, &user.apiAddrsLock, &user.updateChLock) } func (user *User) handleUpdateAddressEvent(_ context.Context, event liteapi.AddressEvent) error { //nolint:unparam @@ -154,7 +150,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event liteapi.Addr }) return nil - }) + }, &user.apiAddrsLock) } func (user *User) handleDeleteAddressEvent(_ context.Context, event liteapi.AddressEvent) error { @@ -164,14 +160,13 @@ func (user *User) handleDeleteAddressEvent(_ context.Context, event liteapi.Addr return fmt.Errorf("address %q does not exist", event.ID) } - if ok := user.updateCh.GetDelete(event.ID, func(updateCh *queue.QueuedChannel[imap.Update]) { - if user.vault.AddressMode() == vault.SplitMode { - updateCh.CloseAndDiscardQueued() - } - }); !ok { - return fmt.Errorf("no such address %q", event.ID) + if user.vault.AddressMode() == vault.SplitMode { + user.updateCh[event.ID].CloseAndDiscardQueued() + delete(user.updateCh, event.ID) } + delete(user.apiAddrs, event.ID) + user.eventCh.Enqueue(events.UserAddressDeleted{ UserID: user.ID(), AddressID: event.ID, @@ -179,7 +174,7 @@ func (user *User) handleDeleteAddressEvent(_ context.Context, event liteapi.Addr }) return nil - }) + }, &user.apiAddrsLock, &user.updateChLock) } // handleLabelEvents handles the given label events. @@ -214,9 +209,9 @@ func (user *User) handleCreateLabelEvent(_ context.Context, event liteapi.LabelE user.apiLabels[event.Label.ID] = event.Label - user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) { + for _, updateCh := range user.updateCh { updateCh.Enqueue(newMailboxCreatedUpdate(imap.MailboxID(event.ID), getMailboxName(event.Label))) - }) + } user.eventCh.Enqueue(events.UserLabelCreated{ UserID: user.ID(), @@ -225,7 +220,7 @@ func (user *User) handleCreateLabelEvent(_ context.Context, event liteapi.LabelE }) return nil - }, &user.apiLabelsLock) + }, &user.apiLabelsLock, &user.updateChLock) } func (user *User) handleUpdateLabelEvent(_ context.Context, event liteapi.LabelEvent) error { //nolint:unparam @@ -236,9 +231,9 @@ func (user *User) handleUpdateLabelEvent(_ context.Context, event liteapi.LabelE user.apiLabels[event.Label.ID] = event.Label - user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) { + for _, updateCh := range user.updateCh { updateCh.Enqueue(imap.NewMailboxUpdated(imap.MailboxID(event.ID), getMailboxName(event.Label))) - }) + } user.eventCh.Enqueue(events.UserLabelUpdated{ UserID: user.ID(), @@ -247,7 +242,7 @@ func (user *User) handleUpdateLabelEvent(_ context.Context, event liteapi.LabelE }) return nil - }, &user.apiLabelsLock) + }, &user.apiLabelsLock, &user.updateChLock) } func (user *User) handleDeleteLabelEvent(_ context.Context, event liteapi.LabelEvent) error { //nolint:unparam @@ -259,9 +254,9 @@ func (user *User) handleDeleteLabelEvent(_ context.Context, event liteapi.LabelE delete(user.apiLabels, event.ID) - user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) { + for _, updateCh := range user.updateCh { updateCh.Enqueue(imap.NewMailboxDeleted(imap.MailboxID(event.ID))) - }) + } user.eventCh.Enqueue(events.UserLabelDeleted{ UserID: user.ID(), @@ -270,7 +265,7 @@ func (user *User) handleDeleteLabelEvent(_ context.Context, event liteapi.LabelE }) return nil - }, &user.apiLabelsLock) + }, &user.apiLabelsLock, &user.updateChLock) } // handleMessageEvents handles the given message events. @@ -308,28 +303,26 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me return fmt.Errorf("failed to build RFC822 message: %w", err) } - user.updateCh.Get(full.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) { - updateCh.Enqueue(imap.NewMessagesCreated(buildRes.update)) - }) + user.updateCh[full.AddressID].Enqueue(imap.NewMessagesCreated(buildRes.update)) return nil }) - }, &user.apiUserLock, &user.apiAddrsLock) + }, &user.apiUserLock, &user.apiAddrsLock, &user.updateChLock) } func (user *User) handleUpdateMessageEvent(_ context.Context, event liteapi.MessageEvent) error { //nolint:unparam - update := imap.NewMessageMailboxesUpdated( - imap.MessageID(event.ID), - mapTo[string, imap.MailboxID](xslices.Filter(event.Message.LabelIDs, wantLabelID)), - event.Message.Seen(), - event.Message.Starred(), - ) + return safe.RLockRet(func() error { + update := imap.NewMessageMailboxesUpdated( + imap.MessageID(event.ID), + mapTo[string, imap.MailboxID](xslices.Filter(event.Message.LabelIDs, wantLabelID)), + event.Message.Seen(), + event.Message.Starred(), + ) - user.updateCh.Get(event.Message.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) { - updateCh.Enqueue(update) - }) + user.updateCh[event.Message.AddressID].Enqueue(update) - return nil + return nil + }, &user.updateChLock) } func getMailboxName(label liteapi.Label) []string { diff --git a/internal/user/imap.go b/internal/user/imap.go index f1310292..107c8bd1 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -24,7 +24,6 @@ import ( "time" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/safe" @@ -349,14 +348,9 @@ func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs [ // GetUpdates returns a stream of updates that the gluon server should apply. // It is recommended that the returned channel is buffered with at least constants.ChannelBufferCount. func (conn *imapConnector) GetUpdates() <-chan imap.Update { - updateCh, ok := safe.MapGetRet(conn.updateCh, conn.addrID, func(updateCh *queue.QueuedChannel[imap.Update]) <-chan imap.Update { - return updateCh.GetChannel() - }) - if !ok { - panic(fmt.Sprintf("update channel for %q not found", conn.addrID)) - } - - return updateCh + return safe.RLockRet(func() <-chan imap.Update { + return conn.updateCh[conn.addrID].GetChannel() + }, &conn.updateChLock) } // GetUIDValidity returns the default UID validity for this user. @@ -413,7 +407,7 @@ func (conn *imapConnector) importMessage( return nil }) - }); err != nil { + }, &conn.apiUserLock, &conn.apiAddrsLock); err != nil { return imap.Message{}, nil, err } diff --git a/internal/user/smtp.go b/internal/user/smtp.go index d11ec63a..9edc8243 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -18,21 +18,17 @@ package user import ( - "bytes" "context" "encoding/base64" "fmt" - "io" "net/mail" "net/url" "runtime" "strings" - "time" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-rfc5322" "github.com/ProtonMail/gopenpgp/v2/crypto" - "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/pkg/message" "github.com/ProtonMail/proton-bridge/v2/pkg/message/parser" @@ -42,107 +38,6 @@ import ( "golang.org/x/exp/slices" ) -func (user *User) sendMail(authID string, emails []string, from string, to []string, r io.Reader) error { //nolint:funlen - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Read the message to send. - b, err := io.ReadAll(r) - if err != nil { - return fmt.Errorf("failed to read message: %w", err) - } - - // Compute the hash of the message (to match it against SMTP messages). - hash, err := getMessageHash(b) - if err != nil { - return err - } - - // Check if we already tried to send this message recently. - if ok, err := user.sendHash.tryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second)); err != nil { - return fmt.Errorf("failed to check send hash: %w", err) - } else if !ok { - user.log.Warn("A duplicate message was already sent recently, skipping") - return nil - } - - // If we fail to send this message, we should remove the hash from the send recorder. - defer user.sendHash.removeOnFail(hash) - - // Create a new message parser from the reader. - parser, err := parser.New(bytes.NewReader(b)) - if err != nil { - return fmt.Errorf("failed to create parser: %w", err) - } - - // If the message contains a sender, use it instead of the one from the return path. - if sender, ok := getMessageSender(parser); ok { - from = sender - } - - // Load the user's mail settings. - settings, err := user.client.GetMailSettings(ctx) - if err != nil { - return fmt.Errorf("failed to get mail settings: %w", err) - } - - return safe.LockRet(func() error { - addrID, err := getAddrID(user.apiAddrs, from) - if err != nil { - return err - } - - return withAddrKR(user.apiUser, user.apiAddrs[addrID], user.vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error { - // Use the first key for encrypting the message. - addrKR, err := addrKR.FirstKey() - if err != nil { - return fmt.Errorf("failed to get first key: %w", err) - } - - // If we have to attach the public key, do it now. - if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { - key, err := addrKR.GetKey(0) - if err != nil { - return fmt.Errorf("failed to get sending key: %w", err) - } - - pubKey, err := key.GetArmoredPublicKey() - if err != nil { - return fmt.Errorf("failed to get public key: %w", err) - } - - parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) - } - - // Parse the message we want to send (after we have attached the public key). - message, err := message.ParseWithParser(parser) - if err != nil { - return fmt.Errorf("failed to parse message: %w", err) - } - - // Send the message using the correct key. - sent, err := sendWithKey( - ctx, - user.client, - authID, - user.vault.AddressMode(), - settings, - userKR, addrKR, - emails, from, to, - message, - ) - if err != nil { - return fmt.Errorf("failed to send message: %w", err) - } - - // If the message was successfully sent, we can update the message ID in the record. - user.sendHash.addMessageID(hash, sent.ID) - - return nil - }) - }, &user.apiUserLock, &user.apiAddrsLock) -} - // sendWithKey sends the message with the given address key. func sendWithKey( //nolint:funlen ctx context.Context, diff --git a/internal/user/sync.go b/internal/user/sync.go index d056aa7b..e3c9d599 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -35,6 +35,7 @@ import ( "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" "gitlab.protontech.ch/go/liteapi" + "golang.org/x/exp/maps" ) const ( @@ -78,9 +79,7 @@ func (user *User) sync(ctx context.Context) error { if !user.vault.SyncStatus().HasLabels { user.log.Debug("Syncing labels") - if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error { - return syncLabels(ctx, user.client, xslices.Unique(updateCh)...) - }); err != nil { + if err := syncLabels(ctx, user.client, xslices.Unique(maps.Values(user.updateCh))...); err != nil { return fmt.Errorf("failed to sync labels: %w", err) } @@ -96,9 +95,7 @@ func (user *User) sync(ctx context.Context) error { if !user.vault.SyncStatus().HasMessages { user.log.Debug("Syncing messages") - if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error { - return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh) - }); err != nil { + if err := syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, user.updateCh, user.eventCh); err != nil { return fmt.Errorf("failed to sync messages: %w", err) } @@ -113,7 +110,7 @@ func (user *User) sync(ctx context.Context) error { return nil }) - }, &user.apiUserLock, &user.apiAddrsLock) + }, &user.apiUserLock, &user.apiAddrsLock, &user.updateChLock) } func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error { diff --git a/internal/user/user.go b/internal/user/user.go index bc18bd01..c6e5d68c 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -18,6 +18,7 @@ package user import ( + "bytes" "context" "crypto/subtle" "fmt" @@ -30,10 +31,13 @@ import ( "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/queue" + "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/async" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/vault" + "github.com/ProtonMail/proton-bridge/v2/pkg/message" + "github.com/ProtonMail/proton-bridge/v2/pkg/message/parser" "github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xsync" "github.com/sirupsen/logrus" @@ -49,9 +53,10 @@ var ( type User struct { log *logrus.Entry - vault *vault.User - client *liteapi.Client - eventCh *queue.QueuedChannel[events.Event] + vault *vault.User + client *liteapi.Client + eventCh *queue.QueuedChannel[events.Event] + sendHash *sendRecorder apiUser liteapi.User apiUserLock sync.RWMutex @@ -62,8 +67,8 @@ type User struct { apiLabels map[string]liteapi.Label apiLabelsLock sync.RWMutex - updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]] - sendHash *sendRecorder + updateCh map[string]*queue.QueuedChannel[imap.Update] + updateChLock sync.RWMutex tasks *xsync.Group abortable async.Abortable @@ -134,15 +139,15 @@ func New( user := &User{ log: logrus.WithField("userID", apiUser.ID), - vault: encVault, - client: client, - eventCh: queue.NewQueuedChannel[events.Event](0, 0), + vault: encVault, + client: client, + eventCh: queue.NewQueuedChannel[events.Event](0, 0), + sendHash: newSendRecorder(sendEntryExpiry), apiUser: apiUser, apiAddrs: groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), apiLabels: groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }), - updateCh: safe.NewMapFrom(updateCh, nil), - sendHash: newSendRecorder(sendEntryExpiry), + updateCh: updateCh, tasks: xsync.NewGroup(context.Background()), @@ -251,26 +256,24 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er user.abortable.Abort() defer user.goSync() - return safe.RLockRet(func() error { - user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) { - for _, updateCh := range xslices.Unique(updateCh) { - updateCh.CloseAndDiscardQueued() - } - }) + return safe.LockRet(func() error { + for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) { + updateCh.CloseAndDiscardQueued() + } - user.updateCh.Clear() + user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update]) switch mode { case vault.CombinedMode: primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0) for addrID := range user.apiAddrs { - user.updateCh.Set(addrID, primaryUpdateCh) + user.updateCh[addrID] = primaryUpdateCh } case vault.SplitMode: for addrID := range user.apiAddrs { - user.updateCh.Set(addrID, queue.NewQueuedChannel[imap.Update](0, 0)) + user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0) } } @@ -283,7 +286,7 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er } return nil - }, &user.apiAddrsLock) + }, &user.apiAddrsLock, &user.updateChLock) } // GetGluonIDs returns the users gluon IDs. @@ -368,7 +371,12 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) { } // SendMail sends an email from the given address to the given recipients. +// +// nolint:funlen func (user *User) SendMail(authID string, from string, to []string, r io.Reader) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if len(to) == 0 { return ErrInvalidRecipient } @@ -382,8 +390,100 @@ func (user *User) SendMail(authID string, from string, to []string, r io.Reader) return addr.Email }) - return user.sendMail(authID, emails, from, to, r) - }, &user.apiAddrsLock) + // Read the message to send. + b, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("failed to read message: %w", err) + } + + // Compute the hash of the message (to match it against SMTP messages). + hash, err := getMessageHash(b) + if err != nil { + return err + } + + // Check if we already tried to send this message recently. + if ok, err := user.sendHash.tryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second)); err != nil { + return fmt.Errorf("failed to check send hash: %w", err) + } else if !ok { + user.log.Warn("A duplicate message was already sent recently, skipping") + return nil + } + + // If we fail to send this message, we should remove the hash from the send recorder. + defer user.sendHash.removeOnFail(hash) + + // Create a new message parser from the reader. + parser, err := parser.New(bytes.NewReader(b)) + if err != nil { + return fmt.Errorf("failed to create parser: %w", err) + } + + // If the message contains a sender, use it instead of the one from the return path. + if sender, ok := getMessageSender(parser); ok { + from = sender + } + + // Load the user's mail settings. + settings, err := user.client.GetMailSettings(ctx) + if err != nil { + return fmt.Errorf("failed to get mail settings: %w", err) + } + + addrID, err := getAddrID(user.apiAddrs, from) + if err != nil { + return err + } + + return withAddrKR(user.apiUser, user.apiAddrs[addrID], user.vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error { + // Use the first key for encrypting the message. + addrKR, err := addrKR.FirstKey() + if err != nil { + return fmt.Errorf("failed to get first key: %w", err) + } + + // If we have to attach the public key, do it now. + if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { + key, err := addrKR.GetKey(0) + if err != nil { + return fmt.Errorf("failed to get sending key: %w", err) + } + + pubKey, err := key.GetArmoredPublicKey() + if err != nil { + return fmt.Errorf("failed to get public key: %w", err) + } + + parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) + } + + // Parse the message we want to send (after we have attached the public key). + message, err := message.ParseWithParser(parser) + if err != nil { + return fmt.Errorf("failed to parse message: %w", err) + } + + // Send the message using the correct key. + sent, err := sendWithKey( + ctx, + user.client, + authID, + user.vault.AddressMode(), + settings, + userKR, addrKR, + emails, from, to, + message, + ) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + + // If the message was successfully sent, we can update the message ID in the record. + user.sendHash.addMessageID(hash, sent.ID) + + return nil + }) + }, &user.apiUserLock, &user.apiAddrsLock) } // CheckAuth returns whether the given email and password can be used to authenticate over IMAP or SMTP with this user. @@ -445,11 +545,11 @@ func (user *User) Close() { user.client.Close() // Close the user's update channels. - user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) { - for _, updateCh := range xslices.Unique(updateCh) { + safe.RLock(func() { + for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) { updateCh.CloseAndDiscardQueued() } - }) + }, &user.updateChLock) // Close the user's notify channel. user.eventCh.CloseAndDiscardQueued()