Other(refactor): Less unwieldy user type in Bridge

Instead of the annoying safe.Map type, we just use a normal go map
and mutex pair, and use safe.Lock/safe.RLock as a helper function
This commit is contained in:
James Houlahan
2022-10-26 23:04:25 +02:00
parent 85c0d6f837
commit 2bda47fcad
11 changed files with 458 additions and 227 deletions

View File

@ -51,8 +51,8 @@ type Bridge struct {
vault *vault.Vault
// users holds authorized users.
users *safe.Map[string, *user.User]
goLoad func()
users map[string]*user.User
usersLock sync.RWMutex
// api manages user API clients.
api *liteapi.Manager
@ -73,7 +73,6 @@ type Bridge struct {
// updater is the bridge's updater.
updater Updater
goUpdate func()
curVersion *semver.Version
// focusService is used to raise the bridge window when needed.
@ -99,6 +98,12 @@ type Bridge struct {
// tasks manages the bridge's goroutines.
tasks *xsync.Group
// goLoad triggers a load of disconnected users from the vault.
goLoad func()
// goUpdate triggers a check/install of updates.
goUpdate func()
}
// New creates a new bridge.
@ -133,12 +138,8 @@ func New( //nolint:funlen
// imapEventCh forwards IMAP events from gluon instances to the bridge for processing.
imapEventCh := make(chan imapEvents.Event)
// users holds all the bridge's users.
users := safe.NewMap[string, *user.User](nil)
// bridge is the bridge.
bridge, err := newBridge(
users,
tasks,
imapEventCh,
@ -180,7 +181,6 @@ func New( //nolint:funlen
// nolint:funlen
func newBridge(
users *safe.Map[string, *user.User],
tasks *xsync.Group,
imapEventCh chan imapEvents.Event,
@ -224,9 +224,9 @@ func newBridge(
return nil, fmt.Errorf("failed to create focus service: %w", err)
}
return &Bridge{
bridge := &Bridge{
vault: vault,
users: users,
users: make(map[string]*user.User),
api: api,
proxyCtl: proxyCtl,
@ -235,7 +235,6 @@ func newBridge(
tlsConfig: tlsConfig,
imapServer: imapServer,
imapEventCh: imapEventCh,
smtpServer: newSMTPServer(users, tlsConfig, logSMTP),
updater: updater,
curVersion: curVersion,
@ -249,7 +248,11 @@ func newBridge(
logSMTP: logSMTP,
tasks: tasks,
}, nil
}
bridge.smtpServer = newSMTPServer(bridge, tlsConfig, logSMTP)
return bridge, nil
}
// nolint:funlen
@ -265,10 +268,10 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
bridge.api.AddStatusObserver(func(status liteapi.Status) {
switch {
case status == liteapi.StatusUp:
bridge.onStatusUp()
go bridge.onStatusUp()
case status == liteapi.StatusDown:
bridge.onStatusDown()
go bridge.onStatusDown()
}
})
@ -356,9 +359,11 @@ func (bridge *Bridge) Close(ctx context.Context) error {
}
// Close all users.
bridge.users.IterValues(func(user *user.User) {
user.Close()
})
safe.RLock(func() {
for _, user := range bridge.users {
user.Close()
}
}, &bridge.usersLock)
// Stop all ongoing tasks.
bridge.tasks.Wait()
@ -426,19 +431,23 @@ func (bridge *Bridge) remWatcher(watcher *watcher.Watcher[events.Event]) {
func (bridge *Bridge) onStatusUp() {
bridge.publish(events.ConnStatusUp{})
bridge.goLoad()
safe.RLock(func() {
for _, user := range bridge.users {
user.OnStatusUp()
}
}, &bridge.usersLock)
bridge.users.IterValues(func(user *user.User) {
go user.OnStatusUp()
})
bridge.goLoad()
}
func (bridge *Bridge) onStatusDown() {
bridge.publish(events.ConnStatusDown{})
bridge.users.IterValues(func(user *user.User) {
go user.OnStatusDown()
})
safe.RLock(func() {
for _, user := range bridge.users {
user.OnStatusDown()
}
}, &bridge.usersLock)
bridge.tasks.Once(func(ctx context.Context) {
backoff := time.Second

View File

@ -18,18 +18,24 @@
package bridge
import (
"fmt"
"strings"
"github.com/ProtonMail/proton-bridge/v2/internal/clientconfig"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
)
// ConfigureAppleMail configures apple mail for the given userID and address.
// If configuring apple mail for Catalina or newer, it ensures Bridge is using SSL.
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error {
return safe.RLockRet(func() error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
if address == "" {
address = user.Emails()[0]
}
@ -42,7 +48,6 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
addresses = strings.Join(user.Emails(), ",")
}
// If configuring apple mail for Catalina or newer, users should use SSL.
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
if err := bridge.SetSMTPSSL(true); err != nil {
return err
@ -59,11 +64,5 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
addresses,
user.BridgePass(),
)
}); !ok {
return ErrNoSuchUser
} else if err != nil {
return fmt.Errorf("failed to configure apple mail: %w", err)
}
return nil
}, &bridge.usersLock)
}

View File

@ -32,6 +32,7 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/async"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xsync"
"github.com/sirupsen/logrus"
@ -83,6 +84,44 @@ func (bridge *Bridge) closeIMAP(ctx context.Context) error {
return nil
}
// addIMAPUser connects the given user to gluon.
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
imapConn, err := user.NewIMAPConnectors()
if err != nil {
return fmt.Errorf("failed to create IMAP connectors: %w", err)
}
for addrID, imapConn := range imapConn {
if gluonID, ok := user.GetGluonID(addrID); ok {
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to load IMAP user: %w", err)
}
} else {
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
}
}
return nil
}
// removeIMAPUser disconnects the given user from gluon, optionally also removing its files.
func (bridge *Bridge) removeIMAPUser(ctx context.Context, user *user.User, withFiles bool) error {
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
}
return nil
}
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
switch event := event.(type) {
case imapEvents.SessionAdded:

View File

@ -24,8 +24,8 @@ import (
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
"github.com/sirupsen/logrus"
@ -118,48 +118,50 @@ func (bridge *Bridge) GetGluonDir() string {
}
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
if newGluonDir == bridge.GetGluonDir() {
return fmt.Errorf("new gluon dir is the same as the old one")
}
return safe.RLockRet(func() error {
if newGluonDir == bridge.GetGluonDir() {
return fmt.Errorf("new gluon dir is the same as the old one")
}
if err := bridge.closeIMAP(context.Background()); err != nil {
return fmt.Errorf("failed to close IMAP: %w", err)
}
if err := bridge.closeIMAP(context.Background()); err != nil {
return fmt.Errorf("failed to close IMAP: %w", err)
}
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
return fmt.Errorf("failed to move gluon dir: %w", err)
}
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
return fmt.Errorf("failed to move gluon dir: %w", err)
}
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
return fmt.Errorf("failed to set new gluon dir: %w", err)
}
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
return fmt.Errorf("failed to set new gluon dir: %w", err)
}
imapServer, err := newIMAPServer(
bridge.vault.GetGluonDir(),
bridge.curVersion,
bridge.tlsConfig,
bridge.logIMAPClient,
bridge.logIMAPServer,
bridge.imapEventCh,
bridge.tasks,
)
if err != nil {
return fmt.Errorf("failed to create new IMAP server: %w", err)
}
imapServer, err := newIMAPServer(
bridge.vault.GetGluonDir(),
bridge.curVersion,
bridge.tlsConfig,
bridge.logIMAPClient,
bridge.logIMAPServer,
bridge.imapEventCh,
bridge.tasks,
)
if err != nil {
return fmt.Errorf("failed to create new IMAP server: %w", err)
}
bridge.imapServer = imapServer
bridge.imapServer = imapServer
if err := bridge.users.IterValuesErr(func(user *user.User) error {
return bridge.addIMAPUser(ctx, user)
}); err != nil {
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
}
for _, user := range bridge.users {
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
}
}
if err := bridge.serveIMAP(); err != nil {
return fmt.Errorf("failed to serve IMAP: %w", err)
}
if err := bridge.serveIMAP(); err != nil {
return fmt.Errorf("failed to serve IMAP: %w", err)
}
return nil
return nil
}, &bridge.usersLock)
}
func (bridge *Bridge) GetProxyAllowed() bool {
@ -181,11 +183,13 @@ func (bridge *Bridge) GetShowAllMail() bool {
}
func (bridge *Bridge) SetShowAllMail(show bool) error {
bridge.users.IterValues(func(user *user.User) {
user.SetShowAllMail(show)
})
return safe.RLockRet(func() error {
for _, user := range bridge.users {
user.SetShowAllMail(show)
}
return bridge.vault.SetShowAllMail(show)
return bridge.vault.SetShowAllMail(show)
}, &bridge.usersLock)
}
func (bridge *Bridge) GetAutostart() bool {
@ -273,14 +277,18 @@ func (bridge *Bridge) SetColorScheme(colorScheme string) error {
}
func (bridge *Bridge) FactoryReset(ctx context.Context) {
// First delete all users.
for _, userID := range bridge.GetUserIDs() {
if bridge.users.Has(userID) {
if err := bridge.DeleteUser(ctx, userID); err != nil {
logrus.WithError(err).Errorf("Failed to delete user %s", userID)
// Delete all the users.
safe.Lock(func() {
for _, user := range bridge.users {
bridge.logoutUser(ctx, user, true)
}
for _, user := range bridge.vault.GetUserIDs() {
if err := bridge.vault.DeleteUser(user); err != nil {
logrus.WithError(err).Error("failed to delete vault user")
}
}
}
}, &bridge.usersLock)
// Then delete all files.
if err := bridge.locator.Clear(); err != nil {

View File

@ -23,8 +23,6 @@ import (
"fmt"
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/emersion/go-smtp"
@ -57,7 +55,7 @@ func (bridge *Bridge) restartSMTP() error {
return fmt.Errorf("failed to close SMTP: %w", err)
}
bridge.smtpServer = newSMTPServer(bridge.users, bridge.tlsConfig, bridge.logSMTP)
bridge.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
return bridge.serveSMTP()
}
@ -80,8 +78,8 @@ func (bridge *Bridge) closeSMTP() error {
return nil
}
func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, shouldLog bool) *smtp.Server {
smtpServer := smtp.NewServer(&smtpBackend{users})
func newSMTPServer(bridge *Bridge, tlsConfig *tls.Config, shouldLog bool) *smtp.Server {
smtpServer := smtp.NewServer(&smtpBackend{Bridge: bridge})
smtpServer.TLSConfig = tlsConfig
smtpServer.Domain = constants.Host
@ -94,6 +92,7 @@ func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, s
log.Warning("================================================")
log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
log.Warning("================================================")
smtpServer.Debug = logging.NewSMTPDebugLogger()
}

View File

@ -22,16 +22,15 @@ import (
"io"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/emersion/go-smtp"
)
type smtpBackend struct {
users *safe.Map[string, *user.User]
*Bridge
}
type smtpSession struct {
users *safe.Map[string, *user.User]
*Bridge
userID string
authID string
@ -40,15 +39,13 @@ type smtpSession struct {
to []string
}
func (be *smtpBackend) NewSession(_ *smtp.Conn) (smtp.Session, error) {
return &smtpSession{
users: be.users,
}, nil
func (be *smtpBackend) NewSession(*smtp.Conn) (smtp.Session, error) {
return &smtpSession{Bridge: be.Bridge}, nil
}
func (s *smtpSession) AuthPlain(username, password string) error {
return s.users.ValuesErr(func(users []*user.User) error {
for _, user := range users {
return safe.RLockRet(func() error {
for _, user := range s.users {
addrID, err := user.CheckAuth(username, []byte(password))
if err != nil {
continue
@ -61,7 +58,7 @@ func (s *smtpSession) AuthPlain(username, password string) error {
}
return fmt.Errorf("invalid username or password")
})
}, &s.usersLock)
}
func (s *smtpSession) Reset() {
@ -88,13 +85,12 @@ func (s *smtpSession) Rcpt(to string) error {
}
func (s *smtpSession) Data(r io.Reader) error {
if ok, err := s.users.GetErr(s.userID, func(user *user.User) error {
return user.SendMail(s.authID, s.from, s.to, r)
}); !ok {
return fmt.Errorf("no such user %q", s.userID)
} else if err != nil {
return fmt.Errorf("failed to send mail: %w", err)
}
return safe.RLockRet(func() error {
user, ok := s.users[s.userID]
if !ok {
return ErrNoSuchUser
}
return nil
return user.SendMail(s.authID, s.from, s.to, r)
}, &s.usersLock)
}

View File

@ -66,32 +66,34 @@ func (bridge *Bridge) GetUserIDs() []string {
// GetUserInfo returns info about the given user.
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
if info, ok := safe.MapGetRet(bridge.users, userID, getConnUserInfo); ok {
return safe.RLockRetErr(func() (UserInfo, error) {
if user, ok := bridge.users[userID]; ok {
return getConnUserInfo(user), nil
}
var info UserInfo
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
}); err != nil {
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
}
return info, nil
}
var info UserInfo
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
}); err != nil {
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
}
return info, nil
}, &bridge.usersLock)
}
// QueryUserInfo queries the user info by username or address.
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
return safe.MapValuesRetErr(bridge.users, func(users []*user.User) (UserInfo, error) {
for _, user := range users {
return safe.RLockRetErr(func() (UserInfo, error) {
for _, user := range bridge.users {
if user.Match(query) {
return getConnUserInfo(user), nil
}
}
return UserInfo{}, ErrNoSuchUser
})
}, &bridge.usersLock)
}
// LoginAuth begins the login process. It returns an authorized client that might need 2FA.
@ -101,7 +103,9 @@ func (bridge *Bridge) LoginAuth(ctx context.Context, username string, password [
return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err)
}
if bridge.users.Has(auth.UserID) {
if ok := safe.RLockRet(func() bool {
return mapHas(bridge.users, auth.UID)
}, &bridge.usersLock); ok {
if err := client.AuthDelete(ctx); err != nil {
logrus.WithError(err).Warn("Failed to delete auth")
}
@ -182,31 +186,56 @@ func (bridge *Bridge) LoginFull(
// LogoutUser logs out the given user.
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
if err := bridge.logoutUser(ctx, userID); err != nil {
return fmt.Errorf("failed to logout user: %w", err)
}
return safe.LockRet(func() error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
bridge.publish(events.UserLoggedOut{
UserID: userID,
})
defer delete(bridge.users, user.ID())
return nil
bridge.logoutUser(ctx, user, true)
bridge.publish(events.UserLoggedOut{
UserID: userID,
})
return nil
}, &bridge.usersLock)
}
// DeleteUser deletes the given user.
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
bridge.deleteUser(ctx, userID)
return safe.LockRet(func() error {
if !bridge.vault.HasUser(userID) {
return ErrNoSuchUser
}
bridge.publish(events.UserDeleted{
UserID: userID,
})
if user, ok := bridge.users[userID]; ok {
defer delete(bridge.users, user.ID())
bridge.logoutUser(ctx, user, true)
}
return nil
if err := bridge.vault.DeleteUser(userID); err != nil {
logrus.WithError(err).Error("Failed to delete vault user")
}
bridge.publish(events.UserDeleted{
UserID: userID,
})
return nil
}, &bridge.usersLock)
}
// SetAddressMode sets the address mode for the given user.
func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error {
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error {
return safe.RLockRet(func() error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
if user.GetAddressMode() == mode {
return fmt.Errorf("address mode is already %q", mode)
}
@ -231,13 +260,7 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
})
return nil
}); !ok {
return ErrNoSuchUser
} else if err != nil {
return fmt.Errorf("failed to set address mode: %w", err)
}
return nil
}, &bridge.usersLock)
}
func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
@ -266,7 +289,13 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut
// loadUsers tries to load each user in the vault that isn't already loaded.
func (bridge *Bridge) loadUsers(ctx context.Context) error {
return bridge.vault.ForUser(func(user *vault.User) error {
if bridge.users.Has(user.UserID()) || user.AuthUID() == "" {
if user.AuthUID() == "" {
return nil
}
if safe.RLockRet(func() bool {
return mapHas(bridge.users, user.UserID())
}, &bridge.usersLock) {
return nil
}
@ -293,20 +322,16 @@ func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
return fmt.Errorf("failed to set auth: %w", err)
}
return try.Catch(
func() error {
apiUser, err := client.GetUser(ctx)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
apiUser, err := client.GetUser(ctx)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil {
return fmt.Errorf("failed to add user: %w", err)
}
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil {
return fmt.Errorf("failed to add user: %w", err)
}
return nil
},
)
return nil
}
// addUser adds a new user with an already salted mailbox password.
@ -364,10 +389,6 @@ func (bridge *Bridge) addUserWithVault(
return fmt.Errorf("failed to create user: %w", err)
}
if had := bridge.users.Set(apiUser.ID, user); had {
panic("double add")
}
// Connect the user's address(es) to gluon.
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
@ -395,6 +416,11 @@ func (bridge *Bridge) addUserWithVault(
return nil
})
// Finally, save the user in the bridge.
safe.Lock(func() {
bridge.users[apiUser.ID] = user
}, &bridge.usersLock)
return nil
}
@ -430,75 +456,17 @@ func (bridge *Bridge) newVaultUser(
return user, false, nil
}
// addIMAPUser connects the given user to gluon.
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
imapConn, err := user.NewIMAPConnectors()
if err != nil {
return fmt.Errorf("failed to create IMAP connectors: %w", err)
// logout logs out the given user, optionally logging them out from the API too.
func (bridge *Bridge) logoutUser(ctx context.Context, user *user.User, withAPI bool) {
if err := bridge.removeIMAPUser(ctx, user, false); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
for addrID, imapConn := range imapConn {
if gluonID, ok := user.GetGluonID(addrID); ok {
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to load IMAP user: %w", err)
}
} else {
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
}
if err := user.Logout(ctx, withAPI); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
return nil
}
// logoutUser logs the given user out from bridge.
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
}
if err := user.Logout(ctx); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
user.Close()
}); !ok {
return ErrNoSuchUser
}
return nil
}
// deleteUser deletes the given user from bridge.
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
}
if err := user.Logout(ctx); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
user.Close()
}); !ok {
logrus.Debug("The bridge user was not connected")
}
if err := bridge.vault.DeleteUser(userID); err != nil {
logrus.WithError(err).Error("Failed to delete user from vault")
}
user.Close()
}
// getUserInfo returns information about a disconnected user.
@ -523,3 +491,8 @@ func getConnUserInfo(user *user.User) UserInfo {
MaxSpace: user.MaxSpace(),
}
}
func mapHas[Key comparable, Val any](m map[Key]Val, key Key) bool {
_, ok := m[key]
return ok
}

