feat(GODT-2585): Only Start IMAP/SMTP once one user is loaded

Update ServerManager to follow the new expected behavior. The servers
will only be started when one user is active.

If all users are logged out or removed from the system, the servers will
stop.

If the network goes down, the servers will stop and resume once network
has been restored.
This commit is contained in:
Leander Beernaert
2023-05-10 10:05:47 +02:00
parent fb4a0e77af
commit 4b5edd62d0
8 changed files with 429 additions and 70 deletions

View File

@ -173,7 +173,19 @@ func TestBridge_UserAgent(t *testing.T) {
func TestBridge_UserAgent_Persistence(t *testing.T) { func TestBridge_UserAgent_Persistence(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
otherPassword := []byte("bar")
otherUser := "foo"
_, _, err := s.CreateUser(otherUser, otherPassword)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(b)
defer imapWaiter.Done()
require.NoError(t, getErr(b.LoginFull(ctx, otherUser, otherPassword, nil, nil)))
imapWaiter.Wait()
currentUserAgent := b.GetCurrentUserAgent() currentUserAgent := b.GetCurrentUserAgent()
require.Contains(t, currentUserAgent, vault.DefaultUserAgent) require.Contains(t, currentUserAgent, vault.DefaultUserAgent)
@ -220,7 +232,19 @@ func TestBridge_UserAgentFromIMAPID(t *testing.T) {
calls = append(calls, call) calls = append(calls, call)
}) })
otherPassword := []byte("bar")
otherUser := "foo"
_, _, err := s.CreateUser(otherUser, otherPassword)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(b)
defer imapWaiter.Done()
require.NoError(t, getErr(b.LoginFull(ctx, otherUser, otherPassword, nil, nil)))
imapWaiter.Wait()
imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = imapClient.Logout() }() defer func() { _ = imapClient.Logout() }()
@ -592,9 +616,17 @@ func TestBridge_InitGluonDirectory(t *testing.T) {
func TestBridge_LoginFailed(t *testing.T) { func TestBridge_LoginFailed(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
failCh, done := chToType[events.Event, events.IMAPLoginFailed](bridge.GetEvents(events.IMAPLoginFailed{})) failCh, done := chToType[events.Event, events.IMAPLoginFailed](bridge.GetEvents(events.IMAPLoginFailed{}))
defer done() defer done()
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
imapWaiter.Wait()
imapClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort()))) imapClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
require.NoError(t, err) require.NoError(t, err)
@ -622,6 +654,9 @@ func TestBridge_ChangeCacheDirectory(t *testing.T) {
configDir, err := b.GetGluonDataDir() configDir, err := b.GetGluonDataDir()
require.NoError(t, err) require.NoError(t, err)
imapWaiter := waitForIMAPServerReady(b)
defer imapWaiter.Done()
// Login the user. // Login the user.
syncCh, done := chToType[events.Event, events.SyncFinished](b.GetEvents(events.SyncFinished{})) syncCh, done := chToType[events.Event, events.SyncFinished](b.GetEvents(events.SyncFinished{}))
defer done() defer done()
@ -655,6 +690,8 @@ func TestBridge_ChangeCacheDirectory(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.True(t, info.State == bridge.Connected) require.True(t, info.State == bridge.Connected)
imapWaiter.Wait()
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
@ -778,6 +815,7 @@ func withBridgeNoMocks(
locator bridge.Locator, locator bridge.Locator,
vaultKey []byte, vaultKey []byte,
tests func(*bridge.Bridge), tests func(*bridge.Bridge),
waitOnServers bool,
) { ) {
// Bridge will disable the proxy by default at startup. // Bridge will disable the proxy by default at startup.
mocks.ProxyCtl.EXPECT().DisallowProxy() mocks.ProxyCtl.EXPECT().DisallowProxy()
@ -828,15 +866,18 @@ func withBridgeNoMocks(
// Wait for bridge to finish loading users. // Wait for bridge to finish loading users.
waitForEvent(t, eventCh, events.AllUsersLoaded{}) waitForEvent(t, eventCh, events.AllUsersLoaded{})
// Wait for bridge to start the IMAP server.
waitForEvent(t, eventCh, events.IMAPServerReady{})
// Wait for bridge to start the SMTP server.
waitForEvent(t, eventCh, events.SMTPServerReady{})
// Set random IMAP and SMTP ports for the tests. // Set random IMAP and SMTP ports for the tests.
require.NoError(t, bridge.SetIMAPPort(ctx, 0)) require.NoError(t, bridge.SetIMAPPort(ctx, 0))
require.NoError(t, bridge.SetSMTPPort(ctx, 0)) require.NoError(t, bridge.SetSMTPPort(ctx, 0))
if waitOnServers {
// Wait for bridge to start the IMAP server.
waitForEvent(t, eventCh, events.IMAPServerReady{})
// Wait for bridge to start the SMTP server.
waitForEvent(t, eventCh, events.SMTPServerReady{})
}
// Close the bridge when done. // Close the bridge when done.
defer bridge.Close(ctx) defer bridge.Close(ctx)
@ -857,7 +898,24 @@ func withBridge(
withMocks(t, func(mocks *bridge.Mocks) { withMocks(t, func(mocks *bridge.Mocks) {
withBridgeNoMocks(ctx, t, mocks, apiURL, netCtl, locator, vaultKey, func(bridge *bridge.Bridge) { withBridgeNoMocks(ctx, t, mocks, apiURL, netCtl, locator, vaultKey, func(bridge *bridge.Bridge) {
tests(bridge, mocks) tests(bridge, mocks)
}) }, false)
})
}
// withBridgeWaitForServers is the same as withBridge, but it will wait until IMAP & SMTP servers are ready.
func withBridgeWaitForServers(
ctx context.Context,
t *testing.T,
apiURL string,
netCtl *proton.NetCtl,
locator bridge.Locator,
vaultKey []byte,
tests func(*bridge.Bridge, *bridge.Mocks),
) {
withMocks(t, func(mocks *bridge.Mocks) {
withBridgeNoMocks(ctx, t, mocks, apiURL, netCtl, locator, vaultKey, func(bridge *bridge.Bridge) {
tests(bridge, mocks)
}, true)
}) })
} }
@ -910,3 +968,40 @@ func chToType[In, Out any](inCh <-chan In, done func()) (<-chan Out, func()) {
return outCh, done return outCh, done
} }
type eventWaiter struct {
evtCh <-chan events.Event
cancel func()
}
func (e *eventWaiter) Done() {
e.cancel()
}
func (e *eventWaiter) Wait() {
<-e.evtCh
}
func waitForSMTPServerReady(b *bridge.Bridge) *eventWaiter {
evtCh, cancel := b.GetEvents(events.SMTPServerReady{})
return &eventWaiter{
evtCh: evtCh,
cancel: cancel,
}
}
func waitForIMAPServerReady(b *bridge.Bridge) *eventWaiter {
evtCh, cancel := b.GetEvents(events.IMAPServerReady{})
return &eventWaiter{
evtCh: evtCh,
cancel: cancel,
}
}
func waitForIMAPServerStopped(b *bridge.Bridge) *eventWaiter {
evtCh, cancel := b.GetEvents(events.IMAPServerStopped{})
return &eventWaiter{
evtCh: evtCh,
cancel: cancel,
}
}

View File

@ -46,12 +46,17 @@ func TestBridge_Send(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
senderUserID, err := bridge.LoginFull(ctx, username, password, nil, nil) senderUserID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err) require.NoError(t, err)
recipientUserID, err := bridge.LoginFull(ctx, "recipient", password, nil, nil) recipientUserID, err := bridge.LoginFull(ctx, "recipient", password, nil, nil)
require.NoError(t, err) require.NoError(t, err)
smtpWaiter.Wait()
senderInfo, err := bridge.GetUserInfo(senderUserID) senderInfo, err := bridge.GetUserInfo(senderUserID)
require.NoError(t, err) require.NoError(t, err)
@ -135,7 +140,7 @@ func TestBridge_SendDraftFlags(t *testing.T) {
}) })
// Start the bridge. // Start the bridge.
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
// Get the sender user info. // Get the sender user info.
userInfo, err := bridge.QueryUserInfo(username) userInfo, err := bridge.QueryUserInfo(username)
require.NoError(t, err) require.NoError(t, err)
@ -245,7 +250,7 @@ func TestBridge_SendInvite(t *testing.T) {
}) })
// Start the bridge. // Start the bridge.
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
// Get the sender user info. // Get the sender user info.
userInfo, err := bridge.QueryUserInfo(username) userInfo, err := bridge.QueryUserInfo(username)
require.NoError(t, err) require.NoError(t, err)
@ -401,6 +406,9 @@ SGVsbG8gd29ybGQK
require.NoError(t, err) require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
senderUserID, err := bridge.LoginFull(ctx, username, password, nil, nil) senderUserID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err) require.NoError(t, err)
@ -420,6 +428,8 @@ SGVsbG8gd29ybGQK
messageMultipartWithoutTextWithTextAttachment, messageMultipartWithoutTextWithTextAttachment,
} }
smtpWaiter.Wait()
for _, m := range messages { for _, m := range messages {
// Dial the server. // Dial the server.
client, err := smtp.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetSMTPPort()))) client, err := smtp.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetSMTPPort())))

