mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-18 08:06:59 +00:00
fix(GODT-2627): Properly handle recording of message with Bcc fields
Ensure the SMTP send recorder properly handles the recording of messages which may have the same body hash but have different recipients. E.g.: send the same message twice to 2 different users via Bcc. The send recorder now maintains a list of send requests and waiting for a message to be sent is done one the oldest of the messages.
This commit is contained in:
@ -25,6 +25,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/rfc822"
|
"github.com/ProtonMail/gluon/rfc822"
|
||||||
|
"github.com/bradenaw/juniper/xslices"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
@ -34,14 +35,14 @@ const sendEntryExpiry = 30 * time.Minute
|
|||||||
type sendRecorder struct {
|
type sendRecorder struct {
|
||||||
expiry time.Duration
|
expiry time.Duration
|
||||||
|
|
||||||
entries map[string]*sendEntry
|
entries map[string][]*sendEntry
|
||||||
entriesLock sync.Mutex
|
entriesLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSendRecorder(expiry time.Duration) *sendRecorder {
|
func newSendRecorder(expiry time.Duration) *sendRecorder {
|
||||||
return &sendRecorder{
|
return &sendRecorder{
|
||||||
expiry: expiry,
|
expiry: expiry,
|
||||||
entries: make(map[string]*sendEntry),
|
entries: make(map[string][]*sendEntry),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,25 +111,40 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context, hash string, deadline t
|
|||||||
return h.hasEntryWait(ctx, hash, deadline)
|
return h.hasEntryWait(ctx, hash, deadline)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *sendRecorder) removeExpiredUnsafe() {
|
||||||
|
for hash, entry := range h.entries {
|
||||||
|
remaining := xslices.Filter(entry, func(t *sendEntry) bool {
|
||||||
|
return !t.exp.Before(time.Now())
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(remaining) == 0 {
|
||||||
|
delete(h.entries, hash)
|
||||||
|
} else {
|
||||||
|
h.entries[hash] = remaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *sendRecorder) tryInsert(hash string, toList []string) bool {
|
func (h *sendRecorder) tryInsert(hash string, toList []string) bool {
|
||||||
h.entriesLock.Lock()
|
h.entriesLock.Lock()
|
||||||
defer h.entriesLock.Unlock()
|
defer h.entriesLock.Unlock()
|
||||||
|
|
||||||
for hash, entry := range h.entries {
|
h.removeExpiredUnsafe()
|
||||||
if entry.exp.Before(time.Now()) {
|
|
||||||
delete(h.entries, hash)
|
entries, ok := h.entries[hash]
|
||||||
|
if ok {
|
||||||
|
for _, entry := range entries {
|
||||||
|
if matchToList(entry.toList, toList) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := h.entries[hash]; ok && matchToList(h.entries[hash].toList, toList) {
|
h.entries[hash] = append(entries, &sendEntry{
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
h.entries[hash] = &sendEntry{
|
|
||||||
exp: time.Now().Add(h.expiry),
|
exp: time.Now().Add(h.expiry),
|
||||||
toList: toList,
|
toList: toList,
|
||||||
waitCh: make(chan struct{}),
|
waitCh: make(chan struct{}),
|
||||||
}
|
})
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -137,11 +153,7 @@ func (h *sendRecorder) hasEntry(hash string) bool {
|
|||||||
h.entriesLock.Lock()
|
h.entriesLock.Lock()
|
||||||
defer h.entriesLock.Unlock()
|
defer h.entriesLock.Unlock()
|
||||||
|
|
||||||
for hash, entry := range h.entries {
|
h.removeExpiredUnsafe()
|
||||||
if entry.exp.Before(time.Now()) {
|
|
||||||
delete(h.entries, hash)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := h.entries[hash]; ok {
|
if _, ok := h.entries[hash]; ok {
|
||||||
return true
|
return true
|
||||||
@ -150,33 +162,46 @@ func (h *sendRecorder) hasEntry(hash string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// addMessageID should be called after a message has been successfully sent.
|
// signalMessageSent should be called after a message has been successfully sent.
|
||||||
func (h *sendRecorder) addMessageID(hash, msgID string) {
|
func (h *sendRecorder) signalMessageSent(hash, msgID string, toList []string) {
|
||||||
h.entriesLock.Lock()
|
h.entriesLock.Lock()
|
||||||
defer h.entriesLock.Unlock()
|
defer h.entriesLock.Unlock()
|
||||||
|
|
||||||
entry, ok := h.entries[hash]
|
entries, ok := h.entries[hash]
|
||||||
if ok {
|
if ok {
|
||||||
entry.msgID = msgID
|
for _, entry := range entries {
|
||||||
} else {
|
if matchToList(entry.toList, toList) {
|
||||||
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
|
entry.msgID = msgID
|
||||||
|
entry.closeWaitChannel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
entry.closeWaitChannel()
|
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sendRecorder) removeOnFail(hash string) {
|
func (h *sendRecorder) removeOnFail(hash string, toList []string) {
|
||||||
h.entriesLock.Lock()
|
h.entriesLock.Lock()
|
||||||
defer h.entriesLock.Unlock()
|
defer h.entriesLock.Unlock()
|
||||||
|
|
||||||
entry, ok := h.entries[hash]
|
entries, ok := h.entries[hash]
|
||||||
if !ok || entry.msgID != "" {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entry.closeWaitChannel()
|
for idx, entry := range entries {
|
||||||
|
if entry.msgID == "" && matchToList(entry.toList, toList) {
|
||||||
|
entry.closeWaitChannel()
|
||||||
|
|
||||||
delete(h.entries, hash)
|
remaining := xslices.Remove(entries, idx, 1)
|
||||||
|
if len(remaining) != 0 {
|
||||||
|
h.entries[hash] = remaining
|
||||||
|
} else {
|
||||||
|
delete(h.entries, hash)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time) (string, bool, error) {
|
func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time) (string, bool, error) {
|
||||||
@ -200,7 +225,7 @@ func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time
|
|||||||
defer h.entriesLock.Unlock()
|
defer h.entriesLock.Unlock()
|
||||||
|
|
||||||
if entry, ok := h.entries[hash]; ok {
|
if entry, ok := h.entries[hash]; ok {
|
||||||
return entry.msgID, true, nil
|
return entry[0].msgID, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", false, nil
|
return "", false, nil
|
||||||
@ -211,7 +236,7 @@ func (h *sendRecorder) getWaitCh(hash string) (<-chan struct{}, bool) {
|
|||||||
defer h.entriesLock.Unlock()
|
defer h.entriesLock.Unlock()
|
||||||
|
|
||||||
if entry, ok := h.entries[hash]; ok {
|
if entry, ok := h.entries[hash]; ok {
|
||||||
return entry.waitCh, true
|
return entry[0].waitCh, true
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
|
|||||||
@ -35,7 +35,7 @@ func TestSendHasher_Insert(t *testing.T) {
|
|||||||
require.NotEmpty(t, hash1)
|
require.NotEmpty(t, hash1)
|
||||||
|
|
||||||
// Simulate successfully sending the message.
|
// Simulate successfully sending the message.
|
||||||
h.addMessageID(hash1, "abc")
|
h.signalMessageSent(hash1, "abc", nil)
|
||||||
|
|
||||||
// Inserting a message with the same hash should return false.
|
// Inserting a message with the same hash should return false.
|
||||||
_, ok, err = testTryInsert(h, literal1, time.Now().Add(time.Second))
|
_, ok, err = testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -59,7 +59,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) {
|
|||||||
require.NotEmpty(t, hash1)
|
require.NotEmpty(t, hash1)
|
||||||
|
|
||||||
// Simulate successfully sending the message.
|
// Simulate successfully sending the message.
|
||||||
h.addMessageID(hash1, "abc")
|
h.signalMessageSent(hash1, "abc", nil)
|
||||||
|
|
||||||
// Wait for the entry to expire.
|
// Wait for the entry to expire.
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@ -106,7 +106,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) {
|
|||||||
// Simulate successfully sending the message after half a second.
|
// Simulate successfully sending the message after half a second.
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Millisecond * 500)
|
time.Sleep(time.Millisecond * 500)
|
||||||
h.addMessageID(hash, "abc")
|
h.signalMessageSent(hash, "abc", nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Inserting a message with the same hash should fail.
|
// Inserting a message with the same hash should fail.
|
||||||
@ -127,7 +127,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) {
|
|||||||
// Simulate failing to send the message after half a second.
|
// Simulate failing to send the message after half a second.
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Millisecond * 500)
|
time.Sleep(time.Millisecond * 500)
|
||||||
h.removeOnFail(hash)
|
h.removeOnFail(hash, nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Inserting a message with the same hash should succeed because the first message failed to send.
|
// Inserting a message with the same hash should succeed because the first message failed to send.
|
||||||
@ -163,7 +163,7 @@ func TestSendHasher_HasEntry(t *testing.T) {
|
|||||||
require.NotEmpty(t, hash)
|
require.NotEmpty(t, hash)
|
||||||
|
|
||||||
// Simulate successfully sending the message.
|
// Simulate successfully sending the message.
|
||||||
h.addMessageID(hash, "abc")
|
h.signalMessageSent(hash, "abc", nil)
|
||||||
|
|
||||||
// The message was already sent; we should find it in the hasher.
|
// The message was already sent; we should find it in the hasher.
|
||||||
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -184,7 +184,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
|
|||||||
// Simulate successfully sending the message after half a second.
|
// Simulate successfully sending the message after half a second.
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Millisecond * 500)
|
time.Sleep(time.Millisecond * 500)
|
||||||
h.addMessageID(hash, "abc")
|
h.signalMessageSent(hash, "abc", nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// The message was already sent; we should find it in the hasher.
|
// The message was already sent; we should find it in the hasher.
|
||||||
@ -197,7 +197,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
|
|||||||
func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
|
func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
|
||||||
// There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message
|
// There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message
|
||||||
// is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be
|
// is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be
|
||||||
// inserted as a new entry. The two clients end up sending the message twice and calling the `addMessageID` x2,
|
// inserted as a new entry. The two clients end up sending the message twice and calling the `signalMessageSent` x2,
|
||||||
// resulting in a crash.
|
// resulting in a crash.
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := newSendRecorder(sendEntryExpiry)
|
||||||
|
|
||||||
@ -209,8 +209,8 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
|
|||||||
|
|
||||||
// Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections
|
// Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections
|
||||||
// to attempt to send the same message.
|
// to attempt to send the same message.
|
||||||
h.addMessageID(hash, "abc")
|
h.signalMessageSent(hash, "abc", nil)
|
||||||
h.addMessageID(hash, "abc")
|
h.signalMessageSent(hash, "abc", nil)
|
||||||
|
|
||||||
// The message was already sent; we should find it in the hasher.
|
// The message was already sent; we should find it in the hasher.
|
||||||
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -219,6 +219,22 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
|
|||||||
require.Equal(t, "abc", messageID)
|
require.Equal(t, "abc", messageID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testing.T) {
|
||||||
|
h := newSendRecorder(sendEntryExpiry)
|
||||||
|
|
||||||
|
// Insert a message into the hasher.
|
||||||
|
hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver <receiver@pm.me>")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEmpty(t, hash)
|
||||||
|
|
||||||
|
hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver <receiver@pm.me>", "Receiver2 <receiver2@pm.me>")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEmpty(t, hash2)
|
||||||
|
require.Equal(t, hash, hash2)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSendHasher_HasEntry_SendFail(t *testing.T) {
|
func TestSendHasher_HasEntry_SendFail(t *testing.T) {
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := newSendRecorder(sendEntryExpiry)
|
||||||
|
|
||||||
@ -231,7 +247,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) {
|
|||||||
// Simulate failing to send the message after half a second.
|
// Simulate failing to send the message after half a second.
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Millisecond * 500)
|
time.Sleep(time.Millisecond * 500)
|
||||||
h.removeOnFail(hash)
|
h.removeOnFail(hash, nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// The message failed to send; we should not find it in the hasher.
|
// The message failed to send; we should not find it in the hasher.
|
||||||
@ -265,7 +281,7 @@ func TestSendHasher_HasEntry_Expired(t *testing.T) {
|
|||||||
require.NotEmpty(t, hash)
|
require.NotEmpty(t, hash)
|
||||||
|
|
||||||
// Simulate successfully sending the message.
|
// Simulate successfully sending the message.
|
||||||
h.addMessageID(hash, "abc")
|
h.signalMessageSent(hash, "abc", nil)
|
||||||
|
|
||||||
// Wait for the entry to expire.
|
// Wait for the entry to expire.
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|||||||
@ -89,7 +89,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If we fail to send this message, we should remove the hash from the send recorder.
|
// If we fail to send this message, we should remove the hash from the send recorder.
|
||||||
defer user.sendHash.removeOnFail(hash)
|
defer user.sendHash.removeOnFail(hash, to)
|
||||||
|
|
||||||
// Create a new message parser from the reader.
|
// Create a new message parser from the reader.
|
||||||
parser, err := parser.New(bytes.NewReader(b))
|
parser, err := parser.New(bytes.NewReader(b))
|
||||||
@ -162,7 +162,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the message was successfully sent, we can update the message ID in the record.
|
// If the message was successfully sent, we can update the message ID in the record.
|
||||||
user.sendHash.addMessageID(hash, sent.ID)
|
user.sendHash.signalMessageSent(hash, sent.ID, to)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user