GODT-1657: More stable sync, with some tests

This commit is contained in:
James Houlahan
2022-10-09 23:05:52 +02:00
parent e7526f2e78
commit 509a767e50
41 changed files with 883 additions and 779 deletions

View File

@ -33,10 +33,10 @@ type Bridge struct {
users map[string]*user.User
// api manages user API clients.
api *liteapi.Manager
cookieJar *cookies.Jar
proxyDialer ProxyDialer
identifier Identifier
api *liteapi.Manager
cookieJar *cookies.Jar
proxyCtl ProxyController
identifier Identifier
// watchers holds all registered event watchers.
watchers []*watcher.Watcher[events.Event]
@ -81,15 +81,16 @@ func New(
vault *vault.Vault, // the bridge's encrypted data store
identifier Identifier, // the identifier to keep track of the user agent
tlsReporter TLSReporter, // the TLS reporter to report TLS errors
proxyDialer ProxyDialer, // the DoH dialer
roundTripper http.RoundTripper, // the round tripper to use for API requests
proxyCtl ProxyController, // the DoH controller
autostarter Autostarter, // the autostarter to manage autostart settings
updater Updater, // the updater to fetch and install updates
curVersion *semver.Version, // the current version of the bridge
) (*Bridge, error) {
if vault.GetProxyAllowed() {
proxyDialer.AllowProxy()
proxyCtl.AllowProxy()
} else {
proxyDialer.DisallowProxy()
proxyCtl.DisallowProxy()
}
cookieJar, err := cookies.NewCookieJar(vault)
@ -101,7 +102,7 @@ func New(
liteapi.WithHostURL(apiURL),
liteapi.WithAppVersion(constants.AppVersion),
liteapi.WithCookieJar(cookieJar),
liteapi.WithTransport(&http.Transport{DialTLSContext: proxyDialer.DialTLSContext}),
liteapi.WithTransport(roundTripper),
)
tlsConfig, err := loadTLSConfig(vault)
@ -139,10 +140,10 @@ func New(
vault: vault,
users: make(map[string]*user.User),
api: api,
cookieJar: cookieJar,
proxyDialer: proxyDialer,
identifier: identifier,
api: api,
cookieJar: cookieJar,
proxyCtl: proxyCtl,
identifier: identifier,
tlsConfig: tlsConfig,
imapServer: imapServer,
@ -179,6 +180,10 @@ func New(
return nil
})
if err := bridge.loadUsers(); err != nil {
return nil, fmt.Errorf("failed to load users: %w", err)
}
go func() {
for range tlsReporter.GetTLSIssueCh() {
bridge.publish(events.TLSIssue{})
@ -197,10 +202,6 @@ func New(
}
}()
if err := bridge.loadUsers(context.Background()); err != nil {
return nil, fmt.Errorf("failed to load connected users: %w", err)
}
if err := bridge.serveIMAP(); err != nil {
bridge.PushError(ErrServeIMAP)
}
@ -309,14 +310,8 @@ func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) {
func (bridge *Bridge) onStatusUp() {
bridge.publish(events.ConnStatusUp{})
for _, userID := range bridge.vault.GetUserIDs() {
if _, ok := bridge.users[userID]; !ok {
if vaultUser, err := bridge.vault.GetUser(userID); err != nil {
logrus.WithError(err).Error("Failed to get user from vault")
} else if err := bridge.loadUser(context.Background(), vaultUser); err != nil {
logrus.WithError(err).Error("Failed to load user")
}
}
if err := bridge.loadUsers(); err != nil {
logrus.WithError(err).Error("Failed to load users")
}
}

View File

@ -2,6 +2,7 @@ package bridge_test
import (
"context"
"crypto/tls"
"net/http"
"os"
"testing"
@ -21,6 +22,7 @@ import (
"github.com/ProtonMail/proton-bridge/v2/tests"
"github.com/bradenaw/juniper/xslices"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
"gitlab.protontech.ch/go/liteapi/server/backend"
)
@ -41,14 +43,14 @@ func init() {
}
func TestBridge_ConnStatus(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of connection status events.
eventCh, done := bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{})
defer done()
// Simulate network disconnect.
dialer.SetCanDial(false)
netCtl.Disable()
// Trigger some operation that will fail due to the network disconnect.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
@ -58,7 +60,7 @@ func TestBridge_ConnStatus(t *testing.T) {
require.Equal(t, events.ConnStatusDown{}, <-eventCh)
// Simulate network reconnect.
dialer.SetCanDial(true)
netCtl.Enable()
// Trigger some operation that will succeed due to the network reconnect.
userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
@ -72,8 +74,8 @@ func TestBridge_ConnStatus(t *testing.T) {
}
func TestBridge_TLSIssue(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of TLS issue events.
tlsEventCh, done := bridge.GetEvents(events.TLSIssue{})
defer done()
@ -90,8 +92,8 @@ func TestBridge_TLSIssue(t *testing.T) {
}
func TestBridge_Focus(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of TLS issue events.
raiseCh, done := bridge.GetEvents(events.Raise{})
defer done()
@ -106,14 +108,14 @@ func TestBridge_Focus(t *testing.T) {
}
func TestBridge_UserAgent(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
var calls []server.Call
s.AddCallWatcher(func(call server.Call) {
calls = append(calls, call)
})
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Set the platform to something other than the default.
bridge.SetCurrentPlatform("platform")
@ -131,7 +133,7 @@ func TestBridge_UserAgent(t *testing.T) {
}
func TestBridge_Cookies(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
var calls []server.Call
s.AddCallWatcher(func(call server.Call) {
@ -141,7 +143,7 @@ func TestBridge_Cookies(t *testing.T) {
var sessionID string
// Start bridge and add a user so that API assigns us a session ID via cookie.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
@ -152,7 +154,7 @@ func TestBridge_Cookies(t *testing.T) {
})
// Start bridge again and check that it uses the same session ID.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
cookie, err := (&http.Request{Header: calls[len(calls)-1].Header}).Cookie("Session-Id")
require.NoError(t, err)
@ -162,8 +164,8 @@ func TestBridge_Cookies(t *testing.T) {
}
func TestBridge_CheckUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
@ -201,8 +203,8 @@ func TestBridge_CheckUpdate(t *testing.T) {
}
func TestBridge_AutoUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Enable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(true))
@ -229,8 +231,8 @@ func TestBridge_AutoUpdate(t *testing.T) {
}
func TestBridge_ManualUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
@ -258,8 +260,8 @@ func TestBridge_ManualUpdate(t *testing.T) {
}
func TestBridge_ForceUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateForced{})
defer done()
@ -278,11 +280,11 @@ func TestBridge_ForceUpdate(t *testing.T) {
}
func TestBridge_BadVaultKey(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
var userID string
// Login a user.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
@ -290,27 +292,27 @@ func TestBridge_BadVaultKey(t *testing.T) {
})
// Start bridge with the correct vault key -- it should load the users correctly.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs())
})
// Start bridge with a bad vault key, the vault will be wiped and bridge will show no users.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
})
// Start bridge with a nil vault key, the vault will be wiped and bridge will show no users.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
})
})
}
func TestBridge_MissingGluonDir(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
var gluonDir string
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
@ -325,20 +327,20 @@ func TestBridge_MissingGluonDir(t *testing.T) {
require.NoError(t, os.RemoveAll(gluonDir))
// Bridge starts but can't find the gluon dir; there should be no error.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// ...
})
})
}
// withEnv creates the full test environment and runs the tests.
func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte)) {
// withTLSEnv creates the full test environment and runs the tests.
func withTLSEnv(t *testing.T, tests func(context.Context, *server.Server, *liteapi.NetCtl, bridge.Locator, []byte)) {
// Create test API.
server := server.NewTLS()
defer server.Close()
// Add test user.
_, _, err := server.CreateUser(username, string(password), username+"@pm.me")
_, _, err := server.CreateUser(username, username+"@pm.me", password)
require.NoError(t, err)
// Generate a random vault key.
@ -349,23 +351,56 @@ func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a net controller so we can simulate network connectivity issues.
netCtl := liteapi.NewNetCtl()
// Create a locations object to provide temporary locations for bridge data during the test.
locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name")
// Run the tests.
tests(
ctx,
server,
bridge.NewTestDialer(),
locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name"),
vaultKey,
)
tests(ctx, server, netCtl, locations, vaultKey)
}
// withEnv creates the full test environment and runs the tests.
func withEnv(t *testing.T, server *server.Server, tests func(context.Context, *liteapi.NetCtl, bridge.Locator, []byte)) {
// Add test user.
_, _, err := server.CreateUser(username, username+"@pm.me", password)
require.NoError(t, err)
// Generate a random vault key.
vaultKey, err := crypto.RandomToken(32)
require.NoError(t, err)
// Create a context used for the test.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a net controller so we can simulate network connectivity issues.
netCtl := liteapi.NewNetCtl()
// Create a locations object to provide temporary locations for bridge data during the test.
locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name")
// Run the tests.
tests(ctx, netCtl, locations, vaultKey)
}
// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done.
func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) {
func withBridge(
t *testing.T,
ctx context.Context,
apiURL string,
netCtl *liteapi.NetCtl,
locator bridge.Locator,
vaultKey []byte,
tests func(*bridge.Bridge, *bridge.Mocks),
) {
// Create the mock objects used in the tests.
mocks := bridge.NewMocks(t, dialer, v2_3_0, v2_3_0)
mocks := bridge.NewMocks(t, v2_3_0, v2_3_0)
defer mocks.Close()
// Bridge will enable the proxy by default at startup.
mocks.ProxyDialer.EXPECT().AllowProxy()
mocks.ProxyCtl.EXPECT().AllowProxy()
// Get the path to the vault.
vaultDir, err := locator.ProvideSettingsPath()
@ -380,7 +415,18 @@ func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge
require.NoError(t, vault.SetSMTPPort(0))
// Create a new bridge.
bridge, err := bridge.New(apiURL, locator, vault, useragent.New(), mocks.TLSReporter, mocks.ProxyDialer, mocks.Autostarter, mocks.Updater, v2_3_0)
bridge, err := bridge.New(
apiURL,
locator,
vault,
useragent.New(),
mocks.TLSReporter,
liteapi.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
mocks.ProxyCtl,
mocks.Autostarter,
mocks.Updater,
v2_3_0,
)
require.NoError(t, err)
// Close the bridge when done.

View File

@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"io/fs"
"net"
"os"
"strconv"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon"
@ -33,6 +35,22 @@ func (bridge *Bridge) serveIMAP() error {
return fmt.Errorf("failed to serve IMAP: %w", err)
}
_, port, err := net.SplitHostPort(imapListener.Addr().String())
if err != nil {
return fmt.Errorf("failed to get IMAP listener address: %w", err)
}
portInt, err := strconv.Atoi(port)
if err != nil {
return fmt.Errorf("failed to convert IMAP listener port to int: %w", err)
}
if portInt != bridge.vault.GetIMAPPort() {
if err := bridge.vault.SetIMAPPort(portInt); err != nil {
return fmt.Errorf("failed to update IMAP port in vault: %w", err)
}
}
go func() {
for err := range bridge.imapServer.GetErrorCh() {
logrus.WithError(err).Error("IMAP server error")

View File

@ -1,10 +1,6 @@
package bridge
import (
"context"
"crypto/tls"
"errors"
"net"
"os"
"testing"
@ -15,7 +11,7 @@ import (
)
type Mocks struct {
ProxyDialer *mocks.MockProxyDialer
ProxyCtl *mocks.MockProxyController
TLSReporter *mocks.MockTLSReporter
TLSIssueCh chan struct{}
@ -23,11 +19,11 @@ type Mocks struct {
Autostarter *mocks.MockAutostarter
}
func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Version) *Mocks {
func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks {
ctl := gomock.NewController(tb)
mocks := &Mocks{
ProxyDialer: mocks.NewMockProxyDialer(ctl),
ProxyCtl: mocks.NewMockProxyController(ctl),
TLSReporter: mocks.NewMockTLSReporter(ctl),
TLSIssueCh: make(chan struct{}),
@ -35,41 +31,14 @@ func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Versio
Autostarter: mocks.NewMockAutostarter(ctl),
}
// When using the proxy dialer, we want to use the test dialer.
mocks.ProxyDialer.EXPECT().DialTLSContext(
gomock.Any(),
gomock.Any(),
gomock.Any(),
).DoAndReturn(func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.DialTLSContext(ctx, network, address)
}).AnyTimes()
// When getting the TLS issue channel, we want to return the test channel.
mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes()
return mocks
}
type TestDialer struct {
canDial bool
}
func NewTestDialer() *TestDialer {
return &TestDialer{
canDial: true,
}
}
func (d *TestDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
if !d.canDial {
return nil, errors.New("cannot dial")
}
return (&tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}).DialContext(ctx, network, address)
}
func (d *TestDialer) SetCanDial(canDial bool) {
d.canDial = canDial
func (mocks *Mocks) Close() {
close(mocks.TLSIssueCh)
}
type TestLocationsProvider struct {

View File

@ -1,12 +1,10 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyDialer,Autostarter)
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyController,Autostarter)
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@ -49,66 +47,51 @@ func (mr *MockTLSReporterMockRecorder) GetTLSIssueCh() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh))
}
// MockProxyDialer is a mock of ProxyDialer interface.
type MockProxyDialer struct {
// MockProxyController is a mock of ProxyController interface.
type MockProxyController struct {
ctrl *gomock.Controller
recorder *MockProxyDialerMockRecorder
recorder *MockProxyControllerMockRecorder
}
// MockProxyDialerMockRecorder is the mock recorder for MockProxyDialer.
type MockProxyDialerMockRecorder struct {
mock *MockProxyDialer
// MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
type MockProxyControllerMockRecorder struct {
mock *MockProxyController
}
// NewMockProxyDialer creates a new mock instance.
func NewMockProxyDialer(ctrl *gomock.Controller) *MockProxyDialer {
mock := &MockProxyDialer{ctrl: ctrl}
mock.recorder = &MockProxyDialerMockRecorder{mock}
// NewMockProxyController creates a new mock instance.
func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
mock := &MockProxyController{ctrl: ctrl}
mock.recorder = &MockProxyControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProxyDialer) EXPECT() *MockProxyDialerMockRecorder {
func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
return m.recorder
}
// AllowProxy mocks base method.
func (m *MockProxyDialer) AllowProxy() {
func (m *MockProxyController) AllowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AllowProxy")
}
// AllowProxy indicates an expected call of AllowProxy.
func (mr *MockProxyDialerMockRecorder) AllowProxy() *gomock.Call {
func (mr *MockProxyControllerMockRecorder) AllowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyDialer)(nil).AllowProxy))
}
// DialTLSContext mocks base method.
func (m *MockProxyDialer) DialTLSContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DialTLSContext", arg0, arg1, arg2)
ret0, _ := ret[0].(net.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DialTLSContext indicates an expected call of DialTLSContext.
func (mr *MockProxyDialerMockRecorder) DialTLSContext(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLSContext", reflect.TypeOf((*MockProxyDialer)(nil).DialTLSContext), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyController)(nil).AllowProxy))
}
// DisallowProxy mocks base method.
func (m *MockProxyDialer) DisallowProxy() {
func (m *MockProxyController) DisallowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisallowProxy")
}
// DisallowProxy indicates an expected call of DisallowProxy.
func (mr *MockProxyDialerMockRecorder) DisallowProxy() *gomock.Call {
func (mr *MockProxyControllerMockRecorder) DisallowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyDialer)(nil).DisallowProxy))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyController)(nil).DisallowProxy))
}
// MockAutostarter is a mock of Autostarter interface.

