1
0

GODT-1815: Combined/Split mode

This commit is contained in:
James Houlahan
2022-09-28 11:29:33 +02:00
parent 9670e29d9f
commit e9672e6bba
55 changed files with 1909 additions and 705 deletions

4
go.mod
View File

@ -5,7 +5,7 @@ go 1.18
require ( require (
github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557 github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557
github.com/Masterminds/semver/v3 v3.1.1 github.com/Masterminds/semver/v3 v3.1.1
github.com/ProtonMail/gluon v0.11.1-0.20220922143913-ef3617264557 github.com/ProtonMail/gluon v0.11.1-0.20221001180052-2e11f5804b8a
github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a
github.com/ProtonMail/go-rfc5322 v0.11.0 github.com/ProtonMail/go-rfc5322 v0.11.0
github.com/ProtonMail/gopenpgp/v2 v2.4.10 github.com/ProtonMail/gopenpgp/v2 v2.4.10
@ -37,7 +37,7 @@ require (
github.com/sirupsen/logrus v1.9.0 github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.0
github.com/urfave/cli/v2 v2.16.3 github.com/urfave/cli/v2 v2.16.3
gitlab.protontech.ch/go/liteapi v0.31.0 gitlab.protontech.ch/go/liteapi v0.31.1-0.20221001204216-b781c54ca2a6
golang.org/x/exp v0.0.0-20220921164117-439092de6870 golang.org/x/exp v0.0.0-20220921164117-439092de6870
golang.org/x/net v0.1.0 golang.org/x/net v0.1.0
golang.org/x/sys v0.1.0 golang.org/x/sys v0.1.0

8
go.sum
View File

@ -29,8 +29,8 @@ github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo= github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
github.com/ProtonMail/docker-credential-helpers v1.1.0 h1:+kvUIpwWcbtP3WFv5sSvkFn/XLzSqPOB5AAthuk9xPk= github.com/ProtonMail/docker-credential-helpers v1.1.0 h1:+kvUIpwWcbtP3WFv5sSvkFn/XLzSqPOB5AAthuk9xPk=
github.com/ProtonMail/docker-credential-helpers v1.1.0/go.mod h1:mK0aBveCxhnQ756AmaTfXMZDeULvheYVhF/MWMErN5g= github.com/ProtonMail/docker-credential-helpers v1.1.0/go.mod h1:mK0aBveCxhnQ756AmaTfXMZDeULvheYVhF/MWMErN5g=
github.com/ProtonMail/gluon v0.11.1-0.20220922143913-ef3617264557 h1:uyiHq7jDgn1p2TeMKRPnVCVs2bHoNL9AYs26UzLYr4I= github.com/ProtonMail/gluon v0.11.1-0.20221001180052-2e11f5804b8a h1:JUjaQ7bUifpYdnLKviBPrVKOPfW6r4Mm8xCL1fdevaA=
github.com/ProtonMail/gluon v0.11.1-0.20220922143913-ef3617264557/go.mod h1:9k3URQEASX9XSA+JEcukjIiK3S6aR9GzhLhwccy8AnI= github.com/ProtonMail/gluon v0.11.1-0.20221001180052-2e11f5804b8a/go.mod h1:9k3URQEASX9XSA+JEcukjIiK3S6aR9GzhLhwccy8AnI=
github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4=
github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a/go.mod h1:oTGdE7/DlWIr23G0IKW3OXK9wZ5Hw1GGiaJFccTvZi4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a/go.mod h1:oTGdE7/DlWIr23G0IKW3OXK9wZ5Hw1GGiaJFccTvZi4=
github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo= github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo=
@ -463,8 +463,8 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0= github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
gitlab.protontech.ch/go/liteapi v0.31.0 h1:Et3P2EyTySldgBunFJqaa5W5Fap8yLvuOaLUkZX/Kn0= gitlab.protontech.ch/go/liteapi v0.31.1-0.20221001204216-b781c54ca2a6 h1:N9Wzm4pNhIjR4aBmP9AzVGy+G8XQCDlkLy9GGEONbYM=
gitlab.protontech.ch/go/liteapi v0.31.0/go.mod h1:ixp1LUOxOYuB1qf172GdV0ZT8fOomKxVFtIMZeSWg+I= gitlab.protontech.ch/go/liteapi v0.31.1-0.20221001204216-b781c54ca2a6/go.mod h1:ixp1LUOxOYuB1qf172GdV0ZT8fOomKxVFtIMZeSWg+I=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=

View File

@ -232,6 +232,13 @@ func (bridge *Bridge) GetErrors() []error {
} }
func (bridge *Bridge) Close(ctx context.Context) error { func (bridge *Bridge) Close(ctx context.Context) error {
// Abort any ongoing syncs.
for _, user := range bridge.users {
if err := user.AbortSync(ctx); err != nil {
return fmt.Errorf("failed to abort sync: %w", err)
}
}
// Close the IMAP server. // Close the IMAP server.
if err := bridge.closeIMAP(ctx); err != nil { if err := bridge.closeIMAP(ctx); err != nil {
logrus.WithError(err).Error("Failed to close IMAP server") logrus.WithError(err).Error("Failed to close IMAP server")

View File

@ -4,19 +4,24 @@ import (
"context" "context"
"os" "os"
"testing" "testing"
"time"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/focus" "github.com/ProtonMail/proton-bridge/v2/internal/focus"
"github.com/ProtonMail/proton-bridge/v2/internal/locations" "github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/updater" "github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/ProtonMail/proton-bridge/v2/tests"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server"
"gitlab.protontech.ch/go/liteapi/server/account"
) )
const ( const (
@ -29,6 +34,13 @@ var (
v2_4_0 = semver.MustParse("2.4.0") v2_4_0 = semver.MustParse("2.4.0")
) )
func init() {
user.DefaultEventPeriod = 100 * time.Millisecond
user.DefaultEventJitter = 0
account.GenerateKey = tests.FastGenerateKey
certs.GenerateCert = tests.FastGenerateCert
}
func TestBridge_ConnStatus(t *testing.T) { func TestBridge_ConnStatus(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) { withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
@ -156,19 +168,27 @@ func TestBridge_CheckUpdate(t *testing.T) {
// Disable autoupdate for this test. // Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false)) require.NoError(t, bridge.SetAutoUpdate(false))
// Get a stream of update events. // Get a stream of update not available events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{}) noUpdateCh, done := bridge.GetEvents(events.UpdateNotAvailable{})
defer done() defer done()
// We are currently on the latest version. // We are currently on the latest version.
bridge.CheckForUpdates() bridge.CheckForUpdates()
require.Equal(t, events.UpdateNotAvailable{}, <-updateCh)
// we should receive an event indicating that no update is available.
require.Equal(t, events.UpdateNotAvailable{}, <-noUpdateCh)
// Simulate a new version being available. // Simulate a new version being available.
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0) mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
// Get a stream of update available events.
updateCh, done := bridge.GetEvents(events.UpdateAvailable{})
defer done()
// Check for updates. // Check for updates.
bridge.CheckForUpdates() bridge.CheckForUpdates()
// We should receive an event indicating that an update is available.
require.Equal(t, events.UpdateAvailable{ require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{ Version: updater.VersionInfo{
Version: v2_4_0, Version: v2_4_0,
@ -188,7 +208,7 @@ func TestBridge_AutoUpdate(t *testing.T) {
require.NoError(t, bridge.SetAutoUpdate(true)) require.NoError(t, bridge.SetAutoUpdate(true))
// Get a stream of update events. // Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateInstalled{}) updateCh, done := bridge.GetEvents(events.UpdateInstalled{})
defer done() defer done()
// Simulate a new version being available. // Simulate a new version being available.
@ -196,6 +216,8 @@ func TestBridge_AutoUpdate(t *testing.T) {
// Check for updates. // Check for updates.
bridge.CheckForUpdates() bridge.CheckForUpdates()
// We should receive an event indicating that the update was installed.
require.Equal(t, events.UpdateInstalled{ require.Equal(t, events.UpdateInstalled{
Version: updater.VersionInfo{ Version: updater.VersionInfo{
Version: v2_4_0, Version: v2_4_0,
@ -213,8 +235,8 @@ func TestBridge_ManualUpdate(t *testing.T) {
// Disable autoupdate for this test. // Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false)) require.NoError(t, bridge.SetAutoUpdate(false))
// Get a stream of update events. // Get a stream of update available events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{}) updateCh, done := bridge.GetEvents(events.UpdateAvailable{})
defer done() defer done()
// Simulate a new version being available, but it's too new for us. // Simulate a new version being available, but it's too new for us.
@ -222,6 +244,8 @@ func TestBridge_ManualUpdate(t *testing.T) {
// Check for updates. // Check for updates.
bridge.CheckForUpdates() bridge.CheckForUpdates()
// We should receive an event indicating an update is available, but we can't install it.
require.Equal(t, events.UpdateAvailable{ require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{ Version: updater.VersionInfo{
Version: v2_4_0, Version: v2_4_0,

View File

@ -14,8 +14,9 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
return ErrNoSuchUser return ErrNoSuchUser
} }
// TODO: Handle split mode!
if address == "" { if address == "" {
address = user.Addresses()[0] address = user.Emails()[0]
} }
// If configuring apple mail for Catalina or newer, users should use SSL. // If configuring apple mail for Catalina or newer, users should use SSL.
@ -32,7 +33,7 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
bridge.vault.GetIMAPSSL(), bridge.vault.GetIMAPSSL(),
bridge.vault.GetSMTPSSL(), bridge.vault.GetSMTPSSL(),
address, address,
strings.Join(user.Addresses(), ","), strings.Join(user.Emails(), ","),
user.BridgePass(), user.BridgePass(),
) )
} }

View File

@ -2,6 +2,7 @@ package bridge
import ( import (
"context" "context"
"fmt"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/updater" "github.com/ProtonMail/proton-bridge/v2/internal/updater"
@ -96,40 +97,39 @@ func (bridge *Bridge) GetGluonDir() string {
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error { func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
if newGluonDir == bridge.GetGluonDir() { if newGluonDir == bridge.GetGluonDir() {
return nil return fmt.Errorf("new gluon dir is the same as the old one")
} }
if err := bridge.closeIMAP(context.Background()); err != nil { if err := bridge.closeIMAP(context.Background()); err != nil {
return err return fmt.Errorf("failed to close IMAP: %w", err)
} }
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil { if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
return err return fmt.Errorf("failed to move gluon dir: %w", err)
} }
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil { if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
return err return fmt.Errorf("failed to set new gluon dir: %w", err)
} }
imapServer, err := newIMAPServer(bridge.vault.GetGluonDir(), bridge.curVersion, bridge.tlsConfig) imapServer, err := newIMAPServer(bridge.vault.GetGluonDir(), bridge.curVersion, bridge.tlsConfig)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to create new IMAP server: %w", err)
}
for _, user := range bridge.users {
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return err
}
if err := imapServer.LoadUser(context.Background(), imapConn, user.GluonID(), user.GluonKey()); err != nil {
return err
}
} }
bridge.imapServer = imapServer bridge.imapServer = imapServer
return bridge.serveIMAP() for _, user := range bridge.users {
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
}
if err := bridge.serveIMAP(); err != nil {
return fmt.Errorf("failed to serve IMAP: %w", err)
}
return nil
} }
func (bridge *Bridge) GetProxyAllowed() bool { func (bridge *Bridge) GetProxyAllowed() bool {

View File

@ -23,8 +23,8 @@ func (backend *smtpBackend) Login(state *smtp.ConnectionState, username string,
defer backend.usersLock.RUnlock() defer backend.usersLock.RUnlock()
for _, user := range backend.users { for _, user := range backend.users {
if slices.Contains(user.Addresses(), username) && user.BridgePass() == password { if slices.Contains(user.Emails(), username) && user.BridgePass() == password {
return user.NewSMTPSession(username) return user.NewSMTPSession(username), nil
} }
} }

View File

@ -29,7 +29,7 @@ type UserInfo struct {
Addresses []string Addresses []string
// AddressMode is the user's address mode. // AddressMode is the user's address mode.
AddressMode AddressMode AddressMode vault.AddressMode
// BridgePass is the user's bridge password. // BridgePass is the user's bridge password.
BridgePass string BridgePass string
@ -41,13 +41,6 @@ type UserInfo struct {
MaxSpace int MaxSpace int
} }
type AddressMode int
const (
SplitMode AddressMode = iota
CombinedMode
)
// GetUserIDs returns the IDs of all known users (authorized or not). // GetUserIDs returns the IDs of all known users (authorized or not).
func (bridge *Bridge) GetUserIDs() []string { func (bridge *Bridge) GetUserIDs() []string {
return bridge.vault.GetUserIDs() return bridge.vault.GetUserIDs()
@ -62,7 +55,7 @@ func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
user, ok := bridge.users[userID] user, ok := bridge.users[userID]
if !ok { if !ok {
return getUserInfo(vaultUser.UserID(), vaultUser.Username()), nil return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil
} }
return getConnUserInfo(user), nil return getConnUserInfo(user), nil
@ -153,12 +146,43 @@ func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
return nil return nil
} }
func (bridge *Bridge) GetAddressMode(userID string) (AddressMode, error) { // SetAddressMode sets the address mode for the given user.
panic("TODO") func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error {
} user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
func (bridge *Bridge) SetAddressMode(userID string, mode AddressMode) error { if user.GetAddressMode() == mode {
panic("TODO") return fmt.Errorf("address mode is already %q", mode)
}
if err := user.AbortSync(ctx); err != nil {
return fmt.Errorf("failed to abort sync: %w", err)
}
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
}
if err := user.SetAddressMode(ctx, mode); err != nil {
return fmt.Errorf("failed to set address mode: %w", err)
}
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
bridge.publish(events.AddressModeChanged{
UserID: userID,
AddressMode: mode,
})
user.DoSync(ctx)
return nil
} }
// loadUsers loads authorized users from the vault. // loadUsers loads authorized users from the vault.
@ -177,7 +201,7 @@ func (bridge *Bridge) loadUsers(ctx context.Context) error {
logrus.WithError(err).Error("Failed to load connected user") logrus.WithError(err).Error("Failed to load connected user")
if _, ok := err.(*resty.ResponseError); ok { if _, ok := err.(*resty.ResponseError); ok {
if err := user.Clear(); err != nil { if err := bridge.vault.ClearUser(userID); err != nil {
logrus.WithError(err).Error("Failed to clear user") logrus.WithError(err).Error("Failed to clear user")
} }
} }
@ -231,33 +255,41 @@ func (bridge *Bridge) addUser(
if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) { if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) {
existingUser, err := bridge.addExistingUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass) existingUser, err := bridge.addExistingUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to add existing user: %w", err)
} }
user = existingUser user = existingUser
} else { } else {
newUser, err := bridge.addNewUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass) newUser, err := bridge.addNewUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to add new user: %w", err)
} }
user = newUser user = newUser
} }
go func() { // Connects the user's address(es) to gluon.
for event := range user.GetNotifyCh() { if err := bridge.addIMAPUser(ctx, user); err != nil {
switch event := event.(type) { return fmt.Errorf("failed to add IMAP user: %w", err)
case events.UserDeauth: }
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
}
bridge.publish(event) // Handle events coming from the user before forwarding them to the bridge.
// For example, if the user's addresses change, we need to update them in gluon.
go func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for event := range user.GetEventCh() {
if err := bridge.handleUserEvent(ctx, user, event); err != nil {
logrus.WithError(err).Error("Failed to handle user event")
} else {
bridge.publish(event)
}
} }
}() }()
// Gluon will set the IMAP ID in the context, if known, before making requests on behalf of this user. // Gluon will set the IMAP ID in the context, if known, before making requests on behalf of this user.
// As such, if we find this ID in the context, we should use it to update our user agent.
client.AddPreRequestHook(func(ctx context.Context, req *resty.Request) error { client.AddPreRequestHook(func(ctx context.Context, req *resty.Request) error {
if imapID, ok := imap.GetIMAPIDFromContext(ctx); ok { if imapID, ok := imap.GetIMAPIDFromContext(ctx); ok {
bridge.identifier.SetClient(imapID.Name, imapID.Version) bridge.identifier.SetClient(imapID.Name, imapID.Version)
@ -266,6 +298,11 @@ func (bridge *Bridge) addUser(
return nil return nil
}) })
// TODO: Replace this with proper sync manager.
if !user.HasSync() {
user.DoSync(ctx)
}
bridge.publish(events.UserLoggedIn{ bridge.publish(events.UserLoggedIn{
UserID: user.ID(), UserID: user.ID(),
}) })
@ -293,25 +330,6 @@ func (bridge *Bridge) addNewUser(
return nil, err return nil, err
} }
gluonKey, err := crypto.RandomToken(32)
if err != nil {
return nil, err
}
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return nil, err
}
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, gluonKey)
if err != nil {
return nil, err
}
if err := vaultUser.SetGluonAuth(gluonID, gluonKey); err != nil {
return nil, err
}
if err := bridge.smtpBackend.addUser(user); err != nil { if err := bridge.smtpBackend.addUser(user); err != nil {
return nil, err return nil, err
} }
@ -349,15 +367,6 @@ func (bridge *Bridge) addExistingUser(
return nil, err return nil, err
} }
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return nil, err
}
if err := bridge.imapServer.LoadUser(ctx, imapConn, user.GluonID(), user.GluonKey()); err != nil {
return nil, err
}
if err := bridge.smtpBackend.addUser(user); err != nil { if err := bridge.smtpBackend.addUser(user); err != nil {
return nil, err return nil, err
} }
@ -376,31 +385,39 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, wi
return ErrNoSuchUser return ErrNoSuchUser
} }
vaultUser, err := bridge.vault.GetUser(userID) // TODO: The sync should be canceled by the sync manager.
if err != nil { if err := user.AbortSync(ctx); err != nil {
return err return fmt.Errorf("failed to abort user sync: %w", err)
}
if err := bridge.imapServer.RemoveUser(ctx, vaultUser.GluonID(), withFiles); err != nil {
return err
} }
if err := bridge.smtpBackend.removeUser(user); err != nil { if err := bridge.smtpBackend.removeUser(user); err != nil {
return err return fmt.Errorf("failed to remove SMTP user: %w", err)
}
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
} }
if withAPI { if withAPI {
if err := user.Logout(ctx); err != nil { if err := user.Logout(ctx); err != nil {
return err return fmt.Errorf("failed to logout user: %w", err)
} }
} }
if err := user.Close(ctx); err != nil { if err := user.Close(ctx); err != nil {
return err return fmt.Errorf("failed to close user: %w", err)
} }
if err := vaultUser.Clear(); err != nil { if err := bridge.vault.ClearUser(userID); err != nil {
return err return fmt.Errorf("failed to clear user: %w", err)
}
if withFiles {
if err := bridge.vault.DeleteUser(userID); err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
} }
delete(bridge.users, userID) delete(bridge.users, userID)
@ -412,12 +429,39 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, wi
return nil return nil
} }
// addIMAPUser connects the given user to gluon.
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
imapConn, err := user.NewIMAPConnectors()
if err != nil {
return fmt.Errorf("failed to create IMAP connectors: %w", err)
}
for addrID, imapConn := range imapConn {
if gluonID, ok := user.GetGluonID(addrID); ok {
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to load IMAP user: %w", err)
}
} else {
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
}
}
return nil
}
// getUserInfo returns information about a disconnected user. // getUserInfo returns information about a disconnected user.
func getUserInfo(userID, username string) UserInfo { func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo {
return UserInfo{ return UserInfo{
UserID: userID, UserID: userID,
Username: username, Username: username,
AddressMode: CombinedMode, AddressMode: addressMode,
} }
} }
@ -427,8 +471,8 @@ func getConnUserInfo(user *user.User) UserInfo {
Connected: true, Connected: true,
UserID: user.ID(), UserID: user.ID(),
Username: user.Name(), Username: user.Name(),
Addresses: user.Addresses(), Addresses: user.Emails(),
AddressMode: CombinedMode, AddressMode: user.GetAddressMode(),
BridgePass: user.BridgePass(), BridgePass: user.BridgePass(),
UsedSpace: user.UsedSpace(), UsedSpace: user.UsedSpace(),
MaxSpace: user.MaxSpace(), MaxSpace: user.MaxSpace(),

View File

@ -0,0 +1,118 @@
package bridge
import (
"context"
"fmt"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
)
func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, event events.Event) error {
switch event := event.(type) {
case events.UserAddressCreated:
if err := bridge.handleUserAddressCreated(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user address created event: %w", err)
}
case events.UserAddressUpdated:
if err := bridge.handleUserAddressUpdated(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user address updated event: %w", err)
}
case events.UserAddressDeleted:
if err := bridge.handleUserAddressDeleted(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user address deleted event: %w", err)
}
case events.UserDeauth:
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
return fmt.Errorf("failed to logout user: %w", err)
}
}
return nil
}
func (bridge *Bridge) handleUserAddressCreated(ctx context.Context, user *user.User, event events.UserAddressCreated) error {
switch user.GetAddressMode() {
case vault.CombinedMode:
for addrID, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
imapConn, err := user.NewIMAPConnector(addrID)
if err != nil {
return fmt.Errorf("failed to create IMAP connector: %w", err)
}
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to add user to IMAP server: %w", err)
}
}
case vault.SplitMode:
imapConn, err := user.NewIMAPConnector(event.AddressID)
if err != nil {
return fmt.Errorf("failed to create IMAP connector: %w", err)
}
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add user to IMAP server: %w", err)
}
if err := user.SetGluonID(event.AddressID, gluonID); err != nil {
return fmt.Errorf("failed to set gluon ID: %w", err)
}
}
return nil
}
// TODO: Handle addresses that have been disabled!
func (bridge *Bridge) handleUserAddressUpdated(ctx context.Context, user *user.User, event events.UserAddressUpdated) error {
switch user.GetAddressMode() {
case vault.CombinedMode:
return fmt.Errorf("not implemented")
case vault.SplitMode:
return fmt.Errorf("not implemented")
}
return nil
}
func (bridge *Bridge) handleUserAddressDeleted(ctx context.Context, user *user.User, event events.UserAddressDeleted) error {
switch user.GetAddressMode() {
case vault.CombinedMode:
for addrID, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
imapConn, err := user.NewIMAPConnector(addrID)
if err != nil {
return fmt.Errorf("failed to create IMAP connector: %w", err)
}
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to add user to IMAP server: %w", err)
}
}
case vault.SplitMode:
gluonID, ok := user.GetGluonID(event.AddressID)
if !ok {
return fmt.Errorf("gluon ID not found for address %s", event.AddressID)
}
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
}
return nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server"
) )
@ -283,3 +284,30 @@ func TestBridge_BridgePass(t *testing.T) {
}) })
}) })
} }
func TestBridge_AddressMode(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
require.NoError(t, err)
// Get the user's info.
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
// The user is in combined mode by default.
require.Equal(t, vault.CombinedMode, info.AddressMode)
// Put the user in split mode.
require.NoError(t, bridge.SetAddressMode(ctx, userID, vault.SplitMode))
// Get the user's info.
info, err = bridge.GetUserInfo(userID)
require.NoError(t, err)
// The user is in split mode.
require.Equal(t, vault.SplitMode, info.AddressMode)
})
})
}

