Other(refactor): Less unwieldy user type in Bridge

Instead of the annoying safe.Map type, we just use a normal go map
and mutex pair, and use safe.Lock/safe.RLock as a helper function
This commit is contained in:
James Houlahan
2022-10-26 23:04:25 +02:00
parent 85c0d6f837
commit 2bda47fcad
11 changed files with 458 additions and 227 deletions

View File

@ -51,8 +51,8 @@ type Bridge struct {
vault *vault.Vault vault *vault.Vault
// users holds authorized users. // users holds authorized users.
users *safe.Map[string, *user.User] users map[string]*user.User
goLoad func() usersLock sync.RWMutex
// api manages user API clients. // api manages user API clients.
api *liteapi.Manager api *liteapi.Manager
@ -73,7 +73,6 @@ type Bridge struct {
// updater is the bridge's updater. // updater is the bridge's updater.
updater Updater updater Updater
goUpdate func()
curVersion *semver.Version curVersion *semver.Version
// focusService is used to raise the bridge window when needed. // focusService is used to raise the bridge window when needed.
@ -99,6 +98,12 @@ type Bridge struct {
// tasks manages the bridge's goroutines. // tasks manages the bridge's goroutines.
tasks *xsync.Group tasks *xsync.Group
// goLoad triggers a load of disconnected users from the vault.
goLoad func()
// goUpdate triggers a check/install of updates.
goUpdate func()
} }
// New creates a new bridge. // New creates a new bridge.
@ -133,12 +138,8 @@ func New( //nolint:funlen
// imapEventCh forwards IMAP events from gluon instances to the bridge for processing. // imapEventCh forwards IMAP events from gluon instances to the bridge for processing.
imapEventCh := make(chan imapEvents.Event) imapEventCh := make(chan imapEvents.Event)
// users holds all the bridge's users.
users := safe.NewMap[string, *user.User](nil)
// bridge is the bridge. // bridge is the bridge.
bridge, err := newBridge( bridge, err := newBridge(
users,
tasks, tasks,
imapEventCh, imapEventCh,
@ -180,7 +181,6 @@ func New( //nolint:funlen
// nolint:funlen // nolint:funlen
func newBridge( func newBridge(
users *safe.Map[string, *user.User],
tasks *xsync.Group, tasks *xsync.Group,
imapEventCh chan imapEvents.Event, imapEventCh chan imapEvents.Event,
@ -224,9 +224,9 @@ func newBridge(
return nil, fmt.Errorf("failed to create focus service: %w", err) return nil, fmt.Errorf("failed to create focus service: %w", err)
} }
return &Bridge{ bridge := &Bridge{
vault: vault, vault: vault,
users: users, users: make(map[string]*user.User),
api: api, api: api,
proxyCtl: proxyCtl, proxyCtl: proxyCtl,
@ -235,7 +235,6 @@ func newBridge(
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
imapServer: imapServer, imapServer: imapServer,
imapEventCh: imapEventCh, imapEventCh: imapEventCh,
smtpServer: newSMTPServer(users, tlsConfig, logSMTP),
updater: updater, updater: updater,
curVersion: curVersion, curVersion: curVersion,
@ -249,7 +248,11 @@ func newBridge(
logSMTP: logSMTP, logSMTP: logSMTP,
tasks: tasks, tasks: tasks,
}, nil }
bridge.smtpServer = newSMTPServer(bridge, tlsConfig, logSMTP)
return bridge, nil
} }
// nolint:funlen // nolint:funlen
@ -265,10 +268,10 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
bridge.api.AddStatusObserver(func(status liteapi.Status) { bridge.api.AddStatusObserver(func(status liteapi.Status) {
switch { switch {
case status == liteapi.StatusUp: case status == liteapi.StatusUp:
bridge.onStatusUp() go bridge.onStatusUp()
case status == liteapi.StatusDown: case status == liteapi.StatusDown:
bridge.onStatusDown() go bridge.onStatusDown()
} }
}) })
@ -356,9 +359,11 @@ func (bridge *Bridge) Close(ctx context.Context) error {
} }
// Close all users. // Close all users.
bridge.users.IterValues(func(user *user.User) { safe.RLock(func() {
user.Close() for _, user := range bridge.users {
}) user.Close()
}
}, &bridge.usersLock)
// Stop all ongoing tasks. // Stop all ongoing tasks.
bridge.tasks.Wait() bridge.tasks.Wait()
@ -426,19 +431,23 @@ func (bridge *Bridge) remWatcher(watcher *watcher.Watcher[events.Event]) {
func (bridge *Bridge) onStatusUp() { func (bridge *Bridge) onStatusUp() {
bridge.publish(events.ConnStatusUp{}) bridge.publish(events.ConnStatusUp{})
bridge.goLoad() safe.RLock(func() {
for _, user := range bridge.users {
user.OnStatusUp()
}
}, &bridge.usersLock)
bridge.users.IterValues(func(user *user.User) { bridge.goLoad()
go user.OnStatusUp()
})
} }
func (bridge *Bridge) onStatusDown() { func (bridge *Bridge) onStatusDown() {
bridge.publish(events.ConnStatusDown{}) bridge.publish(events.ConnStatusDown{})
bridge.users.IterValues(func(user *user.User) { safe.RLock(func() {
go user.OnStatusDown() for _, user := range bridge.users {
}) user.OnStatusDown()
}
}, &bridge.usersLock)
bridge.tasks.Once(func(ctx context.Context) { bridge.tasks.Once(func(ctx context.Context) {
backoff := time.Second backoff := time.Second

View File

@ -18,18 +18,24 @@
package bridge package bridge
import ( import (
"fmt"
"strings" "strings"
"github.com/ProtonMail/proton-bridge/v2/internal/clientconfig" "github.com/ProtonMail/proton-bridge/v2/internal/clientconfig"
"github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
) )
// ConfigureAppleMail configures apple mail for the given userID and address.
// If configuring apple mail for Catalina or newer, it ensures Bridge is using SSL.
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error { func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error { return safe.RLockRet(func() error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
if address == "" { if address == "" {
address = user.Emails()[0] address = user.Emails()[0]
} }
@ -42,7 +48,6 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
addresses = strings.Join(user.Emails(), ",") addresses = strings.Join(user.Emails(), ",")
} }
// If configuring apple mail for Catalina or newer, users should use SSL.
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() { if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
if err := bridge.SetSMTPSSL(true); err != nil { if err := bridge.SetSMTPSSL(true); err != nil {
return err return err
@ -59,11 +64,5 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
addresses, addresses,
user.BridgePass(), user.BridgePass(),
) )
}); !ok { }, &bridge.usersLock)
return ErrNoSuchUser
} else if err != nil {
return fmt.Errorf("failed to configure apple mail: %w", err)
}
return nil
} }

View File

@ -32,6 +32,7 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/async" "github.com/ProtonMail/proton-bridge/v2/internal/async"
"github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/logging" "github.com/ProtonMail/proton-bridge/v2/internal/logging"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xsync" "github.com/bradenaw/juniper/xsync"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -83,6 +84,44 @@ func (bridge *Bridge) closeIMAP(ctx context.Context) error {
return nil return nil
} }
// addIMAPUser connects the given user to gluon.
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
imapConn, err := user.NewIMAPConnectors()
if err != nil {
return fmt.Errorf("failed to create IMAP connectors: %w", err)
}
for addrID, imapConn := range imapConn {
if gluonID, ok := user.GetGluonID(addrID); ok {
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to load IMAP user: %w", err)
}
} else {
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
}
}
return nil
}
// removeIMAPUser disconnects the given user from gluon, optionally also removing its files.
func (bridge *Bridge) removeIMAPUser(ctx context.Context, user *user.User, withFiles bool) error {
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
}
return nil
}
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) { func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
switch event := event.(type) { switch event := event.(type) {
case imapEvents.SessionAdded: case imapEvents.SessionAdded:

View File

@ -24,8 +24,8 @@ import (
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/updater" "github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain" "github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -118,48 +118,50 @@ func (bridge *Bridge) GetGluonDir() string {
} }
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error { func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
if newGluonDir == bridge.GetGluonDir() { return safe.RLockRet(func() error {
return fmt.Errorf("new gluon dir is the same as the old one") if newGluonDir == bridge.GetGluonDir() {
} return fmt.Errorf("new gluon dir is the same as the old one")
}
if err := bridge.closeIMAP(context.Background()); err != nil { if err := bridge.closeIMAP(context.Background()); err != nil {
return fmt.Errorf("failed to close IMAP: %w", err) return fmt.Errorf("failed to close IMAP: %w", err)
} }
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil { if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
return fmt.Errorf("failed to move gluon dir: %w", err) return fmt.Errorf("failed to move gluon dir: %w", err)
} }
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil { if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
return fmt.Errorf("failed to set new gluon dir: %w", err) return fmt.Errorf("failed to set new gluon dir: %w", err)
} }
imapServer, err := newIMAPServer( imapServer, err := newIMAPServer(
bridge.vault.GetGluonDir(), bridge.vault.GetGluonDir(),
bridge.curVersion, bridge.curVersion,
bridge.tlsConfig, bridge.tlsConfig,
bridge.logIMAPClient, bridge.logIMAPClient,
bridge.logIMAPServer, bridge.logIMAPServer,
bridge.imapEventCh, bridge.imapEventCh,
bridge.tasks, bridge.tasks,
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to create new IMAP server: %w", err) return fmt.Errorf("failed to create new IMAP server: %w", err)
} }
bridge.imapServer = imapServer bridge.imapServer = imapServer
if err := bridge.users.IterValuesErr(func(user *user.User) error { for _, user := range bridge.users {
return bridge.addIMAPUser(ctx, user) if err := bridge.addIMAPUser(ctx, user); err != nil {
}); err != nil { return fmt.Errorf("failed to add users to new IMAP server: %w", err)
return fmt.Errorf("failed to add users to new IMAP server: %w", err) }
} }
if err := bridge.serveIMAP(); err != nil { if err := bridge.serveIMAP(); err != nil {
return fmt.Errorf("failed to serve IMAP: %w", err) return fmt.Errorf("failed to serve IMAP: %w", err)
} }
return nil return nil
}, &bridge.usersLock)
} }
func (bridge *Bridge) GetProxyAllowed() bool { func (bridge *Bridge) GetProxyAllowed() bool {
@ -181,11 +183,13 @@ func (bridge *Bridge) GetShowAllMail() bool {
} }
func (bridge *Bridge) SetShowAllMail(show bool) error { func (bridge *Bridge) SetShowAllMail(show bool) error {
bridge.users.IterValues(func(user *user.User) { return safe.RLockRet(func() error {
user.SetShowAllMail(show) for _, user := range bridge.users {
}) user.SetShowAllMail(show)
}
return bridge.vault.SetShowAllMail(show) return bridge.vault.SetShowAllMail(show)
}, &bridge.usersLock)
} }
func (bridge *Bridge) GetAutostart() bool { func (bridge *Bridge) GetAutostart() bool {
@ -273,14 +277,18 @@ func (bridge *Bridge) SetColorScheme(colorScheme string) error {
} }
func (bridge *Bridge) FactoryReset(ctx context.Context) { func (bridge *Bridge) FactoryReset(ctx context.Context) {
// First delete all users. // Delete all the users.
for _, userID := range bridge.GetUserIDs() { safe.Lock(func() {
if bridge.users.Has(userID) { for _, user := range bridge.users {
if err := bridge.DeleteUser(ctx, userID); err != nil { bridge.logoutUser(ctx, user, true)
logrus.WithError(err).Errorf("Failed to delete user %s", userID) }
for _, user := range bridge.vault.GetUserIDs() {
if err := bridge.vault.DeleteUser(user); err != nil {
logrus.WithError(err).Error("failed to delete vault user")
} }
} }
} }, &bridge.usersLock)
// Then delete all files. // Then delete all files.
if err := bridge.locator.Clear(); err != nil { if err := bridge.locator.Clear(); err != nil {

View File

@ -23,8 +23,6 @@ import (
"fmt" "fmt"
"github.com/ProtonMail/proton-bridge/v2/internal/logging" "github.com/ProtonMail/proton-bridge/v2/internal/logging"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
@ -57,7 +55,7 @@ func (bridge *Bridge) restartSMTP() error {
return fmt.Errorf("failed to close SMTP: %w", err) return fmt.Errorf("failed to close SMTP: %w", err)
} }
bridge.smtpServer = newSMTPServer(bridge.users, bridge.tlsConfig, bridge.logSMTP) bridge.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
return bridge.serveSMTP() return bridge.serveSMTP()
} }
@ -80,8 +78,8 @@ func (bridge *Bridge) closeSMTP() error {
return nil return nil
} }
func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, shouldLog bool) *smtp.Server { func newSMTPServer(bridge *Bridge, tlsConfig *tls.Config, shouldLog bool) *smtp.Server {
smtpServer := smtp.NewServer(&smtpBackend{users}) smtpServer := smtp.NewServer(&smtpBackend{Bridge: bridge})
smtpServer.TLSConfig = tlsConfig smtpServer.TLSConfig = tlsConfig
smtpServer.Domain = constants.Host smtpServer.Domain = constants.Host
@ -94,6 +92,7 @@ func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, s
log.Warning("================================================") log.Warning("================================================")
log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA") log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
log.Warning("================================================") log.Warning("================================================")
smtpServer.Debug = logging.NewSMTPDebugLogger() smtpServer.Debug = logging.NewSMTPDebugLogger()
} }

View File

@ -22,16 +22,15 @@ import (
"io" "io"
"github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
) )
type smtpBackend struct { type smtpBackend struct {
users *safe.Map[string, *user.User] *Bridge
} }
type smtpSession struct { type smtpSession struct {
users *safe.Map[string, *user.User] *Bridge
userID string userID string
authID string authID string
@ -40,15 +39,13 @@ type smtpSession struct {
to []string to []string
} }
func (be *smtpBackend) NewSession(_ *smtp.Conn) (smtp.Session, error) { func (be *smtpBackend) NewSession(*smtp.Conn) (smtp.Session, error) {
return &smtpSession{ return &smtpSession{Bridge: be.Bridge}, nil
users: be.users,
}, nil
} }
func (s *smtpSession) AuthPlain(username, password string) error { func (s *smtpSession) AuthPlain(username, password string) error {
return s.users.ValuesErr(func(users []*user.User) error { return safe.RLockRet(func() error {
for _, user := range users { for _, user := range s.users {
addrID, err := user.CheckAuth(username, []byte(password)) addrID, err := user.CheckAuth(username, []byte(password))
if err != nil { if err != nil {
continue continue
@ -61,7 +58,7 @@ func (s *smtpSession) AuthPlain(username, password string) error {
} }
return fmt.Errorf("invalid username or password") return fmt.Errorf("invalid username or password")
}) }, &s.usersLock)
} }
func (s *smtpSession) Reset() { func (s *smtpSession) Reset() {
@ -88,13 +85,12 @@ func (s *smtpSession) Rcpt(to string) error {
} }
func (s *smtpSession) Data(r io.Reader) error { func (s *smtpSession) Data(r io.Reader) error {
if ok, err := s.users.GetErr(s.userID, func(user *user.User) error { return safe.RLockRet(func() error {
return user.SendMail(s.authID, s.from, s.to, r) user, ok := s.users[s.userID]
}); !ok { if !ok {
return fmt.Errorf("no such user %q", s.userID) return ErrNoSuchUser
} else if err != nil { }
return fmt.Errorf("failed to send mail: %w", err)
}
return nil return user.SendMail(s.authID, s.from, s.to, r)
}, &s.usersLock)
} }

View File

@ -66,32 +66,34 @@ func (bridge *Bridge) GetUserIDs() []string {
// GetUserInfo returns info about the given user. // GetUserInfo returns info about the given user.
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) { func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
if info, ok := safe.MapGetRet(bridge.users, userID, getConnUserInfo); ok { return safe.RLockRetErr(func() (UserInfo, error) {
if user, ok := bridge.users[userID]; ok {
return getConnUserInfo(user), nil
}
var info UserInfo
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
}); err != nil {
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
}
return info, nil return info, nil
} }, &bridge.usersLock)
var info UserInfo
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
}); err != nil {
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
}
return info, nil
} }
// QueryUserInfo queries the user info by username or address. // QueryUserInfo queries the user info by username or address.
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) { func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
return safe.MapValuesRetErr(bridge.users, func(users []*user.User) (UserInfo, error) { return safe.RLockRetErr(func() (UserInfo, error) {
for _, user := range users { for _, user := range bridge.users {
if user.Match(query) { if user.Match(query) {
return getConnUserInfo(user), nil return getConnUserInfo(user), nil
} }
} }
return UserInfo{}, ErrNoSuchUser return UserInfo{}, ErrNoSuchUser
}) }, &bridge.usersLock)
} }
// LoginAuth begins the login process. It returns an authorized client that might need 2FA. // LoginAuth begins the login process. It returns an authorized client that might need 2FA.
@ -101,7 +103,9 @@ func (bridge *Bridge) LoginAuth(ctx context.Context, username string, password [
return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err) return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err)
} }
if bridge.users.Has(auth.UserID) { if ok := safe.RLockRet(func() bool {
return mapHas(bridge.users, auth.UID)
}, &bridge.usersLock); ok {
if err := client.AuthDelete(ctx); err != nil { if err := client.AuthDelete(ctx); err != nil {
logrus.WithError(err).Warn("Failed to delete auth") logrus.WithError(err).Warn("Failed to delete auth")
} }
@ -182,31 +186,56 @@ func (bridge *Bridge) LoginFull(
// LogoutUser logs out the given user. // LogoutUser logs out the given user.
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error { func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
if err := bridge.logoutUser(ctx, userID); err != nil { return safe.LockRet(func() error {
return fmt.Errorf("failed to logout user: %w", err) user, ok := bridge.users[userID]
} if !ok {
return ErrNoSuchUser
}
bridge.publish(events.UserLoggedOut{ defer delete(bridge.users, user.ID())
UserID: userID,
})
return nil bridge.logoutUser(ctx, user, true)
bridge.publish(events.UserLoggedOut{
UserID: userID,
})
return nil
}, &bridge.usersLock)
} }
// DeleteUser deletes the given user. // DeleteUser deletes the given user.
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error { func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
bridge.deleteUser(ctx, userID) return safe.LockRet(func() error {
if !bridge.vault.HasUser(userID) {
return ErrNoSuchUser
}
bridge.publish(events.UserDeleted{ if user, ok := bridge.users[userID]; ok {
UserID: userID, defer delete(bridge.users, user.ID())
}) bridge.logoutUser(ctx, user, true)
}
return nil if err := bridge.vault.DeleteUser(userID); err != nil {
logrus.WithError(err).Error("Failed to delete vault user")
}
bridge.publish(events.UserDeleted{
UserID: userID,
})
return nil
}, &bridge.usersLock)
} }
// SetAddressMode sets the address mode for the given user. // SetAddressMode sets the address mode for the given user.
func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error { func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error {
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error { return safe.RLockRet(func() error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
if user.GetAddressMode() == mode { if user.GetAddressMode() == mode {
return fmt.Errorf("address mode is already %q", mode) return fmt.Errorf("address mode is already %q", mode)
} }
@ -231,13 +260,7 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
}) })
return nil return nil
}); !ok { }, &bridge.usersLock)
return ErrNoSuchUser
} else if err != nil {
return fmt.Errorf("failed to set address mode: %w", err)
}
return nil
} }
func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) { func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
@ -266,7 +289,13 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut
// loadUsers tries to load each user in the vault that isn't already loaded. // loadUsers tries to load each user in the vault that isn't already loaded.
func (bridge *Bridge) loadUsers(ctx context.Context) error { func (bridge *Bridge) loadUsers(ctx context.Context) error {
return bridge.vault.ForUser(func(user *vault.User) error { return bridge.vault.ForUser(func(user *vault.User) error {
if bridge.users.Has(user.UserID()) || user.AuthUID() == "" { if user.AuthUID() == "" {
return nil
}
if safe.RLockRet(func() bool {
return mapHas(bridge.users, user.UserID())
}, &bridge.usersLock) {
return nil return nil
} }
@ -293,20 +322,16 @@ func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
return fmt.Errorf("failed to set auth: %w", err) return fmt.Errorf("failed to set auth: %w", err)
} }
return try.Catch( apiUser, err := client.GetUser(ctx)
func() error { if err != nil {
apiUser, err := client.GetUser(ctx) return fmt.Errorf("failed to get user: %w", err)
if err != nil { }
return fmt.Errorf("failed to get user: %w", err)
}
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil { if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil {
return fmt.Errorf("failed to add user: %w", err) return fmt.Errorf("failed to add user: %w", err)
} }
return nil return nil
},
)
} }
// addUser adds a new user with an already salted mailbox password. // addUser adds a new user with an already salted mailbox password.
@ -364,10 +389,6 @@ func (bridge *Bridge) addUserWithVault(
return fmt.Errorf("failed to create user: %w", err) return fmt.Errorf("failed to create user: %w", err)
} }
if had := bridge.users.Set(apiUser.ID, user); had {
panic("double add")
}
// Connect the user's address(es) to gluon. // Connect the user's address(es) to gluon.
if err := bridge.addIMAPUser(ctx, user); err != nil { if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err) return fmt.Errorf("failed to add IMAP user: %w", err)
@ -395,6 +416,11 @@ func (bridge *Bridge) addUserWithVault(
return nil return nil
}) })
// Finally, save the user in the bridge.
safe.Lock(func() {
bridge.users[apiUser.ID] = user
}, &bridge.usersLock)
return nil return nil
} }
@ -430,75 +456,17 @@ func (bridge *Bridge) newVaultUser(
return user, false, nil return user, false, nil
} }
// addIMAPUser connects the given user to gluon. // logout logs out the given user, optionally logging them out from the API too.
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error { func (bridge *Bridge) logoutUser(ctx context.Context, user *user.User, withAPI bool) {
imapConn, err := user.NewIMAPConnectors() if err := bridge.removeIMAPUser(ctx, user, false); err != nil {
if err != nil { logrus.WithError(err).Error("Failed to remove IMAP user")
return fmt.Errorf("failed to create IMAP connectors: %w", err)
} }
for addrID, imapConn := range imapConn { if err := user.Logout(ctx, withAPI); err != nil {
if gluonID, ok := user.GetGluonID(addrID); ok { logrus.WithError(err).Error("Failed to logout user")
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to load IMAP user: %w", err)
}
} else {
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
}
} }
return nil user.Close()
}
// logoutUser logs the given user out from bridge.
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
}
if err := user.Logout(ctx); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
user.Close()
}); !ok {
return ErrNoSuchUser
}
return nil
}
// deleteUser deletes the given user from bridge.
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
}
if err := user.Logout(ctx); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
user.Close()
}); !ok {
logrus.Debug("The bridge user was not connected")
}
if err := bridge.vault.DeleteUser(userID); err != nil {
logrus.WithError(err).Error("Failed to delete user from vault")
}
} }
// getUserInfo returns information about a disconnected user. // getUserInfo returns information about a disconnected user.
@ -523,3 +491,8 @@ func getConnUserInfo(user *user.User) UserInfo {
MaxSpace: user.MaxSpace(), MaxSpace: user.MaxSpace(),
} }
} }
func mapHas[Key comparable, Val any](m map[Key]Val, key Key) bool {
_, ok := m[key]
return ok
}

