GODT-1657: More stable sync, with some tests

This commit is contained in:
James Houlahan
2022-10-09 23:05:52 +02:00
parent e7526f2e78
commit 509a767e50
41 changed files with 883 additions and 779 deletions

View File

@ -234,7 +234,7 @@ integration-test-bridge:
${MAKE} -C test test-bridge ${MAKE} -C test test-bridge
mocks: mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/bridge TLSReporter,ProxyDialer,Autostarter > internal/bridge/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/bridge TLSReporter,ProxyController,Autostarter > internal/bridge/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint: gofiles lint-golang lint-license lint-dependencies lint-changelog

2
go.mod
View File

@ -38,7 +38,7 @@ require (
github.com/sirupsen/logrus v1.9.0 github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.0
github.com/urfave/cli/v2 v2.16.3 github.com/urfave/cli/v2 v2.16.3
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7 gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e
golang.org/x/exp v0.0.0-20220921164117-439092de6870 golang.org/x/exp v0.0.0-20220921164117-439092de6870
golang.org/x/net v0.1.0 golang.org/x/net v0.1.0
golang.org/x/sys v0.1.0 golang.org/x/sys v0.1.0

4
go.sum
View File

@ -397,8 +397,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0= github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7 h1:Hef7jPRzcfLOvOUHYoQ6efaI7p7/aT5kpZDqJ29owNI= gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e h1:CTGaREzkbz7u98nKt6+xsca2bWML79lH1XGbodRo+MY=
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7/go.mod h1:9nsslyEJn7Utbielp4c+hc7qT6hqIJ52aGFR/tX+tYk= gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e/go.mod h1:9nsslyEJn7Utbielp4c+hc7qT6hqIJ52aGFR/tX+tYk=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=

View File

@ -81,7 +81,18 @@ func newBridge(locations *locations.Locations, identifier *useragent.UserAgent)
} }
// Create a new bridge. // Create a new bridge.
bridge, err := bridge.New(constants.APIHost, locations, encVault, identifier, pinningDialer, proxyDialer, autostarter, updater, version) bridge, err := bridge.New(
constants.APIHost,
locations,
encVault,
identifier,
pinningDialer,
dialer.CreateTransportWithDialer(proxyDialer),
proxyDialer,
autostarter,
updater,
version,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create bridge: %w", err) return nil, fmt.Errorf("could not create bridge: %w", err)
} }

View File

