diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 49306d99..87ce3efa 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -51,8 +51,8 @@ type Bridge struct { vault *vault.Vault // users holds authorized users. - users *safe.Map[string, *user.User] - goLoad func() + users map[string]*user.User + usersLock sync.RWMutex // api manages user API clients. api *liteapi.Manager @@ -73,7 +73,6 @@ type Bridge struct { // updater is the bridge's updater. updater Updater - goUpdate func() curVersion *semver.Version // focusService is used to raise the bridge window when needed. @@ -99,6 +98,12 @@ type Bridge struct { // tasks manages the bridge's goroutines. tasks *xsync.Group + + // goLoad triggers a load of disconnected users from the vault. + goLoad func() + + // goUpdate triggers a check/install of updates. + goUpdate func() } // New creates a new bridge. @@ -133,12 +138,8 @@ func New( //nolint:funlen // imapEventCh forwards IMAP events from gluon instances to the bridge for processing. imapEventCh := make(chan imapEvents.Event) - // users holds all the bridge's users. - users := safe.NewMap[string, *user.User](nil) - // bridge is the bridge. bridge, err := newBridge( - users, tasks, imapEventCh, @@ -180,7 +181,6 @@ func New( //nolint:funlen // nolint:funlen func newBridge( - users *safe.Map[string, *user.User], tasks *xsync.Group, imapEventCh chan imapEvents.Event, @@ -224,9 +224,9 @@ func newBridge( return nil, fmt.Errorf("failed to create focus service: %w", err) } - return &Bridge{ + bridge := &Bridge{ vault: vault, - users: users, + users: make(map[string]*user.User), api: api, proxyCtl: proxyCtl, @@ -235,7 +235,6 @@ func newBridge( tlsConfig: tlsConfig, imapServer: imapServer, imapEventCh: imapEventCh, - smtpServer: newSMTPServer(users, tlsConfig, logSMTP), updater: updater, curVersion: curVersion, @@ -249,7 +248,11 @@ func newBridge( logSMTP: logSMTP, tasks: tasks, - }, nil + } + + bridge.smtpServer = newSMTPServer(bridge, tlsConfig, logSMTP) + + return bridge, nil } // nolint:funlen @@ -265,10 +268,10 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error { bridge.api.AddStatusObserver(func(status liteapi.Status) { switch { case status == liteapi.StatusUp: - bridge.onStatusUp() + go bridge.onStatusUp() case status == liteapi.StatusDown: - bridge.onStatusDown() + go bridge.onStatusDown() } }) @@ -356,9 +359,11 @@ func (bridge *Bridge) Close(ctx context.Context) error { } // Close all users. - bridge.users.IterValues(func(user *user.User) { - user.Close() - }) + safe.RLock(func() { + for _, user := range bridge.users { + user.Close() + } + }, &bridge.usersLock) // Stop all ongoing tasks. bridge.tasks.Wait() @@ -426,19 +431,23 @@ func (bridge *Bridge) remWatcher(watcher *watcher.Watcher[events.Event]) { func (bridge *Bridge) onStatusUp() { bridge.publish(events.ConnStatusUp{}) - bridge.goLoad() + safe.RLock(func() { + for _, user := range bridge.users { + user.OnStatusUp() + } + }, &bridge.usersLock) - bridge.users.IterValues(func(user *user.User) { - go user.OnStatusUp() - }) + bridge.goLoad() } func (bridge *Bridge) onStatusDown() { bridge.publish(events.ConnStatusDown{}) - bridge.users.IterValues(func(user *user.User) { - go user.OnStatusDown() - }) + safe.RLock(func() { + for _, user := range bridge.users { + user.OnStatusDown() + } + }, &bridge.usersLock) bridge.tasks.Once(func(ctx context.Context) { backoff := time.Second diff --git a/internal/bridge/configure.go b/internal/bridge/configure.go index c3c27958..edd9d998 100644 --- a/internal/bridge/configure.go +++ b/internal/bridge/configure.go @@ -18,18 +18,24 @@ package bridge import ( - "fmt" "strings" "github.com/ProtonMail/proton-bridge/v2/internal/clientconfig" "github.com/ProtonMail/proton-bridge/v2/internal/constants" - "github.com/ProtonMail/proton-bridge/v2/internal/user" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/vault" ) +// ConfigureAppleMail configures apple mail for the given userID and address. +// If configuring apple mail for Catalina or newer, it ensures Bridge is using SSL. func (bridge *Bridge) ConfigureAppleMail(userID, address string) error { - if ok, err := bridge.users.GetErr(userID, func(user *user.User) error { + return safe.RLockRet(func() error { + user, ok := bridge.users[userID] + if !ok { + return ErrNoSuchUser + } + if address == "" { address = user.Emails()[0] } @@ -42,7 +48,6 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error { addresses = strings.Join(user.Emails(), ",") } - // If configuring apple mail for Catalina or newer, users should use SSL. if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() { if err := bridge.SetSMTPSSL(true); err != nil { return err @@ -59,11 +64,5 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error { addresses, user.BridgePass(), ) - }); !ok { - return ErrNoSuchUser - } else if err != nil { - return fmt.Errorf("failed to configure apple mail: %w", err) - } - - return nil + }, &bridge.usersLock) } diff --git a/internal/bridge/imap.go b/internal/bridge/imap.go index 688cde7f..dff7bf27 100644 --- a/internal/bridge/imap.go +++ b/internal/bridge/imap.go @@ -32,6 +32,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/async" "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/logging" + "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/xsync" "github.com/sirupsen/logrus" @@ -83,6 +84,44 @@ func (bridge *Bridge) closeIMAP(ctx context.Context) error { return nil } +// addIMAPUser connects the given user to gluon. +func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error { + imapConn, err := user.NewIMAPConnectors() + if err != nil { + return fmt.Errorf("failed to create IMAP connectors: %w", err) + } + + for addrID, imapConn := range imapConn { + if gluonID, ok := user.GetGluonID(addrID); ok { + if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil { + return fmt.Errorf("failed to load IMAP user: %w", err) + } + } else { + gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey()) + if err != nil { + return fmt.Errorf("failed to add IMAP user: %w", err) + } + + if err := user.SetGluonID(addrID, gluonID); err != nil { + return fmt.Errorf("failed to set IMAP user ID: %w", err) + } + } + } + + return nil +} + +// removeIMAPUser disconnects the given user from gluon, optionally also removing its files. +func (bridge *Bridge) removeIMAPUser(ctx context.Context, user *user.User, withFiles bool) error { + for _, gluonID := range user.GetGluonIDs() { + if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil { + return fmt.Errorf("failed to remove IMAP user: %w", err) + } + } + + return nil +} + func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) { switch event := event.(type) { case imapEvents.SessionAdded: diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index 29f74f77..84e7f06f 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -24,8 +24,8 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v2/internal/constants" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/updater" - "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/pkg/keychain" "github.com/sirupsen/logrus" @@ -118,48 +118,50 @@ func (bridge *Bridge) GetGluonDir() string { } func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error { - if newGluonDir == bridge.GetGluonDir() { - return fmt.Errorf("new gluon dir is the same as the old one") - } + return safe.RLockRet(func() error { + if newGluonDir == bridge.GetGluonDir() { + return fmt.Errorf("new gluon dir is the same as the old one") + } - if err := bridge.closeIMAP(context.Background()); err != nil { - return fmt.Errorf("failed to close IMAP: %w", err) - } + if err := bridge.closeIMAP(context.Background()); err != nil { + return fmt.Errorf("failed to close IMAP: %w", err) + } - if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil { - return fmt.Errorf("failed to move gluon dir: %w", err) - } + if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil { + return fmt.Errorf("failed to move gluon dir: %w", err) + } - if err := bridge.vault.SetGluonDir(newGluonDir); err != nil { - return fmt.Errorf("failed to set new gluon dir: %w", err) - } + if err := bridge.vault.SetGluonDir(newGluonDir); err != nil { + return fmt.Errorf("failed to set new gluon dir: %w", err) + } - imapServer, err := newIMAPServer( - bridge.vault.GetGluonDir(), - bridge.curVersion, - bridge.tlsConfig, - bridge.logIMAPClient, - bridge.logIMAPServer, - bridge.imapEventCh, - bridge.tasks, - ) - if err != nil { - return fmt.Errorf("failed to create new IMAP server: %w", err) - } + imapServer, err := newIMAPServer( + bridge.vault.GetGluonDir(), + bridge.curVersion, + bridge.tlsConfig, + bridge.logIMAPClient, + bridge.logIMAPServer, + bridge.imapEventCh, + bridge.tasks, + ) + if err != nil { + return fmt.Errorf("failed to create new IMAP server: %w", err) + } - bridge.imapServer = imapServer + bridge.imapServer = imapServer - if err := bridge.users.IterValuesErr(func(user *user.User) error { - return bridge.addIMAPUser(ctx, user) - }); err != nil { - return fmt.Errorf("failed to add users to new IMAP server: %w", err) - } + for _, user := range bridge.users { + if err := bridge.addIMAPUser(ctx, user); err != nil { + return fmt.Errorf("failed to add users to new IMAP server: %w", err) + } + } - if err := bridge.serveIMAP(); err != nil { - return fmt.Errorf("failed to serve IMAP: %w", err) - } + if err := bridge.serveIMAP(); err != nil { + return fmt.Errorf("failed to serve IMAP: %w", err) + } - return nil + return nil + }, &bridge.usersLock) } func (bridge *Bridge) GetProxyAllowed() bool { @@ -181,11 +183,13 @@ func (bridge *Bridge) GetShowAllMail() bool { } func (bridge *Bridge) SetShowAllMail(show bool) error { - bridge.users.IterValues(func(user *user.User) { - user.SetShowAllMail(show) - }) + return safe.RLockRet(func() error { + for _, user := range bridge.users { + user.SetShowAllMail(show) + } - return bridge.vault.SetShowAllMail(show) + return bridge.vault.SetShowAllMail(show) + }, &bridge.usersLock) } func (bridge *Bridge) GetAutostart() bool { @@ -273,14 +277,18 @@ func (bridge *Bridge) SetColorScheme(colorScheme string) error { } func (bridge *Bridge) FactoryReset(ctx context.Context) { - // First delete all users. - for _, userID := range bridge.GetUserIDs() { - if bridge.users.Has(userID) { - if err := bridge.DeleteUser(ctx, userID); err != nil { - logrus.WithError(err).Errorf("Failed to delete user %s", userID) + // Delete all the users. + safe.Lock(func() { + for _, user := range bridge.users { + bridge.logoutUser(ctx, user, true) + } + + for _, user := range bridge.vault.GetUserIDs() { + if err := bridge.vault.DeleteUser(user); err != nil { + logrus.WithError(err).Error("failed to delete vault user") } } - } + }, &bridge.usersLock) // Then delete all files. if err := bridge.locator.Clear(); err != nil { diff --git a/internal/bridge/smtp.go b/internal/bridge/smtp.go index cfb7668a..9e2f0507 100644 --- a/internal/bridge/smtp.go +++ b/internal/bridge/smtp.go @@ -23,8 +23,6 @@ import ( "fmt" "github.com/ProtonMail/proton-bridge/v2/internal/logging" - "github.com/ProtonMail/proton-bridge/v2/internal/safe" - "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/emersion/go-smtp" @@ -57,7 +55,7 @@ func (bridge *Bridge) restartSMTP() error { return fmt.Errorf("failed to close SMTP: %w", err) } - bridge.smtpServer = newSMTPServer(bridge.users, bridge.tlsConfig, bridge.logSMTP) + bridge.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP) return bridge.serveSMTP() } @@ -80,8 +78,8 @@ func (bridge *Bridge) closeSMTP() error { return nil } -func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, shouldLog bool) *smtp.Server { - smtpServer := smtp.NewServer(&smtpBackend{users}) +func newSMTPServer(bridge *Bridge, tlsConfig *tls.Config, shouldLog bool) *smtp.Server { + smtpServer := smtp.NewServer(&smtpBackend{Bridge: bridge}) smtpServer.TLSConfig = tlsConfig smtpServer.Domain = constants.Host @@ -94,6 +92,7 @@ func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, s log.Warning("================================================") log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA") log.Warning("================================================") + smtpServer.Debug = logging.NewSMTPDebugLogger() } diff --git a/internal/bridge/smtp_backend.go b/internal/bridge/smtp_backend.go index cbf870f3..2ef93ce4 100644 --- a/internal/bridge/smtp_backend.go +++ b/internal/bridge/smtp_backend.go @@ -22,16 +22,15 @@ import ( "io" "github.com/ProtonMail/proton-bridge/v2/internal/safe" - "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/emersion/go-smtp" ) type smtpBackend struct { - users *safe.Map[string, *user.User] + *Bridge } type smtpSession struct { - users *safe.Map[string, *user.User] + *Bridge userID string authID string @@ -40,15 +39,13 @@ type smtpSession struct { to []string } -func (be *smtpBackend) NewSession(_ *smtp.Conn) (smtp.Session, error) { - return &smtpSession{ - users: be.users, - }, nil +func (be *smtpBackend) NewSession(*smtp.Conn) (smtp.Session, error) { + return &smtpSession{Bridge: be.Bridge}, nil } func (s *smtpSession) AuthPlain(username, password string) error { - return s.users.ValuesErr(func(users []*user.User) error { - for _, user := range users { + return safe.RLockRet(func() error { + for _, user := range s.users { addrID, err := user.CheckAuth(username, []byte(password)) if err != nil { continue @@ -61,7 +58,7 @@ func (s *smtpSession) AuthPlain(username, password string) error { } return fmt.Errorf("invalid username or password") - }) + }, &s.usersLock) } func (s *smtpSession) Reset() { @@ -88,13 +85,12 @@ func (s *smtpSession) Rcpt(to string) error { } func (s *smtpSession) Data(r io.Reader) error { - if ok, err := s.users.GetErr(s.userID, func(user *user.User) error { - return user.SendMail(s.authID, s.from, s.to, r) - }); !ok { - return fmt.Errorf("no such user %q", s.userID) - } else if err != nil { - return fmt.Errorf("failed to send mail: %w", err) - } + return safe.RLockRet(func() error { + user, ok := s.users[s.userID] + if !ok { + return ErrNoSuchUser + } - return nil + return user.SendMail(s.authID, s.from, s.to, r) + }, &s.usersLock) } diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 571d5115..45d1c2f0 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -66,32 +66,34 @@ func (bridge *Bridge) GetUserIDs() []string { // GetUserInfo returns info about the given user. func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) { - if info, ok := safe.MapGetRet(bridge.users, userID, getConnUserInfo); ok { + return safe.RLockRetErr(func() (UserInfo, error) { + if user, ok := bridge.users[userID]; ok { + return getConnUserInfo(user), nil + } + + var info UserInfo + + if err := bridge.vault.GetUser(userID, func(user *vault.User) { + info = getUserInfo(user.UserID(), user.Username(), user.AddressMode()) + }); err != nil { + return UserInfo{}, fmt.Errorf("failed to get user info: %w", err) + } + return info, nil - } - - var info UserInfo - - if err := bridge.vault.GetUser(userID, func(user *vault.User) { - info = getUserInfo(user.UserID(), user.Username(), user.AddressMode()) - }); err != nil { - return UserInfo{}, fmt.Errorf("failed to get user info: %w", err) - } - - return info, nil + }, &bridge.usersLock) } // QueryUserInfo queries the user info by username or address. func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) { - return safe.MapValuesRetErr(bridge.users, func(users []*user.User) (UserInfo, error) { - for _, user := range users { + return safe.RLockRetErr(func() (UserInfo, error) { + for _, user := range bridge.users { if user.Match(query) { return getConnUserInfo(user), nil } } return UserInfo{}, ErrNoSuchUser - }) + }, &bridge.usersLock) } // LoginAuth begins the login process. It returns an authorized client that might need 2FA. @@ -101,7 +103,9 @@ func (bridge *Bridge) LoginAuth(ctx context.Context, username string, password [ return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err) } - if bridge.users.Has(auth.UserID) { + if ok := safe.RLockRet(func() bool { + return mapHas(bridge.users, auth.UID) + }, &bridge.usersLock); ok { if err := client.AuthDelete(ctx); err != nil { logrus.WithError(err).Warn("Failed to delete auth") } @@ -182,31 +186,56 @@ func (bridge *Bridge) LoginFull( // LogoutUser logs out the given user. func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error { - if err := bridge.logoutUser(ctx, userID); err != nil { - return fmt.Errorf("failed to logout user: %w", err) - } + return safe.LockRet(func() error { + user, ok := bridge.users[userID] + if !ok { + return ErrNoSuchUser + } - bridge.publish(events.UserLoggedOut{ - UserID: userID, - }) + defer delete(bridge.users, user.ID()) - return nil + bridge.logoutUser(ctx, user, true) + + bridge.publish(events.UserLoggedOut{ + UserID: userID, + }) + + return nil + }, &bridge.usersLock) } // DeleteUser deletes the given user. func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error { - bridge.deleteUser(ctx, userID) + return safe.LockRet(func() error { + if !bridge.vault.HasUser(userID) { + return ErrNoSuchUser + } - bridge.publish(events.UserDeleted{ - UserID: userID, - }) + if user, ok := bridge.users[userID]; ok { + defer delete(bridge.users, user.ID()) + bridge.logoutUser(ctx, user, true) + } - return nil + if err := bridge.vault.DeleteUser(userID); err != nil { + logrus.WithError(err).Error("Failed to delete vault user") + } + + bridge.publish(events.UserDeleted{ + UserID: userID, + }) + + return nil + }, &bridge.usersLock) } // SetAddressMode sets the address mode for the given user. func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error { - if ok, err := bridge.users.GetErr(userID, func(user *user.User) error { + return safe.RLockRet(func() error { + user, ok := bridge.users[userID] + if !ok { + return ErrNoSuchUser + } + if user.GetAddressMode() == mode { return fmt.Errorf("address mode is already %q", mode) } @@ -231,13 +260,7 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va }) return nil - }); !ok { - return ErrNoSuchUser - } else if err != nil { - return fmt.Errorf("failed to set address mode: %w", err) - } - - return nil + }, &bridge.usersLock) } func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) { @@ -266,7 +289,13 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut // loadUsers tries to load each user in the vault that isn't already loaded. func (bridge *Bridge) loadUsers(ctx context.Context) error { return bridge.vault.ForUser(func(user *vault.User) error { - if bridge.users.Has(user.UserID()) || user.AuthUID() == "" { + if user.AuthUID() == "" { + return nil + } + + if safe.RLockRet(func() bool { + return mapHas(bridge.users, user.UserID()) + }, &bridge.usersLock) { return nil } @@ -293,20 +322,16 @@ func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error { return fmt.Errorf("failed to set auth: %w", err) } - return try.Catch( - func() error { - apiUser, err := client.GetUser(ctx) - if err != nil { - return fmt.Errorf("failed to get user: %w", err) - } + apiUser, err := client.GetUser(ctx) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) + } - if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil { - return fmt.Errorf("failed to add user: %w", err) - } + if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil { + return fmt.Errorf("failed to add user: %w", err) + } - return nil - }, - ) + return nil } // addUser adds a new user with an already salted mailbox password. @@ -364,10 +389,6 @@ func (bridge *Bridge) addUserWithVault( return fmt.Errorf("failed to create user: %w", err) } - if had := bridge.users.Set(apiUser.ID, user); had { - panic("double add") - } - // Connect the user's address(es) to gluon. if err := bridge.addIMAPUser(ctx, user); err != nil { return fmt.Errorf("failed to add IMAP user: %w", err) @@ -395,6 +416,11 @@ func (bridge *Bridge) addUserWithVault( return nil }) + // Finally, save the user in the bridge. + safe.Lock(func() { + bridge.users[apiUser.ID] = user + }, &bridge.usersLock) + return nil } @@ -430,75 +456,17 @@ func (bridge *Bridge) newVaultUser( return user, false, nil } -// addIMAPUser connects the given user to gluon. -func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error { - imapConn, err := user.NewIMAPConnectors() - if err != nil { - return fmt.Errorf("failed to create IMAP connectors: %w", err) +// logout logs out the given user, optionally logging them out from the API too. +func (bridge *Bridge) logoutUser(ctx context.Context, user *user.User, withAPI bool) { + if err := bridge.removeIMAPUser(ctx, user, false); err != nil { + logrus.WithError(err).Error("Failed to remove IMAP user") } - for addrID, imapConn := range imapConn { - if gluonID, ok := user.GetGluonID(addrID); ok { - if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil { - return fmt.Errorf("failed to load IMAP user: %w", err) - } - } else { - gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey()) - if err != nil { - return fmt.Errorf("failed to add IMAP user: %w", err) - } - - if err := user.SetGluonID(addrID, gluonID); err != nil { - return fmt.Errorf("failed to set IMAP user ID: %w", err) - } - } + if err := user.Logout(ctx, withAPI); err != nil { + logrus.WithError(err).Error("Failed to logout user") } - return nil -} - -// logoutUser logs the given user out from bridge. -func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error { - if ok := bridge.users.GetDelete(userID, func(user *user.User) { - for _, gluonID := range user.GetGluonIDs() { - if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil { - logrus.WithError(err).Error("Failed to remove IMAP user") - } - } - - if err := user.Logout(ctx); err != nil { - logrus.WithError(err).Error("Failed to logout user") - } - - user.Close() - }); !ok { - return ErrNoSuchUser - } - - return nil -} - -// deleteUser deletes the given user from bridge. -func (bridge *Bridge) deleteUser(ctx context.Context, userID string) { - if ok := bridge.users.GetDelete(userID, func(user *user.User) { - for _, gluonID := range user.GetGluonIDs() { - if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil { - logrus.WithError(err).Error("Failed to remove IMAP user") - } - } - - if err := user.Logout(ctx); err != nil { - logrus.WithError(err).Error("Failed to logout user") - } - - user.Close() - }); !ok { - logrus.Debug("The bridge user was not connected") - } - - if err := bridge.vault.DeleteUser(userID); err != nil { - logrus.WithError(err).Error("Failed to delete user from vault") - } + user.Close() } // getUserInfo returns information about a disconnected user. @@ -523,3 +491,8 @@ func getConnUserInfo(user *user.User) UserInfo { MaxSpace: user.MaxSpace(), } } + +func mapHas[Key comparable, Val any](m map[Key]Val, key Key) bool { + _, ok := m[key] + return ok +} diff --git a/internal/bridge/user_events.go b/internal/bridge/user_events.go index b9432172..04dbc50f 100644 --- a/internal/bridge/user_events.go +++ b/internal/bridge/user_events.go @@ -22,6 +22,7 @@ import ( "fmt" "github.com/ProtonMail/proton-bridge/v2/internal/events" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/vault" ) @@ -44,9 +45,11 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even } case events.UserDeauth: - if err := bridge.logoutUser(context.Background(), event.UserID); err != nil { - return fmt.Errorf("failed to logout user: %w", err) - } + safe.Lock(func() { + defer delete(bridge.users, user.ID()) + + bridge.logoutUser(ctx, user, false) + }, &bridge.usersLock) } return nil diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index cc225973..b306c938 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -147,6 +147,46 @@ func TestBridge_LoginDeauthLogin(t *testing.T) { }) } +func TestBridge_LoginDeauthRestartLogin(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + var userID string + + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Login the user. + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) + + // Get a channel to receive the deauth event. + eventCh, done := bridge.GetEvents(events.UserDeauth{}) + defer done() + + // Deauth the user. + require.NoError(t, s.RevokeUser(userID)) + + // The user is eventually disconnected. + require.Eventually(t, func() bool { + return len(getConnectedUserIDs(t, bridge)) == 0 + }, 10*time.Second, time.Second) + + // We should get a deauth event. + require.IsType(t, events.UserDeauth{}, <-eventCh) + }) + + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // The user should be disconnected at startup. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + + // Login the user after the disconnection. + newUserID := must(bridge.LoginFull(ctx, username, password, nil, nil)) + require.Equal(t, userID, newUserID) + + // The user is connected again. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) + }) + }) +} + func TestBridge_LoginExpireLogin(t *testing.T) { const authLife = 2 * time.Second @@ -449,6 +489,82 @@ func TestBridge_LoginLogoutRepeated(t *testing.T) { }) } +func TestBridge_LogoutOffline(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Login the user. + userID, err := bridge.LoginFull(ctx, username, password, nil, nil) + require.NoError(t, err) + + // The user is now connected. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) + + // Go offline. + netCtl.Disable() + + // We can still log the user out. + require.NoError(t, bridge.LogoutUser(ctx, userID)) + + // The user is now disconnected. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + }) + }) +} + +func TestBridge_DeleteDisconnected(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Login the user. + userID, err := bridge.LoginFull(ctx, username, password, nil, nil) + require.NoError(t, err) + + // The user is now connected. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) + + // Logout the user. + require.NoError(t, bridge.LogoutUser(ctx, userID)) + + // The user is now disconnected. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + + // Delete the user. + require.NoError(t, bridge.DeleteUser(ctx, userID)) + + // The user is now deleted. + require.Empty(t, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + }) + }) +} + +func TestBridge_DeleteOffline(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Login the user. + userID, err := bridge.LoginFull(ctx, username, password, nil, nil) + require.NoError(t, err) + + // The user is now connected. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) + + // Go offline. + netCtl.Disable() + + // We can still log the user out. + require.NoError(t, bridge.DeleteUser(ctx, userID)) + + // The user is now gone. + require.Empty(t, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + }) + }) +} + // getErr returns the error that was passed to it. func getErr[T any](val T, err error) error { return err diff --git a/internal/safe/mutex.go b/internal/safe/mutex.go new file mode 100644 index 00000000..b7102c90 --- /dev/null +++ b/internal/safe/mutex.go @@ -0,0 +1,87 @@ +package safe + +type Mutex interface { + Lock() + Unlock() +} + +func Lock(fn func(), m ...Mutex) { + if len(m) == 0 { + panic("no mutexes provided") + } + + for _, m := range m { + m.Lock() + defer m.Unlock() + } + + fn() +} + +func LockRet[T any](fn func() T, m ...Mutex) T { + var ret T + + Lock(func() { + ret = fn() + }, m...) + + return ret +} + +func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) { + var ret T + + err := LockRet(func() error { + var err error + + ret, err = fn() + + return err + }, m...) + + return ret, err +} + +type RWMutex interface { + Mutex + + RLock() + RUnlock() +} + +func RLock(fn func(), m ...RWMutex) { + if len(m) == 0 { + panic("no mutexes provided") + } + + for _, m := range m { + m.RLock() + defer m.RUnlock() + } + + fn() +} + +func RLockRet[T any](fn func() T, m ...RWMutex) T { + var ret T + + RLock(func() { + ret = fn() + }, m...) + + return ret +} + +func RLockRetErr[T any](fn func() (T, error), m ...RWMutex) (T, error) { + var err error + + ret := RLockRet(func() T { + var ret T + + ret, err = fn() + + return ret + }, m...) + + return ret, err +} diff --git a/internal/user/user.go b/internal/user/user.go index e0345962..5a63fd80 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -402,11 +402,13 @@ func (user *User) OnStatusDown() { } // Logout logs the user out from the API. -func (user *User) Logout(ctx context.Context) error { +func (user *User) Logout(ctx context.Context, withAPI bool) error { user.tasks.Wait() - if err := user.client.AuthDelete(ctx); err != nil { - return fmt.Errorf("failed to delete auth: %w", err) + if withAPI { + if err := user.client.AuthDelete(ctx); err != nil { + return fmt.Errorf("failed to delete auth: %w", err) + } } if err := user.vault.Clear(); err != nil {