mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 12:46:46 +00:00
GODT-1657: More stable sync, with some tests
This commit is contained in:
2
Makefile
2
Makefile
@ -234,7 +234,7 @@ integration-test-bridge:
|
|||||||
${MAKE} -C test test-bridge
|
${MAKE} -C test test-bridge
|
||||||
|
|
||||||
mocks:
|
mocks:
|
||||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/bridge TLSReporter,ProxyDialer,Autostarter > internal/bridge/mocks/mocks.go
|
mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/bridge TLSReporter,ProxyController,Autostarter > internal/bridge/mocks/mocks.go
|
||||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go
|
mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go
|
||||||
|
|
||||||
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog
|
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog
|
||||||
|
|||||||
2
go.mod
2
go.mod
@ -38,7 +38,7 @@ require (
|
|||||||
github.com/sirupsen/logrus v1.9.0
|
github.com/sirupsen/logrus v1.9.0
|
||||||
github.com/stretchr/testify v1.8.0
|
github.com/stretchr/testify v1.8.0
|
||||||
github.com/urfave/cli/v2 v2.16.3
|
github.com/urfave/cli/v2 v2.16.3
|
||||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7
|
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e
|
||||||
golang.org/x/exp v0.0.0-20220921164117-439092de6870
|
golang.org/x/exp v0.0.0-20220921164117-439092de6870
|
||||||
golang.org/x/net v0.1.0
|
golang.org/x/net v0.1.0
|
||||||
golang.org/x/sys v0.1.0
|
golang.org/x/sys v0.1.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@ -397,8 +397,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr
|
|||||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||||
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
|
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
|
||||||
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
|
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
|
||||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7 h1:Hef7jPRzcfLOvOUHYoQ6efaI7p7/aT5kpZDqJ29owNI=
|
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e h1:CTGaREzkbz7u98nKt6+xsca2bWML79lH1XGbodRo+MY=
|
||||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7/go.mod h1:9nsslyEJn7Utbielp4c+hc7qT6hqIJ52aGFR/tX+tYk=
|
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e/go.mod h1:9nsslyEJn7Utbielp4c+hc7qT6hqIJ52aGFR/tX+tYk=
|
||||||
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
||||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||||
|
|||||||
@ -81,7 +81,18 @@ func newBridge(locations *locations.Locations, identifier *useragent.UserAgent)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new bridge.
|
// Create a new bridge.
|
||||||
bridge, err := bridge.New(constants.APIHost, locations, encVault, identifier, pinningDialer, proxyDialer, autostarter, updater, version)
|
bridge, err := bridge.New(
|
||||||
|
constants.APIHost,
|
||||||
|
locations,
|
||||||
|
encVault,
|
||||||
|
identifier,
|
||||||
|
pinningDialer,
|
||||||
|
dialer.CreateTransportWithDialer(proxyDialer),
|
||||||
|
proxyDialer,
|
||||||
|
autostarter,
|
||||||
|
updater,
|
||||||
|
version,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not create bridge: %w", err)
|
return nil, fmt.Errorf("could not create bridge: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,7 +35,7 @@ type Bridge struct {
|
|||||||
// api manages user API clients.
|
// api manages user API clients.
|
||||||
api *liteapi.Manager
|
api *liteapi.Manager
|
||||||
cookieJar *cookies.Jar
|
cookieJar *cookies.Jar
|
||||||
proxyDialer ProxyDialer
|
proxyCtl ProxyController
|
||||||
identifier Identifier
|
identifier Identifier
|
||||||
|
|
||||||
// watchers holds all registered event watchers.
|
// watchers holds all registered event watchers.
|
||||||
@ -81,15 +81,16 @@ func New(
|
|||||||
vault *vault.Vault, // the bridge's encrypted data store
|
vault *vault.Vault, // the bridge's encrypted data store
|
||||||
identifier Identifier, // the identifier to keep track of the user agent
|
identifier Identifier, // the identifier to keep track of the user agent
|
||||||
tlsReporter TLSReporter, // the TLS reporter to report TLS errors
|
tlsReporter TLSReporter, // the TLS reporter to report TLS errors
|
||||||
proxyDialer ProxyDialer, // the DoH dialer
|
roundTripper http.RoundTripper, // the round tripper to use for API requests
|
||||||
|
proxyCtl ProxyController, // the DoH controller
|
||||||
autostarter Autostarter, // the autostarter to manage autostart settings
|
autostarter Autostarter, // the autostarter to manage autostart settings
|
||||||
updater Updater, // the updater to fetch and install updates
|
updater Updater, // the updater to fetch and install updates
|
||||||
curVersion *semver.Version, // the current version of the bridge
|
curVersion *semver.Version, // the current version of the bridge
|
||||||
) (*Bridge, error) {
|
) (*Bridge, error) {
|
||||||
if vault.GetProxyAllowed() {
|
if vault.GetProxyAllowed() {
|
||||||
proxyDialer.AllowProxy()
|
proxyCtl.AllowProxy()
|
||||||
} else {
|
} else {
|
||||||
proxyDialer.DisallowProxy()
|
proxyCtl.DisallowProxy()
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieJar, err := cookies.NewCookieJar(vault)
|
cookieJar, err := cookies.NewCookieJar(vault)
|
||||||
@ -101,7 +102,7 @@ func New(
|
|||||||
liteapi.WithHostURL(apiURL),
|
liteapi.WithHostURL(apiURL),
|
||||||
liteapi.WithAppVersion(constants.AppVersion),
|
liteapi.WithAppVersion(constants.AppVersion),
|
||||||
liteapi.WithCookieJar(cookieJar),
|
liteapi.WithCookieJar(cookieJar),
|
||||||
liteapi.WithTransport(&http.Transport{DialTLSContext: proxyDialer.DialTLSContext}),
|
liteapi.WithTransport(roundTripper),
|
||||||
)
|
)
|
||||||
|
|
||||||
tlsConfig, err := loadTLSConfig(vault)
|
tlsConfig, err := loadTLSConfig(vault)
|
||||||
@ -141,7 +142,7 @@ func New(
|
|||||||
|
|
||||||
api: api,
|
api: api,
|
||||||
cookieJar: cookieJar,
|
cookieJar: cookieJar,
|
||||||
proxyDialer: proxyDialer,
|
proxyCtl: proxyCtl,
|
||||||
identifier: identifier,
|
identifier: identifier,
|
||||||
|
|
||||||
tlsConfig: tlsConfig,
|
tlsConfig: tlsConfig,
|
||||||
@ -179,6 +180,10 @@ func New(
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err := bridge.loadUsers(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load users: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for range tlsReporter.GetTLSIssueCh() {
|
for range tlsReporter.GetTLSIssueCh() {
|
||||||
bridge.publish(events.TLSIssue{})
|
bridge.publish(events.TLSIssue{})
|
||||||
@ -197,10 +202,6 @@ func New(
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := bridge.loadUsers(context.Background()); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load connected users: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bridge.serveIMAP(); err != nil {
|
if err := bridge.serveIMAP(); err != nil {
|
||||||
bridge.PushError(ErrServeIMAP)
|
bridge.PushError(ErrServeIMAP)
|
||||||
}
|
}
|
||||||
@ -309,14 +310,8 @@ func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) {
|
|||||||
func (bridge *Bridge) onStatusUp() {
|
func (bridge *Bridge) onStatusUp() {
|
||||||
bridge.publish(events.ConnStatusUp{})
|
bridge.publish(events.ConnStatusUp{})
|
||||||
|
|
||||||
for _, userID := range bridge.vault.GetUserIDs() {
|
if err := bridge.loadUsers(); err != nil {
|
||||||
if _, ok := bridge.users[userID]; !ok {
|
logrus.WithError(err).Error("Failed to load users")
|
||||||
if vaultUser, err := bridge.vault.GetUser(userID); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to get user from vault")
|
|
||||||
} else if err := bridge.loadUser(context.Background(), vaultUser); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to load user")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package bridge_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
@ -21,6 +22,7 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/v2/tests"
|
"github.com/ProtonMail/proton-bridge/v2/tests"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
"gitlab.protontech.ch/go/liteapi/server"
|
"gitlab.protontech.ch/go/liteapi/server"
|
||||||
"gitlab.protontech.ch/go/liteapi/server/backend"
|
"gitlab.protontech.ch/go/liteapi/server/backend"
|
||||||
)
|
)
|
||||||
@ -41,14 +43,14 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_ConnStatus(t *testing.T) {
|
func TestBridge_ConnStatus(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Get a stream of connection status events.
|
// Get a stream of connection status events.
|
||||||
eventCh, done := bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{})
|
eventCh, done := bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{})
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
// Simulate network disconnect.
|
// Simulate network disconnect.
|
||||||
dialer.SetCanDial(false)
|
netCtl.Disable()
|
||||||
|
|
||||||
// Trigger some operation that will fail due to the network disconnect.
|
// Trigger some operation that will fail due to the network disconnect.
|
||||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||||
@ -58,7 +60,7 @@ func TestBridge_ConnStatus(t *testing.T) {
|
|||||||
require.Equal(t, events.ConnStatusDown{}, <-eventCh)
|
require.Equal(t, events.ConnStatusDown{}, <-eventCh)
|
||||||
|
|
||||||
// Simulate network reconnect.
|
// Simulate network reconnect.
|
||||||
dialer.SetCanDial(true)
|
netCtl.Enable()
|
||||||
|
|
||||||
// Trigger some operation that will succeed due to the network reconnect.
|
// Trigger some operation that will succeed due to the network reconnect.
|
||||||
userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||||
@ -72,8 +74,8 @@ func TestBridge_ConnStatus(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_TLSIssue(t *testing.T) {
|
func TestBridge_TLSIssue(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Get a stream of TLS issue events.
|
// Get a stream of TLS issue events.
|
||||||
tlsEventCh, done := bridge.GetEvents(events.TLSIssue{})
|
tlsEventCh, done := bridge.GetEvents(events.TLSIssue{})
|
||||||
defer done()
|
defer done()
|
||||||
@ -90,8 +92,8 @@ func TestBridge_TLSIssue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Focus(t *testing.T) {
|
func TestBridge_Focus(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Get a stream of TLS issue events.
|
// Get a stream of TLS issue events.
|
||||||
raiseCh, done := bridge.GetEvents(events.Raise{})
|
raiseCh, done := bridge.GetEvents(events.Raise{})
|
||||||
defer done()
|
defer done()
|
||||||
@ -106,14 +108,14 @@ func TestBridge_Focus(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_UserAgent(t *testing.T) {
|
func TestBridge_UserAgent(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
var calls []server.Call
|
var calls []server.Call
|
||||||
|
|
||||||
s.AddCallWatcher(func(call server.Call) {
|
s.AddCallWatcher(func(call server.Call) {
|
||||||
calls = append(calls, call)
|
calls = append(calls, call)
|
||||||
})
|
})
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Set the platform to something other than the default.
|
// Set the platform to something other than the default.
|
||||||
bridge.SetCurrentPlatform("platform")
|
bridge.SetCurrentPlatform("platform")
|
||||||
|
|
||||||
@ -131,7 +133,7 @@ func TestBridge_UserAgent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Cookies(t *testing.T) {
|
func TestBridge_Cookies(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
var calls []server.Call
|
var calls []server.Call
|
||||||
|
|
||||||
s.AddCallWatcher(func(call server.Call) {
|
s.AddCallWatcher(func(call server.Call) {
|
||||||
@ -141,7 +143,7 @@ func TestBridge_Cookies(t *testing.T) {
|
|||||||
var sessionID string
|
var sessionID string
|
||||||
|
|
||||||
// Start bridge and add a user so that API assigns us a session ID via cookie.
|
// Start bridge and add a user so that API assigns us a session ID via cookie.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -152,7 +154,7 @@ func TestBridge_Cookies(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Start bridge again and check that it uses the same session ID.
|
// Start bridge again and check that it uses the same session ID.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
cookie, err := (&http.Request{Header: calls[len(calls)-1].Header}).Cookie("Session-Id")
|
cookie, err := (&http.Request{Header: calls[len(calls)-1].Header}).Cookie("Session-Id")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -162,8 +164,8 @@ func TestBridge_Cookies(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_CheckUpdate(t *testing.T) {
|
func TestBridge_CheckUpdate(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Disable autoupdate for this test.
|
// Disable autoupdate for this test.
|
||||||
require.NoError(t, bridge.SetAutoUpdate(false))
|
require.NoError(t, bridge.SetAutoUpdate(false))
|
||||||
|
|
||||||
@ -201,8 +203,8 @@ func TestBridge_CheckUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_AutoUpdate(t *testing.T) {
|
func TestBridge_AutoUpdate(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Enable autoupdate for this test.
|
// Enable autoupdate for this test.
|
||||||
require.NoError(t, bridge.SetAutoUpdate(true))
|
require.NoError(t, bridge.SetAutoUpdate(true))
|
||||||
|
|
||||||
@ -229,8 +231,8 @@ func TestBridge_AutoUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_ManualUpdate(t *testing.T) {
|
func TestBridge_ManualUpdate(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Disable autoupdate for this test.
|
// Disable autoupdate for this test.
|
||||||
require.NoError(t, bridge.SetAutoUpdate(false))
|
require.NoError(t, bridge.SetAutoUpdate(false))
|
||||||
|
|
||||||
@ -258,8 +260,8 @@ func TestBridge_ManualUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_ForceUpdate(t *testing.T) {
|
func TestBridge_ForceUpdate(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Get a stream of update events.
|
// Get a stream of update events.
|
||||||
updateCh, done := bridge.GetEvents(events.UpdateForced{})
|
updateCh, done := bridge.GetEvents(events.UpdateForced{})
|
||||||
defer done()
|
defer done()
|
||||||
@ -278,11 +280,11 @@ func TestBridge_ForceUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_BadVaultKey(t *testing.T) {
|
func TestBridge_BadVaultKey(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
var userID string
|
var userID string
|
||||||
|
|
||||||
// Login a user.
|
// Login a user.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -290,27 +292,27 @@ func TestBridge_BadVaultKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Start bridge with the correct vault key -- it should load the users correctly.
|
// Start bridge with the correct vault key -- it should load the users correctly.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs())
|
require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs())
|
||||||
})
|
})
|
||||||
|
|
||||||
// Start bridge with a bad vault key, the vault will be wiped and bridge will show no users.
|
// Start bridge with a bad vault key, the vault will be wiped and bridge will show no users.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
require.Empty(t, bridge.GetUserIDs())
|
require.Empty(t, bridge.GetUserIDs())
|
||||||
})
|
})
|
||||||
|
|
||||||
// Start bridge with a nil vault key, the vault will be wiped and bridge will show no users.
|
// Start bridge with a nil vault key, the vault will be wiped and bridge will show no users.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
require.Empty(t, bridge.GetUserIDs())
|
require.Empty(t, bridge.GetUserIDs())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_MissingGluonDir(t *testing.T) {
|
func TestBridge_MissingGluonDir(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
var gluonDir string
|
var gluonDir string
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -325,20 +327,20 @@ func TestBridge_MissingGluonDir(t *testing.T) {
|
|||||||
require.NoError(t, os.RemoveAll(gluonDir))
|
require.NoError(t, os.RemoveAll(gluonDir))
|
||||||
|
|
||||||
// Bridge starts but can't find the gluon dir; there should be no error.
|
// Bridge starts but can't find the gluon dir; there should be no error.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// ...
|
// ...
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// withEnv creates the full test environment and runs the tests.
|
// withTLSEnv creates the full test environment and runs the tests.
|
||||||
func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte)) {
|
func withTLSEnv(t *testing.T, tests func(context.Context, *server.Server, *liteapi.NetCtl, bridge.Locator, []byte)) {
|
||||||
// Create test API.
|
// Create test API.
|
||||||
server := server.NewTLS()
|
server := server.NewTLS()
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
// Add test user.
|
// Add test user.
|
||||||
_, _, err := server.CreateUser(username, string(password), username+"@pm.me")
|
_, _, err := server.CreateUser(username, username+"@pm.me", password)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Generate a random vault key.
|
// Generate a random vault key.
|
||||||
@ -349,23 +351,56 @@ func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
// Create a net controller so we can simulate network connectivity issues.
|
||||||
|
netCtl := liteapi.NewNetCtl()
|
||||||
|
|
||||||
|
// Create a locations object to provide temporary locations for bridge data during the test.
|
||||||
|
locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name")
|
||||||
|
|
||||||
// Run the tests.
|
// Run the tests.
|
||||||
tests(
|
tests(ctx, server, netCtl, locations, vaultKey)
|
||||||
ctx,
|
}
|
||||||
server,
|
|
||||||
bridge.NewTestDialer(),
|
// withEnv creates the full test environment and runs the tests.
|
||||||
locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name"),
|
func withEnv(t *testing.T, server *server.Server, tests func(context.Context, *liteapi.NetCtl, bridge.Locator, []byte)) {
|
||||||
vaultKey,
|
// Add test user.
|
||||||
)
|
_, _, err := server.CreateUser(username, username+"@pm.me", password)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Generate a random vault key.
|
||||||
|
vaultKey, err := crypto.RandomToken(32)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a context used for the test.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Create a net controller so we can simulate network connectivity issues.
|
||||||
|
netCtl := liteapi.NewNetCtl()
|
||||||
|
|
||||||
|
// Create a locations object to provide temporary locations for bridge data during the test.
|
||||||
|
locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name")
|
||||||
|
|
||||||
|
// Run the tests.
|
||||||
|
tests(ctx, netCtl, locations, vaultKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done.
|
// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done.
|
||||||
func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) {
|
func withBridge(
|
||||||
|
t *testing.T,
|
||||||
|
ctx context.Context,
|
||||||
|
apiURL string,
|
||||||
|
netCtl *liteapi.NetCtl,
|
||||||
|
locator bridge.Locator,
|
||||||
|
vaultKey []byte,
|
||||||
|
tests func(*bridge.Bridge, *bridge.Mocks),
|
||||||
|
) {
|
||||||
// Create the mock objects used in the tests.
|
// Create the mock objects used in the tests.
|
||||||
mocks := bridge.NewMocks(t, dialer, v2_3_0, v2_3_0)
|
mocks := bridge.NewMocks(t, v2_3_0, v2_3_0)
|
||||||
|
defer mocks.Close()
|
||||||
|
|
||||||
// Bridge will enable the proxy by default at startup.
|
// Bridge will enable the proxy by default at startup.
|
||||||
mocks.ProxyDialer.EXPECT().AllowProxy()
|
mocks.ProxyCtl.EXPECT().AllowProxy()
|
||||||
|
|
||||||
// Get the path to the vault.
|
// Get the path to the vault.
|
||||||
vaultDir, err := locator.ProvideSettingsPath()
|
vaultDir, err := locator.ProvideSettingsPath()
|
||||||
@ -380,7 +415,18 @@ func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge
|
|||||||
require.NoError(t, vault.SetSMTPPort(0))
|
require.NoError(t, vault.SetSMTPPort(0))
|
||||||
|
|
||||||
// Create a new bridge.
|
// Create a new bridge.
|
||||||
bridge, err := bridge.New(apiURL, locator, vault, useragent.New(), mocks.TLSReporter, mocks.ProxyDialer, mocks.Autostarter, mocks.Updater, v2_3_0)
|
bridge, err := bridge.New(
|
||||||
|
apiURL,
|
||||||
|
locator,
|
||||||
|
vault,
|
||||||
|
useragent.New(),
|
||||||
|
mocks.TLSReporter,
|
||||||
|
liteapi.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
|
||||||
|
mocks.ProxyCtl,
|
||||||
|
mocks.Autostarter,
|
||||||
|
mocks.Updater,
|
||||||
|
v2_3_0,
|
||||||
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Close the bridge when done.
|
// Close the bridge when done.
|
||||||
|
|||||||
@ -6,7 +6,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/Masterminds/semver/v3"
|
"github.com/Masterminds/semver/v3"
|
||||||
"github.com/ProtonMail/gluon"
|
"github.com/ProtonMail/gluon"
|
||||||
@ -33,6 +35,22 @@ func (bridge *Bridge) serveIMAP() error {
|
|||||||
return fmt.Errorf("failed to serve IMAP: %w", err)
|
return fmt.Errorf("failed to serve IMAP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(imapListener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get IMAP listener address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
portInt, err := strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to convert IMAP listener port to int: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if portInt != bridge.vault.GetIMAPPort() {
|
||||||
|
if err := bridge.vault.SetIMAPPort(portInt); err != nil {
|
||||||
|
return fmt.Errorf("failed to update IMAP port in vault: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for err := range bridge.imapServer.GetErrorCh() {
|
for err := range bridge.imapServer.GetErrorCh() {
|
||||||
logrus.WithError(err).Error("IMAP server error")
|
logrus.WithError(err).Error("IMAP server error")
|
||||||
|
|||||||
@ -1,10 +1,6 @@
|
|||||||
package bridge
|
package bridge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -15,7 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Mocks struct {
|
type Mocks struct {
|
||||||
ProxyDialer *mocks.MockProxyDialer
|
ProxyCtl *mocks.MockProxyController
|
||||||
TLSReporter *mocks.MockTLSReporter
|
TLSReporter *mocks.MockTLSReporter
|
||||||
TLSIssueCh chan struct{}
|
TLSIssueCh chan struct{}
|
||||||
|
|
||||||
@ -23,11 +19,11 @@ type Mocks struct {
|
|||||||
Autostarter *mocks.MockAutostarter
|
Autostarter *mocks.MockAutostarter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Version) *Mocks {
|
func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks {
|
||||||
ctl := gomock.NewController(tb)
|
ctl := gomock.NewController(tb)
|
||||||
|
|
||||||
mocks := &Mocks{
|
mocks := &Mocks{
|
||||||
ProxyDialer: mocks.NewMockProxyDialer(ctl),
|
ProxyCtl: mocks.NewMockProxyController(ctl),
|
||||||
TLSReporter: mocks.NewMockTLSReporter(ctl),
|
TLSReporter: mocks.NewMockTLSReporter(ctl),
|
||||||
TLSIssueCh: make(chan struct{}),
|
TLSIssueCh: make(chan struct{}),
|
||||||
|
|
||||||
@ -35,41 +31,14 @@ func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Versio
|
|||||||
Autostarter: mocks.NewMockAutostarter(ctl),
|
Autostarter: mocks.NewMockAutostarter(ctl),
|
||||||
}
|
}
|
||||||
|
|
||||||
// When using the proxy dialer, we want to use the test dialer.
|
|
||||||
mocks.ProxyDialer.EXPECT().DialTLSContext(
|
|
||||||
gomock.Any(),
|
|
||||||
gomock.Any(),
|
|
||||||
gomock.Any(),
|
|
||||||
).DoAndReturn(func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
return dialer.DialTLSContext(ctx, network, address)
|
|
||||||
}).AnyTimes()
|
|
||||||
|
|
||||||
// When getting the TLS issue channel, we want to return the test channel.
|
// When getting the TLS issue channel, we want to return the test channel.
|
||||||
mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes()
|
mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes()
|
||||||
|
|
||||||
return mocks
|
return mocks
|
||||||
}
|
}
|
||||||
|
|
||||||
type TestDialer struct {
|
func (mocks *Mocks) Close() {
|
||||||
canDial bool
|
close(mocks.TLSIssueCh)
|
||||||
}
|
|
||||||
|
|
||||||
func NewTestDialer() *TestDialer {
|
|
||||||
return &TestDialer{
|
|
||||||
canDial: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *TestDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
|
||||||
if !d.canDial {
|
|
||||||
return nil, errors.New("cannot dial")
|
|
||||||
}
|
|
||||||
|
|
||||||
return (&tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}).DialContext(ctx, network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *TestDialer) SetCanDial(canDial bool) {
|
|
||||||
d.canDial = canDial
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TestLocationsProvider struct {
|
type TestLocationsProvider struct {
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyDialer,Autostarter)
|
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyController,Autostarter)
|
||||||
|
|
||||||
// Package mocks is a generated GoMock package.
|
// Package mocks is a generated GoMock package.
|
||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
|
||||||
net "net"
|
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
@ -49,66 +47,51 @@ func (mr *MockTLSReporterMockRecorder) GetTLSIssueCh() *gomock.Call {
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh))
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockProxyDialer is a mock of ProxyDialer interface.
|
// MockProxyController is a mock of ProxyController interface.
|
||||||
type MockProxyDialer struct {
|
type MockProxyController struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
recorder *MockProxyDialerMockRecorder
|
recorder *MockProxyControllerMockRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockProxyDialerMockRecorder is the mock recorder for MockProxyDialer.
|
// MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
|
||||||
type MockProxyDialerMockRecorder struct {
|
type MockProxyControllerMockRecorder struct {
|
||||||
mock *MockProxyDialer
|
mock *MockProxyController
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMockProxyDialer creates a new mock instance.
|
// NewMockProxyController creates a new mock instance.
|
||||||
func NewMockProxyDialer(ctrl *gomock.Controller) *MockProxyDialer {
|
func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
|
||||||
mock := &MockProxyDialer{ctrl: ctrl}
|
mock := &MockProxyController{ctrl: ctrl}
|
||||||
mock.recorder = &MockProxyDialerMockRecorder{mock}
|
mock.recorder = &MockProxyControllerMockRecorder{mock}
|
||||||
return mock
|
return mock
|
||||||
}
|
}
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
func (m *MockProxyDialer) EXPECT() *MockProxyDialerMockRecorder {
|
func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
|
||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowProxy mocks base method.
|
// AllowProxy mocks base method.
|
||||||
func (m *MockProxyDialer) AllowProxy() {
|
func (m *MockProxyController) AllowProxy() {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "AllowProxy")
|
m.ctrl.Call(m, "AllowProxy")
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowProxy indicates an expected call of AllowProxy.
|
// AllowProxy indicates an expected call of AllowProxy.
|
||||||
func (mr *MockProxyDialerMockRecorder) AllowProxy() *gomock.Call {
|
func (mr *MockProxyControllerMockRecorder) AllowProxy() *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyDialer)(nil).AllowProxy))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyController)(nil).AllowProxy))
|
||||||
}
|
|
||||||
|
|
||||||
// DialTLSContext mocks base method.
|
|
||||||
func (m *MockProxyDialer) DialTLSContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "DialTLSContext", arg0, arg1, arg2)
|
|
||||||
ret0, _ := ret[0].(net.Conn)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialTLSContext indicates an expected call of DialTLSContext.
|
|
||||||
func (mr *MockProxyDialerMockRecorder) DialTLSContext(arg0, arg1, arg2 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLSContext", reflect.TypeOf((*MockProxyDialer)(nil).DialTLSContext), arg0, arg1, arg2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisallowProxy mocks base method.
|
// DisallowProxy mocks base method.
|
||||||
func (m *MockProxyDialer) DisallowProxy() {
|
func (m *MockProxyController) DisallowProxy() {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "DisallowProxy")
|
m.ctrl.Call(m, "DisallowProxy")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisallowProxy indicates an expected call of DisallowProxy.
|
// DisallowProxy indicates an expected call of DisallowProxy.
|
||||||
func (mr *MockProxyDialerMockRecorder) DisallowProxy() *gomock.Call {
|
func (mr *MockProxyControllerMockRecorder) DisallowProxy() *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyDialer)(nil).DisallowProxy))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyController)(nil).DisallowProxy))
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockAutostarter is a mock of Autostarter interface.
|
// MockAutostarter is a mock of Autostarter interface.
|
||||||
|
|||||||
@ -138,9 +138,9 @@ func (bridge *Bridge) GetProxyAllowed() bool {
|
|||||||
|
|
||||||
func (bridge *Bridge) SetProxyAllowed(allowed bool) error {
|
func (bridge *Bridge) SetProxyAllowed(allowed bool) error {
|
||||||
if allowed {
|
if allowed {
|
||||||
bridge.proxyDialer.AllowProxy()
|
bridge.proxyCtl.AllowProxy()
|
||||||
} else {
|
} else {
|
||||||
bridge.proxyDialer.DisallowProxy()
|
bridge.proxyCtl.DisallowProxy()
|
||||||
}
|
}
|
||||||
|
|
||||||
return bridge.vault.SetProxyAllowed(allowed)
|
return bridge.vault.SetProxyAllowed(allowed)
|
||||||
|
|||||||
@ -7,12 +7,13 @@ import (
|
|||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
"gitlab.protontech.ch/go/liteapi/server"
|
"gitlab.protontech.ch/go/liteapi/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBridge_Settings_GluonDir(t *testing.T) {
|
func TestBridge_Settings_GluonDir(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Create a user.
|
// Create a user.
|
||||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -34,8 +35,8 @@ func TestBridge_Settings_GluonDir(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Settings_IMAPPort(t *testing.T) {
|
func TestBridge_Settings_IMAPPort(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
curPort := bridge.GetIMAPPort()
|
curPort := bridge.GetIMAPPort()
|
||||||
|
|
||||||
// Set the port to 1144.
|
// Set the port to 1144.
|
||||||
@ -51,8 +52,8 @@ func TestBridge_Settings_IMAPPort(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Settings_IMAPSSL(t *testing.T) {
|
func TestBridge_Settings_IMAPSSL(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// By default, IMAP SSL is disabled.
|
// By default, IMAP SSL is disabled.
|
||||||
require.False(t, bridge.GetIMAPSSL())
|
require.False(t, bridge.GetIMAPSSL())
|
||||||
|
|
||||||
@ -66,8 +67,8 @@ func TestBridge_Settings_IMAPSSL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Settings_SMTPPort(t *testing.T) {
|
func TestBridge_Settings_SMTPPort(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
curPort := bridge.GetSMTPPort()
|
curPort := bridge.GetSMTPPort()
|
||||||
|
|
||||||
// Set the port to 1024.
|
// Set the port to 1024.
|
||||||
@ -84,8 +85,8 @@ func TestBridge_Settings_SMTPPort(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Settings_SMTPSSL(t *testing.T) {
|
func TestBridge_Settings_SMTPSSL(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// By default, SMTP SSL is disabled.
|
// By default, SMTP SSL is disabled.
|
||||||
require.False(t, bridge.GetSMTPSSL())
|
require.False(t, bridge.GetSMTPSSL())
|
||||||
|
|
||||||
@ -99,13 +100,13 @@ func TestBridge_Settings_SMTPSSL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Settings_Proxy(t *testing.T) {
|
func TestBridge_Settings_Proxy(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// By default, proxy is allowed.
|
// By default, proxy is allowed.
|
||||||
require.True(t, bridge.GetProxyAllowed())
|
require.True(t, bridge.GetProxyAllowed())
|
||||||
|
|
||||||
// Disallow proxy.
|
// Disallow proxy.
|
||||||
mocks.ProxyDialer.EXPECT().DisallowProxy()
|
mocks.ProxyCtl.EXPECT().DisallowProxy()
|
||||||
require.NoError(t, bridge.SetProxyAllowed(false))
|
require.NoError(t, bridge.SetProxyAllowed(false))
|
||||||
|
|
||||||
// Get the new setting.
|
// Get the new setting.
|
||||||
@ -115,8 +116,8 @@ func TestBridge_Settings_Proxy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Settings_Autostart(t *testing.T) {
|
func TestBridge_Settings_Autostart(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// By default, autostart is disabled.
|
// By default, autostart is disabled.
|
||||||
require.False(t, bridge.GetAutostart())
|
require.False(t, bridge.GetAutostart())
|
||||||
|
|
||||||
@ -131,8 +132,8 @@ func TestBridge_Settings_Autostart(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Settings_FirstStart(t *testing.T) {
|
func TestBridge_Settings_FirstStart(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// By default, first start is true.
|
// By default, first start is true.
|
||||||
require.True(t, bridge.GetFirstStart())
|
require.True(t, bridge.GetFirstStart())
|
||||||
|
|
||||||
@ -146,8 +147,8 @@ func TestBridge_Settings_FirstStart(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Settings_FirstStartGUI(t *testing.T) {
|
func TestBridge_Settings_FirstStartGUI(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// By default, first start is true.
|
// By default, first start is true.
|
||||||
require.True(t, bridge.GetFirstStartGUI())
|
require.True(t, bridge.GetFirstStartGUI())
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,8 @@ package bridge
|
|||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||||
"github.com/emersion/go-sasl"
|
"github.com/emersion/go-sasl"
|
||||||
@ -22,6 +24,22 @@ func (bridge *Bridge) serveSMTP() error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(smtpListener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get SMTP listener address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
portInt, err := strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to convert SMTP listener port to int: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if portInt != bridge.vault.GetSMTPPort() {
|
||||||
|
if err := bridge.vault.SetSMTPPort(portInt); err != nil {
|
||||||
|
return fmt.Errorf("failed to update SMTP port in vault: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
130
internal/bridge/sync_test.go
Normal file
130
internal/bridge/sync_test.go
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
package bridge_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
|
"github.com/emersion/go-imap/client"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
|
"gitlab.protontech.ch/go/liteapi/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBridge_Sync(t *testing.T) {
|
||||||
|
s := server.New()
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
numMsg := 1 << 10
|
||||||
|
|
||||||
|
withEnv(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
|
userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labelID, err := s.CreateLabel(userID, "folder", liteapi.LabelTypeFolder)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for i := 0; i < numMsg; i++ {
|
||||||
|
messageID, err := s.CreateMessage(userID, addrID, literal, liteapi.MessageFlagReceived, false, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, s.LabelMessage(userID, messageID, labelID))
|
||||||
|
}
|
||||||
|
|
||||||
|
var read uint64
|
||||||
|
|
||||||
|
netCtl.OnRead(func(b []byte) {
|
||||||
|
read += uint64(len(b))
|
||||||
|
})
|
||||||
|
|
||||||
|
// The initial user should be fully synced.
|
||||||
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
syncCh, done := bridge.GetEvents(events.SyncFinished{})
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// If we then connect an IMAP client, it should see all the messages.
|
||||||
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
info, err := bridge.GetUserInfo(userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, info.Connected)
|
||||||
|
|
||||||
|
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||||
|
defer client.Logout()
|
||||||
|
|
||||||
|
status, err := client.Select(`Folders/folder`, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, uint32(numMsg), status.Messages)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Now let's remove the user and simulate a network error.
|
||||||
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Pretend we can only sync 2/3 of the original messages.
|
||||||
|
netCtl.SetReadLimit(2 * read / 3)
|
||||||
|
|
||||||
|
// Login the user; its sync should fail.
|
||||||
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
syncCh, done := bridge.GetEvents(events.SyncFailed{})
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, userID, (<-syncCh).(events.SyncFailed).UserID)
|
||||||
|
|
||||||
|
info, err := bridge.GetUserInfo(userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, info.Connected)
|
||||||
|
|
||||||
|
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||||
|
defer client.Logout()
|
||||||
|
|
||||||
|
status, err := client.Select(`Folders/folder`, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Less(t, status.Messages, uint32(numMsg))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Remove the network limit, allowing the sync to finish.
|
||||||
|
netCtl.SetReadLimit(0)
|
||||||
|
|
||||||
|
// Login the user; its sync should now finish.
|
||||||
|
// If we then connect an IMAP client, it should eventually see all the messages.
|
||||||
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
syncCh, done := bridge.GetEvents(events.SyncFinished{})
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID)
|
||||||
|
|
||||||
|
info, err := bridge.GetUserInfo(userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, info.Connected)
|
||||||
|
|
||||||
|
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||||
|
defer client.Logout()
|
||||||
|
|
||||||
|
status, err := client.Select(`Folders/folder`, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, uint32(numMsg), status.Messages)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
6
internal/bridge/testdata/text-plain.eml
vendored
Normal file
6
internal/bridge/testdata/text-plain.eml
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
To: recipient@pm.me
|
||||||
|
From: sender@pm.me
|
||||||
|
Subject: Test
|
||||||
|
Content-Type: text/plain; charset=utf-8
|
||||||
|
|
||||||
|
Test
|
||||||
@ -1,9 +1,6 @@
|
|||||||
package bridge
|
package bridge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -21,17 +18,15 @@ type Identifier interface {
|
|||||||
SetPlatform(platform string)
|
SetPlatform(platform string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TLSReporter interface {
|
type ProxyController interface {
|
||||||
GetTLSIssueCh() <-chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProxyDialer interface {
|
|
||||||
DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error)
|
|
||||||
|
|
||||||
AllowProxy()
|
AllowProxy()
|
||||||
DisallowProxy()
|
DisallowProxy()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TLSReporter interface {
|
||||||
|
GetTLSIssueCh() <-chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
type Autostarter interface {
|
type Autostarter interface {
|
||||||
Enable() error
|
Enable() error
|
||||||
Disable() error
|
Disable() error
|
||||||
|
|||||||
@ -18,6 +18,9 @@ func (bridge *Bridge) watchForUpdates() error {
|
|||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-bridge.stopCh:
|
||||||
|
return
|
||||||
|
|
||||||
case <-bridge.updateCheckCh:
|
case <-bridge.updateCheckCh:
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ProtonMail/gluon/imap"
|
"github.com/ProtonMail/gluon/imap"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/try"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
@ -82,9 +83,11 @@ func (bridge *Bridge) LoginUser(
|
|||||||
) (string, error) {
|
) (string, error) {
|
||||||
client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password)
|
client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to create new API client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userID, err := try.CatchVal(
|
||||||
|
func() (string, error) {
|
||||||
if _, ok := bridge.users[auth.UserID]; ok {
|
if _, ok := bridge.users[auth.UserID]; ok {
|
||||||
return "", ErrUserAlreadyLoggedIn
|
return "", ErrUserAlreadyLoggedIn
|
||||||
}
|
}
|
||||||
@ -92,66 +95,64 @@ func (bridge *Bridge) LoginUser(
|
|||||||
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
|
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
|
||||||
totp, err := getTOTP()
|
totp, err := getTOTP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to get TOTP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
|
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to authorize 2FA: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var keyPass []byte
|
var keyPass []byte
|
||||||
|
|
||||||
if auth.PasswordMode == liteapi.TwoPasswordMode {
|
if auth.PasswordMode == liteapi.TwoPasswordMode {
|
||||||
pass, err := getKeyPass()
|
userKeyPass, err := getKeyPass()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to get key password: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
keyPass = pass
|
keyPass = userKeyPass
|
||||||
} else {
|
} else {
|
||||||
keyPass = password
|
keyPass = password
|
||||||
}
|
}
|
||||||
|
|
||||||
apiUser, err := client.GetUser(ctx)
|
return bridge.loginUser(ctx, client, auth.UID, auth.RefreshToken, keyPass)
|
||||||
|
},
|
||||||
|
func() error {
|
||||||
|
return client.AuthDelete(ctx)
|
||||||
|
},
|
||||||
|
func() error {
|
||||||
|
bridge.deleteUser(ctx, auth.UserID)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to login user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
salts, err := client.GetSalts(ctx)
|
bridge.publish(events.UserLoggedIn{
|
||||||
if err != nil {
|
UserID: userID,
|
||||||
return "", err
|
})
|
||||||
}
|
|
||||||
|
|
||||||
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID)
|
return userID, nil
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return auth.UserID, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogoutUser logs out the given user.
|
// LogoutUser logs out the given user.
|
||||||
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
|
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
|
||||||
return bridge.logoutUser(ctx, userID, true, false)
|
if err := bridge.logoutUser(ctx, userID); err != nil {
|
||||||
|
return fmt.Errorf("failed to logout user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bridge.publish(events.UserLoggedOut{
|
||||||
|
UserID: userID,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUser deletes the given user.
|
// DeleteUser deletes the given user.
|
||||||
// If it is authorized, it is logged out first.
|
|
||||||
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
|
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
|
||||||
if bridge.users[userID] != nil {
|
bridge.deleteUser(ctx, userID)
|
||||||
if err := bridge.logoutUser(ctx, userID, true, true); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
bridge.publish(events.UserDeleted{
|
bridge.publish(events.UserDeleted{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
@ -193,53 +194,91 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadUsers loads authorized users from the vault.
|
func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
|
||||||
func (bridge *Bridge) loadUsers(ctx context.Context) error {
|
apiUser, err := client.GetUser(ctx)
|
||||||
for _, userID := range bridge.vault.GetUserIDs() {
|
|
||||||
user, err := bridge.vault.GetUser(userID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return "", fmt.Errorf("failed to get API user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
salts, err := client.GetSalts(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get key salts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to salt key password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := bridge.addUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to add bridge user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return apiUser.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadUsers is a loop that, when polled, attempts to load authorized users from the vault.
|
||||||
|
func (bridge *Bridge) loadUsers() error {
|
||||||
|
return bridge.vault.ForUser(func(user *vault.User) error {
|
||||||
|
if _, ok := bridge.users[user.UserID()]; ok {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.AuthUID() == "" {
|
if user.AuthUID() == "" {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bridge.loadUser(ctx, user); err != nil {
|
if err := bridge.loadUser(user); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to load connected user")
|
|
||||||
|
|
||||||
if _, ok := err.(*resty.ResponseError); ok {
|
if _, ok := err.(*resty.ResponseError); ok {
|
||||||
if err := bridge.vault.ClearUser(userID); err != nil {
|
logrus.WithError(err).Error("Failed to load connected user, clearing its secrets from vault")
|
||||||
|
|
||||||
|
if err := user.Clear(); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to clear user")
|
logrus.WithError(err).Error("Failed to clear user")
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
logrus.WithError(err).Error("Failed to load connected user")
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bridge.publish(events.UserLoaded{
|
||||||
|
UserID: user.UserID(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
|
// loadUser loads an existing user from the vault.
|
||||||
|
func (bridge *Bridge) loadUser(user *vault.User) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef())
|
client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create API client: %w", err)
|
return fmt.Errorf("failed to create API client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := try.Catch(
|
||||||
|
func() error {
|
||||||
apiUser, err := client.GetUser(ctx)
|
apiUser, err := client.GetUser(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get user: %w", err)
|
return fmt.Errorf("failed to get user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil {
|
return bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass())
|
||||||
return fmt.Errorf("failed to add user: %w", err)
|
},
|
||||||
|
func() error {
|
||||||
|
return client.AuthDelete(ctx)
|
||||||
|
},
|
||||||
|
func() error {
|
||||||
|
return bridge.logoutUser(ctx, user.UserID())
|
||||||
|
},
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("failed to load user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
bridge.publish(events.UserLoggedIn{
|
|
||||||
UserID: user.UserID(),
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -304,10 +343,6 @@ func (bridge *Bridge) addUser(
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
bridge.publish(events.UserLoggedIn{
|
|
||||||
UserID: user.ID(),
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -363,54 +398,6 @@ func (bridge *Bridge) addExistingUser(
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// logoutUser closes and removes the user with the given ID.
|
|
||||||
// If withAPI is true, the user will additionally be logged out from API.
|
|
||||||
// If withFiles is true, the user's files will be deleted.
|
|
||||||
func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, withFiles bool) error {
|
|
||||||
user, ok := bridge.users[userID]
|
|
||||||
if !ok {
|
|
||||||
return ErrNoSuchUser
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
|
||||||
return fmt.Errorf("failed to remove SMTP user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, gluonID := range user.GetGluonIDs() {
|
|
||||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
|
|
||||||
return fmt.Errorf("failed to remove IMAP user: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if withAPI {
|
|
||||||
if err := user.Logout(ctx); err != nil {
|
|
||||||
return fmt.Errorf("failed to logout user: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := user.Close(); err != nil {
|
|
||||||
return fmt.Errorf("failed to close user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bridge.vault.ClearUser(userID); err != nil {
|
|
||||||
return fmt.Errorf("failed to clear user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if withFiles {
|
|
||||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
|
||||||
return fmt.Errorf("failed to delete user: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(bridge.users, userID)
|
|
||||||
|
|
||||||
bridge.publish(events.UserLoggedOut{
|
|
||||||
UserID: userID,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addIMAPUser connects the given user to gluon.
|
// addIMAPUser connects the given user to gluon.
|
||||||
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
||||||
imapConn, err := user.NewIMAPConnectors()
|
imapConn, err := user.NewIMAPConnectors()
|
||||||
@ -438,6 +425,65 @@ func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// logoutUser logs the given user out from bridge.
|
||||||
|
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
|
||||||
|
user, ok := bridge.users[userID]
|
||||||
|
if !ok {
|
||||||
|
return ErrNoSuchUser
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, gluonID := range user.GetGluonIDs() {
|
||||||
|
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.Logout(ctx); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to logout user")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.Close(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to close user")
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(bridge.users, userID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteUser deletes the given user from bridge.
|
||||||
|
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
|
||||||
|
if user, ok := bridge.users[userID]; ok {
|
||||||
|
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, gluonID := range user.GetGluonIDs() {
|
||||||
|
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.Logout(ctx); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to logout user")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.Close(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to close user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to delete user from vault")
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(bridge.users, userID)
|
||||||
|
}
|
||||||
|
|
||||||
// getUserInfo returns information about a disconnected user.
|
// getUserInfo returns information about a disconnected user.
|
||||||
func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo {
|
func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo {
|
||||||
return UserInfo{
|
return UserInfo{
|
||||||
|
|||||||
@ -27,7 +27,7 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
|
|||||||
}
|
}
|
||||||
|
|
||||||
case events.UserDeauth:
|
case events.UserDeauth:
|
||||||
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
|
if err := bridge.logoutUser(context.Background(), event.UserID); err != nil {
|
||||||
return fmt.Errorf("failed to logout user: %w", err)
|
return fmt.Errorf("failed to logout user: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,17 +9,18 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
"gitlab.protontech.ch/go/liteapi/server"
|
"gitlab.protontech.ch/go/liteapi/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBridge_WithoutUsers(t *testing.T) {
|
func TestBridge_WithoutUsers(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
require.Empty(t, bridge.GetUserIDs())
|
require.Empty(t, bridge.GetUserIDs())
|
||||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
})
|
})
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
require.Empty(t, bridge.GetUserIDs())
|
require.Empty(t, bridge.GetUserIDs())
|
||||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
})
|
})
|
||||||
@ -27,8 +28,8 @@ func TestBridge_WithoutUsers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_Login(t *testing.T) {
|
func TestBridge_Login(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Login the user.
|
// Login the user.
|
||||||
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
|
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -41,8 +42,8 @@ func TestBridge_Login(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_LoginLogoutLogin(t *testing.T) {
|
func TestBridge_LoginLogoutLogin(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Login the user.
|
// Login the user.
|
||||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
|
|
||||||
@ -69,8 +70,8 @@ func TestBridge_LoginLogoutLogin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_LoginDeleteLogin(t *testing.T) {
|
func TestBridge_LoginDeleteLogin(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Login the user.
|
// Login the user.
|
||||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
|
|
||||||
@ -97,8 +98,8 @@ func TestBridge_LoginDeleteLogin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_LoginDeauthLogin(t *testing.T) {
|
func TestBridge_LoginDeauthLogin(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Login the user.
|
// Login the user.
|
||||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
|
|
||||||
@ -131,10 +132,10 @@ func TestBridge_LoginDeauthLogin(t *testing.T) {
|
|||||||
func TestBridge_LoginExpireLogin(t *testing.T) {
|
func TestBridge_LoginExpireLogin(t *testing.T) {
|
||||||
const authLife = 2 * time.Second
|
const authLife = 2 * time.Second
|
||||||
|
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
s.SetAuthLife(authLife)
|
s.SetAuthLife(authLife)
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Login the user. Its auth will only be valid for a short time.
|
// Login the user. Its auth will only be valid for a short time.
|
||||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
|
|
||||||
@ -148,11 +149,11 @@ func TestBridge_LoginExpireLogin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_FailToLoad(t *testing.T) {
|
func TestBridge_FailToLoad(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
var userID string
|
var userID string
|
||||||
|
|
||||||
// Login the user.
|
// Login the user.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -160,7 +161,7 @@ func TestBridge_FailToLoad(t *testing.T) {
|
|||||||
require.NoError(t, s.RevokeUser(userID))
|
require.NoError(t, s.RevokeUser(userID))
|
||||||
|
|
||||||
// When bridge starts, the user will not be logged in.
|
// When bridge starts, the user will not be logged in.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
})
|
})
|
||||||
@ -168,25 +169,27 @@ func TestBridge_FailToLoad(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_LoadWithoutInternet(t *testing.T) {
|
func TestBridge_LoadWithoutInternet(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
var userID string
|
var userID string
|
||||||
|
|
||||||
// Login the user.
|
// Login the user.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Simulate loss of internet connection.
|
// Simulate loss of internet connection.
|
||||||
dialer.SetCanDial(false)
|
netCtl.Disable()
|
||||||
|
|
||||||
// Start bridge without internet.
|
// Start bridge without internet.
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Initially, users are not connected.
|
// Initially, users are not connected.
|
||||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
// Simulate internet connection.
|
// Simulate internet connection.
|
||||||
dialer.SetCanDial(true)
|
netCtl.Enable()
|
||||||
|
|
||||||
// The user will eventually be connected.
|
// The user will eventually be connected.
|
||||||
require.Eventually(t, func() bool {
|
require.Eventually(t, func() bool {
|
||||||
@ -197,16 +200,14 @@ func TestBridge_LoadWithoutInternet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_LoginRestart(t *testing.T) {
|
func TestBridge_LoginRestart(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
var userID string
|
var userID string
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Login the user.
|
|
||||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// The user is still connected.
|
|
||||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||||
})
|
})
|
||||||
@ -214,10 +215,10 @@ func TestBridge_LoginRestart(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_LoginLogoutRestart(t *testing.T) {
|
func TestBridge_LoginLogoutRestart(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
var userID string
|
var userID string
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Login the user.
|
// Login the user.
|
||||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
|
|
||||||
@ -225,7 +226,7 @@ func TestBridge_LoginLogoutRestart(t *testing.T) {
|
|||||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||||
})
|
})
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// The user is still disconnected.
|
// The user is still disconnected.
|
||||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
@ -234,10 +235,10 @@ func TestBridge_LoginLogoutRestart(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_LoginDeleteRestart(t *testing.T) {
|
func TestBridge_LoginDeleteRestart(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
var userID string
|
var userID string
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Login the user.
|
// Login the user.
|
||||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
|
|
||||||
@ -245,7 +246,7 @@ func TestBridge_LoginDeleteRestart(t *testing.T) {
|
|||||||
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||||
})
|
})
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// The user is still gone.
|
// The user is still gone.
|
||||||
require.Empty(t, bridge.GetUserIDs())
|
require.Empty(t, bridge.GetUserIDs())
|
||||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
@ -253,13 +254,69 @@ func TestBridge_LoginDeleteRestart(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBridge_FailLoginRecover(t *testing.T) {
|
||||||
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
|
var read uint64
|
||||||
|
|
||||||
|
netCtl.OnRead(func(b []byte) {
|
||||||
|
read += uint64(len(b))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Log the user in and record how much data was read.
|
||||||
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
|
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate a partial read.
|
||||||
|
netCtl.SetReadLimit(read / 2)
|
||||||
|
|
||||||
|
// We should fail to log the user in because we can't fully read its data.
|
||||||
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
require.Error(t, getErr(bridge.LoginUser(ctx, username, password, nil, nil)))
|
||||||
|
|
||||||
|
// There should be no users.
|
||||||
|
require.Empty(t, bridge.GetUserIDs())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBridge_FailLoadRecover(t *testing.T) {
|
||||||
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
var read uint64
|
||||||
|
|
||||||
|
netCtl.OnRead(func(b []byte) {
|
||||||
|
read += uint64(len(b))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start bridge and record how much data was read.
|
||||||
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
// ...
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate a partial read.
|
||||||
|
netCtl.SetReadLimit(read / 2)
|
||||||
|
|
||||||
|
// We should fail to load the user; it should be disconnected.
|
||||||
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
userIDs := bridge.GetUserIDs()
|
||||||
|
|
||||||
|
require.False(t, must(bridge.GetUserInfo(userIDs[0])).Connected)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestBridge_BridgePass(t *testing.T) {
|
func TestBridge_BridgePass(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
var userID string
|
var userID string
|
||||||
|
|
||||||
var pass []byte
|
var pass []byte
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Login the user.
|
// Login the user.
|
||||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||||
|
|
||||||
@ -276,7 +333,7 @@ func TestBridge_BridgePass(t *testing.T) {
|
|||||||
require.Equal(t, pass, pass)
|
require.Equal(t, pass, pass)
|
||||||
})
|
})
|
||||||
|
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// The bridge should load the user.
|
// The bridge should load the user.
|
||||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||||
@ -288,8 +345,8 @@ func TestBridge_BridgePass(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBridge_AddressMode(t *testing.T) {
|
func TestBridge_AddressMode(t *testing.T) {
|
||||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
// Login the user.
|
// Login the user.
|
||||||
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
|
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -313,3 +370,8 @@ func TestBridge_AddressMode(t *testing.T) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getErr returns the error that was passed to it.
|
||||||
|
func getErr[T any](val T, err error) error {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@ -2,6 +2,12 @@ package events
|
|||||||
|
|
||||||
import "github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
import "github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
|
|
||||||
|
type UserLoaded struct {
|
||||||
|
eventBase
|
||||||
|
|
||||||
|
UserID string
|
||||||
|
}
|
||||||
|
|
||||||
type UserLoggedIn struct {
|
type UserLoggedIn struct {
|
||||||
eventBase
|
eventBase
|
||||||
|
|
||||||
|
|||||||
@ -64,4 +64,5 @@ func (service *FocusService) GetRaiseCh() <-chan struct{} {
|
|||||||
// Close closes the service.
|
// Close closes the service.
|
||||||
func (service *FocusService) Close() {
|
func (service *FocusService) Close() {
|
||||||
service.server.Stop()
|
service.server.Stop()
|
||||||
|
close(service.raiseCh)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,41 +0,0 @@
|
|||||||
package pool
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
type job[In, Out any] struct {
|
|
||||||
ctx context.Context
|
|
||||||
req In
|
|
||||||
|
|
||||||
res chan Out
|
|
||||||
err chan error
|
|
||||||
|
|
||||||
done chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newJob[In, Out any](ctx context.Context, req In) *job[In, Out] {
|
|
||||||
return &job[In, Out]{
|
|
||||||
ctx: ctx,
|
|
||||||
req: req,
|
|
||||||
res: make(chan Out),
|
|
||||||
err: make(chan error),
|
|
||||||
done: make(chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (job *job[In, Out]) result() (Out, error) {
|
|
||||||
return <-job.res, <-job.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (job *job[In, Out]) postSuccess(res Out) {
|
|
||||||
close(job.err)
|
|
||||||
job.res <- res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (job *job[In, Out]) postFailure(err error) {
|
|
||||||
close(job.res)
|
|
||||||
job.err <- err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (job *job[In, Out]) waitDone() {
|
|
||||||
<-job.done
|
|
||||||
}
|
|
||||||
@ -1,147 +0,0 @@
|
|||||||
package pool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/queue"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrJobCancelled indicates the job was cancelled.
|
|
||||||
var ErrJobCancelled = errors.New("job cancelled by surrounding context")
|
|
||||||
|
|
||||||
// Pool is a worker pool that handles input of type In and returns results of type Out.
|
|
||||||
type Pool[In comparable, Out any] struct {
|
|
||||||
queue *queue.QueuedChannel[*job[In, Out]]
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
|
|
||||||
// doneFunc must be called to free up pool resources.
|
|
||||||
type doneFunc func()
|
|
||||||
|
|
||||||
// New returns a new pool.
|
|
||||||
func New[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] {
|
|
||||||
queue := queue.NewQueuedChannel[*job[In, Out]](0, 0)
|
|
||||||
|
|
||||||
for i := 0; i < size; i++ {
|
|
||||||
go func() {
|
|
||||||
for job := range queue.GetChannel() {
|
|
||||||
select {
|
|
||||||
case <-job.ctx.Done():
|
|
||||||
job.postFailure(ErrJobCancelled)
|
|
||||||
|
|
||||||
default:
|
|
||||||
res, err := work(job.ctx, job.req)
|
|
||||||
if err != nil {
|
|
||||||
job.postFailure(err)
|
|
||||||
} else {
|
|
||||||
job.postSuccess(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
job.waitDone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Pool[In, Out]{
|
|
||||||
queue: queue,
|
|
||||||
size: size,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred.
|
|
||||||
func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, Out, error) error) error {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
var (
|
|
||||||
wg sync.WaitGroup
|
|
||||||
errList []error
|
|
||||||
lock sync.Mutex
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, req := range reqs {
|
|
||||||
req := req
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
job, done := pool.newJob(ctx, req)
|
|
||||||
defer done()
|
|
||||||
|
|
||||||
res, err := job.result()
|
|
||||||
|
|
||||||
if err := fn(req, res, err); err != nil {
|
|
||||||
lock.Lock()
|
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
// Cancel ongoing jobs.
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
// Collect the error.
|
|
||||||
errList = append(errList, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
// TODO: Join the errors somehow?
|
|
||||||
if len(errList) > 0 {
|
|
||||||
return errList[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessAll submits jobs to the pool. All results are returned once available.
|
|
||||||
func (pool *Pool[In, Out]) ProcessAll(ctx context.Context, reqs []In) (map[In]Out, error) {
|
|
||||||
var (
|
|
||||||
data = make(map[In]Out)
|
|
||||||
lock = sync.Mutex{}
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := pool.Process(ctx, reqs, func(req In, res Out, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
lock.Lock()
|
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
data[req] = res
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessOne submits one job to the pool and returns the result.
|
|
||||||
func (pool *Pool[In, Out]) ProcessOne(ctx context.Context, req In) (Out, error) {
|
|
||||||
job, done := pool.newJob(ctx, req)
|
|
||||||
defer done()
|
|
||||||
|
|
||||||
return job.result()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pool *Pool[In, Out]) Done() {
|
|
||||||
pool.queue.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// newJob submits a job to the pool. It returns a job handle and a DoneFunc.
|
|
||||||
// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done,
|
|
||||||
// which frees up the worker in the pool for reuse.
|
|
||||||
func (pool *Pool[In, Out]) newJob(ctx context.Context, req In) (*job[In, Out], doneFunc) {
|
|
||||||
job := newJob[In, Out](ctx, req)
|
|
||||||
|
|
||||||
pool.queue.Enqueue(job)
|
|
||||||
|
|
||||||
return job, func() { close(job.done) }
|
|
||||||
}
|
|
||||||
@ -1,163 +0,0 @@
|
|||||||
package pool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPool_NewJob(t *testing.T) {
|
|
||||||
doubler := newDoubler(runtime.NumCPU())
|
|
||||||
|
|
||||||
job1, done1 := doubler.newJob(context.Background(), 1)
|
|
||||||
defer done1()
|
|
||||||
|
|
||||||
job2, done2 := doubler.newJob(context.Background(), 2)
|
|
||||||
defer done2()
|
|
||||||
|
|
||||||
res2, err := job2.result()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
res1, err := job1.result()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, 2, res1)
|
|
||||||
assert.Equal(t, 4, res2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPool_NewJob_Done(t *testing.T) {
|
|
||||||
// Create a doubler pool with 2 workers.
|
|
||||||
doubler := newDoubler(2)
|
|
||||||
|
|
||||||
// Start two jobs. Don't mark the jobs as done yet.
|
|
||||||
job1, done1 := doubler.newJob(context.Background(), 1)
|
|
||||||
job2, done2 := doubler.newJob(context.Background(), 2)
|
|
||||||
|
|
||||||
// Get the first result.
|
|
||||||
res1, _ := job1.result()
|
|
||||||
assert.Equal(t, 2, res1)
|
|
||||||
|
|
||||||
// Get the first result.
|
|
||||||
res2, _ := job2.result()
|
|
||||||
assert.Equal(t, 4, res2)
|
|
||||||
|
|
||||||
// Additional jobs will wait.
|
|
||||||
job3, _ := doubler.newJob(context.Background(), 3)
|
|
||||||
job4, _ := doubler.newJob(context.Background(), 4)
|
|
||||||
|
|
||||||
// Channel to collect results from jobs 3 and 4.
|
|
||||||
resCh := make(chan int, 2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
res, _ := job3.result()
|
|
||||||
resCh <- res
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
res, _ := job4.result()
|
|
||||||
resCh <- res
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Mark jobs 1 and 2 as done, freeing up the workers.
|
|
||||||
done1()
|
|
||||||
done2()
|
|
||||||
|
|
||||||
assert.ElementsMatch(t, []int{6, 8}, []int{<-resCh, <-resCh})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPool_Process(t *testing.T) {
|
|
||||||
doubler := newDoubler(runtime.NumCPU())
|
|
||||||
|
|
||||||
var (
|
|
||||||
res = make(map[int]int)
|
|
||||||
lock sync.Mutex
|
|
||||||
)
|
|
||||||
|
|
||||||
require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(reqVal, resVal int, err error) error {
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
lock.Lock()
|
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
res[reqVal] = resVal
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
|
|
||||||
assert.Equal(t, map[int]int{
|
|
||||||
1: 2,
|
|
||||||
2: 4,
|
|
||||||
3: 6,
|
|
||||||
4: 8,
|
|
||||||
5: 10,
|
|
||||||
}, res)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPool_Process_Error(t *testing.T) {
|
|
||||||
doubler := newDoublerWithError(runtime.NumCPU())
|
|
||||||
|
|
||||||
assert.Error(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(_ int, _ int, err error) error {
|
|
||||||
return err
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPool_Process_Parallel(t *testing.T) {
|
|
||||||
doubler := newDoubler(runtime.NumCPU(), 100*time.Millisecond)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
for i := 0; i < 8; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4}, func(_ int, _ int, err error) error {
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPool_ProcessAll(t *testing.T) {
|
|
||||||
doubler := newDoubler(runtime.NumCPU())
|
|
||||||
|
|
||||||
res, err := doubler.ProcessAll(context.Background(), []int{1, 2, 3, 4, 5})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, map[int]int{
|
|
||||||
1: 2,
|
|
||||||
2: 4,
|
|
||||||
3: 6,
|
|
||||||
4: 8,
|
|
||||||
5: 10,
|
|
||||||
}, res)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDoubler(workers int, delay ...time.Duration) *Pool[int, int] {
|
|
||||||
return New(workers, func(ctx context.Context, req int) (int, error) {
|
|
||||||
if len(delay) > 0 {
|
|
||||||
time.Sleep(delay[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
return 2 * req, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDoublerWithError(workers int) *Pool[int, int] {
|
|
||||||
return New(workers, func(ctx context.Context, req int) (int, error) {
|
|
||||||
if req%2 == 0 {
|
|
||||||
return 0, errors.New("oops")
|
|
||||||
}
|
|
||||||
|
|
||||||
return 2 * req, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -23,7 +23,15 @@ func NewMap[Key comparable, Val any](from map[Key]Val) *Map[Key, Val] {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) Get(key Key, fn func(val Val)) bool {
|
func (m *Map[Key, Val]) Has(key Key) bool {
|
||||||
|
m.lock.RLock()
|
||||||
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
|
_, ok := m.data[key]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[Key, Val]) Get(key Key, fn func(Val)) bool {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
@ -37,7 +45,7 @@ func (m *Map[Key, Val]) Get(key Key, fn func(val Val)) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) GetErr(key Key, fn func(val Val) error) (bool, error) {
|
func (m *Map[Key, Val]) GetErr(key Key, fn func(Val) error) (bool, error) {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
@ -56,6 +64,15 @@ func (m *Map[Key, Val]) Set(key Key, val Val) {
|
|||||||
m.data[key] = val
|
m.data[key] = val
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Map[Key, Val]) Iter(fn func(key Key, val Val)) {
|
||||||
|
m.lock.RLock()
|
||||||
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
|
for key, val := range m.data {
|
||||||
|
fn(key, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) Keys(fn func(keys []Key)) {
|
func (m *Map[Key, Val]) Keys(fn func(keys []Key)) {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
@ -70,28 +87,52 @@ func (m *Map[Key, Val]) Values(fn func(vals []Val)) {
|
|||||||
fn(maps.Values(m.data))
|
fn(maps.Values(m.data))
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(val Val) Ret) (Ret, bool) {
|
func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) Ret, fallback func() Ret) Ret {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
val, ok := m.data[key]
|
val, ok := m.data[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return *new(Ret), false
|
return fallback()
|
||||||
}
|
}
|
||||||
|
|
||||||
return fn(val), true
|
return fn(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(val Val) (Ret, error)) (Ret, bool, error) {
|
func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
val, ok := m.data[key]
|
val, ok := m.data[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return *new(Ret), false, nil
|
return fallback()
|
||||||
}
|
}
|
||||||
|
|
||||||
ret, err := fn(val)
|
return fn(val)
|
||||||
|
}
|
||||||
return ret, true, err
|
|
||||||
|
func FindMap[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) Ret, fallback func() Ret) Ret {
|
||||||
|
m.lock.RLock()
|
||||||
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
|
for _, val := range m.data {
|
||||||
|
if cmp(val) {
|
||||||
|
return fn(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fallback()
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) {
|
||||||
|
m.lock.RLock()
|
||||||
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
|
for _, val := range m.data {
|
||||||
|
if cmp(val) {
|
||||||
|
return fn(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fallback()
|
||||||
}
|
}
|
||||||
|
|||||||
49
internal/try/try.go
Normal file
49
internal/try/try.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package try
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Catch tries to execute the `try` function, and if it fails or panics,
|
||||||
|
// it executes the `handlers` functions in order.
|
||||||
|
func Catch(try func() error, handlers ...func() error) error {
|
||||||
|
if _, err := CatchVal(func() (any, error) { return nil, try() }, handlers...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CatchVal tries to execute the `try` function, and if it fails or panics,
|
||||||
|
// it executes the `handlers` functions in order.
|
||||||
|
func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
catch(handlers...)
|
||||||
|
err = fmt.Errorf("panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if res, err = try(); err != nil {
|
||||||
|
catch(handlers...)
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func catch(handlers ...func() error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logrus.WithField("panic", r).Error("Panic in catch")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, handler := range handlers {
|
||||||
|
if err := handler(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to handle error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
74
internal/try/try_test.go
Normal file
74
internal/try/try_test.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package try
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTry(t *testing.T) {
|
||||||
|
res, err := CatchVal(func() (string, error) {
|
||||||
|
return "foo", nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "foo", res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryCatch(t *testing.T) {
|
||||||
|
tryErr := fmt.Errorf("oops")
|
||||||
|
|
||||||
|
res, err := CatchVal(
|
||||||
|
func() (string, error) {
|
||||||
|
return "", tryErr
|
||||||
|
},
|
||||||
|
func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.ErrorIs(t, err, tryErr)
|
||||||
|
require.Zero(t, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryCatchError(t *testing.T) {
|
||||||
|
tryErr := fmt.Errorf("oops")
|
||||||
|
|
||||||
|
res, err := CatchVal(
|
||||||
|
func() (string, error) {
|
||||||
|
return "", tryErr
|
||||||
|
},
|
||||||
|
func() error {
|
||||||
|
return fmt.Errorf("catch error")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.ErrorIs(t, err, tryErr)
|
||||||
|
require.Zero(t, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryPanic(t *testing.T) {
|
||||||
|
res, err := CatchVal(
|
||||||
|
func() (string, error) {
|
||||||
|
panic("oops")
|
||||||
|
},
|
||||||
|
func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.ErrorContains(t, err, "panic: oops")
|
||||||
|
require.Zero(t, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryCatchPanic(t *testing.T) {
|
||||||
|
tryErr := fmt.Errorf("oops")
|
||||||
|
|
||||||
|
res, err := CatchVal(
|
||||||
|
func() (string, error) {
|
||||||
|
return "", tryErr
|
||||||
|
},
|
||||||
|
func() error {
|
||||||
|
panic("oops")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.ErrorIs(t, err, tryErr)
|
||||||
|
require.Zero(t, res)
|
||||||
|
}
|
||||||
@ -259,7 +259,12 @@ func (user *User) handleMessageEvents(ctx context.Context, messageEvents []litea
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
|
func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
|
||||||
buildRes, err := user.buildRFC822(ctx, event.Message)
|
full, err := user.client.GetFullMessage(ctx, event.Message.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get full message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buildRes, err := buildRFC822(ctx, full, user.addrKRs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to build RFC822: %w", err)
|
return fmt.Errorf("failed to build RFC822: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,14 +4,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/imap"
|
"github.com/ProtonMail/gluon/imap"
|
||||||
"github.com/ProtonMail/gluon/queue"
|
"github.com/ProtonMail/gluon/queue"
|
||||||
"github.com/bradenaw/juniper/iterator"
|
|
||||||
"github.com/bradenaw/juniper/parallel"
|
|
||||||
"github.com/bradenaw/juniper/stream"
|
"github.com/bradenaw/juniper/stream"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -105,34 +102,38 @@ func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.
|
|||||||
|
|
||||||
func (user *User) syncMessages(ctx context.Context) error {
|
func (user *User) syncMessages(ctx context.Context) error {
|
||||||
// Determine which messages to sync.
|
// Determine which messages to sync.
|
||||||
metadata, err := user.client.GetAllMessageMetadata(ctx, nil)
|
allMetadata, err := user.client.GetAllMessageMetadata(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get all message metadata: %w", err)
|
return fmt.Errorf("get all message metadata: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If possible, begin syncing from the last synced message.
|
metadata := allMetadata
|
||||||
|
|
||||||
|
// If possible, begin syncing from one beyond the last synced message.
|
||||||
if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" {
|
if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" {
|
||||||
if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool {
|
if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool {
|
||||||
return metadata.ID == beginID
|
return metadata.ID == beginID
|
||||||
}); idx >= 0 {
|
}); idx >= 0 {
|
||||||
metadata = metadata[idx:]
|
metadata = metadata[idx+1:]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the metadata, building the messages.
|
// Process the metadata, building the messages.
|
||||||
buildCh := stream.Chunk(parallel.MapStream(
|
buildCh := stream.Chunk(stream.Map(
|
||||||
ctx,
|
user.client.GetFullMessages(ctx, xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string {
|
||||||
stream.FromIterator(iterator.Slice(metadata)),
|
return metadata.ID
|
||||||
runtime.NumCPU()*runtime.NumCPU()/2,
|
})...),
|
||||||
runtime.NumCPU()*runtime.NumCPU()/2,
|
func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) {
|
||||||
user.buildRFC822,
|
return buildRFC822(ctx, full, user.addrKRs)
|
||||||
|
},
|
||||||
), maxBatchSize)
|
), maxBatchSize)
|
||||||
|
defer buildCh.Close()
|
||||||
|
|
||||||
// Create the flushers, one per update channel.
|
// Create the flushers, one per update channel.
|
||||||
flushers := make(map[string]*flusher)
|
flushers := make(map[string]*flusher)
|
||||||
|
|
||||||
for addrID, updateCh := range user.updateCh {
|
for addrID, updateCh := range user.updateCh {
|
||||||
flusher := newFlusher(user.ID(), updateCh, maxUpdateSize)
|
flusher := newFlusher(updateCh, maxUpdateSize)
|
||||||
defer flusher.flush(ctx, true)
|
defer flusher.flush(ctx, true)
|
||||||
|
|
||||||
flushers[addrID] = flusher
|
flushers[addrID] = flusher
|
||||||
@ -142,6 +143,8 @@ func (user *User) syncMessages(ctx context.Context) error {
|
|||||||
reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second)
|
reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second)
|
||||||
defer reporter.done()
|
defer reporter.done()
|
||||||
|
|
||||||
|
var count int
|
||||||
|
|
||||||
// Send each update to the appropriate flusher.
|
// Send each update to the appropriate flusher.
|
||||||
for {
|
for {
|
||||||
batch, err := buildCh.Next(ctx)
|
batch, err := buildCh.Next(ctx)
|
||||||
@ -170,6 +173,8 @@ func (user *User) syncMessages(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
reporter.add(len(batch))
|
reporter.add(len(batch))
|
||||||
|
|
||||||
|
count += len(batch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/imap"
|
"github.com/ProtonMail/gluon/imap"
|
||||||
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
"gitlab.protontech.ch/go/liteapi"
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
@ -29,30 +30,20 @@ func defaultJobOpts() message.JobOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) buildRFC822(ctx context.Context, metadata liteapi.MessageMetadata) (*buildRes, error) {
|
func buildRFC822(ctx context.Context, full liteapi.FullMessage, addrKRs map[string]*crypto.KeyRing) (*buildRes, error) {
|
||||||
msg, err := user.client.GetMessage(ctx, metadata.ID)
|
literal, err := message.BuildRFC822(addrKRs[full.AddressID], full.Message, full.AttData, defaultJobOpts())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get message %s: %w", metadata.ID, err)
|
return nil, fmt.Errorf("failed to build message %s: %w", full.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
attData, err := user.attPool.ProcessAll(ctx, xslices.Map(msg.Attachments, func(att liteapi.Attachment) string { return att.ID }))
|
update, err := newMessageCreatedUpdate(full.MessageMetadata, literal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get attachments for message %s: %w", metadata.ID, err)
|
return nil, fmt.Errorf("failed to create IMAP update for message %s: %w", full.ID, err)
|
||||||
}
|
|
||||||
|
|
||||||
literal, err := message.BuildRFC822(user.addrKRs[msg.AddressID], msg, attData, defaultJobOpts())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to build message %s: %w", metadata.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
update, err := newMessageCreatedUpdate(metadata, literal)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create IMAP update for message %s: %w", metadata.ID, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &buildRes{
|
return &buildRes{
|
||||||
messageID: metadata.ID,
|
messageID: full.ID,
|
||||||
addressID: metadata.AddressID,
|
addressID: full.AddressID,
|
||||||
update: update,
|
update: update,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,21 +9,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type flusher struct {
|
type flusher struct {
|
||||||
userID string
|
|
||||||
updateCh *queue.QueuedChannel[imap.Update]
|
updateCh *queue.QueuedChannel[imap.Update]
|
||||||
|
|
||||||
updates []*imap.MessageCreated
|
updates []*imap.MessageCreated
|
||||||
maxChunkSize int
|
|
||||||
|
maxUpdateSize int
|
||||||
curChunkSize int
|
curChunkSize int
|
||||||
|
|
||||||
pushLock sync.Mutex
|
pushLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFlusher(userID string, updateCh *queue.QueuedChannel[imap.Update], maxChunkSize int) *flusher {
|
func newFlusher(updateCh *queue.QueuedChannel[imap.Update], maxUpdateSize int) *flusher {
|
||||||
return &flusher{
|
return &flusher{
|
||||||
userID: userID,
|
|
||||||
updateCh: updateCh,
|
updateCh: updateCh,
|
||||||
maxChunkSize: maxChunkSize,
|
maxUpdateSize: maxUpdateSize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,19 +31,17 @@ func (f *flusher) push(ctx context.Context, update *imap.MessageCreated) {
|
|||||||
|
|
||||||
f.updates = append(f.updates, update)
|
f.updates = append(f.updates, update)
|
||||||
|
|
||||||
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
|
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxUpdateSize {
|
||||||
f.flush(ctx, false)
|
f.flush(ctx, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *flusher) flush(ctx context.Context, wait bool) {
|
func (f *flusher) flush(ctx context.Context, wait bool) {
|
||||||
if len(f.updates) == 0 {
|
if len(f.updates) > 0 {
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
|
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
|
||||||
f.updates = nil
|
f.updates = nil
|
||||||
f.curChunkSize = 0
|
f.curChunkSize = 0
|
||||||
|
}
|
||||||
|
|
||||||
if wait {
|
if wait {
|
||||||
update := imap.NewNoop()
|
update := imap.NewNoop()
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/connector"
|
"github.com/ProtonMail/gluon/connector"
|
||||||
@ -14,7 +13,6 @@ import (
|
|||||||
"github.com/ProtonMail/gluon/wait"
|
"github.com/ProtonMail/gluon/wait"
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
@ -31,7 +29,6 @@ var (
|
|||||||
type User struct {
|
type User struct {
|
||||||
vault *vault.User
|
vault *vault.User
|
||||||
client *liteapi.Client
|
client *liteapi.Client
|
||||||
attPool *pool.Pool[string, []byte]
|
|
||||||
eventCh *queue.QueuedChannel[events.Event]
|
eventCh *queue.QueuedChannel[events.Event]
|
||||||
|
|
||||||
apiUser *safe.Type[liteapi.User]
|
apiUser *safe.Type[liteapi.User]
|
||||||
@ -91,7 +88,6 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
|||||||
user := &User{
|
user := &User{
|
||||||
vault: encVault,
|
vault: encVault,
|
||||||
client: client,
|
client: client,
|
||||||
attPool: pool.New(runtime.NumCPU(), client.GetAttachment),
|
|
||||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||||
|
|
||||||
apiUser: safe.NewType(apiUser),
|
apiUser: safe.NewType(apiUser),
|
||||||
@ -123,7 +119,7 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
|||||||
|
|
||||||
// If we haven't synced yet, do it first.
|
// If we haven't synced yet, do it first.
|
||||||
// If it fails, we don't start the event loop.
|
// If it fails, we don't start the event loop.
|
||||||
// Oterwise, begin processing API events, logging any errors that occur.
|
// Otherwise, begin processing API events, logging any errors that occur.
|
||||||
go func() {
|
go func() {
|
||||||
if status := user.vault.SyncStatus(); !status.HasMessages {
|
if status := user.vault.SyncStatus(); !status.HasMessages {
|
||||||
if err := <-user.startSync(); err != nil {
|
if err := <-user.startSync(); err != nil {
|
||||||
@ -336,8 +332,17 @@ func (user *User) NewSMTPSession(email string) (smtp.Session, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Logout logs the user out from the API.
|
// Logout logs the user out from the API.
|
||||||
|
// If withVault is true, the user's vault is also cleared.
|
||||||
func (user *User) Logout(ctx context.Context) error {
|
func (user *User) Logout(ctx context.Context) error {
|
||||||
return user.client.AuthDelete(ctx)
|
if err := user.client.AuthDelete(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.vault.Clear(); err != nil {
|
||||||
|
return fmt.Errorf("failed to clear vault: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes ongoing connections and cleans up resources.
|
// Close closes ongoing connections and cleans up resources.
|
||||||
@ -345,9 +350,6 @@ func (user *User) Close() error {
|
|||||||
// Cancel ongoing syncs.
|
// Cancel ongoing syncs.
|
||||||
user.stopSync()
|
user.stopSync()
|
||||||
|
|
||||||
// Close the attachment pool.
|
|
||||||
user.attPool.Done()
|
|
||||||
|
|
||||||
// Close the user's API client.
|
// Close the user's API client.
|
||||||
user.client.Close()
|
user.client.Close()
|
||||||
|
|
||||||
|
|||||||
@ -89,13 +89,13 @@ func withAPI(t *testing.T, ctx context.Context, fn func(context.Context, *server
|
|||||||
func withAccount(t *testing.T, s *server.Server, username, password string, emails []string, fn func(string, []string)) {
|
func withAccount(t *testing.T, s *server.Server, username, password string, emails []string, fn func(string, []string)) {
|
||||||
var addrIDs []string
|
var addrIDs []string
|
||||||
|
|
||||||
userID, addrID, err := s.CreateUser(username, password, emails[0])
|
userID, addrID, err := s.CreateUser(username, emails[0], []byte(password))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
addrIDs = append(addrIDs, addrID)
|
addrIDs = append(addrIDs, addrID)
|
||||||
|
|
||||||
for _, email := range emails[1:] {
|
for _, email := range emails[1:] {
|
||||||
addrID, err := s.CreateAddress(userID, email, password)
|
addrID, err := s.CreateAddress(userID, email, []byte(password))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
addrIDs = append(addrIDs, addrID)
|
addrIDs = append(addrIDs, addrID)
|
||||||
|
|||||||
@ -138,3 +138,12 @@ func (user *User) SetEventID(eventID string) error {
|
|||||||
data.EventID = eventID
|
data.EventID = eventID
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear clears the user's auth secrets.
|
||||||
|
func (user *User) Clear() error {
|
||||||
|
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||||
|
data.AuthUID = ""
|
||||||
|
data.AuthRef = ""
|
||||||
|
data.KeyPass = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@ -58,7 +58,7 @@ func TestUser_Clear(t *testing.T) {
|
|||||||
require.Equal(t, "keyPass", string(user.KeyPass()))
|
require.Equal(t, "keyPass", string(user.KeyPass()))
|
||||||
|
|
||||||
// Clear the user's auth information.
|
// Clear the user's auth information.
|
||||||
require.NoError(t, s.ClearUser("userID"))
|
require.NoError(t, user.Clear())
|
||||||
|
|
||||||
// Check the user's cleared auth information.
|
// Check the user's cleared auth information.
|
||||||
require.Empty(t, user.AuthUID())
|
require.Empty(t, user.AuthUID())
|
||||||
|
|||||||
@ -107,14 +107,6 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
|
|||||||
return vault.GetUser(userID)
|
return vault.GetUser(userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) ClearUser(userID string) error {
|
|
||||||
return vault.modUser(userID, func(data *UserData) {
|
|
||||||
data.AuthUID = ""
|
|
||||||
data.AuthRef = ""
|
|
||||||
data.KeyPass = nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteUser removes the given user from the vault.
|
// DeleteUser removes the given user from the vault.
|
||||||
func (vault *Vault) DeleteUser(userID string) error {
|
func (vault *Vault) DeleteUser(userID string) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.mod(func(data *Data) {
|
||||||
|
|||||||
@ -12,8 +12,8 @@ type API interface {
|
|||||||
GetHostURL() string
|
GetHostURL() string
|
||||||
AddCallWatcher(func(server.Call), ...string)
|
AddCallWatcher(func(server.Call), ...string)
|
||||||
|
|
||||||
CreateUser(username, password, address string) (string, string, error)
|
CreateUser(username, address string, password []byte) (string, string, error)
|
||||||
CreateAddress(userID, address, password string) (string, error)
|
CreateAddress(userID, address string, password []byte) (string, error)
|
||||||
RemoveAddress(userID, addrID string) error
|
RemoveAddress(userID, addrID string) error
|
||||||
RevokeUser(userID string) error
|
RevokeUser(userID string) error
|
||||||
|
|
||||||
|
|||||||
@ -2,17 +2,19 @@ package tests
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *testCtx) startBridge() error {
|
func (t *testCtx) startBridge() error {
|
||||||
// Bridge will enable the proxy by default at startup.
|
// Bridge will enable the proxy by default at startup.
|
||||||
t.mocks.ProxyDialer.EXPECT().AllowProxy()
|
t.mocks.ProxyCtl.EXPECT().AllowProxy()
|
||||||
|
|
||||||
// Get the path to the vault.
|
// Get the path to the vault.
|
||||||
vaultDir, err := t.locator.ProvideSettingsPath()
|
vaultDir, err := t.locator.ProvideSettingsPath()
|
||||||
@ -41,7 +43,8 @@ func (t *testCtx) startBridge() error {
|
|||||||
vault,
|
vault,
|
||||||
useragent.New(),
|
useragent.New(),
|
||||||
t.mocks.TLSReporter,
|
t.mocks.TLSReporter,
|
||||||
t.mocks.ProxyDialer,
|
liteapi.NewDialer(t.netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
|
||||||
|
t.mocks.ProxyCtl,
|
||||||
t.mocks.Autostarter,
|
t.mocks.Autostarter,
|
||||||
t.mocks.Updater,
|
t.mocks.Updater,
|
||||||
t.version,
|
t.version,
|
||||||
|
|||||||
@ -24,7 +24,7 @@ type testCtx struct {
|
|||||||
// These are the objects supporting the test.
|
// These are the objects supporting the test.
|
||||||
dir string
|
dir string
|
||||||
api API
|
api API
|
||||||
dialer *bridge.TestDialer
|
netCtl *liteapi.NetCtl
|
||||||
locator *locations.Locations
|
locator *locations.Locations
|
||||||
storeKey []byte
|
storeKey []byte
|
||||||
version *semver.Version
|
version *semver.Version
|
||||||
@ -76,15 +76,13 @@ type smtpClient struct {
|
|||||||
func newTestCtx(tb testing.TB) *testCtx {
|
func newTestCtx(tb testing.TB) *testCtx {
|
||||||
dir := tb.TempDir()
|
dir := tb.TempDir()
|
||||||
|
|
||||||
dialer := bridge.NewTestDialer()
|
|
||||||
|
|
||||||
ctx := &testCtx{
|
ctx := &testCtx{
|
||||||
dir: dir,
|
dir: dir,
|
||||||
api: newFakeAPI(),
|
api: newFakeAPI(),
|
||||||
dialer: dialer,
|
netCtl: liteapi.NewNetCtl(),
|
||||||
locator: locations.New(bridge.NewTestLocationsProvider(dir), "config-name"),
|
locator: locations.New(bridge.NewTestLocationsProvider(dir), "config-name"),
|
||||||
storeKey: []byte("super-secret-store-key"),
|
storeKey: []byte("super-secret-store-key"),
|
||||||
mocks: bridge.NewMocks(tb, dialer, defaultVersion, defaultVersion),
|
mocks: bridge.NewMocks(tb, defaultVersion, defaultVersion),
|
||||||
version: defaultVersion,
|
version: defaultVersion,
|
||||||
|
|
||||||
userIDByName: make(map[string]string),
|
userIDByName: make(map[string]string),
|
||||||
|
|||||||
@ -38,12 +38,12 @@ func (s *scenario) itFailsWithError(wantErr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *scenario) internetIsTurnedOff() error {
|
func (s *scenario) internetIsTurnedOff() error {
|
||||||
s.t.dialer.SetCanDial(false)
|
s.t.netCtl.SetCanDial(false)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scenario) internetIsTurnedOn() error {
|
func (s *scenario) internetIsTurnedOn() error {
|
||||||
s.t.dialer.SetCanDial(true)
|
s.t.netCtl.SetCanDial(true)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error {
|
func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error {
|
||||||
// Create the user.
|
// Create the user.
|
||||||
userID, addrID, err := s.t.api.CreateUser(username, password, username)
|
userID, addrID, err := s.t.api.CreateUser(username, username, []byte(password))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -34,7 +34,7 @@ func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, passwor
|
|||||||
func (s *scenario) theAccountHasAdditionalAddress(username, address string) error {
|
func (s *scenario) theAccountHasAdditionalAddress(username, address string) error {
|
||||||
userID := s.t.getUserID(username)
|
userID := s.t.getUserID(username)
|
||||||
|
|
||||||
addrID, err := s.t.api.CreateAddress(userID, address, s.t.getUserPass(userID))
|
addrID, err := s.t.api.CreateAddress(userID, address, []byte(s.t.getUserPass(userID)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user