mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
Other(refactor): Sort safe.Mutex types before locking to prevent deadlocks
This change implements safe.Mutex and safe.RWMutex, which wrap the sync.Mutex and sync.RWMutex types and are assigned a globally unique integer ID. The safe.Lock and safe.RLock methods sort the mutexes by this integer ID before locking to ensure that locks for a given set of mutexes are always performed in the same order, avoiding deadlocks.
This commit is contained in:
@ -17,16 +17,76 @@
|
||||
|
||||
package safe
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
var nextMutexID uint64
|
||||
|
||||
// Mutex is a mutex that can be locked and unlocked.
|
||||
type Mutex interface {
|
||||
Lock()
|
||||
Unlock()
|
||||
|
||||
getMutexID() uint64
|
||||
}
|
||||
|
||||
// NewMutex returns a new mutex.
|
||||
func NewMutex() Mutex {
|
||||
return &mutex{
|
||||
mutexID: atomic.AddUint64(&nextMutexID, 1),
|
||||
}
|
||||
}
|
||||
|
||||
type mutex struct {
|
||||
sync.Mutex
|
||||
|
||||
mutexID uint64
|
||||
}
|
||||
|
||||
func (m *mutex) getMutexID() uint64 {
|
||||
return m.mutexID
|
||||
}
|
||||
|
||||
// RWMutex is a mutex that can be locked and unlocked for reading and writing.
|
||||
type RWMutex interface {
|
||||
Mutex
|
||||
|
||||
RLock()
|
||||
RUnlock()
|
||||
}
|
||||
|
||||
// NewRWMutex returns a new read-write mutex.
|
||||
func NewRWMutex() RWMutex {
|
||||
return &rwMutex{
|
||||
mutexID: atomic.AddUint64(&nextMutexID, 1),
|
||||
}
|
||||
}
|
||||
|
||||
type rwMutex struct {
|
||||
sync.RWMutex
|
||||
|
||||
mutexID uint64
|
||||
}
|
||||
|
||||
func (m *rwMutex) getMutexID() uint64 {
|
||||
return m.mutexID
|
||||
}
|
||||
|
||||
// Lock locks one or more mutexes for writing and calls the given function.
|
||||
// The mutexes are locked in a deterministic order to avoid deadlocks.
|
||||
func Lock(fn func(), m ...Mutex) {
|
||||
if len(m) == 0 {
|
||||
panic("no mutexes provided")
|
||||
}
|
||||
|
||||
slices.SortFunc(m, func(a, b Mutex) bool {
|
||||
return a.getMutexID() < b.getMutexID()
|
||||
})
|
||||
|
||||
for _, m := range m {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
@ -35,6 +95,7 @@ func Lock(fn func(), m ...Mutex) {
|
||||
fn()
|
||||
}
|
||||
|
||||
// LockRet locks one or more mutexes for writing and calls the given function, returning a value.
|
||||
func LockRet[T any](fn func() T, m ...Mutex) T {
|
||||
var ret T
|
||||
|
||||
@ -45,6 +106,7 @@ func LockRet[T any](fn func() T, m ...Mutex) T {
|
||||
return ret
|
||||
}
|
||||
|
||||
// LockRetErr locks one or more mutexes for writing and calls the given function, returning a value and an error.
|
||||
func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) {
|
||||
var ret T
|
||||
|
||||
@ -59,18 +121,17 @@ func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) {
|
||||
return ret, err
|
||||
}
|
||||
|
||||
type RWMutex interface {
|
||||
Mutex
|
||||
|
||||
RLock()
|
||||
RUnlock()
|
||||
}
|
||||
|
||||
// RLock locks one or more mutexes for reading and calls the given function.
|
||||
// The mutexes are locked in a deterministic order to avoid deadlocks.
|
||||
func RLock(fn func(), m ...RWMutex) {
|
||||
if len(m) == 0 {
|
||||
panic("no mutexes provided")
|
||||
}
|
||||
|
||||
slices.SortFunc(m, func(a, b RWMutex) bool {
|
||||
return a.getMutexID() < b.getMutexID()
|
||||
})
|
||||
|
||||
for _, m := range m {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
@ -79,6 +140,7 @@ func RLock(fn func(), m ...RWMutex) {
|
||||
fn()
|
||||
}
|
||||
|
||||
// RLockRet locks one or more mutexes for reading and calls the given function, returning a value.
|
||||
func RLockRet[T any](fn func() T, m ...RWMutex) T {
|
||||
var ret T
|
||||
|
||||
@ -89,6 +151,7 @@ func RLockRet[T any](fn func() T, m ...RWMutex) T {
|
||||
return ret
|
||||
}
|
||||
|
||||
// RLockRetErr locks one or more mutexes for reading and calls the given function, returning a value and an error.
|
||||
func RLockRetErr[T any](fn func() (T, error), m ...RWMutex) (T, error) {
|
||||
var err error
|
||||
|
||||
|
||||
Reference in New Issue
Block a user