View File

@ -1,5 +1,7 @@
package events package events
import "github.com/ProtonMail/proton-bridge/v2/internal/vault"
type UserLoggedIn struct { type UserLoggedIn struct {
eventBase eventBase
@ -33,20 +35,30 @@ type UserChanged struct {
type UserAddressCreated struct { type UserAddressCreated struct {
eventBase eventBase
UserID string UserID string
Address string AddressID string
Email string
} }
type UserAddressChanged struct { type UserAddressUpdated struct {
eventBase eventBase
UserID string UserID string
Address string AddressID string
Email string
} }
type UserAddressDeleted struct { type UserAddressDeleted struct {
eventBase eventBase
UserID string UserID string
Address string AddressID string
Email string
}
type AddressModeChanged struct {
eventBase
UserID string
AddressMode vault.AddressMode
} }

View File

@ -23,6 +23,7 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/abiosoft/ishell" "github.com/abiosoft/ishell"
) )
@ -39,7 +40,7 @@ func (f *frontendCLI) listAccounts(c *ishell.Context) {
connected = "connected" connected = "connected"
} }
mode := "split" mode := "split"
if user.AddressMode == bridge.CombinedMode { if user.AddressMode == vault.CombinedMode {
mode = "combined" mode = "combined"
} }
f.Printf(spacing, idx, user.Username, connected, mode) f.Printf(spacing, idx, user.Username, connected, mode)
@ -58,7 +59,7 @@ func (f *frontendCLI) showAccountInfo(c *ishell.Context) {
return return
} }
if user.AddressMode == bridge.CombinedMode { if user.AddressMode == vault.CombinedMode {
f.showAccountAddressInfo(user, user.Addresses[0]) f.showAccountAddressInfo(user, user.Addresses[0])
} else { } else {
for _, address := range user.Addresses { for _, address := range user.Addresses {
@ -225,19 +226,19 @@ func (f *frontendCLI) changeMode(c *ishell.Context) {
return return
} }
var targetMode bridge.AddressMode var targetMode vault.AddressMode
if user.AddressMode == bridge.CombinedMode { if user.AddressMode == vault.CombinedMode {
targetMode = bridge.SplitMode targetMode = vault.SplitMode
} else { } else {
targetMode = bridge.CombinedMode targetMode = vault.CombinedMode
} }
if !f.yesNoQuestion("Are you sure you want to change the mode for account " + bold(user.Username) + " to " + bold(targetMode)) { if !f.yesNoQuestion("Are you sure you want to change the mode for account " + bold(user.Username) + " to " + bold(targetMode)) {
return return
} }
if err := f.bridge.SetAddressMode(user.UserID, targetMode); err != nil { if err := f.bridge.SetAddressMode(context.Background(), user.UserID, targetMode); err != nil {
f.printAndLogError("Cannot switch address mode:", err) f.printAndLogError("Cannot switch address mode:", err)
} }

View File

@ -296,7 +296,7 @@ func (f *frontendCLI) watchEvents() {
f.notifyLogout(user.Username) f.notifyLogout(user.Username)
case events.UserAddressChanged: case events.UserAddressUpdated:
user, err := f.bridge.GetUserInfo(event.UserID) user, err := f.bridge.GetUserInfo(event.UserID)
if err != nil { if err != nil {
return return
@ -305,7 +305,7 @@ func (f *frontendCLI) watchEvents() {
f.Printf("Address changed for %s. You may need to reconfigure your email client.\n", user.Username) f.Printf("Address changed for %s. You may need to reconfigure your email client.\n", user.Username)
case events.UserAddressDeleted: case events.UserAddressDeleted:
f.notifyLogout(event.Address) f.notifyLogout(event.Email)
case events.SyncStarted: case events.SyncStarted:
user, err := f.bridge.GetUserInfo(event.UserID) user, err := f.bridge.GetUserInfo(event.UserID)

View File

@ -228,13 +228,13 @@ func (s *Service) watchEvents() {
_ = s.SendEvent(NewShowMainWindowEvent()) _ = s.SendEvent(NewShowMainWindowEvent())
case events.UserAddressCreated: case events.UserAddressCreated:
_ = s.SendEvent(NewMailAddressChangeEvent(event.Address)) _ = s.SendEvent(NewMailAddressChangeEvent(event.Email))
case events.UserAddressChanged: case events.UserAddressUpdated:
_ = s.SendEvent(NewMailAddressChangeEvent(event.Address)) _ = s.SendEvent(NewMailAddressChangeEvent(event.Email))
case events.UserAddressDeleted: case events.UserAddressDeleted:
_ = s.SendEvent(NewMailAddressChangeLogoutEvent(event.Address)) _ = s.SendEvent(NewMailAddressChangeLogoutEvent(event.Email))
case events.UserChanged: case events.UserChanged:
_ = s.SendEvent(NewUserChangedEvent(event.UserID)) _ = s.SendEvent(NewUserChangedEvent(event.UserID))

View File

@ -20,7 +20,7 @@ package grpc
import ( import (
"context" "context"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -74,15 +74,15 @@ func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitMode
defer s.panicHandler.HandlePanic() defer s.panicHandler.HandlePanic()
defer func() { _ = s.SendEvent(NewUserToggleSplitModeFinishedEvent(splitMode.UserID)) }() defer func() { _ = s.SendEvent(NewUserToggleSplitModeFinishedEvent(splitMode.UserID)) }()
var targetMode bridge.AddressMode var targetMode vault.AddressMode
if splitMode.Active && user.AddressMode == bridge.CombinedMode { if splitMode.Active && user.AddressMode == vault.CombinedMode {
targetMode = bridge.SplitMode targetMode = vault.SplitMode
} else if !splitMode.Active && user.AddressMode == bridge.SplitMode { } else if !splitMode.Active && user.AddressMode == vault.SplitMode {
targetMode = bridge.CombinedMode targetMode = vault.CombinedMode
} }
if err := s.bridge.SetAddressMode(user.UserID, targetMode); err != nil { if err := s.bridge.SetAddressMode(context.Background(), user.UserID, targetMode); err != nil {
logrus.WithError(err).Error("Failed to set address mode") logrus.WithError(err).Error("Failed to set address mode")
} }
}() }()

View File

@ -22,6 +22,7 @@ import (
"strings" "strings"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -64,7 +65,7 @@ func grpcUserFromInfo(user bridge.UserInfo) *User {
Username: user.Username, Username: user.Username,
AvatarText: getInitials(user.Username), AvatarText: getInitials(user.Username),
LoggedIn: user.Connected, LoggedIn: user.Connected,
SplitMode: user.AddressMode == bridge.SplitMode, SplitMode: user.AddressMode == vault.SplitMode,
SetupGuideSeen: true, // users listed have already seen the setup guide. SetupGuideSeen: true, // users listed have already seen the setup guide.
UsedBytes: int64(user.UsedSpace), UsedBytes: int64(user.UsedSpace),
TotalBytes: int64(user.MaxSpace), TotalBytes: int64(user.MaxSpace),

41
internal/pool/job.go Normal file
View File

@ -0,0 +1,41 @@
package pool
import "context"
type job[In, Out any] struct {
ctx context.Context
req In
res chan Out
err chan error
done chan struct{}
}
func newJob[In, Out any](ctx context.Context, req In) *job[In, Out] {
return &job[In, Out]{
ctx: ctx,
req: req,
res: make(chan Out),
err: make(chan error),
done: make(chan struct{}),
}
}
func (job *job[In, Out]) result() (Out, error) {
return <-job.res, <-job.err
}
func (job *job[In, Out]) postSuccess(res Out) {
close(job.err)
job.res <- res
}
func (job *job[In, Out]) postFailure(err error) {
close(job.res)
job.err <- err
}
func (job *job[In, Out]) waitDone() {
<-job.done
}

View File

@ -13,16 +13,16 @@ var ErrJobCancelled = errors.New("Job cancelled by surrounding context")
// Pool is a worker pool that handles input of type In and returns results of type Out. // Pool is a worker pool that handles input of type In and returns results of type Out.
type Pool[In comparable, Out any] struct { type Pool[In comparable, Out any] struct {
queue *queue.QueuedChannel[*Job[In, Out]] queue *queue.QueuedChannel[*job[In, Out]]
size int size int
} }
// DoneFunc must be called to free up pool resources. // doneFunc must be called to free up pool resources.
type DoneFunc func() type doneFunc func()
// New returns a new pool. // New returns a new pool.
func New[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] { func New[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] {
queue := queue.NewQueuedChannel[*Job[In, Out]](0, 0) queue := queue.NewQueuedChannel[*job[In, Out]](0, 0)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
go func() { go func() {
@ -51,17 +51,6 @@ func New[In comparable, Out any](size int, work func(context.Context, In) (Out,
} }
} }
// NewJob submits a job to the pool. It returns a job handle and a DoneFunc.
// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done,
// which frees up the worker in the pool for reuse.
func (pool *Pool[In, Out]) NewJob(ctx context.Context, req In) (*Job[In, Out], DoneFunc) {
job := newJob[In, Out](ctx, req)
pool.queue.Enqueue(job)
return job, func() { close(job.done) }
}
// Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred. // Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred.
func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, Out, error) error) error { func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, Out, error) error) error {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
@ -81,10 +70,10 @@ func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, O
go func() { go func() {
defer wg.Done() defer wg.Done()
job, done := pool.NewJob(ctx, req) job, done := pool.newJob(ctx, req)
defer done() defer done()
res, err := job.Result() res, err := job.result()
if err := fn(req, res, err); err != nil { if err := fn(req, res, err); err != nil {
lock.Lock() lock.Lock()
@ -134,44 +123,25 @@ func (pool *Pool[In, Out]) ProcessAll(ctx context.Context, reqs []In) (map[In]Ou
return data, nil return data, nil
} }
// ProcessOne submits one job to the pool and returns the result.
func (pool *Pool[In, Out]) ProcessOne(ctx context.Context, req In) (Out, error) {
job, done := pool.newJob(ctx, req)
defer done()
return job.result()
}
func (pool *Pool[In, Out]) Done() { func (pool *Pool[In, Out]) Done() {
pool.queue.Close() pool.queue.Close()
} }
type Job[In, Out any] struct { // newJob submits a job to the pool. It returns a job handle and a DoneFunc.
ctx context.Context // The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done,
req In // which frees up the worker in the pool for reuse.
func (pool *Pool[In, Out]) newJob(ctx context.Context, req In) (*job[In, Out], doneFunc) {
job := newJob[In, Out](ctx, req)
res chan Out pool.queue.Enqueue(job)
err chan error
done chan struct{} return job, func() { close(job.done) }
}
func newJob[In, Out any](ctx context.Context, req In) *Job[In, Out] {
return &Job[In, Out]{
ctx: ctx,
req: req,
res: make(chan Out),
err: make(chan error),
done: make(chan struct{}),
}
}
func (job *Job[In, Out]) Result() (Out, error) {
return <-job.res, <-job.err
}
func (job *Job[In, Out]) postSuccess(res Out) {
close(job.err)
job.res <- res
}
func (job *Job[In, Out]) postFailure(err error) {
close(job.res)
job.err <- err
}
func (job *Job[In, Out]) waitDone() {
<-job.done
} }

View File

@ -15,16 +15,16 @@ import (
func TestPool_NewJob(t *testing.T) { func TestPool_NewJob(t *testing.T) {
doubler := newDoubler(runtime.NumCPU()) doubler := newDoubler(runtime.NumCPU())
job1, done1 := doubler.NewJob(context.Background(), 1) job1, done1 := doubler.newJob(context.Background(), 1)
defer done1() defer done1()
job2, done2 := doubler.NewJob(context.Background(), 2) job2, done2 := doubler.newJob(context.Background(), 2)
defer done2() defer done2()
res2, err := job2.Result() res2, err := job2.result()
require.NoError(t, err) require.NoError(t, err)
res1, err := job1.Result() res1, err := job1.result()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, res1) assert.Equal(t, 2, res1)
@ -36,31 +36,31 @@ func TestPool_NewJob_Done(t *testing.T) {
doubler := newDoubler(2) doubler := newDoubler(2)
// Start two jobs. Don't mark the jobs as done yet. // Start two jobs. Don't mark the jobs as done yet.
job1, done1 := doubler.NewJob(context.Background(), 1) job1, done1 := doubler.newJob(context.Background(), 1)
job2, done2 := doubler.NewJob(context.Background(), 2) job2, done2 := doubler.newJob(context.Background(), 2)
// Get the first result. // Get the first result.
res1, _ := job1.Result() res1, _ := job1.result()
assert.Equal(t, 2, res1) assert.Equal(t, 2, res1)
// Get the first result. // Get the first result.
res2, _ := job2.Result() res2, _ := job2.result()
assert.Equal(t, 4, res2) assert.Equal(t, 4, res2)
// Additional jobs will wait. // Additional jobs will wait.
job3, _ := doubler.NewJob(context.Background(), 3) job3, _ := doubler.newJob(context.Background(), 3)
job4, _ := doubler.NewJob(context.Background(), 4) job4, _ := doubler.newJob(context.Background(), 4)
// Channel to collect results from jobs 3 and 4. // Channel to collect results from jobs 3 and 4.
resCh := make(chan int, 2) resCh := make(chan int, 2)
go func() { go func() {
res, _ := job3.Result() res, _ := job3.result()
resCh <- res resCh <- res
}() }()
go func() { go func() {
res, _ := job4.Result() res, _ := job4.result()
resCh <- res resCh <- res
}() }()

View File

@ -0,0 +1,46 @@
package user
import "gitlab.protontech.ch/go/liteapi"
type addrList struct {
apiAddrs ordMap[string, string, liteapi.Address]
}
func newAddrList(apiAddrs []liteapi.Address) *addrList {
return &addrList{
apiAddrs: newOrdMap(
func(addr liteapi.Address) string { return addr.ID },
func(addr liteapi.Address) string { return addr.Email },
func(a, b liteapi.Address) bool { return a.Order < b.Order },
apiAddrs...,
),
}
}
func (list *addrList) insert(address liteapi.Address) {
list.apiAddrs.insert(address)
}
func (list *addrList) delete(addrID string) string {
return list.apiAddrs.delete(addrID)
}
func (list *addrList) primary() string {
return list.apiAddrs.keys()[0]
}
func (list *addrList) addrIDs() []string {
return list.apiAddrs.keys()
}
func (list *addrList) emails() []string {
return list.apiAddrs.values()
}
func (list *addrList) email(addrID string) string {
return list.apiAddrs.get(addrID)
}
func (list *addrList) addrMap() map[string]string {
return list.apiAddrs.toMap()
}

View File

@ -2,16 +2,20 @@ package user
import ( import (
"context" "context"
"time"
"github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/pool" "github.com/ProtonMail/proton-bridge/v2/internal/pool"
"github.com/ProtonMail/proton-bridge/v2/pkg/message" "github.com/ProtonMail/proton-bridge/v2/pkg/message"
"github.com/bradenaw/juniper/xslices"
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/slices"
) )
type request struct { type request struct {
messageID string messageID string
addressID string
addrKR *crypto.KeyRing addrKR *crypto.KeyRing
} }
@ -54,8 +58,38 @@ func newBuilder(f fetcher, msgWorkers, attWorkers int) *pool.Pool[request, *imap
return nil, err return nil, err
} }
return getMessageCreatedUpdate(msg, literal) return newMessageCreatedUpdate(msg, literal)
}) })
return msgPool return msgPool
} }
func newMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) {
parsedMessage, err := imap.NewParsedMessage(literal)
if err != nil {
return nil, err
}
flags := imap.NewFlagSet()
if !message.Unread {
flags = flags.Add(imap.FlagSeen)
}
if slices.Contains(message.LabelIDs, liteapi.StarredLabel) {
flags = flags.Add(imap.FlagFlagged)
}
imapMessage := imap.Message{
ID: imap.MessageID(message.ID),
Flags: flags,
Date: time.Unix(message.Time, 0),
}
return &imap.MessageCreated{
Message: imapMessage,
Literal: literal,
LabelIDs: mapTo[string, imap.LabelID](xslices.Filter(message.LabelIDs, wantLabelID)),
ParsedMessage: parsedMessage,
}, nil
}

View File

@ -8,5 +8,5 @@ var (
ErrNotSupported = errors.New("not supported") ErrNotSupported = errors.New("not supported")
ErrInvalidReturnPath = errors.New("invalid return path") ErrInvalidReturnPath = errors.New("invalid return path")
ErrInvalidRecipient = errors.New("invalid recipient") ErrInvalidRecipient = errors.New("invalid recipient")
ErrMissingAddressKey = errors.New("missing address key") ErrMissingAddrKey = errors.New("missing address key")
) )

View File

@ -2,43 +2,44 @@ package user
import ( import (
"context" "context"
"fmt"
"github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
) )
// handleAPIEvent handles the given liteapi.Event. // handleAPIEvent handles the given liteapi.Event.
func (user *User) handleAPIEvent(event liteapi.Event) error { func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error {
if event.User != nil { if event.User != nil {
if err := user.handleUserEvent(*event.User); err != nil { if err := user.handleUserEvent(ctx, *event.User); err != nil {
return err return err
} }
} }
if len(event.Addresses) > 0 { if len(event.Addresses) > 0 {
if err := user.handleAddressEvents(event.Addresses); err != nil { if err := user.handleAddressEvents(ctx, event.Addresses); err != nil {
return err return err
} }
} }
if event.MailSettings != nil { if event.MailSettings != nil {
if err := user.handleMailSettingsEvent(*event.MailSettings); err != nil { if err := user.handleMailSettingsEvent(ctx, *event.MailSettings); err != nil {
return err return err
} }
} }
if len(event.Labels) > 0 { if len(event.Labels) > 0 {
if err := user.handleLabelEvents(event.Labels); err != nil { if err := user.handleLabelEvents(ctx, event.Labels); err != nil {
return err return err
} }
} }
if len(event.Messages) > 0 { if len(event.Messages) > 0 {
if err := user.handleMessageEvents(event.Messages); err != nil { if err := user.handleMessageEvents(ctx, event.Messages); err != nil {
return err return err
} }
} }
@ -47,7 +48,7 @@ func (user *User) handleAPIEvent(event liteapi.Event) error {
} }
// handleUserEvent handles the given user event. // handleUserEvent handles the given user event.
func (user *User) handleUserEvent(userEvent liteapi.User) error { func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) error {
userKR, err := userEvent.Keys.Unlock(user.vault.KeyPass(), nil) userKR, err := userEvent.Keys.Unlock(user.vault.KeyPass(), nil)
if err != nil { if err != nil {
return err return err
@ -57,49 +58,31 @@ func (user *User) handleUserEvent(userEvent liteapi.User) error {
user.userKR = userKR user.userKR = userKR
user.notifyCh <- events.UserChanged{ user.eventCh.Enqueue(events.UserChanged{
UserID: user.ID(), UserID: user.ID(),
} })
return nil return nil
} }
// handleAddressEvents handles the given address events. // handleAddressEvents handles the given address events.
// TODO: If split address mode, need to signal back to bridge to update the addresses! // TODO: If split address mode, need to signal back to bridge to update the addresses!
func (user *User) handleAddressEvents(addressEvents []liteapi.AddressEvent) error { func (user *User) handleAddressEvents(ctx context.Context, addressEvents []liteapi.AddressEvent) error {
for _, event := range addressEvents { for _, event := range addressEvents {
switch event.Action { switch event.Action {
case liteapi.EventDelete:
address, err := user.deleteAddress(event.ID)
if err != nil {
return err
}
// TODO: This is not the same as addressChangedLogout event!
// That was only relevant in split mode. This is used differently now.
user.notifyCh <- events.UserAddressDeleted{
UserID: user.ID(),
Address: address.Email,
}
case liteapi.EventCreate: case liteapi.EventCreate:
if err := user.createAddress(event.Address); err != nil { if err := user.handleCreateAddressEvent(ctx, event); err != nil {
return err return fmt.Errorf("failed to handle create address event: %w", err)
}
user.notifyCh <- events.UserAddressCreated{
UserID: user.ID(),
Address: event.Address.Email,
} }
case liteapi.EventUpdate: case liteapi.EventUpdate:
if err := user.updateAddress(event.Address); err != nil { if err := user.handleUpdateAddressEvent(ctx, event); err != nil {
return err return fmt.Errorf("failed to handle update address event: %w", err)
} }
user.notifyCh <- events.UserAddressChanged{ case liteapi.EventDelete:
UserID: user.ID(), if err := user.handleDeleteAddressEvent(ctx, event); err != nil {
Address: event.Address.Email, return fmt.Errorf("failed to delete address: %w", err)
} }
} }
} }
@ -107,111 +90,189 @@ func (user *User) handleAddressEvents(addressEvents []liteapi.AddressEvent) erro
return nil return nil
} }
// createAddress creates the given address. func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
func (user *User) createAddress(address liteapi.Address) error { addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
addrKR, err := address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to unlock address keys: %w", err)
} }
if user.imapConn != nil { user.apiAddrs.insert(event.Address)
user.imapConn.addAddress(address.Email)
user.addrKRs[event.Address.ID] = addrKR
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
if err := user.syncLabels(ctx, event.Address.ID); err != nil {
return fmt.Errorf("failed to sync labels to new address: %w", err)
}
} }
user.addresses = append(user.addresses, address) user.eventCh.Enqueue(events.UserAddressCreated{
UserID: user.ID(),
user.addrKRs[address.ID] = addrKR AddressID: event.Address.ID,
Email: event.Address.Email,
})
return nil return nil
} }
// updateAddress updates the given address. func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
func (user *User) updateAddress(address liteapi.Address) error { addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
if _, err := user.deleteAddress(address.ID); err != nil { if err != nil {
return err return fmt.Errorf("failed to unlock address keys: %w", err)
} }
return user.createAddress(address) user.apiAddrs.insert(event.Address)
}
// deleteAddress deletes the given address. user.addrKRs[event.Address.ID] = addrKR
func (user *User) deleteAddress(addressID string) (liteapi.Address, error) {
idx := xslices.IndexFunc(user.addresses, func(address liteapi.Address) bool { user.eventCh.Enqueue(events.UserAddressUpdated{
return address.ID == addressID UserID: user.ID(),
AddressID: event.Address.ID,
Email: event.Address.Email,
}) })
if idx < 0 { return nil
return liteapi.Address{}, ErrNoSuchAddress }
func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
email := user.apiAddrs.delete(event.ID)
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.ID].Close()
delete(user.updateCh, event.ID)
} }
if user.imapConn != nil { user.eventCh.Enqueue(events.UserAddressDeleted{
user.imapConn.remAddress(user.addresses[idx].Email) UserID: user.ID(),
} AddressID: event.ID,
Email: email,
})
var address liteapi.Address return nil
address, user.addresses = user.addresses[idx], append(user.addresses[:idx], user.addresses[idx+1:]...)
delete(user.addrKRs, addressID)
return address, nil
} }
// handleMailSettingsEvent handles the given mail settings event. // handleMailSettingsEvent handles the given mail settings event.
func (user *User) handleMailSettingsEvent(mailSettingsEvent liteapi.MailSettings) error { func (user *User) handleMailSettingsEvent(ctx context.Context, mailSettingsEvent liteapi.MailSettings) error {
user.settings = mailSettingsEvent user.settings = mailSettingsEvent
return nil return nil
} }
// handleLabelEvents handles the given label events. // handleLabelEvents handles the given label events.
func (user *User) handleLabelEvents(labelEvents []liteapi.LabelEvent) error { func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.LabelEvent) error {
for _, event := range labelEvents { for _, event := range labelEvents {
switch event.Action { switch event.Action {
case liteapi.EventDelete:
user.updateCh <- imap.NewMailboxDeleted(imap.LabelID(event.ID))
case liteapi.EventCreate: case liteapi.EventCreate:
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label)) if err := user.handleCreateLabelEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle create label event: %w", err)
}
case liteapi.EventUpdate, liteapi.EventUpdateFlags: case liteapi.EventUpdate, liteapi.EventUpdateFlags:
user.updateCh <- imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label)) if err := user.handleUpdateLabelEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle update label event: %w", err)
}
case liteapi.EventDelete:
if err := user.handleDeleteLabelEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle delete label event: %w", err)
}
} }
} }
return nil return nil
} }
func (user *User) handleCreateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
for _, updateCh := range user.updateCh {
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label)))
}
return nil
}
func (user *User) handleUpdateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
for _, updateCh := range user.updateCh {
updateCh.Enqueue(imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label)))
}
return nil
}
func (user *User) handleDeleteLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
for _, updateCh := range user.updateCh {
updateCh.Enqueue(imap.NewMailboxDeleted(imap.LabelID(event.ID)))
}
return nil
}
// handleMessageEvents handles the given message events. // handleMessageEvents handles the given message events.
func (user *User) handleMessageEvents(messageEvents []liteapi.MessageEvent) error { func (user *User) handleMessageEvents(ctx context.Context, messageEvents []liteapi.MessageEvent) error {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
for _, event := range messageEvents { for _, event := range messageEvents {
switch event.Action { switch event.Action {
case liteapi.EventDelete:
return ErrNotImplemented
case liteapi.EventCreate: case liteapi.EventCreate:
messages, err := user.builder.ProcessAll(ctx, []request{{event.ID, user.addrKRs[event.Message.AddressID]}}) if err := user.handleCreateMessageEvent(ctx, event); err != nil {
if err != nil { return fmt.Errorf("failed to handle create message event: %w", err)
return err
} }
user.updateCh <- imap.NewMessagesCreated(maps.Values(messages)...)
case liteapi.EventUpdate, liteapi.EventUpdateFlags: case liteapi.EventUpdate, liteapi.EventUpdateFlags:
user.updateCh <- imap.NewMessageLabelsUpdated( if err := user.handleUpdateMessageEvent(ctx, event); err != nil {
imap.MessageID(event.ID), return fmt.Errorf("failed to handle update message event: %w", err)
imapLabelIDs(filterLabelIDs(event.Message.LabelIDs)), }
bool(!event.Message.Unread),
slices.Contains(event.Message.LabelIDs, liteapi.StarredLabel), case liteapi.EventDelete:
) return ErrNotImplemented
} }
} }
return nil return nil
} }
func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
var addressID string
if user.GetAddressMode() == vault.CombinedMode {
addressID = user.apiAddrs.primary()
} else {
addressID = event.Message.AddressID
}
message, err := user.builder.ProcessOne(ctx, request{
messageID: event.ID,
addressID: addressID,
addrKR: user.addrKRs[event.Message.AddressID],
})
if err != nil {
return err
}
user.updateCh[addressID].Enqueue(imap.NewMessagesCreated(message))
return nil
}
func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
update := imap.NewMessageLabelsUpdated(
imap.MessageID(event.ID),
mapTo[string, imap.LabelID](xslices.Filter(event.Message.LabelIDs, wantLabelID)),
event.Message.Seen(),
event.Message.Starred(),
)
if user.GetAddressMode() == vault.CombinedMode {
user.updateCh[user.apiAddrs.primary()].Enqueue(update)
} else {
user.updateCh[event.Message.AddressID].Enqueue(update)
}
return nil
}
func getMailboxName(label liteapi.Label) []string { func getMailboxName(label liteapi.Label) []string {
var name []string var name []string

76
internal/user/flusher.go Normal file
View File

@ -0,0 +1,76 @@
package user
import (
"sync"
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
)
type flusher struct {
userID string
updateCh *queue.QueuedChannel[imap.Update]
eventCh *queue.QueuedChannel[events.Event]
updates []*imap.MessageCreated
maxChunkSize int
curChunkSize int
count int
total int
start time.Time
pushLock sync.Mutex
}
func newFlusher(
userID string,
updateCh *queue.QueuedChannel[imap.Update],
eventCh *queue.QueuedChannel[events.Event],
total, maxChunkSize int,
) *flusher {
return &flusher{
userID: userID,
updateCh: updateCh,
eventCh: eventCh,
maxChunkSize: maxChunkSize,
total: total,
start: time.Now(),
}
}
func (f *flusher) push(update *imap.MessageCreated) {
f.pushLock.Lock()
defer f.pushLock.Unlock()
f.updates = append(f.updates, update)
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
f.flush()
}
}
func (f *flusher) flush() {
if len(f.updates) == 0 {
return
}
f.count += len(f.updates)
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
f.eventCh.Enqueue(newSyncProgress(f.userID, f.count, f.total, f.start))
f.updates = nil
f.curChunkSize = 0
}
func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress {
return events.SyncProgress{
UserID: userID,
Progress: float64(count) / float64(total),
Elapsed: time.Since(start),
Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count),
}
}