@ -35,7 +35,7 @@ type Bridge struct {
// api manages user API clients. // api manages user API clients.
api *liteapi.Manager api *liteapi.Manager
cookieJar *cookies.Jar cookieJar *cookies.Jar
proxyDialer ProxyDialer proxyCtl ProxyController
identifier Identifier identifier Identifier
// watchers holds all registered event watchers. // watchers holds all registered event watchers.
@ -81,15 +81,16 @@ func New(
vault *vault.Vault, // the bridge's encrypted data store vault *vault.Vault, // the bridge's encrypted data store
identifier Identifier, // the identifier to keep track of the user agent identifier Identifier, // the identifier to keep track of the user agent
tlsReporter TLSReporter, // the TLS reporter to report TLS errors tlsReporter TLSReporter, // the TLS reporter to report TLS errors
proxyDialer ProxyDialer, // the DoH dialer roundTripper http.RoundTripper, // the round tripper to use for API requests
proxyCtl ProxyController, // the DoH controller
autostarter Autostarter, // the autostarter to manage autostart settings autostarter Autostarter, // the autostarter to manage autostart settings
updater Updater, // the updater to fetch and install updates updater Updater, // the updater to fetch and install updates
curVersion *semver.Version, // the current version of the bridge curVersion *semver.Version, // the current version of the bridge
) (*Bridge, error) { ) (*Bridge, error) {
if vault.GetProxyAllowed() { if vault.GetProxyAllowed() {
proxyDialer.AllowProxy() proxyCtl.AllowProxy()
} else { } else {
proxyDialer.DisallowProxy() proxyCtl.DisallowProxy()
} }
cookieJar, err := cookies.NewCookieJar(vault) cookieJar, err := cookies.NewCookieJar(vault)
@ -101,7 +102,7 @@ func New(
liteapi.WithHostURL(apiURL), liteapi.WithHostURL(apiURL),
liteapi.WithAppVersion(constants.AppVersion), liteapi.WithAppVersion(constants.AppVersion),
liteapi.WithCookieJar(cookieJar), liteapi.WithCookieJar(cookieJar),
liteapi.WithTransport(&http.Transport{DialTLSContext: proxyDialer.DialTLSContext}), liteapi.WithTransport(roundTripper),
) )
tlsConfig, err := loadTLSConfig(vault) tlsConfig, err := loadTLSConfig(vault)
@ -141,7 +142,7 @@ func New(
api: api, api: api,
cookieJar: cookieJar, cookieJar: cookieJar,
proxyDialer: proxyDialer, proxyCtl: proxyCtl,
identifier: identifier, identifier: identifier,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
@ -179,6 +180,10 @@ func New(
return nil return nil
}) })
if err := bridge.loadUsers(); err != nil {
return nil, fmt.Errorf("failed to load users: %w", err)
}
go func() { go func() {
for range tlsReporter.GetTLSIssueCh() { for range tlsReporter.GetTLSIssueCh() {
bridge.publish(events.TLSIssue{}) bridge.publish(events.TLSIssue{})
@ -197,10 +202,6 @@ func New(
} }
}() }()
if err := bridge.loadUsers(context.Background()); err != nil {
return nil, fmt.Errorf("failed to load connected users: %w", err)
}
if err := bridge.serveIMAP(); err != nil { if err := bridge.serveIMAP(); err != nil {
bridge.PushError(ErrServeIMAP) bridge.PushError(ErrServeIMAP)
} }
@ -309,14 +310,8 @@ func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) {
func (bridge *Bridge) onStatusUp() { func (bridge *Bridge) onStatusUp() {
bridge.publish(events.ConnStatusUp{}) bridge.publish(events.ConnStatusUp{})
for _, userID := range bridge.vault.GetUserIDs() { if err := bridge.loadUsers(); err != nil {
if _, ok := bridge.users[userID]; !ok { logrus.WithError(err).Error("Failed to load users")
if vaultUser, err := bridge.vault.GetUser(userID); err != nil {
logrus.WithError(err).Error("Failed to get user from vault")
} else if err := bridge.loadUser(context.Background(), vaultUser); err != nil {
logrus.WithError(err).Error("Failed to load user")
}
}
} }
} }

View File

@ -2,6 +2,7 @@ package bridge_test
import ( import (
"context" "context"
"crypto/tls"
"net/http" "net/http"
"os" "os"
"testing" "testing"
@ -21,6 +22,7 @@ import (
"github.com/ProtonMail/proton-bridge/v2/tests" "github.com/ProtonMail/proton-bridge/v2/tests"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server"
"gitlab.protontech.ch/go/liteapi/server/backend" "gitlab.protontech.ch/go/liteapi/server/backend"
) )
@ -41,14 +43,14 @@ func init() {
} }
func TestBridge_ConnStatus(t *testing.T) { func TestBridge_ConnStatus(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of connection status events. // Get a stream of connection status events.
eventCh, done := bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{}) eventCh, done := bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{})
defer done() defer done()
// Simulate network disconnect. // Simulate network disconnect.
dialer.SetCanDial(false) netCtl.Disable()
// Trigger some operation that will fail due to the network disconnect. // Trigger some operation that will fail due to the network disconnect.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil) _, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
@ -58,7 +60,7 @@ func TestBridge_ConnStatus(t *testing.T) {
require.Equal(t, events.ConnStatusDown{}, <-eventCh) require.Equal(t, events.ConnStatusDown{}, <-eventCh)
// Simulate network reconnect. // Simulate network reconnect.
dialer.SetCanDial(true) netCtl.Enable()
// Trigger some operation that will succeed due to the network reconnect. // Trigger some operation that will succeed due to the network reconnect.
userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil) userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
@ -72,8 +74,8 @@ func TestBridge_ConnStatus(t *testing.T) {
} }
func TestBridge_TLSIssue(t *testing.T) { func TestBridge_TLSIssue(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of TLS issue events. // Get a stream of TLS issue events.
tlsEventCh, done := bridge.GetEvents(events.TLSIssue{}) tlsEventCh, done := bridge.GetEvents(events.TLSIssue{})
defer done() defer done()
@ -90,8 +92,8 @@ func TestBridge_TLSIssue(t *testing.T) {
} }
func TestBridge_Focus(t *testing.T) { func TestBridge_Focus(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of TLS issue events. // Get a stream of TLS issue events.
raiseCh, done := bridge.GetEvents(events.Raise{}) raiseCh, done := bridge.GetEvents(events.Raise{})
defer done() defer done()
@ -106,14 +108,14 @@ func TestBridge_Focus(t *testing.T) {
} }
func TestBridge_UserAgent(t *testing.T) { func TestBridge_UserAgent(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
var calls []server.Call var calls []server.Call
s.AddCallWatcher(func(call server.Call) { s.AddCallWatcher(func(call server.Call) {
calls = append(calls, call) calls = append(calls, call)
}) })
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Set the platform to something other than the default. // Set the platform to something other than the default.
bridge.SetCurrentPlatform("platform") bridge.SetCurrentPlatform("platform")
@ -131,7 +133,7 @@ func TestBridge_UserAgent(t *testing.T) {
} }
func TestBridge_Cookies(t *testing.T) { func TestBridge_Cookies(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
var calls []server.Call var calls []server.Call
s.AddCallWatcher(func(call server.Call) { s.AddCallWatcher(func(call server.Call) {
@ -141,7 +143,7 @@ func TestBridge_Cookies(t *testing.T) {
var sessionID string var sessionID string
// Start bridge and add a user so that API assigns us a session ID via cookie. // Start bridge and add a user so that API assigns us a session ID via cookie.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil) _, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err) require.NoError(t, err)
@ -152,7 +154,7 @@ func TestBridge_Cookies(t *testing.T) {
}) })
// Start bridge again and check that it uses the same session ID. // Start bridge again and check that it uses the same session ID.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
cookie, err := (&http.Request{Header: calls[len(calls)-1].Header}).Cookie("Session-Id") cookie, err := (&http.Request{Header: calls[len(calls)-1].Header}).Cookie("Session-Id")
require.NoError(t, err) require.NoError(t, err)
@ -162,8 +164,8 @@ func TestBridge_Cookies(t *testing.T) {
} }
func TestBridge_CheckUpdate(t *testing.T) { func TestBridge_CheckUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Disable autoupdate for this test. // Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false)) require.NoError(t, bridge.SetAutoUpdate(false))
@ -201,8 +203,8 @@ func TestBridge_CheckUpdate(t *testing.T) {
} }
func TestBridge_AutoUpdate(t *testing.T) { func TestBridge_AutoUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Enable autoupdate for this test. // Enable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(true)) require.NoError(t, bridge.SetAutoUpdate(true))
@ -229,8 +231,8 @@ func TestBridge_AutoUpdate(t *testing.T) {
} }
func TestBridge_ManualUpdate(t *testing.T) { func TestBridge_ManualUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Disable autoupdate for this test. // Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false)) require.NoError(t, bridge.SetAutoUpdate(false))
@ -258,8 +260,8 @@ func TestBridge_ManualUpdate(t *testing.T) {
} }
func TestBridge_ForceUpdate(t *testing.T) { func TestBridge_ForceUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of update events. // Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateForced{}) updateCh, done := bridge.GetEvents(events.UpdateForced{})
defer done() defer done()
@ -278,11 +280,11 @@ func TestBridge_ForceUpdate(t *testing.T) {
} }
func TestBridge_BadVaultKey(t *testing.T) { func TestBridge_BadVaultKey(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
var userID string var userID string
// Login a user. // Login a user.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil) newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err) require.NoError(t, err)
@ -290,27 +292,27 @@ func TestBridge_BadVaultKey(t *testing.T) {
}) })
// Start bridge with the correct vault key -- it should load the users correctly. // Start bridge with the correct vault key -- it should load the users correctly.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs()) require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs())
}) })
// Start bridge with a bad vault key, the vault will be wiped and bridge will show no users. // Start bridge with a bad vault key, the vault will be wiped and bridge will show no users.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs()) require.Empty(t, bridge.GetUserIDs())
}) })
// Start bridge with a nil vault key, the vault will be wiped and bridge will show no users. // Start bridge with a nil vault key, the vault will be wiped and bridge will show no users.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs()) require.Empty(t, bridge.GetUserIDs())
}) })
}) })
} }
func TestBridge_MissingGluonDir(t *testing.T) { func TestBridge_MissingGluonDir(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
var gluonDir string var gluonDir string
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil) _, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err) require.NoError(t, err)
@ -325,20 +327,20 @@ func TestBridge_MissingGluonDir(t *testing.T) {
require.NoError(t, os.RemoveAll(gluonDir)) require.NoError(t, os.RemoveAll(gluonDir))
// Bridge starts but can't find the gluon dir; there should be no error. // Bridge starts but can't find the gluon dir; there should be no error.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// ... // ...
}) })
}) })
} }
// withEnv creates the full test environment and runs the tests. // withTLSEnv creates the full test environment and runs the tests.
func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte)) { func withTLSEnv(t *testing.T, tests func(context.Context, *server.Server, *liteapi.NetCtl, bridge.Locator, []byte)) {
// Create test API. // Create test API.
server := server.NewTLS() server := server.NewTLS()
defer server.Close() defer server.Close()
// Add test user. // Add test user.
_, _, err := server.CreateUser(username, string(password), username+"@pm.me") _, _, err := server.CreateUser(username, username+"@pm.me", password)
require.NoError(t, err) require.NoError(t, err)
// Generate a random vault key. // Generate a random vault key.
@ -349,23 +351,56 @@ func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
// Create a net controller so we can simulate network connectivity issues.
netCtl := liteapi.NewNetCtl()
// Create a locations object to provide temporary locations for bridge data during the test.
locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name")
// Run the tests. // Run the tests.
tests( tests(ctx, server, netCtl, locations, vaultKey)
ctx, }
server,
bridge.NewTestDialer(), // withEnv creates the full test environment and runs the tests.
locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name"), func withEnv(t *testing.T, server *server.Server, tests func(context.Context, *liteapi.NetCtl, bridge.Locator, []byte)) {
vaultKey, // Add test user.
) _, _, err := server.CreateUser(username, username+"@pm.me", password)
require.NoError(t, err)
// Generate a random vault key.
vaultKey, err := crypto.RandomToken(32)
require.NoError(t, err)
// Create a context used for the test.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a net controller so we can simulate network connectivity issues.
netCtl := liteapi.NewNetCtl()
// Create a locations object to provide temporary locations for bridge data during the test.
locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name")
// Run the tests.
tests(ctx, netCtl, locations, vaultKey)
} }
// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done. // withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done.
func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) { func withBridge(
t *testing.T,
ctx context.Context,
apiURL string,
netCtl *liteapi.NetCtl,
locator bridge.Locator,
vaultKey []byte,
tests func(*bridge.Bridge, *bridge.Mocks),
) {
// Create the mock objects used in the tests. // Create the mock objects used in the tests.
mocks := bridge.NewMocks(t, dialer, v2_3_0, v2_3_0) mocks := bridge.NewMocks(t, v2_3_0, v2_3_0)
defer mocks.Close()
// Bridge will enable the proxy by default at startup. // Bridge will enable the proxy by default at startup.
mocks.ProxyDialer.EXPECT().AllowProxy() mocks.ProxyCtl.EXPECT().AllowProxy()
// Get the path to the vault. // Get the path to the vault.
vaultDir, err := locator.ProvideSettingsPath() vaultDir, err := locator.ProvideSettingsPath()
@ -380,7 +415,18 @@ func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge
require.NoError(t, vault.SetSMTPPort(0)) require.NoError(t, vault.SetSMTPPort(0))
// Create a new bridge. // Create a new bridge.
bridge, err := bridge.New(apiURL, locator, vault, useragent.New(), mocks.TLSReporter, mocks.ProxyDialer, mocks.Autostarter, mocks.Updater, v2_3_0) bridge, err := bridge.New(
apiURL,
locator,
vault,
useragent.New(),
mocks.TLSReporter,
liteapi.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
mocks.ProxyCtl,
mocks.Autostarter,
mocks.Updater,
v2_3_0,
)
require.NoError(t, err) require.NoError(t, err)
// Close the bridge when done. // Close the bridge when done.

View File

@ -6,7 +6,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"net"
"os" "os"
"strconv"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon" "github.com/ProtonMail/gluon"
@ -33,6 +35,22 @@ func (bridge *Bridge) serveIMAP() error {
return fmt.Errorf("failed to serve IMAP: %w", err) return fmt.Errorf("failed to serve IMAP: %w", err)
} }
_, port, err := net.SplitHostPort(imapListener.Addr().String())
if err != nil {
return fmt.Errorf("failed to get IMAP listener address: %w", err)
}
portInt, err := strconv.Atoi(port)
if err != nil {
return fmt.Errorf("failed to convert IMAP listener port to int: %w", err)
}
if portInt != bridge.vault.GetIMAPPort() {
if err := bridge.vault.SetIMAPPort(portInt); err != nil {
return fmt.Errorf("failed to update IMAP port in vault: %w", err)
}
}
go func() { go func() {
for err := range bridge.imapServer.GetErrorCh() { for err := range bridge.imapServer.GetErrorCh() {
logrus.WithError(err).Error("IMAP server error") logrus.WithError(err).Error("IMAP server error")

View File

@ -1,10 +1,6 @@
package bridge package bridge
import ( import (
"context"
"crypto/tls"
"errors"
"net"
"os" "os"
"testing" "testing"
@ -15,7 +11,7 @@ import (
) )
type Mocks struct { type Mocks struct {
ProxyDialer *mocks.MockProxyDialer ProxyCtl *mocks.MockProxyController
TLSReporter *mocks.MockTLSReporter TLSReporter *mocks.MockTLSReporter
TLSIssueCh chan struct{} TLSIssueCh chan struct{}
@ -23,11 +19,11 @@ type Mocks struct {
Autostarter *mocks.MockAutostarter Autostarter *mocks.MockAutostarter
} }
func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Version) *Mocks { func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks {
ctl := gomock.NewController(tb) ctl := gomock.NewController(tb)
mocks := &Mocks{ mocks := &Mocks{
ProxyDialer: mocks.NewMockProxyDialer(ctl), ProxyCtl: mocks.NewMockProxyController(ctl),
TLSReporter: mocks.NewMockTLSReporter(ctl), TLSReporter: mocks.NewMockTLSReporter(ctl),
TLSIssueCh: make(chan struct{}), TLSIssueCh: make(chan struct{}),
@ -35,41 +31,14 @@ func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Versio
Autostarter: mocks.NewMockAutostarter(ctl), Autostarter: mocks.NewMockAutostarter(ctl),
} }
// When using the proxy dialer, we want to use the test dialer.
mocks.ProxyDialer.EXPECT().DialTLSContext(
gomock.Any(),
gomock.Any(),
gomock.Any(),
).DoAndReturn(func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.DialTLSContext(ctx, network, address)
}).AnyTimes()
// When getting the TLS issue channel, we want to return the test channel. // When getting the TLS issue channel, we want to return the test channel.
mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes() mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes()
return mocks return mocks
} }
type TestDialer struct { func (mocks *Mocks) Close() {
canDial bool close(mocks.TLSIssueCh)
}
func NewTestDialer() *TestDialer {
return &TestDialer{
canDial: true,
}
}
func (d *TestDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
if !d.canDial {
return nil, errors.New("cannot dial")
}
return (&tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}).DialContext(ctx, network, address)
}
func (d *TestDialer) SetCanDial(canDial bool) {
d.canDial = canDial
} }
type TestLocationsProvider struct { type TestLocationsProvider struct {

View File

@ -1,12 +1,10 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyDialer,Autostarter) // Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyController,Autostarter)
// Package mocks is a generated GoMock package. // Package mocks is a generated GoMock package.
package mocks package mocks
import ( import (
context "context"
net "net"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@ -49,66 +47,51 @@ func (mr *MockTLSReporterMockRecorder) GetTLSIssueCh() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh))
} }
// MockProxyDialer is a mock of ProxyDialer interface. // MockProxyController is a mock of ProxyController interface.
type MockProxyDialer struct { type MockProxyController struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockProxyDialerMockRecorder recorder *MockProxyControllerMockRecorder
} }
// MockProxyDialerMockRecorder is the mock recorder for MockProxyDialer. // MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
type MockProxyDialerMockRecorder struct { type MockProxyControllerMockRecorder struct {
mock *MockProxyDialer mock *MockProxyController
} }
// NewMockProxyDialer creates a new mock instance. // NewMockProxyController creates a new mock instance.
func NewMockProxyDialer(ctrl *gomock.Controller) *MockProxyDialer { func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
mock := &MockProxyDialer{ctrl: ctrl} mock := &MockProxyController{ctrl: ctrl}
mock.recorder = &MockProxyDialerMockRecorder{mock} mock.recorder = &MockProxyControllerMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProxyDialer) EXPECT() *MockProxyDialerMockRecorder { func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
return m.recorder return m.recorder
} }
// AllowProxy mocks base method. // AllowProxy mocks base method.
func (m *MockProxyDialer) AllowProxy() { func (m *MockProxyController) AllowProxy() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "AllowProxy") m.ctrl.Call(m, "AllowProxy")
} }
// AllowProxy indicates an expected call of AllowProxy. // AllowProxy indicates an expected call of AllowProxy.
func (mr *MockProxyDialerMockRecorder) AllowProxy() *gomock.Call { func (mr *MockProxyControllerMockRecorder) AllowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyDialer)(nil).AllowProxy)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyController)(nil).AllowProxy))
}
// DialTLSContext mocks base method.
func (m *MockProxyDialer) DialTLSContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DialTLSContext", arg0, arg1, arg2)
ret0, _ := ret[0].(net.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DialTLSContext indicates an expected call of DialTLSContext.
func (mr *MockProxyDialerMockRecorder) DialTLSContext(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLSContext", reflect.TypeOf((*MockProxyDialer)(nil).DialTLSContext), arg0, arg1, arg2)
} }
// DisallowProxy mocks base method. // DisallowProxy mocks base method.
func (m *MockProxyDialer) DisallowProxy() { func (m *MockProxyController) DisallowProxy() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "DisallowProxy") m.ctrl.Call(m, "DisallowProxy")
} }
// DisallowProxy indicates an expected call of DisallowProxy. // DisallowProxy indicates an expected call of DisallowProxy.
func (mr *MockProxyDialerMockRecorder) DisallowProxy() *gomock.Call { func (mr *MockProxyControllerMockRecorder) DisallowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyDialer)(nil).DisallowProxy)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyController)(nil).DisallowProxy))
} }
// MockAutostarter is a mock of Autostarter interface. // MockAutostarter is a mock of Autostarter interface.

