Other(refactor): Sort safe.Mutex types before locking to prevent deadlocks

This change implements safe.Mutex and safe.RWMutex, which wrap the
sync.Mutex and sync.RWMutex types and are assigned a globally unique
integer ID. The safe.Lock and safe.RLock methods sort the mutexes
by this integer ID before locking to ensure that locks for a given
set of mutexes are always performed in the same order, avoiding
deadlocks.
This commit is contained in:
James Houlahan
2022-10-27 01:21:40 +02:00
parent 5a4f733518
commit d4da325e57
11 changed files with 133 additions and 61 deletions

View File

@ -17,16 +17,76 @@
package safe
import (
"sync"
"sync/atomic"
"golang.org/x/exp/slices"
)
var nextMutexID uint64
// Mutex is a mutex that can be locked and unlocked.
type Mutex interface {
Lock()
Unlock()
getMutexID() uint64
}
// NewMutex returns a new mutex.
func NewMutex() Mutex {
return &mutex{
mutexID: atomic.AddUint64(&nextMutexID, 1),
}
}
type mutex struct {
sync.Mutex
mutexID uint64
}
func (m *mutex) getMutexID() uint64 {
return m.mutexID
}
// RWMutex is a mutex that can be locked and unlocked for reading and writing.
type RWMutex interface {
Mutex
RLock()
RUnlock()
}
// NewRWMutex returns a new read-write mutex.
func NewRWMutex() RWMutex {
return &rwMutex{
mutexID: atomic.AddUint64(&nextMutexID, 1),
}
}
type rwMutex struct {
sync.RWMutex
mutexID uint64
}
func (m *rwMutex) getMutexID() uint64 {
return m.mutexID
}
// Lock locks one or more mutexes for writing and calls the given function.
// The mutexes are locked in a deterministic order to avoid deadlocks.
func Lock(fn func(), m ...Mutex) {
if len(m) == 0 {
panic("no mutexes provided")
}
slices.SortFunc(m, func(a, b Mutex) bool {
return a.getMutexID() < b.getMutexID()
})
for _, m := range m {
m.Lock()
defer m.Unlock()
@ -35,6 +95,7 @@ func Lock(fn func(), m ...Mutex) {
fn()
}
// LockRet locks one or more mutexes for writing and calls the given function, returning a value.
func LockRet[T any](fn func() T, m ...Mutex) T {
var ret T
@ -45,6 +106,7 @@ func LockRet[T any](fn func() T, m ...Mutex) T {
return ret
}
// LockRetErr locks one or more mutexes for writing and calls the given function, returning a value and an error.
func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) {
var ret T
@ -59,18 +121,17 @@ func LockRetErr[T any](fn func() (T, error), m ...Mutex) (T, error) {
return ret, err
}
type RWMutex interface {
Mutex
RLock()
RUnlock()
}
// RLock locks one or more mutexes for reading and calls the given function.
// The mutexes are locked in a deterministic order to avoid deadlocks.
func RLock(fn func(), m ...RWMutex) {
if len(m) == 0 {
panic("no mutexes provided")
}
slices.SortFunc(m, func(a, b RWMutex) bool {
return a.getMutexID() < b.getMutexID()
})
for _, m := range m {
m.RLock()
defer m.RUnlock()
@ -79,6 +140,7 @@ func RLock(fn func(), m ...RWMutex) {
fn()
}
// RLockRet locks one or more mutexes for reading and calls the given function, returning a value.
func RLockRet[T any](fn func() T, m ...RWMutex) T {
var ret T
@ -89,6 +151,7 @@ func RLockRet[T any](fn func() T, m ...RWMutex) T {
return ret
}
// RLockRetErr locks one or more mutexes for reading and calls the given function, returning a value and an error.
func RLockRetErr[T any](fn func() (T, error), m ...RWMutex) (T, error) {
var err error