View File

@ -22,6 +22,7 @@ import (
"fmt" "fmt"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
) )
@ -44,9 +45,11 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
} }
case events.UserDeauth: case events.UserDeauth:
if err := bridge.logoutUser(context.Background(), event.UserID); err != nil { safe.Lock(func() {
return fmt.Errorf("failed to logout user: %w", err) defer delete(bridge.users, user.ID())
}
bridge.logoutUser(ctx, user, false)
}, &bridge.usersLock)
} }
return nil return nil

View File

@ -147,6 +147,46 @@ func TestBridge_LoginDeauthLogin(t *testing.T) {
}) })
} }
func TestBridge_LoginDeauthRestartLogin(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginFull(ctx, username, password, nil, nil))
// Get a channel to receive the deauth event.
eventCh, done := bridge.GetEvents(events.UserDeauth{})
defer done()
// Deauth the user.
require.NoError(t, s.RevokeUser(userID))
// The user is eventually disconnected.
require.Eventually(t, func() bool {
return len(getConnectedUserIDs(t, bridge)) == 0
}, 10*time.Second, time.Second)
// We should get a deauth event.
require.IsType(t, events.UserDeauth{}, <-eventCh)
})
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user should be disconnected at startup.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
// Login the user after the disconnection.
newUserID := must(bridge.LoginFull(ctx, username, password, nil, nil))
require.Equal(t, userID, newUserID)
// The user is connected again.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginExpireLogin(t *testing.T) { func TestBridge_LoginExpireLogin(t *testing.T) {
const authLife = 2 * time.Second const authLife = 2 * time.Second
@ -449,6 +489,82 @@ func TestBridge_LoginLogoutRepeated(t *testing.T) {
}) })
} }
func TestBridge_LogoutOffline(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// Go offline.
netCtl.Disable()
// We can still log the user out.
require.NoError(t, bridge.LogoutUser(ctx, userID))
// The user is now disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_DeleteDisconnected(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// Logout the user.
require.NoError(t, bridge.LogoutUser(ctx, userID))
// The user is now disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
// Delete the user.
require.NoError(t, bridge.DeleteUser(ctx, userID))
// The user is now deleted.
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_DeleteOffline(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// Go offline.
netCtl.Disable()
// We can still log the user out.
require.NoError(t, bridge.DeleteUser(ctx, userID))
// The user is now gone.
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
// getErr returns the error that was passed to it. // getErr returns the error that was passed to it.
func getErr[T any](val T, err error) error { func getErr[T any](val T, err error) error {
return err return err

87
internal/safe/mutex.go Normal file
View File

@ -0,0 +1,87 @@
package safe
type Mutex interface {
Lock()
Unlock()
}
func Lock(fn func(), m ...Mutex) {
if len(m) == 0 {
panic("no mutexes provided")
}
for _, m := range m {
m.Lock()
defer m.Unlock()
}
fn()
}
func LockRet[T any](fn func() T, m ...Mutex) T {
var ret T
Lock(func() {
ret = fn()
}, m...)
return ret
}
func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) {
var ret T
err := LockRet(func() error {
var err error
ret, err = fn()
return err
}, m...)
return ret, err
}
type RWMutex interface {
Mutex
RLock()
RUnlock()
}
func RLock(fn func(), m ...RWMutex) {
if len(m) == 0 {
panic("no mutexes provided")
}
for _, m := range m {
m.RLock()
defer m.RUnlock()
}
fn()
}
func RLockRet[T any](fn func() T, m ...RWMutex) T {
var ret T
RLock(func() {
ret = fn()
}, m...)
return ret
}
func RLockRetErr[T any](fn func() (T, error), m ...RWMutex) (T, error) {
var err error
ret := RLockRet(func() T {
var ret T
ret, err = fn()
return ret
}, m...)
return ret, err
}

View File

@ -402,11 +402,13 @@ func (user *User) OnStatusDown() {
} }
// Logout logs the user out from the API. // Logout logs the user out from the API.
func (user *User) Logout(ctx context.Context) error { func (user *User) Logout(ctx context.Context, withAPI bool) error {
user.tasks.Wait() user.tasks.Wait()
if err := user.client.AuthDelete(ctx); err != nil { if withAPI {
return fmt.Errorf("failed to delete auth: %w", err) if err := user.client.AuthDelete(ctx); err != nil {
return fmt.Errorf("failed to delete auth: %w", err)
}
} }
if err := user.vault.Clear(); err != nil { if err := user.vault.Clear(); err != nil {