1
0

feat(GODT-2799): Integrate SMTP service with User Identity Service

This commit is contained in:
Leander Beernaert
2023-07-24 17:03:54 +02:00
parent 7be1a8ae8a
commit 0b35f41ffd
3 changed files with 219 additions and 118 deletions

View File

@ -53,17 +53,17 @@ func (s *Accounts) CheckAuth(user string, password []byte) (string, string, erro
defer s.accountsLock.RUnlock() defer s.accountsLock.RUnlock()
for id, service := range s.accounts { for id, service := range s.accounts {
addrID, err := service.user.CheckAuth(user, password) addrID, err := service.checkAuth(context.Background(), user, password)
if err != nil { if err != nil {
continue continue
} }
service.user.ReportSMTPAuthSuccess(context.Background()) service.telemetry.ReportSMTPAuthSuccess(context.Background())
return id, addrID, nil return id, addrID, nil
} }
for _, service := range s.accounts { for _, service := range s.accounts {
service.user.ReportSMTPAuthFailed(user) service.telemetry.ReportSMTPAuthFailed(user)
} }
return "", "", ErrNoSuchUser return "", "", ErrNoSuchUser

View File

@ -20,55 +20,94 @@ package smtp
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"io" "io"
"github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/logging" "github.com/ProtonMail/gluon/logging"
"github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
bridgelogging "github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder" "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/services/userevents"
"github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity"
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc" "github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// UserInterface is just wrapper to avoid recursive go module imports. To be removed when the identity service is ready. type Telemetry interface {
type UserInterface interface { useridentity.Telemetry
ID() string
CheckAuth(string, []byte) (string, error)
WithSMTPData(context.Context, func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error
ReportSMTPAuthSuccess(context.Context) ReportSMTPAuthSuccess(context.Context)
ReportSMTPAuthFailed(username string) ReportSMTPAuthFailed(username string)
} }
type AddressMode int
const (
AddressModeCombined AddressMode = iota
AddressModeSplit
)
type Service struct { type Service struct {
userID string
panicHandler async.PanicHandler panicHandler async.PanicHandler
cpc *cpc.CPC cpc *cpc.CPC
user UserInterface
client *proton.Client client *proton.Client
recorder *sendrecorder.SendRecorder recorder *sendrecorder.SendRecorder
log *logrus.Entry log *logrus.Entry
reporter reporter.Reporter reporter reporter.Reporter
bridgePassProvider useridentity.BridgePassProvider
keyPassProvider useridentity.KeyPassProvider
identityState *useridentity.State
telemetry Telemetry
eventService userevents.Subscribable
refreshSubscriber *userevents.RefreshChanneledSubscriber
addressSubscriber *userevents.AddressChanneledSubscriber
userSubscriber *userevents.UserChanneledSubscriber
addressMode AddressMode
} }
func NewService( func NewService(
user UserInterface, userID string,
client *proton.Client, client *proton.Client,
recorder *sendrecorder.SendRecorder, recorder *sendrecorder.SendRecorder,
handler async.PanicHandler, handler async.PanicHandler,
reporter reporter.Reporter, reporter reporter.Reporter,
bridgePassProvider useridentity.BridgePassProvider,
keyPassProvider useridentity.KeyPassProvider,
telemetry Telemetry,
eventService userevents.Subscribable,
mode AddressMode,
identityState *useridentity.State,
) *Service { ) *Service {
subscriberName := fmt.Sprintf("smpt-%v", userID)
return &Service{ return &Service{
panicHandler: handler, panicHandler: handler,
user: user, userID: userID,
cpc: cpc.NewCPC(), cpc: cpc.NewCPC(),
recorder: recorder, recorder: recorder,
log: logrus.WithFields(logrus.Fields{ log: logrus.WithFields(logrus.Fields{
"user": user.ID(), "user": userID,
"service": "smtp", "service": "smtp",
}), }),
reporter: reporter, reporter: reporter,
client: client, client: client,
bridgePassProvider: bridgePassProvider,
keyPassProvider: keyPassProvider,
telemetry: telemetry,
identityState: identityState,
eventService: eventService,
refreshSubscriber: userevents.NewRefreshSubscriber(subscriberName),
userSubscriber: userevents.NewUserSubscriber(subscriberName),
addressSubscriber: userevents.NewAddressSubscriber(subscriberName),
addressMode: mode,
} }
} }
@ -83,20 +122,33 @@ func (s *Service) SendMail(ctx context.Context, authID string, from string, to [
return err return err
} }
func (s *Service) SetAddressMode(ctx context.Context, mode AddressMode) error {
_, err := s.cpc.Send(ctx, &setAddressModeReq{mode: mode})
return err
}
func (s *Service) checkAuth(ctx context.Context, email string, password []byte) (string, error) {
return cpc.SendTyped[string](ctx, s.cpc, &checkAuthReq{
email: email,
password: password,
})
}
func (s *Service) Start(group *async.Group) { func (s *Service) Start(group *async.Group) {
s.log.Debug("Starting service") s.log.Debug("Starting service")
group.Once(func(ctx context.Context) { group.Once(func(ctx context.Context) {
logging.DoAnnotated(ctx, func(ctx context.Context) { logging.DoAnnotated(ctx, func(ctx context.Context) {
s.run(ctx) s.run(ctx)
}, logging.Labels{ }, logging.Labels{
"user": s.user.ID(), "user": s.userID,
"service": "smtp", "service": "smtp",
}) })
}) })
} }
func (s *Service) UserID() string { func (s *Service) UserID() string {
return s.user.ID() return s.userID
} }
func (s *Service) run(ctx context.Context) { func (s *Service) run(ctx context.Context) {
@ -104,6 +156,15 @@ func (s *Service) run(ctx context.Context) {
defer s.log.Debug("Exiting service main loop") defer s.log.Debug("Exiting service main loop")
defer s.cpc.Close() defer s.cpc.Close()
subscription := userevents.Subscription{
User: s.userSubscriber,
Refresh: s.refreshSubscriber,
Address: s.addressSubscriber,
}
s.eventService.Subscribe(subscription)
defer s.eventService.Unsubscribe(subscription)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -120,9 +181,48 @@ func (s *Service) run(ctx context.Context) {
err := s.sendMail(ctx, r) err := s.sendMail(ctx, r)
request.Reply(ctx, nil, err) request.Reply(ctx, nil, err)
case *setAddressModeReq:
s.log.Debugf("Set address mode %v", r.mode)
s.addressMode = r.mode
request.Reply(ctx, nil, nil)
case *checkAuthReq:
s.log.WithField("email", bridgelogging.Sensitive(r.email)).Debug("Checking authentication")
addrID, err := s.identityState.CheckAuth(r.email, r.password, s.bridgePassProvider, s.telemetry)
request.Reply(ctx, addrID, err)
default: default:
s.log.Error("Received unknown request") s.log.Error("Received unknown request")
} }
case e, ok := <-s.userSubscriber.OnEventCh():
if !ok {
continue
}
s.log.Debug("Handling user event")
e.Consume(func(user proton.User) error {
s.identityState.OnUserEvent(user)
return nil
})
case e, ok := <-s.refreshSubscriber.OnEventCh():
if !ok {
continue
}
s.log.Debug("Handling refresh event")
e.Consume(func(_ proton.RefreshFlag) error {
return s.identityState.OnRefreshEvent(ctx)
})
case e, ok := <-s.addressSubscriber.OnEventCh():
if !ok {
continue
}
s.log.Debug("Handling Address Event")
e.Consume(func(evt []proton.AddressEvent) error {
s.identityState.OnAddressEvents(evt)
return nil
})
} }
} }
} }
@ -146,3 +246,12 @@ func (s *Service) sendMail(ctx context.Context, req *sendMailReq) error {
return nil return nil
} }
type setAddressModeReq struct {
mode AddressMode
}
type checkAuthReq struct {
email string
password []byte
}