View File

@ -22,6 +22,7 @@ import (
"fmt"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
)
@ -44,9 +45,11 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
}
case events.UserDeauth:
if err := bridge.logoutUser(context.Background(), event.UserID); err != nil {
return fmt.Errorf("failed to logout user: %w", err)
}
safe.Lock(func() {
defer delete(bridge.users, user.ID())
bridge.logoutUser(ctx, user, false)
}, &bridge.usersLock)
}
return nil

View File

@ -147,6 +147,46 @@ func TestBridge_LoginDeauthLogin(t *testing.T) {
})
}
func TestBridge_LoginDeauthRestartLogin(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginFull(ctx, username, password, nil, nil))
// Get a channel to receive the deauth event.
eventCh, done := bridge.GetEvents(events.UserDeauth{})
defer done()
// Deauth the user.
require.NoError(t, s.RevokeUser(userID))
// The user is eventually disconnected.
require.Eventually(t, func() bool {
return len(getConnectedUserIDs(t, bridge)) == 0
}, 10*time.Second, time.Second)
// We should get a deauth event.
require.IsType(t, events.UserDeauth{}, <-eventCh)
})
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user should be disconnected at startup.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
// Login the user after the disconnection.
newUserID := must(bridge.LoginFull(ctx, username, password, nil, nil))
require.Equal(t, userID, newUserID)
// The user is connected again.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginExpireLogin(t *testing.T) {
const authLife = 2 * time.Second
@ -449,6 +489,82 @@ func TestBridge_LoginLogoutRepeated(t *testing.T) {
})
}
func TestBridge_LogoutOffline(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// Go offline.
netCtl.Disable()
// We can still log the user out.
require.NoError(t, bridge.LogoutUser(ctx, userID))
// The user is now disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_DeleteDisconnected(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// Logout the user.
require.NoError(t, bridge.LogoutUser(ctx, userID))
// The user is now disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
// Delete the user.
require.NoError(t, bridge.DeleteUser(ctx, userID))
// The user is now deleted.
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_DeleteOffline(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// Go offline.
netCtl.Disable()
// We can still log the user out.
require.NoError(t, bridge.DeleteUser(ctx, userID))
// The user is now gone.
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
// getErr returns the error that was passed to it.
func getErr[T any](val T, err error) error {
return err

87
internal/safe/mutex.go Normal file
View File

@ -0,0 +1,87 @@
package safe
type Mutex interface {
Lock()
Unlock()
}
func Lock(fn func(), m ...Mutex) {
if len(m) == 0 {
panic("no mutexes provided")
}
for _, m := range m {
m.Lock()
defer m.Unlock()
}
fn()
}
func LockRet[T any](fn func() T, m ...Mutex) T {
var ret T
Lock(func() {
ret = fn()
}, m...)
return ret
}
func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) {
var ret T
err := LockRet(func() error {
var err error
ret, err = fn()
return err
}, m...)
return ret, err
}
type RWMutex interface {
Mutex
RLock()
RUnlock()
}
func RLock(fn func(), m ...RWMutex) {
if len(m) == 0 {
panic("no mutexes provided")
}
for _, m := range m {
m.RLock()
defer m.RUnlock()
}
fn()
}
func RLockRet[T any](fn func() T, m ...RWMutex) T {
var ret T
RLock(func() {
ret = fn()
}, m...)
return ret
}
func RLockRetErr[T any](fn func() (T, error), m ...RWMutex) (T, error) {
var err error
ret := RLockRet(func() T {
var ret T
ret, err = fn()
return ret
}, m...)
return ret, err
}

View File

@ -402,11 +402,13 @@ func (user *User) OnStatusDown() {
}
// Logout logs the user out from the API.
func (user *User) Logout(ctx context.Context) error {
func (user *User) Logout(ctx context.Context, withAPI bool) error {
user.tasks.Wait()
if err := user.client.AuthDelete(ctx); err != nil {
return fmt.Errorf("failed to delete auth: %w", err)
if withAPI {
if err := user.client.AuthDelete(ctx); err != nil {
return fmt.Errorf("failed to delete auth: %w", err)
}
}
if err := user.vault.Clear(); err != nil {