View File

@ -43,6 +43,8 @@ type ServerManager struct {
smtpServer *smtp.Server smtpServer *smtp.Server
smtpListener net.Listener smtpListener net.Listener
loadedUserCount int
} }
func newServerManager() *ServerManager { func newServerManager() *ServerManager {
@ -145,19 +147,17 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) {
case evt := <-eventCh: case evt := <-eventCh:
switch evt.(type) { switch evt.(type) {
case events.ConnStatusDown: case events.ConnStatusDown:
// Handle connect down. logrus.Info("Server Manager, network down stopping listeners")
if err := sm.closeSMTPServer(bridge); err != nil {
logrus.WithError(err).Error("Failed to close SMTP server")
}
if err := sm.stopIMAPListener(bridge); err != nil {
logrus.WithError(err)
}
case events.ConnStatusUp: case events.ConnStatusUp:
// Handle connect up. logrus.Info("Server Manager, network up starting listeners")
sm.handleLoadedUserCountChange(ctx, bridge)
case events.AllUsersLoaded:
if err := sm.serveIMAP(ctx, bridge); err != nil {
logrus.WithError(err).Error("Failed to start IMAP server")
}
if err := sm.serveSMTP(bridge); err != nil {
logrus.WithError(err).Error("Failed to start SMTP server")
}
} }
case request, ok := <-sm.requests.ReceiveCh(): case request, ok := <-sm.requests.ReceiveCh():
@ -182,10 +182,18 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) {
case *smRequestAddIMAPUser: case *smRequestAddIMAPUser:
err := sm.handleAddIMAPUser(ctx, r.user) err := sm.handleAddIMAPUser(ctx, r.user)
request.SendReply(ctx, nil, err) request.SendReply(ctx, nil, err)
if err == nil {
sm.loadedUserCount++
sm.handleLoadedUserCountChange(ctx, bridge)
}
case *smRequestRemoveIMAPUser: case *smRequestRemoveIMAPUser:
err := sm.handleRemoveIMAPUser(ctx, r.user, r.withData) err := sm.handleRemoveIMAPUser(ctx, r.user, r.withData)
request.SendReply(ctx, nil, err) request.SendReply(ctx, nil, err)
if err == nil {
sm.loadedUserCount--
sm.handleLoadedUserCountChange(ctx, bridge)
}
case *smRequestSetGluonDir: case *smRequestSetGluonDir:
err := sm.handleSetGluonDir(ctx, bridge, r.dir) err := sm.handleSetGluonDir(ctx, bridge, r.dir)
@ -203,6 +211,35 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) {
} }
} }
func (sm *ServerManager) handleLoadedUserCountChange(ctx context.Context, bridge *Bridge) {
logrus.Infof("Validating Listener State %v", sm.loadedUserCount)
if sm.shouldStartServers() {
if sm.imapListener == nil {
if err := sm.serveIMAP(ctx, bridge); err != nil {
logrus.WithError(err).Error("Failed to start IMAP server")
}
}
if sm.smtpListener == nil {
if err := sm.restartSMTP(bridge); err != nil {
logrus.WithError(err).Error("Failed to start SMTP server")
}
}
} else {
if sm.imapListener != nil {
if err := sm.stopIMAPListener(bridge); err != nil {
logrus.WithError(err).Error("Failed to stop IMAP server")
}
}
if sm.smtpListener != nil {
if err := sm.closeSMTPServer(bridge); err != nil {
logrus.WithError(err).Error("Failed to stop SMTP server")
}
}
}
}
func (sm *ServerManager) handleClose(ctx context.Context, bridge *Bridge) { func (sm *ServerManager) handleClose(ctx context.Context, bridge *Bridge) {
// Close the IMAP server. // Close the IMAP server.
if err := sm.closeIMAPServer(ctx, bridge); err != nil { if err := sm.closeIMAPServer(ctx, bridge); err != nil {
@ -358,27 +395,45 @@ func (sm *ServerManager) closeSMTPServer(bridge *Bridge) error {
// This is because smtpServer.Serve() is called in a separate goroutine and might be executed // This is because smtpServer.Serve() is called in a separate goroutine and might be executed
// after we've already closed the server. However, go-smtp has a bug; it blocks on the listener // after we've already closed the server. However, go-smtp has a bug; it blocks on the listener
// even after the server has been closed. So we close the listener ourselves to unblock it. // even after the server has been closed. So we close the listener ourselves to unblock it.
logrus.Info("Closing SMTP server")
if sm.smtpListener != nil { if sm.smtpListener != nil {
logrus.Info("Closing SMTP Listener")
if err := sm.smtpListener.Close(); err != nil { if err := sm.smtpListener.Close(); err != nil {
return fmt.Errorf("failed to close SMTP listener: %w", err) return fmt.Errorf("failed to close SMTP listener: %w", err)
} }
sm.smtpListener = nil
} }
if err := sm.smtpServer.Close(); err != nil { if sm.smtpServer != nil {
logrus.WithError(err).Debug("Failed to close SMTP server (expected -- we close the listener ourselves)") logrus.Info("Closing SMTP server")
} if err := sm.smtpServer.Close(); err != nil {
logrus.WithError(err).Debug("Failed to close SMTP server (expected -- we close the listener ourselves)")
}
bridge.publish(events.SMTPServerStopped{}) sm.smtpServer = nil
bridge.publish(events.SMTPServerStopped{})
}
return nil return nil
} }
func (sm *ServerManager) closeIMAPServer(ctx context.Context, bridge *Bridge) error { func (sm *ServerManager) closeIMAPServer(ctx context.Context, bridge *Bridge) error {
logrus.Info("Closing IMAP server") if sm.imapListener != nil {
logrus.Info("Closing IMAP Listener")
if err := sm.imapListener.Close(); err != nil {
return fmt.Errorf("failed to close IMAP listener: %w", err)
}
sm.imapListener = nil
bridge.publish(events.IMAPServerStopped{})
}
if sm.imapServer != nil { if sm.imapServer != nil {
logrus.Info("Closing IMAP server")
if err := sm.imapServer.Close(ctx); err != nil { if err := sm.imapServer.Close(ctx); err != nil {
return fmt.Errorf("failed to close IMAP server: %w", err) return fmt.Errorf("failed to close IMAP server: %w", err)
} }
@ -386,16 +441,6 @@ func (sm *ServerManager) closeIMAPServer(ctx context.Context, bridge *Bridge) er
sm.imapServer = nil sm.imapServer = nil
} }
if sm.imapListener != nil {
if err := sm.imapListener.Close(); err != nil {
return fmt.Errorf("failed to close IMAP listener: %w", err)
}
sm.imapListener = nil
}
bridge.publish(events.IMAPServerStopped{})
return nil return nil
} }
@ -412,7 +457,11 @@ func (sm *ServerManager) restartIMAP(ctx context.Context, bridge *Bridge) error
bridge.publish(events.IMAPServerStopped{}) bridge.publish(events.IMAPServerStopped{})
} }
return sm.serveIMAP(ctx, bridge) if sm.shouldStartServers() {
return sm.serveIMAP(ctx, bridge)
}
return nil
} }
func (sm *ServerManager) restartSMTP(bridge *Bridge) error { func (sm *ServerManager) restartSMTP(bridge *Bridge) error {
@ -426,7 +475,11 @@ func (sm *ServerManager) restartSMTP(bridge *Bridge) error {
sm.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP) sm.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
return sm.serveSMTP(bridge) if sm.shouldStartServers() {
return sm.serveSMTP(bridge)
}
return nil
} }
func (sm *ServerManager) serveSMTP(bridge *Bridge) error { func (sm *ServerManager) serveSMTP(bridge *Bridge) error {
@ -515,6 +568,21 @@ func (sm *ServerManager) serveIMAP(ctx context.Context, bridge *Bridge) error {
return nil return nil
} }
func (sm *ServerManager) stopIMAPListener(bridge *Bridge) error {
logrus.Info("Stopping IMAP listener")
if sm.imapListener != nil {
if err := sm.imapListener.Close(); err != nil {
return err
}
sm.imapListener = nil
bridge.publish(events.IMAPServerStopped{})
}
return nil
}
func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge, newGluonDir string) error { func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge, newGluonDir string) error {
return safe.RLockRet(func() error { return safe.RLockRet(func() error {
currentGluonDir := bridge.GetGluonCacheDir() currentGluonDir := bridge.GetGluonCacheDir()
@ -527,6 +595,8 @@ func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge,
return fmt.Errorf("failed to close IMAP: %w", err) return fmt.Errorf("failed to close IMAP: %w", err)
} }
sm.loadedUserCount = 0
if err := bridge.moveGluonCacheDir(currentGluonDir, newGluonDir); err != nil { if err := bridge.moveGluonCacheDir(currentGluonDir, newGluonDir); err != nil {
logrus.WithError(err).Error("failed to move GluonCacheDir") logrus.WithError(err).Error("failed to move GluonCacheDir")
@ -560,15 +630,17 @@ func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge,
} }
sm.imapServer = imapServer sm.imapServer = imapServer
for _, bridgeUser := range bridge.users { for _, bridgeUser := range bridge.users {
if err := sm.handleAddIMAPUser(ctx, bridgeUser); err != nil { if err := sm.handleAddIMAPUser(ctx, bridgeUser); err != nil {
return fmt.Errorf("failed to add users to new IMAP server: %w", err) return fmt.Errorf("failed to add users to new IMAP server: %w", err)
} }
sm.loadedUserCount++
} }
if err := sm.serveIMAP(ctx, bridge); err != nil { if sm.shouldStartServers() {
return fmt.Errorf("failed to serve IMAP: %w", err) if err := sm.serveIMAP(ctx, bridge); err != nil {
return fmt.Errorf("failed to serve IMAP: %w", err)
}
} }
return nil return nil
@ -591,6 +663,10 @@ func (sm *ServerManager) handleRemoveGluonUser(ctx context.Context, userID strin
return sm.imapServer.RemoveUser(ctx, userID, true) return sm.imapServer.RemoveUser(ctx, userID, true)
} }
func (sm *ServerManager) shouldStartServers() bool {
return sm.loadedUserCount >= 1
}
type smRequestClose struct{} type smRequestClose struct{}
type smRequestRestartIMAP struct{} type smRequestRestartIMAP struct{}

View File

@ -0,0 +1,171 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge_test
import (
"context"
"fmt"
"testing"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/emersion/go-imap/client"
"github.com/stretchr/testify/require"
)
func TestServerManager_NoLoadedUsersNoServers(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.Error(t, err)
})
})
}
func TestServerManager_ServersStartAfterFirstConnectedUser(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
imapWaiter.Wait()
smtpWaiter.Wait()
})
})
}
func TestServerManager_ServersStopsAfterUserLogsOut(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
imapWaiter.Wait()
smtpWaiter.Wait()
imapWaiterStopped := waitForIMAPServerStopped(bridge)
defer imapWaiterStopped.Done()
require.NoError(t, bridge.LogoutUser(ctx, userID))
imapWaiterStopped.Wait()
})
})
}
func TestServerManager_ServersDoNotStopWhenThereIsStillOneActiveUser(t *testing.T) {
otherPassword := []byte("bar")
otherUser := "foo"
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
_, _, err := s.CreateUser(otherUser, otherPassword)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
userIDOther, err := bridge.LoginFull(ctx, otherUser, otherPassword, nil, nil)
require.NoError(t, err)
imapWaiter.Wait()
smtpWaiter.Wait()
evtCh, cancel := bridge.GetEvents(events.UserDeauth{})
defer cancel()
require.NoError(t, s.RevokeUser(userIDOther))
waitForEvent(t, evtCh, events.UserDeauth{})
imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, imapClient.Logout())
})
})
}
func TestServerManager_ServersStartIfAtLeastOneUserIsLoggedIn(t *testing.T) {
otherPassword := []byte("bar")
otherUser := "foo"
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
userIDOther, _, err := s.CreateUser(otherUser, otherPassword)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
_, err = bridge.LoginFull(ctx, otherUser, otherPassword, nil, nil)
require.NoError(t, err)
})
require.NoError(t, s.RevokeUser(userIDOther))
withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, imapClient.Logout())
})
})
}
func TestServerManager_NetworkLossStopsServers(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
imapWaiterStop := waitForIMAPServerStopped(bridge)
defer imapWaiterStop.Done()
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
imapWaiter.Wait()
netCtl.Disable()
imapWaiterStop.Wait()
netCtl.Enable()
imapWaiter.Wait()
})
})
}

