diff --git a/internal/user/send_recorder.go b/internal/user/send_recorder.go index 228bbfe6..306f1384 100644 --- a/internal/user/send_recorder.go +++ b/internal/user/send_recorder.go @@ -25,6 +25,7 @@ import ( "time" "github.com/ProtonMail/gluon/rfc822" + "github.com/bradenaw/juniper/xslices" "github.com/sirupsen/logrus" "golang.org/x/exp/slices" ) @@ -34,14 +35,14 @@ const sendEntryExpiry = 30 * time.Minute type sendRecorder struct { expiry time.Duration - entries map[string]*sendEntry + entries map[string][]*sendEntry entriesLock sync.Mutex } func newSendRecorder(expiry time.Duration) *sendRecorder { return &sendRecorder{ expiry: expiry, - entries: make(map[string]*sendEntry), + entries: make(map[string][]*sendEntry), } } @@ -110,25 +111,40 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context, hash string, deadline t return h.hasEntryWait(ctx, hash, deadline) } +func (h *sendRecorder) removeExpiredUnsafe() { + for hash, entry := range h.entries { + remaining := xslices.Filter(entry, func(t *sendEntry) bool { + return !t.exp.Before(time.Now()) + }) + + if len(remaining) == 0 { + delete(h.entries, hash) + } else { + h.entries[hash] = remaining + } + } +} + func (h *sendRecorder) tryInsert(hash string, toList []string) bool { h.entriesLock.Lock() defer h.entriesLock.Unlock() - for hash, entry := range h.entries { - if entry.exp.Before(time.Now()) { - delete(h.entries, hash) + h.removeExpiredUnsafe() + + entries, ok := h.entries[hash] + if ok { + for _, entry := range entries { + if matchToList(entry.toList, toList) { + return false + } } } - if _, ok := h.entries[hash]; ok && matchToList(h.entries[hash].toList, toList) { - return false - } - - h.entries[hash] = &sendEntry{ + h.entries[hash] = append(entries, &sendEntry{ exp: time.Now().Add(h.expiry), toList: toList, waitCh: make(chan struct{}), - } + }) return true } @@ -137,11 +153,7 @@ func (h *sendRecorder) hasEntry(hash string) bool { h.entriesLock.Lock() defer h.entriesLock.Unlock() - for hash, entry := range h.entries { - if entry.exp.Before(time.Now()) { - delete(h.entries, hash) - } - } + h.removeExpiredUnsafe() if _, ok := h.entries[hash]; ok { return true @@ -150,33 +162,46 @@ func (h *sendRecorder) hasEntry(hash string) bool { return false } -// addMessageID should be called after a message has been successfully sent. -func (h *sendRecorder) addMessageID(hash, msgID string) { +// signalMessageSent should be called after a message has been successfully sent. +func (h *sendRecorder) signalMessageSent(hash, msgID string, toList []string) { h.entriesLock.Lock() defer h.entriesLock.Unlock() - entry, ok := h.entries[hash] + entries, ok := h.entries[hash] if ok { - entry.msgID = msgID - } else { - logrus.Warn("Cannot add message ID to send hash entry, it may have expired") + for _, entry := range entries { + if matchToList(entry.toList, toList) { + entry.msgID = msgID + entry.closeWaitChannel() + return + } + } } - entry.closeWaitChannel() + logrus.Warn("Cannot add message ID to send hash entry, it may have expired") } -func (h *sendRecorder) removeOnFail(hash string) { +func (h *sendRecorder) removeOnFail(hash string, toList []string) { h.entriesLock.Lock() defer h.entriesLock.Unlock() - entry, ok := h.entries[hash] - if !ok || entry.msgID != "" { + entries, ok := h.entries[hash] + if !ok { return } - entry.closeWaitChannel() + for idx, entry := range entries { + if entry.msgID == "" && matchToList(entry.toList, toList) { + entry.closeWaitChannel() - delete(h.entries, hash) + remaining := xslices.Remove(entries, idx, 1) + if len(remaining) != 0 { + h.entries[hash] = remaining + } else { + delete(h.entries, hash) + } + } + } } func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time) (string, bool, error) { @@ -200,7 +225,7 @@ func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time defer h.entriesLock.Unlock() if entry, ok := h.entries[hash]; ok { - return entry.msgID, true, nil + return entry[0].msgID, true, nil } return "", false, nil @@ -211,7 +236,7 @@ func (h *sendRecorder) getWaitCh(hash string) (<-chan struct{}, bool) { defer h.entriesLock.Unlock() if entry, ok := h.entries[hash]; ok { - return entry.waitCh, true + return entry[0].waitCh, true } return nil, false diff --git a/internal/user/send_recorder_test.go b/internal/user/send_recorder_test.go index 0aabe90b..942df28e 100644 --- a/internal/user/send_recorder_test.go +++ b/internal/user/send_recorder_test.go @@ -35,7 +35,7 @@ func TestSendHasher_Insert(t *testing.T) { require.NotEmpty(t, hash1) // Simulate successfully sending the message. - h.addMessageID(hash1, "abc") + h.signalMessageSent(hash1, "abc", nil) // Inserting a message with the same hash should return false. _, ok, err = testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -59,7 +59,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) { require.NotEmpty(t, hash1) // Simulate successfully sending the message. - h.addMessageID(hash1, "abc") + h.signalMessageSent(hash1, "abc", nil) // Wait for the entry to expire. time.Sleep(time.Second) @@ -106,7 +106,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) { // Simulate successfully sending the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.addMessageID(hash, "abc") + h.signalMessageSent(hash, "abc", nil) }() // Inserting a message with the same hash should fail. @@ -127,7 +127,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) { // Simulate failing to send the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.removeOnFail(hash) + h.removeOnFail(hash, nil) }() // Inserting a message with the same hash should succeed because the first message failed to send. @@ -163,7 +163,7 @@ func TestSendHasher_HasEntry(t *testing.T) { require.NotEmpty(t, hash) // Simulate successfully sending the message. - h.addMessageID(hash, "abc") + h.signalMessageSent(hash, "abc", nil) // The message was already sent; we should find it in the hasher. messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second)) @@ -184,7 +184,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) { // Simulate successfully sending the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.addMessageID(hash, "abc") + h.signalMessageSent(hash, "abc", nil) }() // The message was already sent; we should find it in the hasher. @@ -197,7 +197,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) { func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) { // There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message // is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be - // inserted as a new entry. The two clients end up sending the message twice and calling the `addMessageID` x2, + // inserted as a new entry. The two clients end up sending the message twice and calling the `signalMessageSent` x2, // resulting in a crash. h := newSendRecorder(sendEntryExpiry) @@ -209,8 +209,8 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) { // Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections // to attempt to send the same message. - h.addMessageID(hash, "abc") - h.addMessageID(hash, "abc") + h.signalMessageSent(hash, "abc", nil) + h.signalMessageSent(hash, "abc", nil) // The message was already sent; we should find it in the hasher. messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second)) @@ -219,6 +219,22 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) { require.Equal(t, "abc", messageID) } +func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testing.T) { + h := newSendRecorder(sendEntryExpiry) + + // Insert a message into the hasher. + hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver ") + require.NoError(t, err) + require.True(t, ok) + require.NotEmpty(t, hash) + + hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver ", "Receiver2 ") + require.NoError(t, err) + require.True(t, ok) + require.NotEmpty(t, hash2) + require.Equal(t, hash, hash2) +} + func TestSendHasher_HasEntry_SendFail(t *testing.T) { h := newSendRecorder(sendEntryExpiry) @@ -231,7 +247,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) { // Simulate failing to send the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.removeOnFail(hash) + h.removeOnFail(hash, nil) }() // The message failed to send; we should not find it in the hasher. @@ -265,7 +281,7 @@ func TestSendHasher_HasEntry_Expired(t *testing.T) { require.NotEmpty(t, hash) // Simulate successfully sending the message. - h.addMessageID(hash, "abc") + h.signalMessageSent(hash, "abc", nil) // Wait for the entry to expire. time.Sleep(time.Second) diff --git a/internal/user/smtp.go b/internal/user/smtp.go index 7351c2ca..58148e72 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -89,7 +89,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader) } // If we fail to send this message, we should remove the hash from the send recorder. - defer user.sendHash.removeOnFail(hash) + defer user.sendHash.removeOnFail(hash, to) // Create a new message parser from the reader. parser, err := parser.New(bytes.NewReader(b)) @@ -162,7 +162,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader) } // If the message was successfully sent, we can update the message ID in the record. - user.sendHash.addMessageID(hash, sent.ID) + user.sendHash.signalMessageSent(hash, sent.ID, to) return nil })