mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2026-02-14 04:48:32 +00:00
Other(refactor): Less unwieldy user type in Bridge
Instead of the annoying safe.Map type, we just use a normal go map and mutex pair, and use safe.Lock/safe.RLock as a helper function
This commit is contained in:
@ -66,32 +66,34 @@ func (bridge *Bridge) GetUserIDs() []string {
|
||||
|
||||
// GetUserInfo returns info about the given user.
|
||||
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
|
||||
if info, ok := safe.MapGetRet(bridge.users, userID, getConnUserInfo); ok {
|
||||
return safe.RLockRetErr(func() (UserInfo, error) {
|
||||
if user, ok := bridge.users[userID]; ok {
|
||||
return getConnUserInfo(user), nil
|
||||
}
|
||||
|
||||
var info UserInfo
|
||||
|
||||
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
|
||||
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
|
||||
}); err != nil {
|
||||
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
var info UserInfo
|
||||
|
||||
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
|
||||
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
|
||||
}); err != nil {
|
||||
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
// QueryUserInfo queries the user info by username or address.
|
||||
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
|
||||
return safe.MapValuesRetErr(bridge.users, func(users []*user.User) (UserInfo, error) {
|
||||
for _, user := range users {
|
||||
return safe.RLockRetErr(func() (UserInfo, error) {
|
||||
for _, user := range bridge.users {
|
||||
if user.Match(query) {
|
||||
return getConnUserInfo(user), nil
|
||||
}
|
||||
}
|
||||
|
||||
return UserInfo{}, ErrNoSuchUser
|
||||
})
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
// LoginAuth begins the login process. It returns an authorized client that might need 2FA.
|
||||
@ -101,7 +103,9 @@ func (bridge *Bridge) LoginAuth(ctx context.Context, username string, password [
|
||||
return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err)
|
||||
}
|
||||
|
||||
if bridge.users.Has(auth.UserID) {
|
||||
if ok := safe.RLockRet(func() bool {
|
||||
return mapHas(bridge.users, auth.UID)
|
||||
}, &bridge.usersLock); ok {
|
||||
if err := client.AuthDelete(ctx); err != nil {
|
||||
logrus.WithError(err).Warn("Failed to delete auth")
|
||||
}
|
||||
@ -182,31 +186,56 @@ func (bridge *Bridge) LoginFull(
|
||||
|
||||
// LogoutUser logs out the given user.
|
||||
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
|
||||
if err := bridge.logoutUser(ctx, userID); err != nil {
|
||||
return fmt.Errorf("failed to logout user: %w", err)
|
||||
}
|
||||
return safe.LockRet(func() error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
bridge.publish(events.UserLoggedOut{
|
||||
UserID: userID,
|
||||
})
|
||||
defer delete(bridge.users, user.ID())
|
||||
|
||||
return nil
|
||||
bridge.logoutUser(ctx, user, true)
|
||||
|
||||
bridge.publish(events.UserLoggedOut{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
// DeleteUser deletes the given user.
|
||||
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
|
||||
bridge.deleteUser(ctx, userID)
|
||||
return safe.LockRet(func() error {
|
||||
if !bridge.vault.HasUser(userID) {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
bridge.publish(events.UserDeleted{
|
||||
UserID: userID,
|
||||
})
|
||||
if user, ok := bridge.users[userID]; ok {
|
||||
defer delete(bridge.users, user.ID())
|
||||
bridge.logoutUser(ctx, user, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to delete vault user")
|
||||
}
|
||||
|
||||
bridge.publish(events.UserDeleted{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
// SetAddressMode sets the address mode for the given user.
|
||||
func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error {
|
||||
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error {
|
||||
return safe.RLockRet(func() error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
if user.GetAddressMode() == mode {
|
||||
return fmt.Errorf("address mode is already %q", mode)
|
||||
}
|
||||
@ -231,13 +260,7 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
|
||||
})
|
||||
|
||||
return nil
|
||||
}); !ok {
|
||||
return ErrNoSuchUser
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
|
||||
@ -266,7 +289,13 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut
|
||||
// loadUsers tries to load each user in the vault that isn't already loaded.
|
||||
func (bridge *Bridge) loadUsers(ctx context.Context) error {
|
||||
return bridge.vault.ForUser(func(user *vault.User) error {
|
||||
if bridge.users.Has(user.UserID()) || user.AuthUID() == "" {
|
||||
if user.AuthUID() == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if safe.RLockRet(func() bool {
|
||||
return mapHas(bridge.users, user.UserID())
|
||||
}, &bridge.usersLock) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -293,20 +322,16 @@ func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
|
||||
return fmt.Errorf("failed to set auth: %w", err)
|
||||
}
|
||||
|
||||
return try.Catch(
|
||||
func() error {
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil {
|
||||
return fmt.Errorf("failed to add user: %w", err)
|
||||
}
|
||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil {
|
||||
return fmt.Errorf("failed to add user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// addUser adds a new user with an already salted mailbox password.
|
||||
@ -364,10 +389,6 @@ func (bridge *Bridge) addUserWithVault(
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
if had := bridge.users.Set(apiUser.ID, user); had {
|
||||
panic("double add")
|
||||
}
|
||||
|
||||
// Connect the user's address(es) to gluon.
|
||||
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
@ -395,6 +416,11 @@ func (bridge *Bridge) addUserWithVault(
|
||||
return nil
|
||||
})
|
||||
|
||||
// Finally, save the user in the bridge.
|
||||
safe.Lock(func() {
|
||||
bridge.users[apiUser.ID] = user
|
||||
}, &bridge.usersLock)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -430,75 +456,17 @@ func (bridge *Bridge) newVaultUser(
|
||||
return user, false, nil
|
||||
}
|
||||
|
||||
// addIMAPUser connects the given user to gluon.
|
||||
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create IMAP connectors: %w", err)
|
||||
// logout logs out the given user, optionally logging them out from the API too.
|
||||
func (bridge *Bridge) logoutUser(ctx context.Context, user *user.User, withAPI bool) {
|
||||
if err := bridge.removeIMAPUser(ctx, user, false); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||
}
|
||||
|
||||
for addrID, imapConn := range imapConn {
|
||||
if gluonID, ok := user.GetGluonID(addrID); ok {
|
||||
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
|
||||
return fmt.Errorf("failed to load IMAP user: %w", err)
|
||||
}
|
||||
} else {
|
||||
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if err := user.SetGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to set IMAP user ID: %w", err)
|
||||
}
|
||||
}
|
||||
if err := user.Logout(ctx, withAPI); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logoutUser logs the given user out from bridge.
|
||||
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
|
||||
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
user.Close()
|
||||
}); !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteUser deletes the given user from bridge.
|
||||
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
|
||||
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
user.Close()
|
||||
}); !ok {
|
||||
logrus.Debug("The bridge user was not connected")
|
||||
}
|
||||
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to delete user from vault")
|
||||
}
|
||||
user.Close()
|
||||
}
|
||||
|
||||
// getUserInfo returns information about a disconnected user.
|
||||
@ -523,3 +491,8 @@ func getConnUserInfo(user *user.User) UserInfo {
|
||||
MaxSpace: user.MaxSpace(),
|
||||
}
|
||||
}
|
||||
|
||||
func mapHas[Key comparable, Val any](m map[Key]Val, key Key) bool {
|
||||
_, ok := m[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user