diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 1ac94815..c9529e10 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -173,7 +173,19 @@ func TestBridge_UserAgent(t *testing.T) { func TestBridge_UserAgent_Persistence(t *testing.T) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) { + otherPassword := []byte("bar") + otherUser := "foo" + _, _, err := s.CreateUser(otherUser, otherPassword) + require.NoError(t, err) + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) { + imapWaiter := waitForIMAPServerReady(b) + defer imapWaiter.Done() + + require.NoError(t, getErr(b.LoginFull(ctx, otherUser, otherPassword, nil, nil))) + + imapWaiter.Wait() + currentUserAgent := b.GetCurrentUserAgent() require.Contains(t, currentUserAgent, vault.DefaultUserAgent) @@ -220,7 +232,19 @@ func TestBridge_UserAgentFromIMAPID(t *testing.T) { calls = append(calls, call) }) + otherPassword := []byte("bar") + otherUser := "foo" + _, _, err := s.CreateUser(otherUser, otherPassword) + require.NoError(t, err) + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) { + imapWaiter := waitForIMAPServerReady(b) + defer imapWaiter.Done() + + require.NoError(t, getErr(b.LoginFull(ctx, otherUser, otherPassword, nil, nil))) + + imapWaiter.Wait() + imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) require.NoError(t, err) defer func() { _ = imapClient.Logout() }() @@ -592,9 +616,17 @@ func TestBridge_InitGluonDirectory(t *testing.T) { func TestBridge_LoginFailed(t *testing.T) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + imapWaiter := waitForIMAPServerReady(bridge) + defer imapWaiter.Done() + failCh, done := chToType[events.Event, events.IMAPLoginFailed](bridge.GetEvents(events.IMAPLoginFailed{})) defer done() + _, err := bridge.LoginFull(ctx, username, password, nil, nil) + require.NoError(t, err) + + imapWaiter.Wait() + imapClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort()))) require.NoError(t, err) @@ -622,6 +654,9 @@ func TestBridge_ChangeCacheDirectory(t *testing.T) { configDir, err := b.GetGluonDataDir() require.NoError(t, err) + imapWaiter := waitForIMAPServerReady(b) + defer imapWaiter.Done() + // Login the user. syncCh, done := chToType[events.Event, events.SyncFinished](b.GetEvents(events.SyncFinished{})) defer done() @@ -655,6 +690,8 @@ func TestBridge_ChangeCacheDirectory(t *testing.T) { require.NoError(t, err) require.True(t, info.State == bridge.Connected) + imapWaiter.Wait() + client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) require.NoError(t, err) require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) @@ -778,6 +815,7 @@ func withBridgeNoMocks( locator bridge.Locator, vaultKey []byte, tests func(*bridge.Bridge), + waitOnServers bool, ) { // Bridge will disable the proxy by default at startup. mocks.ProxyCtl.EXPECT().DisallowProxy() @@ -828,15 +866,18 @@ func withBridgeNoMocks( // Wait for bridge to finish loading users. waitForEvent(t, eventCh, events.AllUsersLoaded{}) - // Wait for bridge to start the IMAP server. - waitForEvent(t, eventCh, events.IMAPServerReady{}) - // Wait for bridge to start the SMTP server. - waitForEvent(t, eventCh, events.SMTPServerReady{}) // Set random IMAP and SMTP ports for the tests. require.NoError(t, bridge.SetIMAPPort(ctx, 0)) require.NoError(t, bridge.SetSMTPPort(ctx, 0)) + if waitOnServers { + // Wait for bridge to start the IMAP server. + waitForEvent(t, eventCh, events.IMAPServerReady{}) + // Wait for bridge to start the SMTP server. + waitForEvent(t, eventCh, events.SMTPServerReady{}) + } + // Close the bridge when done. defer bridge.Close(ctx) @@ -857,7 +898,24 @@ func withBridge( withMocks(t, func(mocks *bridge.Mocks) { withBridgeNoMocks(ctx, t, mocks, apiURL, netCtl, locator, vaultKey, func(bridge *bridge.Bridge) { tests(bridge, mocks) - }) + }, false) + }) +} + +// withBridgeWaitForServers is the same as withBridge, but it will wait until IMAP & SMTP servers are ready. +func withBridgeWaitForServers( + ctx context.Context, + t *testing.T, + apiURL string, + netCtl *proton.NetCtl, + locator bridge.Locator, + vaultKey []byte, + tests func(*bridge.Bridge, *bridge.Mocks), +) { + withMocks(t, func(mocks *bridge.Mocks) { + withBridgeNoMocks(ctx, t, mocks, apiURL, netCtl, locator, vaultKey, func(bridge *bridge.Bridge) { + tests(bridge, mocks) + }, true) }) } @@ -910,3 +968,40 @@ func chToType[In, Out any](inCh <-chan In, done func()) (<-chan Out, func()) { return outCh, done } + +type eventWaiter struct { + evtCh <-chan events.Event + cancel func() +} + +func (e *eventWaiter) Done() { + e.cancel() +} + +func (e *eventWaiter) Wait() { + <-e.evtCh +} + +func waitForSMTPServerReady(b *bridge.Bridge) *eventWaiter { + evtCh, cancel := b.GetEvents(events.SMTPServerReady{}) + return &eventWaiter{ + evtCh: evtCh, + cancel: cancel, + } +} + +func waitForIMAPServerReady(b *bridge.Bridge) *eventWaiter { + evtCh, cancel := b.GetEvents(events.IMAPServerReady{}) + return &eventWaiter{ + evtCh: evtCh, + cancel: cancel, + } +} + +func waitForIMAPServerStopped(b *bridge.Bridge) *eventWaiter { + evtCh, cancel := b.GetEvents(events.IMAPServerStopped{}) + return &eventWaiter{ + evtCh: evtCh, + cancel: cancel, + } +} diff --git a/internal/bridge/send_test.go b/internal/bridge/send_test.go index c3b5e1ca..8015e4ea 100644 --- a/internal/bridge/send_test.go +++ b/internal/bridge/send_test.go @@ -46,12 +46,17 @@ func TestBridge_Send(t *testing.T) { require.NoError(t, err) withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { + smtpWaiter := waitForSMTPServerReady(bridge) + defer smtpWaiter.Done() + senderUserID, err := bridge.LoginFull(ctx, username, password, nil, nil) require.NoError(t, err) recipientUserID, err := bridge.LoginFull(ctx, "recipient", password, nil, nil) require.NoError(t, err) + smtpWaiter.Wait() + senderInfo, err := bridge.GetUserInfo(senderUserID) require.NoError(t, err) @@ -135,7 +140,7 @@ func TestBridge_SendDraftFlags(t *testing.T) { }) // Start the bridge. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { + withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { // Get the sender user info. userInfo, err := bridge.QueryUserInfo(username) require.NoError(t, err) @@ -245,7 +250,7 @@ func TestBridge_SendInvite(t *testing.T) { }) // Start the bridge. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { + withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { // Get the sender user info. userInfo, err := bridge.QueryUserInfo(username) require.NoError(t, err) @@ -401,6 +406,9 @@ SGVsbG8gd29ybGQK require.NoError(t, err) withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { + smtpWaiter := waitForSMTPServerReady(bridge) + defer smtpWaiter.Done() + senderUserID, err := bridge.LoginFull(ctx, username, password, nil, nil) require.NoError(t, err) @@ -420,6 +428,8 @@ SGVsbG8gd29ybGQK messageMultipartWithoutTextWithTextAttachment, } + smtpWaiter.Wait() + for _, m := range messages { // Dial the server. client, err := smtp.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetSMTPPort()))) diff --git a/internal/bridge/server_manager.go b/internal/bridge/server_manager.go index f1aecf4a..acca55c8 100644 --- a/internal/bridge/server_manager.go +++ b/internal/bridge/server_manager.go @@ -43,6 +43,8 @@ type ServerManager struct { smtpServer *smtp.Server smtpListener net.Listener + + loadedUserCount int } func newServerManager() *ServerManager { @@ -145,19 +147,17 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) { case evt := <-eventCh: switch evt.(type) { case events.ConnStatusDown: - // Handle connect down. + logrus.Info("Server Manager, network down stopping listeners") + if err := sm.closeSMTPServer(bridge); err != nil { + logrus.WithError(err).Error("Failed to close SMTP server") + } + if err := sm.stopIMAPListener(bridge); err != nil { + logrus.WithError(err) + } case events.ConnStatusUp: - // Handle connect up. - - case events.AllUsersLoaded: - if err := sm.serveIMAP(ctx, bridge); err != nil { - logrus.WithError(err).Error("Failed to start IMAP server") - } - - if err := sm.serveSMTP(bridge); err != nil { - logrus.WithError(err).Error("Failed to start SMTP server") - } + logrus.Info("Server Manager, network up starting listeners") + sm.handleLoadedUserCountChange(ctx, bridge) } case request, ok := <-sm.requests.ReceiveCh(): @@ -182,10 +182,18 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) { case *smRequestAddIMAPUser: err := sm.handleAddIMAPUser(ctx, r.user) request.SendReply(ctx, nil, err) + if err == nil { + sm.loadedUserCount++ + sm.handleLoadedUserCountChange(ctx, bridge) + } case *smRequestRemoveIMAPUser: err := sm.handleRemoveIMAPUser(ctx, r.user, r.withData) request.SendReply(ctx, nil, err) + if err == nil { + sm.loadedUserCount-- + sm.handleLoadedUserCountChange(ctx, bridge) + } case *smRequestSetGluonDir: err := sm.handleSetGluonDir(ctx, bridge, r.dir) @@ -203,6 +211,35 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) { } } +func (sm *ServerManager) handleLoadedUserCountChange(ctx context.Context, bridge *Bridge) { + logrus.Infof("Validating Listener State %v", sm.loadedUserCount) + if sm.shouldStartServers() { + if sm.imapListener == nil { + if err := sm.serveIMAP(ctx, bridge); err != nil { + logrus.WithError(err).Error("Failed to start IMAP server") + } + } + + if sm.smtpListener == nil { + if err := sm.restartSMTP(bridge); err != nil { + logrus.WithError(err).Error("Failed to start SMTP server") + } + } + } else { + if sm.imapListener != nil { + if err := sm.stopIMAPListener(bridge); err != nil { + logrus.WithError(err).Error("Failed to stop IMAP server") + } + } + + if sm.smtpListener != nil { + if err := sm.closeSMTPServer(bridge); err != nil { + logrus.WithError(err).Error("Failed to stop SMTP server") + } + } + } +} + func (sm *ServerManager) handleClose(ctx context.Context, bridge *Bridge) { // Close the IMAP server. if err := sm.closeIMAPServer(ctx, bridge); err != nil { @@ -358,27 +395,45 @@ func (sm *ServerManager) closeSMTPServer(bridge *Bridge) error { // This is because smtpServer.Serve() is called in a separate goroutine and might be executed // after we've already closed the server. However, go-smtp has a bug; it blocks on the listener // even after the server has been closed. So we close the listener ourselves to unblock it. - logrus.Info("Closing SMTP server") if sm.smtpListener != nil { + logrus.Info("Closing SMTP Listener") if err := sm.smtpListener.Close(); err != nil { return fmt.Errorf("failed to close SMTP listener: %w", err) } + + sm.smtpListener = nil } - if err := sm.smtpServer.Close(); err != nil { - logrus.WithError(err).Debug("Failed to close SMTP server (expected -- we close the listener ourselves)") - } + if sm.smtpServer != nil { + logrus.Info("Closing SMTP server") + if err := sm.smtpServer.Close(); err != nil { + logrus.WithError(err).Debug("Failed to close SMTP server (expected -- we close the listener ourselves)") + } - bridge.publish(events.SMTPServerStopped{}) + sm.smtpServer = nil + + bridge.publish(events.SMTPServerStopped{}) + } return nil } func (sm *ServerManager) closeIMAPServer(ctx context.Context, bridge *Bridge) error { - logrus.Info("Closing IMAP server") + if sm.imapListener != nil { + logrus.Info("Closing IMAP Listener") + + if err := sm.imapListener.Close(); err != nil { + return fmt.Errorf("failed to close IMAP listener: %w", err) + } + + sm.imapListener = nil + + bridge.publish(events.IMAPServerStopped{}) + } if sm.imapServer != nil { + logrus.Info("Closing IMAP server") if err := sm.imapServer.Close(ctx); err != nil { return fmt.Errorf("failed to close IMAP server: %w", err) } @@ -386,16 +441,6 @@ func (sm *ServerManager) closeIMAPServer(ctx context.Context, bridge *Bridge) er sm.imapServer = nil } - if sm.imapListener != nil { - if err := sm.imapListener.Close(); err != nil { - return fmt.Errorf("failed to close IMAP listener: %w", err) - } - - sm.imapListener = nil - } - - bridge.publish(events.IMAPServerStopped{}) - return nil } @@ -412,7 +457,11 @@ func (sm *ServerManager) restartIMAP(ctx context.Context, bridge *Bridge) error bridge.publish(events.IMAPServerStopped{}) } - return sm.serveIMAP(ctx, bridge) + if sm.shouldStartServers() { + return sm.serveIMAP(ctx, bridge) + } + + return nil } func (sm *ServerManager) restartSMTP(bridge *Bridge) error { @@ -426,7 +475,11 @@ func (sm *ServerManager) restartSMTP(bridge *Bridge) error { sm.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP) - return sm.serveSMTP(bridge) + if sm.shouldStartServers() { + return sm.serveSMTP(bridge) + } + + return nil } func (sm *ServerManager) serveSMTP(bridge *Bridge) error { @@ -515,6 +568,21 @@ func (sm *ServerManager) serveIMAP(ctx context.Context, bridge *Bridge) error { return nil } +func (sm *ServerManager) stopIMAPListener(bridge *Bridge) error { + logrus.Info("Stopping IMAP listener") + if sm.imapListener != nil { + if err := sm.imapListener.Close(); err != nil { + return err + } + + sm.imapListener = nil + + bridge.publish(events.IMAPServerStopped{}) + } + + return nil +} + func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge, newGluonDir string) error { return safe.RLockRet(func() error { currentGluonDir := bridge.GetGluonCacheDir() @@ -527,6 +595,8 @@ func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge, return fmt.Errorf("failed to close IMAP: %w", err) } + sm.loadedUserCount = 0 + if err := bridge.moveGluonCacheDir(currentGluonDir, newGluonDir); err != nil { logrus.WithError(err).Error("failed to move GluonCacheDir") @@ -560,15 +630,17 @@ func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge, } sm.imapServer = imapServer - for _, bridgeUser := range bridge.users { if err := sm.handleAddIMAPUser(ctx, bridgeUser); err != nil { return fmt.Errorf("failed to add users to new IMAP server: %w", err) } + sm.loadedUserCount++ } - if err := sm.serveIMAP(ctx, bridge); err != nil { - return fmt.Errorf("failed to serve IMAP: %w", err) + if sm.shouldStartServers() { + if err := sm.serveIMAP(ctx, bridge); err != nil { + return fmt.Errorf("failed to serve IMAP: %w", err) + } } return nil @@ -591,6 +663,10 @@ func (sm *ServerManager) handleRemoveGluonUser(ctx context.Context, userID strin return sm.imapServer.RemoveUser(ctx, userID, true) } +func (sm *ServerManager) shouldStartServers() bool { + return sm.loadedUserCount >= 1 +} + type smRequestClose struct{} type smRequestRestartIMAP struct{} diff --git a/internal/bridge/server_manager_test.go b/internal/bridge/server_manager_test.go new file mode 100644 index 00000000..23a94a1d --- /dev/null +++ b/internal/bridge/server_manager_test.go @@ -0,0 +1,171 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package bridge_test + +import ( + "context" + "fmt" + "testing" + + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/server" + "github.com/ProtonMail/proton-bridge/v3/internal/bridge" + "github.com/ProtonMail/proton-bridge/v3/internal/constants" + "github.com/ProtonMail/proton-bridge/v3/internal/events" + "github.com/emersion/go-imap/client" + "github.com/stretchr/testify/require" +) + +func TestServerManager_NoLoadedUsersNoServers(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + _, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort())) + require.Error(t, err) + }) + }) +} + +func TestServerManager_ServersStartAfterFirstConnectedUser(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + imapWaiter := waitForIMAPServerReady(bridge) + defer imapWaiter.Done() + + smtpWaiter := waitForSMTPServerReady(bridge) + defer smtpWaiter.Done() + + _, err := bridge.LoginFull(ctx, username, password, nil, nil) + require.NoError(t, err) + + imapWaiter.Wait() + smtpWaiter.Wait() + }) + }) +} + +func TestServerManager_ServersStopsAfterUserLogsOut(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + imapWaiter := waitForIMAPServerReady(bridge) + defer imapWaiter.Done() + + smtpWaiter := waitForSMTPServerReady(bridge) + defer smtpWaiter.Done() + + userID, err := bridge.LoginFull(ctx, username, password, nil, nil) + require.NoError(t, err) + + imapWaiter.Wait() + smtpWaiter.Wait() + + imapWaiterStopped := waitForIMAPServerStopped(bridge) + defer imapWaiterStopped.Done() + + require.NoError(t, bridge.LogoutUser(ctx, userID)) + + imapWaiterStopped.Wait() + }) + }) +} + +func TestServerManager_ServersDoNotStopWhenThereIsStillOneActiveUser(t *testing.T) { + otherPassword := []byte("bar") + otherUser := "foo" + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + _, _, err := s.CreateUser(otherUser, otherPassword) + require.NoError(t, err) + + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + imapWaiter := waitForIMAPServerReady(bridge) + defer imapWaiter.Done() + + smtpWaiter := waitForSMTPServerReady(bridge) + defer smtpWaiter.Done() + + _, err := bridge.LoginFull(ctx, username, password, nil, nil) + require.NoError(t, err) + + userIDOther, err := bridge.LoginFull(ctx, otherUser, otherPassword, nil, nil) + require.NoError(t, err) + + imapWaiter.Wait() + smtpWaiter.Wait() + + evtCh, cancel := bridge.GetEvents(events.UserDeauth{}) + defer cancel() + + require.NoError(t, s.RevokeUser(userIDOther)) + + waitForEvent(t, evtCh, events.UserDeauth{}) + + imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort())) + require.NoError(t, err) + require.NoError(t, imapClient.Logout()) + }) + }) +} + +func TestServerManager_ServersStartIfAtLeastOneUserIsLoggedIn(t *testing.T) { + otherPassword := []byte("bar") + otherUser := "foo" + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + userIDOther, _, err := s.CreateUser(otherUser, otherPassword) + require.NoError(t, err) + + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + _, err := bridge.LoginFull(ctx, username, password, nil, nil) + require.NoError(t, err) + + _, err = bridge.LoginFull(ctx, otherUser, otherPassword, nil, nil) + require.NoError(t, err) + }) + + require.NoError(t, s.RevokeUser(userIDOther)) + + withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort())) + require.NoError(t, err) + require.NoError(t, imapClient.Logout()) + }) + }) +} + +func TestServerManager_NetworkLossStopsServers(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + imapWaiter := waitForIMAPServerReady(bridge) + defer imapWaiter.Done() + + imapWaiterStop := waitForIMAPServerStopped(bridge) + defer imapWaiterStop.Done() + + _, err := bridge.LoginFull(ctx, username, password, nil, nil) + require.NoError(t, err) + + imapWaiter.Wait() + + netCtl.Disable() + + imapWaiterStop.Wait() + + netCtl.Enable() + + imapWaiter.Wait() + }) + }) +} diff --git a/internal/bridge/sync_test.go b/internal/bridge/sync_test.go index a83dac42..c880f7b3 100644 --- a/internal/bridge/sync_test.go +++ b/internal/bridge/sync_test.go @@ -112,15 +112,6 @@ func TestBridge_Sync(t *testing.T) { info, err := b.GetUserInfo(userID) require.NoError(t, err) require.True(t, info.State == bridge.Connected) - - client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) - require.NoError(t, err) - require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) - defer func() { _ = client.Logout() }() - - status, err := client.Select(`Folders/folder`, false) - require.NoError(t, err) - require.Less(t, status.Messages, uint32(numMsg)) } // Remove the network limit, allowing the sync to finish. @@ -273,15 +264,6 @@ func TestBridge_SyncWithOngoingEvents(t *testing.T) { info, err := b.GetUserInfo(userID) require.NoError(t, err) require.True(t, info.State == bridge.Connected) - - client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) - require.NoError(t, err) - require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) - defer func() { _ = client.Logout() }() - - status, err := client.Select(`Folders/folder`, false) - require.NoError(t, err) - require.Less(t, status.Messages, uint32(numMsg)) } // Create a new mailbox and move that last 1/3 of the messages into it to simulate user diff --git a/internal/bridge/user_event_test.go b/internal/bridge/user_event_test.go index c5356a07..caae0134 100644 --- a/internal/bridge/user_event_test.go +++ b/internal/bridge/user_event_test.go @@ -141,6 +141,9 @@ func test_badMessage_badEvent(userFeedback func(t *testing.T, ctx context.Contex }) withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + smtpWaiter := waitForSMTPServerReady(bridge) + defer smtpWaiter.Done() + userLoginAndSync(ctx, t, bridge, "user", password) var messageIDs []string @@ -176,6 +179,8 @@ func test_badMessage_badEvent(userFeedback func(t *testing.T, ctx context.Contex userFeedback(t, ctx, bridge, badUserID) + smtpWaiter.Wait() + userContinueEventProcess(ctx, t, s, bridge) }) }) @@ -194,6 +199,9 @@ func TestBridge_User_BadMessage_NoBadEvent(t *testing.T) { }) withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + smtpWaiter := waitForSMTPServerReady(bridge) + defer smtpWaiter.Done() + userLoginAndSync(ctx, t, bridge, "user", password) var messageIDs []string @@ -217,6 +225,7 @@ func TestBridge_User_BadMessage_NoBadEvent(t *testing.T) { require.NoError(t, c.DeleteMessage(ctx, messageIDs...)) }) + smtpWaiter.Wait() userContinueEventProcess(ctx, t, s, bridge) }) }) @@ -412,6 +421,17 @@ func TestBridge_User_DropConn_NoBadEvent(t *testing.T) { }) withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + var count int32 + // The first 10 times bridge attempts to sync any of the messages, drop the connection. + s.AddStatusHook(func(req *http.Request) (int, bool) { + if strings.Contains(req.URL.Path, "/mail/v4/messages") { + if atomic.AddInt32(&count, 1) < 10 { + dropListener.DropAll() + } + } + + return 0, false + }) userLoginAndSync(ctx, t, bridge, "user", password) mocks.Reporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Any()).AnyTimes() @@ -421,19 +441,6 @@ func TestBridge_User_DropConn_NoBadEvent(t *testing.T) { createNumMessages(ctx, t, c, addrID, proton.InboxLabel, 10) }) - var count int - - // The first 10 times bridge attempts to sync any of the messages, drop the connection. - s.AddStatusHook(func(req *http.Request) (int, bool) { - if strings.Contains(req.URL.Path, "/mail/v4/messages") { - if count++; count < 10 { - dropListener.DropAll() - } - } - - return 0, false - }) - info, err := bridge.QueryUserInfo("user") require.NoError(t, err) @@ -771,11 +778,16 @@ func TestBridge_User_CreateDisabledAddress(t *testing.T) { func TestBridge_User_HandleParentLabelRename(t *testing.T) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + imapWaiter := waitForIMAPServerReady(bridge) + defer imapWaiter.Done() + require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) info, err := bridge.QueryUserInfo(username) require.NoError(t, err) + imapWaiter.Wait() + client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort())) require.NoError(t, err) require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) diff --git a/tests/ctx_test.go b/tests/ctx_test.go index 5fad8fee..b37a0a99 100644 --- a/tests/ctx_test.go +++ b/tests/ctx_test.go @@ -167,6 +167,8 @@ type testCtx struct { // This slice contains the dummy listeners that are intended to block network ports. dummyListeners []net.Listener + + imapServerStarted bool } type imapClient struct { diff --git a/tests/user_test.go b/tests/user_test.go index c06f0bd7..f1dffda8 100644 --- a/tests/user_test.go +++ b/tests/user_test.go @@ -28,6 +28,7 @@ import ( "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v3/internal/bridge" + "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/pkg/algo" "github.com/bradenaw/juniper/iterator" @@ -331,10 +332,20 @@ func (s *scenario) drafAtIndexWasMovedToTrashForAddressOfAccount(draftIndex int, } func (s *scenario) userLogsInWithUsernameAndPassword(username, password string) error { + evtCh, cancel := s.t.bridge.GetEvents(events.SMTPServerReady{}) + defer cancel() + userID, err := s.t.bridge.LoginFull(context.Background(), username, []byte(password), nil, nil) if err != nil { s.t.pushError(err) } else { + // We need to wait for server to be up or we won't be able to connect. It should only happen once to avoid + // blocking on multiple Logins. + if !s.t.imapServerStarted { + <-evtCh + s.t.imapServerStarted = true + } + if userID != s.t.getUserByName(username).getUserID() { return errors.New("user ID mismatch") }