From 35fa43f47cc83bc52af55fc8985cea36d21d786e Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Tue, 25 Oct 2022 16:14:56 +0200 Subject: [PATCH] Other: Properly handle SMTP to list in send recorder Checking the BCC header is unreliable; it is usually omitted from messages. Instead, we can use the SMTP "to" list for deduplication. --- internal/user/send_recorder.go | 44 ++++++++++++++++++++++------- internal/user/send_recorder_test.go | 25 ++++++++++++++-- internal/user/smtp.go | 2 +- 3 files changed, 58 insertions(+), 13 deletions(-) diff --git a/internal/user/send_recorder.go b/internal/user/send_recorder.go index 378ea67e..969a97fb 100644 --- a/internal/user/send_recorder.go +++ b/internal/user/send_recorder.go @@ -29,6 +29,7 @@ import ( "github.com/ProtonMail/gluon/rfc822" "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" ) const sendEntryExpiry = 30 * time.Minute @@ -49,6 +50,7 @@ func newSendRecorder(expiry time.Duration) *sendRecorder { type sendEntry struct { msgID string + toList []string exp time.Time waitCh chan struct{} } @@ -56,9 +58,14 @@ type sendEntry struct { // tryInsertWait tries to insert the given message into the send recorder. // If an entry already exists but it was not sent yet, it waits. // It returns whether an entry could be inserted and an error if it times out while waiting. -func (h *sendRecorder) tryInsertWait(ctx context.Context, hash string, deadline time.Time) (bool, error) { +func (h *sendRecorder) tryInsertWait( + ctx context.Context, + hash string, + toList []string, + deadline time.Time, +) (bool, error) { // If we successfully inserted the hash, we can return true. - if h.tryInsert(hash) { + if h.tryInsert(hash, toList) { return true, nil } @@ -70,7 +77,7 @@ func (h *sendRecorder) tryInsertWait(ctx context.Context, hash string, deadline // If the message failed to send, try to insert it again. if !wasSent { - return h.tryInsertWait(ctx, hash, deadline) + return h.tryInsertWait(ctx, hash, toList, deadline) } return false, nil @@ -98,7 +105,7 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context, hash string, deadline t return h.hasEntryWait(ctx, hash, deadline) } -func (h *sendRecorder) tryInsert(hash string) bool { +func (h *sendRecorder) tryInsert(hash string, toList []string) bool { h.entriesLock.Lock() defer h.entriesLock.Unlock() @@ -108,12 +115,13 @@ func (h *sendRecorder) tryInsert(hash string) bool { } } - if _, ok := h.entries[hash]; ok { + if _, ok := h.entries[hash]; ok && matchToList(h.entries[hash].toList, toList) { return false } h.entries[hash] = &sendEntry{ exp: time.Now().Add(h.expiry), + toList: toList, waitCh: make(chan struct{}), } @@ -206,7 +214,7 @@ func (h *sendRecorder) getWaitCh(hash string) (<-chan struct{}, bool) { // getMessageHash returns the hash of the given message. // This takes into account: // - the Subject header, -// - the From/To/Cc/Bcc headers, +// - the From/To/Cc headers, // - the Content-Type header of each (leaf) part, // - the Content-Disposition header of each (leaf) part, // - the (decoded) body of each part. @@ -238,10 +246,6 @@ func getMessageHash(b []byte) (string, error) { return "", err } - if _, err := h.Write([]byte(header.Get("Bcc"))); err != nil { - return "", err - } - if _, err := h.Write([]byte(header.Get("Reply-To"))); err != nil { return "", err } @@ -287,3 +291,23 @@ func getMessageHash(b []byte) (string, error) { return base64.StdEncoding.EncodeToString(h.Sum(nil)), nil } + +func matchToList(a, b []string) bool { + if len(a) != len(b) { + return false + } + + for i := range a { + if !slices.Contains(b, a[i]) { + return false + } + } + + for i := range b { + if !slices.Contains(a, b[i]) { + return false + } + } + + return true +} diff --git a/internal/user/send_recorder_test.go b/internal/user/send_recorder_test.go index d6d561fc..e445b071 100644 --- a/internal/user/send_recorder_test.go +++ b/internal/user/send_recorder_test.go @@ -73,6 +73,27 @@ func TestSendHasher_Insert_Expired(t *testing.T) { require.Equal(t, hash1, hash2) } +func TestSendHasher_Insert_DifferentToList(t *testing.T) { + h := newSendRecorder(time.Second) + + // Insert a message into the hasher. + hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), []string{"abc", "def"}...) + require.NoError(t, err) + require.True(t, ok) + require.NotEmpty(t, hash1) + + // Insert the same message into the hasher but with a different to list. + hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), []string{"abc", "def", "ghi"}...) + require.NoError(t, err) + require.True(t, ok) + require.NotEmpty(t, hash2) + + // Insert the same message into the hasher but with the same to list. + _, ok, err = testTryInsert(h, literal1, time.Now().Add(time.Second), []string{"abc", "def", "ghi"}...) + require.Error(t, err) + require.False(t, ok) +} + func TestSendHasher_Wait_SendSuccess(t *testing.T) { h := newSendRecorder(sendEntryExpiry) @@ -349,13 +370,13 @@ func TestGetMessageHash(t *testing.T) { } } -func testTryInsert(h *sendRecorder, literal string, deadline time.Time) (string, bool, error) { //nolint:unparam +func testTryInsert(h *sendRecorder, literal string, deadline time.Time, toList ...string) (string, bool, error) { //nolint:unparam hash, err := getMessageHash([]byte(literal)) if err != nil { return "", false, err } - ok, err := h.tryInsertWait(context.Background(), hash, deadline) + ok, err := h.tryInsertWait(context.Background(), hash, toList, deadline) if err != nil { return "", false, err } diff --git a/internal/user/smtp.go b/internal/user/smtp.go index 723dfdcf..78d0d334 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -58,7 +58,7 @@ func (user *User) sendMail(authID string, emails []string, from string, to []str } // Check if we already tried to send this message recently. - if ok, err := user.sendHash.tryInsertWait(ctx, hash, time.Now().Add(90*time.Second)); err != nil { + if ok, err := user.sendHash.tryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second)); err != nil { return fmt.Errorf("failed to check send hash: %w", err) } else if !ok { user.log.Warn("A duplicate message was already sent recently, skipping")