View File

@ -25,11 +25,12 @@ const (
) )
type imapConnector struct { type imapConnector struct {
addrID string
client *liteapi.Client client *liteapi.Client
updateCh <-chan imap.Update updateCh <-chan imap.Update
addresses []string emails []string
password string password string
flags, permFlags, attrs imap.FlagSet flags, permFlags, attrs imap.FlagSet
} }
@ -37,15 +38,15 @@ type imapConnector struct {
func newIMAPConnector( func newIMAPConnector(
client *liteapi.Client, client *liteapi.Client,
updateCh <-chan imap.Update, updateCh <-chan imap.Update,
addresses []string,
password string, password string,
emails ...string,
) *imapConnector { ) *imapConnector {
return &imapConnector{ return &imapConnector{
client: client, client: client,
updateCh: updateCh, updateCh: updateCh,
addresses: addresses, emails: emails,
password: password, password: password,
flags: defaultFlags, flags: defaultFlags,
permFlags: defaultPermanentFlags, permFlags: defaultPermanentFlags,
@ -59,7 +60,7 @@ func (conn *imapConnector) Authorize(username string, password string) bool {
return false return false
} }
return xslices.IndexFunc(conn.addresses, func(address string) bool { return xslices.IndexFunc(conn.emails, func(address string) bool {
return strings.EqualFold(address, username) return strings.EqualFold(address, username)
}) >= 0 }) >= 0
} }
@ -187,7 +188,7 @@ func (conn *imapConnector) GetMessage(ctx context.Context, messageID imap.Messag
ID: imap.MessageID(message.ID), ID: imap.MessageID(message.ID),
Flags: flags, Flags: flags,
Date: time.Unix(message.Time, 0), Date: time.Unix(message.Time, 0),
}, imapLabelIDs(message.LabelIDs), nil }, mapTo[string, imap.LabelID](message.LabelIDs), nil
} }
// CreateMessage creates a new message on the remote. // CreateMessage creates a new message on the remote.
@ -204,21 +205,21 @@ func (conn *imapConnector) CreateMessage(
// LabelMessages labels the given messages with the given label ID. // LabelMessages labels the given messages with the given label ID.
func (conn *imapConnector) LabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error { func (conn *imapConnector) LabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error {
return conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), string(labelID)) return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelID))
} }
// UnlabelMessages unlabels the given messages with the given label ID. // UnlabelMessages unlabels the given messages with the given label ID.
func (conn *imapConnector) UnlabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error { func (conn *imapConnector) UnlabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error {
return conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), string(labelID)) return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelID))
} }
// MoveMessages removes the given messages from one label and adds them to the other label. // MoveMessages removes the given messages from one label and adds them to the other label.
func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.MessageID, labelFromID imap.LabelID, labelToID imap.LabelID) error { func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.MessageID, labelFromID imap.LabelID, labelToID imap.LabelID) error {
if err := conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), string(labelToID)); err != nil { if err := conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
return fmt.Errorf("labeling messages: %w", err) return fmt.Errorf("labeling messages: %w", err)
} }
if err := conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), string(labelFromID)); err != nil { if err := conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
return fmt.Errorf("unlabeling messages: %w", err) return fmt.Errorf("unlabeling messages: %w", err)
} }
@ -228,18 +229,18 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
// MarkMessagesSeen sets the seen value of the given messages. // MarkMessagesSeen sets the seen value of the given messages.
func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.MessageID, seen bool) error { func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.MessageID, seen bool) error {
if seen { if seen {
return conn.client.MarkMessagesRead(ctx, strMessageIDs(messageIDs)...) return conn.client.MarkMessagesRead(ctx, mapTo[imap.MessageID, string](messageIDs)...)
} else { } else {
return conn.client.MarkMessagesUnread(ctx, strMessageIDs(messageIDs)...) return conn.client.MarkMessagesUnread(ctx, mapTo[imap.MessageID, string](messageIDs)...)
} }
} }
// MarkMessagesFlagged sets the flagged value of the given messages. // MarkMessagesFlagged sets the flagged value of the given messages.
func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs []imap.MessageID, flagged bool) error { func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs []imap.MessageID, flagged bool) error {
if flagged { if flagged {
return conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), liteapi.StarredLabel) return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel)
} else { } else {
return conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), liteapi.StarredLabel) return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel)
} }
} }
@ -249,45 +250,17 @@ func (conn *imapConnector) GetUpdates() <-chan imap.Update {
return conn.updateCh return conn.updateCh
} }
// Close the connector when it will no longer be used and all resources should be closed/released. // GetUIDValidity returns the default UID validity for this user.
func (conn *imapConnector) Close(ctx context.Context) error { func (conn *imapConnector) GetUIDValidity() imap.UID {
return imap.UID(1)
}
// SetUIDValidity sets the default UID validity for this user.
func (conn *imapConnector) SetUIDValidity(uidValidity imap.UID) error {
return nil return nil
} }
func (conn *imapConnector) addAddress(address string) { // Close the connector will no longer be used and all resources should be closed/released.
conn.addresses = append(conn.addresses, address) func (conn *imapConnector) Close(ctx context.Context) error {
} return nil
func (conn *imapConnector) remAddress(address string) {
idx := slices.Index(conn.addresses, address)
if idx < 0 {
return
}
conn.addresses = append(conn.addresses[:idx], conn.addresses[idx+1:]...)
}
func strLabelIDs(imapLabelIDs []imap.LabelID) []string {
return xslices.Map(imapLabelIDs, func(labelID imap.LabelID) string {
return string(labelID)
})
}
func imapLabelIDs(labelIDs []string) []imap.LabelID {
return xslices.Map(labelIDs, func(labelID string) imap.LabelID {
return imap.LabelID(labelID)
})
}
func strMessageIDs(imapMessageIDs []imap.MessageID) []string {
return xslices.Map(imapMessageIDs, func(messageID imap.MessageID) string {
return string(messageID)
})
}
func imapMessageIDs(messageIDs []string) []imap.MessageID {
return xslices.Map(messageIDs, func(messageID string) imap.MessageID {
return imap.MessageID(messageID)
})
} }