View File

@ -138,9 +138,9 @@ func (bridge *Bridge) GetProxyAllowed() bool {
func (bridge *Bridge) SetProxyAllowed(allowed bool) error {
if allowed {
bridge.proxyDialer.AllowProxy()
bridge.proxyCtl.AllowProxy()
} else {
bridge.proxyDialer.DisallowProxy()
bridge.proxyCtl.DisallowProxy()
}
return bridge.vault.SetProxyAllowed(allowed)

View File

@ -7,12 +7,13 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
)
func TestBridge_Settings_GluonDir(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Create a user.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
@ -34,8 +35,8 @@ func TestBridge_Settings_GluonDir(t *testing.T) {
}
func TestBridge_Settings_IMAPPort(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
curPort := bridge.GetIMAPPort()
// Set the port to 1144.
@ -51,8 +52,8 @@ func TestBridge_Settings_IMAPPort(t *testing.T) {
}
func TestBridge_Settings_IMAPSSL(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, IMAP SSL is disabled.
require.False(t, bridge.GetIMAPSSL())
@ -66,8 +67,8 @@ func TestBridge_Settings_IMAPSSL(t *testing.T) {
}
func TestBridge_Settings_SMTPPort(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
curPort := bridge.GetSMTPPort()
// Set the port to 1024.
@ -84,8 +85,8 @@ func TestBridge_Settings_SMTPPort(t *testing.T) {
}
func TestBridge_Settings_SMTPSSL(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, SMTP SSL is disabled.
require.False(t, bridge.GetSMTPSSL())
@ -99,13 +100,13 @@ func TestBridge_Settings_SMTPSSL(t *testing.T) {
}
func TestBridge_Settings_Proxy(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, proxy is allowed.
require.True(t, bridge.GetProxyAllowed())
// Disallow proxy.
mocks.ProxyDialer.EXPECT().DisallowProxy()
mocks.ProxyCtl.EXPECT().DisallowProxy()
require.NoError(t, bridge.SetProxyAllowed(false))
// Get the new setting.
@ -115,8 +116,8 @@ func TestBridge_Settings_Proxy(t *testing.T) {
}
func TestBridge_Settings_Autostart(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, autostart is disabled.
require.False(t, bridge.GetAutostart())
@ -131,8 +132,8 @@ func TestBridge_Settings_Autostart(t *testing.T) {
}
func TestBridge_Settings_FirstStart(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, first start is true.
require.True(t, bridge.GetFirstStart())
@ -146,8 +147,8 @@ func TestBridge_Settings_FirstStart(t *testing.T) {
}
func TestBridge_Settings_FirstStartGUI(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, first start is true.
require.True(t, bridge.GetFirstStartGUI())

View File

@ -3,6 +3,8 @@ package bridge
import (
"crypto/tls"
"fmt"
"net"
"strconv"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/emersion/go-sasl"
@ -22,6 +24,22 @@ func (bridge *Bridge) serveSMTP() error {
}
}()
_, port, err := net.SplitHostPort(smtpListener.Addr().String())
if err != nil {
return fmt.Errorf("failed to get SMTP listener address: %w", err)
}
portInt, err := strconv.Atoi(port)
if err != nil {
return fmt.Errorf("failed to convert SMTP listener port to int: %w", err)
}
if portInt != bridge.vault.GetSMTPPort() {
if err := bridge.vault.SetSMTPPort(portInt); err != nil {
return fmt.Errorf("failed to update SMTP port in vault: %w", err)
}
}
return nil
}

View File

@ -0,0 +1,130 @@
package bridge_test
import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/emersion/go-imap/client"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
)
func TestBridge_Sync(t *testing.T) {
s := server.New()
defer s.Close()
numMsg := 1 << 10
withEnv(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password)
require.NoError(t, err)
labelID, err := s.CreateLabel(userID, "folder", liteapi.LabelTypeFolder)
require.NoError(t, err)
literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml"))
require.NoError(t, err)
for i := 0; i < numMsg; i++ {
messageID, err := s.CreateMessage(userID, addrID, literal, liteapi.MessageFlagReceived, false, false)
require.NoError(t, err)
require.NoError(t, s.LabelMessage(userID, messageID, labelID))
}
var read uint64
netCtl.OnRead(func(b []byte) {
read += uint64(len(b))
})
// The initial user should be fully synced.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
syncCh, done := bridge.GetEvents(events.SyncFinished{})
defer done()
userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil)
require.NoError(t, err)
require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID)
})
// If we then connect an IMAP client, it should see all the messages.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
require.True(t, info.Connected)
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
defer client.Logout()
status, err := client.Select(`Folders/folder`, false)
require.NoError(t, err)
require.Equal(t, uint32(numMsg), status.Messages)
})
// Now let's remove the user and simulate a network error.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.NoError(t, bridge.DeleteUser(ctx, userID))
})
// Pretend we can only sync 2/3 of the original messages.
netCtl.SetReadLimit(2 * read / 3)
// Login the user; its sync should fail.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
syncCh, done := bridge.GetEvents(events.SyncFailed{})
defer done()
userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil)
require.NoError(t, err)
require.Equal(t, userID, (<-syncCh).(events.SyncFailed).UserID)
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
require.True(t, info.Connected)
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
defer client.Logout()
status, err := client.Select(`Folders/folder`, false)
require.NoError(t, err)
require.Less(t, status.Messages, uint32(numMsg))
})
// Remove the network limit, allowing the sync to finish.
netCtl.SetReadLimit(0)
// Login the user; its sync should now finish.
// If we then connect an IMAP client, it should eventually see all the messages.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
syncCh, done := bridge.GetEvents(events.SyncFinished{})
defer done()
require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID)
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
require.True(t, info.Connected)
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
defer client.Logout()
status, err := client.Select(`Folders/folder`, false)
require.NoError(t, err)
require.Equal(t, uint32(numMsg), status.Messages)
})
})
}

