mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-21 17:46:48 +00:00
fix(GODT-2606): Improve Vault concurrency scopes
Rewrite the vault to have one RWlock rather then two separate locks for data and reference counts. In certain circumstance, it could be possible that that different requests could end up in undefined states if a user got deleted successfully while at he same time another goroutine/thread is loading the given user. While I have not been able to reproduce this in a test, restricting the access scope to one lock rather than two, should avoid corner cases where logic code is executing outside of the lock scope.
This commit is contained in:
@ -40,11 +40,11 @@ type Vault struct {
|
||||
path string
|
||||
gcm cipher.AEAD
|
||||
|
||||
enc []byte
|
||||
encLock sync.RWMutex
|
||||
enc []byte
|
||||
|
||||
ref map[string]int
|
||||
refLock sync.Mutex
|
||||
ref map[string]int
|
||||
|
||||
lock sync.RWMutex
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
@ -79,14 +79,46 @@ func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHan
|
||||
|
||||
// GetUserIDs returns the user IDs and usernames of all users in the vault.
|
||||
func (vault *Vault) GetUserIDs() []string {
|
||||
return xslices.Map(vault.get().Users, func(user UserData) string {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
return xslices.Map(vault.getUnsafe().Users, func(user UserData) string {
|
||||
return user.UserID
|
||||
})
|
||||
}
|
||||
|
||||
func (vault *Vault) getUsers() ([]*User, error) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
users := vault.getUnsafe().Users
|
||||
|
||||
result := make([]*User, 0, len(users))
|
||||
|
||||
for _, user := range users {
|
||||
u, err := vault.newUserUnsafe(user.UserID)
|
||||
if err != nil {
|
||||
for _, v := range result {
|
||||
if err := v.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Fait to close user after failed get")
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, u)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// HasUser returns true if the vault contains a user with the given ID.
|
||||
func (vault *Vault) HasUser(userID string) bool {
|
||||
return xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
return xslices.IndexFunc(vault.getUnsafe().Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
}) >= 0
|
||||
}
|
||||
@ -106,41 +138,61 @@ func (vault *Vault) GetUser(userID string, fn func(*User)) error {
|
||||
|
||||
// NewUser returns a new vault user. It must be closed before it can be deleted.
|
||||
func (vault *Vault) NewUser(userID string) (*User, error) {
|
||||
if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.newUserUnsafe(userID)
|
||||
}
|
||||
|
||||
func (vault *Vault) newUserUnsafe(userID string) (*User, error) {
|
||||
if idx := xslices.IndexFunc(vault.getUnsafe().Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
}); idx < 0 {
|
||||
return nil, errors.New("no such user")
|
||||
}
|
||||
|
||||
return vault.attachUser(userID), nil
|
||||
return vault.attachUserUnsafe(userID), nil
|
||||
}
|
||||
|
||||
// ForUser executes a callback for each user in the vault.
|
||||
func (vault *Vault) ForUser(parallelism int, fn func(*User) error) error {
|
||||
userIDs := vault.GetUserIDs()
|
||||
users, err := vault.getUsers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return parallel.DoContext(context.Background(), parallelism, len(userIDs), func(_ context.Context, idx int) error {
|
||||
r := parallel.DoContext(context.Background(), parallelism, len(users), func(_ context.Context, idx int) error {
|
||||
defer async.HandlePanic(vault.panicHandler)
|
||||
|
||||
user, err := vault.NewUser(userIDs[idx])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = user.Close() }()
|
||||
|
||||
user := users[idx]
|
||||
return fn(user)
|
||||
})
|
||||
|
||||
for _, u := range users {
|
||||
if err := u.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close user after ForUser")
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// AddUser creates a new user in the vault with the given ID, username and password.
|
||||
// A gluon key is generated using the package's token generator. If a password is found in the password archive for this user,
|
||||
// it is restored, otherwise a new bridge password is generated using the package's token generator.
|
||||
func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.addUserUnsafe(userID, username, primaryEmail, authUID, authRef, keyPass)
|
||||
}
|
||||
|
||||
func (vault *Vault) addUserUnsafe(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) {
|
||||
logrus.WithField("userID", userID).Info("Adding vault user")
|
||||
|
||||
var exists bool
|
||||
|
||||
if err := vault.mod(func(data *Data) {
|
||||
if err := vault.modUnsafe(func(data *Data) {
|
||||
if idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
}); idx >= 0 {
|
||||
@ -161,13 +213,42 @@ func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef str
|
||||
return nil, errors.New("user already exists")
|
||||
}
|
||||
|
||||
return vault.NewUser(userID)
|
||||
return vault.attachUserUnsafe(userID), nil
|
||||
}
|
||||
|
||||
// GetOrAddUser retrieves an existing user and updates the authRef and keyPass or creates a new user. Returns
|
||||
// the user and whether the user did not exist before.
|
||||
func (vault *Vault) GetOrAddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, bool, error) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
{
|
||||
users := vault.getUnsafe().Users
|
||||
|
||||
idx := xslices.IndexFunc(users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
})
|
||||
|
||||
if idx >= 0 {
|
||||
user := vault.attachUserUnsafe(userID)
|
||||
|
||||
if err := user.setAuthAndKeyPassUnsafe(authUID, authRef, keyPass); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return user, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
u, err := vault.addUserUnsafe(userID, username, primaryEmail, authUID, authRef, keyPass)
|
||||
|
||||
return u, true, err
|
||||
}
|
||||
|
||||
// DeleteUser removes the given user from the vault.
|
||||
func (vault *Vault) DeleteUser(userID string) error {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).Info("Deleting vault user")
|
||||
|
||||
@ -175,7 +256,7 @@ func (vault *Vault) DeleteUser(userID string) error {
|
||||
return fmt.Errorf("user %s is currently in use", userID)
|
||||
}
|
||||
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modUnsafe(func(data *Data) {
|
||||
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
})
|
||||
@ -189,17 +270,26 @@ func (vault *Vault) DeleteUser(userID string) error {
|
||||
}
|
||||
|
||||
func (vault *Vault) Migrated() bool {
|
||||
return vault.get().Migrated
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
return vault.getUnsafe().Migrated
|
||||
}
|
||||
|
||||
func (vault *Vault) SetMigrated() error {
|
||||
return vault.mod(func(data *Data) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.modUnsafe(func(data *Data) {
|
||||
data.Migrated = true
|
||||
})
|
||||
}
|
||||
|
||||
func (vault *Vault) Reset(gluonDir string) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.modUnsafe(func(data *Data) {
|
||||
*data = newDefaultData(gluonDir)
|
||||
})
|
||||
}
|
||||
@ -209,8 +299,8 @@ func (vault *Vault) Path() string {
|
||||
}
|
||||
|
||||
func (vault *Vault) Close() error {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
if len(vault.ref) > 0 {
|
||||
return errors.New("vault is still in use")
|
||||
@ -221,10 +311,7 @@ func (vault *Vault) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vault *Vault) attachUser(userID string) *User {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
|
||||
func (vault *Vault) attachUserUnsafe(userID string) *User {
|
||||
logrus.WithField("userID", userID).Trace("Attaching vault user")
|
||||
|
||||
vault.ref[userID]++
|
||||
@ -236,8 +323,8 @@ func (vault *Vault) attachUser(userID string) *User {
|
||||
}
|
||||
|
||||
func (vault *Vault) detachUser(userID string) error {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).Trace("Detaching vault user")
|
||||
|
||||
@ -289,10 +376,14 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
||||
}, corrupt, nil
|
||||
}
|
||||
|
||||
func (vault *Vault) get() Data {
|
||||
vault.encLock.RLock()
|
||||
defer vault.encLock.RUnlock()
|
||||
func (vault *Vault) getSafe() Data {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
return vault.getUnsafe()
|
||||
}
|
||||
|
||||
func (vault *Vault) getUnsafe() Data {
|
||||
var data Data
|
||||
|
||||
if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil {
|
||||
@ -302,10 +393,14 @@ func (vault *Vault) get() Data {
|
||||
return data
|
||||
}
|
||||
|
||||
func (vault *Vault) mod(fn func(data *Data)) error {
|
||||
vault.encLock.Lock()
|
||||
defer vault.encLock.Unlock()
|
||||
func (vault *Vault) modSafe(fn func(data *Data)) error {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.modUnsafe(fn)
|
||||
}
|
||||
|
||||
func (vault *Vault) modUnsafe(fn func(data *Data)) error {
|
||||
var data Data
|
||||
|
||||
if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil {
|
||||
@ -325,13 +420,31 @@ func (vault *Vault) mod(fn func(data *Data)) error {
|
||||
}
|
||||
|
||||
func (vault *Vault) getUser(userID string) UserData {
|
||||
return vault.get().Users[xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
users := vault.getUnsafe().Users
|
||||
|
||||
idx := xslices.IndexFunc(users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
})]
|
||||
})
|
||||
|
||||
if idx < 0 {
|
||||
panic("Unknown user")
|
||||
}
|
||||
|
||||
return users[idx]
|
||||
}
|
||||
|
||||
func (vault *Vault) modUser(userID string, fn func(userData *UserData)) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.modUserUnsafe(userID, fn)
|
||||
}
|
||||
|
||||
func (vault *Vault) modUserUnsafe(userID string, fn func(userData *UserData)) error {
|
||||
return vault.modUnsafe(func(data *Data) {
|
||||
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user