From 612fb7ad7b6bae19c656e4cd42f599314d1257dd Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Tue, 27 Sep 2022 12:02:28 +0200 Subject: [PATCH] GODT-1815: Start without internet, load users later --- internal/bridge/bridge.go | 53 ++++++++++++- internal/bridge/bridge_test.go | 104 ++++++++++++------------- internal/bridge/mocks.go | 8 +- internal/bridge/settings_test.go | 49 ++++++------ internal/bridge/users.go | 6 +- internal/bridge/users_test.go | 123 +++++++++++++++--------------- internal/events/connection.go | 10 +-- internal/frontend/cli/frontend.go | 12 +-- internal/frontend/grpc/service.go | 8 +- tests/bdd_test.go | 1 + tests/bridge_test.go | 13 ++-- tests/ctx_bridge_test.go | 2 +- tests/ctx_test.go | 8 +- tests/environment_test.go | 4 +- tests/features/user/login.feature | 8 ++ tests/user_test.go | 9 +++ 16 files changed, 243 insertions(+), 175 deletions(-) diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 533a9c58..8a0f9214 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "sync" + "time" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gluon" @@ -150,9 +151,13 @@ func New( } api.AddStatusObserver(func(status liteapi.Status) { - bridge.publish(events.ConnStatus{ - Status: status, - }) + switch { + case status == liteapi.StatusUp: + go bridge.onStatusUp() + + case status == liteapi.StatusDown: + go bridge.onStatusDown() + } }) api.AddErrorHandler(liteapi.AppVersionBadCode, func() { @@ -288,6 +293,48 @@ func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) { }) } +func (bridge *Bridge) onStatusUp() { + bridge.publish(events.ConnStatusUp{}) + + for _, userID := range bridge.vault.GetUserIDs() { + if _, ok := bridge.users[userID]; !ok { + if vaultUser, err := bridge.vault.GetUser(userID); err != nil { + logrus.WithError(err).Error("Failed to get user from vault") + } else if err := bridge.loadUser(context.Background(), vaultUser); err != nil { + logrus.WithError(err).Error("Failed to load user") + } + } + } +} + +func (bridge *Bridge) onStatusDown() { + bridge.publish(events.ConnStatusDown{}) + + upCh, done := bridge.GetEvents(events.ConnStatusUp{}) + defer done() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + backoff := time.Second + + for { + select { + case <-upCh: + return + + case <-time.After(backoff): + if err := bridge.api.Ping(ctx); err != nil { + logrus.WithError(err).Debug("Failed to ping API") + } + } + + if backoff < 30*time.Second { + backoff *= 2 + } + } +} + func loadTLSConfig(vault *vault.Vault) (*tls.Config, error) { cert, err := tls.X509KeyPair(vault.GetBridgeTLSCert(), vault.GetBridgeTLSKey()) if err != nil { diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 6e6fd3ef..9c83088c 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -15,7 +15,6 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/xslices" "github.com/stretchr/testify/require" - "gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi/server" ) @@ -30,24 +29,24 @@ var ( ) func TestBridge_ConnStatus(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) { - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Get a stream of connection status events. - eventCh, done := bridge.GetEvents(events.ConnStatus{}) + eventCh, done := bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{}) defer done() // Simulate network disconnect. - mocks.TLSDialer.SetCanDial(false) + dialer.SetCanDial(false) // Trigger some operation that will fail due to the network disconnect. _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) require.Error(t, err) // Wait for the event. - require.Equal(t, events.ConnStatus{Status: liteapi.StatusDown}, <-eventCh) + require.Equal(t, events.ConnStatusDown{}, <-eventCh) // Simulate network reconnect. - mocks.TLSDialer.SetCanDial(true) + dialer.SetCanDial(true) // Trigger some operation that will succeed due to the network reconnect. userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil) @@ -55,14 +54,14 @@ func TestBridge_ConnStatus(t *testing.T) { require.NotEmpty(t, userID) // Wait for the event. - require.Equal(t, events.ConnStatus{Status: liteapi.StatusUp}, <-eventCh) + require.Equal(t, events.ConnStatusUp{}, <-eventCh) }) }) } func TestBridge_TLSIssue(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) { - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Get a stream of TLS issue events. tlsEventCh, done := bridge.GetEvents(events.TLSIssue{}) defer done() @@ -79,8 +78,8 @@ func TestBridge_TLSIssue(t *testing.T) { } func TestBridge_Focus(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) { - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Get a stream of TLS issue events. raiseCh, done := bridge.GetEvents(events.Raise{}) defer done() @@ -95,14 +94,14 @@ func TestBridge_Focus(t *testing.T) { } func TestBridge_UserAgent(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { var calls []server.Call s.AddCallWatcher(func(call server.Call) { calls = append(calls, call) }) - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Set the platform to something other than the default. bridge.SetCurrentPlatform("platform") @@ -120,7 +119,7 @@ func TestBridge_UserAgent(t *testing.T) { } func TestBridge_Cookies(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { var calls []server.Call s.AddCallWatcher(func(call server.Call) { @@ -130,7 +129,7 @@ func TestBridge_Cookies(t *testing.T) { var sessionID string // Start bridge and add a user so that API assigns us a session ID via cookie. - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) require.NoError(t, err) @@ -141,7 +140,7 @@ func TestBridge_Cookies(t *testing.T) { }) // Start bridge again and check that it uses the same session ID. - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { cookie, err := calls[len(calls)-1].Request.Cookie("Session-Id") require.NoError(t, err) @@ -151,8 +150,8 @@ func TestBridge_Cookies(t *testing.T) { } func TestBridge_CheckUpdate(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) { - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Disable autoupdate for this test. require.NoError(t, bridge.SetAutoUpdate(false)) @@ -182,8 +181,8 @@ func TestBridge_CheckUpdate(t *testing.T) { } func TestBridge_AutoUpdate(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) { - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Enable autoupdate for this test. require.NoError(t, bridge.SetAutoUpdate(true)) @@ -208,8 +207,8 @@ func TestBridge_AutoUpdate(t *testing.T) { } func TestBridge_ManualUpdate(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) { - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Disable autoupdate for this test. require.NoError(t, bridge.SetAutoUpdate(false)) @@ -235,8 +234,8 @@ func TestBridge_ManualUpdate(t *testing.T) { } func TestBridge_ForceUpdate(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) { - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Get a stream of update events. updateCh, done := bridge.GetEvents(events.UpdateForced{}) defer done() @@ -255,11 +254,11 @@ func TestBridge_ForceUpdate(t *testing.T) { } func TestBridge_BadVaultKey(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { var userID string // Login a user. - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil) require.NoError(t, err) @@ -267,24 +266,24 @@ func TestBridge_BadVaultKey(t *testing.T) { }) // Start bridge with the correct vault key -- it should load the users correctly. - withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs()) }) // Start bridge with a bad vault key, the vault will be wiped and bridge will show no users. - withBridge(t, s.GetHostURL(), locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Empty(t, bridge.GetUserIDs()) }) // Start bridge with a nil vault key, the vault will be wiped and bridge will show no users. - withBridge(t, s.GetHostURL(), locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Empty(t, bridge.GetUserIDs()) }) }) } // withEnv creates the full test environment and runs the tests. -func withEnv(t *testing.T, tests func(server *server.Server, locator bridge.Locator, vaultKey []byte)) { +func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte)) { // Create test API. server := server.NewTLS() defer server.Close() @@ -297,17 +296,24 @@ func withEnv(t *testing.T, tests func(server *server.Server, locator bridge.Loca vaultKey, err := crypto.RandomToken(32) require.NoError(t, err) - // Run the tests. - tests(server, locations.New(bridge.NewTestLocationsProvider(t), "config-name"), vaultKey) -} - -// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done. -func withBridge(t *testing.T, apiURL string, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) { + // Create a context used for the test. ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Run the tests. + tests( + ctx, + server, + bridge.NewTestDialer(), + locations.New(bridge.NewTestLocationsProvider(t), "config-name"), + vaultKey, + ) +} + +// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done. +func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) { // Create the mock objects used in the tests. - mocks := bridge.NewMocks(t, v2_3_0, v2_3_0) + mocks := bridge.NewMocks(t, dialer, v2_3_0, v2_3_0) // Bridge will enable the proxy by default at startup. mocks.ProxyDialer.EXPECT().AllowProxy() @@ -320,25 +326,19 @@ func withBridge(t *testing.T, apiURL string, locator bridge.Locator, vaultKey [] vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey) require.NoError(t, err) + // Let the IMAP and SMTP servers choose random available ports for this test. + require.NoError(t, vault.SetIMAPPort(0)) + require.NoError(t, vault.SetSMTPPort(0)) + // Create a new bridge. - bridge, err := bridge.New( - apiURL, - locator, - vault, - useragent.New(), - mocks.TLSReporter, - mocks.ProxyDialer, - mocks.Autostarter, - mocks.Updater, - v2_3_0, - ) + bridge, err := bridge.New(apiURL, locator, vault, useragent.New(), mocks.TLSReporter, mocks.ProxyDialer, mocks.Autostarter, mocks.Updater, v2_3_0) require.NoError(t, err) + // Close the bridge when done. + defer bridge.Close(ctx) + // Use the bridge. tests(bridge, mocks) - - // Close the bridge. - require.NoError(t, bridge.Close(ctx)) } // must is a helper function that panics on error. diff --git a/internal/bridge/mocks.go b/internal/bridge/mocks.go index 67fd3dc4..e0d08ed8 100644 --- a/internal/bridge/mocks.go +++ b/internal/bridge/mocks.go @@ -14,9 +14,7 @@ import ( ) type Mocks struct { - TLSDialer *TestDialer ProxyDialer *mocks.MockProxyDialer - TLSReporter *mocks.MockTLSReporter TLSIssueCh chan struct{} @@ -24,13 +22,11 @@ type Mocks struct { Autostarter *mocks.MockAutostarter } -func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks { +func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Version) *Mocks { ctl := gomock.NewController(tb) mocks := &Mocks{ - TLSDialer: NewTestDialer(), ProxyDialer: mocks.NewMockProxyDialer(ctl), - TLSReporter: mocks.NewMockTLSReporter(ctl), TLSIssueCh: make(chan struct{}), @@ -44,7 +40,7 @@ func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks { gomock.Any(), gomock.Any(), ).DoAndReturn(func(ctx context.Context, network, address string) (net.Conn, error) { - return mocks.TLSDialer.DialTLSContext(ctx, network, address) + return dialer.DialTLSContext(ctx, network, address) }).AnyTimes() // When getting the TLS issue channel, we want to return the test channel. diff --git a/internal/bridge/settings_test.go b/internal/bridge/settings_test.go index 2397b958..ae94597b 100644 --- a/internal/bridge/settings_test.go +++ b/internal/bridge/settings_test.go @@ -11,8 +11,8 @@ import ( ) func TestBridge_Settings_GluonDir(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Create a user. _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) require.NoError(t, err) @@ -34,23 +34,25 @@ func TestBridge_Settings_GluonDir(t *testing.T) { } func TestBridge_Settings_IMAPPort(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // By default, the port is 1143. - require.Equal(t, 1143, bridge.GetIMAPPort()) + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + curPort := bridge.GetIMAPPort() // Set the port to 1144. require.NoError(t, bridge.SetIMAPPort(1144)) // Get the new setting. require.Equal(t, 1144, bridge.GetIMAPPort()) + + // Assert that it has changed. + require.NotEqual(t, curPort, bridge.GetIMAPPort()) }) }) } func TestBridge_Settings_IMAPSSL(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, IMAP SSL is disabled. require.False(t, bridge.GetIMAPSSL()) @@ -64,23 +66,26 @@ func TestBridge_Settings_IMAPSSL(t *testing.T) { } func TestBridge_Settings_SMTPPort(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // By default, the port is 1025. - require.Equal(t, 1025, bridge.GetSMTPPort()) + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + curPort := bridge.GetSMTPPort() // Set the port to 1024. require.NoError(t, bridge.SetSMTPPort(1024)) // Get the new setting. require.Equal(t, 1024, bridge.GetSMTPPort()) + + // Assert that it has changed. + require.NotEqual(t, curPort, bridge.GetSMTPPort()) + }) }) } func TestBridge_Settings_SMTPSSL(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, SMTP SSL is disabled. require.False(t, bridge.GetSMTPSSL()) @@ -94,8 +99,8 @@ func TestBridge_Settings_SMTPSSL(t *testing.T) { } func TestBridge_Settings_Proxy(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, proxy is allowed. require.True(t, bridge.GetProxyAllowed()) @@ -110,8 +115,8 @@ func TestBridge_Settings_Proxy(t *testing.T) { } func TestBridge_Settings_Autostart(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, autostart is disabled. require.False(t, bridge.GetAutostart()) @@ -126,8 +131,8 @@ func TestBridge_Settings_Autostart(t *testing.T) { } func TestBridge_Settings_FirstStart(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, first start is true. require.True(t, bridge.GetFirstStart()) @@ -141,8 +146,8 @@ func TestBridge_Settings_FirstStart(t *testing.T) { } func TestBridge_Settings_FirstStartGUI(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // By default, first start is true. require.True(t, bridge.GetFirstStartGUI()) diff --git a/internal/bridge/users.go b/internal/bridge/users.go index 3a2667e4..bf211972 100644 --- a/internal/bridge/users.go +++ b/internal/bridge/users.go @@ -176,8 +176,10 @@ func (bridge *Bridge) loadUsers(ctx context.Context) error { if err := bridge.loadUser(ctx, user); err != nil { logrus.WithError(err).Error("Failed to load connected user") - if err := user.Clear(); err != nil { - logrus.WithError(err).Error("Failed to clear user") + if _, ok := err.(*resty.ResponseError); ok { + if err := user.Clear(); err != nil { + logrus.WithError(err).Error("Failed to clear user") + } } continue diff --git a/internal/bridge/users_test.go b/internal/bridge/users_test.go index 51dd838b..f868fb09 100644 --- a/internal/bridge/users_test.go +++ b/internal/bridge/users_test.go @@ -12,13 +12,13 @@ import ( ) func TestBridge_WithoutUsers(t *testing.T) { - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Empty(t, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) }) - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Empty(t, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) }) @@ -26,11 +26,8 @@ func TestBridge_WithoutUsers(t *testing.T) { } func TestBridge_Login(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID, err := bridge.LoginUser(ctx, username, password, nil, nil) require.NoError(t, err) @@ -43,11 +40,8 @@ func TestBridge_Login(t *testing.T) { } func TestBridge_LoginLogoutLogin(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -74,11 +68,8 @@ func TestBridge_LoginLogoutLogin(t *testing.T) { } func TestBridge_LoginDeleteLogin(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -105,11 +96,8 @@ func TestBridge_LoginDeleteLogin(t *testing.T) { } func TestBridge_LoginDeauthLogin(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -142,13 +130,10 @@ func TestBridge_LoginDeauthLogin(t *testing.T) { func TestBridge_LoginExpireLogin(t *testing.T) { const authLife = 2 * time.Second - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { s.SetAuthLife(authLife) - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. Its auth will only be valid for a short time. userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -162,41 +147,64 @@ func TestBridge_LoginExpireLogin(t *testing.T) { } func TestBridge_FailToLoad(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { var userID string - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // Login the user. + // Login the user. + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) }) // Deauth the user while bridge is stopped. require.NoError(t, s.RevokeUser(userID)) - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // The user is disconnected. + // When bridge starts, the user will not be logged in. + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) }) }) } -func TestBridge_LoginRestart(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { +func TestBridge_LoadWithoutInternet(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { var userID string - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Login the user. + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) + }) + + // Simulate loss of internet connection. + dialer.SetCanDial(false) + + // Start bridge without internet. + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Initially, users are not connected. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + + // Simulate internet connection. + dialer.SetCanDial(true) + + // The user will eventually be connected. + require.Eventually(t, func() bool { + return len(getConnectedUserIDs(t, bridge)) == 1 && getConnectedUserIDs(t, bridge)[0] == userID + }, 10*time.Second, time.Second) + }) + }) +} + +func TestBridge_LoginRestart(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + var userID string + + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) }) - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // The user is still connected. require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) @@ -205,13 +213,10 @@ func TestBridge_LoginRestart(t *testing.T) { } func TestBridge_LoginLogoutRestart(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { var userID string - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -219,7 +224,7 @@ func TestBridge_LoginLogoutRestart(t *testing.T) { require.NoError(t, bridge.LogoutUser(ctx, userID)) }) - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // The user is still disconnected. require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) @@ -228,13 +233,10 @@ func TestBridge_LoginLogoutRestart(t *testing.T) { } func TestBridge_LoginDeleteRestart(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { var userID string - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -242,7 +244,7 @@ func TestBridge_LoginDeleteRestart(t *testing.T) { require.NoError(t, bridge.DeleteUser(ctx, userID)) }) - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // The user is still gone. require.Empty(t, bridge.GetUserIDs()) require.Empty(t, getConnectedUserIDs(t, bridge)) @@ -251,13 +253,10 @@ func TestBridge_LoginDeleteRestart(t *testing.T) { } func TestBridge_BridgePass(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { var userID, pass string - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) @@ -274,8 +273,8 @@ func TestBridge_BridgePass(t *testing.T) { require.Equal(t, pass, pass) }) - withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // The bridge should load schizofrenic. + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // The bridge should load the user. require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) diff --git a/internal/events/connection.go b/internal/events/connection.go index c3e456c3..1b8b64ad 100644 --- a/internal/events/connection.go +++ b/internal/events/connection.go @@ -1,13 +1,13 @@ package events -import "gitlab.protontech.ch/go/liteapi" - type TLSIssue struct { eventBase } -type ConnStatus struct { +type ConnStatusUp struct { + eventBase +} + +type ConnStatusDown struct { eventBase - - Status liteapi.Status } diff --git a/internal/frontend/cli/frontend.go b/internal/frontend/cli/frontend.go index 0d4384ab..f45cacc7 100644 --- a/internal/frontend/cli/frontend.go +++ b/internal/frontend/cli/frontend.go @@ -25,7 +25,6 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/vault" - "gitlab.protontech.ch/go/liteapi" "github.com/abiosoft/ishell" "github.com/sirupsen/logrus" @@ -283,14 +282,11 @@ func (f *frontendCLI) watchEvents() { for event := range eventCh { switch event := event.(type) { - case events.ConnStatus: - switch event.Status { - case liteapi.StatusUp: - f.notifyInternetOn() + case events.ConnStatusUp: + f.notifyInternetOn() - case liteapi.StatusDown: - f.notifyInternetOff() - } + case events.ConnStatusDown: + f.notifyInternetOff() case events.UserDeauth: user, err := f.bridge.GetUserInfo(event.UserID) diff --git a/internal/frontend/grpc/service.go b/internal/frontend/grpc/service.go index e82c09c8..31677375 100644 --- a/internal/frontend/grpc/service.go +++ b/internal/frontend/grpc/service.go @@ -38,7 +38,6 @@ import ( "github.com/ProtonMail/proton-bridge/v2/pkg/restarter" "github.com/google/uuid" "github.com/sirupsen/logrus" - "gitlab.protontech.ch/go/liteapi" "google.golang.org/grpc" codes "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -219,8 +218,11 @@ func (s *Service) watchEvents() { for event := range eventCh { switch event := event.(type) { - case events.ConnStatus: - _ = s.SendEvent(NewInternetStatusEvent(event.Status == liteapi.StatusUp)) + case events.ConnStatusUp: + _ = s.SendEvent(NewInternetStatusEvent(true)) + + case events.ConnStatusDown: + _ = s.SendEvent(NewInternetStatusEvent(false)) case events.Raise: _ = s.SendEvent(NewShowMainWindowEvent()) diff --git a/tests/bdd_test.go b/tests/bdd_test.go index f0e4d8d2..673c11ec 100644 --- a/tests/bdd_test.go +++ b/tests/bdd_test.go @@ -109,6 +109,7 @@ func TestFeatures(testingT *testing.T) { ctx.Step(`^user "([^"]*)" is deleted$`, s.userIsDeleted) ctx.Step(`^the auth of user "([^"]*)" is revoked$`, s.theAuthOfUserIsRevoked) ctx.Step(`^user "([^"]*)" is listed and connected$`, s.userIsListedAndConnected) + ctx.Step(`^user "([^"]*)" is eventually listed and connected$`, s.userIsEventuallyListedAndConnected) ctx.Step(`^user "([^"]*)" is listed but not connected$`, s.userIsListedButNotConnected) ctx.Step(`^user "([^"]*)" is not listed$`, s.userIsNotListed) ctx.Step(`^user "([^"]*)" finishes syncing$`, s.userFinishesSyncing) diff --git a/tests/bridge_test.go b/tests/bridge_test.go index 92fe4b12..5065b15b 100644 --- a/tests/bridge_test.go +++ b/tests/bridge_test.go @@ -9,7 +9,6 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v2/internal/events" - "gitlab.protontech.ch/go/liteapi" ) func (s *scenario) bridgeStarts() error { @@ -85,9 +84,9 @@ func (s *scenario) theUserReportsABug() error { } func (s *scenario) bridgeSendsAConnectionUpEvent() error { - return try(s.t.connStatusCh, 5*time.Second, func(event events.ConnStatus) error { - if event.Status != liteapi.StatusUp { - return fmt.Errorf("expected connection up event, got %v", event.Status) + return try(s.t.connStatusCh, 5*time.Second, func(event events.Event) error { + if event, ok := event.(events.ConnStatusUp); !ok { + return fmt.Errorf("expected connection up event, got %T", event) } return nil @@ -95,9 +94,9 @@ func (s *scenario) bridgeSendsAConnectionUpEvent() error { } func (s *scenario) bridgeSendsAConnectionDownEvent() error { - return try(s.t.connStatusCh, 5*time.Second, func(event events.ConnStatus) error { - if event.Status != liteapi.StatusDown { - return fmt.Errorf("expected connection down event, got %v", event.Status) + return try(s.t.connStatusCh, 5*time.Second, func(event events.Event) error { + if event, ok := event.(events.ConnStatusDown); !ok { + return fmt.Errorf("expected connection down event, got %T", event) } return nil diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index 2c67b5ad..c696eb5f 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -54,7 +54,6 @@ func (t *testCtx) startBridge() error { t.bridge = bridge // Connect the event channels. - t.connStatusCh = chToType[events.Event, events.ConnStatus](bridge.GetEvents(events.ConnStatus{})) t.userLoginCh = chToType[events.Event, events.UserLoggedIn](bridge.GetEvents(events.UserLoggedIn{})) t.userLogoutCh = chToType[events.Event, events.UserLoggedOut](bridge.GetEvents(events.UserLoggedOut{})) t.userDeletedCh = chToType[events.Event, events.UserDeleted](bridge.GetEvents(events.UserDeleted{})) @@ -62,6 +61,7 @@ func (t *testCtx) startBridge() error { t.syncStartedCh = chToType[events.Event, events.SyncStarted](bridge.GetEvents(events.SyncStarted{})) t.syncFinishedCh = chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) t.forcedUpdateCh = chToType[events.Event, events.UpdateForced](bridge.GetEvents(events.UpdateForced{})) + t.connStatusCh, _ = bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{}) t.updateCh, _ = bridge.GetEvents(events.UpdateAvailable{}, events.UpdateNotAvailable{}, events.UpdateInstalled{}, events.UpdateForced{}) return nil diff --git a/tests/ctx_test.go b/tests/ctx_test.go index 6a767931..fec89d9c 100644 --- a/tests/ctx_test.go +++ b/tests/ctx_test.go @@ -22,6 +22,7 @@ type testCtx struct { // These are the objects supporting the test. dir string api API + dialer *bridge.TestDialer locator *locations.Locations storeKey []byte version *semver.Version @@ -31,7 +32,6 @@ type testCtx struct { bridge *bridge.Bridge // These channels hold events of various types coming from bridge. - connStatusCh <-chan events.ConnStatus userLoginCh <-chan events.UserLoggedIn userLogoutCh <-chan events.UserLoggedOut userDeletedCh <-chan events.UserDeleted @@ -39,6 +39,7 @@ type testCtx struct { syncStartedCh <-chan events.SyncStarted syncFinishedCh <-chan events.SyncFinished forcedUpdateCh <-chan events.UpdateForced + connStatusCh <-chan events.Event updateCh <-chan events.Event // These maps hold expected userIDByName, their primary addresses and bridge passwords. @@ -69,12 +70,15 @@ type smtpClient struct { } func newTestCtx(tb testing.TB) *testCtx { + dialer := bridge.NewTestDialer() + ctx := &testCtx{ dir: tb.TempDir(), api: newFakeAPI(), + dialer: dialer, locator: locations.New(bridge.NewTestLocationsProvider(tb), "config-name"), storeKey: []byte("super-secret-store-key"), - mocks: bridge.NewMocks(tb, defaultVersion, defaultVersion), + mocks: bridge.NewMocks(tb, dialer, defaultVersion, defaultVersion), version: defaultVersion, userIDByName: make(map[string]string), diff --git a/tests/environment_test.go b/tests/environment_test.go index aace1ce1..2854bd0d 100644 --- a/tests/environment_test.go +++ b/tests/environment_test.go @@ -35,12 +35,12 @@ func (s *scenario) itFailsWithError(wantErr string) error { } func (s *scenario) internetIsTurnedOff() error { - s.t.mocks.TLSDialer.SetCanDial(false) + s.t.dialer.SetCanDial(false) return nil } func (s *scenario) internetIsTurnedOn() error { - s.t.mocks.TLSDialer.SetCanDial(true) + s.t.dialer.SetCanDial(true) return nil } diff --git a/tests/features/user/login.feature b/tests/features/user/login.feature index f7f46321..9b14b772 100644 --- a/tests/features/user/login.feature +++ b/tests/features/user/login.feature @@ -20,6 +20,14 @@ Feature: A user can login When the user logs in with username "user@pm.me" and password "password" Then user "user@pm.me" is not listed + Scenario: Login to account without internet but the connection is later restored + When the user logs in with username "user@pm.me" and password "password" + And bridge stops + And the internet is turned off + And bridge starts + And the internet is turned on + Then user "user@pm.me" is eventually listed and connected + Scenario: Login to multiple accounts Given there exists an account with username "additional@pm.me" and password "other" When the user logs in with username "user@pm.me" and password "password" diff --git a/tests/user_test.go b/tests/user_test.go index 792bdb40..634a711c 100644 --- a/tests/user_test.go +++ b/tests/user_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "time" "github.com/ProtonMail/gluon/rfc822" "github.com/bradenaw/juniper/xslices" @@ -182,6 +183,14 @@ func (s *scenario) userIsListedAndConnected(username string) error { return nil } +func (s *scenario) userIsEventuallyListedAndConnected(username string) error { + return eventually( + func() error { return s.userIsListedAndConnected(username) }, + 5*time.Second, + 100*time.Millisecond, + ) +} + func (s *scenario) userIsListedButNotConnected(username string) error { user, err := s.t.bridge.GetUserInfo(s.t.getUserID(username)) if err != nil {