From 509a767e50ce93c65f4bd26216e584514e2b1870 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Sun, 9 Oct 2022 23:05:52 +0200 Subject: [PATCH] GODT-1657: More stable sync, with some tests --- Makefile | 2 +- go.mod | 2 +- go.sum | 4 +- internal/app/bridge.go | 13 +- internal/bridge/bridge.go | 43 ++-- internal/bridge/bridge_test.go | 132 +++++++---- internal/bridge/imap.go | 18 ++ internal/bridge/mocks.go | 41 +--- internal/bridge/mocks/mocks.go | 53 ++--- internal/bridge/settings.go | 4 +- internal/bridge/settings_test.go | 39 +-- internal/bridge/smtp.go | 18 ++ internal/bridge/sync_test.go | 130 ++++++++++ internal/bridge/testdata/text-plain.eml | 6 + internal/bridge/types.go | 15 +- internal/bridge/updates.go | 3 + internal/bridge/user.go | 300 ++++++++++++++---------- internal/bridge/user_events.go | 2 +- internal/bridge/user_test.go | 136 ++++++++--- internal/events/user.go | 6 + internal/focus/service.go | 1 + internal/pool/job.go | 41 ---- internal/pool/pool.go | 147 ------------ internal/pool/pool_test.go | 163 ------------- internal/safe/map.go | 61 ++++- internal/try/try.go | 49 ++++ internal/try/try_test.go | 74 ++++++ internal/user/events.go | 7 +- internal/user/sync.go | 31 ++- internal/user/sync_build.go | 25 +- internal/user/sync_flusher.go | 26 +- internal/user/user.go | 20 +- internal/user/user_test.go | 4 +- internal/vault/user.go | 9 + internal/vault/user_test.go | 2 +- internal/vault/vault.go | 8 - tests/api_test.go | 4 +- tests/ctx_bridge_test.go | 7 +- tests/ctx_test.go | 8 +- tests/environment_test.go | 4 +- tests/user_test.go | 4 +- 41 files changed, 883 insertions(+), 779 deletions(-) create mode 100644 internal/bridge/sync_test.go create mode 100644 internal/bridge/testdata/text-plain.eml delete mode 100644 internal/pool/job.go delete mode 100644 internal/pool/pool.go delete mode 100644 internal/pool/pool_test.go create mode 100644 internal/try/try.go create mode 100644 internal/try/try_test.go diff --git a/Makefile b/Makefile index 592094da..f7c2b035 100644 --- a/Makefile +++ b/Makefile @@ -234,7 +234,7 @@ integration-test-bridge: ${MAKE} -C test test-bridge mocks: - mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/bridge TLSReporter,ProxyDialer,Autostarter > internal/bridge/mocks/mocks.go + mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/bridge TLSReporter,ProxyController,Autostarter > internal/bridge/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go lint: gofiles lint-golang lint-license lint-dependencies lint-changelog diff --git a/go.mod b/go.mod index 1fca1111..bb7eb8f6 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.0 github.com/urfave/cli/v2 v2.16.3 - gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7 + gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e golang.org/x/exp v0.0.0-20220921164117-439092de6870 golang.org/x/net v0.1.0 golang.org/x/sys v0.1.0 diff --git a/go.sum b/go.sum index 4df27693..9c210b8d 100644 --- a/go.sum +++ b/go.sum @@ -397,8 +397,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0= github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= -gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7 h1:Hef7jPRzcfLOvOUHYoQ6efaI7p7/aT5kpZDqJ29owNI= -gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7/go.mod h1:9nsslyEJn7Utbielp4c+hc7qT6hqIJ52aGFR/tX+tYk= +gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e h1:CTGaREzkbz7u98nKt6+xsca2bWML79lH1XGbodRo+MY= +gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e/go.mod h1:9nsslyEJn7Utbielp4c+hc7qT6hqIJ52aGFR/tX+tYk= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= diff --git a/internal/app/bridge.go b/internal/app/bridge.go index 4363f272..f9ae7e68 100644 --- a/internal/app/bridge.go +++ b/internal/app/bridge.go @@ -81,7 +81,18 @@ func newBridge(locations *locations.Locations, identifier *useragent.UserAgent) } // Create a new bridge. - bridge, err := bridge.New(constants.APIHost, locations, encVault, identifier, pinningDialer, proxyDialer, autostarter, updater, version) + bridge, err := bridge.New( + constants.APIHost, + locations, + encVault, + identifier, + pinningDialer, + dialer.CreateTransportWithDialer(proxyDialer), + proxyDialer, + autostarter, + updater, + version, + ) if err != nil { return nil, fmt.Errorf("could not create bridge: %w", err) } diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index c6e94d28..4a4cda29 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -33,10 +33,10 @@ type Bridge struct { users map[string]*user.User // api manages user API clients. - api *liteapi.Manager - cookieJar *cookies.Jar - proxyDialer ProxyDialer - identifier Identifier + api *liteapi.Manager + cookieJar *cookies.Jar + proxyCtl ProxyController + identifier Identifier // watchers holds all registered event watchers. watchers []*watcher.Watcher[events.Event] @@ -81,15 +81,16 @@ func New( vault *vault.Vault, // the bridge's encrypted data store identifier Identifier, // the identifier to keep track of the user agent tlsReporter TLSReporter, // the TLS reporter to report TLS errors - proxyDialer ProxyDialer, // the DoH dialer + roundTripper http.RoundTripper, // the round tripper to use for API requests + proxyCtl ProxyController, // the DoH controller autostarter Autostarter, // the autostarter to manage autostart settings updater Updater, // the updater to fetch and install updates curVersion *semver.Version, // the current version of the bridge ) (*Bridge, error) { if vault.GetProxyAllowed() { - proxyDialer.AllowProxy() + proxyCtl.AllowProxy() } else { - proxyDialer.DisallowProxy() + proxyCtl.DisallowProxy() } cookieJar, err := cookies.NewCookieJar(vault) @@ -101,7 +102,7 @@ func New( liteapi.WithHostURL(apiURL), liteapi.WithAppVersion(constants.AppVersion), liteapi.WithCookieJar(cookieJar), - liteapi.WithTransport(&http.Transport{DialTLSContext: proxyDialer.DialTLSContext}), + liteapi.WithTransport(roundTripper), ) tlsConfig, err := loadTLSConfig(vault) @@ -139,10 +140,10 @@ func New( vault: vault, users: make(map[string]*user.User), - api: api, - cookieJar: cookieJar, - proxyDialer: proxyDialer, - identifier: identifier, + api: api, + cookieJar: cookieJar, + proxyCtl: proxyCtl, + identifier: identifier, tlsConfig: tlsConfig, imapServer: imapServer, @@ -179,6 +180,10 @@ func New( return nil }) + if err := bridge.loadUsers(); err != nil { + return nil, fmt.Errorf("failed to load users: %w", err) + } + go func() { for range tlsReporter.GetTLSIssueCh() { bridge.publish(events.TLSIssue{}) @@ -197,10 +202,6 @@ func New( } }() - if err := bridge.loadUsers(context.Background()); err != nil { - return nil, fmt.Errorf("failed to load connected users: %w", err) - } - if err := bridge.serveIMAP(); err != nil { bridge.PushError(ErrServeIMAP) } @@ -309,14 +310,8 @@ func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) { func (bridge *Bridge) onStatusUp() { bridge.publish(events.ConnStatusUp{}) - for _, userID := range bridge.vault.GetUserIDs() { - if _, ok := bridge.users[userID]; !ok { - if vaultUser, err := bridge.vault.GetUser(userID); err != nil { - logrus.WithError(err).Error("Failed to get user from vault") - } else if err := bridge.loadUser(context.Background(), vaultUser); err != nil { - logrus.WithError(err).Error("Failed to load user") - } - } + if err := bridge.loadUsers(); err != nil { + logrus.WithError(err).Error("Failed to load users") } } diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 52c8f1f1..b39f8d1c 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -2,6 +2,7 @@ package bridge_test import ( "context" + "crypto/tls" "net/http" "os" "testing" @@ -21,6 +22,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/tests" "github.com/bradenaw/juniper/xslices" "github.com/stretchr/testify/require" + "gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server/backend" ) @@ -41,14 +43,14 @@ func init() { } func TestBridge_ConnStatus(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Get a stream of connection status events. eventCh, done := bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{}) defer done() // Simulate network disconnect. - dialer.SetCanDial(false) + netCtl.Disable() // Trigger some operation that will fail due to the network disconnect. _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) @@ -58,7 +60,7 @@ func TestBridge_ConnStatus(t *testing.T) { require.Equal(t, events.ConnStatusDown{}, <-eventCh) // Simulate network reconnect. - dialer.SetCanDial(true) + netCtl.Enable() // Trigger some operation that will succeed due to the network reconnect. userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil) @@ -72,8 +74,8 @@ func TestBridge_ConnStatus(t *testing.T) { } func TestBridge_TLSIssue(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Get a stream of TLS issue events. tlsEventCh, done := bridge.GetEvents(events.TLSIssue{}) defer done() @@ -90,8 +92,8 @@ func TestBridge_TLSIssue(t *testing.T) { } func TestBridge_Focus(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Get a stream of TLS issue events. raiseCh, done := bridge.GetEvents(events.Raise{}) defer done() @@ -106,14 +108,14 @@ func TestBridge_Focus(t *testing.T) { } func TestBridge_UserAgent(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { var calls []server.Call s.AddCallWatcher(func(call server.Call) { calls = append(calls, call) }) - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Set the platform to something other than the default. bridge.SetCurrentPlatform("platform") @@ -131,7 +133,7 @@ func TestBridge_UserAgent(t *testing.T) { } func TestBridge_Cookies(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { var calls []server.Call s.AddCallWatcher(func(call server.Call) { @@ -141,7 +143,7 @@ func TestBridge_Cookies(t *testing.T) { var sessionID string // Start bridge and add a user so that API assigns us a session ID via cookie. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) require.NoError(t, err) @@ -152,7 +154,7 @@ func TestBridge_Cookies(t *testing.T) { }) // Start bridge again and check that it uses the same session ID. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { cookie, err := (&http.Request{Header: calls[len(calls)-1].Header}).Cookie("Session-Id") require.NoError(t, err) @@ -162,8 +164,8 @@ func TestBridge_Cookies(t *testing.T) { } func TestBridge_CheckUpdate(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Disable autoupdate for this test. require.NoError(t, bridge.SetAutoUpdate(false)) @@ -201,8 +203,8 @@ func TestBridge_CheckUpdate(t *testing.T) { } func TestBridge_AutoUpdate(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Enable autoupdate for this test. require.NoError(t, bridge.SetAutoUpdate(true)) @@ -229,8 +231,8 @@ func TestBridge_AutoUpdate(t *testing.T) { } func TestBridge_ManualUpdate(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Disable autoupdate for this test. require.NoError(t, bridge.SetAutoUpdate(false)) @@ -258,8 +260,8 @@ func TestBridge_ManualUpdate(t *testing.T) { } func TestBridge_ForceUpdate(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Get a stream of update events. updateCh, done := bridge.GetEvents(events.UpdateForced{}) defer done() @@ -278,11 +280,11 @@ func TestBridge_ForceUpdate(t *testing.T) { } func TestBridge_BadVaultKey(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { var userID string // Login a user. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil) require.NoError(t, err) @@ -290,27 +292,27 @@ func TestBridge_BadVaultKey(t *testing.T) { }) // Start bridge with the correct vault key -- it should load the users correctly. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs()) }) // Start bridge with a bad vault key, the vault will be wiped and bridge will show no users. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Empty(t, bridge.GetUserIDs()) }) // Start bridge with a nil vault key, the vault will be wiped and bridge will show no users. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Empty(t, bridge.GetUserIDs()) }) }) } func TestBridge_MissingGluonDir(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { var gluonDir string - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) require.NoError(t, err) @@ -325,20 +327,20 @@ func TestBridge_MissingGluonDir(t *testing.T) { require.NoError(t, os.RemoveAll(gluonDir)) // Bridge starts but can't find the gluon dir; there should be no error. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // ... }) }) } -// withEnv creates the full test environment and runs the tests. -func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte)) { +// withTLSEnv creates the full test environment and runs the tests. +func withTLSEnv(t *testing.T, tests func(context.Context, *server.Server, *liteapi.NetCtl, bridge.Locator, []byte)) { // Create test API. server := server.NewTLS() defer server.Close() // Add test user. - _, _, err := server.CreateUser(username, string(password), username+"@pm.me") + _, _, err := server.CreateUser(username, username+"@pm.me", password) require.NoError(t, err) // Generate a random vault key. @@ -349,23 +351,56 @@ func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Create a net controller so we can simulate network connectivity issues. + netCtl := liteapi.NewNetCtl() + + // Create a locations object to provide temporary locations for bridge data during the test. + locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name") + // Run the tests. - tests( - ctx, - server, - bridge.NewTestDialer(), - locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name"), - vaultKey, - ) + tests(ctx, server, netCtl, locations, vaultKey) +} + +// withEnv creates the full test environment and runs the tests. +func withEnv(t *testing.T, server *server.Server, tests func(context.Context, *liteapi.NetCtl, bridge.Locator, []byte)) { + // Add test user. + _, _, err := server.CreateUser(username, username+"@pm.me", password) + require.NoError(t, err) + + // Generate a random vault key. + vaultKey, err := crypto.RandomToken(32) + require.NoError(t, err) + + // Create a context used for the test. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create a net controller so we can simulate network connectivity issues. + netCtl := liteapi.NewNetCtl() + + // Create a locations object to provide temporary locations for bridge data during the test. + locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name") + + // Run the tests. + tests(ctx, netCtl, locations, vaultKey) } // withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done. -func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) { +func withBridge( + t *testing.T, + ctx context.Context, + apiURL string, + netCtl *liteapi.NetCtl, + locator bridge.Locator, + vaultKey []byte, + tests func(*bridge.Bridge, *bridge.Mocks), +) { // Create the mock objects used in the tests. - mocks := bridge.NewMocks(t, dialer, v2_3_0, v2_3_0) + mocks := bridge.NewMocks(t, v2_3_0, v2_3_0) + defer mocks.Close() // Bridge will enable the proxy by default at startup. - mocks.ProxyDialer.EXPECT().AllowProxy() + mocks.ProxyCtl.EXPECT().AllowProxy() // Get the path to the vault. vaultDir, err := locator.ProvideSettingsPath() @@ -380,7 +415,18 @@ func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge require.NoError(t, vault.SetSMTPPort(0)) // Create a new bridge. - bridge, err := bridge.New(apiURL, locator, vault, useragent.New(), mocks.TLSReporter, mocks.ProxyDialer, mocks.Autostarter, mocks.Updater, v2_3_0) + bridge, err := bridge.New( + apiURL, + locator, + vault, + useragent.New(), + mocks.TLSReporter, + liteapi.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(), + mocks.ProxyCtl, + mocks.Autostarter, + mocks.Updater, + v2_3_0, + ) require.NoError(t, err) // Close the bridge when done. diff --git a/internal/bridge/imap.go b/internal/bridge/imap.go index 7dabf692..acb8ab7d 100644 --- a/internal/bridge/imap.go +++ b/internal/bridge/imap.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "io/fs" + "net" "os" + "strconv" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gluon" @@ -33,6 +35,22 @@ func (bridge *Bridge) serveIMAP() error { return fmt.Errorf("failed to serve IMAP: %w", err) } + _, port, err := net.SplitHostPort(imapListener.Addr().String()) + if err != nil { + return fmt.Errorf("failed to get IMAP listener address: %w", err) + } + + portInt, err := strconv.Atoi(port) + if err != nil { + return fmt.Errorf("failed to convert IMAP listener port to int: %w", err) + } + + if portInt != bridge.vault.GetIMAPPort() { + if err := bridge.vault.SetIMAPPort(portInt); err != nil { + return fmt.Errorf("failed to update IMAP port in vault: %w", err) + } + } + go func() { for err := range bridge.imapServer.GetErrorCh() { logrus.WithError(err).Error("IMAP server error") diff --git a/internal/bridge/mocks.go b/internal/bridge/mocks.go index 17673c90..7c34fe7b 100644 --- a/internal/bridge/mocks.go +++ b/internal/bridge/mocks.go @@ -1,10 +1,6 @@ package bridge import ( - "context" - "crypto/tls" - "errors" - "net" "os" "testing" @@ -15,7 +11,7 @@ import ( ) type Mocks struct { - ProxyDialer *mocks.MockProxyDialer + ProxyCtl *mocks.MockProxyController TLSReporter *mocks.MockTLSReporter TLSIssueCh chan struct{} @@ -23,11 +19,11 @@ type Mocks struct { Autostarter *mocks.MockAutostarter } -func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Version) *Mocks { +func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks { ctl := gomock.NewController(tb) mocks := &Mocks{ - ProxyDialer: mocks.NewMockProxyDialer(ctl), + ProxyCtl: mocks.NewMockProxyController(ctl), TLSReporter: mocks.NewMockTLSReporter(ctl), TLSIssueCh: make(chan struct{}), @@ -35,41 +31,14 @@ func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Versio Autostarter: mocks.NewMockAutostarter(ctl), } - // When using the proxy dialer, we want to use the test dialer. - mocks.ProxyDialer.EXPECT().DialTLSContext( - gomock.Any(), - gomock.Any(), - gomock.Any(), - ).DoAndReturn(func(ctx context.Context, network, address string) (net.Conn, error) { - return dialer.DialTLSContext(ctx, network, address) - }).AnyTimes() - // When getting the TLS issue channel, we want to return the test channel. mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes() return mocks } -type TestDialer struct { - canDial bool -} - -func NewTestDialer() *TestDialer { - return &TestDialer{ - canDial: true, - } -} - -func (d *TestDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) { - if !d.canDial { - return nil, errors.New("cannot dial") - } - - return (&tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}).DialContext(ctx, network, address) -} - -func (d *TestDialer) SetCanDial(canDial bool) { - d.canDial = canDial +func (mocks *Mocks) Close() { + close(mocks.TLSIssueCh) } type TestLocationsProvider struct { diff --git a/internal/bridge/mocks/mocks.go b/internal/bridge/mocks/mocks.go index 547a9bd2..d519ee43 100644 --- a/internal/bridge/mocks/mocks.go +++ b/internal/bridge/mocks/mocks.go @@ -1,12 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyDialer,Autostarter) +// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyController,Autostarter) // Package mocks is a generated GoMock package. package mocks import ( - context "context" - net "net" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -49,66 +47,51 @@ func (mr *MockTLSReporterMockRecorder) GetTLSIssueCh() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh)) } -// MockProxyDialer is a mock of ProxyDialer interface. -type MockProxyDialer struct { +// MockProxyController is a mock of ProxyController interface. +type MockProxyController struct { ctrl *gomock.Controller - recorder *MockProxyDialerMockRecorder + recorder *MockProxyControllerMockRecorder } -// MockProxyDialerMockRecorder is the mock recorder for MockProxyDialer. -type MockProxyDialerMockRecorder struct { - mock *MockProxyDialer +// MockProxyControllerMockRecorder is the mock recorder for MockProxyController. +type MockProxyControllerMockRecorder struct { + mock *MockProxyController } -// NewMockProxyDialer creates a new mock instance. -func NewMockProxyDialer(ctrl *gomock.Controller) *MockProxyDialer { - mock := &MockProxyDialer{ctrl: ctrl} - mock.recorder = &MockProxyDialerMockRecorder{mock} +// NewMockProxyController creates a new mock instance. +func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController { + mock := &MockProxyController{ctrl: ctrl} + mock.recorder = &MockProxyControllerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockProxyDialer) EXPECT() *MockProxyDialerMockRecorder { +func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder { return m.recorder } // AllowProxy mocks base method. -func (m *MockProxyDialer) AllowProxy() { +func (m *MockProxyController) AllowProxy() { m.ctrl.T.Helper() m.ctrl.Call(m, "AllowProxy") } // AllowProxy indicates an expected call of AllowProxy. -func (mr *MockProxyDialerMockRecorder) AllowProxy() *gomock.Call { +func (mr *MockProxyControllerMockRecorder) AllowProxy() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyDialer)(nil).AllowProxy)) -} - -// DialTLSContext mocks base method. -func (m *MockProxyDialer) DialTLSContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DialTLSContext", arg0, arg1, arg2) - ret0, _ := ret[0].(net.Conn) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DialTLSContext indicates an expected call of DialTLSContext. -func (mr *MockProxyDialerMockRecorder) DialTLSContext(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLSContext", reflect.TypeOf((*MockProxyDialer)(nil).DialTLSContext), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyController)(nil).AllowProxy)) } // DisallowProxy mocks base method. -func (m *MockProxyDialer) DisallowProxy() { +func (m *MockProxyController) DisallowProxy() { m.ctrl.T.Helper() m.ctrl.Call(m, "DisallowProxy") } // DisallowProxy indicates an expected call of DisallowProxy. -func (mr *MockProxyDialerMockRecorder) DisallowProxy() *gomock.Call { +func (mr *MockProxyControllerMockRecorder) DisallowProxy() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyDialer)(nil).DisallowProxy)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyController)(nil).DisallowProxy)) } // MockAutostarter is a mock of Autostarter interface. diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index acce5731..a83d0ca7 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -138,9 +138,9 @@ func (bridge *Bridge) GetProxyAllowed() bool { func (bridge *Bridge) SetProxyAllowed(allowed bool) error { if allowed { - bridge.proxyDialer.AllowProxy() + bridge.proxyCtl.AllowProxy() } else { - bridge.proxyDialer.DisallowProxy() + bridge.proxyCtl.DisallowProxy() } return bridge.vault.SetProxyAllowed(allowed) diff --git a/internal/bridge/settings_test.go b/internal/bridge/settings_test.go index ae94597b..174a9e4d 100644 --- a/internal/bridge/settings_test.go +++ b/internal/bridge/settings_test.go @@ -7,12 +7,13 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/stretchr/testify/require" + "gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi/server" ) func TestBridge_Settings_GluonDir(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Create a user. _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) require.NoError(t, err) @@ -34,8 +35,8 @@ func TestBridge_Settings_GluonDir(t *testing.T) { } func TestBridge_Settings_IMAPPort(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { curPort := bridge.GetIMAPPort() // Set the port to 1144. @@ -51,8 +52,8 @@ func TestBridge_Settings_IMAPPort(t *testing.T) { } func TestBridge_Settings_IMAPSSL(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, IMAP SSL is disabled. require.False(t, bridge.GetIMAPSSL()) @@ -66,8 +67,8 @@ func TestBridge_Settings_IMAPSSL(t *testing.T) { } func TestBridge_Settings_SMTPPort(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { curPort := bridge.GetSMTPPort() // Set the port to 1024. @@ -84,8 +85,8 @@ func TestBridge_Settings_SMTPPort(t *testing.T) { } func TestBridge_Settings_SMTPSSL(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, SMTP SSL is disabled. require.False(t, bridge.GetSMTPSSL()) @@ -99,13 +100,13 @@ func TestBridge_Settings_SMTPSSL(t *testing.T) { } func TestBridge_Settings_Proxy(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, proxy is allowed. require.True(t, bridge.GetProxyAllowed()) // Disallow proxy. - mocks.ProxyDialer.EXPECT().DisallowProxy() + mocks.ProxyCtl.EXPECT().DisallowProxy() require.NoError(t, bridge.SetProxyAllowed(false)) // Get the new setting. @@ -115,8 +116,8 @@ func TestBridge_Settings_Proxy(t *testing.T) { } func TestBridge_Settings_Autostart(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, autostart is disabled. require.False(t, bridge.GetAutostart()) @@ -131,8 +132,8 @@ func TestBridge_Settings_Autostart(t *testing.T) { } func TestBridge_Settings_FirstStart(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, first start is true. require.True(t, bridge.GetFirstStart()) @@ -146,8 +147,8 @@ func TestBridge_Settings_FirstStart(t *testing.T) { } func TestBridge_Settings_FirstStartGUI(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, first start is true. require.True(t, bridge.GetFirstStartGUI()) diff --git a/internal/bridge/smtp.go b/internal/bridge/smtp.go index 07978646..818eb4bd 100644 --- a/internal/bridge/smtp.go +++ b/internal/bridge/smtp.go @@ -3,6 +3,8 @@ package bridge import ( "crypto/tls" "fmt" + "net" + "strconv" "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/emersion/go-sasl" @@ -22,6 +24,22 @@ func (bridge *Bridge) serveSMTP() error { } }() + _, port, err := net.SplitHostPort(smtpListener.Addr().String()) + if err != nil { + return fmt.Errorf("failed to get SMTP listener address: %w", err) + } + + portInt, err := strconv.Atoi(port) + if err != nil { + return fmt.Errorf("failed to convert SMTP listener port to int: %w", err) + } + + if portInt != bridge.vault.GetSMTPPort() { + if err := bridge.vault.SetSMTPPort(portInt); err != nil { + return fmt.Errorf("failed to update SMTP port in vault: %w", err) + } + } + return nil } diff --git a/internal/bridge/sync_test.go b/internal/bridge/sync_test.go new file mode 100644 index 00000000..a61f6f15 --- /dev/null +++ b/internal/bridge/sync_test.go @@ -0,0 +1,130 @@ +package bridge_test + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/ProtonMail/proton-bridge/v2/internal/bridge" + "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/emersion/go-imap/client" + "github.com/stretchr/testify/require" + "gitlab.protontech.ch/go/liteapi" + "gitlab.protontech.ch/go/liteapi/server" +) + +func TestBridge_Sync(t *testing.T) { + s := server.New() + defer s.Close() + + numMsg := 1 << 10 + + withEnv(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password) + require.NoError(t, err) + + labelID, err := s.CreateLabel(userID, "folder", liteapi.LabelTypeFolder) + require.NoError(t, err) + + literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml")) + require.NoError(t, err) + + for i := 0; i < numMsg; i++ { + messageID, err := s.CreateMessage(userID, addrID, literal, liteapi.MessageFlagReceived, false, false) + require.NoError(t, err) + require.NoError(t, s.LabelMessage(userID, messageID, labelID)) + } + + var read uint64 + + netCtl.OnRead(func(b []byte) { + read += uint64(len(b)) + }) + + // The initial user should be fully synced. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + syncCh, done := bridge.GetEvents(events.SyncFinished{}) + defer done() + + userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil) + require.NoError(t, err) + + require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID) + }) + + // If we then connect an IMAP client, it should see all the messages. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + info, err := bridge.GetUserInfo(userID) + require.NoError(t, err) + require.True(t, info.Connected) + + client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort())) + require.NoError(t, err) + require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + defer client.Logout() + + status, err := client.Select(`Folders/folder`, false) + require.NoError(t, err) + require.Equal(t, uint32(numMsg), status.Messages) + }) + + // Now let's remove the user and simulate a network error. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + require.NoError(t, bridge.DeleteUser(ctx, userID)) + }) + + // Pretend we can only sync 2/3 of the original messages. + netCtl.SetReadLimit(2 * read / 3) + + // Login the user; its sync should fail. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + syncCh, done := bridge.GetEvents(events.SyncFailed{}) + defer done() + + userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil) + require.NoError(t, err) + + require.Equal(t, userID, (<-syncCh).(events.SyncFailed).UserID) + + info, err := bridge.GetUserInfo(userID) + require.NoError(t, err) + require.True(t, info.Connected) + + client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort())) + require.NoError(t, err) + require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + defer client.Logout() + + status, err := client.Select(`Folders/folder`, false) + require.NoError(t, err) + require.Less(t, status.Messages, uint32(numMsg)) + }) + + // Remove the network limit, allowing the sync to finish. + netCtl.SetReadLimit(0) + + // Login the user; its sync should now finish. + // If we then connect an IMAP client, it should eventually see all the messages. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + syncCh, done := bridge.GetEvents(events.SyncFinished{}) + defer done() + + require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID) + + info, err := bridge.GetUserInfo(userID) + require.NoError(t, err) + require.True(t, info.Connected) + + client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort())) + require.NoError(t, err) + require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + defer client.Logout() + + status, err := client.Select(`Folders/folder`, false) + require.NoError(t, err) + require.Equal(t, uint32(numMsg), status.Messages) + }) + }) +} diff --git a/internal/bridge/testdata/text-plain.eml b/internal/bridge/testdata/text-plain.eml new file mode 100644 index 00000000..bfcdb6e1 --- /dev/null +++ b/internal/bridge/testdata/text-plain.eml @@ -0,0 +1,6 @@ +To: recipient@pm.me +From: sender@pm.me +Subject: Test +Content-Type: text/plain; charset=utf-8 + +Test \ No newline at end of file diff --git a/internal/bridge/types.go b/internal/bridge/types.go index 55f1dd42..32b08135 100644 --- a/internal/bridge/types.go +++ b/internal/bridge/types.go @@ -1,9 +1,6 @@ package bridge import ( - "context" - "net" - "github.com/ProtonMail/proton-bridge/v2/internal/updater" ) @@ -21,17 +18,15 @@ type Identifier interface { SetPlatform(platform string) } -type TLSReporter interface { - GetTLSIssueCh() <-chan struct{} -} - -type ProxyDialer interface { - DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) - +type ProxyController interface { AllowProxy() DisallowProxy() } +type TLSReporter interface { + GetTLSIssueCh() <-chan struct{} +} + type Autostarter interface { Enable() error Disable() error diff --git a/internal/bridge/updates.go b/internal/bridge/updates.go index 85e43ca1..de1524a4 100644 --- a/internal/bridge/updates.go +++ b/internal/bridge/updates.go @@ -18,6 +18,9 @@ func (bridge *Bridge) watchForUpdates() error { go func() { for { select { + case <-bridge.stopCh: + return + case <-bridge.updateCheckCh: case <-ticker.C: } diff --git a/internal/bridge/user.go b/internal/bridge/user.go index e960a4cf..4481364d 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -6,6 +6,7 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/try" "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/go-resty/resty/v2" @@ -82,76 +83,76 @@ func (bridge *Bridge) LoginUser( ) (string, error) { client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password) if err != nil { - return "", err + return "", fmt.Errorf("failed to create new API client: %w", err) } - if _, ok := bridge.users[auth.UserID]; ok { - return "", ErrUserAlreadyLoggedIn - } + userID, err := try.CatchVal( + func() (string, error) { + if _, ok := bridge.users[auth.UserID]; ok { + return "", ErrUserAlreadyLoggedIn + } - if auth.TwoFA.Enabled == liteapi.TOTPEnabled { - totp, err := getTOTP() - if err != nil { - return "", err - } + if auth.TwoFA.Enabled == liteapi.TOTPEnabled { + totp, err := getTOTP() + if err != nil { + return "", fmt.Errorf("failed to get TOTP: %w", err) + } - if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil { - return "", err - } - } + if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil { + return "", fmt.Errorf("failed to authorize 2FA: %w", err) + } + } - var keyPass []byte + var keyPass []byte - if auth.PasswordMode == liteapi.TwoPasswordMode { - pass, err := getKeyPass() - if err != nil { - return "", err - } + if auth.PasswordMode == liteapi.TwoPasswordMode { + userKeyPass, err := getKeyPass() + if err != nil { + return "", fmt.Errorf("failed to get key password: %w", err) + } - keyPass = pass - } else { - keyPass = password - } + keyPass = userKeyPass + } else { + keyPass = password + } - apiUser, err := client.GetUser(ctx) + return bridge.loginUser(ctx, client, auth.UID, auth.RefreshToken, keyPass) + }, + func() error { + return client.AuthDelete(ctx) + }, + func() error { + bridge.deleteUser(ctx, auth.UserID) + return nil + }, + ) if err != nil { - return "", err + return "", fmt.Errorf("failed to login user: %w", err) } - salts, err := client.GetSalts(ctx) - if err != nil { - return "", err - } + bridge.publish(events.UserLoggedIn{ + UserID: userID, + }) - saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID) - if err != nil { - return "", err - } - - if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil { - return "", err - } - - return auth.UserID, nil + return userID, nil } // LogoutUser logs out the given user. func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error { - return bridge.logoutUser(ctx, userID, true, false) + if err := bridge.logoutUser(ctx, userID); err != nil { + return fmt.Errorf("failed to logout user: %w", err) + } + + bridge.publish(events.UserLoggedOut{ + UserID: userID, + }) + + return nil } // DeleteUser deletes the given user. -// If it is authorized, it is logged out first. func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error { - if bridge.users[userID] != nil { - if err := bridge.logoutUser(ctx, userID, true, true); err != nil { - return err - } - } - - if err := bridge.vault.DeleteUser(userID); err != nil { - return err - } + bridge.deleteUser(ctx, userID) bridge.publish(events.UserDeleted{ UserID: userID, @@ -193,53 +194,91 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va return nil } -// loadUsers loads authorized users from the vault. -func (bridge *Bridge) loadUsers(ctx context.Context) error { - for _, userID := range bridge.vault.GetUserIDs() { - user, err := bridge.vault.GetUser(userID) - if err != nil { - return err +func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) { + apiUser, err := client.GetUser(ctx) + if err != nil { + return "", fmt.Errorf("failed to get API user: %w", err) + } + + salts, err := client.GetSalts(ctx) + if err != nil { + return "", fmt.Errorf("failed to get key salts: %w", err) + } + + saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID) + if err != nil { + return "", fmt.Errorf("failed to salt key password: %w", err) + } + + if err := bridge.addUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass); err != nil { + return "", fmt.Errorf("failed to add bridge user: %w", err) + } + + return apiUser.ID, nil +} + +// loadUsers is a loop that, when polled, attempts to load authorized users from the vault. +func (bridge *Bridge) loadUsers() error { + return bridge.vault.ForUser(func(user *vault.User) error { + if _, ok := bridge.users[user.UserID()]; ok { + return nil } if user.AuthUID() == "" { - continue + return nil } - if err := bridge.loadUser(ctx, user); err != nil { - logrus.WithError(err).Error("Failed to load connected user") - + if err := bridge.loadUser(user); err != nil { if _, ok := err.(*resty.ResponseError); ok { - if err := bridge.vault.ClearUser(userID); err != nil { + logrus.WithError(err).Error("Failed to load connected user, clearing its secrets from vault") + + if err := user.Clear(); err != nil { logrus.WithError(err).Error("Failed to clear user") } + } else { + logrus.WithError(err).Error("Failed to load connected user") } - continue + return nil } - } - return nil + bridge.publish(events.UserLoaded{ + UserID: user.UserID(), + }) + + return nil + }) } -func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error { +// loadUser loads an existing user from the vault. +func (bridge *Bridge) loadUser(user *vault.User) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef()) if err != nil { return fmt.Errorf("failed to create API client: %w", err) } - apiUser, err := client.GetUser(ctx) - if err != nil { - return fmt.Errorf("failed to get user: %w", err) - } + if err := try.Catch( + func() error { + apiUser, err := client.GetUser(ctx) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) + } - if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil { - return fmt.Errorf("failed to add user: %w", err) + return bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass()) + }, + func() error { + return client.AuthDelete(ctx) + }, + func() error { + return bridge.logoutUser(ctx, user.UserID()) + }, + ); err != nil { + return fmt.Errorf("failed to load user: %w", err) } - bridge.publish(events.UserLoggedIn{ - UserID: user.UserID(), - }) - return nil } @@ -304,10 +343,6 @@ func (bridge *Bridge) addUser( return nil }) - bridge.publish(events.UserLoggedIn{ - UserID: user.ID(), - }) - return nil } @@ -363,54 +398,6 @@ func (bridge *Bridge) addExistingUser( return user, nil } -// logoutUser closes and removes the user with the given ID. -// If withAPI is true, the user will additionally be logged out from API. -// If withFiles is true, the user's files will be deleted. -func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, withFiles bool) error { - user, ok := bridge.users[userID] - if !ok { - return ErrNoSuchUser - } - - if err := bridge.smtpBackend.removeUser(user); err != nil { - return fmt.Errorf("failed to remove SMTP user: %w", err) - } - - for _, gluonID := range user.GetGluonIDs() { - if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil { - return fmt.Errorf("failed to remove IMAP user: %w", err) - } - } - - if withAPI { - if err := user.Logout(ctx); err != nil { - return fmt.Errorf("failed to logout user: %w", err) - } - } - - if err := user.Close(); err != nil { - return fmt.Errorf("failed to close user: %w", err) - } - - if err := bridge.vault.ClearUser(userID); err != nil { - return fmt.Errorf("failed to clear user: %w", err) - } - - if withFiles { - if err := bridge.vault.DeleteUser(userID); err != nil { - return fmt.Errorf("failed to delete user: %w", err) - } - } - - delete(bridge.users, userID) - - bridge.publish(events.UserLoggedOut{ - UserID: userID, - }) - - return nil -} - // addIMAPUser connects the given user to gluon. func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error { imapConn, err := user.NewIMAPConnectors() @@ -438,6 +425,65 @@ func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error { return nil } +// logoutUser logs the given user out from bridge. +func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error { + user, ok := bridge.users[userID] + if !ok { + return ErrNoSuchUser + } + + if err := bridge.smtpBackend.removeUser(user); err != nil { + logrus.WithError(err).Error("Failed to remove user from SMTP backend") + } + + for _, gluonID := range user.GetGluonIDs() { + if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil { + logrus.WithError(err).Error("Failed to remove IMAP user") + } + } + + if err := user.Logout(ctx); err != nil { + logrus.WithError(err).Error("Failed to logout user") + } + + if err := user.Close(); err != nil { + logrus.WithError(err).Error("Failed to close user") + } + + delete(bridge.users, userID) + + return nil +} + +// deleteUser deletes the given user from bridge. +func (bridge *Bridge) deleteUser(ctx context.Context, userID string) { + if user, ok := bridge.users[userID]; ok { + if err := bridge.smtpBackend.removeUser(user); err != nil { + logrus.WithError(err).Error("Failed to remove user from SMTP backend") + } + + for _, gluonID := range user.GetGluonIDs() { + if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil { + logrus.WithError(err).Error("Failed to remove IMAP user") + } + } + + if err := user.Logout(ctx); err != nil { + logrus.WithError(err).Error("Failed to logout user") + } + + if err := user.Close(); err != nil { + logrus.WithError(err).Error("Failed to close user") + } + } + + if err := bridge.vault.DeleteUser(userID); err != nil { + logrus.WithError(err).Error("Failed to delete user from vault") + } + + delete(bridge.users, userID) +} + // getUserInfo returns information about a disconnected user. func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo { return UserInfo{ diff --git a/internal/bridge/user_events.go b/internal/bridge/user_events.go index 0458aff7..80bb5bbc 100644 --- a/internal/bridge/user_events.go +++ b/internal/bridge/user_events.go @@ -27,7 +27,7 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even } case events.UserDeauth: - if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil { + if err := bridge.logoutUser(context.Background(), event.UserID); err != nil { return fmt.Errorf("failed to logout user: %w", err) } } diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index be928433..931035a1 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -9,17 +9,18 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/stretchr/testify/require" + "gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi/server" ) func TestBridge_WithoutUsers(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Empty(t, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) }) - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Empty(t, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) }) @@ -27,8 +28,8 @@ func TestBridge_WithoutUsers(t *testing.T) { } func TestBridge_Login(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID, err := bridge.LoginUser(ctx, username, password, nil, nil) require.NoError(t, err) @@ -41,8 +42,8 @@ func TestBridge_Login(t *testing.T) { } func TestBridge_LoginLogoutLogin(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -69,8 +70,8 @@ func TestBridge_LoginLogoutLogin(t *testing.T) { } func TestBridge_LoginDeleteLogin(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -97,8 +98,8 @@ func TestBridge_LoginDeleteLogin(t *testing.T) { } func TestBridge_LoginDeauthLogin(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -131,10 +132,10 @@ func TestBridge_LoginDeauthLogin(t *testing.T) { func TestBridge_LoginExpireLogin(t *testing.T) { const authLife = 2 * time.Second - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { s.SetAuthLife(authLife) - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. Its auth will only be valid for a short time. userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -148,11 +149,11 @@ func TestBridge_LoginExpireLogin(t *testing.T) { } func TestBridge_FailToLoad(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { var userID string // Login the user. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) }) @@ -160,7 +161,7 @@ func TestBridge_FailToLoad(t *testing.T) { require.NoError(t, s.RevokeUser(userID)) // When bridge starts, the user will not be logged in. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) }) @@ -168,25 +169,27 @@ func TestBridge_FailToLoad(t *testing.T) { } func TestBridge_LoadWithoutInternet(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { var userID string // Login the user. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) }) // Simulate loss of internet connection. - dialer.SetCanDial(false) + netCtl.Disable() // Start bridge without internet. - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Initially, users are not connected. require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) + time.Sleep(5 * time.Second) + // Simulate internet connection. - dialer.SetCanDial(true) + netCtl.Enable() // The user will eventually be connected. require.Eventually(t, func() bool { @@ -197,16 +200,14 @@ func TestBridge_LoadWithoutInternet(t *testing.T) { } func TestBridge_LoginRestart(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { var userID string - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // Login the user. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) }) - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // The user is still connected. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) }) @@ -214,10 +215,10 @@ func TestBridge_LoginRestart(t *testing.T) { } func TestBridge_LoginLogoutRestart(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { var userID string - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -225,7 +226,7 @@ func TestBridge_LoginLogoutRestart(t *testing.T) { require.NoError(t, bridge.LogoutUser(ctx, userID)) }) - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // The user is still disconnected. require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) @@ -234,10 +235,10 @@ func TestBridge_LoginLogoutRestart(t *testing.T) { } func TestBridge_LoginDeleteRestart(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { var userID string - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -245,7 +246,7 @@ func TestBridge_LoginDeleteRestart(t *testing.T) { require.NoError(t, bridge.DeleteUser(ctx, userID)) }) - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // The user is still gone. require.Empty(t, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) @@ -253,13 +254,69 @@ func TestBridge_LoginDeleteRestart(t *testing.T) { }) } +func TestBridge_FailLoginRecover(t *testing.T) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + var read uint64 + + netCtl.OnRead(func(b []byte) { + read += uint64(len(b)) + }) + + // Log the user in and record how much data was read. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) + require.NoError(t, bridge.LogoutUser(ctx, userID)) + }) + + // Simulate a partial read. + netCtl.SetReadLimit(read / 2) + + // We should fail to log the user in because we can't fully read its data. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + require.Error(t, getErr(bridge.LoginUser(ctx, username, password, nil, nil))) + + // There should be no users. + require.Empty(t, bridge.GetUserIDs()) + }) + }) +} + +func TestBridge_FailLoadRecover(t *testing.T) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + must(bridge.LoginUser(ctx, username, password, nil, nil)) + }) + + var read uint64 + + netCtl.OnRead(func(b []byte) { + read += uint64(len(b)) + }) + + // Start bridge and record how much data was read. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // ... + }) + + // Simulate a partial read. + netCtl.SetReadLimit(read / 2) + + // We should fail to load the user; it should be disconnected. + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + userIDs := bridge.GetUserIDs() + + require.False(t, must(bridge.GetUserInfo(userIDs[0])).Connected) + }) + }) +} + func TestBridge_BridgePass(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { var userID string var pass []byte - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -276,7 +333,7 @@ func TestBridge_BridgePass(t *testing.T) { require.Equal(t, pass, pass) }) - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // The bridge should load the user. require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) @@ -288,8 +345,8 @@ func TestBridge_BridgePass(t *testing.T) { } func TestBridge_AddressMode(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { - withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID, err := bridge.LoginUser(ctx, username, password, nil, nil) require.NoError(t, err) @@ -313,3 +370,8 @@ func TestBridge_AddressMode(t *testing.T) { }) }) } + +// getErr returns the error that was passed to it. +func getErr[T any](val T, err error) error { + return err +} diff --git a/internal/events/user.go b/internal/events/user.go index 09c9accb..e4ec9e00 100644 --- a/internal/events/user.go +++ b/internal/events/user.go @@ -2,6 +2,12 @@ package events import "github.com/ProtonMail/proton-bridge/v2/internal/vault" +type UserLoaded struct { + eventBase + + UserID string +} + type UserLoggedIn struct { eventBase diff --git a/internal/focus/service.go b/internal/focus/service.go index 0ce39cec..14f5479d 100644 --- a/internal/focus/service.go +++ b/internal/focus/service.go @@ -64,4 +64,5 @@ func (service *FocusService) GetRaiseCh() <-chan struct{} { // Close closes the service. func (service *FocusService) Close() { service.server.Stop() + close(service.raiseCh) } diff --git a/internal/pool/job.go b/internal/pool/job.go deleted file mode 100644 index cfd70f5d..00000000 --- a/internal/pool/job.go +++ /dev/null @@ -1,41 +0,0 @@ -package pool - -import "context" - -type job[In, Out any] struct { - ctx context.Context - req In - - res chan Out - err chan error - - done chan struct{} -} - -func newJob[In, Out any](ctx context.Context, req In) *job[In, Out] { - return &job[In, Out]{ - ctx: ctx, - req: req, - res: make(chan Out), - err: make(chan error), - done: make(chan struct{}), - } -} - -func (job *job[In, Out]) result() (Out, error) { - return <-job.res, <-job.err -} - -func (job *job[In, Out]) postSuccess(res Out) { - close(job.err) - job.res <- res -} - -func (job *job[In, Out]) postFailure(err error) { - close(job.res) - job.err <- err -} - -func (job *job[In, Out]) waitDone() { - <-job.done -} diff --git a/internal/pool/pool.go b/internal/pool/pool.go deleted file mode 100644 index 6b1b33ae..00000000 --- a/internal/pool/pool.go +++ /dev/null @@ -1,147 +0,0 @@ -package pool - -import ( - "context" - "errors" - "sync" - - "github.com/ProtonMail/gluon/queue" -) - -// ErrJobCancelled indicates the job was cancelled. -var ErrJobCancelled = errors.New("job cancelled by surrounding context") - -// Pool is a worker pool that handles input of type In and returns results of type Out. -type Pool[In comparable, Out any] struct { - queue *queue.QueuedChannel[*job[In, Out]] - size int -} - -// doneFunc must be called to free up pool resources. -type doneFunc func() - -// New returns a new pool. -func New[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] { - queue := queue.NewQueuedChannel[*job[In, Out]](0, 0) - - for i := 0; i < size; i++ { - go func() { - for job := range queue.GetChannel() { - select { - case <-job.ctx.Done(): - job.postFailure(ErrJobCancelled) - - default: - res, err := work(job.ctx, job.req) - if err != nil { - job.postFailure(err) - } else { - job.postSuccess(res) - } - - job.waitDone() - } - } - }() - } - - return &Pool[In, Out]{ - queue: queue, - size: size, - } -} - -// Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred. -func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, Out, error) error) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - var ( - wg sync.WaitGroup - errList []error - lock sync.Mutex - ) - - for _, req := range reqs { - req := req - - wg.Add(1) - - go func() { - defer wg.Done() - - job, done := pool.newJob(ctx, req) - defer done() - - res, err := job.result() - - if err := fn(req, res, err); err != nil { - lock.Lock() - defer lock.Unlock() - - // Cancel ongoing jobs. - cancel() - - // Collect the error. - errList = append(errList, err) - } - }() - } - - wg.Wait() - - // TODO: Join the errors somehow? - if len(errList) > 0 { - return errList[0] - } - - return nil -} - -// ProcessAll submits jobs to the pool. All results are returned once available. -func (pool *Pool[In, Out]) ProcessAll(ctx context.Context, reqs []In) (map[In]Out, error) { - var ( - data = make(map[In]Out) - lock = sync.Mutex{} - ) - - if err := pool.Process(ctx, reqs, func(req In, res Out, err error) error { - if err != nil { - return err - } - - lock.Lock() - defer lock.Unlock() - - data[req] = res - - return nil - }); err != nil { - return nil, err - } - - return data, nil -} - -// ProcessOne submits one job to the pool and returns the result. -func (pool *Pool[In, Out]) ProcessOne(ctx context.Context, req In) (Out, error) { - job, done := pool.newJob(ctx, req) - defer done() - - return job.result() -} - -func (pool *Pool[In, Out]) Done() { - pool.queue.Close() -} - -// newJob submits a job to the pool. It returns a job handle and a DoneFunc. -// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done, -// which frees up the worker in the pool for reuse. -func (pool *Pool[In, Out]) newJob(ctx context.Context, req In) (*job[In, Out], doneFunc) { - job := newJob[In, Out](ctx, req) - - pool.queue.Enqueue(job) - - return job, func() { close(job.done) } -} diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go deleted file mode 100644 index 03bbb953..00000000 --- a/internal/pool/pool_test.go +++ /dev/null @@ -1,163 +0,0 @@ -package pool - -import ( - "context" - "errors" - "runtime" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestPool_NewJob(t *testing.T) { - doubler := newDoubler(runtime.NumCPU()) - - job1, done1 := doubler.newJob(context.Background(), 1) - defer done1() - - job2, done2 := doubler.newJob(context.Background(), 2) - defer done2() - - res2, err := job2.result() - require.NoError(t, err) - - res1, err := job1.result() - require.NoError(t, err) - - assert.Equal(t, 2, res1) - assert.Equal(t, 4, res2) -} - -func TestPool_NewJob_Done(t *testing.T) { - // Create a doubler pool with 2 workers. - doubler := newDoubler(2) - - // Start two jobs. Don't mark the jobs as done yet. - job1, done1 := doubler.newJob(context.Background(), 1) - job2, done2 := doubler.newJob(context.Background(), 2) - - // Get the first result. - res1, _ := job1.result() - assert.Equal(t, 2, res1) - - // Get the first result. - res2, _ := job2.result() - assert.Equal(t, 4, res2) - - // Additional jobs will wait. - job3, _ := doubler.newJob(context.Background(), 3) - job4, _ := doubler.newJob(context.Background(), 4) - - // Channel to collect results from jobs 3 and 4. - resCh := make(chan int, 2) - - go func() { - res, _ := job3.result() - resCh <- res - }() - - go func() { - res, _ := job4.result() - resCh <- res - }() - - // Mark jobs 1 and 2 as done, freeing up the workers. - done1() - done2() - - assert.ElementsMatch(t, []int{6, 8}, []int{<-resCh, <-resCh}) -} - -func TestPool_Process(t *testing.T) { - doubler := newDoubler(runtime.NumCPU()) - - var ( - res = make(map[int]int) - lock sync.Mutex - ) - - require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(reqVal, resVal int, err error) error { - require.NoError(t, err) - - lock.Lock() - defer lock.Unlock() - - res[reqVal] = resVal - - return nil - })) - - assert.Equal(t, map[int]int{ - 1: 2, - 2: 4, - 3: 6, - 4: 8, - 5: 10, - }, res) -} - -func TestPool_Process_Error(t *testing.T) { - doubler := newDoublerWithError(runtime.NumCPU()) - - assert.Error(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(_ int, _ int, err error) error { - return err - })) -} - -func TestPool_Process_Parallel(t *testing.T) { - doubler := newDoubler(runtime.NumCPU(), 100*time.Millisecond) - - var wg sync.WaitGroup - - for i := 0; i < 8; i++ { - wg.Add(1) - - go func() { - defer wg.Done() - - require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4}, func(_ int, _ int, err error) error { - return nil - })) - }() - } - - wg.Wait() -} - -func TestPool_ProcessAll(t *testing.T) { - doubler := newDoubler(runtime.NumCPU()) - - res, err := doubler.ProcessAll(context.Background(), []int{1, 2, 3, 4, 5}) - require.NoError(t, err) - - assert.Equal(t, map[int]int{ - 1: 2, - 2: 4, - 3: 6, - 4: 8, - 5: 10, - }, res) -} - -func newDoubler(workers int, delay ...time.Duration) *Pool[int, int] { - return New(workers, func(ctx context.Context, req int) (int, error) { - if len(delay) > 0 { - time.Sleep(delay[0]) - } - - return 2 * req, nil - }) -} - -func newDoublerWithError(workers int) *Pool[int, int] { - return New(workers, func(ctx context.Context, req int) (int, error) { - if req%2 == 0 { - return 0, errors.New("oops") - } - - return 2 * req, nil - }) -} diff --git a/internal/safe/map.go b/internal/safe/map.go index fee77515..39a2d58f 100644 --- a/internal/safe/map.go +++ b/internal/safe/map.go @@ -23,7 +23,15 @@ func NewMap[Key comparable, Val any](from map[Key]Val) *Map[Key, Val] { return m } -func (m *Map[Key, Val]) Get(key Key, fn func(val Val)) bool { +func (m *Map[Key, Val]) Has(key Key) bool { + m.lock.RLock() + defer m.lock.RUnlock() + + _, ok := m.data[key] + return ok +} + +func (m *Map[Key, Val]) Get(key Key, fn func(Val)) bool { m.lock.RLock() defer m.lock.RUnlock() @@ -37,7 +45,7 @@ func (m *Map[Key, Val]) Get(key Key, fn func(val Val)) bool { return true } -func (m *Map[Key, Val]) GetErr(key Key, fn func(val Val) error) (bool, error) { +func (m *Map[Key, Val]) GetErr(key Key, fn func(Val) error) (bool, error) { m.lock.RLock() defer m.lock.RUnlock() @@ -56,6 +64,15 @@ func (m *Map[Key, Val]) Set(key Key, val Val) { m.data[key] = val } +func (m *Map[Key, Val]) Iter(fn func(key Key, val Val)) { + m.lock.RLock() + defer m.lock.RUnlock() + + for key, val := range m.data { + fn(key, val) + } +} + func (m *Map[Key, Val]) Keys(fn func(keys []Key)) { m.lock.RLock() defer m.lock.RUnlock() @@ -70,28 +87,52 @@ func (m *Map[Key, Val]) Values(fn func(vals []Val)) { fn(maps.Values(m.data)) } -func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(val Val) Ret) (Ret, bool) { +func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) Ret, fallback func() Ret) Ret { m.lock.RLock() defer m.lock.RUnlock() val, ok := m.data[key] if !ok { - return *new(Ret), false + return fallback() } - return fn(val), true + return fn(val) } -func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(val Val) (Ret, error)) (Ret, bool, error) { +func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) { m.lock.RLock() defer m.lock.RUnlock() val, ok := m.data[key] if !ok { - return *new(Ret), false, nil + return fallback() } - ret, err := fn(val) - - return ret, true, err + return fn(val) +} + +func FindMap[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) Ret, fallback func() Ret) Ret { + m.lock.RLock() + defer m.lock.RUnlock() + + for _, val := range m.data { + if cmp(val) { + return fn(val) + } + } + + return fallback() +} + +func FindMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + for _, val := range m.data { + if cmp(val) { + return fn(val) + } + } + + return fallback() } diff --git a/internal/try/try.go b/internal/try/try.go new file mode 100644 index 00000000..f26a1669 --- /dev/null +++ b/internal/try/try.go @@ -0,0 +1,49 @@ +package try + +import ( + "fmt" + + "github.com/sirupsen/logrus" +) + +// Catch tries to execute the `try` function, and if it fails or panics, +// it executes the `handlers` functions in order. +func Catch(try func() error, handlers ...func() error) error { + if _, err := CatchVal(func() (any, error) { return nil, try() }, handlers...); err != nil { + return err + } + + return nil +} + +// CatchVal tries to execute the `try` function, and if it fails or panics, +// it executes the `handlers` functions in order. +func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, err error) { + defer func() { + if r := recover(); r != nil { + catch(handlers...) + err = fmt.Errorf("panic: %v", r) + } + }() + + if res, err = try(); err != nil { + catch(handlers...) + return res, err + } + + return res, nil +} + +func catch(handlers ...func() error) { + defer func() { + if r := recover(); r != nil { + logrus.WithField("panic", r).Error("Panic in catch") + } + }() + + for _, handler := range handlers { + if err := handler(); err != nil { + logrus.WithError(err).Error("Failed to handle error") + } + } +} diff --git a/internal/try/try_test.go b/internal/try/try_test.go new file mode 100644 index 00000000..6c55646a --- /dev/null +++ b/internal/try/try_test.go @@ -0,0 +1,74 @@ +package try + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTry(t *testing.T) { + res, err := CatchVal(func() (string, error) { + return "foo", nil + }) + require.NoError(t, err) + require.Equal(t, "foo", res) +} + +func TestTryCatch(t *testing.T) { + tryErr := fmt.Errorf("oops") + + res, err := CatchVal( + func() (string, error) { + return "", tryErr + }, + func() error { + return nil + }, + ) + require.ErrorIs(t, err, tryErr) + require.Zero(t, res) +} + +func TestTryCatchError(t *testing.T) { + tryErr := fmt.Errorf("oops") + + res, err := CatchVal( + func() (string, error) { + return "", tryErr + }, + func() error { + return fmt.Errorf("catch error") + }, + ) + require.ErrorIs(t, err, tryErr) + require.Zero(t, res) +} + +func TestTryPanic(t *testing.T) { + res, err := CatchVal( + func() (string, error) { + panic("oops") + }, + func() error { + return nil + }, + ) + require.ErrorContains(t, err, "panic: oops") + require.Zero(t, res) +} + +func TestTryCatchPanic(t *testing.T) { + tryErr := fmt.Errorf("oops") + + res, err := CatchVal( + func() (string, error) { + return "", tryErr + }, + func() error { + panic("oops") + }, + ) + require.ErrorIs(t, err, tryErr) + require.Zero(t, res) +} diff --git a/internal/user/events.go b/internal/user/events.go index 17695966..0c6ca097 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -259,7 +259,12 @@ func (user *User) handleMessageEvents(ctx context.Context, messageEvents []litea } func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error { - buildRes, err := user.buildRFC822(ctx, event.Message) + full, err := user.client.GetFullMessage(ctx, event.Message.ID) + if err != nil { + return fmt.Errorf("failed to get full message: %w", err) + } + + buildRes, err := buildRFC822(ctx, full, user.addrKRs) if err != nil { return fmt.Errorf("failed to build RFC822: %w", err) } diff --git a/internal/user/sync.go b/internal/user/sync.go index 754e25db..ce7c818d 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -4,14 +4,11 @@ import ( "context" "errors" "fmt" - "runtime" "strings" "time" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/queue" - "github.com/bradenaw/juniper/iterator" - "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" @@ -105,34 +102,38 @@ func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue. func (user *User) syncMessages(ctx context.Context) error { // Determine which messages to sync. - metadata, err := user.client.GetAllMessageMetadata(ctx, nil) + allMetadata, err := user.client.GetAllMessageMetadata(ctx, nil) if err != nil { return fmt.Errorf("get all message metadata: %w", err) } - // If possible, begin syncing from the last synced message. + metadata := allMetadata + + // If possible, begin syncing from one beyond the last synced message. if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" { if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool { return metadata.ID == beginID }); idx >= 0 { - metadata = metadata[idx:] + metadata = metadata[idx+1:] } } // Process the metadata, building the messages. - buildCh := stream.Chunk(parallel.MapStream( - ctx, - stream.FromIterator(iterator.Slice(metadata)), - runtime.NumCPU()*runtime.NumCPU()/2, - runtime.NumCPU()*runtime.NumCPU()/2, - user.buildRFC822, + buildCh := stream.Chunk(stream.Map( + user.client.GetFullMessages(ctx, xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string { + return metadata.ID + })...), + func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) { + return buildRFC822(ctx, full, user.addrKRs) + }, ), maxBatchSize) + defer buildCh.Close() // Create the flushers, one per update channel. flushers := make(map[string]*flusher) for addrID, updateCh := range user.updateCh { - flusher := newFlusher(user.ID(), updateCh, maxUpdateSize) + flusher := newFlusher(updateCh, maxUpdateSize) defer flusher.flush(ctx, true) flushers[addrID] = flusher @@ -142,6 +143,8 @@ func (user *User) syncMessages(ctx context.Context) error { reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second) defer reporter.done() + var count int + // Send each update to the appropriate flusher. for { batch, err := buildCh.Next(ctx) @@ -170,6 +173,8 @@ func (user *User) syncMessages(ctx context.Context) error { } reporter.add(len(batch)) + + count += len(batch) } } diff --git a/internal/user/sync_build.go b/internal/user/sync_build.go index d917944a..73544e55 100644 --- a/internal/user/sync_build.go +++ b/internal/user/sync_build.go @@ -6,6 +6,7 @@ import ( "time" "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/pkg/message" "github.com/bradenaw/juniper/xslices" "gitlab.protontech.ch/go/liteapi" @@ -29,30 +30,20 @@ func defaultJobOpts() message.JobOptions { } } -func (user *User) buildRFC822(ctx context.Context, metadata liteapi.MessageMetadata) (*buildRes, error) { - msg, err := user.client.GetMessage(ctx, metadata.ID) +func buildRFC822(ctx context.Context, full liteapi.FullMessage, addrKRs map[string]*crypto.KeyRing) (*buildRes, error) { + literal, err := message.BuildRFC822(addrKRs[full.AddressID], full.Message, full.AttData, defaultJobOpts()) if err != nil { - return nil, fmt.Errorf("failed to get message %s: %w", metadata.ID, err) + return nil, fmt.Errorf("failed to build message %s: %w", full.ID, err) } - attData, err := user.attPool.ProcessAll(ctx, xslices.Map(msg.Attachments, func(att liteapi.Attachment) string { return att.ID })) + update, err := newMessageCreatedUpdate(full.MessageMetadata, literal) if err != nil { - return nil, fmt.Errorf("failed to get attachments for message %s: %w", metadata.ID, err) - } - - literal, err := message.BuildRFC822(user.addrKRs[msg.AddressID], msg, attData, defaultJobOpts()) - if err != nil { - return nil, fmt.Errorf("failed to build message %s: %w", metadata.ID, err) - } - - update, err := newMessageCreatedUpdate(metadata, literal) - if err != nil { - return nil, fmt.Errorf("failed to create IMAP update for message %s: %w", metadata.ID, err) + return nil, fmt.Errorf("failed to create IMAP update for message %s: %w", full.ID, err) } return &buildRes{ - messageID: metadata.ID, - addressID: metadata.AddressID, + messageID: full.ID, + addressID: full.AddressID, update: update, }, nil } diff --git a/internal/user/sync_flusher.go b/internal/user/sync_flusher.go index 0050fb58..2d873bc3 100644 --- a/internal/user/sync_flusher.go +++ b/internal/user/sync_flusher.go @@ -9,21 +9,19 @@ import ( ) type flusher struct { - userID string updateCh *queue.QueuedChannel[imap.Update] + updates []*imap.MessageCreated - updates []*imap.MessageCreated - maxChunkSize int - curChunkSize int + maxUpdateSize int + curChunkSize int pushLock sync.Mutex } -func newFlusher(userID string, updateCh *queue.QueuedChannel[imap.Update], maxChunkSize int) *flusher { +func newFlusher(updateCh *queue.QueuedChannel[imap.Update], maxUpdateSize int) *flusher { return &flusher{ - userID: userID, - updateCh: updateCh, - maxChunkSize: maxChunkSize, + updateCh: updateCh, + maxUpdateSize: maxUpdateSize, } } @@ -33,20 +31,18 @@ func (f *flusher) push(ctx context.Context, update *imap.MessageCreated) { f.updates = append(f.updates, update) - if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize { + if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxUpdateSize { f.flush(ctx, false) } } func (f *flusher) flush(ctx context.Context, wait bool) { - if len(f.updates) == 0 { - return + if len(f.updates) > 0 { + f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...)) + f.updates = nil + f.curChunkSize = 0 } - f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...)) - f.updates = nil - f.curChunkSize = 0 - if wait { update := imap.NewNoop() defer update.WaitContext(ctx) diff --git a/internal/user/user.go b/internal/user/user.go index ffa5dfca..7fcaeb9c 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -5,7 +5,6 @@ import ( "context" "encoding/hex" "fmt" - "runtime" "time" "github.com/ProtonMail/gluon/connector" @@ -14,7 +13,6 @@ import ( "github.com/ProtonMail/gluon/wait" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/events" - "github.com/ProtonMail/proton-bridge/v2/internal/pool" "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/xslices" @@ -31,7 +29,6 @@ var ( type User struct { vault *vault.User client *liteapi.Client - attPool *pool.Pool[string, []byte] eventCh *queue.QueuedChannel[events.Event] apiUser *safe.Type[liteapi.User] @@ -91,7 +88,6 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU user := &User{ vault: encVault, client: client, - attPool: pool.New(runtime.NumCPU(), client.GetAttachment), eventCh: queue.NewQueuedChannel[events.Event](0, 0), apiUser: safe.NewType(apiUser), @@ -123,7 +119,7 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU // If we haven't synced yet, do it first. // If it fails, we don't start the event loop. - // Oterwise, begin processing API events, logging any errors that occur. + // Otherwise, begin processing API events, logging any errors that occur. go func() { if status := user.vault.SyncStatus(); !status.HasMessages { if err := <-user.startSync(); err != nil { @@ -336,8 +332,17 @@ func (user *User) NewSMTPSession(email string) (smtp.Session, error) { } // Logout logs the user out from the API. +// If withVault is true, the user's vault is also cleared. func (user *User) Logout(ctx context.Context) error { - return user.client.AuthDelete(ctx) + if err := user.client.AuthDelete(ctx); err != nil { + return fmt.Errorf("failed to delete auth: %w", err) + } + + if err := user.vault.Clear(); err != nil { + return fmt.Errorf("failed to clear vault: %w", err) + } + + return nil } // Close closes ongoing connections and cleans up resources. @@ -345,9 +350,6 @@ func (user *User) Close() error { // Cancel ongoing syncs. user.stopSync() - // Close the attachment pool. - user.attPool.Done() - // Close the user's API client. user.client.Close() diff --git a/internal/user/user_test.go b/internal/user/user_test.go index a108b2b2..76c1a9ce 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -89,13 +89,13 @@ func withAPI(t *testing.T, ctx context.Context, fn func(context.Context, *server func withAccount(t *testing.T, s *server.Server, username, password string, emails []string, fn func(string, []string)) { var addrIDs []string - userID, addrID, err := s.CreateUser(username, password, emails[0]) + userID, addrID, err := s.CreateUser(username, emails[0], []byte(password)) require.NoError(t, err) addrIDs = append(addrIDs, addrID) for _, email := range emails[1:] { - addrID, err := s.CreateAddress(userID, email, password) + addrID, err := s.CreateAddress(userID, email, []byte(password)) require.NoError(t, err) addrIDs = append(addrIDs, addrID) diff --git a/internal/vault/user.go b/internal/vault/user.go index 2eb207f7..4727f599 100644 --- a/internal/vault/user.go +++ b/internal/vault/user.go @@ -138,3 +138,12 @@ func (user *User) SetEventID(eventID string) error { data.EventID = eventID }) } + +// Clear clears the user's auth secrets. +func (user *User) Clear() error { + return user.vault.modUser(user.userID, func(data *UserData) { + data.AuthUID = "" + data.AuthRef = "" + data.KeyPass = nil + }) +} diff --git a/internal/vault/user_test.go b/internal/vault/user_test.go index 76fd8c70..4e4c322e 100644 --- a/internal/vault/user_test.go +++ b/internal/vault/user_test.go @@ -58,7 +58,7 @@ func TestUser_Clear(t *testing.T) { require.Equal(t, "keyPass", string(user.KeyPass())) // Clear the user's auth information. - require.NoError(t, s.ClearUser("userID")) + require.NoError(t, user.Clear()) // Check the user's cleared auth information. require.Empty(t, user.AuthUID()) diff --git a/internal/vault/vault.go b/internal/vault/vault.go index 52857830..85f438ca 100644 --- a/internal/vault/vault.go +++ b/internal/vault/vault.go @@ -107,14 +107,6 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [ return vault.GetUser(userID) } -func (vault *Vault) ClearUser(userID string) error { - return vault.modUser(userID, func(data *UserData) { - data.AuthUID = "" - data.AuthRef = "" - data.KeyPass = nil - }) -} - // DeleteUser removes the given user from the vault. func (vault *Vault) DeleteUser(userID string) error { return vault.mod(func(data *Data) { diff --git a/tests/api_test.go b/tests/api_test.go index 9c64e195..e6803a85 100644 --- a/tests/api_test.go +++ b/tests/api_test.go @@ -12,8 +12,8 @@ type API interface { GetHostURL() string AddCallWatcher(func(server.Call), ...string) - CreateUser(username, password, address string) (string, string, error) - CreateAddress(userID, address, password string) (string, error) + CreateUser(username, address string, password []byte) (string, string, error) + CreateAddress(userID, address string, password []byte) (string, error) RemoveAddress(userID, addrID string) error RevokeUser(userID string) error diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index f46dc3a5..6f25e7d5 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -2,17 +2,19 @@ package tests import ( "context" + "crypto/tls" "fmt" "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/vault" + "gitlab.protontech.ch/go/liteapi" ) func (t *testCtx) startBridge() error { // Bridge will enable the proxy by default at startup. - t.mocks.ProxyDialer.EXPECT().AllowProxy() + t.mocks.ProxyCtl.EXPECT().AllowProxy() // Get the path to the vault. vaultDir, err := t.locator.ProvideSettingsPath() @@ -41,7 +43,8 @@ func (t *testCtx) startBridge() error { vault, useragent.New(), t.mocks.TLSReporter, - t.mocks.ProxyDialer, + liteapi.NewDialer(t.netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(), + t.mocks.ProxyCtl, t.mocks.Autostarter, t.mocks.Updater, t.version, diff --git a/tests/ctx_test.go b/tests/ctx_test.go index 5838f3b6..ffd648af 100644 --- a/tests/ctx_test.go +++ b/tests/ctx_test.go @@ -24,7 +24,7 @@ type testCtx struct { // These are the objects supporting the test. dir string api API - dialer *bridge.TestDialer + netCtl *liteapi.NetCtl locator *locations.Locations storeKey []byte version *semver.Version @@ -76,15 +76,13 @@ type smtpClient struct { func newTestCtx(tb testing.TB) *testCtx { dir := tb.TempDir() - dialer := bridge.NewTestDialer() - ctx := &testCtx{ dir: dir, api: newFakeAPI(), - dialer: dialer, + netCtl: liteapi.NewNetCtl(), locator: locations.New(bridge.NewTestLocationsProvider(dir), "config-name"), storeKey: []byte("super-secret-store-key"), - mocks: bridge.NewMocks(tb, dialer, defaultVersion, defaultVersion), + mocks: bridge.NewMocks(tb, defaultVersion, defaultVersion), version: defaultVersion, userIDByName: make(map[string]string), diff --git a/tests/environment_test.go b/tests/environment_test.go index 955f2540..e57f72db 100644 --- a/tests/environment_test.go +++ b/tests/environment_test.go @@ -38,12 +38,12 @@ func (s *scenario) itFailsWithError(wantErr string) error { } func (s *scenario) internetIsTurnedOff() error { - s.t.dialer.SetCanDial(false) + s.t.netCtl.SetCanDial(false) return nil } func (s *scenario) internetIsTurnedOn() error { - s.t.dialer.SetCanDial(true) + s.t.netCtl.SetCanDial(true) return nil } diff --git a/tests/user_test.go b/tests/user_test.go index 2e7d88bd..cc80f414 100644 --- a/tests/user_test.go +++ b/tests/user_test.go @@ -14,7 +14,7 @@ import ( func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error { // Create the user. - userID, addrID, err := s.t.api.CreateUser(username, password, username) + userID, addrID, err := s.t.api.CreateUser(username, username, []byte(password)) if err != nil { return err } @@ -34,7 +34,7 @@ func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, passwor func (s *scenario) theAccountHasAdditionalAddress(username, address string) error { userID := s.t.getUserID(username) - addrID, err := s.t.api.CreateAddress(userID, address, s.t.getUserPass(userID)) + addrID, err := s.t.api.CreateAddress(userID, address, []byte(s.t.getUserPass(userID))) if err != nil { return err }