From e9672e6bbaaa891b49e79a55066686ee3fa6ac34 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Wed, 28 Sep 2022 11:29:33 +0200 Subject: [PATCH] GODT-1815: Combined/Split mode --- go.mod | 4 +- go.sum | 8 +- internal/bridge/bridge.go | 7 + internal/bridge/bridge_test.go | 36 ++- internal/bridge/configure.go | 5 +- .../bridge/{useragent.go => identifier.go} | 0 internal/bridge/settings.go | 34 +-- internal/bridge/smtp_backend.go | 4 +- internal/bridge/{users.go => user.go} | 184 +++++++----- internal/bridge/user_events.go | 118 ++++++++ .../bridge/{users_test.go => user_test.go} | 28 ++ internal/events/user.go | 26 +- internal/frontend/cli/accounts.go | 15 +- internal/frontend/cli/frontend.go | 4 +- internal/frontend/grpc/service.go | 8 +- internal/frontend/grpc/service_user.go | 14 +- internal/frontend/grpc/utils.go | 3 +- internal/pool/job.go | 41 +++ internal/pool/pool.go | 72 ++--- internal/pool/pool_test.go | 24 +- internal/user/addresses.go | 46 +++ internal/user/builder.go | 36 ++- internal/user/errors.go | 2 +- internal/user/events.go | 243 +++++++++------ internal/user/flusher.go | 76 +++++ internal/user/imap.go | 79 ++--- internal/user/map.go | 89 ++++++ internal/user/map_test.go | 48 +++ internal/user/smtp.go | 58 ++-- internal/user/sync.go | 189 ++++-------- internal/user/types.go | 13 + internal/user/types_test.go | 20 ++ internal/user/user.go | 282 +++++++++++++----- internal/user/user_test.go | 162 ++++++++++ internal/vault/token.go | 13 +- internal/vault/types.go | 16 +- internal/vault/user.go | 59 ++-- internal/vault/user_test.go | 35 ++- internal/vault/vault.go | 26 +- tests/api_test.go | 4 +- tests/bdd_test.go | 26 +- tests/bridge_test.go | 36 ++- tests/ctx_bridge_test.go | 10 +- tests/ctx_test.go | 59 ++-- tests/fast.go | 48 +++ .../imap/{user_agent.feature => id.feature} | 0 tests/features/imap/mailbox/info.feature | 2 +- tests/features/imap/message/copy.feature | 2 +- tests/features/imap/message/delete.feature | 2 +- tests/features/user/addressmode.feature | 180 +++++++++++ tests/features/user/sync.feature | 4 +- tests/imap_test.go | 24 +- tests/init_test.go | 33 +- tests/smtp_test.go | 14 +- tests/user_test.go | 43 ++- 55 files changed, 1909 insertions(+), 705 deletions(-) rename internal/bridge/{useragent.go => identifier.go} (100%) rename internal/bridge/{users.go => user.go} (69%) create mode 100644 internal/bridge/user_events.go rename internal/bridge/{users_test.go => user_test.go} (91%) create mode 100644 internal/pool/job.go create mode 100644 internal/user/addresses.go create mode 100644 internal/user/flusher.go create mode 100644 internal/user/map.go create mode 100644 internal/user/map_test.go create mode 100644 internal/user/types.go create mode 100644 internal/user/types_test.go create mode 100644 internal/user/user_test.go create mode 100644 tests/fast.go rename tests/features/imap/{user_agent.feature => id.feature} (100%) create mode 100644 tests/features/user/addressmode.feature diff --git a/go.mod b/go.mod index afe67a70..f3fe69b6 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.18 require ( github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557 github.com/Masterminds/semver/v3 v3.1.1 - github.com/ProtonMail/gluon v0.11.1-0.20220922143913-ef3617264557 + github.com/ProtonMail/gluon v0.11.1-0.20221001180052-2e11f5804b8a github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a github.com/ProtonMail/go-rfc5322 v0.11.0 github.com/ProtonMail/gopenpgp/v2 v2.4.10 @@ -37,7 +37,7 @@ require ( github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.0 github.com/urfave/cli/v2 v2.16.3 - gitlab.protontech.ch/go/liteapi v0.31.0 + gitlab.protontech.ch/go/liteapi v0.31.1-0.20221001204216-b781c54ca2a6 golang.org/x/exp v0.0.0-20220921164117-439092de6870 golang.org/x/net v0.1.0 golang.org/x/sys v0.1.0 diff --git a/go.sum b/go.sum index 1e1cc0a4..8fc868a9 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo= github.com/ProtonMail/docker-credential-helpers v1.1.0 h1:+kvUIpwWcbtP3WFv5sSvkFn/XLzSqPOB5AAthuk9xPk= github.com/ProtonMail/docker-credential-helpers v1.1.0/go.mod h1:mK0aBveCxhnQ756AmaTfXMZDeULvheYVhF/MWMErN5g= -github.com/ProtonMail/gluon v0.11.1-0.20220922143913-ef3617264557 h1:uyiHq7jDgn1p2TeMKRPnVCVs2bHoNL9AYs26UzLYr4I= -github.com/ProtonMail/gluon v0.11.1-0.20220922143913-ef3617264557/go.mod h1:9k3URQEASX9XSA+JEcukjIiK3S6aR9GzhLhwccy8AnI= +github.com/ProtonMail/gluon v0.11.1-0.20221001180052-2e11f5804b8a h1:JUjaQ7bUifpYdnLKviBPrVKOPfW6r4Mm8xCL1fdevaA= +github.com/ProtonMail/gluon v0.11.1-0.20221001180052-2e11f5804b8a/go.mod h1:9k3URQEASX9XSA+JEcukjIiK3S6aR9GzhLhwccy8AnI= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a/go.mod h1:oTGdE7/DlWIr23G0IKW3OXK9wZ5Hw1GGiaJFccTvZi4= github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo= @@ -463,8 +463,8 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0= github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= -gitlab.protontech.ch/go/liteapi v0.31.0 h1:Et3P2EyTySldgBunFJqaa5W5Fap8yLvuOaLUkZX/Kn0= -gitlab.protontech.ch/go/liteapi v0.31.0/go.mod h1:ixp1LUOxOYuB1qf172GdV0ZT8fOomKxVFtIMZeSWg+I= +gitlab.protontech.ch/go/liteapi v0.31.1-0.20221001204216-b781c54ca2a6 h1:N9Wzm4pNhIjR4aBmP9AzVGy+G8XQCDlkLy9GGEONbYM= +gitlab.protontech.ch/go/liteapi v0.31.1-0.20221001204216-b781c54ca2a6/go.mod h1:ixp1LUOxOYuB1qf172GdV0ZT8fOomKxVFtIMZeSWg+I= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index ca2f297f..e4314abc 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -232,6 +232,13 @@ func (bridge *Bridge) GetErrors() []error { } func (bridge *Bridge) Close(ctx context.Context) error { + // Abort any ongoing syncs. + for _, user := range bridge.users { + if err := user.AbortSync(ctx); err != nil { + return fmt.Errorf("failed to abort sync: %w", err) + } + } + // Close the IMAP server. if err := bridge.closeIMAP(ctx); err != nil { logrus.WithError(err).Error("Failed to close IMAP server") diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 9f9c7410..ba6e38b7 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -4,19 +4,24 @@ import ( "context" "os" "testing" + "time" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/bridge" + "github.com/ProtonMail/proton-bridge/v2/internal/certs" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/focus" "github.com/ProtonMail/proton-bridge/v2/internal/locations" "github.com/ProtonMail/proton-bridge/v2/internal/updater" + "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/vault" + "github.com/ProtonMail/proton-bridge/v2/tests" "github.com/bradenaw/juniper/xslices" "github.com/stretchr/testify/require" "gitlab.protontech.ch/go/liteapi/server" + "gitlab.protontech.ch/go/liteapi/server/account" ) const ( @@ -29,6 +34,13 @@ var ( v2_4_0 = semver.MustParse("2.4.0") ) +func init() { + user.DefaultEventPeriod = 100 * time.Millisecond + user.DefaultEventJitter = 0 + account.GenerateKey = tests.FastGenerateKey + certs.GenerateCert = tests.FastGenerateCert +} + func TestBridge_ConnStatus(t *testing.T) { withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { @@ -156,19 +168,27 @@ func TestBridge_CheckUpdate(t *testing.T) { // Disable autoupdate for this test. require.NoError(t, bridge.SetAutoUpdate(false)) - // Get a stream of update events. - updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{}) + // Get a stream of update not available events. + noUpdateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}) defer done() // We are currently on the latest version. bridge.CheckForUpdates() - require.Equal(t, events.UpdateNotAvailable{}, <-updateCh) + + // we should receive an event indicating that no update is available. + require.Equal(t, events.UpdateNotAvailable{}, <-noUpdateCh) // Simulate a new version being available. mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0) + // Get a stream of update available events. + updateCh, done := bridge.GetEvents(events.UpdateAvailable{}) + defer done() + // Check for updates. bridge.CheckForUpdates() + + // We should receive an event indicating that an update is available. require.Equal(t, events.UpdateAvailable{ Version: updater.VersionInfo{ Version: v2_4_0, @@ -188,7 +208,7 @@ func TestBridge_AutoUpdate(t *testing.T) { require.NoError(t, bridge.SetAutoUpdate(true)) // Get a stream of update events. - updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateInstalled{}) + updateCh, done := bridge.GetEvents(events.UpdateInstalled{}) defer done() // Simulate a new version being available. @@ -196,6 +216,8 @@ func TestBridge_AutoUpdate(t *testing.T) { // Check for updates. bridge.CheckForUpdates() + + // We should receive an event indicating that the update was installed. require.Equal(t, events.UpdateInstalled{ Version: updater.VersionInfo{ Version: v2_4_0, @@ -213,8 +235,8 @@ func TestBridge_ManualUpdate(t *testing.T) { // Disable autoupdate for this test. require.NoError(t, bridge.SetAutoUpdate(false)) - // Get a stream of update events. - updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{}) + // Get a stream of update available events. + updateCh, done := bridge.GetEvents(events.UpdateAvailable{}) defer done() // Simulate a new version being available, but it's too new for us. @@ -222,6 +244,8 @@ func TestBridge_ManualUpdate(t *testing.T) { // Check for updates. bridge.CheckForUpdates() + + // We should receive an event indicating an update is available, but we can't install it. require.Equal(t, events.UpdateAvailable{ Version: updater.VersionInfo{ Version: v2_4_0, diff --git a/internal/bridge/configure.go b/internal/bridge/configure.go index 0b239db8..04e888f7 100644 --- a/internal/bridge/configure.go +++ b/internal/bridge/configure.go @@ -14,8 +14,9 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error { return ErrNoSuchUser } + // TODO: Handle split mode! if address == "" { - address = user.Addresses()[0] + address = user.Emails()[0] } // If configuring apple mail for Catalina or newer, users should use SSL. @@ -32,7 +33,7 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error { bridge.vault.GetIMAPSSL(), bridge.vault.GetSMTPSSL(), address, - strings.Join(user.Addresses(), ","), + strings.Join(user.Emails(), ","), user.BridgePass(), ) } diff --git a/internal/bridge/useragent.go b/internal/bridge/identifier.go similarity index 100% rename from internal/bridge/useragent.go rename to internal/bridge/identifier.go diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index 74c1a997..acce5731 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -2,6 +2,7 @@ package bridge import ( "context" + "fmt" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v2/internal/updater" @@ -96,40 +97,39 @@ func (bridge *Bridge) GetGluonDir() string { func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error { if newGluonDir == bridge.GetGluonDir() { - return nil + return fmt.Errorf("new gluon dir is the same as the old one") } if err := bridge.closeIMAP(context.Background()); err != nil { - return err + return fmt.Errorf("failed to close IMAP: %w", err) } if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil { - return err + return fmt.Errorf("failed to move gluon dir: %w", err) } if err := bridge.vault.SetGluonDir(newGluonDir); err != nil { - return err + return fmt.Errorf("failed to set new gluon dir: %w", err) } imapServer, err := newIMAPServer(bridge.vault.GetGluonDir(), bridge.curVersion, bridge.tlsConfig) if err != nil { - return err - } - - for _, user := range bridge.users { - imapConn, err := user.NewGluonConnector(ctx) - if err != nil { - return err - } - - if err := imapServer.LoadUser(context.Background(), imapConn, user.GluonID(), user.GluonKey()); err != nil { - return err - } + return fmt.Errorf("failed to create new IMAP server: %w", err) } bridge.imapServer = imapServer - return bridge.serveIMAP() + for _, user := range bridge.users { + if err := bridge.addIMAPUser(ctx, user); err != nil { + return fmt.Errorf("failed to add IMAP user: %w", err) + } + } + + if err := bridge.serveIMAP(); err != nil { + return fmt.Errorf("failed to serve IMAP: %w", err) + } + + return nil } func (bridge *Bridge) GetProxyAllowed() bool { diff --git a/internal/bridge/smtp_backend.go b/internal/bridge/smtp_backend.go index cb08d686..15be3313 100644 --- a/internal/bridge/smtp_backend.go +++ b/internal/bridge/smtp_backend.go @@ -23,8 +23,8 @@ func (backend *smtpBackend) Login(state *smtp.ConnectionState, username string, defer backend.usersLock.RUnlock() for _, user := range backend.users { - if slices.Contains(user.Addresses(), username) && user.BridgePass() == password { - return user.NewSMTPSession(username) + if slices.Contains(user.Emails(), username) && user.BridgePass() == password { + return user.NewSMTPSession(username), nil } } diff --git a/internal/bridge/users.go b/internal/bridge/user.go similarity index 69% rename from internal/bridge/users.go rename to internal/bridge/user.go index 4714af3f..65ca35e2 100644 --- a/internal/bridge/users.go +++ b/internal/bridge/user.go @@ -29,7 +29,7 @@ type UserInfo struct { Addresses []string // AddressMode is the user's address mode. - AddressMode AddressMode + AddressMode vault.AddressMode // BridgePass is the user's bridge password. BridgePass string @@ -41,13 +41,6 @@ type UserInfo struct { MaxSpace int } -type AddressMode int - -const ( - SplitMode AddressMode = iota - CombinedMode -) - // GetUserIDs returns the IDs of all known users (authorized or not). func (bridge *Bridge) GetUserIDs() []string { return bridge.vault.GetUserIDs() @@ -62,7 +55,7 @@ func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) { user, ok := bridge.users[userID] if !ok { - return getUserInfo(vaultUser.UserID(), vaultUser.Username()), nil + return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil } return getConnUserInfo(user), nil @@ -153,12 +146,43 @@ func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error { return nil } -func (bridge *Bridge) GetAddressMode(userID string) (AddressMode, error) { - panic("TODO") -} +// SetAddressMode sets the address mode for the given user. +func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error { + user, ok := bridge.users[userID] + if !ok { + return ErrNoSuchUser + } -func (bridge *Bridge) SetAddressMode(userID string, mode AddressMode) error { - panic("TODO") + if user.GetAddressMode() == mode { + return fmt.Errorf("address mode is already %q", mode) + } + + if err := user.AbortSync(ctx); err != nil { + return fmt.Errorf("failed to abort sync: %w", err) + } + + for _, gluonID := range user.GetGluonIDs() { + if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil { + return fmt.Errorf("failed to remove user from IMAP server: %w", err) + } + } + + if err := user.SetAddressMode(ctx, mode); err != nil { + return fmt.Errorf("failed to set address mode: %w", err) + } + + if err := bridge.addIMAPUser(ctx, user); err != nil { + return fmt.Errorf("failed to add IMAP user: %w", err) + } + + bridge.publish(events.AddressModeChanged{ + UserID: userID, + AddressMode: mode, + }) + + user.DoSync(ctx) + + return nil } // loadUsers loads authorized users from the vault. @@ -177,7 +201,7 @@ func (bridge *Bridge) loadUsers(ctx context.Context) error { logrus.WithError(err).Error("Failed to load connected user") if _, ok := err.(*resty.ResponseError); ok { - if err := user.Clear(); err != nil { + if err := bridge.vault.ClearUser(userID); err != nil { logrus.WithError(err).Error("Failed to clear user") } } @@ -231,33 +255,41 @@ func (bridge *Bridge) addUser( if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) { existingUser, err := bridge.addExistingUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass) if err != nil { - return err + return fmt.Errorf("failed to add existing user: %w", err) } user = existingUser } else { newUser, err := bridge.addNewUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass) if err != nil { - return err + return fmt.Errorf("failed to add new user: %w", err) } user = newUser } - go func() { - for event := range user.GetNotifyCh() { - switch event := event.(type) { - case events.UserDeauth: - if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil { - logrus.WithError(err).Error("Failed to logout user") - } - } + // Connects the user's address(es) to gluon. + if err := bridge.addIMAPUser(ctx, user); err != nil { + return fmt.Errorf("failed to add IMAP user: %w", err) + } - bridge.publish(event) + // Handle events coming from the user before forwarding them to the bridge. + // For example, if the user's addresses change, we need to update them in gluon. + go func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for event := range user.GetEventCh() { + if err := bridge.handleUserEvent(ctx, user, event); err != nil { + logrus.WithError(err).Error("Failed to handle user event") + } else { + bridge.publish(event) + } } }() // Gluon will set the IMAP ID in the context, if known, before making requests on behalf of this user. + // As such, if we find this ID in the context, we should use it to update our user agent. client.AddPreRequestHook(func(ctx context.Context, req *resty.Request) error { if imapID, ok := imap.GetIMAPIDFromContext(ctx); ok { bridge.identifier.SetClient(imapID.Name, imapID.Version) @@ -266,6 +298,11 @@ func (bridge *Bridge) addUser( return nil }) + // TODO: Replace this with proper sync manager. + if !user.HasSync() { + user.DoSync(ctx) + } + bridge.publish(events.UserLoggedIn{ UserID: user.ID(), }) @@ -293,25 +330,6 @@ func (bridge *Bridge) addNewUser( return nil, err } - gluonKey, err := crypto.RandomToken(32) - if err != nil { - return nil, err - } - - imapConn, err := user.NewGluonConnector(ctx) - if err != nil { - return nil, err - } - - gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, gluonKey) - if err != nil { - return nil, err - } - - if err := vaultUser.SetGluonAuth(gluonID, gluonKey); err != nil { - return nil, err - } - if err := bridge.smtpBackend.addUser(user); err != nil { return nil, err } @@ -349,15 +367,6 @@ func (bridge *Bridge) addExistingUser( return nil, err } - imapConn, err := user.NewGluonConnector(ctx) - if err != nil { - return nil, err - } - - if err := bridge.imapServer.LoadUser(ctx, imapConn, user.GluonID(), user.GluonKey()); err != nil { - return nil, err - } - if err := bridge.smtpBackend.addUser(user); err != nil { return nil, err } @@ -376,31 +385,39 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, wi return ErrNoSuchUser } - vaultUser, err := bridge.vault.GetUser(userID) - if err != nil { - return err - } - - if err := bridge.imapServer.RemoveUser(ctx, vaultUser.GluonID(), withFiles); err != nil { - return err + // TODO: The sync should be canceled by the sync manager. + if err := user.AbortSync(ctx); err != nil { + return fmt.Errorf("failed to abort user sync: %w", err) } if err := bridge.smtpBackend.removeUser(user); err != nil { - return err + return fmt.Errorf("failed to remove SMTP user: %w", err) + } + + for _, gluonID := range user.GetGluonIDs() { + if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil { + return fmt.Errorf("failed to remove IMAP user: %w", err) + } } if withAPI { if err := user.Logout(ctx); err != nil { - return err + return fmt.Errorf("failed to logout user: %w", err) } } if err := user.Close(ctx); err != nil { - return err + return fmt.Errorf("failed to close user: %w", err) } - if err := vaultUser.Clear(); err != nil { - return err + if err := bridge.vault.ClearUser(userID); err != nil { + return fmt.Errorf("failed to clear user: %w", err) + } + + if withFiles { + if err := bridge.vault.DeleteUser(userID); err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } } delete(bridge.users, userID) @@ -412,12 +429,39 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, wi return nil } +// addIMAPUser connects the given user to gluon. +func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error { + imapConn, err := user.NewIMAPConnectors() + if err != nil { + return fmt.Errorf("failed to create IMAP connectors: %w", err) + } + + for addrID, imapConn := range imapConn { + if gluonID, ok := user.GetGluonID(addrID); ok { + if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil { + return fmt.Errorf("failed to load IMAP user: %w", err) + } + } else { + gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey()) + if err != nil { + return fmt.Errorf("failed to add IMAP user: %w", err) + } + + if err := user.SetGluonID(addrID, gluonID); err != nil { + return fmt.Errorf("failed to set IMAP user ID: %w", err) + } + } + } + + return nil +} + // getUserInfo returns information about a disconnected user. -func getUserInfo(userID, username string) UserInfo { +func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo { return UserInfo{ UserID: userID, Username: username, - AddressMode: CombinedMode, + AddressMode: addressMode, } } @@ -427,8 +471,8 @@ func getConnUserInfo(user *user.User) UserInfo { Connected: true, UserID: user.ID(), Username: user.Name(), - Addresses: user.Addresses(), - AddressMode: CombinedMode, + Addresses: user.Emails(), + AddressMode: user.GetAddressMode(), BridgePass: user.BridgePass(), UsedSpace: user.UsedSpace(), MaxSpace: user.MaxSpace(), diff --git a/internal/bridge/user_events.go b/internal/bridge/user_events.go new file mode 100644 index 00000000..0458aff7 --- /dev/null +++ b/internal/bridge/user_events.go @@ -0,0 +1,118 @@ +package bridge + +import ( + "context" + "fmt" + + "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/user" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" +) + +func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, event events.Event) error { + switch event := event.(type) { + case events.UserAddressCreated: + if err := bridge.handleUserAddressCreated(ctx, user, event); err != nil { + return fmt.Errorf("failed to handle user address created event: %w", err) + } + + case events.UserAddressUpdated: + if err := bridge.handleUserAddressUpdated(ctx, user, event); err != nil { + return fmt.Errorf("failed to handle user address updated event: %w", err) + } + + case events.UserAddressDeleted: + if err := bridge.handleUserAddressDeleted(ctx, user, event); err != nil { + return fmt.Errorf("failed to handle user address deleted event: %w", err) + } + + case events.UserDeauth: + if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil { + return fmt.Errorf("failed to logout user: %w", err) + } + } + + return nil +} + +func (bridge *Bridge) handleUserAddressCreated(ctx context.Context, user *user.User, event events.UserAddressCreated) error { + switch user.GetAddressMode() { + case vault.CombinedMode: + for addrID, gluonID := range user.GetGluonIDs() { + if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil { + return fmt.Errorf("failed to remove user from IMAP server: %w", err) + } + + imapConn, err := user.NewIMAPConnector(addrID) + if err != nil { + return fmt.Errorf("failed to create IMAP connector: %w", err) + } + + if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil { + return fmt.Errorf("failed to add user to IMAP server: %w", err) + } + } + + case vault.SplitMode: + imapConn, err := user.NewIMAPConnector(event.AddressID) + if err != nil { + return fmt.Errorf("failed to create IMAP connector: %w", err) + } + + gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey()) + if err != nil { + return fmt.Errorf("failed to add user to IMAP server: %w", err) + } + + if err := user.SetGluonID(event.AddressID, gluonID); err != nil { + return fmt.Errorf("failed to set gluon ID: %w", err) + } + } + + return nil +} + +// TODO: Handle addresses that have been disabled! +func (bridge *Bridge) handleUserAddressUpdated(ctx context.Context, user *user.User, event events.UserAddressUpdated) error { + switch user.GetAddressMode() { + case vault.CombinedMode: + return fmt.Errorf("not implemented") + + case vault.SplitMode: + return fmt.Errorf("not implemented") + } + + return nil +} + +func (bridge *Bridge) handleUserAddressDeleted(ctx context.Context, user *user.User, event events.UserAddressDeleted) error { + switch user.GetAddressMode() { + case vault.CombinedMode: + for addrID, gluonID := range user.GetGluonIDs() { + if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil { + return fmt.Errorf("failed to remove user from IMAP server: %w", err) + } + + imapConn, err := user.NewIMAPConnector(addrID) + if err != nil { + return fmt.Errorf("failed to create IMAP connector: %w", err) + } + + if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil { + return fmt.Errorf("failed to add user to IMAP server: %w", err) + } + } + + case vault.SplitMode: + gluonID, ok := user.GetGluonID(event.AddressID) + if !ok { + return fmt.Errorf("gluon ID not found for address %s", event.AddressID) + } + + if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil { + return fmt.Errorf("failed to remove user from IMAP server: %w", err) + } + } + + return nil +} diff --git a/internal/bridge/users_test.go b/internal/bridge/user_test.go similarity index 91% rename from internal/bridge/users_test.go rename to internal/bridge/user_test.go index f868fb09..ccc47f6b 100644 --- a/internal/bridge/users_test.go +++ b/internal/bridge/user_test.go @@ -7,6 +7,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/stretchr/testify/require" "gitlab.protontech.ch/go/liteapi/server" ) @@ -283,3 +284,30 @@ func TestBridge_BridgePass(t *testing.T) { }) }) } + +func TestBridge_AddressMode(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) { + withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Login the user. + userID, err := bridge.LoginUser(ctx, username, password, nil, nil) + require.NoError(t, err) + + // Get the user's info. + info, err := bridge.GetUserInfo(userID) + require.NoError(t, err) + + // The user is in combined mode by default. + require.Equal(t, vault.CombinedMode, info.AddressMode) + + // Put the user in split mode. + require.NoError(t, bridge.SetAddressMode(ctx, userID, vault.SplitMode)) + + // Get the user's info. + info, err = bridge.GetUserInfo(userID) + require.NoError(t, err) + + // The user is in split mode. + require.Equal(t, vault.SplitMode, info.AddressMode) + }) + }) +} diff --git a/internal/events/user.go b/internal/events/user.go index 7b6a525d..09c9accb 100644 --- a/internal/events/user.go +++ b/internal/events/user.go @@ -1,5 +1,7 @@ package events +import "github.com/ProtonMail/proton-bridge/v2/internal/vault" + type UserLoggedIn struct { eventBase @@ -33,20 +35,30 @@ type UserChanged struct { type UserAddressCreated struct { eventBase - UserID string - Address string + UserID string + AddressID string + Email string } -type UserAddressChanged struct { +type UserAddressUpdated struct { eventBase - UserID string - Address string + UserID string + AddressID string + Email string } type UserAddressDeleted struct { eventBase - UserID string - Address string + UserID string + AddressID string + Email string +} + +type AddressModeChanged struct { + eventBase + + UserID string + AddressMode vault.AddressMode } diff --git a/internal/frontend/cli/accounts.go b/internal/frontend/cli/accounts.go index e3ce1f81..69c4ef31 100644 --- a/internal/frontend/cli/accounts.go +++ b/internal/frontend/cli/accounts.go @@ -23,6 +23,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/constants" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/abiosoft/ishell" ) @@ -39,7 +40,7 @@ func (f *frontendCLI) listAccounts(c *ishell.Context) { connected = "connected" } mode := "split" - if user.AddressMode == bridge.CombinedMode { + if user.AddressMode == vault.CombinedMode { mode = "combined" } f.Printf(spacing, idx, user.Username, connected, mode) @@ -58,7 +59,7 @@ func (f *frontendCLI) showAccountInfo(c *ishell.Context) { return } - if user.AddressMode == bridge.CombinedMode { + if user.AddressMode == vault.CombinedMode { f.showAccountAddressInfo(user, user.Addresses[0]) } else { for _, address := range user.Addresses { @@ -225,19 +226,19 @@ func (f *frontendCLI) changeMode(c *ishell.Context) { return } - var targetMode bridge.AddressMode + var targetMode vault.AddressMode - if user.AddressMode == bridge.CombinedMode { - targetMode = bridge.SplitMode + if user.AddressMode == vault.CombinedMode { + targetMode = vault.SplitMode } else { - targetMode = bridge.CombinedMode + targetMode = vault.CombinedMode } if !f.yesNoQuestion("Are you sure you want to change the mode for account " + bold(user.Username) + " to " + bold(targetMode)) { return } - if err := f.bridge.SetAddressMode(user.UserID, targetMode); err != nil { + if err := f.bridge.SetAddressMode(context.Background(), user.UserID, targetMode); err != nil { f.printAndLogError("Cannot switch address mode:", err) } diff --git a/internal/frontend/cli/frontend.go b/internal/frontend/cli/frontend.go index f45cacc7..471c267b 100644 --- a/internal/frontend/cli/frontend.go +++ b/internal/frontend/cli/frontend.go @@ -296,7 +296,7 @@ func (f *frontendCLI) watchEvents() { f.notifyLogout(user.Username) - case events.UserAddressChanged: + case events.UserAddressUpdated: user, err := f.bridge.GetUserInfo(event.UserID) if err != nil { return @@ -305,7 +305,7 @@ func (f *frontendCLI) watchEvents() { f.Printf("Address changed for %s. You may need to reconfigure your email client.\n", user.Username) case events.UserAddressDeleted: - f.notifyLogout(event.Address) + f.notifyLogout(event.Email) case events.SyncStarted: user, err := f.bridge.GetUserInfo(event.UserID) diff --git a/internal/frontend/grpc/service.go b/internal/frontend/grpc/service.go index 31677375..dfe92f99 100644 --- a/internal/frontend/grpc/service.go +++ b/internal/frontend/grpc/service.go @@ -228,13 +228,13 @@ func (s *Service) watchEvents() { _ = s.SendEvent(NewShowMainWindowEvent()) case events.UserAddressCreated: - _ = s.SendEvent(NewMailAddressChangeEvent(event.Address)) + _ = s.SendEvent(NewMailAddressChangeEvent(event.Email)) - case events.UserAddressChanged: - _ = s.SendEvent(NewMailAddressChangeEvent(event.Address)) + case events.UserAddressUpdated: + _ = s.SendEvent(NewMailAddressChangeEvent(event.Email)) case events.UserAddressDeleted: - _ = s.SendEvent(NewMailAddressChangeLogoutEvent(event.Address)) + _ = s.SendEvent(NewMailAddressChangeLogoutEvent(event.Email)) case events.UserChanged: _ = s.SendEvent(NewUserChangedEvent(event.UserID)) diff --git a/internal/frontend/grpc/service_user.go b/internal/frontend/grpc/service_user.go index d6fadb7a..d588ba73 100644 --- a/internal/frontend/grpc/service_user.go +++ b/internal/frontend/grpc/service_user.go @@ -20,7 +20,7 @@ package grpc import ( "context" - "github.com/ProtonMail/proton-bridge/v2/internal/bridge" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -74,15 +74,15 @@ func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitMode defer s.panicHandler.HandlePanic() defer func() { _ = s.SendEvent(NewUserToggleSplitModeFinishedEvent(splitMode.UserID)) }() - var targetMode bridge.AddressMode + var targetMode vault.AddressMode - if splitMode.Active && user.AddressMode == bridge.CombinedMode { - targetMode = bridge.SplitMode - } else if !splitMode.Active && user.AddressMode == bridge.SplitMode { - targetMode = bridge.CombinedMode + if splitMode.Active && user.AddressMode == vault.CombinedMode { + targetMode = vault.SplitMode + } else if !splitMode.Active && user.AddressMode == vault.SplitMode { + targetMode = vault.CombinedMode } - if err := s.bridge.SetAddressMode(user.UserID, targetMode); err != nil { + if err := s.bridge.SetAddressMode(context.Background(), user.UserID, targetMode); err != nil { logrus.WithError(err).Error("Failed to set address mode") } }() diff --git a/internal/frontend/grpc/utils.go b/internal/frontend/grpc/utils.go index b8dbb580..5626a128 100644 --- a/internal/frontend/grpc/utils.go +++ b/internal/frontend/grpc/utils.go @@ -22,6 +22,7 @@ import ( "strings" "github.com/ProtonMail/proton-bridge/v2/internal/bridge" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/sirupsen/logrus" ) @@ -64,7 +65,7 @@ func grpcUserFromInfo(user bridge.UserInfo) *User { Username: user.Username, AvatarText: getInitials(user.Username), LoggedIn: user.Connected, - SplitMode: user.AddressMode == bridge.SplitMode, + SplitMode: user.AddressMode == vault.SplitMode, SetupGuideSeen: true, // users listed have already seen the setup guide. UsedBytes: int64(user.UsedSpace), TotalBytes: int64(user.MaxSpace), diff --git a/internal/pool/job.go b/internal/pool/job.go new file mode 100644 index 00000000..cfd70f5d --- /dev/null +++ b/internal/pool/job.go @@ -0,0 +1,41 @@ +package pool + +import "context" + +type job[In, Out any] struct { + ctx context.Context + req In + + res chan Out + err chan error + + done chan struct{} +} + +func newJob[In, Out any](ctx context.Context, req In) *job[In, Out] { + return &job[In, Out]{ + ctx: ctx, + req: req, + res: make(chan Out), + err: make(chan error), + done: make(chan struct{}), + } +} + +func (job *job[In, Out]) result() (Out, error) { + return <-job.res, <-job.err +} + +func (job *job[In, Out]) postSuccess(res Out) { + close(job.err) + job.res <- res +} + +func (job *job[In, Out]) postFailure(err error) { + close(job.res) + job.err <- err +} + +func (job *job[In, Out]) waitDone() { + <-job.done +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 5b1155ed..aafb42f9 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -13,16 +13,16 @@ var ErrJobCancelled = errors.New("Job cancelled by surrounding context") // Pool is a worker pool that handles input of type In and returns results of type Out. type Pool[In comparable, Out any] struct { - queue *queue.QueuedChannel[*Job[In, Out]] + queue *queue.QueuedChannel[*job[In, Out]] size int } -// DoneFunc must be called to free up pool resources. -type DoneFunc func() +// doneFunc must be called to free up pool resources. +type doneFunc func() // New returns a new pool. func New[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] { - queue := queue.NewQueuedChannel[*Job[In, Out]](0, 0) + queue := queue.NewQueuedChannel[*job[In, Out]](0, 0) for i := 0; i < size; i++ { go func() { @@ -51,17 +51,6 @@ func New[In comparable, Out any](size int, work func(context.Context, In) (Out, } } -// NewJob submits a job to the pool. It returns a job handle and a DoneFunc. -// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done, -// which frees up the worker in the pool for reuse. -func (pool *Pool[In, Out]) NewJob(ctx context.Context, req In) (*Job[In, Out], DoneFunc) { - job := newJob[In, Out](ctx, req) - - pool.queue.Enqueue(job) - - return job, func() { close(job.done) } -} - // Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred. func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, Out, error) error) error { ctx, cancel := context.WithCancel(ctx) @@ -81,10 +70,10 @@ func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, O go func() { defer wg.Done() - job, done := pool.NewJob(ctx, req) + job, done := pool.newJob(ctx, req) defer done() - res, err := job.Result() + res, err := job.result() if err := fn(req, res, err); err != nil { lock.Lock() @@ -134,44 +123,25 @@ func (pool *Pool[In, Out]) ProcessAll(ctx context.Context, reqs []In) (map[In]Ou return data, nil } +// ProcessOne submits one job to the pool and returns the result. +func (pool *Pool[In, Out]) ProcessOne(ctx context.Context, req In) (Out, error) { + job, done := pool.newJob(ctx, req) + defer done() + + return job.result() +} + func (pool *Pool[In, Out]) Done() { pool.queue.Close() } -type Job[In, Out any] struct { - ctx context.Context - req In +// newJob submits a job to the pool. It returns a job handle and a DoneFunc. +// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done, +// which frees up the worker in the pool for reuse. +func (pool *Pool[In, Out]) newJob(ctx context.Context, req In) (*job[In, Out], doneFunc) { + job := newJob[In, Out](ctx, req) - res chan Out - err chan error + pool.queue.Enqueue(job) - done chan struct{} -} - -func newJob[In, Out any](ctx context.Context, req In) *Job[In, Out] { - return &Job[In, Out]{ - ctx: ctx, - req: req, - res: make(chan Out), - err: make(chan error), - done: make(chan struct{}), - } -} - -func (job *Job[In, Out]) Result() (Out, error) { - return <-job.res, <-job.err -} - -func (job *Job[In, Out]) postSuccess(res Out) { - close(job.err) - job.res <- res -} - -func (job *Job[In, Out]) postFailure(err error) { - close(job.res) - job.err <- err -} - -func (job *Job[In, Out]) waitDone() { - <-job.done + return job, func() { close(job.done) } } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index a59f3941..03bbb953 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -15,16 +15,16 @@ import ( func TestPool_NewJob(t *testing.T) { doubler := newDoubler(runtime.NumCPU()) - job1, done1 := doubler.NewJob(context.Background(), 1) + job1, done1 := doubler.newJob(context.Background(), 1) defer done1() - job2, done2 := doubler.NewJob(context.Background(), 2) + job2, done2 := doubler.newJob(context.Background(), 2) defer done2() - res2, err := job2.Result() + res2, err := job2.result() require.NoError(t, err) - res1, err := job1.Result() + res1, err := job1.result() require.NoError(t, err) assert.Equal(t, 2, res1) @@ -36,31 +36,31 @@ func TestPool_NewJob_Done(t *testing.T) { doubler := newDoubler(2) // Start two jobs. Don't mark the jobs as done yet. - job1, done1 := doubler.NewJob(context.Background(), 1) - job2, done2 := doubler.NewJob(context.Background(), 2) + job1, done1 := doubler.newJob(context.Background(), 1) + job2, done2 := doubler.newJob(context.Background(), 2) // Get the first result. - res1, _ := job1.Result() + res1, _ := job1.result() assert.Equal(t, 2, res1) // Get the first result. - res2, _ := job2.Result() + res2, _ := job2.result() assert.Equal(t, 4, res2) // Additional jobs will wait. - job3, _ := doubler.NewJob(context.Background(), 3) - job4, _ := doubler.NewJob(context.Background(), 4) + job3, _ := doubler.newJob(context.Background(), 3) + job4, _ := doubler.newJob(context.Background(), 4) // Channel to collect results from jobs 3 and 4. resCh := make(chan int, 2) go func() { - res, _ := job3.Result() + res, _ := job3.result() resCh <- res }() go func() { - res, _ := job4.Result() + res, _ := job4.result() resCh <- res }() diff --git a/internal/user/addresses.go b/internal/user/addresses.go new file mode 100644 index 00000000..451770f0 --- /dev/null +++ b/internal/user/addresses.go @@ -0,0 +1,46 @@ +package user + +import "gitlab.protontech.ch/go/liteapi" + +type addrList struct { + apiAddrs ordMap[string, string, liteapi.Address] +} + +func newAddrList(apiAddrs []liteapi.Address) *addrList { + return &addrList{ + apiAddrs: newOrdMap( + func(addr liteapi.Address) string { return addr.ID }, + func(addr liteapi.Address) string { return addr.Email }, + func(a, b liteapi.Address) bool { return a.Order < b.Order }, + apiAddrs..., + ), + } +} + +func (list *addrList) insert(address liteapi.Address) { + list.apiAddrs.insert(address) +} + +func (list *addrList) delete(addrID string) string { + return list.apiAddrs.delete(addrID) +} + +func (list *addrList) primary() string { + return list.apiAddrs.keys()[0] +} + +func (list *addrList) addrIDs() []string { + return list.apiAddrs.keys() +} + +func (list *addrList) emails() []string { + return list.apiAddrs.values() +} + +func (list *addrList) email(addrID string) string { + return list.apiAddrs.get(addrID) +} + +func (list *addrList) addrMap() map[string]string { + return list.apiAddrs.toMap() +} diff --git a/internal/user/builder.go b/internal/user/builder.go index ecaf4140..003cb31d 100644 --- a/internal/user/builder.go +++ b/internal/user/builder.go @@ -2,16 +2,20 @@ package user import ( "context" + "time" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/pool" "github.com/ProtonMail/proton-bridge/v2/pkg/message" + "github.com/bradenaw/juniper/xslices" "gitlab.protontech.ch/go/liteapi" + "golang.org/x/exp/slices" ) type request struct { messageID string + addressID string addrKR *crypto.KeyRing } @@ -54,8 +58,38 @@ func newBuilder(f fetcher, msgWorkers, attWorkers int) *pool.Pool[request, *imap return nil, err } - return getMessageCreatedUpdate(msg, literal) + return newMessageCreatedUpdate(msg, literal) }) return msgPool } + +func newMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) { + parsedMessage, err := imap.NewParsedMessage(literal) + if err != nil { + return nil, err + } + + flags := imap.NewFlagSet() + + if !message.Unread { + flags = flags.Add(imap.FlagSeen) + } + + if slices.Contains(message.LabelIDs, liteapi.StarredLabel) { + flags = flags.Add(imap.FlagFlagged) + } + + imapMessage := imap.Message{ + ID: imap.MessageID(message.ID), + Flags: flags, + Date: time.Unix(message.Time, 0), + } + + return &imap.MessageCreated{ + Message: imapMessage, + Literal: literal, + LabelIDs: mapTo[string, imap.LabelID](xslices.Filter(message.LabelIDs, wantLabelID)), + ParsedMessage: parsedMessage, + }, nil +} diff --git a/internal/user/errors.go b/internal/user/errors.go index 6bfdba3d..0045bcc9 100644 --- a/internal/user/errors.go +++ b/internal/user/errors.go @@ -8,5 +8,5 @@ var ( ErrNotSupported = errors.New("not supported") ErrInvalidReturnPath = errors.New("invalid return path") ErrInvalidRecipient = errors.New("invalid recipient") - ErrMissingAddressKey = errors.New("missing address key") + ErrMissingAddrKey = errors.New("missing address key") ) diff --git a/internal/user/events.go b/internal/user/events.go index b6dcc500..85b4553c 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -2,43 +2,44 @@ package user import ( "context" + "fmt" "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/xslices" "gitlab.protontech.ch/go/liteapi" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" ) // handleAPIEvent handles the given liteapi.Event. -func (user *User) handleAPIEvent(event liteapi.Event) error { +func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error { if event.User != nil { - if err := user.handleUserEvent(*event.User); err != nil { + if err := user.handleUserEvent(ctx, *event.User); err != nil { return err } } if len(event.Addresses) > 0 { - if err := user.handleAddressEvents(event.Addresses); err != nil { + if err := user.handleAddressEvents(ctx, event.Addresses); err != nil { return err } } if event.MailSettings != nil { - if err := user.handleMailSettingsEvent(*event.MailSettings); err != nil { + if err := user.handleMailSettingsEvent(ctx, *event.MailSettings); err != nil { return err } } if len(event.Labels) > 0 { - if err := user.handleLabelEvents(event.Labels); err != nil { + if err := user.handleLabelEvents(ctx, event.Labels); err != nil { return err } } if len(event.Messages) > 0 { - if err := user.handleMessageEvents(event.Messages); err != nil { + if err := user.handleMessageEvents(ctx, event.Messages); err != nil { return err } } @@ -47,7 +48,7 @@ func (user *User) handleAPIEvent(event liteapi.Event) error { } // handleUserEvent handles the given user event. -func (user *User) handleUserEvent(userEvent liteapi.User) error { +func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) error { userKR, err := userEvent.Keys.Unlock(user.vault.KeyPass(), nil) if err != nil { return err @@ -57,49 +58,31 @@ func (user *User) handleUserEvent(userEvent liteapi.User) error { user.userKR = userKR - user.notifyCh <- events.UserChanged{ + user.eventCh.Enqueue(events.UserChanged{ UserID: user.ID(), - } + }) return nil } // handleAddressEvents handles the given address events. // TODO: If split address mode, need to signal back to bridge to update the addresses! -func (user *User) handleAddressEvents(addressEvents []liteapi.AddressEvent) error { +func (user *User) handleAddressEvents(ctx context.Context, addressEvents []liteapi.AddressEvent) error { for _, event := range addressEvents { switch event.Action { - case liteapi.EventDelete: - address, err := user.deleteAddress(event.ID) - if err != nil { - return err - } - - // TODO: This is not the same as addressChangedLogout event! - // That was only relevant in split mode. This is used differently now. - user.notifyCh <- events.UserAddressDeleted{ - UserID: user.ID(), - Address: address.Email, - } - case liteapi.EventCreate: - if err := user.createAddress(event.Address); err != nil { - return err - } - - user.notifyCh <- events.UserAddressCreated{ - UserID: user.ID(), - Address: event.Address.Email, + if err := user.handleCreateAddressEvent(ctx, event); err != nil { + return fmt.Errorf("failed to handle create address event: %w", err) } case liteapi.EventUpdate: - if err := user.updateAddress(event.Address); err != nil { - return err + if err := user.handleUpdateAddressEvent(ctx, event); err != nil { + return fmt.Errorf("failed to handle update address event: %w", err) } - user.notifyCh <- events.UserAddressChanged{ - UserID: user.ID(), - Address: event.Address.Email, + case liteapi.EventDelete: + if err := user.handleDeleteAddressEvent(ctx, event); err != nil { + return fmt.Errorf("failed to delete address: %w", err) } } } @@ -107,111 +90,189 @@ func (user *User) handleAddressEvents(addressEvents []liteapi.AddressEvent) erro return nil } -// createAddress creates the given address. -func (user *User) createAddress(address liteapi.Address) error { - addrKR, err := address.Keys.Unlock(user.vault.KeyPass(), user.userKR) +func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error { + addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR) if err != nil { - return err + return fmt.Errorf("failed to unlock address keys: %w", err) } - if user.imapConn != nil { - user.imapConn.addAddress(address.Email) + user.apiAddrs.insert(event.Address) + + user.addrKRs[event.Address.ID] = addrKR + + if user.vault.AddressMode() == vault.SplitMode { + user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0) + + if err := user.syncLabels(ctx, event.Address.ID); err != nil { + return fmt.Errorf("failed to sync labels to new address: %w", err) + } } - user.addresses = append(user.addresses, address) - - user.addrKRs[address.ID] = addrKR + user.eventCh.Enqueue(events.UserAddressCreated{ + UserID: user.ID(), + AddressID: event.Address.ID, + Email: event.Address.Email, + }) return nil } -// updateAddress updates the given address. -func (user *User) updateAddress(address liteapi.Address) error { - if _, err := user.deleteAddress(address.ID); err != nil { - return err +func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error { + addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR) + if err != nil { + return fmt.Errorf("failed to unlock address keys: %w", err) } - return user.createAddress(address) -} + user.apiAddrs.insert(event.Address) -// deleteAddress deletes the given address. -func (user *User) deleteAddress(addressID string) (liteapi.Address, error) { - idx := xslices.IndexFunc(user.addresses, func(address liteapi.Address) bool { - return address.ID == addressID + user.addrKRs[event.Address.ID] = addrKR + + user.eventCh.Enqueue(events.UserAddressUpdated{ + UserID: user.ID(), + AddressID: event.Address.ID, + Email: event.Address.Email, }) - if idx < 0 { - return liteapi.Address{}, ErrNoSuchAddress + return nil +} + +func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.AddressEvent) error { + email := user.apiAddrs.delete(event.ID) + + if user.vault.AddressMode() == vault.SplitMode { + user.updateCh[event.ID].Close() + delete(user.updateCh, event.ID) } - if user.imapConn != nil { - user.imapConn.remAddress(user.addresses[idx].Email) - } + user.eventCh.Enqueue(events.UserAddressDeleted{ + UserID: user.ID(), + AddressID: event.ID, + Email: email, + }) - var address liteapi.Address - - address, user.addresses = user.addresses[idx], append(user.addresses[:idx], user.addresses[idx+1:]...) - - delete(user.addrKRs, addressID) - - return address, nil + return nil } // handleMailSettingsEvent handles the given mail settings event. -func (user *User) handleMailSettingsEvent(mailSettingsEvent liteapi.MailSettings) error { +func (user *User) handleMailSettingsEvent(ctx context.Context, mailSettingsEvent liteapi.MailSettings) error { user.settings = mailSettingsEvent + return nil } // handleLabelEvents handles the given label events. -func (user *User) handleLabelEvents(labelEvents []liteapi.LabelEvent) error { +func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.LabelEvent) error { for _, event := range labelEvents { switch event.Action { - case liteapi.EventDelete: - user.updateCh <- imap.NewMailboxDeleted(imap.LabelID(event.ID)) - case liteapi.EventCreate: - user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label)) + if err := user.handleCreateLabelEvent(ctx, event); err != nil { + return fmt.Errorf("failed to handle create label event: %w", err) + } case liteapi.EventUpdate, liteapi.EventUpdateFlags: - user.updateCh <- imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label)) + if err := user.handleUpdateLabelEvent(ctx, event); err != nil { + return fmt.Errorf("failed to handle update label event: %w", err) + } + + case liteapi.EventDelete: + if err := user.handleDeleteLabelEvent(ctx, event); err != nil { + return fmt.Errorf("failed to handle delete label event: %w", err) + } } } return nil } +func (user *User) handleCreateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error { + for _, updateCh := range user.updateCh { + updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label))) + } + + return nil +} + +func (user *User) handleUpdateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error { + for _, updateCh := range user.updateCh { + updateCh.Enqueue(imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label))) + } + + return nil +} + +func (user *User) handleDeleteLabelEvent(ctx context.Context, event liteapi.LabelEvent) error { + for _, updateCh := range user.updateCh { + updateCh.Enqueue(imap.NewMailboxDeleted(imap.LabelID(event.ID))) + } + + return nil +} + // handleMessageEvents handles the given message events. -func (user *User) handleMessageEvents(messageEvents []liteapi.MessageEvent) error { - ctx, cancel := context.WithCancel(context.Background()) +func (user *User) handleMessageEvents(ctx context.Context, messageEvents []liteapi.MessageEvent) error { + ctx, cancel := context.WithCancel(ctx) defer cancel() for _, event := range messageEvents { switch event.Action { - case liteapi.EventDelete: - return ErrNotImplemented - case liteapi.EventCreate: - messages, err := user.builder.ProcessAll(ctx, []request{{event.ID, user.addrKRs[event.Message.AddressID]}}) - if err != nil { - return err + if err := user.handleCreateMessageEvent(ctx, event); err != nil { + return fmt.Errorf("failed to handle create message event: %w", err) } - user.updateCh <- imap.NewMessagesCreated(maps.Values(messages)...) - case liteapi.EventUpdate, liteapi.EventUpdateFlags: - user.updateCh <- imap.NewMessageLabelsUpdated( - imap.MessageID(event.ID), - imapLabelIDs(filterLabelIDs(event.Message.LabelIDs)), - bool(!event.Message.Unread), - slices.Contains(event.Message.LabelIDs, liteapi.StarredLabel), - ) + if err := user.handleUpdateMessageEvent(ctx, event); err != nil { + return fmt.Errorf("failed to handle update message event: %w", err) + } + + case liteapi.EventDelete: + return ErrNotImplemented } } return nil } +func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error { + var addressID string + + if user.GetAddressMode() == vault.CombinedMode { + addressID = user.apiAddrs.primary() + } else { + addressID = event.Message.AddressID + } + + message, err := user.builder.ProcessOne(ctx, request{ + messageID: event.ID, + addressID: addressID, + addrKR: user.addrKRs[event.Message.AddressID], + }) + if err != nil { + return err + } + + user.updateCh[addressID].Enqueue(imap.NewMessagesCreated(message)) + + return nil +} + +func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error { + update := imap.NewMessageLabelsUpdated( + imap.MessageID(event.ID), + mapTo[string, imap.LabelID](xslices.Filter(event.Message.LabelIDs, wantLabelID)), + event.Message.Seen(), + event.Message.Starred(), + ) + + if user.GetAddressMode() == vault.CombinedMode { + user.updateCh[user.apiAddrs.primary()].Enqueue(update) + } else { + user.updateCh[event.Message.AddressID].Enqueue(update) + } + + return nil +} + func getMailboxName(label liteapi.Label) []string { var name []string diff --git a/internal/user/flusher.go b/internal/user/flusher.go new file mode 100644 index 00000000..2bab1428 --- /dev/null +++ b/internal/user/flusher.go @@ -0,0 +1,76 @@ +package user + +import ( + "sync" + "time" + + "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/queue" + "github.com/ProtonMail/proton-bridge/v2/internal/events" +) + +type flusher struct { + userID string + updateCh *queue.QueuedChannel[imap.Update] + eventCh *queue.QueuedChannel[events.Event] + + updates []*imap.MessageCreated + maxChunkSize int + curChunkSize int + + count int + total int + start time.Time + + pushLock sync.Mutex +} + +func newFlusher( + userID string, + updateCh *queue.QueuedChannel[imap.Update], + eventCh *queue.QueuedChannel[events.Event], + total, maxChunkSize int, +) *flusher { + return &flusher{ + userID: userID, + updateCh: updateCh, + eventCh: eventCh, + + maxChunkSize: maxChunkSize, + + total: total, + start: time.Now(), + } +} + +func (f *flusher) push(update *imap.MessageCreated) { + f.pushLock.Lock() + defer f.pushLock.Unlock() + + f.updates = append(f.updates, update) + + if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize { + f.flush() + } +} + +func (f *flusher) flush() { + if len(f.updates) == 0 { + return + } + + f.count += len(f.updates) + f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...)) + f.eventCh.Enqueue(newSyncProgress(f.userID, f.count, f.total, f.start)) + f.updates = nil + f.curChunkSize = 0 +} + +func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress { + return events.SyncProgress{ + UserID: userID, + Progress: float64(count) / float64(total), + Elapsed: time.Since(start), + Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count), + } +} diff --git a/internal/user/imap.go b/internal/user/imap.go index 5f068859..4b75e843 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -25,11 +25,12 @@ const ( ) type imapConnector struct { + addrID string client *liteapi.Client updateCh <-chan imap.Update - addresses []string - password string + emails []string + password string flags, permFlags, attrs imap.FlagSet } @@ -37,15 +38,15 @@ type imapConnector struct { func newIMAPConnector( client *liteapi.Client, updateCh <-chan imap.Update, - addresses []string, password string, + emails ...string, ) *imapConnector { return &imapConnector{ client: client, updateCh: updateCh, - addresses: addresses, - password: password, + emails: emails, + password: password, flags: defaultFlags, permFlags: defaultPermanentFlags, @@ -59,7 +60,7 @@ func (conn *imapConnector) Authorize(username string, password string) bool { return false } - return xslices.IndexFunc(conn.addresses, func(address string) bool { + return xslices.IndexFunc(conn.emails, func(address string) bool { return strings.EqualFold(address, username) }) >= 0 } @@ -187,7 +188,7 @@ func (conn *imapConnector) GetMessage(ctx context.Context, messageID imap.Messag ID: imap.MessageID(message.ID), Flags: flags, Date: time.Unix(message.Time, 0), - }, imapLabelIDs(message.LabelIDs), nil + }, mapTo[string, imap.LabelID](message.LabelIDs), nil } // CreateMessage creates a new message on the remote. @@ -204,21 +205,21 @@ func (conn *imapConnector) CreateMessage( // LabelMessages labels the given messages with the given label ID. func (conn *imapConnector) LabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error { - return conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), string(labelID)) + return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelID)) } // UnlabelMessages unlabels the given messages with the given label ID. func (conn *imapConnector) UnlabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error { - return conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), string(labelID)) + return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelID)) } // MoveMessages removes the given messages from one label and adds them to the other label. func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.MessageID, labelFromID imap.LabelID, labelToID imap.LabelID) error { - if err := conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), string(labelToID)); err != nil { + if err := conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil { return fmt.Errorf("labeling messages: %w", err) } - if err := conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), string(labelFromID)); err != nil { + if err := conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil { return fmt.Errorf("unlabeling messages: %w", err) } @@ -228,18 +229,18 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M // MarkMessagesSeen sets the seen value of the given messages. func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.MessageID, seen bool) error { if seen { - return conn.client.MarkMessagesRead(ctx, strMessageIDs(messageIDs)...) + return conn.client.MarkMessagesRead(ctx, mapTo[imap.MessageID, string](messageIDs)...) } else { - return conn.client.MarkMessagesUnread(ctx, strMessageIDs(messageIDs)...) + return conn.client.MarkMessagesUnread(ctx, mapTo[imap.MessageID, string](messageIDs)...) } } // MarkMessagesFlagged sets the flagged value of the given messages. func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs []imap.MessageID, flagged bool) error { if flagged { - return conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), liteapi.StarredLabel) + return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel) } else { - return conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), liteapi.StarredLabel) + return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel) } } @@ -249,45 +250,17 @@ func (conn *imapConnector) GetUpdates() <-chan imap.Update { return conn.updateCh } -// Close the connector when it will no longer be used and all resources should be closed/released. -func (conn *imapConnector) Close(ctx context.Context) error { +// GetUIDValidity returns the default UID validity for this user. +func (conn *imapConnector) GetUIDValidity() imap.UID { + return imap.UID(1) +} + +// SetUIDValidity sets the default UID validity for this user. +func (conn *imapConnector) SetUIDValidity(uidValidity imap.UID) error { return nil } -func (conn *imapConnector) addAddress(address string) { - conn.addresses = append(conn.addresses, address) -} - -func (conn *imapConnector) remAddress(address string) { - idx := slices.Index(conn.addresses, address) - - if idx < 0 { - return - } - - conn.addresses = append(conn.addresses[:idx], conn.addresses[idx+1:]...) -} - -func strLabelIDs(imapLabelIDs []imap.LabelID) []string { - return xslices.Map(imapLabelIDs, func(labelID imap.LabelID) string { - return string(labelID) - }) -} - -func imapLabelIDs(labelIDs []string) []imap.LabelID { - return xslices.Map(labelIDs, func(labelID string) imap.LabelID { - return imap.LabelID(labelID) - }) -} - -func strMessageIDs(imapMessageIDs []imap.MessageID) []string { - return xslices.Map(imapMessageIDs, func(messageID imap.MessageID) string { - return string(messageID) - }) -} - -func imapMessageIDs(messageIDs []string) []imap.MessageID { - return xslices.Map(messageIDs, func(messageID string) imap.MessageID { - return imap.MessageID(messageID) - }) +// Close the connector will no longer be used and all resources should be closed/released. +func (conn *imapConnector) Close(ctx context.Context) error { + return nil } diff --git a/internal/user/map.go b/internal/user/map.go new file mode 100644 index 00000000..3f0efb4f --- /dev/null +++ b/internal/user/map.go @@ -0,0 +1,89 @@ +package user + +import ( + "github.com/bradenaw/juniper/xslices" + "golang.org/x/exp/slices" +) + +type ordMap[Key comparable, Val, Data any] struct { + data map[Key]Data + order []Key + + toKey func(Data) Key + toVal func(Data) Val + isLess func(Data, Data) bool +} + +func newOrdMap[Key comparable, Val, Data any]( + key func(Data) Key, + value func(Data) Val, + less func(Data, Data) bool, + data ...Data, +) ordMap[Key, Val, Data] { + m := ordMap[Key, Val, Data]{ + data: make(map[Key]Data), + + toKey: key, + toVal: value, + isLess: less, + } + + for _, d := range data { + m.insert(d) + } + + return m +} + +func (set *ordMap[Key, Val, Data]) insert(data Data) { + if _, ok := set.data[set.toKey(data)]; ok { + set.delete(set.toKey(data)) + } + + set.data[set.toKey(data)] = data + + set.order = append(set.order, set.toKey(data)) + + slices.SortFunc(set.order, func(a, b Key) bool { + return set.isLess(set.data[a], set.data[b]) + }) +} + +func (set *ordMap[Key, Val, Data]) delete(key Key) Val { + data, ok := set.data[key] + if !ok { + return *new(Val) + } + + delete(set.data, key) + + set.order = xslices.Filter(set.order, func(otherKey Key) bool { + return otherKey != key + }) + + return set.toVal(data) +} + +func (set *ordMap[Key, Val, Data]) get(key Key) Val { + return set.toVal(set.data[key]) +} + +func (set *ordMap[Key, Val, Data]) keys() []Key { + return set.order +} + +func (set *ordMap[Key, Val, Data]) values() []Val { + return xslices.Map(set.order, func(key Key) Val { + return set.toVal(set.data[key]) + }) +} + +func (set *ordMap[Key, Val, Data]) toMap() map[Key]Val { + m := make(map[Key]Val) + + for _, key := range set.order { + m[key] = set.toVal(set.data[key]) + } + + return m +} diff --git a/internal/user/map_test.go b/internal/user/map_test.go new file mode 100644 index 00000000..75e6d99e --- /dev/null +++ b/internal/user/map_test.go @@ -0,0 +1,48 @@ +package user + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMap(t *testing.T) { + type Key int + + type Value string + + type Data struct { + key Key + value Value + } + + m := newOrdMap( + func(d Data) Key { return d.key }, + func(d Data) Value { return d.value }, + func(a, b Data) bool { return a.key < b.key }, + Data{key: 1, value: "a"}, + Data{key: 2, value: "b"}, + Data{key: 3, value: "c"}, + ) + + // Insert some new data. + m.insert(Data{key: 4, value: "d"}) + m.insert(Data{key: 5, value: "e"}) + + // Delete some data. + require.Equal(t, Value("c"), m.delete(3)) + require.Equal(t, Value("a"), m.delete(1)) + require.Equal(t, Value("e"), m.delete(5)) + + // Check the remaining keys and values are correct. + require.Equal(t, []Key{2, 4}, m.keys()) + require.Equal(t, []Value{"b", "d"}, m.values()) + + // Overwrite some data. + m.insert(Data{key: 2, value: "two"}) + m.insert(Data{key: 4, value: "four"}) + + // Check the remaining keys and values are correct. + require.Equal(t, []Key{2, 4}, m.keys()) + require.Equal(t, []Value{"two", "four"}, m.values()) +} diff --git a/internal/user/smtp.go b/internal/user/smtp.go index d50fb871..0a92e2fc 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -20,12 +20,14 @@ import ( ) type smtpSession struct { - client *liteapi.Client - username string - addresses []liteapi.Address - userKR *crypto.KeyRing - addrKRs map[string]*crypto.KeyRing - settings liteapi.MailSettings + client *liteapi.Client + + username string + emails map[string]string + settings liteapi.MailSettings + + userKR *crypto.KeyRing + addrKRs map[string]*crypto.KeyRing from string to map[string]struct{} @@ -34,18 +36,20 @@ type smtpSession struct { func newSMTPSession( client *liteapi.Client, username string, - addresses []liteapi.Address, + addresses map[string]string, + settings liteapi.MailSettings, userKR *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing, - settings liteapi.MailSettings, ) *smtpSession { return &smtpSession{ - client: client, - username: username, - addresses: addresses, - userKR: userKR, - addrKRs: addrKRs, - settings: settings, + client: client, + + username: username, + emails: addresses, + settings: settings, + + userKR: userKR, + addrKRs: addrKRs, from: "", to: make(map[string]struct{}), @@ -86,15 +90,15 @@ func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error { return ErrNotImplemented } - idx := xslices.IndexFunc(session.addresses, func(address liteapi.Address) bool { - return strings.EqualFold(address.Email, from) - }) - - if idx < 0 { - return ErrInvalidReturnPath + for addrID, email := range session.emails { + if strings.EqualFold(from, email) { + session.from = addrID + } } - session.from = session.addresses[idx].ID + if session.from == "" { + return ErrInvalidReturnPath + } return nil } @@ -129,10 +133,10 @@ func (session *smtpSession) Data(r io.Reader) error { addrKR, ok := session.addrKRs[session.from] if !ok { - return ErrMissingAddressKey + return ErrMissingAddrKey } - addrKR, err := addrKR.FirstKey() + addrKey, err := addrKR.FirstKey() if err != nil { return fmt.Errorf("failed to get first key: %w", err) } @@ -143,7 +147,7 @@ func (session *smtpSession) Data(r io.Reader) error { } if session.settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { - key, err := addrKR.GetKey(0) + key, err := addrKey.GetKey(0) if err != nil { return fmt.Errorf("failed to get user public key: %w", err) } @@ -153,7 +157,7 @@ func (session *smtpSession) Data(r io.Reader) error { return fmt.Errorf("failed to get user public key: %w", err) } - parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) + parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKey.GetIdentities()[0].Name, key.GetFingerprint()[:8])) } message, err := message.ParseWithParser(parser) @@ -161,7 +165,7 @@ func (session *smtpSession) Data(r io.Reader) error { return fmt.Errorf("failed to parse message: %w", err) } - draft, attKeys, err := session.createDraft(ctx, addrKR, message) + draft, attKeys, err := session.createDraft(ctx, addrKey, message) if err != nil { return fmt.Errorf("failed to create draft: %w", err) } @@ -171,7 +175,7 @@ func (session *smtpSession) Data(r io.Reader) error { return fmt.Errorf("failed to get recipients: %w", err) } - req, err := createSendReq(addrKR, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys) + req, err := createSendReq(addrKey, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys) if err != nil { return fmt.Errorf("failed to create packages: %w", err) } diff --git a/internal/user/sync.go b/internal/user/sync.go index 285b9137..2370956f 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -4,57 +4,34 @@ import ( "context" "fmt" "strings" - "sync" - "time" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" "gitlab.protontech.ch/go/liteapi" - "golang.org/x/exp/slices" ) const chunkSize = 1 << 20 -func (user *User) sync(ctx context.Context) error { - user.notifyCh <- events.SyncStarted{ - UserID: user.ID(), - } - - if err := user.syncLabels(ctx); err != nil { - return fmt.Errorf("failed to sync labels: %w", err) - } - - if err := user.syncMessages(ctx); err != nil { - return fmt.Errorf("failed to sync messages: %w", err) - } - - user.notifyCh <- events.SyncFinished{ - UserID: user.ID(), - } - - if err := user.vault.SetSync(true); err != nil { - return fmt.Errorf("failed to update sync status: %w", err) - } - - return nil -} - -func (user *User) syncLabels(ctx context.Context) error { +func (user *User) syncLabels(ctx context.Context, addrIDs ...string) error { // Sync the system folders. system, err := user.client.GetLabels(ctx, liteapi.LabelTypeSystem) if err != nil { return err } - for _, label := range system { - user.updateCh <- newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name) + for _, label := range xslices.Filter(system, func(label liteapi.Label) bool { return wantLabelID(label.ID) }) { + for _, addrID := range addrIDs { + user.updateCh[addrID].Enqueue(newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name)) + } } // Create Folders/Labels mailboxes with a random ID and with the \Noselect attribute. for _, prefix := range []string{folderPrefix, labelPrefix} { - user.updateCh <- newPlaceHolderMailboxCreatedUpdate(prefix) + for _, addrID := range addrIDs { + user.updateCh[addrID].Enqueue(newPlaceHolderMailboxCreatedUpdate(prefix)) + } } // Sync the API folders. @@ -64,7 +41,9 @@ func (user *User) syncLabels(ctx context.Context) error { } for _, folder := range folders { - user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path}) + for _, addrID := range addrIDs { + user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path})) + } } // Sync the API labels. @@ -74,7 +53,9 @@ func (user *User) syncLabels(ctx context.Context) error { } for _, label := range labels { - user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path}) + for _, addrID := range addrIDs { + user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path})) + } } return nil @@ -84,27 +65,53 @@ func (user *User) syncMessages(ctx context.Context) error { ctx, cancel := context.WithCancel(ctx) defer cancel() + // Determine which messages to sync. + // TODO: This needs to be done better using the new API route to retrieve just the message IDs. metadata, err := user.client.GetAllMessageMetadata(ctx) if err != nil { return err } + // If in split mode, we need to send each message to a different IMAP connector. + isSplitMode := user.vault.AddressMode() == vault.SplitMode + + // Collect the build requests -- we need: + // - the message ID to build, + // - the keyring to decrypt the message, + // - and the address to send the message to (for split mode). requests := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) request { + var addressID string + + if isSplitMode { + addressID = metadata.AddressID + } else { + addressID = user.apiAddrs.primary() + } + return request{ messageID: metadata.ID, + addressID: addressID, addrKR: user.addrKRs[metadata.AddressID], } }) - flusher := newFlusher(user.ID(), user.updateCh, user.notifyCh, len(metadata), chunkSize) - defer flusher.flush() + // Create the flushers, one per update channel. + flushers := make(map[string]*flusher) + for addrID, updateCh := range user.updateCh { + flusher := newFlusher(user.ID(), updateCh, user.eventCh, len(requests), chunkSize) + defer flusher.flush() + + flushers[addrID] = flusher + } + + // Build the messages and send them to the correct flusher. if err := user.builder.Process(ctx, requests, func(req request, res *imap.MessageCreated, err error) error { if err != nil { return fmt.Errorf("failed to build message %s: %w", req.messageID, err) } - flusher.push(res) + flushers[req.addressID].push(res) return nil }); err != nil { @@ -114,95 +121,15 @@ func (user *User) syncMessages(ctx context.Context) error { return nil } -type flusher struct { - userID string +func (user *User) syncWait() { + for _, updateCh := range user.updateCh { + waiter := imap.NewNoop() + defer waiter.Wait() - updates []*imap.MessageCreated - updateCh chan<- imap.Update - notifyCh chan<- events.Event - maxChunkSize int - curChunkSize int - - count int - total int - start time.Time - - pushLock sync.Mutex -} - -func newFlusher(userID string, updateCh chan<- imap.Update, notifyCh chan<- events.Event, total, maxChunkSize int) *flusher { - return &flusher{ - userID: userID, - updateCh: updateCh, - notifyCh: notifyCh, - maxChunkSize: maxChunkSize, - total: total, - start: time.Now(), + updateCh.Enqueue(waiter) } } -func (f *flusher) push(update *imap.MessageCreated) { - f.pushLock.Lock() - defer f.pushLock.Unlock() - - f.updates = append(f.updates, update) - - if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize { - f.flush() - } -} - -func (f *flusher) flush() { - if len(f.updates) == 0 { - return - } - - f.count += len(f.updates) - f.updateCh <- imap.NewMessagesCreated(f.updates...) - f.notifyCh <- newSyncProgress(f.userID, f.count, f.total, f.start) - f.updates = nil - f.curChunkSize = 0 -} - -func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress { - return events.SyncProgress{ - UserID: userID, - Progress: float64(count) / float64(total), - Elapsed: time.Since(start), - Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count), - } -} - -func getMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) { - parsedMessage, err := imap.NewParsedMessage(literal) - if err != nil { - return nil, err - } - - flags := imap.NewFlagSet() - - if !message.Unread { - flags = flags.Add(imap.FlagSeen) - } - - if slices.Contains(message.LabelIDs, liteapi.StarredLabel) { - flags = flags.Add(imap.FlagFlagged) - } - - imapMessage := imap.Message{ - ID: imap.MessageID(message.ID), - Flags: flags, - Date: time.Unix(message.Time, 0), - } - - return &imap.MessageCreated{ - Message: imapMessage, - Literal: literal, - LabelIDs: imapLabelIDs(filterLabelIDs(message.LabelIDs)), - ParsedMessage: parsedMessage, - }, nil -} - func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated { if strings.EqualFold(labelName, imap.Inbox) { labelName = imap.Inbox @@ -237,18 +164,12 @@ func newMailboxCreatedUpdate(labelID imap.LabelID, labelName []string) *imap.Mai }) } -func filterLabelIDs(labelIDs []string) []string { - var filteredLabelIDs []string +func wantLabelID(labelID string) bool { + switch labelID { + case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel: + return false - for _, labelID := range labelIDs { - switch labelID { - case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel: - // ... skip ... - - default: - filteredLabelIDs = append(filteredLabelIDs, labelID) - } + default: + return true } - - return filteredLabelIDs } diff --git a/internal/user/types.go b/internal/user/types.go new file mode 100644 index 00000000..f65cd6cb --- /dev/null +++ b/internal/user/types.go @@ -0,0 +1,13 @@ +package user + +import "reflect" + +func mapTo[From, To any](from []From) []To { + to := make([]To, 0, len(from)) + + for _, from := range from { + to = append(to, reflect.ValueOf(from).Convert(reflect.TypeOf(to).Elem()).Interface().(To)) + } + + return to +} diff --git a/internal/user/types_test.go b/internal/user/types_test.go new file mode 100644 index 00000000..b9641862 --- /dev/null +++ b/internal/user/types_test.go @@ -0,0 +1,20 @@ +package user + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToType(t *testing.T) { + type myString string + + // Slices of different types are not equal. + require.NotEqual(t, []myString{"a", "b", "c"}, []string{"a", "b", "c"}) + + // But converting them to the same type makes them equal. + require.Equal(t, []myString{"a", "b", "c"}, mapTo[string, myString]([]string{"a", "b", "c"})) + + // The conversion can happen in the other direction too. + require.Equal(t, []string{"a", "b", "c"}, mapTo[myString, string]([]myString{"a", "b", "c"})) +} diff --git a/internal/user/user.go b/internal/user/user.go index 75ea0136..37f7b90b 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -2,19 +2,22 @@ package user import ( "context" + "fmt" "runtime" "time" + "github.com/ProtonMail/gluon" "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/pool" "github.com/ProtonMail/proton-bridge/v2/internal/vault" - "github.com/bradenaw/juniper/xslices" "github.com/emersion/go-smtp" "github.com/sirupsen/logrus" "gitlab.protontech.ch/go/liteapi" + "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) @@ -23,40 +26,38 @@ var ( DefaultEventJitter = 20 * time.Second ) -// TODO: Is it bad to store the key pass in the user? Any worse than storing private keys? type User struct { vault *vault.User client *liteapi.Client builder *pool.Pool[request, *imap.MessageCreated] + eventCh *queue.QueuedChannel[events.Event] - apiUser liteapi.User - addresses []liteapi.Address - settings liteapi.MailSettings - - notifyCh chan events.Event - updateCh chan imap.Update - + apiUser liteapi.User + apiAddrs *addrList userKR *crypto.KeyRing addrKRs map[string]*crypto.KeyRing - imapConn *imapConnector + settings liteapi.MailSettings + + updateCh map[string]*queue.QueuedChannel[imap.Update] + syncWG gluon.WaitGroup } func New( ctx context.Context, - vault *vault.User, + encVault *vault.User, client *liteapi.Client, apiUser liteapi.User, apiAddrs []liteapi.Address, userKR *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing, ) (*User, error) { - if vault.EventID() == "" { + if encVault.EventID() == "" { eventID, err := client.GetLatestEventID(ctx) if err != nil { return nil, err } - if err := vault.SetEventID(eventID); err != nil { + if err := encVault.SetEventID(eventID); err != nil { return nil, err } } @@ -67,19 +68,29 @@ func New( } user := &User{ - apiUser: apiUser, - addresses: apiAddrs, - settings: settings, - - vault: vault, + vault: encVault, client: client, builder: newBuilder(client, runtime.NumCPU()*runtime.NumCPU(), runtime.NumCPU()*runtime.NumCPU()), + eventCh: queue.NewQueuedChannel[events.Event](0, 0), - notifyCh: make(chan events.Event), - updateCh: make(chan imap.Update), + apiUser: apiUser, + apiAddrs: newAddrList(apiAddrs), - userKR: userKR, - addrKRs: addrKRs, + userKR: userKR, + addrKRs: addrKRs, + settings: settings, + + updateCh: make(map[string]*queue.QueuedChannel[imap.Update]), + } + + // Initialize update channels for each of the user's addresses. + for _, addrID := range user.apiAddrs.addrIDs() { + user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0) + + // If in combined mode, we only need one update channel. + if encVault.AddressMode() == vault.CombinedMode { + break + } } // When we receive an auth object, we update it in the store. @@ -93,111 +104,234 @@ func New( // When we are deauthorized, we send a deauth event to the notify channel. // Bridge will catch this and log the user out. client.AddDeauthHandler(func() { - user.notifyCh <- events.UserDeauth{ + user.eventCh.Enqueue(events.UserDeauth{ UserID: user.ID(), - } + }) }) - // When we receive an API event, we attempt to handle it. If successful, we send the event to the event channel. + // When we receive an API event, we attempt to handle it. + // If successful, we update the event ID in the vault. go func() { - for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, vault.EventID()).Subscribe() { - if err := user.handleAPIEvent(event); err != nil { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, encVault.EventID()).Subscribe() { + if err := user.handleAPIEvent(ctx, event); err != nil { logrus.WithError(err).Error("Failed to handle event") - } else { - if err := user.vault.SetEventID(event.EventID); err != nil { - logrus.WithError(err).Error("Failed to update event ID") - } + } else if err := user.vault.SetEventID(event.EventID); err != nil { + logrus.WithError(err).Error("Failed to update event ID") } } }() - // TODO: Use a proper sync manager! (if partial sync, pickup from where we last stopped) - if !vault.HasSync() { - go user.sync(context.Background()) - } - return user, nil } +// ID returns the user's ID. func (user *User) ID() string { return user.apiUser.ID } +// Name returns the user's username. func (user *User) Name() string { return user.apiUser.Name } +// Match matches the given query against the user's username and email addresses. func (user *User) Match(query string) bool { - if query == user.Name() { + if query == user.apiUser.Name { return true } - return slices.Contains(user.Addresses(), query) + return slices.Contains(user.apiAddrs.emails(), query) } -func (user *User) Addresses() []string { - return xslices.Map( - sort(user.addresses, func(a, b liteapi.Address) bool { - return a.Order < b.Order - }), - func(address liteapi.Address) string { - return address.Email - }, - ) +// Emails returns all the user's email addresses. +func (user *User) Emails() []string { + return user.apiAddrs.emails() } -func (user *User) GluonID() string { - return user.vault.GluonID() +// GetAddressMode returns the user's current address mode. +func (user *User) GetAddressMode() vault.AddressMode { + return user.vault.AddressMode() } +// SetAddressMode sets the user's address mode. +func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error { + for _, updateCh := range user.updateCh { + updateCh.Close() + } + + user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update]) + + for _, addrID := range user.apiAddrs.addrIDs() { + user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0) + + if mode == vault.CombinedMode { + break + } + } + + if err := user.vault.SetAddressMode(mode); err != nil { + return fmt.Errorf("failed to set address mode: %w", err) + } + + return nil +} + +// GetGluonIDs returns the users gluon IDs. +func (user *User) GetGluonIDs() map[string]string { + return user.vault.GetGluonIDs() +} + +// GetGluonID returns the gluon ID for the given address, if present. +func (user *User) GetGluonID(addrID string) (string, bool) { + gluonID, ok := user.vault.GetGluonIDs()[addrID] + if !ok { + return "", false + } + + return gluonID, true +} + +// SetGluonID sets the gluon ID for the given address. +func (user *User) SetGluonID(addrID, gluonID string) error { + return user.vault.SetGluonID(addrID, gluonID) +} + +// GluonKey returns the user's gluon key from the vault. func (user *User) GluonKey() []byte { return user.vault.GluonKey() } +// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP. func (user *User) BridgePass() string { return user.vault.BridgePass() } +// UsedSpace returns the total space used by the user on the API. func (user *User) UsedSpace() int { return user.apiUser.UsedSpace } +// MaxSpace returns the amount of space the user can use on the API. func (user *User) MaxSpace() int { return user.apiUser.MaxSpace } -// GetNotifyCh returns a channel which notifies of events happening to the user (such as deauth, address change) -func (user *User) GetNotifyCh() <-chan events.Event { - return user.notifyCh +// HasSync returns whether the user has finished syncing. +func (user *User) HasSync() bool { + return user.vault.HasSync() } -func (user *User) NewGluonConnector(ctx context.Context) (connector.Connector, error) { - if user.imapConn != nil { - if err := user.imapConn.Close(ctx); err != nil { - return nil, err +// AbortSync aborts any ongoing sync. +// TODO: This should abort the sync rather than just waiting. +// Should probably be done automatically when one of the user's IMAP connectors is closed. +func (user *User) AbortSync(ctx context.Context) error { + user.syncWG.Wait() + + return nil +} + +// DoSync performs a sync for the user. +func (user *User) DoSync(ctx context.Context) <-chan error { + errCh := queue.NewQueuedChannel[error](0, 0) + + user.syncWG.Go(func() { + defer errCh.Close() + + user.eventCh.Enqueue(events.SyncStarted{ + UserID: user.ID(), + }) + + errCh.Enqueue(func() error { + if err := user.syncLabels(ctx, maps.Keys(user.updateCh)...); err != nil { + return fmt.Errorf("failed to sync labels: %w", err) + } + + if err := user.syncMessages(ctx); err != nil { + return fmt.Errorf("failed to sync messages: %w", err) + } + + user.syncWait() + + if err := user.vault.SetSync(true); err != nil { + return fmt.Errorf("failed to set sync status: %w", err) + } + + return nil + }()) + + user.eventCh.Enqueue(events.SyncFinished{ + UserID: user.ID(), + }) + }) + + return errCh.GetChannel() +} + +// GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change) +func (user *User) GetEventCh() <-chan events.Event { + return user.eventCh.GetChannel() +} + +// NewIMAPConnector returns an IMAP connector for the given address. +// If not in split mode, this function returns an error. +func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) { + var emails []string + + switch user.vault.AddressMode() { + case vault.CombinedMode: + if addrID != user.apiAddrs.primary() { + return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode") } + + emails = user.apiAddrs.emails() + + case vault.SplitMode: + emails = []string{user.apiAddrs.email(addrID)} } - user.imapConn = newIMAPConnector(user.client, user.updateCh, user.Addresses(), user.vault.BridgePass()) - - return user.imapConn, nil + return newIMAPConnector( + user.client, + user.updateCh[addrID].GetChannel(), + user.vault.BridgePass(), + emails..., + ), nil } -func (user *User) NewSMTPSession(username string) (smtp.Session, error) { - return newSMTPSession(user.client, username, user.addresses, user.userKR, user.addrKRs, user.settings), nil +// NewIMAPConnectors returns IMAP connectors for each of the user's addresses. +// In combined mode, this is just the user's primary address. +// In split mode, this is all the user's addresses. +func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) { + imapConn := make(map[string]connector.Connector) + + for addrID := range user.updateCh { + conn, err := user.NewIMAPConnector(addrID) + if err != nil { + return nil, fmt.Errorf("failed to create IMAP connector: %w", err) + } + + imapConn[addrID] = conn + } + + return imapConn, nil } +// NewSMTPSession returns an SMTP session for the user. +func (user *User) NewSMTPSession(username string) smtp.Session { + return newSMTPSession(user.client, username, user.apiAddrs.addrMap(), user.settings, user.userKR, user.addrKRs) +} + +// Logout logs the user out from the API. func (user *User) Logout(ctx context.Context) error { return user.client.AuthDelete(ctx) } +// Close closes ongoing connections and cleans up resources. func (user *User) Close(ctx context.Context) error { - // Close the user's IMAP connectors. - if user.imapConn != nil { - if err := user.imapConn.Close(ctx); err != nil { - return err - } - } + // Wait for ongoing syncs to finish. + user.syncWG.Wait() // Close the user's message builder. user.builder.Done() @@ -205,15 +339,13 @@ func (user *User) Close(ctx context.Context) error { // Close the user's API client. user.client.Close() + // Close the user's update channels. + for _, updateCh := range user.updateCh { + updateCh.Close() + } + // Close the user's notify channel. - close(user.notifyCh) + user.eventCh.Close() return nil } - -// sort returns the slice, sorted by the given callback. -func sort[T any](slice []T, less func(a, b T) bool) []T { - slices.SortFunc(slice, less) - - return slice -} diff --git a/internal/user/user_test.go b/internal/user/user_test.go new file mode 100644 index 00000000..cc9e9fc4 --- /dev/null +++ b/internal/user/user_test.go @@ -0,0 +1,162 @@ +package user_test + +import ( + "context" + "testing" + "time" + + "github.com/ProtonMail/proton-bridge/v2/internal/certs" + "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/user" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" + "github.com/ProtonMail/proton-bridge/v2/tests" + "github.com/bradenaw/juniper/iterator" + "github.com/emersion/go-imap" + "github.com/emersion/go-imap/client" + "github.com/stretchr/testify/require" + "gitlab.protontech.ch/go/liteapi" + "gitlab.protontech.ch/go/liteapi/server" + "gitlab.protontech.ch/go/liteapi/server/account" +) + +func init() { + user.DefaultEventPeriod = 100 * time.Millisecond + user.DefaultEventJitter = 0 + account.GenerateKey = tests.FastGenerateKey + certs.GenerateCert = tests.FastGenerateCert +} + +func TestUser_Data(t *testing.T) { + withAPI(t, context.Background(), "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) { + withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) { + // User's ID should be correct. + require.Equal(t, userID, user.ID()) + + // User's name should be correct. + require.Equal(t, "username", user.Name()) + + // User's email should be correct. + require.ElementsMatch(t, []string{"email@pm.me", "alias@pm.me"}, user.Emails()) + + // By default, user should be in combined mode. + require.Equal(t, vault.CombinedMode, user.GetAddressMode()) + + // By default, user should have a non-empty bridge password. + require.NotEmpty(t, user.BridgePass()) + }) + }) +} + +func TestUser_Sync(t *testing.T) { + withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) { + withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) { + // Get the user's IMAP connectors. + imapConn, err := user.NewIMAPConnectors() + require.NoError(t, err) + + // Pretend to be gluon applying all the updates. + go func() { + for _, imapConn := range imapConn { + for update := range imapConn.GetUpdates() { + update.Done() + } + } + }() + + // Trigger a user sync. + errCh := user.DoSync(ctx) + + // User starts a sync at startup. + require.IsType(t, events.SyncStarted{}, <-user.GetEventCh()) + + // User finishes a sync at startup. + require.IsType(t, events.SyncFinished{}, <-user.GetEventCh()) + + // The sync completes without error. + require.NoError(t, <-errCh) + }) + }) +} + +func TestUser_Deauth(t *testing.T) { + withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) { + withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) { + eventCh := user.GetEventCh() + + // Revoke the user's auth token. + require.NoError(t, s.RevokeUser(userID)) + + // The user should eventually be logged out. + require.Eventually(t, func() bool { _, ok := (<-eventCh).(events.UserDeauth); return ok }, 5*time.Second, 100*time.Millisecond) + }) + }) +} + +func withAPI(t *testing.T, ctx context.Context, username, password string, emails []string, fn func(context.Context, *server.Server, string, []string)) { + server := server.New() + defer server.Close() + + var addrIDs []string + + userID, addrID, err := server.AddUser(username, password, emails[0]) + require.NoError(t, err) + + addrIDs = append(addrIDs, addrID) + + for _, email := range emails[1:] { + addrID, err := server.AddAddress(userID, email, password) + require.NoError(t, err) + + addrIDs = append(addrIDs, addrID) + } + + fn(ctx, server, userID, addrIDs) +} + +func withUser(t *testing.T, ctx context.Context, apiURL, username, password string, fn func(*user.User)) { + c, apiAuth, err := liteapi.New(liteapi.WithHostURL(apiURL)).NewClientWithLogin(ctx, username, password) + require.NoError(t, err) + defer func() { require.NoError(t, c.Close()) }() + + apiUser, apiAddrs, userKR, addrKRs, passphrase, err := c.Unlock(ctx, []byte(password)) + require.NoError(t, err) + + vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key")) + require.NoError(t, err) + require.False(t, corrupt) + + vaultUser, err := vault.AddUser(apiUser.ID, username, apiAuth.UID, apiAuth.RefreshToken, passphrase) + require.NoError(t, err) + + user, err := user.New(ctx, vaultUser, c, apiUser, apiAddrs, userKR, addrKRs) + require.NoError(t, err) + defer func() { require.NoError(t, user.Close(ctx)) }() + + fn(user) +} + +func withIMAPClient(t *testing.T, addr string, fn func(*client.Client)) { + c, err := client.Dial(addr) + require.NoError(t, err) + defer c.Close() + + fn(c) +} + +func fetch(t *testing.T, c *client.Client, seqset string, items ...imap.FetchItem) []*imap.Message { + msgCh := make(chan *imap.Message) + + go func() { + require.NoError(t, c.Fetch(must(imap.ParseSeqSet(seqset)), items, msgCh)) + }() + + return iterator.Collect(iterator.Chan(msgCh)) +} + +func must[T any](v T, err error) T { + if err != nil { + panic(err) + } + + return v +} diff --git a/internal/vault/token.go b/internal/vault/token.go index 7bd70445..df55e3e1 100644 --- a/internal/vault/token.go +++ b/internal/vault/token.go @@ -5,9 +5,14 @@ import ( ) // RandomToken is a function that returns a random token. -var RandomToken func(size int) ([]byte, error) - // By default, we use crypto.RandomToken to generate tokens. -func init() { - RandomToken = crypto.RandomToken +var RandomToken = crypto.RandomToken + +func newRandomToken(size int) []byte { + token, err := RandomToken(size) + if err != nil { + panic(err) + } + + return token } diff --git a/internal/vault/types.go b/internal/vault/types.go index a1519dde..e3f6b48a 100644 --- a/internal/vault/types.go +++ b/internal/vault/types.go @@ -4,6 +4,7 @@ import ( "math/rand" "github.com/Masterminds/semver/v3" + "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/proton-bridge/v2/internal/updater" ) @@ -45,15 +46,24 @@ type Settings struct { FirstStartGUI bool } +type AddressMode int + +const ( + CombinedMode AddressMode = iota + SplitMode +) + // UserData holds information about a single bridge user. // The user may or may not be logged in. type UserData struct { UserID string Username string - GluonID string - GluonKey []byte - BridgePass string + GluonKey []byte + GluonIDs map[string]string + UIDValidity map[string]imap.UID + BridgePass []byte + AddressMode AddressMode AuthUID string AuthRef string diff --git a/internal/vault/user.go b/internal/vault/user.go index 5c599f51..76fc1be0 100644 --- a/internal/vault/user.go +++ b/internal/vault/user.go @@ -1,5 +1,11 @@ package vault +import ( + "encoding/hex" + + "github.com/ProtonMail/gluon/imap" +) + type User struct { vault *Vault userID string @@ -13,16 +19,41 @@ func (user *User) Username() string { return user.vault.getUser(user.userID).Username } -func (user *User) GluonID() string { - return user.vault.getUser(user.userID).GluonID +func (user *User) GetGluonIDs() map[string]string { + return user.vault.getUser(user.userID).GluonIDs +} + +func (user *User) SetGluonID(addrID, gluonID string) error { + return user.vault.modUser(user.userID, func(data *UserData) { + data.GluonIDs[addrID] = gluonID + }) +} + +func (user *User) GetUIDValidity(addrID string) (imap.UID, bool) { + validity, ok := user.vault.getUser(user.userID).UIDValidity[addrID] + if !ok { + return imap.UID(0), false + } + + return validity, true +} + +func (user *User) SetUIDValidity(addrID string, validity imap.UID) error { + return user.vault.modUser(user.userID, func(data *UserData) { + data.UIDValidity[addrID] = validity + }) } func (user *User) GluonKey() []byte { return user.vault.getUser(user.userID).GluonKey } +func (user *User) AddressMode() AddressMode { + return user.vault.getUser(user.userID).AddressMode +} + func (user *User) BridgePass() string { - return user.vault.getUser(user.userID).BridgePass + return hex.EncodeToString(user.vault.getUser(user.userID).BridgePass) } func (user *User) AuthUID() string { @@ -51,7 +82,7 @@ func (user *User) SetKeyPass(keyPass []byte) error { }) } -// SetAuth updates the auth secrets for the given user. +// SetAuth sets the auth secrets for the given user. func (user *User) SetAuth(authUID, authRef string) error { return user.vault.modUser(user.userID, func(data *UserData) { data.AuthUID = authUID @@ -59,33 +90,23 @@ func (user *User) SetAuth(authUID, authRef string) error { }) } -// SetGluonAuth updates the gluon ID and key for the given user. -func (user *User) SetGluonAuth(gluonID string, gluonKey []byte) error { +// SetAddressMode sets the address mode for the given user. +func (user *User) SetAddressMode(mode AddressMode) error { return user.vault.modUser(user.userID, func(data *UserData) { - data.GluonID = gluonID - data.GluonKey = gluonKey + data.AddressMode = mode }) } -// SetEventID updates the event ID for the given user. +// SetEventID sets the event ID for the given user. func (user *User) SetEventID(eventID string) error { return user.vault.modUser(user.userID, func(data *UserData) { data.EventID = eventID }) } -// SetSync updates the sync state for the given user. +// SetSync sets the sync state for the given user. func (user *User) SetSync(hasSync bool) error { return user.vault.modUser(user.userID, func(data *UserData) { data.HasSync = hasSync }) } - -// Clear clears the secrets for the given user. -func (user *User) Clear() error { - return user.vault.modUser(user.userID, func(data *UserData) { - data.AuthUID = "" - data.AuthRef = "" - data.KeyPass = nil - }) -} diff --git a/internal/vault/user_test.go b/internal/vault/user_test.go index 682d2c05..8e2e9691 100644 --- a/internal/vault/user_test.go +++ b/internal/vault/user_test.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "testing" + "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/stretchr/testify/require" ) @@ -32,30 +33,48 @@ func TestUser(t *testing.T) { require.NoError(t, user2.SetSync(false)) // Set gluon data for user 1 and 2. - require.NoError(t, user1.SetGluonAuth("gluonID1", []byte("gluonKey1"))) - require.NoError(t, user2.SetGluonAuth("gluonID2", []byte("gluonKey2"))) + require.NoError(t, user1.SetGluonID("addrID1", "gluonID1")) + require.NoError(t, user2.SetGluonID("addrID2", "gluonID2")) + require.NoError(t, user1.SetUIDValidity("addrID1", imap.UID(1))) + require.NoError(t, user2.SetUIDValidity("addrID2", imap.UID(2))) // List available users. require.ElementsMatch(t, []string{"userID1", "userID2"}, s.GetUserIDs()) + // Check gluon information for user 1. + gluonID1, ok := user1.GetGluonIDs()["addrID1"] + require.True(t, ok) + require.Equal(t, "gluonID1", gluonID1) + uidValidity1, ok := user1.GetUIDValidity("addrID1") + require.True(t, ok) + require.Equal(t, imap.UID(1), uidValidity1) + require.NotEmpty(t, user1.GluonKey()) + // Get auth information for user 1. require.Equal(t, "userID1", user1.UserID()) require.Equal(t, "user1", user1.Username()) - require.Equal(t, "gluonID1", user1.GluonID()) - require.Equal(t, []byte("gluonKey1"), user1.GluonKey()) require.Equal(t, hex.EncodeToString([]byte("token")), user1.BridgePass()) + require.Equal(t, vault.CombinedMode, user1.AddressMode()) require.Equal(t, "authUID1", user1.AuthUID()) require.Equal(t, "authRef1", user1.AuthRef()) require.Equal(t, []byte("keyPass1"), user1.KeyPass()) require.Equal(t, "eventID1", user1.EventID()) require.Equal(t, true, user1.HasSync()) + // Check gluon information for user 1. + gluonID2, ok := user2.GetGluonIDs()["addrID2"] + require.True(t, ok) + require.Equal(t, "gluonID2", gluonID2) + uidValidity2, ok := user2.GetUIDValidity("addrID2") + require.True(t, ok) + require.Equal(t, imap.UID(2), uidValidity2) + require.NotEmpty(t, user2.GluonKey()) + // Get auth information for user 2. require.Equal(t, "userID2", user2.UserID()) require.Equal(t, "user2", user2.Username()) - require.Equal(t, "gluonID2", user2.GluonID()) - require.Equal(t, []byte("gluonKey2"), user2.GluonKey()) require.Equal(t, hex.EncodeToString([]byte("token")), user2.BridgePass()) + require.Equal(t, vault.CombinedMode, user2.AddressMode()) require.Equal(t, "authUID2", user2.AuthUID()) require.Equal(t, "authRef2", user2.AuthRef()) require.Equal(t, []byte("keyPass2"), user2.KeyPass()) @@ -63,8 +82,8 @@ func TestUser(t *testing.T) { require.Equal(t, false, user2.HasSync()) // Clear the users. - require.NoError(t, user1.Clear()) - require.NoError(t, user2.Clear()) + require.NoError(t, s.ClearUser("userID1")) + require.NoError(t, s.ClearUser("userID2")) // Their secrets should now be cleared. require.Equal(t, "", user1.AuthUID()) diff --git a/internal/vault/vault.go b/internal/vault/vault.go index b10dd6f0..908021b5 100644 --- a/internal/vault/vault.go +++ b/internal/vault/vault.go @@ -4,7 +4,6 @@ import ( "crypto/aes" "crypto/cipher" "crypto/sha256" - "encoding/hex" "encoding/json" "errors" "io/fs" @@ -12,6 +11,7 @@ import ( "os" "path/filepath" + "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/proton-bridge/v2/internal/certs" "github.com/bradenaw/juniper/xslices" ) @@ -99,16 +99,16 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [ return nil, errors.New("user already exists") } - tok, err := RandomToken(16) - if err != nil { - return nil, err - } - if err := vault.mod(func(data *Data) { data.Users = append(data.Users, UserData{ - UserID: userID, - Username: username, - BridgePass: hex.EncodeToString(tok), + UserID: userID, + Username: username, + + GluonKey: newRandomToken(32), + GluonIDs: make(map[string]string), + UIDValidity: make(map[string]imap.UID), + BridgePass: newRandomToken(16), + AddressMode: CombinedMode, AuthUID: authUID, AuthRef: authRef, @@ -121,6 +121,14 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [ return vault.GetUser(userID) } +func (vault *Vault) ClearUser(userID string) error { + return vault.modUser(userID, func(data *UserData) { + data.AuthUID = "" + data.AuthRef = "" + data.KeyPass = nil + }) +} + // DeleteUser removes the given user from the vault. func (vault *Vault) DeleteUser(userID string) error { return vault.mod(func(data *Data) { diff --git a/tests/api_test.go b/tests/api_test.go index 9f0f94e8..6ecbd4b4 100644 --- a/tests/api_test.go +++ b/tests/api_test.go @@ -13,7 +13,9 @@ type API interface { GetHostURL() string AddCallWatcher(func(server.Call), ...string) - AddUser(username, password, address string) (userID, addrID string, err error) + AddUser(username, password, address string) (string, string, error) + AddAddress(userID, address, password string) (string, error) + RemoveAddress(userID, addrID string) error RevokeUser(userID string) error GetLabels(userID string) ([]liteapi.Label, error) diff --git a/tests/bdd_test.go b/tests/bdd_test.go index d06770b5..515fb633 100644 --- a/tests/bdd_test.go +++ b/tests/bdd_test.go @@ -30,8 +30,8 @@ import ( ) func init() { - user.DefaultEventPeriod = time.Second - user.DefaultEventJitter = time.Second + user.DefaultEventPeriod = 100 * time.Millisecond + user.DefaultEventJitter = 0 } type scenario struct { @@ -76,6 +76,16 @@ func TestFeatures(testingT *testing.T) { ctx.Step(`^the user agent is "([^"]*)"$`, s.theUserAgentIs) ctx.Step(`^the value of the "([^"]*)" header in the request to "([^"]*)" is "([^"]*)"$`, s.theValueOfTheHeaderInTheRequestToIs) + // ==== SETUP ==== + ctx.Step(`^there exists an account with username "([^"]*)" and password "([^"]*)"$`, s.thereExistsAnAccountWithUsernameAndPassword) + ctx.Step(`^the account "([^"]*)" has additional address "([^"]*)"$`, s.theAccountHasAdditionalAddress) + ctx.Step(`^the account "([^"]*)" no longer has additional address "([^"]*)"$`, s.theAccountNoLongerHasAdditionalAddress) + ctx.Step(`^the account "([^"]*)" has (\d+) custom folders$`, s.theAccountHasCustomFolders) + ctx.Step(`^the account "([^"]*)" has (\d+) custom labels$`, s.theAccountHasCustomLabels) + ctx.Step(`^the account "([^"]*)" has the following custom mailboxes:$`, s.theAccountHasTheFollowingCustomMailboxes) + ctx.Step(`^the address "([^"]*)" of account "([^"]*)" has the following messages in "([^"]*)":$`, s.theAddressOfAccountHasTheFollowingMessagesInMailbox) + ctx.Step(`^the address "([^"]*)" of account "([^"]*)" has (\d+) messages in "([^"]*)"$`, s.theAddressOfAccountHasMessagesInMailbox) + // ==== BRIDGE ==== ctx.Step(`^bridge starts$`, s.bridgeStarts) ctx.Step(`^bridge restarts$`, s.bridgeRestarts) @@ -85,12 +95,15 @@ func TestFeatures(testingT *testing.T) { ctx.Step(`^the user has disabled automatic updates$`, s.theUserHasDisabledAutomaticUpdates) ctx.Step(`^the user changes the IMAP port to (\d+)$`, s.theUserChangesTheIMAPPortTo) ctx.Step(`^the user changes the SMTP port to (\d+)$`, s.theUserChangesTheSMTPPortTo) + ctx.Step(`^the user sets the address mode of "([^"]*)" to "([^"]*)"$`, s.theUserSetsTheAddressModeOfTo) ctx.Step(`^the user changes the gluon path$`, s.theUserChangesTheGluonPath) ctx.Step(`^the user deletes the gluon files$`, s.theUserDeletesTheGluonFiles) ctx.Step(`^the user reports a bug$`, s.theUserReportsABug) ctx.Step(`^bridge sends a connection up event$`, s.bridgeSendsAConnectionUpEvent) ctx.Step(`^bridge sends a connection down event$`, s.bridgeSendsAConnectionDownEvent) ctx.Step(`^bridge sends a deauth event for user "([^"]*)"$`, s.bridgeSendsADeauthEventForUser) + ctx.Step(`^bridge sends an address created event for user "([^"]*)"$`, s.bridgeSendsAnAddressCreatedEventForUser) + ctx.Step(`^bridge sends an address deleted event for user "([^"]*)"$`, s.bridgeSendsAnAddressDeletedEventForUser) ctx.Step(`^bridge sends sync started and finished events for user "([^"]*)"$`, s.bridgeSendsSyncStartedAndFinishedEventsForUser) ctx.Step(`^bridge sends an update available event for version "([^"]*)"$`, s.bridgeSendsAnUpdateAvailableEventForVersion) ctx.Step(`^bridge sends a manual update event for version "([^"]*)"$`, s.bridgeSendsAManualUpdateEventForVersion) @@ -99,12 +112,6 @@ func TestFeatures(testingT *testing.T) { ctx.Step(`^bridge sends a forced update event$`, s.bridgeSendsAForcedUpdateEvent) // ==== USER ==== - ctx.Step(`^there exists an account with username "([^"]*)" and password "([^"]*)"$`, s.thereExistsAnAccountWithUsernameAndPassword) - ctx.Step(`^the account "([^"]*)" has (\d+) custom folders$`, s.theAccountHasCustomFolders) - ctx.Step(`^the account "([^"]*)" has (\d+) custom labels$`, s.theAccountHasCustomLabels) - ctx.Step(`^the account "([^"]*)" has the following custom mailboxes:$`, s.theAccountHasTheFollowingCustomMailboxes) - ctx.Step(`^the account "([^"]*)" has the following messages in "([^"]*)":$`, s.theAccountHasTheFollowingMessagesInMailbox) - ctx.Step(`^the account "([^"]*)" has (\d+) messages in "([^"]*)"$`, s.theAccountHasMessagesInMailbox) ctx.Step(`^the user logs in with username "([^"]*)" and password "([^"]*)"$`, s.userLogsInWithUsernameAndPassword) ctx.Step(`^user "([^"]*)" logs out$`, s.userLogsOut) ctx.Step(`^user "([^"]*)" is deleted$`, s.userIsDeleted) @@ -119,8 +126,10 @@ func TestFeatures(testingT *testing.T) { ctx.Step(`^user "([^"]*)" connects IMAP client "([^"]*)"$`, s.userConnectsIMAPClient) ctx.Step(`^user "([^"]*)" connects IMAP client "([^"]*)" on port (\d+)$`, s.userConnectsIMAPClientOnPort) ctx.Step(`^user "([^"]*)" connects and authenticates IMAP client "([^"]*)"$`, s.userConnectsAndAuthenticatesIMAPClient) + ctx.Step(`^user "([^"]*)" connects and authenticates IMAP client "([^"]*)" with address "([^"]*)"$`, s.userConnectsAndAuthenticatesIMAPClientWithAddress) ctx.Step(`^IMAP client "([^"]*)" can authenticate$`, s.imapClientCanAuthenticate) ctx.Step(`^IMAP client "([^"]*)" cannot authenticate$`, s.imapClientCannotAuthenticate) + ctx.Step(`^IMAP client "([^"]*)" cannot authenticate with address "([^"]*)"$`, s.imapClientCannotAuthenticateWithAddress) ctx.Step(`^IMAP client "([^"]*)" cannot authenticate with incorrect username$`, s.imapClientCannotAuthenticateWithIncorrectUsername) ctx.Step(`^IMAP client "([^"]*)" cannot authenticate with incorrect password$`, s.imapClientCannotAuthenticateWithIncorrectPassword) ctx.Step(`^IMAP client "([^"]*)" announces its ID with name "([^"]*)" and version "([^"]*)"$`, s.imapClientAnnouncesItsIDWithNameAndVersion) @@ -151,6 +160,7 @@ func TestFeatures(testingT *testing.T) { ctx.Step(`^user "([^"]*)" connects SMTP client "([^"]*)"$`, s.userConnectsSMTPClient) ctx.Step(`^user "([^"]*)" connects SMTP client "([^"]*)" on port (\d+)$`, s.userConnectsSMTPClientOnPort) ctx.Step(`^user "([^"]*)" connects and authenticates SMTP client "([^"]*)"$`, s.userConnectsAndAuthenticatesSMTPClient) + ctx.Step(`^user "([^"]*)" connects and authenticates SMTP client "([^"]*)" with address "([^"]*)"$`, s.userConnectsAndAuthenticatesSMTPClientWithAddress) ctx.Step(`^SMTP client "([^"]*)" can authenticate$`, s.smtpClientCanAuthenticate) ctx.Step(`^SMTP client "([^"]*)" cannot authenticate$`, s.smtpClientCannotAuthenticate) ctx.Step(`^SMTP client "([^"]*)" cannot authenticate with incorrect username$`, s.smtpClientCannotAuthenticateWithIncorrectUsername) diff --git a/tests/bridge_test.go b/tests/bridge_test.go index 063b65fa..ccd94bf3 100644 --- a/tests/bridge_test.go +++ b/tests/bridge_test.go @@ -9,6 +9,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" ) func (s *scenario) bridgeStarts() error { @@ -46,6 +47,19 @@ func (s *scenario) theUserChangesTheSMTPPortTo(port int) error { return s.t.bridge.SetSMTPPort(port) } +func (s *scenario) theUserSetsTheAddressModeOfTo(user, mode string) error { + switch mode { + case "split": + return s.t.bridge.SetAddressMode(context.Background(), s.t.getUserID(user), vault.SplitMode) + + case "combined": + return s.t.bridge.SetAddressMode(context.Background(), s.t.getUserID(user), vault.CombinedMode) + + default: + return fmt.Errorf("unknown address mode %q", mode) + } +} + func (s *scenario) theUserChangesTheGluonPath() error { gluonDir, err := os.MkdirTemp(s.t.dir, "gluon") if err != nil { @@ -113,7 +127,7 @@ func (s *scenario) bridgeSendsAConnectionDownEvent() error { } func (s *scenario) bridgeSendsADeauthEventForUser(username string) error { - return try(s.t.userDeauthCh, 5*time.Second, func(event events.UserDeauth) error { + return try(s.t.deauthCh, 5*time.Second, func(event events.UserDeauth) error { if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { return fmt.Errorf("expected deauth event for user with ID %s, got %s", wantUserID, event.UserID) } @@ -122,6 +136,26 @@ func (s *scenario) bridgeSendsADeauthEventForUser(username string) error { }) } +func (s *scenario) bridgeSendsAnAddressCreatedEventForUser(username string) error { + return try(s.t.addrCreatedCh, 5*time.Second, func(event events.UserAddressCreated) error { + if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { + return fmt.Errorf("expected user address created event for user with ID %s, got %s", wantUserID, event.UserID) + } + + return nil + }) +} + +func (s *scenario) bridgeSendsAnAddressDeletedEventForUser(username string) error { + return try(s.t.addrDeletedCh, 5*time.Second, func(event events.UserAddressDeleted) error { + if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { + return fmt.Errorf("expected user address deleted event for user with ID %s, got %s", wantUserID, event.UserID) + } + + return nil + }) +} + func (s *scenario) bridgeSendsSyncStartedAndFinishedEventsForUser(username string) error { if err := get(s.t.syncStartedCh, func(event events.SyncStarted) error { if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index c696eb5f..f46dc3a5 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -54,10 +54,12 @@ func (t *testCtx) startBridge() error { t.bridge = bridge // Connect the event channels. - t.userLoginCh = chToType[events.Event, events.UserLoggedIn](bridge.GetEvents(events.UserLoggedIn{})) - t.userLogoutCh = chToType[events.Event, events.UserLoggedOut](bridge.GetEvents(events.UserLoggedOut{})) - t.userDeletedCh = chToType[events.Event, events.UserDeleted](bridge.GetEvents(events.UserDeleted{})) - t.userDeauthCh = chToType[events.Event, events.UserDeauth](bridge.GetEvents(events.UserDeauth{})) + t.loginCh = chToType[events.Event, events.UserLoggedIn](bridge.GetEvents(events.UserLoggedIn{})) + t.logoutCh = chToType[events.Event, events.UserLoggedOut](bridge.GetEvents(events.UserLoggedOut{})) + t.deletedCh = chToType[events.Event, events.UserDeleted](bridge.GetEvents(events.UserDeleted{})) + t.deauthCh = chToType[events.Event, events.UserDeauth](bridge.GetEvents(events.UserDeauth{})) + t.addrCreatedCh = chToType[events.Event, events.UserAddressCreated](bridge.GetEvents(events.UserAddressCreated{})) + t.addrDeletedCh = chToType[events.Event, events.UserAddressDeleted](bridge.GetEvents(events.UserAddressDeleted{})) t.syncStartedCh = chToType[events.Event, events.SyncStarted](bridge.GetEvents(events.SyncStarted{})) t.syncFinishedCh = chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) t.forcedUpdateCh = chToType[events.Event, events.UpdateForced](bridge.GetEvents(events.UpdateForced{})) diff --git a/tests/ctx_test.go b/tests/ctx_test.go index 24d30747..06ed6a39 100644 --- a/tests/ctx_test.go +++ b/tests/ctx_test.go @@ -14,6 +14,7 @@ import ( "github.com/emersion/go-imap/client" "gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi/server" + "golang.org/x/exp/maps" ) var defaultVersion = semver.MustParse("1.0.0") @@ -32,10 +33,12 @@ type testCtx struct { bridge *bridge.Bridge // These channels hold events of various types coming from bridge. - userLoginCh <-chan events.UserLoggedIn - userLogoutCh <-chan events.UserLoggedOut - userDeletedCh <-chan events.UserDeleted - userDeauthCh <-chan events.UserDeauth + loginCh <-chan events.UserLoggedIn + logoutCh <-chan events.UserLoggedOut + deletedCh <-chan events.UserDeleted + deauthCh <-chan events.UserDeauth + addrCreatedCh <-chan events.UserAddressCreated + addrDeletedCh <-chan events.UserAddressDeleted syncStartedCh <-chan events.SyncStarted syncFinishedCh <-chan events.SyncFinished forcedUpdateCh <-chan events.UpdateForced @@ -43,10 +46,10 @@ type testCtx struct { updateCh <-chan events.Event // These maps hold expected userIDByName, their primary addresses and bridge passwords. - userIDByName map[string]string - userAddrByID map[string]string - userPassByID map[string]string - addrIDByID map[string]string + userIDByName map[string]string + userAddrByEmail map[string]map[string]string + userPassByID map[string]string + userBridgePassByID map[string]string // These are the IMAP and SMTP clients used to connect to bridge. imapClients map[string]*imapClient @@ -83,10 +86,10 @@ func newTestCtx(tb testing.TB) *testCtx { mocks: bridge.NewMocks(tb, dialer, defaultVersion, defaultVersion), version: defaultVersion, - userIDByName: make(map[string]string), - userAddrByID: make(map[string]string), - userPassByID: make(map[string]string), - addrIDByID: make(map[string]string), + userIDByName: make(map[string]string), + userAddrByEmail: make(map[string]map[string]string), + userPassByID: make(map[string]string), + userBridgePassByID: make(map[string]string), imapClients: make(map[string]*imapClient), smtpClients: make(map[string]*smtpClient), @@ -112,12 +115,28 @@ func (t *testCtx) setUserID(username, userID string) { t.userIDByName[username] = userID } -func (t *testCtx) getUserAddr(userID string) string { - return t.userAddrByID[userID] +func (t *testCtx) getUserAddrID(userID, email string) string { + return t.userAddrByEmail[userID][email] } -func (t *testCtx) setUserAddr(userID, addr string) { - t.userAddrByID[userID] = addr +func (t *testCtx) getUserAddrs(userID string) []string { + return maps.Keys(t.userAddrByEmail[userID]) +} + +func (t *testCtx) setUserAddr(userID, addrID, email string) { + if _, ok := t.userAddrByEmail[userID]; !ok { + t.userAddrByEmail[userID] = make(map[string]string) + } + + t.userAddrByEmail[userID][email] = addrID +} + +func (t *testCtx) unsetUserAddr(userID, wantAddrID string) { + for email, addrID := range t.userAddrByEmail[userID] { + if addrID == wantAddrID { + delete(t.userAddrByEmail[userID], email) + } + } } func (t *testCtx) getUserPass(userID string) string { @@ -128,12 +147,12 @@ func (t *testCtx) setUserPass(userID, pass string) { t.userPassByID[userID] = pass } -func (t *testCtx) getAddrID(userID string) string { - return t.addrIDByID[userID] +func (t *testCtx) getUserBridgePass(userID string) string { + return t.userBridgePassByID[userID] } -func (t *testCtx) setAddrID(userID, addrID string) { - t.addrIDByID[userID] = addrID +func (t *testCtx) setUserBridgePass(userID, pass string) { + t.userBridgePassByID[userID] = pass } func (t *testCtx) getMBoxID(userID string, name string) string { diff --git a/tests/fast.go b/tests/fast.go new file mode 100644 index 00000000..73797794 --- /dev/null +++ b/tests/fast.go @@ -0,0 +1,48 @@ +package tests + +import ( + "crypto/x509" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/proton-bridge/v2/internal/certs" +) + +var ( + preCompPGPKey *crypto.Key + preCompCertPEM []byte + preCompKeyPEM []byte +) + +func FastGenerateKey(name, email string, passphrase []byte, keyType string, bits int) (string, error) { + encKey, err := preCompPGPKey.Lock(passphrase) + if err != nil { + return "", err + } + + return encKey.Armor() +} + +func FastGenerateCert(template *x509.Certificate) ([]byte, []byte, error) { + return preCompCertPEM, preCompKeyPEM, nil +} + +func init() { + key, err := crypto.GenerateKey("name", "email", "rsa", 1024) + if err != nil { + panic(err) + } + + template, err := certs.NewTLSTemplate() + if err != nil { + panic(err) + } + + certPEM, keyPEM, err := certs.GenerateCert(template) + if err != nil { + panic(err) + } + + preCompPGPKey = key + preCompCertPEM = certPEM + preCompKeyPEM = keyPEM +} diff --git a/tests/features/imap/user_agent.feature b/tests/features/imap/id.feature similarity index 100% rename from tests/features/imap/user_agent.feature rename to tests/features/imap/id.feature diff --git a/tests/features/imap/mailbox/info.feature b/tests/features/imap/mailbox/info.feature index 7f7b8752..fb16caa4 100644 --- a/tests/features/imap/mailbox/info.feature +++ b/tests/features/imap/mailbox/info.feature @@ -4,7 +4,7 @@ Feature: IMAP get mailbox info And the account "user@pm.me" has the following custom mailboxes: | name | type | | one | folder | - And the account "user@pm.me" has the following messages in "one": + And the address "user@pm.me" of account "user@pm.me" has the following messages in "one": | sender | recipient | subject | unread | | a@pm.me | a@pm.me | one | true | | b@pm.me | b@pm.me | two | false | diff --git a/tests/features/imap/message/copy.feature b/tests/features/imap/message/copy.feature index f435501c..9130b8bc 100644 --- a/tests/features/imap/message/copy.feature +++ b/tests/features/imap/message/copy.feature @@ -5,7 +5,7 @@ Feature: IMAP copy messages | name | type | | mbox | folder | | label | label | - And the account "user@pm.me" has the following messages in "Inbox": + And the address "user@pm.me" of account "user@pm.me" has the following messages in "Inbox": | sender | recipient | subject | unread | | john.doe@mail.com | user@pm.me | foo | false | | jane.doe@mail.com | name@pm.me | bar | true | diff --git a/tests/features/imap/message/delete.feature b/tests/features/imap/message/delete.feature index 5ac90ffd..4848d300 100644 --- a/tests/features/imap/message/delete.feature +++ b/tests/features/imap/message/delete.feature @@ -5,7 +5,7 @@ Feature: IMAP remove messages from mailbox | name | type | | mbox | folder | | label | label | - And the account "user@pm.me" has 10 messages in "mbox" + And the address "user@pm.me" of account "user@pm.me" has 10 messages in "mbox" And bridge starts And the user logs in with username "user@pm.me" and password "password" And user "user@pm.me" finishes syncing diff --git a/tests/features/user/addressmode.feature b/tests/features/user/addressmode.feature new file mode 100644 index 00000000..d0659773 --- /dev/null +++ b/tests/features/user/addressmode.feature @@ -0,0 +1,180 @@ +Feature: Address mode + Background: + Given there exists an account with username "user@pm.me" and password "password" + And the account "user@pm.me" has additional address "alias@pm.me" + And the account "user@pm.me" has the following custom mailboxes: + | name | type | + | one | folder | + | two | folder | + And the address "user@pm.me" of account "user@pm.me" has the following messages in "one": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + And the address "alias@pm.me" of account "user@pm.me" has the following messages in "two": + | sender | recipient | subject | unread | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + And bridge starts + And the user logs in with username "user@pm.me" and password "password" + And user "user@pm.me" finishes syncing + + Scenario: The user is in combined mode + When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me" + Then IMAP client "1" sees the following messages in "Folders/one": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + And IMAP client "1" sees the following messages in "Folders/two": + | sender | recipient | subject | unread | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + And IMAP client "1" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me" + Then IMAP client "2" sees the following messages in "Folders/one": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + And IMAP client "2" sees the following messages in "Folders/two": + | sender | recipient | subject | unread | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + And IMAP client "2" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + + Scenario: The user is in split mode + Given the user sets the address mode of "user@pm.me" to "split" + And user "user@pm.me" finishes syncing + When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me" + Then IMAP client "1" sees the following messages in "Folders/one": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + And IMAP client "1" sees 0 messages in "Folders/two" + And IMAP client "1" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me" + Then IMAP client "2" sees 0 messages in "Folders/one" + And IMAP client "2" sees the following messages in "Folders/two": + | sender | recipient | subject | unread | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + And IMAP client "2" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + + Scenario: The user switches from combined to split mode and back + Given the user sets the address mode of "user@pm.me" to "split" + And user "user@pm.me" finishes syncing + And the user sets the address mode of "user@pm.me" to "combined" + And user "user@pm.me" finishes syncing + When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me" + Then IMAP client "1" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me" + Then IMAP client "2" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + + Scenario: The user adds an address while in combined mode + When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me" + Then IMAP client "1" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me" + Then IMAP client "2" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + Given the account "user@pm.me" has additional address "other@pm.me" + And bridge sends an address created event for user "user@pm.me" + When user "user@pm.me" connects and authenticates IMAP client "3" with address "other@pm.me" + Then IMAP client "3" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + + Scenario: The user adds an address while in split mode + Given the user sets the address mode of "user@pm.me" to "split" + And user "user@pm.me" finishes syncing + When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me" + And IMAP client "1" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me" + And IMAP client "2" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + Given the account "user@pm.me" has additional address "other@pm.me" + And bridge sends an address created event for user "user@pm.me" + When user "user@pm.me" connects and authenticates IMAP client "3" with address "other@pm.me" + Then IMAP client "3" eventually sees 0 messages in "All Mail" + + Scenario: The user deletes an address while in combined mode + When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me" + Then IMAP client "1" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me" + Then IMAP client "2" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + Given the account "user@pm.me" no longer has additional address "alias@pm.me" + And bridge sends an address deleted event for user "user@pm.me" + When user "user@pm.me" connects IMAP client "3" + Then IMAP client "3" cannot authenticate with address "alias@pm.me" + + Scenario: The user deletes an address while in split mode + Given the user sets the address mode of "user@pm.me" to "split" + And user "user@pm.me" finishes syncing + When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me" + And IMAP client "1" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | a@pm.me | a@pm.me | one | true | + | b@pm.me | b@pm.me | two | false | + When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me" + And IMAP client "2" sees the following messages in "All Mail": + | sender | recipient | subject | unread | + | c@pm.me | c@pm.me | three | true | + | d@pm.me | d@pm.me | four | false | + Given the account "user@pm.me" no longer has additional address "alias@pm.me" + And bridge sends an address deleted event for user "user@pm.me" + When user "user@pm.me" connects IMAP client "3" + Then IMAP client "3" cannot authenticate with address "alias@pm.me" + + Scenario: The user makes an alias the primary address while in combined mode + + Scenario: The user makes an alias the primary address while in split mode \ No newline at end of file diff --git a/tests/features/user/sync.feature b/tests/features/user/sync.feature index f9567022..3dce33b6 100644 --- a/tests/features/user/sync.feature +++ b/tests/features/user/sync.feature @@ -6,11 +6,11 @@ Feature: Bridge can fully sync an account | one | folder | | two | folder | | three | label | - And the account "user@pm.me" has the following messages in "one": + And the address "user@pm.me" of account "user@pm.me" has the following messages in "one": | sender | recipient | subject | unread | | a@pm.me | a@pm.me | one | true | | b@pm.me | b@pm.me | two | false | - And the account "user@pm.me" has the following messages in "two": + And the address "user@pm.me" of account "user@pm.me" has the following messages in "two": | sender | recipient | subject | unread | | a@pm.me | a@pm.me | one | true | | b@pm.me | b@pm.me | two | false | diff --git a/tests/imap_test.go b/tests/imap_test.go index 7d4a92ad..b1d1db3c 100644 --- a/tests/imap_test.go +++ b/tests/imap_test.go @@ -26,25 +26,39 @@ func (s *scenario) userConnectsIMAPClientOnPort(username, clientID string, port } func (s *scenario) userConnectsAndAuthenticatesIMAPClient(username, clientID string) error { + return s.userConnectsAndAuthenticatesIMAPClientWithAddress(username, clientID, s.t.getUserAddrs(s.t.getUserID(username))[0]) +} + +func (s *scenario) userConnectsAndAuthenticatesIMAPClientWithAddress(username, clientID, address string) error { if err := s.t.newIMAPClient(s.t.getUserID(username), clientID); err != nil { return err } userID, client := s.t.getIMAPClient(clientID) - return client.Login(s.t.getUserAddr(userID), s.t.getUserPass(userID)) + return client.Login(address, s.t.getUserBridgePass(userID)) } func (s *scenario) imapClientCanAuthenticate(clientID string) error { userID, client := s.t.getIMAPClient(clientID) - return client.Login(s.t.getUserAddr(userID), s.t.getUserPass(userID)) + return client.Login(s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)) } func (s *scenario) imapClientCannotAuthenticate(clientID string) error { userID, client := s.t.getIMAPClient(clientID) - if err := client.Login(s.t.getUserAddr(userID), s.t.getUserPass(userID)); err == nil { + if err := client.Login(s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)); err == nil { + return fmt.Errorf("expected error, got nil") + } + + return nil +} + +func (s *scenario) imapClientCannotAuthenticateWithAddress(clientID, address string) error { + userID, client := s.t.getIMAPClient(clientID) + + if err := client.Login(address, s.t.getUserBridgePass(userID)); err == nil { return fmt.Errorf("expected error, got nil") } @@ -54,7 +68,7 @@ func (s *scenario) imapClientCannotAuthenticate(clientID string) error { func (s *scenario) imapClientCannotAuthenticateWithIncorrectUsername(clientID string) error { userID, client := s.t.getIMAPClient(clientID) - if err := client.Login(s.t.getUserAddr(userID)+"bad", s.t.getUserPass(userID)); err == nil { + if err := client.Login(s.t.getUserAddrs(userID)[0]+"bad", s.t.getUserBridgePass(userID)); err == nil { return fmt.Errorf("expected error, got nil") } @@ -64,7 +78,7 @@ func (s *scenario) imapClientCannotAuthenticateWithIncorrectUsername(clientID st func (s *scenario) imapClientCannotAuthenticateWithIncorrectPassword(clientID string) error { userID, client := s.t.getIMAPClient(clientID) - if err := client.Login(s.t.getUserAddr(userID), s.t.getUserPass(userID)+"bad"); err == nil { + if err := client.Login(s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)+"bad"); err == nil { return fmt.Errorf("expected error, got nil") } diff --git a/tests/init_test.go b/tests/init_test.go index a80896f7..6f6acd68 100644 --- a/tests/init_test.go +++ b/tests/init_test.go @@ -1,39 +1,14 @@ package tests import ( - "crypto/x509" - - "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/certs" "gitlab.protontech.ch/go/liteapi/server/account" ) func init() { - key, err := crypto.GenerateKey("name", "email", "rsa", 1024) - if err != nil { - panic(err) - } + // Use the fast key generation for tests. + account.GenerateKey = FastGenerateKey - account.GenerateKey = func(name, email string, passphrase []byte, keyType string, bits int) (string, error) { - encKey, err := key.Lock(passphrase) - if err != nil { - return "", err - } - - return encKey.Armor() - } - - template, err := certs.NewTLSTemplate() - if err != nil { - panic(err) - } - - certPEM, keyPEM, err := certs.GenerateCert(template) - if err != nil { - panic(err) - } - - certs.GenerateCert = func(template *x509.Certificate) ([]byte, []byte, error) { - return certPEM, keyPEM, nil - } + // Use the fast cert generation for tests. + certs.GenerateCert = FastGenerateCert } diff --git a/tests/smtp_test.go b/tests/smtp_test.go index 578d4767..81f84f6f 100644 --- a/tests/smtp_test.go +++ b/tests/smtp_test.go @@ -16,13 +16,17 @@ func (s *scenario) userConnectsSMTPClientOnPort(username, clientID string, port } func (s *scenario) userConnectsAndAuthenticatesSMTPClient(username, clientID string) error { + return s.userConnectsAndAuthenticatesSMTPClientWithAddress(username, clientID, s.t.getUserAddrs(s.t.getUserID(username))[0]) +} + +func (s *scenario) userConnectsAndAuthenticatesSMTPClientWithAddress(username, clientID, address string) error { if err := s.t.newSMTPClient(s.t.getUserID(username), clientID); err != nil { return err } userID, client := s.t.getSMTPClient(clientID) - s.t.pushError(client.Auth(smtp.PlainAuth("", s.t.getUserAddr(userID), s.t.getUserPass(userID), constants.Host))) + s.t.pushError(client.Auth(smtp.PlainAuth("", address, s.t.getUserBridgePass(userID), constants.Host))) return nil } @@ -30,7 +34,7 @@ func (s *scenario) userConnectsAndAuthenticatesSMTPClient(username, clientID str func (s *scenario) smtpClientCanAuthenticate(clientID string) error { userID, client := s.t.getSMTPClient(clientID) - if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddr(userID), s.t.getUserPass(userID), constants.Host)); err != nil { + if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID), constants.Host)); err != nil { return fmt.Errorf("expected no error, got %v", err) } @@ -40,7 +44,7 @@ func (s *scenario) smtpClientCanAuthenticate(clientID string) error { func (s *scenario) smtpClientCannotAuthenticate(clientID string) error { userID, client := s.t.getSMTPClient(clientID) - if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddr(userID), s.t.getUserPass(userID), constants.Host)); err == nil { + if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID), constants.Host)); err == nil { return fmt.Errorf("expected error, got nil") } @@ -50,7 +54,7 @@ func (s *scenario) smtpClientCannotAuthenticate(clientID string) error { func (s *scenario) smtpClientCannotAuthenticateWithIncorrectUsername(clientID string) error { userID, client := s.t.getSMTPClient(clientID) - if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddr(userID)+"bad", s.t.getUserPass(userID), constants.Host)); err == nil { + if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0]+"bad", s.t.getUserBridgePass(userID), constants.Host)); err == nil { return fmt.Errorf("expected error, got nil") } @@ -60,7 +64,7 @@ func (s *scenario) smtpClientCannotAuthenticateWithIncorrectUsername(clientID st func (s *scenario) smtpClientCannotAuthenticateWithIncorrectPassword(clientID string) error { userID, client := s.t.getSMTPClient(clientID) - if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddr(userID), s.t.getUserPass(userID)+"bad", constants.Host)); err == nil { + if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)+"bad", constants.Host)); err == nil { return fmt.Errorf("expected error, got nil") } diff --git a/tests/user_test.go b/tests/user_test.go index 634a711c..044d659d 100644 --- a/tests/user_test.go +++ b/tests/user_test.go @@ -16,6 +16,7 @@ import ( ) func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error { + // Create the user. userID, addrID, err := s.t.api.AddUser(username, password, username) if err != nil { return err @@ -24,11 +25,37 @@ func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, passwor // Set the ID of this user. s.t.setUserID(username, userID) - // Set the address ID of this user. - s.t.setAddrID(userID, addrID) + // Set the password of this user. + s.t.setUserPass(userID, password) // Set the address of this user (right now just the same as the username, but let's stay flexible). - s.t.setUserAddr(userID, username) + s.t.setUserAddr(userID, addrID, username) + + return nil +} + +func (s *scenario) theAccountHasAdditionalAddress(username, address string) error { + userID := s.t.getUserID(username) + + addrID, err := s.t.api.AddAddress(userID, address, s.t.getUserPass(userID)) + if err != nil { + return err + } + + s.t.setUserAddr(userID, addrID, address) + + return nil +} + +func (s *scenario) theAccountNoLongerHasAdditionalAddress(username, address string) error { + userID := s.t.getUserID(username) + addrID := s.t.getUserAddrID(userID, address) + + if err := s.t.api.RemoveAddress(userID, addrID); err != nil { + return err + } + + s.t.unsetUserAddr(userID, addrID) return nil } @@ -84,9 +111,9 @@ func (s *scenario) theAccountHasTheFollowingCustomMailboxes(username string, tab return nil } -func (s *scenario) theAccountHasTheFollowingMessagesInMailbox(username, mailbox string, table *godog.Table) error { +func (s *scenario) theAddressOfAccountHasTheFollowingMessagesInMailbox(address, username, mailbox string, table *godog.Table) error { userID := s.t.getUserID(username) - addrID := s.t.getAddrID(userID) + addrID := s.t.getUserAddrID(userID, address) mboxID := s.t.getMBoxID(userID, mailbox) for _, wantMessage := range parseMessages(table) { @@ -109,9 +136,9 @@ func (s *scenario) theAccountHasTheFollowingMessagesInMailbox(username, mailbox return nil } -func (s *scenario) theAccountHasMessagesInMailbox(username string, count int, mailbox string) error { +func (s *scenario) theAddressOfAccountHasMessagesInMailbox(address, username string, count int, mailbox string) error { userID := s.t.getUserID(username) - addrID := s.t.getAddrID(userID) + addrID := s.t.getUserAddrID(userID, address) mboxID := s.t.getMBoxID(userID, mailbox) for idx := 0; idx < count; idx++ { @@ -148,7 +175,7 @@ func (s *scenario) userLogsInWithUsernameAndPassword(username, password string) return err } - s.t.setUserPass(userID, info.BridgePass) + s.t.setUserBridgePass(userID, info.BridgePass) } return nil