diff --git a/go.mod b/go.mod index 11cdba69..9c01e323 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.18 require ( github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557 github.com/Masterminds/semver/v3 v3.1.1 - github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb + github.com/ProtonMail/gluon v0.13.1-0.20221024090110-e36d02c8912a github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a github.com/ProtonMail/go-rfc5322 v0.11.0 github.com/ProtonMail/gopenpgp/v2 v2.4.10 diff --git a/go.sum b/go.sum index 19b37729..700cdd22 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,8 @@ github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo= github.com/ProtonMail/docker-credential-helpers v1.1.0 h1:+kvUIpwWcbtP3WFv5sSvkFn/XLzSqPOB5AAthuk9xPk= github.com/ProtonMail/docker-credential-helpers v1.1.0/go.mod h1:mK0aBveCxhnQ756AmaTfXMZDeULvheYVhF/MWMErN5g= -github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb h1:TGRkkuOdF3mIxbu5QMp62dJ2WvfQGgZ4MUVsNuRD7sc= -github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4= +github.com/ProtonMail/gluon v0.13.1-0.20221024090110-e36d02c8912a h1:kC9+LmLImdfwa+j7Y7YiSUoKnLz8K1/9tgtecrtoY5A= +github.com/ProtonMail/gluon v0.13.1-0.20221024090110-e36d02c8912a/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a/go.mod h1:oTGdE7/DlWIr23G0IKW3OXK9wZ5Hw1GGiaJFccTvZi4= github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo= diff --git a/internal/async/context.go b/internal/async/context.go index 6faabd34..ffb80565 100644 --- a/internal/async/context.go +++ b/internal/async/context.go @@ -51,3 +51,32 @@ func (a *Abortable) newCancelCtx(ctx context.Context) context.Context { return ctx } + +// RangeContext iterates over the given channel until the context is canceled or the +// channel is closed. +func RangeContext[T any](ctx context.Context, ch <-chan T, fn func(T)) { + for { + select { + case v, ok := <-ch: + if !ok { + return + } + + fn(v) + + case <-ctx.Done(): + return + } + } +} + +// ForwardContext forwards all values from the src channel to the dst channel until the +// context is canceled or the src channel is closed. +func ForwardContext[T any](ctx context.Context, dst chan<- T, src <-chan T) { + RangeContext(ctx, src, func(v T) { + select { + case dst <- v: + case <-ctx.Done(): + } + }) +} diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index d0baa2fd..49306d99 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -24,18 +24,22 @@ import ( "fmt" "net" "net/http" + "sync" "time" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gluon" + imapEvents "github.com/ProtonMail/gluon/events" "github.com/ProtonMail/gluon/watcher" + "github.com/ProtonMail/proton-bridge/v2/internal/async" "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/focus" "github.com/ProtonMail/proton-bridge/v2/internal/safe" - "github.com/ProtonMail/proton-bridge/v2/internal/try" "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/vault" + "github.com/bradenaw/juniper/xslices" + "github.com/bradenaw/juniper/xsync" "github.com/emersion/go-smtp" "github.com/go-resty/resty/v2" "github.com/sirupsen/logrus" @@ -48,8 +52,7 @@ type Bridge struct { // users holds authorized users. users *safe.Map[string, *user.User] - loadCh chan struct{} - loadWG try.Group + goLoad func() // api manages user API clients. api *liteapi.Manager @@ -62,14 +65,16 @@ type Bridge struct { // imapServer is the bridge's IMAP server. imapServer *gluon.Server imapListener net.Listener + imapEventCh chan imapEvents.Event // smtpServer is the bridge's SMTP server. - smtpServer *smtp.Server + smtpServer *smtp.Server + smtpListener net.Listener // updater is the bridge's updater. - updater Updater - curVersion *semver.Version - updateCheckCh chan struct{} + updater Updater + goUpdate func() + curVersion *semver.Version // focusService is used to raise the bridge window when needed. focusService *focus.Service @@ -81,7 +86,8 @@ type Bridge struct { locator Locator // watchers holds all registered event watchers. - watchers *safe.Slice[*watcher.Watcher[events.Event]] + watchers []*watcher.Watcher[events.Event] + watchersLock sync.RWMutex // errors contains errors encountered during startup. errors []error @@ -91,10 +97,8 @@ type Bridge struct { logIMAPServer bool logSMTP bool - // stopCh is used to stop ongoing goroutines when the bridge is closed. - stopCh chan struct{} - - closeEventChFn func() + // tasks manages the bridge's goroutines. + tasks *xsync.Group } // New creates a new bridge. @@ -115,6 +119,7 @@ func New( //nolint:funlen logIMAPClient, logIMAPServer bool, // whether to log IMAP client/server activity logSMTP bool, // whether to log SMTP activity ) (*Bridge, <-chan events.Event, error) { + // api is the user's API manager. api := liteapi.New( liteapi.WithHostURL(apiURL), liteapi.WithAppVersion(constants.AppVersion(curVersion.Original())), @@ -122,63 +127,63 @@ func New( //nolint:funlen liteapi.WithTransport(roundTripper), ) - tlsConfig, err := loadTLSConfig(vault) - if err != nil { - return nil, nil, fmt.Errorf("failed to load TLS config: %w", err) - } + // tasks holds all the bridge's background tasks. + tasks := xsync.NewGroup(context.Background()) - gluonDir, err := getGluonDir(vault) - if err != nil { - return nil, nil, fmt.Errorf("failed to get Gluon directory: %w", err) - } + // imapEventCh forwards IMAP events from gluon instances to the bridge for processing. + imapEventCh := make(chan imapEvents.Event) - imapServer, err := newIMAPServer(gluonDir, curVersion, tlsConfig, logIMAPClient, logIMAPServer) - if err != nil { - return nil, nil, fmt.Errorf("failed to create IMAP server: %w", err) - } + // users holds all the bridge's users. + users := safe.NewMap[string, *user.User](nil) - focusService, err := focus.NewService(curVersion) - if err != nil { - return nil, nil, fmt.Errorf("failed to create focus service: %w", err) - } + // bridge is the bridge. + bridge, err := newBridge( + users, + tasks, + imapEventCh, - bridge := newBridge( - // App stuff locator, vault, autostarter, updater, curVersion, - // API stuff api, identifier, proxyCtl, - - // Service stuff - tlsConfig, - imapServer, - focusService, - - // Logging stuff - logIMAPClient, - logIMAPServer, - logSMTP, + logIMAPClient, logIMAPServer, logSMTP, ) + if err != nil { + return nil, nil, fmt.Errorf("failed to create bridge: %w", err) + } // Get an event channel for all events (individual events can be subscribed to later). - eventCh, closeFn := bridge.GetEvents() - - bridge.closeEventChFn = closeFn + eventCh, _ := bridge.GetEvents() + // Initialize all of bridge's background tasks and operations. if err := bridge.init(tlsReporter); err != nil { return nil, nil, fmt.Errorf("failed to initialize bridge: %w", err) } + // Start serving IMAP. + if err := bridge.serveIMAP(); err != nil { + bridge.PushError(ErrServeIMAP) + } + + // Start serving SMTP. + if err := bridge.serveSMTP(); err != nil { + bridge.PushError(ErrServeSMTP) + } + return bridge, eventCh, nil } +// nolint:funlen func newBridge( + users *safe.Map[string, *user.User], + tasks *xsync.Group, + imapEventCh chan imapEvents.Event, + locator Locator, vault *vault.Vault, autostarter Autostarter, @@ -189,113 +194,146 @@ func newBridge( identifier Identifier, proxyCtl ProxyController, - tlsConfig *tls.Config, - imapServer *gluon.Server, - focusService *focus.Service, logIMAPClient, logIMAPServer, logSMTP bool, -) *Bridge { - bridge := &Bridge{ - vault: vault, +) (*Bridge, error) { + tlsConfig, err := loadTLSConfig(vault) + if err != nil { + return nil, fmt.Errorf("failed to load TLS config: %w", err) + } - users: safe.NewMap[string, *user.User](nil), - loadCh: make(chan struct{}, 1), + gluonDir, err := getGluonDir(vault) + if err != nil { + return nil, fmt.Errorf("failed to get Gluon directory: %w", err) + } + + imapServer, err := newIMAPServer( + gluonDir, + curVersion, + tlsConfig, + logIMAPClient, + logIMAPServer, + imapEventCh, + tasks, + ) + if err != nil { + return nil, fmt.Errorf("failed to create IMAP server: %w", err) + } + + focusService, err := focus.NewService(curVersion) + if err != nil { + return nil, fmt.Errorf("failed to create focus service: %w", err) + } + + return &Bridge{ + vault: vault, + users: users, api: api, proxyCtl: proxyCtl, identifier: identifier, - tlsConfig: tlsConfig, - imapServer: imapServer, + tlsConfig: tlsConfig, + imapServer: imapServer, + imapEventCh: imapEventCh, + smtpServer: newSMTPServer(users, tlsConfig, logSMTP), - updater: updater, - curVersion: curVersion, - updateCheckCh: make(chan struct{}, 1), + updater: updater, + curVersion: curVersion, focusService: focusService, autostarter: autostarter, locator: locator, - watchers: safe.NewSlice[*watcher.Watcher[events.Event]](), - logIMAPClient: logIMAPClient, logIMAPServer: logIMAPServer, logSMTP: logSMTP, - stopCh: make(chan struct{}), - } - - bridge.smtpServer = newSMTPServer(&smtpBackend{bridge}, tlsConfig, logSMTP) - - return bridge + tasks: tasks, + }, nil } +// nolint:funlen func (bridge *Bridge) init(tlsReporter TLSReporter) error { + // Enable or disable the proxy at startup. if bridge.vault.GetProxyAllowed() { bridge.proxyCtl.AllowProxy() } else { bridge.proxyCtl.DisallowProxy() } + // Handle connection up/down events. bridge.api.AddStatusObserver(func(status liteapi.Status) { switch { case status == liteapi.StatusUp: - go bridge.onStatusUp() + bridge.onStatusUp() case status == liteapi.StatusDown: - go bridge.onStatusDown() + bridge.onStatusDown() } }) + // If any call returns a bad version code, we need to update. bridge.api.AddErrorHandler(liteapi.AppVersionBadCode, func() { bridge.publish(events.UpdateForced{}) }) + // Ensure all outgoing headers have the correct user agent. bridge.api.AddPreRequestHook(func(_ *resty.Client, req *resty.Request) error { req.SetHeader("User-Agent", bridge.identifier.GetUserAgent()) return nil }) - go func() { - for range tlsReporter.GetTLSIssueCh() { + // Publish a TLS issue event if a TLS issue is encountered. + bridge.tasks.Once(func(ctx context.Context) { + async.RangeContext(ctx, tlsReporter.GetTLSIssueCh(), func(struct{}) { bridge.publish(events.TLSIssue{}) - } - }() + }) + }) - go func() { - for range bridge.focusService.GetRaiseCh() { + // Publish a raise event if the focus service is called. + bridge.tasks.Once(func(ctx context.Context) { + async.RangeContext(ctx, bridge.focusService.GetRaiseCh(), func(struct{}) { bridge.publish(events.Raise{}) - } - }() + }) + }) - go func() { - for event := range bridge.imapServer.AddWatcher() { + // Handle any IMAP events that are forwarded to the bridge from gluon. + bridge.tasks.Once(func(ctx context.Context) { + async.RangeContext(ctx, bridge.imapEventCh, func(event imapEvents.Event) { bridge.handleIMAPEvent(event) + }) + }) + + // Attempt to lazy load users when triggered. + bridge.goLoad = bridge.tasks.Trigger(func(ctx context.Context) { + if err := bridge.loadUsers(ctx); err != nil { + logrus.WithError(err).Error("Failed to load users") + } else { + bridge.publish(events.AllUsersLoaded{}) } - }() + }) + defer bridge.goLoad() - if err := bridge.serveIMAP(); err != nil { - bridge.PushError(ErrServeIMAP) - } - - if err := bridge.serveSMTP(); err != nil { - bridge.PushError(ErrServeSMTP) - } - - if err := bridge.watchForUpdates(); err != nil { - bridge.PushError(ErrWatchUpdates) - } - - go bridge.loadLoop() + // Check for updates when triggered. + bridge.goUpdate = bridge.tasks.PeriodicOrTrigger(constants.UpdateCheckInterval, 0, func(ctx context.Context) { + version, err := bridge.updater.GetVersionInfo(bridge.api, bridge.vault.GetUpdateChannel()) + if err != nil { + logrus.WithError(err).Error("Failed to get version info") + } else if err := bridge.handleUpdate(version); err != nil { + logrus.WithError(err).Error("Failed to handle update") + } + }) + defer bridge.goUpdate() return nil } // GetEvents returns a channel of events of the given type. // If no types are supplied, all events are returned. -func (bridge *Bridge) GetEvents(ofType ...events.Event) (<-chan events.Event, func()) { - newWatcher := bridge.addWatcher(ofType...) +func (bridge *Bridge) GetEvents(ofType ...events.Event) (<-chan events.Event, context.CancelFunc) { + watcher := bridge.addWatcher(ofType...) - return newWatcher.GetChannel(), func() { bridge.remWatcher(newWatcher) } + return watcher.GetChannel(), func() { bridge.remWatcher(watcher) } } func (bridge *Bridge) PushError(err error) { @@ -307,20 +345,6 @@ func (bridge *Bridge) GetErrors() []error { } func (bridge *Bridge) Close(ctx context.Context) error { - defer func() { - if bridge.closeEventChFn != nil { - bridge.closeEventChFn() - } - - bridge.closeEventChFn = nil - }() - - // Stop ongoing operations such as connectivity checks. - close(bridge.stopCh) - - // Wait for ongoing user load operations to finish. - bridge.loadWG.Wait() - // Close the IMAP server. if err := bridge.closeIMAP(ctx); err != nil { logrus.WithError(err).Error("Failed to close IMAP server") @@ -336,9 +360,22 @@ func (bridge *Bridge) Close(ctx context.Context) error { user.Close() }) + // Stop all ongoing tasks. + bridge.tasks.Wait() + // Close the focus service. bridge.focusService.Close() + // Close the watchers. + bridge.watchersLock.Lock() + defer bridge.watchersLock.Unlock() + + for _, watcher := range bridge.watchers { + watcher.Close() + } + + bridge.watchers = nil + // Save the last version of bridge that was run. if err := bridge.vault.SetLastVersion(bridge.curVersion); err != nil { logrus.WithError(err).Error("Failed to save last version") @@ -348,35 +385,51 @@ func (bridge *Bridge) Close(ctx context.Context) error { } func (bridge *Bridge) publish(event events.Event) { - bridge.watchers.Iter(func(watcher *watcher.Watcher[events.Event]) { + bridge.watchersLock.RLock() + defer bridge.watchersLock.RUnlock() + + for _, watcher := range bridge.watchers { if watcher.IsWatching(event) { if ok := watcher.Send(event); !ok { logrus.WithField("event", event).Warn("Failed to send event to watcher") } } - }) + } } func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events.Event] { - newWatcher := watcher.New(ofType...) + bridge.watchersLock.Lock() + defer bridge.watchersLock.Unlock() - bridge.watchers.Append(newWatcher) + watcher := watcher.New(ofType...) - return newWatcher + bridge.watchers = append(bridge.watchers, watcher) + + return watcher } -func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) { - oldWatcher.Close() - bridge.watchers.Delete(oldWatcher) +func (bridge *Bridge) remWatcher(watcher *watcher.Watcher[events.Event]) { + bridge.watchersLock.Lock() + defer bridge.watchersLock.Unlock() + + idx := xslices.Index(bridge.watchers, watcher) + + if idx < 0 { + return + } + + bridge.watchers = append(bridge.watchers[:idx], bridge.watchers[idx+1:]...) + + watcher.Close() } func (bridge *Bridge) onStatusUp() { bridge.publish(events.ConnStatusUp{}) - bridge.loadCh <- struct{}{} + bridge.goLoad() bridge.users.IterValues(func(user *user.User) { - user.OnStatusUp() + go user.OnStatusUp() }) } @@ -384,35 +437,30 @@ func (bridge *Bridge) onStatusDown() { bridge.publish(events.ConnStatusDown{}) bridge.users.IterValues(func(user *user.User) { - user.OnStatusDown() + go user.OnStatusDown() }) - upCh, done := bridge.GetEvents(events.ConnStatusUp{}) - defer done() + bridge.tasks.Once(func(ctx context.Context) { + backoff := time.Second - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + for { + select { + case <-ctx.Done(): + return - backoff := time.Second + case <-time.After(backoff): + if err := bridge.api.Ping(ctx); err != nil { + logrus.WithError(err).Debug("Failed to ping API, will retry") + } else { + return + } + } - for { - select { - case <-upCh: - return - - case <-bridge.stopCh: - return - - case <-time.After(backoff): - if err := bridge.api.Ping(ctx); err != nil { - logrus.WithError(err).Debug("Failed to ping API") + if backoff < 30*time.Second { + backoff *= 2 } } - - if backoff < 30*time.Second { - backoff *= 2 - } - } + }) } func loadTLSConfig(vault *vault.Vault) (*tls.Config, error) { diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index fbda4876..f15cb135 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -44,6 +44,7 @@ import ( "gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server/backend" + "go.uber.org/goleak" ) var ( @@ -54,6 +55,10 @@ var ( v2_4_0 = semver.MustParse("2.4.0") ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m, goleak.IgnoreCurrent()) +} + func init() { user.EventPeriod = 100 * time.Millisecond user.EventJitter = 0 @@ -356,17 +361,10 @@ func TestBridge_MissingGluonDir(t *testing.T) { } // withEnv creates the full test environment and runs the tests. -func withEnv(t *testing.T, tests func(context.Context, *server.Server, *liteapi.NetCtl, bridge.Locator, []byte)) { - server := server.New() +func withEnv(t *testing.T, tests func(context.Context, *server.Server, *liteapi.NetCtl, bridge.Locator, []byte), opts ...server.Option) { + server := server.New(opts...) defer server.Close() - withEnvServer(t, server, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { - tests(ctx, server, netCtl, locator, vaultKey) - }) -} - -// withEnvServer creates the full test environment and runs the tests. -func withEnvServer(t *testing.T, server *server.Server, tests func(context.Context, *liteapi.NetCtl, bridge.Locator, []byte)) { // Add test user. _, _, err := server.CreateUser(username, username+"@pm.me", password) require.NoError(t, err) @@ -386,7 +384,7 @@ func withEnvServer(t *testing.T, server *server.Server, tests func(context.Conte locations := locations.New(bridge.NewTestLocationsProvider(t.TempDir()), "config-name") // Run the tests. - tests(ctx, netCtl, locations, vaultKey) + tests(ctx, server, netCtl, locations, vaultKey) } // withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done. @@ -414,10 +412,6 @@ func withBridge( vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey) require.NoError(t, err) - // Let the IMAP and SMTP servers choose random available ports for this test. - require.NoError(t, vault.SetIMAPPort(0)) - require.NoError(t, vault.SetSMTPPort(0)) - // Create a new cookie jar. cookieJar, err := cookies.NewCookieJar(bridge.NewTestCookieJar(), vault) require.NoError(t, err) @@ -446,10 +440,15 @@ func withBridge( false, ) require.NoError(t, err) + require.Empty(t, bridge.GetErrors()) // Wait for bridge to finish loading users. waitForEvent(t, eventCh, events.AllUsersLoaded{}) + // Set random IMAP and SMTP ports for the tests. + require.NoError(t, bridge.SetIMAPPort(0)) + require.NoError(t, bridge.SetSMTPPort(0)) + // Close the bridge when done. defer func() { require.NoError(t, bridge.Close(ctx)) }() diff --git a/internal/bridge/imap.go b/internal/bridge/imap.go index aa5bda3b..de76c93f 100644 --- a/internal/bridge/imap.go +++ b/internal/bridge/imap.go @@ -24,17 +24,16 @@ import ( "fmt" "io" "io/fs" - "net" "os" - "strconv" - - "github.com/ProtonMail/proton-bridge/v2/internal/logging" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gluon" imapEvents "github.com/ProtonMail/gluon/events" + "github.com/ProtonMail/proton-bridge/v2/internal/async" "github.com/ProtonMail/proton-bridge/v2/internal/constants" + "github.com/ProtonMail/proton-bridge/v2/internal/logging" "github.com/ProtonMail/proton-bridge/v2/internal/vault" + "github.com/bradenaw/juniper/xsync" "github.com/sirupsen/logrus" ) @@ -55,34 +54,16 @@ func (bridge *Bridge) serveIMAP() error { return fmt.Errorf("failed to serve IMAP: %w", err) } - _, port, err := net.SplitHostPort(imapListener.Addr().String()) - if err != nil { - return fmt.Errorf("failed to get IMAP listener address: %w", err) + if err := bridge.vault.SetIMAPPort(getPort(imapListener.Addr())); err != nil { + return fmt.Errorf("failed to set IMAP port: %w", err) } - portInt, err := strconv.Atoi(port) - if err != nil { - return fmt.Errorf("failed to convert IMAP listener port to int: %w", err) - } - - if portInt != bridge.vault.GetIMAPPort() { - if err := bridge.vault.SetIMAPPort(portInt); err != nil { - return fmt.Errorf("failed to update IMAP port in vault: %w", err) - } - } - - go func() { - for err := range bridge.imapServer.GetErrorCh() { - logrus.WithError(err).Error("IMAP server error") - } - }() - return nil } func (bridge *Bridge) restartIMAP() error { if err := bridge.imapListener.Close(); err != nil { - logrus.WithError(err).Warn("Failed to close IMAP listener") + return fmt.Errorf("failed to close IMAP listener: %w", err) } return bridge.serveIMAP() @@ -103,12 +84,10 @@ func (bridge *Bridge) closeIMAP(ctx context.Context) error { func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) { switch event := event.(type) { case imapEvents.SessionAdded: - if bridge.identifier.HasClient() { - return + if !bridge.identifier.HasClient() { + bridge.identifier.SetClient(defaultClientName, defaultClientVersion) } - bridge.identifier.SetClient(defaultClientName, defaultClientVersion) - case imapEvents.IMAPID: bridge.identifier.SetClient(event.IMAPID.Name, event.IMAPID.Version) } @@ -137,11 +116,14 @@ func getGluonDir(encVault *vault.Vault) (string, error) { return encVault.GetGluonDir(), nil } +// nolint:funlen func newIMAPServer( gluonDir string, version *semver.Version, tlsConfig *tls.Config, logClient, logServer bool, + eventCh chan<- imapEvents.Event, + tasks *xsync.Group, ) (*gluon.Server, error) { if logClient || logServer { log := logrus.WithField("protocol", "IMAP") @@ -186,6 +168,16 @@ func newIMAPServer( return nil, err } + tasks.Once(func(ctx context.Context) { + async.ForwardContext(ctx, eventCh, imapServer.AddWatcher()) + }) + + tasks.Once(func(ctx context.Context) { + async.RangeContext(ctx, imapServer.GetErrorCh(), func(err error) { + logrus.WithError(err).Error("IMAP server error") + }) + }) + return imapServer, nil } diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index 3cebf0e0..3331b3e7 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -20,6 +20,7 @@ package bridge import ( "context" "fmt" + "net" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v2/internal/updater" @@ -131,7 +132,15 @@ func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error return fmt.Errorf("failed to set new gluon dir: %w", err) } - imapServer, err := newIMAPServer(bridge.vault.GetGluonDir(), bridge.curVersion, bridge.tlsConfig, bridge.logIMAPClient, bridge.logIMAPServer) + imapServer, err := newIMAPServer( + bridge.vault.GetGluonDir(), + bridge.curVersion, + bridge.tlsConfig, + bridge.logIMAPClient, + bridge.logIMAPServer, + bridge.imapEventCh, + bridge.tasks, + ) if err != nil { return fmt.Errorf("failed to create new IMAP server: %w", err) } @@ -210,7 +219,7 @@ func (bridge *Bridge) SetAutoUpdate(autoUpdate bool) error { return err } - bridge.updateCheckCh <- struct{}{} + bridge.goUpdate() return nil } @@ -228,7 +237,7 @@ func (bridge *Bridge) SetUpdateChannel(channel updater.Channel) error { return err } - bridge.updateCheckCh <- struct{}{} + bridge.goUpdate() return nil } @@ -276,3 +285,16 @@ func (bridge *Bridge) FactoryReset(ctx context.Context) { logrus.WithError(err).Error("Failed to clear data paths") } } + +func getPort(addr net.Addr) int { + switch addr := addr.(type) { + case *net.TCPAddr: + return addr.Port + + case *net.UDPAddr: + return addr.Port + + default: + return 0 + } +} diff --git a/internal/bridge/smtp.go b/internal/bridge/smtp.go index eb79c681..26936107 100644 --- a/internal/bridge/smtp.go +++ b/internal/bridge/smtp.go @@ -18,12 +18,13 @@ package bridge import ( + "context" "crypto/tls" "fmt" - "net" - "strconv" "github.com/ProtonMail/proton-bridge/v2/internal/logging" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" + "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/emersion/go-smtp" @@ -36,26 +37,16 @@ func (bridge *Bridge) serveSMTP() error { return fmt.Errorf("failed to create SMTP listener: %w", err) } - go func() { + bridge.smtpListener = smtpListener + + bridge.tasks.Once(func(ctx context.Context) { if err := bridge.smtpServer.Serve(smtpListener); err != nil { - logrus.WithError(err).Error("SMTP server stopped") + logrus.WithError(err).Debug("SMTP server stopped") } - }() + }) - _, port, err := net.SplitHostPort(smtpListener.Addr().String()) - if err != nil { - return fmt.Errorf("failed to get SMTP listener address: %w", err) - } - - portInt, err := strconv.Atoi(port) - if err != nil { - return fmt.Errorf("failed to convert SMTP listener port to int: %w", err) - } - - if portInt != bridge.vault.GetSMTPPort() { - if err := bridge.vault.SetSMTPPort(portInt); err != nil { - return fmt.Errorf("failed to update SMTP port in vault: %w", err) - } + if err := bridge.vault.SetSMTPPort(getPort(smtpListener.Addr())); err != nil { + return fmt.Errorf("failed to set IMAP port: %w", err) } return nil @@ -63,26 +54,32 @@ func (bridge *Bridge) serveSMTP() error { func (bridge *Bridge) restartSMTP() error { if err := bridge.closeSMTP(); err != nil { - return err + return fmt.Errorf("failed to close SMTP: %w", err) } - bridge.smtpServer = newSMTPServer(&smtpBackend{bridge}, bridge.tlsConfig, bridge.logSMTP) + bridge.smtpServer = newSMTPServer(bridge.users, bridge.tlsConfig, bridge.logSMTP) return bridge.serveSMTP() } +// We close the listener ourselves even though it's also closed by smtpServer.Close(). +// This is because smtpServer.Serve() is called in a separate goroutine and might be executed +// after we've already closed the server. However, go-smtp has a bug; it blocks on the listener +// even after the server has been closed. So we close the listener ourselves to unblock it. func (bridge *Bridge) closeSMTP() error { - if err := bridge.smtpServer.Close(); err != nil { - logrus.WithError(err).Warn("Failed to close SMTP server") + if err := bridge.smtpListener.Close(); err != nil { + return fmt.Errorf("failed to close SMTP listener: %w", err) } - // Don't close the SMTP listener -- it's closed by the server. + if err := bridge.smtpServer.Close(); err != nil { + logrus.WithError(err).Debug("Failed to close SMTP server") + } return nil } -func newSMTPServer(smtpBackend *smtpBackend, tlsConfig *tls.Config, shouldLog bool) *smtp.Server { - smtpServer := smtp.NewServer(smtpBackend) +func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, shouldLog bool) *smtp.Server { + smtpServer := smtp.NewServer(&smtpBackend{users}) smtpServer.TLSConfig = tlsConfig smtpServer.Domain = constants.Host diff --git a/internal/bridge/smtp_backend.go b/internal/bridge/smtp_backend.go index d768a2a2..cbf870f3 100644 --- a/internal/bridge/smtp_backend.go +++ b/internal/bridge/smtp_backend.go @@ -21,16 +21,17 @@ import ( "fmt" "io" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/emersion/go-smtp" ) type smtpBackend struct { - bridge *Bridge + users *safe.Map[string, *user.User] } type smtpSession struct { - bridge *Bridge + users *safe.Map[string, *user.User] userID string authID string @@ -41,12 +42,12 @@ type smtpSession struct { func (be *smtpBackend) NewSession(_ *smtp.Conn) (smtp.Session, error) { return &smtpSession{ - bridge: be.bridge, + users: be.users, }, nil } func (s *smtpSession) AuthPlain(username, password string) error { - return s.bridge.users.ValuesErr(func(users []*user.User) error { + return s.users.ValuesErr(func(users []*user.User) error { for _, user := range users { addrID, err := user.CheckAuth(username, []byte(password)) if err != nil { @@ -87,7 +88,7 @@ func (s *smtpSession) Rcpt(to string) error { } func (s *smtpSession) Data(r io.Reader) error { - if ok, err := s.bridge.users.GetErr(s.userID, func(user *user.User) error { + if ok, err := s.users.GetErr(s.userID, func(user *user.User) error { return user.SendMail(s.authID, s.from, s.to, r) }); !ok { return fmt.Errorf("no such user %q", s.userID) diff --git a/internal/bridge/sync_test.go b/internal/bridge/sync_test.go index ac960492..147db32f 100644 --- a/internal/bridge/sync_test.go +++ b/internal/bridge/sync_test.go @@ -23,10 +23,9 @@ import ( "os" "path/filepath" "runtime" + "sync/atomic" "testing" - "go.uber.org/goleak" - "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/bradenaw/juniper/iterator" @@ -38,78 +37,33 @@ import ( ) func TestBridge_Sync(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) - - s := server.New() - defer s.Close() - numMsg := 1 << 8 - withEnvServer(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password) require.NoError(t, err) labelID, err := s.CreateLabel(userID, "folder", liteapi.LabelTypeFolder) require.NoError(t, err) - literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml")) - require.NoError(t, err) - - c, _, err := liteapi.New( - liteapi.WithHostURL(s.GetHostURL()), - liteapi.WithTransport(liteapi.InsecureTransport()), - ).NewClientWithLogin(ctx, "imap", password) - require.NoError(t, err) - defer c.Close() - - user, err := c.GetUser(ctx) - require.NoError(t, err) - - addr, err := c.GetAddresses(ctx) - require.NoError(t, err) - require.Equal(t, addrID, addr[0].ID) - - salt, err := c.GetSalts(ctx) - require.NoError(t, err) - - keyPass, err := salt.SaltForKey(password, user.Keys.Primary().ID) - require.NoError(t, err) - - _, addrKRs, err := liteapi.Unlock(user, addr, keyPass) - require.NoError(t, err) - - require.NoError(t, getErr(stream.Collect(ctx, c.ImportMessages( - ctx, - addrKRs[addr[0].ID], - runtime.NumCPU(), - runtime.NumCPU(), - iterator.Collect(iterator.Map(iterator.Counter(numMsg), func(i int) liteapi.ImportReq { - return liteapi.ImportReq{ - Metadata: liteapi.ImportMetadata{ - AddressID: addr[0].ID, - LabelIDs: []string{labelID}, - Flags: liteapi.MessageFlagReceived, - }, - Message: literal, - } - }))..., - )))) - - var read uint64 - - netCtl.OnRead(func(b []byte) { - read += uint64(len(b)) + withClient(ctx, t, s, "imap", password, func(ctx context.Context, c *liteapi.Client) { + createMessages(ctx, t, c, addrID, labelID, numMsg) }) + var total uint64 + // The initial user should be fully synced. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) defer done() - userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) - require.NoError(t, err) + // Count how many bytes it takes to fully sync the user. + total = countBytesRead(netCtl, func() { + userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) + require.NoError(t, err) - require.Equal(t, userID, (<-syncCh).UserID) + require.Equal(t, userID, (<-syncCh).UserID) + }) }) // If we then connect an IMAP client, it should see all the messages. @@ -134,7 +88,7 @@ func TestBridge_Sync(t *testing.T) { }) // Pretend we can only sync 2/3 of the original messages. - netCtl.SetReadLimit(2 * read / 3) + netCtl.SetReadLimit(2 * total / 3) // Login the user; its sync should fail. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { @@ -184,7 +138,69 @@ func TestBridge_Sync(t *testing.T) { require.Equal(t, uint32(numMsg), status.Messages) } }) + }, server.WithTLS(false)) +} + +func withClient(ctx context.Context, t *testing.T, s *server.Server, username string, password []byte, fn func(context.Context, *liteapi.Client)) { + m := liteapi.New( + liteapi.WithHostURL(s.GetHostURL()), + liteapi.WithTransport(liteapi.InsecureTransport()), + ) + + c, _, err := m.NewClientWithLogin(ctx, username, password) + require.NoError(t, err) + defer c.Close() + + fn(ctx, c) +} + +func createMessages(ctx context.Context, t *testing.T, c *liteapi.Client, addrID, labelID string, count int) { + literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml")) + require.NoError(t, err) + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + keyPass, err := salt.SaltForKey(password, user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := liteapi.Unlock(user, addr, keyPass) + require.NoError(t, err) + + require.NoError(t, getErr(stream.Collect(ctx, c.ImportMessages( + ctx, + addrKRs[addrID], + runtime.NumCPU(), + runtime.NumCPU(), + iterator.Collect(iterator.Map(iterator.Counter(count), func(i int) liteapi.ImportReq { + return liteapi.ImportReq{ + Metadata: liteapi.ImportMetadata{ + AddressID: addrID, + LabelIDs: []string{labelID}, + Flags: liteapi.MessageFlagReceived, + }, + Message: literal, + } + }))..., + )))) +} + +func countBytesRead(ctl *liteapi.NetCtl, fn func()) uint64 { + var read uint64 + + ctl.OnRead(func(b []byte) { + atomic.AddUint64(&read, uint64(len(b))) }) + + fn() + + return read } func chToType[In, Out any](inCh <-chan In, done func()) (<-chan Out, func()) { diff --git a/internal/bridge/updates.go b/internal/bridge/updates.go index c63eaa44..4fefb080 100644 --- a/internal/bridge/updates.go +++ b/internal/bridge/updates.go @@ -18,51 +18,12 @@ package bridge import ( - "time" - - "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/updater" ) func (bridge *Bridge) CheckForUpdates() { - bridge.updateCheckCh <- struct{}{} -} - -func (bridge *Bridge) watchForUpdates() error { - if _, err := bridge.updater.GetVersionInfo(bridge.api, bridge.vault.GetUpdateChannel()); err != nil { - return err - } - - ticker := time.NewTicker(constants.UpdateCheckInterval) - - go func() { - for { - select { - case <-bridge.stopCh: - return - - case <-bridge.updateCheckCh: - // ... - - case <-ticker.C: - // ... - } - - version, err := bridge.updater.GetVersionInfo(bridge.api, bridge.vault.GetUpdateChannel()) - if err != nil { - continue - } - - if err := bridge.handleUpdate(version); err != nil { - continue - } - } - }() - - bridge.updateCheckCh <- struct{}{} - - return nil + bridge.goUpdate() } func (bridge *Bridge) handleUpdate(version updater.VersionInfo) error { diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 2eceaefe..571d5115 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -22,6 +22,7 @@ import ( "fmt" "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/proton-bridge/v2/internal/async" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/try" @@ -255,84 +256,43 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut return "", fmt.Errorf("failed to salt key password: %w", err) } - if err := bridge.addUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass); err != nil { + if err := bridge.addUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass, true); err != nil { return "", fmt.Errorf("failed to add bridge user: %w", err) } return apiUser.ID, nil } -// loadLoop is a loop that, when polled, attempts to load authorized users from the vault. -func (bridge *Bridge) loadLoop() { - for { - bridge.loadWG.GoTry(func(ok bool) { - if !ok { - return - } - - if err := bridge.loadUsers(); err != nil { - logrus.WithError(err).Error("Failed to load users") - } - }) - - select { - case <-bridge.stopCh: - return - - case <-bridge.loadCh: - } - } -} - // loadUsers tries to load each user in the vault that isn't already loaded. -func (bridge *Bridge) loadUsers() error { - if err := bridge.vault.ForUser(func(user *vault.User) error { - if bridge.users.Has(user.UserID()) { +func (bridge *Bridge) loadUsers(ctx context.Context) error { + return bridge.vault.ForUser(func(user *vault.User) error { + if bridge.users.Has(user.UserID()) || user.AuthUID() == "" { return nil } - if user.AuthUID() == "" { - return nil + if err := bridge.loadUser(ctx, user); err != nil { + logrus.WithError(err).Error("Failed to load connected user") + } else { + bridge.publish(events.UserLoaded{ + UserID: user.UserID(), + }) } - if err := bridge.loadUser(user); err != nil { - if _, ok := err.(*resty.ResponseError); ok { - logrus.WithError(err).Error("Failed to load connected user, clearing its secrets from vault") - - if err := user.Clear(); err != nil { - logrus.WithError(err).Error("Failed to clear user") - } - } else { - logrus.WithError(err).Error("Failed to load connected user") - } - - return nil - } - - bridge.publish(events.UserLoaded{ - UserID: user.UserID(), - }) - return nil - }); err != nil { - return fmt.Errorf("failed to iterate over users: %w", err) - } - - bridge.publish(events.AllUsersLoaded{}) - - return nil + }) } // loadUser loads an existing user from the vault. -func (bridge *Bridge) loadUser(user *vault.User) error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - +func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error { client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef()) if err != nil { return fmt.Errorf("failed to create API client: %w", err) } + if err := user.SetAuth(auth.UID, auth.RefreshToken); err != nil { + return fmt.Errorf("failed to set auth: %w", err) + } + return try.Catch( func() error { apiUser, err := client.GetUser(ctx) @@ -340,10 +300,11 @@ func (bridge *Bridge) loadUser(user *vault.User) error { return fmt.Errorf("failed to get user: %w", err) } - return bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass()) - }, - func() error { - return client.AuthDelete(ctx) + if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil { + return fmt.Errorf("failed to add user: %w", err) + } + + return nil }, ) } @@ -355,15 +316,22 @@ func (bridge *Bridge) addUser( apiUser liteapi.User, authUID, authRef string, saltedKeyPass []byte, + isLogin bool, ) error { - vaultUser, isNew, err := bridge.newVaultUser(client, apiUser, authUID, authRef, saltedKeyPass) + vaultUser, isNew, err := bridge.newVaultUser(apiUser, authUID, authRef, saltedKeyPass) if err != nil { return fmt.Errorf("failed to add vault user: %w", err) } if err := bridge.addUserWithVault(ctx, client, apiUser, vaultUser); err != nil { - if err := vaultUser.Clear(); err != nil { - logrus.WithError(err).Error("Failed to clear vault user") + if _, ok := err.(*resty.ResponseError); ok || isLogin { + logrus.WithError(err).Error("Failed to add user, clearing its secrets from vault") + + if err := vaultUser.Clear(); err != nil { + logrus.WithError(err).Error("Failed to clear user secrets") + } + } else { + logrus.WithError(err).Error("Failed to add user") } if err := vaultUser.Close(); err != nil { @@ -371,6 +339,8 @@ func (bridge *Bridge) addUser( } if isNew { + logrus.Warn("Deleting newly added vault user") + if err := bridge.vault.DeleteUser(apiUser.ID); err != nil { logrus.WithError(err).Error("Failed to delete vault user") } @@ -394,10 +364,9 @@ func (bridge *Bridge) addUserWithVault( return fmt.Errorf("failed to create user: %w", err) } - if bridge.users.Has(apiUser.ID) { + if had := bridge.users.Set(apiUser.ID, user); had { panic("double add") } - bridge.users.Set(apiUser.ID, user) // Connect the user's address(es) to gluon. if err := bridge.addIMAPUser(ctx, user); err != nil { @@ -406,18 +375,15 @@ func (bridge *Bridge) addUserWithVault( // Handle events coming from the user before forwarding them to the bridge. // For example, if the user's addresses change, we need to update them in gluon. - go func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - for event := range user.GetEventCh() { + bridge.tasks.Once(func(ctx context.Context) { + async.RangeContext(ctx, user.GetEventCh(), func(event events.Event) { if err := bridge.handleUserEvent(ctx, user, event); err != nil { logrus.WithError(err).Error("Failed to handle user event") } else { bridge.publish(event) } - } - }() + }) + }) // Gluon will set the IMAP ID in the context, if known, before making requests on behalf of this user. // As such, if we find this ID in the context, we should use it to update our user agent. @@ -435,7 +401,6 @@ func (bridge *Bridge) addUserWithVault( // newVaultUser creates a new vault user from the given auth information. // If one already exists in the vault, its data will be updated. func (bridge *Bridge) newVaultUser( - _ *liteapi.Client, apiUser liteapi.User, authUID, authRef string, saltedKeyPass []byte, @@ -494,7 +459,7 @@ func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error { // logoutUser logs the given user out from bridge. func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error { - if ok, err := bridge.users.GetDeleteErr(userID, func(user *user.User) error { + if ok := bridge.users.GetDelete(userID, func(user *user.User) { for _, gluonID := range user.GetGluonIDs() { if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil { logrus.WithError(err).Error("Failed to remove IMAP user") @@ -506,12 +471,8 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error { } user.Close() - - return nil }); !ok { return ErrNoSuchUser - } else if err != nil { - return fmt.Errorf("failed to delete user: %w", err) } return nil diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index e7e80670..cc225973 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -19,6 +19,7 @@ package bridge_test import ( "context" + "fmt" "testing" "time" @@ -272,75 +273,105 @@ func TestBridge_LoginDeleteRestart(t *testing.T) { } func TestBridge_FailLoginRecover(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { - var read uint64 + for i := uint64(1); i < 10; i++ { + t.Run(fmt.Sprintf("read %v%% of the data", 100*i/10), func(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + var userID string - netCtl.OnRead(func(b []byte) { - read += uint64(len(b)) + // Log the user in, wait for it to sync, then log it out. + // (We don't want to count message sync data in the test.) + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) + defer done() + + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) + require.Equal(t, userID, (<-syncCh).UserID) + require.NoError(t, bridge.LogoutUser(ctx, userID)) + }) + + var total uint64 + + // Now that the user is synced, we can measure exactly how much data is needed during login. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + total = countBytesRead(netCtl, func() { + must(bridge.LoginFull(ctx, username, password, nil, nil)) + }) + + require.NoError(t, bridge.LogoutUser(ctx, userID)) + }) + + // Now simulate failing to login. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Simulate a partial read. + netCtl.SetReadLimit(i * total / 10) + + // We should fail to log the user in because we can't fully read its data. + require.Error(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) + + // The user should still be there (but disconnected). + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + }) + + // Simulate the network recovering. + netCtl.SetReadLimit(0) + + // We should now be able to log the user in. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) + + // The user should be there, now connected. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) + }) + }) }) - - var userID string - - // Log the user in and record how much data was read. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) - require.NoError(t, bridge.LogoutUser(ctx, userID)) - }) - - // Now simulate failing to login. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // Simulate a partial read. - netCtl.SetReadLimit(3 * read / 4) - - // We should fail to log the user in because we can't fully read its data. - require.Error(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) - - // The user should still be there (but disconnected). - require.Equal(t, []string{userID}, bridge.GetUserIDs()) - require.Empty(t, getConnectedUserIDs(t, bridge)) - }) - - // Simulate the network recovering. - netCtl.SetReadLimit(0) - - // We should now be able to log the user in. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) - - // The user should be there, now connected. - require.Equal(t, []string{userID}, bridge.GetUserIDs()) - require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) - }) - }) + } } func TestBridge_FailLoadRecover(t *testing.T) { - withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - must(bridge.LoginFull(ctx, username, password, nil, nil)) + for i := uint64(1); i < 10; i++ { + t.Run(fmt.Sprintf("read %v%% of the data", 100*i/10), func(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + var userID string + + // Log the user in and wait for it to sync. + // (We don't want to count message sync data in the test.) + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) + defer done() + + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) + require.Equal(t, userID, (<-syncCh).UserID) + }) + + // See how much data it takes to load the user at startup. + total := countBytesRead(netCtl, func() { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // ... + }) + }) + + // Simulate a partial read. + netCtl.SetReadLimit(i * total / 10) + + // We should fail to load the user; it should be listed but disconnected. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + }) + + // Simulate the network recovering. + netCtl.SetReadLimit(0) + + // We should now be able to load the user. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) + }) + }) }) - - var read uint64 - - netCtl.OnRead(func(b []byte) { - read += uint64(len(b)) - }) - - // Start bridge and record how much data was read. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - // ... - }) - - // Simulate a partial read. - netCtl.SetReadLimit(read / 2) - - // We should fail to load the user; it should be disconnected. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - userIDs := bridge.GetUserIDs() - - require.False(t, must(bridge.GetUserInfo(userIDs[0])).Connected) - }) - }) + } } func TestBridge_BridgePass(t *testing.T) {