diff --git a/internal/user/send_recorder.go b/internal/services/sendrecorder/recorder.go
similarity index 79%
rename from internal/user/send_recorder.go
rename to internal/services/sendrecorder/recorder.go
index 3119aaa1..5b8df1b0 100644
--- a/internal/user/send_recorder.go
+++ b/internal/services/sendrecorder/recorder.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see .
-package user
+package sendrecorder
import (
"context"
@@ -30,11 +30,11 @@ import (
"golang.org/x/exp/slices"
)
-const sendEntryExpiry = 30 * time.Minute
+const SendEntryExpiry = 30 * time.Minute
-type SendRecorderID uint64
+type ID uint64
-type sendRecorder struct {
+type SendRecorder struct {
expiry time.Duration
entries map[string][]*sendEntry
@@ -42,15 +42,15 @@ type sendRecorder struct {
cancelIDCounter uint64
}
-func newSendRecorder(expiry time.Duration) *sendRecorder {
- return &sendRecorder{
+func NewSendRecorder(expiry time.Duration) *SendRecorder {
+ return &SendRecorder{
expiry: expiry,
entries: make(map[string][]*sendEntry),
}
}
type sendEntry struct {
- srID SendRecorderID
+ srID ID
msgID string
toList []string
exp time.Time
@@ -65,17 +65,17 @@ func (s *sendEntry) closeWaitChannel() {
}
}
-// tryInsertWait tries to insert the given message into the send recorder.
+// TryInsertWait tries to insert the given message into the send recorder.
// If an entry already exists but it was not sent yet, it waits.
// It returns whether an entry could be inserted and an error if it times out while waiting.
-func (h *sendRecorder) tryInsertWait(
+func (h *SendRecorder) TryInsertWait(
ctx context.Context,
hash string,
toList []string,
deadline time.Time,
-) (SendRecorderID, bool, error) {
+) (ID, bool, error) {
// If we successfully inserted the hash, we can return true.
- srID, waitCh, ok := h.tryInsert(hash, toList)
+ srID, waitCh, ok := h.TryInsert(hash, toList)
if ok {
return srID, true, nil
}
@@ -88,16 +88,16 @@ func (h *sendRecorder) tryInsertWait(
// If the message failed to send, try to insert it again.
if !wasSent {
- return h.tryInsertWait(ctx, hash, toList, deadline)
+ return h.TryInsertWait(ctx, hash, toList, deadline)
}
return srID, false, nil
}
-// hasEntryWait returns whether the given message already exists in the send recorder.
+// HasEntryWait returns whether the given message already exists in the send recorder.
// If it does, it waits for its ID to be known, then returns it and true.
// If no entry exists, or it times out while waiting for its ID to be known, it returns false.
-func (h *sendRecorder) hasEntryWait(ctx context.Context,
+func (h *SendRecorder) HasEntryWait(ctx context.Context,
hash string,
deadline time.Time,
toList []string,
@@ -118,10 +118,10 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context,
return messageID, true, nil
}
- return h.hasEntryWait(ctx, hash, deadline, toList)
+ return h.HasEntryWait(ctx, hash, deadline, toList)
}
-func (h *sendRecorder) removeExpiredUnsafe() {
+func (h *SendRecorder) removeExpiredUnsafe() {
for hash, entry := range h.entries {
remaining := xslices.Filter(entry, func(t *sendEntry) bool {
return !t.exp.Before(time.Now())
@@ -135,7 +135,7 @@ func (h *sendRecorder) removeExpiredUnsafe() {
}
}
-func (h *sendRecorder) tryInsert(hash string, toList []string) (SendRecorderID, <-chan struct{}, bool) {
+func (h *SendRecorder) TryInsert(hash string, toList []string) (ID, <-chan struct{}, bool) {
h.entriesLock.Lock()
defer h.entriesLock.Unlock()
@@ -163,7 +163,7 @@ func (h *sendRecorder) tryInsert(hash string, toList []string) (SendRecorderID,
return cancelID, waitCh, true
}
-func (h *sendRecorder) getEntryWaitInfo(hash string, toList []string) (SendRecorderID, <-chan struct{}, bool) {
+func (h *SendRecorder) getEntryWaitInfo(hash string, toList []string) (ID, <-chan struct{}, bool) {
h.entriesLock.Lock()
defer h.entriesLock.Unlock()
@@ -180,8 +180,8 @@ func (h *sendRecorder) getEntryWaitInfo(hash string, toList []string) (SendRecor
return 0, nil, false
}
-// signalMessageSent should be called after a message has been successfully sent.
-func (h *sendRecorder) signalMessageSent(hash string, srID SendRecorderID, msgID string) {
+// SignalMessageSent should be called after a message has been successfully sent.
+func (h *SendRecorder) SignalMessageSent(hash string, srID ID, msgID string) {
h.entriesLock.Lock()
defer h.entriesLock.Unlock()
@@ -199,7 +199,7 @@ func (h *sendRecorder) signalMessageSent(hash string, srID SendRecorderID, msgID
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
}
-func (h *sendRecorder) removeOnFail(hash string, id SendRecorderID) {
+func (h *SendRecorder) RemoveOnFail(hash string, id ID) {
h.entriesLock.Lock()
defer h.entriesLock.Unlock()
@@ -222,11 +222,11 @@ func (h *sendRecorder) removeOnFail(hash string, id SendRecorderID) {
}
}
-func (h *sendRecorder) wait(
+func (h *SendRecorder) wait(
ctx context.Context,
hash string,
waitCh <-chan struct{},
- srID SendRecorderID,
+ srID ID,
deadline time.Time,
) (string, bool, error) {
ctx, cancel := context.WithDeadline(ctx, deadline)
@@ -254,19 +254,19 @@ func (h *sendRecorder) wait(
return "", false, nil
}
-func (h *sendRecorder) newSendRecorderID() SendRecorderID {
+func (h *SendRecorder) newSendRecorderID() ID {
h.cancelIDCounter++
- return SendRecorderID(h.cancelIDCounter)
+ return ID(h.cancelIDCounter)
}
-// getMessageHash returns the hash of the given message.
+// GetMessageHash returns the hash of the given message.
// This takes into account:
// - the Subject header,
// - the From/To/Cc headers,
// - the Content-Type header of each (leaf) part,
// - the Content-Disposition header of each (leaf) part,
// - the (decoded) body of each part.
-func getMessageHash(b []byte) (string, error) {
+func GetMessageHash(b []byte) (string, error) {
return rfc822.GetMessageHash(b)
}
diff --git a/internal/user/send_recorder_test.go b/internal/services/sendrecorder/recorder_test.go
similarity index 89%
rename from internal/user/send_recorder_test.go
rename to internal/services/sendrecorder/recorder_test.go
index abc14037..215a3bd5 100644
--- a/internal/user/send_recorder_test.go
+++ b/internal/services/sendrecorder/recorder_test.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see .
-package user
+package sendrecorder
import (
"context"
@@ -26,7 +26,7 @@ import (
)
func TestSendHasher_Insert(t *testing.T) {
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
srdID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -35,7 +35,7 @@ func TestSendHasher_Insert(t *testing.T) {
require.NotEmpty(t, hash1)
// Simulate successfully sending the message.
- h.signalMessageSent(hash1, srdID1, "abc")
+ h.SignalMessageSent(hash1, srdID1, "abc")
// Inserting a message with the same hash should return false.
srdID2, _, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -52,7 +52,7 @@ func TestSendHasher_Insert(t *testing.T) {
}
func TestSendHasher_Insert_Expired(t *testing.T) {
- h := newSendRecorder(time.Second)
+ h := NewSendRecorder(time.Second)
// Insert a message into the hasher.
srID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -61,7 +61,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) {
require.NotEmpty(t, hash1)
// Simulate successfully sending the message.
- h.signalMessageSent(hash1, srID1, "abc")
+ h.SignalMessageSent(hash1, srID1, "abc")
// Wait for the entry to expire.
time.Sleep(time.Second)
@@ -79,7 +79,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) {
}
func TestSendHasher_Insert_DifferentToList(t *testing.T) {
- h := newSendRecorder(time.Second)
+ h := NewSendRecorder(time.Second)
// Insert a message into the hasher.
srID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), []string{"abc", "def"}...)
@@ -101,7 +101,7 @@ func TestSendHasher_Insert_DifferentToList(t *testing.T) {
}
func TestSendHasher_Wait_SendSuccess(t *testing.T) {
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -112,7 +112,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) {
// Simulate successfully sending the message after half a second.
go func() {
time.Sleep(time.Millisecond * 500)
- h.signalMessageSent(hash, srID1, "abc")
+ h.SignalMessageSent(hash, srID1, "abc")
}()
// Inserting a message with the same hash should fail.
@@ -123,7 +123,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) {
}
func TestSendHasher_Wait_SendFail(t *testing.T) {
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -134,7 +134,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) {
// Simulate failing to send the message after half a second.
go func() {
time.Sleep(time.Millisecond * 500)
- h.removeOnFail(hash, srID1)
+ h.RemoveOnFail(hash, srID1)
}()
// Inserting a message with the same hash should succeed because the first message failed to send.
@@ -148,7 +148,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) {
}
func TestSendHasher_Wait_Timeout(t *testing.T) {
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -162,7 +162,7 @@ func TestSendHasher_Wait_Timeout(t *testing.T) {
}
func TestSendHasher_HasEntry(t *testing.T) {
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -171,7 +171,7 @@ func TestSendHasher_HasEntry(t *testing.T) {
require.NotEmpty(t, hash)
// Simulate successfully sending the message.
- h.signalMessageSent(hash, srID1, "abc")
+ h.SignalMessageSent(hash, srID1, "abc")
// The message was already sent; we should find it in the hasher.
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
@@ -181,7 +181,7 @@ func TestSendHasher_HasEntry(t *testing.T) {
}
func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -192,7 +192,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
// Simulate successfully sending the message after half a second.
go func() {
time.Sleep(time.Millisecond * 500)
- h.signalMessageSent(hash, srID1, "abc")
+ h.SignalMessageSent(hash, srID1, "abc")
}()
// The message was already sent; we should find it in the hasher.
@@ -205,9 +205,9 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
// There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message
// is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be
- // inserted as a new entry. The two clients end up sending the message twice and calling the `signalMessageSent` x2,
+ // inserted as a new entry. The two clients end up sending the message twice and calling the `SignalMessageSent` x2,
// resulting in a crash.
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -217,8 +217,8 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
// Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections
// to attempt to send the same message.
- h.signalMessageSent(hash, srID1, "abc")
- h.signalMessageSent(hash, srID1, "abc")
+ h.SignalMessageSent(hash, srID1, "abc")
+ h.SignalMessageSent(hash, srID1, "abc")
// The message was already sent; we should find it in the hasher.
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
@@ -228,7 +228,7 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
}
func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testing.T) {
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver ")
@@ -249,7 +249,7 @@ func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testin
func TestSendHashed_SameMessageWIthDifferentToListShouldWaitSuccessfullyAfterSend(t *testing.T) {
// Check that if we send the same message twice with different recipients and the second message is somehow
// sent before the first, ensure that we check if the message was sent we wait on the correct object.
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Minute), "Receiver ")
@@ -264,16 +264,16 @@ func TestSendHashed_SameMessageWIthDifferentToListShouldWaitSuccessfullyAfterSen
require.Equal(t, hash, hash2)
// simulate message sent
- h.signalMessageSent(hash2, srID2, "newID")
+ h.SignalMessageSent(hash2, srID2, "newID")
// Simulate Wait on message 2
- _, ok, err = h.hasEntryWait(context.Background(), hash2, time.Now().Add(time.Second), []string{"Receiver ", "Receiver2 "})
+ _, ok, err = h.HasEntryWait(context.Background(), hash2, time.Now().Add(time.Second), []string{"Receiver ", "Receiver2 "})
require.NoError(t, err)
require.True(t, ok)
}
func TestSendHasher_HasEntry_SendFail(t *testing.T) {
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -284,7 +284,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) {
// Simulate failing to send the message after half a second.
go func() {
time.Sleep(time.Millisecond * 500)
- h.removeOnFail(hash, srID1)
+ h.RemoveOnFail(hash, srID1)
}()
// The message failed to send; we should not find it in the hasher.
@@ -294,7 +294,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) {
}
func TestSendHasher_HasEntry_Timeout(t *testing.T) {
- h := newSendRecorder(sendEntryExpiry)
+ h := NewSendRecorder(SendEntryExpiry)
// Insert a message into the hasher.
_, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -309,7 +309,7 @@ func TestSendHasher_HasEntry_Timeout(t *testing.T) {
}
func TestSendHasher_HasEntry_Expired(t *testing.T) {
- h := newSendRecorder(time.Second)
+ h := NewSendRecorder(time.Second)
// Insert a message into the hasher.
srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
@@ -318,7 +318,7 @@ func TestSendHasher_HasEntry_Expired(t *testing.T) {
require.NotEmpty(t, hash)
// Simulate successfully sending the message.
- h.signalMessageSent(hash, srID1, "abc")
+ h.SignalMessageSent(hash, srID1, "abc")
// Wait for the entry to expire.
time.Sleep(time.Second)
@@ -432,10 +432,10 @@ func TestGetMessageHash(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- hash1, err := getMessageHash(tt.lit1)
+ hash1, err := GetMessageHash(tt.lit1)
require.NoError(t, err)
- hash2, err := getMessageHash(tt.lit2)
+ hash2, err := GetMessageHash(tt.lit2)
require.NoError(t, err)
if tt.wantEqual {
@@ -447,13 +447,13 @@ func TestGetMessageHash(t *testing.T) {
}
}
-func testTryInsert(h *sendRecorder, literal string, deadline time.Time, toList ...string) (SendRecorderID, string, bool, error) { //nolint:unparam
- hash, err := getMessageHash([]byte(literal))
+func testTryInsert(h *SendRecorder, literal string, deadline time.Time, toList ...string) (ID, string, bool, error) { //nolint:unparam
+ hash, err := GetMessageHash([]byte(literal))
if err != nil {
return 0, "", false, err
}
- srID, ok, err := h.tryInsertWait(context.Background(), hash, toList, deadline)
+ srID, ok, err := h.TryInsertWait(context.Background(), hash, toList, deadline)
if err != nil {
return 0, "", false, err
}
@@ -461,11 +461,11 @@ func testTryInsert(h *sendRecorder, literal string, deadline time.Time, toList .
return srID, hash, ok, nil
}
-func testHasEntry(h *sendRecorder, literal string, deadline time.Time, toList ...string) (string, bool, error) { //nolint:unparam
- hash, err := getMessageHash([]byte(literal))
+func testHasEntry(h *SendRecorder, literal string, deadline time.Time, toList ...string) (string, bool, error) { //nolint:unparam
+ hash, err := GetMessageHash([]byte(literal))
if err != nil {
return "", false, err
}
- return h.hasEntryWait(context.Background(), hash, deadline, toList)
+ return h.HasEntryWait(context.Background(), hash, deadline, toList)
}
diff --git a/internal/services/smtp/errors.go b/internal/services/smtp/errors.go
new file mode 100644
index 00000000..eb549ad6
--- /dev/null
+++ b/internal/services/smtp/errors.go
@@ -0,0 +1,23 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package smtp
+
+import "errors"
+
+var ErrInvalidRecipient = errors.New("invalid recipient")
+var ErrInvalidReturnPath = errors.New("invalid return path")
diff --git a/internal/services/smtp/service.go b/internal/services/smtp/service.go
new file mode 100644
index 00000000..431bb502
--- /dev/null
+++ b/internal/services/smtp/service.go
@@ -0,0 +1,132 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package smtp
+
+import (
+ "context"
+ "io"
+
+ "github.com/ProtonMail/gluon/async"
+ "github.com/ProtonMail/gluon/logging"
+ "github.com/ProtonMail/gluon/reporter"
+ "github.com/ProtonMail/go-proton-api"
+ "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
+ "github.com/ProtonMail/proton-bridge/v3/internal/vault"
+ "github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
+ "github.com/sirupsen/logrus"
+)
+
+// UserInterface is just wrapper to avoid recursive go module imports. To be removed when the identity service is ready.
+type UserInterface interface {
+ ID() string
+ WithSMTPData(context.Context, func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error
+}
+
+type Service struct {
+ panicHandler async.PanicHandler
+ cpc *cpc.CPC
+ user UserInterface
+ client *proton.Client
+ recorder *sendrecorder.SendRecorder
+ log *logrus.Entry
+ reporter reporter.Reporter
+}
+
+func NewService(
+ user UserInterface,
+ client *proton.Client,
+ recorder *sendrecorder.SendRecorder,
+ handler async.PanicHandler,
+ reporter reporter.Reporter,
+) *Service {
+ return &Service{
+ panicHandler: handler,
+ user: user,
+ cpc: cpc.NewCPC(),
+ recorder: recorder,
+ log: logrus.WithFields(logrus.Fields{
+ "user": user.ID(),
+ "service": "smtp",
+ }),
+ reporter: reporter,
+ client: client,
+ }
+}
+
+func (s *Service) SendMail(ctx context.Context, authID string, from string, to []string, r io.Reader) error {
+ _, err := s.cpc.Send(ctx, &sendMailReq{
+ authID: authID,
+ from: from,
+ to: to,
+ r: r,
+ })
+
+ return err
+}
+
+func (s *Service) Start(group *async.Group) {
+ s.log.Debug("Starting service")
+ group.Once(func(ctx context.Context) {
+ logging.DoAnnotated(ctx, func(ctx context.Context) {
+ s.run(ctx)
+ }, logging.Labels{
+ "user": s.user.ID(),
+ "service": "smtp",
+ })
+ })
+}
+
+func (s *Service) run(ctx context.Context) {
+ s.log.Debug("Starting service main loop")
+ defer s.log.Debug("Exiting service main loop")
+ defer s.cpc.Close()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+
+ case request, ok := <-s.cpc.ReceiveCh():
+ if !ok {
+ return
+ }
+
+ switch r := request.Value().(type) {
+ case *sendMailReq:
+ s.log.Debug("Received send mail request")
+ err := s.sendMail(ctx, r)
+ request.Reply(ctx, nil, err)
+
+ default:
+ s.log.Error("Received unknown request")
+ }
+ }
+ }
+}
+
+type sendMailReq struct {
+ authID string
+ from string
+ to []string
+ r io.Reader
+}
+
+func (s *Service) sendMail(ctx context.Context, req *sendMailReq) error {
+ defer async.HandlePanic(s.panicHandler)
+ return s.smtpSendMail(ctx, req.authID, req.from, req.to, req.r)
+}
diff --git a/internal/user/smtp.go b/internal/services/smtp/smtp.go
similarity index 85%
rename from internal/user/smtp.go
rename to internal/services/smtp/smtp.go
index 183df3be..1bbfaa05 100644
--- a/internal/user/smtp.go
+++ b/internal/services/smtp/smtp.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see .
-package user
+package smtp
import (
"bytes"
@@ -36,7 +36,8 @@ import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
- "github.com/ProtonMail/proton-bridge/v3/internal/safe"
+ "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
+ "github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
@@ -47,19 +48,14 @@ import (
"golang.org/x/exp/slices"
)
-// sendMail sends an email from the given address to the given recipients.
-func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error {
- defer async.HandlePanic(user.panicHandler)
-
- return safe.RLockRet(func() error {
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
-
- if _, err := getAddrID(user.apiAddrs, from); err != nil {
+// smtpSendMail sends an email from the given address to the given recipients.
+func (s *Service) smtpSendMail(ctx context.Context, authID string, from string, to []string, r io.Reader) error {
+ return s.user.WithSMTPData(ctx, func(ctx context.Context, apiAddrs map[string]proton.Address, user proton.User, vault *vault.User) error {
+ if _, err := usertypes.GetAddrID(apiAddrs, from); err != nil {
return ErrInvalidReturnPath
}
- emails := xslices.Map(maps.Values(user.apiAddrs), func(addr proton.Address) string {
+ emails := xslices.Map(maps.Values(apiAddrs), func(addr proton.Address) string {
return addr.Email
})
@@ -71,26 +67,26 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
// If running a QA build, dump to disk.
if err := debugDumpToDisk(b); err != nil {
- user.log.WithError(err).Warn("Failed to dump message to disk")
+ s.log.WithError(err).Warn("Failed to dump message to disk")
}
// Compute the hash of the message (to match it against SMTP messages).
- hash, err := getMessageHash(b)
+ hash, err := sendrecorder.GetMessageHash(b)
if err != nil {
return err
}
// Check if we already tried to send this message recently.
- srID, ok, err := user.sendHash.tryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second))
+ srID, ok, err := s.recorder.TryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second))
if err != nil {
return fmt.Errorf("failed to check send hash: %w", err)
} else if !ok {
- user.log.Warn("A duplicate message was already sent recently, skipping")
+ s.log.Warn("A duplicate message was already sent recently, skipping")
return nil
}
// If we fail to send this message, we should remove the hash from the send recorder.
- defer user.sendHash.removeOnFail(hash, srID)
+ defer s.recorder.RemoveOnFail(hash, srID)
// Create a new message parser from the reader.
parser, err := parser.New(bytes.NewReader(b))
@@ -104,17 +100,17 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
}
// Load the user's mail settings.
- settings, err := user.client.GetMailSettings(ctx)
+ settings, err := s.client.GetMailSettings(ctx)
if err != nil {
return fmt.Errorf("failed to get mail settings: %w", err)
}
- addrID, err := getAddrID(user.apiAddrs, from)
+ addrID, err := usertypes.GetAddrID(apiAddrs, from)
if err != nil {
return err
}
- return withAddrKR(user.apiUser, user.apiAddrs[addrID], user.vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error {
+ return usertypes.WithAddrKR(user, apiAddrs[addrID], vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error {
// Use the first key for encrypting the message.
addrKR, err := addrKR.FirstKey()
if err != nil {
@@ -147,12 +143,10 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
}
// Send the message using the correct key.
- sent, err := user.sendWithKey(
+ sent, err := s.sendWithKey(
ctx,
- user.client,
- user.reporter,
authID,
- user.vault.AddressMode(),
+ vault.AddressMode(),
settings,
userKR, addrKR,
emails, from, to,
@@ -163,18 +157,16 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
}
// If the message was successfully sent, we can update the message ID in the record.
- user.sendHash.signalMessageSent(hash, srID, sent.ID)
+ s.recorder.SignalMessageSent(hash, srID, sent.ID)
return nil
})
- }, user.apiUserLock, user.apiAddrsLock, user.eventLock)
+ })
}
// sendWithKey sends the message with the given address key.
-func (user *User) sendWithKey(
+func (s *Service) sendWithKey(
ctx context.Context,
- client *proton.Client,
- sentry reporter.Reporter,
authAddrID string,
addrMode vault.AddressMode,
settings proton.MailSettings,
@@ -188,16 +180,16 @@ func (user *User) sendWithKey(
if message.InReplyTo != "" {
references = append(references, message.InReplyTo)
}
- parentID, err := getParentID(ctx, client, authAddrID, addrMode, references)
+ parentID, err := getParentID(ctx, s.client, authAddrID, addrMode, references)
if err != nil {
- if err := sentry.ReportMessageWithContext("Failed to get parent ID", reporter.Context{
+ if err := s.reporter.ReportMessageWithContext("Failed to get parent ID", reporter.Context{
"error": err,
"references": message.References,
}); err != nil {
logrus.WithError(err).Error("Failed to report error")
}
- logrus.WithError(err).Warn("Failed to get parent ID")
+ s.log.WithError(err).Warn("Failed to get parent ID")
}
var decBody string
@@ -214,7 +206,7 @@ func (user *User) sendWithKey(
return proton.Message{}, fmt.Errorf("unsupported MIME type: %v", message.MIMEType)
}
- draft, err := createDraft(ctx, client, addrKR, emails, from, to, parentID, message.InReplyTo, proton.DraftTemplate{
+ draft, err := s.createDraft(ctx, addrKR, emails, from, to, parentID, message.InReplyTo, proton.DraftTemplate{
Subject: message.Subject,
Body: decBody,
MIMEType: message.MIMEType,
@@ -230,12 +222,12 @@ func (user *User) sendWithKey(
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
}
- attKeys, err := user.createAttachments(ctx, client, addrKR, draft.ID, message.Attachments)
+ attKeys, err := s.createAttachments(ctx, s.client, addrKR, draft.ID, message.Attachments)
if err != nil {
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
}
- recipients, err := user.getRecipients(ctx, client, userKR, settings, draft)
+ recipients, err := s.getRecipients(ctx, s.client, userKR, settings, draft)
if err != nil {
return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err)
}
@@ -245,7 +237,7 @@ func (user *User) sendWithKey(
return proton.Message{}, fmt.Errorf("failed to create packages: %w", err)
}
- res, err := client.SendDraft(ctx, draft.ID, req)
+ res, err := s.client.SendDraft(ctx, draft.ID, req)
if err != nil {
return proton.Message{}, fmt.Errorf("failed to send draft: %w", err)
}
@@ -340,9 +332,8 @@ func getParentID(
return parentID, nil
}
-func createDraft(
+func (s *Service) createDraft(
ctx context.Context,
- client *proton.Client,
addrKR *crypto.KeyRing,
emails []string,
from string,
@@ -360,7 +351,7 @@ func createDraft(
// Check that the sending address is owned by the user, and if so, sanitize it.
if idx := xslices.IndexFunc(emails, func(email string) bool {
- return strings.EqualFold(email, sanitizeEmail(template.Sender.Address))
+ return strings.EqualFold(email, usertypes.SanitizeEmail(template.Sender.Address))
}); idx < 0 {
return proton.Message{}, fmt.Errorf("address %q is not owned by user", template.Sender.Address)
} else { //nolint:revive
@@ -389,14 +380,14 @@ func createDraft(
action = proton.ForwardAction
}
- return client.CreateDraft(ctx, addrKR, proton.CreateDraftReq{
+ return s.client.CreateDraft(ctx, addrKR, proton.CreateDraftReq{
Message: template,
ParentID: parentID,
Action: action,
})
}
-func (user *User) createAttachments(
+func (s *Service) createAttachments(
ctx context.Context,
client *proton.Client,
addrKR *crypto.KeyRing,
@@ -409,9 +400,9 @@ func (user *User) createAttachments(
}
keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) {
- defer async.HandlePanic(user.panicHandler)
+ defer async.HandlePanic(s.panicHandler)
- logrus.WithFields(logrus.Fields{
+ s.log.WithFields(logrus.Fields{
"name": logging.Sensitive(att.Name),
"contentID": att.ContentID,
"disposition": att.Disposition,
@@ -480,7 +471,7 @@ func (user *User) createAttachments(
return attKeys, nil
}
-func (user *User) getRecipients(
+func (s *Service) getRecipients(
ctx context.Context,
client *proton.Client,
userKR *crypto.KeyRing,
@@ -492,7 +483,7 @@ func (user *User) getRecipients(
})
prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) {
- defer async.HandlePanic(user.panicHandler)
+ defer async.HandlePanic(s.panicHandler)
pubKeys, recType, err := client.GetPublicKeys(ctx, recipient)
if err != nil {
@@ -557,15 +548,6 @@ func getMessageSender(parser *parser.Parser) (string, bool) {
return address[0].Address, true
}
-func sanitizeEmail(email string) string {
- splitAt := strings.Split(email, "@")
- if len(splitAt) != 2 {
- return email
- }
-
- return strings.Split(splitAt[0], "+")[0] + "@" + splitAt[1]
-}
-
func constructEmail(headerEmail string, addressEmail string) string {
splitAtHeader := strings.Split(headerEmail, "@")
if len(splitAtHeader) != 2 {
diff --git a/internal/user/smtp_debug.go b/internal/services/smtp/smtp_debug.go
similarity index 98%
rename from internal/user/smtp_debug.go
rename to internal/services/smtp/smtp_debug.go
index 42c1910e..b9ce856b 100644
--- a/internal/user/smtp_debug.go
+++ b/internal/services/smtp/smtp_debug.go
@@ -17,7 +17,7 @@
//go:build build_qa
-package user
+package smtp
import (
"fmt"
diff --git a/internal/user/smtp_default.go b/internal/services/smtp/smtp_default.go
similarity index 98%
rename from internal/user/smtp_default.go
rename to internal/services/smtp/smtp_default.go
index 11bb8ded..be14bd91 100644
--- a/internal/user/smtp_default.go
+++ b/internal/services/smtp/smtp_default.go
@@ -17,7 +17,7 @@
//go:build !build_qa
-package user
+package smtp
func debugDumpToDisk(_ []byte) error {
return nil
diff --git a/internal/user/smtp_packages.go b/internal/services/smtp/smtp_packages.go
similarity index 99%
rename from internal/user/smtp_packages.go
rename to internal/services/smtp/smtp_packages.go
index e32b4450..a9422378 100644
--- a/internal/user/smtp_packages.go
+++ b/internal/services/smtp/smtp_packages.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see .
-package user
+package smtp
import (
"github.com/ProtonMail/gluon/rfc822"
diff --git a/internal/user/smtp_prefs.go b/internal/services/smtp/smtp_prefs.go
similarity index 99%
rename from internal/user/smtp_prefs.go
rename to internal/services/smtp/smtp_prefs.go
index 3519f0c0..ac7d0bd9 100644
--- a/internal/user/smtp_prefs.go
+++ b/internal/services/smtp/smtp_prefs.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see .
-package user
+package smtp
import (
"fmt"
diff --git a/internal/user/smtp_prefs_test.go b/internal/services/smtp/smtp_prefs_test.go
similarity index 99%
rename from internal/user/smtp_prefs_test.go
rename to internal/services/smtp/smtp_prefs_test.go
index 2a60ee7d..dad56bb0 100644
--- a/internal/user/smtp_prefs_test.go
+++ b/internal/services/smtp/smtp_prefs_test.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see .
-package user
+package smtp
import (
"testing"
diff --git a/internal/user/debug.go b/internal/user/debug.go
index dbc73992..fbbb7c45 100644
--- a/internal/user/debug.go
+++ b/internal/user/debug.go
@@ -35,6 +35,7 @@ import (
"github.com/ProtonMail/gopenpgp/v2/constants"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
+ "github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/bradenaw/juniper/xmaps"
"github.com/bradenaw/juniper/xslices"
@@ -62,7 +63,7 @@ func (apm DiagnosticMetadata) BuildMailboxToMessageMap(user *User) (map[string]A
result := make(map[string]AccountMailboxMap)
mode := user.GetAddressMode()
- primaryAddrID, err := getPrimaryAddr(user.apiAddrs)
+ primaryAddrID, err := usertypes.GetPrimaryAddr(user.apiAddrs)
if err != nil {
return nil, fmt.Errorf("failed to get primary addr for user: %w", err)
}
@@ -178,7 +179,7 @@ func (user *User) DebugDownloadMessages(
return fmt.Errorf("failed to create directory '%v':%w", msgDir, err)
}
- message, err := user.client.GetFullMessage(ctx, msg.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
+ message, err := user.client.GetFullMessage(ctx, msg.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
return fmt.Errorf("failed to download message '%v':%w", msg.ID, err)
}
@@ -187,7 +188,7 @@ func (user *User) DebugDownloadMessages(
return err
}
- if err := withAddrKR(user.apiUser, user.apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
+ if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
switch {
case len(message.Attachments) > 0:
return decodeMultipartMessage(msgDir, addrKR, message.Message, message.AttData)
diff --git a/internal/user/errors.go b/internal/user/errors.go
index ef61b136..9435388b 100644
--- a/internal/user/errors.go
+++ b/internal/user/errors.go
@@ -20,8 +20,6 @@ package user
import "errors"
var (
- ErrNoSuchAddress = errors.New("no such address")
- ErrInvalidReturnPath = errors.New("invalid return path")
- ErrInvalidRecipient = errors.New("invalid recipient")
- ErrMissingAddrKey = errors.New("missing address key")
+ ErrNoSuchAddress = errors.New("no such address")
+ ErrMissingAddrKey = errors.New("missing address key")
)
diff --git a/internal/user/events.go b/internal/user/events.go
index 6876aef6..ee4ed885 100644
--- a/internal/user/events.go
+++ b/internal/user/events.go
@@ -34,6 +34,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
+ "github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
@@ -114,8 +115,8 @@ func (user *User) syncUserAddressesLabelsAndClearSync(ctx context.Context, cance
// Update the API info in the user.
user.apiUser = apiUser
- user.apiAddrs = groupBy(apiAddrs, func(addr proton.Address) string { return addr.ID })
- user.apiLabels = groupBy(apiLabels, func(label proton.Label) string { return label.ID })
+ user.apiAddrs = usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID })
+ user.apiLabels = usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID })
// Clear sync status; we want to sync everything again.
if err := user.clearSyncStatus(); err != nil {
@@ -208,7 +209,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
// If the address is enabled, we need to hook it up to the update channels.
switch user.vault.AddressMode() {
case vault.CombinedMode:
- primAddr, err := getPrimaryAddr(user.apiAddrs)
+ primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
if err != nil {
return fmt.Errorf("failed to get primary address: %w", err)
}
@@ -267,7 +268,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
switch user.vault.AddressMode() {
case vault.CombinedMode:
- primAddr, err := getPrimaryAddr(user.apiAddrs)
+ primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
if err != nil {
return fmt.Errorf("failed to get primary address: %w", err)
}
@@ -585,7 +586,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
"subject": logging.Sensitive(message.Subject),
}).Info("Handling message created event")
- full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
+ full, err := user.client.GetFullMessage(ctx, message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
// If the message is not found, it means that it has been deleted before we could fetch it.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
@@ -599,7 +600,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
return safe.RLockRetErr(func() ([]imap.Update, error) {
var update imap.Update
- if err := withAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
+ if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
if res.err != nil {
@@ -652,7 +653,7 @@ func (user *User) handleUpdateMessageEvent(_ context.Context, message proton.Mes
update := imap.NewMessageMailboxesUpdated(
imap.MessageID(message.ID),
- mapTo[string, imap.MailboxID](wantLabels(user.apiLabels, message.LabelIDs)),
+ usertypes.MapTo[string, imap.MailboxID](wantLabels(user.apiLabels, message.LabelIDs)),
flags,
)
@@ -693,7 +694,7 @@ func (user *User) handleUpdateDraftOrSentMessage(ctx context.Context, event prot
"isDraft": event.Message.IsDraft(),
}).Info("Handling draft or sent updated event")
- full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
+ full, err := user.client.GetFullMessage(ctx, event.Message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
// If the message is not found, it means that it has been deleted before we could fetch it.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
@@ -706,7 +707,7 @@ func (user *User) handleUpdateDraftOrSentMessage(ctx context.Context, event prot
var update imap.Update
- if err := withAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
+ if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
if res.err != nil {
@@ -827,7 +828,7 @@ func safePublishMessageUpdate(user *User, addressID string, update imap.Update)
v, ok := user.updateCh[addressID]
if !ok {
if user.GetAddressMode() == vault.CombinedMode {
- primAddr, err := getPrimaryAddr(user.apiAddrs)
+ primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
if err != nil {
return false, fmt.Errorf("failed to get primary address: %w", err)
}
diff --git a/internal/user/imap.go b/internal/user/imap.go
index e3e2b7a5..6ab3dd98 100644
--- a/internal/user/imap.go
+++ b/internal/user/imap.go
@@ -33,6 +33,8 @@ import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
+ "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
+ "github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
@@ -288,26 +290,26 @@ func (conn *imapConnector) CreateMessage(
}
// Compute the hash of the message (to match it against SMTP messages).
- hash, err := getMessageHash(literal)
+ hash, err := sendrecorder.GetMessageHash(literal)
if err != nil {
return imap.Message{}, nil, err
}
// Check if we already tried to send this message recently.
- if messageID, ok, err := conn.sendHash.hasEntryWait(ctx, hash, time.Now().Add(90*time.Second), toList); err != nil {
+ if messageID, ok, err := conn.sendHash.HasEntryWait(ctx, hash, time.Now().Add(90*time.Second), toList); err != nil {
return imap.Message{}, nil, fmt.Errorf("failed to check send hash: %w", err)
} else if ok {
conn.log.WithField("messageID", messageID).Warn("Message already sent")
// Query the server-side message.
- full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
+ full, err := conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
}
// Build the message as it is on the server.
if err := safe.RLockRet(func() error {
- return withAddrKR(conn.apiUser, conn.apiAddrs[full.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
+ return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[full.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
var err error
if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil {
@@ -378,14 +380,14 @@ func (conn *imapConnector) CreateMessage(
}
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
- msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
+ msg, err := conn.client.GetFullMessage(ctx, string(id), usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
return nil, err
}
return safe.RLockRetErr(func() ([]byte, error) {
var literal []byte
- err := withAddrKR(conn.apiUser, conn.apiAddrs[msg.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
+ err := usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[msg.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
l, buildErr := message.BuildRFC822(addrKR, msg.Message, msg.AttData, defaultJobOpts())
if buildErr != nil {
return buildErr
@@ -408,7 +410,7 @@ func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs
return connector.ErrOperationNotAllowed
}
- return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(mailboxID))
+ return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(mailboxID))
}
// RemoveMessagesFromMailbox unlabels the given messages with the given label ID.
@@ -419,7 +421,7 @@ func (conn *imapConnector) RemoveMessagesFromMailbox(ctx context.Context, messag
return connector.ErrOperationNotAllowed
}
- msgIDs := mapTo[imap.MessageID, string](messageIDs)
+ msgIDs := usertypes.MapTo[imap.MessageID, string](messageIDs)
if err := conn.client.UnlabelMessages(ctx, msgIDs, string(mailboxID)); err != nil {
return err
}
@@ -461,12 +463,12 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
return result
}()
- if err := conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
+ if err := conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
return false, fmt.Errorf("labeling messages: %w", err)
}
if shouldExpungeOldLocation {
- if err := conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
+ if err := conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
return false, fmt.Errorf("unlabeling messages: %w", err)
}
}
@@ -479,10 +481,10 @@ func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []im
defer conn.goPollAPIEvents(false)
if seen {
- return conn.client.MarkMessagesRead(ctx, mapTo[imap.MessageID, string](messageIDs)...)
+ return conn.client.MarkMessagesRead(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...)
}
- return conn.client.MarkMessagesUnread(ctx, mapTo[imap.MessageID, string](messageIDs)...)
+ return conn.client.MarkMessagesUnread(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...)
}
// MarkMessagesFlagged sets the flagged value of the given messages.
@@ -490,10 +492,10 @@ func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs [
defer conn.goPollAPIEvents(false)
if flagged {
- return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
+ return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
}
- return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
+ return conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
}
// GetUpdates returns a stream of updates that the gluon server should apply.
@@ -535,7 +537,7 @@ func (conn *imapConnector) importMessage(
var full proton.FullMessage
if err := safe.RLockRet(func() error {
- return withAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
+ return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
var messageID string
if slices.Contains(labelIDs, proton.DraftsLabel) {
@@ -571,7 +573,7 @@ func (conn *imapConnector) importMessage(
var err error
- if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
+ if full, err = conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
return fmt.Errorf("failed to fetch message: %w", err)
}
diff --git a/internal/user/keys_test.go b/internal/user/keys_test.go
index bd29ff78..4e9af945 100644
--- a/internal/user/keys_test.go
+++ b/internal/user/keys_test.go
@@ -24,6 +24,7 @@ import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
"github.com/ProtonMail/gopenpgp/v2/crypto"
+ "github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/stretchr/testify/require"
)
@@ -36,7 +37,7 @@ func BenchmarkAddrKeyRing(b *testing.B) {
b.StartTimer()
for i := 0; i < b.N; i++ {
- require.NoError(b, withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
+ require.NoError(b, usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
return nil
}))
}
diff --git a/internal/user/sync.go b/internal/user/sync.go
index 2b60cbf4..bc8921e6 100644
--- a/internal/user/sync.go
+++ b/internal/user/sync.go
@@ -34,6 +34,7 @@ import (
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
+ "github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/xslices"
@@ -114,7 +115,7 @@ func (user *User) doSync(ctx context.Context) error {
func (user *User) sync(ctx context.Context) error {
return safe.RLockRet(func() error {
- return withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
+ return usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
if !user.vault.SyncStatus().HasLabels {
user.log.Info("Syncing labels")
diff --git a/internal/user/sync_build.go b/internal/user/sync_build.go
index 93309c35..a0e03998 100644
--- a/internal/user/sync_build.go
+++ b/internal/user/sync_build.go
@@ -25,6 +25,7 @@ import (
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
+ "github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
"github.com/bradenaw/juniper/xslices"
@@ -87,7 +88,7 @@ func newMessageCreatedUpdate(
return &imap.MessageCreated{
Message: toIMAPMessage(message),
Literal: literal,
- MailboxIDs: mapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
+ MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
ParsedMessage: parsedMessage,
}, nil
}
@@ -106,7 +107,7 @@ func newMessageCreatedFailedUpdate(
return &imap.MessageCreated{
Message: toIMAPMessage(message),
- MailboxIDs: mapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
+ MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
Literal: literal,
ParsedMessage: parsedMessage,
}
diff --git a/internal/user/user.go b/internal/user/user.go
index 49d29ab7..23a32602 100644
--- a/internal/user/user.go
+++ b/internal/user/user.go
@@ -40,7 +40,10 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
+ "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
+ "github.com/ProtonMail/proton-bridge/v3/internal/services/smtp"
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry"
+ "github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
"github.com/bradenaw/juniper/xslices"
@@ -65,7 +68,7 @@ type User struct {
vault *vault.User
client *proton.Client
reporter reporter.Reporter
- sendHash *sendRecorder
+ sendHash *sendrecorder.SendRecorder
eventCh *async.QueuedChannel[events.Event]
eventLock safe.RWMutex
@@ -100,6 +103,8 @@ type User struct {
telemetryManager telemetry.Availability
// goStatusProgress triggers a check/sending if progress is needed.
goStatusProgress func()
+
+ smtpService *smtp.Service
}
// New returns a new user.
@@ -148,7 +153,7 @@ func New(
vault: encVault,
client: client,
reporter: reporter,
- sendHash: newSendRecorder(sendEntryExpiry),
+ sendHash: sendrecorder.NewSendRecorder(sendrecorder.SendEntryExpiry),
eventCh: async.NewQueuedChannel[events.Event](0, 0, crashHandler),
eventLock: safe.NewRWMutex(),
@@ -156,10 +161,10 @@ func New(
apiUser: apiUser,
apiUserLock: safe.NewRWMutex(),
- apiAddrs: groupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }),
+ apiAddrs: usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }),
apiAddrsLock: safe.NewRWMutex(),
- apiLabels: groupBy(apiLabels, func(label proton.Label) string { return label.ID }),
+ apiLabels: usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID }),
apiLabelsLock: safe.NewRWMutex(),
updateCh: make(map[string]*async.QueuedChannel[imap.Update]),
@@ -178,6 +183,8 @@ func New(
telemetryManager: telemetryManager,
}
+ user.smtpService = smtp.NewService(user, client, user.sendHash, user.panicHandler, user.reporter)
+
// Check for status_progress when triggered.
user.goStatusProgress = user.tasks.PeriodicOrTrigger(configstatus.ProgressCheckInterval, 0, func(ctx context.Context) {
user.SendConfigStatusProgress(ctx)
@@ -264,6 +271,9 @@ func New(
}
})
+ // Start SMTP Service
+ user.smtpService.Start(user.tasks)
+
return user, nil
}
@@ -461,7 +471,7 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
switch user.vault.AddressMode() {
case vault.CombinedMode:
- primAddr, err := getAddrIdx(user.apiAddrs, 0)
+ primAddr, err := usertypes.GetAddrIdx(user.apiAddrs, 0)
if err != nil {
return nil, fmt.Errorf("failed to get primary address: %w", err)
}
@@ -485,10 +495,10 @@ func (user *User) SendMail(authID string, from string, to []string, r io.Reader)
}
if len(to) == 0 {
- return ErrInvalidRecipient
+ return smtp.ErrInvalidRecipient
}
- if err := user.sendMail(authID, from, to, r); err != nil {
+ if err := user.smtpService.SendMail(context.Background(), authID, from, to, r); err != nil {
if apiErr := new(proton.APIError); errors.As(err, &apiErr) {
logrus.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("failed to send message")
}
@@ -664,6 +674,12 @@ func (user *User) SendTelemetry(ctx context.Context, data []byte) error {
return nil
}
+func (user *User) WithSMTPData(ctx context.Context, op func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error {
+ return safe.RLockRet(func() error {
+ return op(ctx, user.apiAddrs, user.apiUser, user.vault)
+ }, user.apiUserLock, user.apiAddrsLock, user.eventLock)
+}
+
// initUpdateCh initializes the user's update channels in the given address mode.
// It is assumed that user.apiAddrs and user.updateCh are already locked.
func (user *User) initUpdateCh(mode vault.AddressMode) {
diff --git a/internal/user/keys.go b/internal/usertypes/keys.go
similarity index 93%
rename from internal/user/keys.go
rename to internal/usertypes/keys.go
index 949c873b..d267474f 100644
--- a/internal/user/keys.go
+++ b/internal/usertypes/keys.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see .
-package user
+package usertypes
import (
"fmt"
@@ -25,7 +25,7 @@ import (
"github.com/sirupsen/logrus"
)
-func withAddrKR(apiUser proton.User, apiAddr proton.Address, keyPass []byte, fn func(userKR, addrKR *crypto.KeyRing) error) error {
+func WithAddrKR(apiUser proton.User, apiAddr proton.Address, keyPass []byte, fn func(userKR, addrKR *crypto.KeyRing) error) error {
userKR, err := apiUser.Keys.Unlock(keyPass, nil)
if err != nil {
return fmt.Errorf("failed to unlock user keys: %w", err)
@@ -41,7 +41,7 @@ func withAddrKR(apiUser proton.User, apiAddr proton.Address, keyPass []byte, fn
return fn(userKR, addrKR)
}
-func withAddrKRs(apiUser proton.User, apiAddr map[string]proton.Address, keyPass []byte, fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
+func WithAddrKRs(apiUser proton.User, apiAddr map[string]proton.Address, keyPass []byte, fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
userKR, err := apiUser.Keys.Unlock(keyPass, nil)
if err != nil {
return fmt.Errorf("failed to unlock user keys: %w", err)
diff --git a/internal/user/types.go b/internal/usertypes/types.go
similarity index 74%
rename from internal/user/types.go
rename to internal/usertypes/types.go
index 67fbdf04..b30df75b 100644
--- a/internal/user/types.go
+++ b/internal/usertypes/types.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see .
-package user
+package usertypes
import (
"fmt"
@@ -29,10 +29,10 @@ import (
"golang.org/x/exp/slices"
)
-// mapTo converts the slice to the given type.
+// MapTo converts the slice to the given type.
// This is not runtime safe, so make sure the slice is of the correct type!
// (This is a workaround for the fact that slices cannot be converted to other types generically).
-func mapTo[From, To any](from []From) []To {
+func MapTo[From, To any](from []From) []To {
to := make([]To, 0, len(from))
for _, from := range from {
@@ -47,9 +47,9 @@ func mapTo[From, To any](from []From) []To {
return to
}
-// groupBy returns a map of the given slice grouped by the given key.
+// GroupBy returns a map of the given slice grouped by the given key.
// Duplicate keys are overwritten.
-func groupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[Key]Value {
+func GroupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[Key]Value {
groups := make(map[Key]Value)
for _, item := range items {
@@ -59,10 +59,10 @@ func groupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[
return groups
}
-// getAddrID returns the address ID for the given email address.
-func getAddrID(apiAddrs map[string]proton.Address, email string) (string, error) {
+// GetAddrID returns the address ID for the given email address.
+func GetAddrID(apiAddrs map[string]proton.Address, email string) (string, error) {
for _, addr := range apiAddrs {
- if strings.EqualFold(addr.Email, sanitizeEmail(email)) {
+ if strings.EqualFold(addr.Email, SanitizeEmail(email)) {
return addr.ID, nil
}
}
@@ -70,8 +70,8 @@ func getAddrID(apiAddrs map[string]proton.Address, email string) (string, error)
return "", fmt.Errorf("address %s not found", email)
}
-// getAddrIdx returns the address with the given index.
-func getAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, error) {
+// GetAddrIdx returns the address with the given index.
+func GetAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, error) {
sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool {
return a.Order < b.Order
})
@@ -83,7 +83,7 @@ func getAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, er
return sorted[idx], nil
}
-func getPrimaryAddr(apiAddrs map[string]proton.Address) (proton.Address, error) {
+func GetPrimaryAddr(apiAddrs map[string]proton.Address) (proton.Address, error) {
sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool {
return a.Order < b.Order
})
@@ -106,6 +106,15 @@ func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
return sorted
}
-func newProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler {
+func NewProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler {
return proton.NewParallelScheduler(runtime.NumCPU()/2, panicHandler)
}
+
+func SanitizeEmail(email string) string {
+ splitAt := strings.Split(email, "@")
+ if len(splitAt) != 2 {
+ return email
+ }
+
+ return strings.Split(splitAt[0], "+")[0] + "@" + splitAt[1]
+}
diff --git a/internal/user/types_test.go b/internal/usertypes/types_test.go
similarity index 88%
rename from internal/user/types_test.go
rename to internal/usertypes/types_test.go
index bfc23157..2c974898 100644
--- a/internal/user/types_test.go
+++ b/internal/usertypes/types_test.go
@@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see .
-package user
+package usertypes
import (
"testing"
@@ -30,8 +30,8 @@ func TestToType(t *testing.T) {
require.NotEqual(t, []myString{"a", "b", "c"}, []string{"a", "b", "c"})
// But converting them to the same type makes them equal.
- require.Equal(t, []myString{"a", "b", "c"}, mapTo[string, myString]([]string{"a", "b", "c"}))
+ require.Equal(t, []myString{"a", "b", "c"}, MapTo[string, myString]([]string{"a", "b", "c"}))
// The conversion can happen in the other direction too.
- require.Equal(t, []string{"a", "b", "c"}, mapTo[myString, string]([]myString{"a", "b", "c"}))
+ require.Equal(t, []string{"a", "b", "c"}, MapTo[myString, string]([]myString{"a", "b", "c"}))
}