View File

@ -138,9 +138,9 @@ func (bridge *Bridge) GetProxyAllowed() bool {
func (bridge *Bridge) SetProxyAllowed(allowed bool) error { func (bridge *Bridge) SetProxyAllowed(allowed bool) error {
if allowed { if allowed {
bridge.proxyDialer.AllowProxy() bridge.proxyCtl.AllowProxy()
} else { } else {
bridge.proxyDialer.DisallowProxy() bridge.proxyCtl.DisallowProxy()
} }
return bridge.vault.SetProxyAllowed(allowed) return bridge.vault.SetProxyAllowed(allowed)

View File

@ -7,12 +7,13 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server"
) )
func TestBridge_Settings_GluonDir(t *testing.T) { func TestBridge_Settings_GluonDir(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Create a user. // Create a user.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil) _, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err) require.NoError(t, err)
@ -34,8 +35,8 @@ func TestBridge_Settings_GluonDir(t *testing.T) {
} }
func TestBridge_Settings_IMAPPort(t *testing.T) { func TestBridge_Settings_IMAPPort(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
curPort := bridge.GetIMAPPort() curPort := bridge.GetIMAPPort()
// Set the port to 1144. // Set the port to 1144.
@ -51,8 +52,8 @@ func TestBridge_Settings_IMAPPort(t *testing.T) {
} }
func TestBridge_Settings_IMAPSSL(t *testing.T) { func TestBridge_Settings_IMAPSSL(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, IMAP SSL is disabled. // By default, IMAP SSL is disabled.
require.False(t, bridge.GetIMAPSSL()) require.False(t, bridge.GetIMAPSSL())
@ -66,8 +67,8 @@ func TestBridge_Settings_IMAPSSL(t *testing.T) {
} }
func TestBridge_Settings_SMTPPort(t *testing.T) { func TestBridge_Settings_SMTPPort(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
curPort := bridge.GetSMTPPort() curPort := bridge.GetSMTPPort()
// Set the port to 1024. // Set the port to 1024.
@ -84,8 +85,8 @@ func TestBridge_Settings_SMTPPort(t *testing.T) {
} }
func TestBridge_Settings_SMTPSSL(t *testing.T) { func TestBridge_Settings_SMTPSSL(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, SMTP SSL is disabled. // By default, SMTP SSL is disabled.
require.False(t, bridge.GetSMTPSSL()) require.False(t, bridge.GetSMTPSSL())
@ -99,13 +100,13 @@ func TestBridge_Settings_SMTPSSL(t *testing.T) {
} }
func TestBridge_Settings_Proxy(t *testing.T) { func TestBridge_Settings_Proxy(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, proxy is allowed. // By default, proxy is allowed.
require.True(t, bridge.GetProxyAllowed()) require.True(t, bridge.GetProxyAllowed())
// Disallow proxy. // Disallow proxy.
mocks.ProxyDialer.EXPECT().DisallowProxy() mocks.ProxyCtl.EXPECT().DisallowProxy()
require.NoError(t, bridge.SetProxyAllowed(false)) require.NoError(t, bridge.SetProxyAllowed(false))
// Get the new setting. // Get the new setting.
@ -115,8 +116,8 @@ func TestBridge_Settings_Proxy(t *testing.T) {
} }
func TestBridge_Settings_Autostart(t *testing.T) { func TestBridge_Settings_Autostart(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, autostart is disabled. // By default, autostart is disabled.
require.False(t, bridge.GetAutostart()) require.False(t, bridge.GetAutostart())
@ -131,8 +132,8 @@ func TestBridge_Settings_Autostart(t *testing.T) {
} }
func TestBridge_Settings_FirstStart(t *testing.T) { func TestBridge_Settings_FirstStart(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, first start is true. // By default, first start is true.
require.True(t, bridge.GetFirstStart()) require.True(t, bridge.GetFirstStart())
@ -146,8 +147,8 @@ func TestBridge_Settings_FirstStart(t *testing.T) {
} }
func TestBridge_Settings_FirstStartGUI(t *testing.T) { func TestBridge_Settings_FirstStartGUI(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, first start is true. // By default, first start is true.
require.True(t, bridge.GetFirstStartGUI()) require.True(t, bridge.GetFirstStartGUI())

View File

@ -3,6 +3,8 @@ package bridge
import ( import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net"
"strconv"
"github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/emersion/go-sasl" "github.com/emersion/go-sasl"
@ -22,6 +24,22 @@ func (bridge *Bridge) serveSMTP() error {
} }
}() }()
_, port, err := net.SplitHostPort(smtpListener.Addr().String())
if err != nil {
return fmt.Errorf("failed to get SMTP listener address: %w", err)
}
portInt, err := strconv.Atoi(port)
if err != nil {
return fmt.Errorf("failed to convert SMTP listener port to int: %w", err)
}
if portInt != bridge.vault.GetSMTPPort() {
if err := bridge.vault.SetSMTPPort(portInt); err != nil {
return fmt.Errorf("failed to update SMTP port in vault: %w", err)
}
}
return nil return nil
} }

View File

@ -0,0 +1,130 @@
package bridge_test
import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/emersion/go-imap/client"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
)
func TestBridge_Sync(t *testing.T) {
s := server.New()
defer s.Close()
numMsg := 1 << 10
withEnv(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password)
require.NoError(t, err)
labelID, err := s.CreateLabel(userID, "folder", liteapi.LabelTypeFolder)
require.NoError(t, err)
literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml"))
require.NoError(t, err)
for i := 0; i < numMsg; i++ {
messageID, err := s.CreateMessage(userID, addrID, literal, liteapi.MessageFlagReceived, false, false)
require.NoError(t, err)
require.NoError(t, s.LabelMessage(userID, messageID, labelID))
}
var read uint64
netCtl.OnRead(func(b []byte) {
read += uint64(len(b))
})
// The initial user should be fully synced.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
syncCh, done := bridge.GetEvents(events.SyncFinished{})
defer done()
userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil)
require.NoError(t, err)
require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID)
})
// If we then connect an IMAP client, it should see all the messages.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
require.True(t, info.Connected)
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
defer client.Logout()
status, err := client.Select(`Folders/folder`, false)
require.NoError(t, err)
require.Equal(t, uint32(numMsg), status.Messages)
})
// Now let's remove the user and simulate a network error.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.NoError(t, bridge.DeleteUser(ctx, userID))
})
// Pretend we can only sync 2/3 of the original messages.
netCtl.SetReadLimit(2 * read / 3)
// Login the user; its sync should fail.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
syncCh, done := bridge.GetEvents(events.SyncFailed{})
defer done()
userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil)
require.NoError(t, err)
require.Equal(t, userID, (<-syncCh).(events.SyncFailed).UserID)
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
require.True(t, info.Connected)
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
defer client.Logout()
status, err := client.Select(`Folders/folder`, false)
require.NoError(t, err)
require.Less(t, status.Messages, uint32(numMsg))
})
// Remove the network limit, allowing the sync to finish.
netCtl.SetReadLimit(0)
// Login the user; its sync should now finish.
// If we then connect an IMAP client, it should eventually see all the messages.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
syncCh, done := bridge.GetEvents(events.SyncFinished{})
defer done()
require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID)
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
require.True(t, info.Connected)
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
defer client.Logout()
status, err := client.Select(`Folders/folder`, false)
require.NoError(t, err)
require.Equal(t, uint32(numMsg), status.Messages)
})
})
}

