diff --git a/internal/user/send_recorder.go b/internal/services/sendrecorder/recorder.go similarity index 79% rename from internal/user/send_recorder.go rename to internal/services/sendrecorder/recorder.go index 3119aaa1..5b8df1b0 100644 --- a/internal/user/send_recorder.go +++ b/internal/services/sendrecorder/recorder.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with Proton Mail Bridge. If not, see . -package user +package sendrecorder import ( "context" @@ -30,11 +30,11 @@ import ( "golang.org/x/exp/slices" ) -const sendEntryExpiry = 30 * time.Minute +const SendEntryExpiry = 30 * time.Minute -type SendRecorderID uint64 +type ID uint64 -type sendRecorder struct { +type SendRecorder struct { expiry time.Duration entries map[string][]*sendEntry @@ -42,15 +42,15 @@ type sendRecorder struct { cancelIDCounter uint64 } -func newSendRecorder(expiry time.Duration) *sendRecorder { - return &sendRecorder{ +func NewSendRecorder(expiry time.Duration) *SendRecorder { + return &SendRecorder{ expiry: expiry, entries: make(map[string][]*sendEntry), } } type sendEntry struct { - srID SendRecorderID + srID ID msgID string toList []string exp time.Time @@ -65,17 +65,17 @@ func (s *sendEntry) closeWaitChannel() { } } -// tryInsertWait tries to insert the given message into the send recorder. +// TryInsertWait tries to insert the given message into the send recorder. // If an entry already exists but it was not sent yet, it waits. // It returns whether an entry could be inserted and an error if it times out while waiting. -func (h *sendRecorder) tryInsertWait( +func (h *SendRecorder) TryInsertWait( ctx context.Context, hash string, toList []string, deadline time.Time, -) (SendRecorderID, bool, error) { +) (ID, bool, error) { // If we successfully inserted the hash, we can return true. - srID, waitCh, ok := h.tryInsert(hash, toList) + srID, waitCh, ok := h.TryInsert(hash, toList) if ok { return srID, true, nil } @@ -88,16 +88,16 @@ func (h *sendRecorder) tryInsertWait( // If the message failed to send, try to insert it again. if !wasSent { - return h.tryInsertWait(ctx, hash, toList, deadline) + return h.TryInsertWait(ctx, hash, toList, deadline) } return srID, false, nil } -// hasEntryWait returns whether the given message already exists in the send recorder. +// HasEntryWait returns whether the given message already exists in the send recorder. // If it does, it waits for its ID to be known, then returns it and true. // If no entry exists, or it times out while waiting for its ID to be known, it returns false. -func (h *sendRecorder) hasEntryWait(ctx context.Context, +func (h *SendRecorder) HasEntryWait(ctx context.Context, hash string, deadline time.Time, toList []string, @@ -118,10 +118,10 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context, return messageID, true, nil } - return h.hasEntryWait(ctx, hash, deadline, toList) + return h.HasEntryWait(ctx, hash, deadline, toList) } -func (h *sendRecorder) removeExpiredUnsafe() { +func (h *SendRecorder) removeExpiredUnsafe() { for hash, entry := range h.entries { remaining := xslices.Filter(entry, func(t *sendEntry) bool { return !t.exp.Before(time.Now()) @@ -135,7 +135,7 @@ func (h *sendRecorder) removeExpiredUnsafe() { } } -func (h *sendRecorder) tryInsert(hash string, toList []string) (SendRecorderID, <-chan struct{}, bool) { +func (h *SendRecorder) TryInsert(hash string, toList []string) (ID, <-chan struct{}, bool) { h.entriesLock.Lock() defer h.entriesLock.Unlock() @@ -163,7 +163,7 @@ func (h *sendRecorder) tryInsert(hash string, toList []string) (SendRecorderID, return cancelID, waitCh, true } -func (h *sendRecorder) getEntryWaitInfo(hash string, toList []string) (SendRecorderID, <-chan struct{}, bool) { +func (h *SendRecorder) getEntryWaitInfo(hash string, toList []string) (ID, <-chan struct{}, bool) { h.entriesLock.Lock() defer h.entriesLock.Unlock() @@ -180,8 +180,8 @@ func (h *sendRecorder) getEntryWaitInfo(hash string, toList []string) (SendRecor return 0, nil, false } -// signalMessageSent should be called after a message has been successfully sent. -func (h *sendRecorder) signalMessageSent(hash string, srID SendRecorderID, msgID string) { +// SignalMessageSent should be called after a message has been successfully sent. +func (h *SendRecorder) SignalMessageSent(hash string, srID ID, msgID string) { h.entriesLock.Lock() defer h.entriesLock.Unlock() @@ -199,7 +199,7 @@ func (h *sendRecorder) signalMessageSent(hash string, srID SendRecorderID, msgID logrus.Warn("Cannot add message ID to send hash entry, it may have expired") } -func (h *sendRecorder) removeOnFail(hash string, id SendRecorderID) { +func (h *SendRecorder) RemoveOnFail(hash string, id ID) { h.entriesLock.Lock() defer h.entriesLock.Unlock() @@ -222,11 +222,11 @@ func (h *sendRecorder) removeOnFail(hash string, id SendRecorderID) { } } -func (h *sendRecorder) wait( +func (h *SendRecorder) wait( ctx context.Context, hash string, waitCh <-chan struct{}, - srID SendRecorderID, + srID ID, deadline time.Time, ) (string, bool, error) { ctx, cancel := context.WithDeadline(ctx, deadline) @@ -254,19 +254,19 @@ func (h *sendRecorder) wait( return "", false, nil } -func (h *sendRecorder) newSendRecorderID() SendRecorderID { +func (h *SendRecorder) newSendRecorderID() ID { h.cancelIDCounter++ - return SendRecorderID(h.cancelIDCounter) + return ID(h.cancelIDCounter) } -// getMessageHash returns the hash of the given message. +// GetMessageHash returns the hash of the given message. // This takes into account: // - the Subject header, // - the From/To/Cc headers, // - the Content-Type header of each (leaf) part, // - the Content-Disposition header of each (leaf) part, // - the (decoded) body of each part. -func getMessageHash(b []byte) (string, error) { +func GetMessageHash(b []byte) (string, error) { return rfc822.GetMessageHash(b) } diff --git a/internal/user/send_recorder_test.go b/internal/services/sendrecorder/recorder_test.go similarity index 89% rename from internal/user/send_recorder_test.go rename to internal/services/sendrecorder/recorder_test.go index abc14037..215a3bd5 100644 --- a/internal/user/send_recorder_test.go +++ b/internal/services/sendrecorder/recorder_test.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with Proton Mail Bridge. If not, see . -package user +package sendrecorder import ( "context" @@ -26,7 +26,7 @@ import ( ) func TestSendHasher_Insert(t *testing.T) { - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. srdID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -35,7 +35,7 @@ func TestSendHasher_Insert(t *testing.T) { require.NotEmpty(t, hash1) // Simulate successfully sending the message. - h.signalMessageSent(hash1, srdID1, "abc") + h.SignalMessageSent(hash1, srdID1, "abc") // Inserting a message with the same hash should return false. srdID2, _, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -52,7 +52,7 @@ func TestSendHasher_Insert(t *testing.T) { } func TestSendHasher_Insert_Expired(t *testing.T) { - h := newSendRecorder(time.Second) + h := NewSendRecorder(time.Second) // Insert a message into the hasher. srID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -61,7 +61,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) { require.NotEmpty(t, hash1) // Simulate successfully sending the message. - h.signalMessageSent(hash1, srID1, "abc") + h.SignalMessageSent(hash1, srID1, "abc") // Wait for the entry to expire. time.Sleep(time.Second) @@ -79,7 +79,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) { } func TestSendHasher_Insert_DifferentToList(t *testing.T) { - h := newSendRecorder(time.Second) + h := NewSendRecorder(time.Second) // Insert a message into the hasher. srID1, hash1, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), []string{"abc", "def"}...) @@ -101,7 +101,7 @@ func TestSendHasher_Insert_DifferentToList(t *testing.T) { } func TestSendHasher_Wait_SendSuccess(t *testing.T) { - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -112,7 +112,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) { // Simulate successfully sending the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.signalMessageSent(hash, srID1, "abc") + h.SignalMessageSent(hash, srID1, "abc") }() // Inserting a message with the same hash should fail. @@ -123,7 +123,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) { } func TestSendHasher_Wait_SendFail(t *testing.T) { - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -134,7 +134,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) { // Simulate failing to send the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.removeOnFail(hash, srID1) + h.RemoveOnFail(hash, srID1) }() // Inserting a message with the same hash should succeed because the first message failed to send. @@ -148,7 +148,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) { } func TestSendHasher_Wait_Timeout(t *testing.T) { - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. _, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -162,7 +162,7 @@ func TestSendHasher_Wait_Timeout(t *testing.T) { } func TestSendHasher_HasEntry(t *testing.T) { - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -171,7 +171,7 @@ func TestSendHasher_HasEntry(t *testing.T) { require.NotEmpty(t, hash) // Simulate successfully sending the message. - h.signalMessageSent(hash, srID1, "abc") + h.SignalMessageSent(hash, srID1, "abc") // The message was already sent; we should find it in the hasher. messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second)) @@ -181,7 +181,7 @@ func TestSendHasher_HasEntry(t *testing.T) { } func TestSendHasher_HasEntry_SendSuccess(t *testing.T) { - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -192,7 +192,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) { // Simulate successfully sending the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.signalMessageSent(hash, srID1, "abc") + h.SignalMessageSent(hash, srID1, "abc") }() // The message was already sent; we should find it in the hasher. @@ -205,9 +205,9 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) { func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) { // There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message // is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be - // inserted as a new entry. The two clients end up sending the message twice and calling the `signalMessageSent` x2, + // inserted as a new entry. The two clients end up sending the message twice and calling the `SignalMessageSent` x2, // resulting in a crash. - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -217,8 +217,8 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) { // Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections // to attempt to send the same message. - h.signalMessageSent(hash, srID1, "abc") - h.signalMessageSent(hash, srID1, "abc") + h.SignalMessageSent(hash, srID1, "abc") + h.SignalMessageSent(hash, srID1, "abc") // The message was already sent; we should find it in the hasher. messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second)) @@ -228,7 +228,7 @@ func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) { } func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testing.T) { - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver ") @@ -249,7 +249,7 @@ func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testin func TestSendHashed_SameMessageWIthDifferentToListShouldWaitSuccessfullyAfterSend(t *testing.T) { // Check that if we send the same message twice with different recipients and the second message is somehow // sent before the first, ensure that we check if the message was sent we wait on the correct object. - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. _, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Minute), "Receiver ") @@ -264,16 +264,16 @@ func TestSendHashed_SameMessageWIthDifferentToListShouldWaitSuccessfullyAfterSen require.Equal(t, hash, hash2) // simulate message sent - h.signalMessageSent(hash2, srID2, "newID") + h.SignalMessageSent(hash2, srID2, "newID") // Simulate Wait on message 2 - _, ok, err = h.hasEntryWait(context.Background(), hash2, time.Now().Add(time.Second), []string{"Receiver ", "Receiver2 "}) + _, ok, err = h.HasEntryWait(context.Background(), hash2, time.Now().Add(time.Second), []string{"Receiver ", "Receiver2 "}) require.NoError(t, err) require.True(t, ok) } func TestSendHasher_HasEntry_SendFail(t *testing.T) { - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -284,7 +284,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) { // Simulate failing to send the message after half a second. go func() { time.Sleep(time.Millisecond * 500) - h.removeOnFail(hash, srID1) + h.RemoveOnFail(hash, srID1) }() // The message failed to send; we should not find it in the hasher. @@ -294,7 +294,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) { } func TestSendHasher_HasEntry_Timeout(t *testing.T) { - h := newSendRecorder(sendEntryExpiry) + h := NewSendRecorder(SendEntryExpiry) // Insert a message into the hasher. _, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -309,7 +309,7 @@ func TestSendHasher_HasEntry_Timeout(t *testing.T) { } func TestSendHasher_HasEntry_Expired(t *testing.T) { - h := newSendRecorder(time.Second) + h := NewSendRecorder(time.Second) // Insert a message into the hasher. srID1, hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second)) @@ -318,7 +318,7 @@ func TestSendHasher_HasEntry_Expired(t *testing.T) { require.NotEmpty(t, hash) // Simulate successfully sending the message. - h.signalMessageSent(hash, srID1, "abc") + h.SignalMessageSent(hash, srID1, "abc") // Wait for the entry to expire. time.Sleep(time.Second) @@ -432,10 +432,10 @@ func TestGetMessageHash(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hash1, err := getMessageHash(tt.lit1) + hash1, err := GetMessageHash(tt.lit1) require.NoError(t, err) - hash2, err := getMessageHash(tt.lit2) + hash2, err := GetMessageHash(tt.lit2) require.NoError(t, err) if tt.wantEqual { @@ -447,13 +447,13 @@ func TestGetMessageHash(t *testing.T) { } } -func testTryInsert(h *sendRecorder, literal string, deadline time.Time, toList ...string) (SendRecorderID, string, bool, error) { //nolint:unparam - hash, err := getMessageHash([]byte(literal)) +func testTryInsert(h *SendRecorder, literal string, deadline time.Time, toList ...string) (ID, string, bool, error) { //nolint:unparam + hash, err := GetMessageHash([]byte(literal)) if err != nil { return 0, "", false, err } - srID, ok, err := h.tryInsertWait(context.Background(), hash, toList, deadline) + srID, ok, err := h.TryInsertWait(context.Background(), hash, toList, deadline) if err != nil { return 0, "", false, err } @@ -461,11 +461,11 @@ func testTryInsert(h *sendRecorder, literal string, deadline time.Time, toList . return srID, hash, ok, nil } -func testHasEntry(h *sendRecorder, literal string, deadline time.Time, toList ...string) (string, bool, error) { //nolint:unparam - hash, err := getMessageHash([]byte(literal)) +func testHasEntry(h *SendRecorder, literal string, deadline time.Time, toList ...string) (string, bool, error) { //nolint:unparam + hash, err := GetMessageHash([]byte(literal)) if err != nil { return "", false, err } - return h.hasEntryWait(context.Background(), hash, deadline, toList) + return h.HasEntryWait(context.Background(), hash, deadline, toList) } diff --git a/internal/services/smtp/errors.go b/internal/services/smtp/errors.go new file mode 100644 index 00000000..eb549ad6 --- /dev/null +++ b/internal/services/smtp/errors.go @@ -0,0 +1,23 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package smtp + +import "errors" + +var ErrInvalidRecipient = errors.New("invalid recipient") +var ErrInvalidReturnPath = errors.New("invalid return path") diff --git a/internal/services/smtp/service.go b/internal/services/smtp/service.go new file mode 100644 index 00000000..431bb502 --- /dev/null +++ b/internal/services/smtp/service.go @@ -0,0 +1,132 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package smtp + +import ( + "context" + "io" + + "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/logging" + "github.com/ProtonMail/gluon/reporter" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder" + "github.com/ProtonMail/proton-bridge/v3/internal/vault" + "github.com/ProtonMail/proton-bridge/v3/pkg/cpc" + "github.com/sirupsen/logrus" +) + +// UserInterface is just wrapper to avoid recursive go module imports. To be removed when the identity service is ready. +type UserInterface interface { + ID() string + WithSMTPData(context.Context, func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error +} + +type Service struct { + panicHandler async.PanicHandler + cpc *cpc.CPC + user UserInterface + client *proton.Client + recorder *sendrecorder.SendRecorder + log *logrus.Entry + reporter reporter.Reporter +} + +func NewService( + user UserInterface, + client *proton.Client, + recorder *sendrecorder.SendRecorder, + handler async.PanicHandler, + reporter reporter.Reporter, +) *Service { + return &Service{ + panicHandler: handler, + user: user, + cpc: cpc.NewCPC(), + recorder: recorder, + log: logrus.WithFields(logrus.Fields{ + "user": user.ID(), + "service": "smtp", + }), + reporter: reporter, + client: client, + } +} + +func (s *Service) SendMail(ctx context.Context, authID string, from string, to []string, r io.Reader) error { + _, err := s.cpc.Send(ctx, &sendMailReq{ + authID: authID, + from: from, + to: to, + r: r, + }) + + return err +} + +func (s *Service) Start(group *async.Group) { + s.log.Debug("Starting service") + group.Once(func(ctx context.Context) { + logging.DoAnnotated(ctx, func(ctx context.Context) { + s.run(ctx) + }, logging.Labels{ + "user": s.user.ID(), + "service": "smtp", + }) + }) +} + +func (s *Service) run(ctx context.Context) { + s.log.Debug("Starting service main loop") + defer s.log.Debug("Exiting service main loop") + defer s.cpc.Close() + + for { + select { + case <-ctx.Done(): + return + + case request, ok := <-s.cpc.ReceiveCh(): + if !ok { + return + } + + switch r := request.Value().(type) { + case *sendMailReq: + s.log.Debug("Received send mail request") + err := s.sendMail(ctx, r) + request.Reply(ctx, nil, err) + + default: + s.log.Error("Received unknown request") + } + } + } +} + +type sendMailReq struct { + authID string + from string + to []string + r io.Reader +} + +func (s *Service) sendMail(ctx context.Context, req *sendMailReq) error { + defer async.HandlePanic(s.panicHandler) + return s.smtpSendMail(ctx, req.authID, req.from, req.to, req.r) +} diff --git a/internal/user/smtp.go b/internal/services/smtp/smtp.go similarity index 85% rename from internal/user/smtp.go rename to internal/services/smtp/smtp.go index 183df3be..1bbfaa05 100644 --- a/internal/user/smtp.go +++ b/internal/services/smtp/smtp.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with Proton Mail Bridge. If not, see . -package user +package smtp import ( "bytes" @@ -36,7 +36,8 @@ import ( "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v3/internal/logging" - "github.com/ProtonMail/proton-bridge/v3/internal/safe" + "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder" + "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/pkg/message" "github.com/ProtonMail/proton-bridge/v3/pkg/message/parser" @@ -47,19 +48,14 @@ import ( "golang.org/x/exp/slices" ) -// sendMail sends an email from the given address to the given recipients. -func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error { - defer async.HandlePanic(user.panicHandler) - - return safe.RLockRet(func() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - if _, err := getAddrID(user.apiAddrs, from); err != nil { +// smtpSendMail sends an email from the given address to the given recipients. +func (s *Service) smtpSendMail(ctx context.Context, authID string, from string, to []string, r io.Reader) error { + return s.user.WithSMTPData(ctx, func(ctx context.Context, apiAddrs map[string]proton.Address, user proton.User, vault *vault.User) error { + if _, err := usertypes.GetAddrID(apiAddrs, from); err != nil { return ErrInvalidReturnPath } - emails := xslices.Map(maps.Values(user.apiAddrs), func(addr proton.Address) string { + emails := xslices.Map(maps.Values(apiAddrs), func(addr proton.Address) string { return addr.Email }) @@ -71,26 +67,26 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader) // If running a QA build, dump to disk. if err := debugDumpToDisk(b); err != nil { - user.log.WithError(err).Warn("Failed to dump message to disk") + s.log.WithError(err).Warn("Failed to dump message to disk") } // Compute the hash of the message (to match it against SMTP messages). - hash, err := getMessageHash(b) + hash, err := sendrecorder.GetMessageHash(b) if err != nil { return err } // Check if we already tried to send this message recently. - srID, ok, err := user.sendHash.tryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second)) + srID, ok, err := s.recorder.TryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second)) if err != nil { return fmt.Errorf("failed to check send hash: %w", err) } else if !ok { - user.log.Warn("A duplicate message was already sent recently, skipping") + s.log.Warn("A duplicate message was already sent recently, skipping") return nil } // If we fail to send this message, we should remove the hash from the send recorder. - defer user.sendHash.removeOnFail(hash, srID) + defer s.recorder.RemoveOnFail(hash, srID) // Create a new message parser from the reader. parser, err := parser.New(bytes.NewReader(b)) @@ -104,17 +100,17 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader) } // Load the user's mail settings. - settings, err := user.client.GetMailSettings(ctx) + settings, err := s.client.GetMailSettings(ctx) if err != nil { return fmt.Errorf("failed to get mail settings: %w", err) } - addrID, err := getAddrID(user.apiAddrs, from) + addrID, err := usertypes.GetAddrID(apiAddrs, from) if err != nil { return err } - return withAddrKR(user.apiUser, user.apiAddrs[addrID], user.vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error { + return usertypes.WithAddrKR(user, apiAddrs[addrID], vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error { // Use the first key for encrypting the message. addrKR, err := addrKR.FirstKey() if err != nil { @@ -147,12 +143,10 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader) } // Send the message using the correct key. - sent, err := user.sendWithKey( + sent, err := s.sendWithKey( ctx, - user.client, - user.reporter, authID, - user.vault.AddressMode(), + vault.AddressMode(), settings, userKR, addrKR, emails, from, to, @@ -163,18 +157,16 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader) } // If the message was successfully sent, we can update the message ID in the record. - user.sendHash.signalMessageSent(hash, srID, sent.ID) + s.recorder.SignalMessageSent(hash, srID, sent.ID) return nil }) - }, user.apiUserLock, user.apiAddrsLock, user.eventLock) + }) } // sendWithKey sends the message with the given address key. -func (user *User) sendWithKey( +func (s *Service) sendWithKey( ctx context.Context, - client *proton.Client, - sentry reporter.Reporter, authAddrID string, addrMode vault.AddressMode, settings proton.MailSettings, @@ -188,16 +180,16 @@ func (user *User) sendWithKey( if message.InReplyTo != "" { references = append(references, message.InReplyTo) } - parentID, err := getParentID(ctx, client, authAddrID, addrMode, references) + parentID, err := getParentID(ctx, s.client, authAddrID, addrMode, references) if err != nil { - if err := sentry.ReportMessageWithContext("Failed to get parent ID", reporter.Context{ + if err := s.reporter.ReportMessageWithContext("Failed to get parent ID", reporter.Context{ "error": err, "references": message.References, }); err != nil { logrus.WithError(err).Error("Failed to report error") } - logrus.WithError(err).Warn("Failed to get parent ID") + s.log.WithError(err).Warn("Failed to get parent ID") } var decBody string @@ -214,7 +206,7 @@ func (user *User) sendWithKey( return proton.Message{}, fmt.Errorf("unsupported MIME type: %v", message.MIMEType) } - draft, err := createDraft(ctx, client, addrKR, emails, from, to, parentID, message.InReplyTo, proton.DraftTemplate{ + draft, err := s.createDraft(ctx, addrKR, emails, from, to, parentID, message.InReplyTo, proton.DraftTemplate{ Subject: message.Subject, Body: decBody, MIMEType: message.MIMEType, @@ -230,12 +222,12 @@ func (user *User) sendWithKey( return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err) } - attKeys, err := user.createAttachments(ctx, client, addrKR, draft.ID, message.Attachments) + attKeys, err := s.createAttachments(ctx, s.client, addrKR, draft.ID, message.Attachments) if err != nil { return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err) } - recipients, err := user.getRecipients(ctx, client, userKR, settings, draft) + recipients, err := s.getRecipients(ctx, s.client, userKR, settings, draft) if err != nil { return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err) } @@ -245,7 +237,7 @@ func (user *User) sendWithKey( return proton.Message{}, fmt.Errorf("failed to create packages: %w", err) } - res, err := client.SendDraft(ctx, draft.ID, req) + res, err := s.client.SendDraft(ctx, draft.ID, req) if err != nil { return proton.Message{}, fmt.Errorf("failed to send draft: %w", err) } @@ -340,9 +332,8 @@ func getParentID( return parentID, nil } -func createDraft( +func (s *Service) createDraft( ctx context.Context, - client *proton.Client, addrKR *crypto.KeyRing, emails []string, from string, @@ -360,7 +351,7 @@ func createDraft( // Check that the sending address is owned by the user, and if so, sanitize it. if idx := xslices.IndexFunc(emails, func(email string) bool { - return strings.EqualFold(email, sanitizeEmail(template.Sender.Address)) + return strings.EqualFold(email, usertypes.SanitizeEmail(template.Sender.Address)) }); idx < 0 { return proton.Message{}, fmt.Errorf("address %q is not owned by user", template.Sender.Address) } else { //nolint:revive @@ -389,14 +380,14 @@ func createDraft( action = proton.ForwardAction } - return client.CreateDraft(ctx, addrKR, proton.CreateDraftReq{ + return s.client.CreateDraft(ctx, addrKR, proton.CreateDraftReq{ Message: template, ParentID: parentID, Action: action, }) } -func (user *User) createAttachments( +func (s *Service) createAttachments( ctx context.Context, client *proton.Client, addrKR *crypto.KeyRing, @@ -409,9 +400,9 @@ func (user *User) createAttachments( } keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) { - defer async.HandlePanic(user.panicHandler) + defer async.HandlePanic(s.panicHandler) - logrus.WithFields(logrus.Fields{ + s.log.WithFields(logrus.Fields{ "name": logging.Sensitive(att.Name), "contentID": att.ContentID, "disposition": att.Disposition, @@ -480,7 +471,7 @@ func (user *User) createAttachments( return attKeys, nil } -func (user *User) getRecipients( +func (s *Service) getRecipients( ctx context.Context, client *proton.Client, userKR *crypto.KeyRing, @@ -492,7 +483,7 @@ func (user *User) getRecipients( }) prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) { - defer async.HandlePanic(user.panicHandler) + defer async.HandlePanic(s.panicHandler) pubKeys, recType, err := client.GetPublicKeys(ctx, recipient) if err != nil { @@ -557,15 +548,6 @@ func getMessageSender(parser *parser.Parser) (string, bool) { return address[0].Address, true } -func sanitizeEmail(email string) string { - splitAt := strings.Split(email, "@") - if len(splitAt) != 2 { - return email - } - - return strings.Split(splitAt[0], "+")[0] + "@" + splitAt[1] -} - func constructEmail(headerEmail string, addressEmail string) string { splitAtHeader := strings.Split(headerEmail, "@") if len(splitAtHeader) != 2 { diff --git a/internal/user/smtp_debug.go b/internal/services/smtp/smtp_debug.go similarity index 98% rename from internal/user/smtp_debug.go rename to internal/services/smtp/smtp_debug.go index 42c1910e..b9ce856b 100644 --- a/internal/user/smtp_debug.go +++ b/internal/services/smtp/smtp_debug.go @@ -17,7 +17,7 @@ //go:build build_qa -package user +package smtp import ( "fmt" diff --git a/internal/user/smtp_default.go b/internal/services/smtp/smtp_default.go similarity index 98% rename from internal/user/smtp_default.go rename to internal/services/smtp/smtp_default.go index 11bb8ded..be14bd91 100644 --- a/internal/user/smtp_default.go +++ b/internal/services/smtp/smtp_default.go @@ -17,7 +17,7 @@ //go:build !build_qa -package user +package smtp func debugDumpToDisk(_ []byte) error { return nil diff --git a/internal/user/smtp_packages.go b/internal/services/smtp/smtp_packages.go similarity index 99% rename from internal/user/smtp_packages.go rename to internal/services/smtp/smtp_packages.go index e32b4450..a9422378 100644 --- a/internal/user/smtp_packages.go +++ b/internal/services/smtp/smtp_packages.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with Proton Mail Bridge. If not, see . -package user +package smtp import ( "github.com/ProtonMail/gluon/rfc822" diff --git a/internal/user/smtp_prefs.go b/internal/services/smtp/smtp_prefs.go similarity index 99% rename from internal/user/smtp_prefs.go rename to internal/services/smtp/smtp_prefs.go index 3519f0c0..ac7d0bd9 100644 --- a/internal/user/smtp_prefs.go +++ b/internal/services/smtp/smtp_prefs.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with Proton Mail Bridge. If not, see . -package user +package smtp import ( "fmt" diff --git a/internal/user/smtp_prefs_test.go b/internal/services/smtp/smtp_prefs_test.go similarity index 99% rename from internal/user/smtp_prefs_test.go rename to internal/services/smtp/smtp_prefs_test.go index 2a60ee7d..dad56bb0 100644 --- a/internal/user/smtp_prefs_test.go +++ b/internal/services/smtp/smtp_prefs_test.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with Proton Mail Bridge. If not, see . -package user +package smtp import ( "testing" diff --git a/internal/user/debug.go b/internal/user/debug.go index dbc73992..fbbb7c45 100644 --- a/internal/user/debug.go +++ b/internal/user/debug.go @@ -35,6 +35,7 @@ import ( "github.com/ProtonMail/gopenpgp/v2/constants" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v3/internal/safe" + "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/bradenaw/juniper/xmaps" "github.com/bradenaw/juniper/xslices" @@ -62,7 +63,7 @@ func (apm DiagnosticMetadata) BuildMailboxToMessageMap(user *User) (map[string]A result := make(map[string]AccountMailboxMap) mode := user.GetAddressMode() - primaryAddrID, err := getPrimaryAddr(user.apiAddrs) + primaryAddrID, err := usertypes.GetPrimaryAddr(user.apiAddrs) if err != nil { return nil, fmt.Errorf("failed to get primary addr for user: %w", err) } @@ -178,7 +179,7 @@ func (user *User) DebugDownloadMessages( return fmt.Errorf("failed to create directory '%v':%w", msgDir, err) } - message, err := user.client.GetFullMessage(ctx, msg.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator()) + message, err := user.client.GetFullMessage(ctx, msg.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator()) if err != nil { return fmt.Errorf("failed to download message '%v':%w", msg.ID, err) } @@ -187,7 +188,7 @@ func (user *User) DebugDownloadMessages( return err } - if err := withAddrKR(user.apiUser, user.apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { + if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { switch { case len(message.Attachments) > 0: return decodeMultipartMessage(msgDir, addrKR, message.Message, message.AttData) diff --git a/internal/user/errors.go b/internal/user/errors.go index ef61b136..9435388b 100644 --- a/internal/user/errors.go +++ b/internal/user/errors.go @@ -20,8 +20,6 @@ package user import "errors" var ( - ErrNoSuchAddress = errors.New("no such address") - ErrInvalidReturnPath = errors.New("invalid return path") - ErrInvalidRecipient = errors.New("invalid recipient") - ErrMissingAddrKey = errors.New("missing address key") + ErrNoSuchAddress = errors.New("no such address") + ErrMissingAddrKey = errors.New("missing address key") ) diff --git a/internal/user/events.go b/internal/user/events.go index 6876aef6..ee4ed885 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -34,6 +34,7 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/logging" "github.com/ProtonMail/proton-bridge/v3/internal/safe" + "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/bradenaw/juniper/xslices" "github.com/sirupsen/logrus" @@ -114,8 +115,8 @@ func (user *User) syncUserAddressesLabelsAndClearSync(ctx context.Context, cance // Update the API info in the user. user.apiUser = apiUser - user.apiAddrs = groupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }) - user.apiLabels = groupBy(apiLabels, func(label proton.Label) string { return label.ID }) + user.apiAddrs = usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }) + user.apiLabels = usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID }) // Clear sync status; we want to sync everything again. if err := user.clearSyncStatus(); err != nil { @@ -208,7 +209,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add // If the address is enabled, we need to hook it up to the update channels. switch user.vault.AddressMode() { case vault.CombinedMode: - primAddr, err := getPrimaryAddr(user.apiAddrs) + primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs) if err != nil { return fmt.Errorf("failed to get primary address: %w", err) } @@ -267,7 +268,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled: switch user.vault.AddressMode() { case vault.CombinedMode: - primAddr, err := getPrimaryAddr(user.apiAddrs) + primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs) if err != nil { return fmt.Errorf("failed to get primary address: %w", err) } @@ -585,7 +586,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M "subject": logging.Sensitive(message.Subject), }).Info("Handling message created event") - full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator()) + full, err := user.client.GetFullMessage(ctx, message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator()) if err != nil { // If the message is not found, it means that it has been deleted before we could fetch it. if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity { @@ -599,7 +600,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M return safe.RLockRetErr(func() ([]imap.Update, error) { var update imap.Update - if err := withAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { + if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer)) if res.err != nil { @@ -652,7 +653,7 @@ func (user *User) handleUpdateMessageEvent(_ context.Context, message proton.Mes update := imap.NewMessageMailboxesUpdated( imap.MessageID(message.ID), - mapTo[string, imap.MailboxID](wantLabels(user.apiLabels, message.LabelIDs)), + usertypes.MapTo[string, imap.MailboxID](wantLabels(user.apiLabels, message.LabelIDs)), flags, ) @@ -693,7 +694,7 @@ func (user *User) handleUpdateDraftOrSentMessage(ctx context.Context, event prot "isDraft": event.Message.IsDraft(), }).Info("Handling draft or sent updated event") - full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator()) + full, err := user.client.GetFullMessage(ctx, event.Message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator()) if err != nil { // If the message is not found, it means that it has been deleted before we could fetch it. if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity { @@ -706,7 +707,7 @@ func (user *User) handleUpdateDraftOrSentMessage(ctx context.Context, event prot var update imap.Update - if err := withAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { + if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer)) if res.err != nil { @@ -827,7 +828,7 @@ func safePublishMessageUpdate(user *User, addressID string, update imap.Update) v, ok := user.updateCh[addressID] if !ok { if user.GetAddressMode() == vault.CombinedMode { - primAddr, err := getPrimaryAddr(user.apiAddrs) + primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs) if err != nil { return false, fmt.Errorf("failed to get primary address: %w", err) } diff --git a/internal/user/imap.go b/internal/user/imap.go index e3e2b7a5..6ab3dd98 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -33,6 +33,8 @@ import ( "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v3/internal/safe" + "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder" + "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/pkg/message" "github.com/ProtonMail/proton-bridge/v3/pkg/message/parser" @@ -288,26 +290,26 @@ func (conn *imapConnector) CreateMessage( } // Compute the hash of the message (to match it against SMTP messages). - hash, err := getMessageHash(literal) + hash, err := sendrecorder.GetMessageHash(literal) if err != nil { return imap.Message{}, nil, err } // Check if we already tried to send this message recently. - if messageID, ok, err := conn.sendHash.hasEntryWait(ctx, hash, time.Now().Add(90*time.Second), toList); err != nil { + if messageID, ok, err := conn.sendHash.HasEntryWait(ctx, hash, time.Now().Add(90*time.Second), toList); err != nil { return imap.Message{}, nil, fmt.Errorf("failed to check send hash: %w", err) } else if ok { conn.log.WithField("messageID", messageID).Warn("Message already sent") // Query the server-side message. - full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()) + full, err := conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()) if err != nil { return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err) } // Build the message as it is on the server. if err := safe.RLockRet(func() error { - return withAddrKR(conn.apiUser, conn.apiAddrs[full.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { + return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[full.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { var err error if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil { @@ -378,14 +380,14 @@ func (conn *imapConnector) CreateMessage( } func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) { - msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()) + msg, err := conn.client.GetFullMessage(ctx, string(id), usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()) if err != nil { return nil, err } return safe.RLockRetErr(func() ([]byte, error) { var literal []byte - err := withAddrKR(conn.apiUser, conn.apiAddrs[msg.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { + err := usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[msg.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { l, buildErr := message.BuildRFC822(addrKR, msg.Message, msg.AttData, defaultJobOpts()) if buildErr != nil { return buildErr @@ -408,7 +410,7 @@ func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs return connector.ErrOperationNotAllowed } - return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(mailboxID)) + return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(mailboxID)) } // RemoveMessagesFromMailbox unlabels the given messages with the given label ID. @@ -419,7 +421,7 @@ func (conn *imapConnector) RemoveMessagesFromMailbox(ctx context.Context, messag return connector.ErrOperationNotAllowed } - msgIDs := mapTo[imap.MessageID, string](messageIDs) + msgIDs := usertypes.MapTo[imap.MessageID, string](messageIDs) if err := conn.client.UnlabelMessages(ctx, msgIDs, string(mailboxID)); err != nil { return err } @@ -461,12 +463,12 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M return result }() - if err := conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil { + if err := conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil { return false, fmt.Errorf("labeling messages: %w", err) } if shouldExpungeOldLocation { - if err := conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil { + if err := conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil { return false, fmt.Errorf("unlabeling messages: %w", err) } } @@ -479,10 +481,10 @@ func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []im defer conn.goPollAPIEvents(false) if seen { - return conn.client.MarkMessagesRead(ctx, mapTo[imap.MessageID, string](messageIDs)...) + return conn.client.MarkMessagesRead(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...) } - return conn.client.MarkMessagesUnread(ctx, mapTo[imap.MessageID, string](messageIDs)...) + return conn.client.MarkMessagesUnread(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...) } // MarkMessagesFlagged sets the flagged value of the given messages. @@ -490,10 +492,10 @@ func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs [ defer conn.goPollAPIEvents(false) if flagged { - return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), proton.StarredLabel) + return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel) } - return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), proton.StarredLabel) + return conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel) } // GetUpdates returns a stream of updates that the gluon server should apply. @@ -535,7 +537,7 @@ func (conn *imapConnector) importMessage( var full proton.FullMessage if err := safe.RLockRet(func() error { - return withAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { + return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error { var messageID string if slices.Contains(labelIDs, proton.DraftsLabel) { @@ -571,7 +573,7 @@ func (conn *imapConnector) importMessage( var err error - if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil { + if full, err = conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil { return fmt.Errorf("failed to fetch message: %w", err) } diff --git a/internal/user/keys_test.go b/internal/user/keys_test.go index bd29ff78..4e9af945 100644 --- a/internal/user/keys_test.go +++ b/internal/user/keys_test.go @@ -24,6 +24,7 @@ import ( "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api/server" "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" "github.com/stretchr/testify/require" ) @@ -36,7 +37,7 @@ func BenchmarkAddrKeyRing(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - require.NoError(b, withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { + require.NoError(b, usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { return nil })) } diff --git a/internal/user/sync.go b/internal/user/sync.go index 2b60cbf4..bc8921e6 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -34,6 +34,7 @@ import ( "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/safe" + "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/xslices" @@ -114,7 +115,7 @@ func (user *User) doSync(ctx context.Context) error { func (user *User) sync(ctx context.Context) error { return safe.RLockRet(func() error { - return withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { + return usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { if !user.vault.SyncStatus().HasLabels { user.log.Info("Syncing labels") diff --git a/internal/user/sync_build.go b/internal/user/sync_build.go index 93309c35..a0e03998 100644 --- a/internal/user/sync_build.go +++ b/internal/user/sync_build.go @@ -25,6 +25,7 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" "github.com/ProtonMail/proton-bridge/v3/pkg/algo" "github.com/ProtonMail/proton-bridge/v3/pkg/message" "github.com/bradenaw/juniper/xslices" @@ -87,7 +88,7 @@ func newMessageCreatedUpdate( return &imap.MessageCreated{ Message: toIMAPMessage(message), Literal: literal, - MailboxIDs: mapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)), + MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)), ParsedMessage: parsedMessage, }, nil } @@ -106,7 +107,7 @@ func newMessageCreatedFailedUpdate( return &imap.MessageCreated{ Message: toIMAPMessage(message), - MailboxIDs: mapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)), + MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)), Literal: literal, ParsedMessage: parsedMessage, } diff --git a/internal/user/user.go b/internal/user/user.go index 49d29ab7..23a32602 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -40,7 +40,10 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/logging" "github.com/ProtonMail/proton-bridge/v3/internal/safe" + "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder" + "github.com/ProtonMail/proton-bridge/v3/internal/services/smtp" "github.com/ProtonMail/proton-bridge/v3/internal/telemetry" + "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/pkg/algo" "github.com/bradenaw/juniper/xslices" @@ -65,7 +68,7 @@ type User struct { vault *vault.User client *proton.Client reporter reporter.Reporter - sendHash *sendRecorder + sendHash *sendrecorder.SendRecorder eventCh *async.QueuedChannel[events.Event] eventLock safe.RWMutex @@ -100,6 +103,8 @@ type User struct { telemetryManager telemetry.Availability // goStatusProgress triggers a check/sending if progress is needed. goStatusProgress func() + + smtpService *smtp.Service } // New returns a new user. @@ -148,7 +153,7 @@ func New( vault: encVault, client: client, reporter: reporter, - sendHash: newSendRecorder(sendEntryExpiry), + sendHash: sendrecorder.NewSendRecorder(sendrecorder.SendEntryExpiry), eventCh: async.NewQueuedChannel[events.Event](0, 0, crashHandler), eventLock: safe.NewRWMutex(), @@ -156,10 +161,10 @@ func New( apiUser: apiUser, apiUserLock: safe.NewRWMutex(), - apiAddrs: groupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }), + apiAddrs: usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }), apiAddrsLock: safe.NewRWMutex(), - apiLabels: groupBy(apiLabels, func(label proton.Label) string { return label.ID }), + apiLabels: usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID }), apiLabelsLock: safe.NewRWMutex(), updateCh: make(map[string]*async.QueuedChannel[imap.Update]), @@ -178,6 +183,8 @@ func New( telemetryManager: telemetryManager, } + user.smtpService = smtp.NewService(user, client, user.sendHash, user.panicHandler, user.reporter) + // Check for status_progress when triggered. user.goStatusProgress = user.tasks.PeriodicOrTrigger(configstatus.ProgressCheckInterval, 0, func(ctx context.Context) { user.SendConfigStatusProgress(ctx) @@ -264,6 +271,9 @@ func New( } }) + // Start SMTP Service + user.smtpService.Start(user.tasks) + return user, nil } @@ -461,7 +471,7 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) { switch user.vault.AddressMode() { case vault.CombinedMode: - primAddr, err := getAddrIdx(user.apiAddrs, 0) + primAddr, err := usertypes.GetAddrIdx(user.apiAddrs, 0) if err != nil { return nil, fmt.Errorf("failed to get primary address: %w", err) } @@ -485,10 +495,10 @@ func (user *User) SendMail(authID string, from string, to []string, r io.Reader) } if len(to) == 0 { - return ErrInvalidRecipient + return smtp.ErrInvalidRecipient } - if err := user.sendMail(authID, from, to, r); err != nil { + if err := user.smtpService.SendMail(context.Background(), authID, from, to, r); err != nil { if apiErr := new(proton.APIError); errors.As(err, &apiErr) { logrus.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("failed to send message") } @@ -664,6 +674,12 @@ func (user *User) SendTelemetry(ctx context.Context, data []byte) error { return nil } +func (user *User) WithSMTPData(ctx context.Context, op func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error { + return safe.RLockRet(func() error { + return op(ctx, user.apiAddrs, user.apiUser, user.vault) + }, user.apiUserLock, user.apiAddrsLock, user.eventLock) +} + // initUpdateCh initializes the user's update channels in the given address mode. // It is assumed that user.apiAddrs and user.updateCh are already locked. func (user *User) initUpdateCh(mode vault.AddressMode) { diff --git a/internal/user/keys.go b/internal/usertypes/keys.go similarity index 93% rename from internal/user/keys.go rename to internal/usertypes/keys.go index 949c873b..d267474f 100644 --- a/internal/user/keys.go +++ b/internal/usertypes/keys.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with Proton Mail Bridge. If not, see . -package user +package usertypes import ( "fmt" @@ -25,7 +25,7 @@ import ( "github.com/sirupsen/logrus" ) -func withAddrKR(apiUser proton.User, apiAddr proton.Address, keyPass []byte, fn func(userKR, addrKR *crypto.KeyRing) error) error { +func WithAddrKR(apiUser proton.User, apiAddr proton.Address, keyPass []byte, fn func(userKR, addrKR *crypto.KeyRing) error) error { userKR, err := apiUser.Keys.Unlock(keyPass, nil) if err != nil { return fmt.Errorf("failed to unlock user keys: %w", err) @@ -41,7 +41,7 @@ func withAddrKR(apiUser proton.User, apiAddr proton.Address, keyPass []byte, fn return fn(userKR, addrKR) } -func withAddrKRs(apiUser proton.User, apiAddr map[string]proton.Address, keyPass []byte, fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error { +func WithAddrKRs(apiUser proton.User, apiAddr map[string]proton.Address, keyPass []byte, fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error { userKR, err := apiUser.Keys.Unlock(keyPass, nil) if err != nil { return fmt.Errorf("failed to unlock user keys: %w", err) diff --git a/internal/user/types.go b/internal/usertypes/types.go similarity index 74% rename from internal/user/types.go rename to internal/usertypes/types.go index 67fbdf04..b30df75b 100644 --- a/internal/user/types.go +++ b/internal/usertypes/types.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with Proton Mail Bridge. If not, see . -package user +package usertypes import ( "fmt" @@ -29,10 +29,10 @@ import ( "golang.org/x/exp/slices" ) -// mapTo converts the slice to the given type. +// MapTo converts the slice to the given type. // This is not runtime safe, so make sure the slice is of the correct type! // (This is a workaround for the fact that slices cannot be converted to other types generically). -func mapTo[From, To any](from []From) []To { +func MapTo[From, To any](from []From) []To { to := make([]To, 0, len(from)) for _, from := range from { @@ -47,9 +47,9 @@ func mapTo[From, To any](from []From) []To { return to } -// groupBy returns a map of the given slice grouped by the given key. +// GroupBy returns a map of the given slice grouped by the given key. // Duplicate keys are overwritten. -func groupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[Key]Value { +func GroupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[Key]Value { groups := make(map[Key]Value) for _, item := range items { @@ -59,10 +59,10 @@ func groupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[ return groups } -// getAddrID returns the address ID for the given email address. -func getAddrID(apiAddrs map[string]proton.Address, email string) (string, error) { +// GetAddrID returns the address ID for the given email address. +func GetAddrID(apiAddrs map[string]proton.Address, email string) (string, error) { for _, addr := range apiAddrs { - if strings.EqualFold(addr.Email, sanitizeEmail(email)) { + if strings.EqualFold(addr.Email, SanitizeEmail(email)) { return addr.ID, nil } } @@ -70,8 +70,8 @@ func getAddrID(apiAddrs map[string]proton.Address, email string) (string, error) return "", fmt.Errorf("address %s not found", email) } -// getAddrIdx returns the address with the given index. -func getAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, error) { +// GetAddrIdx returns the address with the given index. +func GetAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, error) { sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool { return a.Order < b.Order }) @@ -83,7 +83,7 @@ func getAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, er return sorted[idx], nil } -func getPrimaryAddr(apiAddrs map[string]proton.Address) (proton.Address, error) { +func GetPrimaryAddr(apiAddrs map[string]proton.Address) (proton.Address, error) { sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool { return a.Order < b.Order }) @@ -106,6 +106,15 @@ func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item { return sorted } -func newProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler { +func NewProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler { return proton.NewParallelScheduler(runtime.NumCPU()/2, panicHandler) } + +func SanitizeEmail(email string) string { + splitAt := strings.Split(email, "@") + if len(splitAt) != 2 { + return email + } + + return strings.Split(splitAt[0], "+")[0] + "@" + splitAt[1] +} diff --git a/internal/user/types_test.go b/internal/usertypes/types_test.go similarity index 88% rename from internal/user/types_test.go rename to internal/usertypes/types_test.go index bfc23157..2c974898 100644 --- a/internal/user/types_test.go +++ b/internal/usertypes/types_test.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with Proton Mail Bridge. If not, see . -package user +package usertypes import ( "testing" @@ -30,8 +30,8 @@ func TestToType(t *testing.T) { require.NotEqual(t, []myString{"a", "b", "c"}, []string{"a", "b", "c"}) // But converting them to the same type makes them equal. - require.Equal(t, []myString{"a", "b", "c"}, mapTo[string, myString]([]string{"a", "b", "c"})) + require.Equal(t, []myString{"a", "b", "c"}, MapTo[string, myString]([]string{"a", "b", "c"})) // The conversion can happen in the other direction too. - require.Equal(t, []string{"a", "b", "c"}, mapTo[myString, string]([]myString{"a", "b", "c"})) + require.Equal(t, []string{"a", "b", "c"}, MapTo[myString, string]([]myString{"a", "b", "c"})) }