From 6fbf6d90dc373b03bfbcabc62b04a0c37f1992ac Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Mon, 24 Oct 2022 13:25:11 +0200 Subject: [PATCH] Other: Fix IMAP/SMTP/Login leaks/race conditions Depending on the timing of bridge closure, it was possible for the IMAP/SMTP servers to not have started serving yet. By grouping this in a cancelable goroutine group (*xsync.Group), we mitigate this issue. Further, depending on internet disconnection timing during user login, it was possible for a user to be improperly logged in. This change fixes this and adds test coverage for it. Lastly, depending on timing, certain background tasks (updates check, connectivity ping) could be improperly started or closed. This change groups them in the *xsync.Group as well to be closed properly. --- go.mod | 2 +- go.sum | 4 +- internal/async/context.go | 29 +++ internal/bridge/bridge.go | 328 ++++++++++++++++++-------------- internal/bridge/bridge_test.go | 27 ++- internal/bridge/imap.go | 50 ++--- internal/bridge/settings.go | 28 ++- internal/bridge/smtp.go | 49 +++-- internal/bridge/smtp_backend.go | 11 +- internal/bridge/sync_test.go | 134 +++++++------ internal/bridge/updates.go | 41 +--- internal/bridge/user.go | 119 ++++-------- internal/bridge/user_test.go | 157 +++++++++------ 13 files changed, 518 insertions(+), 461 deletions(-) diff --git a/go.mod b/go.mod index 11cdba69..9c01e323 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.18 require ( github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557 github.com/Masterminds/semver/v3 v3.1.1 - github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb + github.com/ProtonMail/gluon v0.13.1-0.20221024090110-e36d02c8912a github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a github.com/ProtonMail/go-rfc5322 v0.11.0 github.com/ProtonMail/gopenpgp/v2 v2.4.10 diff --git a/go.sum b/go.sum index 19b37729..700cdd22 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,8 @@ github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo= github.com/ProtonMail/docker-credential-helpers v1.1.0 h1:+kvUIpwWcbtP3WFv5sSvkFn/XLzSqPOB5AAthuk9xPk= github.com/ProtonMail/docker-credential-helpers v1.1.0/go.mod h1:mK0aBveCxhnQ756AmaTfXMZDeULvheYVhF/MWMErN5g= -github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb h1:TGRkkuOdF3mIxbu5QMp62dJ2WvfQGgZ4MUVsNuRD7sc= -github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4= +github.com/ProtonMail/gluon v0.13.1-0.20221024090110-e36d02c8912a h1:kC9+LmLImdfwa+j7Y7YiSUoKnLz8K1/9tgtecrtoY5A= +github.com/ProtonMail/gluon v0.13.1-0.20221024090110-e36d02c8912a/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a/go.mod h1:oTGdE7/DlWIr23G0IKW3OXK9wZ5Hw1GGiaJFccTvZi4= github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo= diff --git a/internal/async/context.go b/internal/async/context.go index 6faabd34..ffb80565 100644 --- a/internal/async/context.go +++ b/internal/async/context.go @@ -51,3 +51,32 @@ func (a *Abortable) newCancelCtx(ctx context.Context) context.Context { return ctx } + +// RangeContext iterates over the given channel until the context is canceled or the +// channel is closed. +func RangeContext[T any](ctx context.Context, ch <-chan T, fn func(T)) { + for { + select { + case v, ok := <-ch: + if !ok { + return + } + + fn(v) + + case <-ctx.Done(): + return + } + } +} + +// ForwardContext forwards all values from the src channel to the dst channel until the +// context is canceled or the src channel is closed. +func ForwardContext[T any](ctx context.Context, dst chan<- T, src <-chan T) { + RangeContext(ctx, src, func(v T) { + select { + case dst <- v: + case <-ctx.Done(): + } + }) +} diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index d0baa2fd..49306d99 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -24,18 +24,22 @@ import ( "fmt" "net" "net/http" + "sync" "time" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gluon" + imapEvents "github.com/ProtonMail/gluon/events" "github.com/ProtonMail/gluon/watcher" + "github.com/ProtonMail/proton-bridge/v2/internal/async" "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/focus" "github.com/ProtonMail/proton-bridge/v2/internal/safe" - "github.com/ProtonMail/proton-bridge/v2/internal/try" "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/vault" + "github.com/bradenaw/juniper/xslices" + "github.com/bradenaw/juniper/xsync" "github.com/emersion/go-smtp" "github.com/go-resty/resty/v2" "github.com/sirupsen/logrus" @@ -48,8 +52,7 @@ type Bridge struct { // users holds authorized users. users *safe.Map[string, *user.User] - loadCh chan struct{} - loadWG try.Group + goLoad func() // api manages user API clients. api *liteapi.Manager @@ -62,14 +65,16 @@ type Bridge struct { // imapServer is the bridge's IMAP server. imapServer *gluon.Server imapListener net.Listener + imapEventCh chan imapEvents.Event // smtpServer is the bridge's SMTP server. - smtpServer *smtp.Server + smtpServer *smtp.Server + smtpListener net.Listener // updater is the bridge's updater. - updater Updater - curVersion *semver.Version - updateCheckCh chan struct{} + updater Updater + goUpdate func() + curVersion *semver.Version // focusService is used to raise the bridge window when needed. focusService *focus.Service @@ -81,7 +86,8 @@ type Bridge struct { locator Locator // watchers holds all registered event watchers. - watchers *safe.Slice[*watcher.Watcher[events.Event]] + watchers []*watcher.Watcher[events.Event] + watchersLock sync.RWMutex // errors contains errors encountered during startup. errors []error @@ -91,10 +97,8 @@ type Bridge struct { logIMAPServer bool logSMTP bool - // stopCh is used to stop ongoing goroutines when the bridge is closed. - stopCh chan struct{} - - closeEventChFn func() + // tasks manages the bridge's goroutines. + tasks *xsync.Group } // New creates a new bridge. @@ -115,6 +119,7 @@ func New( //nolint:funlen logIMAPClient, logIMAPServer bool, // whether to log IMAP client/server activity logSMTP bool, // whether to log SMTP activity ) (*Bridge, <-chan events.Event, error) { + // api is the user's API manager. api := liteapi.New( liteapi.WithHostURL(apiURL), liteapi.WithAppVersion(constants.AppVersion(curVersion.Original())), @@ -122,63 +127,63 @@ func New( //nolint:funlen liteapi.WithTransport(roundTripper), ) - tlsConfig, err := loadTLSConfig(vault) - if err != nil { - return nil, nil, fmt.Errorf("failed to load TLS config: %w", err) - } + // tasks holds all the bridge's background tasks. + tasks := xsync.NewGroup(context.Background()) - gluonDir, err := getGluonDir(vault) - if err != nil { - return nil, nil, fmt.Errorf("failed to get Gluon directory: %w", err) - } + // imapEventCh forwards IMAP events from gluon instances to the bridge for processing. + imapEventCh := make(chan imapEvents.Event) - imapServer, err := newIMAPServer(gluonDir, curVersion, tlsConfig, logIMAPClient, logIMAPServer) - if err != nil { - return nil, nil, fmt.Errorf("failed to create IMAP server: %w", err) - } + // users holds all the bridge's users. + users := safe.NewMap[string, *user.User](nil) - focusService, err := focus.NewService(curVersion) - if err != nil { - return nil, nil, fmt.Errorf("failed to create focus service: %w", err) - } + // bridge is the bridge. + bridge, err := newBridge( + users, + tasks, + imapEventCh, - bridge := newBridge( - // App stuff locator, vault, autostarter, updater, curVersion, - // API stuff api, identifier, proxyCtl, - - // Service stuff - tlsConfig, - imapServer, - focusService, - - // Logging stuff - logIMAPClient, - logIMAPServer, - logSMTP, + logIMAPClient, logIMAPServer, logSMTP, ) + if err != nil { + return nil, nil, fmt.Errorf("failed to create bridge: %w", err) + } // Get an event channel for all events (individual events can be subscribed to later). - eventCh, closeFn := bridge.GetEvents() - - bridge.closeEventChFn = closeFn + eventCh, _ := bridge.GetEvents() + // Initialize all of bridge's background tasks and operations. if err := bridge.init(tlsReporter); err != nil { return nil, nil, fmt.Errorf("failed to initialize bridge: %w", err) } + // Start serving IMAP. + if err := bridge.serveIMAP(); err != nil { + bridge.PushError(ErrServeIMAP) + } + + // Start serving SMTP. + if err := bridge.serveSMTP(); err != nil { + bridge.PushError(ErrServeSMTP) + } + return bridge, eventCh, nil } +// nolint:funlen func newBridge( + users *safe.Map[string, *user.User], + tasks *xsync.Group, + imapEventCh chan imapEvents.Event, + locator Locator, vault *vault.Vault, autostarter Autostarter, @@ -189,113 +194,146 @@ func newBridge( identifier Identifier, proxyCtl ProxyController, - tlsConfig *tls.Config, - imapServer *gluon.Server, - focusService *focus.Service, logIMAPClient, logIMAPServer, logSMTP bool, -) *Bridge { - bridge := &Bridge{ - vault: vault, +) (*Bridge, error) { + tlsConfig, err := loadTLSConfig(vault) + if err != nil { + return nil, fmt.Errorf("failed to load TLS config: %w", err) + } - users: safe.NewMap[string, *user.User](nil), - loadCh: make(chan struct{}, 1), + gluonDir, err := getGluonDir(vault) + if err != nil { + return nil, fmt.Errorf("failed to get Gluon directory: %w", err) + } + + imapServer, err := newIMAPServer( + gluonDir, + curVersion, + tlsConfig, + logIMAPClient, + logIMAPServer, + imapEventCh, + tasks, + ) + if err != nil { + return nil, fmt.Errorf("failed to create IMAP server: %w", err) + } + + focusService, err := focus.NewService(curVersion) + if err != nil { + return nil, fmt.Errorf("failed to create focus service: %w", err) + } + + return &Bridge{ + vault: vault, + users: users, api: api, proxyCtl: proxyCtl, identifier: identifier, - tlsConfig: tlsConfig, - imapServer: imapServer, + tlsConfig: tlsConfig, + imapServer: imapServer, + imapEventCh: imapEventCh, + smtpServer: newSMTPServer(users, tlsConfig, logSMTP), - updater: updater, - curVersion: curVersion, - updateCheckCh: make(chan struct{}, 1), + updater: updater, + curVersion: curVersion, focusService: focusService, autostarter: autostarter, locator: locator, - watchers: safe.NewSlice[*watcher.Watcher[events.Event]](), - logIMAPClient: logIMAPClient, logIMAPServer: logIMAPServer, logSMTP: logSMTP, - stopCh: make(chan struct{}), - } - - bridge.smtpServer = newSMTPServer(&smtpBackend{bridge}, tlsConfig, logSMTP) - - return bridge + tasks: tasks, + }, nil } +// nolint:funlen func (bridge *Bridge) init(tlsReporter TLSReporter) error { + // Enable or disable the proxy at startup. if bridge.vault.GetProxyAllowed() { bridge.proxyCtl.AllowProxy() } else { bridge.proxyCtl.DisallowProxy() } + // Handle connection up/down events. bridge.api.AddStatusObserver(func(status liteapi.Status) { switch { case status == liteapi.StatusUp: - go bridge.onStatusUp() + bridge.onStatusUp() case status == liteapi.StatusDown: - go bridge.onStatusDown() + bridge.onStatusDown() } }) + // If any call returns a bad version code, we need to update. bridge.api.AddErrorHandler(liteapi.AppVersionBadCode, func() { bridge.publish(events.UpdateForced{}) }) + // Ensure all outgoing headers have the correct user agent. bridge.api.AddPreRequestHook(func(_ *resty.Client, req *resty.Request) error { req.SetHeader("User-Agent", bridge.identifier.GetUserAgent()) return nil }) - go func() { - for range tlsReporter.GetTLSIssueCh() { + // Publish a TLS issue event if a TLS issue is encountered. + bridge.tasks.Once(func(ctx context.Context) { + async.RangeContext(ctx, tlsReporter.GetTLSIssueCh(), func(struct{}) { bridge.publish(events.TLSIssue{}) - } - }() + }) + }) - go func() { - for range bridge.focusService.GetRaiseCh() { + // Publish a raise event if the focus service is called. + bridge.tasks.Once(func(ctx context.Context) { + async.RangeContext(ctx, bridge.focusService.GetRaiseCh(), func(struct{}) { bridge.publish(events.Raise{}) - } - }() + }) + }) - go func() { - for event := range bridge.imapServer.AddWatcher() { + // Handle any IMAP events that are forwarded to the bridge from gluon. + bridge.tasks.Once(func(ctx context.Context) { + async.RangeContext(ctx, bridge.imapEventCh, func(event imapEvents.Event) { bridge.handleIMAPEvent(event) + }) + }) + + // Attempt to lazy load users when triggered. + bridge.goLoad = bridge.tasks.Trigger(func(ctx context.Context) { + if err := bridge.loadUsers(ctx); err != nil { + logrus.WithError(err).Error("Failed to load users") + } else { + bridge.publish(events.AllUsersLoaded{}) } - }() + }) + defer bridge.goLoad() - if err := bridge.serveIMAP(); err != nil { - bridge.PushError(ErrServeIMAP) - } - - if err := bridge.serveSMTP(); err != nil { - bridge.PushError(ErrServeSMTP) - } - - if err := bridge.watchForUpdates(); err != nil { - bridge.PushError(ErrWatchUpdates) - } - - go bridge.loadLoop() + // Check for updates when triggered. + bridge.goUpdate = bridge.tasks.PeriodicOrTrigger(constants.UpdateCheckInterval, 0, func(ctx context.Context) { + version, err := bridge.updater.GetVersionInfo(bridge.api, bridge.vault.GetUpdateChannel()) + if err != nil { + logrus.WithError(err).Error("Failed to get version info") + } else if err := bridge.handleUpdate(version); err != nil { + logrus.WithError(err).Error("Failed to handle update") + } + }) + defer bridge.goUpdate() return nil } // GetEvents returns a channel of events of the given type. // If no types are supplied, all events are returned. -func (bridge *Bridge) GetEvents(ofType ...events.Event) (<-chan events.Event, func()) { - newWatcher := bridge.addWatcher(ofType...) +func (bridge *Bridge) GetEvents(ofType ...events.Event) (<-chan events.Event, context.CancelFunc) { + watcher := bridge.addWatcher(ofType...) - return newWatcher.GetChannel(), func() { bridge.remWatcher(newWatcher) } + return watcher.GetChannel(), func() { bridge.remWatcher(watcher) } } func (bridge *Bridge) PushError(err error) { @@ -307,20 +345,6 @@ func (bridge *Bridge) GetErrors() []error { } func (bridge *Bridge) Close(ctx context.Context) error { - defer func() { - if bridge.closeEventChFn != nil { - bridge.closeEventChFn() - } - - bridge.closeEventChFn = nil - }() - - // Stop ongoing operations such as connectivity checks. - close(bridge.stopCh) - - // Wait for ongoing user load operations to finish. - bridge.loadWG.Wait() - // Close the IMAP server. if err := bridge.closeIMAP(ctx); err != nil { logrus.WithError(err).Error("Failed to close IMAP server") @@ -336,9 +360,22 @@ func (bridge *Bridge) Close(ctx context.Context) error { user.Close() }) + // Stop all ongoing tasks. + bridge.tasks.Wait() + // Close the focus service. bridge.focusService.Close() + // Close the watchers. + bridge.watchersLock.Lock() + defer bridge.watchersLock.Unlock() + + for _, watcher := range bridge.watchers { + watcher.Close() + } + + bridge.watchers = nil + // Save the last version of bridge that was run. if err := bridge.vault.SetLastVersion(bridge.curVersion); err != nil { logrus.WithError(err).Error("Failed to save last version") @@ -348,35 +385,51 @@ func (bridge *Bridge) Close(ctx context.Context) error { } func (bridge *Bridge) publish(event events.Event) { - bridge.watchers.Iter(func(watcher *watcher.Watcher[events.Event]) { + bridge.watchersLock.RLock() + defer bridge.watchersLock.RUnlock() + + for _, watcher := range bridge.watchers { if watcher.IsWatching(event) { if ok := watcher.Send(event); !ok { logrus.WithField("event", event).Warn("Failed to send event to watcher") } } - }) + } } func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events.Event] { - newWatcher := watcher.New(ofType...) + bridge.watchersLock.Lock() + defer bridge.watchersLock.Unlock() - bridge.watchers.Append(newWatcher) + watcher := watcher.New(ofType...) - return newWatcher + bridge.watchers = append(bridge.watchers, watcher) + + return watcher } -func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) { - oldWatcher.Close() - bridge.watchers.Delete(oldWatcher) +func (bridge *Bridge) remWatcher(watcher *watcher.Watcher[events.Event]) { + bridge.watchersLock.Lock() + defer bridge.watchersLock.Unlock() + + idx := xslices.Index(bridge.watchers, watcher) + + if idx < 0 { + return + } + + bridge.watchers = append(bridge.watchers[:idx], bridge.watchers[idx+1:]...) + + watcher.Close() } func (bridge *Bridge) onStatusUp() { bridge.publish(events.ConnStatusUp{}) - bridge.loadCh <- struct{}{} + bridge.goLoad() bridge.users.IterValues(func(user *user.User) { - user.OnStatusUp() + go user.OnStatusUp() }) } @@ -384,35 +437,30 @@ func (bridge *Bridge) onStatusDown() { bridge.publish(events.ConnStatusDown{}) bridge.users.IterValues(func(user *user.User) { - user.OnStatusDown() + go user.OnStatusDown() }) - upCh, done := bridge.GetEvents(events.ConnStatusUp{}) - defer done() + bridge.tasks.Once(func(ctx context.Context) { + backoff := time.Second - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + for { + select { + case <-ctx.Done(): + return - backoff := time.Second + case <-time.After(backoff): + if err := bridge.api.Ping(ctx); err != nil { + logrus.WithError(err).Debug("Failed to ping API, will retry") + } else { + return + } + } - for { - select { - case <-upCh: - return - - case <-bridge.stopCh: - return - - case <-time.After(backoff): - if err := bridge.api.Ping(ctx); err != nil { - logrus.WithError(err).Debug("Failed to ping API") + if backoff < 30*time.Second { + backoff *= 2 } } - - if backoff < 30*time.Second { - backoff *= 2 - } - } + }) } func loadTLSConfig(vault *vault.Vault) (*tls.Config, error) { diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index fbda4876..f15cb135 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -44,6 +44,7 @@ import ( "gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server/backend" + "go.uber.org/goleak" ) var ( @@ -54,6 +55,10 @@ var ( v2_4_0 = semver.MustParse("2.4.0") ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m, goleak.IgnoreCurrent()) +} + func init() { user.EventPeriod = 100 * time.Millisecond user.EventJitter = 0 @@ -356,17 +361,10 @@ func TestBridge_MissingGluonDir(t *testing.T) { } // withEnv creates the full test environment and runs the tests. -func withEnv(t *testing.T, tests func(context.Context, *server.Server, *liteapi.NetCtl, bridge.Locator, []byte)) { - server := server.New() +func withEnv(t *testing.T, tests func(context.Context, *server.Server, *liteapi.NetCtl, bridge.Locator, []byte), opts ...server.Option) { + server := server.New(opts...) defer server.Close() - withEnvServer(t, server, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { - tests(ctx, server, netCtl, locator, vaultKey) - }) -} - -// withEnvServer creates the full test environment and runs the tests. -func withEnvServer(t *testing.T, server *server.Server, tests func(context.Context, *liteapi.NetCtl, bridge.Locator, []byte)) { // Add test user. _, _, err := server.CreateUser(username, username+"@pm.me", password) require.NoError(t, err) @@ -386,7 +384,7 @@ func withEnvServer(t *testing.T, server *server.Server, tests func(context.Conte locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name") // Run the tests. - tests(ctx, netCtl, locations, vaultKey) + tests(ctx, server, netCtl, locations, vaultKey) } // withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done. @@ -414,10 +412,6 @@ func withBridge( vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey) require.NoError(t, err) - // Let the IMAP and SMTP servers choose random available ports for this test. - require.NoError(t, vault.SetIMAPPort(0)) - require.NoError(t, vault.SetSMTPPort(0)) - // Create a new cookie jar. cookieJar, err := cookies.NewCookieJar(bridge.NewTestCookieJar(), vault) require.NoError(t, err) @@ -446,10 +440,15 @@ func withBridge( false, ) require.NoError(t, err) + require.Empty(t, bridge.GetErrors()) // Wait for bridge to finish loading users. waitForEvent(t, eventCh, events.AllUsersLoaded{}) + // Set random IMAP and SMTP ports for the tests. + require.NoError(t, bridge.SetIMAPPort(0)) + require.NoError(t, bridge.SetSMTPPort(0)) + // Close the bridge when done. defer func() { require.NoError(t, bridge.Close(ctx)) }() diff --git a/internal/bridge/imap.go b/internal/bridge/imap.go index aa5bda3b..de76c93f 100644 --- a/internal/bridge/imap.go +++ b/internal/bridge/imap.go @@ -24,17 +24,16 @@ import ( "fmt" "io" "io/fs" - "net" "os" - "strconv" - - "github.com/ProtonMail/proton-bridge/v2/internal/logging" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gluon" imapEvents "github.com/ProtonMail/gluon/events" + "github.com/ProtonMail/proton-bridge/v2/internal/async" "github.com/ProtonMail/proton-bridge/v2/internal/constants" + "github.com/ProtonMail/proton-bridge/v2/internal/logging" "github.com/ProtonMail/proton-bridge/v2/internal/vault" + "github.com/bradenaw/juniper/xsync" "github.com/sirupsen/logrus" ) @@ -55,34 +54,16 @@ func (bridge *Bridge) serveIMAP() error { return fmt.Errorf("failed to serve IMAP: %w", err) } - _, port, err := net.SplitHostPort(imapListener.Addr().String()) - if err != nil { - return fmt.Errorf("failed to get IMAP listener address: %w", err) + if err := bridge.vault.SetIMAPPort(getPort(imapListener.Addr())); err != nil { + return fmt.Errorf("failed to set IMAP port: %w", err) } - portInt, err := strconv.Atoi(port) - if err != nil { - return fmt.Errorf("failed to convert IMAP listener port to int: %w", err) - } - - if portInt != bridge.vault.GetIMAPPort() { - if err := bridge.vault.SetIMAPPort(portInt); err != nil { - return fmt.Errorf("failed to update IMAP port in vault: %w", err) - } - } - - go func() { - for err := range bridge.imapServer.GetErrorCh() { - logrus.WithError(err).Error("IMAP server error") - } - }() - return nil } func (bridge *Bridge) restartIMAP() error { if err := bridge.imapListener.Close(); err != nil { - logrus.WithError(err).Warn("Failed to close IMAP listener") + return fmt.Errorf("failed to close IMAP listener: %w", err) } return bridge.serveIMAP() @@ -103,12 +84,10 @@ func (bridge *Bridge) closeIMAP(ctx context.Context) error { func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) { switch event := event.(type) { case imapEvents.SessionAdded: - if bridge.identifier.HasClient() { - return + if !bridge.identifier.HasClient() { + bridge.identifier.SetClient(defaultClientName, defaultClientVersion) } - bridge.identifier.SetClient(defaultClientName, defaultClientVersion) - case imapEvents.IMAPID: bridge.identifier.SetClient(event.IMAPID.Name, event.IMAPID.Version) } @@ -137,11 +116,14 @@ func getGluonDir(encVault *vault.Vault) (string, error) { return encVault.GetGluonDir(), nil } +// nolint:funlen func newIMAPServer( gluonDir string, version *semver.Version, tlsConfig *tls.Config, logClient, logServer bool, + eventCh chan<- imapEvents.Event, + tasks *xsync.Group, ) (*gluon.Server, error) { if logClient || logServer { log := logrus.WithField("protocol", "IMAP") @@ -186,6 +168,16 @@ func newIMAPServer( return nil, err } + tasks.Once(func(ctx context.Context) { + async.ForwardContext(ctx, eventCh, imapServer.AddWatcher()) + }) + + tasks.Once(func(ctx context.Context) { + async.RangeContext(ctx, imapServer.GetErrorCh(), func(err error) { + logrus.WithError(err).Error("IMAP server error") + }) + }) + return imapServer, nil } diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index 3cebf0e0..3331b3e7 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -20,6 +20,7 @@ package bridge import ( "context" "fmt" + "net" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v2/internal/updater" @@ -131,7 +132,15 @@ func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error return fmt.Errorf("failed to set new gluon dir: %w", err) } - imapServer, err := newIMAPServer(bridge.vault.GetGluonDir(), bridge.curVersion, bridge.tlsConfig, bridge.logIMAPClient, bridge.logIMAPServer) + imapServer, err := newIMAPServer( + bridge.vault.GetGluonDir(), + bridge.curVersion, + bridge.tlsConfig, + bridge.logIMAPClient, + bridge.logIMAPServer, + bridge.imapEventCh, + bridge.tasks, + ) if err != nil { return fmt.Errorf("failed to create new IMAP server: %w", err) } @@ -210,7 +219,7 @@ func (bridge *Bridge) SetAutoUpdate(autoUpdate bool) error { return err } - bridge.updateCheckCh <- struct{}{} + bridge.goUpdate() return nil } @@ -228,7 +237,7 @@ func (bridge *Bridge) SetUpdateChannel(channel updater.Channel) error { return err } - bridge.updateCheckCh <- struct{}{} + bridge.goUpdate() return nil } @@ -276,3 +285,16 @@ func (bridge *Bridge) FactoryReset(ctx context.Context) { logrus.WithError(err).Error("Failed to clear data paths") } } + +func getPort(addr net.Addr) int { + switch addr := addr.(type) { + case *net.TCPAddr: + return addr.Port + + case *net.UDPAddr: + return addr.Port + + default: + return 0 + } +} diff --git a/internal/bridge/smtp.go b/internal/bridge/smtp.go index eb79c681..26936107 100644 --- a/internal/bridge/smtp.go +++ b/internal/bridge/smtp.go @@ -18,12 +18,13 @@ package bridge import ( + "context" "crypto/tls" "fmt" - "net" - "strconv" "github.com/ProtonMail/proton-bridge/v2/internal/logging" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" + "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/emersion/go-smtp" @@ -36,26 +37,16 @@ func (bridge *Bridge) serveSMTP() error { return fmt.Errorf("failed to create SMTP listener: %w", err) } - go func() { + bridge.smtpListener = smtpListener + + bridge.tasks.Once(func(ctx context.Context) { if err := bridge.smtpServer.Serve(smtpListener); err != nil { - logrus.WithError(err).Error("SMTP server stopped") + logrus.WithError(err).Debug("SMTP server stopped") } - }() + }) - _, port, err := net.SplitHostPort(smtpListener.Addr().String()) - if err != nil { - return fmt.Errorf("failed to get SMTP listener address: %w", err) - } - - portInt, err := strconv.Atoi(port) - if err != nil { - return fmt.Errorf("failed to convert SMTP listener port to int: %w", err) - } - - if portInt != bridge.vault.GetSMTPPort() { - if err := bridge.vault.SetSMTPPort(portInt); err != nil { - return fmt.Errorf("failed to update SMTP port in vault: %w", err) - } + if err := bridge.vault.SetSMTPPort(getPort(smtpListener.Addr())); err != nil { + return fmt.Errorf("failed to set IMAP port: %w", err) } return nil @@ -63,26 +54,32 @@ func (bridge *Bridge) serveSMTP() error { func (bridge *Bridge) restartSMTP() error { if err := bridge.closeSMTP(); err != nil { - return err + return fmt.Errorf("failed to close SMTP: %w", err) } - bridge.smtpServer = newSMTPServer(&smtpBackend{bridge}, bridge.tlsConfig, bridge.logSMTP) + bridge.smtpServer = newSMTPServer(bridge.users, bridge.tlsConfig, bridge.logSMTP) return bridge.serveSMTP() } +// We close the listener ourselves even though it's also closed by smtpServer.Close(). +// This is because smtpServer.Serve() is called in a separate goroutine and might be executed +// after we've already closed the server. However, go-smtp has a bug; it blocks on the listener +// even after the server has been closed. So we close the listener ourselves to unblock it. func (bridge *Bridge) closeSMTP() error { - if err := bridge.smtpServer.Close(); err != nil { - logrus.WithError(err).Warn("Failed to close SMTP server") + if err := bridge.smtpListener.Close(); err != nil { + return fmt.Errorf("failed to close SMTP listener: %w", err) } - // Don't close the SMTP listener -- it's closed by the server. + if err := bridge.smtpServer.Close(); err != nil { + logrus.WithError(err).Debug("Failed to close SMTP server") + } return nil } -func newSMTPServer(smtpBackend *smtpBackend, tlsConfig *tls.Config, shouldLog bool) *smtp.Server { - smtpServer := smtp.NewServer(smtpBackend) +func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, shouldLog bool) *smtp.Server { + smtpServer := smtp.NewServer(&smtpBackend{users}) smtpServer.TLSConfig = tlsConfig smtpServer.Domain = constants.Host diff --git a/internal/bridge/smtp_backend.go b/internal/bridge/smtp_backend.go index d768a2a2..cbf870f3 100644 --- a/internal/bridge/smtp_backend.go +++ b/internal/bridge/smtp_backend.go @@ -21,16 +21,17 @@ import ( "fmt" "io" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/emersion/go-smtp" ) type smtpBackend struct { - bridge *Bridge + users *safe.Map[string, *user.User] } type smtpSession struct { - bridge *Bridge + users *safe.Map[string, *user.User] userID string authID string @@ -41,12 +42,12 @@ type smtpSession struct { func (be *smtpBackend) NewSession(_ *smtp.Conn) (smtp.Session, error) { return &smtpSession{ - bridge: be.bridge, + users: be.users, }, nil } func (s *smtpSession) AuthPlain(username, password string) error { - return s.bridge.users.ValuesErr(func(users []*user.User) error { + return s.users.ValuesErr(func(users []*user.User) error { for _, user := range users { addrID, err := user.CheckAuth(username, []byte(password)) if err != nil { @@ -87,7 +88,7 @@ func (s *smtpSession) Rcpt(to string) error { } func (s *smtpSession) Data(r io.Reader) error { - if ok, err := s.bridge.users.GetErr(s.userID, func(user *user.User) error { + if ok, err := s.users.GetErr(s.userID, func(user *user.User) error { return user.SendMail(s.authID, s.from, s.to, r) }); !ok { return fmt.Errorf("no such user %q", s.userID) diff --git a/internal/bridge/sync_test.go b/internal/bridge/sync_test.go index ac960492..147db32f 100644 --- a/internal/bridge/sync_test.go +++ b/internal/bridge/sync_test.go @@ -23,10 +23,9 @@ import ( "os" "path/filepath" "runtime" + "sync/atomic" "testing" - "go.uber.org/goleak" - "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/bradenaw/juniper/iterator" @@ -38,78 +37,33 @@ import ( ) func TestBridge_Sync(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) - - s := server.New() - defer s.Close() - numMsg := 1 << 8 - withEnvServer(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password) require.NoError(t, err) labelID, err := s.CreateLabel(userID, "folder", liteapi.LabelTypeFolder) require.NoError(t, err) - literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml")) - require.NoError(t, err) - - c, _, err := liteapi.New( - liteapi.WithHostURL(s.GetHostURL()), - liteapi.WithTransport(liteapi.InsecureTransport()), - ).NewClientWithLogin(ctx, "imap", password) - require.NoError(t, err) - defer c.Close() - - user, err := c.GetUser(ctx) - require.NoError(t, err) - - addr, err := c.GetAddresses(ctx) - require.NoError(t, err) - require.Equal(t, addrID, addr[0].ID) - - salt, err := c.GetSalts(ctx) - require.NoError(t, err) - - keyPass, err := salt.SaltForKey(password, user.Keys.Primary().ID) - require.NoError(t, err) - - _, addrKRs, err := liteapi.Unlock(user, addr, keyPass) - require.NoError(t, err) - - require.NoError(t, getErr(stream.Collect(ctx, c.ImportMessages( - ctx, - addrKRs[addr[0].ID], - runtime.NumCPU(), - runtime.NumCPU(), - iterator.Collect(iterator.Map(iterator.Counter(numMsg), func(i int) liteapi.ImportReq { - return liteapi.ImportReq{ - Metadata: liteapi.ImportMetadata{ - AddressID: addr[0].ID, - LabelIDs: []string{labelID}, - Flags: liteapi.MessageFlagReceived, - }, - Message: literal, - } - }))..., - )))) - - var read uint64 - - netCtl.OnRead(func(b []byte) { - read += uint64(len(b)) + withClient(ctx, t, s, "imap", password, func(ctx context.Context, c *liteapi.Client) { + createMessages(ctx, t, c, addrID, labelID, numMsg) }) + var total uint64 + // The initial user should be fully synced. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) defer done() - userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) - require.NoError(t, err) + // Count how many bytes it takes to fully sync the user. + total = countBytesRead(netCtl, func() { + userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) + require.NoError(t, err) - require.Equal(t, userID, (<-syncCh).UserID) + require.Equal(t, userID, (<-syncCh).UserID) + }) }) // If we then connect an IMAP client, it should see all the messages. @@ -134,7 +88,7 @@ func TestBridge_Sync(t *testing.T) { }) // Pretend we can only sync 2/3 of the original messages. - netCtl.SetReadLimit(2 * read / 3) + netCtl.SetReadLimit(2 * total / 3) // Login the user; its sync should fail. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { @@ -184,7 +138,69 @@ func TestBridge_Sync(t *testing.T) { require.Equal(t, uint32(numMsg), status.Messages) } }) + }, server.WithTLS(false)) +} + +func withClient(ctx context.Context, t *testing.T, s *server.Server, username string, password []byte, fn func(context.Context, *liteapi.Client)) { + m := liteapi.New( + liteapi.WithHostURL(s.GetHostURL()), + liteapi.WithTransport(liteapi.InsecureTransport()), + ) + + c, _, err := m.NewClientWithLogin(ctx, username, password) + require.NoError(t, err) + defer c.Close() + + fn(ctx, c) +} + +func createMessages(ctx context.Context, t *testing.T, c *liteapi.Client, addrID, labelID string, count int) { + literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml")) + require.NoError(t, err) + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + keyPass, err := salt.SaltForKey(password, user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := liteapi.Unlock(user, addr, keyPass) + require.NoError(t, err) + + require.NoError(t, getErr(stream.Collect(ctx, c.ImportMessages( + ctx, + addrKRs[addrID], + runtime.NumCPU(), + runtime.NumCPU(), + iterator.Collect(iterator.Map(iterator.Counter(count), func(i int) liteapi.ImportReq { + return liteapi.ImportReq{ + Metadata: liteapi.ImportMetadata{ + AddressID: addrID, + LabelIDs: []string{labelID}, + Flags: liteapi.MessageFlagReceived, + }, + Message: literal, + } + }))..., + )))) +} + +func countBytesRead(ctl *liteapi.NetCtl, fn func()) uint64 { + var read uint64 + + ctl.OnRead(func(b []byte) { + atomic.AddUint64(&read, uint64(len(b))) }) + + fn() + + return read } func chToType[In, Out any](inCh <-chan In, done func()) (<-chan Out, func()) { diff --git a/internal/bridge/updates.go b/internal/bridge/updates.go index c63eaa44..4fefb080 100644 --- a/internal/bridge/updates.go +++ b/internal/bridge/updates.go @@ -18,51 +18,12 @@ package bridge import ( - "time" - - "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/updater" ) func (bridge *Bridge) CheckForUpdates() { - bridge.updateCheckCh <- struct{}{} -} - -func (bridge *Bridge) watchForUpdates() error { - if _, err := bridge.updater.GetVersionInfo(bridge.api, bridge.vault.GetUpdateChannel()); err != nil { - return err - } - - ticker := time.NewTicker(constants.UpdateCheckInterval) - - go func() { - for { - select { - case <-bridge.stopCh: - return - - case <-bridge.updateCheckCh: - // ... - - case <-ticker.C: - // ... - } - - version, err := bridge.updater.GetVersionInfo(bridge.api, bridge.vault.GetUpdateChannel()) - if err != nil { - continue - } - - if err := bridge.handleUpdate(version); err != nil { - continue - } - } - }() - - bridge.updateCheckCh <- struct{}{} - - return nil + bridge.goUpdate() } func (bridge *Bridge) handleUpdate(version updater.VersionInfo) error { diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 2eceaefe..571d5115 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -22,6 +22,7 @@ import ( "fmt" "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/proton-bridge/v2/internal/async" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/try" @@ -255,84 +256,43 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut return "", fmt.Errorf("failed to salt key password: %w", err) } - if err := bridge.addUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass); err != nil { + if err := bridge.addUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass, true); err != nil { return "", fmt.Errorf("failed to add bridge user: %w", err) } return apiUser.ID, nil } -// loadLoop is a loop that, when polled, attempts to load authorized users from the vault. -func (bridge *Bridge) loadLoop() { - for { - bridge.loadWG.GoTry(func(ok bool) { - if !ok { - return - } - - if err := bridge.loadUsers(); err != nil { - logrus.WithError(err).Error("Failed to load users") - } - }) - - select { - case <-bridge.stopCh: - return - - case <-bridge.loadCh: - } - } -} - // loadUsers tries to load each user in the vault that isn't already loaded. -func (bridge *Bridge) loadUsers() error { - if err := bridge.vault.ForUser(func(user *vault.User) error { - if bridge.users.Has(user.UserID()) { +func (bridge *Bridge) loadUsers(ctx context.Context) error { + return bridge.vault.ForUser(func(user *vault.User) error { + if bridge.users.Has(user.UserID()) || user.AuthUID() == "" { return nil } - if user.AuthUID() == "" { - return nil + if err := bridge.loadUser(ctx, user); err != nil { + logrus.WithError(err).Error("Failed to load connected user") + } else { + bridge.publish(events.UserLoaded{ + UserID: user.UserID(), + }) } - if err := bridge.loadUser(user); err != nil { - if _, ok := err.(*resty.ResponseError); ok { - logrus.WithError(err).Error("Failed to load connected user, clearing its secrets from vault") - - if err := user.Clear(); err != nil { - logrus.WithError(err).Error("Failed to clear user") - } - } else { - logrus.WithError(err).Error("Failed to load connected user") - } - - return nil - } - - bridge.publish(events.UserLoaded{ - UserID: user.UserID(), - }) - return nil - }); err != nil { - return fmt.Errorf("failed to iterate over users: %w", err) - } - - bridge.publish(events.AllUsersLoaded{}) - - return nil + }) } // loadUser loads an existing user from the vault. -func (bridge *Bridge) loadUser(user *vault.User) error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - +func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error { client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef()) if err != nil { return fmt.Errorf("failed to create API client: %w", err) } + if err := user.SetAuth(auth.UID, auth.RefreshToken); err != nil { + return fmt.Errorf("failed to set auth: %w", err) + } + return try.Catch( func() error { apiUser, err := client.GetUser(ctx) @@ -340,10 +300,11 @@ func (bridge *Bridge) loadUser(user *vault.User) error { return fmt.Errorf("failed to get user: %w", err) } - return bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass()) - }, - func() error { - return client.AuthDelete(ctx) + if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil { + return fmt.Errorf("failed to add user: %w", err) + } + + return nil }, ) } @@ -355,15 +316,22 @@ func (bridge *Bridge) addUser( apiUser liteapi.User, authUID, authRef string, saltedKeyPass []byte, + isLogin bool, ) error { - vaultUser, isNew, err := bridge.newVaultUser(client, apiUser, authUID, authRef, saltedKeyPass) + vaultUser, isNew, err := bridge.newVaultUser(apiUser, authUID, authRef, saltedKeyPass) if err != nil { return fmt.Errorf("failed to add vault user: %w", err) } if err := bridge.addUserWithVault(ctx, client, apiUser, vaultUser); err != nil { - if err := vaultUser.Clear(); err != nil { - logrus.WithError(err).Error("Failed to clear vault user") + if _, ok := err.(*resty.ResponseError); ok || isLogin { + logrus.WithError(err).Error("Failed to add user, clearing its secrets from vault") + + if err := vaultUser.Clear(); err != nil { + logrus.WithError(err).Error("Failed to clear user secrets") + } + } else { + logrus.WithError(err).Error("Failed to add user") } if err := vaultUser.Close(); err != nil { @@ -371,6 +339,8 @@ func (bridge *Bridge) addUser( } if isNew { + logrus.Warn("Deleting newly added vault user") + if err := bridge.vault.DeleteUser(apiUser.ID); err != nil { logrus.WithError(err).Error("Failed to delete vault user") } @@ -394,10 +364,9 @@ func (bridge *Bridge) addUserWithVault( return fmt.Errorf("failed to create user: %w", err) } - if bridge.users.Has(apiUser.ID) { + if had := bridge.users.Set(apiUser.ID, user); had { panic("double add") } - bridge.users.Set(apiUser.ID, user) // Connect the user's address(es) to gluon. if err := bridge.addIMAPUser(ctx, user); err != nil { @@ -406,18 +375,15 @@ func (bridge *Bridge) addUserWithVault( // Handle events coming from the user before forwarding them to the bridge. // For example, if the user's addresses change, we need to update them in gluon. - go func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - for event := range user.GetEventCh() { + bridge.tasks.Once(func(ctx context.Context) { + async.RangeContext(ctx, user.GetEventCh(), func(event events.Event) { if err := bridge.handleUserEvent(ctx, user, event); err != nil { logrus.WithError(err).Error("Failed to handle user event") } else { bridge.publish(event) } - } - }() + }) + }) // Gluon will set the IMAP ID in the context, if known, before making requests on behalf of this user. // As such, if we find this ID in the context, we should use it to update our user agent. @@ -435,7 +401,6 @@ func (bridge *Bridge) addUserWithVault( // newVaultUser creates a new vault user from the given auth information. // If one already exists in the vault, its data will be updated. func (bridge *Bridge) newVaultUser( - _ *liteapi.Client, apiUser liteapi.User, authUID, authRef string, saltedKeyPass []byte, @@ -494,7 +459,7 @@ func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error { // logoutUser logs the given user out from bridge. func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error { - if ok, err := bridge.users.GetDeleteErr(userID, func(user *user.User) error { + if ok := bridge.users.GetDelete(userID, func(user *user.User) { for _, gluonID := range user.GetGluonIDs() { if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil { logrus.WithError(err).Error("Failed to remove IMAP user") @@ -506,12 +471,8 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error { } user.Close() - - return nil }); !ok { return ErrNoSuchUser - } else if err != nil { - return fmt.Errorf("failed to delete user: %w", err) } return nil diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index e7e80670..cc225973 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -19,6 +19,7 @@ package bridge_test import ( "context" + "fmt" "testing" "time" @@ -272,75 +273,105 @@ func TestBridge_LoginDeleteRestart(t *testing.T) { } func TestBridge_FailLoginRecover(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { - var read uint64 + for i := uint64(1); i < 10; i++ { + t.Run(fmt.Sprintf("read %v%% of the data", 100*i/10), func(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + var userID string - netCtl.OnRead(func(b []byte) { - read += uint64(len(b)) + // Log the user in, wait for it to sync, then log it out. + // (We don't want to count message sync data in the test.) + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) + defer done() + + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) + require.Equal(t, userID, (<-syncCh).UserID) + require.NoError(t, bridge.LogoutUser(ctx, userID)) + }) + + var total uint64 + + // Now that the user is synced, we can measure exactly how much data is needed during login. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + total = countBytesRead(netCtl, func() { + must(bridge.LoginFull(ctx, username, password, nil, nil)) + }) + + require.NoError(t, bridge.LogoutUser(ctx, userID)) + }) + + // Now simulate failing to login. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Simulate a partial read. + netCtl.SetReadLimit(i * total / 10) + + // We should fail to log the user in because we can't fully read its data. + require.Error(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) + + // The user should still be there (but disconnected). + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + }) + + // Simulate the network recovering. + netCtl.SetReadLimit(0) + + // We should now be able to log the user in. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) + + // The user should be there, now connected. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) + }) + }) }) - - var userID string - - // Log the user in and record how much data was read. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) - require.NoError(t, bridge.LogoutUser(ctx, userID)) - }) - - // Now simulate failing to login. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // Simulate a partial read. - netCtl.SetReadLimit(3 * read / 4) - - // We should fail to log the user in because we can't fully read its data. - require.Error(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) - - // The user should still be there (but disconnected). - require.Equal(t, []string{userID}, bridge.GetUserIDs()) - require.Empty(t, getConnectedUserIDs(t, bridge)) - }) - - // Simulate the network recovering. - netCtl.SetReadLimit(0) - - // We should now be able to log the user in. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) - - // The user should be there, now connected. - require.Equal(t, []string{userID}, bridge.GetUserIDs()) - require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) - }) - }) + } } func TestBridge_FailLoadRecover(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - must(bridge.LoginFull(ctx, username, password, nil, nil)) + for i := uint64(1); i < 10; i++ { + t.Run(fmt.Sprintf("read %v%% of the data", 100*i/10), func(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + var userID string + + // Log the user in and wait for it to sync. + // (We don't want to count message sync data in the test.) + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) + defer done() + + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) + require.Equal(t, userID, (<-syncCh).UserID) + }) + + // See how much data it takes to load the user at startup. + total := countBytesRead(netCtl, func() { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // ... + }) + }) + + // Simulate a partial read. + netCtl.SetReadLimit(i * total / 10) + + // We should fail to load the user; it should be listed but disconnected. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + }) + + // Simulate the network recovering. + netCtl.SetReadLimit(0) + + // We should now be able to load the user. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) + }) + }) }) - - var read uint64 - - netCtl.OnRead(func(b []byte) { - read += uint64(len(b)) - }) - - // Start bridge and record how much data was read. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // ... - }) - - // Simulate a partial read. - netCtl.SetReadLimit(read / 2) - - // We should fail to load the user; it should be disconnected. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - userIDs := bridge.GetUserIDs() - - require.False(t, must(bridge.GetUserInfo(userIDs[0])).Connected) - }) - }) + } } func TestBridge_BridgePass(t *testing.T) {