mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
GODT-1657: More stable sync, with some tests
This commit is contained in:
2
Makefile
2
Makefile
@ -234,7 +234,7 @@ integration-test-bridge:
|
||||
${MAKE} -C test test-bridge
|
||||
|
||||
mocks:
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/bridge TLSReporter,ProxyDialer,Autostarter > internal/bridge/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/bridge TLSReporter,ProxyController,Autostarter > internal/bridge/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/v2/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go
|
||||
|
||||
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog
|
||||
|
||||
2
go.mod
2
go.mod
@ -38,7 +38,7 @@ require (
|
||||
github.com/sirupsen/logrus v1.9.0
|
||||
github.com/stretchr/testify v1.8.0
|
||||
github.com/urfave/cli/v2 v2.16.3
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e
|
||||
golang.org/x/exp v0.0.0-20220921164117-439092de6870
|
||||
golang.org/x/net v0.1.0
|
||||
golang.org/x/sys v0.1.0
|
||||
|
||||
4
go.sum
4
go.sum
@ -397,8 +397,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
|
||||
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7 h1:Hef7jPRzcfLOvOUHYoQ6efaI7p7/aT5kpZDqJ29owNI=
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221007210933-605ca74449b7/go.mod h1:9nsslyEJn7Utbielp4c+hc7qT6hqIJ52aGFR/tX+tYk=
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e h1:CTGaREzkbz7u98nKt6+xsca2bWML79lH1XGbodRo+MY=
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221010190235-49df4dcc853e/go.mod h1:9nsslyEJn7Utbielp4c+hc7qT6hqIJ52aGFR/tX+tYk=
|
||||
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||
|
||||
@ -81,7 +81,18 @@ func newBridge(locations *locations.Locations, identifier *useragent.UserAgent)
|
||||
}
|
||||
|
||||
// Create a new bridge.
|
||||
bridge, err := bridge.New(constants.APIHost, locations, encVault, identifier, pinningDialer, proxyDialer, autostarter, updater, version)
|
||||
bridge, err := bridge.New(
|
||||
constants.APIHost,
|
||||
locations,
|
||||
encVault,
|
||||
identifier,
|
||||
pinningDialer,
|
||||
dialer.CreateTransportWithDialer(proxyDialer),
|
||||
proxyDialer,
|
||||
autostarter,
|
||||
updater,
|
||||
version,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create bridge: %w", err)
|
||||
}
|
||||
|
||||
@ -33,10 +33,10 @@ type Bridge struct {
|
||||
users map[string]*user.User
|
||||
|
||||
// api manages user API clients.
|
||||
api *liteapi.Manager
|
||||
cookieJar *cookies.Jar
|
||||
proxyDialer ProxyDialer
|
||||
identifier Identifier
|
||||
api *liteapi.Manager
|
||||
cookieJar *cookies.Jar
|
||||
proxyCtl ProxyController
|
||||
identifier Identifier
|
||||
|
||||
// watchers holds all registered event watchers.
|
||||
watchers []*watcher.Watcher[events.Event]
|
||||
@ -81,15 +81,16 @@ func New(
|
||||
vault *vault.Vault, // the bridge's encrypted data store
|
||||
identifier Identifier, // the identifier to keep track of the user agent
|
||||
tlsReporter TLSReporter, // the TLS reporter to report TLS errors
|
||||
proxyDialer ProxyDialer, // the DoH dialer
|
||||
roundTripper http.RoundTripper, // the round tripper to use for API requests
|
||||
proxyCtl ProxyController, // the DoH controller
|
||||
autostarter Autostarter, // the autostarter to manage autostart settings
|
||||
updater Updater, // the updater to fetch and install updates
|
||||
curVersion *semver.Version, // the current version of the bridge
|
||||
) (*Bridge, error) {
|
||||
if vault.GetProxyAllowed() {
|
||||
proxyDialer.AllowProxy()
|
||||
proxyCtl.AllowProxy()
|
||||
} else {
|
||||
proxyDialer.DisallowProxy()
|
||||
proxyCtl.DisallowProxy()
|
||||
}
|
||||
|
||||
cookieJar, err := cookies.NewCookieJar(vault)
|
||||
@ -101,7 +102,7 @@ func New(
|
||||
liteapi.WithHostURL(apiURL),
|
||||
liteapi.WithAppVersion(constants.AppVersion),
|
||||
liteapi.WithCookieJar(cookieJar),
|
||||
liteapi.WithTransport(&http.Transport{DialTLSContext: proxyDialer.DialTLSContext}),
|
||||
liteapi.WithTransport(roundTripper),
|
||||
)
|
||||
|
||||
tlsConfig, err := loadTLSConfig(vault)
|
||||
@ -139,10 +140,10 @@ func New(
|
||||
vault: vault,
|
||||
users: make(map[string]*user.User),
|
||||
|
||||
api: api,
|
||||
cookieJar: cookieJar,
|
||||
proxyDialer: proxyDialer,
|
||||
identifier: identifier,
|
||||
api: api,
|
||||
cookieJar: cookieJar,
|
||||
proxyCtl: proxyCtl,
|
||||
identifier: identifier,
|
||||
|
||||
tlsConfig: tlsConfig,
|
||||
imapServer: imapServer,
|
||||
@ -179,6 +180,10 @@ func New(
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := bridge.loadUsers(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load users: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for range tlsReporter.GetTLSIssueCh() {
|
||||
bridge.publish(events.TLSIssue{})
|
||||
@ -197,10 +202,6 @@ func New(
|
||||
}
|
||||
}()
|
||||
|
||||
if err := bridge.loadUsers(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("failed to load connected users: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.serveIMAP(); err != nil {
|
||||
bridge.PushError(ErrServeIMAP)
|
||||
}
|
||||
@ -309,14 +310,8 @@ func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) {
|
||||
func (bridge *Bridge) onStatusUp() {
|
||||
bridge.publish(events.ConnStatusUp{})
|
||||
|
||||
for _, userID := range bridge.vault.GetUserIDs() {
|
||||
if _, ok := bridge.users[userID]; !ok {
|
||||
if vaultUser, err := bridge.vault.GetUser(userID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to get user from vault")
|
||||
} else if err := bridge.loadUser(context.Background(), vaultUser); err != nil {
|
||||
logrus.WithError(err).Error("Failed to load user")
|
||||
}
|
||||
}
|
||||
if err := bridge.loadUsers(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to load users")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ package bridge_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
@ -21,6 +22,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v2/tests"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
"gitlab.protontech.ch/go/liteapi/server/backend"
|
||||
)
|
||||
@ -41,14 +43,14 @@ func init() {
|
||||
}
|
||||
|
||||
func TestBridge_ConnStatus(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of connection status events.
|
||||
eventCh, done := bridge.GetEvents(events.ConnStatusUp{}, events.ConnStatusDown{})
|
||||
defer done()
|
||||
|
||||
// Simulate network disconnect.
|
||||
dialer.SetCanDial(false)
|
||||
netCtl.Disable()
|
||||
|
||||
// Trigger some operation that will fail due to the network disconnect.
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
@ -58,7 +60,7 @@ func TestBridge_ConnStatus(t *testing.T) {
|
||||
require.Equal(t, events.ConnStatusDown{}, <-eventCh)
|
||||
|
||||
// Simulate network reconnect.
|
||||
dialer.SetCanDial(true)
|
||||
netCtl.Enable()
|
||||
|
||||
// Trigger some operation that will succeed due to the network reconnect.
|
||||
userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
@ -72,8 +74,8 @@ func TestBridge_ConnStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_TLSIssue(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of TLS issue events.
|
||||
tlsEventCh, done := bridge.GetEvents(events.TLSIssue{})
|
||||
defer done()
|
||||
@ -90,8 +92,8 @@ func TestBridge_TLSIssue(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Focus(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of TLS issue events.
|
||||
raiseCh, done := bridge.GetEvents(events.Raise{})
|
||||
defer done()
|
||||
@ -106,14 +108,14 @@ func TestBridge_Focus(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_UserAgent(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
var calls []server.Call
|
||||
|
||||
s.AddCallWatcher(func(call server.Call) {
|
||||
calls = append(calls, call)
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Set the platform to something other than the default.
|
||||
bridge.SetCurrentPlatform("platform")
|
||||
|
||||
@ -131,7 +133,7 @@ func TestBridge_UserAgent(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Cookies(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
var calls []server.Call
|
||||
|
||||
s.AddCallWatcher(func(call server.Call) {
|
||||
@ -141,7 +143,7 @@ func TestBridge_Cookies(t *testing.T) {
|
||||
var sessionID string
|
||||
|
||||
// Start bridge and add a user so that API assigns us a session ID via cookie.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -152,7 +154,7 @@ func TestBridge_Cookies(t *testing.T) {
|
||||
})
|
||||
|
||||
// Start bridge again and check that it uses the same session ID.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
cookie, err := (&http.Request{Header: calls[len(calls)-1].Header}).Cookie("Session-Id")
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -162,8 +164,8 @@ func TestBridge_Cookies(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_CheckUpdate(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Disable autoupdate for this test.
|
||||
require.NoError(t, bridge.SetAutoUpdate(false))
|
||||
|
||||
@ -201,8 +203,8 @@ func TestBridge_CheckUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_AutoUpdate(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Enable autoupdate for this test.
|
||||
require.NoError(t, bridge.SetAutoUpdate(true))
|
||||
|
||||
@ -229,8 +231,8 @@ func TestBridge_AutoUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_ManualUpdate(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Disable autoupdate for this test.
|
||||
require.NoError(t, bridge.SetAutoUpdate(false))
|
||||
|
||||
@ -258,8 +260,8 @@ func TestBridge_ManualUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_ForceUpdate(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of update events.
|
||||
updateCh, done := bridge.GetEvents(events.UpdateForced{})
|
||||
defer done()
|
||||
@ -278,11 +280,11 @@ func TestBridge_ForceUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_BadVaultKey(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
var userID string
|
||||
|
||||
// Login a user.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -290,27 +292,27 @@ func TestBridge_BadVaultKey(t *testing.T) {
|
||||
})
|
||||
|
||||
// Start bridge with the correct vault key -- it should load the users correctly.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs())
|
||||
})
|
||||
|
||||
// Start bridge with a bad vault key, the vault will be wiped and bridge will show no users.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
})
|
||||
|
||||
// Start bridge with a nil vault key, the vault will be wiped and bridge will show no users.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_MissingGluonDir(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
var gluonDir string
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -325,20 +327,20 @@ func TestBridge_MissingGluonDir(t *testing.T) {
|
||||
require.NoError(t, os.RemoveAll(gluonDir))
|
||||
|
||||
// Bridge starts but can't find the gluon dir; there should be no error.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// ...
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// withEnv creates the full test environment and runs the tests.
|
||||
func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte)) {
|
||||
// withTLSEnv creates the full test environment and runs the tests.
|
||||
func withTLSEnv(t *testing.T, tests func(context.Context, *server.Server, *liteapi.NetCtl, bridge.Locator, []byte)) {
|
||||
// Create test API.
|
||||
server := server.NewTLS()
|
||||
defer server.Close()
|
||||
|
||||
// Add test user.
|
||||
_, _, err := server.CreateUser(username, string(password), username+"@pm.me")
|
||||
_, _, err := server.CreateUser(username, username+"@pm.me", password)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a random vault key.
|
||||
@ -349,23 +351,56 @@ func withEnv(t *testing.T, tests func(ctx context.Context, server *server.Server
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create a net controller so we can simulate network connectivity issues.
|
||||
netCtl := liteapi.NewNetCtl()
|
||||
|
||||
// Create a locations object to provide temporary locations for bridge data during the test.
|
||||
locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name")
|
||||
|
||||
// Run the tests.
|
||||
tests(
|
||||
ctx,
|
||||
server,
|
||||
bridge.NewTestDialer(),
|
||||
locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name"),
|
||||
vaultKey,
|
||||
)
|
||||
tests(ctx, server, netCtl, locations, vaultKey)
|
||||
}
|
||||
|
||||
// withEnv creates the full test environment and runs the tests.
|
||||
func withEnv(t *testing.T, server *server.Server, tests func(context.Context, *liteapi.NetCtl, bridge.Locator, []byte)) {
|
||||
// Add test user.
|
||||
_, _, err := server.CreateUser(username, username+"@pm.me", password)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a random vault key.
|
||||
vaultKey, err := crypto.RandomToken(32)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a context used for the test.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create a net controller so we can simulate network connectivity issues.
|
||||
netCtl := liteapi.NewNetCtl()
|
||||
|
||||
// Create a locations object to provide temporary locations for bridge data during the test.
|
||||
locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name")
|
||||
|
||||
// Run the tests.
|
||||
tests(ctx, netCtl, locations, vaultKey)
|
||||
}
|
||||
|
||||
// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done.
|
||||
func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) {
|
||||
func withBridge(
|
||||
t *testing.T,
|
||||
ctx context.Context,
|
||||
apiURL string,
|
||||
netCtl *liteapi.NetCtl,
|
||||
locator bridge.Locator,
|
||||
vaultKey []byte,
|
||||
tests func(*bridge.Bridge, *bridge.Mocks),
|
||||
) {
|
||||
// Create the mock objects used in the tests.
|
||||
mocks := bridge.NewMocks(t, dialer, v2_3_0, v2_3_0)
|
||||
mocks := bridge.NewMocks(t, v2_3_0, v2_3_0)
|
||||
defer mocks.Close()
|
||||
|
||||
// Bridge will enable the proxy by default at startup.
|
||||
mocks.ProxyDialer.EXPECT().AllowProxy()
|
||||
mocks.ProxyCtl.EXPECT().AllowProxy()
|
||||
|
||||
// Get the path to the vault.
|
||||
vaultDir, err := locator.ProvideSettingsPath()
|
||||
@ -380,7 +415,18 @@ func withBridge(t *testing.T, ctx context.Context, apiURL string, dialer *bridge
|
||||
require.NoError(t, vault.SetSMTPPort(0))
|
||||
|
||||
// Create a new bridge.
|
||||
bridge, err := bridge.New(apiURL, locator, vault, useragent.New(), mocks.TLSReporter, mocks.ProxyDialer, mocks.Autostarter, mocks.Updater, v2_3_0)
|
||||
bridge, err := bridge.New(
|
||||
apiURL,
|
||||
locator,
|
||||
vault,
|
||||
useragent.New(),
|
||||
mocks.TLSReporter,
|
||||
liteapi.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
|
||||
mocks.ProxyCtl,
|
||||
mocks.Autostarter,
|
||||
mocks.Updater,
|
||||
v2_3_0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close the bridge when done.
|
||||
|
||||
@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon"
|
||||
@ -33,6 +35,22 @@ func (bridge *Bridge) serveIMAP() error {
|
||||
return fmt.Errorf("failed to serve IMAP: %w", err)
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(imapListener.Addr().String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get IMAP listener address: %w", err)
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert IMAP listener port to int: %w", err)
|
||||
}
|
||||
|
||||
if portInt != bridge.vault.GetIMAPPort() {
|
||||
if err := bridge.vault.SetIMAPPort(portInt); err != nil {
|
||||
return fmt.Errorf("failed to update IMAP port in vault: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
for err := range bridge.imapServer.GetErrorCh() {
|
||||
logrus.WithError(err).Error("IMAP server error")
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@ -15,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
type Mocks struct {
|
||||
ProxyDialer *mocks.MockProxyDialer
|
||||
ProxyCtl *mocks.MockProxyController
|
||||
TLSReporter *mocks.MockTLSReporter
|
||||
TLSIssueCh chan struct{}
|
||||
|
||||
@ -23,11 +19,11 @@ type Mocks struct {
|
||||
Autostarter *mocks.MockAutostarter
|
||||
}
|
||||
|
||||
func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Version) *Mocks {
|
||||
func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks {
|
||||
ctl := gomock.NewController(tb)
|
||||
|
||||
mocks := &Mocks{
|
||||
ProxyDialer: mocks.NewMockProxyDialer(ctl),
|
||||
ProxyCtl: mocks.NewMockProxyController(ctl),
|
||||
TLSReporter: mocks.NewMockTLSReporter(ctl),
|
||||
TLSIssueCh: make(chan struct{}),
|
||||
|
||||
@ -35,41 +31,14 @@ func NewMocks(tb testing.TB, dialer *TestDialer, version, minAuto *semver.Versio
|
||||
Autostarter: mocks.NewMockAutostarter(ctl),
|
||||
}
|
||||
|
||||
// When using the proxy dialer, we want to use the test dialer.
|
||||
mocks.ProxyDialer.EXPECT().DialTLSContext(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).DoAndReturn(func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return dialer.DialTLSContext(ctx, network, address)
|
||||
}).AnyTimes()
|
||||
|
||||
// When getting the TLS issue channel, we want to return the test channel.
|
||||
mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes()
|
||||
|
||||
return mocks
|
||||
}
|
||||
|
||||
type TestDialer struct {
|
||||
canDial bool
|
||||
}
|
||||
|
||||
func NewTestDialer() *TestDialer {
|
||||
return &TestDialer{
|
||||
canDial: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *TestDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||
if !d.canDial {
|
||||
return nil, errors.New("cannot dial")
|
||||
}
|
||||
|
||||
return (&tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}).DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
func (d *TestDialer) SetCanDial(canDial bool) {
|
||||
d.canDial = canDial
|
||||
func (mocks *Mocks) Close() {
|
||||
close(mocks.TLSIssueCh)
|
||||
}
|
||||
|
||||
type TestLocationsProvider struct {
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyDialer,Autostarter)
|
||||
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyController,Autostarter)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@ -49,66 +47,51 @@ func (mr *MockTLSReporterMockRecorder) GetTLSIssueCh() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh))
|
||||
}
|
||||
|
||||
// MockProxyDialer is a mock of ProxyDialer interface.
|
||||
type MockProxyDialer struct {
|
||||
// MockProxyController is a mock of ProxyController interface.
|
||||
type MockProxyController struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockProxyDialerMockRecorder
|
||||
recorder *MockProxyControllerMockRecorder
|
||||
}
|
||||
|
||||
// MockProxyDialerMockRecorder is the mock recorder for MockProxyDialer.
|
||||
type MockProxyDialerMockRecorder struct {
|
||||
mock *MockProxyDialer
|
||||
// MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
|
||||
type MockProxyControllerMockRecorder struct {
|
||||
mock *MockProxyController
|
||||
}
|
||||
|
||||
// NewMockProxyDialer creates a new mock instance.
|
||||
func NewMockProxyDialer(ctrl *gomock.Controller) *MockProxyDialer {
|
||||
mock := &MockProxyDialer{ctrl: ctrl}
|
||||
mock.recorder = &MockProxyDialerMockRecorder{mock}
|
||||
// NewMockProxyController creates a new mock instance.
|
||||
func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
|
||||
mock := &MockProxyController{ctrl: ctrl}
|
||||
mock.recorder = &MockProxyControllerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockProxyDialer) EXPECT() *MockProxyDialerMockRecorder {
|
||||
func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AllowProxy mocks base method.
|
||||
func (m *MockProxyDialer) AllowProxy() {
|
||||
func (m *MockProxyController) AllowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AllowProxy")
|
||||
}
|
||||
|
||||
// AllowProxy indicates an expected call of AllowProxy.
|
||||
func (mr *MockProxyDialerMockRecorder) AllowProxy() *gomock.Call {
|
||||
func (mr *MockProxyControllerMockRecorder) AllowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyDialer)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// DialTLSContext mocks base method.
|
||||
func (m *MockProxyDialer) DialTLSContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DialTLSContext", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(net.Conn)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DialTLSContext indicates an expected call of DialTLSContext.
|
||||
func (mr *MockProxyDialerMockRecorder) DialTLSContext(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLSContext", reflect.TypeOf((*MockProxyDialer)(nil).DialTLSContext), arg0, arg1, arg2)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyController)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// DisallowProxy mocks base method.
|
||||
func (m *MockProxyDialer) DisallowProxy() {
|
||||
func (m *MockProxyController) DisallowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisallowProxy")
|
||||
}
|
||||
|
||||
// DisallowProxy indicates an expected call of DisallowProxy.
|
||||
func (mr *MockProxyDialerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
func (mr *MockProxyControllerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyDialer)(nil).DisallowProxy))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyController)(nil).DisallowProxy))
|
||||
}
|
||||
|
||||
// MockAutostarter is a mock of Autostarter interface.
|
||||
|
||||
@ -138,9 +138,9 @@ func (bridge *Bridge) GetProxyAllowed() bool {
|
||||
|
||||
func (bridge *Bridge) SetProxyAllowed(allowed bool) error {
|
||||
if allowed {
|
||||
bridge.proxyDialer.AllowProxy()
|
||||
bridge.proxyCtl.AllowProxy()
|
||||
} else {
|
||||
bridge.proxyDialer.DisallowProxy()
|
||||
bridge.proxyCtl.DisallowProxy()
|
||||
}
|
||||
|
||||
return bridge.vault.SetProxyAllowed(allowed)
|
||||
|
||||
@ -7,12 +7,13 @@ import (
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
func TestBridge_Settings_GluonDir(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Create a user.
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
@ -34,8 +35,8 @@ func TestBridge_Settings_GluonDir(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_IMAPPort(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
curPort := bridge.GetIMAPPort()
|
||||
|
||||
// Set the port to 1144.
|
||||
@ -51,8 +52,8 @@ func TestBridge_Settings_IMAPPort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_IMAPSSL(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, IMAP SSL is disabled.
|
||||
require.False(t, bridge.GetIMAPSSL())
|
||||
|
||||
@ -66,8 +67,8 @@ func TestBridge_Settings_IMAPSSL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_SMTPPort(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
curPort := bridge.GetSMTPPort()
|
||||
|
||||
// Set the port to 1024.
|
||||
@ -84,8 +85,8 @@ func TestBridge_Settings_SMTPPort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_SMTPSSL(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, SMTP SSL is disabled.
|
||||
require.False(t, bridge.GetSMTPSSL())
|
||||
|
||||
@ -99,13 +100,13 @@ func TestBridge_Settings_SMTPSSL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_Proxy(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, proxy is allowed.
|
||||
require.True(t, bridge.GetProxyAllowed())
|
||||
|
||||
// Disallow proxy.
|
||||
mocks.ProxyDialer.EXPECT().DisallowProxy()
|
||||
mocks.ProxyCtl.EXPECT().DisallowProxy()
|
||||
require.NoError(t, bridge.SetProxyAllowed(false))
|
||||
|
||||
// Get the new setting.
|
||||
@ -115,8 +116,8 @@ func TestBridge_Settings_Proxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_Autostart(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, autostart is disabled.
|
||||
require.False(t, bridge.GetAutostart())
|
||||
|
||||
@ -131,8 +132,8 @@ func TestBridge_Settings_Autostart(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_FirstStart(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, first start is true.
|
||||
require.True(t, bridge.GetFirstStart())
|
||||
|
||||
@ -146,8 +147,8 @@ func TestBridge_Settings_FirstStart(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Settings_FirstStartGUI(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, first start is true.
|
||||
require.True(t, bridge.GetFirstStartGUI())
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@ package bridge
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/emersion/go-sasl"
|
||||
@ -22,6 +24,22 @@ func (bridge *Bridge) serveSMTP() error {
|
||||
}
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(smtpListener.Addr().String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get SMTP listener address: %w", err)
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert SMTP listener port to int: %w", err)
|
||||
}
|
||||
|
||||
if portInt != bridge.vault.GetSMTPPort() {
|
||||
if err := bridge.vault.SetSMTPPort(portInt); err != nil {
|
||||
return fmt.Errorf("failed to update SMTP port in vault: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
130
internal/bridge/sync_test.go
Normal file
130
internal/bridge/sync_test.go
Normal file
@ -0,0 +1,130 @@
|
||||
package bridge_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/emersion/go-imap/client"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
func TestBridge_Sync(t *testing.T) {
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
|
||||
numMsg := 1 << 10
|
||||
|
||||
withEnv(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password)
|
||||
require.NoError(t, err)
|
||||
|
||||
labelID, err := s.CreateLabel(userID, "folder", liteapi.LabelTypeFolder)
|
||||
require.NoError(t, err)
|
||||
|
||||
literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml"))
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < numMsg; i++ {
|
||||
messageID, err := s.CreateMessage(userID, addrID, literal, liteapi.MessageFlagReceived, false, false)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.LabelMessage(userID, messageID, labelID))
|
||||
}
|
||||
|
||||
var read uint64
|
||||
|
||||
netCtl.OnRead(func(b []byte) {
|
||||
read += uint64(len(b))
|
||||
})
|
||||
|
||||
// The initial user should be fully synced.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
syncCh, done := bridge.GetEvents(events.SyncFinished{})
|
||||
defer done()
|
||||
|
||||
userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID)
|
||||
})
|
||||
|
||||
// If we then connect an IMAP client, it should see all the messages.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
info, err := bridge.GetUserInfo(userID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||
defer client.Logout()
|
||||
|
||||
status, err := client.Select(`Folders/folder`, false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(numMsg), status.Messages)
|
||||
})
|
||||
|
||||
// Now let's remove the user and simulate a network error.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||
})
|
||||
|
||||
// Pretend we can only sync 2/3 of the original messages.
|
||||
netCtl.SetReadLimit(2 * read / 3)
|
||||
|
||||
// Login the user; its sync should fail.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
syncCh, done := bridge.GetEvents(events.SyncFailed{})
|
||||
defer done()
|
||||
|
||||
userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, userID, (<-syncCh).(events.SyncFailed).UserID)
|
||||
|
||||
info, err := bridge.GetUserInfo(userID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||
defer client.Logout()
|
||||
|
||||
status, err := client.Select(`Folders/folder`, false)
|
||||
require.NoError(t, err)
|
||||
require.Less(t, status.Messages, uint32(numMsg))
|
||||
})
|
||||
|
||||
// Remove the network limit, allowing the sync to finish.
|
||||
netCtl.SetReadLimit(0)
|
||||
|
||||
// Login the user; its sync should now finish.
|
||||
// If we then connect an IMAP client, it should eventually see all the messages.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
syncCh, done := bridge.GetEvents(events.SyncFinished{})
|
||||
defer done()
|
||||
|
||||
require.Equal(t, userID, (<-syncCh).(events.SyncFinished).UserID)
|
||||
|
||||
info, err := bridge.GetUserInfo(userID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||
defer client.Logout()
|
||||
|
||||
status, err := client.Select(`Folders/folder`, false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(numMsg), status.Messages)
|
||||
})
|
||||
})
|
||||
}
|
||||
6
internal/bridge/testdata/text-plain.eml
vendored
Normal file
6
internal/bridge/testdata/text-plain.eml
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
To: recipient@pm.me
|
||||
From: sender@pm.me
|
||||
Subject: Test
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
|
||||
Test
|
||||
@ -1,9 +1,6 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
)
|
||||
|
||||
@ -21,17 +18,15 @@ type Identifier interface {
|
||||
SetPlatform(platform string)
|
||||
}
|
||||
|
||||
type TLSReporter interface {
|
||||
GetTLSIssueCh() <-chan struct{}
|
||||
}
|
||||
|
||||
type ProxyDialer interface {
|
||||
DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
type ProxyController interface {
|
||||
AllowProxy()
|
||||
DisallowProxy()
|
||||
}
|
||||
|
||||
type TLSReporter interface {
|
||||
GetTLSIssueCh() <-chan struct{}
|
||||
}
|
||||
|
||||
type Autostarter interface {
|
||||
Enable() error
|
||||
Disable() error
|
||||
|
||||
@ -18,6 +18,9 @@ func (bridge *Bridge) watchForUpdates() error {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-bridge.stopCh:
|
||||
return
|
||||
|
||||
case <-bridge.updateCheckCh:
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/try"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/go-resty/resty/v2"
|
||||
@ -82,76 +83,76 @@ func (bridge *Bridge) LoginUser(
|
||||
) (string, error) {
|
||||
client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to create new API client: %w", err)
|
||||
}
|
||||
|
||||
if _, ok := bridge.users[auth.UserID]; ok {
|
||||
return "", ErrUserAlreadyLoggedIn
|
||||
}
|
||||
userID, err := try.CatchVal(
|
||||
func() (string, error) {
|
||||
if _, ok := bridge.users[auth.UserID]; ok {
|
||||
return "", ErrUserAlreadyLoggedIn
|
||||
}
|
||||
|
||||
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
|
||||
totp, err := getTOTP()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
|
||||
totp, err := getTOTP()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get TOTP: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
|
||||
return "", fmt.Errorf("failed to authorize 2FA: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var keyPass []byte
|
||||
var keyPass []byte
|
||||
|
||||
if auth.PasswordMode == liteapi.TwoPasswordMode {
|
||||
pass, err := getKeyPass()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if auth.PasswordMode == liteapi.TwoPasswordMode {
|
||||
userKeyPass, err := getKeyPass()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get key password: %w", err)
|
||||
}
|
||||
|
||||
keyPass = pass
|
||||
} else {
|
||||
keyPass = password
|
||||
}
|
||||
keyPass = userKeyPass
|
||||
} else {
|
||||
keyPass = password
|
||||
}
|
||||
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
return bridge.loginUser(ctx, client, auth.UID, auth.RefreshToken, keyPass)
|
||||
},
|
||||
func() error {
|
||||
return client.AuthDelete(ctx)
|
||||
},
|
||||
func() error {
|
||||
bridge.deleteUser(ctx, auth.UserID)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to login user: %w", err)
|
||||
}
|
||||
|
||||
salts, err := client.GetSalts(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
bridge.publish(events.UserLoggedIn{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return auth.UserID, nil
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// LogoutUser logs out the given user.
|
||||
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
|
||||
return bridge.logoutUser(ctx, userID, true, false)
|
||||
if err := bridge.logoutUser(ctx, userID); err != nil {
|
||||
return fmt.Errorf("failed to logout user: %w", err)
|
||||
}
|
||||
|
||||
bridge.publish(events.UserLoggedOut{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes the given user.
|
||||
// If it is authorized, it is logged out first.
|
||||
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
|
||||
if bridge.users[userID] != nil {
|
||||
if err := bridge.logoutUser(ctx, userID, true, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
return err
|
||||
}
|
||||
bridge.deleteUser(ctx, userID)
|
||||
|
||||
bridge.publish(events.UserDeleted{
|
||||
UserID: userID,
|
||||
@ -193,53 +194,91 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadUsers loads authorized users from the vault.
|
||||
func (bridge *Bridge) loadUsers(ctx context.Context) error {
|
||||
for _, userID := range bridge.vault.GetUserIDs() {
|
||||
user, err := bridge.vault.GetUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get API user: %w", err)
|
||||
}
|
||||
|
||||
salts, err := client.GetSalts(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get key salts: %w", err)
|
||||
}
|
||||
|
||||
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to salt key password: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass); err != nil {
|
||||
return "", fmt.Errorf("failed to add bridge user: %w", err)
|
||||
}
|
||||
|
||||
return apiUser.ID, nil
|
||||
}
|
||||
|
||||
// loadUsers is a loop that, when polled, attempts to load authorized users from the vault.
|
||||
func (bridge *Bridge) loadUsers() error {
|
||||
return bridge.vault.ForUser(func(user *vault.User) error {
|
||||
if _, ok := bridge.users[user.UserID()]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if user.AuthUID() == "" {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := bridge.loadUser(ctx, user); err != nil {
|
||||
logrus.WithError(err).Error("Failed to load connected user")
|
||||
|
||||
if err := bridge.loadUser(user); err != nil {
|
||||
if _, ok := err.(*resty.ResponseError); ok {
|
||||
if err := bridge.vault.ClearUser(userID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to load connected user, clearing its secrets from vault")
|
||||
|
||||
if err := user.Clear(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to clear user")
|
||||
}
|
||||
} else {
|
||||
logrus.WithError(err).Error("Failed to load connected user")
|
||||
}
|
||||
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
bridge.publish(events.UserLoaded{
|
||||
UserID: user.UserID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
|
||||
// loadUser loads an existing user from the vault.
|
||||
func (bridge *Bridge) loadUser(user *vault.User) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create API client: %w", err)
|
||||
}
|
||||
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
if err := try.Catch(
|
||||
func() error {
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil {
|
||||
return fmt.Errorf("failed to add user: %w", err)
|
||||
return bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass())
|
||||
},
|
||||
func() error {
|
||||
return client.AuthDelete(ctx)
|
||||
},
|
||||
func() error {
|
||||
return bridge.logoutUser(ctx, user.UserID())
|
||||
},
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to load user: %w", err)
|
||||
}
|
||||
|
||||
bridge.publish(events.UserLoggedIn{
|
||||
UserID: user.UserID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -304,10 +343,6 @@ func (bridge *Bridge) addUser(
|
||||
return nil
|
||||
})
|
||||
|
||||
bridge.publish(events.UserLoggedIn{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -363,54 +398,6 @@ func (bridge *Bridge) addExistingUser(
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// logoutUser closes and removes the user with the given ID.
|
||||
// If withAPI is true, the user will additionally be logged out from API.
|
||||
// If withFiles is true, the user's files will be deleted.
|
||||
func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, withFiles bool) error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||
return fmt.Errorf("failed to remove SMTP user: %w", err)
|
||||
}
|
||||
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
|
||||
return fmt.Errorf("failed to remove IMAP user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if withAPI {
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
return fmt.Errorf("failed to logout user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close user: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.vault.ClearUser(userID); err != nil {
|
||||
return fmt.Errorf("failed to clear user: %w", err)
|
||||
}
|
||||
|
||||
if withFiles {
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
delete(bridge.users, userID)
|
||||
|
||||
bridge.publish(events.UserLoggedOut{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addIMAPUser connects the given user to gluon.
|
||||
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
@ -438,6 +425,65 @@ func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// logoutUser logs the given user out from bridge.
|
||||
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
|
||||
}
|
||||
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
if err := user.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close user")
|
||||
}
|
||||
|
||||
delete(bridge.users, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteUser deletes the given user from bridge.
|
||||
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
|
||||
if user, ok := bridge.users[userID]; ok {
|
||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
|
||||
}
|
||||
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
if err := user.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to delete user from vault")
|
||||
}
|
||||
|
||||
delete(bridge.users, userID)
|
||||
}
|
||||
|
||||
// getUserInfo returns information about a disconnected user.
|
||||
func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo {
|
||||
return UserInfo{
|
||||
|
||||
@ -27,7 +27,7 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
|
||||
}
|
||||
|
||||
case events.UserDeauth:
|
||||
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
|
||||
if err := bridge.logoutUser(context.Background(), event.UserID); err != nil {
|
||||
return fmt.Errorf("failed to logout user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -9,17 +9,18 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
func TestBridge_WithoutUsers(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
@ -27,8 +28,8 @@ func TestBridge_WithoutUsers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_Login(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
@ -41,8 +42,8 @@ func TestBridge_Login(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginLogoutLogin(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -69,8 +70,8 @@ func TestBridge_LoginLogoutLogin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginDeleteLogin(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -97,8 +98,8 @@ func TestBridge_LoginDeleteLogin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginDeauthLogin(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -131,10 +132,10 @@ func TestBridge_LoginDeauthLogin(t *testing.T) {
|
||||
func TestBridge_LoginExpireLogin(t *testing.T) {
|
||||
const authLife = 2 * time.Second
|
||||
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
s.SetAuthLife(authLife)
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user. Its auth will only be valid for a short time.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -148,11 +149,11 @@ func TestBridge_LoginExpireLogin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_FailToLoad(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
// Login the user.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
})
|
||||
|
||||
@ -160,7 +161,7 @@ func TestBridge_FailToLoad(t *testing.T) {
|
||||
require.NoError(t, s.RevokeUser(userID))
|
||||
|
||||
// When bridge starts, the user will not be logged in.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
@ -168,25 +169,27 @@ func TestBridge_FailToLoad(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoadWithoutInternet(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
// Login the user.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
})
|
||||
|
||||
// Simulate loss of internet connection.
|
||||
dialer.SetCanDial(false)
|
||||
netCtl.Disable()
|
||||
|
||||
// Start bridge without internet.
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Initially, users are not connected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Simulate internet connection.
|
||||
dialer.SetCanDial(true)
|
||||
netCtl.Enable()
|
||||
|
||||
// The user will eventually be connected.
|
||||
require.Eventually(t, func() bool {
|
||||
@ -197,16 +200,14 @@ func TestBridge_LoadWithoutInternet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginRestart(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user is still connected.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
@ -214,10 +215,10 @@ func TestBridge_LoginRestart(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginLogoutRestart(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -225,7 +226,7 @@ func TestBridge_LoginLogoutRestart(t *testing.T) {
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user is still disconnected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
@ -234,10 +235,10 @@ func TestBridge_LoginLogoutRestart(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_LoginDeleteRestart(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -245,7 +246,7 @@ func TestBridge_LoginDeleteRestart(t *testing.T) {
|
||||
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user is still gone.
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
@ -253,13 +254,69 @@ func TestBridge_LoginDeleteRestart(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_FailLoginRecover(t *testing.T) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var read uint64
|
||||
|
||||
netCtl.OnRead(func(b []byte) {
|
||||
read += uint64(len(b))
|
||||
})
|
||||
|
||||
// Log the user in and record how much data was read.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
})
|
||||
|
||||
// Simulate a partial read.
|
||||
netCtl.SetReadLimit(read / 2)
|
||||
|
||||
// We should fail to log the user in because we can't fully read its data.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Error(t, getErr(bridge.LoginUser(ctx, username, password, nil, nil)))
|
||||
|
||||
// There should be no users.
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_FailLoadRecover(t *testing.T) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
})
|
||||
|
||||
var read uint64
|
||||
|
||||
netCtl.OnRead(func(b []byte) {
|
||||
read += uint64(len(b))
|
||||
})
|
||||
|
||||
// Start bridge and record how much data was read.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// ...
|
||||
})
|
||||
|
||||
// Simulate a partial read.
|
||||
netCtl.SetReadLimit(read / 2)
|
||||
|
||||
// We should fail to load the user; it should be disconnected.
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userIDs := bridge.GetUserIDs()
|
||||
|
||||
require.False(t, must(bridge.GetUserInfo(userIDs[0])).Connected)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_BridgePass(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
var pass []byte
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
@ -276,7 +333,7 @@ func TestBridge_BridgePass(t *testing.T) {
|
||||
require.Equal(t, pass, pass)
|
||||
})
|
||||
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The bridge should load the user.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
@ -288,8 +345,8 @@ func TestBridge_BridgePass(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBridge_AddressMode(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, ctx, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
@ -313,3 +370,8 @@ func TestBridge_AddressMode(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// getErr returns the error that was passed to it.
|
||||
func getErr[T any](val T, err error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -2,6 +2,12 @@ package events
|
||||
|
||||
import "github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
|
||||
type UserLoaded struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
}
|
||||
|
||||
type UserLoggedIn struct {
|
||||
eventBase
|
||||
|
||||
|
||||
@ -64,4 +64,5 @@ func (service *FocusService) GetRaiseCh() <-chan struct{} {
|
||||
// Close closes the service.
|
||||
func (service *FocusService) Close() {
|
||||
service.server.Stop()
|
||||
close(service.raiseCh)
|
||||
}
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
package pool
|
||||
|
||||
import "context"
|
||||
|
||||
type job[In, Out any] struct {
|
||||
ctx context.Context
|
||||
req In
|
||||
|
||||
res chan Out
|
||||
err chan error
|
||||
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newJob[In, Out any](ctx context.Context, req In) *job[In, Out] {
|
||||
return &job[In, Out]{
|
||||
ctx: ctx,
|
||||
req: req,
|
||||
res: make(chan Out),
|
||||
err: make(chan error),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (job *job[In, Out]) result() (Out, error) {
|
||||
return <-job.res, <-job.err
|
||||
}
|
||||
|
||||
func (job *job[In, Out]) postSuccess(res Out) {
|
||||
close(job.err)
|
||||
job.res <- res
|
||||
}
|
||||
|
||||
func (job *job[In, Out]) postFailure(err error) {
|
||||
close(job.res)
|
||||
job.err <- err
|
||||
}
|
||||
|
||||
func (job *job[In, Out]) waitDone() {
|
||||
<-job.done
|
||||
}
|
||||
@ -1,147 +0,0 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
)
|
||||
|
||||
// ErrJobCancelled indicates the job was cancelled.
|
||||
var ErrJobCancelled = errors.New("job cancelled by surrounding context")
|
||||
|
||||
// Pool is a worker pool that handles input of type In and returns results of type Out.
|
||||
type Pool[In comparable, Out any] struct {
|
||||
queue *queue.QueuedChannel[*job[In, Out]]
|
||||
size int
|
||||
}
|
||||
|
||||
// doneFunc must be called to free up pool resources.
|
||||
type doneFunc func()
|
||||
|
||||
// New returns a new pool.
|
||||
func New[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] {
|
||||
queue := queue.NewQueuedChannel[*job[In, Out]](0, 0)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
go func() {
|
||||
for job := range queue.GetChannel() {
|
||||
select {
|
||||
case <-job.ctx.Done():
|
||||
job.postFailure(ErrJobCancelled)
|
||||
|
||||
default:
|
||||
res, err := work(job.ctx, job.req)
|
||||
if err != nil {
|
||||
job.postFailure(err)
|
||||
} else {
|
||||
job.postSuccess(res)
|
||||
}
|
||||
|
||||
job.waitDone()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return &Pool[In, Out]{
|
||||
queue: queue,
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
// Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred.
|
||||
func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, Out, error) error) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
errList []error
|
||||
lock sync.Mutex
|
||||
)
|
||||
|
||||
for _, req := range reqs {
|
||||
req := req
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
job, done := pool.newJob(ctx, req)
|
||||
defer done()
|
||||
|
||||
res, err := job.result()
|
||||
|
||||
if err := fn(req, res, err); err != nil {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
// Cancel ongoing jobs.
|
||||
cancel()
|
||||
|
||||
// Collect the error.
|
||||
errList = append(errList, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// TODO: Join the errors somehow?
|
||||
if len(errList) > 0 {
|
||||
return errList[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessAll submits jobs to the pool. All results are returned once available.
|
||||
func (pool *Pool[In, Out]) ProcessAll(ctx context.Context, reqs []In) (map[In]Out, error) {
|
||||
var (
|
||||
data = make(map[In]Out)
|
||||
lock = sync.Mutex{}
|
||||
)
|
||||
|
||||
if err := pool.Process(ctx, reqs, func(req In, res Out, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
data[req] = res
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// ProcessOne submits one job to the pool and returns the result.
|
||||
func (pool *Pool[In, Out]) ProcessOne(ctx context.Context, req In) (Out, error) {
|
||||
job, done := pool.newJob(ctx, req)
|
||||
defer done()
|
||||
|
||||
return job.result()
|
||||
}
|
||||
|
||||
func (pool *Pool[In, Out]) Done() {
|
||||
pool.queue.Close()
|
||||
}
|
||||
|
||||
// newJob submits a job to the pool. It returns a job handle and a DoneFunc.
|
||||
// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done,
|
||||
// which frees up the worker in the pool for reuse.
|
||||
func (pool *Pool[In, Out]) newJob(ctx context.Context, req In) (*job[In, Out], doneFunc) {
|
||||
job := newJob[In, Out](ctx, req)
|
||||
|
||||
pool.queue.Enqueue(job)
|
||||
|
||||
return job, func() { close(job.done) }
|
||||
}
|
||||
@ -1,163 +0,0 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPool_NewJob(t *testing.T) {
|
||||
doubler := newDoubler(runtime.NumCPU())
|
||||
|
||||
job1, done1 := doubler.newJob(context.Background(), 1)
|
||||
defer done1()
|
||||
|
||||
job2, done2 := doubler.newJob(context.Background(), 2)
|
||||
defer done2()
|
||||
|
||||
res2, err := job2.result()
|
||||
require.NoError(t, err)
|
||||
|
||||
res1, err := job1.result()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, res1)
|
||||
assert.Equal(t, 4, res2)
|
||||
}
|
||||
|
||||
func TestPool_NewJob_Done(t *testing.T) {
|
||||
// Create a doubler pool with 2 workers.
|
||||
doubler := newDoubler(2)
|
||||
|
||||
// Start two jobs. Don't mark the jobs as done yet.
|
||||
job1, done1 := doubler.newJob(context.Background(), 1)
|
||||
job2, done2 := doubler.newJob(context.Background(), 2)
|
||||
|
||||
// Get the first result.
|
||||
res1, _ := job1.result()
|
||||
assert.Equal(t, 2, res1)
|
||||
|
||||
// Get the first result.
|
||||
res2, _ := job2.result()
|
||||
assert.Equal(t, 4, res2)
|
||||
|
||||
// Additional jobs will wait.
|
||||
job3, _ := doubler.newJob(context.Background(), 3)
|
||||
job4, _ := doubler.newJob(context.Background(), 4)
|
||||
|
||||
// Channel to collect results from jobs 3 and 4.
|
||||
resCh := make(chan int, 2)
|
||||
|
||||
go func() {
|
||||
res, _ := job3.result()
|
||||
resCh <- res
|
||||
}()
|
||||
|
||||
go func() {
|
||||
res, _ := job4.result()
|
||||
resCh <- res
|
||||
}()
|
||||
|
||||
// Mark jobs 1 and 2 as done, freeing up the workers.
|
||||
done1()
|
||||
done2()
|
||||
|
||||
assert.ElementsMatch(t, []int{6, 8}, []int{<-resCh, <-resCh})
|
||||
}
|
||||
|
||||
func TestPool_Process(t *testing.T) {
|
||||
doubler := newDoubler(runtime.NumCPU())
|
||||
|
||||
var (
|
||||
res = make(map[int]int)
|
||||
lock sync.Mutex
|
||||
)
|
||||
|
||||
require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(reqVal, resVal int, err error) error {
|
||||
require.NoError(t, err)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
res[reqVal] = resVal
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
assert.Equal(t, map[int]int{
|
||||
1: 2,
|
||||
2: 4,
|
||||
3: 6,
|
||||
4: 8,
|
||||
5: 10,
|
||||
}, res)
|
||||
}
|
||||
|
||||
func TestPool_Process_Error(t *testing.T) {
|
||||
doubler := newDoublerWithError(runtime.NumCPU())
|
||||
|
||||
assert.Error(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(_ int, _ int, err error) error {
|
||||
return err
|
||||
}))
|
||||
}
|
||||
|
||||
func TestPool_Process_Parallel(t *testing.T) {
|
||||
doubler := newDoubler(runtime.NumCPU(), 100*time.Millisecond)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4}, func(_ int, _ int, err error) error {
|
||||
return nil
|
||||
}))
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPool_ProcessAll(t *testing.T) {
|
||||
doubler := newDoubler(runtime.NumCPU())
|
||||
|
||||
res, err := doubler.ProcessAll(context.Background(), []int{1, 2, 3, 4, 5})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, map[int]int{
|
||||
1: 2,
|
||||
2: 4,
|
||||
3: 6,
|
||||
4: 8,
|
||||
5: 10,
|
||||
}, res)
|
||||
}
|
||||
|
||||
func newDoubler(workers int, delay ...time.Duration) *Pool[int, int] {
|
||||
return New(workers, func(ctx context.Context, req int) (int, error) {
|
||||
if len(delay) > 0 {
|
||||
time.Sleep(delay[0])
|
||||
}
|
||||
|
||||
return 2 * req, nil
|
||||
})
|
||||
}
|
||||
|
||||
func newDoublerWithError(workers int) *Pool[int, int] {
|
||||
return New(workers, func(ctx context.Context, req int) (int, error) {
|
||||
if req%2 == 0 {
|
||||
return 0, errors.New("oops")
|
||||
}
|
||||
|
||||
return 2 * req, nil
|
||||
})
|
||||
}
|
||||
@ -23,7 +23,15 @@ func NewMap[Key comparable, Val any](from map[Key]Val) *Map[Key, Val] {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) Get(key Key, fn func(val Val)) bool {
|
||||
func (m *Map[Key, Val]) Has(key Key) bool {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
_, ok := m.data[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) Get(key Key, fn func(Val)) bool {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
@ -37,7 +45,7 @@ func (m *Map[Key, Val]) Get(key Key, fn func(val Val)) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) GetErr(key Key, fn func(val Val) error) (bool, error) {
|
||||
func (m *Map[Key, Val]) GetErr(key Key, fn func(Val) error) (bool, error) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
@ -56,6 +64,15 @@ func (m *Map[Key, Val]) Set(key Key, val Val) {
|
||||
m.data[key] = val
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) Iter(fn func(key Key, val Val)) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
for key, val := range m.data {
|
||||
fn(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) Keys(fn func(keys []Key)) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
@ -70,28 +87,52 @@ func (m *Map[Key, Val]) Values(fn func(vals []Val)) {
|
||||
fn(maps.Values(m.data))
|
||||
}
|
||||
|
||||
func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(val Val) Ret) (Ret, bool) {
|
||||
func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) Ret, fallback func() Ret) Ret {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
val, ok := m.data[key]
|
||||
if !ok {
|
||||
return *new(Ret), false
|
||||
return fallback()
|
||||
}
|
||||
|
||||
return fn(val), true
|
||||
return fn(val)
|
||||
}
|
||||
|
||||
func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(val Val) (Ret, error)) (Ret, bool, error) {
|
||||
func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
val, ok := m.data[key]
|
||||
if !ok {
|
||||
return *new(Ret), false, nil
|
||||
return fallback()
|
||||
}
|
||||
|
||||
ret, err := fn(val)
|
||||
|
||||
return ret, true, err
|
||||
return fn(val)
|
||||
}
|
||||
|
||||
func FindMap[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) Ret, fallback func() Ret) Ret {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
for _, val := range m.data {
|
||||
if cmp(val) {
|
||||
return fn(val)
|
||||
}
|
||||
}
|
||||
|
||||
return fallback()
|
||||
}
|
||||
|
||||
func FindMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
for _, val := range m.data {
|
||||
if cmp(val) {
|
||||
return fn(val)
|
||||
}
|
||||
}
|
||||
|
||||
return fallback()
|
||||
}
|
||||
|
||||
49
internal/try/try.go
Normal file
49
internal/try/try.go
Normal file
@ -0,0 +1,49 @@
|
||||
package try
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Catch tries to execute the `try` function, and if it fails or panics,
|
||||
// it executes the `handlers` functions in order.
|
||||
func Catch(try func() error, handlers ...func() error) error {
|
||||
if _, err := CatchVal(func() (any, error) { return nil, try() }, handlers...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CatchVal tries to execute the `try` function, and if it fails or panics,
|
||||
// it executes the `handlers` functions in order.
|
||||
func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
catch(handlers...)
|
||||
err = fmt.Errorf("panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if res, err = try(); err != nil {
|
||||
catch(handlers...)
|
||||
return res, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func catch(handlers ...func() error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logrus.WithField("panic", r).Error("Panic in catch")
|
||||
}
|
||||
}()
|
||||
|
||||
for _, handler := range handlers {
|
||||
if err := handler(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to handle error")
|
||||
}
|
||||
}
|
||||
}
|
||||
74
internal/try/try_test.go
Normal file
74
internal/try/try_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package try
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTry(t *testing.T) {
|
||||
res, err := CatchVal(func() (string, error) {
|
||||
return "foo", nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foo", res)
|
||||
}
|
||||
|
||||
func TestTryCatch(t *testing.T) {
|
||||
tryErr := fmt.Errorf("oops")
|
||||
|
||||
res, err := CatchVal(
|
||||
func() (string, error) {
|
||||
return "", tryErr
|
||||
},
|
||||
func() error {
|
||||
return nil
|
||||
},
|
||||
)
|
||||
require.ErrorIs(t, err, tryErr)
|
||||
require.Zero(t, res)
|
||||
}
|
||||
|
||||
func TestTryCatchError(t *testing.T) {
|
||||
tryErr := fmt.Errorf("oops")
|
||||
|
||||
res, err := CatchVal(
|
||||
func() (string, error) {
|
||||
return "", tryErr
|
||||
},
|
||||
func() error {
|
||||
return fmt.Errorf("catch error")
|
||||
},
|
||||
)
|
||||
require.ErrorIs(t, err, tryErr)
|
||||
require.Zero(t, res)
|
||||
}
|
||||
|
||||
func TestTryPanic(t *testing.T) {
|
||||
res, err := CatchVal(
|
||||
func() (string, error) {
|
||||
panic("oops")
|
||||
},
|
||||
func() error {
|
||||
return nil
|
||||
},
|
||||
)
|
||||
require.ErrorContains(t, err, "panic: oops")
|
||||
require.Zero(t, res)
|
||||
}
|
||||
|
||||
func TestTryCatchPanic(t *testing.T) {
|
||||
tryErr := fmt.Errorf("oops")
|
||||
|
||||
res, err := CatchVal(
|
||||
func() (string, error) {
|
||||
return "", tryErr
|
||||
},
|
||||
func() error {
|
||||
panic("oops")
|
||||
},
|
||||
)
|
||||
require.ErrorIs(t, err, tryErr)
|
||||
require.Zero(t, res)
|
||||
}
|
||||
@ -259,7 +259,12 @@ func (user *User) handleMessageEvents(ctx context.Context, messageEvents []litea
|
||||
}
|
||||
|
||||
func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
|
||||
buildRes, err := user.buildRFC822(ctx, event.Message)
|
||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get full message: %w", err)
|
||||
}
|
||||
|
||||
buildRes, err := buildRFC822(ctx, full, user.addrKRs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build RFC822: %w", err)
|
||||
}
|
||||
|
||||
@ -4,14 +4,11 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/bradenaw/juniper/iterator"
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/stream"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/google/uuid"
|
||||
@ -105,34 +102,38 @@ func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.
|
||||
|
||||
func (user *User) syncMessages(ctx context.Context) error {
|
||||
// Determine which messages to sync.
|
||||
metadata, err := user.client.GetAllMessageMetadata(ctx, nil)
|
||||
allMetadata, err := user.client.GetAllMessageMetadata(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get all message metadata: %w", err)
|
||||
}
|
||||
|
||||
// If possible, begin syncing from the last synced message.
|
||||
metadata := allMetadata
|
||||
|
||||
// If possible, begin syncing from one beyond the last synced message.
|
||||
if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" {
|
||||
if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool {
|
||||
return metadata.ID == beginID
|
||||
}); idx >= 0 {
|
||||
metadata = metadata[idx:]
|
||||
metadata = metadata[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// Process the metadata, building the messages.
|
||||
buildCh := stream.Chunk(parallel.MapStream(
|
||||
ctx,
|
||||
stream.FromIterator(iterator.Slice(metadata)),
|
||||
runtime.NumCPU()*runtime.NumCPU()/2,
|
||||
runtime.NumCPU()*runtime.NumCPU()/2,
|
||||
user.buildRFC822,
|
||||
buildCh := stream.Chunk(stream.Map(
|
||||
user.client.GetFullMessages(ctx, xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string {
|
||||
return metadata.ID
|
||||
})...),
|
||||
func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) {
|
||||
return buildRFC822(ctx, full, user.addrKRs)
|
||||
},
|
||||
), maxBatchSize)
|
||||
defer buildCh.Close()
|
||||
|
||||
// Create the flushers, one per update channel.
|
||||
flushers := make(map[string]*flusher)
|
||||
|
||||
for addrID, updateCh := range user.updateCh {
|
||||
flusher := newFlusher(user.ID(), updateCh, maxUpdateSize)
|
||||
flusher := newFlusher(updateCh, maxUpdateSize)
|
||||
defer flusher.flush(ctx, true)
|
||||
|
||||
flushers[addrID] = flusher
|
||||
@ -142,6 +143,8 @@ func (user *User) syncMessages(ctx context.Context) error {
|
||||
reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second)
|
||||
defer reporter.done()
|
||||
|
||||
var count int
|
||||
|
||||
// Send each update to the appropriate flusher.
|
||||
for {
|
||||
batch, err := buildCh.Next(ctx)
|
||||
@ -170,6 +173,8 @@ func (user *User) syncMessages(ctx context.Context) error {
|
||||
}
|
||||
|
||||
reporter.add(len(batch))
|
||||
|
||||
count += len(batch)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
@ -29,30 +30,20 @@ func defaultJobOpts() message.JobOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) buildRFC822(ctx context.Context, metadata liteapi.MessageMetadata) (*buildRes, error) {
|
||||
msg, err := user.client.GetMessage(ctx, metadata.ID)
|
||||
func buildRFC822(ctx context.Context, full liteapi.FullMessage, addrKRs map[string]*crypto.KeyRing) (*buildRes, error) {
|
||||
literal, err := message.BuildRFC822(addrKRs[full.AddressID], full.Message, full.AttData, defaultJobOpts())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get message %s: %w", metadata.ID, err)
|
||||
return nil, fmt.Errorf("failed to build message %s: %w", full.ID, err)
|
||||
}
|
||||
|
||||
attData, err := user.attPool.ProcessAll(ctx, xslices.Map(msg.Attachments, func(att liteapi.Attachment) string { return att.ID }))
|
||||
update, err := newMessageCreatedUpdate(full.MessageMetadata, literal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get attachments for message %s: %w", metadata.ID, err)
|
||||
}
|
||||
|
||||
literal, err := message.BuildRFC822(user.addrKRs[msg.AddressID], msg, attData, defaultJobOpts())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build message %s: %w", metadata.ID, err)
|
||||
}
|
||||
|
||||
update, err := newMessageCreatedUpdate(metadata, literal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IMAP update for message %s: %w", metadata.ID, err)
|
||||
return nil, fmt.Errorf("failed to create IMAP update for message %s: %w", full.ID, err)
|
||||
}
|
||||
|
||||
return &buildRes{
|
||||
messageID: metadata.ID,
|
||||
addressID: metadata.AddressID,
|
||||
messageID: full.ID,
|
||||
addressID: full.AddressID,
|
||||
update: update,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -9,21 +9,19 @@ import (
|
||||
)
|
||||
|
||||
type flusher struct {
|
||||
userID string
|
||||
updateCh *queue.QueuedChannel[imap.Update]
|
||||
updates []*imap.MessageCreated
|
||||
|
||||
updates []*imap.MessageCreated
|
||||
maxChunkSize int
|
||||
curChunkSize int
|
||||
maxUpdateSize int
|
||||
curChunkSize int
|
||||
|
||||
pushLock sync.Mutex
|
||||
}
|
||||
|
||||
func newFlusher(userID string, updateCh *queue.QueuedChannel[imap.Update], maxChunkSize int) *flusher {
|
||||
func newFlusher(updateCh *queue.QueuedChannel[imap.Update], maxUpdateSize int) *flusher {
|
||||
return &flusher{
|
||||
userID: userID,
|
||||
updateCh: updateCh,
|
||||
maxChunkSize: maxChunkSize,
|
||||
updateCh: updateCh,
|
||||
maxUpdateSize: maxUpdateSize,
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,20 +31,18 @@ func (f *flusher) push(ctx context.Context, update *imap.MessageCreated) {
|
||||
|
||||
f.updates = append(f.updates, update)
|
||||
|
||||
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
|
||||
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxUpdateSize {
|
||||
f.flush(ctx, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) flush(ctx context.Context, wait bool) {
|
||||
if len(f.updates) == 0 {
|
||||
return
|
||||
if len(f.updates) > 0 {
|
||||
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
|
||||
f.updates = nil
|
||||
f.curChunkSize = 0
|
||||
}
|
||||
|
||||
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
|
||||
f.updates = nil
|
||||
f.curChunkSize = 0
|
||||
|
||||
if wait {
|
||||
update := imap.NewNoop()
|
||||
defer update.WaitContext(ctx)
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/connector"
|
||||
@ -14,7 +13,6 @@ import (
|
||||
"github.com/ProtonMail/gluon/wait"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
@ -31,7 +29,6 @@ var (
|
||||
type User struct {
|
||||
vault *vault.User
|
||||
client *liteapi.Client
|
||||
attPool *pool.Pool[string, []byte]
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
|
||||
apiUser *safe.Type[liteapi.User]
|
||||
@ -91,7 +88,6 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
user := &User{
|
||||
vault: encVault,
|
||||
client: client,
|
||||
attPool: pool.New(runtime.NumCPU(), client.GetAttachment),
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
|
||||
apiUser: safe.NewType(apiUser),
|
||||
@ -123,7 +119,7 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
|
||||
// If we haven't synced yet, do it first.
|
||||
// If it fails, we don't start the event loop.
|
||||
// Oterwise, begin processing API events, logging any errors that occur.
|
||||
// Otherwise, begin processing API events, logging any errors that occur.
|
||||
go func() {
|
||||
if status := user.vault.SyncStatus(); !status.HasMessages {
|
||||
if err := <-user.startSync(); err != nil {
|
||||
@ -336,8 +332,17 @@ func (user *User) NewSMTPSession(email string) (smtp.Session, error) {
|
||||
}
|
||||
|
||||
// Logout logs the user out from the API.
|
||||
// If withVault is true, the user's vault is also cleared.
|
||||
func (user *User) Logout(ctx context.Context) error {
|
||||
return user.client.AuthDelete(ctx)
|
||||
if err := user.client.AuthDelete(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete auth: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.Clear(); err != nil {
|
||||
return fmt.Errorf("failed to clear vault: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes ongoing connections and cleans up resources.
|
||||
@ -345,9 +350,6 @@ func (user *User) Close() error {
|
||||
// Cancel ongoing syncs.
|
||||
user.stopSync()
|
||||
|
||||
// Close the attachment pool.
|
||||
user.attPool.Done()
|
||||
|
||||
// Close the user's API client.
|
||||
user.client.Close()
|
||||
|
||||
|
||||
@ -89,13 +89,13 @@ func withAPI(t *testing.T, ctx context.Context, fn func(context.Context, *server
|
||||
func withAccount(t *testing.T, s *server.Server, username, password string, emails []string, fn func(string, []string)) {
|
||||
var addrIDs []string
|
||||
|
||||
userID, addrID, err := s.CreateUser(username, password, emails[0])
|
||||
userID, addrID, err := s.CreateUser(username, emails[0], []byte(password))
|
||||
require.NoError(t, err)
|
||||
|
||||
addrIDs = append(addrIDs, addrID)
|
||||
|
||||
for _, email := range emails[1:] {
|
||||
addrID, err := s.CreateAddress(userID, email, password)
|
||||
addrID, err := s.CreateAddress(userID, email, []byte(password))
|
||||
require.NoError(t, err)
|
||||
|
||||
addrIDs = append(addrIDs, addrID)
|
||||
|
||||
@ -138,3 +138,12 @@ func (user *User) SetEventID(eventID string) error {
|
||||
data.EventID = eventID
|
||||
})
|
||||
}
|
||||
|
||||
// Clear clears the user's auth secrets.
|
||||
func (user *User) Clear() error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
data.AuthUID = ""
|
||||
data.AuthRef = ""
|
||||
data.KeyPass = nil
|
||||
})
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ func TestUser_Clear(t *testing.T) {
|
||||
require.Equal(t, "keyPass", string(user.KeyPass()))
|
||||
|
||||
// Clear the user's auth information.
|
||||
require.NoError(t, s.ClearUser("userID"))
|
||||
require.NoError(t, user.Clear())
|
||||
|
||||
// Check the user's cleared auth information.
|
||||
require.Empty(t, user.AuthUID())
|
||||
|
||||
@ -107,14 +107,6 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
|
||||
return vault.GetUser(userID)
|
||||
}
|
||||
|
||||
func (vault *Vault) ClearUser(userID string) error {
|
||||
return vault.modUser(userID, func(data *UserData) {
|
||||
data.AuthUID = ""
|
||||
data.AuthRef = ""
|
||||
data.KeyPass = nil
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteUser removes the given user from the vault.
|
||||
func (vault *Vault) DeleteUser(userID string) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
|
||||
@ -12,8 +12,8 @@ type API interface {
|
||||
GetHostURL() string
|
||||
AddCallWatcher(func(server.Call), ...string)
|
||||
|
||||
CreateUser(username, password, address string) (string, string, error)
|
||||
CreateAddress(userID, address, password string) (string, error)
|
||||
CreateUser(username, address string, password []byte) (string, string, error)
|
||||
CreateAddress(userID, address string, password []byte) (string, error)
|
||||
RemoveAddress(userID, addrID string) error
|
||||
RevokeUser(userID string) error
|
||||
|
||||
|
||||
@ -2,17 +2,19 @@ package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
|
||||
func (t *testCtx) startBridge() error {
|
||||
// Bridge will enable the proxy by default at startup.
|
||||
t.mocks.ProxyDialer.EXPECT().AllowProxy()
|
||||
t.mocks.ProxyCtl.EXPECT().AllowProxy()
|
||||
|
||||
// Get the path to the vault.
|
||||
vaultDir, err := t.locator.ProvideSettingsPath()
|
||||
@ -41,7 +43,8 @@ func (t *testCtx) startBridge() error {
|
||||
vault,
|
||||
useragent.New(),
|
||||
t.mocks.TLSReporter,
|
||||
t.mocks.ProxyDialer,
|
||||
liteapi.NewDialer(t.netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
|
||||
t.mocks.ProxyCtl,
|
||||
t.mocks.Autostarter,
|
||||
t.mocks.Updater,
|
||||
t.version,
|
||||
|
||||
@ -24,7 +24,7 @@ type testCtx struct {
|
||||
// These are the objects supporting the test.
|
||||
dir string
|
||||
api API
|
||||
dialer *bridge.TestDialer
|
||||
netCtl *liteapi.NetCtl
|
||||
locator *locations.Locations
|
||||
storeKey []byte
|
||||
version *semver.Version
|
||||
@ -76,15 +76,13 @@ type smtpClient struct {
|
||||
func newTestCtx(tb testing.TB) *testCtx {
|
||||
dir := tb.TempDir()
|
||||
|
||||
dialer := bridge.NewTestDialer()
|
||||
|
||||
ctx := &testCtx{
|
||||
dir: dir,
|
||||
api: newFakeAPI(),
|
||||
dialer: dialer,
|
||||
netCtl: liteapi.NewNetCtl(),
|
||||
locator: locations.New(bridge.NewTestLocationsProvider(dir), "config-name"),
|
||||
storeKey: []byte("super-secret-store-key"),
|
||||
mocks: bridge.NewMocks(tb, dialer, defaultVersion, defaultVersion),
|
||||
mocks: bridge.NewMocks(tb, defaultVersion, defaultVersion),
|
||||
version: defaultVersion,
|
||||
|
||||
userIDByName: make(map[string]string),
|
||||
|
||||
@ -38,12 +38,12 @@ func (s *scenario) itFailsWithError(wantErr string) error {
|
||||
}
|
||||
|
||||
func (s *scenario) internetIsTurnedOff() error {
|
||||
s.t.dialer.SetCanDial(false)
|
||||
s.t.netCtl.SetCanDial(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) internetIsTurnedOn() error {
|
||||
s.t.dialer.SetCanDial(true)
|
||||
s.t.netCtl.SetCanDial(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ import (
|
||||
|
||||
func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error {
|
||||
// Create the user.
|
||||
userID, addrID, err := s.t.api.CreateUser(username, password, username)
|
||||
userID, addrID, err := s.t.api.CreateUser(username, username, []byte(password))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -34,7 +34,7 @@ func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, passwor
|
||||
func (s *scenario) theAccountHasAdditionalAddress(username, address string) error {
|
||||
userID := s.t.getUserID(username)
|
||||
|
||||
addrID, err := s.t.api.CreateAddress(userID, address, s.t.getUserPass(userID))
|
||||
addrID, err := s.t.api.CreateAddress(userID, address, []byte(s.t.getUserPass(userID)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user