View File

@ -112,15 +112,6 @@ func TestBridge_Sync(t *testing.T) {
info, err := b.GetUserInfo(userID) info, err := b.GetUserInfo(userID)
require.NoError(t, err) require.NoError(t, err)
require.True(t, info.State == bridge.Connected) require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
status, err := client.Select(`Folders/folder`, false)
require.NoError(t, err)
require.Less(t, status.Messages, uint32(numMsg))
} }
// Remove the network limit, allowing the sync to finish. // Remove the network limit, allowing the sync to finish.
@ -273,15 +264,6 @@ func TestBridge_SyncWithOngoingEvents(t *testing.T) {
info, err := b.GetUserInfo(userID) info, err := b.GetUserInfo(userID)
require.NoError(t, err) require.NoError(t, err)
require.True(t, info.State == bridge.Connected) require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
status, err := client.Select(`Folders/folder`, false)
require.NoError(t, err)
require.Less(t, status.Messages, uint32(numMsg))
} }
// Create a new mailbox and move that last 1/3 of the messages into it to simulate user // Create a new mailbox and move that last 1/3 of the messages into it to simulate user

View File

@ -141,6 +141,9 @@ func test_badMessage_badEvent(userFeedback func(t *testing.T, ctx context.Contex
}) })
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
userLoginAndSync(ctx, t, bridge, "user", password) userLoginAndSync(ctx, t, bridge, "user", password)
var messageIDs []string var messageIDs []string
@ -176,6 +179,8 @@ func test_badMessage_badEvent(userFeedback func(t *testing.T, ctx context.Contex
userFeedback(t, ctx, bridge, badUserID) userFeedback(t, ctx, bridge, badUserID)
smtpWaiter.Wait()
userContinueEventProcess(ctx, t, s, bridge) userContinueEventProcess(ctx, t, s, bridge)
}) })
}) })
@ -194,6 +199,9 @@ func TestBridge_User_BadMessage_NoBadEvent(t *testing.T) {
}) })
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
userLoginAndSync(ctx, t, bridge, "user", password) userLoginAndSync(ctx, t, bridge, "user", password)
var messageIDs []string var messageIDs []string
@ -217,6 +225,7 @@ func TestBridge_User_BadMessage_NoBadEvent(t *testing.T) {
require.NoError(t, c.DeleteMessage(ctx, messageIDs...)) require.NoError(t, c.DeleteMessage(ctx, messageIDs...))
}) })
smtpWaiter.Wait()
userContinueEventProcess(ctx, t, s, bridge) userContinueEventProcess(ctx, t, s, bridge)
}) })
}) })
@ -412,6 +421,17 @@ func TestBridge_User_DropConn_NoBadEvent(t *testing.T) {
}) })
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
var count int32
// The first 10 times bridge attempts to sync any of the messages, drop the connection.
s.AddStatusHook(func(req *http.Request) (int, bool) {
if strings.Contains(req.URL.Path, "/mail/v4/messages") {
if atomic.AddInt32(&count, 1) < 10 {
dropListener.DropAll()
}
}
return 0, false
})
userLoginAndSync(ctx, t, bridge, "user", password) userLoginAndSync(ctx, t, bridge, "user", password)
mocks.Reporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Any()).AnyTimes() mocks.Reporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Any()).AnyTimes()
@ -421,19 +441,6 @@ func TestBridge_User_DropConn_NoBadEvent(t *testing.T) {
createNumMessages(ctx, t, c, addrID, proton.InboxLabel, 10) createNumMessages(ctx, t, c, addrID, proton.InboxLabel, 10)
}) })
var count int
// The first 10 times bridge attempts to sync any of the messages, drop the connection.
s.AddStatusHook(func(req *http.Request) (int, bool) {
if strings.Contains(req.URL.Path, "/mail/v4/messages") {
if count++; count < 10 {
dropListener.DropAll()
}
}
return 0, false
})
info, err := bridge.QueryUserInfo("user") info, err := bridge.QueryUserInfo("user")
require.NoError(t, err) require.NoError(t, err)
@ -771,11 +778,16 @@ func TestBridge_User_CreateDisabledAddress(t *testing.T) {
func TestBridge_User_HandleParentLabelRename(t *testing.T) { func TestBridge_User_HandleParentLabelRename(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
info, err := bridge.QueryUserInfo(username) info, err := bridge.QueryUserInfo(username)
require.NoError(t, err) require.NoError(t, err)
imapWaiter.Wait()
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort())) client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))

