From 5fee2f707be2d26f92fffb3de3c016aa96431c23 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Tue, 16 May 2023 17:37:25 +0200 Subject: [PATCH] fix(GODT-2627): Properly handle recording of message with Bcc fields Ensure the SMTP send recorder properly handles the recording of messages which may have the same body hash but have different recipients. E.g.: send the same message twice to 2 different users via Bcc. The send recorder now maintains a list of send requests and waiting for a message to be sent is done one the oldest of the messages. --- internal/user/send_recorder.go | 85 +++++++++++++++++++---------- internal/user/send_recorder_test.go | 38 +++++++++---- internal/user/smtp.go | 4 +- 3 files changed, 84 insertions(+), 43 deletions(-) diff --git a/internal/user/send_recorder.go b/internal/user/send_recorder.go index 228bbfe6..306f1384 100644 --- a/internal/user/send_recorder.go +++ b/internal/user/send_recorder.go @@ -25,6 +25,7 @@ import ( "time" "github.com/ProtonMail/gluon/rfc822" + "github.com/bradenaw/juniper/xslices" "github.com/sirupsen/logrus" "golang.org/x/exp/slices" ) @@ -34,14 +35,14 @@ const sendEntryExpiry = 30 * time.Minute type sendRecorder struct { expiry time.Duration - entries map[string]*sendEntry + entries map[string][]*sendEntry entriesLock sync.Mutex } func newSendRecorder(expiry time.Duration) *sendRecorder { return &sendRecorder{ expiry: expiry, - entries: make(map[string]*sendEntry), + entries: make(map[string][]*sendEntry), } } @@ -110,25 +111,40 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context, hash string, deadline t return h.hasEntryWait(ctx, hash, deadline) } +func (h *sendRecorder) removeExpiredUnsafe() { + for hash, entry := range h.entries { + remaining := xslices.Filter(entry, func(t *sendEntry) bool { + return !t.exp.Before(time.Now()) + }) + + if len(remaining) == 0 { + delete(h.entries, hash) + } else { + h.entries[hash] = remaining + } + } +} + func (h *sendRecorder) tryInsert(hash string, toList []string) bool { h.entriesLock.Lock() defer h.entriesLock.Unlock() - for hash, entry := range h.entries { - if entry.exp.Before(time.Now()) { - delete(h.entries, hash) + h.removeExpiredUnsafe() + + entries, ok := h.entries[hash] + if ok { + for _, entry := range entries { + if matchToList(entry.toList, toList) { + return false + } } } - if _, ok := h.entries[hash]; ok && matchToList(h.entries[hash].toList, toList) { - return false - } - - h.entries[hash] = &sendEntry{ + h.entries[hash] = append(entries, &sendEntry{ exp: time.Now().Add(h.expiry), toList: toList, waitCh: make(chan struct{}), - } + }) return true } @@ -137,11 +153,7 @@ func (h *sendRecorder) hasEntry(hash string) bool { h.entriesLock.Lock() defer h.entriesLock.Unlock() - for hash, entry := range h.entries { - if entry.exp.Before(time.Now()) { - delete(h.entries, hash) - } - } + h.removeExpiredUnsafe() if _, ok := h.entries[hash]; ok { return true @@ -150,33 +162,46 @@ func (h *sendRecorder) hasEntry(hash string) bool { return false } -// addMessageID should be called after a message has been successfully sent. -func (h *sendRecorder) addMessageID(hash, msgID string) { +// signalMessageSent should be called after a message has been successfully sent. +func (h *sendRecorder) signalMessageSent(hash, msgID string, toList []string) { h.entriesLock.Lock() defer h.entriesLock.Unlock() - entry, ok := h.entries[hash] + entries, ok := h.entries[hash] if ok { - entry.msgID = msgID - } else { - logrus.Warn("Cannot add message ID to send hash entry, it may have expired") + for _, entry := range entries { + if matchToList(entry.toList, toList) { + entry.msgID = msgID + entry.closeWaitChannel() + return + } + } } - entry.closeWaitChannel() + logrus.Warn("Cannot add message ID to send hash entry, it may have expired") } -func (h *sendRecorder) removeOnFail(hash string) { +func (h *sendRecorder) removeOnFail(hash string, toList []string) { h.entriesLock.Lock() defer h.entriesLock.Unlock() - entry, ok := h.entries[hash] - if !ok || entry.msgID != "" { + entries, ok := h.entries[hash] + if !ok { return } - entry.closeWaitChannel() + for idx, entry := range entries { + if entry.msgID == "" && matchToList(entry.toList, toList) { + entry.closeWaitChannel() - delete(h.entries, hash) + remaining := xslices.Remove(entries, idx, 1) + if len(remaining) != 0 { + h.entries[hash] = remaining + } else { + delete(h.entries, hash) + } + } + } } func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time) (string, bool, error) { @@ -200,7 +225,7 @@ func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time defer h.entriesLock.Unlock() if entry, ok := h.entries[hash]; ok { - return entry.msgID, true, nil + return entry[0].msgID, true, nil } return "", false, nil @@ -211,7 +236,7 @@ func (h *sendRecorder) getWaitCh(hash string) (<-chan struct{}, bool) { defer h.entriesLock.Unlock() if entry, ok := h.entries[hash]; ok { - return entry.waitCh, true + return entry[0].waitCh, true } return nil, false diff --git a/internal/user/send_recorder_test.go b/internal/user/send_recorder_test.go index 0aabe90b..942df28e 100644 --- a/internal/user/send_recorder_test.go +++ b/internal/user/send_recorder_test.go @@ -35,7 +35,7 @@ func TestSendHasher_Insert(t *testing.T) { require.NotEmpty(t, hash1) // Simulate successfully sending the message. - h.addMessageID(hash1, "abc") + h.signalMessageSent(hash1, "abc", nil) // Inserting a message with the same hash should return false. _, ok, err = testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -59,7 +59,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) { require.NotEmpty(t, hash1) // Simulate successfully sending the message. - h.addMessageID(hash1, "abc") + h.signalMessageSent(hash1, "abc", nil) // Wait for the entry to expire. time.Sleep(time.Second) @@ -106,7 +106,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) { // Simulate successfully sending the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.addMessageID(hash, "abc") + h.signalMessageSent(hash, "abc", nil) }() // Inserting a message with the same hash should fail. @@ -127,7 +127,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) { // Simulate failing to send the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.removeOnFail(hash) + h.removeOnFail(hash, nil) }() // Inserting a message with the same hash should succeed because the first message failed to send. @@ -163,7 +163,7 @@ func TestSendHasher_HasEntry(t *testing.T) { require.NotEmpty(t, hash) // Simulate successfully sending the message. - h.addMessageID(hash, "abc") + h.signalMessageSent(hash, "abc", nil) // The message was already sent; we should find it in the hasher. messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second)) @@ -184,7 +184,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) { // Simulate successfully sending the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.addMessageID(hash, "abc") + h.signalMessageSent(hash, "abc", nil) }() // The message was already sent; we should find it in the hasher. @@ -197,7 +197,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) { func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) { // There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message // is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be - // inserted as a new entry. The two clients end up sending the message twice and calling the `addMessageID` x2, + // inserted as a new entry. The two clients end up sending the message twice and calling the `signalMessageSent` x2, // resulting in a crash. h := newSendRecorder(sendEntryExpiry) @@ -209,8 +209,8 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) { // Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections // to attempt to send the same message. - h.addMessageID(hash, "abc") - h.addMessageID(hash, "abc") + h.signalMessageSent(hash, "abc", nil) + h.signalMessageSent(hash, "abc", nil) // The message was already sent; we should find it in the hasher. messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second)) @@ -219,6 +219,22 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) { require.Equal(t, "abc", messageID) } +func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testing.T) { + h := newSendRecorder(sendEntryExpiry) + + // Insert a message into the hasher. + hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver ") + require.NoError(t, err) + require.True(t, ok) + require.NotEmpty(t, hash) + + hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver ", "Receiver2 ") + require.NoError(t, err) + require.True(t, ok) + require.NotEmpty(t, hash2) + require.Equal(t, hash, hash2) +} + func TestSendHasher_HasEntry_SendFail(t *testing.T) { h := newSendRecorder(sendEntryExpiry) @@ -231,7 +247,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) { // Simulate failing to send the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.removeOnFail(hash) + h.removeOnFail(hash, nil) }() // The message failed to send; we should not find it in the hasher. @@ -265,7 +281,7 @@ func TestSendHasher_HasEntry_Expired(t *testing.T) { require.NotEmpty(t, hash) // Simulate successfully sending the message. - h.addMessageID(hash, "abc") + h.signalMessageSent(hash, "abc", nil) // Wait for the entry to expire. time.Sleep(time.Second) diff --git a/internal/user/smtp.go b/internal/user/smtp.go index 7351c2ca..58148e72 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -89,7 +89,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader) } // If we fail to send this message, we should remove the hash from the send recorder. - defer user.sendHash.removeOnFail(hash) + defer user.sendHash.removeOnFail(hash, to) // Create a new message parser from the reader. parser, err := parser.New(bytes.NewReader(b)) @@ -162,7 +162,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader) } // If the message was successfully sent, we can update the message ID in the record. - user.sendHash.addMessageID(hash, sent.ID) + user.sendHash.signalMessageSent(hash, sent.ID, to) return nil })