View File

@ -0,0 +1,6 @@
To: recipient@pm.me
From: sender@pm.me
Subject: Test
Content-Type: text/plain; charset=utf-8
Test

View File

@ -1,9 +1,6 @@
package bridge package bridge
import ( import (
"context"
"net"
"github.com/ProtonMail/proton-bridge/v2/internal/updater" "github.com/ProtonMail/proton-bridge/v2/internal/updater"
) )
@ -21,17 +18,15 @@ type Identifier interface {
SetPlatform(platform string) SetPlatform(platform string)
} }
type TLSReporter interface { type ProxyController interface {
GetTLSIssueCh() <-chan struct{}
}
type ProxyDialer interface {
DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error)
AllowProxy() AllowProxy()
DisallowProxy() DisallowProxy()
} }
type TLSReporter interface {
GetTLSIssueCh() <-chan struct{}
}
type Autostarter interface { type Autostarter interface {
Enable() error Enable() error
Disable() error Disable() error

View File

@ -18,6 +18,9 @@ func (bridge *Bridge) watchForUpdates() error {
go func() { go func() {
for { for {
select { select {
case <-bridge.stopCh:
return
case <-bridge.updateCheckCh: case <-bridge.updateCheckCh:
case <-ticker.C: case <-ticker.C:
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/try"
"github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
@ -82,9 +83,11 @@ func (bridge *Bridge) LoginUser(
) (string, error) { ) (string, error) {
client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password) client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to create new API client: %w", err)
} }
userID, err := try.CatchVal(
func() (string, error) {
if _, ok := bridge.users[auth.UserID]; ok { if _, ok := bridge.users[auth.UserID]; ok {
return "", ErrUserAlreadyLoggedIn return "", ErrUserAlreadyLoggedIn
} }
@ -92,66 +95,64 @@ func (bridge *Bridge) LoginUser(
if auth.TwoFA.Enabled == liteapi.TOTPEnabled { if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
totp, err := getTOTP() totp, err := getTOTP()
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to get TOTP: %w", err)
} }
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil { if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
return "", err return "", fmt.Errorf("failed to authorize 2FA: %w", err)
} }
} }
var keyPass []byte var keyPass []byte
if auth.PasswordMode == liteapi.TwoPasswordMode { if auth.PasswordMode == liteapi.TwoPasswordMode {
pass, err := getKeyPass() userKeyPass, err := getKeyPass()
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to get key password: %w", err)
} }
keyPass = pass keyPass = userKeyPass
} else { } else {
keyPass = password keyPass = password
} }
apiUser, err := client.GetUser(ctx) return bridge.loginUser(ctx, client, auth.UID, auth.RefreshToken, keyPass)
},
func() error {
return client.AuthDelete(ctx)
},
func() error {
bridge.deleteUser(ctx, auth.UserID)
return nil
},
)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to login user: %w", err)
} }
salts, err := client.GetSalts(ctx) bridge.publish(events.UserLoggedIn{
if err != nil { UserID: userID,
return "", err })
}
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID) return userID, nil
if err != nil {
return "", err
}
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil {
return "", err
}
return auth.UserID, nil
} }
// LogoutUser logs out the given user. // LogoutUser logs out the given user.
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error { func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
return bridge.logoutUser(ctx, userID, true, false) if err := bridge.logoutUser(ctx, userID); err != nil {
return fmt.Errorf("failed to logout user: %w", err)
}
bridge.publish(events.UserLoggedOut{
UserID: userID,
})
return nil
} }
// DeleteUser deletes the given user. // DeleteUser deletes the given user.
// If it is authorized, it is logged out first.
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error { func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
if bridge.users[userID] != nil { bridge.deleteUser(ctx, userID)
if err := bridge.logoutUser(ctx, userID, true, true); err != nil {
return err
}
}
if err := bridge.vault.DeleteUser(userID); err != nil {
return err
}
bridge.publish(events.UserDeleted{ bridge.publish(events.UserDeleted{
UserID: userID, UserID: userID,
@ -193,53 +194,91 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
return nil return nil
} }
// loadUsers loads authorized users from the vault. func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
func (bridge *Bridge) loadUsers(ctx context.Context) error { apiUser, err := client.GetUser(ctx)
for _, userID := range bridge.vault.GetUserIDs() {
user, err := bridge.vault.GetUser(userID)
if err != nil { if err != nil {
return err return "", fmt.Errorf("failed to get API user: %w", err)
}
salts, err := client.GetSalts(ctx)
if err != nil {
return "", fmt.Errorf("failed to get key salts: %w", err)
}
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID)
if err != nil {
return "", fmt.Errorf("failed to salt key password: %w", err)
}
if err := bridge.addUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass); err != nil {
return "", fmt.Errorf("failed to add bridge user: %w", err)
}
return apiUser.ID, nil
}
// loadUsers is a loop that, when polled, attempts to load authorized users from the vault.
func (bridge *Bridge) loadUsers() error {
return bridge.vault.ForUser(func(user *vault.User) error {
if _, ok := bridge.users[user.UserID()]; ok {
return nil
} }
if user.AuthUID() == "" { if user.AuthUID() == "" {
continue return nil
} }
if err := bridge.loadUser(ctx, user); err != nil { if err := bridge.loadUser(user); err != nil {
logrus.WithError(err).Error("Failed to load connected user")
if _, ok := err.(*resty.ResponseError); ok { if _, ok := err.(*resty.ResponseError); ok {
if err := bridge.vault.ClearUser(userID); err != nil { logrus.WithError(err).Error("Failed to load connected user, clearing its secrets from vault")
if err := user.Clear(); err != nil {
logrus.WithError(err).Error("Failed to clear user") logrus.WithError(err).Error("Failed to clear user")
} }
} } else {
logrus.WithError(err).Error("Failed to load connected user")
continue
}
} }
return nil return nil
}
bridge.publish(events.UserLoaded{
UserID: user.UserID(),
})
return nil
})
} }
func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error { // loadUser loads an existing user from the vault.
func (bridge *Bridge) loadUser(user *vault.User) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef()) client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef())
if err != nil { if err != nil {
return fmt.Errorf("failed to create API client: %w", err) return fmt.Errorf("failed to create API client: %w", err)
} }
if err := try.Catch(
func() error {
apiUser, err := client.GetUser(ctx) apiUser, err := client.GetUser(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user: %w", err) return fmt.Errorf("failed to get user: %w", err)
} }
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil { return bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass())
return fmt.Errorf("failed to add user: %w", err) },
func() error {
return client.AuthDelete(ctx)
},
func() error {
return bridge.logoutUser(ctx, user.UserID())
},
); err != nil {
return fmt.Errorf("failed to load user: %w", err)
} }
bridge.publish(events.UserLoggedIn{
UserID: user.UserID(),
})
return nil return nil
} }
@ -304,10 +343,6 @@ func (bridge *Bridge) addUser(
return nil return nil
}) })
bridge.publish(events.UserLoggedIn{
UserID: user.ID(),
})
return nil return nil
} }
@ -363,54 +398,6 @@ func (bridge *Bridge) addExistingUser(
return user, nil return user, nil
} }
// logoutUser closes and removes the user with the given ID.
// If withAPI is true, the user will additionally be logged out from API.
// If withFiles is true, the user's files will be deleted.
func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, withFiles bool) error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
if err := bridge.smtpBackend.removeUser(user); err != nil {
return fmt.Errorf("failed to remove SMTP user: %w", err)
}
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
}
if withAPI {
if err := user.Logout(ctx); err != nil {
return fmt.Errorf("failed to logout user: %w", err)
}
}
if err := user.Close(); err != nil {
return fmt.Errorf("failed to close user: %w", err)
}
if err := bridge.vault.ClearUser(userID); err != nil {
return fmt.Errorf("failed to clear user: %w", err)
}
if withFiles {
if err := bridge.vault.DeleteUser(userID); err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
}
delete(bridge.users, userID)
bridge.publish(events.UserLoggedOut{
UserID: userID,
})
return nil
}
// addIMAPUser connects the given user to gluon. // addIMAPUser connects the given user to gluon.
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error { func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
imapConn, err := user.NewIMAPConnectors() imapConn, err := user.NewIMAPConnectors()
@ -438,6 +425,65 @@ func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
return nil return nil
} }
// logoutUser logs the given user out from bridge.
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
if err := bridge.smtpBackend.removeUser(user); err != nil {
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
}
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
}
if err := user.Logout(ctx); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
if err := user.Close(); err != nil {
logrus.WithError(err).Error("Failed to close user")
}
delete(bridge.users, userID)
return nil
}
// deleteUser deletes the given user from bridge.
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
if user, ok := bridge.users[userID]; ok {
if err := bridge.smtpBackend.removeUser(user); err != nil {
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
}
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
}
if err := user.Logout(ctx); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
if err := user.Close(); err != nil {
logrus.WithError(err).Error("Failed to close user")
}
}
if err := bridge.vault.DeleteUser(userID); err != nil {
logrus.WithError(err).Error("Failed to delete user from vault")
}
delete(bridge.users, userID)
}
// getUserInfo returns information about a disconnected user. // getUserInfo returns information about a disconnected user.
func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo { func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo {
return UserInfo{ return UserInfo{

View File

@ -27,7 +27,7 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
} }
case events.UserDeauth: case events.UserDeauth:
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil { if err := bridge.logoutUser(context.Background(), event.UserID); err != nil {
return fmt.Errorf("failed to logout user: %w", err) return fmt.Errorf("failed to logout user: %w", err)
} }
} }