89
internal/user/map.go Normal file
View File

@ -0,0 +1,89 @@
package user
import (
"github.com/bradenaw/juniper/xslices"
"golang.org/x/exp/slices"
)
type ordMap[Key comparable, Val, Data any] struct {
data map[Key]Data
order []Key
toKey func(Data) Key
toVal func(Data) Val
isLess func(Data, Data) bool
}
func newOrdMap[Key comparable, Val, Data any](
key func(Data) Key,
value func(Data) Val,
less func(Data, Data) bool,
data ...Data,
) ordMap[Key, Val, Data] {
m := ordMap[Key, Val, Data]{
data: make(map[Key]Data),
toKey: key,
toVal: value,
isLess: less,
}
for _, d := range data {
m.insert(d)
}
return m
}
func (set *ordMap[Key, Val, Data]) insert(data Data) {
if _, ok := set.data[set.toKey(data)]; ok {
set.delete(set.toKey(data))
}
set.data[set.toKey(data)] = data
set.order = append(set.order, set.toKey(data))
slices.SortFunc(set.order, func(a, b Key) bool {
return set.isLess(set.data[a], set.data[b])
})
}
func (set *ordMap[Key, Val, Data]) delete(key Key) Val {
data, ok := set.data[key]
if !ok {
return *new(Val)
}
delete(set.data, key)
set.order = xslices.Filter(set.order, func(otherKey Key) bool {
return otherKey != key
})
return set.toVal(data)
}
func (set *ordMap[Key, Val, Data]) get(key Key) Val {
return set.toVal(set.data[key])
}
func (set *ordMap[Key, Val, Data]) keys() []Key {
return set.order
}
func (set *ordMap[Key, Val, Data]) values() []Val {
return xslices.Map(set.order, func(key Key) Val {
return set.toVal(set.data[key])
})
}
func (set *ordMap[Key, Val, Data]) toMap() map[Key]Val {
m := make(map[Key]Val)
for _, key := range set.order {
m[key] = set.toVal(set.data[key])
}
return m
}

48
internal/user/map_test.go Normal file
View File

@ -0,0 +1,48 @@
package user
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMap(t *testing.T) {
type Key int
type Value string
type Data struct {
key Key
value Value
}
m := newOrdMap(
func(d Data) Key { return d.key },
func(d Data) Value { return d.value },
func(a, b Data) bool { return a.key < b.key },
Data{key: 1, value: "a"},
Data{key: 2, value: "b"},
Data{key: 3, value: "c"},
)
// Insert some new data.
m.insert(Data{key: 4, value: "d"})
m.insert(Data{key: 5, value: "e"})
// Delete some data.
require.Equal(t, Value("c"), m.delete(3))
require.Equal(t, Value("a"), m.delete(1))
require.Equal(t, Value("e"), m.delete(5))
// Check the remaining keys and values are correct.
require.Equal(t, []Key{2, 4}, m.keys())
require.Equal(t, []Value{"b", "d"}, m.values())
// Overwrite some data.
m.insert(Data{key: 2, value: "two"})
m.insert(Data{key: 4, value: "four"})
// Check the remaining keys and values are correct.
require.Equal(t, []Key{2, 4}, m.keys())
require.Equal(t, []Value{"two", "four"}, m.values())
}

View File

@ -20,12 +20,14 @@ import (
) )
type smtpSession struct { type smtpSession struct {
client *liteapi.Client client *liteapi.Client
username string
addresses []liteapi.Address username string
userKR *crypto.KeyRing emails map[string]string
addrKRs map[string]*crypto.KeyRing settings liteapi.MailSettings
settings liteapi.MailSettings
userKR *crypto.KeyRing
addrKRs map[string]*crypto.KeyRing
from string from string
to map[string]struct{} to map[string]struct{}
@ -34,18 +36,20 @@ type smtpSession struct {
func newSMTPSession( func newSMTPSession(
client *liteapi.Client, client *liteapi.Client,
username string, username string,
addresses []liteapi.Address, addresses map[string]string,
settings liteapi.MailSettings,
userKR *crypto.KeyRing, userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing, addrKRs map[string]*crypto.KeyRing,
settings liteapi.MailSettings,
) *smtpSession { ) *smtpSession {
return &smtpSession{ return &smtpSession{
client: client, client: client,
username: username,
addresses: addresses, username: username,
userKR: userKR, emails: addresses,
addrKRs: addrKRs, settings: settings,
settings: settings,
userKR: userKR,
addrKRs: addrKRs,
from: "", from: "",
to: make(map[string]struct{}), to: make(map[string]struct{}),
@ -86,15 +90,15 @@ func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error {
return ErrNotImplemented return ErrNotImplemented
} }
idx := xslices.IndexFunc(session.addresses, func(address liteapi.Address) bool { for addrID, email := range session.emails {
return strings.EqualFold(address.Email, from) if strings.EqualFold(from, email) {
}) session.from = addrID
}
if idx < 0 {
return ErrInvalidReturnPath
} }
session.from = session.addresses[idx].ID if session.from == "" {
return ErrInvalidReturnPath
}
return nil return nil
} }
@ -129,10 +133,10 @@ func (session *smtpSession) Data(r io.Reader) error {
addrKR, ok := session.addrKRs[session.from] addrKR, ok := session.addrKRs[session.from]
if !ok { if !ok {
return ErrMissingAddressKey return ErrMissingAddrKey
} }
addrKR, err := addrKR.FirstKey() addrKey, err := addrKR.FirstKey()
if err != nil { if err != nil {
return fmt.Errorf("failed to get first key: %w", err) return fmt.Errorf("failed to get first key: %w", err)
} }
@ -143,7 +147,7 @@ func (session *smtpSession) Data(r io.Reader) error {
} }
if session.settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { if session.settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled {
key, err := addrKR.GetKey(0) key, err := addrKey.GetKey(0)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user public key: %w", err) return fmt.Errorf("failed to get user public key: %w", err)
} }
@ -153,7 +157,7 @@ func (session *smtpSession) Data(r io.Reader) error {
return fmt.Errorf("failed to get user public key: %w", err) return fmt.Errorf("failed to get user public key: %w", err)
} }
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKey.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
} }
message, err := message.ParseWithParser(parser) message, err := message.ParseWithParser(parser)
@ -161,7 +165,7 @@ func (session *smtpSession) Data(r io.Reader) error {
return fmt.Errorf("failed to parse message: %w", err) return fmt.Errorf("failed to parse message: %w", err)
} }
draft, attKeys, err := session.createDraft(ctx, addrKR, message) draft, attKeys, err := session.createDraft(ctx, addrKey, message)
if err != nil { if err != nil {
return fmt.Errorf("failed to create draft: %w", err) return fmt.Errorf("failed to create draft: %w", err)
} }
@ -171,7 +175,7 @@ func (session *smtpSession) Data(r io.Reader) error {
return fmt.Errorf("failed to get recipients: %w", err) return fmt.Errorf("failed to get recipients: %w", err)
} }
req, err := createSendReq(addrKR, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys) req, err := createSendReq(addrKey, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys)
if err != nil { if err != nil {
return fmt.Errorf("failed to create packages: %w", err) return fmt.Errorf("failed to create packages: %w", err)
} }

View File

@ -4,57 +4,34 @@ import (
"context" "context"
"fmt" "fmt"
"strings" "strings"
"sync"
"time"
"github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"github.com/google/uuid" "github.com/google/uuid"
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/slices"
) )
const chunkSize = 1 << 20 const chunkSize = 1 << 20
func (user *User) sync(ctx context.Context) error { func (user *User) syncLabels(ctx context.Context, addrIDs ...string) error {
user.notifyCh <- events.SyncStarted{
UserID: user.ID(),
}
if err := user.syncLabels(ctx); err != nil {
return fmt.Errorf("failed to sync labels: %w", err)
}
if err := user.syncMessages(ctx); err != nil {
return fmt.Errorf("failed to sync messages: %w", err)
}
user.notifyCh <- events.SyncFinished{
UserID: user.ID(),
}
if err := user.vault.SetSync(true); err != nil {
return fmt.Errorf("failed to update sync status: %w", err)
}
return nil
}
func (user *User) syncLabels(ctx context.Context) error {
// Sync the system folders. // Sync the system folders.
system, err := user.client.GetLabels(ctx, liteapi.LabelTypeSystem) system, err := user.client.GetLabels(ctx, liteapi.LabelTypeSystem)
if err != nil { if err != nil {
return err return err
} }
for _, label := range system { for _, label := range xslices.Filter(system, func(label liteapi.Label) bool { return wantLabelID(label.ID) }) {
user.updateCh <- newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name) for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name))
}
} }
// Create Folders/Labels mailboxes with a random ID and with the \Noselect attribute. // Create Folders/Labels mailboxes with a random ID and with the \Noselect attribute.
for _, prefix := range []string{folderPrefix, labelPrefix} { for _, prefix := range []string{folderPrefix, labelPrefix} {
user.updateCh <- newPlaceHolderMailboxCreatedUpdate(prefix) for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newPlaceHolderMailboxCreatedUpdate(prefix))
}
} }
// Sync the API folders. // Sync the API folders.
@ -64,7 +41,9 @@ func (user *User) syncLabels(ctx context.Context) error {
} }
for _, folder := range folders { for _, folder := range folders {
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path}) for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path}))
}
} }
// Sync the API labels. // Sync the API labels.
@ -74,7 +53,9 @@ func (user *User) syncLabels(ctx context.Context) error {
} }
for _, label := range labels { for _, label := range labels {
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path}) for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path}))
}
} }
return nil return nil
@ -84,27 +65,53 @@ func (user *User) syncMessages(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
// Determine which messages to sync.
// TODO: This needs to be done better using the new API route to retrieve just the message IDs.
metadata, err := user.client.GetAllMessageMetadata(ctx) metadata, err := user.client.GetAllMessageMetadata(ctx)
if err != nil { if err != nil {
return err return err
} }
// If in split mode, we need to send each message to a different IMAP connector.
isSplitMode := user.vault.AddressMode() == vault.SplitMode
// Collect the build requests -- we need:
// - the message ID to build,
// - the keyring to decrypt the message,
// - and the address to send the message to (for split mode).
requests := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) request { requests := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) request {
var addressID string
if isSplitMode {
addressID = metadata.AddressID
} else {
addressID = user.apiAddrs.primary()
}
return request{ return request{
messageID: metadata.ID, messageID: metadata.ID,
addressID: addressID,
addrKR: user.addrKRs[metadata.AddressID], addrKR: user.addrKRs[metadata.AddressID],
} }
}) })
flusher := newFlusher(user.ID(), user.updateCh, user.notifyCh, len(metadata), chunkSize) // Create the flushers, one per update channel.
defer flusher.flush() flushers := make(map[string]*flusher)
for addrID, updateCh := range user.updateCh {
flusher := newFlusher(user.ID(), updateCh, user.eventCh, len(requests), chunkSize)
defer flusher.flush()
flushers[addrID] = flusher
}
// Build the messages and send them to the correct flusher.
if err := user.builder.Process(ctx, requests, func(req request, res *imap.MessageCreated, err error) error { if err := user.builder.Process(ctx, requests, func(req request, res *imap.MessageCreated, err error) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to build message %s: %w", req.messageID, err) return fmt.Errorf("failed to build message %s: %w", req.messageID, err)
} }
flusher.push(res) flushers[req.addressID].push(res)
return nil return nil
}); err != nil { }); err != nil {
@ -114,95 +121,15 @@ func (user *User) syncMessages(ctx context.Context) error {
return nil return nil
} }
type flusher struct { func (user *User) syncWait() {
userID string for _, updateCh := range user.updateCh {
waiter := imap.NewNoop()
defer waiter.Wait()
updates []*imap.MessageCreated updateCh.Enqueue(waiter)
updateCh chan<- imap.Update
notifyCh chan<- events.Event
maxChunkSize int
curChunkSize int
count int
total int
start time.Time
pushLock sync.Mutex
}
func newFlusher(userID string, updateCh chan<- imap.Update, notifyCh chan<- events.Event, total, maxChunkSize int) *flusher {
return &flusher{
userID: userID,
updateCh: updateCh,
notifyCh: notifyCh,
maxChunkSize: maxChunkSize,
total: total,
start: time.Now(),
} }
} }
func (f *flusher) push(update *imap.MessageCreated) {
f.pushLock.Lock()
defer f.pushLock.Unlock()
f.updates = append(f.updates, update)
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
f.flush()
}
}
func (f *flusher) flush() {
if len(f.updates) == 0 {
return
}
f.count += len(f.updates)
f.updateCh <- imap.NewMessagesCreated(f.updates...)
f.notifyCh <- newSyncProgress(f.userID, f.count, f.total, f.start)
f.updates = nil
f.curChunkSize = 0
}
func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress {
return events.SyncProgress{
UserID: userID,
Progress: float64(count) / float64(total),
Elapsed: time.Since(start),
Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count),
}
}
func getMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) {
parsedMessage, err := imap.NewParsedMessage(literal)
if err != nil {
return nil, err
}
flags := imap.NewFlagSet()
if !message.Unread {
flags = flags.Add(imap.FlagSeen)
}
if slices.Contains(message.LabelIDs, liteapi.StarredLabel) {
flags = flags.Add(imap.FlagFlagged)
}
imapMessage := imap.Message{
ID: imap.MessageID(message.ID),
Flags: flags,
Date: time.Unix(message.Time, 0),
}
return &imap.MessageCreated{
Message: imapMessage,
Literal: literal,
LabelIDs: imapLabelIDs(filterLabelIDs(message.LabelIDs)),
ParsedMessage: parsedMessage,
}, nil
}
func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated { func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated {
if strings.EqualFold(labelName, imap.Inbox) { if strings.EqualFold(labelName, imap.Inbox) {
labelName = imap.Inbox labelName = imap.Inbox
@ -237,18 +164,12 @@ func newMailboxCreatedUpdate(labelID imap.LabelID, labelName []string) *imap.Mai
}) })
} }
func filterLabelIDs(labelIDs []string) []string { func wantLabelID(labelID string) bool {
var filteredLabelIDs []string switch labelID {
case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel:
return false
for _, labelID := range labelIDs { default:
switch labelID { return true
case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel:
// ... skip ...
default:
filteredLabelIDs = append(filteredLabelIDs, labelID)
}
} }
return filteredLabelIDs
} }

