diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 21262186..797b1da4 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -52,7 +52,7 @@ type Bridge struct { // users holds authorized users. users map[string]*user.User - usersLock sync.RWMutex + usersLock safe.RWMutex // api manages user API clients. api *liteapi.Manager @@ -226,7 +226,9 @@ func newBridge( bridge := &Bridge{ vault: vault, - users: make(map[string]*user.User), + + users: make(map[string]*user.User), + usersLock: safe.NewRWMutex(), api: api, proxyCtl: proxyCtl, @@ -363,7 +365,7 @@ func (bridge *Bridge) Close(ctx context.Context) { for _, user := range bridge.users { user.Close() } - }, &bridge.usersLock) + }, bridge.usersLock) // Stop all ongoing tasks. bridge.tasks.Wait() @@ -433,7 +435,7 @@ func (bridge *Bridge) onStatusUp() { for _, user := range bridge.users { user.OnStatusUp() } - }, &bridge.usersLock) + }, bridge.usersLock) bridge.goLoad() } @@ -445,7 +447,7 @@ func (bridge *Bridge) onStatusDown() { for _, user := range bridge.users { user.OnStatusDown() } - }, &bridge.usersLock) + }, bridge.usersLock) bridge.tasks.Once(func(ctx context.Context) { backoff := time.Second diff --git a/internal/bridge/configure.go b/internal/bridge/configure.go index edd9d998..020c709f 100644 --- a/internal/bridge/configure.go +++ b/internal/bridge/configure.go @@ -64,5 +64,5 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error { addresses, user.BridgePass(), ) - }, &bridge.usersLock) + }, bridge.usersLock) } diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index 84e7f06f..579e8d1c 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -161,7 +161,7 @@ func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error } return nil - }, &bridge.usersLock) + }, bridge.usersLock) } func (bridge *Bridge) GetProxyAllowed() bool { @@ -189,7 +189,7 @@ func (bridge *Bridge) SetShowAllMail(show bool) error { } return bridge.vault.SetShowAllMail(show) - }, &bridge.usersLock) + }, bridge.usersLock) } func (bridge *Bridge) GetAutostart() bool { @@ -288,7 +288,7 @@ func (bridge *Bridge) FactoryReset(ctx context.Context) { logrus.WithError(err).Error("failed to delete vault user") } } - }, &bridge.usersLock) + }, bridge.usersLock) // Then delete all files. if err := bridge.locator.Clear(); err != nil { diff --git a/internal/bridge/smtp_backend.go b/internal/bridge/smtp_backend.go index 2ef93ce4..b8df5569 100644 --- a/internal/bridge/smtp_backend.go +++ b/internal/bridge/smtp_backend.go @@ -58,7 +58,7 @@ func (s *smtpSession) AuthPlain(username, password string) error { } return fmt.Errorf("invalid username or password") - }, &s.usersLock) + }, s.usersLock) } func (s *smtpSession) Reset() { @@ -92,5 +92,5 @@ func (s *smtpSession) Data(r io.Reader) error { } return user.SendMail(s.authID, s.from, s.to, r) - }, &s.usersLock) + }, s.usersLock) } diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 45d1c2f0..10ae5774 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -80,7 +80,7 @@ func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) { } return info, nil - }, &bridge.usersLock) + }, bridge.usersLock) } // QueryUserInfo queries the user info by username or address. @@ -93,7 +93,7 @@ func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) { } return UserInfo{}, ErrNoSuchUser - }, &bridge.usersLock) + }, bridge.usersLock) } // LoginAuth begins the login process. It returns an authorized client that might need 2FA. @@ -105,7 +105,7 @@ func (bridge *Bridge) LoginAuth(ctx context.Context, username string, password [ if ok := safe.RLockRet(func() bool { return mapHas(bridge.users, auth.UID) - }, &bridge.usersLock); ok { + }, bridge.usersLock); ok { if err := client.AuthDelete(ctx); err != nil { logrus.WithError(err).Warn("Failed to delete auth") } @@ -201,7 +201,7 @@ func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error { }) return nil - }, &bridge.usersLock) + }, bridge.usersLock) } // DeleteUser deletes the given user. @@ -225,7 +225,7 @@ func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error { }) return nil - }, &bridge.usersLock) + }, bridge.usersLock) } // SetAddressMode sets the address mode for the given user. @@ -260,7 +260,7 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va }) return nil - }, &bridge.usersLock) + }, bridge.usersLock) } func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, authUID, authRef string, keyPass []byte) (string, error) { @@ -295,7 +295,7 @@ func (bridge *Bridge) loadUsers(ctx context.Context) error { if safe.RLockRet(func() bool { return mapHas(bridge.users, user.UserID()) - }, &bridge.usersLock) { + }, bridge.usersLock) { return nil } @@ -419,7 +419,7 @@ func (bridge *Bridge) addUserWithVault( // Finally, save the user in the bridge. safe.Lock(func() { bridge.users[apiUser.ID] = user - }, &bridge.usersLock) + }, bridge.usersLock) return nil } diff --git a/internal/bridge/user_events.go b/internal/bridge/user_events.go index 04dbc50f..fd5214f6 100644 --- a/internal/bridge/user_events.go +++ b/internal/bridge/user_events.go @@ -49,7 +49,7 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even defer delete(bridge.users, user.ID()) bridge.logoutUser(ctx, user, false) - }, &bridge.usersLock) + }, bridge.usersLock) } return nil diff --git a/internal/safe/mutex.go b/internal/safe/mutex.go index 8e21f4b5..bb8c6b44 100644 --- a/internal/safe/mutex.go +++ b/internal/safe/mutex.go @@ -17,16 +17,76 @@ package safe +import ( + "sync" + "sync/atomic" + + "golang.org/x/exp/slices" +) + +var nextMutexID uint64 + +// Mutex is a mutex that can be locked and unlocked. type Mutex interface { Lock() Unlock() + + getMutexID() uint64 } +// NewMutex returns a new mutex. +func NewMutex() Mutex { + return &mutex{ + mutexID: atomic.AddUint64(&nextMutexID, 1), + } +} + +type mutex struct { + sync.Mutex + + mutexID uint64 +} + +func (m *mutex) getMutexID() uint64 { + return m.mutexID +} + +// RWMutex is a mutex that can be locked and unlocked for reading and writing. +type RWMutex interface { + Mutex + + RLock() + RUnlock() +} + +// NewRWMutex returns a new read-write mutex. +func NewRWMutex() RWMutex { + return &rwMutex{ + mutexID: atomic.AddUint64(&nextMutexID, 1), + } +} + +type rwMutex struct { + sync.RWMutex + + mutexID uint64 +} + +func (m *rwMutex) getMutexID() uint64 { + return m.mutexID +} + +// Lock locks one or more mutexes for writing and calls the given function. +// The mutexes are locked in a deterministic order to avoid deadlocks. func Lock(fn func(), m ...Mutex) { if len(m) == 0 { panic("no mutexes provided") } + slices.SortFunc(m, func(a, b Mutex) bool { + return a.getMutexID() < b.getMutexID() + }) + for _, m := range m { m.Lock() defer m.Unlock() @@ -35,6 +95,7 @@ func Lock(fn func(), m ...Mutex) { fn() } +// LockRet locks one or more mutexes for writing and calls the given function, returning a value. func LockRet[T any](fn func() T, m ...Mutex) T { var ret T @@ -45,6 +106,7 @@ func LockRet[T any](fn func() T, m ...Mutex) T { return ret } +// LockRetErr locks one or more mutexes for writing and calls the given function, returning a value and an error. func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) { var ret T @@ -59,18 +121,17 @@ func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) { return ret, err } -type RWMutex interface { - Mutex - - RLock() - RUnlock() -} - +// RLock locks one or more mutexes for reading and calls the given function. +// The mutexes are locked in a deterministic order to avoid deadlocks. func RLock(fn func(), m ...RWMutex) { if len(m) == 0 { panic("no mutexes provided") } + slices.SortFunc(m, func(a, b RWMutex) bool { + return a.getMutexID() < b.getMutexID() + }) + for _, m := range m { m.RLock() defer m.RUnlock() @@ -79,6 +140,7 @@ func RLock(fn func(), m ...RWMutex) { fn() } +// RLockRet locks one or more mutexes for reading and calls the given function, returning a value. func RLockRet[T any](fn func() T, m ...RWMutex) T { var ret T @@ -89,6 +151,7 @@ func RLockRet[T any](fn func() T, m ...RWMutex) T { return ret } +// RLockRetErr locks one or more mutexes for reading and calls the given function, returning a value and an error. func RLockRetErr[T any](fn func() (T, error), m ...RWMutex) (T, error) { var err error diff --git a/internal/user/events.go b/internal/user/events.go index d7628c8e..6d652f2a 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -70,7 +70,7 @@ func (user *User) handleUserEvent(_ context.Context, userEvent liteapi.User) err }) return nil - }, &user.apiUserLock) + }, user.apiUserLock) } // handleAddressEvents handles the given address events. @@ -132,7 +132,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad }) return nil - }, &user.apiAddrsLock, &user.updateChLock) + }, user.apiAddrsLock, user.updateChLock) } func (user *User) handleUpdateAddressEvent(_ context.Context, event liteapi.AddressEvent) error { //nolint:unparam @@ -150,7 +150,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event liteapi.Addr }) return nil - }, &user.apiAddrsLock) + }, user.apiAddrsLock) } func (user *User) handleDeleteAddressEvent(_ context.Context, event liteapi.AddressEvent) error { @@ -174,7 +174,7 @@ func (user *User) handleDeleteAddressEvent(_ context.Context, event liteapi.Addr }) return nil - }, &user.apiAddrsLock, &user.updateChLock) + }, user.apiAddrsLock, user.updateChLock) } // handleLabelEvents handles the given label events. @@ -220,7 +220,7 @@ func (user *User) handleCreateLabelEvent(_ context.Context, event liteapi.LabelE }) return nil - }, &user.apiLabelsLock, &user.updateChLock) + }, user.apiLabelsLock, user.updateChLock) } func (user *User) handleUpdateLabelEvent(_ context.Context, event liteapi.LabelEvent) error { //nolint:unparam @@ -242,7 +242,7 @@ func (user *User) handleUpdateLabelEvent(_ context.Context, event liteapi.LabelE }) return nil - }, &user.apiLabelsLock, &user.updateChLock) + }, user.apiLabelsLock, user.updateChLock) } func (user *User) handleDeleteLabelEvent(_ context.Context, event liteapi.LabelEvent) error { //nolint:unparam @@ -265,7 +265,7 @@ func (user *User) handleDeleteLabelEvent(_ context.Context, event liteapi.LabelE }) return nil - }, &user.apiLabelsLock, &user.updateChLock) + }, user.apiLabelsLock, user.updateChLock) } // handleMessageEvents handles the given message events. @@ -307,7 +307,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me return nil }) - }, &user.apiUserLock, &user.apiAddrsLock, &user.updateChLock) + }, user.apiUserLock, user.apiAddrsLock, user.updateChLock) } func (user *User) handleUpdateMessageEvent(_ context.Context, event liteapi.MessageEvent) error { //nolint:unparam @@ -322,7 +322,7 @@ func (user *User) handleUpdateMessageEvent(_ context.Context, event liteapi.Mess user.updateCh[event.Message.AddressID].Enqueue(update) return nil - }, &user.updateChLock) + }, user.updateChLock) } func getMailboxName(label liteapi.Label) []string { diff --git a/internal/user/imap.go b/internal/user/imap.go index 107c8bd1..f57b16e7 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -88,7 +88,7 @@ func (conn *imapConnector) GetMailbox(ctx context.Context, mailboxID imap.Mailbo } return toIMAPMailbox(mailbox, conn.flags, conn.permFlags, conn.attrs), nil - }, &conn.apiLabelsLock) + }, conn.apiLabelsLock) } // CreateMailbox creates a label with the given name. @@ -157,7 +157,7 @@ func (conn *imapConnector) createFolder(ctx context.Context, name []string) (ima } return toIMAPMailbox(label, conn.flags, conn.permFlags, conn.attrs), nil - }, &conn.apiLabelsLock) + }, conn.apiLabelsLock) } // UpdateMailboxName sets the name of the label with the given ID. @@ -232,7 +232,7 @@ func (conn *imapConnector) updateFolder(ctx context.Context, labelID imap.Mailbo } return nil - }, &conn.apiLabelsLock) + }, conn.apiLabelsLock) } // DeleteMailbox deletes the label with the given ID. @@ -350,7 +350,7 @@ func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs [ func (conn *imapConnector) GetUpdates() <-chan imap.Update { return safe.RLockRet(func() <-chan imap.Update { return conn.updateCh[conn.addrID].GetChannel() - }, &conn.updateChLock) + }, conn.updateChLock) } // GetUIDValidity returns the default UID validity for this user. @@ -407,7 +407,7 @@ func (conn *imapConnector) importMessage( return nil }) - }, &conn.apiUserLock, &conn.apiAddrsLock); err != nil { + }, conn.apiUserLock, conn.apiAddrsLock); err != nil { return imap.Message{}, nil, err } diff --git a/internal/user/sync.go b/internal/user/sync.go index e3c9d599..50ca4179 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -110,7 +110,7 @@ func (user *User) sync(ctx context.Context) error { return nil }) - }, &user.apiUserLock, &user.apiAddrsLock, &user.updateChLock) + }, user.apiUserLock, user.apiAddrsLock, user.updateChLock) } func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error { diff --git a/internal/user/user.go b/internal/user/user.go index c6e5d68c..66d2b6fc 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -59,16 +59,16 @@ type User struct { sendHash *sendRecorder apiUser liteapi.User - apiUserLock sync.RWMutex + apiUserLock safe.RWMutex apiAddrs map[string]liteapi.Address - apiAddrsLock sync.RWMutex + apiAddrsLock safe.RWMutex apiLabels map[string]liteapi.Label - apiLabelsLock sync.RWMutex + apiLabelsLock safe.RWMutex updateCh map[string]*queue.QueuedChannel[imap.Update] - updateChLock sync.RWMutex + updateChLock safe.RWMutex tasks *xsync.Group abortable async.Abortable @@ -144,10 +144,17 @@ func New( eventCh: queue.NewQueuedChannel[events.Event](0, 0), sendHash: newSendRecorder(sendEntryExpiry), - apiUser: apiUser, - apiAddrs: groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), - apiLabels: groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }), - updateCh: updateCh, + apiUser: apiUser, + apiUserLock: safe.NewRWMutex(), + + apiAddrs: groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), + apiAddrsLock: safe.NewRWMutex(), + + apiLabels: groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }), + apiLabelsLock: safe.NewRWMutex(), + + updateCh: updateCh, + updateChLock: safe.NewRWMutex(), tasks: xsync.NewGroup(context.Background()), @@ -210,14 +217,14 @@ func New( func (user *User) ID() string { return safe.RLockRet(func() string { return user.apiUser.ID - }, &user.apiUserLock) + }, user.apiUserLock) } // Name returns the user's username. func (user *User) Name() string { return safe.RLockRet(func() string { return user.apiUser.Name - }, &user.apiUserLock) + }, user.apiUserLock) } // Match matches the given query against the user's username and email addresses. @@ -234,7 +241,7 @@ func (user *User) Match(query string) bool { } return false - }, &user.apiUserLock, &user.apiAddrsLock) + }, user.apiUserLock, user.apiAddrsLock) } // Emails returns all the user's email addresses via the callback. @@ -243,7 +250,7 @@ func (user *User) Emails() []string { return xslices.Map(maps.Values(user.apiAddrs), func(addr liteapi.Address) string { return addr.Email }) - }, &user.apiAddrsLock) + }, user.apiAddrsLock) } // GetAddressMode returns the user's current address mode. @@ -286,7 +293,7 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er } return nil - }, &user.apiAddrsLock, &user.updateChLock) + }, user.apiAddrsLock, user.updateChLock) } // GetGluonIDs returns the users gluon IDs. @@ -323,14 +330,14 @@ func (user *User) BridgePass() []byte { func (user *User) UsedSpace() int { return safe.RLockRet(func() int { return user.apiUser.UsedSpace - }, &user.apiUserLock) + }, user.apiUserLock) } // MaxSpace returns the amount of space the user can use on the API. func (user *User) MaxSpace() int { return safe.RLockRet(func() int { return user.apiUser.MaxSpace - }, &user.apiUserLock) + }, user.apiUserLock) } // GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change). @@ -367,7 +374,7 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) { } return imapConn, nil - }, &user.apiAddrsLock) + }, user.apiAddrsLock) } // SendMail sends an email from the given address to the given recipients. @@ -483,7 +490,7 @@ func (user *User) SendMail(authID string, from string, to []string, r io.Reader) return nil }) - }, &user.apiUserLock, &user.apiAddrsLock) + }, user.apiUserLock, user.apiAddrsLock) } // CheckAuth returns whether the given email and password can be used to authenticate over IMAP or SMTP with this user. @@ -506,7 +513,7 @@ func (user *User) CheckAuth(email string, password []byte) (string, error) { } return "", fmt.Errorf("invalid email") - }, &user.apiAddrsLock) + }, user.apiAddrsLock) } // OnStatusUp is called when the connection goes up. @@ -549,7 +556,7 @@ func (user *User) Close() { for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) { updateCh.CloseAndDiscardQueued() } - }, &user.updateChLock) + }, user.updateChLock) // Close the user's notify channel. user.eventCh.CloseAndDiscardQueued()