From fd63611b411920641bbe831ad61083bc154272ef Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Wed, 12 Oct 2022 00:20:04 +0200 Subject: [PATCH] Other: Safer user types --- go.mod | 2 +- go.sum | 4 +- internal/app/app.go | 12 +- internal/app/bridge.go | 7 +- internal/bridge/bridge.go | 93 +++--- internal/bridge/bridge_test.go | 26 +- internal/bridge/bug_report.go | 7 +- internal/bridge/configure.go | 67 +++-- internal/bridge/settings.go | 9 +- internal/bridge/smtp_backend.go | 10 +- internal/bridge/sync_test.go | 70 ++--- internal/bridge/user.go | 158 ++++++---- internal/bridge/user_events.go | 21 +- internal/events/user.go | 4 + internal/frontend/cli/frontend.go | 9 +- internal/frontend/grpc/service.go | 11 +- internal/safe/map.go | 226 ++++++++++---- internal/safe/map_test.go | 97 ++++++ internal/safe/set.go | 46 --- internal/safe/slice.go | 44 ++- internal/safe/slice_test.go | 34 +++ internal/safe/value.go | 48 ++- internal/safe/value_test.go | 37 +++ internal/try/try.go | 29 ++ internal/user/events.go | 151 +++------- internal/user/imap.go | 45 +-- internal/user/keys.go | 60 ++++ internal/user/smtp.go | 102 ++++--- internal/user/sync.go | 145 ++++----- internal/user/types.go | 83 ++++++ internal/user/user.go | 309 ++++++++++---------- internal/vault/types.go | 4 + internal/vault/vault.go | 32 +- internal/versioner/versioner_remove_test.go | 8 +- tests/ctx_bridge_test.go | 14 +- 35 files changed, 1253 insertions(+), 771 deletions(-) create mode 100644 internal/safe/map_test.go delete mode 100644 internal/safe/set.go create mode 100644 internal/safe/slice_test.go create mode 100644 internal/safe/value_test.go create mode 100644 internal/user/keys.go diff --git a/go.mod b/go.mod index ffaab172..7d95c386 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.0 github.com/urfave/cli/v2 v2.16.3 - gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012095146-bd94443eeb8e + gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455 golang.org/x/exp v0.0.0-20220921164117-439092de6870 golang.org/x/net v0.1.0 golang.org/x/sys v0.1.0 diff --git a/go.sum b/go.sum index b4a4ee56..34c6d438 100644 --- a/go.sum +++ b/go.sum @@ -397,8 +397,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0= github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= -gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012095146-bd94443eeb8e h1:UBgcmAYZ45ylLlfmc8/0evP40LwVthBHRoMgGqt4YV8= -gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012095146-bd94443eeb8e/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4= +gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455 h1:TWNT/rPSUGjYsNTwWx5Fd029LipSv+h1XuBwFSd5cAo= +gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= diff --git a/internal/app/app.go b/internal/app/app.go index 2c97810b..faddc405 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -10,6 +10,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/cookies" "github.com/ProtonMail/proton-bridge/v2/internal/crash" + "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/focus" bridgeCLI "github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc" @@ -23,6 +24,7 @@ import ( "github.com/urfave/cli/v2" ) +// Visible flags const ( flagCPUProfile = "cpu-prof" flagCPUProfileShort = "p" @@ -40,8 +42,10 @@ const ( flagLogIMAP = "log-imap" flagLogSMTP = "log-smtp" +) - // Hidden flags +// Hidden flags +const ( flagLauncher = "launcher" flagNoWindow = "no-window" ) @@ -137,7 +141,7 @@ func run(c *cli.Context) error { // Load the cookies from the vault. return withCookieJar(vault, func(cookieJar http.CookieJar) error { // Create a new bridge instance. - return withBridge(c, locations, identifier, reporter, vault, cookieJar, func(b *bridge.Bridge) error { + return withBridge(c, locations, identifier, reporter, vault, cookieJar, func(b *bridge.Bridge, eventCh <-chan events.Event) error { if insecure { logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted") b.PushError(bridge.ErrVaultInsecure) @@ -150,13 +154,13 @@ func run(c *cli.Context) error { switch { case c.Bool(flagCLI): - return bridgeCLI.New(b).Loop() + return bridgeCLI.New(b, eventCh).Loop() case c.Bool(flagNonInteractive): select {} default: - service, err := grpc.NewService(crashHandler, restarter, locations, b, !c.Bool(flagNoWindow)) + service, err := grpc.NewService(crashHandler, restarter, locations, b, eventCh, !c.Bool(flagNoWindow)) if err != nil { return fmt.Errorf("could not create service: %w", err) } diff --git a/internal/app/bridge.go b/internal/app/bridge.go index 7db9af80..d089c75f 100644 --- a/internal/app/bridge.go +++ b/internal/app/bridge.go @@ -12,6 +12,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/dialer" + "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/locations" "github.com/ProtonMail/proton-bridge/v2/internal/sentry" "github.com/ProtonMail/proton-bridge/v2/internal/updater" @@ -32,7 +33,7 @@ func withBridge( reporter *sentry.Reporter, vault *vault.Vault, cookieJar http.CookieJar, - fn func(*bridge.Bridge) error, + fn func(*bridge.Bridge, <-chan events.Event) error, ) error { // Get the current bridge version. version, err := semver.NewVersion(constants.Version) @@ -64,7 +65,7 @@ func withBridge( } // Create a new bridge. - bridge, err := bridge.New( + bridge, eventCh, err := bridge.New( // The app stuff. locations, vault, @@ -96,7 +97,7 @@ func withBridge( } }() - return fn(bridge) + return fn(bridge, eventCh) } func newAutostarter() (*autostart.App, error) { diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 879c8bbc..9f3e5ed2 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "net/http" - "sync" "time" "github.com/Masterminds/semver/v3" @@ -16,9 +15,10 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/focus" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" + "github.com/ProtonMail/proton-bridge/v2/internal/try" "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/vault" - "github.com/bradenaw/juniper/xslices" "github.com/emersion/go-smtp" "github.com/go-resty/resty/v2" "github.com/sirupsen/logrus" @@ -30,17 +30,15 @@ type Bridge struct { vault *vault.Vault // users holds authorized users. - users map[string]*user.User + users *safe.Map[string, *user.User] + loadCh chan struct{} + loadWG try.Group // api manages user API clients. api *liteapi.Manager proxyCtl ProxyController identifier Identifier - // watchers holds all registered event watchers. - watchers []*watcher.Watcher[events.Event] - watchersLock sync.RWMutex - // tlsConfig holds the bridge TLS config used by the IMAP and SMTP servers. tlsConfig *tls.Config @@ -66,6 +64,9 @@ type Bridge struct { // locator is the bridge's locator. locator Locator + // watchers holds all registered event watchers. + watchers *safe.Slice[*watcher.Watcher[events.Event]] + // errors contains errors encountered during startup. errors []error @@ -95,7 +96,7 @@ func New( logIMAPClient, logIMAPServer bool, // whether to log IMAP client/server activity logSMTP bool, // whether to log SMTP activity -) (*Bridge, error) { +) (*Bridge, <-chan events.Event, error) { api := liteapi.New( liteapi.WithHostURL(apiURL), liteapi.WithAppVersion(constants.AppVersion(curVersion.Original())), @@ -105,54 +106,62 @@ func New( tlsConfig, err := loadTLSConfig(vault) if err != nil { - return nil, fmt.Errorf("failed to load TLS config: %w", err) + return nil, nil, fmt.Errorf("failed to load TLS config: %w", err) } gluonDir, err := getGluonDir(vault) if err != nil { - return nil, fmt.Errorf("failed to get Gluon directory: %w", err) + return nil, nil, fmt.Errorf("failed to get Gluon directory: %w", err) } smtpBackend, err := newSMTPBackend() if err != nil { - return nil, fmt.Errorf("failed to create SMTP backend: %w", err) + return nil, nil, fmt.Errorf("failed to create SMTP backend: %w", err) } imapServer, err := newIMAPServer(gluonDir, curVersion, tlsConfig, logIMAPClient, logIMAPServer) if err != nil { - return nil, fmt.Errorf("failed to create IMAP server: %w", err) + return nil, nil, fmt.Errorf("failed to create IMAP server: %w", err) } focusService, err := focus.NewService() if err != nil { - return nil, fmt.Errorf("failed to create focus service: %w", err) + return nil, nil, fmt.Errorf("failed to create focus service: %w", err) } bridge := newBridge( + // App stuff locator, vault, autostarter, updater, curVersion, + // API stuff api, identifier, proxyCtl, + // Service stuff tlsConfig, imapServer, smtpBackend, focusService, + + // Logging stuff logIMAPClient, logIMAPServer, logSMTP, ) + // Get an event channel for all events (individual events can be subscribed to later). + eventCh, _ := bridge.GetEvents() + if err := bridge.init(tlsReporter); err != nil { - return nil, fmt.Errorf("failed to initialize bridge: %w", err) + return nil, nil, fmt.Errorf("failed to initialize bridge: %w", err) } - return bridge, nil + return bridge, eventCh, nil } func newBridge( @@ -174,7 +183,9 @@ func newBridge( ) *Bridge { return &Bridge{ vault: vault, - users: make(map[string]*user.User), + + users: safe.NewMap[string, *user.User](nil), + loadCh: make(chan struct{}, 1), api: api, proxyCtl: proxyCtl, @@ -193,6 +204,8 @@ func newBridge( autostarter: autostarter, locator: locator, + watchers: safe.NewSlice[*watcher.Watcher[events.Event]](), + logIMAPClient: logIMAPClient, logIMAPServer: logIMAPServer, logSMTP: logSMTP, @@ -227,10 +240,6 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error { return nil }) - if err := bridge.loadUsers(); err != nil { - return fmt.Errorf("failed to load users: %w", err) - } - go func() { for range tlsReporter.GetTLSIssueCh() { bridge.publish(events.TLSIssue{}) @@ -261,6 +270,8 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error { bridge.PushError(ErrWatchUpdates) } + go bridge.loadLoop() + return nil } @@ -288,6 +299,9 @@ func (bridge *Bridge) Close(ctx context.Context) error { // Stop ongoing operations such as connectivity checks. close(bridge.stopCh) + // Wait for ongoing user load operations to finish. + bridge.loadWG.Wait() + // Close the IMAP server. if err := bridge.closeIMAP(ctx); err != nil { logrus.WithError(err).Error("Failed to close IMAP server") @@ -299,10 +313,10 @@ func (bridge *Bridge) Close(ctx context.Context) error { } // Close all users. - for _, user := range bridge.users { - if err := user.Close(); err != nil { - logrus.WithError(err).Error("Failed to close user") - } + if err := bridge.users.IterValuesErr(func(user *user.User) error { + return user.Close() + }); err != nil { + logrus.WithError(err).Error("Failed to close users") } // Close the focus service. @@ -317,49 +331,44 @@ func (bridge *Bridge) Close(ctx context.Context) error { } func (bridge *Bridge) publish(event events.Event) { - bridge.watchersLock.RLock() - defer bridge.watchersLock.RUnlock() - - for _, watcher := range bridge.watchers { + bridge.watchers.Iter(func(watcher *watcher.Watcher[events.Event]) { if watcher.IsWatching(event) { if ok := watcher.Send(event); !ok { logrus.WithField("event", event).Warn("Failed to send event to watcher") } } - } + }) } func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events.Event] { - bridge.watchersLock.Lock() - defer bridge.watchersLock.Unlock() - newWatcher := watcher.New(ofType...) - bridge.watchers = append(bridge.watchers, newWatcher) + bridge.watchers.Append(newWatcher) return newWatcher } func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) { - bridge.watchersLock.Lock() - defer bridge.watchersLock.Unlock() - - bridge.watchers = xslices.Filter(bridge.watchers, func(other *watcher.Watcher[events.Event]) bool { - return other != oldWatcher - }) + bridge.watchers.Delete(oldWatcher) } func (bridge *Bridge) onStatusUp() { bridge.publish(events.ConnStatusUp{}) - if err := bridge.loadUsers(); err != nil { - logrus.WithError(err).Error("Failed to load users") - } + bridge.loadCh <- struct{}{} + + bridge.users.IterValues(func(user *user.User) { + user.OnStatusUp() + }) } func (bridge *Bridge) onStatusDown() { bridge.publish(events.ConnStatusDown{}) + bridge.users.IterValues(func(user *user.User) { + user.OnStatusDown() + }) + upCh, done := bridge.GetEvents(events.ConnStatusUp{}) defer done() diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index f8ad5517..c07e7572 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -136,7 +136,7 @@ func TestBridge_UserAgent(t *testing.T) { func TestBridge_Cookies(t *testing.T) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { - sessionIDs := safe.NewSet[string]() + sessionIDs := safe.NewValue([]string{}) // Save any session IDs we use. s.AddCallWatcher(func(call server.Call) { @@ -145,7 +145,9 @@ func TestBridge_Cookies(t *testing.T) { return } - sessionIDs.Insert(cookie.Value) + sessionIDs.Mod(func(sessionIDs *[]string) { + *sessionIDs = append(*sessionIDs, cookie.Value) + }) }) // Start bridge and add a user so that API assigns us a session ID via cookie. @@ -160,8 +162,8 @@ func TestBridge_Cookies(t *testing.T) { }) // We should have used just one session ID. - sessionIDs.Values(func(sessionIDs []string) { - require.Len(t, sessionIDs, 1) + sessionIDs.Load(func(sessionIDs []string) { + require.Len(t, xslices.Unique(sessionIDs), 1) }) }) } @@ -405,7 +407,7 @@ func withBridge( defer func() { require.NoError(t, cookieJar.PersistCookies()) }() // Create a new bridge. - bridge, err := bridge.New( + bridge, eventCh, err := bridge.New( // The app stuff. locator, vault, @@ -428,6 +430,9 @@ func withBridge( ) require.NoError(t, err) + // Wait for bridge to finish loading users. + waitForEvent(t, eventCh, events.AllUsersLoaded{}) + // Close the bridge when done. defer func() { require.NoError(t, bridge.Close(ctx)) }() @@ -435,6 +440,17 @@ func withBridge( tests(bridge, mocks) } +func waitForEvent[T any](t *testing.T, eventCh <-chan events.Event, wantEvent T) { + t.Helper() + + for event := range eventCh { + switch event.(type) { + case T: + return + } + } +} + // must is a helper function that panics on error. func must[T any](val T, err error) T { if err != nil { diff --git a/internal/bridge/bug_report.go b/internal/bridge/bug_report.go index 69e26b6d..55180ad4 100644 --- a/internal/bridge/bug_report.go +++ b/internal/bridge/bug_report.go @@ -41,7 +41,12 @@ func (bridge *Bridge) ReportBug(ctx context.Context, osType, osVersion, descript if info, err := bridge.QueryUserInfo(username); err == nil { account = info.Username } else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 { - account = bridge.users[userIDs[0]].Name() + user, err := bridge.vault.GetUser(userIDs[0]) + if err != nil { + return err + } + + account = user.Username() } var atts []liteapi.ReportBugAttachment diff --git a/internal/bridge/configure.go b/internal/bridge/configure.go index 7884c346..61534eed 100644 --- a/internal/bridge/configure.go +++ b/internal/bridge/configure.go @@ -1,47 +1,52 @@ package bridge import ( + "fmt" "strings" "github.com/ProtonMail/proton-bridge/v2/internal/clientconfig" "github.com/ProtonMail/proton-bridge/v2/internal/constants" + "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/vault" ) func (bridge *Bridge) ConfigureAppleMail(userID, address string) error { - user, ok := bridge.users[userID] - if !ok { - return ErrNoSuchUser - } - - if address == "" { - address = user.Emails()[0] - } - - username := address - addresses := address - - if user.GetAddressMode() == vault.CombinedMode { - username = user.Emails()[0] - addresses = strings.Join(user.Emails(), ",") - } - - // If configuring apple mail for Catalina or newer, users should use SSL. - if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() { - if err := bridge.SetSMTPSSL(true); err != nil { - return err + if ok, err := bridge.users.GetErr(userID, func(user *user.User) error { + if address == "" { + address = user.Emails()[0] } + + username := address + addresses := address + + if user.GetAddressMode() == vault.CombinedMode { + username = user.Emails()[0] + addresses = strings.Join(user.Emails(), ",") + } + + // If configuring apple mail for Catalina or newer, users should use SSL. + if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() { + if err := bridge.SetSMTPSSL(true); err != nil { + return err + } + } + + return (&clientconfig.AppleMail{}).Configure( + constants.Host, + bridge.vault.GetIMAPPort(), + bridge.vault.GetSMTPPort(), + bridge.vault.GetIMAPSSL(), + bridge.vault.GetSMTPSSL(), + username, + addresses, + user.BridgePass(), + ) + }); !ok { + return ErrNoSuchUser + } else if err != nil { + return fmt.Errorf("failed to configure apple mail: %w", err) } - return (&clientconfig.AppleMail{}).Configure( - constants.Host, - bridge.vault.GetIMAPPort(), - bridge.vault.GetSMTPPort(), - bridge.vault.GetIMAPSSL(), - bridge.vault.GetSMTPSSL(), - username, - addresses, - user.BridgePass(), - ) + return nil } diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index 37b2ceb2..67d62c8d 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -6,6 +6,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v2/internal/updater" + "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/vault" ) @@ -119,10 +120,10 @@ func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error bridge.imapServer = imapServer - for _, user := range bridge.users { - if err := bridge.addIMAPUser(ctx, user); err != nil { - return fmt.Errorf("failed to add IMAP user: %w", err) - } + if err := bridge.users.IterValuesErr(func(user *user.User) error { + return bridge.addIMAPUser(ctx, user) + }); err != nil { + return fmt.Errorf("failed to add users to new IMAP server: %w", err) } if err := bridge.serveIMAP(); err != nil { diff --git a/internal/bridge/smtp_backend.go b/internal/bridge/smtp_backend.go index ea408caf..0b5729dd 100644 --- a/internal/bridge/smtp_backend.go +++ b/internal/bridge/smtp_backend.go @@ -1,13 +1,10 @@ package bridge import ( - "crypto/subtle" - "strings" "sync" "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/emersion/go-smtp" - "golang.org/x/exp/slices" ) type smtpBackend struct { @@ -26,13 +23,12 @@ func (backend *smtpBackend) Login(state *smtp.ConnectionState, email, password s defer backend.usersLock.RUnlock() for _, user := range backend.users { - if subtle.ConstantTimeCompare(user.BridgePass(), []byte(password)) != 1 { + session, err := user.NewSMTPSession(email, []byte(password)) + if err != nil { continue } - if email := strings.ToLower(email); slices.Contains(user.Emails(), email) { - return user.NewSMTPSession(email) - } + return session, nil } return nil, ErrNoSuchUser diff --git a/internal/bridge/sync_test.go b/internal/bridge/sync_test.go index 9c764b50..e801a8fc 100644 --- a/internal/bridge/sync_test.go +++ b/internal/bridge/sync_test.go @@ -19,7 +19,7 @@ func TestBridge_Sync(t *testing.T) { s := server.New() defer s.Close() - numMsg := 1 << 10 + numMsg := 1 << 8 withEnv(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password) @@ -80,51 +80,51 @@ func TestBridge_Sync(t *testing.T) { // Login the user; its sync should fail. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - syncCh, done := chToType[events.Event, events.SyncFailed](bridge.GetEvents(events.SyncFailed{})) - defer done() + { + syncCh, done := chToType[events.Event, events.SyncFailed](bridge.GetEvents(events.SyncFailed{})) + defer done() - userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) - require.NoError(t, err) + userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) + require.NoError(t, err) - require.Equal(t, userID, (<-syncCh).UserID) + require.Equal(t, userID, (<-syncCh).UserID) - info, err := bridge.GetUserInfo(userID) - require.NoError(t, err) - require.True(t, info.Connected) + info, err := bridge.GetUserInfo(userID) + require.NoError(t, err) + require.True(t, info.Connected) - client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort())) - require.NoError(t, err) - require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) - defer func() { _ = client.Logout() }() + client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort())) + require.NoError(t, err) + require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + defer func() { _ = client.Logout() }() - status, err := client.Select(`Folders/folder`, false) - require.NoError(t, err) - require.Less(t, status.Messages, uint32(numMsg)) - }) + status, err := client.Select(`Folders/folder`, false) + require.NoError(t, err) + require.Less(t, status.Messages, uint32(numMsg)) + } - // Remove the network limit, allowing the sync to finish. - netCtl.SetReadLimit(0) + // Remove the network limit, allowing the sync to finish. + netCtl.SetReadLimit(0) - // Login the user; its sync should now finish. - // If we then connect an IMAP client, it should eventually see all the messages. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) - defer done() + { + syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) + defer done() - require.Equal(t, userID, (<-syncCh).UserID) + require.Equal(t, userID, (<-syncCh).UserID) - info, err := bridge.GetUserInfo(userID) - require.NoError(t, err) - require.True(t, info.Connected) + info, err := bridge.GetUserInfo(userID) + require.NoError(t, err) + require.True(t, info.Connected) - client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort())) - require.NoError(t, err) - require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) - defer func() { _ = client.Logout() }() + client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort())) + require.NoError(t, err) + require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + defer func() { _ = client.Logout() }() - status, err := client.Select(`Folders/folder`, false) - require.NoError(t, err) - require.Equal(t, uint32(numMsg), status.Messages) + status, err := client.Select(`Folders/folder`, false) + require.NoError(t, err) + require.Equal(t, uint32(numMsg), status.Messages) + } }) }) } diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 19431a3c..d42525b3 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -6,6 +6,7 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/try" "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/vault" @@ -53,23 +54,26 @@ func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) { return UserInfo{}, err } - user, ok := bridge.users[userID] - if !ok { - return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil + if info, ok := safe.MapGetRet(bridge.users, userID, func(user *user.User) UserInfo { + return getConnUserInfo(user) + }); ok { + return info, nil } - return getConnUserInfo(user), nil + return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil } // QueryUserInfo queries the user info by username or address. func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) { - for userID, user := range bridge.users { - if user.Match(query) { - return bridge.GetUserInfo(userID) + return safe.MapValuesRetErr(bridge.users, func(users []*user.User) (UserInfo, error) { + for _, user := range users { + if user.Match(query) { + return getConnUserInfo(user), nil + } } - } - return UserInfo{}, ErrNoSuchUser + return UserInfo{}, ErrNoSuchUser + }) } // LoginAuth begins the login process. It returns an authorized client that might need 2FA. @@ -79,7 +83,7 @@ func (bridge *Bridge) LoginAuth(ctx context.Context, username string, password [ return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err) } - if _, ok := bridge.users[auth.UserID]; ok { + if bridge.users.Has(auth.UserID) { if err := client.AuthDelete(ctx); err != nil { logrus.WithError(err).Warn("Failed to delete auth") } @@ -187,34 +191,37 @@ func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error { // SetAddressMode sets the address mode for the given user. func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error { - user, ok := bridge.users[userID] - if !ok { - return ErrNoSuchUser - } - - if user.GetAddressMode() == mode { - return fmt.Errorf("address mode is already %q", mode) - } - - for _, gluonID := range user.GetGluonIDs() { - if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil { - return fmt.Errorf("failed to remove user from IMAP server: %w", err) + if ok, err := bridge.users.GetErr(userID, func(user *user.User) error { + if user.GetAddressMode() == mode { + return fmt.Errorf("address mode is already %q", mode) } - } - if err := user.SetAddressMode(ctx, mode); err != nil { + for _, gluonID := range user.GetGluonIDs() { + if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil { + return fmt.Errorf("failed to remove user from IMAP server: %w", err) + } + } + + if err := user.SetAddressMode(ctx, mode); err != nil { + return fmt.Errorf("failed to set address mode: %w", err) + } + + if err := bridge.addIMAPUser(ctx, user); err != nil { + return fmt.Errorf("failed to add IMAP user: %w", err) + } + + bridge.publish(events.AddressModeChanged{ + UserID: userID, + AddressMode: mode, + }) + + return nil + }); !ok { + return ErrNoSuchUser + } else if err != nil { return fmt.Errorf("failed to set address mode: %w", err) } - if err := bridge.addIMAPUser(ctx, user); err != nil { - return fmt.Errorf("failed to add IMAP user: %w", err) - } - - bridge.publish(events.AddressModeChanged{ - UserID: userID, - AddressMode: mode, - }) - return nil } @@ -241,10 +248,30 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut return apiUser.ID, nil } -// loadUsers is a loop that, when polled, attempts to load authorized users from the vault. +// loadLoop is a loop that, when polled, attempts to load authorized users from the vault. +func (bridge *Bridge) loadLoop() { + for { + bridge.loadWG.GoTry(func(ok bool) { + if ok { + if err := bridge.loadUsers(); err != nil { + logrus.WithError(err).Error("Failed to load users") + } + } + }) + + select { + case <-bridge.stopCh: + return + + case <-bridge.loadCh: + // ... + } + } +} + func (bridge *Bridge) loadUsers() error { - return bridge.vault.ForUser(func(user *vault.User) error { - if _, ok := bridge.users[user.UserID()]; ok { + if err := bridge.vault.ForUser(func(user *vault.User) error { + if bridge.users.Has(user.UserID()) { return nil } @@ -271,7 +298,13 @@ func (bridge *Bridge) loadUsers() error { }) return nil - }) + }); err != nil { + return fmt.Errorf("failed to iterate over users: %w", err) + } + + bridge.publish(events.AllUsersLoaded{}) + + return nil } // loadUser loads an existing user from the vault. @@ -387,7 +420,7 @@ func (bridge *Bridge) addNewUser( return nil, err } - bridge.users[apiUser.ID] = user + bridge.users.Set(apiUser.ID, user) return user, nil } @@ -417,7 +450,7 @@ func (bridge *Bridge) addExistingUser( return nil, err } - bridge.users[apiUser.ID] = user + bridge.users.Set(apiUser.ID, user) return user, nil } @@ -451,37 +484,38 @@ func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error { // logoutUser logs the given user out from bridge. func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error { - user, ok := bridge.users[userID] - if !ok { - return ErrNoSuchUser - } - - if err := bridge.smtpBackend.removeUser(user); err != nil { - logrus.WithError(err).Error("Failed to remove user from SMTP backend") - } - - for _, gluonID := range user.GetGluonIDs() { - if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil { - logrus.WithError(err).Error("Failed to remove IMAP user") + if ok, err := bridge.users.GetDeleteErr(userID, func(user *user.User) error { + if err := bridge.smtpBackend.removeUser(user); err != nil { + logrus.WithError(err).Error("Failed to remove user from SMTP backend") } - } - if err := user.Logout(ctx); err != nil { - logrus.WithError(err).Error("Failed to logout user") - } + for _, gluonID := range user.GetGluonIDs() { + if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil { + logrus.WithError(err).Error("Failed to remove IMAP user") + } + } - if err := user.Close(); err != nil { - logrus.WithError(err).Error("Failed to close user") - } + if err := user.Logout(ctx); err != nil { + logrus.WithError(err).Error("Failed to logout user") + } - delete(bridge.users, userID) + if err := user.Close(); err != nil { + logrus.WithError(err).Error("Failed to close user") + } + + return nil + }); !ok { + return ErrNoSuchUser + } else if err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } return nil } // deleteUser deletes the given user from bridge. func (bridge *Bridge) deleteUser(ctx context.Context, userID string) { - if user, ok := bridge.users[userID]; ok { + if ok := bridge.users.GetDelete(userID, func(user *user.User) { if err := bridge.smtpBackend.removeUser(user); err != nil { logrus.WithError(err).Error("Failed to remove user from SMTP backend") } @@ -499,13 +533,13 @@ func (bridge *Bridge) deleteUser(ctx context.Context, userID string) { if err := user.Close(); err != nil { logrus.WithError(err).Error("Failed to close user") } + }); !ok { + logrus.Debug("The bridge user was not connected") } if err := bridge.vault.DeleteUser(userID); err != nil { logrus.WithError(err).Error("Failed to delete user from vault") } - - delete(bridge.users, userID) } // getUserInfo returns information about a disconnected user. diff --git a/internal/bridge/user_events.go b/internal/bridge/user_events.go index 8ac3435e..56a3dcf1 100644 --- a/internal/bridge/user_events.go +++ b/internal/bridge/user_events.go @@ -43,23 +43,13 @@ func (bridge *Bridge) handleUserAddressCreated(ctx context.Context, user *user.U return fmt.Errorf("failed to remove user from IMAP server: %w", err) } - imapConn, err := user.NewIMAPConnector(addrID) - if err != nil { - return fmt.Errorf("failed to create IMAP connector: %w", err) - } - - if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil { + if err := bridge.imapServer.LoadUser(ctx, user.NewIMAPConnector(addrID), gluonID, user.GluonKey()); err != nil { return fmt.Errorf("failed to add user to IMAP server: %w", err) } } case vault.SplitMode: - imapConn, err := user.NewIMAPConnector(event.AddressID) - if err != nil { - return fmt.Errorf("failed to create IMAP connector: %w", err) - } - - gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey()) + gluonID, err := bridge.imapServer.AddUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey()) if err != nil { return fmt.Errorf("failed to add user to IMAP server: %w", err) } @@ -93,12 +83,7 @@ func (bridge *Bridge) handleUserAddressDeleted(ctx context.Context, user *user.U return fmt.Errorf("failed to remove user from IMAP server: %w", err) } - imapConn, err := user.NewIMAPConnector(addrID) - if err != nil { - return fmt.Errorf("failed to create IMAP connector: %w", err) - } - - if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil { + if err := bridge.imapServer.LoadUser(ctx, user.NewIMAPConnector(addrID), gluonID, user.GluonKey()); err != nil { return fmt.Errorf("failed to add user to IMAP server: %w", err) } } diff --git a/internal/events/user.go b/internal/events/user.go index e4ec9e00..5f758110 100644 --- a/internal/events/user.go +++ b/internal/events/user.go @@ -2,6 +2,10 @@ package events import "github.com/ProtonMail/proton-bridge/v2/internal/vault" +type AllUsersLoaded struct { + eventBase +} + type UserLoaded struct { eventBase diff --git a/internal/frontend/cli/frontend.go b/internal/frontend/cli/frontend.go index a2503188..375ae8aa 100644 --- a/internal/frontend/cli/frontend.go +++ b/internal/frontend/cli/frontend.go @@ -38,7 +38,7 @@ type frontendCLI struct { } // New returns a new CLI frontend configured with the given options. -func New(bridge *bridge.Bridge) *frontendCLI { +func New(bridge *bridge.Bridge, eventCh <-chan events.Event) *frontendCLI { fe := &frontendCLI{ Shell: ishell.New(), bridge: bridge, @@ -253,15 +253,12 @@ func New(bridge *bridge.Bridge) *frontendCLI { Completer: fe.completeUsernames, }) - go fe.watchEvents() + go fe.watchEvents(eventCh) return fe } -func (f *frontendCLI) watchEvents() { - eventCh, done := f.bridge.GetEvents() - defer done() - +func (f *frontendCLI) watchEvents(eventCh <-chan events.Event) { // TODO: Better error events. for _, err := range f.bridge.GetErrors() { switch { diff --git a/internal/frontend/grpc/service.go b/internal/frontend/grpc/service.go index 63acf0d9..9c675541 100644 --- a/internal/frontend/grpc/service.go +++ b/internal/frontend/grpc/service.go @@ -64,6 +64,7 @@ type Service struct { // nolint:structcheck panicHandler *crash.Handler restarter *restarter.Restarter bridge *bridge.Bridge + eventCh <-chan events.Event newVersionInfo updater.VersionInfo authClient *liteapi.Client @@ -84,6 +85,7 @@ func NewService( restarter *restarter.Restarter, locations *locations.Locations, bridge *bridge.Bridge, + eventCh <-chan events.Event, showOnStartup bool, ) (*Service, error) { tlsConfig, certPEM, err := newTLSConfig() @@ -115,6 +117,7 @@ func NewService( panicHandler: panicHandler, restarter: restarter, bridge: bridge, + eventCh: eventCh, log: logrus.WithField("pkg", "grpc"), initializing: sync.WaitGroup{}, @@ -200,9 +203,6 @@ func (s *Service) WaitUntilFrontendIsReady() { } func (s *Service) watchEvents() { - eventCh, done := s.bridge.GetEvents() - defer done() - // TODO: Better error events. for _, err := range s.bridge.GetErrors() { switch { @@ -220,7 +220,7 @@ func (s *Service) watchEvents() { } } - for event := range eventCh { + for event := range s.eventCh { switch event := event.(type) { case events.ConnStatusUp: _ = s.SendEvent(NewInternetStatusEvent(true)) @@ -243,6 +243,9 @@ func (s *Service) watchEvents() { case events.UserChanged: _ = s.SendEvent(NewUserChangedEvent(event.UserID)) + case events.UserLoaded: + _ = s.SendEvent(NewUserChangedEvent(event.UserID)) + case events.UserLoggedIn: _ = s.SendEvent(NewUserChangedEvent(event.UserID)) diff --git a/internal/safe/map.go b/internal/safe/map.go index f3f08228..b5a9a801 100644 --- a/internal/safe/map.go +++ b/internal/safe/map.go @@ -3,18 +3,26 @@ package safe import ( "sync" - "golang.org/x/exp/maps" + "github.com/bradenaw/juniper/xslices" + "golang.org/x/exp/slices" ) type Map[Key comparable, Val any] struct { - data map[Key]Val - lock sync.RWMutex + data map[Key]Val + order []Key + sort func(a, b Key, data map[Key]Val) bool + lock sync.RWMutex } -func NewMap[Key comparable, Val any](from map[Key]Val) *Map[Key, Val] { - m := &Map[Key, Val]{ +func NewMap[Key comparable, Val any](sort func(a, b Key, data map[Key]Val) bool) *Map[Key, Val] { + return &Map[Key, Val]{ data: make(map[Key]Val), + sort: sort, } +} + +func NewMapFrom[Key comparable, Val any](from map[Key]Val, sort func(a, b Key, data map[Key]Val) bool) *Map[Key, Val] { + m := NewMap(sort) for key, val := range from { m.Set(key, val) @@ -23,12 +31,36 @@ func NewMap[Key comparable, Val any](from map[Key]Val) *Map[Key, Val] { return m } -func (m *Map[Key, Val]) Has(key Key) bool { +func (m *Map[Key, Val]) Index(idx int, fn func(Key, Val)) bool { m.lock.RLock() defer m.lock.RUnlock() - _, ok := m.data[key] - return ok + if idx < 0 || idx >= len(m.order) { + return false + } + + fn(m.order[idx], m.data[m.order[idx]]) + + return true +} + +func (m *Map[Key, Val]) Has(key Key) bool { + return m.HasFunc(func(k Key, v Val) bool { + return k == key + }) +} + +func (m *Map[Key, Val]) HasFunc(fn func(key Key, val Val) bool) bool { + m.lock.RLock() + defer m.lock.RUnlock() + + for key, val := range m.data { + if fn(key, val) { + return true + } + } + + return false } func (m *Map[Key, Val]) Get(key Key, fn func(Val)) bool { @@ -46,15 +78,45 @@ func (m *Map[Key, Val]) Get(key Key, fn func(Val)) bool { } func (m *Map[Key, Val]) GetErr(key Key, fn func(Val) error) (bool, error) { - m.lock.RLock() - defer m.lock.RUnlock() + var err error + + ok := m.Get(key, func(val Val) { + err = fn(val) + }) + + return ok, err +} + +func (m *Map[Key, Val]) GetDelete(key Key, fn func(Val)) bool { + m.lock.Lock() + defer m.lock.Unlock() val, ok := m.data[key] if !ok { - return false, nil + return false } - return true, fn(val) + fn(val) + + delete(m.data, key) + + if idx := xslices.Index(m.order, key); idx >= 0 { + m.order = append(m.order[:idx], m.order[idx+1:]...) + } else { + panic("order and data out of sync") + } + + return true +} + +func (m *Map[Key, Val]) GetDeleteErr(key Key, fn func(Val) error) (bool, error) { + var err error + + ok := m.GetDelete(key, func(val Val) { + err = fn(val) + }) + + return ok, err } func (m *Map[Key, Val]) Set(key Key, val Val) { @@ -62,84 +124,140 @@ func (m *Map[Key, Val]) Set(key Key, val Val) { defer m.lock.Unlock() m.data[key] = val + + m.order = append(m.order, key) + + if m.sort != nil { + slices.SortFunc(m.order, func(a, b Key) bool { + return m.sort(a, b, m.data) + }) + } } -func (m *Map[Key, Val]) Delete(key Key) { +func (m *Map[Key, Val]) SetFrom(key Key, other Key) { m.lock.Lock() defer m.lock.Unlock() - delete(m.data, key) + m.data[key] = m.data[other] + + m.order = append(m.order, key) + + if m.sort != nil { + slices.SortFunc(m.order, func(a, b Key) bool { + return m.sort(a, b, m.data) + }) + } } func (m *Map[Key, Val]) Iter(fn func(key Key, val Val)) { m.lock.RLock() defer m.lock.RUnlock() - for key, val := range m.data { - fn(key, val) + for _, key := range m.order { + fn(key, m.data[key]) } } -func (m *Map[Key, Val]) Keys(fn func(keys []Key)) { - m.lock.RLock() - defer m.lock.RUnlock() +func (m *Map[Key, Val]) IterKeys(fn func(Key)) { + m.Iter(func(key Key, _ Val) { + fn(key) + }) +} - fn(maps.Keys(m.data)) +func (m *Map[Key, Val]) IterKeysErr(fn func(Key) error) error { + var err error + + m.IterKeys(func(key Key) { + if err != nil { + return + } + + err = fn(key) + }) + + return err +} + +func (m *Map[Key, Val]) IterValues(fn func(Val)) { + m.Iter(func(_ Key, val Val) { + fn(val) + }) +} + +func (m *Map[Key, Val]) IterValuesErr(fn func(Val) error) error { + var err error + + m.IterValues(func(val Val) { + if err != nil { + return + } + + err = fn(val) + }) + + return err } func (m *Map[Key, Val]) Values(fn func(vals []Val)) { m.lock.RLock() defer m.lock.RUnlock() - fn(maps.Values(m.data)) + vals := make([]Val, len(m.order)) + + for i, key := range m.order { + vals[i] = m.data[key] + } + + fn(vals) } -func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) Ret, fallback func() Ret) Ret { +func (m *Map[Key, Val]) ValuesErr(fn func(vals []Val) error) error { + var err error + + m.Values(func(vals []Val) { + err = fn(vals) + }) + + return err +} + +func (m *Map[Key, Val]) MapErr(fn func(map[Key]Val) error) error { m.lock.RLock() defer m.lock.RUnlock() - val, ok := m.data[key] - if !ok { - return fallback() - } - - return fn(val) + return fn(m.data) } -func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) { - m.lock.RLock() - defer m.lock.RUnlock() +func MapGetRet[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) Ret) (Ret, bool) { + var ret Ret - val, ok := m.data[key] - if !ok { - return fallback() - } + ok := m.Get(key, func(val Val) { + ret = fn(val) + }) - return fn(val) + return ret, ok } -func FindMap[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) Ret, fallback func() Ret) Ret { - m.lock.RLock() - defer m.lock.RUnlock() +func MapValuesRet[Key comparable, Val, Ret any](m *Map[Key, Val], fn func([]Val) Ret) Ret { + var ret Ret - for _, val := range m.data { - if cmp(val) { - return fn(val) - } - } + m.Values(func(vals []Val) { + ret = fn(vals) + }) - return fallback() + return ret } -func FindMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) { - m.lock.RLock() - defer m.lock.RUnlock() +func MapValuesRetErr[Key comparable, Val, Ret any](m *Map[Key, Val], fn func([]Val) (Ret, error)) (Ret, error) { + var ret Ret - for _, val := range m.data { - if cmp(val) { - return fn(val) - } - } + err := m.ValuesErr(func(vals []Val) error { + var err error - return fallback() + ret, err = fn(vals) + + return err + }) + + return ret, err } diff --git a/internal/safe/map_test.go b/internal/safe/map_test.go new file mode 100644 index 00000000..2af2339d --- /dev/null +++ b/internal/safe/map_test.go @@ -0,0 +1,97 @@ +package safe + +import "testing" + +func TestSafe_Map(t *testing.T) { + m := NewMap(func(a, b string, data map[string]string) bool { + return a < b + }) + + m.Set("a", "b") + + if !m.Has("a") { + t.Fatal("expected to have key") + } + + if m.Has("b") { + t.Fatal("expected not to have key") + } + + m.Set("b", "c") + + if !m.Has("b") { + t.Fatal("expected to have key") + } + + if !m.HasFunc(func(key string, val string) bool { + return key == "b" + }) { + t.Fatal("expected to have key") + } + + if !m.Get("b", func(val string) { + if val != "c" { + t.Fatal("expected to have value") + } + }) { + t.Fatal("expected to have key") + } + + if !m.Index(0, func(key string, val string) { + if key != "a" || val != "b" { + t.Fatal("expected to have key and value") + } + }) { + t.Fatal("expected to have index") + } + + if !m.Index(1, func(key string, val string) { + if key != "b" || val != "c" { + t.Fatal("expected to have key and value") + } + }) { + t.Fatal("expected to have index") + } + + if m.Index(2, func(key string, val string) { + t.Fatal("expected not to have index") + }) { + t.Fatal("expected not to have index") + } + + if !m.GetDelete("b", func(val string) { + if val != "c" { + t.Fatal("expected to have value") + } + }) { + t.Fatal("expected to have key") + } + + if m.Has("b") { + t.Fatal("expected not to have key") + } + + if m.GetDelete("b", func(val string) { + t.Fatal("expected not to have value") + }) { + t.Fatal("expected not to have key") + } + + if !m.Index(0, func(key string, val string) { + if key != "a" || val != "b" { + t.Fatal("expected to have key and value") + } + }) { + t.Fatal("expected to have index") + } + + m.Values(func(val []string) { + if len(val) != 1 { + t.Fatal("expected to have values") + } + + if val[0] != "b" { + t.Fatal("expected to have value") + } + }) +} diff --git a/internal/safe/set.go b/internal/safe/set.go deleted file mode 100644 index 7d844c95..00000000 --- a/internal/safe/set.go +++ /dev/null @@ -1,46 +0,0 @@ -package safe - -import "golang.org/x/exp/maps" - -type Set[Val comparable] Map[Val, struct{}] - -func NewSet[Val comparable](vals ...Val) *Set[Val] { - set := (*Set[Val])(NewMap[Val, struct{}](nil)) - - for _, val := range vals { - set.Insert(val) - } - - return set -} - -func (m *Set[Val]) Has(key Val) bool { - m.lock.RLock() - defer m.lock.RUnlock() - - _, ok := m.data[key] - return ok -} - -func (m *Set[Val]) Insert(key Val) { - m.lock.Lock() - defer m.lock.Unlock() - - m.data[key] = struct{}{} -} - -func (m *Set[Val]) Iter(fn func(key Val)) { - m.lock.RLock() - defer m.lock.RUnlock() - - for key := range m.data { - fn(key) - } -} - -func (m *Set[Val]) Values(fn func(vals []Val)) { - m.lock.RLock() - defer m.lock.RUnlock() - - fn(maps.Keys(m.data)) -} diff --git a/internal/safe/slice.go b/internal/safe/slice.go index 623682a1..a2f9e2a8 100644 --- a/internal/safe/slice.go +++ b/internal/safe/slice.go @@ -1,13 +1,17 @@ package safe -import "sync" +import ( + "sync" -type Slice[Val any] struct { + "github.com/bradenaw/juniper/xslices" +) + +type Slice[Val comparable] struct { data []Val lock sync.RWMutex } -func NewSlice[Val any](from []Val) *Slice[Val] { +func NewSlice[Val comparable](from ...Val) *Slice[Val] { s := &Slice[Val]{ data: make([]Val, len(from)), } @@ -17,37 +21,27 @@ func NewSlice[Val any](from []Val) *Slice[Val] { return s } -func (s *Slice[Val]) Get(fn func(data []Val)) { +func (s *Slice[Val]) Iter(fn func(val Val)) { s.lock.RLock() defer s.lock.RUnlock() - fn(s.data) + for _, val := range s.data { + fn(val) + } } -func (s *Slice[Val]) GetErr(fn func(data []Val) error) error { - s.lock.RLock() - defer s.lock.RUnlock() - - return fn(s.data) -} - -func (s *Slice[Val]) Set(data []Val) { +func (s *Slice[Val]) Append(val Val) { s.lock.Lock() defer s.lock.Unlock() - s.data = data + s.data = append(s.data, val) } -func GetSlice[Val, Ret any](s *Slice[Val], fn func(data []Val) Ret) Ret { - s.lock.RLock() - defer s.lock.RUnlock() +func (s *Slice[Val]) Delete(val Val) { + s.lock.Lock() + defer s.lock.Unlock() - return fn(s.data) -} - -func GetSliceErr[Val, Ret any](s *Slice[Val], fn func(data []Val) (Ret, error)) (Ret, error) { - s.lock.RLock() - defer s.lock.RUnlock() - - return fn(s.data) + s.data = xslices.Filter(s.data, func(v Val) bool { + return v != val + }) } diff --git a/internal/safe/slice_test.go b/internal/safe/slice_test.go new file mode 100644 index 00000000..ae4ef0d9 --- /dev/null +++ b/internal/safe/slice_test.go @@ -0,0 +1,34 @@ +package safe + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSlice(t *testing.T) { + s := NewSlice(1, 2, 3, 4, 5) + + { + var have []int + + s.Iter(func(val int) { + have = append(have, val) + }) + + require.Equal(t, []int{1, 2, 3, 4, 5}, have) + } + + s.Append(6) + s.Delete(3) + + { + var have []int + + s.Iter(func(val int) { + have = append(have, val) + }) + + require.Equal(t, []int{1, 2, 4, 5, 6}, have) + } +} diff --git a/internal/safe/value.go b/internal/safe/value.go index 854798fa..aa3ef73b 100644 --- a/internal/safe/value.go +++ b/internal/safe/value.go @@ -13,37 +13,57 @@ func NewValue[T any](data T) *Value[T] { } } -func (s *Value[T]) Get(fn func(data T)) { +func (s *Value[T]) Load(fn func(data T)) { s.lock.RLock() defer s.lock.RUnlock() fn(s.data) } -func (s *Value[T]) GetErr(fn func(data T) error) error { - s.lock.RLock() - defer s.lock.RUnlock() +func (s *Value[T]) LoadErr(fn func(data T) error) error { + var err error - return fn(s.data) + s.Load(func(data T) { + err = fn(data) + }) + + return err } -func (s *Value[T]) Set(data T) { +func (s *Value[T]) Save(data T) { s.lock.Lock() defer s.lock.Unlock() s.data = data } -func GetType[T, Ret any](s *Value[T], fn func(data T) Ret) Ret { - s.lock.RLock() - defer s.lock.RUnlock() +func (s *Value[T]) Mod(fn func(data *T)) { + s.lock.Lock() + defer s.lock.Unlock() - return fn(s.data) + fn(&s.data) } -func GetTypeErr[T, Ret any](s *Value[T], fn func(data T) (Ret, error)) (Ret, error) { - s.lock.RLock() - defer s.lock.RUnlock() +func LoadRet[T, Ret any](s *Value[T], fn func(data T) Ret) Ret { + var ret Ret - return fn(s.data) + s.Load(func(data T) { + ret = fn(data) + }) + + return ret +} + +func LoadRetErr[T, Ret any](s *Value[T], fn func(data T) (Ret, error)) (Ret, error) { + var ret Ret + + err := s.LoadErr(func(data T) error { + var err error + + ret, err = fn(data) + + return err + }) + + return ret, err } diff --git a/internal/safe/value_test.go b/internal/safe/value_test.go new file mode 100644 index 00000000..4ce4acaf --- /dev/null +++ b/internal/safe/value_test.go @@ -0,0 +1,37 @@ +package safe + +import "testing" + +func TestValue(t *testing.T) { + v := NewValue("foo") + + v.Load(func(data string) { + if data != "foo" { + t.Error("expected foo") + } + }) + + v.Save("bar") + + v.Load(func(data string) { + if data != "bar" { + t.Error("expected bar") + } + }) + + v.Mod(func(data *string) { + *data = "baz" + }) + + v.Load(func(data string) { + if data != "baz" { + t.Error("expected baz") + } + }) + + if LoadRet(v, func(data string) string { + return data + }) != "baz" { + t.Error("expected baz") + } +} diff --git a/internal/try/try.go b/internal/try/try.go index f26a1669..4d97d081 100644 --- a/internal/try/try.go +++ b/internal/try/try.go @@ -2,6 +2,7 @@ package try import ( "fmt" + "sync" "github.com/sirupsen/logrus" ) @@ -47,3 +48,31 @@ func catch(handlers ...func() error) { } } } + +type Group struct { + mu sync.Mutex +} + +func (wg *Group) GoTry(fn func(bool)) { + if wg.mu.TryLock() { + go func() { + defer wg.mu.Unlock() + fn(true) + }() + } else { + go fn(false) + } +} + +func (wg *Group) Lock() { + wg.mu.Lock() +} + +func (wg *Group) Unlock() { + wg.mu.Unlock() +} + +func (wg *Group) Wait() { + wg.mu.Lock() + defer wg.mu.Unlock() +} diff --git a/internal/user/events.go b/internal/user/events.go index 761db23f..03c4ca94 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -8,7 +8,6 @@ import ( "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/events" - "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/xslices" "gitlab.protontech.ch/go/liteapi" @@ -28,12 +27,6 @@ func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error } } - if event.MailSettings != nil { - if err := user.handleMailSettingsEvent(ctx, *event.MailSettings); err != nil { - return err - } - } - if len(event.Labels) > 0 { if err := user.handleLabelEvents(ctx, event.Labels); err != nil { return err @@ -51,14 +44,7 @@ func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error // handleUserEvent handles the given user event. func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) error { - userKR, err := userEvent.Keys.Unlock(user.vault.KeyPass(), nil) - if err != nil { - return err - } - - user.apiUser.Set(userEvent) - - user.userKR.Set(userKR) + user.apiUser.Save(userEvent) user.eventCh.Enqueue(events.UserChanged{ UserID: user.ID(), @@ -93,22 +79,18 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea } func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error { - addrKR, err := safe.GetTypeErr(user.userKR, func(userKR *crypto.KeyRing) (*crypto.KeyRing, error) { - return event.Address.Keys.Unlock(user.vault.KeyPass(), userKR) - }) - if err != nil { - return fmt.Errorf("failed to unlock address keys: %w", err) + user.apiAddrs.Set(event.Address.ID, event.Address) + + switch user.vault.AddressMode() { + case vault.CombinedMode: + user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) { + user.updateCh.SetFrom(event.Address.ID, addrID) + }) + + case vault.SplitMode: + user.updateCh.Set(event.Address.ID, queue.NewQueuedChannel[imap.Update](0, 0)) } - apiAddrs, err := user.client.GetAddresses(ctx) - if err != nil { - return fmt.Errorf("failed to get addresses: %w", err) - } - - user.apiAddrs.Set(apiAddrs) - - user.addrKRs.Set(event.Address.ID, addrKR) - user.eventCh.Enqueue(events.UserAddressCreated{ UserID: user.ID(), AddressID: event.Address.ID, @@ -116,9 +98,11 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad }) if user.vault.AddressMode() == vault.SplitMode { - user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0) - - if err := syncLabels(ctx, user.client, user.updateCh[event.Address.ID]); err != nil { + if ok, err := user.updateCh.GetErr(event.Address.ID, func(updateCh *queue.QueuedChannel[imap.Update]) error { + return syncLabels(ctx, user.client, updateCh) + }); !ok { + return fmt.Errorf("no such address %q", event.Address.ID) + } else if err != nil { return fmt.Errorf("failed to sync labels to new address: %w", err) } } @@ -127,21 +111,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad } func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error { - addrKR, err := safe.GetTypeErr(user.userKR, func(userKR *crypto.KeyRing) (*crypto.KeyRing, error) { - return event.Address.Keys.Unlock(user.vault.KeyPass(), userKR) - }) - if err != nil { - return fmt.Errorf("failed to unlock address keys: %w", err) - } - - apiAddrs, err := user.client.GetAddresses(ctx) - if err != nil { - return fmt.Errorf("failed to get addresses: %w", err) - } - - user.apiAddrs.Set(apiAddrs) - - user.addrKRs.Set(event.Address.ID, addrKR) + user.apiAddrs.Set(event.Address.ID, event.Address) user.eventCh.Enqueue(events.UserAddressUpdated{ UserID: user.ID(), @@ -153,25 +123,20 @@ func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.Ad } func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.AddressEvent) error { - email, err := safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) { - return getAddrEmail(apiAddrs, event.ID) - }) - if err != nil { - return fmt.Errorf("failed to get address email: %w", err) + var email string + + if ok := user.apiAddrs.GetDelete(event.ID, func(apiAddr liteapi.Address) { + email = apiAddr.Email + }); !ok { + return fmt.Errorf("no such address %q", event.ID) } - apiAddrs, err := user.client.GetAddresses(ctx) - if err != nil { - return fmt.Errorf("failed to get addresses: %w", err) - } - - user.apiAddrs.Set(apiAddrs) - - user.addrKRs.Delete(event.ID) - - if len(user.updateCh) > 1 { - user.updateCh[event.ID].Close() - delete(user.updateCh, event.ID) + if ok := user.updateCh.GetDelete(event.ID, func(updateCh *queue.QueuedChannel[imap.Update]) { + if user.vault.AddressMode() == vault.SplitMode { + updateCh.Close() + } + }); !ok { + return fmt.Errorf("no such address %q", event.ID) } user.eventCh.Enqueue(events.UserAddressDeleted{ @@ -183,13 +148,6 @@ func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.Ad return nil } -// handleMailSettingsEvent handles the given mail settings event. -func (user *User) handleMailSettingsEvent(ctx context.Context, mailSettingsEvent liteapi.MailSettings) error { - user.settings.Set(mailSettingsEvent) - - return nil -} - // handleLabelEvents handles the given label events. func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.LabelEvent) error { for _, event := range labelEvents { @@ -215,25 +173,25 @@ func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.L } func (user *User) handleCreateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error { - for _, updateCh := range user.updateCh { + user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) { updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label))) - } + }) return nil } func (user *User) handleUpdateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error { - for _, updateCh := range user.updateCh { + user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) { updateCh.Enqueue(imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label))) - } + }) return nil } func (user *User) handleDeleteLabelEvent(ctx context.Context, event liteapi.LabelEvent) error { - for _, updateCh := range user.updateCh { + user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) { updateCh.Enqueue(imap.NewMailboxDeleted(imap.LabelID(event.ID))) - } + }) return nil } @@ -269,29 +227,18 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me return fmt.Errorf("failed to get full message: %w", err) } - buildRes, err := safe.GetMapErr( - user.addrKRs, - full.AddressID, - func(addrKR *crypto.KeyRing) (*buildRes, error) { - return buildRFC822(ctx, full, addrKR) - }, - func() (*buildRes, error) { - return nil, fmt.Errorf("address keyring not found") - }, - ) - if err != nil { - return fmt.Errorf("failed to build RFC822: %w", err) - } + return user.withAddrKR(event.Message.AddressID, func(addrKR *crypto.KeyRing) error { + buildRes, err := buildRFC822(ctx, full, addrKR) + if err != nil { + return fmt.Errorf("failed to build RFC822 message: %w", err) + } - if len(user.updateCh) > 1 { - user.updateCh[buildRes.addressID].Enqueue(imap.NewMessagesCreated(buildRes.update)) - } else { - user.apiAddrs.Get(func(apiAddrs []liteapi.Address) { - user.updateCh[apiAddrs[0].ID].Enqueue(imap.NewMessagesCreated(buildRes.update)) + user.updateCh.Get(full.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) { + updateCh.Enqueue(imap.NewMessagesCreated(buildRes.update)) }) - } - return nil + return nil + }) } func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error { @@ -302,13 +249,9 @@ func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.Me event.Message.Starred(), ) - if len(user.updateCh) > 1 { - user.updateCh[event.Message.AddressID].Enqueue(update) - } else { - user.apiAddrs.Get(func(apiAddrs []liteapi.Address) { - user.updateCh[apiAddrs[0].ID].Enqueue(update) - }) - } + user.updateCh.Get(event.Message.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) { + updateCh.Enqueue(update) + }) return nil } diff --git a/internal/user/imap.go b/internal/user/imap.go index 06cc32ab..9f67608b 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -2,13 +2,13 @@ package user import ( "context" - "crypto/subtle" "fmt" - "strings" "time" "github.com/ProtonMail/gluon/imap" - "github.com/bradenaw/juniper/xslices" + "github.com/ProtonMail/gluon/queue" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" "gitlab.protontech.ch/go/liteapi" "golang.org/x/exp/slices" ) @@ -25,27 +25,18 @@ const ( ) type imapConnector struct { - client *liteapi.Client - updateCh <-chan imap.Update + *User - emails []string - password []byte + addrID string flags, permFlags, attrs imap.FlagSet } -func newIMAPConnector( - client *liteapi.Client, - updateCh <-chan imap.Update, - password []byte, - emails ...string, -) *imapConnector { +func newIMAPConnector(user *User, addrID string) *imapConnector { return &imapConnector{ - client: client, - updateCh: updateCh, + User: user, - emails: emails, - password: password, + addrID: addrID, flags: defaultFlags, permFlags: defaultPermanentFlags, @@ -55,13 +46,16 @@ func newIMAPConnector( // Authorize returns whether the given username/password combination are valid for this connector. func (conn *imapConnector) Authorize(username string, password []byte) bool { - if subtle.ConstantTimeCompare(conn.password, password) != 1 { + addrID, err := conn.checkAuth(username, password) + if err != nil { return false } - return xslices.IndexFunc(conn.emails, func(address string) bool { - return strings.EqualFold(address, username) - }) >= 0 + if conn.vault.AddressMode() == vault.SplitMode && addrID != conn.addrID { + return false + } + + return true } // GetLabel returns information about the label with the given ID. @@ -246,7 +240,14 @@ func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs [ // GetUpdates returns a stream of updates that the gluon server should apply. // It is recommended that the returned channel is buffered with at least constants.ChannelBufferCount. func (conn *imapConnector) GetUpdates() <-chan imap.Update { - return conn.updateCh + updateCh, ok := safe.MapGetRet(conn.updateCh, conn.addrID, func(updateCh *queue.QueuedChannel[imap.Update]) <-chan imap.Update { + return updateCh.GetChannel() + }) + if !ok { + panic(fmt.Sprintf("update channel for %q not found", conn.addrID)) + } + + return updateCh } // GetUIDValidity returns the default UID validity for this user. diff --git a/internal/user/keys.go b/internal/user/keys.go new file mode 100644 index 00000000..7bad30dd --- /dev/null +++ b/internal/user/keys.go @@ -0,0 +1,60 @@ +package user + +import ( + "fmt" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "gitlab.protontech.ch/go/liteapi" +) + +func (user *User) withUserKR(fn func(*crypto.KeyRing) error) error { + return user.apiUser.LoadErr(func(apiUser liteapi.User) error { + userKR, err := apiUser.Keys.Unlock(user.vault.KeyPass(), nil) + if err != nil { + return fmt.Errorf("failed to unlock user keys: %w", err) + } + defer userKR.ClearPrivateParams() + + return fn(userKR) + }) +} + +func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing) error) error { + return user.withUserKR(func(userKR *crypto.KeyRing) error { + if ok, err := user.apiAddrs.GetErr(addrID, func(apiAddr liteapi.Address) error { + addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR) + if err != nil { + return fmt.Errorf("failed to unlock address keys: %w", err) + } + defer userKR.ClearPrivateParams() + + return fn(addrKR) + }); !ok { + return fmt.Errorf("no such address %q", addrID) + } else if err != nil { + return err + } + + return nil + }) +} + +func (user *User) withAddrKRs(fn func(map[string]*crypto.KeyRing) error) error { + return user.withUserKR(func(userKR *crypto.KeyRing) error { + return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error { + addrKRs := make(map[string]*crypto.KeyRing) + + for _, apiAddr := range apiAddrs { + addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR) + if err != nil { + return fmt.Errorf("failed to unlock address keys: %w", err) + } + defer userKR.ClearPrivateParams() + + addrKRs[apiAddr.ID] = addrKR + } + + return fn(addrKRs) + }) + }) +} diff --git a/internal/user/smtp.go b/internal/user/smtp.go index 3a0e8193..60dc5dfa 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -34,22 +34,25 @@ type smtpSession struct { // from is the current sending address (taken from the return path). from string + // fromAddrID is the ID of the curent sending address (taken from the return path). + fromAddrID string + // to holds all to for the current message. to []string } func newSMTPSession(user *User, email string) (*smtpSession, error) { - authID, err := safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) { - return getAddrID(apiAddrs, email) - }) - if err != nil { - return nil, fmt.Errorf("failed to get address ID: %w", err) - } + return safe.MapValuesRetErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (*smtpSession, error) { + authID, err := getAddrID(apiAddrs, email) + if err != nil { + return nil, fmt.Errorf("failed to get address ID: %w", err) + } - return &smtpSession{ - User: user, - authID: authID, - }, nil + return &smtpSession{ + User: user, + authID: authID, + }, nil + }) } // Discard currently processed message. @@ -58,6 +61,7 @@ func (session *smtpSession) Reset() { // Clear the from and to fields. session.from = "" + session.fromAddrID = "" session.to = nil } @@ -74,7 +78,7 @@ func (session *smtpSession) Logout() error { func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error { logrus.Info("SMTP session mail") - return session.apiAddrs.GetErr(func(apiAddrs []liteapi.Address) error { + return session.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error { switch { case opts.RequireTLS: return ErrNotImplemented @@ -93,12 +97,15 @@ func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error { } } - if _, err := getAddrID(apiAddrs, sanitizeEmail(from)); err != nil { + addrID, err := getAddrID(apiAddrs, sanitizeEmail(from)) + if err != nil { return fmt.Errorf("invalid return path: %w", err) } session.from = from + session.fromAddrID = addrID + return nil }) } @@ -138,18 +145,13 @@ func (session *smtpSession) Data(r io.Reader) error { return fmt.Errorf("failed to create parser: %w", err) } - message, err := safe.GetSliceErr(session.apiAddrs, func(apiAddrs []liteapi.Address) (liteapi.Message, error) { - addrID, err := getAddrID(apiAddrs, session.from) - if err != nil { - return liteapi.Message{}, fmt.Errorf("invalid return path: %w", err) - } - - return safe.GetMapErr(session.addrKRs, addrID, func(addrKR *crypto.KeyRing) (liteapi.Message, error) { - return safe.GetTypeErr(session.settings, func(settings liteapi.MailSettings) (liteapi.Message, error) { + return session.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error { + return session.withAddrKR(session.fromAddrID, func(addrKR *crypto.KeyRing) error { + return session.withUserKR(func(userKR *crypto.KeyRing) error { // Use the first key for encrypting the message. addrKR, err := addrKR.FirstKey() if err != nil { - return liteapi.Message{}, fmt.Errorf("failed to get first key: %w", err) + return fmt.Errorf("failed to get first key: %w", err) } // If the message contains a sender, use it instead of the one from the return path. @@ -157,51 +159,61 @@ func (session *smtpSession) Data(r io.Reader) error { session.from = sender } + // Load the user's mail settings. + settings, err := session.client.GetMailSettings(ctx) + if err != nil { + return fmt.Errorf("failed to get mail settings: %w", err) + } + // If we have to attach the public key, do it now. if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { key, err := addrKR.GetKey(0) if err != nil { - return liteapi.Message{}, fmt.Errorf("failed to get sending key: %w", err) + return fmt.Errorf("failed to get sending key: %w", err) } pubKey, err := key.GetArmoredPublicKey() if err != nil { - return liteapi.Message{}, fmt.Errorf("failed to get public key: %w", err) + return fmt.Errorf("failed to get public key: %w", err) } parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) } + // Parse the message we want to send (after we have attached the public key). message, err := message.ParseWithParser(parser) if err != nil { - return liteapi.Message{}, fmt.Errorf("failed to parse message: %w", err) + return fmt.Errorf("failed to parse message: %w", err) } - return sendWithKey( + // Collect all the user's emails so we can match them to the outgoing message. + emails := xslices.Map(apiAddrs, func(addr liteapi.Address) string { + return addr.Email + }) + + sent, err := sendWithKey( ctx, session.client, session.authID, session.vault.AddressMode(), - apiAddrs, settings, - session.userKR, + userKR, addrKR, + emails, session.from, session.to, message, ) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + + logrus.WithField("messageID", sent.ID).Info("Message sent") + + return nil }) - }, func() (liteapi.Message, error) { - return liteapi.Message{}, ErrMissingAddrKey }) }) - if err != nil { - return fmt.Errorf("failed to send message: %w", err) - } - - logrus.WithField("messageID", message.ID).Info("Message sent") - - return nil } // sendWithKey sends the message with the given address key. @@ -210,10 +222,10 @@ func sendWithKey( client *liteapi.Client, authAddrID string, addrMode vault.AddressMode, - apiAddrs []liteapi.Address, settings liteapi.MailSettings, - userKR *safe.Value[*crypto.KeyRing], + userKR *crypto.KeyRing, addrKR *crypto.KeyRing, + emails []string, from string, to []string, message message.Message, @@ -243,7 +255,7 @@ func sendWithKey( return liteapi.Message{}, fmt.Errorf("failed to get armored message body: %w", err) } - draft, err := createDraft(ctx, client, apiAddrs, from, to, parentID, liteapi.DraftTemplate{ + draft, err := createDraft(ctx, client, emails, from, to, parentID, liteapi.DraftTemplate{ Subject: message.Subject, Body: armBody, MIMEType: message.MIMEType, @@ -264,9 +276,7 @@ func sendWithKey( return liteapi.Message{}, fmt.Errorf("failed to create attachments: %w", err) } - recipients, err := safe.GetTypeErr(userKR, func(userKR *crypto.KeyRing) (recipients, error) { - return getRecipients(ctx, client, userKR, settings, draft) - }) + recipients, err := getRecipients(ctx, client, userKR, settings, draft) if err != nil { return liteapi.Message{}, fmt.Errorf("failed to get recipients: %w", err) } @@ -357,7 +367,7 @@ func getParentID( func createDraft( ctx context.Context, client *liteapi.Client, - apiAddrs []liteapi.Address, + emails []string, from string, to []string, parentID string, @@ -371,12 +381,12 @@ func createDraft( } // Check that the sending address is owned by the user, and if so, sanitize it. - if idx := xslices.IndexFunc(apiAddrs, func(addr liteapi.Address) bool { - return strings.EqualFold(addr.Email, sanitizeEmail(template.Sender.Address)) + if idx := xslices.IndexFunc(emails, func(email string) bool { + return strings.EqualFold(email, sanitizeEmail(template.Sender.Address)) }); idx < 0 { return liteapi.Message{}, fmt.Errorf("address %q is not owned by user", template.Sender.Address) } else { - template.Sender.Address = constructEmail(template.Sender.Address, apiAddrs[idx].Email) + template.Sender.Address = constructEmail(template.Sender.Address, emails[idx]) } // Check ToList: ensure that ToList only contains addresses we actually plan to send to. diff --git a/internal/user/sync.go b/internal/user/sync.go index ace40437..216eff98 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -10,12 +10,13 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gopenpgp/v2/crypto" - "github.com/ProtonMail/proton-bridge/v2/internal/safe" + "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" + "github.com/sirupsen/logrus" "gitlab.protontech.ch/go/liteapi" - "golang.org/x/exp/maps" ) const ( @@ -24,27 +25,43 @@ const ( ) func (user *User) sync(ctx context.Context) error { - if !user.vault.SyncStatus().HasLabels { - if err := syncLabels(ctx, user.client, maps.Values(user.updateCh)...); err != nil { - return fmt.Errorf("failed to sync labels: %w", err) + return user.withAddrKRs(func(addrKRs map[string]*crypto.KeyRing) error { + logrus.Info("Beginning sync") + + if !user.vault.SyncStatus().HasLabels { + logrus.Info("Syncing labels") + + if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error { + return syncLabels(ctx, user.client, xslices.Unique(updateCh)...) + }); err != nil { + return fmt.Errorf("failed to sync labels: %w", err) + } + + if err := user.vault.SetHasLabels(true); err != nil { + return fmt.Errorf("failed to set has labels: %w", err) + } + } else { + logrus.Info("Labels are already synced, skipping") } - if err := user.vault.SetHasLabels(true); err != nil { - return fmt.Errorf("failed to set has labels: %w", err) - } - } + if !user.vault.SyncStatus().HasMessages { + logrus.Info("Syncing labels") - if !user.vault.SyncStatus().HasMessages { - if err := user.syncMessages(ctx); err != nil { - return fmt.Errorf("failed to sync messages: %w", err) + if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error { + return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh) + }); err != nil { + return fmt.Errorf("failed to sync messages: %w", err) + } + + if err := user.vault.SetHasMessages(true); err != nil { + return fmt.Errorf("failed to set has messages: %w", err) + } + } else { + logrus.Info("Messages are already synced, skipping") } - if err := user.vault.SetHasMessages(true); err != nil { - return fmt.Errorf("failed to set has messages: %w", err) - } - } - - return nil + return nil + }) } func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error { @@ -102,48 +119,44 @@ func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue. return nil } -func (user *User) syncMessages(ctx context.Context) error { +func syncMessages( + ctx context.Context, + userID string, + client *liteapi.Client, + vault *vault.User, + addrKRs map[string]*crypto.KeyRing, + updateCh map[string]*queue.QueuedChannel[imap.Update], + eventCh *queue.QueuedChannel[events.Event], +) error { // Determine which messages to sync. - allMetadata, err := user.client.GetAllMessageMetadata(ctx, nil) + metadata, err := client.GetAllMessageMetadata(ctx, nil) if err != nil { return fmt.Errorf("get all message metadata: %w", err) } - metadata := allMetadata + // Get the message IDs to sync. + messageIDs := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string { + return metadata.ID + }) // If possible, begin syncing from one beyond the last synced message. - if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" { - if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool { - return metadata.ID == beginID - }); idx >= 0 { - metadata = metadata[idx+1:] - } + if idx := xslices.Index(messageIDs, vault.SyncStatus().LastMessageID); idx >= 0 { + messageIDs = messageIDs[idx+1:] } - // Process the metadata, building the messages. - buildCh := stream.Chunk(stream.Map( - user.client.GetFullMessages(ctx, xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string { - return metadata.ID - })...), + // Fetch and build each message. + buildCh := stream.Map( + client.GetFullMessages(ctx, messageIDs...), func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) { - return safe.GetMapErr( - user.addrKRs, - full.AddressID, - func(addrKR *crypto.KeyRing) (*buildRes, error) { - return buildRFC822(ctx, full, addrKR) - }, - func() (*buildRes, error) { - return nil, fmt.Errorf("address keyring not found") - }, - ) + return buildRFC822(ctx, full, addrKRs[full.AddressID]) }, - ), maxBatchSize) + ) defer buildCh.Close() // Create the flushers, one per update channel. flushers := make(map[string]*flusher) - for addrID, updateCh := range user.updateCh { + for addrID, updateCh := range updateCh { flusher := newFlusher(updateCh, maxUpdateSize) defer flusher.flush(ctx, true) @@ -151,42 +164,27 @@ func (user *User) syncMessages(ctx context.Context) error { } // Create a reporter to report sync progress updates. - reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second) + reporter := newReporter(userID, eventCh, len(messageIDs), time.Second) defer reporter.done() - var count int - // Send each update to the appropriate flusher. - for { - batch, err := buildCh.Next(ctx) - if errors.Is(err, stream.End) { - return nil - } else if err != nil { - return fmt.Errorf("failed to get next sync batch: %w", err) + return forEach(ctx, stream.Chunk(buildCh, maxBatchSize), func(batch []*buildRes) error { + for _, res := range batch { + flushers[res.addressID].push(ctx, res.update) } - user.apiAddrs.Get(func(apiAddrs []liteapi.Address) { - for _, res := range batch { - if len(flushers) > 1 { - flushers[res.addressID].push(ctx, res.update) - } else { - flushers[apiAddrs[0].ID].push(ctx, res.update) - } - } - }) - for _, flusher := range flushers { flusher.flush(ctx, true) } - if err := user.vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil { + if err := vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil { return fmt.Errorf("failed to set last synced message ID: %w", err) } reporter.add(len(batch)) - count += len(batch) - } + return nil + }) } func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated { @@ -232,3 +230,18 @@ func wantLabelID(labelID string) bool { return true } } + +func forEach[T any](ctx context.Context, streamer stream.Stream[T], fn func(T) error) error { + for { + res, err := streamer.Next(ctx) + if errors.Is(err, stream.End) { + return nil + } else if err != nil { + return fmt.Errorf("failed to get next stream item: %w", err) + } + + if err := fn(res); err != nil { + return fmt.Errorf("failed to process stream item: %w", err) + } + } +} diff --git a/internal/user/types.go b/internal/user/types.go index 7fca1530..fbef2336 100644 --- a/internal/user/types.go +++ b/internal/user/types.go @@ -1,10 +1,17 @@ package user import ( + "context" + "encoding/hex" "fmt" "reflect" + + "gitlab.protontech.ch/go/liteapi" ) +// mapTo converts the slice to the given type. +// This is not runtime safe, so make sure the slice is of the correct type! +// (This is a workaround for the fact that slices cannot be converted to other types generically). func mapTo[From, To any](from []From) []To { to := make([]To, 0, len(from)) @@ -19,3 +26,79 @@ func mapTo[From, To any](from []From) []To { return to } + +// groupBy returns a map of the given slice grouped by the given key. +// Duplicate keys are overwritten. +func groupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[Key]Value { + groups := make(map[Key]Value) + + for _, item := range items { + groups[key(item)] = item + } + + return groups +} + +// sortAddr returns whether the first address should be sorted before the second. +func sortAddr(addrIDA, addrIDB string, apiAddrs map[string]liteapi.Address) bool { + return apiAddrs[addrIDA].Order < apiAddrs[addrIDB].Order +} + +// hexEncode returns the hexadecimal encoding of the given byte slice. +func hexEncode(b []byte) []byte { + enc := make([]byte, hex.EncodedLen(len(b))) + + hex.Encode(enc, b) + + return enc +} + +// hexDecode returns the bytes represented by the hexadecimal encoding of the given byte slice. +func hexDecode(b []byte) ([]byte, error) { + dec := make([]byte, hex.DecodedLen(len(b))) + + if _, err := hex.Decode(dec, b); err != nil { + return nil, err + } + + return dec, nil +} + +// getAddrID returns the address ID for the given email address. +func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) { + for _, addr := range apiAddrs { + if addr.Email == email { + return addr.ID, nil + } + } + + return "", fmt.Errorf("address %s not found", email) +} + +// getAddrEmail returns the email address of the given address ID. +func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) { + for _, addr := range apiAddrs { + if addr.ID == addrID { + return addr.Email, nil + } + } + + return "", fmt.Errorf("address %s not found", addrID) +} + +// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it. +func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + + go func() { + select { + case <-stopCh: + cancel() + + case <-ctx.Done(): + // ... + } + }() + + return ctx, cancel +} diff --git a/internal/user/user.go b/internal/user/user.go index 90514522..29de575a 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -1,19 +1,18 @@ package user import ( - "bytes" "context" - "encoding/hex" + "crypto/subtle" "fmt" + "strings" "time" "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/queue" - "github.com/ProtonMail/gluon/wait" - "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/safe" + "github.com/ProtonMail/proton-bridge/v2/internal/try" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/xslices" "github.com/emersion/go-smtp" @@ -32,15 +31,11 @@ type User struct { eventCh *queue.QueuedChannel[events.Event] apiUser *safe.Value[liteapi.User] - apiAddrs *safe.Slice[liteapi.Address] - settings *safe.Value[liteapi.MailSettings] + apiAddrs *safe.Map[string, liteapi.Address] + updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]] - userKR *safe.Value[*crypto.KeyRing] - addrKRs *safe.Map[string, *crypto.KeyRing] - - updateCh map[string]*queue.QueuedChannel[imap.Update] syncStopCh chan struct{} - syncWG wait.Group + syncLock try.Group } func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User) (*User, error) { @@ -50,9 +45,8 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU return nil, fmt.Errorf("failed to get addresses: %w", err) } - // Unlock the user's keyrings. - userKR, addrKRs, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass()) - if err != nil { + // Check we can unlock the keyrings. + if _, _, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass()); err != nil { return nil, fmt.Errorf("failed to unlock user: %w", err) } @@ -68,20 +62,21 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU } } - // Get the user's mail settings. - settings, err := client.GetMailSettings(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get mail settings: %w", err) - } - - // Create update channels for each of the user's addresses (if in combined mode, just the primary). + // Create update channels for each of the user's addresses. + // In combined mode, the addresses all share the same update channel. updateCh := make(map[string]*queue.QueuedChannel[imap.Update]) - for _, addr := range apiAddrs { - updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0) + switch encVault.AddressMode() { + case vault.CombinedMode: + primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0) - if encVault.AddressMode() == vault.CombinedMode { - break + for _, addr := range apiAddrs { + updateCh[addr.ID] = primaryUpdateCh + } + + case vault.SplitMode: + for _, addr := range apiAddrs { + updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0) } } @@ -91,19 +86,15 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU eventCh: queue.NewQueuedChannel[events.Event](0, 0), apiUser: safe.NewValue(apiUser), - apiAddrs: safe.NewSlice(apiAddrs), - settings: safe.NewValue(settings), + apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr), + updateCh: safe.NewMapFrom(updateCh, nil), - userKR: safe.NewValue(userKR), - addrKRs: safe.NewMap(addrKRs), - - updateCh: updateCh, syncStopCh: make(chan struct{}), } // When we receive an auth object, we update it in the vault. // This will be used to authorize the user on the next run. - client.AddAuthHandler(func(auth liteapi.Auth) { + user.client.AddAuthHandler(func(auth liteapi.Auth) { if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil { logrus.WithError(err).Error("Failed to update auth in vault") } @@ -111,23 +102,24 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU // When we are deauthorized, we send a deauth event to the event channel. // Bridge will react to this event by logging out the user. - client.AddDeauthHandler(func() { + user.client.AddDeauthHandler(func() { user.eventCh.Enqueue(events.UserDeauth{ UserID: user.ID(), }) }) + // TODO: Don't start the event loop until the initial sync has finished! + eventCh := user.client.NewEventStream(EventPeriod, EventJitter, user.vault.EventID()) + // If we haven't synced yet, do it first. // If it fails, we don't start the event loop. // Otherwise, begin processing API events, logging any errors that occur. go func() { - if status := user.vault.SyncStatus(); !status.HasMessages { - if err := <-user.startSync(); err != nil { - return - } + if err := <-user.startSync(); err != nil { + return } - for err := range user.streamEvents() { + for err := range user.streamEvents(eventCh) { logrus.WithError(err).Error("Error while streaming events") } }() @@ -137,40 +129,34 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU // ID returns the user's ID. func (user *User) ID() string { - return safe.GetType(user.apiUser, func(apiUser liteapi.User) string { + return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string { return apiUser.ID }) } // Name returns the user's username. func (user *User) Name() string { - return safe.GetType(user.apiUser, func(apiUser liteapi.User) string { + return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string { return apiUser.Name }) } // Match matches the given query against the user's username and email addresses. func (user *User) Match(query string) bool { - return safe.GetType(user.apiUser, func(apiUser liteapi.User) bool { - return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) bool { - if query == apiUser.Name { - return true - } + return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) bool { + if query == apiUser.Name { + return true + } - for _, addr := range apiAddrs { - if addr.Email == query { - return true - } - } - - return false + return user.apiAddrs.HasFunc(func(_ string, addr liteapi.Address) bool { + return addr.Email == query }) }) } -// Emails returns all the user's email addresses. +// Emails returns all the user's email addresses via the callback. func (user *User) Emails() []string { - return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) []string { + return safe.MapValuesRet(user.apiAddrs, func(apiAddrs []liteapi.Address) []string { return xslices.Map(apiAddrs, func(addr liteapi.Address) string { return addr.Email }) @@ -184,28 +170,38 @@ func (user *User) GetAddressMode() vault.AddressMode { // SetAddressMode sets the user's address mode. func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error { - for _, updateCh := range user.updateCh { - updateCh.Close() - } + user.stopSync() + user.lockSync() + defer user.unlockSync() - user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update]) - - user.apiAddrs.Get(func(apiAddrs []liteapi.Address) { - for _, addr := range apiAddrs { - user.updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0) - - if mode == vault.CombinedMode { - break - } + user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) { + for _, updateCh := range xslices.Unique(updateCh) { + updateCh.Close() } }) + updateCh := make(map[string]*queue.QueuedChannel[imap.Update]) + + switch mode { + case vault.CombinedMode: + primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0) + + user.apiAddrs.IterKeys(func(addrID string) { + updateCh[addrID] = primaryUpdateCh + }) + + case vault.SplitMode: + user.apiAddrs.IterKeys(func(addrID string) { + updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0) + }) + } + + user.updateCh = safe.NewMapFrom(updateCh, nil) + if err := user.vault.SetAddressMode(mode); err != nil { return fmt.Errorf("failed to set address mode: %w", err) } - user.stopSync() - if err := user.vault.ClearSyncStatus(); err != nil { return fmt.Errorf("failed to clear sync status: %w", err) } @@ -246,25 +242,19 @@ func (user *User) GluonKey() []byte { // BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP. func (user *User) BridgePass() []byte { - buf := new(bytes.Buffer) - - if _, err := hex.NewEncoder(buf).Write(user.vault.BridgePass()); err != nil { - panic(err) - } - - return buf.Bytes() + return hexEncode(user.vault.BridgePass()) } // UsedSpace returns the total space used by the user on the API. func (user *User) UsedSpace() int { - return safe.GetType(user.apiUser, func(apiUser liteapi.User) int { + return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int { return apiUser.UsedSpace }) } // MaxSpace returns the amount of space the user can use on the API. func (user *User) MaxSpace() int { - return safe.GetType(user.apiUser, func(apiUser liteapi.User) int { + return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int { return apiUser.MaxSpace }) } @@ -275,37 +265,9 @@ func (user *User) GetEventCh() <-chan events.Event { } // NewIMAPConnector returns an IMAP connector for the given address. -// If not in split mode, this function returns an error. -func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) { - return safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (connector.Connector, error) { - var emails []string - - switch user.vault.AddressMode() { - case vault.CombinedMode: - if addrID != apiAddrs[0].ID { - return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode") - } - - emails = xslices.Map(apiAddrs, func(addr liteapi.Address) string { - return addr.Email - }) - - case vault.SplitMode: - email, err := getAddrEmail(apiAddrs, addrID) - if err != nil { - return nil, err - } - - emails = []string{email} - } - - return newIMAPConnector( - user.client, - user.updateCh[addrID].GetChannel(), - user.BridgePass(), - emails..., - ), nil - }) +// If not in split mode, this must be the primary address. +func (user *User) NewIMAPConnector(addrID string) connector.Connector { + return newIMAPConnector(user, addrID) } // NewIMAPConnectors returns IMAP connectors for each of the user's addresses. @@ -314,23 +276,48 @@ func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) { func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) { imapConn := make(map[string]connector.Connector) - for addrID := range user.updateCh { - conn, err := user.NewIMAPConnector(addrID) - if err != nil { - return nil, fmt.Errorf("failed to create IMAP connector: %w", err) - } + switch user.vault.AddressMode() { + case vault.CombinedMode: + user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) { + imapConn[addrID] = newIMAPConnector(user, addrID) + }) - imapConn[addrID] = conn + case vault.SplitMode: + user.apiAddrs.IterKeys(func(addrID string) { + imapConn[addrID] = newIMAPConnector(user, addrID) + }) } return imapConn, nil } // NewSMTPSession returns an SMTP session for the user. -func (user *User) NewSMTPSession(email string) (smtp.Session, error) { +func (user *User) NewSMTPSession(email string, password []byte) (smtp.Session, error) { + if _, err := user.checkAuth(email, password); err != nil { + return nil, err + } + return newSMTPSession(user, email) } +// OnStatusUp is called when the connection goes up. +func (user *User) OnStatusUp() { + go func() { + logrus.Info("Connection up, checking if sync is needed") + + if err := <-user.startSync(); err != nil { + logrus.WithError(err).Error("Failed to sync on status up") + } + }() +} + +// OnStatusDown is called when the connection goes down. +func (user *User) OnStatusDown() { + logrus.Info("Connection down, aborting any ongoing syncs") + + user.stopSync() +} + // Logout logs the user out from the API. // If withVault is true, the user's vault is also cleared. func (user *User) Logout(ctx context.Context) error { @@ -350,13 +337,18 @@ func (user *User) Close() error { // Cancel ongoing syncs. user.stopSync() + // Wait for ongoing syncs to stop. + user.waitSync() + // Close the user's API client. user.client.Close() // Close the user's update channels. - for _, updateCh := range user.updateCh { - updateCh.Close() - } + user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) { + for _, updateCh := range xslices.Unique(updateCh) { + updateCh.Close() + } + }) // Close the user's notify channel. user.eventCh.Close() @@ -364,16 +356,37 @@ func (user *User) Close() error { return nil } +func (user *User) checkAuth(email string, password []byte) (string, error) { + dec, err := hexDecode(password) + if err != nil { + return "", fmt.Errorf("failed to decode password: %w", err) + } + + if subtle.ConstantTimeCompare(user.vault.BridgePass(), dec) != 1 { + return "", fmt.Errorf("invalid password") + } + + return safe.MapValuesRetErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) { + for _, addr := range apiAddrs { + if addr.Email == strings.ToLower(email) { + return addr.ID, nil + } + } + + return "", fmt.Errorf("invalid email") + }) +} + // streamEvents begins streaming API events for the user. // When we receive an API event, we attempt to handle it. // If successful, we update the event ID in the vault. -func (user *User) streamEvents() <-chan error { +func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error { errCh := make(chan error) go func() { defer close(errCh) - for event := range user.client.NewEventStreamer(EventPeriod, EventJitter, user.vault.EventID()).Subscribe() { + for event := range eventCh { if err := user.handleAPIEvent(context.Background(), event); err != nil { errCh <- fmt.Errorf("failed to handle API event: %w", err) } else if err := user.vault.SetEventID(event.EventID); err != nil { @@ -387,11 +400,21 @@ func (user *User) streamEvents() <-chan error { // startSync begins a startSync for the user. func (user *User) startSync() <-chan error { + if user.vault.SyncStatus().IsComplete() { + logrus.Debug("Already synced, skipping") + return nil + } + errCh := make(chan error) - user.syncWG.Go(func() { + user.syncLock.GoTry(func(ok bool) { defer close(errCh) + if !ok { + logrus.Debug("Sync already in progress, skipping") + return + } + ctx, cancel := contextWithStopCh(context.Background(), user.syncStopCh) defer cancel() @@ -421,46 +444,24 @@ func (user *User) startSync() <-chan error { func (user *User) stopSync() { select { case user.syncStopCh <- struct{}{}: - user.syncWG.Wait() + logrus.Debug("Sent sync abort signal") default: - // ... + logrus.Debug("No sync to abort") } } -func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) { - for _, addr := range apiAddrs { - if addr.Email == email { - return addr.ID, nil - } - } - - return "", fmt.Errorf("address %s not found", email) +// lockSync prevents a new sync from starting. +func (user *User) lockSync() { + user.syncLock.Lock() } -func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) { - for _, addr := range apiAddrs { - if addr.ID == addrID { - return addr.Email, nil - } - } - - return "", fmt.Errorf("address %s not found", addrID) +// unlockSync allows a new sync to start. +func (user *User) unlockSync() { + user.syncLock.Unlock() } -// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it. -func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) { - ctx, cancel := context.WithCancel(ctx) - - go func() { - select { - case <-stopCh: - cancel() - - case <-ctx.Done(): - // ... - } - }() - - return ctx, cancel +// waitSync waits for any ongoing sync to finish. +func (user *User) waitSync() { + user.syncLock.Wait() } diff --git a/internal/vault/types.go b/internal/vault/types.go index 870498ee..4a7bba26 100644 --- a/internal/vault/types.go +++ b/internal/vault/types.go @@ -103,6 +103,10 @@ type SyncStatus struct { LastMessageID string } +func (status SyncStatus) IsComplete() bool { + return status.HasLabels && status.HasMessages +} + func newDefaultUser(userID, username, authUID, authRef string, keyPass []byte) UserData { return UserData{ UserID: userID, diff --git a/internal/vault/vault.go b/internal/vault/vault.go index 468f9556..78dad6f8 100644 --- a/internal/vault/vault.go +++ b/internal/vault/vault.go @@ -10,15 +10,18 @@ import ( "math/rand" "os" "path/filepath" + "sync" "github.com/ProtonMail/proton-bridge/v2/internal/certs" "github.com/bradenaw/juniper/xslices" ) +// Vault is an encrypted data vault that stores bridge and user data. type Vault struct { path string enc []byte gcm cipher.AEAD + lock sync.RWMutex } // New constructs a new encrypted data vault at the given filepath using the given encryption key. @@ -150,6 +153,9 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) { } func (vault *Vault) get() Data { + vault.lock.RLock() + defer vault.lock.RUnlock() + dec, err := decrypt(vault.gcm, vault.enc) if err != nil { panic(err) @@ -165,20 +171,28 @@ func (vault *Vault) get() Data { } func (vault *Vault) mod(fn func(data *Data)) error { - data := vault.get() + vault.lock.Lock() + defer vault.lock.Unlock() - fn(&data) - - return vault.set(data) -} - -func (vault *Vault) set(data Data) error { - dec, err := json.Marshal(data) + dec, err := decrypt(vault.gcm, vault.enc) if err != nil { return err } - enc, err := encrypt(vault.gcm, dec) + var data Data + + if err := json.Unmarshal(dec, &data); err != nil { + return err + } + + fn(&data) + + mod, err := json.Marshal(data) + if err != nil { + return err + } + + enc, err := encrypt(vault.gcm, mod) if err != nil { return err } diff --git a/internal/versioner/versioner_remove_test.go b/internal/versioner/versioner_remove_test.go index dcc22f1c..bc48ec8b 100644 --- a/internal/versioner/versioner_remove_test.go +++ b/internal/versioner/versioner_remove_test.go @@ -21,7 +21,6 @@ package versioner import ( - "os" "path/filepath" "testing" @@ -33,10 +32,9 @@ import ( // RemoveOldVersions is a noop on darwin; we don't test it there. func TestRemoveOldVersions(t *testing.T) { - updates, err := os.MkdirTemp(t.TempDir(), "updates") - require.NoError(t, err) + tempDir := t.TempDir() - v := newTestVersioner(t, "myCoolApp", updates, "2.3.4-beta", "2.3.4", "2.3.5", "2.4.0") + v := newTestVersioner(t, "myCoolApp", tempDir, "2.3.4-beta", "2.3.4", "2.3.5", "2.4.0") allVersions, err := v.ListVersions() require.NoError(t, err) @@ -49,5 +47,5 @@ func TestRemoveOldVersions(t *testing.T) { assert.Len(t, cleanedVersions, 1) assert.Equal(t, semver.MustParse("2.4.0"), cleanedVersions[0].version) - assert.Equal(t, filepath.Join(updates, "2.4.0"), cleanedVersions[0].path) + assert.Equal(t, filepath.Join(tempDir, "2.4.0"), cleanedVersions[0].path) } diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index 23fd6dca..1e7d2c3d 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -51,7 +51,7 @@ func (t *testCtx) startBridge() error { } // Create the bridge. - bridge, err := bridge.New( + bridge, eventCh, err := bridge.New( t.locator, vault, t.mocks.Autostarter, @@ -73,6 +73,9 @@ func (t *testCtx) startBridge() error { return err } + // Wait for the users to be loaded. + waitForEvent(eventCh, events.AllUsersLoaded{}) + // Save the bridge t. t.bridge = bridge @@ -101,3 +104,12 @@ func (t *testCtx) stopBridge() error { return nil } + +func waitForEvent[T any](eventCh <-chan events.Event, wantEvent T) { + for event := range eventCh { + switch event.(type) { + case T: + return + } + } +}