diff --git a/internal/services/smtp/accounts.go b/internal/services/smtp/accounts.go index 854721c5..4bff6965 100644 --- a/internal/services/smtp/accounts.go +++ b/internal/services/smtp/accounts.go @@ -53,17 +53,17 @@ func (s *Accounts) CheckAuth(user string, password []byte) (string, string, erro defer s.accountsLock.RUnlock() for id, service := range s.accounts { - addrID, err := service.user.CheckAuth(user, password) + addrID, err := service.checkAuth(context.Background(), user, password) if err != nil { continue } - service.user.ReportSMTPAuthSuccess(context.Background()) + service.telemetry.ReportSMTPAuthSuccess(context.Background()) return id, addrID, nil } for _, service := range s.accounts { - service.user.ReportSMTPAuthFailed(user) + service.telemetry.ReportSMTPAuthFailed(user) } return "", "", ErrNoSuchUser diff --git a/internal/services/smtp/service.go b/internal/services/smtp/service.go index cdc250ea..cff2f67b 100644 --- a/internal/services/smtp/service.go +++ b/internal/services/smtp/service.go @@ -20,55 +20,94 @@ package smtp import ( "context" "errors" + "fmt" "io" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/logging" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/go-proton-api" + bridgelogging "github.com/ProtonMail/proton-bridge/v3/internal/logging" "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder" - "github.com/ProtonMail/proton-bridge/v3/internal/vault" + "github.com/ProtonMail/proton-bridge/v3/internal/services/userevents" + "github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity" "github.com/ProtonMail/proton-bridge/v3/pkg/cpc" "github.com/sirupsen/logrus" ) -// UserInterface is just wrapper to avoid recursive go module imports. To be removed when the identity service is ready. -type UserInterface interface { - ID() string - CheckAuth(string, []byte) (string, error) - WithSMTPData(context.Context, func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error +type Telemetry interface { + useridentity.Telemetry ReportSMTPAuthSuccess(context.Context) ReportSMTPAuthFailed(username string) } +type AddressMode int + +const ( + AddressModeCombined AddressMode = iota + AddressModeSplit +) + type Service struct { + userID string panicHandler async.PanicHandler cpc *cpc.CPC - user UserInterface client *proton.Client recorder *sendrecorder.SendRecorder log *logrus.Entry reporter reporter.Reporter + + bridgePassProvider useridentity.BridgePassProvider + keyPassProvider useridentity.KeyPassProvider + identityState *useridentity.State + telemetry Telemetry + + eventService userevents.Subscribable + refreshSubscriber *userevents.RefreshChanneledSubscriber + addressSubscriber *userevents.AddressChanneledSubscriber + userSubscriber *userevents.UserChanneledSubscriber + + addressMode AddressMode } func NewService( - user UserInterface, + userID string, client *proton.Client, recorder *sendrecorder.SendRecorder, handler async.PanicHandler, reporter reporter.Reporter, + bridgePassProvider useridentity.BridgePassProvider, + keyPassProvider useridentity.KeyPassProvider, + telemetry Telemetry, + eventService userevents.Subscribable, + mode AddressMode, + identityState *useridentity.State, ) *Service { + subscriberName := fmt.Sprintf("smpt-%v", userID) + return &Service{ panicHandler: handler, - user: user, + userID: userID, cpc: cpc.NewCPC(), recorder: recorder, log: logrus.WithFields(logrus.Fields{ - "user": user.ID(), + "user": userID, "service": "smtp", }), reporter: reporter, client: client, + + bridgePassProvider: bridgePassProvider, + keyPassProvider: keyPassProvider, + telemetry: telemetry, + identityState: identityState, + eventService: eventService, + + refreshSubscriber: userevents.NewRefreshSubscriber(subscriberName), + userSubscriber: userevents.NewUserSubscriber(subscriberName), + addressSubscriber: userevents.NewAddressSubscriber(subscriberName), + + addressMode: mode, } } @@ -83,20 +122,33 @@ func (s *Service) SendMail(ctx context.Context, authID string, from string, to [ return err } +func (s *Service) SetAddressMode(ctx context.Context, mode AddressMode) error { + _, err := s.cpc.Send(ctx, &setAddressModeReq{mode: mode}) + + return err +} + +func (s *Service) checkAuth(ctx context.Context, email string, password []byte) (string, error) { + return cpc.SendTyped[string](ctx, s.cpc, &checkAuthReq{ + email: email, + password: password, + }) +} + func (s *Service) Start(group *async.Group) { s.log.Debug("Starting service") group.Once(func(ctx context.Context) { logging.DoAnnotated(ctx, func(ctx context.Context) { s.run(ctx) }, logging.Labels{ - "user": s.user.ID(), + "user": s.userID, "service": "smtp", }) }) } func (s *Service) UserID() string { - return s.user.ID() + return s.userID } func (s *Service) run(ctx context.Context) { @@ -104,6 +156,15 @@ func (s *Service) run(ctx context.Context) { defer s.log.Debug("Exiting service main loop") defer s.cpc.Close() + subscription := userevents.Subscription{ + User: s.userSubscriber, + Refresh: s.refreshSubscriber, + Address: s.addressSubscriber, + } + + s.eventService.Subscribe(subscription) + defer s.eventService.Unsubscribe(subscription) + for { select { case <-ctx.Done(): @@ -120,9 +181,48 @@ func (s *Service) run(ctx context.Context) { err := s.sendMail(ctx, r) request.Reply(ctx, nil, err) + case *setAddressModeReq: + s.log.Debugf("Set address mode %v", r.mode) + s.addressMode = r.mode + request.Reply(ctx, nil, nil) + + case *checkAuthReq: + s.log.WithField("email", bridgelogging.Sensitive(r.email)).Debug("Checking authentication") + addrID, err := s.identityState.CheckAuth(r.email, r.password, s.bridgePassProvider, s.telemetry) + request.Reply(ctx, addrID, err) + default: s.log.Error("Received unknown request") } + case e, ok := <-s.userSubscriber.OnEventCh(): + if !ok { + continue + } + + s.log.Debug("Handling user event") + e.Consume(func(user proton.User) error { + s.identityState.OnUserEvent(user) + return nil + }) + case e, ok := <-s.refreshSubscriber.OnEventCh(): + if !ok { + continue + } + + s.log.Debug("Handling refresh event") + e.Consume(func(_ proton.RefreshFlag) error { + return s.identityState.OnRefreshEvent(ctx) + }) + case e, ok := <-s.addressSubscriber.OnEventCh(): + if !ok { + continue + } + + s.log.Debug("Handling Address Event") + e.Consume(func(evt []proton.AddressEvent) error { + s.identityState.OnAddressEvents(evt) + return nil + }) } } } @@ -146,3 +246,12 @@ func (s *Service) sendMail(ctx context.Context, req *sendMailReq) error { return nil } + +type setAddressModeReq struct { + mode AddressMode +} + +type checkAuthReq struct { + email string + password []byte +} diff --git a/internal/services/smtp/smtp.go b/internal/services/smtp/smtp.go index 1bbfaa05..a59f7c10 100644 --- a/internal/services/smtp/smtp.go +++ b/internal/services/smtp/smtp.go @@ -38,129 +38,121 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/logging" "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder" "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" - "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/pkg/message" "github.com/ProtonMail/proton-bridge/v3/pkg/message/parser" "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/xslices" "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) // smtpSendMail sends an email from the given address to the given recipients. func (s *Service) smtpSendMail(ctx context.Context, authID string, from string, to []string, r io.Reader) error { - return s.user.WithSMTPData(ctx, func(ctx context.Context, apiAddrs map[string]proton.Address, user proton.User, vault *vault.User) error { - if _, err := usertypes.GetAddrID(apiAddrs, from); err != nil { - return ErrInvalidReturnPath - } + fromAddr, err := s.identityState.GetAddr(from) + if err != nil { + return ErrInvalidReturnPath + } - emails := xslices.Map(maps.Values(apiAddrs), func(addr proton.Address) string { - return addr.Email - }) + emails := xslices.Map(s.identityState.AddressesSorted, func(addr proton.Address) string { + return addr.Email + }) - // Read the message to send. - b, err := io.ReadAll(r) + // Read the message to send. + b, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("failed to read message: %w", err) + } + + // If running a QA build, dump to disk. + if err := debugDumpToDisk(b); err != nil { + s.log.WithError(err).Warn("Failed to dump message to disk") + } + + // Compute the hash of the message (to match it against SMTP messages). + hash, err := sendrecorder.GetMessageHash(b) + if err != nil { + return err + } + + // Check if we already tried to send this message recently. + srID, ok, err := s.recorder.TryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second)) + if err != nil { + return fmt.Errorf("failed to check send hash: %w", err) + } else if !ok { + s.log.Warn("A duplicate message was already sent recently, skipping") + return nil + } + + // If we fail to send this message, we should remove the hash from the send recorder. + defer s.recorder.RemoveOnFail(hash, srID) + + // Create a new message parser from the reader. + parser, err := parser.New(bytes.NewReader(b)) + if err != nil { + return fmt.Errorf("failed to create parser: %w", err) + } + + // If the message contains a sender, use it instead of the one from the return path. + if sender, ok := getMessageSender(parser); ok { + from = sender + } + + // Load the user's mail settings. + settings, err := s.client.GetMailSettings(ctx) + if err != nil { + return fmt.Errorf("failed to get mail settings: %w", err) + } + + return usertypes.WithAddrKR(s.identityState.User, fromAddr, s.keyPassProvider.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error { + // Use the first key for encrypting the message. + addrKR, err := addrKR.FirstKey() if err != nil { - return fmt.Errorf("failed to read message: %w", err) + return fmt.Errorf("failed to get first key: %w", err) } - // If running a QA build, dump to disk. - if err := debugDumpToDisk(b); err != nil { - s.log.WithError(err).Warn("Failed to dump message to disk") - } + // Ensure that there is always a text/html or text/plain body part. This is required by the API. If none + // exists and empty text part will be added. + parser.AttachEmptyTextPartIfNoneExists() - // Compute the hash of the message (to match it against SMTP messages). - hash, err := sendrecorder.GetMessageHash(b) - if err != nil { - return err - } - - // Check if we already tried to send this message recently. - srID, ok, err := s.recorder.TryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second)) - if err != nil { - return fmt.Errorf("failed to check send hash: %w", err) - } else if !ok { - s.log.Warn("A duplicate message was already sent recently, skipping") - return nil - } - - // If we fail to send this message, we should remove the hash from the send recorder. - defer s.recorder.RemoveOnFail(hash, srID) - - // Create a new message parser from the reader. - parser, err := parser.New(bytes.NewReader(b)) - if err != nil { - return fmt.Errorf("failed to create parser: %w", err) - } - - // If the message contains a sender, use it instead of the one from the return path. - if sender, ok := getMessageSender(parser); ok { - from = sender - } - - // Load the user's mail settings. - settings, err := s.client.GetMailSettings(ctx) - if err != nil { - return fmt.Errorf("failed to get mail settings: %w", err) - } - - addrID, err := usertypes.GetAddrID(apiAddrs, from) - if err != nil { - return err - } - - return usertypes.WithAddrKR(user, apiAddrs[addrID], vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error { - // Use the first key for encrypting the message. - addrKR, err := addrKR.FirstKey() + // If we have to attach the public key, do it now. + if settings.AttachPublicKey { + key, err := addrKR.GetKey(0) if err != nil { - return fmt.Errorf("failed to get first key: %w", err) + return fmt.Errorf("failed to get sending key: %w", err) } - // Ensure that there is always a text/html or text/plain body part. This is required by the API. If none - // exists and empty text part will be added. - parser.AttachEmptyTextPartIfNoneExists() - - // If we have to attach the public key, do it now. - if settings.AttachPublicKey { - key, err := addrKR.GetKey(0) - if err != nil { - return fmt.Errorf("failed to get sending key: %w", err) - } - - pubKey, err := key.GetArmoredPublicKey() - if err != nil { - return fmt.Errorf("failed to get public key: %w", err) - } - - parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) - } - - // Parse the message we want to send (after we have attached the public key). - message, err := message.ParseWithParser(parser, false) + pubKey, err := key.GetArmoredPublicKey() if err != nil { - return fmt.Errorf("failed to parse message: %w", err) + return fmt.Errorf("failed to get public key: %w", err) } - // Send the message using the correct key. - sent, err := s.sendWithKey( - ctx, - authID, - vault.AddressMode(), - settings, - userKR, addrKR, - emails, from, to, - message, - ) - if err != nil { - return fmt.Errorf("failed to send message: %w", err) - } + parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) + } - // If the message was successfully sent, we can update the message ID in the record. - s.recorder.SignalMessageSent(hash, srID, sent.ID) + // Parse the message we want to send (after we have attached the public key). + message, err := message.ParseWithParser(parser, false) + if err != nil { + return fmt.Errorf("failed to parse message: %w", err) + } - return nil - }) + // Send the message using the correct key. + sent, err := s.sendWithKey( + ctx, + authID, + s.addressMode, + settings, + userKR, addrKR, + emails, from, to, + message, + ) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + + // If the message was successfully sent, we can update the message ID in the record. + s.recorder.SignalMessageSent(hash, srID, sent.ID) + + return nil }) } @@ -168,7 +160,7 @@ func (s *Service) smtpSendMail(ctx context.Context, authID string, from string, func (s *Service) sendWithKey( ctx context.Context, authAddrID string, - addrMode vault.AddressMode, + addrMode AddressMode, settings proton.MailSettings, userKR, addrKR *crypto.KeyRing, emails []string, @@ -249,7 +241,7 @@ func getParentID( ctx context.Context, client *proton.Client, authAddrID string, - addrMode vault.AddressMode, + addrMode AddressMode, references []string, ) (string, error) { var ( @@ -271,7 +263,7 @@ func getParentID( for _, internal := range internal { var addrID string - if addrMode == vault.SplitMode { + if addrMode == AddressModeSplit { addrID = authAddrID } @@ -299,7 +291,7 @@ func getParentID( if parentID == "" && len(external) > 0 { var addrID string - if addrMode == vault.SplitMode { + if addrMode == AddressModeSplit { addrID = authAddrID }