diff --git a/internal/user/imap.go b/internal/user/imap.go index 33effdb8..f6257f4b 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -180,8 +180,14 @@ func (conn *imapConnector) CreateMessage( flags imap.FlagSet, date time.Time, ) (imap.Message, []byte, error) { + // Compute the hash of the message (to match it against SMTP messages). + hash, err := getMessageHash(literal) + if err != nil { + return imap.Message{}, nil, err + } + // Check if we already tried to send this message recently. - if messageID, ok, err := conn.sendHash.hasEntryWait(ctx, literal, time.Now().Add(90*time.Second)); err != nil { + if messageID, ok, err := conn.sendHash.hasEntryWait(ctx, hash, time.Now().Add(90*time.Second)); err != nil { return imap.Message{}, nil, fmt.Errorf("failed to check send hash: %w", err) } else if ok { message, err := conn.client.GetMessage(ctx, messageID) diff --git a/internal/user/send_recorder.go b/internal/user/send_recorder.go index e214acb6..378ea67e 100644 --- a/internal/user/send_recorder.go +++ b/internal/user/send_recorder.go @@ -34,7 +34,6 @@ import ( const sendEntryExpiry = 30 * time.Minute type sendRecorder struct { - hasher func([]byte) (string, error) expiry time.Duration entries map[string]*sendEntry @@ -43,7 +42,6 @@ type sendRecorder struct { func newSendRecorder(expiry time.Duration) *sendRecorder { return &sendRecorder{ - hasher: getMessageHash, expiry: expiry, entries: make(map[string]*sendEntry), } @@ -58,40 +56,30 @@ type sendEntry struct { // tryInsertWait tries to insert the given message into the send recorder. // If an entry already exists but it was not sent yet, it waits. // It returns whether an entry could be inserted and an error if it times out while waiting. -func (h *sendRecorder) tryInsertWait(ctx context.Context, b []byte, deadline time.Time) (string, bool, error) { - hash, err := h.hasher(b) - if err != nil { - return "", false, fmt.Errorf("failed to hash message: %w", err) - } - +func (h *sendRecorder) tryInsertWait(ctx context.Context, hash string, deadline time.Time) (bool, error) { // If we successfully inserted the hash, we can return true. if h.tryInsert(hash) { - return hash, true, nil + return true, nil } // A message with this hash is already being sent; wait for it. _, wasSent, err := h.wait(ctx, hash, deadline) if err != nil { - return "", false, fmt.Errorf("failed to wait for message to be sent: %w", err) + return false, fmt.Errorf("failed to wait for message to be sent: %w", err) } // If the message failed to send, try to insert it again. if !wasSent { - return h.tryInsertWait(ctx, b, deadline) + return h.tryInsertWait(ctx, hash, deadline) } - return hash, false, nil + return false, nil } // hasEntryWait returns whether the given message already exists in the send recorder. // If it does, it waits for its ID to be known, then returns it and true. // If no entry exists, or it times out while waiting for its ID to be known, it returns false. -func (h *sendRecorder) hasEntryWait(ctx context.Context, b []byte, deadline time.Time) (string, bool, error) { - hash, err := h.hasher(b) - if err != nil { - return "", false, fmt.Errorf("failed to hash message: %w", err) - } - +func (h *sendRecorder) hasEntryWait(ctx context.Context, hash string, deadline time.Time) (string, bool, error) { if !h.hasEntry(hash) { return "", false, nil } @@ -107,7 +95,7 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context, b []byte, deadline time return messageID, true, nil } - return h.hasEntryWait(ctx, b, deadline) + return h.hasEntryWait(ctx, hash, deadline) } func (h *sendRecorder) tryInsert(hash string) bool { diff --git a/internal/user/send_recorder_test.go b/internal/user/send_recorder_test.go index b5f9dea1..ffe291c6 100644 --- a/internal/user/send_recorder_test.go +++ b/internal/user/send_recorder_test.go @@ -29,7 +29,7 @@ func TestSendHasher_Insert(t *testing.T) { h := newSendRecorder(sendEntryExpiry) // Insert a message into the hasher. - hash1, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash1) @@ -38,12 +38,12 @@ func TestSendHasher_Insert(t *testing.T) { h.addMessageID(hash1, "abc") // Inserting a message with the same hash should return false. - _, ok, err = h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + _, ok, err = testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.False(t, ok) // Inserting a message with a different hash should return true. - hash2, ok, err := h.tryInsertWait(context.Background(), []byte(literal2), time.Now().Add(time.Second)) + hash2, ok, err := testTryInsert(h, literal2, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash2) @@ -53,7 +53,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) { h := newSendRecorder(time.Second) // Insert a message into the hasher. - hash1, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash1) @@ -65,7 +65,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) { time.Sleep(time.Second) // Inserting a message with the same hash should return true because the previous entry has since expired. - hash2, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) @@ -77,7 +77,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) { h := newSendRecorder(sendEntryExpiry) // Insert a message into the hasher. - hash, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash) @@ -89,7 +89,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) { }() // Inserting a message with the same hash should fail. - _, ok, err = h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + _, ok, err = testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.False(t, ok) } @@ -98,7 +98,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) { h := newSendRecorder(sendEntryExpiry) // Insert a message into the hasher. - hash, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash) @@ -110,7 +110,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) { }() // Inserting a message with the same hash should succeed because the first message failed to send. - hash2, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) @@ -122,13 +122,13 @@ func TestSendHasher_Wait_Timeout(t *testing.T) { h := newSendRecorder(sendEntryExpiry) // Insert a message into the hasher. - hash, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash) // We should fail to insert because the message is not sent within the timeout period. - _, _, err = h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + _, _, err = testTryInsert(h, literal1, time.Now().Add(time.Second)) require.Error(t, err) } @@ -136,7 +136,7 @@ func TestSendHasher_HasEntry(t *testing.T) { h := newSendRecorder(sendEntryExpiry) // Insert a message into the hasher. - hash, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash) @@ -145,7 +145,7 @@ func TestSendHasher_HasEntry(t *testing.T) { h.addMessageID(hash, "abc") // The message was already sent; we should find it in the hasher. - messageID, ok, err := h.hasEntryWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.Equal(t, "abc", messageID) @@ -155,7 +155,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) { h := newSendRecorder(sendEntryExpiry) // Insert a message into the hasher. - hash, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash) @@ -167,7 +167,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) { }() // The message was already sent; we should find it in the hasher. - messageID, ok, err := h.hasEntryWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.Equal(t, "abc", messageID) @@ -177,7 +177,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) { h := newSendRecorder(sendEntryExpiry) // Insert a message into the hasher. - hash, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash) @@ -189,7 +189,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) { }() // The message failed to send; we should not find it in the hasher. - _, ok, err = h.hasEntryWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + _, ok, err = testHasEntry(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.False(t, ok) } @@ -198,13 +198,13 @@ func TestSendHasher_HasEntry_Timeout(t *testing.T) { h := newSendRecorder(sendEntryExpiry) // Insert a message into the hasher. - hash, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash) // The message is never sent; we should not find it in the hasher. - _, ok, err = h.hasEntryWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + _, ok, err = testHasEntry(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.False(t, ok) } @@ -213,7 +213,7 @@ func TestSendHasher_HasEntry_Expired(t *testing.T) { h := newSendRecorder(time.Second) // Insert a message into the hasher. - hash, ok, err := h.tryInsertWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.True(t, ok) require.NotEmpty(t, hash) @@ -225,7 +225,7 @@ func TestSendHasher_HasEntry_Expired(t *testing.T) { time.Sleep(time.Second) // The entry has expired; we should not find it in the hasher. - _, ok, err = h.hasEntryWait(context.Background(), []byte(literal1), time.Now().Add(time.Second)) + _, ok, err = testHasEntry(h, literal1, time.Now().Add(time.Second)) require.NoError(t, err) require.False(t, ok) } @@ -348,3 +348,26 @@ func TestGetMessageHash(t *testing.T) { }) } } + +func testTryInsert(h *sendRecorder, literal string, deadline time.Time) (string, bool, error) { + hash, err := getMessageHash([]byte(literal)) + if err != nil { + return "", false, err + } + + ok, err := h.tryInsertWait(context.Background(), hash, deadline) + if err != nil { + return "", false, err + } + + return hash, ok, nil +} + +func testHasEntry(h *sendRecorder, literal string, deadline time.Time) (string, bool, error) { + hash, err := getMessageHash([]byte(literal)) + if err != nil { + return "", false, err + } + + return h.hasEntryWait(context.Background(), hash, deadline) +} diff --git a/internal/user/smtp.go b/internal/user/smtp.go index 1d27837f..723dfdcf 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -51,14 +51,21 @@ func (user *User) sendMail(authID string, emails []string, from string, to []str return fmt.Errorf("failed to read message: %w", err) } - // Check if we already tried to send this message recently. - hash, ok, err := user.sendHash.tryInsertWait(ctx, b, time.Now().Add(90*time.Second)) + // Compute the hash of the message (to match it against SMTP messages). + hash, err := getMessageHash(b) if err != nil { + return err + } + + // Check if we already tried to send this message recently. + if ok, err := user.sendHash.tryInsertWait(ctx, hash, time.Now().Add(90*time.Second)); err != nil { return fmt.Errorf("failed to check send hash: %w", err) } else if !ok { user.log.Warn("A duplicate message was already sent recently, skipping") return nil } + + // If we fail to send this message, we should remove the hash from the send recorder. defer user.sendHash.removeOnFail(hash) // Create a new message parser from the reader.