diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 43cd7087..634181d7 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -188,18 +188,6 @@ func New( //nolint:funlen return nil, nil, fmt.Errorf("failed to initialize bridge: %w", err) } - // Start serving IMAP. - if err := bridge.serveIMAP(); err != nil { - logrus.WithError(err).Error("IMAP error") - bridge.PushError(ErrServeIMAP) - } - - // Start serving SMTP. - if err := bridge.serveSMTP(); err != nil { - logrus.WithError(err).Error("SMTP error") - bridge.PushError(ErrServeSMTP) - } - return bridge, eventCh, nil } @@ -383,15 +371,33 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error { }) }) - // Attempt to lazy load users when triggered. + // Attempt to lazy load users when triggered. Only load once. + var loaded bool bridge.goLoad = bridge.tasks.Trigger(func(ctx context.Context) { + if loaded { + logrus.Debug("All users are already loaded, skipping") + return + } + logrus.Info("Loading users") if err := bridge.loadUsers(ctx); err != nil { logrus.WithError(err).Error("Failed to load users") - } else { - bridge.publish(events.AllUsersLoaded{}) + return } + + bridge.publish(events.AllUsersLoaded{}) + + // Once all users have been loaded, start the bridge's IMAP and SMTP servers. + if err := bridge.serveIMAP(); err != nil { + logrus.WithError(err).Error("Failed to start IMAP server") + } + + if err := bridge.serveSMTP(); err != nil { + logrus.WithError(err).Error("Failed to start SMTP server") + } + + loaded = true }) defer bridge.goLoad() diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 20d5a602..77e6f4a1 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -671,6 +671,10 @@ func withBridgeNoMocks( // Wait for bridge to finish loading users. waitForEvent(t, eventCh, events.AllUsersLoaded{}) + // Wait for bridge to start the IMAP server. + waitForEvent(t, eventCh, events.IMAPServerReady{}) + // Wait for bridge to start the SMTP server. + waitForEvent(t, eventCh, events.SMTPServerReady{}) // Set random IMAP and SMTP ports for the tests. require.NoError(t, bridge.SetIMAPPort(0)) diff --git a/internal/bridge/imap.go b/internal/bridge/imap.go index cd3d25ad..358da57f 100644 --- a/internal/bridge/imap.go +++ b/internal/bridge/imap.go @@ -46,28 +46,42 @@ const ( ) func (bridge *Bridge) serveIMAP() error { - if bridge.imapServer == nil { - return fmt.Errorf("no imap server instance running") + if port, err := func() (int, error) { + if bridge.imapServer == nil { + return 0, fmt.Errorf("no IMAP server instance running") + } + + logrus.Info("Starting IMAP server") + + imapListener, err := newListener(bridge.vault.GetIMAPPort(), bridge.vault.GetIMAPSSL(), bridge.tlsConfig) + if err != nil { + return 0, fmt.Errorf("failed to create IMAP listener: %w", err) + } + + bridge.imapListener = imapListener + + if err := bridge.imapServer.Serve(context.Background(), bridge.imapListener); err != nil { + return 0, fmt.Errorf("failed to serve IMAP: %w", err) + } + + if err := bridge.vault.SetIMAPPort(getPort(imapListener.Addr())); err != nil { + return 0, fmt.Errorf("failed to store IMAP port in vault: %w", err) + } + + return getPort(imapListener.Addr()), nil + }(); err != nil { + bridge.publish(events.IMAPServerError{ + Error: err, + }) + + return err + } else { + bridge.publish(events.IMAPServerReady{ + Port: port, + }) + + return nil } - - logrus.Info("Starting IMAP server") - - imapListener, err := newListener(bridge.vault.GetIMAPPort(), bridge.vault.GetIMAPSSL(), bridge.tlsConfig) - if err != nil { - return fmt.Errorf("failed to create IMAP listener: %w", err) - } - - bridge.imapListener = imapListener - - if err := bridge.imapServer.Serve(context.Background(), bridge.imapListener); err != nil { - return fmt.Errorf("failed to serve IMAP: %w", err) - } - - if err := bridge.vault.SetIMAPPort(getPort(imapListener.Addr())); err != nil { - return fmt.Errorf("failed to store IMAP port in vault: %w", err) - } - - return nil } func (bridge *Bridge) restartIMAP() error { @@ -77,6 +91,8 @@ func (bridge *Bridge) restartIMAP() error { if err := bridge.imapListener.Close(); err != nil { return fmt.Errorf("failed to close IMAP listener: %w", err) } + + bridge.publish(events.IMAPServerStopped{}) } return bridge.serveIMAP() @@ -89,6 +105,7 @@ func (bridge *Bridge) closeIMAP(ctx context.Context) error { if err := bridge.imapServer.Close(ctx); err != nil { return fmt.Errorf("failed to close IMAP server: %w", err) } + bridge.imapServer = nil } @@ -98,6 +115,8 @@ func (bridge *Bridge) closeIMAP(ctx context.Context) error { } } + bridge.publish(events.IMAPServerStopped{}) + return nil } diff --git a/internal/bridge/smtp.go b/internal/bridge/smtp.go index 35642e6f..0ecf269d 100644 --- a/internal/bridge/smtp.go +++ b/internal/bridge/smtp.go @@ -22,6 +22,7 @@ import ( "crypto/tls" "fmt" + "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/logging" "github.com/ProtonMail/proton-bridge/v3/internal/constants" @@ -31,26 +32,40 @@ import ( ) func (bridge *Bridge) serveSMTP() error { - logrus.Info("Starting SMTP server") + if port, err := func() (int, error) { + logrus.Info("Starting SMTP server") - smtpListener, err := newListener(bridge.vault.GetSMTPPort(), bridge.vault.GetSMTPSSL(), bridge.tlsConfig) - if err != nil { - return fmt.Errorf("failed to create SMTP listener: %w", err) - } - - bridge.smtpListener = smtpListener - - bridge.tasks.Once(func(context.Context) { - if err := bridge.smtpServer.Serve(smtpListener); err != nil { - logrus.WithError(err).Info("SMTP server stopped") + smtpListener, err := newListener(bridge.vault.GetSMTPPort(), bridge.vault.GetSMTPSSL(), bridge.tlsConfig) + if err != nil { + return 0, fmt.Errorf("failed to create SMTP listener: %w", err) } - }) - if err := bridge.vault.SetSMTPPort(getPort(smtpListener.Addr())); err != nil { - return fmt.Errorf("failed to store SMTP port in vault: %w", err) + bridge.smtpListener = smtpListener + + bridge.tasks.Once(func(context.Context) { + if err := bridge.smtpServer.Serve(smtpListener); err != nil { + logrus.WithError(err).Info("SMTP server stopped") + } + }) + + if err := bridge.vault.SetSMTPPort(getPort(smtpListener.Addr())); err != nil { + return 0, fmt.Errorf("failed to store SMTP port in vault: %w", err) + } + + return getPort(smtpListener.Addr()), nil + }(); err != nil { + bridge.publish(events.SMTPServerError{ + Error: err, + }) + + return err + } else { + bridge.publish(events.SMTPServerReady{ + Port: port, + }) + + return nil } - - return nil } func (bridge *Bridge) restartSMTP() error { @@ -60,6 +75,8 @@ func (bridge *Bridge) restartSMTP() error { return fmt.Errorf("failed to close SMTP: %w", err) } + bridge.publish(events.SMTPServerStopped{}) + bridge.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP) return bridge.serveSMTP() @@ -82,6 +99,8 @@ func (bridge *Bridge) closeSMTP() error { logrus.WithError(err).Debug("Failed to close SMTP server (expected -- we close the listener ourselves)") } + bridge.publish(events.SMTPServerStopped{}) + return nil } diff --git a/internal/events/serve.go b/internal/events/serve.go new file mode 100644 index 00000000..ca6c3500 --- /dev/null +++ b/internal/events/serve.go @@ -0,0 +1,59 @@ +package events + +import "fmt" + +type IMAPServerReady struct { + eventBase + + Port int +} + +func (event IMAPServerReady) String() string { + return fmt.Sprintf("IMAPServerReady: Port %d", event.Port) +} + +type IMAPServerStopped struct { + eventBase +} + +func (event IMAPServerStopped) String() string { + return "IMAPServerStopped" +} + +type IMAPServerError struct { + eventBase + + Error error +} + +func (event IMAPServerError) String() string { + return fmt.Sprintf("IMAPServerError: %v", event.Error) +} + +type SMTPServerReady struct { + eventBase + + Port int +} + +func (event SMTPServerReady) String() string { + return fmt.Sprintf("SMTPServerReady: Port %d", event.Port) +} + +type SMTPServerStopped struct { + eventBase +} + +func (event SMTPServerStopped) String() string { + return "SMTPServerStopped" +} + +type SMTPServerError struct { + eventBase + + Error error +} + +func (event SMTPServerError) String() string { + return fmt.Sprintf("SMTPServerError: %v", event.Error) +}