View File

@ -9,17 +9,18 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server"
) )
func TestBridge_WithoutUsers(t *testing.T) { func TestBridge_WithoutUsers(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs()) require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge)) require.Empty(t, getConnectedUserIDs(t, bridge))
}) })
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs()) require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge)) require.Empty(t, getConnectedUserIDs(t, bridge))
}) })
@ -27,8 +28,8 @@ func TestBridge_WithoutUsers(t *testing.T) {
} }
func TestBridge_Login(t *testing.T) { func TestBridge_Login(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. // Login the user.
userID, err := bridge.LoginUser(ctx, username, password, nil, nil) userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
require.NoError(t, err) require.NoError(t, err)
@ -41,8 +42,8 @@ func TestBridge_Login(t *testing.T) {
} }
func TestBridge_LoginLogoutLogin(t *testing.T) { func TestBridge_LoginLogoutLogin(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. // Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -69,8 +70,8 @@ func TestBridge_LoginLogoutLogin(t *testing.T) {
} }
func TestBridge_LoginDeleteLogin(t *testing.T) { func TestBridge_LoginDeleteLogin(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. // Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -97,8 +98,8 @@ func TestBridge_LoginDeleteLogin(t *testing.T) {
} }
func TestBridge_LoginDeauthLogin(t *testing.T) { func TestBridge_LoginDeauthLogin(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. // Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -131,10 +132,10 @@ func TestBridge_LoginDeauthLogin(t *testing.T) {
func TestBridge_LoginExpireLogin(t *testing.T) { func TestBridge_LoginExpireLogin(t *testing.T) {
const authLife = 2 * time.Second const authLife = 2 * time.Second
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
s.SetAuthLife(authLife) s.SetAuthLife(authLife)
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. Its auth will only be valid for a short time. // Login the user. Its auth will only be valid for a short time.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -148,11 +149,11 @@ func TestBridge_LoginExpireLogin(t *testing.T) {
} }
func TestBridge_FailToLoad(t *testing.T) { func TestBridge_FailToLoad(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string var userID string
// Login the user. // Login the user.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
}) })
@ -160,7 +161,7 @@ func TestBridge_FailToLoad(t *testing.T) {
require.NoError(t, s.RevokeUser(userID)) require.NoError(t, s.RevokeUser(userID))
// When bridge starts, the user will not be logged in. // When bridge starts, the user will not be logged in.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge)) require.Empty(t, getConnectedUserIDs(t, bridge))
}) })
@ -168,25 +169,27 @@ func TestBridge_FailToLoad(t *testing.T) {
} }
func TestBridge_LoadWithoutInternet(t *testing.T) { func TestBridge_LoadWithoutInternet(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string var userID string
// Login the user. // Login the user.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
}) })
// Simulate loss of internet connection. // Simulate loss of internet connection.
dialer.SetCanDial(false) netCtl.Disable()
// Start bridge without internet. // Start bridge without internet.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Initially, users are not connected. // Initially, users are not connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge)) require.Empty(t, getConnectedUserIDs(t, bridge))
time.Sleep(5 * time.Second)
// Simulate internet connection. // Simulate internet connection.
dialer.SetCanDial(true) netCtl.Enable()
// The user will eventually be connected. // The user will eventually be connected.
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
@ -197,16 +200,14 @@ func TestBridge_LoadWithoutInternet(t *testing.T) {
} }
func TestBridge_LoginRestart(t *testing.T) { func TestBridge_LoginRestart(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string var userID string
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
}) })
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
}) })
@ -214,10 +215,10 @@ func TestBridge_LoginRestart(t *testing.T) {
} }
func TestBridge_LoginLogoutRestart(t *testing.T) { func TestBridge_LoginLogoutRestart(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string var userID string
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. // Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -225,7 +226,7 @@ func TestBridge_LoginLogoutRestart(t *testing.T) {
require.NoError(t, bridge.LogoutUser(ctx, userID)) require.NoError(t, bridge.LogoutUser(ctx, userID))
}) })
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still disconnected. // The user is still disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge)) require.Empty(t, getConnectedUserIDs(t, bridge))
@ -234,10 +235,10 @@ func TestBridge_LoginLogoutRestart(t *testing.T) {
} }
func TestBridge_LoginDeleteRestart(t *testing.T) { func TestBridge_LoginDeleteRestart(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string var userID string
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. // Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -245,7 +246,7 @@ func TestBridge_LoginDeleteRestart(t *testing.T) {
require.NoError(t, bridge.DeleteUser(ctx, userID)) require.NoError(t, bridge.DeleteUser(ctx, userID))
}) })
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still gone. // The user is still gone.
require.Empty(t, bridge.GetUserIDs()) require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge)) require.Empty(t, getConnectedUserIDs(t, bridge))
@ -253,13 +254,69 @@ func TestBridge_LoginDeleteRestart(t *testing.T) {
}) })
} }
func TestBridge_FailLoginRecover(t *testing.T) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var read uint64
netCtl.OnRead(func(b []byte) {
read += uint64(len(b))
})
// Log the user in and record how much data was read.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
require.NoError(t, bridge.LogoutUser(ctx, userID))
})
// Simulate a partial read.
netCtl.SetReadLimit(read / 2)
// We should fail to log the user in because we can't fully read its data.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Error(t, getErr(bridge.LoginUser(ctx, username, password, nil, nil)))
// There should be no users.
require.Empty(t, bridge.GetUserIDs())
})
})
}
func TestBridge_FailLoadRecover(t *testing.T) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
must(bridge.LoginUser(ctx, username, password, nil, nil))
})
var read uint64
netCtl.OnRead(func(b []byte) {
read += uint64(len(b))
})
// Start bridge and record how much data was read.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// ...
})
// Simulate a partial read.
netCtl.SetReadLimit(read / 2)
// We should fail to load the user; it should be disconnected.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userIDs := bridge.GetUserIDs()
require.False(t, must(bridge.GetUserInfo(userIDs[0])).Connected)
})
})
}
func TestBridge_BridgePass(t *testing.T) { func TestBridge_BridgePass(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string var userID string
var pass []byte var pass []byte
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. // Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -276,7 +333,7 @@ func TestBridge_BridgePass(t *testing.T) {
require.Equal(t, pass, pass) require.Equal(t, pass, pass)
}) })
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The bridge should load the user. // The bridge should load the user.
require.Equal(t, []string{userID}, bridge.GetUserIDs()) require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
@ -288,8 +345,8 @@ func TestBridge_BridgePass(t *testing.T) {
} }
func TestBridge_AddressMode(t *testing.T) { func TestBridge_AddressMode(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. // Login the user.
userID, err := bridge.LoginUser(ctx, username, password, nil, nil) userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
require.NoError(t, err) require.NoError(t, err)
@ -313,3 +370,8 @@ func TestBridge_AddressMode(t *testing.T) {
}) })
}) })
} }
// getErr returns the error that was passed to it.
func getErr[T any](val T, err error) error {
return err
}

View File

