mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2026-02-14 12:48:33 +00:00
Other(refactor): Less unwieldy user type in Bridge
Instead of the annoying safe.Map type, we just use a normal go map and mutex pair, and use safe.Lock/safe.RLock as a helper function
This commit is contained in:
@ -51,8 +51,8 @@ type Bridge struct {
|
|||||||
vault *vault.Vault
|
vault *vault.Vault
|
||||||
|
|
||||||
// users holds authorized users.
|
// users holds authorized users.
|
||||||
users *safe.Map[string, *user.User]
|
users map[string]*user.User
|
||||||
goLoad func()
|
usersLock sync.RWMutex
|
||||||
|
|
||||||
// api manages user API clients.
|
// api manages user API clients.
|
||||||
api *liteapi.Manager
|
api *liteapi.Manager
|
||||||
@ -73,7 +73,6 @@ type Bridge struct {
|
|||||||
|
|
||||||
// updater is the bridge's updater.
|
// updater is the bridge's updater.
|
||||||
updater Updater
|
updater Updater
|
||||||
goUpdate func()
|
|
||||||
curVersion *semver.Version
|
curVersion *semver.Version
|
||||||
|
|
||||||
// focusService is used to raise the bridge window when needed.
|
// focusService is used to raise the bridge window when needed.
|
||||||
@ -99,6 +98,12 @@ type Bridge struct {
|
|||||||
|
|
||||||
// tasks manages the bridge's goroutines.
|
// tasks manages the bridge's goroutines.
|
||||||
tasks *xsync.Group
|
tasks *xsync.Group
|
||||||
|
|
||||||
|
// goLoad triggers a load of disconnected users from the vault.
|
||||||
|
goLoad func()
|
||||||
|
|
||||||
|
// goUpdate triggers a check/install of updates.
|
||||||
|
goUpdate func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new bridge.
|
// New creates a new bridge.
|
||||||
@ -133,12 +138,8 @@ func New( //nolint:funlen
|
|||||||
// imapEventCh forwards IMAP events from gluon instances to the bridge for processing.
|
// imapEventCh forwards IMAP events from gluon instances to the bridge for processing.
|
||||||
imapEventCh := make(chan imapEvents.Event)
|
imapEventCh := make(chan imapEvents.Event)
|
||||||
|
|
||||||
// users holds all the bridge's users.
|
|
||||||
users := safe.NewMap[string, *user.User](nil)
|
|
||||||
|
|
||||||
// bridge is the bridge.
|
// bridge is the bridge.
|
||||||
bridge, err := newBridge(
|
bridge, err := newBridge(
|
||||||
users,
|
|
||||||
tasks,
|
tasks,
|
||||||
imapEventCh,
|
imapEventCh,
|
||||||
|
|
||||||
@ -180,7 +181,6 @@ func New( //nolint:funlen
|
|||||||
|
|
||||||
// nolint:funlen
|
// nolint:funlen
|
||||||
func newBridge(
|
func newBridge(
|
||||||
users *safe.Map[string, *user.User],
|
|
||||||
tasks *xsync.Group,
|
tasks *xsync.Group,
|
||||||
imapEventCh chan imapEvents.Event,
|
imapEventCh chan imapEvents.Event,
|
||||||
|
|
||||||
@ -224,9 +224,9 @@ func newBridge(
|
|||||||
return nil, fmt.Errorf("failed to create focus service: %w", err)
|
return nil, fmt.Errorf("failed to create focus service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Bridge{
|
bridge := &Bridge{
|
||||||
vault: vault,
|
vault: vault,
|
||||||
users: users,
|
users: make(map[string]*user.User),
|
||||||
|
|
||||||
api: api,
|
api: api,
|
||||||
proxyCtl: proxyCtl,
|
proxyCtl: proxyCtl,
|
||||||
@ -235,7 +235,6 @@ func newBridge(
|
|||||||
tlsConfig: tlsConfig,
|
tlsConfig: tlsConfig,
|
||||||
imapServer: imapServer,
|
imapServer: imapServer,
|
||||||
imapEventCh: imapEventCh,
|
imapEventCh: imapEventCh,
|
||||||
smtpServer: newSMTPServer(users, tlsConfig, logSMTP),
|
|
||||||
|
|
||||||
updater: updater,
|
updater: updater,
|
||||||
curVersion: curVersion,
|
curVersion: curVersion,
|
||||||
@ -249,7 +248,11 @@ func newBridge(
|
|||||||
logSMTP: logSMTP,
|
logSMTP: logSMTP,
|
||||||
|
|
||||||
tasks: tasks,
|
tasks: tasks,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
bridge.smtpServer = newSMTPServer(bridge, tlsConfig, logSMTP)
|
||||||
|
|
||||||
|
return bridge, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:funlen
|
// nolint:funlen
|
||||||
@ -265,10 +268,10 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
|
|||||||
bridge.api.AddStatusObserver(func(status liteapi.Status) {
|
bridge.api.AddStatusObserver(func(status liteapi.Status) {
|
||||||
switch {
|
switch {
|
||||||
case status == liteapi.StatusUp:
|
case status == liteapi.StatusUp:
|
||||||
bridge.onStatusUp()
|
go bridge.onStatusUp()
|
||||||
|
|
||||||
case status == liteapi.StatusDown:
|
case status == liteapi.StatusDown:
|
||||||
bridge.onStatusDown()
|
go bridge.onStatusDown()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -356,9 +359,11 @@ func (bridge *Bridge) Close(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close all users.
|
// Close all users.
|
||||||
bridge.users.IterValues(func(user *user.User) {
|
safe.RLock(func() {
|
||||||
user.Close()
|
for _, user := range bridge.users {
|
||||||
})
|
user.Close()
|
||||||
|
}
|
||||||
|
}, &bridge.usersLock)
|
||||||
|
|
||||||
// Stop all ongoing tasks.
|
// Stop all ongoing tasks.
|
||||||
bridge.tasks.Wait()
|
bridge.tasks.Wait()
|
||||||
@ -426,19 +431,23 @@ func (bridge *Bridge) remWatcher(watcher *watcher.Watcher[events.Event]) {
|
|||||||
func (bridge *Bridge) onStatusUp() {
|
func (bridge *Bridge) onStatusUp() {
|
||||||
bridge.publish(events.ConnStatusUp{})
|
bridge.publish(events.ConnStatusUp{})
|
||||||
|
|
||||||
bridge.goLoad()
|
safe.RLock(func() {
|
||||||
|
for _, user := range bridge.users {
|
||||||
|
user.OnStatusUp()
|
||||||
|
}
|
||||||
|
}, &bridge.usersLock)
|
||||||
|
|
||||||
bridge.users.IterValues(func(user *user.User) {
|
bridge.goLoad()
|
||||||
go user.OnStatusUp()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) onStatusDown() {
|
func (bridge *Bridge) onStatusDown() {
|
||||||
bridge.publish(events.ConnStatusDown{})
|
bridge.publish(events.ConnStatusDown{})
|
||||||
|
|
||||||
bridge.users.IterValues(func(user *user.User) {
|
safe.RLock(func() {
|
||||||
go user.OnStatusDown()
|
for _, user := range bridge.users {
|
||||||
})
|
user.OnStatusDown()
|
||||||
|
}
|
||||||
|
}, &bridge.usersLock)
|
||||||
|
|
||||||
bridge.tasks.Once(func(ctx context.Context) {
|
bridge.tasks.Once(func(ctx context.Context) {
|
||||||
backoff := time.Second
|
backoff := time.Second
|
||||||
|
|||||||
@ -18,18 +18,24 @@
|
|||||||
package bridge
|
package bridge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/clientconfig"
|
"github.com/ProtonMail/proton-bridge/v2/internal/clientconfig"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ConfigureAppleMail configures apple mail for the given userID and address.
|
||||||
|
// If configuring apple mail for Catalina or newer, it ensures Bridge is using SSL.
|
||||||
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
||||||
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error {
|
return safe.RLockRet(func() error {
|
||||||
|
user, ok := bridge.users[userID]
|
||||||
|
if !ok {
|
||||||
|
return ErrNoSuchUser
|
||||||
|
}
|
||||||
|
|
||||||
if address == "" {
|
if address == "" {
|
||||||
address = user.Emails()[0]
|
address = user.Emails()[0]
|
||||||
}
|
}
|
||||||
@ -42,7 +48,6 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
|||||||
addresses = strings.Join(user.Emails(), ",")
|
addresses = strings.Join(user.Emails(), ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
// If configuring apple mail for Catalina or newer, users should use SSL.
|
|
||||||
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
|
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
|
||||||
if err := bridge.SetSMTPSSL(true); err != nil {
|
if err := bridge.SetSMTPSSL(true); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -59,11 +64,5 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
|||||||
addresses,
|
addresses,
|
||||||
user.BridgePass(),
|
user.BridgePass(),
|
||||||
)
|
)
|
||||||
}); !ok {
|
}, &bridge.usersLock)
|
||||||
return ErrNoSuchUser
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("failed to configure apple mail: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/v2/internal/async"
|
"github.com/ProtonMail/proton-bridge/v2/internal/async"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"github.com/bradenaw/juniper/xsync"
|
"github.com/bradenaw/juniper/xsync"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@ -83,6 +84,44 @@ func (bridge *Bridge) closeIMAP(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addIMAPUser connects the given user to gluon.
|
||||||
|
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
||||||
|
imapConn, err := user.NewIMAPConnectors()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create IMAP connectors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for addrID, imapConn := range imapConn {
|
||||||
|
if gluonID, ok := user.GetGluonID(addrID); ok {
|
||||||
|
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
|
||||||
|
return fmt.Errorf("failed to load IMAP user: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.SetGluonID(addrID, gluonID); err != nil {
|
||||||
|
return fmt.Errorf("failed to set IMAP user ID: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeIMAPUser disconnects the given user from gluon, optionally also removing its files.
|
||||||
|
func (bridge *Bridge) removeIMAPUser(ctx context.Context, user *user.User, withFiles bool) error {
|
||||||
|
for _, gluonID := range user.GetGluonIDs() {
|
||||||
|
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove IMAP user: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
|
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
|
||||||
switch event := event.(type) {
|
switch event := event.(type) {
|
||||||
case imapEvents.SessionAdded:
|
case imapEvents.SessionAdded:
|
||||||
|
|||||||
@ -24,8 +24,8 @@ import (
|
|||||||
|
|
||||||
"github.com/Masterminds/semver/v3"
|
"github.com/Masterminds/semver/v3"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
|
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@ -118,48 +118,50 @@ func (bridge *Bridge) GetGluonDir() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
|
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
|
||||||
if newGluonDir == bridge.GetGluonDir() {
|
return safe.RLockRet(func() error {
|
||||||
return fmt.Errorf("new gluon dir is the same as the old one")
|
if newGluonDir == bridge.GetGluonDir() {
|
||||||
}
|
return fmt.Errorf("new gluon dir is the same as the old one")
|
||||||
|
}
|
||||||
|
|
||||||
if err := bridge.closeIMAP(context.Background()); err != nil {
|
if err := bridge.closeIMAP(context.Background()); err != nil {
|
||||||
return fmt.Errorf("failed to close IMAP: %w", err)
|
return fmt.Errorf("failed to close IMAP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
|
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
|
||||||
return fmt.Errorf("failed to move gluon dir: %w", err)
|
return fmt.Errorf("failed to move gluon dir: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
|
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
|
||||||
return fmt.Errorf("failed to set new gluon dir: %w", err)
|
return fmt.Errorf("failed to set new gluon dir: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
imapServer, err := newIMAPServer(
|
imapServer, err := newIMAPServer(
|
||||||
bridge.vault.GetGluonDir(),
|
bridge.vault.GetGluonDir(),
|
||||||
bridge.curVersion,
|
bridge.curVersion,
|
||||||
bridge.tlsConfig,
|
bridge.tlsConfig,
|
||||||
bridge.logIMAPClient,
|
bridge.logIMAPClient,
|
||||||
bridge.logIMAPServer,
|
bridge.logIMAPServer,
|
||||||
bridge.imapEventCh,
|
bridge.imapEventCh,
|
||||||
bridge.tasks,
|
bridge.tasks,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create new IMAP server: %w", err)
|
return fmt.Errorf("failed to create new IMAP server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
bridge.imapServer = imapServer
|
bridge.imapServer = imapServer
|
||||||
|
|
||||||
if err := bridge.users.IterValuesErr(func(user *user.User) error {
|
for _, user := range bridge.users {
|
||||||
return bridge.addIMAPUser(ctx, user)
|
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
||||||
}); err != nil {
|
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
|
||||||
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bridge.serveIMAP(); err != nil {
|
if err := bridge.serveIMAP(); err != nil {
|
||||||
return fmt.Errorf("failed to serve IMAP: %w", err)
|
return fmt.Errorf("failed to serve IMAP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
}, &bridge.usersLock)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) GetProxyAllowed() bool {
|
func (bridge *Bridge) GetProxyAllowed() bool {
|
||||||
@ -181,11 +183,13 @@ func (bridge *Bridge) GetShowAllMail() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) SetShowAllMail(show bool) error {
|
func (bridge *Bridge) SetShowAllMail(show bool) error {
|
||||||
bridge.users.IterValues(func(user *user.User) {
|
return safe.RLockRet(func() error {
|
||||||
user.SetShowAllMail(show)
|
for _, user := range bridge.users {
|
||||||
})
|
user.SetShowAllMail(show)
|
||||||
|
}
|
||||||
|
|
||||||
return bridge.vault.SetShowAllMail(show)
|
return bridge.vault.SetShowAllMail(show)
|
||||||
|
}, &bridge.usersLock)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) GetAutostart() bool {
|
func (bridge *Bridge) GetAutostart() bool {
|
||||||
@ -273,14 +277,18 @@ func (bridge *Bridge) SetColorScheme(colorScheme string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) FactoryReset(ctx context.Context) {
|
func (bridge *Bridge) FactoryReset(ctx context.Context) {
|
||||||
// First delete all users.
|
// Delete all the users.
|
||||||
for _, userID := range bridge.GetUserIDs() {
|
safe.Lock(func() {
|
||||||
if bridge.users.Has(userID) {
|
for _, user := range bridge.users {
|
||||||
if err := bridge.DeleteUser(ctx, userID); err != nil {
|
bridge.logoutUser(ctx, user, true)
|
||||||
logrus.WithError(err).Errorf("Failed to delete user %s", userID)
|
}
|
||||||
|
|
||||||
|
for _, user := range bridge.vault.GetUserIDs() {
|
||||||
|
if err := bridge.vault.DeleteUser(user); err != nil {
|
||||||
|
logrus.WithError(err).Error("failed to delete vault user")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}, &bridge.usersLock)
|
||||||
|
|
||||||
// Then delete all files.
|
// Then delete all files.
|
||||||
if err := bridge.locator.Clear(); err != nil {
|
if err := bridge.locator.Clear(); err != nil {
|
||||||
|
|||||||
@ -23,8 +23,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
@ -57,7 +55,7 @@ func (bridge *Bridge) restartSMTP() error {
|
|||||||
return fmt.Errorf("failed to close SMTP: %w", err)
|
return fmt.Errorf("failed to close SMTP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
bridge.smtpServer = newSMTPServer(bridge.users, bridge.tlsConfig, bridge.logSMTP)
|
bridge.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
|
||||||
|
|
||||||
return bridge.serveSMTP()
|
return bridge.serveSMTP()
|
||||||
}
|
}
|
||||||
@ -80,8 +78,8 @@ func (bridge *Bridge) closeSMTP() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, shouldLog bool) *smtp.Server {
|
func newSMTPServer(bridge *Bridge, tlsConfig *tls.Config, shouldLog bool) *smtp.Server {
|
||||||
smtpServer := smtp.NewServer(&smtpBackend{users})
|
smtpServer := smtp.NewServer(&smtpBackend{Bridge: bridge})
|
||||||
|
|
||||||
smtpServer.TLSConfig = tlsConfig
|
smtpServer.TLSConfig = tlsConfig
|
||||||
smtpServer.Domain = constants.Host
|
smtpServer.Domain = constants.Host
|
||||||
@ -94,6 +92,7 @@ func newSMTPServer(users *safe.Map[string, *user.User], tlsConfig *tls.Config, s
|
|||||||
log.Warning("================================================")
|
log.Warning("================================================")
|
||||||
log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
|
log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
|
||||||
log.Warning("================================================")
|
log.Warning("================================================")
|
||||||
|
|
||||||
smtpServer.Debug = logging.NewSMTPDebugLogger()
|
smtpServer.Debug = logging.NewSMTPDebugLogger()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -22,16 +22,15 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type smtpBackend struct {
|
type smtpBackend struct {
|
||||||
users *safe.Map[string, *user.User]
|
*Bridge
|
||||||
}
|
}
|
||||||
|
|
||||||
type smtpSession struct {
|
type smtpSession struct {
|
||||||
users *safe.Map[string, *user.User]
|
*Bridge
|
||||||
|
|
||||||
userID string
|
userID string
|
||||||
authID string
|
authID string
|
||||||
@ -40,15 +39,13 @@ type smtpSession struct {
|
|||||||
to []string
|
to []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (be *smtpBackend) NewSession(_ *smtp.Conn) (smtp.Session, error) {
|
func (be *smtpBackend) NewSession(*smtp.Conn) (smtp.Session, error) {
|
||||||
return &smtpSession{
|
return &smtpSession{Bridge: be.Bridge}, nil
|
||||||
users: be.users,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *smtpSession) AuthPlain(username, password string) error {
|
func (s *smtpSession) AuthPlain(username, password string) error {
|
||||||
return s.users.ValuesErr(func(users []*user.User) error {
|
return safe.RLockRet(func() error {
|
||||||
for _, user := range users {
|
for _, user := range s.users {
|
||||||
addrID, err := user.CheckAuth(username, []byte(password))
|
addrID, err := user.CheckAuth(username, []byte(password))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
@ -61,7 +58,7 @@ func (s *smtpSession) AuthPlain(username, password string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("invalid username or password")
|
return fmt.Errorf("invalid username or password")
|
||||||
})
|
}, &s.usersLock)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *smtpSession) Reset() {
|
func (s *smtpSession) Reset() {
|
||||||
@ -88,13 +85,12 @@ func (s *smtpSession) Rcpt(to string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *smtpSession) Data(r io.Reader) error {
|
func (s *smtpSession) Data(r io.Reader) error {
|
||||||
if ok, err := s.users.GetErr(s.userID, func(user *user.User) error {
|
return safe.RLockRet(func() error {
|
||||||
return user.SendMail(s.authID, s.from, s.to, r)
|
user, ok := s.users[s.userID]
|
||||||
}); !ok {
|
if !ok {
|
||||||
return fmt.Errorf("no such user %q", s.userID)
|
return ErrNoSuchUser
|
||||||
} else if err != nil {
|
}
|
||||||
return fmt.Errorf("failed to send mail: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return user.SendMail(s.authID, s.from, s.to, r)
|
||||||
|
}, &s.usersLock)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -66,32 +66,34 @@ func (bridge *Bridge) GetUserIDs() []string {
|
|||||||
|
|
||||||
// GetUserInfo returns info about the given user.
|
// GetUserInfo returns info about the given user.
|
||||||
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
|
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
|
||||||
if info, ok := safe.MapGetRet(bridge.users, userID, getConnUserInfo); ok {
|
return safe.RLockRetErr(func() (UserInfo, error) {
|
||||||
|
if user, ok := bridge.users[userID]; ok {
|
||||||
|
return getConnUserInfo(user), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var info UserInfo
|
||||||
|
|
||||||
|
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
|
||||||
|
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
|
||||||
|
}); err != nil {
|
||||||
|
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return info, nil
|
return info, nil
|
||||||
}
|
}, &bridge.usersLock)
|
||||||
|
|
||||||
var info UserInfo
|
|
||||||
|
|
||||||
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
|
|
||||||
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
|
|
||||||
}); err != nil {
|
|
||||||
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return info, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryUserInfo queries the user info by username or address.
|
// QueryUserInfo queries the user info by username or address.
|
||||||
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
|
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
|
||||||
return safe.MapValuesRetErr(bridge.users, func(users []*user.User) (UserInfo, error) {
|
return safe.RLockRetErr(func() (UserInfo, error) {
|
||||||
for _, user := range users {
|
for _, user := range bridge.users {
|
||||||
if user.Match(query) {
|
if user.Match(query) {
|
||||||
return getConnUserInfo(user), nil
|
return getConnUserInfo(user), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return UserInfo{}, ErrNoSuchUser
|
return UserInfo{}, ErrNoSuchUser
|
||||||
})
|
}, &bridge.usersLock)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginAuth begins the login process. It returns an authorized client that might need 2FA.
|
// LoginAuth begins the login process. It returns an authorized client that might need 2FA.
|
||||||
@ -101,7 +103,9 @@ func (bridge *Bridge) LoginAuth(ctx context.Context, username string, password [
|
|||||||
return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err)
|
return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if bridge.users.Has(auth.UserID) {
|
if ok := safe.RLockRet(func() bool {
|
||||||
|
return mapHas(bridge.users, auth.UID)
|
||||||
|
}, &bridge.usersLock); ok {
|
||||||
if err := client.AuthDelete(ctx); err != nil {
|
if err := client.AuthDelete(ctx); err != nil {
|
||||||
logrus.WithError(err).Warn("Failed to delete auth")
|
logrus.WithError(err).Warn("Failed to delete auth")
|
||||||
}
|
}
|
||||||
@ -182,31 +186,56 @@ func (bridge *Bridge) LoginFull(
|
|||||||
|
|
||||||
// LogoutUser logs out the given user.
|
// LogoutUser logs out the given user.
|
||||||
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
|
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
|
||||||
if err := bridge.logoutUser(ctx, userID); err != nil {
|
return safe.LockRet(func() error {
|
||||||
return fmt.Errorf("failed to logout user: %w", err)
|
user, ok := bridge.users[userID]
|
||||||
}
|
if !ok {
|
||||||
|
return ErrNoSuchUser
|
||||||
|
}
|
||||||
|
|
||||||
bridge.publish(events.UserLoggedOut{
|
defer delete(bridge.users, user.ID())
|
||||||
UserID: userID,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
bridge.logoutUser(ctx, user, true)
|
||||||
|
|
||||||
|
bridge.publish(events.UserLoggedOut{
|
||||||
|
UserID: userID,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}, &bridge.usersLock)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUser deletes the given user.
|
// DeleteUser deletes the given user.
|
||||||
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
|
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
|
||||||
bridge.deleteUser(ctx, userID)
|
return safe.LockRet(func() error {
|
||||||
|
if !bridge.vault.HasUser(userID) {
|
||||||
|
return ErrNoSuchUser
|
||||||
|
}
|
||||||
|
|
||||||
bridge.publish(events.UserDeleted{
|
if user, ok := bridge.users[userID]; ok {
|
||||||
UserID: userID,
|
defer delete(bridge.users, user.ID())
|
||||||
})
|
bridge.logoutUser(ctx, user, true)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to delete vault user")
|
||||||
|
}
|
||||||
|
|
||||||
|
bridge.publish(events.UserDeleted{
|
||||||
|
UserID: userID,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}, &bridge.usersLock)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAddressMode sets the address mode for the given user.
|
// SetAddressMode sets the address mode for the given user.
|
||||||
func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error {
|
func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error {
|
||||||
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error {
|
return safe.RLockRet(func() error {
|
||||||
|
user, ok := bridge.users[userID]
|
||||||
|
if !ok {
|
||||||
|
return ErrNoSuchUser
|
||||||
|
}
|
||||||
|
|
||||||
if user.GetAddressMode() == mode {
|
if user.GetAddressMode() == mode {
|
||||||
return fmt.Errorf("address mode is already %q", mode)
|
return fmt.Errorf("address mode is already %q", mode)
|
||||||
}
|
}
|
||||||
@ -231,13 +260,7 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
|
|||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); !ok {
|
}, &bridge.usersLock)
|
||||||
return ErrNoSuchUser
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("failed to set address mode: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
|
func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) {
|
||||||
@ -266,7 +289,13 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut
|
|||||||
// loadUsers tries to load each user in the vault that isn't already loaded.
|
// loadUsers tries to load each user in the vault that isn't already loaded.
|
||||||
func (bridge *Bridge) loadUsers(ctx context.Context) error {
|
func (bridge *Bridge) loadUsers(ctx context.Context) error {
|
||||||
return bridge.vault.ForUser(func(user *vault.User) error {
|
return bridge.vault.ForUser(func(user *vault.User) error {
|
||||||
if bridge.users.Has(user.UserID()) || user.AuthUID() == "" {
|
if user.AuthUID() == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if safe.RLockRet(func() bool {
|
||||||
|
return mapHas(bridge.users, user.UserID())
|
||||||
|
}, &bridge.usersLock) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,20 +322,16 @@ func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
|
|||||||
return fmt.Errorf("failed to set auth: %w", err)
|
return fmt.Errorf("failed to set auth: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return try.Catch(
|
apiUser, err := client.GetUser(ctx)
|
||||||
func() error {
|
if err != nil {
|
||||||
apiUser, err := client.GetUser(ctx)
|
return fmt.Errorf("failed to get user: %w", err)
|
||||||
if err != nil {
|
}
|
||||||
return fmt.Errorf("failed to get user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil {
|
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass(), false); err != nil {
|
||||||
return fmt.Errorf("failed to add user: %w", err)
|
return fmt.Errorf("failed to add user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addUser adds a new user with an already salted mailbox password.
|
// addUser adds a new user with an already salted mailbox password.
|
||||||
@ -364,10 +389,6 @@ func (bridge *Bridge) addUserWithVault(
|
|||||||
return fmt.Errorf("failed to create user: %w", err)
|
return fmt.Errorf("failed to create user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if had := bridge.users.Set(apiUser.ID, user); had {
|
|
||||||
panic("double add")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect the user's address(es) to gluon.
|
// Connect the user's address(es) to gluon.
|
||||||
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
||||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||||
@ -395,6 +416,11 @@ func (bridge *Bridge) addUserWithVault(
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Finally, save the user in the bridge.
|
||||||
|
safe.Lock(func() {
|
||||||
|
bridge.users[apiUser.ID] = user
|
||||||
|
}, &bridge.usersLock)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -430,75 +456,17 @@ func (bridge *Bridge) newVaultUser(
|
|||||||
return user, false, nil
|
return user, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addIMAPUser connects the given user to gluon.
|
// logout logs out the given user, optionally logging them out from the API too.
|
||||||
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
func (bridge *Bridge) logoutUser(ctx context.Context, user *user.User, withAPI bool) {
|
||||||
imapConn, err := user.NewIMAPConnectors()
|
if err := bridge.removeIMAPUser(ctx, user, false); err != nil {
|
||||||
if err != nil {
|
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||||
return fmt.Errorf("failed to create IMAP connectors: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for addrID, imapConn := range imapConn {
|
if err := user.Logout(ctx, withAPI); err != nil {
|
||||||
if gluonID, ok := user.GetGluonID(addrID); ok {
|
logrus.WithError(err).Error("Failed to logout user")
|
||||||
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
|
|
||||||
return fmt.Errorf("failed to load IMAP user: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := user.SetGluonID(addrID, gluonID); err != nil {
|
|
||||||
return fmt.Errorf("failed to set IMAP user ID: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
user.Close()
|
||||||
}
|
|
||||||
|
|
||||||
// logoutUser logs the given user out from bridge.
|
|
||||||
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
|
|
||||||
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
|
|
||||||
for _, gluonID := range user.GetGluonIDs() {
|
|
||||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := user.Logout(ctx); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to logout user")
|
|
||||||
}
|
|
||||||
|
|
||||||
user.Close()
|
|
||||||
}); !ok {
|
|
||||||
return ErrNoSuchUser
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteUser deletes the given user from bridge.
|
|
||||||
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
|
|
||||||
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
|
|
||||||
for _, gluonID := range user.GetGluonIDs() {
|
|
||||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := user.Logout(ctx); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to logout user")
|
|
||||||
}
|
|
||||||
|
|
||||||
user.Close()
|
|
||||||
}); !ok {
|
|
||||||
logrus.Debug("The bridge user was not connected")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to delete user from vault")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserInfo returns information about a disconnected user.
|
// getUserInfo returns information about a disconnected user.
|
||||||
@ -523,3 +491,8 @@ func getConnUserInfo(user *user.User) UserInfo {
|
|||||||
MaxSpace: user.MaxSpace(),
|
MaxSpace: user.MaxSpace(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mapHas[Key comparable, Val any](m map[Key]Val, key Key) bool {
|
||||||
|
_, ok := m[key]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|||||||
@ -22,6 +22,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
)
|
)
|
||||||
@ -44,9 +45,11 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
|
|||||||
}
|
}
|
||||||
|
|
||||||
case events.UserDeauth:
|
case events.UserDeauth:
|
||||||
if err := bridge.logoutUser(context.Background(), event.UserID); err != nil {
|
safe.Lock(func() {
|
||||||
return fmt.Errorf("failed to logout user: %w", err)
|
defer delete(bridge.users, user.ID())
|
||||||
}
|
|
||||||
|
bridge.logoutUser(ctx, user, false)
|
||||||
|
}, &bridge.usersLock)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -147,6 +147,46 @@ func TestBridge_LoginDeauthLogin(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBridge_LoginDeauthRestartLogin(t *testing.T) {
|
||||||
|
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
|
var userID string
|
||||||
|
|
||||||
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
// Login the user.
|
||||||
|
userID = must(bridge.LoginFull(ctx, username, password, nil, nil))
|
||||||
|
|
||||||
|
// Get a channel to receive the deauth event.
|
||||||
|
eventCh, done := bridge.GetEvents(events.UserDeauth{})
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
// Deauth the user.
|
||||||
|
require.NoError(t, s.RevokeUser(userID))
|
||||||
|
|
||||||
|
// The user is eventually disconnected.
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return len(getConnectedUserIDs(t, bridge)) == 0
|
||||||
|
}, 10*time.Second, time.Second)
|
||||||
|
|
||||||
|
// We should get a deauth event.
|
||||||
|
require.IsType(t, events.UserDeauth{}, <-eventCh)
|
||||||
|
})
|
||||||
|
|
||||||
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
// The user should be disconnected at startup.
|
||||||
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
|
|
||||||
|
// Login the user after the disconnection.
|
||||||
|
newUserID := must(bridge.LoginFull(ctx, username, password, nil, nil))
|
||||||
|
require.Equal(t, userID, newUserID)
|
||||||
|
|
||||||
|
// The user is connected again.
|
||||||
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
|
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestBridge_LoginExpireLogin(t *testing.T) {
|
func TestBridge_LoginExpireLogin(t *testing.T) {
|
||||||
const authLife = 2 * time.Second
|
const authLife = 2 * time.Second
|
||||||
|
|
||||||
@ -449,6 +489,82 @@ func TestBridge_LoginLogoutRepeated(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBridge_LogoutOffline(t *testing.T) {
|
||||||
|
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
// Login the user.
|
||||||
|
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// The user is now connected.
|
||||||
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
|
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||||
|
|
||||||
|
// Go offline.
|
||||||
|
netCtl.Disable()
|
||||||
|
|
||||||
|
// We can still log the user out.
|
||||||
|
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||||
|
|
||||||
|
// The user is now disconnected.
|
||||||
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBridge_DeleteDisconnected(t *testing.T) {
|
||||||
|
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
// Login the user.
|
||||||
|
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// The user is now connected.
|
||||||
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
|
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||||
|
|
||||||
|
// Logout the user.
|
||||||
|
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||||
|
|
||||||
|
// The user is now disconnected.
|
||||||
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
|
|
||||||
|
// Delete the user.
|
||||||
|
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||||
|
|
||||||
|
// The user is now deleted.
|
||||||
|
require.Empty(t, bridge.GetUserIDs())
|
||||||
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBridge_DeleteOffline(t *testing.T) {
|
||||||
|
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
// Login the user.
|
||||||
|
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// The user is now connected.
|
||||||
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
|
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||||
|
|
||||||
|
// Go offline.
|
||||||
|
netCtl.Disable()
|
||||||
|
|
||||||
|
// We can still log the user out.
|
||||||
|
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||||
|
|
||||||
|
// The user is now gone.
|
||||||
|
require.Empty(t, bridge.GetUserIDs())
|
||||||
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// getErr returns the error that was passed to it.
|
// getErr returns the error that was passed to it.
|
||||||
func getErr[T any](val T, err error) error {
|
func getErr[T any](val T, err error) error {
|
||||||
return err
|
return err
|
||||||
|
|||||||
87
internal/safe/mutex.go
Normal file
87
internal/safe/mutex.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package safe
|
||||||
|
|
||||||
|
type Mutex interface {
|
||||||
|
Lock()
|
||||||
|
Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Lock(fn func(), m ...Mutex) {
|
||||||
|
if len(m) == 0 {
|
||||||
|
panic("no mutexes provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range m {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
func LockRet[T any](fn func() T, m ...Mutex) T {
|
||||||
|
var ret T
|
||||||
|
|
||||||
|
Lock(func() {
|
||||||
|
ret = fn()
|
||||||
|
}, m...)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) {
|
||||||
|
var ret T
|
||||||
|
|
||||||
|
err := LockRet(func() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
ret, err = fn()
|
||||||
|
|
||||||
|
return err
|
||||||
|
}, m...)
|
||||||
|
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type RWMutex interface {
|
||||||
|
Mutex
|
||||||
|
|
||||||
|
RLock()
|
||||||
|
RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RLock(fn func(), m ...RWMutex) {
|
||||||
|
if len(m) == 0 {
|
||||||
|
panic("no mutexes provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range m {
|
||||||
|
m.RLock()
|
||||||
|
defer m.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RLockRet[T any](fn func() T, m ...RWMutex) T {
|
||||||
|
var ret T
|
||||||
|
|
||||||
|
RLock(func() {
|
||||||
|
ret = fn()
|
||||||
|
}, m...)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func RLockRetErr[T any](fn func() (T, error), m ...RWMutex) (T, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
ret := RLockRet(func() T {
|
||||||
|
var ret T
|
||||||
|
|
||||||
|
ret, err = fn()
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}, m...)
|
||||||
|
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
@ -402,11 +402,13 @@ func (user *User) OnStatusDown() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Logout logs the user out from the API.
|
// Logout logs the user out from the API.
|
||||||
func (user *User) Logout(ctx context.Context) error {
|
func (user *User) Logout(ctx context.Context, withAPI bool) error {
|
||||||
user.tasks.Wait()
|
user.tasks.Wait()
|
||||||
|
|
||||||
if err := user.client.AuthDelete(ctx); err != nil {
|
if withAPI {
|
||||||
return fmt.Errorf("failed to delete auth: %w", err)
|
if err := user.client.AuthDelete(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete auth: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.vault.Clear(); err != nil {
|
if err := user.vault.Clear(); err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user