View File

@ -167,6 +167,8 @@ type testCtx struct {
// This slice contains the dummy listeners that are intended to block network ports. // This slice contains the dummy listeners that are intended to block network ports.
dummyListeners []net.Listener dummyListeners []net.Listener
imapServerStarted bool
} }
type imapClient struct { type imapClient struct {

View File

@ -28,6 +28,7 @@ import (
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/bridge" "github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/algo" "github.com/ProtonMail/proton-bridge/v3/pkg/algo"
"github.com/bradenaw/juniper/iterator" "github.com/bradenaw/juniper/iterator"
@ -331,10 +332,20 @@ func (s *scenario) drafAtIndexWasMovedToTrashForAddressOfAccount(draftIndex int,
} }
func (s *scenario) userLogsInWithUsernameAndPassword(username, password string) error { func (s *scenario) userLogsInWithUsernameAndPassword(username, password string) error {
evtCh, cancel := s.t.bridge.GetEvents(events.SMTPServerReady{})
defer cancel()
userID, err := s.t.bridge.LoginFull(context.Background(), username, []byte(password), nil, nil) userID, err := s.t.bridge.LoginFull(context.Background(), username, []byte(password), nil, nil)
if err != nil { if err != nil {
s.t.pushError(err) s.t.pushError(err)
} else { } else {
// We need to wait for server to be up or we won't be able to connect. It should only happen once to avoid
// blocking on multiple Logins.
if !s.t.imapServerStarted {
<-evtCh
s.t.imapServerStarted = true
}
if userID != s.t.getUserByName(username).getUserID() { if userID != s.t.getUserByName(username).getUserID() {
return errors.New("user ID mismatch") return errors.New("user ID mismatch")
} }