@ -2,6 +2,12 @@ package events
import "github.com/ProtonMail/proton-bridge/v2/internal/vault" import "github.com/ProtonMail/proton-bridge/v2/internal/vault"
type UserLoaded struct {
eventBase
UserID string
}
type UserLoggedIn struct { type UserLoggedIn struct {
eventBase eventBase

View File

@ -64,4 +64,5 @@ func (service *FocusService) GetRaiseCh() <-chan struct{} {
// Close closes the service. // Close closes the service.
func (service *FocusService) Close() { func (service *FocusService) Close() {
service.server.Stop() service.server.Stop()
close(service.raiseCh)
} }

View File

@ -1,41 +0,0 @@
package pool
import "context"
type job[In, Out any] struct {
ctx context.Context
req In
res chan Out
err chan error
done chan struct{}
}
func newJob[In, Out any](ctx context.Context, req In) *job[In, Out] {
return &job[In, Out]{
ctx: ctx,
req: req,
res: make(chan Out),
err: make(chan error),
done: make(chan struct{}),
}
}
func (job *job[In, Out]) result() (Out, error) {
return <-job.res, <-job.err
}
func (job *job[In, Out]) postSuccess(res Out) {
close(job.err)
job.res <- res
}
func (job *job[In, Out]) postFailure(err error) {
close(job.res)
job.err <- err
}
func (job *job[In, Out]) waitDone() {
<-job.done
}

View File

@ -1,147 +0,0 @@
package pool
import (
"context"
"errors"
"sync"
"github.com/ProtonMail/gluon/queue"
)
// ErrJobCancelled indicates the job was cancelled.
var ErrJobCancelled = errors.New("job cancelled by surrounding context")
// Pool is a worker pool that handles input of type In and returns results of type Out.
type Pool[In comparable, Out any] struct {
queue *queue.QueuedChannel[*job[In, Out]]
size int
}
// doneFunc must be called to free up pool resources.
type doneFunc func()
// New returns a new pool.
func New[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] {
queue := queue.NewQueuedChannel[*job[In, Out]](0, 0)
for i := 0; i < size; i++ {
go func() {
for job := range queue.GetChannel() {
select {
case <-job.ctx.Done():
job.postFailure(ErrJobCancelled)
default:
res, err := work(job.ctx, job.req)
if err != nil {
job.postFailure(err)
} else {
job.postSuccess(res)
}
job.waitDone()
}
}
}()
}
return &Pool[In, Out]{
queue: queue,
size: size,
}
}
// Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred.
func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, Out, error) error) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var (
wg sync.WaitGroup
errList []error
lock sync.Mutex
)
for _, req := range reqs {
req := req
wg.Add(1)
go func() {
defer wg.Done()
job, done := pool.newJob(ctx, req)
defer done()
res, err := job.result()
if err := fn(req, res, err); err != nil {
lock.Lock()
defer lock.Unlock()
// Cancel ongoing jobs.
cancel()
// Collect the error.
errList = append(errList, err)
}
}()
}
wg.Wait()
// TODO: Join the errors somehow?
if len(errList) > 0 {
return errList[0]
}
return nil
}
// ProcessAll submits jobs to the pool. All results are returned once available.
func (pool *Pool[In, Out]) ProcessAll(ctx context.Context, reqs []In) (map[In]Out, error) {
var (
data = make(map[In]Out)
lock = sync.Mutex{}
)
if err := pool.Process(ctx, reqs, func(req In, res Out, err error) error {
if err != nil {
return err
}
lock.Lock()
defer lock.Unlock()
data[req] = res
return nil
}); err != nil {
return nil, err
}
return data, nil
}
// ProcessOne submits one job to the pool and returns the result.
func (pool *Pool[In, Out]) ProcessOne(ctx context.Context, req In) (Out, error) {
job, done := pool.newJob(ctx, req)
defer done()
return job.result()
}
func (pool *Pool[In, Out]) Done() {
pool.queue.Close()
}
// newJob submits a job to the pool. It returns a job handle and a DoneFunc.
// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done,
// which frees up the worker in the pool for reuse.
func (pool *Pool[In, Out]) newJob(ctx context.Context, req In) (*job[In, Out], doneFunc) {
job := newJob[In, Out](ctx, req)
pool.queue.Enqueue(job)
return job, func() { close(job.done) }
}

View File

@ -1,163 +0,0 @@
package pool
import (
"context"
"errors"
"runtime"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPool_NewJob(t *testing.T) {
doubler := newDoubler(runtime.NumCPU())
job1, done1 := doubler.newJob(context.Background(), 1)
defer done1()
job2, done2 := doubler.newJob(context.Background(), 2)
defer done2()
res2, err := job2.result()
require.NoError(t, err)
res1, err := job1.result()
require.NoError(t, err)
assert.Equal(t, 2, res1)
assert.Equal(t, 4, res2)
}
func TestPool_NewJob_Done(t *testing.T) {
// Create a doubler pool with 2 workers.
doubler := newDoubler(2)
// Start two jobs. Don't mark the jobs as done yet.
job1, done1 := doubler.newJob(context.Background(), 1)
job2, done2 := doubler.newJob(context.Background(), 2)
// Get the first result.
res1, _ := job1.result()
assert.Equal(t, 2, res1)
// Get the first result.
res2, _ := job2.result()
assert.Equal(t, 4, res2)
// Additional jobs will wait.
job3, _ := doubler.newJob(context.Background(), 3)
job4, _ := doubler.newJob(context.Background(), 4)
// Channel to collect results from jobs 3 and 4.
resCh := make(chan int, 2)
go func() {
res, _ := job3.result()
resCh <- res
}()
go func() {
res, _ := job4.result()
resCh <- res
}()
// Mark jobs 1 and 2 as done, freeing up the workers.
done1()
done2()
assert.ElementsMatch(t, []int{6, 8}, []int{<-resCh, <-resCh})
}
func TestPool_Process(t *testing.T) {
doubler := newDoubler(runtime.NumCPU())
var (
res = make(map[int]int)
lock sync.Mutex
)
require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(reqVal, resVal int, err error) error {
require.NoError(t, err)
lock.Lock()
defer lock.Unlock()
res[reqVal] = resVal
return nil
}))
assert.Equal(t, map[int]int{
1: 2,
2: 4,
3: 6,
4: 8,
5: 10,
}, res)
}
func TestPool_Process_Error(t *testing.T) {
doubler := newDoublerWithError(runtime.NumCPU())
assert.Error(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(_ int, _ int, err error) error {
return err
}))
}
func TestPool_Process_Parallel(t *testing.T) {
doubler := newDoubler(runtime.NumCPU(), 100*time.Millisecond)
var wg sync.WaitGroup
for i := 0; i < 8; i++ {
wg.Add(1)
go func() {
defer wg.Done()
require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4}, func(_ int, _ int, err error) error {
return nil
}))
}()
}
wg.Wait()
}
func TestPool_ProcessAll(t *testing.T) {
doubler := newDoubler(runtime.NumCPU())
res, err := doubler.ProcessAll(context.Background(), []int{1, 2, 3, 4, 5})
require.NoError(t, err)
assert.Equal(t, map[int]int{
1: 2,
2: 4,
3: 6,
4: 8,
5: 10,
}, res)
}
func newDoubler(workers int, delay ...time.Duration) *Pool[int, int] {
return New(workers, func(ctx context.Context, req int) (int, error) {
if len(delay) > 0 {
time.Sleep(delay[0])
}
return 2 * req, nil
})
}
func newDoublerWithError(workers int) *Pool[int, int] {
return New(workers, func(ctx context.Context, req int) (int, error) {
if req%2 == 0 {
return 0, errors.New("oops")
}
return 2 * req, nil
})
}

View File

@ -23,7 +23,15 @@ func NewMap[Key comparable, Val any](from map[Key]Val) *Map[Key, Val] {
return m return m
} }
func (m *Map[Key, Val]) Get(key Key, fn func(val Val)) bool { func (m *Map[Key, Val]) Has(key Key) bool {
m.lock.RLock()
defer m.lock.RUnlock()
_, ok := m.data[key]
return ok
}
func (m *Map[Key, Val]) Get(key Key, fn func(Val)) bool {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
@ -37,7 +45,7 @@ func (m *Map[Key, Val]) Get(key Key, fn func(val Val)) bool {
return true return true
} }
func (m *Map[Key, Val]) GetErr(key Key, fn func(val Val) error) (bool, error) { func (m *Map[Key, Val]) GetErr(key Key, fn func(Val) error) (bool, error) {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
@ -56,6 +64,15 @@ func (m *Map[Key, Val]) Set(key Key, val Val) {
m.data[key] = val m.data[key] = val
} }
func (m *Map[Key, Val]) Iter(fn func(key Key, val Val)) {
m.lock.RLock()
defer m.lock.RUnlock()
for key, val := range m.data {
fn(key, val)
}
}
func (m *Map[Key, Val]) Keys(fn func(keys []Key)) { func (m *Map[Key, Val]) Keys(fn func(keys []Key)) {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
@ -70,28 +87,52 @@ func (m *Map[Key, Val]) Values(fn func(vals []Val)) {
fn(maps.Values(m.data)) fn(maps.Values(m.data))
} }
func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(val Val) Ret) (Ret, bool) { func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) Ret, fallback func() Ret) Ret {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
val, ok := m.data[key] val, ok := m.data[key]
if !ok { if !ok {
return *new(Ret), false return fallback()
} }
return fn(val), true return fn(val)
} }
func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(val Val) (Ret, error)) (Ret, bool, error) { func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
val, ok := m.data[key] val, ok := m.data[key]
if !ok { if !ok {
return *new(Ret), false, nil return fallback()
} }
ret, err := fn(val) return fn(val)
}
return ret, true, err
func FindMap[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) Ret, fallback func() Ret) Ret {
m.lock.RLock()
defer m.lock.RUnlock()
for _, val := range m.data {
if cmp(val) {
return fn(val)
}
}
return fallback()
}
func FindMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) {
m.lock.RLock()
defer m.lock.RUnlock()
for _, val := range m.data {
if cmp(val) {
return fn(val)
}
}
return fallback()
} }