View File

@ -0,0 +1,6 @@
To: recipient@pm.me
From: sender@pm.me
Subject: Test
Content-Type: text/plain; charset=utf-8
Test

View File

@ -1,9 +1,6 @@
package bridge
import (
"context"
"net"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
)
@ -21,17 +18,15 @@ type Identifier interface {
SetPlatform(platform string)
}
type TLSReporter interface {
GetTLSIssueCh() <-chan struct{}
}
type ProxyDialer interface {
DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error)
type ProxyController interface {
AllowProxy()
DisallowProxy()
}
type TLSReporter interface {
GetTLSIssueCh() <-chan struct{}
}
type Autostarter interface {
Enable() error
Disable() error

View File

@ -18,6 +18,9 @@ func (bridge *Bridge) watchForUpdates() error {
go func() {
for {
select {
case <-bridge.stopCh:
return
case <-bridge.updateCheckCh:
case <-ticker.C:
}

View File

@ -6,6 +6,7 @@ import (
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/try"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/go-resty/resty/v2"
@ -82,76 +83,76 @@ func (bridge *Bridge) LoginUser(
) (string, error) {
client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password)
if err != nil {
return "", err
return "", fmt.Errorf("failed to create new API client: %w", err)
}
if _, ok := bridge.users[auth.UserID]; ok {
return "", ErrUserAlreadyLoggedIn
}
userID, err := try.CatchVal(
func() (string, error) {
if _, ok := bridge.users[auth.UserID]; ok {
return "", ErrUserAlreadyLoggedIn
}
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
totp, err := getTOTP()
if err != nil {
return "", err
}
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
totp, err := getTOTP()
if err != nil {
return "", fmt.Errorf("failed to get TOTP: %w", err)
}
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
return "", err
}
}
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
return "", fmt.Errorf("failed to authorize 2FA: %w", err)
}
}
var keyPass []byte
var keyPass []byte
if auth.PasswordMode == liteapi.TwoPasswordMode {
pass, err := getKeyPass()
if err != nil {
return "", err
}
if auth.PasswordMode == liteapi.TwoPasswordMode {
userKeyPass, err := getKeyPass()
if err != nil {
return "", fmt.Errorf("failed to get key password: %w", err)
}
keyPass = pass
} else {
keyPass = password
}
keyPass = userKeyPass
} else {
keyPass = password
}
apiUser, err := client.GetUser(ctx)
return bridge.loginUser(ctx, client, auth.UID, auth.RefreshToken, keyPass)
},
func() error {
return client.AuthDelete(ctx)
},
func() error {
bridge.deleteUser(ctx, auth.UserID)
return nil
},
)
if err != nil {
return "", err
return "", fmt.Errorf("failed to login user: %w", err)
}
salts, err := client.GetSalts(ctx)
if err != nil {
return "", err
}
bridge.publish(events.UserLoggedIn{
UserID: userID,
})
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID)
if err != nil {
return "", err
}
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil {
return "", err
}
return auth.UserID, nil
return userID, nil
}
// LogoutUser logs out the given user.
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
return bridge.logoutUser(ctx, userID, true, false)
if err := bridge.logoutUser(ctx, userID); err != nil {
return fmt.Errorf("failed to logout user: %w", err)
}
bridge.publish(events.UserLoggedOut{
UserID: userID,
})
return nil
}
// DeleteUser deletes the given user.
// If it is authorized, it is logged out first.
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
if bridge.users[userID] != nil {
if err := bridge.logoutUser(ctx, userID, true, true); err != nil {
return err
}
}
if err := bridge.vault.DeleteUser(userID); err != nil {
return err
}
bridge.deleteUser(ctx, userID)
bridge.publish(events.UserDeleted{
UserID: userID,
@ -193,53 +194,91 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
return nil
}
// loadUsers loads authorized users from the vault.
func (bridge *Bridge) loadUsers(ctx context.Context) error {
for _, userID := range bridge.vault.GetUserIDs() {
user, err := bridge.vault.GetUser(userID)
if err != nil {
return err
func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
apiUser, err := client.GetUser(ctx)
if err != nil {
return "", fmt.Errorf("failed to get API user: %w", err)
}
salts, err := client.GetSalts(ctx)
if err != nil {
return "", fmt.Errorf("failed to get key salts: %w", err)
}
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID)
if err != nil {
return "", fmt.Errorf("failed to salt key password: %w", err)
}
if err := bridge.addUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass); err != nil {
return "", fmt.Errorf("failed to add bridge user: %w", err)
}
return apiUser.ID, nil
}
// loadUsers is a loop that, when polled, attempts to load authorized users from the vault.
func (bridge *Bridge) loadUsers() error {
return bridge.vault.ForUser(func(user *vault.User) error {
if _, ok := bridge.users[user.UserID()]; ok {
return nil
}
if user.AuthUID() == "" {
continue
return nil
}
if err := bridge.loadUser(ctx, user); err != nil {
logrus.WithError(err).Error("Failed to load connected user")
if err := bridge.loadUser(user); err != nil {
if _, ok := err.(*resty.ResponseError); ok {
if err := bridge.vault.ClearUser(userID); err != nil {
logrus.WithError(err).Error("Failed to load connected user, clearing its secrets from vault")
if err := user.Clear(); err != nil {
logrus.WithError(err).Error("Failed to clear user")
}
} else {
logrus.WithError(err).Error("Failed to load connected user")
}
continue
return nil
}
}
return nil
bridge.publish(events.UserLoaded{
UserID: user.UserID(),
})
return nil
})
}
func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
// loadUser loads an existing user from the vault.
func (bridge *Bridge) loadUser(user *vault.User) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef())
if err != nil {
return fmt.Errorf("failed to create API client: %w", err)
}
apiUser, err := client.GetUser(ctx)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if err := try.Catch(
func() error {
apiUser, err := client.GetUser(ctx)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil {
return fmt.Errorf("failed to add user: %w", err)
return bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass())
},
func() error {
return client.AuthDelete(ctx)
},
func() error {
return bridge.logoutUser(ctx, user.UserID())
},
); err != nil {
return fmt.Errorf("failed to load user: %w", err)
}
bridge.publish(events.UserLoggedIn{
UserID: user.UserID(),
})
return nil
}
@ -304,10 +343,6 @@ func (bridge *Bridge) addUser(
return nil
})
bridge.publish(events.UserLoggedIn{
UserID: user.ID(),
})
return nil
}
@ -363,54 +398,6 @@ func (bridge *Bridge) addExistingUser(
return user, nil
}
// logoutUser closes and removes the user with the given ID.
// If withAPI is true, the user will additionally be logged out from API.
// If withFiles is true, the user's files will be deleted.
func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, withFiles bool) error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
if err := bridge.smtpBackend.removeUser(user); err != nil {
return fmt.Errorf("failed to remove SMTP user: %w", err)
}
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
}
if withAPI {
if err := user.Logout(ctx); err != nil {
return fmt.Errorf("failed to logout user: %w", err)
}
}
if err := user.Close(); err != nil {
return fmt.Errorf("failed to close user: %w", err)
}
if err := bridge.vault.ClearUser(userID); err != nil {
return fmt.Errorf("failed to clear user: %w", err)
}
if withFiles {
if err := bridge.vault.DeleteUser(userID); err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
}
delete(bridge.users, userID)
bridge.publish(events.UserLoggedOut{
UserID: userID,
})
return nil
}
// addIMAPUser connects the given user to gluon.
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
imapConn, err := user.NewIMAPConnectors()
@ -438,6 +425,65 @@ func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
return nil
}
// logoutUser logs the given user out from bridge.
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
if err := bridge.smtpBackend.removeUser(user); err != nil {
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
}
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
}
if err := user.Logout(ctx); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
if err := user.Close(); err != nil {
logrus.WithError(err).Error("Failed to close user")
}
delete(bridge.users, userID)
return nil
}
// deleteUser deletes the given user from bridge.
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
if user, ok := bridge.users[userID]; ok {
if err := bridge.smtpBackend.removeUser(user); err != nil {
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
}
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
}
if err := user.Logout(ctx); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
if err := user.Close(); err != nil {
logrus.WithError(err).Error("Failed to close user")
}
}
if err := bridge.vault.DeleteUser(userID); err != nil {
logrus.WithError(err).Error("Failed to delete user from vault")
}
delete(bridge.users, userID)
}
// getUserInfo returns information about a disconnected user.
func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo {
return UserInfo{

View File

@ -27,7 +27,7 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
}
case events.UserDeauth:
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
if err := bridge.logoutUser(context.Background(), event.UserID); err != nil {
return fmt.Errorf("failed to logout user: %w", err)
}
}

