mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-23 10:26:44 +00:00
feat(GODT-2799): SMTP Service
Refactor code to isolate the SMTP functionality in a dedicated SMTP service for each user as discussed in the Bridge Service Architecture RFC. Some shared types have been moved from `user` to `usertypes` so that they can be shared with Service and User Code. Finally due to lack of recursive imports, the user data SMTP needs access to is hidden behind an interface until the User Identity service is implemented.
This commit is contained in:
@ -35,6 +35,7 @@ import (
|
||||
"github.com/ProtonMail/gopenpgp/v2/constants"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/bradenaw/juniper/xmaps"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
@ -62,7 +63,7 @@ func (apm DiagnosticMetadata) BuildMailboxToMessageMap(user *User) (map[string]A
|
||||
result := make(map[string]AccountMailboxMap)
|
||||
|
||||
mode := user.GetAddressMode()
|
||||
primaryAddrID, err := getPrimaryAddr(user.apiAddrs)
|
||||
primaryAddrID, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get primary addr for user: %w", err)
|
||||
}
|
||||
@ -178,7 +179,7 @@ func (user *User) DebugDownloadMessages(
|
||||
return fmt.Errorf("failed to create directory '%v':%w", msgDir, err)
|
||||
}
|
||||
|
||||
message, err := user.client.GetFullMessage(ctx, msg.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
message, err := user.client.GetFullMessage(ctx, msg.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download message '%v':%w", msg.ID, err)
|
||||
}
|
||||
@ -187,7 +188,7 @@ func (user *User) DebugDownloadMessages(
|
||||
return err
|
||||
}
|
||||
|
||||
if err := withAddrKR(user.apiUser, user.apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
switch {
|
||||
case len(message.Attachments) > 0:
|
||||
return decodeMultipartMessage(msgDir, addrKR, message.Message, message.AttData)
|
||||
|
||||
@ -20,8 +20,6 @@ package user
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNoSuchAddress = errors.New("no such address")
|
||||
ErrInvalidReturnPath = errors.New("invalid return path")
|
||||
ErrInvalidRecipient = errors.New("invalid recipient")
|
||||
ErrMissingAddrKey = errors.New("missing address key")
|
||||
ErrNoSuchAddress = errors.New("no such address")
|
||||
ErrMissingAddrKey = errors.New("missing address key")
|
||||
)
|
||||
|
||||
@ -34,6 +34,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -114,8 +115,8 @@ func (user *User) syncUserAddressesLabelsAndClearSync(ctx context.Context, cance
|
||||
|
||||
// Update the API info in the user.
|
||||
user.apiUser = apiUser
|
||||
user.apiAddrs = groupBy(apiAddrs, func(addr proton.Address) string { return addr.ID })
|
||||
user.apiLabels = groupBy(apiLabels, func(label proton.Label) string { return label.ID })
|
||||
user.apiAddrs = usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID })
|
||||
user.apiLabels = usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID })
|
||||
|
||||
// Clear sync status; we want to sync everything again.
|
||||
if err := user.clearSyncStatus(); err != nil {
|
||||
@ -208,7 +209,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
|
||||
// If the address is enabled, we need to hook it up to the update channels.
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := getPrimaryAddr(user.apiAddrs)
|
||||
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
@ -267,7 +268,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
|
||||
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := getPrimaryAddr(user.apiAddrs)
|
||||
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
@ -585,7 +586,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
|
||||
"subject": logging.Sensitive(message.Subject),
|
||||
}).Info("Handling message created event")
|
||||
|
||||
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
full, err := user.client.GetFullMessage(ctx, message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
@ -599,7 +600,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
|
||||
return safe.RLockRetErr(func() ([]imap.Update, error) {
|
||||
var update imap.Update
|
||||
|
||||
if err := withAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
|
||||
|
||||
if res.err != nil {
|
||||
@ -652,7 +653,7 @@ func (user *User) handleUpdateMessageEvent(_ context.Context, message proton.Mes
|
||||
|
||||
update := imap.NewMessageMailboxesUpdated(
|
||||
imap.MessageID(message.ID),
|
||||
mapTo[string, imap.MailboxID](wantLabels(user.apiLabels, message.LabelIDs)),
|
||||
usertypes.MapTo[string, imap.MailboxID](wantLabels(user.apiLabels, message.LabelIDs)),
|
||||
flags,
|
||||
)
|
||||
|
||||
@ -693,7 +694,7 @@ func (user *User) handleUpdateDraftOrSentMessage(ctx context.Context, event prot
|
||||
"isDraft": event.Message.IsDraft(),
|
||||
}).Info("Handling draft or sent updated event")
|
||||
|
||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
@ -706,7 +707,7 @@ func (user *User) handleUpdateDraftOrSentMessage(ctx context.Context, event prot
|
||||
|
||||
var update imap.Update
|
||||
|
||||
if err := withAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
|
||||
|
||||
if res.err != nil {
|
||||
@ -827,7 +828,7 @@ func safePublishMessageUpdate(user *User, addressID string, update imap.Update)
|
||||
v, ok := user.updateCh[addressID]
|
||||
if !ok {
|
||||
if user.GetAddressMode() == vault.CombinedMode {
|
||||
primAddr, err := getPrimaryAddr(user.apiAddrs)
|
||||
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
|
||||
@ -33,6 +33,8 @@ import (
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
|
||||
@ -288,26 +290,26 @@ func (conn *imapConnector) CreateMessage(
|
||||
}
|
||||
|
||||
// Compute the hash of the message (to match it against SMTP messages).
|
||||
hash, err := getMessageHash(literal)
|
||||
hash, err := sendrecorder.GetMessageHash(literal)
|
||||
if err != nil {
|
||||
return imap.Message{}, nil, err
|
||||
}
|
||||
|
||||
// Check if we already tried to send this message recently.
|
||||
if messageID, ok, err := conn.sendHash.hasEntryWait(ctx, hash, time.Now().Add(90*time.Second), toList); err != nil {
|
||||
if messageID, ok, err := conn.sendHash.HasEntryWait(ctx, hash, time.Now().Add(90*time.Second), toList); err != nil {
|
||||
return imap.Message{}, nil, fmt.Errorf("failed to check send hash: %w", err)
|
||||
} else if ok {
|
||||
conn.log.WithField("messageID", messageID).Warn("Message already sent")
|
||||
|
||||
// Query the server-side message.
|
||||
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
full, err := conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
|
||||
// Build the message as it is on the server.
|
||||
if err := safe.RLockRet(func() error {
|
||||
return withAddrKR(conn.apiUser, conn.apiAddrs[full.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[full.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
var err error
|
||||
|
||||
if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil {
|
||||
@ -378,14 +380,14 @@ func (conn *imapConnector) CreateMessage(
|
||||
}
|
||||
|
||||
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
|
||||
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
msg, err := conn.client.GetFullMessage(ctx, string(id), usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return safe.RLockRetErr(func() ([]byte, error) {
|
||||
var literal []byte
|
||||
err := withAddrKR(conn.apiUser, conn.apiAddrs[msg.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
err := usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[msg.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
l, buildErr := message.BuildRFC822(addrKR, msg.Message, msg.AttData, defaultJobOpts())
|
||||
if buildErr != nil {
|
||||
return buildErr
|
||||
@ -408,7 +410,7 @@ func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs
|
||||
return connector.ErrOperationNotAllowed
|
||||
}
|
||||
|
||||
return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(mailboxID))
|
||||
return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(mailboxID))
|
||||
}
|
||||
|
||||
// RemoveMessagesFromMailbox unlabels the given messages with the given label ID.
|
||||
@ -419,7 +421,7 @@ func (conn *imapConnector) RemoveMessagesFromMailbox(ctx context.Context, messag
|
||||
return connector.ErrOperationNotAllowed
|
||||
}
|
||||
|
||||
msgIDs := mapTo[imap.MessageID, string](messageIDs)
|
||||
msgIDs := usertypes.MapTo[imap.MessageID, string](messageIDs)
|
||||
if err := conn.client.UnlabelMessages(ctx, msgIDs, string(mailboxID)); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -461,12 +463,12 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
|
||||
return result
|
||||
}()
|
||||
|
||||
if err := conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
|
||||
if err := conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
|
||||
return false, fmt.Errorf("labeling messages: %w", err)
|
||||
}
|
||||
|
||||
if shouldExpungeOldLocation {
|
||||
if err := conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
|
||||
if err := conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
|
||||
return false, fmt.Errorf("unlabeling messages: %w", err)
|
||||
}
|
||||
}
|
||||
@ -479,10 +481,10 @@ func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []im
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if seen {
|
||||
return conn.client.MarkMessagesRead(ctx, mapTo[imap.MessageID, string](messageIDs)...)
|
||||
return conn.client.MarkMessagesRead(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...)
|
||||
}
|
||||
|
||||
return conn.client.MarkMessagesUnread(ctx, mapTo[imap.MessageID, string](messageIDs)...)
|
||||
return conn.client.MarkMessagesUnread(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...)
|
||||
}
|
||||
|
||||
// MarkMessagesFlagged sets the flagged value of the given messages.
|
||||
@ -490,10 +492,10 @@ func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs [
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if flagged {
|
||||
return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
|
||||
return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
|
||||
}
|
||||
|
||||
return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
|
||||
return conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
|
||||
}
|
||||
|
||||
// GetUpdates returns a stream of updates that the gluon server should apply.
|
||||
@ -535,7 +537,7 @@ func (conn *imapConnector) importMessage(
|
||||
var full proton.FullMessage
|
||||
|
||||
if err := safe.RLockRet(func() error {
|
||||
return withAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
var messageID string
|
||||
|
||||
if slices.Contains(labelIDs, proton.DraftsLabel) {
|
||||
@ -571,7 +573,7 @@ func (conn *imapConnector) importMessage(
|
||||
|
||||
var err error
|
||||
|
||||
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
|
||||
if full, err = conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
|
||||
return fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@ -1,68 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func withAddrKR(apiUser proton.User, apiAddr proton.Address, keyPass []byte, fn func(userKR, addrKR *crypto.KeyRing) error) error {
|
||||
userKR, err := apiUser.Keys.Unlock(keyPass, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock user keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
|
||||
addrKR, err := apiAddr.Keys.Unlock(keyPass, userKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
defer addrKR.ClearPrivateParams()
|
||||
|
||||
return fn(userKR, addrKR)
|
||||
}
|
||||
|
||||
func withAddrKRs(apiUser proton.User, apiAddr map[string]proton.Address, keyPass []byte, fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
|
||||
userKR, err := apiUser.Keys.Unlock(keyPass, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock user keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
|
||||
addrKRs := make(map[string]*crypto.KeyRing, len(apiAddr))
|
||||
|
||||
for addrID, apiAddr := range apiAddr {
|
||||
addrKR, err := apiAddr.Keys.Unlock(keyPass, userKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
defer addrKR.ClearPrivateParams()
|
||||
|
||||
if addrKR.CountDecryptionEntities() == 0 {
|
||||
logrus.WithField("addressID", addrID).Warn("Address keyring has no decryption entities")
|
||||
}
|
||||
|
||||
addrKRs[addrID] = addrKR
|
||||
}
|
||||
|
||||
return fn(userKR, addrKRs)
|
||||
}
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/go-proton-api/server"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -36,7 +37,7 @@ func BenchmarkAddrKeyRing(b *testing.B) {
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
require.NoError(b, withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
require.NoError(b, usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
@ -1,291 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const sendEntryExpiry = 30 * time.Minute
|
||||
|
||||
type SendRecorderID uint64
|
||||
|
||||
type sendRecorder struct {
|
||||
expiry time.Duration
|
||||
|
||||
entries map[string][]*sendEntry
|
||||
entriesLock sync.Mutex
|
||||
cancelIDCounter uint64
|
||||
}
|
||||
|
||||
func newSendRecorder(expiry time.Duration) *sendRecorder {
|
||||
return &sendRecorder{
|
||||
expiry: expiry,
|
||||
entries: make(map[string][]*sendEntry),
|
||||
}
|
||||
}
|
||||
|
||||
type sendEntry struct {
|
||||
srID SendRecorderID
|
||||
msgID string
|
||||
toList []string
|
||||
exp time.Time
|
||||
waitCh chan struct{}
|
||||
waitChClosed bool
|
||||
}
|
||||
|
||||
func (s *sendEntry) closeWaitChannel() {
|
||||
if !s.waitChClosed {
|
||||
close(s.waitCh)
|
||||
s.waitChClosed = true
|
||||
}
|
||||
}
|
||||
|
||||
// tryInsertWait tries to insert the given message into the send recorder.
|
||||
// If an entry already exists but it was not sent yet, it waits.
|
||||
// It returns whether an entry could be inserted and an error if it times out while waiting.
|
||||
func (h *sendRecorder) tryInsertWait(
|
||||
ctx context.Context,
|
||||
hash string,
|
||||
toList []string,
|
||||
deadline time.Time,
|
||||
) (SendRecorderID, bool, error) {
|
||||
// If we successfully inserted the hash, we can return true.
|
||||
srID, waitCh, ok := h.tryInsert(hash, toList)
|
||||
if ok {
|
||||
return srID, true, nil
|
||||
}
|
||||
|
||||
// A message with this hash is already being sent; wait for it.
|
||||
_, wasSent, err := h.wait(ctx, hash, waitCh, srID, deadline)
|
||||
if err != nil {
|
||||
return 0, false, fmt.Errorf("failed to wait for message to be sent: %w", err)
|
||||
}
|
||||
|
||||
// If the message failed to send, try to insert it again.
|
||||
if !wasSent {
|
||||
return h.tryInsertWait(ctx, hash, toList, deadline)
|
||||
}
|
||||
|
||||
return srID, false, nil
|
||||
}
|
||||
|
||||
// hasEntryWait returns whether the given message already exists in the send recorder.
|
||||
// If it does, it waits for its ID to be known, then returns it and true.
|
||||
// If no entry exists, or it times out while waiting for its ID to be known, it returns false.
|
||||
func (h *sendRecorder) hasEntryWait(ctx context.Context,
|
||||
hash string,
|
||||
deadline time.Time,
|
||||
toList []string,
|
||||
) (string, bool, error) {
|
||||
srID, waitCh, found := h.getEntryWaitInfo(hash, toList)
|
||||
if !found {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
messageID, wasSent, err := h.wait(ctx, hash, waitCh, srID, deadline)
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return "", false, nil
|
||||
} else if err != nil {
|
||||
return "", false, fmt.Errorf("failed to wait for message to be sent: %w", err)
|
||||
}
|
||||
|
||||
if wasSent {
|
||||
return messageID, true, nil
|
||||
}
|
||||
|
||||
return h.hasEntryWait(ctx, hash, deadline, toList)
|
||||
}
|
||||
|
||||
func (h *sendRecorder) removeExpiredUnsafe() {
|
||||
for hash, entry := range h.entries {
|
||||
remaining := xslices.Filter(entry, func(t *sendEntry) bool {
|
||||
return !t.exp.Before(time.Now())
|
||||
})
|
||||
|
||||
if len(remaining) == 0 {
|
||||
delete(h.entries, hash)
|
||||
} else {
|
||||
h.entries[hash] = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sendRecorder) tryInsert(hash string, toList []string) (SendRecorderID, <-chan struct{}, bool) {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
h.removeExpiredUnsafe()
|
||||
|
||||
entries, ok := h.entries[hash]
|
||||
if ok {
|
||||
for _, entry := range entries {
|
||||
if matchToList(entry.toList, toList) {
|
||||
return entry.srID, entry.waitCh, false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cancelID := h.newSendRecorderID()
|
||||
waitCh := make(chan struct{})
|
||||
|
||||
h.entries[hash] = append(entries, &sendEntry{
|
||||
srID: cancelID,
|
||||
exp: time.Now().Add(h.expiry),
|
||||
toList: toList,
|
||||
waitCh: waitCh,
|
||||
})
|
||||
|
||||
return cancelID, waitCh, true
|
||||
}
|
||||
|
||||
func (h *sendRecorder) getEntryWaitInfo(hash string, toList []string) (SendRecorderID, <-chan struct{}, bool) {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
h.removeExpiredUnsafe()
|
||||
|
||||
if entries, ok := h.entries[hash]; ok {
|
||||
for _, e := range entries {
|
||||
if matchToList(e.toList, toList) {
|
||||
return e.srID, e.waitCh, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil, false
|
||||
}
|
||||
|
||||
// signalMessageSent should be called after a message has been successfully sent.
|
||||
func (h *sendRecorder) signalMessageSent(hash string, srID SendRecorderID, msgID string) {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
entries, ok := h.entries[hash]
|
||||
if ok {
|
||||
for _, entry := range entries {
|
||||
if entry.srID == srID {
|
||||
entry.msgID = msgID
|
||||
entry.closeWaitChannel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
|
||||
}
|
||||
|
||||
func (h *sendRecorder) removeOnFail(hash string, id SendRecorderID) {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
entries, ok := h.entries[hash]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for idx, entry := range entries {
|
||||
if entry.srID == id && entry.msgID == "" {
|
||||
entry.closeWaitChannel()
|
||||
|
||||
remaining := xslices.Remove(entries, idx, 1)
|
||||
if len(remaining) != 0 {
|
||||
h.entries[hash] = remaining
|
||||
} else {
|
||||
delete(h.entries, hash)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sendRecorder) wait(
|
||||
ctx context.Context,
|
||||
hash string,
|
||||
waitCh <-chan struct{},
|
||||
srID SendRecorderID,
|
||||
deadline time.Time,
|
||||
) (string, bool, error) {
|
||||
ctx, cancel := context.WithDeadline(ctx, deadline)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", false, ctx.Err()
|
||||
|
||||
case <-waitCh:
|
||||
// ...
|
||||
}
|
||||
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
if entry, ok := h.entries[hash]; ok {
|
||||
for _, e := range entry {
|
||||
if e.srID == srID {
|
||||
return e.msgID, true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func (h *sendRecorder) newSendRecorderID() SendRecorderID {
|
||||
h.cancelIDCounter++
|
||||
return SendRecorderID(h.cancelIDCounter)
|
||||
}
|
||||
|
||||
// getMessageHash returns the hash of the given message.
|
||||
// This takes into account:
|
||||
// - the Subject header,
|
||||
// - the From/To/Cc headers,
|
||||
// - the Content-Type header of each (leaf) part,
|
||||
// - the Content-Disposition header of each (leaf) part,
|
||||
// - the (decoded) body of each part.
|
||||
func getMessageHash(b []byte) (string, error) {
|
||||
return rfc822.GetMessageHash(b)
|
||||
}
|
||||
|
||||
func matchToList(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range a {
|
||||
if !slices.Contains(b, a[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for i := range b {
|
||||
if !slices.Contains(a, b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@ -1,471 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSendHasher_Insert(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srdID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash1)
|
||||
|
||||
// Simulate successfully sending the message.
|
||||
h.signalMessageSent(hash1, srdID1, "abc")
|
||||
|
||||
// Inserting a message with the same hash should return false.
|
||||
srdID2, _, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
require.Equal(t, srdID1, srdID2)
|
||||
|
||||
// Inserting a message with a different hash should return true.
|
||||
srdID3, hash2, ok, err := testTryInsert(h, literal2, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash2)
|
||||
require.NotEqual(t, srdID3, srdID1)
|
||||
}
|
||||
|
||||
func TestSendHasher_Insert_Expired(t *testing.T) {
|
||||
h := newSendRecorder(time.Second)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash1)
|
||||
|
||||
// Simulate successfully sending the message.
|
||||
h.signalMessageSent(hash1, srID1, "abc")
|
||||
|
||||
// Wait for the entry to expire.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Inserting a message with the same hash should return true because the previous entry has since expired.
|
||||
srID2, hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
|
||||
// The hashes should be the same.
|
||||
require.Equal(t, hash1, hash2)
|
||||
|
||||
// Send IDs should differ
|
||||
require.NotEqual(t, srID2, srID1)
|
||||
}
|
||||
|
||||
func TestSendHasher_Insert_DifferentToList(t *testing.T) {
|
||||
h := newSendRecorder(time.Second)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), []string{"abc", "def"}...)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash1)
|
||||
|
||||
// Insert the same message into the hasher but with a different to list.
|
||||
srID2, hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), []string{"abc", "def", "ghi"}...)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash2)
|
||||
require.NotEqual(t, srID1, srID2)
|
||||
|
||||
// Insert the same message into the hasher but with the same to list.
|
||||
_, _, ok, err = testTryInsert(h, literal1, time.Now().Add(time.Second), []string{"abc", "def", "ghi"}...)
|
||||
require.Error(t, err)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSendHasher_Wait_SendSuccess(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Simulate successfully sending the message after half a second.
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
h.signalMessageSent(hash, srID1, "abc")
|
||||
}()
|
||||
|
||||
// Inserting a message with the same hash should fail.
|
||||
srID2, _, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
require.Equal(t, srID1, srID2)
|
||||
}
|
||||
|
||||
func TestSendHasher_Wait_SendFail(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Simulate failing to send the message after half a second.
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
h.removeOnFail(hash, srID1)
|
||||
}()
|
||||
|
||||
// Inserting a message with the same hash should succeed because the first message failed to send.
|
||||
srID2, hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, srID2, srID1)
|
||||
|
||||
// The hashes should be the same.
|
||||
require.Equal(t, hash, hash2)
|
||||
}
|
||||
|
||||
func TestSendHasher_Wait_Timeout(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// We should fail to insert because the message is not sent within the timeout period.
|
||||
_, _, _, err = testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSendHasher_HasEntry(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Simulate successfully sending the message.
|
||||
h.signalMessageSent(hash, srID1, "abc")
|
||||
|
||||
// The message was already sent; we should find it in the hasher.
|
||||
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "abc", messageID)
|
||||
}
|
||||
|
||||
func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Simulate successfully sending the message after half a second.
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
h.signalMessageSent(hash, srID1, "abc")
|
||||
}()
|
||||
|
||||
// The message was already sent; we should find it in the hasher.
|
||||
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "abc", messageID)
|
||||
}
|
||||
|
||||
func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
|
||||
// There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message
|
||||
// is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be
|
||||
// inserted as a new entry. The two clients end up sending the message twice and calling the `signalMessageSent` x2,
|
||||
// resulting in a crash.
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections
|
||||
// to attempt to send the same message.
|
||||
h.signalMessageSent(hash, srID1, "abc")
|
||||
h.signalMessageSent(hash, srID1, "abc")
|
||||
|
||||
// The message was already sent; we should find it in the hasher.
|
||||
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "abc", messageID)
|
||||
}
|
||||
|
||||
func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver <receiver@pm.me>")
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
srID2, hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver <receiver@pm.me>", "Receiver2 <receiver2@pm.me>")
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash2)
|
||||
require.Equal(t, hash, hash2)
|
||||
|
||||
// Should map to different requests
|
||||
require.NotEqual(t, srID2, srID1)
|
||||
}
|
||||
|
||||
func TestSendHashed_SameMessageWIthDifferentToListShouldWaitSuccessfullyAfterSend(t *testing.T) {
|
||||
// Check that if we send the same message twice with different recipients and the second message is somehow
|
||||
// sent before the first, ensure that we check if the message was sent we wait on the correct object.
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Minute), "Receiver <receiver@pm.me>")
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
srID2, hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Microsecond), "Receiver <receiver@pm.me>", "Receiver2 <receiver2@pm.me>")
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash2)
|
||||
require.Equal(t, hash, hash2)
|
||||
|
||||
// simulate message sent
|
||||
h.signalMessageSent(hash2, srID2, "newID")
|
||||
|
||||
// Simulate Wait on message 2
|
||||
_, ok, err = h.hasEntryWait(context.Background(), hash2, time.Now().Add(time.Second), []string{"Receiver <receiver@pm.me>", "Receiver2 <receiver2@pm.me>"})
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestSendHasher_HasEntry_SendFail(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Simulate failing to send the message after half a second.
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
h.removeOnFail(hash, srID1)
|
||||
}()
|
||||
|
||||
// The message failed to send; we should not find it in the hasher.
|
||||
_, ok, err = testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSendHasher_HasEntry_Timeout(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// The message is never sent; we should not find it in the hasher.
|
||||
_, ok, err = testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSendHasher_HasEntry_Expired(t *testing.T) {
|
||||
h := newSendRecorder(time.Second)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Simulate successfully sending the message.
|
||||
h.signalMessageSent(hash, srID1, "abc")
|
||||
|
||||
// Wait for the entry to expire.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// The entry has expired; we should not find it in the hasher.
|
||||
_, ok, err = testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
const literal1 = `From: Sender <sender@pm.me>
|
||||
To: Receiver <receiver@pm.me>
|
||||
Content-Type: multipart/mixed; boundary=longrandomstring
|
||||
|
||||
--longrandomstring
|
||||
|
||||
body
|
||||
--longrandomstring
|
||||
Content-Disposition: attachment; filename="attname.txt"
|
||||
|
||||
attachment
|
||||
--longrandomstring--
|
||||
`
|
||||
const literal2 = `From: Sender <sender@pm.me>
|
||||
To: Receiver <receiver@pm.me>
|
||||
Content-Type: multipart/mixed; boundary=longrandomstring
|
||||
|
||||
--longrandomstring
|
||||
|
||||
body
|
||||
--longrandomstring
|
||||
Content-Disposition: attachment; filename="attname2.txt"
|
||||
|
||||
attachment
|
||||
--longrandomstring--
|
||||
`
|
||||
|
||||
func TestGetMessageHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lit1, lit2 []byte
|
||||
wantEqual bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
lit1: []byte{},
|
||||
lit2: []byte{},
|
||||
wantEqual: true,
|
||||
},
|
||||
{
|
||||
name: "same to",
|
||||
lit1: []byte("To: someone@pm.me\r\n\r\nHello world!"),
|
||||
lit2: []byte("To: someone@pm.me\r\n\r\nHello world!"),
|
||||
wantEqual: true,
|
||||
},
|
||||
{
|
||||
name: "different to",
|
||||
lit1: []byte("To: someone@pm.me\r\n\r\nHello world!"),
|
||||
lit2: []byte("To: another@pm.me\r\n\r\nHello world!"),
|
||||
wantEqual: false,
|
||||
},
|
||||
{
|
||||
name: "same from",
|
||||
lit1: []byte("From: someone@pm.me\r\n\r\nHello world!"),
|
||||
lit2: []byte("From: someone@pm.me\r\n\r\nHello world!"),
|
||||
wantEqual: true,
|
||||
},
|
||||
{
|
||||
name: "different from",
|
||||
lit1: []byte("From: someone@pm.me\r\n\r\nHello world!"),
|
||||
lit2: []byte("From: another@pm.me\r\n\r\nHello world!"),
|
||||
wantEqual: false,
|
||||
},
|
||||
{
|
||||
name: "same subject",
|
||||
lit1: []byte("Subject: Hello world!\r\n\r\nHello world!"),
|
||||
lit2: []byte("Subject: Hello world!\r\n\r\nHello world!"),
|
||||
wantEqual: true,
|
||||
},
|
||||
{
|
||||
name: "different subject",
|
||||
lit1: []byte("Subject: Hello world!\r\n\r\nHello world!"),
|
||||
lit2: []byte("Subject: Goodbye world!\r\n\r\nHello world!"),
|
||||
wantEqual: false,
|
||||
},
|
||||
{
|
||||
name: "same plaintext body",
|
||||
lit1: []byte("To: someone@pm.me\r\nContent-Type: text/plain\r\n\r\nHello world!"),
|
||||
lit2: []byte("To: someone@pm.me\r\nContent-Type: text/plain\r\n\r\nHello world!"),
|
||||
wantEqual: true,
|
||||
},
|
||||
{
|
||||
name: "different plaintext body",
|
||||
lit1: []byte("To: someone@pm.me\r\nContent-Type: text/plain\r\n\r\nHello world!"),
|
||||
lit2: []byte("To: someone@pm.me\r\nContent-Type: text/plain\r\n\r\nGoodbye world!"),
|
||||
wantEqual: false,
|
||||
},
|
||||
{
|
||||
name: "different attachment filenames",
|
||||
lit1: []byte(literal1),
|
||||
lit2: []byte(literal2),
|
||||
wantEqual: false,
|
||||
},
|
||||
{
|
||||
name: "different date and message ID should still match",
|
||||
lit1: []byte("To: a@b.c\r\nDate: Fri, 13 Aug 1982\r\nMessage-Id: 1@b.c\r\n\r\nHello"),
|
||||
lit2: []byte("To: a@b.c\r\nDate: Sat, 14 Aug 1982\r\nMessage-Id: 2@b.c\r\n\r\nHello"),
|
||||
wantEqual: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash1, err := getMessageHash(tt.lit1)
|
||||
require.NoError(t, err)
|
||||
|
||||
hash2, err := getMessageHash(tt.lit2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.wantEqual {
|
||||
require.Equal(t, hash1, hash2)
|
||||
} else {
|
||||
require.NotEqual(t, hash1, hash2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testTryInsert(h *sendRecorder, literal string, deadline time.Time, toList ...string) (SendRecorderID, string, bool, error) { //nolint:unparam
|
||||
hash, err := getMessageHash([]byte(literal))
|
||||
if err != nil {
|
||||
return 0, "", false, err
|
||||
}
|
||||
|
||||
srID, ok, err := h.tryInsertWait(context.Background(), hash, toList, deadline)
|
||||
if err != nil {
|
||||
return 0, "", false, err
|
||||
}
|
||||
|
||||
return srID, hash, ok, nil
|
||||
}
|
||||
|
||||
func testHasEntry(h *sendRecorder, literal string, deadline time.Time, toList ...string) (string, bool, error) { //nolint:unparam
|
||||
hash, err := getMessageHash([]byte(literal))
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
return h.hasEntryWait(context.Background(), hash, deadline, toList)
|
||||
}
|
||||
@ -1,586 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/mail"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/reporter"
|
||||
"github.com/ProtonMail/gluon/rfc5322"
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// sendMail sends an email from the given address to the given recipients.
|
||||
func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error {
|
||||
defer async.HandlePanic(user.panicHandler)
|
||||
|
||||
return safe.RLockRet(func() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if _, err := getAddrID(user.apiAddrs, from); err != nil {
|
||||
return ErrInvalidReturnPath
|
||||
}
|
||||
|
||||
emails := xslices.Map(maps.Values(user.apiAddrs), func(addr proton.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
|
||||
// Read the message to send.
|
||||
b, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read message: %w", err)
|
||||
}
|
||||
|
||||
// If running a QA build, dump to disk.
|
||||
if err := debugDumpToDisk(b); err != nil {
|
||||
user.log.WithError(err).Warn("Failed to dump message to disk")
|
||||
}
|
||||
|
||||
// Compute the hash of the message (to match it against SMTP messages).
|
||||
hash, err := getMessageHash(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if we already tried to send this message recently.
|
||||
srID, ok, err := user.sendHash.tryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check send hash: %w", err)
|
||||
} else if !ok {
|
||||
user.log.Warn("A duplicate message was already sent recently, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we fail to send this message, we should remove the hash from the send recorder.
|
||||
defer user.sendHash.removeOnFail(hash, srID)
|
||||
|
||||
// Create a new message parser from the reader.
|
||||
parser, err := parser.New(bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create parser: %w", err)
|
||||
}
|
||||
|
||||
// If the message contains a sender, use it instead of the one from the return path.
|
||||
if sender, ok := getMessageSender(parser); ok {
|
||||
from = sender
|
||||
}
|
||||
|
||||
// Load the user's mail settings.
|
||||
settings, err := user.client.GetMailSettings(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get mail settings: %w", err)
|
||||
}
|
||||
|
||||
addrID, err := getAddrID(user.apiAddrs, from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return withAddrKR(user.apiUser, user.apiAddrs[addrID], user.vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error {
|
||||
// Use the first key for encrypting the message.
|
||||
addrKR, err := addrKR.FirstKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get first key: %w", err)
|
||||
}
|
||||
|
||||
// Ensure that there is always a text/html or text/plain body part. This is required by the API. If none
|
||||
// exists and empty text part will be added.
|
||||
parser.AttachEmptyTextPartIfNoneExists()
|
||||
|
||||
// If we have to attach the public key, do it now.
|
||||
if settings.AttachPublicKey {
|
||||
key, err := addrKR.GetKey(0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get sending key: %w", err)
|
||||
}
|
||||
|
||||
pubKey, err := key.GetArmoredPublicKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
|
||||
}
|
||||
|
||||
// Parse the message we want to send (after we have attached the public key).
|
||||
message, err := message.ParseWithParser(parser, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse message: %w", err)
|
||||
}
|
||||
|
||||
// Send the message using the correct key.
|
||||
sent, err := user.sendWithKey(
|
||||
ctx,
|
||||
user.client,
|
||||
user.reporter,
|
||||
authID,
|
||||
user.vault.AddressMode(),
|
||||
settings,
|
||||
userKR, addrKR,
|
||||
emails, from, to,
|
||||
message,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
|
||||
// If the message was successfully sent, we can update the message ID in the record.
|
||||
user.sendHash.signalMessageSent(hash, srID, sent.ID)
|
||||
|
||||
return nil
|
||||
})
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.eventLock)
|
||||
}
|
||||
|
||||
// sendWithKey sends the message with the given address key.
|
||||
func (user *User) sendWithKey(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
sentry reporter.Reporter,
|
||||
authAddrID string,
|
||||
addrMode vault.AddressMode,
|
||||
settings proton.MailSettings,
|
||||
userKR, addrKR *crypto.KeyRing,
|
||||
emails []string,
|
||||
from string,
|
||||
to []string,
|
||||
message message.Message,
|
||||
) (proton.Message, error) {
|
||||
references := message.References
|
||||
if message.InReplyTo != "" {
|
||||
references = append(references, message.InReplyTo)
|
||||
}
|
||||
parentID, err := getParentID(ctx, client, authAddrID, addrMode, references)
|
||||
if err != nil {
|
||||
if err := sentry.ReportMessageWithContext("Failed to get parent ID", reporter.Context{
|
||||
"error": err,
|
||||
"references": message.References,
|
||||
}); err != nil {
|
||||
logrus.WithError(err).Error("Failed to report error")
|
||||
}
|
||||
|
||||
logrus.WithError(err).Warn("Failed to get parent ID")
|
||||
}
|
||||
|
||||
var decBody string
|
||||
|
||||
// nolint:exhaustive
|
||||
switch message.MIMEType {
|
||||
case rfc822.TextHTML:
|
||||
decBody = string(message.RichBody)
|
||||
|
||||
case rfc822.TextPlain:
|
||||
decBody = string(message.PlainBody)
|
||||
|
||||
default:
|
||||
return proton.Message{}, fmt.Errorf("unsupported MIME type: %v", message.MIMEType)
|
||||
}
|
||||
|
||||
draft, err := createDraft(ctx, client, addrKR, emails, from, to, parentID, message.InReplyTo, proton.DraftTemplate{
|
||||
Subject: message.Subject,
|
||||
Body: decBody,
|
||||
MIMEType: message.MIMEType,
|
||||
|
||||
Sender: message.Sender,
|
||||
ToList: message.ToList,
|
||||
CCList: message.CCList,
|
||||
BCCList: message.BCCList,
|
||||
|
||||
ExternalID: message.ExternalID,
|
||||
})
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
||||
}
|
||||
|
||||
attKeys, err := user.createAttachments(ctx, client, addrKR, draft.ID, message.Attachments)
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
||||
}
|
||||
|
||||
recipients, err := user.getRecipients(ctx, client, userKR, settings, draft)
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err)
|
||||
}
|
||||
|
||||
req, err := createSendReq(addrKR, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys)
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to create packages: %w", err)
|
||||
}
|
||||
|
||||
res, err := client.SendDraft(ctx, draft.ID, req)
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to send draft: %w", err)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func getParentID(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
authAddrID string,
|
||||
addrMode vault.AddressMode,
|
||||
references []string,
|
||||
) (string, error) {
|
||||
var (
|
||||
parentID string
|
||||
internal []string
|
||||
external []string
|
||||
)
|
||||
|
||||
// Collect all the internal and external references of the message.
|
||||
for _, ref := range references {
|
||||
if strings.Contains(ref, message.InternalIDDomain) {
|
||||
internal = append(internal, strings.TrimSuffix(ref, "@"+message.InternalIDDomain))
|
||||
} else {
|
||||
external = append(external, ref)
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find a parent ID in the internal references.
|
||||
for _, internal := range internal {
|
||||
var addrID string
|
||||
|
||||
if addrMode == vault.SplitMode {
|
||||
addrID = authAddrID
|
||||
}
|
||||
|
||||
metadata, err := client.GetMessageMetadata(ctx, proton.MessageFilter{
|
||||
ID: []string{internal},
|
||||
AddressID: addrID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get message metadata: %w", err)
|
||||
}
|
||||
|
||||
for _, metadata := range metadata {
|
||||
if !metadata.IsDraft() {
|
||||
parentID = metadata.ID
|
||||
} else if err := client.DeleteMessage(ctx, metadata.ID); err != nil {
|
||||
return "", fmt.Errorf("failed to delete message: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no parent was found, try to find it in the last external reference.
|
||||
// There can be multiple messages with the same external ID; in this case, we first look if
|
||||
// there is a single one sent by this account (with the `MessageFlagSent` flag set), if yes,
|
||||
// then pick that, otherwise don't pick any parent.
|
||||
if parentID == "" && len(external) > 0 {
|
||||
var addrID string
|
||||
|
||||
if addrMode == vault.SplitMode {
|
||||
addrID = authAddrID
|
||||
}
|
||||
|
||||
metadata, err := client.GetMessageMetadata(ctx, proton.MessageFilter{
|
||||
ExternalID: external[len(external)-1],
|
||||
AddressID: addrID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get message metadata: %w", err)
|
||||
}
|
||||
|
||||
switch len(metadata) {
|
||||
case 1:
|
||||
// found exactly one parent
|
||||
parentID = metadata[0].ID
|
||||
case 0:
|
||||
// found no parents
|
||||
default:
|
||||
// found multiple parents, search through metadata to try to find a singular parent that
|
||||
// was sent by this account.
|
||||
for _, metadata := range metadata {
|
||||
if metadata.Flags.Has(proton.MessageFlagSent) {
|
||||
parentID = metadata.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parentID, nil
|
||||
}
|
||||
|
||||
func createDraft(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
addrKR *crypto.KeyRing,
|
||||
emails []string,
|
||||
from string,
|
||||
to []string,
|
||||
parentID string,
|
||||
replyToID string,
|
||||
template proton.DraftTemplate,
|
||||
) (proton.Message, error) {
|
||||
// Check sender: set the sender if it's missing.
|
||||
if template.Sender == nil {
|
||||
template.Sender = &mail.Address{Address: from}
|
||||
} else if template.Sender.Address == "" {
|
||||
template.Sender.Address = from
|
||||
}
|
||||
|
||||
// Check that the sending address is owned by the user, and if so, sanitize it.
|
||||
if idx := xslices.IndexFunc(emails, func(email string) bool {
|
||||
return strings.EqualFold(email, sanitizeEmail(template.Sender.Address))
|
||||
}); idx < 0 {
|
||||
return proton.Message{}, fmt.Errorf("address %q is not owned by user", template.Sender.Address)
|
||||
} else { //nolint:revive
|
||||
template.Sender.Address = constructEmail(template.Sender.Address, emails[idx])
|
||||
}
|
||||
|
||||
// Check ToList: ensure that ToList only contains addresses we actually plan to send to.
|
||||
template.ToList = xslices.Filter(template.ToList, func(addr *mail.Address) bool {
|
||||
return slices.Contains(to, addr.Address)
|
||||
})
|
||||
|
||||
// Check BCCList: any recipients not present in the ToList or CCList are BCC recipients.
|
||||
for _, recipient := range to {
|
||||
if !slices.Contains(xslices.Map(xslices.Join(template.ToList, template.CCList, template.BCCList), func(addr *mail.Address) string {
|
||||
return addr.Address
|
||||
}), recipient) {
|
||||
template.BCCList = append(template.BCCList, &mail.Address{Address: recipient})
|
||||
}
|
||||
}
|
||||
|
||||
var action proton.CreateDraftAction
|
||||
|
||||
if len(replyToID) > 0 {
|
||||
action = proton.ReplyAction
|
||||
} else {
|
||||
action = proton.ForwardAction
|
||||
}
|
||||
|
||||
return client.CreateDraft(ctx, addrKR, proton.CreateDraftReq{
|
||||
Message: template,
|
||||
ParentID: parentID,
|
||||
Action: action,
|
||||
})
|
||||
}
|
||||
|
||||
func (user *User) createAttachments(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
addrKR *crypto.KeyRing,
|
||||
draftID string,
|
||||
attachments []message.Attachment,
|
||||
) (map[string]*crypto.SessionKey, error) {
|
||||
type attKey struct {
|
||||
attID string
|
||||
key *crypto.SessionKey
|
||||
}
|
||||
|
||||
keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) {
|
||||
defer async.HandlePanic(user.panicHandler)
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"name": logging.Sensitive(att.Name),
|
||||
"contentID": att.ContentID,
|
||||
"disposition": att.Disposition,
|
||||
"mime-type": att.MIMEType,
|
||||
}).Debug("Uploading attachment")
|
||||
|
||||
switch att.Disposition {
|
||||
case proton.InlineDisposition:
|
||||
// Some clients use inline disposition but don't set a content ID. Our API doesn't support this.
|
||||
// We could generate our own content ID, but for simplicity, we just set the disposition to attachment.
|
||||
if att.ContentID == "" {
|
||||
att.Disposition = proton.AttachmentDisposition
|
||||
}
|
||||
|
||||
case proton.AttachmentDisposition:
|
||||
// Nothing to do.
|
||||
|
||||
default:
|
||||
// Some clients leave the content disposition empty or use unsupported values.
|
||||
// We default to inline disposition if a content ID is set, and to attachment disposition otherwise.
|
||||
if att.ContentID != "" {
|
||||
att.Disposition = proton.InlineDisposition
|
||||
} else {
|
||||
att.Disposition = proton.AttachmentDisposition
|
||||
}
|
||||
}
|
||||
|
||||
// Exclude name from params since this is already provided using Filename.
|
||||
delete(att.MIMEParams, "name")
|
||||
delete(att.MIMEParams, "filename")
|
||||
|
||||
attachment, err := client.UploadAttachment(ctx, addrKR, proton.CreateAttachmentReq{
|
||||
Filename: att.Name,
|
||||
MessageID: draftID,
|
||||
MIMEType: rfc822.MIMEType(mime.FormatMediaType(att.MIMEType, att.MIMEParams)),
|
||||
Disposition: att.Disposition,
|
||||
ContentID: att.ContentID,
|
||||
Body: att.Data,
|
||||
})
|
||||
if err != nil {
|
||||
return attKey{}, fmt.Errorf("failed to upload attachment: %w", err)
|
||||
}
|
||||
|
||||
keyPacket, err := base64.StdEncoding.DecodeString(attachment.KeyPackets)
|
||||
if err != nil {
|
||||
return attKey{}, fmt.Errorf("failed to decode key packets: %w", err)
|
||||
}
|
||||
|
||||
key, err := addrKR.DecryptSessionKey(keyPacket)
|
||||
if err != nil {
|
||||
return attKey{}, fmt.Errorf("failed to decrypt session key: %w", err)
|
||||
}
|
||||
|
||||
return attKey{attID: attachment.ID, key: key}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create attachments: %w", err)
|
||||
}
|
||||
|
||||
attKeys := make(map[string]*crypto.SessionKey)
|
||||
|
||||
for _, key := range keys {
|
||||
attKeys[key.attID] = key.key
|
||||
}
|
||||
|
||||
return attKeys, nil
|
||||
}
|
||||
|
||||
func (user *User) getRecipients(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
userKR *crypto.KeyRing,
|
||||
settings proton.MailSettings,
|
||||
draft proton.Message,
|
||||
) (recipients, error) {
|
||||
addresses := xslices.Map(xslices.Join(draft.ToList, draft.CCList, draft.BCCList), func(addr *mail.Address) string {
|
||||
return addr.Address
|
||||
})
|
||||
|
||||
prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) {
|
||||
defer async.HandlePanic(user.panicHandler)
|
||||
|
||||
pubKeys, recType, err := client.GetPublicKeys(ctx, recipient)
|
||||
if err != nil {
|
||||
return proton.SendPreferences{}, fmt.Errorf("failed to get public key for %v: %w", recipient, err)
|
||||
}
|
||||
|
||||
contactSettings, err := getContactSettings(ctx, client, userKR, recipient)
|
||||
if err != nil {
|
||||
return proton.SendPreferences{}, fmt.Errorf("failed to get contact settings for %v: %w", recipient, err)
|
||||
}
|
||||
|
||||
return buildSendPrefs(contactSettings, settings, pubKeys, draft.MIMEType, recType == proton.RecipientTypeInternal)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get send preferences: %w", err)
|
||||
}
|
||||
|
||||
recipients := make(recipients)
|
||||
|
||||
for idx, pref := range prefs {
|
||||
recipients[addresses[idx]] = pref
|
||||
}
|
||||
|
||||
return recipients, nil
|
||||
}
|
||||
|
||||
func getContactSettings(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
userKR *crypto.KeyRing,
|
||||
recipient string,
|
||||
) (proton.ContactSettings, error) {
|
||||
contacts, err := client.GetAllContactEmails(ctx, recipient)
|
||||
if err != nil {
|
||||
return proton.ContactSettings{}, fmt.Errorf("failed to get contact data: %w", err)
|
||||
}
|
||||
|
||||
idx := xslices.IndexFunc(contacts, func(contact proton.ContactEmail) bool {
|
||||
return contact.Email == recipient
|
||||
})
|
||||
|
||||
if idx < 0 {
|
||||
return proton.ContactSettings{}, nil
|
||||
}
|
||||
|
||||
contact, err := client.GetContact(ctx, contacts[idx].ContactID)
|
||||
if err != nil {
|
||||
return proton.ContactSettings{}, fmt.Errorf("failed to get contact: %w", err)
|
||||
}
|
||||
|
||||
return contact.GetSettings(userKR, recipient)
|
||||
}
|
||||
|
||||
func getMessageSender(parser *parser.Parser) (string, bool) {
|
||||
address, err := rfc5322.ParseAddressList(parser.Root().Header.Get("From"))
|
||||
if err != nil {
|
||||
return "", false
|
||||
} else if len(address) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return address[0].Address, true
|
||||
}
|
||||
|
||||
func sanitizeEmail(email string) string {
|
||||
splitAt := strings.Split(email, "@")
|
||||
if len(splitAt) != 2 {
|
||||
return email
|
||||
}
|
||||
|
||||
return strings.Split(splitAt[0], "+")[0] + "@" + splitAt[1]
|
||||
}
|
||||
|
||||
func constructEmail(headerEmail string, addressEmail string) string {
|
||||
splitAtHeader := strings.Split(headerEmail, "@")
|
||||
if len(splitAtHeader) != 2 {
|
||||
return addressEmail
|
||||
}
|
||||
|
||||
splitPlus := strings.Split(splitAtHeader[0], "+")
|
||||
if len(splitPlus) != 2 {
|
||||
return addressEmail
|
||||
}
|
||||
|
||||
splitAtAddress := strings.Split(addressEmail, "@")
|
||||
if len(splitAtAddress) != 2 {
|
||||
return addressEmail
|
||||
}
|
||||
|
||||
return splitAtAddress[0] + "+" + splitPlus[1] + "@" + splitAtAddress[1]
|
||||
}
|
||||
@ -1,48 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
//go:build build_qa
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
func debugDumpToDisk(b []byte) error {
|
||||
if os.Getenv("BRIDGE_SMTP_DEBUG") == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user home dir: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(home, getFileName()), b, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write message file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getFileName() string {
|
||||
return fmt.Sprintf("smtp_debug_%v.eml", time.Now().Unix())
|
||||
}
|
||||
@ -1,24 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
//go:build !build_qa
|
||||
|
||||
package user
|
||||
|
||||
func debugDumpToDisk(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
@ -1,86 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
func createSendReq(
|
||||
kr *crypto.KeyRing,
|
||||
mimeBody message.MIMEBody,
|
||||
richBody, plainBody message.Body,
|
||||
recipients recipients,
|
||||
attKeys map[string]*crypto.SessionKey,
|
||||
) (proton.SendDraftReq, error) {
|
||||
var req proton.SendDraftReq
|
||||
|
||||
if recs := recipients.scheme(proton.PGPMIMEScheme, proton.ClearMIMEScheme); len(recs) > 0 {
|
||||
if err := req.AddMIMEPackage(kr, string(mimeBody), recs); err != nil {
|
||||
return proton.SendDraftReq{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if recs := recipients.scheme(proton.InternalScheme, proton.ClearScheme, proton.PGPInlineScheme); len(recs) > 0 {
|
||||
if recs := recs.content(rfc822.TextHTML); len(recs) > 0 {
|
||||
if err := req.AddTextPackage(kr, string(richBody), rfc822.TextHTML, recs, attKeys); err != nil {
|
||||
return proton.SendDraftReq{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if recs := recs.content(rfc822.TextPlain); len(recs) > 0 {
|
||||
if err := req.AddTextPackage(kr, string(plainBody), rfc822.TextPlain, recs, attKeys); err != nil {
|
||||
return proton.SendDraftReq{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
type recipients map[string]proton.SendPreferences
|
||||
|
||||
func (r recipients) scheme(scheme ...proton.EncryptionScheme) recipients {
|
||||
res := make(recipients)
|
||||
|
||||
for _, addr := range xslices.Filter(maps.Keys(r), func(addr string) bool {
|
||||
return slices.Contains(scheme, r[addr].EncryptionScheme)
|
||||
}) {
|
||||
res[addr] = r[addr]
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (r recipients) content(mimeType ...rfc822.MIMEType) recipients {
|
||||
res := make(recipients)
|
||||
|
||||
for _, addr := range xslices.Filter(maps.Keys(r), func(addr string) bool {
|
||||
return slices.Contains(mimeType, r[addr].MIMEType)
|
||||
}) {
|
||||
res[addr] = r[addr]
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
@ -1,586 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
pgpInline = "pgp-inline"
|
||||
pgpMIME = "pgp-mime"
|
||||
pmInternal = "internal" // A mix between pgpInline and pgpMime used by PM.
|
||||
)
|
||||
|
||||
type contactSettings struct {
|
||||
Email string
|
||||
Keys []string
|
||||
Scheme string
|
||||
Sign bool
|
||||
SignIsSet bool
|
||||
Encrypt bool
|
||||
MIMEType rfc822.MIMEType
|
||||
}
|
||||
|
||||
// newContactSettings converts the API settings into our local settings.
|
||||
// This is due to the legacy send preferences code.
|
||||
func newContactSettings(settings proton.ContactSettings) *contactSettings {
|
||||
metadata := &contactSettings{}
|
||||
|
||||
if settings.MIMEType != nil {
|
||||
metadata.MIMEType = *settings.MIMEType
|
||||
}
|
||||
|
||||
if settings.Sign != nil {
|
||||
metadata.Sign = *settings.Sign
|
||||
metadata.SignIsSet = true
|
||||
}
|
||||
|
||||
if settings.Encrypt != nil {
|
||||
metadata.Encrypt = *settings.Encrypt
|
||||
}
|
||||
|
||||
if settings.Scheme != nil {
|
||||
switch *settings.Scheme { // nolint:exhaustive
|
||||
case proton.PGPMIMEScheme:
|
||||
metadata.Scheme = pgpMIME
|
||||
|
||||
case proton.PGPInlineScheme:
|
||||
metadata.Scheme = pgpInline
|
||||
|
||||
default:
|
||||
panic("unknown scheme")
|
||||
}
|
||||
}
|
||||
|
||||
if settings.Keys != nil {
|
||||
for _, key := range settings.Keys {
|
||||
b, err := key.Serialize()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
metadata.Keys = append(metadata.Keys, string(b))
|
||||
}
|
||||
}
|
||||
|
||||
return metadata
|
||||
}
|
||||
|
||||
func buildSendPrefs(
|
||||
contactSettings proton.ContactSettings,
|
||||
mailSettings proton.MailSettings,
|
||||
pubKeys []proton.PublicKey,
|
||||
mimeType rfc822.MIMEType,
|
||||
isInternal bool,
|
||||
) (proton.SendPreferences, error) {
|
||||
builder := &sendPrefsBuilder{}
|
||||
|
||||
if err := builder.setPGPSettings(newContactSettings(contactSettings), pubKeys, isInternal); err != nil {
|
||||
return proton.SendPreferences{}, fmt.Errorf("failed to set PGP settings: %w", err)
|
||||
}
|
||||
|
||||
builder.setEncryptionPreferences(mailSettings)
|
||||
|
||||
builder.setMIMEPreferences(string(mimeType))
|
||||
|
||||
return builder.build(), nil
|
||||
}
|
||||
|
||||
type sendPrefsBuilder struct {
|
||||
internal bool
|
||||
encrypt *bool
|
||||
sign *bool
|
||||
scheme *string
|
||||
mimeType *rfc822.MIMEType
|
||||
publicKey *crypto.KeyRing
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) withInternal() {
|
||||
b.internal = true
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) isInternal() bool {
|
||||
return b.internal
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) withEncrypt(v bool) {
|
||||
b.encrypt = &v
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) withEncryptDefault(v bool) {
|
||||
if b.encrypt == nil {
|
||||
b.encrypt = &v
|
||||
}
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) shouldEncrypt() bool {
|
||||
if b.encrypt != nil {
|
||||
return *b.encrypt
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) withSign(sign bool) {
|
||||
b.sign = &sign
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) withSignDefault() {
|
||||
v := true
|
||||
if b.sign == nil {
|
||||
b.sign = &v
|
||||
}
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) shouldSign() bool {
|
||||
if b.sign != nil {
|
||||
return *b.sign
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) withScheme(v string) {
|
||||
b.scheme = &v
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) withSchemeDefault(v string) {
|
||||
if b.scheme == nil {
|
||||
b.scheme = &v
|
||||
}
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) getScheme() string {
|
||||
if b.scheme != nil {
|
||||
return *b.scheme
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) withMIMEType(v rfc822.MIMEType) {
|
||||
b.mimeType = &v
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) withMIMETypeDefault(v rfc822.MIMEType) {
|
||||
if b.mimeType == nil {
|
||||
b.mimeType = &v
|
||||
}
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) removeMIMEType() {
|
||||
b.mimeType = nil
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) getMIMEType() rfc822.MIMEType {
|
||||
if b.mimeType != nil {
|
||||
return *b.mimeType
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) withPublicKey(v *crypto.KeyRing) {
|
||||
b.publicKey = v
|
||||
}
|
||||
|
||||
// Build converts the PGP scheme with a string value into a number value, and
|
||||
// we may override some of the other encryption preferences with the composer
|
||||
// preferences. Notice that the composer allows to select a sign preference,
|
||||
// an email format preference and an encrypt-to-outside preference. The
|
||||
// object we extract has the following possible value types:
|
||||
//
|
||||
// {
|
||||
// encrypt: true | false,
|
||||
// sign: true | false,
|
||||
// pgpScheme: 1 (protonmail custom scheme)
|
||||
// | 2 (Protonmail scheme for encrypted-to-outside email)
|
||||
// | 4 (no cryptographic scheme)
|
||||
// | 8 (PGP/INLINE)
|
||||
// | 16 (PGP/MIME),
|
||||
// mimeType: 'text/html' | 'text/plain' | 'multipart/mixed',
|
||||
// publicKey: OpenPGPKey | undefined/null
|
||||
// }.
|
||||
func (b *sendPrefsBuilder) build() (p proton.SendPreferences) {
|
||||
p.Encrypt = b.shouldEncrypt()
|
||||
p.MIMEType = b.getMIMEType()
|
||||
p.PubKey = b.publicKey
|
||||
|
||||
if b.shouldSign() {
|
||||
p.SignatureType = proton.DetachedSignature
|
||||
} else {
|
||||
p.SignatureType = proton.NoSignature
|
||||
}
|
||||
|
||||
switch {
|
||||
case b.isInternal():
|
||||
p.EncryptionScheme = proton.InternalScheme
|
||||
|
||||
case b.shouldSign() && b.shouldEncrypt():
|
||||
if b.getScheme() == pgpInline {
|
||||
p.EncryptionScheme = proton.PGPInlineScheme
|
||||
} else {
|
||||
p.EncryptionScheme = proton.PGPMIMEScheme
|
||||
}
|
||||
|
||||
case b.shouldSign() && !b.shouldEncrypt():
|
||||
if b.getScheme() == pgpInline {
|
||||
p.EncryptionScheme = proton.ClearScheme
|
||||
} else {
|
||||
p.EncryptionScheme = proton.ClearMIMEScheme
|
||||
}
|
||||
|
||||
default:
|
||||
p.EncryptionScheme = proton.ClearScheme
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// setPGPSettings returns a SendPreferences with the following possible values:
|
||||
//
|
||||
// {
|
||||
// encrypt: true | false | undefined/null/'',
|
||||
// sign: true | false | undefined/null/'',
|
||||
// pgpScheme: 'pgp-mime' | 'pgp-inline' | undefined/null/'',
|
||||
// mimeType: 'text/html' | 'text/plain' | undefined/null/'',
|
||||
// publicKey: OpenPGPKey | undefined/null
|
||||
// }
|
||||
//
|
||||
// These settings are simply a reflection of the vCard content plus the public
|
||||
// key info retrieved from the API via the GET KEYS route.
|
||||
func (b *sendPrefsBuilder) setPGPSettings(
|
||||
vCardData *contactSettings,
|
||||
apiKeys []proton.PublicKey,
|
||||
isInternal bool,
|
||||
) (err error) {
|
||||
// If there is no contact metadata, we can just use a default constructed one.
|
||||
if vCardData == nil {
|
||||
vCardData = &contactSettings{}
|
||||
}
|
||||
|
||||
// Sending internal.
|
||||
// We are guaranteed to always receive API keys.
|
||||
if isInternal {
|
||||
b.withInternal()
|
||||
return b.setInternalPGPSettings(vCardData, apiKeys)
|
||||
}
|
||||
|
||||
// Sending external but with keys supplied by WKD.
|
||||
// Treated pretty much same as internal.
|
||||
if len(apiKeys) > 0 {
|
||||
return b.setExternalPGPSettingsWithWKDKeys(vCardData, apiKeys)
|
||||
}
|
||||
|
||||
// Sending external without any WKD keys.
|
||||
// If we have a contact saved, we can use its settings.
|
||||
return b.setExternalPGPSettingsWithoutWKDKeys(vCardData)
|
||||
}
|
||||
|
||||
// setInternalPGPSettings returns SendPreferences for internal messages.
|
||||
// An internal address can be either an obvious one: abc@protonmail.com,
|
||||
// abc@protonmail.ch or abc@pm.me, or one belonging to a custom domain
|
||||
// registered with proton.
|
||||
func (b *sendPrefsBuilder) setInternalPGPSettings(
|
||||
vCardData *contactSettings,
|
||||
apiKeys []proton.PublicKey,
|
||||
) error {
|
||||
// We're guaranteed to get at least one valid (i.e. not expired, revoked or
|
||||
// marked as verification-only) public key from the server.
|
||||
if len(apiKeys) == 0 {
|
||||
return errors.New("an API key is necessary but wasn't provided")
|
||||
}
|
||||
|
||||
// We always encrypt and sign internal mail.
|
||||
b.withEncrypt(true)
|
||||
b.withSign(true)
|
||||
|
||||
// We use a custom scheme for internal messages.
|
||||
b.withScheme(pmInternal)
|
||||
|
||||
// If user has overridden the MIMEType for a contact, we use that.
|
||||
// Otherwise, we take the MIMEType from the composer.
|
||||
if vCardData.MIMEType != "" {
|
||||
b.withMIMEType(vCardData.MIMEType)
|
||||
}
|
||||
|
||||
sendingKey, err := pickSendingKey(vCardData, apiKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.withPublicKey(sendingKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// pickSendingKey tries to determine which key to use to encrypt outgoing mail.
|
||||
// It returns a keyring containing the chosen key or an error.
|
||||
//
|
||||
// 1. If there are pinned keys in the vCard, those should be given preference
|
||||
// (assuming the fingerprint matches one of the keys served by the API).
|
||||
// 2. If there are pinned keys in the vCard but no matching keys were served
|
||||
// by the API, we use one of the API keys but first show a modal to the
|
||||
// user to ask them to confirm that they trust the API key.
|
||||
// (Use case: user doesn't trust server, pins the only keys they trust to
|
||||
// the contact, rogue server sends unknown keys, user should have option
|
||||
// to say they don't recognise these keys and abort the mail send.)
|
||||
// 3. If there are no pinned keys, then the client should encrypt with the
|
||||
// first valid key served by the API (in principle the server already
|
||||
// validates the keys and the first one provided should be valid).
|
||||
func pickSendingKey(vCardData *contactSettings, rawAPIKeys []proton.PublicKey) (*crypto.KeyRing, error) {
|
||||
contactKeys := make([]*crypto.Key, len(vCardData.Keys))
|
||||
apiKeys := make([]*crypto.Key, len(rawAPIKeys))
|
||||
|
||||
for i, key := range vCardData.Keys {
|
||||
var ck *crypto.Key
|
||||
|
||||
// Contact keys are not armored.
|
||||
var err error
|
||||
if ck, err = crypto.NewKey([]byte(key)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contactKeys[i] = ck
|
||||
}
|
||||
|
||||
for i, key := range rawAPIKeys {
|
||||
var ck *crypto.Key
|
||||
|
||||
// API keys are armored.
|
||||
var err error
|
||||
if ck, err = crypto.NewKeyFromArmored(key.PublicKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiKeys[i] = ck
|
||||
}
|
||||
|
||||
matchedKeys := matchFingerprints(contactKeys, apiKeys)
|
||||
|
||||
var sendingKey *crypto.Key
|
||||
|
||||
switch {
|
||||
// Case 1.
|
||||
case len(matchedKeys) > 0:
|
||||
sendingKey = matchedKeys[0]
|
||||
|
||||
// Case 2.
|
||||
case len(matchedKeys) == 0 && len(contactKeys) > 0:
|
||||
// NOTE: Here we should ask for trust confirmation.
|
||||
sendingKey = apiKeys[0]
|
||||
|
||||
// Case 3.
|
||||
default:
|
||||
sendingKey = apiKeys[0]
|
||||
}
|
||||
|
||||
return crypto.NewKeyRing(sendingKey)
|
||||
}
|
||||
|
||||
func matchFingerprints(a, b []*crypto.Key) (res []*crypto.Key) {
|
||||
aMap := make(map[string]*crypto.Key)
|
||||
|
||||
for _, el := range a {
|
||||
aMap[el.GetFingerprint()] = el
|
||||
}
|
||||
|
||||
for _, el := range b {
|
||||
if _, inA := aMap[el.GetFingerprint()]; inA {
|
||||
res = append(res, el)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) setExternalPGPSettingsWithWKDKeys(
|
||||
vCardData *contactSettings,
|
||||
apiKeys []proton.PublicKey,
|
||||
) error {
|
||||
// We're guaranteed to get at least one valid (i.e. not expired, revoked or
|
||||
// marked as verification-only) public key from the server.
|
||||
if len(apiKeys) == 0 {
|
||||
return errors.New("an API key is necessary but wasn't provided")
|
||||
}
|
||||
|
||||
// We always encrypt and sign external mail if WKD keys are present.
|
||||
b.withEncrypt(true)
|
||||
b.withSign(true)
|
||||
|
||||
// If the contact has a specific Scheme preference, we set it (otherwise we
|
||||
// leave it unset to allow it to be filled in with the default value later).
|
||||
if vCardData.Scheme != "" {
|
||||
b.withScheme(vCardData.Scheme)
|
||||
}
|
||||
|
||||
// Because the email is signed, the cryptographic scheme determines the email
|
||||
// format. A PGP/INLINE scheme forces to use plain text. A PGP/MIME scheme
|
||||
// forces the automatic format.
|
||||
switch vCardData.Scheme {
|
||||
case pgpMIME:
|
||||
b.removeMIMEType()
|
||||
case pgpInline:
|
||||
b.withMIMEType("text/plain")
|
||||
}
|
||||
|
||||
sendingKey, err := pickSendingKey(vCardData, apiKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.withPublicKey(sendingKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) setExternalPGPSettingsWithoutWKDKeys(
|
||||
vCardData *contactSettings,
|
||||
) error {
|
||||
b.withEncrypt(vCardData.Encrypt)
|
||||
|
||||
if vCardData.SignIsSet {
|
||||
b.withSign(vCardData.Sign)
|
||||
}
|
||||
|
||||
// Sign must be enabled whenever encrypt is.
|
||||
if vCardData.Encrypt {
|
||||
b.withSign(true)
|
||||
}
|
||||
|
||||
// If the contact has a specific Scheme preference, we set it (otherwise we
|
||||
// leave it unset to allow it to be filled in with the default value later).
|
||||
if vCardData.Scheme != "" {
|
||||
b.withScheme(vCardData.Scheme)
|
||||
}
|
||||
|
||||
// If we are signing the message, the PGP scheme overrides the MIMEType.
|
||||
// Otherwise, we read the MIMEType from the vCard, if set.
|
||||
if vCardData.Sign {
|
||||
switch vCardData.Scheme {
|
||||
case pgpMIME:
|
||||
b.removeMIMEType()
|
||||
case pgpInline:
|
||||
b.withMIMEType("text/plain")
|
||||
}
|
||||
} else if vCardData.MIMEType != "" {
|
||||
b.withMIMEType(vCardData.MIMEType)
|
||||
}
|
||||
|
||||
if len(vCardData.Keys) > 0 {
|
||||
var (
|
||||
key *crypto.Key
|
||||
err error
|
||||
)
|
||||
|
||||
// Contact keys are not armored.
|
||||
if key, err = crypto.NewKey([]byte(vCardData.Keys[0])); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var kr *crypto.KeyRing
|
||||
|
||||
if kr, err = crypto.NewKeyRing(key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.withPublicKey(kr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setEncryptionPreferences sets the undefined values in the SendPreferences
|
||||
// determined thus far using using the (global) user mail settings.
|
||||
// The object we extract has the following possible value types:
|
||||
//
|
||||
// {
|
||||
// encrypt: true | false,
|
||||
// sign: true | false,
|
||||
// pgpScheme: 'pgp-mime' | 'pgp-inline',
|
||||
// mimeType: 'text/html' | 'text/plain',
|
||||
// publicKey: OpenPGPKey | undefined/null
|
||||
// }
|
||||
//
|
||||
// The public key can still be undefined as we do not need it if the outgoing
|
||||
// email is not encrypted.
|
||||
func (b *sendPrefsBuilder) setEncryptionPreferences(mailSettings proton.MailSettings) {
|
||||
// For internal addresses or external ones with WKD keys, this flag should
|
||||
// always be true. For external ones, an undefined flag defaults to false.
|
||||
b.withEncryptDefault(false)
|
||||
|
||||
// For internal addresses or external ones with WKD keys, this flag should
|
||||
// always be true. For external ones, an undefined flag defaults to the user
|
||||
// mail setting "Sign External messages". Otherwise we keep the defined value
|
||||
// unless it conflicts with the encrypt flag (we do not allow to send
|
||||
// encrypted but not signed).
|
||||
if mailSettings.Sign > 0 {
|
||||
b.withSignDefault()
|
||||
}
|
||||
|
||||
if b.shouldEncrypt() {
|
||||
b.withSign(true)
|
||||
}
|
||||
|
||||
// If undefined, default to the user mail setting "Default PGP scheme".
|
||||
// Otherwise keep the defined value.
|
||||
switch mailSettings.PGPScheme {
|
||||
case proton.PGPInlineScheme:
|
||||
b.withSchemeDefault(pgpInline)
|
||||
case proton.PGPMIMEScheme:
|
||||
b.withSchemeDefault(pgpMIME)
|
||||
case proton.ClearMIMEScheme, proton.ClearScheme, proton.EncryptedOutsideScheme, proton.InternalScheme:
|
||||
// nothing to set
|
||||
}
|
||||
|
||||
// Its value is constrained by the sign flag and the PGP scheme:
|
||||
// - Sign flag = true → For a PGP/Inline scheme, the MIME type must be
|
||||
// 'plain/text'. Otherwise we default to the user mail setting "Composer mode"
|
||||
// - Sign flag = false → If undefined, default to the user mail setting
|
||||
// "Composer mode". Otherwise keep the defined value.
|
||||
if b.shouldSign() && b.getScheme() == pgpInline {
|
||||
b.withMIMEType("text/plain")
|
||||
} else {
|
||||
b.withMIMETypeDefault(mailSettings.DraftMIMEType)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *sendPrefsBuilder) setMIMEPreferences(composerMIMEType string) {
|
||||
// If the sign flag (that we just determined above) is true, then the MIME
|
||||
// type is determined by the PGP scheme (also determined above): we should
|
||||
// use 'text/plain' for a PGP/Inline scheme, and 'multipart/mixed' otherwise.
|
||||
// Otherwise we use the MIME type from the encryption preferences, unless
|
||||
// the plain text option has been selecting in the composer, which should
|
||||
// enforce 'text/plain' and override the encryption preference.
|
||||
if !b.isInternal() && b.shouldSign() {
|
||||
switch b.getScheme() {
|
||||
case pgpInline:
|
||||
b.withMIMEType("text/plain")
|
||||
default:
|
||||
b.withMIMEType("multipart/mixed")
|
||||
}
|
||||
} else if composerMIMEType == "text/plain" {
|
||||
b.withMIMEType("text/plain")
|
||||
}
|
||||
}
|
||||
@ -1,445 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPreferencesBuilder(t *testing.T) {
|
||||
testContactKey := loadContactKey(t, testPublicKey)
|
||||
testOtherContactKey := loadContactKey(t, testOtherPublicKey)
|
||||
|
||||
tests := []struct { //nolint:maligned
|
||||
name string
|
||||
|
||||
contactMeta *contactSettings
|
||||
receivedKeys []proton.PublicKey
|
||||
isInternal bool
|
||||
mailSettings proton.MailSettings
|
||||
composerMIMEType string
|
||||
|
||||
wantEncrypt bool
|
||||
wantSign proton.SignatureType
|
||||
wantScheme proton.EncryptionScheme
|
||||
wantMIMEType rfc822.MIMEType
|
||||
wantPublicKey string
|
||||
}{
|
||||
{
|
||||
name: "internal",
|
||||
|
||||
contactMeta: &contactSettings{},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: true,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.InternalScheme,
|
||||
wantMIMEType: "text/html",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "internal with contact-specific email format",
|
||||
|
||||
contactMeta: &contactSettings{MIMEType: "text/plain"},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: true,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.InternalScheme,
|
||||
wantMIMEType: "text/plain",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "internal with pinned contact public key",
|
||||
|
||||
contactMeta: &contactSettings{Keys: []string{testContactKey}},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: true,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.InternalScheme,
|
||||
wantMIMEType: "text/html",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
// NOTE: Need to figured out how to test that this calls the frontend to check for user confirmation.
|
||||
name: "internal with conflicting contact public key",
|
||||
|
||||
contactMeta: &contactSettings{Keys: []string{testOtherContactKey}},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: true,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.InternalScheme,
|
||||
wantMIMEType: "text/html",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "wkd-external",
|
||||
|
||||
contactMeta: &contactSettings{},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.PGPMIMEScheme,
|
||||
wantMIMEType: "multipart/mixed",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "wkd-external with contact-specific email format",
|
||||
|
||||
contactMeta: &contactSettings{MIMEType: "text/plain"},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.PGPMIMEScheme,
|
||||
wantMIMEType: "multipart/mixed",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "wkd-external with global pgp-inline scheme",
|
||||
|
||||
contactMeta: &contactSettings{},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPInlineScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.PGPInlineScheme,
|
||||
wantMIMEType: "text/plain",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "wkd-external with contact-specific pgp-inline scheme overriding global pgp-mime setting",
|
||||
|
||||
contactMeta: &contactSettings{Scheme: pgpInline},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.PGPInlineScheme,
|
||||
wantMIMEType: "text/plain",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "wkd-external with contact-specific pgp-mime scheme overriding global pgp-inline setting",
|
||||
|
||||
contactMeta: &contactSettings{Scheme: pgpMIME},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPInlineScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.PGPMIMEScheme,
|
||||
wantMIMEType: "multipart/mixed",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "wkd-external with additional pinned contact public key",
|
||||
|
||||
contactMeta: &contactSettings{Keys: []string{testContactKey}},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.PGPMIMEScheme,
|
||||
wantMIMEType: "multipart/mixed",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
// NOTE: Need to figured out how to test that this calls the frontend to check for user confirmation.
|
||||
name: "wkd-external with additional conflicting contact public key",
|
||||
|
||||
contactMeta: &contactSettings{Keys: []string{testOtherContactKey}},
|
||||
receivedKeys: []proton.PublicKey{{PublicKey: testPublicKey}},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.PGPMIMEScheme,
|
||||
wantMIMEType: "multipart/mixed",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "external",
|
||||
|
||||
contactMeta: &contactSettings{},
|
||||
receivedKeys: []proton.PublicKey{},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: false,
|
||||
wantSign: proton.NoSignature,
|
||||
wantScheme: proton.ClearScheme,
|
||||
wantMIMEType: "text/html",
|
||||
},
|
||||
|
||||
{
|
||||
name: "external with contact-specific email format",
|
||||
|
||||
contactMeta: &contactSettings{MIMEType: "text/plain"},
|
||||
receivedKeys: []proton.PublicKey{},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: false,
|
||||
wantSign: proton.NoSignature,
|
||||
wantScheme: proton.ClearScheme,
|
||||
wantMIMEType: "text/plain",
|
||||
},
|
||||
|
||||
{
|
||||
name: "external with sign enabled",
|
||||
|
||||
contactMeta: &contactSettings{Sign: true, SignIsSet: true},
|
||||
receivedKeys: []proton.PublicKey{},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: false,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.ClearMIMEScheme,
|
||||
wantMIMEType: "multipart/mixed",
|
||||
},
|
||||
|
||||
{
|
||||
name: "external with contact sign enabled and plain text",
|
||||
|
||||
contactMeta: &contactSettings{MIMEType: "text/plain", Scheme: pgpInline, Sign: true, SignIsSet: true},
|
||||
receivedKeys: []proton.PublicKey{},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: false,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.ClearScheme,
|
||||
wantMIMEType: "text/plain",
|
||||
},
|
||||
|
||||
{
|
||||
name: "external with sign enabled, sending plaintext, should still send as ClearMIME",
|
||||
|
||||
contactMeta: &contactSettings{Sign: true, SignIsSet: true},
|
||||
receivedKeys: []proton.PublicKey{},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/plain"},
|
||||
|
||||
wantEncrypt: false,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.ClearMIMEScheme,
|
||||
wantMIMEType: "multipart/mixed",
|
||||
},
|
||||
|
||||
{
|
||||
name: "external with pinned contact public key but no intention to encrypt/sign",
|
||||
|
||||
contactMeta: &contactSettings{Keys: []string{testContactKey}},
|
||||
receivedKeys: []proton.PublicKey{},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: false,
|
||||
wantSign: proton.NoSignature,
|
||||
wantScheme: proton.ClearScheme,
|
||||
wantMIMEType: "text/html",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "external with pinned contact public key, encrypted and signed",
|
||||
|
||||
contactMeta: &contactSettings{Keys: []string{testContactKey}, Encrypt: true, Sign: true, SignIsSet: true},
|
||||
receivedKeys: []proton.PublicKey{},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.PGPMIMEScheme,
|
||||
wantMIMEType: "multipart/mixed",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "external with pinned contact public key, encrypted and signed using contact-specific pgp-inline",
|
||||
|
||||
contactMeta: &contactSettings{Keys: []string{testContactKey}, Encrypt: true, Sign: true, Scheme: pgpInline, SignIsSet: true},
|
||||
receivedKeys: []proton.PublicKey{},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPMIMEScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.PGPInlineScheme,
|
||||
wantMIMEType: "text/plain",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
|
||||
{
|
||||
name: "external with pinned contact public key, encrypted and signed using global pgp-inline",
|
||||
|
||||
contactMeta: &contactSettings{Keys: []string{testContactKey}, Encrypt: true, Sign: true, SignIsSet: true},
|
||||
receivedKeys: []proton.PublicKey{},
|
||||
isInternal: false,
|
||||
mailSettings: proton.MailSettings{PGPScheme: proton.PGPInlineScheme, DraftMIMEType: "text/html"},
|
||||
|
||||
wantEncrypt: true,
|
||||
wantSign: proton.DetachedSignature,
|
||||
wantScheme: proton.PGPInlineScheme,
|
||||
wantMIMEType: "text/plain",
|
||||
wantPublicKey: testPublicKey,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test // Avoid using range scope test inside function literal.
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b := &sendPrefsBuilder{}
|
||||
|
||||
require.NoError(t, b.setPGPSettings(test.contactMeta, test.receivedKeys, test.isInternal))
|
||||
b.setEncryptionPreferences(test.mailSettings)
|
||||
b.setMIMEPreferences(test.composerMIMEType)
|
||||
|
||||
prefs := b.build()
|
||||
|
||||
assert.Equal(t, test.wantEncrypt, prefs.Encrypt)
|
||||
assert.Equal(t, test.wantSign, prefs.SignatureType)
|
||||
assert.Equal(t, test.wantScheme, prefs.EncryptionScheme)
|
||||
assert.Equal(t, test.wantMIMEType, prefs.MIMEType)
|
||||
|
||||
if prefs.PubKey != nil {
|
||||
wantKey, err := crypto.NewKeyFromArmored(test.wantPublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
haveKey, err := prefs.PubKey.GetKey(0)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantKey.GetFingerprint(), haveKey.GetFingerprint())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func loadContactKey(t *testing.T, key string) string {
|
||||
ck, err := crypto.NewKeyFromArmored(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := ck.GetPublicKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
return string(pk)
|
||||
}
|
||||
|
||||
const testPublicKey = `-----BEGIN PGP PUBLIC KEY BLOCK-----
|
||||
|
||||
xsBNBFRJbc0BCAC0mMLZPDBbtSCWvxwmOfXfJkE2+ssM3ux21LhD/bPiWefEWSHl
|
||||
CjJ8PqPHy7snSiUuxuj3f9AvXPvg+mjGLBwu1/QsnSP24sl3qD2onl39vPiLJXUq
|
||||
Zs20ZRgnvX70gjkgEzMFBxINiy2MTIG+4RU8QA7y8KzWev0btqKiMeVa+GLEHhgZ
|
||||
2KPOn4Jv1q4bI9hV0C9NUe2tTXS6/Vv3vbCY7lRR0kbJ65T5c8CmpqJuASIJNrSX
|
||||
M/Q3NnnsY4kBYH0s5d2FgbASQvzrjuC2rngUg0EoPsrbDEVRA2/BCJonw7aASiNC
|
||||
rSP92lkZdtYlax/pcoE/mQ4WSwySFmcFT7yFABEBAAHNBlVzZXJJRMLAcgQQAQgA
|
||||
JgUCVEltzwYLCQgHAwIJED62JZ7fId8kBBUIAgoDFgIBAhsDAh4BAAD0nQf9EtH9
|
||||
TC0JqSs8q194Zo244jjlJFM3EzxOSULq0zbywlLORfyoo/O8jU/HIuGz+LT98JDt
|
||||
nltTqfjWgu6pS3ZL2/L4AGUKEoB7OI6oIdRwzMc61sqI+Qpbzxo7rzufH4CiXZc6
|
||||
cxORUgL550xSCcqnq0q1mds7h5roKDzxMW6WLiEsc1dN8IQKzC7Ec5wA7U4oNGsJ
|
||||
3TyI8jkIs0IhXrRCd26K0TW8Xp6GCsfblWXosR13y89WVNgC+xrrJKTZEisc0tRl
|
||||
neIgjcwEUvwfIg2n9cDUFA/5BsfzTW5IurxqDEziIVP0L44PXjtJrBQaGMPlEbtP
|
||||
5i2oi3OADVX2XbvsRc7ATQRUSW3PAQgAkPnu5fps5zhOB/e618v/iF3KiogxUeRh
|
||||
A68TbvA+xnFfTxCx2Vo14aOL0CnaJ8gO5yRSqfomL2O1kMq07N1MGbqucbmc+aSf
|
||||
oElc+Gd5xBE/w3RcEhKcAaYTi35vG22zlZup4x3ElioyIarOssFEkQgNNyDf5AXZ
|
||||
jdHLA6qVxeqAb/Ff74+y9HUmLPSsRU9NwFzvK3Jv8C/ubHVLzTYdFgYkc4W1Uug9
|
||||
Ou08K+/4NEMrwnPFBbZdJAuUjQz2zW2ZiEKiBggiorH2o5N3mYUnWEmUvqL3EOS8
|
||||
TbWo8UBIW3DDm2JiZR8VrEgvBtc9mVDUj/x+5pR07Fy1D6DjRmAc9wARAQABwsBf
|
||||
BBgBCAATBQJUSW3SCRA+tiWe3yHfJAIbDAAA/iwH/ik9RKZMB9Ir0x5mGpKPuqhu
|
||||
gwrc3d04m1sOdXJm2NtD4ddzSEvzHwaPNvEvUl5v7FVMzf6+6mYGWHyNP4+e7Rtw
|
||||
YLlRpud6smuGyDSsotUYyumiqP6680ZIeWVQ+a1TThNs878mAJy1FhvQFdTmA8XI
|
||||
C616hDFpamQKPlpoO1a0wZnQhrPwT77HDYEEa+hqY4Jr/a7ui40S+7xYRHKL/7ZA
|
||||
S4/grWllhU3dbNrwSzrOKwrA/U0/9t738Ap6JL71YymDeaL4sutcoaahda1pTrMW
|
||||
ePtrCltz6uySwbZs7GXoEzjX3EAH+6qhkUJtzMaE3YEFEoQMGzcDTUEfXCJ3zJw=
|
||||
=yT9U
|
||||
-----END PGP PUBLIC KEY BLOCK-----`
|
||||
|
||||
const testOtherPublicKey = `-----BEGIN PGP PUBLIC KEY BLOCK-----
|
||||
|
||||
mQENBF8Rmj4BCACgXXxRqLsmEUWZGd0f88BteXBfi9zL+9GysOTk4n9EgINLN2PU
|
||||
5rYSmWvVocO8IAfl/z9zpTJQesQjGe5lHbygUWFmjadox2ZeecZw0PWCSRdAjk6w
|
||||
Q4UX0JiCo3IuICZk1t53WWRtGnhA2Q21J4b2DJg4T5ZFKgKDzDhWoGF1ZStbI5X1
|
||||
0rKTGFNHgreV5PqxUjxHVtx3rgT9Mx+13QTffqKR9oaYC6mNs4TNJdhyqfaYxqGw
|
||||
ElxfdS9Wz6ODXrUNuSHETfgvAmo1Qep7GkefrC1isrmXA2+a+mXzFn4L0FCG073w
|
||||
Vi/lEw6R/vKfN6QukHPxwoSguow4wTyhRRmfABEBAAG0GVRlc3RUZXN0IDx0ZXN0
|
||||
dGVzdEBwbS5tZT6JAU4EEwEIADgWIQTsXZU1AxlWCPT02+BKdWAu4Q1jXQUCXxGa
|
||||
PgIbAwULCQgHAgYVCgkICwIEFgIDAQIeAQIXgAAKCRBKdWAu4Q1jXQw+B/0ZudN+
|
||||
W9EqJtL/elm7Qla47zNsFmB+pHObdGoKtp3mNc97CQoW1yQ/i/V0heBFTAioP00g
|
||||
FgEk1ZUJfO++EtI8esNFdDZqY99826/Cl0FlJwubn/XYxi4XyaGTY1nhhyEJ2HWI
|
||||
/mZ+Jfm9ojbHSLwO5/AHiQt5t+LPDsKLXZw1BDJTgf1xD6e36CwAZgrPGWDqCXJ9
|
||||
BjlQn5hje7p0F8vYWBnnfSPkMHwibz9FlFqDh5v3XTgGpFIWDVkPVgAs8erM9AM2
|
||||
TjdpGcdW8xfcymo3j/o2QUBGYGJwPTsGEO5IkFRre9c/3REa7MKIi17Y479ub0A6
|
||||
2J3xgnqgI4sxmgmOuQENBF8Rmj4BCADX3BamNZsjC3I0knVIwjbz//1r8WOfNwGh
|
||||
gg5LsvpfLkrsNUZy+deSwb+hS9Auyr1xsMmtVyiTPGUXTjU4uUzY2zyTYWgYfSEi
|
||||
CojlXmYYLsjyPzR7KhVP6QIYZqYkOQXaCQDRlprRoFIEe4FzTCuqDHatJNwSesGy
|
||||
5pPJrjiAeb47m9KaoEIacoe9D3w1z4FCKN3A8cjiWT8NRfhYTBoE/T34oXVUj8l+
|
||||
jLIgVUQgGoBos160Z1Cnxd2PKWFVh/Br3QtIPTbNVDWhh5T1+N2ypbwsXCawy6fj
|
||||
cbOaTLz/vF9g+RJKC0MtxdL5qUtv3d3Zn07Sg+9H6wjsboAdAvirABEBAAGJATYE
|
||||
GAEIACAWIQTsXZU1AxlWCPT02+BKdWAu4Q1jXQUCXxGaPgIbDAAKCRBKdWAu4Q1j
|
||||
Xc4WB/9+aTGMMTlIdAFs9rf0i7i83pUOOxuLl34YQ0t5WGsjteQ4IK+gfuFvp37W
|
||||
ktv98ShOxAexbfqzGyGcYLLgaCxCbbB85fvSeX0xK/C2UbiH3Gv1z8GTelailCxt
|
||||
vyx642TwpcLXW1obHaHTSIi5L35Tce9gbug9sKCRSlAH76dANYBbMLa2Bl0LSrF8
|
||||
mcie9jJaPRXGOeHOyZmPZwwGhVYgadjptWqXnFz3ua8vxgqG0sefWF23F36iVz2q
|
||||
UjxSE+nKLaPFLlEDLgxG4SwHkcR9fi7zaQVnXg4rEjr0uz5MSUqZC4MNB4rkhU3g
|
||||
/rUMQyZupw+xJ+ayQNVBEtYZd/9u
|
||||
=TNX4
|
||||
-----END PGP PUBLIC KEY BLOCK-----`
|
||||
@ -34,6 +34,7 @@ import (
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
@ -114,7 +115,7 @@ func (user *User) doSync(ctx context.Context) error {
|
||||
|
||||
func (user *User) sync(ctx context.Context) error {
|
||||
return safe.RLockRet(func() error {
|
||||
return withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
return usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
if !user.vault.SyncStatus().HasLabels {
|
||||
user.log.Info("Syncing labels")
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
@ -87,7 +88,7 @@ func newMessageCreatedUpdate(
|
||||
return &imap.MessageCreated{
|
||||
Message: toIMAPMessage(message),
|
||||
Literal: literal,
|
||||
MailboxIDs: mapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
|
||||
MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
|
||||
ParsedMessage: parsedMessage,
|
||||
}, nil
|
||||
}
|
||||
@ -106,7 +107,7 @@ func newMessageCreatedFailedUpdate(
|
||||
|
||||
return &imap.MessageCreated{
|
||||
Message: toIMAPMessage(message),
|
||||
MailboxIDs: mapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
|
||||
MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
|
||||
Literal: literal,
|
||||
ParsedMessage: parsedMessage,
|
||||
}
|
||||
|
||||
@ -1,111 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// mapTo converts the slice to the given type.
|
||||
// This is not runtime safe, so make sure the slice is of the correct type!
|
||||
// (This is a workaround for the fact that slices cannot be converted to other types generically).
|
||||
func mapTo[From, To any](from []From) []To {
|
||||
to := make([]To, 0, len(from))
|
||||
|
||||
for _, from := range from {
|
||||
val, ok := reflect.ValueOf(from).Convert(reflect.TypeOf(to).Elem()).Interface().(To)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("cannot convert %T to %T", from, *new(To))) //nolint:gocritic
|
||||
}
|
||||
|
||||
to = append(to, val)
|
||||
}
|
||||
|
||||
return to
|
||||
}
|
||||
|
||||
// groupBy returns a map of the given slice grouped by the given key.
|
||||
// Duplicate keys are overwritten.
|
||||
func groupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[Key]Value {
|
||||
groups := make(map[Key]Value)
|
||||
|
||||
for _, item := range items {
|
||||
groups[key(item)] = item
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
// getAddrID returns the address ID for the given email address.
|
||||
func getAddrID(apiAddrs map[string]proton.Address, email string) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if strings.EqualFold(addr.Email, sanitizeEmail(email)) {
|
||||
return addr.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("address %s not found", email)
|
||||
}
|
||||
|
||||
// getAddrIdx returns the address with the given index.
|
||||
func getAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, error) {
|
||||
sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool {
|
||||
return a.Order < b.Order
|
||||
})
|
||||
|
||||
if idx < 0 || idx >= len(sorted) {
|
||||
return proton.Address{}, fmt.Errorf("address index %d out of range", idx)
|
||||
}
|
||||
|
||||
return sorted[idx], nil
|
||||
}
|
||||
|
||||
func getPrimaryAddr(apiAddrs map[string]proton.Address) (proton.Address, error) {
|
||||
sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool {
|
||||
return a.Order < b.Order
|
||||
})
|
||||
|
||||
if len(sorted) == 0 {
|
||||
return proton.Address{}, fmt.Errorf("no addresses available")
|
||||
}
|
||||
|
||||
return sorted[0], nil
|
||||
}
|
||||
|
||||
// sortSlice returns the given slice sorted by the given comparator.
|
||||
func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
|
||||
sorted := make([]Item, len(items))
|
||||
|
||||
copy(sorted, items)
|
||||
|
||||
slices.SortFunc(sorted, less)
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
func newProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler {
|
||||
return proton.NewParallelScheduler(runtime.NumCPU()/2, panicHandler)
|
||||
}
|
||||
@ -1,37 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToType(t *testing.T) {
|
||||
type myString string
|
||||
|
||||
// Slices of different types are not equal.
|
||||
require.NotEqual(t, []myString{"a", "b", "c"}, []string{"a", "b", "c"})
|
||||
|
||||
// But converting them to the same type makes them equal.
|
||||
require.Equal(t, []myString{"a", "b", "c"}, mapTo[string, myString]([]string{"a", "b", "c"}))
|
||||
|
||||
// The conversion can happen in the other direction too.
|
||||
require.Equal(t, []string{"a", "b", "c"}, mapTo[myString, string]([]myString{"a", "b", "c"}))
|
||||
}
|
||||
@ -40,7 +40,10 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/services/smtp"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
@ -65,7 +68,7 @@ type User struct {
|
||||
vault *vault.User
|
||||
client *proton.Client
|
||||
reporter reporter.Reporter
|
||||
sendHash *sendRecorder
|
||||
sendHash *sendrecorder.SendRecorder
|
||||
|
||||
eventCh *async.QueuedChannel[events.Event]
|
||||
eventLock safe.RWMutex
|
||||
@ -100,6 +103,8 @@ type User struct {
|
||||
telemetryManager telemetry.Availability
|
||||
// goStatusProgress triggers a check/sending if progress is needed.
|
||||
goStatusProgress func()
|
||||
|
||||
smtpService *smtp.Service
|
||||
}
|
||||
|
||||
// New returns a new user.
|
||||
@ -148,7 +153,7 @@ func New(
|
||||
vault: encVault,
|
||||
client: client,
|
||||
reporter: reporter,
|
||||
sendHash: newSendRecorder(sendEntryExpiry),
|
||||
sendHash: sendrecorder.NewSendRecorder(sendrecorder.SendEntryExpiry),
|
||||
|
||||
eventCh: async.NewQueuedChannel[events.Event](0, 0, crashHandler),
|
||||
eventLock: safe.NewRWMutex(),
|
||||
@ -156,10 +161,10 @@ func New(
|
||||
apiUser: apiUser,
|
||||
apiUserLock: safe.NewRWMutex(),
|
||||
|
||||
apiAddrs: groupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }),
|
||||
apiAddrs: usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }),
|
||||
apiAddrsLock: safe.NewRWMutex(),
|
||||
|
||||
apiLabels: groupBy(apiLabels, func(label proton.Label) string { return label.ID }),
|
||||
apiLabels: usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID }),
|
||||
apiLabelsLock: safe.NewRWMutex(),
|
||||
|
||||
updateCh: make(map[string]*async.QueuedChannel[imap.Update]),
|
||||
@ -178,6 +183,8 @@ func New(
|
||||
telemetryManager: telemetryManager,
|
||||
}
|
||||
|
||||
user.smtpService = smtp.NewService(user, client, user.sendHash, user.panicHandler, user.reporter)
|
||||
|
||||
// Check for status_progress when triggered.
|
||||
user.goStatusProgress = user.tasks.PeriodicOrTrigger(configstatus.ProgressCheckInterval, 0, func(ctx context.Context) {
|
||||
user.SendConfigStatusProgress(ctx)
|
||||
@ -264,6 +271,9 @@ func New(
|
||||
}
|
||||
})
|
||||
|
||||
// Start SMTP Service
|
||||
user.smtpService.Start(user.tasks)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@ -461,7 +471,7 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := getAddrIdx(user.apiAddrs, 0)
|
||||
primAddr, err := usertypes.GetAddrIdx(user.apiAddrs, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
@ -485,10 +495,10 @@ func (user *User) SendMail(authID string, from string, to []string, r io.Reader)
|
||||
}
|
||||
|
||||
if len(to) == 0 {
|
||||
return ErrInvalidRecipient
|
||||
return smtp.ErrInvalidRecipient
|
||||
}
|
||||
|
||||
if err := user.sendMail(authID, from, to, r); err != nil {
|
||||
if err := user.smtpService.SendMail(context.Background(), authID, from, to, r); err != nil {
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) {
|
||||
logrus.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("failed to send message")
|
||||
}
|
||||
@ -664,6 +674,12 @@ func (user *User) SendTelemetry(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) WithSMTPData(ctx context.Context, op func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error {
|
||||
return safe.RLockRet(func() error {
|
||||
return op(ctx, user.apiAddrs, user.apiUser, user.vault)
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.eventLock)
|
||||
}
|
||||
|
||||
// initUpdateCh initializes the user's update channels in the given address mode.
|
||||
// It is assumed that user.apiAddrs and user.updateCh are already locked.
|
||||
func (user *User) initUpdateCh(mode vault.AddressMode) {
|
||||
|
||||
Reference in New Issue
Block a user