Files
proton-bridge/internal/bridge/bridge_test.go
2022-11-16 12:26:08 +01:00

389 lines
13 KiB
Go

package bridge_test
import (
"context"
"os"
"testing"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi/server"
)
const (
username = "username"
password = "password"
)
var (
v2_3_0 = semver.MustParse("2.3.0")
v2_4_0 = semver.MustParse("2.4.0")
)
func TestBridge_ConnStatus(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of connection status events.
eventCh, done := bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{})
defer done()
// Simulate network disconnect.
dialer.SetCanDial(false)
// Trigger some operation that will fail due to the network disconnect.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.Error(t, err)
// Wait for the event.
require.Equal(t, events.ConnStatusDown{}, <-eventCh)
// Simulate network reconnect.
dialer.SetCanDial(true)
// Trigger some operation that will succeed due to the network reconnect.
userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
require.NotEmpty(t, userID)
// Wait for the event.
require.Equal(t, events.ConnStatusUp{}, <-eventCh)
})
})
}
func TestBridge_TLSIssue(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of TLS issue events.
tlsEventCh, done := bridge.GetEvents(events.TLSIssue{})
defer done()
// Simulate a TLS issue.
go func() {
mocks.TLSIssueCh <- struct{}{}
}()
// Wait for the event.
require.IsType(t, events.TLSIssue{}, <-tlsEventCh)
})
})
}
func TestBridge_Focus(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of TLS issue events.
raiseCh, done := bridge.GetEvents(events.Raise{})
defer done()
// Simulate a focus event.
focus.TryRaise()
// Wait for the event.
require.IsType(t, events.Raise{}, <-raiseCh)
})
})
}
func TestBridge_UserAgent(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
var calls []server.Call
s.AddCallWatcher(func(call server.Call) {
calls = append(calls, call)
})
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Set the platform to something other than the default.
bridge.SetCurrentPlatform("platform")
// Assert that the user agent then contains the platform.
require.Contains(t, bridge.GetCurrentUserAgent(), "platform")
// Login the user.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
// Assert that the user agent was sent to the API.
require.Contains(t, calls[len(calls)-1].Request.Header.Get("User-Agent"), bridge.GetCurrentUserAgent())
})
})
}
func TestBridge_Cookies(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
var calls []server.Call
s.AddCallWatcher(func(call server.Call) {
calls = append(calls, call)
})
var sessionID string
// Start bridge and add a user so that API assigns us a session ID via cookie.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
cookie, err := calls[len(calls)-1].Request.Cookie("Session-Id")
require.NoError(t, err)
sessionID = cookie.Value
})
// Start bridge again and check that it uses the same session ID.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
cookie, err := calls[len(calls)-1].Request.Cookie("Session-Id")
require.NoError(t, err)
require.Equal(t, sessionID, cookie.Value)
})
})
}
func TestBridge_CheckUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{})
defer done()
// We are currently on the latest version.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateNotAvailable{}, <-updateCh)
// Simulate a new version being available.
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
// Check for updates.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{
Version: v2_4_0,
MinAuto: v2_3_0,
RolloutProportion: 1.0,
},
CanInstall: true,
}, <-updateCh)
})
})
}
func TestBridge_AutoUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Enable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(true))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateInstalled{})
defer done()
// Simulate a new version being available.
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
// Check for updates.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateInstalled{
Version: updater.VersionInfo{
Version: v2_4_0,
MinAuto: v2_3_0,
RolloutProportion: 1.0,
},
}, <-updateCh)
})
})
}
func TestBridge_ManualUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{})
defer done()
// Simulate a new version being available, but it's too new for us.
mocks.Updater.SetLatestVersion(v2_4_0, v2_4_0)
// Check for updates.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{
Version: v2_4_0,
MinAuto: v2_4_0,
RolloutProportion: 1.0,
},
CanInstall: false,
}, <-updateCh)
})
})
}
func TestBridge_ForceUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateForced{})
defer done()
// Set the minimum accepted app version to something newer than the current version.
s.SetMinAppVersion(v2_4_0)
// Try to login the user. It will fail because the bridge is too old.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.Error(t, err)
// We should get an update required event.
require.Equal(t, events.UpdateForced{}, <-updateCh)
})
})
}
func TestBridge_BadVaultKey(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
var userID string
// Login a user.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
userID = newUserID
})
// Start bridge with the correct vault key -- it should load the users correctly.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs())
})
// Start bridge with a bad vault key, the vault will be wiped and bridge will show no users.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
})
// Start bridge with a nil vault key, the vault will be wiped and bridge will show no users.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
})
})
}
func TestBridge_MissingGluonDir(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
var gluonDir string
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
// Move the gluon dir.
bridge.SetGluonDir(ctx, t.TempDir())
// Get the gluon dir.
gluonDir = bridge.GetGluonDir()
})
// The user removes the gluon dir while bridge is not running.
require.NoError(t, os.RemoveAll(gluonDir))
// Bridge starts but can't find the gluon dir; there should be no error.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// ...
})
})
}
// withEnv creates the full test environment and runs the tests.
func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte)) {
// Create test API.
server := server.NewTLS()
defer server.Close()
// Add test user.
_, _, err := server.AddUser(username, password, username+"@pm.me")
require.NoError(t, err)
// Generate a random vault key.
vaultKey, err := crypto.RandomToken(32)
require.NoError(t, err)
// Create a context used for the test.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Run the tests.
tests(
ctx,
server,
bridge.NewTestDialer(),
locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name"),
vaultKey,
)
}
// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done.
func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) {
// Create the mock objects used in the tests.
mocks := bridge.NewMocks(t, dialer, v2_3_0, v2_3_0)
// Bridge will enable the proxy by default at startup.
mocks.ProxyDialer.EXPECT().AllowProxy()
// Get the path to the vault.
vaultDir, err := locator.ProvideSettingsPath()
require.NoError(t, err)
// Create the vault.
vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey)
require.NoError(t, err)
// Let the IMAP and SMTP servers choose random available ports for this test.
require.NoError(t, vault.SetIMAPPort(0))
require.NoError(t, vault.SetSMTPPort(0))
// Create a new bridge.
bridge, err := bridge.New(apiURL, locator, vault, useragent.New(), mocks.TLSReporter, mocks.ProxyDialer, mocks.Autostarter, mocks.Updater, v2_3_0)
require.NoError(t, err)
// Close the bridge when done.
defer bridge.Close(ctx)
// Use the bridge.
tests(bridge, mocks)
}
// must is a helper function that panics on error.
func must[T any](val T, err error) T {
if err != nil {
panic(err)
}
return val
}
func getConnectedUserIDs(t *testing.T, bridge *bridge.Bridge) []string {
t.Helper()
return xslices.Filter(bridge.GetUserIDs(), func(userID string) bool {
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
return info.Connected
})
}