diff --git a/internal/safe/map.go b/internal/safe/map.go index 39a2d58f..f3f08228 100644 --- a/internal/safe/map.go +++ b/internal/safe/map.go @@ -64,6 +64,13 @@ func (m *Map[Key, Val]) Set(key Key, val Val) { m.data[key] = val } +func (m *Map[Key, Val]) Delete(key Key) { + m.lock.Lock() + defer m.lock.Unlock() + + delete(m.data, key) +} + func (m *Map[Key, Val]) Iter(fn func(key Key, val Val)) { m.lock.RLock() defer m.lock.RUnlock() diff --git a/internal/user/events.go b/internal/user/events.go index cf5a1219..761db23f 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -6,6 +6,7 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/queue" + "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/vault" @@ -57,7 +58,7 @@ func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) e user.apiUser.Set(userEvent) - user.userKR = userKR + user.userKR.Set(userKR) user.eventCh.Enqueue(events.UserChanged{ UserID: user.ID(), @@ -92,7 +93,9 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea } func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error { - addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR) + addrKR, err := safe.GetTypeErr(user.userKR, func(userKR *crypto.KeyRing) (*crypto.KeyRing, error) { + return event.Address.Keys.Unlock(user.vault.KeyPass(), userKR) + }) if err != nil { return fmt.Errorf("failed to unlock address keys: %w", err) } @@ -104,7 +107,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad user.apiAddrs.Set(apiAddrs) - user.addrKRs[event.Address.ID] = addrKR + user.addrKRs.Set(event.Address.ID, addrKR) user.eventCh.Enqueue(events.UserAddressCreated{ UserID: user.ID(), @@ -124,7 +127,9 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad } func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error { - addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR) + addrKR, err := safe.GetTypeErr(user.userKR, func(userKR *crypto.KeyRing) (*crypto.KeyRing, error) { + return event.Address.Keys.Unlock(user.vault.KeyPass(), userKR) + }) if err != nil { return fmt.Errorf("failed to unlock address keys: %w", err) } @@ -136,7 +141,7 @@ func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.Ad user.apiAddrs.Set(apiAddrs) - user.addrKRs[event.Address.ID] = addrKR + user.addrKRs.Set(event.Address.ID, addrKR) user.eventCh.Enqueue(events.UserAddressUpdated{ UserID: user.ID(), @@ -162,7 +167,7 @@ func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.Ad user.apiAddrs.Set(apiAddrs) - delete(user.addrKRs, event.ID) + user.addrKRs.Delete(event.ID) if len(user.updateCh) > 1 { user.updateCh[event.ID].Close() @@ -264,7 +269,16 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me return fmt.Errorf("failed to get full message: %w", err) } - buildRes, err := buildRFC822(ctx, full, user.addrKRs) + buildRes, err := safe.GetMapErr( + user.addrKRs, + full.AddressID, + func(addrKR *crypto.KeyRing) (*buildRes, error) { + return buildRFC822(ctx, full, addrKR) + }, + func() (*buildRes, error) { + return nil, fmt.Errorf("address keyring not found") + }, + ) if err != nil { return fmt.Errorf("failed to build RFC822: %w", err) } diff --git a/internal/user/smtp.go b/internal/user/smtp.go index b06eea21..3a0e8193 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -31,8 +31,8 @@ type smtpSession struct { // authID holds the ID of the address that the SMTP client authenticated with to send the message. authID string - // fromAddrID is the ID of the current sending address (taken from the return path). - fromAddrID string + // from is the current sending address (taken from the return path). + from string // to holds all to for the current message. to []string @@ -57,7 +57,7 @@ func (session *smtpSession) Reset() { logrus.Info("SMTP session reset") // Clear the from and to fields. - session.fromAddrID = "" + session.from = "" session.to = nil } @@ -93,12 +93,11 @@ func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error { } } - fromAddrID, err := getAddrID(apiAddrs, from) - if err != nil { + if _, err := getAddrID(apiAddrs, sanitizeEmail(from)); err != nil { return fmt.Errorf("invalid return path: %w", err) } - session.fromAddrID = fromAddrID + session.from = from return nil }) @@ -123,8 +122,11 @@ func (session *smtpSession) Rcpt(to string) error { func (session *smtpSession) Data(r io.Reader) error { logrus.Info("SMTP session data") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + switch { - case session.fromAddrID == "": + case session.from == "": return ErrInvalidReturnPath case len(session.to) == 0: @@ -136,53 +138,61 @@ func (session *smtpSession) Data(r io.Reader) error { return fmt.Errorf("failed to create parser: %w", err) } - // If the message contains a sender, use it instead of the one from the return path. - session.apiAddrs.Get(func(apiAddrs []liteapi.Address) { - if sender, ok := getMessageSender(parser); ok { - for _, addr := range apiAddrs { - if strings.EqualFold(addr.Email, sanitizeEmail(sender)) { - session.fromAddrID = addr.ID - } - } - } - }) - - addrKR, ok := session.addrKRs[session.fromAddrID] - if !ok { - return ErrMissingAddrKey - } - - firstAddrKR, err := addrKR.FirstKey() - if err != nil { - return fmt.Errorf("failed to get first key: %w", err) - } - - from, err := safe.GetSliceErr(session.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) { - email, err := getAddrEmail(apiAddrs, session.fromAddrID) - if err != nil { - return "", fmt.Errorf("failed to get address email: %w", err) - } - - return sanitizeEmail(email), nil - }) - if err != nil { - return fmt.Errorf("failed to get address email: %w", err) - } - message, err := safe.GetSliceErr(session.apiAddrs, func(apiAddrs []liteapi.Address) (liteapi.Message, error) { - return safe.GetTypeErr(session.settings, func(settings liteapi.MailSettings) (liteapi.Message, error) { - return sendWithKey( - session.client, - session.authID, - session.vault.AddressMode(), - apiAddrs, - settings, - session.userKR, - firstAddrKR, - parser, - from, - session.to, - ) + addrID, err := getAddrID(apiAddrs, session.from) + if err != nil { + return liteapi.Message{}, fmt.Errorf("invalid return path: %w", err) + } + + return safe.GetMapErr(session.addrKRs, addrID, func(addrKR *crypto.KeyRing) (liteapi.Message, error) { + return safe.GetTypeErr(session.settings, func(settings liteapi.MailSettings) (liteapi.Message, error) { + // Use the first key for encrypting the message. + addrKR, err := addrKR.FirstKey() + if err != nil { + return liteapi.Message{}, fmt.Errorf("failed to get first key: %w", err) + } + + // If the message contains a sender, use it instead of the one from the return path. + if sender, ok := getMessageSender(parser); ok { + session.from = sender + } + + // If we have to attach the public key, do it now. + if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { + key, err := addrKR.GetKey(0) + if err != nil { + return liteapi.Message{}, fmt.Errorf("failed to get sending key: %w", err) + } + + pubKey, err := key.GetArmoredPublicKey() + if err != nil { + return liteapi.Message{}, fmt.Errorf("failed to get public key: %w", err) + } + + parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) + } + + message, err := message.ParseWithParser(parser) + if err != nil { + return liteapi.Message{}, fmt.Errorf("failed to parse message: %w", err) + } + + return sendWithKey( + ctx, + session.client, + session.authID, + session.vault.AddressMode(), + apiAddrs, + settings, + session.userKR, + addrKR, + session.from, + session.to, + message, + ) + }) + }, func() (liteapi.Message, error) { + return liteapi.Message{}, ErrMissingAddrKey }) }) if err != nil { @@ -196,53 +206,67 @@ func (session *smtpSession) Data(r io.Reader) error { // sendWithKey sends the message with the given address key. func sendWithKey( + ctx context.Context, client *liteapi.Client, authAddrID string, addrMode vault.AddressMode, apiAddrs []liteapi.Address, settings liteapi.MailSettings, - userKR, addrKR *crypto.KeyRing, - parser *parser.Parser, + userKR *safe.Value[*crypto.KeyRing], + addrKR *crypto.KeyRing, from string, to []string, + message message.Message, ) (liteapi.Message, error) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { - key, err := addrKR.GetKey(0) - if err != nil { - return liteapi.Message{}, fmt.Errorf("failed to get user public key: %w", err) - } - - pubKey, err := key.GetArmoredPublicKey() - if err != nil { - return liteapi.Message{}, fmt.Errorf("failed to get user public key: %w", err) - } - - parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) - } - - message, err := message.ParseWithParser(parser) - if err != nil { - return liteapi.Message{}, fmt.Errorf("failed to parse message: %w", err) - } - - if err := sanitizeParsedMessage(&message, apiAddrs, from, to); err != nil { - return liteapi.Message{}, fmt.Errorf("failed to sanitize message: %w", err) - } - parentID, err := getParentID(ctx, client, authAddrID, addrMode, message.References) if err != nil { return liteapi.Message{}, fmt.Errorf("failed to get parent ID: %w", err) } - draft, attKeys, err := createDraftWithAttachments(ctx, client, addrKR, message, parentID) - if err != nil { - return liteapi.Message{}, fmt.Errorf("failed to create draft: %w", err) + var decBody string + + switch message.MIMEType { + case rfc822.TextHTML: + decBody = string(message.RichBody) + + case rfc822.TextPlain: + decBody = string(message.PlainBody) } - recipients, err := getRecipients(ctx, client, userKR, settings, message.Recipients(), message.MIMEType) + encBody, err := addrKR.Encrypt(crypto.NewPlainMessageFromString(decBody), nil) + if err != nil { + return liteapi.Message{}, fmt.Errorf("failed to encrypt message body: %w", err) + } + + armBody, err := encBody.GetArmored() + if err != nil { + return liteapi.Message{}, fmt.Errorf("failed to get armored message body: %w", err) + } + + draft, err := createDraft(ctx, client, apiAddrs, from, to, parentID, liteapi.DraftTemplate{ + Subject: message.Subject, + Body: armBody, + MIMEType: message.MIMEType, + + Sender: message.Sender, + ToList: message.ToList, + CCList: message.CCList, + BCCList: message.BCCList, + + ExternalID: message.ExternalID, + }) + if err != nil { + return liteapi.Message{}, fmt.Errorf("failed to create attachments: %w", err) + } + + attKeys, err := createAttachments(ctx, client, addrKR, draft.ID, message.Attachments) + if err != nil { + return liteapi.Message{}, fmt.Errorf("failed to create attachments: %w", err) + } + + recipients, err := safe.GetTypeErr(userKR, func(userKR *crypto.KeyRing) (recipients, error) { + return getRecipients(ctx, client, userKR, settings, draft) + }) if err != nil { return liteapi.Message{}, fmt.Errorf("failed to get recipients: %w", err) } @@ -260,38 +284,6 @@ func sendWithKey( return res, nil } -func sanitizeParsedMessage(message *message.Message, apiAddrs []liteapi.Address, from string, to []string) error { - // Check sender: set the sender in the parsed message if it's missing. - if message.Sender == nil { - message.Sender = &mail.Address{Address: from} - } else if message.Sender.Address == "" { - message.Sender.Address = from - } - - // Check that the sending address is owned by the user, and if so, properly capitalize it. - if idx := xslices.IndexFunc(apiAddrs, func(addr liteapi.Address) bool { - return strings.EqualFold(addr.Email, sanitizeEmail(message.Sender.Address)) - }); idx < 0 { - return fmt.Errorf("address %q is not owned by user", message.Sender.Address) - } else { - message.Sender.Address = constructEmail(message.Sender.Address, apiAddrs[idx].Email) - } - - // Check ToList: ensure that ToList only contains addresses we actually plan to send to. - message.ToList = xslices.Filter(message.ToList, func(addr *mail.Address) bool { - return slices.Contains(to, addr.Address) - }) - - // Check BCCList: any recipients not present in the ToList or CCList are BCC recipients. - for _, recipient := range to { - if !slices.Contains(message.Recipients(), recipient) { - message.BCCList = append(message.BCCList, &mail.Address{Address: recipient}) - } - } - - return nil -} - func getParentID( ctx context.Context, client *liteapi.Client, @@ -362,48 +354,49 @@ func getParentID( return parentID, nil } -func createDraftWithAttachments( +func createDraft( ctx context.Context, client *liteapi.Client, - addrKR *crypto.KeyRing, - message message.Message, + apiAddrs []liteapi.Address, + from string, + to []string, parentID string, -) (liteapi.Message, map[string]*crypto.SessionKey, error) { - encBody, err := addrKR.Encrypt(crypto.NewPlainMessageFromString(string(message.RichBody)), nil) - if err != nil { - return liteapi.Message{}, nil, fmt.Errorf("failed to encrypt message body: %w", err) + template liteapi.DraftTemplate, +) (liteapi.Message, error) { + // Check sender: set the sender if it's missing. + if template.Sender == nil { + template.Sender = &mail.Address{Address: from} + } else if template.Sender.Address == "" { + template.Sender.Address = from } - armBody, err := encBody.GetArmored() - if err != nil { - return liteapi.Message{}, nil, fmt.Errorf("failed to armor message body: %w", err) + // Check that the sending address is owned by the user, and if so, sanitize it. + if idx := xslices.IndexFunc(apiAddrs, func(addr liteapi.Address) bool { + return strings.EqualFold(addr.Email, sanitizeEmail(template.Sender.Address)) + }); idx < 0 { + return liteapi.Message{}, fmt.Errorf("address %q is not owned by user", template.Sender.Address) + } else { + template.Sender.Address = constructEmail(template.Sender.Address, apiAddrs[idx].Email) } - draft, err := client.CreateDraft(ctx, liteapi.CreateDraftReq{ - Message: liteapi.DraftTemplate{ - Subject: message.Subject, - Sender: message.Sender, - ToList: message.ToList, - CCList: message.CCList, - BCCList: message.BCCList, - Body: armBody, - MIMEType: message.MIMEType, + // Check ToList: ensure that ToList only contains addresses we actually plan to send to. + template.ToList = xslices.Filter(template.ToList, func(addr *mail.Address) bool { + return slices.Contains(to, addr.Address) + }) - ExternalID: message.ExternalID, - }, + // Check BCCList: any recipients not present in the ToList or CCList are BCC recipients. + for _, recipient := range to { + if !slices.Contains(xslices.Map(xslices.Join(template.ToList, template.CCList, template.BCCList), func(addr *mail.Address) string { + return addr.Address + }), recipient) { + template.BCCList = append(template.BCCList, &mail.Address{Address: recipient}) + } + } + return client.CreateDraft(ctx, liteapi.CreateDraftReq{ + Message: template, ParentID: parentID, }) - if err != nil { - return liteapi.Message{}, nil, fmt.Errorf("failed to create draft: %w", err) - } - - attKeys, err := createAttachments(ctx, client, addrKR, draft.ID, message.Attachments) - if err != nil { - return liteapi.Message{}, nil, fmt.Errorf("failed to create attachments: %w", err) - } - - return draft, attKeys, nil } func createAttachments( @@ -473,11 +466,24 @@ func getRecipients( client *liteapi.Client, userKR *crypto.KeyRing, settings liteapi.MailSettings, - addresses []string, - mimeType rfc822.MIMEType, + draft liteapi.Message, ) (recipients, error) { - prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, address string) (liteapi.SendPreferences, error) { - return getSendPrefs(ctx, client, userKR, settings, address, mimeType) + addresses := xslices.Map(xslices.Join(draft.ToList, draft.CCList, draft.BCCList), func(addr *mail.Address) string { + return addr.Address + }) + + prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (liteapi.SendPreferences, error) { + pubKeys, recType, err := client.GetPublicKeys(ctx, recipient) + if err != nil { + return liteapi.SendPreferences{}, fmt.Errorf("failed to get public keys: %w", err) + } + + contactSettings, err := getContactSettings(ctx, client, userKR, recipient) + if err != nil { + return liteapi.SendPreferences{}, fmt.Errorf("failed to get contact settings: %w", err) + } + + return buildSendPrefs(contactSettings, settings, pubKeys, draft.MIMEType, recType == liteapi.RecipientTypeInternal) }) if err != nil { return nil, fmt.Errorf("failed to get send preferences: %w", err) @@ -492,27 +498,6 @@ func getRecipients( return recipients, nil } -func getSendPrefs( - ctx context.Context, - client *liteapi.Client, - userKR *crypto.KeyRing, - settings liteapi.MailSettings, - recipient string, - mimeType rfc822.MIMEType, -) (liteapi.SendPreferences, error) { - pubKeys, recType, err := client.GetPublicKeys(ctx, recipient) - if err != nil { - return liteapi.SendPreferences{}, fmt.Errorf("failed to get public keys: %w", err) - } - - contactSettings, err := getContactSettings(ctx, client, userKR, recipient) - if err != nil { - return liteapi.SendPreferences{}, fmt.Errorf("failed to get contact settings: %w", err) - } - - return buildSendPrefs(contactSettings, settings, pubKeys, mimeType, recType == liteapi.RecipientTypeInternal) -} - func getContactSettings( ctx context.Context, client *liteapi.Client, diff --git a/internal/user/sync.go b/internal/user/sync.go index 1a64c9bb..ace40437 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -9,6 +9,8 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/queue" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" @@ -124,7 +126,16 @@ func (user *User) syncMessages(ctx context.Context) error { return metadata.ID })...), func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) { - return buildRFC822(ctx, full, user.addrKRs) + return safe.GetMapErr( + user.addrKRs, + full.AddressID, + func(addrKR *crypto.KeyRing) (*buildRes, error) { + return buildRFC822(ctx, full, addrKR) + }, + func() (*buildRes, error) { + return nil, fmt.Errorf("address keyring not found") + }, + ) }, ), maxBatchSize) defer buildCh.Close() diff --git a/internal/user/sync_build.go b/internal/user/sync_build.go index 73544e55..d16e636f 100644 --- a/internal/user/sync_build.go +++ b/internal/user/sync_build.go @@ -30,8 +30,8 @@ func defaultJobOpts() message.JobOptions { } } -func buildRFC822(ctx context.Context, full liteapi.FullMessage, addrKRs map[string]*crypto.KeyRing) (*buildRes, error) { - literal, err := message.BuildRFC822(addrKRs[full.AddressID], full.Message, full.AttData, defaultJobOpts()) +func buildRFC822(ctx context.Context, full liteapi.FullMessage, addrKR *crypto.KeyRing) (*buildRes, error) { + literal, err := message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()) if err != nil { return nil, fmt.Errorf("failed to build message %s: %w", full.ID, err) } diff --git a/internal/user/user.go b/internal/user/user.go index 113c5a79..90514522 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -35,8 +35,8 @@ type User struct { apiAddrs *safe.Slice[liteapi.Address] settings *safe.Value[liteapi.MailSettings] - userKR *crypto.KeyRing - addrKRs map[string]*crypto.KeyRing + userKR *safe.Value[*crypto.KeyRing] + addrKRs *safe.Map[string, *crypto.KeyRing] updateCh map[string]*queue.QueuedChannel[imap.Update] syncStopCh chan struct{} @@ -94,8 +94,8 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU apiAddrs: safe.NewSlice(apiAddrs), settings: safe.NewValue(settings), - userKR: userKR, - addrKRs: addrKRs, + userKR: safe.NewValue(userKR), + addrKRs: safe.NewMap(addrKRs), updateCh: updateCh, syncStopCh: make(chan struct{}), diff --git a/pkg/message/parser.go b/pkg/message/parser.go index 5422bd6c..ac85d45f 100644 --- a/pkg/message/parser.go +++ b/pkg/message/parser.go @@ -60,18 +60,6 @@ type Message struct { ExternalID string } -func (m *Message) Recipients() []string { - var recipients []string - - for _, addresses := range [][]*mail.Address{m.ToList, m.CCList, m.BCCList} { - recipients = append(recipients, xslices.Map(addresses, func(address *mail.Address) string { - return address.Address - })...) - } - - return recipients -} - type Attachment struct { Header mail.Header Name string diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index dcb0ef16..23fd6dca 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -38,11 +38,13 @@ func (t *testCtx) startBridge() error { return fmt.Errorf("vault is corrupt") } + // Create the underlying cookie jar. jar, err := cookiejar.New(nil) if err != nil { return err } + // Create the persisting cookie jar. persister, err := cookies.NewCookieJar(jar, vault) if err != nil { return err