mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-15 14:56:42 +00:00
GODT-1657: More stable sync, with some tests
This commit is contained in:
@ -33,10 +33,10 @@ type Bridge struct {
|
||||
users map[string]*user.User
|
||||
|
||||
// api manages user API clients.
|
||||
api *liteapi.Manager
|
||||
cookieJar *cookies.Jar
|
||||
proxyDialer ProxyDialer
|
||||
identifier Identifier
|
||||
api *liteapi.Manager
|
||||
cookieJar *cookies.Jar
|
||||
proxyCtl ProxyController
|
||||
identifier Identifier
|
||||
|
||||
// watchers holds all registered event watchers.
|
||||
watchers []*watcher.Watcher[events.Event]
|
||||
@ -81,15 +81,16 @@ func New(
|
||||
vault *vault.Vault, // the bridge's encrypted data store
|
||||
identifier Identifier, // the identifier to keep track of the user agent
|
||||
tlsReporter TLSReporter, // the TLS reporter to report TLS errors
|
||||
proxyDialer ProxyDialer, // the DoH dialer
|
||||
roundTripper http.RoundTripper, // the round tripper to use for API requests
|
||||
proxyCtl ProxyController, // the DoH controller
|
||||
autostarter Autostarter, // the autostarter to manage autostart settings
|
||||
updater Updater, // the updater to fetch and install updates
|
||||
curVersion *semver.Version, // the current version of the bridge
|
||||
) (*Bridge, error) {
|
||||
if vault.GetProxyAllowed() {
|
||||
proxyDialer.AllowProxy()
|
||||
proxyCtl.AllowProxy()
|
||||
} else {
|
||||
proxyDialer.DisallowProxy()
|
||||
proxyCtl.DisallowProxy()
|
||||
}
|
||||
|
||||
cookieJar, err := cookies.NewCookieJar(vault)
|
||||
@ -101,7 +102,7 @@ func New(
|
||||
liteapi.WithHostURL(apiURL),
|
||||
liteapi.WithAppVersion(constants.AppVersion),
|
||||
liteapi.WithCookieJar(cookieJar),
|
||||
liteapi.WithTransport(&http.Transport{DialTLSContext: proxyDialer.DialTLSContext}),
|
||||
liteapi.WithTransport(roundTripper),
|
||||
)
|
||||
|
||||
tlsConfig, err := loadTLSConfig(vault)
|
||||
@ -139,10 +140,10 @@ func New(
|
||||
vault: vault,
|
||||
users: make(map[string]*user.User),
|
||||
|
||||
api: api,
|
||||
cookieJar: cookieJar,
|
||||
proxyDialer: proxyDialer,
|
||||
identifier: identifier,
|
||||
api: api,
|
||||
cookieJar: cookieJar,
|
||||
proxyCtl: proxyCtl,
|
||||
identifier: identifier,
|
||||
|
||||
tlsConfig: tlsConfig,
|
||||
imapServer: imapServer,
|
||||
@ -179,6 +180,10 @@ func New(
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := bridge.loadUsers(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load users: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for range tlsReporter.GetTLSIssueCh() {
|
||||
bridge.publish(events.TLSIssue{})
|
||||
@ -197,10 +202,6 @@ func New(
|
||||
}
|
||||
}()
|
||||
|
||||
if err := bridge.loadUsers(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("failed to load connected users: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.serveIMAP(); err != nil {
|
||||
bridge.PushError(ErrServeIMAP)
|
||||
}
|
||||
@ -309,14 +310,8 @@ func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) {
|
||||
func (bridge *Bridge) onStatusUp() {
|
||||
bridge.publish(events.ConnStatusUp{})
|
||||
|
||||
for _, userID := range bridge.vault.GetUserIDs() {
|
||||
if _, ok := bridge.users[userID]; !ok {
|
||||
if vaultUser, err := bridge.vault.GetUser(userID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to get user from vault")
|
||||
} else if err := bridge.loadUser(context.Background(), vaultUser); err != nil {
|
||||
logrus.WithError(err).Error("Failed to load user")
|
||||
}
|
||||
}
|
||||
if err := bridge.loadUsers(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to load users")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ package bridge_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
@ -21,6 +22,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v2/tests"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
"gitlab.protontech.ch/go/liteapi/server/backend"
|
||||
)
|
||||
@ -41,14 +43,14 @@ func init() {
|
||||
}
|
||||
|
||||
func TestBridge_ConnStatus(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of connection status events.
|
||||
eventCh, done := bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{})
|
||||
defer done()
|
||||
|
||||
// Simulate network disconnect.
|
||||
dialer.SetCanDial(false)
|
||||
netCtl.Disable()
|
||||
|
||||
// Trigger some operation that will fail due to the network disconnect.
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
@ -58,7 +60,7 @@ func TestBridge_ConnStatus(t *testing.T) {
|
||||
require.Equal(t, events.ConnStatusDown{}, <-eventCh)
|
||||
|
||||
// Simulate network reconnect.
|
||||
dialer.SetCanDial(true)
|
||||
netCtl.Enable()
|
||||
|
||||
// Trigger some operation that will succeed due to the network reconnect.
|
||||
userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
@ -72,8 +74,8 @@ func TestBridge_ConnStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_TLSIssue(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of TLS issue events.
|
||||
tlsEventCh, done := bridge.GetEvents(events.TLSIssue{})
|
||||
defer done()
|
||||
@ -90,8 +92,8 @@ func TestBridge_TLSIssue(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Focus(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of TLS issue events.
|
||||
raiseCh, done := bridge.GetEvents(events.Raise{})
|
||||
defer done()
|
||||
@ -106,14 +108,14 @@ func TestBridge_Focus(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_UserAgent(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
var calls []server.Call
|
||||
|
||||
s.AddCallWatcher(func(call server.Call) {
|
||||
calls = append(calls, call)
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Set the platform to something other than the default.
|
||||
bridge.SetCurrentPlatform("platform")
|
||||
|
||||
@ -131,7 +133,7 @@ func TestBridge_UserAgent(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Cookies(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
var calls []server.Call
|
||||
|
||||
s.AddCallWatcher(func(call server.Call) {
|
||||
@ -141,7 +143,7 @@ func TestBridge_Cookies(t *testing.T) {
|
||||
var sessionID string
|
||||
|
||||
// Start bridge and add a user so that API assigns us a session ID via cookie.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -152,7 +154,7 @@ func TestBridge_Cookies(t *testing.T) {
|
||||
})
|
||||
|
||||
// Start bridge again and check that it uses the same session ID.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
cookie, err := (&http.Request{Header: calls[len(calls)-1].Header}).Cookie("Session-Id")
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -162,8 +164,8 @@ func TestBridge_Cookies(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_CheckUpdate(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Disable autoupdate for this test.
|
||||
require.NoError(t, bridge.SetAutoUpdate(false))
|
||||
|
||||
@ -201,8 +203,8 @@ func TestBridge_CheckUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_AutoUpdate(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Enable autoupdate for this test.
|
||||
require.NoError(t, bridge.SetAutoUpdate(true))
|
||||
|
||||
@ -229,8 +231,8 @@ func TestBridge_AutoUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_ManualUpdate(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Disable autoupdate for this test.
|
||||
require.NoError(t, bridge.SetAutoUpdate(false))
|
||||
|
||||
@ -258,8 +260,8 @@ func TestBridge_ManualUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_ForceUpdate(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of update events.
|
||||
updateCh, done := bridge.GetEvents(events.UpdateForced{})
|
||||
defer done()
|
||||
@ -278,11 +280,11 @@ func TestBridge_ForceUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_BadVaultKey(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
var userID string
|
||||
|
||||
// Login a user.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -290,27 +292,27 @@ func TestBridge_BadVaultKey(t *testing.T) {
|
||||
})
|
||||
|
||||
// Start bridge with the correct vault key -- it should load the users correctly.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs())
|
||||
})
|
||||
|
||||
// Start bridge with a bad vault key, the vault will be wiped and bridge will show no users.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
})
|
||||
|
||||
// Start bridge with a nil vault key, the vault will be wiped and bridge will show no users.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_MissingGluonDir(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
var gluonDir string
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -325,20 +327,20 @@ func TestBridge_MissingGluonDir(t *testing.T) {
|
||||
require.NoError(t, os.RemoveAll(gluonDir))
|
||||
|
||||
// Bridge starts but can't find the gluon dir; there should be no error.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// ...
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// withEnv creates the full test environment and runs the tests.
|
||||
func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte)) {
|
||||
// withTLSEnv creates the full test environment and runs the tests.
|
||||
func withTLSEnv(t *testing.T, tests func(context.Context, *server.Server, *liteapi.NetCtl, bridge.Locator, []byte)) {
|
||||
// Create test API.
|
||||
server := server.NewTLS()
|
||||
defer server.Close()
|
||||
|
||||
// Add test user.
|
||||
_, _, err := server.CreateUser(username, string(password), username+"@pm.me")
|
||||
_, _, err := server.CreateUser(username, username+"@pm.me", password)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a random vault key.
|
||||
@ -349,23 +351,56 @@ func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create a net controller so we can simulate network connectivity issues.
|
||||
netCtl := liteapi.NewNetCtl()
|
||||
|
||||
// Create a locations object to provide temporary locations for bridge data during the test.
|
||||
locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name")
|
||||
|
||||
// Run the tests.
|
||||
tests(
|
||||
ctx,
|
||||
server,
|
||||
bridge.NewTestDialer(),
|
||||
locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name"),
|
||||
vaultKey,
|
||||
)
|
||||
tests(ctx, server, netCtl, locations, vaultKey)
|
||||
}
|
||||
|
||||
// withEnv creates the full test environment and runs the tests.
|
||||
func withEnv(t *testing.T, server *server.Server, tests func(context.Context, *liteapi.NetCtl, bridge.Locator, []byte)) {
|
||||
// Add test user.
|
||||
_, _, err := server.CreateUser(username, username+"@pm.me", password)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a random vault key.
|
||||
vaultKey, err := crypto.RandomToken(32)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a context used for the test.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create a net controller so we can simulate network connectivity issues.
|
||||
netCtl := liteapi.NewNetCtl()
|
||||
|
||||
// Create a locations object to provide temporary locations for bridge data during the test.
|
||||
locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name")
|
||||
|
||||
// Run the tests.
|
||||
tests(ctx, netCtl, locations, vaultKey)
|
||||
}
|
||||
|
||||
// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done.
|
||||
func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) {
|
||||
func withBridge(
|
||||
t *testing.T,
|
||||
ctx context.Context,
|
||||
apiURL string,
|
||||
netCtl *liteapi.NetCtl,
|
||||
locator bridge.Locator,
|
||||
vaultKey []byte,
|
||||
tests func(*bridge.Bridge, *bridge.Mocks),
|
||||
) {
|
||||
// Create the mock objects used in the tests.
|
||||
mocks := bridge.NewMocks(t, dialer, v2_3_0, v2_3_0)
|
||||
mocks := bridge.NewMocks(t, v2_3_0, v2_3_0)
|
||||
defer mocks.Close()
|
||||
|
||||
// Bridge will enable the proxy by default at startup.
|
||||
mocks.ProxyDialer.EXPECT().AllowProxy()
|
||||
mocks.ProxyCtl.EXPECT().AllowProxy()
|
||||
|
||||
// Get the path to the vault.
|
||||
vaultDir, err := locator.ProvideSettingsPath()
|
||||
@ -380,7 +415,18 @@ func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge
|
||||
require.NoError(t, vault.SetSMTPPort(0))
|
||||
|
||||
// Create a new bridge.
|
||||
bridge, err := bridge.New(apiURL, locator, vault, useragent.New(), mocks.TLSReporter, mocks.ProxyDialer, mocks.Autostarter, mocks.Updater, v2_3_0)
|
||||
bridge, err := bridge.New(
|
||||
apiURL,
|
||||
locator,
|
||||
vault,
|
||||
useragent.New(),
|
||||
mocks.TLSReporter,
|
||||
liteapi.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
|
||||
mocks.ProxyCtl,
|
||||
mocks.Autostarter,
|
||||
mocks.Updater,
|
||||
v2_3_0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close the bridge when done.
|
||||
|
||||
@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon"
|
||||
@ -33,6 +35,22 @@ func (bridge *Bridge) serveIMAP() error {
|
||||
return fmt.Errorf("failed to serve IMAP: %w", err)
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(imapListener.Addr().String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get IMAP listener address: %w", err)
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert IMAP listener port to int: %w", err)
|
||||
}
|
||||
|
||||
if portInt != bridge.vault.GetIMAPPort() {
|
||||
if err := bridge.vault.SetIMAPPort(portInt); err != nil {
|
||||
return fmt.Errorf("failed to update IMAP port in vault: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
for err := range bridge.imapServer.GetErrorCh() {
|
||||
logrus.WithError(err).Error("IMAP server error")
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@ -15,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
type Mocks struct {
|
||||
ProxyDialer *mocks.MockProxyDialer
|
||||
ProxyCtl *mocks.MockProxyController
|
||||
TLSReporter *mocks.MockTLSReporter
|
||||
TLSIssueCh chan struct{}
|
||||
|
||||
@ -23,11 +19,11 @@ type Mocks struct {
|
||||
Autostarter *mocks.MockAutostarter
|
||||
}
|
||||
|
||||
func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Version) *Mocks {
|
||||
func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks {
|
||||
ctl := gomock.NewController(tb)
|
||||
|
||||
mocks := &Mocks{
|
||||
ProxyDialer: mocks.NewMockProxyDialer(ctl),
|
||||
ProxyCtl: mocks.NewMockProxyController(ctl),
|
||||
TLSReporter: mocks.NewMockTLSReporter(ctl),
|
||||
TLSIssueCh: make(chan struct{}),
|
||||
|
||||
@ -35,41 +31,14 @@ func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Versio
|
||||
Autostarter: mocks.NewMockAutostarter(ctl),
|
||||
}
|
||||
|
||||
// When using the proxy dialer, we want to use the test dialer.
|
||||
mocks.ProxyDialer.EXPECT().DialTLSContext(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).DoAndReturn(func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return dialer.DialTLSContext(ctx, network, address)
|
||||
}).AnyTimes()
|
||||
|
||||
// When getting the TLS issue channel, we want to return the test channel.
|
||||
mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes()
|
||||
|
||||
return mocks
|
||||
}
|
||||
|
||||
type TestDialer struct {
|
||||
canDial bool
|
||||
}
|
||||
|
||||
func NewTestDialer() *TestDialer {
|
||||
return &TestDialer{
|
||||
canDial: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *TestDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||
if !d.canDial {
|
||||
return nil, errors.New("cannot dial")
|
||||
}
|
||||
|
||||
return (&tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}).DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
func (d *TestDialer) SetCanDial(canDial bool) {
|
||||
d.canDial = canDial
|
||||
func (mocks *Mocks) Close() {
|
||||
close(mocks.TLSIssueCh)
|
||||
}
|
||||
|
||||
type TestLocationsProvider struct {
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyDialer,Autostarter)
|
||||
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyController,Autostarter)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@ -49,66 +47,51 @@ func (mr *MockTLSReporterMockRecorder) GetTLSIssueCh() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh))
|
||||
}
|
||||
|
||||
// MockProxyDialer is a mock of ProxyDialer interface.
|
||||
type MockProxyDialer struct {
|
||||
// MockProxyController is a mock of ProxyController interface.
|
||||
type MockProxyController struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockProxyDialerMockRecorder
|
||||
recorder *MockProxyControllerMockRecorder
|
||||
}
|
||||
|
||||
// MockProxyDialerMockRecorder is the mock recorder for MockProxyDialer.
|
||||
type MockProxyDialerMockRecorder struct {
|
||||
mock *MockProxyDialer
|
||||
// MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
|
||||
type MockProxyControllerMockRecorder struct {
|
||||
mock *MockProxyController
|
||||
}
|
||||
|
||||
// NewMockProxyDialer creates a new mock instance.
|
||||
func NewMockProxyDialer(ctrl *gomock.Controller) *MockProxyDialer {
|
||||
mock := &MockProxyDialer{ctrl: ctrl}
|
||||
mock.recorder = &MockProxyDialerMockRecorder{mock}
|
||||
// NewMockProxyController creates a new mock instance.
|
||||
func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
|
||||
mock := &MockProxyController{ctrl: ctrl}
|
||||
mock.recorder = &MockProxyControllerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockProxyDialer) EXPECT() *MockProxyDialerMockRecorder {
|
||||
func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AllowProxy mocks base method.
|
||||
func (m *MockProxyDialer) AllowProxy() {
|
||||
func (m *MockProxyController) AllowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AllowProxy")
|
||||
}
|
||||
|
||||
// AllowProxy indicates an expected call of AllowProxy.
|
||||
func (mr *MockProxyDialerMockRecorder) AllowProxy() *gomock.Call {
|
||||
func (mr *MockProxyControllerMockRecorder) AllowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyDialer)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// DialTLSContext mocks base method.
|
||||
func (m *MockProxyDialer) DialTLSContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DialTLSContext", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(net.Conn)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DialTLSContext indicates an expected call of DialTLSContext.
|
||||
func (mr *MockProxyDialerMockRecorder) DialTLSContext(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLSContext", reflect.TypeOf((*MockProxyDialer)(nil).DialTLSContext), arg0, arg1, arg2)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyController)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// DisallowProxy mocks base method.
|
||||
func (m *MockProxyDialer) DisallowProxy() {
|
||||
func (m *MockProxyController) DisallowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisallowProxy")
|
||||
}
|
||||
|
||||
// DisallowProxy indicates an expected call of DisallowProxy.
|
||||
func (mr *MockProxyDialerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
func (mr *MockProxyControllerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyDialer)(nil).DisallowProxy))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyController)(nil).DisallowProxy))
|
||||
}
|
||||
|
||||
// MockAutostarter is a mock of Autostarter interface.
|
||||
|
||||
@ -138,9 +138,9 @@ func (bridge *Bridge) GetProxyAllowed() bool {
|
||||
|
||||
func (bridge *Bridge) SetProxyAllowed(allowed bool) error {
|
||||
if allowed {
|
||||
bridge.proxyDialer.AllowProxy()
|
||||
bridge.proxyCtl.AllowProxy()
|
||||
} else {
|
||||
bridge.proxyDialer.DisallowProxy()
|
||||
bridge.proxyCtl.DisallowProxy()
|
||||
}
|
||||
|
||||
return bridge.vault.SetProxyAllowed(allowed)
|
||||
|
||||
@ -7,12 +7,13 @@ import (
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
func TestBridge_Settings_GluonDir(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Create a user.
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
@ -34,8 +35,8 @@ func TestBridge_Settings_GluonDir(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_IMAPPort(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
curPort := bridge.GetIMAPPort()
|
||||
|
||||
// Set the port to 1144.
|
||||
@ -51,8 +52,8 @@ func TestBridge_Settings_IMAPPort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_IMAPSSL(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, IMAP SSL is disabled.
|
||||
require.False(t, bridge.GetIMAPSSL())
|
||||
|
||||
@ -66,8 +67,8 @@ func TestBridge_Settings_IMAPSSL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_SMTPPort(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
curPort := bridge.GetSMTPPort()
|
||||
|
||||
// Set the port to 1024.
|
||||
@ -84,8 +85,8 @@ func TestBridge_Settings_SMTPPort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_SMTPSSL(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, SMTP SSL is disabled.
|
||||
require.False(t, bridge.GetSMTPSSL())
|
||||
|
||||
@ -99,13 +100,13 @@ func TestBridge_Settings_SMTPSSL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_Proxy(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, proxy is allowed.
|
||||
require.True(t, bridge.GetProxyAllowed())
|
||||
|
||||
// Disallow proxy.
|
||||
mocks.ProxyDialer.EXPECT().DisallowProxy()
|
||||
mocks.ProxyCtl.EXPECT().DisallowProxy()
|
||||
require.NoError(t, bridge.SetProxyAllowed(false))
|
||||
|
||||
// Get the new setting.
|
||||
@ -115,8 +116,8 @@ func TestBridge_Settings_Proxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_Autostart(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, autostart is disabled.
|
||||
require.False(t, bridge.GetAutostart())
|
||||
|
||||
@ -131,8 +132,8 @@ func TestBridge_Settings_Autostart(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_FirstStart(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, first start is true.
|
||||
require.True(t, bridge.GetFirstStart())
|
||||
|
||||
@ -146,8 +147,8 @@ func TestBridge_Settings_FirstStart(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_FirstStartGUI(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, first start is true.
|
||||
require.True(t, bridge.GetFirstStartGUI())
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@ package bridge
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/emersion/go-sasl"
|
||||
@ -22,6 +24,22 @@ func (bridge *Bridge) serveSMTP() error {
|
||||
}
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(smtpListener.Addr().String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get SMTP listener address: %w", err)
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert SMTP listener port to int: %w", err)
|
||||
}
|
||||
|
||||
if portInt != bridge.vault.GetSMTPPort() {
|
||||
if err := bridge.vault.SetSMTPPort(portInt); err != nil {
|
||||
return fmt.Errorf("failed to update SMTP port in vault: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
130
internal/bridge/sync_test.go
Normal file
130
internal/bridge/sync_test.go
Normal file
@ -0,0 +1,130 @@
|
||||
package bridge_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/emersion/go-imap/client"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
func TestBridge_Sync(t *testing.T) {
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
|
||||
numMsg := 1 << 10
|
||||
|
||||
withEnv(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password)
|
||||
require.NoError(t, err)
|
||||
|
||||
labelID, err := s.CreateLabel(userID, "folder", liteapi.LabelTypeFolder)
|
||||
require.NoError(t, err)
|
||||
|
||||
literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml"))
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < numMsg; i++ {
|
||||
messageID, err := s.CreateMessage(userID, addrID, literal, liteapi.MessageFlagReceived, false, false)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.LabelMessage(userID, messageID, labelID))
|
||||
}
|
||||
|
||||
var read uint64
|
||||
|
||||
netCtl.OnRead(func(b []byte) {
|
||||
read += uint64(len(b))
|
||||
})
|
||||
|
||||
// The initial user should be fully synced.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
syncCh, done := bridge.GetEvents(events.SyncFinished{})
|
||||
defer done()
|
||||
|
||||
userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID)
|
||||
})
|
||||
|
||||
// If we then connect an IMAP client, it should see all the messages.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
info, err := bridge.GetUserInfo(userID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||
defer client.Logout()
|
||||
|
||||
status, err := client.Select(`Folders/folder`, false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(numMsg), status.Messages)
|
||||
})
|
||||
|
||||
// Now let's remove the user and simulate a network error.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||
})
|
||||
|
||||
// Pretend we can only sync 2/3 of the original messages.
|
||||
netCtl.SetReadLimit(2 * read / 3)
|
||||
|
||||
// Login the user; its sync should fail.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
syncCh, done := bridge.GetEvents(events.SyncFailed{})
|
||||
defer done()
|
||||
|
||||
userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, userID, (<-syncCh).(events.SyncFailed).UserID)
|
||||
|
||||
info, err := bridge.GetUserInfo(userID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||
defer client.Logout()
|
||||
|
||||
status, err := client.Select(`Folders/folder`, false)
|
||||
require.NoError(t, err)
|
||||
require.Less(t, status.Messages, uint32(numMsg))
|
||||
})
|
||||
|
||||
// Remove the network limit, allowing the sync to finish.
|
||||
netCtl.SetReadLimit(0)
|
||||
|
||||
// Login the user; its sync should now finish.
|
||||
// If we then connect an IMAP client, it should eventually see all the messages.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
syncCh, done := bridge.GetEvents(events.SyncFinished{})
|
||||
defer done()
|
||||
|
||||
require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID)
|
||||
|
||||
info, err := bridge.GetUserInfo(userID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||
defer client.Logout()
|
||||
|
||||
status, err := client.Select(`Folders/folder`, false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(numMsg), status.Messages)
|
||||
})
|
||||
})
|
||||
}
|
||||
6
internal/bridge/testdata/text-plain.eml
vendored
Normal file
6
internal/bridge/testdata/text-plain.eml
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
To: recipient@pm.me
|
||||
From: sender@pm.me
|
||||
Subject: Test
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
|
||||
Test
|
||||
@ -1,9 +1,6 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
)
|
||||
|
||||
@ -21,17 +18,15 @@ type Identifier interface {
|
||||
SetPlatform(platform string)
|
||||
}
|
||||
|
||||
type TLSReporter interface {
|
||||
GetTLSIssueCh() <-chan struct{}
|
||||
}
|
||||
|
||||
type ProxyDialer interface {
|
||||
DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
type ProxyController interface {
|
||||
AllowProxy()
|
||||
DisallowProxy()
|
||||
}
|
||||
|
||||
type TLSReporter interface {
|
||||
GetTLSIssueCh() <-chan struct{}
|
||||
}
|
||||
|
||||
type Autostarter interface {
|
||||
Enable() error
|
||||
Disable() error
|
||||
|
||||
@ -18,6 +18,9 @@ func (bridge *Bridge) watchForUpdates() error {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-bridge.stopCh:
|
||||
return
|
||||
|
||||
case <-bridge.updateCheckCh:
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/try"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/go-resty/resty/v2"
|
||||
@ -82,76 +83,76 @@ func (bridge *Bridge) LoginUser(
|
||||
) (string, error) {
|
||||
client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to create new API client: %w", err)
|
||||
}
|
||||
|
||||
if _, ok := bridge.users[auth.UserID]; ok {
|
||||
return "", ErrUserAlreadyLoggedIn
|
||||
}
|
||||
userID, err := try.CatchVal(
|
||||
func() (string, error) {
|
||||
if _, ok := bridge.users[auth.UserID]; ok {
|
||||
return "", ErrUserAlreadyLoggedIn
|
||||
}
|
||||
|
||||
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
|
||||
totp, err := getTOTP()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
|
||||
totp, err := getTOTP()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get TOTP: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
|
||||
return "", fmt.Errorf("failed to authorize 2FA: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var keyPass []byte
|
||||
var keyPass []byte
|
||||
|
||||
if auth.PasswordMode == liteapi.TwoPasswordMode {
|
||||
pass, err := getKeyPass()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if auth.PasswordMode == liteapi.TwoPasswordMode {
|
||||
userKeyPass, err := getKeyPass()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get key password: %w", err)
|
||||
}
|
||||
|
||||
keyPass = pass
|
||||
} else {
|
||||
keyPass = password
|
||||
}
|
||||
keyPass = userKeyPass
|
||||
} else {
|
||||
keyPass = password
|
||||
}
|
||||
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
return bridge.loginUser(ctx, client, auth.UID, auth.RefreshToken, keyPass)
|
||||
},
|
||||
func() error {
|
||||
return client.AuthDelete(ctx)
|
||||
},
|
||||
func() error {
|
||||
bridge.deleteUser(ctx, auth.UserID)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to login user: %w", err)
|
||||
}
|
||||
|
||||
salts, err := client.GetSalts(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
bridge.publish(events.UserLoggedIn{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return auth.UserID, nil
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// LogoutUser logs out the given user.
|
||||
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
|
||||
return bridge.logoutUser(ctx, userID, true, false)
|
||||
if err := bridge.logoutUser(ctx, userID); err != nil {
|
||||
return fmt.Errorf("failed to logout user: %w", err)
|
||||
}
|
||||
|
||||
bridge.publish(events.UserLoggedOut{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes the given user.
|
||||
// If it is authorized, it is logged out first.
|
||||
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
|
||||
if bridge.users[userID] != nil {
|
||||
if err := bridge.logoutUser(ctx, userID, true, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
return err
|
||||
}
|
||||
bridge.deleteUser(ctx, userID)
|
||||
|
||||
bridge.publish(events.UserDeleted{
|
||||
UserID: userID,
|
||||
@ -193,53 +194,91 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadUsers loads authorized users from the vault.
|
||||
func (bridge *Bridge) loadUsers(ctx context.Context) error {
|
||||
for _, userID := range bridge.vault.GetUserIDs() {
|
||||
user, err := bridge.vault.GetUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get API user: %w", err)
|
||||
}
|
||||
|
||||
salts, err := client.GetSalts(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get key salts: %w", err)
|
||||
}
|
||||
|
||||
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to salt key password: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass); err != nil {
|
||||
return "", fmt.Errorf("failed to add bridge user: %w", err)
|
||||
}
|
||||
|
||||
return apiUser.ID, nil
|
||||
}
|
||||
|
||||
// loadUsers is a loop that, when polled, attempts to load authorized users from the vault.
|
||||
func (bridge *Bridge) loadUsers() error {
|
||||
return bridge.vault.ForUser(func(user *vault.User) error {
|
||||
if _, ok := bridge.users[user.UserID()]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if user.AuthUID() == "" {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := bridge.loadUser(ctx, user); err != nil {
|
||||
logrus.WithError(err).Error("Failed to load connected user")
|
||||
|
||||
if err := bridge.loadUser(user); err != nil {
|
||||
if _, ok := err.(*resty.ResponseError); ok {
|
||||
if err := bridge.vault.ClearUser(userID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to load connected user, clearing its secrets from vault")
|
||||
|
||||
if err := user.Clear(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to clear user")
|
||||
}
|
||||
} else {
|
||||
logrus.WithError(err).Error("Failed to load connected user")
|
||||
}
|
||||
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
bridge.publish(events.UserLoaded{
|
||||
UserID: user.UserID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
|
||||
// loadUser loads an existing user from the vault.
|
||||
func (bridge *Bridge) loadUser(user *vault.User) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create API client: %w", err)
|
||||
}
|
||||
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
if err := try.Catch(
|
||||
func() error {
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil {
|
||||
return fmt.Errorf("failed to add user: %w", err)
|
||||
return bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass())
|
||||
},
|
||||
func() error {
|
||||
return client.AuthDelete(ctx)
|
||||
},
|
||||
func() error {
|
||||
return bridge.logoutUser(ctx, user.UserID())
|
||||
},
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to load user: %w", err)
|
||||
}
|
||||
|
||||
bridge.publish(events.UserLoggedIn{
|
||||
UserID: user.UserID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -304,10 +343,6 @@ func (bridge *Bridge) addUser(
|
||||
return nil
|
||||
})
|
||||
|
||||
bridge.publish(events.UserLoggedIn{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -363,54 +398,6 @@ func (bridge *Bridge) addExistingUser(
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// logoutUser closes and removes the user with the given ID.
|
||||
// If withAPI is true, the user will additionally be logged out from API.
|
||||
// If withFiles is true, the user's files will be deleted.
|
||||
func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, withFiles bool) error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||
return fmt.Errorf("failed to remove SMTP user: %w", err)
|
||||
}
|
||||
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
|
||||
return fmt.Errorf("failed to remove IMAP user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if withAPI {
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
return fmt.Errorf("failed to logout user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close user: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.vault.ClearUser(userID); err != nil {
|
||||
return fmt.Errorf("failed to clear user: %w", err)
|
||||
}
|
||||
|
||||
if withFiles {
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
delete(bridge.users, userID)
|
||||
|
||||
bridge.publish(events.UserLoggedOut{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addIMAPUser connects the given user to gluon.
|
||||
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
@ -438,6 +425,65 @@ func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// logoutUser logs the given user out from bridge.
|
||||
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
|
||||
}
|
||||
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
if err := user.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close user")
|
||||
}
|
||||
|
||||
delete(bridge.users, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteUser deletes the given user from bridge.
|
||||
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
|
||||
if user, ok := bridge.users[userID]; ok {
|
||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
|
||||
}
|
||||
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
if err := user.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to delete user from vault")
|
||||
}
|
||||
|
||||
delete(bridge.users, userID)
|
||||
}
|
||||
|
||||
// getUserInfo returns information about a disconnected user.
|
||||
func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo {
|
||||
return UserInfo{
|
||||
|
||||
@ -27,7 +27,7 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
|
||||
}
|
||||
|
||||
case events.UserDeauth:
|
||||
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
|
||||
if err := bridge.logoutUser(context.Background(), event.UserID); err != nil {
|
||||
return fmt.Errorf("failed to logout user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -9,17 +9,18 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
func TestBridge_WithoutUsers(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
@ -27,8 +28,8 @@ func TestBridge_WithoutUsers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Login(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
@ -41,8 +42,8 @@ func TestBridge_Login(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginLogoutLogin(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -69,8 +70,8 @@ func TestBridge_LoginLogoutLogin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginDeleteLogin(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -97,8 +98,8 @@ func TestBridge_LoginDeleteLogin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginDeauthLogin(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -131,10 +132,10 @@ func TestBridge_LoginDeauthLogin(t *testing.T) {
|
||||
func TestBridge_LoginExpireLogin(t *testing.T) {
|
||||
const authLife = 2 * time.Second
|
||||
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
s.SetAuthLife(authLife)
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user. Its auth will only be valid for a short time.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -148,11 +149,11 @@ func TestBridge_LoginExpireLogin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_FailToLoad(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
// Login the user.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
})
|
||||
|
||||
@ -160,7 +161,7 @@ func TestBridge_FailToLoad(t *testing.T) {
|
||||
require.NoError(t, s.RevokeUser(userID))
|
||||
|
||||
// When bridge starts, the user will not be logged in.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
@ -168,25 +169,27 @@ func TestBridge_FailToLoad(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoadWithoutInternet(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
// Login the user.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
})
|
||||
|
||||
// Simulate loss of internet connection.
|
||||
dialer.SetCanDial(false)
|
||||
netCtl.Disable()
|
||||
|
||||
// Start bridge without internet.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Initially, users are not connected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Simulate internet connection.
|
||||
dialer.SetCanDial(true)
|
||||
netCtl.Enable()
|
||||
|
||||
// The user will eventually be connected.
|
||||
require.Eventually(t, func() bool {
|
||||
@ -197,16 +200,14 @@ func TestBridge_LoadWithoutInternet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginRestart(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user is still connected.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
@ -214,10 +215,10 @@ func TestBridge_LoginRestart(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginLogoutRestart(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -225,7 +226,7 @@ func TestBridge_LoginLogoutRestart(t *testing.T) {
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user is still disconnected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
@ -234,10 +235,10 @@ func TestBridge_LoginLogoutRestart(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginDeleteRestart(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -245,7 +246,7 @@ func TestBridge_LoginDeleteRestart(t *testing.T) {
|
||||
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user is still gone.
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
@ -253,13 +254,69 @@ func TestBridge_LoginDeleteRestart(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_FailLoginRecover(t *testing.T) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var read uint64
|
||||
|
||||
netCtl.OnRead(func(b []byte) {
|
||||
read += uint64(len(b))
|
||||
})
|
||||
|
||||
// Log the user in and record how much data was read.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
})
|
||||
|
||||
// Simulate a partial read.
|
||||
netCtl.SetReadLimit(read / 2)
|
||||
|
||||
// We should fail to log the user in because we can't fully read its data.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Error(t, getErr(bridge.LoginUser(ctx, username, password, nil, nil)))
|
||||
|
||||
// There should be no users.
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_FailLoadRecover(t *testing.T) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
})
|
||||
|
||||
var read uint64
|
||||
|
||||
netCtl.OnRead(func(b []byte) {
|
||||
read += uint64(len(b))
|
||||
})
|
||||
|
||||
// Start bridge and record how much data was read.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// ...
|
||||
})
|
||||
|
||||
// Simulate a partial read.
|
||||
netCtl.SetReadLimit(read / 2)
|
||||
|
||||
// We should fail to load the user; it should be disconnected.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userIDs := bridge.GetUserIDs()
|
||||
|
||||
require.False(t, must(bridge.GetUserInfo(userIDs[0])).Connected)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_BridgePass(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
var pass []byte
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -276,7 +333,7 @@ func TestBridge_BridgePass(t *testing.T) {
|
||||
require.Equal(t, pass, pass)
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The bridge should load the user.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
@ -288,8 +345,8 @@ func TestBridge_BridgePass(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_AddressMode(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
@ -313,3 +370,8 @@ func TestBridge_AddressMode(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// getErr returns the error that was passed to it.
|
||||
func getErr[T any](val T, err error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user