mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-21 09:36:51 +00:00
fix(GODT-2627): Properly handle recording of message with Bcc fields
Ensure the SMTP send recorder properly handles the recording of messages which may have the same body hash but have different recipients. E.g.: send the same message twice to 2 different users via Bcc. The send recorder now maintains a list of send requests and waiting for a message to be sent is done one the oldest of the messages.
This commit is contained in:
@ -25,6 +25,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
@ -34,14 +35,14 @@ const sendEntryExpiry = 30 * time.Minute
|
||||
type sendRecorder struct {
|
||||
expiry time.Duration
|
||||
|
||||
entries map[string]*sendEntry
|
||||
entries map[string][]*sendEntry
|
||||
entriesLock sync.Mutex
|
||||
}
|
||||
|
||||
func newSendRecorder(expiry time.Duration) *sendRecorder {
|
||||
return &sendRecorder{
|
||||
expiry: expiry,
|
||||
entries: make(map[string]*sendEntry),
|
||||
entries: make(map[string][]*sendEntry),
|
||||
}
|
||||
}
|
||||
|
||||
@ -110,25 +111,40 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context, hash string, deadline t
|
||||
return h.hasEntryWait(ctx, hash, deadline)
|
||||
}
|
||||
|
||||
func (h *sendRecorder) removeExpiredUnsafe() {
|
||||
for hash, entry := range h.entries {
|
||||
remaining := xslices.Filter(entry, func(t *sendEntry) bool {
|
||||
return !t.exp.Before(time.Now())
|
||||
})
|
||||
|
||||
if len(remaining) == 0 {
|
||||
delete(h.entries, hash)
|
||||
} else {
|
||||
h.entries[hash] = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sendRecorder) tryInsert(hash string, toList []string) bool {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
for hash, entry := range h.entries {
|
||||
if entry.exp.Before(time.Now()) {
|
||||
delete(h.entries, hash)
|
||||
h.removeExpiredUnsafe()
|
||||
|
||||
entries, ok := h.entries[hash]
|
||||
if ok {
|
||||
for _, entry := range entries {
|
||||
if matchToList(entry.toList, toList) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := h.entries[hash]; ok && matchToList(h.entries[hash].toList, toList) {
|
||||
return false
|
||||
}
|
||||
|
||||
h.entries[hash] = &sendEntry{
|
||||
h.entries[hash] = append(entries, &sendEntry{
|
||||
exp: time.Now().Add(h.expiry),
|
||||
toList: toList,
|
||||
waitCh: make(chan struct{}),
|
||||
}
|
||||
})
|
||||
|
||||
return true
|
||||
}
|
||||
@ -137,11 +153,7 @@ func (h *sendRecorder) hasEntry(hash string) bool {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
for hash, entry := range h.entries {
|
||||
if entry.exp.Before(time.Now()) {
|
||||
delete(h.entries, hash)
|
||||
}
|
||||
}
|
||||
h.removeExpiredUnsafe()
|
||||
|
||||
if _, ok := h.entries[hash]; ok {
|
||||
return true
|
||||
@ -150,33 +162,46 @@ func (h *sendRecorder) hasEntry(hash string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// addMessageID should be called after a message has been successfully sent.
|
||||
func (h *sendRecorder) addMessageID(hash, msgID string) {
|
||||
// signalMessageSent should be called after a message has been successfully sent.
|
||||
func (h *sendRecorder) signalMessageSent(hash, msgID string, toList []string) {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
entry, ok := h.entries[hash]
|
||||
entries, ok := h.entries[hash]
|
||||
if ok {
|
||||
entry.msgID = msgID
|
||||
} else {
|
||||
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
|
||||
for _, entry := range entries {
|
||||
if matchToList(entry.toList, toList) {
|
||||
entry.msgID = msgID
|
||||
entry.closeWaitChannel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
entry.closeWaitChannel()
|
||||
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
|
||||
}
|
||||
|
||||
func (h *sendRecorder) removeOnFail(hash string) {
|
||||
func (h *sendRecorder) removeOnFail(hash string, toList []string) {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
entry, ok := h.entries[hash]
|
||||
if !ok || entry.msgID != "" {
|
||||
entries, ok := h.entries[hash]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
entry.closeWaitChannel()
|
||||
for idx, entry := range entries {
|
||||
if entry.msgID == "" && matchToList(entry.toList, toList) {
|
||||
entry.closeWaitChannel()
|
||||
|
||||
delete(h.entries, hash)
|
||||
remaining := xslices.Remove(entries, idx, 1)
|
||||
if len(remaining) != 0 {
|
||||
h.entries[hash] = remaining
|
||||
} else {
|
||||
delete(h.entries, hash)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time) (string, bool, error) {
|
||||
@ -200,7 +225,7 @@ func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
if entry, ok := h.entries[hash]; ok {
|
||||
return entry.msgID, true, nil
|
||||
return entry[0].msgID, true, nil
|
||||
}
|
||||
|
||||
return "", false, nil
|
||||
@ -211,7 +236,7 @@ func (h *sendRecorder) getWaitCh(hash string) (<-chan struct{}, bool) {
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
if entry, ok := h.entries[hash]; ok {
|
||||
return entry.waitCh, true
|
||||
return entry[0].waitCh, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
|
||||
Reference in New Issue
Block a user