View File

@ -9,17 +9,18 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
)
func TestBridge_WithoutUsers(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
@ -27,8 +28,8 @@ func TestBridge_WithoutUsers(t *testing.T) {
}
func TestBridge_Login(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
require.NoError(t, err)
@ -41,8 +42,8 @@ func TestBridge_Login(t *testing.T) {
}
func TestBridge_LoginLogoutLogin(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -69,8 +70,8 @@ func TestBridge_LoginLogoutLogin(t *testing.T) {
}
func TestBridge_LoginDeleteLogin(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -97,8 +98,8 @@ func TestBridge_LoginDeleteLogin(t *testing.T) {
}
func TestBridge_LoginDeauthLogin(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -131,10 +132,10 @@ func TestBridge_LoginDeauthLogin(t *testing.T) {
func TestBridge_LoginExpireLogin(t *testing.T) {
const authLife = 2 * time.Second
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
s.SetAuthLife(authLife)
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. Its auth will only be valid for a short time.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -148,11 +149,11 @@ func TestBridge_LoginExpireLogin(t *testing.T) {
}
func TestBridge_FailToLoad(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string
// Login the user.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
})
@ -160,7 +161,7 @@ func TestBridge_FailToLoad(t *testing.T) {
require.NoError(t, s.RevokeUser(userID))
// When bridge starts, the user will not be logged in.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
@ -168,25 +169,27 @@ func TestBridge_FailToLoad(t *testing.T) {
}
func TestBridge_LoadWithoutInternet(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string
// Login the user.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
})
// Simulate loss of internet connection.
dialer.SetCanDial(false)
netCtl.Disable()
// Start bridge without internet.
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Initially, users are not connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
time.Sleep(5 * time.Second)
// Simulate internet connection.
dialer.SetCanDial(true)
netCtl.Enable()
// The user will eventually be connected.
require.Eventually(t, func() bool {
@ -197,16 +200,14 @@ func TestBridge_LoadWithoutInternet(t *testing.T) {
}
func TestBridge_LoginRestart(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
})
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still connected.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
@ -214,10 +215,10 @@ func TestBridge_LoginRestart(t *testing.T) {
}
func TestBridge_LoginLogoutRestart(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -225,7 +226,7 @@ func TestBridge_LoginLogoutRestart(t *testing.T) {
require.NoError(t, bridge.LogoutUser(ctx, userID))
})
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
@ -234,10 +235,10 @@ func TestBridge_LoginLogoutRestart(t *testing.T) {
}
func TestBridge_LoginDeleteRestart(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -245,7 +246,7 @@ func TestBridge_LoginDeleteRestart(t *testing.T) {
require.NoError(t, bridge.DeleteUser(ctx, userID))
})
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still gone.
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
@ -253,13 +254,69 @@ func TestBridge_LoginDeleteRestart(t *testing.T) {
})
}
func TestBridge_FailLoginRecover(t *testing.T) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var read uint64
netCtl.OnRead(func(b []byte) {
read += uint64(len(b))
})
// Log the user in and record how much data was read.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
require.NoError(t, bridge.LogoutUser(ctx, userID))
})
// Simulate a partial read.
netCtl.SetReadLimit(read / 2)
// We should fail to log the user in because we can't fully read its data.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Error(t, getErr(bridge.LoginUser(ctx, username, password, nil, nil)))
// There should be no users.
require.Empty(t, bridge.GetUserIDs())
})
})
}
func TestBridge_FailLoadRecover(t *testing.T) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
must(bridge.LoginUser(ctx, username, password, nil, nil))
})
var read uint64
netCtl.OnRead(func(b []byte) {
read += uint64(len(b))
})
// Start bridge and record how much data was read.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// ...
})
// Simulate a partial read.
netCtl.SetReadLimit(read / 2)
// We should fail to load the user; it should be disconnected.
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userIDs := bridge.GetUserIDs()
require.False(t, must(bridge.GetUserInfo(userIDs[0])).Connected)
})
})
}
func TestBridge_BridgePass(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string
var pass []byte
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
@ -276,7 +333,7 @@ func TestBridge_BridgePass(t *testing.T) {
require.Equal(t, pass, pass)
})
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The bridge should load the user.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
@ -288,8 +345,8 @@ func TestBridge_BridgePass(t *testing.T) {
}
func TestBridge_AddressMode(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
require.NoError(t, err)
@ -313,3 +370,8 @@ func TestBridge_AddressMode(t *testing.T) {
})
})
}
// getErr returns the error that was passed to it.
func getErr[T any](val T, err error) error {
return err
}