forked from Silverfish/proton-bridge
feat(GODT-2799): SMTP Service
Refactor code to isolate the SMTP functionality in a dedicated SMTP service for each user as discussed in the Bridge Service Architecture RFC. Some shared types have been moved from `user` to `usertypes` so that they can be shared with Service and User Code. Finally due to lack of recursive imports, the user data SMTP needs access to is hidden behind an interface until the User Identity service is implemented.
This commit is contained in:
@ -15,7 +15,7 @@
|
|||||||
// You should have received a copy of the GNU General Public License
|
// You should have received a copy of the GNU General Public License
|
||||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package user
|
package sendrecorder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -30,11 +30,11 @@ import (
|
|||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
const sendEntryExpiry = 30 * time.Minute
|
const SendEntryExpiry = 30 * time.Minute
|
||||||
|
|
||||||
type SendRecorderID uint64
|
type ID uint64
|
||||||
|
|
||||||
type sendRecorder struct {
|
type SendRecorder struct {
|
||||||
expiry time.Duration
|
expiry time.Duration
|
||||||
|
|
||||||
entries map[string][]*sendEntry
|
entries map[string][]*sendEntry
|
||||||
@ -42,15 +42,15 @@ type sendRecorder struct {
|
|||||||
cancelIDCounter uint64
|
cancelIDCounter uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSendRecorder(expiry time.Duration) *sendRecorder {
|
func NewSendRecorder(expiry time.Duration) *SendRecorder {
|
||||||
return &sendRecorder{
|
return &SendRecorder{
|
||||||
expiry: expiry,
|
expiry: expiry,
|
||||||
entries: make(map[string][]*sendEntry),
|
entries: make(map[string][]*sendEntry),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type sendEntry struct {
|
type sendEntry struct {
|
||||||
srID SendRecorderID
|
srID ID
|
||||||
msgID string
|
msgID string
|
||||||
toList []string
|
toList []string
|
||||||
exp time.Time
|
exp time.Time
|
||||||
@ -65,17 +65,17 @@ func (s *sendEntry) closeWaitChannel() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tryInsertWait tries to insert the given message into the send recorder.
|
// TryInsertWait tries to insert the given message into the send recorder.
|
||||||
// If an entry already exists but it was not sent yet, it waits.
|
// If an entry already exists but it was not sent yet, it waits.
|
||||||
// It returns whether an entry could be inserted and an error if it times out while waiting.
|
// It returns whether an entry could be inserted and an error if it times out while waiting.
|
||||||
func (h *sendRecorder) tryInsertWait(
|
func (h *SendRecorder) TryInsertWait(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
hash string,
|
hash string,
|
||||||
toList []string,
|
toList []string,
|
||||||
deadline time.Time,
|
deadline time.Time,
|
||||||
) (SendRecorderID, bool, error) {
|
) (ID, bool, error) {
|
||||||
// If we successfully inserted the hash, we can return true.
|
// If we successfully inserted the hash, we can return true.
|
||||||
srID, waitCh, ok := h.tryInsert(hash, toList)
|
srID, waitCh, ok := h.TryInsert(hash, toList)
|
||||||
if ok {
|
if ok {
|
||||||
return srID, true, nil
|
return srID, true, nil
|
||||||
}
|
}
|
||||||
@ -88,16 +88,16 @@ func (h *sendRecorder) tryInsertWait(
|
|||||||
|
|
||||||
// If the message failed to send, try to insert it again.
|
// If the message failed to send, try to insert it again.
|
||||||
if !wasSent {
|
if !wasSent {
|
||||||
return h.tryInsertWait(ctx, hash, toList, deadline)
|
return h.TryInsertWait(ctx, hash, toList, deadline)
|
||||||
}
|
}
|
||||||
|
|
||||||
return srID, false, nil
|
return srID, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasEntryWait returns whether the given message already exists in the send recorder.
|
// HasEntryWait returns whether the given message already exists in the send recorder.
|
||||||
// If it does, it waits for its ID to be known, then returns it and true.
|
// If it does, it waits for its ID to be known, then returns it and true.
|
||||||
// If no entry exists, or it times out while waiting for its ID to be known, it returns false.
|
// If no entry exists, or it times out while waiting for its ID to be known, it returns false.
|
||||||
func (h *sendRecorder) hasEntryWait(ctx context.Context,
|
func (h *SendRecorder) HasEntryWait(ctx context.Context,
|
||||||
hash string,
|
hash string,
|
||||||
deadline time.Time,
|
deadline time.Time,
|
||||||
toList []string,
|
toList []string,
|
||||||
@ -118,10 +118,10 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context,
|
|||||||
return messageID, true, nil
|
return messageID, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.hasEntryWait(ctx, hash, deadline, toList)
|
return h.HasEntryWait(ctx, hash, deadline, toList)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sendRecorder) removeExpiredUnsafe() {
|
func (h *SendRecorder) removeExpiredUnsafe() {
|
||||||
for hash, entry := range h.entries {
|
for hash, entry := range h.entries {
|
||||||
remaining := xslices.Filter(entry, func(t *sendEntry) bool {
|
remaining := xslices.Filter(entry, func(t *sendEntry) bool {
|
||||||
return !t.exp.Before(time.Now())
|
return !t.exp.Before(time.Now())
|
||||||
@ -135,7 +135,7 @@ func (h *sendRecorder) removeExpiredUnsafe() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sendRecorder) tryInsert(hash string, toList []string) (SendRecorderID, <-chan struct{}, bool) {
|
func (h *SendRecorder) TryInsert(hash string, toList []string) (ID, <-chan struct{}, bool) {
|
||||||
h.entriesLock.Lock()
|
h.entriesLock.Lock()
|
||||||
defer h.entriesLock.Unlock()
|
defer h.entriesLock.Unlock()
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ func (h *sendRecorder) tryInsert(hash string, toList []string) (SendRecorderID,
|
|||||||
return cancelID, waitCh, true
|
return cancelID, waitCh, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sendRecorder) getEntryWaitInfo(hash string, toList []string) (SendRecorderID, <-chan struct{}, bool) {
|
func (h *SendRecorder) getEntryWaitInfo(hash string, toList []string) (ID, <-chan struct{}, bool) {
|
||||||
h.entriesLock.Lock()
|
h.entriesLock.Lock()
|
||||||
defer h.entriesLock.Unlock()
|
defer h.entriesLock.Unlock()
|
||||||
|
|
||||||
@ -180,8 +180,8 @@ func (h *sendRecorder) getEntryWaitInfo(hash string, toList []string) (SendRecor
|
|||||||
return 0, nil, false
|
return 0, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// signalMessageSent should be called after a message has been successfully sent.
|
// SignalMessageSent should be called after a message has been successfully sent.
|
||||||
func (h *sendRecorder) signalMessageSent(hash string, srID SendRecorderID, msgID string) {
|
func (h *SendRecorder) SignalMessageSent(hash string, srID ID, msgID string) {
|
||||||
h.entriesLock.Lock()
|
h.entriesLock.Lock()
|
||||||
defer h.entriesLock.Unlock()
|
defer h.entriesLock.Unlock()
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ func (h *sendRecorder) signalMessageSent(hash string, srID SendRecorderID, msgID
|
|||||||
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
|
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sendRecorder) removeOnFail(hash string, id SendRecorderID) {
|
func (h *SendRecorder) RemoveOnFail(hash string, id ID) {
|
||||||
h.entriesLock.Lock()
|
h.entriesLock.Lock()
|
||||||
defer h.entriesLock.Unlock()
|
defer h.entriesLock.Unlock()
|
||||||
|
|
||||||
@ -222,11 +222,11 @@ func (h *sendRecorder) removeOnFail(hash string, id SendRecorderID) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sendRecorder) wait(
|
func (h *SendRecorder) wait(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
hash string,
|
hash string,
|
||||||
waitCh <-chan struct{},
|
waitCh <-chan struct{},
|
||||||
srID SendRecorderID,
|
srID ID,
|
||||||
deadline time.Time,
|
deadline time.Time,
|
||||||
) (string, bool, error) {
|
) (string, bool, error) {
|
||||||
ctx, cancel := context.WithDeadline(ctx, deadline)
|
ctx, cancel := context.WithDeadline(ctx, deadline)
|
||||||
@ -254,19 +254,19 @@ func (h *sendRecorder) wait(
|
|||||||
return "", false, nil
|
return "", false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sendRecorder) newSendRecorderID() SendRecorderID {
|
func (h *SendRecorder) newSendRecorderID() ID {
|
||||||
h.cancelIDCounter++
|
h.cancelIDCounter++
|
||||||
return SendRecorderID(h.cancelIDCounter)
|
return ID(h.cancelIDCounter)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMessageHash returns the hash of the given message.
|
// GetMessageHash returns the hash of the given message.
|
||||||
// This takes into account:
|
// This takes into account:
|
||||||
// - the Subject header,
|
// - the Subject header,
|
||||||
// - the From/To/Cc headers,
|
// - the From/To/Cc headers,
|
||||||
// - the Content-Type header of each (leaf) part,
|
// - the Content-Type header of each (leaf) part,
|
||||||
// - the Content-Disposition header of each (leaf) part,
|
// - the Content-Disposition header of each (leaf) part,
|
||||||
// - the (decoded) body of each part.
|
// - the (decoded) body of each part.
|
||||||
func getMessageHash(b []byte) (string, error) {
|
func GetMessageHash(b []byte) (string, error) {
|
||||||
return rfc822.GetMessageHash(b)
|
return rfc822.GetMessageHash(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -15,7 +15,7 @@
|
|||||||
// You should have received a copy of the GNU General Public License
|
// You should have received a copy of the GNU General Public License
|
||||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package user
|
package sendrecorder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -26,7 +26,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestSendHasher_Insert(t *testing.T) {
|
func TestSendHasher_Insert(t *testing.T) {
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srdID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
srdID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -35,7 +35,7 @@ func TestSendHasher_Insert(t *testing.T) {
|
|||||||
require.NotEmpty(t, hash1)
|
require.NotEmpty(t, hash1)
|
||||||
|
|
||||||
// Simulate successfully sending the message.
|
// Simulate successfully sending the message.
|
||||||
h.signalMessageSent(hash1, srdID1, "abc")
|
h.SignalMessageSent(hash1, srdID1, "abc")
|
||||||
|
|
||||||
// Inserting a message with the same hash should return false.
|
// Inserting a message with the same hash should return false.
|
||||||
srdID2, _, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
srdID2, _, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -52,7 +52,7 @@ func TestSendHasher_Insert(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHasher_Insert_Expired(t *testing.T) {
|
func TestSendHasher_Insert_Expired(t *testing.T) {
|
||||||
h := newSendRecorder(time.Second)
|
h := NewSendRecorder(time.Second)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
srID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -61,7 +61,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) {
|
|||||||
require.NotEmpty(t, hash1)
|
require.NotEmpty(t, hash1)
|
||||||
|
|
||||||
// Simulate successfully sending the message.
|
// Simulate successfully sending the message.
|
||||||
h.signalMessageSent(hash1, srID1, "abc")
|
h.SignalMessageSent(hash1, srID1, "abc")
|
||||||
|
|
||||||
// Wait for the entry to expire.
|
// Wait for the entry to expire.
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@ -79,7 +79,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHasher_Insert_DifferentToList(t *testing.T) {
|
func TestSendHasher_Insert_DifferentToList(t *testing.T) {
|
||||||
h := newSendRecorder(time.Second)
|
h := NewSendRecorder(time.Second)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), []string{"abc", "def"}...)
|
srID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), []string{"abc", "def"}...)
|
||||||
@ -101,7 +101,7 @@ func TestSendHasher_Insert_DifferentToList(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHasher_Wait_SendSuccess(t *testing.T) {
|
func TestSendHasher_Wait_SendSuccess(t *testing.T) {
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -112,7 +112,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) {
|
|||||||
// Simulate successfully sending the message after half a second.
|
// Simulate successfully sending the message after half a second.
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Millisecond * 500)
|
time.Sleep(time.Millisecond * 500)
|
||||||
h.signalMessageSent(hash, srID1, "abc")
|
h.SignalMessageSent(hash, srID1, "abc")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Inserting a message with the same hash should fail.
|
// Inserting a message with the same hash should fail.
|
||||||
@ -123,7 +123,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHasher_Wait_SendFail(t *testing.T) {
|
func TestSendHasher_Wait_SendFail(t *testing.T) {
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -134,7 +134,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) {
|
|||||||
// Simulate failing to send the message after half a second.
|
// Simulate failing to send the message after half a second.
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Millisecond * 500)
|
time.Sleep(time.Millisecond * 500)
|
||||||
h.removeOnFail(hash, srID1)
|
h.RemoveOnFail(hash, srID1)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Inserting a message with the same hash should succeed because the first message failed to send.
|
// Inserting a message with the same hash should succeed because the first message failed to send.
|
||||||
@ -148,7 +148,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHasher_Wait_Timeout(t *testing.T) {
|
func TestSendHasher_Wait_Timeout(t *testing.T) {
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -162,7 +162,7 @@ func TestSendHasher_Wait_Timeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHasher_HasEntry(t *testing.T) {
|
func TestSendHasher_HasEntry(t *testing.T) {
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -171,7 +171,7 @@ func TestSendHasher_HasEntry(t *testing.T) {
|
|||||||
require.NotEmpty(t, hash)
|
require.NotEmpty(t, hash)
|
||||||
|
|
||||||
// Simulate successfully sending the message.
|
// Simulate successfully sending the message.
|
||||||
h.signalMessageSent(hash, srID1, "abc")
|
h.SignalMessageSent(hash, srID1, "abc")
|
||||||
|
|
||||||
// The message was already sent; we should find it in the hasher.
|
// The message was already sent; we should find it in the hasher.
|
||||||
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -181,7 +181,7 @@ func TestSendHasher_HasEntry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
|
func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -192,7 +192,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
|
|||||||
// Simulate successfully sending the message after half a second.
|
// Simulate successfully sending the message after half a second.
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Millisecond * 500)
|
time.Sleep(time.Millisecond * 500)
|
||||||
h.signalMessageSent(hash, srID1, "abc")
|
h.SignalMessageSent(hash, srID1, "abc")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// The message was already sent; we should find it in the hasher.
|
// The message was already sent; we should find it in the hasher.
|
||||||
@ -205,9 +205,9 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
|
|||||||
func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
|
func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
|
||||||
// There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message
|
// There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message
|
||||||
// is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be
|
// is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be
|
||||||
// inserted as a new entry. The two clients end up sending the message twice and calling the `signalMessageSent` x2,
|
// inserted as a new entry. The two clients end up sending the message twice and calling the `SignalMessageSent` x2,
|
||||||
// resulting in a crash.
|
// resulting in a crash.
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -217,8 +217,8 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
|
|||||||
|
|
||||||
// Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections
|
// Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections
|
||||||
// to attempt to send the same message.
|
// to attempt to send the same message.
|
||||||
h.signalMessageSent(hash, srID1, "abc")
|
h.SignalMessageSent(hash, srID1, "abc")
|
||||||
h.signalMessageSent(hash, srID1, "abc")
|
h.SignalMessageSent(hash, srID1, "abc")
|
||||||
|
|
||||||
// The message was already sent; we should find it in the hasher.
|
// The message was already sent; we should find it in the hasher.
|
||||||
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -228,7 +228,7 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testing.T) {
|
func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testing.T) {
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver <receiver@pm.me>")
|
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver <receiver@pm.me>")
|
||||||
@ -249,7 +249,7 @@ func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testin
|
|||||||
func TestSendHashed_SameMessageWIthDifferentToListShouldWaitSuccessfullyAfterSend(t *testing.T) {
|
func TestSendHashed_SameMessageWIthDifferentToListShouldWaitSuccessfullyAfterSend(t *testing.T) {
|
||||||
// Check that if we send the same message twice with different recipients and the second message is somehow
|
// Check that if we send the same message twice with different recipients and the second message is somehow
|
||||||
// sent before the first, ensure that we check if the message was sent we wait on the correct object.
|
// sent before the first, ensure that we check if the message was sent we wait on the correct object.
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Minute), "Receiver <receiver@pm.me>")
|
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Minute), "Receiver <receiver@pm.me>")
|
||||||
@ -264,16 +264,16 @@ func TestSendHashed_SameMessageWIthDifferentToListShouldWaitSuccessfullyAfterSen
|
|||||||
require.Equal(t, hash, hash2)
|
require.Equal(t, hash, hash2)
|
||||||
|
|
||||||
// simulate message sent
|
// simulate message sent
|
||||||
h.signalMessageSent(hash2, srID2, "newID")
|
h.SignalMessageSent(hash2, srID2, "newID")
|
||||||
|
|
||||||
// Simulate Wait on message 2
|
// Simulate Wait on message 2
|
||||||
_, ok, err = h.hasEntryWait(context.Background(), hash2, time.Now().Add(time.Second), []string{"Receiver <receiver@pm.me>", "Receiver2 <receiver2@pm.me>"})
|
_, ok, err = h.HasEntryWait(context.Background(), hash2, time.Now().Add(time.Second), []string{"Receiver <receiver@pm.me>", "Receiver2 <receiver2@pm.me>"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHasher_HasEntry_SendFail(t *testing.T) {
|
func TestSendHasher_HasEntry_SendFail(t *testing.T) {
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -284,7 +284,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) {
|
|||||||
// Simulate failing to send the message after half a second.
|
// Simulate failing to send the message after half a second.
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Millisecond * 500)
|
time.Sleep(time.Millisecond * 500)
|
||||||
h.removeOnFail(hash, srID1)
|
h.RemoveOnFail(hash, srID1)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// The message failed to send; we should not find it in the hasher.
|
// The message failed to send; we should not find it in the hasher.
|
||||||
@ -294,7 +294,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHasher_HasEntry_Timeout(t *testing.T) {
|
func TestSendHasher_HasEntry_Timeout(t *testing.T) {
|
||||||
h := newSendRecorder(sendEntryExpiry)
|
h := NewSendRecorder(SendEntryExpiry)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -309,7 +309,7 @@ func TestSendHasher_HasEntry_Timeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendHasher_HasEntry_Expired(t *testing.T) {
|
func TestSendHasher_HasEntry_Expired(t *testing.T) {
|
||||||
h := newSendRecorder(time.Second)
|
h := NewSendRecorder(time.Second)
|
||||||
|
|
||||||
// Insert a message into the hasher.
|
// Insert a message into the hasher.
|
||||||
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||||
@ -318,7 +318,7 @@ func TestSendHasher_HasEntry_Expired(t *testing.T) {
|
|||||||
require.NotEmpty(t, hash)
|
require.NotEmpty(t, hash)
|
||||||
|
|
||||||
// Simulate successfully sending the message.
|
// Simulate successfully sending the message.
|
||||||
h.signalMessageSent(hash, srID1, "abc")
|
h.SignalMessageSent(hash, srID1, "abc")
|
||||||
|
|
||||||
// Wait for the entry to expire.
|
// Wait for the entry to expire.
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@ -432,10 +432,10 @@ func TestGetMessageHash(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
hash1, err := getMessageHash(tt.lit1)
|
hash1, err := GetMessageHash(tt.lit1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
hash2, err := getMessageHash(tt.lit2)
|
hash2, err := GetMessageHash(tt.lit2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if tt.wantEqual {
|
if tt.wantEqual {
|
||||||
@ -447,13 +447,13 @@ func TestGetMessageHash(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testTryInsert(h *sendRecorder, literal string, deadline time.Time, toList ...string) (SendRecorderID, string, bool, error) { //nolint:unparam
|
func testTryInsert(h *SendRecorder, literal string, deadline time.Time, toList ...string) (ID, string, bool, error) { //nolint:unparam
|
||||||
hash, err := getMessageHash([]byte(literal))
|
hash, err := GetMessageHash([]byte(literal))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", false, err
|
return 0, "", false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
srID, ok, err := h.tryInsertWait(context.Background(), hash, toList, deadline)
|
srID, ok, err := h.TryInsertWait(context.Background(), hash, toList, deadline)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", false, err
|
return 0, "", false, err
|
||||||
}
|
}
|
||||||
@ -461,11 +461,11 @@ func testTryInsert(h *sendRecorder, literal string, deadline time.Time, toList .
|
|||||||
return srID, hash, ok, nil
|
return srID, hash, ok, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func testHasEntry(h *sendRecorder, literal string, deadline time.Time, toList ...string) (string, bool, error) { //nolint:unparam
|
func testHasEntry(h *SendRecorder, literal string, deadline time.Time, toList ...string) (string, bool, error) { //nolint:unparam
|
||||||
hash, err := getMessageHash([]byte(literal))
|
hash, err := GetMessageHash([]byte(literal))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", false, err
|
return "", false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.hasEntryWait(context.Background(), hash, deadline, toList)
|
return h.HasEntryWait(context.Background(), hash, deadline, toList)
|
||||||
}
|
}
|
||||||
23
internal/services/smtp/errors.go
Normal file
23
internal/services/smtp/errors.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
// Copyright (c) 2023 Proton AG
|
||||||
|
//
|
||||||
|
// This file is part of Proton Mail Bridge.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package smtp
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var ErrInvalidRecipient = errors.New("invalid recipient")
|
||||||
|
var ErrInvalidReturnPath = errors.New("invalid return path")
|
||||||
132
internal/services/smtp/service.go
Normal file
132
internal/services/smtp/service.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
// Copyright (c) 2023 Proton AG
|
||||||
|
//
|
||||||
|
// This file is part of Proton Mail Bridge.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package smtp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/gluon/async"
|
||||||
|
"github.com/ProtonMail/gluon/logging"
|
||||||
|
"github.com/ProtonMail/gluon/reporter"
|
||||||
|
"github.com/ProtonMail/go-proton-api"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserInterface is just wrapper to avoid recursive go module imports. To be removed when the identity service is ready.
|
||||||
|
type UserInterface interface {
|
||||||
|
ID() string
|
||||||
|
WithSMTPData(context.Context, func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
panicHandler async.PanicHandler
|
||||||
|
cpc *cpc.CPC
|
||||||
|
user UserInterface
|
||||||
|
client *proton.Client
|
||||||
|
recorder *sendrecorder.SendRecorder
|
||||||
|
log *logrus.Entry
|
||||||
|
reporter reporter.Reporter
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(
|
||||||
|
user UserInterface,
|
||||||
|
client *proton.Client,
|
||||||
|
recorder *sendrecorder.SendRecorder,
|
||||||
|
handler async.PanicHandler,
|
||||||
|
reporter reporter.Reporter,
|
||||||
|
) *Service {
|
||||||
|
return &Service{
|
||||||
|
panicHandler: handler,
|
||||||
|
user: user,
|
||||||
|
cpc: cpc.NewCPC(),
|
||||||
|
recorder: recorder,
|
||||||
|
log: logrus.WithFields(logrus.Fields{
|
||||||
|
"user": user.ID(),
|
||||||
|
"service": "smtp",
|
||||||
|
}),
|
||||||
|
reporter: reporter,
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) SendMail(ctx context.Context, authID string, from string, to []string, r io.Reader) error {
|
||||||
|
_, err := s.cpc.Send(ctx, &sendMailReq{
|
||||||
|
authID: authID,
|
||||||
|
from: from,
|
||||||
|
to: to,
|
||||||
|
r: r,
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Start(group *async.Group) {
|
||||||
|
s.log.Debug("Starting service")
|
||||||
|
group.Once(func(ctx context.Context) {
|
||||||
|
logging.DoAnnotated(ctx, func(ctx context.Context) {
|
||||||
|
s.run(ctx)
|
||||||
|
}, logging.Labels{
|
||||||
|
"user": s.user.ID(),
|
||||||
|
"service": "smtp",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) run(ctx context.Context) {
|
||||||
|
s.log.Debug("Starting service main loop")
|
||||||
|
defer s.log.Debug("Exiting service main loop")
|
||||||
|
defer s.cpc.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
|
||||||
|
case request, ok := <-s.cpc.ReceiveCh():
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r := request.Value().(type) {
|
||||||
|
case *sendMailReq:
|
||||||
|
s.log.Debug("Received send mail request")
|
||||||
|
err := s.sendMail(ctx, r)
|
||||||
|
request.Reply(ctx, nil, err)
|
||||||
|
|
||||||
|
default:
|
||||||
|
s.log.Error("Received unknown request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sendMailReq struct {
|
||||||
|
authID string
|
||||||
|
from string
|
||||||
|
to []string
|
||||||
|
r io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) sendMail(ctx context.Context, req *sendMailReq) error {
|
||||||
|
defer async.HandlePanic(s.panicHandler)
|
||||||
|
return s.smtpSendMail(ctx, req.authID, req.from, req.to, req.r)
|
||||||
|
}
|
||||||
@ -15,7 +15,7 @@
|
|||||||
// You should have received a copy of the GNU General Public License
|
// You should have received a copy of the GNU General Public License
|
||||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package user
|
package smtp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -36,7 +36,8 @@ import (
|
|||||||
"github.com/ProtonMail/go-proton-api"
|
"github.com/ProtonMail/go-proton-api"
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
|
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
|
||||||
@ -47,19 +48,14 @@ import (
|
|||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
// sendMail sends an email from the given address to the given recipients.
|
// smtpSendMail sends an email from the given address to the given recipients.
|
||||||
func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error {
|
func (s *Service) smtpSendMail(ctx context.Context, authID string, from string, to []string, r io.Reader) error {
|
||||||
defer async.HandlePanic(user.panicHandler)
|
return s.user.WithSMTPData(ctx, func(ctx context.Context, apiAddrs map[string]proton.Address, user proton.User, vault *vault.User) error {
|
||||||
|
if _, err := usertypes.GetAddrID(apiAddrs, from); err != nil {
|
||||||
return safe.RLockRet(func() error {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if _, err := getAddrID(user.apiAddrs, from); err != nil {
|
|
||||||
return ErrInvalidReturnPath
|
return ErrInvalidReturnPath
|
||||||
}
|
}
|
||||||
|
|
||||||
emails := xslices.Map(maps.Values(user.apiAddrs), func(addr proton.Address) string {
|
emails := xslices.Map(maps.Values(apiAddrs), func(addr proton.Address) string {
|
||||||
return addr.Email
|
return addr.Email
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -71,26 +67,26 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
|||||||
|
|
||||||
// If running a QA build, dump to disk.
|
// If running a QA build, dump to disk.
|
||||||
if err := debugDumpToDisk(b); err != nil {
|
if err := debugDumpToDisk(b); err != nil {
|
||||||
user.log.WithError(err).Warn("Failed to dump message to disk")
|
s.log.WithError(err).Warn("Failed to dump message to disk")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the hash of the message (to match it against SMTP messages).
|
// Compute the hash of the message (to match it against SMTP messages).
|
||||||
hash, err := getMessageHash(b)
|
hash, err := sendrecorder.GetMessageHash(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we already tried to send this message recently.
|
// Check if we already tried to send this message recently.
|
||||||
srID, ok, err := user.sendHash.tryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second))
|
srID, ok, err := s.recorder.TryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check send hash: %w", err)
|
return fmt.Errorf("failed to check send hash: %w", err)
|
||||||
} else if !ok {
|
} else if !ok {
|
||||||
user.log.Warn("A duplicate message was already sent recently, skipping")
|
s.log.Warn("A duplicate message was already sent recently, skipping")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we fail to send this message, we should remove the hash from the send recorder.
|
// If we fail to send this message, we should remove the hash from the send recorder.
|
||||||
defer user.sendHash.removeOnFail(hash, srID)
|
defer s.recorder.RemoveOnFail(hash, srID)
|
||||||
|
|
||||||
// Create a new message parser from the reader.
|
// Create a new message parser from the reader.
|
||||||
parser, err := parser.New(bytes.NewReader(b))
|
parser, err := parser.New(bytes.NewReader(b))
|
||||||
@ -104,17 +100,17 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Load the user's mail settings.
|
// Load the user's mail settings.
|
||||||
settings, err := user.client.GetMailSettings(ctx)
|
settings, err := s.client.GetMailSettings(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get mail settings: %w", err)
|
return fmt.Errorf("failed to get mail settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
addrID, err := getAddrID(user.apiAddrs, from)
|
addrID, err := usertypes.GetAddrID(apiAddrs, from)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return withAddrKR(user.apiUser, user.apiAddrs[addrID], user.vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error {
|
return usertypes.WithAddrKR(user, apiAddrs[addrID], vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error {
|
||||||
// Use the first key for encrypting the message.
|
// Use the first key for encrypting the message.
|
||||||
addrKR, err := addrKR.FirstKey()
|
addrKR, err := addrKR.FirstKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -147,12 +143,10 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send the message using the correct key.
|
// Send the message using the correct key.
|
||||||
sent, err := user.sendWithKey(
|
sent, err := s.sendWithKey(
|
||||||
ctx,
|
ctx,
|
||||||
user.client,
|
|
||||||
user.reporter,
|
|
||||||
authID,
|
authID,
|
||||||
user.vault.AddressMode(),
|
vault.AddressMode(),
|
||||||
settings,
|
settings,
|
||||||
userKR, addrKR,
|
userKR, addrKR,
|
||||||
emails, from, to,
|
emails, from, to,
|
||||||
@ -163,18 +157,16 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the message was successfully sent, we can update the message ID in the record.
|
// If the message was successfully sent, we can update the message ID in the record.
|
||||||
user.sendHash.signalMessageSent(hash, srID, sent.ID)
|
s.recorder.SignalMessageSent(hash, srID, sent.ID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}, user.apiUserLock, user.apiAddrsLock, user.eventLock)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendWithKey sends the message with the given address key.
|
// sendWithKey sends the message with the given address key.
|
||||||
func (user *User) sendWithKey(
|
func (s *Service) sendWithKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *proton.Client,
|
|
||||||
sentry reporter.Reporter,
|
|
||||||
authAddrID string,
|
authAddrID string,
|
||||||
addrMode vault.AddressMode,
|
addrMode vault.AddressMode,
|
||||||
settings proton.MailSettings,
|
settings proton.MailSettings,
|
||||||
@ -188,16 +180,16 @@ func (user *User) sendWithKey(
|
|||||||
if message.InReplyTo != "" {
|
if message.InReplyTo != "" {
|
||||||
references = append(references, message.InReplyTo)
|
references = append(references, message.InReplyTo)
|
||||||
}
|
}
|
||||||
parentID, err := getParentID(ctx, client, authAddrID, addrMode, references)
|
parentID, err := getParentID(ctx, s.client, authAddrID, addrMode, references)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := sentry.ReportMessageWithContext("Failed to get parent ID", reporter.Context{
|
if err := s.reporter.ReportMessageWithContext("Failed to get parent ID", reporter.Context{
|
||||||
"error": err,
|
"error": err,
|
||||||
"references": message.References,
|
"references": message.References,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to report error")
|
logrus.WithError(err).Error("Failed to report error")
|
||||||
}
|
}
|
||||||
|
|
||||||
logrus.WithError(err).Warn("Failed to get parent ID")
|
s.log.WithError(err).Warn("Failed to get parent ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
var decBody string
|
var decBody string
|
||||||
@ -214,7 +206,7 @@ func (user *User) sendWithKey(
|
|||||||
return proton.Message{}, fmt.Errorf("unsupported MIME type: %v", message.MIMEType)
|
return proton.Message{}, fmt.Errorf("unsupported MIME type: %v", message.MIMEType)
|
||||||
}
|
}
|
||||||
|
|
||||||
draft, err := createDraft(ctx, client, addrKR, emails, from, to, parentID, message.InReplyTo, proton.DraftTemplate{
|
draft, err := s.createDraft(ctx, addrKR, emails, from, to, parentID, message.InReplyTo, proton.DraftTemplate{
|
||||||
Subject: message.Subject,
|
Subject: message.Subject,
|
||||||
Body: decBody,
|
Body: decBody,
|
||||||
MIMEType: message.MIMEType,
|
MIMEType: message.MIMEType,
|
||||||
@ -230,12 +222,12 @@ func (user *User) sendWithKey(
|
|||||||
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
attKeys, err := user.createAttachments(ctx, client, addrKR, draft.ID, message.Attachments)
|
attKeys, err := s.createAttachments(ctx, s.client, addrKR, draft.ID, message.Attachments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
recipients, err := user.getRecipients(ctx, client, userKR, settings, draft)
|
recipients, err := s.getRecipients(ctx, s.client, userKR, settings, draft)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err)
|
return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err)
|
||||||
}
|
}
|
||||||
@ -245,7 +237,7 @@ func (user *User) sendWithKey(
|
|||||||
return proton.Message{}, fmt.Errorf("failed to create packages: %w", err)
|
return proton.Message{}, fmt.Errorf("failed to create packages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := client.SendDraft(ctx, draft.ID, req)
|
res, err := s.client.SendDraft(ctx, draft.ID, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return proton.Message{}, fmt.Errorf("failed to send draft: %w", err)
|
return proton.Message{}, fmt.Errorf("failed to send draft: %w", err)
|
||||||
}
|
}
|
||||||
@ -340,9 +332,8 @@ func getParentID(
|
|||||||
return parentID, nil
|
return parentID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createDraft(
|
func (s *Service) createDraft(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *proton.Client,
|
|
||||||
addrKR *crypto.KeyRing,
|
addrKR *crypto.KeyRing,
|
||||||
emails []string,
|
emails []string,
|
||||||
from string,
|
from string,
|
||||||
@ -360,7 +351,7 @@ func createDraft(
|
|||||||
|
|
||||||
// Check that the sending address is owned by the user, and if so, sanitize it.
|
// Check that the sending address is owned by the user, and if so, sanitize it.
|
||||||
if idx := xslices.IndexFunc(emails, func(email string) bool {
|
if idx := xslices.IndexFunc(emails, func(email string) bool {
|
||||||
return strings.EqualFold(email, sanitizeEmail(template.Sender.Address))
|
return strings.EqualFold(email, usertypes.SanitizeEmail(template.Sender.Address))
|
||||||
}); idx < 0 {
|
}); idx < 0 {
|
||||||
return proton.Message{}, fmt.Errorf("address %q is not owned by user", template.Sender.Address)
|
return proton.Message{}, fmt.Errorf("address %q is not owned by user", template.Sender.Address)
|
||||||
} else { //nolint:revive
|
} else { //nolint:revive
|
||||||
@ -389,14 +380,14 @@ func createDraft(
|
|||||||
action = proton.ForwardAction
|
action = proton.ForwardAction
|
||||||
}
|
}
|
||||||
|
|
||||||
return client.CreateDraft(ctx, addrKR, proton.CreateDraftReq{
|
return s.client.CreateDraft(ctx, addrKR, proton.CreateDraftReq{
|
||||||
Message: template,
|
Message: template,
|
||||||
ParentID: parentID,
|
ParentID: parentID,
|
||||||
Action: action,
|
Action: action,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) createAttachments(
|
func (s *Service) createAttachments(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *proton.Client,
|
client *proton.Client,
|
||||||
addrKR *crypto.KeyRing,
|
addrKR *crypto.KeyRing,
|
||||||
@ -409,9 +400,9 @@ func (user *User) createAttachments(
|
|||||||
}
|
}
|
||||||
|
|
||||||
keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) {
|
keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) {
|
||||||
defer async.HandlePanic(user.panicHandler)
|
defer async.HandlePanic(s.panicHandler)
|
||||||
|
|
||||||
logrus.WithFields(logrus.Fields{
|
s.log.WithFields(logrus.Fields{
|
||||||
"name": logging.Sensitive(att.Name),
|
"name": logging.Sensitive(att.Name),
|
||||||
"contentID": att.ContentID,
|
"contentID": att.ContentID,
|
||||||
"disposition": att.Disposition,
|
"disposition": att.Disposition,
|
||||||
@ -480,7 +471,7 @@ func (user *User) createAttachments(
|
|||||||
return attKeys, nil
|
return attKeys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) getRecipients(
|
func (s *Service) getRecipients(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *proton.Client,
|
client *proton.Client,
|
||||||
userKR *crypto.KeyRing,
|
userKR *crypto.KeyRing,
|
||||||
@ -492,7 +483,7 @@ func (user *User) getRecipients(
|
|||||||
})
|
})
|
||||||
|
|
||||||
prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) {
|
prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) {
|
||||||
defer async.HandlePanic(user.panicHandler)
|
defer async.HandlePanic(s.panicHandler)
|
||||||
|
|
||||||
pubKeys, recType, err := client.GetPublicKeys(ctx, recipient)
|
pubKeys, recType, err := client.GetPublicKeys(ctx, recipient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -557,15 +548,6 @@ func getMessageSender(parser *parser.Parser) (string, bool) {
|
|||||||
return address[0].Address, true
|
return address[0].Address, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeEmail(email string) string {
|
|
||||||
splitAt := strings.Split(email, "@")
|
|
||||||
if len(splitAt) != 2 {
|
|
||||||
return email
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Split(splitAt[0], "+")[0] + "@" + splitAt[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
func constructEmail(headerEmail string, addressEmail string) string {
|
func constructEmail(headerEmail string, addressEmail string) string {
|
||||||
splitAtHeader := strings.Split(headerEmail, "@")
|
splitAtHeader := strings.Split(headerEmail, "@")
|
||||||
if len(splitAtHeader) != 2 {
|
if len(splitAtHeader) != 2 {
|
||||||
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
//go:build build_qa
|
//go:build build_qa
|
||||||
|
|
||||||
package user
|
package smtp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
//go:build !build_qa
|
//go:build !build_qa
|
||||||
|
|
||||||
package user
|
package smtp
|
||||||
|
|
||||||
func debugDumpToDisk(_ []byte) error {
|
func debugDumpToDisk(_ []byte) error {
|
||||||
return nil
|
return nil
|
||||||
@ -15,7 +15,7 @@
|
|||||||
// You should have received a copy of the GNU General Public License
|
// You should have received a copy of the GNU General Public License
|
||||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package user
|
package smtp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/ProtonMail/gluon/rfc822"
|
"github.com/ProtonMail/gluon/rfc822"
|
||||||
@ -15,7 +15,7 @@
|
|||||||
// You should have received a copy of the GNU General Public License
|
// You should have received a copy of the GNU General Public License
|
||||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package user
|
package smtp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -15,7 +15,7 @@
|
|||||||
// You should have received a copy of the GNU General Public License
|
// You should have received a copy of the GNU General Public License
|
||||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package user
|
package smtp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
@ -35,6 +35,7 @@ import (
|
|||||||
"github.com/ProtonMail/gopenpgp/v2/constants"
|
"github.com/ProtonMail/gopenpgp/v2/constants"
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||||
"github.com/bradenaw/juniper/xmaps"
|
"github.com/bradenaw/juniper/xmaps"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
@ -62,7 +63,7 @@ func (apm DiagnosticMetadata) BuildMailboxToMessageMap(user *User) (map[string]A
|
|||||||
result := make(map[string]AccountMailboxMap)
|
result := make(map[string]AccountMailboxMap)
|
||||||
|
|
||||||
mode := user.GetAddressMode()
|
mode := user.GetAddressMode()
|
||||||
primaryAddrID, err := getPrimaryAddr(user.apiAddrs)
|
primaryAddrID, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get primary addr for user: %w", err)
|
return nil, fmt.Errorf("failed to get primary addr for user: %w", err)
|
||||||
}
|
}
|
||||||
@ -178,7 +179,7 @@ func (user *User) DebugDownloadMessages(
|
|||||||
return fmt.Errorf("failed to create directory '%v':%w", msgDir, err)
|
return fmt.Errorf("failed to create directory '%v':%w", msgDir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
message, err := user.client.GetFullMessage(ctx, msg.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
message, err := user.client.GetFullMessage(ctx, msg.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to download message '%v':%w", msg.ID, err)
|
return fmt.Errorf("failed to download message '%v':%w", msg.ID, err)
|
||||||
}
|
}
|
||||||
@ -187,7 +188,7 @@ func (user *User) DebugDownloadMessages(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := withAddrKR(user.apiUser, user.apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||||
switch {
|
switch {
|
||||||
case len(message.Attachments) > 0:
|
case len(message.Attachments) > 0:
|
||||||
return decodeMultipartMessage(msgDir, addrKR, message.Message, message.AttData)
|
return decodeMultipartMessage(msgDir, addrKR, message.Message, message.AttData)
|
||||||
|
|||||||
@ -20,8 +20,6 @@ package user
|
|||||||
import "errors"
|
import "errors"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrNoSuchAddress = errors.New("no such address")
|
ErrNoSuchAddress = errors.New("no such address")
|
||||||
ErrInvalidReturnPath = errors.New("invalid return path")
|
ErrMissingAddrKey = errors.New("missing address key")
|
||||||
ErrInvalidRecipient = errors.New("invalid recipient")
|
|
||||||
ErrMissingAddrKey = errors.New("missing address key")
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -34,6 +34,7 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@ -114,8 +115,8 @@ func (user *User) syncUserAddressesLabelsAndClearSync(ctx context.Context, cance
|
|||||||
|
|
||||||
// Update the API info in the user.
|
// Update the API info in the user.
|
||||||
user.apiUser = apiUser
|
user.apiUser = apiUser
|
||||||
user.apiAddrs = groupBy(apiAddrs, func(addr proton.Address) string { return addr.ID })
|
user.apiAddrs = usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID })
|
||||||
user.apiLabels = groupBy(apiLabels, func(label proton.Label) string { return label.ID })
|
user.apiLabels = usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID })
|
||||||
|
|
||||||
// Clear sync status; we want to sync everything again.
|
// Clear sync status; we want to sync everything again.
|
||||||
if err := user.clearSyncStatus(); err != nil {
|
if err := user.clearSyncStatus(); err != nil {
|
||||||
@ -208,7 +209,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
|
|||||||
// If the address is enabled, we need to hook it up to the update channels.
|
// If the address is enabled, we need to hook it up to the update channels.
|
||||||
switch user.vault.AddressMode() {
|
switch user.vault.AddressMode() {
|
||||||
case vault.CombinedMode:
|
case vault.CombinedMode:
|
||||||
primAddr, err := getPrimaryAddr(user.apiAddrs)
|
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get primary address: %w", err)
|
return fmt.Errorf("failed to get primary address: %w", err)
|
||||||
}
|
}
|
||||||
@ -267,7 +268,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
|
|||||||
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
|
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
|
||||||
switch user.vault.AddressMode() {
|
switch user.vault.AddressMode() {
|
||||||
case vault.CombinedMode:
|
case vault.CombinedMode:
|
||||||
primAddr, err := getPrimaryAddr(user.apiAddrs)
|
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get primary address: %w", err)
|
return fmt.Errorf("failed to get primary address: %w", err)
|
||||||
}
|
}
|
||||||
@ -585,7 +586,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
|
|||||||
"subject": logging.Sensitive(message.Subject),
|
"subject": logging.Sensitive(message.Subject),
|
||||||
}).Info("Handling message created event")
|
}).Info("Handling message created event")
|
||||||
|
|
||||||
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
full, err := user.client.GetFullMessage(ctx, message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||||
@ -599,7 +600,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
|
|||||||
return safe.RLockRetErr(func() ([]imap.Update, error) {
|
return safe.RLockRetErr(func() ([]imap.Update, error) {
|
||||||
var update imap.Update
|
var update imap.Update
|
||||||
|
|
||||||
if err := withAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||||
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
|
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
|
||||||
|
|
||||||
if res.err != nil {
|
if res.err != nil {
|
||||||
@ -652,7 +653,7 @@ func (user *User) handleUpdateMessageEvent(_ context.Context, message proton.Mes
|
|||||||
|
|
||||||
update := imap.NewMessageMailboxesUpdated(
|
update := imap.NewMessageMailboxesUpdated(
|
||||||
imap.MessageID(message.ID),
|
imap.MessageID(message.ID),
|
||||||
mapTo[string, imap.MailboxID](wantLabels(user.apiLabels, message.LabelIDs)),
|
usertypes.MapTo[string, imap.MailboxID](wantLabels(user.apiLabels, message.LabelIDs)),
|
||||||
flags,
|
flags,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -693,7 +694,7 @@ func (user *User) handleUpdateDraftOrSentMessage(ctx context.Context, event prot
|
|||||||
"isDraft": event.Message.IsDraft(),
|
"isDraft": event.Message.IsDraft(),
|
||||||
}).Info("Handling draft or sent updated event")
|
}).Info("Handling draft or sent updated event")
|
||||||
|
|
||||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
full, err := user.client.GetFullMessage(ctx, event.Message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||||
@ -706,7 +707,7 @@ func (user *User) handleUpdateDraftOrSentMessage(ctx context.Context, event prot
|
|||||||
|
|
||||||
var update imap.Update
|
var update imap.Update
|
||||||
|
|
||||||
if err := withAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||||
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
|
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
|
||||||
|
|
||||||
if res.err != nil {
|
if res.err != nil {
|
||||||
@ -827,7 +828,7 @@ func safePublishMessageUpdate(user *User, addressID string, update imap.Update)
|
|||||||
v, ok := user.updateCh[addressID]
|
v, ok := user.updateCh[addressID]
|
||||||
if !ok {
|
if !ok {
|
||||||
if user.GetAddressMode() == vault.CombinedMode {
|
if user.GetAddressMode() == vault.CombinedMode {
|
||||||
primAddr, err := getPrimaryAddr(user.apiAddrs)
|
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to get primary address: %w", err)
|
return false, fmt.Errorf("failed to get primary address: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -33,6 +33,8 @@ import (
|
|||||||
"github.com/ProtonMail/go-proton-api"
|
"github.com/ProtonMail/go-proton-api"
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
|
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
|
||||||
@ -288,26 +290,26 @@ func (conn *imapConnector) CreateMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compute the hash of the message (to match it against SMTP messages).
|
// Compute the hash of the message (to match it against SMTP messages).
|
||||||
hash, err := getMessageHash(literal)
|
hash, err := sendrecorder.GetMessageHash(literal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return imap.Message{}, nil, err
|
return imap.Message{}, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we already tried to send this message recently.
|
// Check if we already tried to send this message recently.
|
||||||
if messageID, ok, err := conn.sendHash.hasEntryWait(ctx, hash, time.Now().Add(90*time.Second), toList); err != nil {
|
if messageID, ok, err := conn.sendHash.HasEntryWait(ctx, hash, time.Now().Add(90*time.Second), toList); err != nil {
|
||||||
return imap.Message{}, nil, fmt.Errorf("failed to check send hash: %w", err)
|
return imap.Message{}, nil, fmt.Errorf("failed to check send hash: %w", err)
|
||||||
} else if ok {
|
} else if ok {
|
||||||
conn.log.WithField("messageID", messageID).Warn("Message already sent")
|
conn.log.WithField("messageID", messageID).Warn("Message already sent")
|
||||||
|
|
||||||
// Query the server-side message.
|
// Query the server-side message.
|
||||||
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
full, err := conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
|
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the message as it is on the server.
|
// Build the message as it is on the server.
|
||||||
if err := safe.RLockRet(func() error {
|
if err := safe.RLockRet(func() error {
|
||||||
return withAddrKR(conn.apiUser, conn.apiAddrs[full.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[full.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil {
|
if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil {
|
||||||
@ -378,14 +380,14 @@ func (conn *imapConnector) CreateMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
|
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
|
||||||
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
msg, err := conn.client.GetFullMessage(ctx, string(id), usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return safe.RLockRetErr(func() ([]byte, error) {
|
return safe.RLockRetErr(func() ([]byte, error) {
|
||||||
var literal []byte
|
var literal []byte
|
||||||
err := withAddrKR(conn.apiUser, conn.apiAddrs[msg.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
err := usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[msg.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||||
l, buildErr := message.BuildRFC822(addrKR, msg.Message, msg.AttData, defaultJobOpts())
|
l, buildErr := message.BuildRFC822(addrKR, msg.Message, msg.AttData, defaultJobOpts())
|
||||||
if buildErr != nil {
|
if buildErr != nil {
|
||||||
return buildErr
|
return buildErr
|
||||||
@ -408,7 +410,7 @@ func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs
|
|||||||
return connector.ErrOperationNotAllowed
|
return connector.ErrOperationNotAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(mailboxID))
|
return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(mailboxID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveMessagesFromMailbox unlabels the given messages with the given label ID.
|
// RemoveMessagesFromMailbox unlabels the given messages with the given label ID.
|
||||||
@ -419,7 +421,7 @@ func (conn *imapConnector) RemoveMessagesFromMailbox(ctx context.Context, messag
|
|||||||
return connector.ErrOperationNotAllowed
|
return connector.ErrOperationNotAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
msgIDs := mapTo[imap.MessageID, string](messageIDs)
|
msgIDs := usertypes.MapTo[imap.MessageID, string](messageIDs)
|
||||||
if err := conn.client.UnlabelMessages(ctx, msgIDs, string(mailboxID)); err != nil {
|
if err := conn.client.UnlabelMessages(ctx, msgIDs, string(mailboxID)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -461,12 +463,12 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
|
|||||||
return result
|
return result
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
|
if err := conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
|
||||||
return false, fmt.Errorf("labeling messages: %w", err)
|
return false, fmt.Errorf("labeling messages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldExpungeOldLocation {
|
if shouldExpungeOldLocation {
|
||||||
if err := conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
|
if err := conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
|
||||||
return false, fmt.Errorf("unlabeling messages: %w", err)
|
return false, fmt.Errorf("unlabeling messages: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -479,10 +481,10 @@ func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []im
|
|||||||
defer conn.goPollAPIEvents(false)
|
defer conn.goPollAPIEvents(false)
|
||||||
|
|
||||||
if seen {
|
if seen {
|
||||||
return conn.client.MarkMessagesRead(ctx, mapTo[imap.MessageID, string](messageIDs)...)
|
return conn.client.MarkMessagesRead(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn.client.MarkMessagesUnread(ctx, mapTo[imap.MessageID, string](messageIDs)...)
|
return conn.client.MarkMessagesUnread(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkMessagesFlagged sets the flagged value of the given messages.
|
// MarkMessagesFlagged sets the flagged value of the given messages.
|
||||||
@ -490,10 +492,10 @@ func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs [
|
|||||||
defer conn.goPollAPIEvents(false)
|
defer conn.goPollAPIEvents(false)
|
||||||
|
|
||||||
if flagged {
|
if flagged {
|
||||||
return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
|
return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
|
return conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUpdates returns a stream of updates that the gluon server should apply.
|
// GetUpdates returns a stream of updates that the gluon server should apply.
|
||||||
@ -535,7 +537,7 @@ func (conn *imapConnector) importMessage(
|
|||||||
var full proton.FullMessage
|
var full proton.FullMessage
|
||||||
|
|
||||||
if err := safe.RLockRet(func() error {
|
if err := safe.RLockRet(func() error {
|
||||||
return withAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||||
var messageID string
|
var messageID string
|
||||||
|
|
||||||
if slices.Contains(labelIDs, proton.DraftsLabel) {
|
if slices.Contains(labelIDs, proton.DraftsLabel) {
|
||||||
@ -571,7 +573,7 @@ func (conn *imapConnector) importMessage(
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
|
if full, err = conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
|
||||||
return fmt.Errorf("failed to fetch message: %w", err)
|
return fmt.Errorf("failed to fetch message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/ProtonMail/go-proton-api"
|
"github.com/ProtonMail/go-proton-api"
|
||||||
"github.com/ProtonMail/go-proton-api/server"
|
"github.com/ProtonMail/go-proton-api/server"
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ func BenchmarkAddrKeyRing(b *testing.B) {
|
|||||||
b.StartTimer()
|
b.StartTimer()
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
require.NoError(b, withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
require.NoError(b, usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||||
return nil
|
return nil
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,6 +34,7 @@ import (
|
|||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||||
"github.com/bradenaw/juniper/parallel"
|
"github.com/bradenaw/juniper/parallel"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
@ -114,7 +115,7 @@ func (user *User) doSync(ctx context.Context) error {
|
|||||||
|
|
||||||
func (user *User) sync(ctx context.Context) error {
|
func (user *User) sync(ctx context.Context) error {
|
||||||
return safe.RLockRet(func() error {
|
return safe.RLockRet(func() error {
|
||||||
return withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
return usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||||
if !user.vault.SyncStatus().HasLabels {
|
if !user.vault.SyncStatus().HasLabels {
|
||||||
user.log.Info("Syncing labels")
|
user.log.Info("Syncing labels")
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/ProtonMail/gluon/imap"
|
"github.com/ProtonMail/gluon/imap"
|
||||||
"github.com/ProtonMail/go-proton-api"
|
"github.com/ProtonMail/go-proton-api"
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
|
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
@ -87,7 +88,7 @@ func newMessageCreatedUpdate(
|
|||||||
return &imap.MessageCreated{
|
return &imap.MessageCreated{
|
||||||
Message: toIMAPMessage(message),
|
Message: toIMAPMessage(message),
|
||||||
Literal: literal,
|
Literal: literal,
|
||||||
MailboxIDs: mapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
|
MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
|
||||||
ParsedMessage: parsedMessage,
|
ParsedMessage: parsedMessage,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -106,7 +107,7 @@ func newMessageCreatedFailedUpdate(
|
|||||||
|
|
||||||
return &imap.MessageCreated{
|
return &imap.MessageCreated{
|
||||||
Message: toIMAPMessage(message),
|
Message: toIMAPMessage(message),
|
||||||
MailboxIDs: mapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
|
MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
|
||||||
Literal: literal,
|
Literal: literal,
|
||||||
ParsedMessage: parsedMessage,
|
ParsedMessage: parsedMessage,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,7 +40,10 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/services/smtp"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry"
|
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
|
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
@ -65,7 +68,7 @@ type User struct {
|
|||||||
vault *vault.User
|
vault *vault.User
|
||||||
client *proton.Client
|
client *proton.Client
|
||||||
reporter reporter.Reporter
|
reporter reporter.Reporter
|
||||||
sendHash *sendRecorder
|
sendHash *sendrecorder.SendRecorder
|
||||||
|
|
||||||
eventCh *async.QueuedChannel[events.Event]
|
eventCh *async.QueuedChannel[events.Event]
|
||||||
eventLock safe.RWMutex
|
eventLock safe.RWMutex
|
||||||
@ -100,6 +103,8 @@ type User struct {
|
|||||||
telemetryManager telemetry.Availability
|
telemetryManager telemetry.Availability
|
||||||
// goStatusProgress triggers a check/sending if progress is needed.
|
// goStatusProgress triggers a check/sending if progress is needed.
|
||||||
goStatusProgress func()
|
goStatusProgress func()
|
||||||
|
|
||||||
|
smtpService *smtp.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new user.
|
// New returns a new user.
|
||||||
@ -148,7 +153,7 @@ func New(
|
|||||||
vault: encVault,
|
vault: encVault,
|
||||||
client: client,
|
client: client,
|
||||||
reporter: reporter,
|
reporter: reporter,
|
||||||
sendHash: newSendRecorder(sendEntryExpiry),
|
sendHash: sendrecorder.NewSendRecorder(sendrecorder.SendEntryExpiry),
|
||||||
|
|
||||||
eventCh: async.NewQueuedChannel[events.Event](0, 0, crashHandler),
|
eventCh: async.NewQueuedChannel[events.Event](0, 0, crashHandler),
|
||||||
eventLock: safe.NewRWMutex(),
|
eventLock: safe.NewRWMutex(),
|
||||||
@ -156,10 +161,10 @@ func New(
|
|||||||
apiUser: apiUser,
|
apiUser: apiUser,
|
||||||
apiUserLock: safe.NewRWMutex(),
|
apiUserLock: safe.NewRWMutex(),
|
||||||
|
|
||||||
apiAddrs: groupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }),
|
apiAddrs: usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }),
|
||||||
apiAddrsLock: safe.NewRWMutex(),
|
apiAddrsLock: safe.NewRWMutex(),
|
||||||
|
|
||||||
apiLabels: groupBy(apiLabels, func(label proton.Label) string { return label.ID }),
|
apiLabels: usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID }),
|
||||||
apiLabelsLock: safe.NewRWMutex(),
|
apiLabelsLock: safe.NewRWMutex(),
|
||||||
|
|
||||||
updateCh: make(map[string]*async.QueuedChannel[imap.Update]),
|
updateCh: make(map[string]*async.QueuedChannel[imap.Update]),
|
||||||
@ -178,6 +183,8 @@ func New(
|
|||||||
telemetryManager: telemetryManager,
|
telemetryManager: telemetryManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user.smtpService = smtp.NewService(user, client, user.sendHash, user.panicHandler, user.reporter)
|
||||||
|
|
||||||
// Check for status_progress when triggered.
|
// Check for status_progress when triggered.
|
||||||
user.goStatusProgress = user.tasks.PeriodicOrTrigger(configstatus.ProgressCheckInterval, 0, func(ctx context.Context) {
|
user.goStatusProgress = user.tasks.PeriodicOrTrigger(configstatus.ProgressCheckInterval, 0, func(ctx context.Context) {
|
||||||
user.SendConfigStatusProgress(ctx)
|
user.SendConfigStatusProgress(ctx)
|
||||||
@ -264,6 +271,9 @@ func New(
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Start SMTP Service
|
||||||
|
user.smtpService.Start(user.tasks)
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -461,7 +471,7 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
|||||||
|
|
||||||
switch user.vault.AddressMode() {
|
switch user.vault.AddressMode() {
|
||||||
case vault.CombinedMode:
|
case vault.CombinedMode:
|
||||||
primAddr, err := getAddrIdx(user.apiAddrs, 0)
|
primAddr, err := usertypes.GetAddrIdx(user.apiAddrs, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get primary address: %w", err)
|
return nil, fmt.Errorf("failed to get primary address: %w", err)
|
||||||
}
|
}
|
||||||
@ -485,10 +495,10 @@ func (user *User) SendMail(authID string, from string, to []string, r io.Reader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(to) == 0 {
|
if len(to) == 0 {
|
||||||
return ErrInvalidRecipient
|
return smtp.ErrInvalidRecipient
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.sendMail(authID, from, to, r); err != nil {
|
if err := user.smtpService.SendMail(context.Background(), authID, from, to, r); err != nil {
|
||||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) {
|
if apiErr := new(proton.APIError); errors.As(err, &apiErr) {
|
||||||
logrus.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("failed to send message")
|
logrus.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("failed to send message")
|
||||||
}
|
}
|
||||||
@ -664,6 +674,12 @@ func (user *User) SendTelemetry(ctx context.Context, data []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *User) WithSMTPData(ctx context.Context, op func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error {
|
||||||
|
return safe.RLockRet(func() error {
|
||||||
|
return op(ctx, user.apiAddrs, user.apiUser, user.vault)
|
||||||
|
}, user.apiUserLock, user.apiAddrsLock, user.eventLock)
|
||||||
|
}
|
||||||
|
|
||||||
// initUpdateCh initializes the user's update channels in the given address mode.
|
// initUpdateCh initializes the user's update channels in the given address mode.
|
||||||
// It is assumed that user.apiAddrs and user.updateCh are already locked.
|
// It is assumed that user.apiAddrs and user.updateCh are already locked.
|
||||||
func (user *User) initUpdateCh(mode vault.AddressMode) {
|
func (user *User) initUpdateCh(mode vault.AddressMode) {
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
// You should have received a copy of the GNU General Public License
|
// You should have received a copy of the GNU General Public License
|
||||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package user
|
package usertypes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -25,7 +25,7 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func withAddrKR(apiUser proton.User, apiAddr proton.Address, keyPass []byte, fn func(userKR, addrKR *crypto.KeyRing) error) error {
|
func WithAddrKR(apiUser proton.User, apiAddr proton.Address, keyPass []byte, fn func(userKR, addrKR *crypto.KeyRing) error) error {
|
||||||
userKR, err := apiUser.Keys.Unlock(keyPass, nil)
|
userKR, err := apiUser.Keys.Unlock(keyPass, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to unlock user keys: %w", err)
|
return fmt.Errorf("failed to unlock user keys: %w", err)
|
||||||
@ -41,7 +41,7 @@ func withAddrKR(apiUser proton.User, apiAddr proton.Address, keyPass []byte, fn
|
|||||||
return fn(userKR, addrKR)
|
return fn(userKR, addrKR)
|
||||||
}
|
}
|
||||||
|
|
||||||
func withAddrKRs(apiUser proton.User, apiAddr map[string]proton.Address, keyPass []byte, fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
|
func WithAddrKRs(apiUser proton.User, apiAddr map[string]proton.Address, keyPass []byte, fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
|
||||||
userKR, err := apiUser.Keys.Unlock(keyPass, nil)
|
userKR, err := apiUser.Keys.Unlock(keyPass, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to unlock user keys: %w", err)
|
return fmt.Errorf("failed to unlock user keys: %w", err)
|
||||||
@ -15,7 +15,7 @@
|
|||||||
// You should have received a copy of the GNU General Public License
|
// You should have received a copy of the GNU General Public License
|
||||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package user
|
package usertypes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -29,10 +29,10 @@ import (
|
|||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
// mapTo converts the slice to the given type.
|
// MapTo converts the slice to the given type.
|
||||||
// This is not runtime safe, so make sure the slice is of the correct type!
|
// This is not runtime safe, so make sure the slice is of the correct type!
|
||||||
// (This is a workaround for the fact that slices cannot be converted to other types generically).
|
// (This is a workaround for the fact that slices cannot be converted to other types generically).
|
||||||
func mapTo[From, To any](from []From) []To {
|
func MapTo[From, To any](from []From) []To {
|
||||||
to := make([]To, 0, len(from))
|
to := make([]To, 0, len(from))
|
||||||
|
|
||||||
for _, from := range from {
|
for _, from := range from {
|
||||||
@ -47,9 +47,9 @@ func mapTo[From, To any](from []From) []To {
|
|||||||
return to
|
return to
|
||||||
}
|
}
|
||||||
|
|
||||||
// groupBy returns a map of the given slice grouped by the given key.
|
// GroupBy returns a map of the given slice grouped by the given key.
|
||||||
// Duplicate keys are overwritten.
|
// Duplicate keys are overwritten.
|
||||||
func groupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[Key]Value {
|
func GroupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[Key]Value {
|
||||||
groups := make(map[Key]Value)
|
groups := make(map[Key]Value)
|
||||||
|
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
@ -59,10 +59,10 @@ func groupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[
|
|||||||
return groups
|
return groups
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAddrID returns the address ID for the given email address.
|
// GetAddrID returns the address ID for the given email address.
|
||||||
func getAddrID(apiAddrs map[string]proton.Address, email string) (string, error) {
|
func GetAddrID(apiAddrs map[string]proton.Address, email string) (string, error) {
|
||||||
for _, addr := range apiAddrs {
|
for _, addr := range apiAddrs {
|
||||||
if strings.EqualFold(addr.Email, sanitizeEmail(email)) {
|
if strings.EqualFold(addr.Email, SanitizeEmail(email)) {
|
||||||
return addr.ID, nil
|
return addr.ID, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -70,8 +70,8 @@ func getAddrID(apiAddrs map[string]proton.Address, email string) (string, error)
|
|||||||
return "", fmt.Errorf("address %s not found", email)
|
return "", fmt.Errorf("address %s not found", email)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAddrIdx returns the address with the given index.
|
// GetAddrIdx returns the address with the given index.
|
||||||
func getAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, error) {
|
func GetAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, error) {
|
||||||
sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool {
|
sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool {
|
||||||
return a.Order < b.Order
|
return a.Order < b.Order
|
||||||
})
|
})
|
||||||
@ -83,7 +83,7 @@ func getAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, er
|
|||||||
return sorted[idx], nil
|
return sorted[idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPrimaryAddr(apiAddrs map[string]proton.Address) (proton.Address, error) {
|
func GetPrimaryAddr(apiAddrs map[string]proton.Address) (proton.Address, error) {
|
||||||
sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool {
|
sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool {
|
||||||
return a.Order < b.Order
|
return a.Order < b.Order
|
||||||
})
|
})
|
||||||
@ -106,6 +106,15 @@ func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
|
|||||||
return sorted
|
return sorted
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler {
|
func NewProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler {
|
||||||
return proton.NewParallelScheduler(runtime.NumCPU()/2, panicHandler)
|
return proton.NewParallelScheduler(runtime.NumCPU()/2, panicHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SanitizeEmail(email string) string {
|
||||||
|
splitAt := strings.Split(email, "@")
|
||||||
|
if len(splitAt) != 2 {
|
||||||
|
return email
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Split(splitAt[0], "+")[0] + "@" + splitAt[1]
|
||||||
|
}
|
||||||
@ -15,7 +15,7 @@
|
|||||||
// You should have received a copy of the GNU General Public License
|
// You should have received a copy of the GNU General Public License
|
||||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package user
|
package usertypes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
@ -30,8 +30,8 @@ func TestToType(t *testing.T) {
|
|||||||
require.NotEqual(t, []myString{"a", "b", "c"}, []string{"a", "b", "c"})
|
require.NotEqual(t, []myString{"a", "b", "c"}, []string{"a", "b", "c"})
|
||||||
|
|
||||||
// But converting them to the same type makes them equal.
|
// But converting them to the same type makes them equal.
|
||||||
require.Equal(t, []myString{"a", "b", "c"}, mapTo[string, myString]([]string{"a", "b", "c"}))
|
require.Equal(t, []myString{"a", "b", "c"}, MapTo[string, myString]([]string{"a", "b", "c"}))
|
||||||
|
|
||||||
// The conversion can happen in the other direction too.
|
// The conversion can happen in the other direction too.
|
||||||
require.Equal(t, []string{"a", "b", "c"}, mapTo[myString, string]([]myString{"a", "b", "c"}))
|
require.Equal(t, []string{"a", "b", "c"}, MapTo[myString, string]([]myString{"a", "b", "c"}))
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user