mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-15 22:56:48 +00:00
Other: Safer vault
This commit is contained in:
@ -147,3 +147,8 @@ func (user *User) Clear() error {
|
||||
data.KeyPass = nil
|
||||
})
|
||||
}
|
||||
|
||||
// Close closes the user. This allows it to be removed from the vault.
|
||||
func (user *User) Close() error {
|
||||
return user.vault.detachUser(user.userID)
|
||||
}
|
||||
|
||||
@ -80,7 +80,11 @@ func TestUser_Delete(t *testing.T) {
|
||||
// The user should be listed in the store.
|
||||
require.ElementsMatch(t, []string{"userID"}, s.GetUserIDs())
|
||||
|
||||
// Clear the user's auth information.
|
||||
// Try to delete the user; it should fail because it is still in use.
|
||||
require.Error(t, s.DeleteUser("userID"))
|
||||
|
||||
// Close the user; it should now be deletable.
|
||||
require.NoError(t, user.Close())
|
||||
require.NoError(t, s.DeleteUser("userID"))
|
||||
|
||||
// The store should have no users again.
|
||||
@ -122,6 +126,56 @@ func TestUser_SyncStatus(t *testing.T) {
|
||||
require.Empty(t, user.SyncStatus().LastMessageID)
|
||||
}
|
||||
|
||||
func TestUser_ForEach(t *testing.T) {
|
||||
// Create a new test vault.
|
||||
s := newVault(t)
|
||||
|
||||
// Create some new users.
|
||||
user1, err := s.AddUser("userID1", "username1", "authUID1", "authRef1", []byte("keyPass1"))
|
||||
require.NoError(t, err)
|
||||
user2, err := s.AddUser("userID2", "username2", "authUID2", "authRef2", []byte("keyPass2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Iterate through the users.
|
||||
s.ForUser(func(user *vault.User) error {
|
||||
switch user.UserID() {
|
||||
case "userID1":
|
||||
require.Equal(t, "username1", user.Username())
|
||||
require.Equal(t, "authUID1", user.AuthUID())
|
||||
require.Equal(t, "authRef1", user.AuthRef())
|
||||
require.Equal(t, "keyPass1", string(user.KeyPass()))
|
||||
|
||||
case "userID2":
|
||||
require.Equal(t, "username2", user.Username())
|
||||
require.Equal(t, "authUID2", user.AuthUID())
|
||||
require.Equal(t, "authRef2", user.AuthRef())
|
||||
require.Equal(t, "keyPass2", string(user.KeyPass()))
|
||||
|
||||
default:
|
||||
t.Fatalf("unexpected user %q", user.UserID())
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Try to delete the first user; it should fail because it is still in use.
|
||||
require.Error(t, s.DeleteUser("userID1"))
|
||||
|
||||
// Close the first user; it should now be deletable.
|
||||
require.NoError(t, user1.Close())
|
||||
require.NoError(t, s.DeleteUser("userID1"))
|
||||
|
||||
// Try to delete the second user; it should fail because it is still in use.
|
||||
require.Error(t, s.DeleteUser("userID2"))
|
||||
|
||||
// Close the second user; it should now be deletable.
|
||||
require.NoError(t, user2.Close())
|
||||
require.NoError(t, s.DeleteUser("userID2"))
|
||||
|
||||
// The store should have no users again.
|
||||
require.Empty(t, s.GetUserIDs())
|
||||
}
|
||||
|
||||
/*
|
||||
func TestUser(t *testing.T) {
|
||||
// Replace the token generator with a dummy one.
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math/rand"
|
||||
"os"
|
||||
@ -19,9 +20,13 @@ import (
|
||||
// Vault is an encrypted data vault that stores bridge and user data.
|
||||
type Vault struct {
|
||||
path string
|
||||
enc []byte
|
||||
gcm cipher.AEAD
|
||||
lock sync.RWMutex
|
||||
|
||||
enc []byte
|
||||
encLock sync.RWMutex
|
||||
|
||||
ref map[string]int
|
||||
refLock sync.Mutex
|
||||
}
|
||||
|
||||
// New constructs a new encrypted data vault at the given filepath using the given encryption key.
|
||||
@ -57,27 +62,45 @@ func (vault *Vault) GetUserIDs() []string {
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserIDs returns the user IDs and usernames of all users in the vault.
|
||||
func (vault *Vault) GetUser(userID string) (*User, error) {
|
||||
// HasUser returns true if the vault contains a user with the given ID.
|
||||
func (vault *Vault) HasUser(userID string) bool {
|
||||
return xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
}) >= 0
|
||||
}
|
||||
|
||||
// GetUser provides access to a vault user. It returns an error if the user does not exist.
|
||||
func (vault *Vault) GetUser(userID string, fn func(*User)) error {
|
||||
user, err := vault.NewUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = user.Close() }()
|
||||
|
||||
fn(user)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewUser returns a new vault user. It must be closed before it can be deleted.
|
||||
func (vault *Vault) NewUser(userID string) (*User, error) {
|
||||
if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
}); idx < 0 {
|
||||
return nil, errors.New("no such user")
|
||||
}
|
||||
|
||||
return &User{
|
||||
vault: vault,
|
||||
userID: userID,
|
||||
}, nil
|
||||
return vault.attachUser(userID), nil
|
||||
}
|
||||
|
||||
// ForUser executes a callback for each user in the vault.
|
||||
func (vault *Vault) ForUser(fn func(*User) error) error {
|
||||
for _, userID := range vault.GetUserIDs() {
|
||||
user, err := vault.GetUser(userID)
|
||||
user, err := vault.NewUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = user.Close() }()
|
||||
|
||||
if err := fn(user); err != nil {
|
||||
return err
|
||||
@ -102,11 +125,18 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return vault.GetUser(userID)
|
||||
return vault.NewUser(userID)
|
||||
}
|
||||
|
||||
// DeleteUser removes the given user from the vault.
|
||||
func (vault *Vault) DeleteUser(userID string) error {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
|
||||
if _, ok := vault.ref[userID]; ok {
|
||||
return fmt.Errorf("user %s is currently in use", userID)
|
||||
}
|
||||
|
||||
return vault.mod(func(data *Data) {
|
||||
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
@ -124,6 +154,35 @@ func (vault *Vault) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vault *Vault) attachUser(userID string) *User {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
|
||||
vault.ref[userID] += 1
|
||||
|
||||
return &User{
|
||||
vault: vault,
|
||||
userID: userID,
|
||||
}
|
||||
}
|
||||
|
||||
func (vault *Vault) detachUser(userID string) error {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
|
||||
if _, ok := vault.ref[userID]; !ok {
|
||||
return fmt.Errorf("user %s is not attached", userID)
|
||||
}
|
||||
|
||||
vault.ref[userID] -= 1
|
||||
|
||||
if vault.ref[userID] == 0 {
|
||||
delete(vault.ref, userID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
||||
if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) {
|
||||
if _, err := initVault(path, gluonDir, gcm); err != nil {
|
||||
@ -149,12 +208,17 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
||||
enc = newEnc
|
||||
}
|
||||
|
||||
return &Vault{path: path, enc: enc, gcm: gcm}, corrupt, nil
|
||||
return &Vault{
|
||||
path: path,
|
||||
enc: enc,
|
||||
gcm: gcm,
|
||||
ref: make(map[string]int),
|
||||
}, corrupt, nil
|
||||
}
|
||||
|
||||
func (vault *Vault) get() Data {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
vault.encLock.RLock()
|
||||
defer vault.encLock.RUnlock()
|
||||
|
||||
dec, err := decrypt(vault.gcm, vault.enc)
|
||||
if err != nil {
|
||||
@ -171,8 +235,8 @@ func (vault *Vault) get() Data {
|
||||
}
|
||||
|
||||
func (vault *Vault) mod(fn func(data *Data)) error {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
vault.encLock.Lock()
|
||||
defer vault.encLock.Unlock()
|
||||
|
||||
dec, err := decrypt(vault.gcm, vault.enc)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user