13
internal/user/types.go Normal file
View File

@ -0,0 +1,13 @@
package user
import "reflect"
func mapTo[From, To any](from []From) []To {
to := make([]To, 0, len(from))
for _, from := range from {
to = append(to, reflect.ValueOf(from).Convert(reflect.TypeOf(to).Elem()).Interface().(To))
}
return to
}

View File

@ -0,0 +1,20 @@
package user
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestToType(t *testing.T) {
type myString string
// Slices of different types are not equal.
require.NotEqual(t, []myString{"a", "b", "c"}, []string{"a", "b", "c"})
// But converting them to the same type makes them equal.
require.Equal(t, []myString{"a", "b", "c"}, mapTo[string, myString]([]string{"a", "b", "c"}))
// The conversion can happen in the other direction too.
require.Equal(t, []string{"a", "b", "c"}, mapTo[myString, string]([]myString{"a", "b", "c"}))
}

View File

@ -2,19 +2,22 @@ package user
import ( import (
"context" "context"
"fmt"
"runtime" "runtime"
"time" "time"
"github.com/ProtonMail/gluon"
"github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/pool" "github.com/ProtonMail/proton-bridge/v2/internal/pool"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
@ -23,40 +26,38 @@ var (
DefaultEventJitter = 20 * time.Second DefaultEventJitter = 20 * time.Second
) )
// TODO: Is it bad to store the key pass in the user? Any worse than storing private keys?
type User struct { type User struct {
vault *vault.User vault *vault.User
client *liteapi.Client client *liteapi.Client
builder *pool.Pool[request, *imap.MessageCreated] builder *pool.Pool[request, *imap.MessageCreated]
eventCh *queue.QueuedChannel[events.Event]
apiUser liteapi.User apiUser liteapi.User
addresses []liteapi.Address apiAddrs *addrList
settings liteapi.MailSettings
notifyCh chan events.Event
updateCh chan imap.Update
userKR *crypto.KeyRing userKR *crypto.KeyRing
addrKRs map[string]*crypto.KeyRing addrKRs map[string]*crypto.KeyRing
imapConn *imapConnector settings liteapi.MailSettings
updateCh map[string]*queue.QueuedChannel[imap.Update]
syncWG gluon.WaitGroup
} }
func New( func New(
ctx context.Context, ctx context.Context,
vault *vault.User, encVault *vault.User,
client *liteapi.Client, client *liteapi.Client,
apiUser liteapi.User, apiUser liteapi.User,
apiAddrs []liteapi.Address, apiAddrs []liteapi.Address,
userKR *crypto.KeyRing, userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing, addrKRs map[string]*crypto.KeyRing,
) (*User, error) { ) (*User, error) {
if vault.EventID() == "" { if encVault.EventID() == "" {
eventID, err := client.GetLatestEventID(ctx) eventID, err := client.GetLatestEventID(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := vault.SetEventID(eventID); err != nil { if err := encVault.SetEventID(eventID); err != nil {
return nil, err return nil, err
} }
} }
@ -67,19 +68,29 @@ func New(
} }
user := &User{ user := &User{
apiUser: apiUser, vault: encVault,
addresses: apiAddrs,
settings: settings,
vault: vault,
client: client, client: client,
builder: newBuilder(client, runtime.NumCPU()*runtime.NumCPU(), runtime.NumCPU()*runtime.NumCPU()), builder: newBuilder(client, runtime.NumCPU()*runtime.NumCPU(), runtime.NumCPU()*runtime.NumCPU()),
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
notifyCh: make(chan events.Event), apiUser: apiUser,
updateCh: make(chan imap.Update), apiAddrs: newAddrList(apiAddrs),
userKR: userKR, userKR: userKR,
addrKRs: addrKRs, addrKRs: addrKRs,
settings: settings,
updateCh: make(map[string]*queue.QueuedChannel[imap.Update]),
}
// Initialize update channels for each of the user's addresses.
for _, addrID := range user.apiAddrs.addrIDs() {
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
// If in combined mode, we only need one update channel.
if encVault.AddressMode() == vault.CombinedMode {
break
}
} }
// When we receive an auth object, we update it in the store. // When we receive an auth object, we update it in the store.
@ -93,111 +104,234 @@ func New(
// When we are deauthorized, we send a deauth event to the notify channel. // When we are deauthorized, we send a deauth event to the notify channel.
// Bridge will catch this and log the user out. // Bridge will catch this and log the user out.
client.AddDeauthHandler(func() { client.AddDeauthHandler(func() {
user.notifyCh <- events.UserDeauth{ user.eventCh.Enqueue(events.UserDeauth{
UserID: user.ID(), UserID: user.ID(),
} })
}) })
// When we receive an API event, we attempt to handle it. If successful, we send the event to the event channel. // When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault.
go func() { go func() {
for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, vault.EventID()).Subscribe() { ctx, cancel := context.WithCancel(context.Background())
if err := user.handleAPIEvent(event); err != nil { defer cancel()
for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, encVault.EventID()).Subscribe() {
if err := user.handleAPIEvent(ctx, event); err != nil {
logrus.WithError(err).Error("Failed to handle event") logrus.WithError(err).Error("Failed to handle event")
} else { } else if err := user.vault.SetEventID(event.EventID); err != nil {
if err := user.vault.SetEventID(event.EventID); err != nil { logrus.WithError(err).Error("Failed to update event ID")
logrus.WithError(err).Error("Failed to update event ID")
}
} }
} }
}() }()
// TODO: Use a proper sync manager! (if partial sync, pickup from where we last stopped)
if !vault.HasSync() {
go user.sync(context.Background())
}
return user, nil return user, nil
} }
// ID returns the user's ID.
func (user *User) ID() string { func (user *User) ID() string {
return user.apiUser.ID return user.apiUser.ID
} }
// Name returns the user's username.
func (user *User) Name() string { func (user *User) Name() string {
return user.apiUser.Name return user.apiUser.Name
} }
// Match matches the given query against the user's username and email addresses.
func (user *User) Match(query string) bool { func (user *User) Match(query string) bool {
if query == user.Name() { if query == user.apiUser.Name {
return true return true
} }
return slices.Contains(user.Addresses(), query) return slices.Contains(user.apiAddrs.emails(), query)
} }
func (user *User) Addresses() []string { // Emails returns all the user's email addresses.
return xslices.Map( func (user *User) Emails() []string {
sort(user.addresses, func(a, b liteapi.Address) bool { return user.apiAddrs.emails()
return a.Order < b.Order
}),
func(address liteapi.Address) string {
return address.Email
},
)
} }
func (user *User) GluonID() string { // GetAddressMode returns the user's current address mode.
return user.vault.GluonID() func (user *User) GetAddressMode() vault.AddressMode {
return user.vault.AddressMode()
} }
// SetAddressMode sets the user's address mode.
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
for _, updateCh := range user.updateCh {
updateCh.Close()
}
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
for _, addrID := range user.apiAddrs.addrIDs() {
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
if mode == vault.CombinedMode {
break
}
}
if err := user.vault.SetAddressMode(mode); err != nil {
return fmt.Errorf("failed to set address mode: %w", err)
}
return nil
}
// GetGluonIDs returns the users gluon IDs.
func (user *User) GetGluonIDs() map[string]string {
return user.vault.GetGluonIDs()
}
// GetGluonID returns the gluon ID for the given address, if present.
func (user *User) GetGluonID(addrID string) (string, bool) {
gluonID, ok := user.vault.GetGluonIDs()[addrID]
if !ok {
return "", false
}
return gluonID, true
}
// SetGluonID sets the gluon ID for the given address.
func (user *User) SetGluonID(addrID, gluonID string) error {
return user.vault.SetGluonID(addrID, gluonID)
}
// GluonKey returns the user's gluon key from the vault.
func (user *User) GluonKey() []byte { func (user *User) GluonKey() []byte {
return user.vault.GluonKey() return user.vault.GluonKey()
} }
// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP.
func (user *User) BridgePass() string { func (user *User) BridgePass() string {
return user.vault.BridgePass() return user.vault.BridgePass()
} }
// UsedSpace returns the total space used by the user on the API.
func (user *User) UsedSpace() int { func (user *User) UsedSpace() int {
return user.apiUser.UsedSpace return user.apiUser.UsedSpace
} }
// MaxSpace returns the amount of space the user can use on the API.
func (user *User) MaxSpace() int { func (user *User) MaxSpace() int {
return user.apiUser.MaxSpace return user.apiUser.MaxSpace
} }
// GetNotifyCh returns a channel which notifies of events happening to the user (such as deauth, address change) // HasSync returns whether the user has finished syncing.
func (user *User) GetNotifyCh() <-chan events.Event { func (user *User) HasSync() bool {
return user.notifyCh return user.vault.HasSync()
} }
func (user *User) NewGluonConnector(ctx context.Context) (connector.Connector, error) { // AbortSync aborts any ongoing sync.
if user.imapConn != nil { // TODO: This should abort the sync rather than just waiting.
if err := user.imapConn.Close(ctx); err != nil { // Should probably be done automatically when one of the user's IMAP connectors is closed.
return nil, err func (user *User) AbortSync(ctx context.Context) error {
user.syncWG.Wait()
return nil
}
// DoSync performs a sync for the user.
func (user *User) DoSync(ctx context.Context) <-chan error {
errCh := queue.NewQueuedChannel[error](0, 0)
user.syncWG.Go(func() {
defer errCh.Close()
user.eventCh.Enqueue(events.SyncStarted{
UserID: user.ID(),
})
errCh.Enqueue(func() error {
if err := user.syncLabels(ctx, maps.Keys(user.updateCh)...); err != nil {
return fmt.Errorf("failed to sync labels: %w", err)
}
if err := user.syncMessages(ctx); err != nil {
return fmt.Errorf("failed to sync messages: %w", err)
}
user.syncWait()
if err := user.vault.SetSync(true); err != nil {
return fmt.Errorf("failed to set sync status: %w", err)
}
return nil
}())
user.eventCh.Enqueue(events.SyncFinished{
UserID: user.ID(),
})
})
return errCh.GetChannel()
}
// GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change)
func (user *User) GetEventCh() <-chan events.Event {
return user.eventCh.GetChannel()
}
// NewIMAPConnector returns an IMAP connector for the given address.
// If not in split mode, this function returns an error.
func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
var emails []string
switch user.vault.AddressMode() {
case vault.CombinedMode:
if addrID != user.apiAddrs.primary() {
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
} }
emails = user.apiAddrs.emails()
case vault.SplitMode:
emails = []string{user.apiAddrs.email(addrID)}
} }
user.imapConn = newIMAPConnector(user.client, user.updateCh, user.Addresses(), user.vault.BridgePass()) return newIMAPConnector(
user.client,
return user.imapConn, nil user.updateCh[addrID].GetChannel(),
user.vault.BridgePass(),
emails...,
), nil
} }
func (user *User) NewSMTPSession(username string) (smtp.Session, error) { // NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
return newSMTPSession(user.client, username, user.addresses, user.userKR, user.addrKRs, user.settings), nil // In combined mode, this is just the user's primary address.
// In split mode, this is all the user's addresses.
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
imapConn := make(map[string]connector.Connector)
for addrID := range user.updateCh {
conn, err := user.NewIMAPConnector(addrID)
if err != nil {
return nil, fmt.Errorf("failed to create IMAP connector: %w", err)
}
imapConn[addrID] = conn
}
return imapConn, nil
} }
// NewSMTPSession returns an SMTP session for the user.
func (user *User) NewSMTPSession(username string) smtp.Session {
return newSMTPSession(user.client, username, user.apiAddrs.addrMap(), user.settings, user.userKR, user.addrKRs)
}
// Logout logs the user out from the API.
func (user *User) Logout(ctx context.Context) error { func (user *User) Logout(ctx context.Context) error {
return user.client.AuthDelete(ctx) return user.client.AuthDelete(ctx)
} }
// Close closes ongoing connections and cleans up resources.
func (user *User) Close(ctx context.Context) error { func (user *User) Close(ctx context.Context) error {
// Close the user's IMAP connectors. // Wait for ongoing syncs to finish.
if user.imapConn != nil { user.syncWG.Wait()
if err := user.imapConn.Close(ctx); err != nil {
return err
}
}
// Close the user's message builder. // Close the user's message builder.
user.builder.Done() user.builder.Done()
@ -205,15 +339,13 @@ func (user *User) Close(ctx context.Context) error {
// Close the user's API client. // Close the user's API client.
user.client.Close() user.client.Close()
// Close the user's update channels.
for _, updateCh := range user.updateCh {
updateCh.Close()
}
// Close the user's notify channel. // Close the user's notify channel.
close(user.notifyCh) user.eventCh.Close()
return nil return nil
} }
// sort returns the slice, sorted by the given callback.
func sort[T any](slice []T, less func(a, b T) bool) []T {
slices.SortFunc(slice, less)
return slice
}

162
internal/user/user_test.go Normal file
View File

@ -0,0 +1,162 @@
package user_test
import (
"context"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/ProtonMail/proton-bridge/v2/tests"
"github.com/bradenaw/juniper/iterator"
"github.com/emersion/go-imap"
"github.com/emersion/go-imap/client"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
"gitlab.protontech.ch/go/liteapi/server/account"
)
func init() {
user.DefaultEventPeriod = 100 * time.Millisecond
user.DefaultEventJitter = 0
account.GenerateKey = tests.FastGenerateKey
certs.GenerateCert = tests.FastGenerateCert
}
func TestUser_Data(t *testing.T) {
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
// User's ID should be correct.
require.Equal(t, userID, user.ID())
// User's name should be correct.
require.Equal(t, "username", user.Name())
// User's email should be correct.
require.ElementsMatch(t, []string{"email@pm.me", "alias@pm.me"}, user.Emails())
// By default, user should be in combined mode.
require.Equal(t, vault.CombinedMode, user.GetAddressMode())
// By default, user should have a non-empty bridge password.
require.NotEmpty(t, user.BridgePass())
})
})
}
func TestUser_Sync(t *testing.T) {
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
// Get the user's IMAP connectors.
imapConn, err := user.NewIMAPConnectors()
require.NoError(t, err)
// Pretend to be gluon applying all the updates.
go func() {
for _, imapConn := range imapConn {
for update := range imapConn.GetUpdates() {
update.Done()
}
}
}()
// Trigger a user sync.
errCh := user.DoSync(ctx)
// User starts a sync at startup.
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
// User finishes a sync at startup.
require.IsType(t, events.SyncFinished{}, <-user.GetEventCh())
// The sync completes without error.
require.NoError(t, <-errCh)
})
})
}
func TestUser_Deauth(t *testing.T) {
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
eventCh := user.GetEventCh()
// Revoke the user's auth token.
require.NoError(t, s.RevokeUser(userID))
// The user should eventually be logged out.
require.Eventually(t, func() bool { _, ok := (<-eventCh).(events.UserDeauth); return ok }, 5*time.Second, 100*time.Millisecond)
})
})
}
func withAPI(t *testing.T, ctx context.Context, username, password string, emails []string, fn func(context.Context, *server.Server, string, []string)) {
server := server.New()
defer server.Close()
var addrIDs []string
userID, addrID, err := server.AddUser(username, password, emails[0])
require.NoError(t, err)
addrIDs = append(addrIDs, addrID)
for _, email := range emails[1:] {
addrID, err := server.AddAddress(userID, email, password)
require.NoError(t, err)
addrIDs = append(addrIDs, addrID)
}
fn(ctx, server, userID, addrIDs)
}
func withUser(t *testing.T, ctx context.Context, apiURL, username, password string, fn func(*user.User)) {
c, apiAuth, err := liteapi.New(liteapi.WithHostURL(apiURL)).NewClientWithLogin(ctx, username, password)
require.NoError(t, err)
defer func() { require.NoError(t, c.Close()) }()
apiUser, apiAddrs, userKR, addrKRs, passphrase, err := c.Unlock(ctx, []byte(password))
require.NoError(t, err)
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"))
require.NoError(t, err)
require.False(t, corrupt)
vaultUser, err := vault.AddUser(apiUser.ID, username, apiAuth.UID, apiAuth.RefreshToken, passphrase)
require.NoError(t, err)
user, err := user.New(ctx, vaultUser, c, apiUser, apiAddrs, userKR, addrKRs)
require.NoError(t, err)
defer func() { require.NoError(t, user.Close(ctx)) }()
fn(user)
}
func withIMAPClient(t *testing.T, addr string, fn func(*client.Client)) {
c, err := client.Dial(addr)
require.NoError(t, err)
defer c.Close()
fn(c)
}
func fetch(t *testing.T, c *client.Client, seqset string, items ...imap.FetchItem) []*imap.Message {
msgCh := make(chan *imap.Message)
go func() {
require.NoError(t, c.Fetch(must(imap.ParseSeqSet(seqset)), items, msgCh))
}()
return iterator.Collect(iterator.Chan(msgCh))
}
func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}