49
internal/try/try.go Normal file
View File

@ -0,0 +1,49 @@
package try
import (
"fmt"
"github.com/sirupsen/logrus"
)
// Catch tries to execute the `try` function, and if it fails or panics,
// it executes the `handlers` functions in order.
func Catch(try func() error, handlers ...func() error) error {
if _, err := CatchVal(func() (any, error) { return nil, try() }, handlers...); err != nil {
return err
}
return nil
}
// CatchVal tries to execute the `try` function, and if it fails or panics,
// it executes the `handlers` functions in order.
func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, err error) {
defer func() {
if r := recover(); r != nil {
catch(handlers...)
err = fmt.Errorf("panic: %v", r)
}
}()
if res, err = try(); err != nil {
catch(handlers...)
return res, err
}
return res, nil
}
func catch(handlers ...func() error) {
defer func() {
if r := recover(); r != nil {
logrus.WithField("panic", r).Error("Panic in catch")
}
}()
for _, handler := range handlers {
if err := handler(); err != nil {
logrus.WithError(err).Error("Failed to handle error")
}
}
}

74
internal/try/try_test.go Normal file
View File

@ -0,0 +1,74 @@
package try
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestTry(t *testing.T) {
res, err := CatchVal(func() (string, error) {
return "foo", nil
})
require.NoError(t, err)
require.Equal(t, "foo", res)
}
func TestTryCatch(t *testing.T) {
tryErr := fmt.Errorf("oops")
res, err := CatchVal(
func() (string, error) {
return "", tryErr
},
func() error {
return nil
},
)
require.ErrorIs(t, err, tryErr)
require.Zero(t, res)
}
func TestTryCatchError(t *testing.T) {
tryErr := fmt.Errorf("oops")
res, err := CatchVal(
func() (string, error) {
return "", tryErr
},
func() error {
return fmt.Errorf("catch error")
},
)
require.ErrorIs(t, err, tryErr)
require.Zero(t, res)
}
func TestTryPanic(t *testing.T) {
res, err := CatchVal(
func() (string, error) {
panic("oops")
},
func() error {
return nil
},
)
require.ErrorContains(t, err, "panic: oops")
require.Zero(t, res)
}
func TestTryCatchPanic(t *testing.T) {
tryErr := fmt.Errorf("oops")
res, err := CatchVal(
func() (string, error) {
return "", tryErr
},
func() error {
panic("oops")
},
)
require.ErrorIs(t, err, tryErr)
require.Zero(t, res)
}

View File

@ -259,7 +259,12 @@ func (user *User) handleMessageEvents(ctx context.Context, messageEvents []litea
} }
func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error { func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
buildRes, err := user.buildRFC822(ctx, event.Message) full, err := user.client.GetFullMessage(ctx, event.Message.ID)
if err != nil {
return fmt.Errorf("failed to get full message: %w", err)
}
buildRes, err := buildRFC822(ctx, full, user.addrKRs)
if err != nil { if err != nil {
return fmt.Errorf("failed to build RFC822: %w", err) return fmt.Errorf("failed to build RFC822: %w", err)
} }

View File

@ -4,14 +4,11 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"runtime"
"strings" "strings"
"time" "time"
"github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/queue"
"github.com/bradenaw/juniper/iterator"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"github.com/google/uuid" "github.com/google/uuid"
@ -105,34 +102,38 @@ func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.
func (user *User) syncMessages(ctx context.Context) error { func (user *User) syncMessages(ctx context.Context) error {
// Determine which messages to sync. // Determine which messages to sync.
metadata, err := user.client.GetAllMessageMetadata(ctx, nil) allMetadata, err := user.client.GetAllMessageMetadata(ctx, nil)
if err != nil { if err != nil {
return fmt.Errorf("get all message metadata: %w", err) return fmt.Errorf("get all message metadata: %w", err)
} }
// If possible, begin syncing from the last synced message. metadata := allMetadata
// If possible, begin syncing from one beyond the last synced message.
if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" { if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" {
if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool { if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool {
return metadata.ID == beginID return metadata.ID == beginID
}); idx >= 0 { }); idx >= 0 {
metadata = metadata[idx:] metadata = metadata[idx+1:]
} }
} }
// Process the metadata, building the messages. // Process the metadata, building the messages.
buildCh := stream.Chunk(parallel.MapStream( buildCh := stream.Chunk(stream.Map(
ctx, user.client.GetFullMessages(ctx, xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string {
stream.FromIterator(iterator.Slice(metadata)), return metadata.ID
runtime.NumCPU()*runtime.NumCPU()/2, })...),
runtime.NumCPU()*runtime.NumCPU()/2, func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) {
user.buildRFC822, return buildRFC822(ctx, full, user.addrKRs)
},
), maxBatchSize) ), maxBatchSize)
defer buildCh.Close()
// Create the flushers, one per update channel. // Create the flushers, one per update channel.
flushers := make(map[string]*flusher) flushers := make(map[string]*flusher)
for addrID, updateCh := range user.updateCh { for addrID, updateCh := range user.updateCh {
flusher := newFlusher(user.ID(), updateCh, maxUpdateSize) flusher := newFlusher(updateCh, maxUpdateSize)
defer flusher.flush(ctx, true) defer flusher.flush(ctx, true)
flushers[addrID] = flusher flushers[addrID] = flusher
@ -142,6 +143,8 @@ func (user *User) syncMessages(ctx context.Context) error {
reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second) reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second)
defer reporter.done() defer reporter.done()
var count int
// Send each update to the appropriate flusher. // Send each update to the appropriate flusher.
for { for {
batch, err := buildCh.Next(ctx) batch, err := buildCh.Next(ctx)
@ -170,6 +173,8 @@ func (user *User) syncMessages(ctx context.Context) error {
} }
reporter.add(len(batch)) reporter.add(len(batch))
count += len(batch)
} }
} }

View File

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/pkg/message" "github.com/ProtonMail/proton-bridge/v2/pkg/message"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
@ -29,30 +30,20 @@ func defaultJobOpts() message.JobOptions {
} }
} }
func (user *User) buildRFC822(ctx context.Context, metadata liteapi.MessageMetadata) (*buildRes, error) { func buildRFC822(ctx context.Context, full liteapi.FullMessage, addrKRs map[string]*crypto.KeyRing) (*buildRes, error) {
msg, err := user.client.GetMessage(ctx, metadata.ID) literal, err := message.BuildRFC822(addrKRs[full.AddressID], full.Message, full.AttData, defaultJobOpts())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get message %s: %w", metadata.ID, err) return nil, fmt.Errorf("failed to build message %s: %w", full.ID, err)
} }
attData, err := user.attPool.ProcessAll(ctx, xslices.Map(msg.Attachments, func(att liteapi.Attachment) string { return att.ID })) update, err := newMessageCreatedUpdate(full.MessageMetadata, literal)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get attachments for message %s: %w", metadata.ID, err) return nil, fmt.Errorf("failed to create IMAP update for message %s: %w", full.ID, err)
}
literal, err := message.BuildRFC822(user.addrKRs[msg.AddressID], msg, attData, defaultJobOpts())
if err != nil {
return nil, fmt.Errorf("failed to build message %s: %w", metadata.ID, err)
}
update, err := newMessageCreatedUpdate(metadata, literal)
if err != nil {
return nil, fmt.Errorf("failed to create IMAP update for message %s: %w", metadata.ID, err)
} }
return &buildRes{ return &buildRes{
messageID: metadata.ID, messageID: full.ID,
addressID: metadata.AddressID, addressID: full.AddressID,
update: update, update: update,
}, nil }, nil
} }

View File