View File

@ -38,129 +38,121 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/logging" "github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder" "github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes" "github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/message" "github.com/ProtonMail/proton-bridge/v3/pkg/message"
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser" "github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
"github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
// smtpSendMail sends an email from the given address to the given recipients. // smtpSendMail sends an email from the given address to the given recipients.
func (s *Service) smtpSendMail(ctx context.Context, authID string, from string, to []string, r io.Reader) error { func (s *Service) smtpSendMail(ctx context.Context, authID string, from string, to []string, r io.Reader) error {
return s.user.WithSMTPData(ctx, func(ctx context.Context, apiAddrs map[string]proton.Address, user proton.User, vault *vault.User) error { fromAddr, err := s.identityState.GetAddr(from)
if _, err := usertypes.GetAddrID(apiAddrs, from); err != nil { if err != nil {
return ErrInvalidReturnPath return ErrInvalidReturnPath
} }
emails := xslices.Map(maps.Values(apiAddrs), func(addr proton.Address) string { emails := xslices.Map(s.identityState.AddressesSorted, func(addr proton.Address) string {
return addr.Email return addr.Email
}) })
// Read the message to send. // Read the message to send.
b, err := io.ReadAll(r) b, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
// If running a QA build, dump to disk.
if err := debugDumpToDisk(b); err != nil {
s.log.WithError(err).Warn("Failed to dump message to disk")
}
// Compute the hash of the message (to match it against SMTP messages).
hash, err := sendrecorder.GetMessageHash(b)
if err != nil {
return err
}
// Check if we already tried to send this message recently.
srID, ok, err := s.recorder.TryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second))
if err != nil {
return fmt.Errorf("failed to check send hash: %w", err)
} else if !ok {
s.log.Warn("A duplicate message was already sent recently, skipping")
return nil
}
// If we fail to send this message, we should remove the hash from the send recorder.
defer s.recorder.RemoveOnFail(hash, srID)
// Create a new message parser from the reader.
parser, err := parser.New(bytes.NewReader(b))
if err != nil {
return fmt.Errorf("failed to create parser: %w", err)
}
// If the message contains a sender, use it instead of the one from the return path.
if sender, ok := getMessageSender(parser); ok {
from = sender
}
// Load the user's mail settings.
settings, err := s.client.GetMailSettings(ctx)
if err != nil {
return fmt.Errorf("failed to get mail settings: %w", err)
}
return usertypes.WithAddrKR(s.identityState.User, fromAddr, s.keyPassProvider.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error {
// Use the first key for encrypting the message.
addrKR, err := addrKR.FirstKey()
if err != nil { if err != nil {
return fmt.Errorf("failed to read message: %w", err) return fmt.Errorf("failed to get first key: %w", err)
} }
// If running a QA build, dump to disk. // Ensure that there is always a text/html or text/plain body part. This is required by the API. If none
if err := debugDumpToDisk(b); err != nil { // exists and empty text part will be added.
s.log.WithError(err).Warn("Failed to dump message to disk") parser.AttachEmptyTextPartIfNoneExists()
}
// Compute the hash of the message (to match it against SMTP messages). // If we have to attach the public key, do it now.
hash, err := sendrecorder.GetMessageHash(b) if settings.AttachPublicKey {
if err != nil { key, err := addrKR.GetKey(0)
return err
}
// Check if we already tried to send this message recently.
srID, ok, err := s.recorder.TryInsertWait(ctx, hash, to, time.Now().Add(90*time.Second))
if err != nil {
return fmt.Errorf("failed to check send hash: %w", err)
} else if !ok {
s.log.Warn("A duplicate message was already sent recently, skipping")
return nil
}
// If we fail to send this message, we should remove the hash from the send recorder.
defer s.recorder.RemoveOnFail(hash, srID)
// Create a new message parser from the reader.
parser, err := parser.New(bytes.NewReader(b))
if err != nil {
return fmt.Errorf("failed to create parser: %w", err)
}
// If the message contains a sender, use it instead of the one from the return path.
if sender, ok := getMessageSender(parser); ok {
from = sender
}
// Load the user's mail settings.
settings, err := s.client.GetMailSettings(ctx)
if err != nil {
return fmt.Errorf("failed to get mail settings: %w", err)
}
addrID, err := usertypes.GetAddrID(apiAddrs, from)
if err != nil {
return err
}
return usertypes.WithAddrKR(user, apiAddrs[addrID], vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error {
// Use the first key for encrypting the message.
addrKR, err := addrKR.FirstKey()
if err != nil { if err != nil {
return fmt.Errorf("failed to get first key: %w", err) return fmt.Errorf("failed to get sending key: %w", err)
} }
// Ensure that there is always a text/html or text/plain body part. This is required by the API. If none pubKey, err := key.GetArmoredPublicKey()
// exists and empty text part will be added.
parser.AttachEmptyTextPartIfNoneExists()
// If we have to attach the public key, do it now.
if settings.AttachPublicKey {
key, err := addrKR.GetKey(0)
if err != nil {
return fmt.Errorf("failed to get sending key: %w", err)
}
pubKey, err := key.GetArmoredPublicKey()
if err != nil {
return fmt.Errorf("failed to get public key: %w", err)
}
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
}
// Parse the message we want to send (after we have attached the public key).
message, err := message.ParseWithParser(parser, false)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse message: %w", err) return fmt.Errorf("failed to get public key: %w", err)
} }
// Send the message using the correct key. parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
sent, err := s.sendWithKey( }
ctx,
authID,
vault.AddressMode(),
settings,
userKR, addrKR,
emails, from, to,
message,
)
if err != nil {
return fmt.Errorf("failed to send message: %w", err)
}
// If the message was successfully sent, we can update the message ID in the record. // Parse the message we want to send (after we have attached the public key).
s.recorder.SignalMessageSent(hash, srID, sent.ID) message, err := message.ParseWithParser(parser, false)
if err != nil {
return fmt.Errorf("failed to parse message: %w", err)
}
return nil // Send the message using the correct key.
}) sent, err := s.sendWithKey(
ctx,
authID,
s.addressMode,
settings,
userKR, addrKR,
emails, from, to,
message,
)
if err != nil {
return fmt.Errorf("failed to send message: %w", err)
}
// If the message was successfully sent, we can update the message ID in the record.
s.recorder.SignalMessageSent(hash, srID, sent.ID)
return nil
}) })
} }
@ -168,7 +160,7 @@ func (s *Service) smtpSendMail(ctx context.Context, authID string, from string,
func (s *Service) sendWithKey( func (s *Service) sendWithKey(
ctx context.Context, ctx context.Context,
authAddrID string, authAddrID string,
addrMode vault.AddressMode, addrMode AddressMode,
settings proton.MailSettings, settings proton.MailSettings,
userKR, addrKR *crypto.KeyRing, userKR, addrKR *crypto.KeyRing,
emails []string, emails []string,
@ -249,7 +241,7 @@ func getParentID(
ctx context.Context, ctx context.Context,
client *proton.Client, client *proton.Client,
authAddrID string, authAddrID string,
addrMode vault.AddressMode, addrMode AddressMode,
references []string, references []string,
) (string, error) { ) (string, error) {
var ( var (
@ -271,7 +263,7 @@ func getParentID(
for _, internal := range internal { for _, internal := range internal {
var addrID string var addrID string
if addrMode == vault.SplitMode { if addrMode == AddressModeSplit {
addrID = authAddrID addrID = authAddrID
} }
@ -299,7 +291,7 @@ func getParentID(
if parentID == "" && len(external) > 0 { if parentID == "" && len(external) > 0 {
var addrID string var addrID string
if addrMode == vault.SplitMode { if addrMode == AddressModeSplit {
addrID = authAddrID addrID = authAddrID
} }