View File

@ -5,9 +5,14 @@ import (
) )
// RandomToken is a function that returns a random token. // RandomToken is a function that returns a random token.
var RandomToken func(size int) ([]byte, error)
// By default, we use crypto.RandomToken to generate tokens. // By default, we use crypto.RandomToken to generate tokens.
func init() { var RandomToken = crypto.RandomToken
RandomToken = crypto.RandomToken
func newRandomToken(size int) []byte {
token, err := RandomToken(size)
if err != nil {
panic(err)
}
return token
} }

View File

@ -4,6 +4,7 @@ import (
"math/rand" "math/rand"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/updater" "github.com/ProtonMail/proton-bridge/v2/internal/updater"
) )
@ -45,15 +46,24 @@ type Settings struct {
FirstStartGUI bool FirstStartGUI bool
} }
type AddressMode int
const (
CombinedMode AddressMode = iota
SplitMode
)
// UserData holds information about a single bridge user. // UserData holds information about a single bridge user.
// The user may or may not be logged in. // The user may or may not be logged in.
type UserData struct { type UserData struct {
UserID string UserID string
Username string Username string
GluonID string GluonKey []byte
GluonKey []byte GluonIDs map[string]string
BridgePass string UIDValidity map[string]imap.UID
BridgePass []byte
AddressMode AddressMode
AuthUID string AuthUID string
AuthRef string AuthRef string

View File

@ -1,5 +1,11 @@
package vault package vault
import (
"encoding/hex"
"github.com/ProtonMail/gluon/imap"
)
type User struct { type User struct {
vault *Vault vault *Vault
userID string userID string
@ -13,16 +19,41 @@ func (user *User) Username() string {
return user.vault.getUser(user.userID).Username return user.vault.getUser(user.userID).Username
} }
func (user *User) GluonID() string { func (user *User) GetGluonIDs() map[string]string {
return user.vault.getUser(user.userID).GluonID return user.vault.getUser(user.userID).GluonIDs
}
func (user *User) SetGluonID(addrID, gluonID string) error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.GluonIDs[addrID] = gluonID
})
}
func (user *User) GetUIDValidity(addrID string) (imap.UID, bool) {
validity, ok := user.vault.getUser(user.userID).UIDValidity[addrID]
if !ok {
return imap.UID(0), false
}
return validity, true
}
func (user *User) SetUIDValidity(addrID string, validity imap.UID) error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.UIDValidity[addrID] = validity
})
} }
func (user *User) GluonKey() []byte { func (user *User) GluonKey() []byte {
return user.vault.getUser(user.userID).GluonKey return user.vault.getUser(user.userID).GluonKey
} }
func (user *User) AddressMode() AddressMode {
return user.vault.getUser(user.userID).AddressMode
}
func (user *User) BridgePass() string { func (user *User) BridgePass() string {
return user.vault.getUser(user.userID).BridgePass return hex.EncodeToString(user.vault.getUser(user.userID).BridgePass)
} }
func (user *User) AuthUID() string { func (user *User) AuthUID() string {
@ -51,7 +82,7 @@ func (user *User) SetKeyPass(keyPass []byte) error {
}) })
} }
// SetAuth updates the auth secrets for the given user. // SetAuth sets the auth secrets for the given user.
func (user *User) SetAuth(authUID, authRef string) error { func (user *User) SetAuth(authUID, authRef string) error {
return user.vault.modUser(user.userID, func(data *UserData) { return user.vault.modUser(user.userID, func(data *UserData) {
data.AuthUID = authUID data.AuthUID = authUID
@ -59,33 +90,23 @@ func (user *User) SetAuth(authUID, authRef string) error {
}) })
} }
// SetGluonAuth updates the gluon ID and key for the given user. // SetAddressMode sets the address mode for the given user.
func (user *User) SetGluonAuth(gluonID string, gluonKey []byte) error { func (user *User) SetAddressMode(mode AddressMode) error {
return user.vault.modUser(user.userID, func(data *UserData) { return user.vault.modUser(user.userID, func(data *UserData) {
data.GluonID = gluonID data.AddressMode = mode
data.GluonKey = gluonKey
}) })
} }
// SetEventID updates the event ID for the given user. // SetEventID sets the event ID for the given user.
func (user *User) SetEventID(eventID string) error { func (user *User) SetEventID(eventID string) error {
return user.vault.modUser(user.userID, func(data *UserData) { return user.vault.modUser(user.userID, func(data *UserData) {
data.EventID = eventID data.EventID = eventID
}) })
} }
// SetSync updates the sync state for the given user. // SetSync sets the sync state for the given user.
func (user *User) SetSync(hasSync bool) error { func (user *User) SetSync(hasSync bool) error {
return user.vault.modUser(user.userID, func(data *UserData) { return user.vault.modUser(user.userID, func(data *UserData) {
data.HasSync = hasSync data.HasSync = hasSync
}) })
} }
// Clear clears the secrets for the given user.
func (user *User) Clear() error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.AuthUID = ""
data.AuthRef = ""
data.KeyPass = nil
})
}

View File