@ -9,21 +9,19 @@ import (
) )
type flusher struct { type flusher struct {
userID string
updateCh *queue.QueuedChannel[imap.Update] updateCh *queue.QueuedChannel[imap.Update]
updates []*imap.MessageCreated updates []*imap.MessageCreated
maxChunkSize int
maxUpdateSize int
curChunkSize int curChunkSize int
pushLock sync.Mutex pushLock sync.Mutex
} }
func newFlusher(userID string, updateCh *queue.QueuedChannel[imap.Update], maxChunkSize int) *flusher { func newFlusher(updateCh *queue.QueuedChannel[imap.Update], maxUpdateSize int) *flusher {
return &flusher{ return &flusher{
userID: userID,
updateCh: updateCh, updateCh: updateCh,
maxChunkSize: maxChunkSize, maxUpdateSize: maxUpdateSize,
} }
} }
@ -33,19 +31,17 @@ func (f *flusher) push(ctx context.Context, update *imap.MessageCreated) {
f.updates = append(f.updates, update) f.updates = append(f.updates, update)
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize { if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxUpdateSize {
f.flush(ctx, false) f.flush(ctx, false)
} }
} }
func (f *flusher) flush(ctx context.Context, wait bool) { func (f *flusher) flush(ctx context.Context, wait bool) {
if len(f.updates) == 0 { if len(f.updates) > 0 {
return
}
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...)) f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
f.updates = nil f.updates = nil
f.curChunkSize = 0 f.curChunkSize = 0
}
if wait { if wait {
update := imap.NewNoop() update := imap.NewNoop()

View File

@ -5,7 +5,6 @@ import (
"context" "context"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"runtime"
"time" "time"
"github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/connector"
@ -14,7 +13,6 @@ import (
"github.com/ProtonMail/gluon/wait" "github.com/ProtonMail/gluon/wait"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
"github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
@ -31,7 +29,6 @@ var (
type User struct { type User struct {
vault *vault.User vault *vault.User
client *liteapi.Client client *liteapi.Client
attPool *pool.Pool[string, []byte]
eventCh *queue.QueuedChannel[events.Event] eventCh *queue.QueuedChannel[events.Event]
apiUser *safe.Type[liteapi.User] apiUser *safe.Type[liteapi.User]
@ -91,7 +88,6 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
user := &User{ user := &User{
vault: encVault, vault: encVault,
client: client, client: client,
attPool: pool.New(runtime.NumCPU(), client.GetAttachment),
eventCh: queue.NewQueuedChannel[events.Event](0, 0), eventCh: queue.NewQueuedChannel[events.Event](0, 0),
apiUser: safe.NewType(apiUser), apiUser: safe.NewType(apiUser),
@ -123,7 +119,7 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
// If we haven't synced yet, do it first. // If we haven't synced yet, do it first.
// If it fails, we don't start the event loop. // If it fails, we don't start the event loop.
// Oterwise, begin processing API events, logging any errors that occur. // Otherwise, begin processing API events, logging any errors that occur.
go func() { go func() {
if status := user.vault.SyncStatus(); !status.HasMessages { if status := user.vault.SyncStatus(); !status.HasMessages {
if err := <-user.startSync(); err != nil { if err := <-user.startSync(); err != nil {
@ -336,8 +332,17 @@ func (user *User) NewSMTPSession(email string) (smtp.Session, error) {
} }
// Logout logs the user out from the API. // Logout logs the user out from the API.
// If withVault is true, the user's vault is also cleared.
func (user *User) Logout(ctx context.Context) error { func (user *User) Logout(ctx context.Context) error {
return user.client.AuthDelete(ctx) if err := user.client.AuthDelete(ctx); err != nil {
return fmt.Errorf("failed to delete auth: %w", err)
}
if err := user.vault.Clear(); err != nil {
return fmt.Errorf("failed to clear vault: %w", err)
}
return nil
} }
// Close closes ongoing connections and cleans up resources. // Close closes ongoing connections and cleans up resources.
@ -345,9 +350,6 @@ func (user *User) Close() error {
// Cancel ongoing syncs. // Cancel ongoing syncs.
user.stopSync() user.stopSync()
// Close the attachment pool.
user.attPool.Done()
// Close the user's API client. // Close the user's API client.
user.client.Close() user.client.Close()

View File

@ -89,13 +89,13 @@ func withAPI(t *testing.T, ctx context.Context, fn func(context.Context, *server
func withAccount(t *testing.T, s *server.Server, username, password string, emails []string, fn func(string, []string)) { func withAccount(t *testing.T, s *server.Server, username, password string, emails []string, fn func(string, []string)) {
var addrIDs []string var addrIDs []string
userID, addrID, err := s.CreateUser(username, password, emails[0]) userID, addrID, err := s.CreateUser(username, emails[0], []byte(password))
require.NoError(t, err) require.NoError(t, err)
addrIDs = append(addrIDs, addrID) addrIDs = append(addrIDs, addrID)
for _, email := range emails[1:] { for _, email := range emails[1:] {
addrID, err := s.CreateAddress(userID, email, password) addrID, err := s.CreateAddress(userID, email, []byte(password))
require.NoError(t, err) require.NoError(t, err)
addrIDs = append(addrIDs, addrID) addrIDs = append(addrIDs, addrID)

View File

@ -138,3 +138,12 @@ func (user *User) SetEventID(eventID string) error {
data.EventID = eventID data.EventID = eventID
}) })
} }
// Clear clears the user's auth secrets.
func (user *User) Clear() error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.AuthUID = ""
data.AuthRef = ""
data.KeyPass = nil
})
}

View File

@ -58,7 +58,7 @@ func TestUser_Clear(t *testing.T) {
require.Equal(t, "keyPass", string(user.KeyPass())) require.Equal(t, "keyPass", string(user.KeyPass()))
// Clear the user's auth information. // Clear the user's auth information.
require.NoError(t, s.ClearUser("userID")) require.NoError(t, user.Clear())
// Check the user's cleared auth information. // Check the user's cleared auth information.
require.Empty(t, user.AuthUID()) require.Empty(t, user.AuthUID())

View File

@ -107,14 +107,6 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
return vault.GetUser(userID) return vault.GetUser(userID)
} }
func (vault *Vault) ClearUser(userID string) error {
return vault.modUser(userID, func(data *UserData) {
data.AuthUID = ""
data.AuthRef = ""
data.KeyPass = nil
})
}
// DeleteUser removes the given user from the vault. // DeleteUser removes the given user from the vault.
func (vault *Vault) DeleteUser(userID string) error { func (vault *Vault) DeleteUser(userID string) error {
return vault.mod(func(data *Data) { return vault.mod(func(data *Data) {

View File

@ -12,8 +12,8 @@ type API interface {
GetHostURL() string GetHostURL() string
AddCallWatcher(func(server.Call), ...string) AddCallWatcher(func(server.Call), ...string)
CreateUser(username, password, address string) (string, string, error) CreateUser(username, address string, password []byte) (string, string, error)
CreateAddress(userID, address, password string) (string, error) CreateAddress(userID, address string, password []byte) (string, error)
RemoveAddress(userID, addrID string) error RemoveAddress(userID, addrID string) error
RevokeUser(userID string) error RevokeUser(userID string) error

View File

@ -2,17 +2,19 @@ package tests
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"gitlab.protontech.ch/go/liteapi"
) )
func (t *testCtx) startBridge() error { func (t *testCtx) startBridge() error {
// Bridge will enable the proxy by default at startup. // Bridge will enable the proxy by default at startup.
t.mocks.ProxyDialer.EXPECT().AllowProxy() t.mocks.ProxyCtl.EXPECT().AllowProxy()
// Get the path to the vault. // Get the path to the vault.
vaultDir, err := t.locator.ProvideSettingsPath() vaultDir, err := t.locator.ProvideSettingsPath()
@ -41,7 +43,8 @@ func (t *testCtx) startBridge() error {
vault, vault,
useragent.New(), useragent.New(),
t.mocks.TLSReporter, t.mocks.TLSReporter,
t.mocks.ProxyDialer, liteapi.NewDialer(t.netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
t.mocks.ProxyCtl,
t.mocks.Autostarter, t.mocks.Autostarter,
t.mocks.Updater, t.mocks.Updater,
t.version, t.version,

View File

@ -24,7 +24,7 @@ type testCtx struct {
// These are the objects supporting the test. // These are the objects supporting the test.
dir string dir string
api API api API
dialer *bridge.TestDialer netCtl *liteapi.NetCtl
locator *locations.Locations locator *locations.Locations
storeKey []byte storeKey []byte
version *semver.Version version *semver.Version
@ -76,15 +76,13 @@ type smtpClient struct {
func newTestCtx(tb testing.TB) *testCtx { func newTestCtx(tb testing.TB) *testCtx {
dir := tb.TempDir() dir := tb.TempDir()
dialer := bridge.NewTestDialer()
ctx := &testCtx{ ctx := &testCtx{
dir: dir, dir: dir,
api: newFakeAPI(), api: newFakeAPI(),
dialer: dialer, netCtl: liteapi.NewNetCtl(),
locator: locations.New(bridge.NewTestLocationsProvider(dir), "config-name"), locator: locations.New(bridge.NewTestLocationsProvider(dir), "config-name"),
storeKey: []byte("super-secret-store-key"), storeKey: []byte("super-secret-store-key"),
mocks: bridge.NewMocks(tb, dialer, defaultVersion, defaultVersion), mocks: bridge.NewMocks(tb, defaultVersion, defaultVersion),
version: defaultVersion, version: defaultVersion,
userIDByName: make(map[string]string), userIDByName: make(map[string]string),

View File

@ -38,12 +38,12 @@ func (s *scenario) itFailsWithError(wantErr string) error {
} }
func (s *scenario) internetIsTurnedOff() error { func (s *scenario) internetIsTurnedOff() error {
s.t.dialer.SetCanDial(false) s.t.netCtl.SetCanDial(false)
return nil return nil
} }
func (s *scenario) internetIsTurnedOn() error { func (s *scenario) internetIsTurnedOn() error {
s.t.dialer.SetCanDial(true) s.t.netCtl.SetCanDial(true)
return nil return nil
} }

View File

@ -14,7 +14,7 @@ import (
func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error { func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error {
// Create the user. // Create the user.
userID, addrID, err := s.t.api.CreateUser(username, password, username) userID, addrID, err := s.t.api.CreateUser(username, username, []byte(password))
if err != nil { if err != nil {
return err return err
} }
@ -34,7 +34,7 @@ func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, passwor
func (s *scenario) theAccountHasAdditionalAddress(username, address string) error { func (s *scenario) theAccountHasAdditionalAddress(username, address string) error {
userID := s.t.getUserID(username) userID := s.t.getUserID(username)
addrID, err := s.t.api.CreateAddress(userID, address, s.t.getUserPass(userID)) addrID, err := s.t.api.CreateAddress(userID, address, []byte(s.t.getUserPass(userID)))
if err != nil { if err != nil {
return err return err
} }