diff --git a/internal/bridge/smtp.go b/internal/bridge/smtp.go index 09776aa1..7d9eeaf2 100644 --- a/internal/bridge/smtp.go +++ b/internal/bridge/smtp.go @@ -22,23 +22,12 @@ import ( "crypto/tls" "github.com/ProtonMail/proton-bridge/v3/internal/identifier" - "github.com/ProtonMail/proton-bridge/v3/internal/user" ) func (bridge *Bridge) restartSMTP(ctx context.Context) error { return bridge.serverManager.RestartSMTP(ctx) } -// addSMTPUser connects the given user to the smtp server. -func (bridge *Bridge) addSMTPUser(ctx context.Context, user *user.User) error { - return bridge.serverManager.AddSMTPAccount(ctx, user.GetSMTPService()) -} - -// removeSMTPUser disconnects the given user from the smtp server. -func (bridge *Bridge) removeSMTPUser(ctx context.Context, user *user.User) error { - return bridge.serverManager.RemoveSMTPAccount(ctx, user.GetSMTPService()) -} - type bridgeSMTPSettings struct { b *Bridge } diff --git a/internal/bridge/user.go b/internal/bridge/user.go index e6c79046..011b5cd4 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -522,16 +522,13 @@ func (bridge *Bridge) addUserWithVault( statsPath, bridge, bridge.serverManager, + bridge.serverManager, &bridgeEventSubscription{b: bridge}, ) if err != nil { return fmt.Errorf("failed to create user: %w", err) } - if err := bridge.addSMTPUser(ctx, user); err != nil { - return fmt.Errorf("failed to add SMTP user: %w", err) - } - // Handle events coming from the user before forwarding them to the bridge. // For example, if the user's addresses change, we need to update them in gluon. bridge.tasks.Once(func(ctx context.Context) { @@ -593,10 +590,6 @@ func (bridge *Bridge) logoutUser(ctx context.Context, user *user.User, withAPI, "withData": withData, }).Debug("Logging out user") - if err := bridge.removeSMTPUser(ctx, user); err != nil { - logrus.WithError(err).Error("Failed to remove SMTP user") - } - if err := user.Logout(ctx, withAPI); err != nil { logrus.WithError(err).Error("Failed to logout user") } diff --git a/internal/services/smtp/server_manager.go b/internal/services/smtp/server_manager.go new file mode 100644 index 00000000..8bdcc4cc --- /dev/null +++ b/internal/services/smtp/server_manager.go @@ -0,0 +1,41 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package smtp + +import "context" + +type ServerManager interface { + AddSMTPAccount(ctx context.Context, service *Service) error + RemoveSMTPAccount(ctx context.Context, service *Service) error +} + +type NullServerManager struct{} + +func NewNullServerManager() *NullServerManager { + return &NullServerManager{} +} + +func (n NullServerManager) AddSMTPAccount(_ context.Context, _ *Service) error { + // Does nothing. + return nil +} + +func (n NullServerManager) RemoveSMTPAccount(_ context.Context, _ *Service) error { + // Does nothing. + return nil +} diff --git a/internal/services/smtp/service.go b/internal/services/smtp/service.go index 29a571da..10ba8a8a 100644 --- a/internal/services/smtp/service.go +++ b/internal/services/smtp/service.go @@ -62,7 +62,8 @@ type Service struct { addressSubscriber *userevents.AddressChanneledSubscriber userSubscriber *userevents.UserChanneledSubscriber - addressMode usertypes.AddressMode + addressMode usertypes.AddressMode + serverManager ServerManager } func NewService( @@ -77,6 +78,7 @@ func NewService( eventService userevents.Subscribable, mode usertypes.AddressMode, identityState *useridentity.State, + serverManager ServerManager, ) *Service { subscriberName := fmt.Sprintf("smpt-%v", userID) @@ -102,7 +104,8 @@ func NewService( userSubscriber: userevents.NewUserSubscriber(subscriberName), addressSubscriber: userevents.NewAddressSubscriber(subscriberName), - addressMode: mode, + addressMode: mode, + serverManager: serverManager, } } @@ -129,6 +132,12 @@ func (s *Service) Resync(ctx context.Context) error { return err } +func (s *Service) OnLogout(ctx context.Context) error { + _, err := s.cpc.Send(ctx, &onLogoutReq{}) + + return err +} + func (s *Service) checkAuth(ctx context.Context, email string, password []byte) (string, error) { return cpc.SendTyped[string](ctx, s.cpc, &checkAuthReq{ email: email, @@ -136,8 +145,13 @@ func (s *Service) checkAuth(ctx context.Context, email string, password []byte) }) } -func (s *Service) Start(ctx context.Context, group *orderedtasks.OrderedCancelGroup) { +func (s *Service) Start(ctx context.Context, group *orderedtasks.OrderedCancelGroup) error { s.log.Debug("Starting service") + + if err := s.serverManager.AddSMTPAccount(ctx, s); err != nil { + return fmt.Errorf("failed to add SMTP account to server: %w", err) + } + group.Go(ctx, s.userID, "smtp-service", func(ctx context.Context) { logging.DoAnnotated(ctx, func(ctx context.Context) { s.run(ctx) @@ -146,6 +160,8 @@ func (s *Service) Start(ctx context.Context, group *orderedtasks.OrderedCancelGr "service": "smtp", }) }) + + return nil } func (s *Service) UserID() string { @@ -196,6 +212,10 @@ func (s *Service) run(ctx context.Context) { err := s.identityState.OnRefreshEvent(ctx) request.Reply(ctx, nil, err) + case *onLogoutReq: + err := s.serverManager.RemoveSMTPAccount(ctx, s) + request.Reply(ctx, nil, err) + default: s.log.Error("Received unknown request") } @@ -262,3 +282,5 @@ type checkAuthReq struct { } type resyncReq struct{} + +type onLogoutReq struct{} diff --git a/internal/user/user.go b/internal/user/user.go index a4a5f1fa..8e68057f 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -97,7 +97,8 @@ func New( maxSyncMemory uint64, statsDir string, telemetryManager telemetry.Availability, - serverManager imapservice.IMAPServerManager, + imapServerManager imapservice.IMAPServerManager, + smtpServerManager smtp.ServerManager, eventSubscription events.Subscription, ) (*User, error) { user, err := newImpl( @@ -111,7 +112,8 @@ func New( maxSyncMemory, statsDir, telemetryManager, - serverManager, + imapServerManager, + smtpServerManager, eventSubscription, ) if err != nil { @@ -138,7 +140,8 @@ func newImpl( maxSyncMemory uint64, statsDir string, telemetryManager telemetry.Availability, - serverManager imapservice.IMAPServerManager, + imapServerManager imapservice.IMAPServerManager, + smtpServerManager smtp.ServerManager, eventSubscription events.Subscription, ) (*User, error) { logrus.WithField("userID", apiUser.ID).Info("Creating new user") @@ -223,6 +226,7 @@ func newImpl( user.eventService, addressMode, identityState.Clone(), + smtpServerManager, ) user.imapService = imapservice.NewService( @@ -231,7 +235,7 @@ func newImpl( user, encVault, user.eventService, - serverManager, + imapServerManager, user, encVault, encVault, @@ -282,7 +286,9 @@ func newImpl( user.identityService.Start(ctx, user.serviceGroup) // Start SMTP Service - user.smtpService.Start(ctx, user.serviceGroup) + if err := user.smtpService.Start(ctx, user.serviceGroup); err != nil { + return user, fmt.Errorf("failed to start smtp service: %w", err) + } // Start IMAP Service if err := user.imapService.Start(ctx, user.serviceGroup); err != nil { @@ -548,8 +554,12 @@ func (user *User) Logout(ctx context.Context, withAPI bool) error { user.log.Debug("Canceling ongoing tasks") + if err := user.smtpService.OnLogout(ctx); err != nil { + return fmt.Errorf("failed to remove user from smtp server: %w", err) + } + if err := user.imapService.OnLogout(ctx); err != nil { - return fmt.Errorf("failed to remove user from server: %w", err) + return fmt.Errorf("failed to remove user from imap server: %w", err) } // Stop Services diff --git a/internal/user/user_test.go b/internal/user/user_test.go index 9af73bac..5055b939 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -28,6 +28,7 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/certs" "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/services/imapservice" + "github.com/ProtonMail/proton-bridge/v3/internal/services/smtp" "github.com/ProtonMail/proton-bridge/v3/internal/telemetry/mocks" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/tests" @@ -155,7 +156,8 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma manager.EXPECT().IsTelemetryAvailable(context.Background()).AnyTimes() nullEventSubscription := events.NewNullSubscription() - nullServerManager := imapservice.NewNullIMAPServerManager() + nullIMAPServerManager := imapservice.NewNullIMAPServerManager() + nullSMTPServerManager := smtp.NewNullServerManager() user, err := New( ctx, @@ -168,7 +170,8 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma vault.DefaultMaxSyncMemory, tb.TempDir(), manager, - nullServerManager, + nullIMAPServerManager, + nullSMTPServerManager, nullEventSubscription, ) require.NoError(tb, err)