@ -4,6 +4,7 @@ import (
"encoding/hex" "encoding/hex"
"testing" "testing"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -32,30 +33,48 @@ func TestUser(t *testing.T) {
require.NoError(t, user2.SetSync(false)) require.NoError(t, user2.SetSync(false))
// Set gluon data for user 1 and 2. // Set gluon data for user 1 and 2.
require.NoError(t, user1.SetGluonAuth("gluonID1", []byte("gluonKey1"))) require.NoError(t, user1.SetGluonID("addrID1", "gluonID1"))
require.NoError(t, user2.SetGluonAuth("gluonID2", []byte("gluonKey2"))) require.NoError(t, user2.SetGluonID("addrID2", "gluonID2"))
require.NoError(t, user1.SetUIDValidity("addrID1", imap.UID(1)))
require.NoError(t, user2.SetUIDValidity("addrID2", imap.UID(2)))
// List available users. // List available users.
require.ElementsMatch(t, []string{"userID1", "userID2"}, s.GetUserIDs()) require.ElementsMatch(t, []string{"userID1", "userID2"}, s.GetUserIDs())
// Check gluon information for user 1.
gluonID1, ok := user1.GetGluonIDs()["addrID1"]
require.True(t, ok)
require.Equal(t, "gluonID1", gluonID1)
uidValidity1, ok := user1.GetUIDValidity("addrID1")
require.True(t, ok)
require.Equal(t, imap.UID(1), uidValidity1)
require.NotEmpty(t, user1.GluonKey())
// Get auth information for user 1. // Get auth information for user 1.
require.Equal(t, "userID1", user1.UserID()) require.Equal(t, "userID1", user1.UserID())
require.Equal(t, "user1", user1.Username()) require.Equal(t, "user1", user1.Username())
require.Equal(t, "gluonID1", user1.GluonID())
require.Equal(t, []byte("gluonKey1"), user1.GluonKey())
require.Equal(t, hex.EncodeToString([]byte("token")), user1.BridgePass()) require.Equal(t, hex.EncodeToString([]byte("token")), user1.BridgePass())
require.Equal(t, vault.CombinedMode, user1.AddressMode())
require.Equal(t, "authUID1", user1.AuthUID()) require.Equal(t, "authUID1", user1.AuthUID())
require.Equal(t, "authRef1", user1.AuthRef()) require.Equal(t, "authRef1", user1.AuthRef())
require.Equal(t, []byte("keyPass1"), user1.KeyPass()) require.Equal(t, []byte("keyPass1"), user1.KeyPass())
require.Equal(t, "eventID1", user1.EventID()) require.Equal(t, "eventID1", user1.EventID())
require.Equal(t, true, user1.HasSync()) require.Equal(t, true, user1.HasSync())
// Check gluon information for user 1.
gluonID2, ok := user2.GetGluonIDs()["addrID2"]
require.True(t, ok)
require.Equal(t, "gluonID2", gluonID2)
uidValidity2, ok := user2.GetUIDValidity("addrID2")
require.True(t, ok)
require.Equal(t, imap.UID(2), uidValidity2)
require.NotEmpty(t, user2.GluonKey())
// Get auth information for user 2. // Get auth information for user 2.
require.Equal(t, "userID2", user2.UserID()) require.Equal(t, "userID2", user2.UserID())
require.Equal(t, "user2", user2.Username()) require.Equal(t, "user2", user2.Username())
require.Equal(t, "gluonID2", user2.GluonID())
require.Equal(t, []byte("gluonKey2"), user2.GluonKey())
require.Equal(t, hex.EncodeToString([]byte("token")), user2.BridgePass()) require.Equal(t, hex.EncodeToString([]byte("token")), user2.BridgePass())
require.Equal(t, vault.CombinedMode, user2.AddressMode())
require.Equal(t, "authUID2", user2.AuthUID()) require.Equal(t, "authUID2", user2.AuthUID())
require.Equal(t, "authRef2", user2.AuthRef()) require.Equal(t, "authRef2", user2.AuthRef())
require.Equal(t, []byte("keyPass2"), user2.KeyPass()) require.Equal(t, []byte("keyPass2"), user2.KeyPass())
@ -63,8 +82,8 @@ func TestUser(t *testing.T) {
require.Equal(t, false, user2.HasSync()) require.Equal(t, false, user2.HasSync())
// Clear the users. // Clear the users.
require.NoError(t, user1.Clear()) require.NoError(t, s.ClearUser("userID1"))
require.NoError(t, user2.Clear()) require.NoError(t, s.ClearUser("userID2"))
// Their secrets should now be cleared. // Their secrets should now be cleared.
require.Equal(t, "", user1.AuthUID()) require.Equal(t, "", user1.AuthUID())

View File

@ -4,7 +4,6 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"io/fs" "io/fs"
@ -12,6 +11,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/certs" "github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
) )
@ -99,16 +99,16 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
return nil, errors.New("user already exists") return nil, errors.New("user already exists")
} }
tok, err := RandomToken(16)
if err != nil {
return nil, err
}
if err := vault.mod(func(data *Data) { if err := vault.mod(func(data *Data) {
data.Users = append(data.Users, UserData{ data.Users = append(data.Users, UserData{
UserID: userID, UserID: userID,
Username: username, Username: username,
BridgePass: hex.EncodeToString(tok),
GluonKey: newRandomToken(32),
GluonIDs: make(map[string]string),
UIDValidity: make(map[string]imap.UID),
BridgePass: newRandomToken(16),
AddressMode: CombinedMode,
AuthUID: authUID, AuthUID: authUID,
AuthRef: authRef, AuthRef: authRef,
@ -121,6 +121,14 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
return vault.GetUser(userID) return vault.GetUser(userID)
} }
func (vault *Vault) ClearUser(userID string) error {
return vault.modUser(userID, func(data *UserData) {
data.AuthUID = ""
data.AuthRef = ""
data.KeyPass = nil
})
}
// DeleteUser removes the given user from the vault. // DeleteUser removes the given user from the vault.
func (vault *Vault) DeleteUser(userID string) error { func (vault *Vault) DeleteUser(userID string) error {
return vault.mod(func(data *Data) { return vault.mod(func(data *Data) {

View File

@ -13,7 +13,9 @@ type API interface {
GetHostURL() string GetHostURL() string
AddCallWatcher(func(server.Call), ...string) AddCallWatcher(func(server.Call), ...string)
AddUser(username, password, address string) (userID, addrID string, err error) AddUser(username, password, address string) (string, string, error)
AddAddress(userID, address, password string) (string, error)
RemoveAddress(userID, addrID string) error
RevokeUser(userID string) error RevokeUser(userID string) error
GetLabels(userID string) ([]liteapi.Label, error) GetLabels(userID string) ([]liteapi.Label, error)

View File

@ -30,8 +30,8 @@ import (
) )
func init() { func init() {
user.DefaultEventPeriod = time.Second user.DefaultEventPeriod = 100 * time.Millisecond
user.DefaultEventJitter = time.Second user.DefaultEventJitter = 0
} }
type scenario struct { type scenario struct {
@ -76,6 +76,16 @@ func TestFeatures(testingT *testing.T) {
ctx.Step(`^the user agent is "([^"]*)"$`, s.theUserAgentIs) ctx.Step(`^the user agent is "([^"]*)"$`, s.theUserAgentIs)
ctx.Step(`^the value of the "([^"]*)" header in the request to "([^"]*)" is "([^"]*)"$`, s.theValueOfTheHeaderInTheRequestToIs) ctx.Step(`^the value of the "([^"]*)" header in the request to "([^"]*)" is "([^"]*)"$`, s.theValueOfTheHeaderInTheRequestToIs)
// ==== SETUP ====
ctx.Step(`^there exists an account with username "([^"]*)" and password "([^"]*)"$`, s.thereExistsAnAccountWithUsernameAndPassword)
ctx.Step(`^the account "([^"]*)" has additional address "([^"]*)"$`, s.theAccountHasAdditionalAddress)
ctx.Step(`^the account "([^"]*)" no longer has additional address "([^"]*)"$`, s.theAccountNoLongerHasAdditionalAddress)
ctx.Step(`^the account "([^"]*)" has (\d+) custom folders$`, s.theAccountHasCustomFolders)
ctx.Step(`^the account "([^"]*)" has (\d+) custom labels$`, s.theAccountHasCustomLabels)
ctx.Step(`^the account "([^"]*)" has the following custom mailboxes:$`, s.theAccountHasTheFollowingCustomMailboxes)
ctx.Step(`^the address "([^"]*)" of account "([^"]*)" has the following messages in "([^"]*)":$`, s.theAddressOfAccountHasTheFollowingMessagesInMailbox)
ctx.Step(`^the address "([^"]*)" of account "([^"]*)" has (\d+) messages in "([^"]*)"$`, s.theAddressOfAccountHasMessagesInMailbox)
// ==== BRIDGE ==== // ==== BRIDGE ====
ctx.Step(`^bridge starts$`, s.bridgeStarts) ctx.Step(`^bridge starts$`, s.bridgeStarts)
ctx.Step(`^bridge restarts$`, s.bridgeRestarts) ctx.Step(`^bridge restarts$`, s.bridgeRestarts)
@ -85,12 +95,15 @@ func TestFeatures(testingT *testing.T) {
ctx.Step(`^the user has disabled automatic updates$`, s.theUserHasDisabledAutomaticUpdates) ctx.Step(`^the user has disabled automatic updates$`, s.theUserHasDisabledAutomaticUpdates)
ctx.Step(`^the user changes the IMAP port to (\d+)$`, s.theUserChangesTheIMAPPortTo) ctx.Step(`^the user changes the IMAP port to (\d+)$`, s.theUserChangesTheIMAPPortTo)
ctx.Step(`^the user changes the SMTP port to (\d+)$`, s.theUserChangesTheSMTPPortTo) ctx.Step(`^the user changes the SMTP port to (\d+)$`, s.theUserChangesTheSMTPPortTo)
ctx.Step(`^the user sets the address mode of "([^"]*)" to "([^"]*)"$`, s.theUserSetsTheAddressModeOfTo)
ctx.Step(`^the user changes the gluon path$`, s.theUserChangesTheGluonPath) ctx.Step(`^the user changes the gluon path$`, s.theUserChangesTheGluonPath)
ctx.Step(`^the user deletes the gluon files$`, s.theUserDeletesTheGluonFiles) ctx.Step(`^the user deletes the gluon files$`, s.theUserDeletesTheGluonFiles)
ctx.Step(`^the user reports a bug$`, s.theUserReportsABug) ctx.Step(`^the user reports a bug$`, s.theUserReportsABug)
ctx.Step(`^bridge sends a connection up event$`, s.bridgeSendsAConnectionUpEvent) ctx.Step(`^bridge sends a connection up event$`, s.bridgeSendsAConnectionUpEvent)
ctx.Step(`^bridge sends a connection down event$`, s.bridgeSendsAConnectionDownEvent) ctx.Step(`^bridge sends a connection down event$`, s.bridgeSendsAConnectionDownEvent)
ctx.Step(`^bridge sends a deauth event for user "([^"]*)"$`, s.bridgeSendsADeauthEventForUser) ctx.Step(`^bridge sends a deauth event for user "([^"]*)"$`, s.bridgeSendsADeauthEventForUser)
ctx.Step(`^bridge sends an address created event for user "([^"]*)"$`, s.bridgeSendsAnAddressCreatedEventForUser)
ctx.Step(`^bridge sends an address deleted event for user "([^"]*)"$`, s.bridgeSendsAnAddressDeletedEventForUser)
ctx.Step(`^bridge sends sync started and finished events for user "([^"]*)"$`, s.bridgeSendsSyncStartedAndFinishedEventsForUser) ctx.Step(`^bridge sends sync started and finished events for user "([^"]*)"$`, s.bridgeSendsSyncStartedAndFinishedEventsForUser)
ctx.Step(`^bridge sends an update available event for version "([^"]*)"$`, s.bridgeSendsAnUpdateAvailableEventForVersion) ctx.Step(`^bridge sends an update available event for version "([^"]*)"$`, s.bridgeSendsAnUpdateAvailableEventForVersion)
ctx.Step(`^bridge sends a manual update event for version "([^"]*)"$`, s.bridgeSendsAManualUpdateEventForVersion) ctx.Step(`^bridge sends a manual update event for version "([^"]*)"$`, s.bridgeSendsAManualUpdateEventForVersion)
@ -99,12 +112,6 @@ func TestFeatures(testingT *testing.T) {
ctx.Step(`^bridge sends a forced update event$`, s.bridgeSendsAForcedUpdateEvent) ctx.Step(`^bridge sends a forced update event$`, s.bridgeSendsAForcedUpdateEvent)
// ==== USER ==== // ==== USER ====
ctx.Step(`^there exists an account with username "([^"]*)" and password "([^"]*)"$`, s.thereExistsAnAccountWithUsernameAndPassword)
ctx.Step(`^the account "([^"]*)" has (\d+) custom folders$`, s.theAccountHasCustomFolders)
ctx.Step(`^the account "([^"]*)" has (\d+) custom labels$`, s.theAccountHasCustomLabels)
ctx.Step(`^the account "([^"]*)" has the following custom mailboxes:$`, s.theAccountHasTheFollowingCustomMailboxes)
ctx.Step(`^the account "([^"]*)" has the following messages in "([^"]*)":$`, s.theAccountHasTheFollowingMessagesInMailbox)
ctx.Step(`^the account "([^"]*)" has (\d+) messages in "([^"]*)"$`, s.theAccountHasMessagesInMailbox)
ctx.Step(`^the user logs in with username "([^"]*)" and password "([^"]*)"$`, s.userLogsInWithUsernameAndPassword) ctx.Step(`^the user logs in with username "([^"]*)" and password "([^"]*)"$`, s.userLogsInWithUsernameAndPassword)
ctx.Step(`^user "([^"]*)" logs out$`, s.userLogsOut) ctx.Step(`^user "([^"]*)" logs out$`, s.userLogsOut)
ctx.Step(`^user "([^"]*)" is deleted$`, s.userIsDeleted) ctx.Step(`^user "([^"]*)" is deleted$`, s.userIsDeleted)
@ -119,8 +126,10 @@ func TestFeatures(testingT *testing.T) {
ctx.Step(`^user "([^"]*)" connects IMAP client "([^"]*)"$`, s.userConnectsIMAPClient) ctx.Step(`^user "([^"]*)" connects IMAP client "([^"]*)"$`, s.userConnectsIMAPClient)
ctx.Step(`^user "([^"]*)" connects IMAP client "([^"]*)" on port (\d+)$`, s.userConnectsIMAPClientOnPort) ctx.Step(`^user "([^"]*)" connects IMAP client "([^"]*)" on port (\d+)$`, s.userConnectsIMAPClientOnPort)
ctx.Step(`^user "([^"]*)" connects and authenticates IMAP client "([^"]*)"$`, s.userConnectsAndAuthenticatesIMAPClient) ctx.Step(`^user "([^"]*)" connects and authenticates IMAP client "([^"]*)"$`, s.userConnectsAndAuthenticatesIMAPClient)
ctx.Step(`^user "([^"]*)" connects and authenticates IMAP client "([^"]*)" with address "([^"]*)"$`, s.userConnectsAndAuthenticatesIMAPClientWithAddress)
ctx.Step(`^IMAP client "([^"]*)" can authenticate$`, s.imapClientCanAuthenticate) ctx.Step(`^IMAP client "([^"]*)" can authenticate$`, s.imapClientCanAuthenticate)
ctx.Step(`^IMAP client "([^"]*)" cannot authenticate$`, s.imapClientCannotAuthenticate) ctx.Step(`^IMAP client "([^"]*)" cannot authenticate$`, s.imapClientCannotAuthenticate)
ctx.Step(`^IMAP client "([^"]*)" cannot authenticate with address "([^"]*)"$`, s.imapClientCannotAuthenticateWithAddress)
ctx.Step(`^IMAP client "([^"]*)" cannot authenticate with incorrect username$`, s.imapClientCannotAuthenticateWithIncorrectUsername) ctx.Step(`^IMAP client "([^"]*)" cannot authenticate with incorrect username$`, s.imapClientCannotAuthenticateWithIncorrectUsername)
ctx.Step(`^IMAP client "([^"]*)" cannot authenticate with incorrect password$`, s.imapClientCannotAuthenticateWithIncorrectPassword) ctx.Step(`^IMAP client "([^"]*)" cannot authenticate with incorrect password$`, s.imapClientCannotAuthenticateWithIncorrectPassword)
ctx.Step(`^IMAP client "([^"]*)" announces its ID with name "([^"]*)" and version "([^"]*)"$`, s.imapClientAnnouncesItsIDWithNameAndVersion) ctx.Step(`^IMAP client "([^"]*)" announces its ID with name "([^"]*)" and version "([^"]*)"$`, s.imapClientAnnouncesItsIDWithNameAndVersion)
@ -151,6 +160,7 @@ func TestFeatures(testingT *testing.T) {
ctx.Step(`^user "([^"]*)" connects SMTP client "([^"]*)"$`, s.userConnectsSMTPClient) ctx.Step(`^user "([^"]*)" connects SMTP client "([^"]*)"$`, s.userConnectsSMTPClient)
ctx.Step(`^user "([^"]*)" connects SMTP client "([^"]*)" on port (\d+)$`, s.userConnectsSMTPClientOnPort) ctx.Step(`^user "([^"]*)" connects SMTP client "([^"]*)" on port (\d+)$`, s.userConnectsSMTPClientOnPort)
ctx.Step(`^user "([^"]*)" connects and authenticates SMTP client "([^"]*)"$`, s.userConnectsAndAuthenticatesSMTPClient) ctx.Step(`^user "([^"]*)" connects and authenticates SMTP client "([^"]*)"$`, s.userConnectsAndAuthenticatesSMTPClient)
ctx.Step(`^user "([^"]*)" connects and authenticates SMTP client "([^"]*)" with address "([^"]*)"$`, s.userConnectsAndAuthenticatesSMTPClientWithAddress)
ctx.Step(`^SMTP client "([^"]*)" can authenticate$`, s.smtpClientCanAuthenticate) ctx.Step(`^SMTP client "([^"]*)" can authenticate$`, s.smtpClientCanAuthenticate)
ctx.Step(`^SMTP client "([^"]*)" cannot authenticate$`, s.smtpClientCannotAuthenticate) ctx.Step(`^SMTP client "([^"]*)" cannot authenticate$`, s.smtpClientCannotAuthenticate)
ctx.Step(`^SMTP client "([^"]*)" cannot authenticate with incorrect username$`, s.smtpClientCannotAuthenticateWithIncorrectUsername) ctx.Step(`^SMTP client "([^"]*)" cannot authenticate with incorrect username$`, s.smtpClientCannotAuthenticateWithIncorrectUsername)

View File

@ -9,6 +9,7 @@ import (
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
) )
func (s *scenario) bridgeStarts() error { func (s *scenario) bridgeStarts() error {
@ -46,6 +47,19 @@ func (s *scenario) theUserChangesTheSMTPPortTo(port int) error {
return s.t.bridge.SetSMTPPort(port) return s.t.bridge.SetSMTPPort(port)
} }
func (s *scenario) theUserSetsTheAddressModeOfTo(user, mode string) error {
switch mode {
case "split":
return s.t.bridge.SetAddressMode(context.Background(), s.t.getUserID(user), vault.SplitMode)
case "combined":
return s.t.bridge.SetAddressMode(context.Background(), s.t.getUserID(user), vault.CombinedMode)
default:
return fmt.Errorf("unknown address mode %q", mode)
}
}
func (s *scenario) theUserChangesTheGluonPath() error { func (s *scenario) theUserChangesTheGluonPath() error {
gluonDir, err := os.MkdirTemp(s.t.dir, "gluon") gluonDir, err := os.MkdirTemp(s.t.dir, "gluon")
if err != nil { if err != nil {
@ -113,7 +127,7 @@ func (s *scenario) bridgeSendsAConnectionDownEvent() error {
} }
func (s *scenario) bridgeSendsADeauthEventForUser(username string) error { func (s *scenario) bridgeSendsADeauthEventForUser(username string) error {
return try(s.t.userDeauthCh, 5*time.Second, func(event events.UserDeauth) error { return try(s.t.deauthCh, 5*time.Second, func(event events.UserDeauth) error {
if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { if wantUserID := s.t.getUserID(username); wantUserID != event.UserID {
return fmt.Errorf("expected deauth event for user with ID %s, got %s", wantUserID, event.UserID) return fmt.Errorf("expected deauth event for user with ID %s, got %s", wantUserID, event.UserID)
} }
@ -122,6 +136,26 @@ func (s *scenario) bridgeSendsADeauthEventForUser(username string) error {
}) })
} }
func (s *scenario) bridgeSendsAnAddressCreatedEventForUser(username string) error {
return try(s.t.addrCreatedCh, 5*time.Second, func(event events.UserAddressCreated) error {
if wantUserID := s.t.getUserID(username); wantUserID != event.UserID {
return fmt.Errorf("expected user address created event for user with ID %s, got %s", wantUserID, event.UserID)
}
return nil
})
}
func (s *scenario) bridgeSendsAnAddressDeletedEventForUser(username string) error {
return try(s.t.addrDeletedCh, 5*time.Second, func(event events.UserAddressDeleted) error {
if wantUserID := s.t.getUserID(username); wantUserID != event.UserID {
return fmt.Errorf("expected user address deleted event for user with ID %s, got %s", wantUserID, event.UserID)
}
return nil
})
}
func (s *scenario) bridgeSendsSyncStartedAndFinishedEventsForUser(username string) error { func (s *scenario) bridgeSendsSyncStartedAndFinishedEventsForUser(username string) error {
if err := get(s.t.syncStartedCh, func(event events.SyncStarted) error { if err := get(s.t.syncStartedCh, func(event events.SyncStarted) error {
if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { if wantUserID := s.t.getUserID(username); wantUserID != event.UserID {

View File

@ -54,10 +54,12 @@ func (t *testCtx) startBridge() error {
t.bridge = bridge t.bridge = bridge
// Connect the event channels. // Connect the event channels.
t.userLoginCh = chToType[events.Event, events.UserLoggedIn](bridge.GetEvents(events.UserLoggedIn{})) t.loginCh = chToType[events.Event, events.UserLoggedIn](bridge.GetEvents(events.UserLoggedIn{}))
t.userLogoutCh = chToType[events.Event, events.UserLoggedOut](bridge.GetEvents(events.UserLoggedOut{})) t.logoutCh = chToType[events.Event, events.UserLoggedOut](bridge.GetEvents(events.UserLoggedOut{}))
t.userDeletedCh = chToType[events.Event, events.UserDeleted](bridge.GetEvents(events.UserDeleted{})) t.deletedCh = chToType[events.Event, events.UserDeleted](bridge.GetEvents(events.UserDeleted{}))
t.userDeauthCh = chToType[events.Event, events.UserDeauth](bridge.GetEvents(events.UserDeauth{})) t.deauthCh = chToType[events.Event, events.UserDeauth](bridge.GetEvents(events.UserDeauth{}))
t.addrCreatedCh = chToType[events.Event, events.UserAddressCreated](bridge.GetEvents(events.UserAddressCreated{}))
t.addrDeletedCh = chToType[events.Event, events.UserAddressDeleted](bridge.GetEvents(events.UserAddressDeleted{}))
t.syncStartedCh = chToType[events.Event, events.SyncStarted](bridge.GetEvents(events.SyncStarted{})) t.syncStartedCh = chToType[events.Event, events.SyncStarted](bridge.GetEvents(events.SyncStarted{}))
t.syncFinishedCh = chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) t.syncFinishedCh = chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{}))
t.forcedUpdateCh = chToType[events.Event, events.UpdateForced](bridge.GetEvents(events.UpdateForced{})) t.forcedUpdateCh = chToType[events.Event, events.UpdateForced](bridge.GetEvents(events.UpdateForced{}))

View File

@ -14,6 +14,7 @@ import (
"github.com/emersion/go-imap/client" "github.com/emersion/go-imap/client"
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server"
"golang.org/x/exp/maps"
) )
var defaultVersion = semver.MustParse("1.0.0") var defaultVersion = semver.MustParse("1.0.0")
@ -32,10 +33,12 @@ type testCtx struct {
bridge *bridge.Bridge bridge *bridge.Bridge
// These channels hold events of various types coming from bridge. // These channels hold events of various types coming from bridge.
userLoginCh <-chan events.UserLoggedIn loginCh <-chan events.UserLoggedIn
userLogoutCh <-chan events.UserLoggedOut logoutCh <-chan events.UserLoggedOut
userDeletedCh <-chan events.UserDeleted deletedCh <-chan events.UserDeleted
userDeauthCh <-chan events.UserDeauth deauthCh <-chan events.UserDeauth
addrCreatedCh <-chan events.UserAddressCreated
addrDeletedCh <-chan events.UserAddressDeleted
syncStartedCh <-chan events.SyncStarted syncStartedCh <-chan events.SyncStarted
syncFinishedCh <-chan events.SyncFinished syncFinishedCh <-chan events.SyncFinished
forcedUpdateCh <-chan events.UpdateForced forcedUpdateCh <-chan events.UpdateForced
@ -43,10 +46,10 @@ type testCtx struct {
updateCh <-chan events.Event updateCh <-chan events.Event
// These maps hold expected userIDByName, their primary addresses and bridge passwords. // These maps hold expected userIDByName, their primary addresses and bridge passwords.
userIDByName map[string]string userIDByName map[string]string
userAddrByID map[string]string userAddrByEmail map[string]map[string]string
userPassByID map[string]string userPassByID map[string]string
addrIDByID map[string]string userBridgePassByID map[string]string
// These are the IMAP and SMTP clients used to connect to bridge. // These are the IMAP and SMTP clients used to connect to bridge.
imapClients map[string]*imapClient imapClients map[string]*imapClient
@ -83,10 +86,10 @@ func newTestCtx(tb testing.TB) *testCtx {
mocks: bridge.NewMocks(tb, dialer, defaultVersion, defaultVersion), mocks: bridge.NewMocks(tb, dialer, defaultVersion, defaultVersion),
version: defaultVersion, version: defaultVersion,
userIDByName: make(map[string]string), userIDByName: make(map[string]string),
userAddrByID: make(map[string]string), userAddrByEmail: make(map[string]map[string]string),
userPassByID: make(map[string]string), userPassByID: make(map[string]string),
addrIDByID: make(map[string]string), userBridgePassByID: make(map[string]string),
imapClients: make(map[string]*imapClient), imapClients: make(map[string]*imapClient),
smtpClients: make(map[string]*smtpClient), smtpClients: make(map[string]*smtpClient),
@ -112,12 +115,28 @@ func (t *testCtx) setUserID(username, userID string) {
t.userIDByName[username] = userID t.userIDByName[username] = userID
} }
func (t *testCtx) getUserAddr(userID string) string { func (t *testCtx) getUserAddrID(userID, email string) string {
return t.userAddrByID[userID] return t.userAddrByEmail[userID][email]
} }
func (t *testCtx) setUserAddr(userID, addr string) { func (t *testCtx) getUserAddrs(userID string) []string {
t.userAddrByID[userID] = addr return maps.Keys(t.userAddrByEmail[userID])
}
func (t *testCtx) setUserAddr(userID, addrID, email string) {
if _, ok := t.userAddrByEmail[userID]; !ok {
t.userAddrByEmail[userID] = make(map[string]string)
}
t.userAddrByEmail[userID][email] = addrID
}
func (t *testCtx) unsetUserAddr(userID, wantAddrID string) {
for email, addrID := range t.userAddrByEmail[userID] {
if addrID == wantAddrID {
delete(t.userAddrByEmail[userID], email)
}
}
} }
func (t *testCtx) getUserPass(userID string) string { func (t *testCtx) getUserPass(userID string) string {
@ -128,12 +147,12 @@ func (t *testCtx) setUserPass(userID, pass string) {
t.userPassByID[userID] = pass t.userPassByID[userID] = pass
} }
func (t *testCtx) getAddrID(userID string) string { func (t *testCtx) getUserBridgePass(userID string) string {
return t.addrIDByID[userID] return t.userBridgePassByID[userID]
} }
func (t *testCtx) setAddrID(userID, addrID string) { func (t *testCtx) setUserBridgePass(userID, pass string) {
t.addrIDByID[userID] = addrID t.userBridgePassByID[userID] = pass
} }
func (t *testCtx) getMBoxID(userID string, name string) string { func (t *testCtx) getMBoxID(userID string, name string) string {

48
tests/fast.go Normal file
View File

@ -0,0 +1,48 @@
package tests
import (
"crypto/x509"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
)
var (
preCompPGPKey *crypto.Key
preCompCertPEM []byte
preCompKeyPEM []byte
)
func FastGenerateKey(name, email string, passphrase []byte, keyType string, bits int) (string, error) {
encKey, err := preCompPGPKey.Lock(passphrase)
if err != nil {
return "", err
}
return encKey.Armor()
}
func FastGenerateCert(template *x509.Certificate) ([]byte, []byte, error) {
return preCompCertPEM, preCompKeyPEM, nil
}
func init() {
key, err := crypto.GenerateKey("name", "email", "rsa", 1024)
if err != nil {
panic(err)
}
template, err := certs.NewTLSTemplate()
if err != nil {
panic(err)
}
certPEM, keyPEM, err := certs.GenerateCert(template)
if err != nil {
panic(err)
}
preCompPGPKey = key
preCompCertPEM = certPEM
preCompKeyPEM = keyPEM
}

View File

@ -4,7 +4,7 @@ Feature: IMAP get mailbox info
And the account "user@pm.me" has the following custom mailboxes: And the account "user@pm.me" has the following custom mailboxes:
| name | type | | name | type |
| one | folder | | one | folder |
And the account "user@pm.me" has the following messages in "one": And the address "user@pm.me" of account "user@pm.me" has the following messages in "one":
| sender | recipient | subject | unread | | sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true | | a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false | | b@pm.me | b@pm.me | two | false |

View File

@ -5,7 +5,7 @@ Feature: IMAP copy messages
| name | type | | name | type |
| mbox | folder | | mbox | folder |
| label | label | | label | label |
And the account "user@pm.me" has the following messages in "Inbox": And the address "user@pm.me" of account "user@pm.me" has the following messages in "Inbox":
| sender | recipient | subject | unread | | sender | recipient | subject | unread |
| john.doe@mail.com | user@pm.me | foo | false | | john.doe@mail.com | user@pm.me | foo | false |
| jane.doe@mail.com | name@pm.me | bar | true | | jane.doe@mail.com | name@pm.me | bar | true |

View File

@ -5,7 +5,7 @@ Feature: IMAP remove messages from mailbox
| name | type | | name | type |
| mbox | folder | | mbox | folder |
| label | label | | label | label |
And the account "user@pm.me" has 10 messages in "mbox" And the address "user@pm.me" of account "user@pm.me" has 10 messages in "mbox"
And bridge starts And bridge starts
And the user logs in with username "user@pm.me" and password "password" And the user logs in with username "user@pm.me" and password "password"
And user "user@pm.me" finishes syncing And user "user@pm.me" finishes syncing

View File

@ -0,0 +1,180 @@
Feature: Address mode
Background:
Given there exists an account with username "user@pm.me" and password "password"
And the account "user@pm.me" has additional address "alias@pm.me"
And the account "user@pm.me" has the following custom mailboxes:
| name | type |
| one | folder |
| two | folder |
And the address "user@pm.me" of account "user@pm.me" has the following messages in "one":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
And the address "alias@pm.me" of account "user@pm.me" has the following messages in "two":
| sender | recipient | subject | unread |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
And bridge starts
And the user logs in with username "user@pm.me" and password "password"
And user "user@pm.me" finishes syncing
Scenario: The user is in combined mode
When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me"
Then IMAP client "1" sees the following messages in "Folders/one":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
And IMAP client "1" sees the following messages in "Folders/two":
| sender | recipient | subject | unread |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
And IMAP client "1" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me"
Then IMAP client "2" sees the following messages in "Folders/one":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
And IMAP client "2" sees the following messages in "Folders/two":
| sender | recipient | subject | unread |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
And IMAP client "2" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
Scenario: The user is in split mode
Given the user sets the address mode of "user@pm.me" to "split"
And user "user@pm.me" finishes syncing
When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me"
Then IMAP client "1" sees the following messages in "Folders/one":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
And IMAP client "1" sees 0 messages in "Folders/two"
And IMAP client "1" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me"
Then IMAP client "2" sees 0 messages in "Folders/one"
And IMAP client "2" sees the following messages in "Folders/two":
| sender | recipient | subject | unread |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
And IMAP client "2" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
Scenario: The user switches from combined to split mode and back
Given the user sets the address mode of "user@pm.me" to "split"
And user "user@pm.me" finishes syncing
And the user sets the address mode of "user@pm.me" to "combined"
And user "user@pm.me" finishes syncing
When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me"
Then IMAP client "1" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me"
Then IMAP client "2" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
Scenario: The user adds an address while in combined mode
When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me"
Then IMAP client "1" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me"
Then IMAP client "2" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
Given the account "user@pm.me" has additional address "other@pm.me"
And bridge sends an address created event for user "user@pm.me"
When user "user@pm.me" connects and authenticates IMAP client "3" with address "other@pm.me"
Then IMAP client "3" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
Scenario: The user adds an address while in split mode
Given the user sets the address mode of "user@pm.me" to "split"
And user "user@pm.me" finishes syncing
When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me"
And IMAP client "1" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me"
And IMAP client "2" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
Given the account "user@pm.me" has additional address "other@pm.me"
And bridge sends an address created event for user "user@pm.me"
When user "user@pm.me" connects and authenticates IMAP client "3" with address "other@pm.me"
Then IMAP client "3" eventually sees 0 messages in "All Mail"
Scenario: The user deletes an address while in combined mode
When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me"
Then IMAP client "1" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me"
Then IMAP client "2" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
Given the account "user@pm.me" no longer has additional address "alias@pm.me"
And bridge sends an address deleted event for user "user@pm.me"
When user "user@pm.me" connects IMAP client "3"
Then IMAP client "3" cannot authenticate with address "alias@pm.me"
Scenario: The user deletes an address while in split mode
Given the user sets the address mode of "user@pm.me" to "split"
And user "user@pm.me" finishes syncing
When user "user@pm.me" connects and authenticates IMAP client "1" with address "user@pm.me"
And IMAP client "1" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false |
When user "user@pm.me" connects and authenticates IMAP client "2" with address "alias@pm.me"
And IMAP client "2" sees the following messages in "All Mail":
| sender | recipient | subject | unread |
| c@pm.me | c@pm.me | three | true |
| d@pm.me | d@pm.me | four | false |
Given the account "user@pm.me" no longer has additional address "alias@pm.me"
And bridge sends an address deleted event for user "user@pm.me"
When user "user@pm.me" connects IMAP client "3"
Then IMAP client "3" cannot authenticate with address "alias@pm.me"
Scenario: The user makes an alias the primary address while in combined mode
Scenario: The user makes an alias the primary address while in split mode

View File

@ -6,11 +6,11 @@ Feature: Bridge can fully sync an account
| one | folder | | one | folder |
| two | folder | | two | folder |
| three | label | | three | label |
And the account "user@pm.me" has the following messages in "one": And the address "user@pm.me" of account "user@pm.me" has the following messages in "one":
| sender | recipient | subject | unread | | sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true | | a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false | | b@pm.me | b@pm.me | two | false |
And the account "user@pm.me" has the following messages in "two": And the address "user@pm.me" of account "user@pm.me" has the following messages in "two":
| sender | recipient | subject | unread | | sender | recipient | subject | unread |
| a@pm.me | a@pm.me | one | true | | a@pm.me | a@pm.me | one | true |
| b@pm.me | b@pm.me | two | false | | b@pm.me | b@pm.me | two | false |

View File

@ -26,25 +26,39 @@ func (s *scenario) userConnectsIMAPClientOnPort(username, clientID string, port
} }
func (s *scenario) userConnectsAndAuthenticatesIMAPClient(username, clientID string) error { func (s *scenario) userConnectsAndAuthenticatesIMAPClient(username, clientID string) error {
return s.userConnectsAndAuthenticatesIMAPClientWithAddress(username, clientID, s.t.getUserAddrs(s.t.getUserID(username))[0])
}
func (s *scenario) userConnectsAndAuthenticatesIMAPClientWithAddress(username, clientID, address string) error {
if err := s.t.newIMAPClient(s.t.getUserID(username), clientID); err != nil { if err := s.t.newIMAPClient(s.t.getUserID(username), clientID); err != nil {
return err return err
} }
userID, client := s.t.getIMAPClient(clientID) userID, client := s.t.getIMAPClient(clientID)
return client.Login(s.t.getUserAddr(userID), s.t.getUserPass(userID)) return client.Login(address, s.t.getUserBridgePass(userID))
} }
func (s *scenario) imapClientCanAuthenticate(clientID string) error { func (s *scenario) imapClientCanAuthenticate(clientID string) error {
userID, client := s.t.getIMAPClient(clientID) userID, client := s.t.getIMAPClient(clientID)
return client.Login(s.t.getUserAddr(userID), s.t.getUserPass(userID)) return client.Login(s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID))
} }
func (s *scenario) imapClientCannotAuthenticate(clientID string) error { func (s *scenario) imapClientCannotAuthenticate(clientID string) error {
userID, client := s.t.getIMAPClient(clientID) userID, client := s.t.getIMAPClient(clientID)
if err := client.Login(s.t.getUserAddr(userID), s.t.getUserPass(userID)); err == nil { if err := client.Login(s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)); err == nil {
return fmt.Errorf("expected error, got nil")
}
return nil
}
func (s *scenario) imapClientCannotAuthenticateWithAddress(clientID, address string) error {
userID, client := s.t.getIMAPClient(clientID)
if err := client.Login(address, s.t.getUserBridgePass(userID)); err == nil {
return fmt.Errorf("expected error, got nil") return fmt.Errorf("expected error, got nil")
} }
@ -54,7 +68,7 @@ func (s *scenario) imapClientCannotAuthenticate(clientID string) error {
func (s *scenario) imapClientCannotAuthenticateWithIncorrectUsername(clientID string) error { func (s *scenario) imapClientCannotAuthenticateWithIncorrectUsername(clientID string) error {
userID, client := s.t.getIMAPClient(clientID) userID, client := s.t.getIMAPClient(clientID)
if err := client.Login(s.t.getUserAddr(userID)+"bad", s.t.getUserPass(userID)); err == nil { if err := client.Login(s.t.getUserAddrs(userID)[0]+"bad", s.t.getUserBridgePass(userID)); err == nil {
return fmt.Errorf("expected error, got nil") return fmt.Errorf("expected error, got nil")
} }
@ -64,7 +78,7 @@ func (s *scenario) imapClientCannotAuthenticateWithIncorrectUsername(clientID st
func (s *scenario) imapClientCannotAuthenticateWithIncorrectPassword(clientID string) error { func (s *scenario) imapClientCannotAuthenticateWithIncorrectPassword(clientID string) error {
userID, client := s.t.getIMAPClient(clientID) userID, client := s.t.getIMAPClient(clientID)
if err := client.Login(s.t.getUserAddr(userID), s.t.getUserPass(userID)+"bad"); err == nil { if err := client.Login(s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)+"bad"); err == nil {
return fmt.Errorf("expected error, got nil") return fmt.Errorf("expected error, got nil")
} }

View File

@ -1,39 +1,14 @@
package tests package tests
import ( import (
"crypto/x509"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/certs" "github.com/ProtonMail/proton-bridge/v2/internal/certs"
"gitlab.protontech.ch/go/liteapi/server/account" "gitlab.protontech.ch/go/liteapi/server/account"
) )
func init() { func init() {
key, err := crypto.GenerateKey("name", "email", "rsa", 1024) // Use the fast key generation for tests.
if err != nil { account.GenerateKey = FastGenerateKey
panic(err)
}
account.GenerateKey = func(name, email string, passphrase []byte, keyType string, bits int) (string, error) { // Use the fast cert generation for tests.
encKey, err := key.Lock(passphrase) certs.GenerateCert = FastGenerateCert
if err != nil {
return "", err
}
return encKey.Armor()
}
template, err := certs.NewTLSTemplate()
if err != nil {
panic(err)
}
certPEM, keyPEM, err := certs.GenerateCert(template)
if err != nil {
panic(err)
}
certs.GenerateCert = func(template *x509.Certificate) ([]byte, []byte, error) {
return certPEM, keyPEM, nil
}
} }

View File

@ -16,13 +16,17 @@ func (s *scenario) userConnectsSMTPClientOnPort(username, clientID string, port
} }
func (s *scenario) userConnectsAndAuthenticatesSMTPClient(username, clientID string) error { func (s *scenario) userConnectsAndAuthenticatesSMTPClient(username, clientID string) error {
return s.userConnectsAndAuthenticatesSMTPClientWithAddress(username, clientID, s.t.getUserAddrs(s.t.getUserID(username))[0])
}
func (s *scenario) userConnectsAndAuthenticatesSMTPClientWithAddress(username, clientID, address string) error {
if err := s.t.newSMTPClient(s.t.getUserID(username), clientID); err != nil { if err := s.t.newSMTPClient(s.t.getUserID(username), clientID); err != nil {
return err return err
} }
userID, client := s.t.getSMTPClient(clientID) userID, client := s.t.getSMTPClient(clientID)
s.t.pushError(client.Auth(smtp.PlainAuth("", s.t.getUserAddr(userID), s.t.getUserPass(userID), constants.Host))) s.t.pushError(client.Auth(smtp.PlainAuth("", address, s.t.getUserBridgePass(userID), constants.Host)))
return nil return nil
} }
@ -30,7 +34,7 @@ func (s *scenario) userConnectsAndAuthenticatesSMTPClient(username, clientID str
func (s *scenario) smtpClientCanAuthenticate(clientID string) error { func (s *scenario) smtpClientCanAuthenticate(clientID string) error {
userID, client := s.t.getSMTPClient(clientID) userID, client := s.t.getSMTPClient(clientID)
if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddr(userID), s.t.getUserPass(userID), constants.Host)); err != nil { if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID), constants.Host)); err != nil {
return fmt.Errorf("expected no error, got %v", err) return fmt.Errorf("expected no error, got %v", err)
} }
@ -40,7 +44,7 @@ func (s *scenario) smtpClientCanAuthenticate(clientID string) error {
func (s *scenario) smtpClientCannotAuthenticate(clientID string) error { func (s *scenario) smtpClientCannotAuthenticate(clientID string) error {
userID, client := s.t.getSMTPClient(clientID) userID, client := s.t.getSMTPClient(clientID)
if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddr(userID), s.t.getUserPass(userID), constants.Host)); err == nil { if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID), constants.Host)); err == nil {
return fmt.Errorf("expected error, got nil") return fmt.Errorf("expected error, got nil")
} }
@ -50,7 +54,7 @@ func (s *scenario) smtpClientCannotAuthenticate(clientID string) error {
func (s *scenario) smtpClientCannotAuthenticateWithIncorrectUsername(clientID string) error { func (s *scenario) smtpClientCannotAuthenticateWithIncorrectUsername(clientID string) error {
userID, client := s.t.getSMTPClient(clientID) userID, client := s.t.getSMTPClient(clientID)
if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddr(userID)+"bad", s.t.getUserPass(userID), constants.Host)); err == nil { if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0]+"bad", s.t.getUserBridgePass(userID), constants.Host)); err == nil {
return fmt.Errorf("expected error, got nil") return fmt.Errorf("expected error, got nil")
} }
@ -60,7 +64,7 @@ func (s *scenario) smtpClientCannotAuthenticateWithIncorrectUsername(clientID st
func (s *scenario) smtpClientCannotAuthenticateWithIncorrectPassword(clientID string) error { func (s *scenario) smtpClientCannotAuthenticateWithIncorrectPassword(clientID string) error {
userID, client := s.t.getSMTPClient(clientID) userID, client := s.t.getSMTPClient(clientID)
if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddr(userID), s.t.getUserPass(userID)+"bad", constants.Host)); err == nil { if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)+"bad", constants.Host)); err == nil {
return fmt.Errorf("expected error, got nil") return fmt.Errorf("expected error, got nil")
} }

View File

@ -16,6 +16,7 @@ import (
) )
func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error { func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error {
// Create the user.
userID, addrID, err := s.t.api.AddUser(username, password, username) userID, addrID, err := s.t.api.AddUser(username, password, username)
if err != nil { if err != nil {
return err return err
@ -24,11 +25,37 @@ func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, passwor
// Set the ID of this user. // Set the ID of this user.
s.t.setUserID(username, userID) s.t.setUserID(username, userID)
// Set the address ID of this user. // Set the password of this user.
s.t.setAddrID(userID, addrID) s.t.setUserPass(userID, password)
// Set the address of this user (right now just the same as the username, but let's stay flexible). // Set the address of this user (right now just the same as the username, but let's stay flexible).
s.t.setUserAddr(userID, username) s.t.setUserAddr(userID, addrID, username)
return nil
}
func (s *scenario) theAccountHasAdditionalAddress(username, address string) error {
userID := s.t.getUserID(username)
addrID, err := s.t.api.AddAddress(userID, address, s.t.getUserPass(userID))
if err != nil {
return err
}
s.t.setUserAddr(userID, addrID, address)
return nil
}
func (s *scenario) theAccountNoLongerHasAdditionalAddress(username, address string) error {
userID := s.t.getUserID(username)
addrID := s.t.getUserAddrID(userID, address)
if err := s.t.api.RemoveAddress(userID, addrID); err != nil {
return err
}
s.t.unsetUserAddr(userID, addrID)
return nil return nil
} }
@ -84,9 +111,9 @@ func (s *scenario) theAccountHasTheFollowingCustomMailboxes(username string, tab
return nil return nil
} }
func (s *scenario) theAccountHasTheFollowingMessagesInMailbox(username, mailbox string, table *godog.Table) error { func (s *scenario) theAddressOfAccountHasTheFollowingMessagesInMailbox(address, username, mailbox string, table *godog.Table) error {
userID := s.t.getUserID(username) userID := s.t.getUserID(username)
addrID := s.t.getAddrID(userID) addrID := s.t.getUserAddrID(userID, address)
mboxID := s.t.getMBoxID(userID, mailbox) mboxID := s.t.getMBoxID(userID, mailbox)
for _, wantMessage := range parseMessages(table) { for _, wantMessage := range parseMessages(table) {
@ -109,9 +136,9 @@ func (s *scenario) theAccountHasTheFollowingMessagesInMailbox(username, mailbox
return nil return nil
} }
func (s *scenario) theAccountHasMessagesInMailbox(username string, count int, mailbox string) error { func (s *scenario) theAddressOfAccountHasMessagesInMailbox(address, username string, count int, mailbox string) error {
userID := s.t.getUserID(username) userID := s.t.getUserID(username)
addrID := s.t.getAddrID(userID) addrID := s.t.getUserAddrID(userID, address)
mboxID := s.t.getMBoxID(userID, mailbox) mboxID := s.t.getMBoxID(userID, mailbox)
for idx := 0; idx < count; idx++ { for idx := 0; idx < count; idx++ {
@ -148,7 +175,7 @@ func (s *scenario) userLogsInWithUsernameAndPassword(username, password string)
return err return err
} }
s.t.setUserPass(userID, info.BridgePass) s.t.setUserBridgePass(userID, info.BridgePass)
} }
return nil return nil