forked from Silverfish/proton-bridge
Other(refactor): Less unwieldy user type in Bridge
Instead of the annoying safe.Map type, we just use a normal go map and mutex pair, and use safe.Lock/safe.RLock as a helper function
This commit is contained in:
@ -51,8 +51,8 @@ type Bridge struct {
|
||||
vault *vault.Vault
|
||||
|
||||
// users holds authorized users.
|
||||
users *safe.Map[string, *user.User]
|
||||
goLoad func()
|
||||
users map[string]*user.User
|
||||
usersLock sync.RWMutex
|
||||
|
||||
// api manages user API clients.
|
||||
api *liteapi.Manager
|
||||
@ -73,7 +73,6 @@ type Bridge struct {
|
||||
|
||||
// updater is the bridge's updater.
|
||||
updater Updater
|
||||
goUpdate func()
|
||||
curVersion *semver.Version
|
||||
|
||||
// focusService is used to raise the bridge window when needed.
|
||||
@ -99,6 +98,12 @@ type Bridge struct {
|
||||
|
||||
// tasks manages the bridge's goroutines.
|
||||
tasks *xsync.Group
|
||||
|
||||
// goLoad triggers a load of disconnected users from the vault.
|
||||
goLoad func()
|
||||
|
||||
// goUpdate triggers a check/install of updates.
|
||||
goUpdate func()
|
||||
}
|
||||
|
||||
// New creates a new bridge.
|
||||
@ -133,12 +138,8 @@ func New( //nolint:funlen
|
||||
// imapEventCh forwards IMAP events from gluon instances to the bridge for processing.
|
||||
imapEventCh := make(chan imapEvents.Event)
|
||||
|
||||
// users holds all the bridge's users.
|
||||
users := safe.NewMap[string, *user.User](nil)
|
||||
|
||||
// bridge is the bridge.
|
||||
bridge, err := newBridge(
|
||||
users,
|
||||
tasks,
|
||||
imapEventCh,
|
||||
|
||||
@ -180,7 +181,6 @@ func New( //nolint:funlen
|
||||
|
||||
// nolint:funlen
|
||||
func newBridge(
|
||||
users *safe.Map[string, *user.User],
|
||||
tasks *xsync.Group,
|
||||
imapEventCh chan imapEvents.Event,
|
||||
|
||||
@ -224,9 +224,9 @@ func newBridge(
|
||||
return nil, fmt.Errorf("failed to create focus service: %w", err)
|
||||
}
|
||||
|
||||
return &Bridge{
|
||||
bridge := &Bridge{
|
||||
vault: vault,
|
||||
users: users,
|
||||
users: make(map[string]*user.User),
|
||||
|
||||
api: api,
|
||||
proxyCtl: proxyCtl,
|
||||
@ -235,7 +235,6 @@ func newBridge(
|
||||
tlsConfig: tlsConfig,
|
||||
imapServer: imapServer,
|
||||
imapEventCh: imapEventCh,
|
||||
smtpServer: newSMTPServer(users, tlsConfig, logSMTP),
|
||||
|
||||
updater: updater,
|
||||
curVersion: curVersion,
|
||||
@ -249,7 +248,11 @@ func newBridge(
|
||||
logSMTP: logSMTP,
|
||||
|
||||
tasks: tasks,
|
||||
}, nil
|
||||
}
|
||||
|
||||
bridge.smtpServer = newSMTPServer(bridge, tlsConfig, logSMTP)
|
||||
|
||||
return bridge, nil
|
||||
}
|
||||
|
||||
// nolint:funlen
|
||||
@ -265,10 +268,10 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
|
||||
bridge.api.AddStatusObserver(func(status liteapi.Status) {
|
||||
switch {
|
||||
case status == liteapi.StatusUp:
|
||||
bridge.onStatusUp()
|
||||
go bridge.onStatusUp()
|
||||
|
||||
case status == liteapi.StatusDown:
|
||||
bridge.onStatusDown()
|
||||
go bridge.onStatusDown()
|
||||
}
|
||||
})
|
||||
|
||||
@ -356,9 +359,11 @@ func (bridge *Bridge) Close(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Close all users.
|
||||
bridge.users.IterValues(func(user *user.User) {
|
||||
user.Close()
|
||||
})
|
||||
safe.RLock(func() {
|
||||
for _, user := range bridge.users {
|
||||
user.Close()
|
||||
}
|
||||
}, &bridge.usersLock)
|
||||
|
||||
// Stop all ongoing tasks.
|
||||
bridge.tasks.Wait()
|
||||
@ -426,19 +431,23 @@ func (bridge *Bridge) remWatcher(watcher *watcher.Watcher[events.Event]) {
|
||||
func (bridge *Bridge) onStatusUp() {
|
||||
bridge.publish(events.ConnStatusUp{})
|
||||
|
||||
bridge.goLoad()
|
||||
safe.RLock(func() {
|
||||
for _, user := range bridge.users {
|
||||
user.OnStatusUp()
|
||||
}
|
||||
}, &bridge.usersLock)
|
||||
|
||||
bridge.users.IterValues(func(user *user.User) {
|
||||
go user.OnStatusUp()
|
||||
})
|
||||
bridge.goLoad()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) onStatusDown() {
|
||||
bridge.publish(events.ConnStatusDown{})
|
||||
|
||||
bridge.users.IterValues(func(user *user.User) {
|
||||
go user.OnStatusDown()
|
||||
})
|
||||
safe.RLock(func() {
|
||||
for _, user := range bridge.users {
|
||||
user.OnStatusDown()
|
||||
}
|
||||
}, &bridge.usersLock)
|
||||
|
||||
bridge.tasks.Once(func(ctx context.Context) {
|
||||
backoff := time.Second
|
||||
|
||||
@ -18,18 +18,24 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/clientconfig"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
)
|
||||
|
||||
// ConfigureAppleMail configures apple mail for the given userID and address.
|
||||
// If configuring apple mail for Catalina or newer, it ensures Bridge is using SSL.
|
||||
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
||||
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error {
|
||||
return safe.RLockRet(func() error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
if address == "" {
|
||||
address = user.Emails()[0]
|
||||
}
|
||||
@ -42,7 +48,6 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
||||
addresses = strings.Join(user.Emails(), ",")
|
||||
}
|
||||
|
||||
// If configuring apple mail for Catalina or newer, users should use SSL.
|
||||
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
|
||||
if err := bridge.SetSMTPSSL(true); err != nil {
|
||||
return err
|
||||
@ -59,11 +64,5 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
||||
addresses,
|
||||
user.BridgePass(),
|
||||
)
|
||||
}); !ok {
|
||||
return ErrNoSuchUser
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to configure apple mail: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
@ -32,6 +32,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/async"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xsync"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -83,6 +84,44 @@ func (bridge *Bridge) closeIMAP(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addIMAPUser connects the given user to gluon.
|
||||
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create IMAP connectors: %w", err)
|
||||
}
|
||||
|
||||
for addrID, imapConn := range imapConn {
|
||||
if gluonID, ok := user.GetGluonID(addrID); ok {
|
||||
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
|
||||
return fmt.Errorf("failed to load IMAP user: %w", err)
|
||||
}
|
||||
} else {
|
||||
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if err := user.SetGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to set IMAP user ID: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeIMAPUser disconnects the given user from gluon, optionally also removing its files.
|
||||
func (bridge *Bridge) removeIMAPUser(ctx context.Context, user *user.User, withFiles bool) error {
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
|
||||
return fmt.Errorf("failed to remove IMAP user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
|
||||
switch event := event.(type) {
|
||||
case imapEvents.SessionAdded:
|
||||
|
||||
@ -24,8 +24,8 @@ import (
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -118,48 +118,50 @@ func (bridge *Bridge) GetGluonDir() string {
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
|
||||
if newGluonDir == bridge.GetGluonDir() {
|
||||
return fmt.Errorf("new gluon dir is the same as the old one")
|
||||
}
|
||||
return safe.RLockRet(func() error {
|
||||
if newGluonDir == bridge.GetGluonDir() {
|
||||
return fmt.Errorf("new gluon dir is the same as the old one")
|
||||
}
|
||||
|
||||
if err := bridge.closeIMAP(context.Background()); err != nil {
|
||||
return fmt.Errorf("failed to close IMAP: %w", err)
|
||||
}
|
||||
if err := bridge.closeIMAP(context.Background()); err != nil {
|
||||
return fmt.Errorf("failed to close IMAP: %w", err)
|
||||
}
|
||||
|
||||
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
|
||||
return fmt.Errorf("failed to move gluon dir: %w", err)
|
||||
}
|
||||
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
|
||||
return fmt.Errorf("failed to move gluon dir: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
|
||||
return fmt.Errorf("failed to set new gluon dir: %w", err)
|
||||
}
|
||||
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
|
||||
return fmt.Errorf("failed to set new gluon dir: %w", err)
|
||||
}
|
||||
|
||||
imapServer, err := newIMAPServer(
|
||||
bridge.vault.GetGluonDir(),
|
||||
bridge.curVersion,
|
||||
bridge.tlsConfig,
|
||||
bridge.logIMAPClient,
|
||||
bridge.logIMAPServer,
|
||||
bridge.imapEventCh,
|
||||
bridge.tasks,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new IMAP server: %w", err)
|
||||
}
|
||||
imapServer, err := newIMAPServer(
|
||||
bridge.vault.GetGluonDir(),
|
||||
bridge.curVersion,
|
||||
bridge.tlsConfig,
|
||||
bridge.logIMAPClient,
|
||||
bridge.logIMAPServer,
|
||||
bridge.imapEventCh,
|
||||
bridge.tasks,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new IMAP server: %w", err)
|
||||
}
|
||||
|
||||
bridge.imapServer = imapServer
|
||||
bridge.imapServer = imapServer
|
||||
|
||||
if err := bridge.users.IterValuesErr(func(user *user.User) error {
|
||||
return bridge.addIMAPUser(ctx, user)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
|
||||
}
|
||||
for _, user := range bridge.users {
|
||||
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
||||
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := bridge.serveIMAP(); err != nil {
|
||||
return fmt.Errorf("failed to serve IMAP: %w", err)
|
||||
}
|
||||
if err := bridge.serveIMAP(); err != nil {
|
||||
return fmt.Errorf("failed to serve IMAP: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetProxyAllowed() bool {
|
||||
@ -181,11 +183,13 @@ func (bridge *Bridge) GetShowAllMail() bool {
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetShowAllMail(show bool) error {
|
||||
bridge.users.IterValues(func(user *user.User) {
|
||||
user.SetShowAllMail(show)
|
||||
})
|
||||
return safe.RLockRet(func() error {
|
||||
for _, user := range bridge.users {
|
||||
user.SetShowAllMail(show)
|
||||
}
|
||||
|
||||
return bridge.vault.SetShowAllMail(show)
|
||||
return bridge.vault.SetShowAllMail(show)
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetAutostart() bool {
|
||||
@ -273,14 +277,18 @@ func (bridge *Bridge) SetColorScheme(colorScheme string) error {
|
||||
}
|
||||
|
||||
func (bridge *Bridge) FactoryReset(ctx context.Context) {
|
||||
// First delete all users.
|
||||
for _, userID := range bridge.GetUserIDs() {
|
||||
if bridge.users.Has(userID) {
|
||||
if err := bridge.DeleteUser(ctx, userID); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to delete user %s", userID)
|
||||
// Delete all the users.
|
||||
safe.Lock(func() {
|
||||
for _, user := range bridge.users {
|
||||
bridge.logoutUser(ctx, user, true)
|
||||
}
|
||||
|
||||
for _, user := range bridge.vault.GetUserIDs() {
|
||||
if err := bridge.vault.DeleteUser(user); err != nil {
|
||||
logrus.WithError(err).Error("failed to delete vault user")
|
||||
}
|
||||
}
|
||||
}
|
||||
}, &bridge.usersLock)
|
||||
|
||||
// Then delete all files.
|
||||
if err := bridge.locator.Clear(); err != nil {
|
||||
|
||||
@ -23,8 +23,6 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/emersion/go-smtp"
|
||||
@ -57,7 +55,7 @@ func (bridge *Bridge) restartSMTP() error {
|
||||
return fmt.Errorf("failed to close SMTP: %w", err)
|
||||
}
|
||||
|
||||
bridge.smtpServer = newSMTPServer(bridge.users, bridge.tlsConfig, bridge.logSMTP)
|
||||
bridge.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
|
||||
|
||||
return bridge.serveSMTP()
|
||||
}
|
||||
@ -80,8 +78,8 @@ func (bridge *Bridge) closeSMTP() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, shouldLog bool) *smtp.Server {
|
||||
smtpServer := smtp.NewServer(&smtpBackend{users})
|
||||
func newSMTPServer(bridge *Bridge, tlsConfig *tls.Config, shouldLog bool) *smtp.Server {
|
||||
smtpServer := smtp.NewServer(&smtpBackend{Bridge: bridge})
|
||||
|
||||
smtpServer.TLSConfig = tlsConfig
|
||||
smtpServer.Domain = constants.Host
|
||||
@ -94,6 +92,7 @@ func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, s
|
||||
log.Warning("================================================")
|
||||
log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
|
||||
log.Warning("================================================")
|
||||
|
||||
smtpServer.Debug = logging.NewSMTPDebugLogger()
|
||||
}
|
||||
|
||||
|
||||
@ -22,16 +22,15 @@ import (
|
||||
"io"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/emersion/go-smtp"
|
||||
)
|
||||
|
||||
type smtpBackend struct {
|
||||
users *safe.Map[string, *user.User]
|
||||
*Bridge
|
||||
}
|
||||
|
||||
type smtpSession struct {
|
||||
users *safe.Map[string, *user.User]
|
||||
*Bridge
|
||||
|
||||
userID string
|
||||
authID string
|
||||
@ -40,15 +39,13 @@ type smtpSession struct {
|
||||
to []string
|
||||
}
|
||||
|
||||
func (be *smtpBackend) NewSession(_ *smtp.Conn) (smtp.Session, error) {
|
||||
return &smtpSession{
|
||||
users: be.users,
|
||||
}, nil
|
||||
func (be *smtpBackend) NewSession(*smtp.Conn) (smtp.Session, error) {
|
||||
return &smtpSession{Bridge: be.Bridge}, nil
|
||||
}
|
||||
|
||||
func (s *smtpSession) AuthPlain(username, password string) error {
|
||||
return s.users.ValuesErr(func(users []*user.User) error {
|
||||
for _, user := range users {
|
||||
return safe.RLockRet(func() error {
|
||||
for _, user := range s.users {
|
||||
addrID, err := user.CheckAuth(username, []byte(password))
|
||||
if err != nil {
|
||||
continue
|
||||
@ -61,7 +58,7 @@ func (s *smtpSession) AuthPlain(username, password string) error {
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid username or password")
|
||||
})
|
||||
}, &s.usersLock)
|
||||
}
|
||||
|
||||
func (s *smtpSession) Reset() {
|
||||
@ -88,13 +85,12 @@ func (s *smtpSession) Rcpt(to string) error {
|
||||
}
|
||||
|
||||
func (s *smtpSession) Data(r io.Reader) error {
|
||||
if ok, err := s.users.GetErr(s.userID, func(user *user.User) error {
|
||||
return user.SendMail(s.authID, s.from, s.to, r)
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such user %q", s.userID)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to send mail: %w", err)
|
||||
}
|
||||
return safe.RLockRet(func() error {
|
||||
user, ok := s.users[s.userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
return nil
|
||||
return user.SendMail(s.authID, s.from, s.to, r)
|
||||
}, &s.usersLock)
|
||||
}
|
||||
|
||||
@ -66,32 +66,34 @@ func (bridge *Bridge) GetUserIDs() []string {
|
||||
|
||||
// GetUserInfo returns info about the given user.
|
||||
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
|
||||
if info, ok := safe.MapGetRet(bridge.users, userID, getConnUserInfo); ok {
|
||||
return safe.RLockRetErr(func() (UserInfo, error) {
|
||||
if user, ok := bridge.users[userID]; ok {
|
||||
return getConnUserInfo(user), nil
|
||||
}
|
||||
|
||||
var info UserInfo
|
||||
|
||||
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
|
||||
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
|
||||
}); err != nil {
|
||||
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
var info UserInfo
|
||||
|
||||
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
|
||||
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
|
||||
}); err != nil {
|
||||
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
// QueryUserInfo queries the user info by username or address.
|
||||
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
|
||||
return safe.MapValuesRetErr(bridge.users, func(users []*user.User) (UserInfo, error) {
|
||||
for _, user := range users {
|
||||
return safe.RLockRetErr(func() (UserInfo, error) {
|
||||
for _, user := range bridge.users {
|
||||
if user.Match(query) {
|
||||
return getConnUserInfo(user), nil
|
||||
}
|
||||
}
|
||||
|
||||
return UserInfo{}, ErrNoSuchUser
|
||||
})
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
// LoginAuth begins the login process. It returns an authorized client that might need 2FA.
|
||||
@ -101,7 +103,9 @@ func (bridge *Bridge) LoginAuth(ctx context.Context, username string, password [
|
||||
return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err)
|
||||
}
|
||||
|
||||
if bridge.users.Has(auth.UserID) {
|
||||
if ok := safe.RLockRet(func() bool {
|
||||
return mapHas(bridge.users, auth.UID)
|
||||
}, &bridge.usersLock); ok {
|
||||
if err := client.AuthDelete(ctx); err != nil {
|
||||
logrus.WithError(err).Warn("Failed to delete auth")
|
||||
}
|
||||
@ -182,31 +186,56 @@ func (bridge *Bridge) LoginFull(
|
||||
|
||||
// LogoutUser logs out the given user.
|
||||
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
|
||||
if err := bridge.logoutUser(ctx, userID); err != nil {
|
||||
return fmt.Errorf("failed to logout user: %w", err)
|
||||
}
|
||||
return safe.LockRet(func() error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
bridge.publish(events.UserLoggedOut{
|
||||
UserID: userID,
|
||||
})
|
||||
defer delete(bridge.users, user.ID())
|
||||
|
||||
return nil
|
||||
bridge.logoutUser(ctx, user, true)
|
||||
|
||||
bridge.publish(events.UserLoggedOut{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
// DeleteUser deletes the given user.
|
||||
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
|
||||
bridge.deleteUser(ctx, userID)
|
||||
return safe.LockRet(func() error {
|
||||
if !bridge.vault.HasUser(userID) {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
bridge.publish(events.UserDeleted{
|
||||
UserID: userID,
|
||||
})
|
||||
if user, ok := bridge.users[userID]; ok {
|
||||
defer delete(bridge.users, user.ID())
|
||||
bridge.logoutUser(ctx, user, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to delete vault user")
|
||||
}
|
||||
|
||||
bridge.publish(events.UserDeleted{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
// SetAddressMode sets the address mode for the given user.
|
||||
func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error {
|
||||
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error {
|
||||
return safe.RLockRet(func() error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
if user.GetAddressMode() == mode {
|
||||
return fmt.Errorf("address mode is already %q", mode)
|
||||
}
|
||||
@ -231,13 +260,7 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
|
||||
})
|
||||
|
||||
return nil
|
||||
}); !ok {
|
||||
return ErrNoSuchUser
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
|
||||
@ -266,7 +289,13 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut
|
||||
// loadUsers tries to load each user in the vault that isn't already loaded.
|
||||
func (bridge *Bridge) loadUsers(ctx context.Context) error {
|
||||
return bridge.vault.ForUser(func(user *vault.User) error {
|
||||
if bridge.users.Has(user.UserID()) || user.AuthUID() == "" {
|
||||
if user.AuthUID() == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if safe.RLockRet(func() bool {
|
||||
return mapHas(bridge.users, user.UserID())
|
||||
}, &bridge.usersLock) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -293,20 +322,16 @@ func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
|
||||
return fmt.Errorf("failed to set auth: %w", err)
|
||||
}
|
||||
|
||||
return try.Catch(
|
||||
func() error {
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil {
|
||||
return fmt.Errorf("failed to add user: %w", err)
|
||||
}
|
||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil {
|
||||
return fmt.Errorf("failed to add user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// addUser adds a new user with an already salted mailbox password.
|
||||
@ -364,10 +389,6 @@ func (bridge *Bridge) addUserWithVault(
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
if had := bridge.users.Set(apiUser.ID, user); had {
|
||||
panic("double add")
|
||||
}
|
||||
|
||||
// Connect the user's address(es) to gluon.
|
||||
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
@ -395,6 +416,11 @@ func (bridge *Bridge) addUserWithVault(
|
||||
return nil
|
||||
})
|
||||
|
||||
// Finally, save the user in the bridge.
|
||||
safe.Lock(func() {
|
||||
bridge.users[apiUser.ID] = user
|
||||
}, &bridge.usersLock)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -430,75 +456,17 @@ func (bridge *Bridge) newVaultUser(
|
||||
return user, false, nil
|
||||
}
|
||||
|
||||
// addIMAPUser connects the given user to gluon.
|
||||
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create IMAP connectors: %w", err)
|
||||
// logout logs out the given user, optionally logging them out from the API too.
|
||||
func (bridge *Bridge) logoutUser(ctx context.Context, user *user.User, withAPI bool) {
|
||||
if err := bridge.removeIMAPUser(ctx, user, false); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||
}
|
||||
|
||||
for addrID, imapConn := range imapConn {
|
||||
if gluonID, ok := user.GetGluonID(addrID); ok {
|
||||
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
|
||||
return fmt.Errorf("failed to load IMAP user: %w", err)
|
||||
}
|
||||
} else {
|
||||
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if err := user.SetGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to set IMAP user ID: %w", err)
|
||||
}
|
||||
}
|
||||
if err := user.Logout(ctx, withAPI); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logoutUser logs the given user out from bridge.
|
||||
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
|
||||
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
user.Close()
|
||||
}); !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteUser deletes the given user from bridge.
|
||||
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
|
||||
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
user.Close()
|
||||
}); !ok {
|
||||
logrus.Debug("The bridge user was not connected")
|
||||
}
|
||||
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to delete user from vault")
|
||||
}
|
||||
user.Close()
|
||||
}
|
||||
|
||||
// getUserInfo returns information about a disconnected user.
|
||||
@ -523,3 +491,8 @@ func getConnUserInfo(user *user.User) UserInfo {
|
||||
MaxSpace: user.MaxSpace(),
|
||||
}
|
||||
}
|
||||
|
||||
func mapHas[Key comparable, Val any](m map[Key]Val, key Key) bool {
|
||||
_, ok := m[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
)
|
||||
@ -44,9 +45,11 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
|
||||
}
|
||||
|
||||
case events.UserDeauth:
|
||||
if err := bridge.logoutUser(context.Background(), event.UserID); err != nil {
|
||||
return fmt.Errorf("failed to logout user: %w", err)
|
||||
}
|
||||
safe.Lock(func() {
|
||||
defer delete(bridge.users, user.ID())
|
||||
|
||||
bridge.logoutUser(ctx, user, false)
|
||||
}, &bridge.usersLock)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@ -147,6 +147,46 @@ func TestBridge_LoginDeauthLogin(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LoginDeauthRestartLogin(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginFull(ctx, username, password, nil, nil))
|
||||
|
||||
// Get a channel to receive the deauth event.
|
||||
eventCh, done := bridge.GetEvents(events.UserDeauth{})
|
||||
defer done()
|
||||
|
||||
// Deauth the user.
|
||||
require.NoError(t, s.RevokeUser(userID))
|
||||
|
||||
// The user is eventually disconnected.
|
||||
require.Eventually(t, func() bool {
|
||||
return len(getConnectedUserIDs(t, bridge)) == 0
|
||||
}, 10*time.Second, time.Second)
|
||||
|
||||
// We should get a deauth event.
|
||||
require.IsType(t, events.UserDeauth{}, <-eventCh)
|
||||
})
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user should be disconnected at startup.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
|
||||
// Login the user after the disconnection.
|
||||
newUserID := must(bridge.LoginFull(ctx, username, password, nil, nil))
|
||||
require.Equal(t, userID, newUserID)
|
||||
|
||||
// The user is connected again.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LoginExpireLogin(t *testing.T) {
|
||||
const authLife = 2 * time.Second
|
||||
|
||||
@ -449,6 +489,82 @@ func TestBridge_LoginLogoutRepeated(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LogoutOffline(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The user is now connected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
|
||||
// Go offline.
|
||||
netCtl.Disable()
|
||||
|
||||
// We can still log the user out.
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
|
||||
// The user is now disconnected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_DeleteDisconnected(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The user is now connected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
|
||||
// Logout the user.
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
|
||||
// The user is now disconnected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
|
||||
// Delete the user.
|
||||
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||
|
||||
// The user is now deleted.
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_DeleteOffline(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The user is now connected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
|
||||
// Go offline.
|
||||
netCtl.Disable()
|
||||
|
||||
// We can still log the user out.
|
||||
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||
|
||||
// The user is now gone.
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// getErr returns the error that was passed to it.
|
||||
func getErr[T any](val T, err error) error {
|
||||
return err
|
||||
|
||||
87
internal/safe/mutex.go
Normal file
87
internal/safe/mutex.go
Normal file
@ -0,0 +1,87 @@
|
||||
package safe
|
||||
|
||||
type Mutex interface {
|
||||
Lock()
|
||||
Unlock()
|
||||
}
|
||||
|
||||
func Lock(fn func(), m ...Mutex) {
|
||||
if len(m) == 0 {
|
||||
panic("no mutexes provided")
|
||||
}
|
||||
|
||||
for _, m := range m {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
}
|
||||
|
||||
fn()
|
||||
}
|
||||
|
||||
func LockRet[T any](fn func() T, m ...Mutex) T {
|
||||
var ret T
|
||||
|
||||
Lock(func() {
|
||||
ret = fn()
|
||||
}, m...)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) {
|
||||
var ret T
|
||||
|
||||
err := LockRet(func() error {
|
||||
var err error
|
||||
|
||||
ret, err = fn()
|
||||
|
||||
return err
|
||||
}, m...)
|
||||
|
||||
return ret, err
|
||||
}
|
||||
|
||||
type RWMutex interface {
|
||||
Mutex
|
||||
|
||||
RLock()
|
||||
RUnlock()
|
||||
}
|
||||
|
||||
func RLock(fn func(), m ...RWMutex) {
|
||||
if len(m) == 0 {
|
||||
panic("no mutexes provided")
|
||||
}
|
||||
|
||||
for _, m := range m {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
}
|
||||
|
||||
fn()
|
||||
}
|
||||
|
||||
func RLockRet[T any](fn func() T, m ...RWMutex) T {
|
||||
var ret T
|
||||
|
||||
RLock(func() {
|
||||
ret = fn()
|
||||
}, m...)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func RLockRetErr[T any](fn func() (T, error), m ...RWMutex) (T, error) {
|
||||
var err error
|
||||
|
||||
ret := RLockRet(func() T {
|
||||
var ret T
|
||||
|
||||
ret, err = fn()
|
||||
|
||||
return ret
|
||||
}, m...)
|
||||
|
||||
return ret, err
|
||||
}
|
||||
@ -402,11 +402,13 @@ func (user *User) OnStatusDown() {
|
||||
}
|
||||
|
||||
// Logout logs the user out from the API.
|
||||
func (user *User) Logout(ctx context.Context) error {
|
||||
func (user *User) Logout(ctx context.Context, withAPI bool) error {
|
||||
user.tasks.Wait()
|
||||
|
||||
if err := user.client.AuthDelete(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete auth: %w", err)
|
||||
if withAPI {
|
||||
if err := user.client.AuthDelete(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete auth: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.vault.Clear(); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user