forked from Silverfish/proton-bridge
Other: Safer vault
This commit is contained in:
2
go.mod
2
go.mod
@ -39,7 +39,7 @@ require (
|
||||
github.com/sirupsen/logrus v1.9.0
|
||||
github.com/stretchr/testify v1.8.0
|
||||
github.com/urfave/cli/v2 v2.16.3
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f
|
||||
golang.org/x/exp v0.0.0-20220921164117-439092de6870
|
||||
golang.org/x/net v0.1.0
|
||||
golang.org/x/sys v0.1.0
|
||||
|
||||
4
go.sum
4
go.sum
@ -399,8 +399,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
|
||||
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455 h1:TWNT/rPSUGjYsNTwWx5Fd029LipSv+h1XuBwFSd5cAo=
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4=
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f h1:5gPPdQS+dm0A2GAE0IaGLZwgTKn2Q2dCQeMxgJUD+Nk=
|
||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4=
|
||||
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||
|
||||
@ -27,6 +27,7 @@ import (
|
||||
"sort"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
|
||||
@ -41,12 +42,11 @@ func (bridge *Bridge) ReportBug(ctx context.Context, osType, osVersion, descript
|
||||
if info, err := bridge.QueryUserInfo(username); err == nil {
|
||||
account = info.Username
|
||||
} else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 {
|
||||
user, err := bridge.vault.GetUser(userIDs[0])
|
||||
if err != nil {
|
||||
if err := bridge.vault.GetUser(userIDs[0], func(user *vault.User) {
|
||||
account = user.Username()
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account = user.Username()
|
||||
}
|
||||
|
||||
var atts []liteapi.ReportBugAttachment
|
||||
|
||||
@ -13,7 +13,6 @@ import (
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
@ -49,18 +48,21 @@ func (bridge *Bridge) GetUserIDs() []string {
|
||||
|
||||
// GetUserInfo returns info about the given user.
|
||||
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
|
||||
vaultUser, err := bridge.vault.GetUser(userID)
|
||||
if err != nil {
|
||||
return UserInfo{}, err
|
||||
}
|
||||
|
||||
if info, ok := safe.MapGetRet(bridge.users, userID, func(user *user.User) UserInfo {
|
||||
return getConnUserInfo(user)
|
||||
}); ok {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil
|
||||
var info UserInfo
|
||||
|
||||
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
|
||||
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
|
||||
}); err != nil {
|
||||
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// QueryUserInfo queries the user info by username or address.
|
||||
@ -108,10 +110,6 @@ func (bridge *Bridge) LoginUser(
|
||||
func() error {
|
||||
return client.AuthDelete(ctx)
|
||||
},
|
||||
func() error {
|
||||
bridge.deleteUser(ctx, auth.UserID)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to login user: %w", err)
|
||||
@ -126,6 +124,7 @@ func (bridge *Bridge) LoginUser(
|
||||
|
||||
// LoginUser authorizes a new bridge user with the given username and password.
|
||||
// If necessary, a TOTP and mailbox password are requested via the callbacks.
|
||||
// This is equivalent to doing LoginAuth and LoginUser separately.
|
||||
func (bridge *Bridge) LoginFull(
|
||||
ctx context.Context,
|
||||
username string,
|
||||
@ -252,11 +251,13 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut
|
||||
func (bridge *Bridge) loadLoop() {
|
||||
for {
|
||||
bridge.loadWG.GoTry(func(ok bool) {
|
||||
if ok {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := bridge.loadUsers(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to load users")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
select {
|
||||
@ -269,6 +270,7 @@ func (bridge *Bridge) loadLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// loadUsers tries to load each user in the vault that isn't already loaded.
|
||||
func (bridge *Bridge) loadUsers() error {
|
||||
if err := bridge.vault.ForUser(func(user *vault.User) error {
|
||||
if bridge.users.Has(user.UserID()) {
|
||||
@ -317,7 +319,7 @@ func (bridge *Bridge) loadUser(user *vault.User) error {
|
||||
return fmt.Errorf("failed to create API client: %w", err)
|
||||
}
|
||||
|
||||
if err := try.Catch(
|
||||
return try.Catch(
|
||||
func() error {
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
@ -329,14 +331,7 @@ func (bridge *Bridge) loadUser(user *vault.User) error {
|
||||
func() error {
|
||||
return client.AuthDelete(ctx)
|
||||
},
|
||||
func() error {
|
||||
return bridge.logoutUser(ctx, user.UserID())
|
||||
},
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to load user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
)
|
||||
}
|
||||
|
||||
// addUser adds a new user with an already salted mailbox password.
|
||||
@ -347,23 +342,45 @@ func (bridge *Bridge) addUser(
|
||||
authUID, authRef string,
|
||||
saltedKeyPass []byte,
|
||||
) error {
|
||||
var user *user.User
|
||||
|
||||
if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) {
|
||||
existingUser, err := bridge.addExistingUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass)
|
||||
vaultUser, isNew, err := bridge.newVaultUser(client, apiUser, authUID, authRef, saltedKeyPass)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add existing user: %w", err)
|
||||
return fmt.Errorf("failed to add vault user: %w", err)
|
||||
}
|
||||
|
||||
user = existingUser
|
||||
} else {
|
||||
newUser, err := bridge.addNewUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass)
|
||||
if err := bridge.addUserWithVault(ctx, client, apiUser, vaultUser); err != nil {
|
||||
if err := vaultUser.Clear(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to clear vault user")
|
||||
}
|
||||
|
||||
if err := vaultUser.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close vault user")
|
||||
}
|
||||
|
||||
if isNew {
|
||||
if err := bridge.vault.DeleteUser(apiUser.ID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to delete vault user")
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to add user with vault: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addUserWithVault adds a new user to bridge with the given vault.
|
||||
func (bridge *Bridge) addUserWithVault(
|
||||
ctx context.Context,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
vaultUser *vault.User,
|
||||
) error {
|
||||
user, err := user.New(ctx, vaultUser, client, apiUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add new user: %w", err)
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
user = newUser
|
||||
}
|
||||
bridge.users.Set(apiUser.ID, user)
|
||||
|
||||
// Connect the user's address(es) to gluon.
|
||||
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
||||
@ -403,56 +420,37 @@ func (bridge *Bridge) addUser(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) addNewUser(
|
||||
ctx context.Context,
|
||||
// newVaultUser creates a new vault user from the given auth information.
|
||||
// If one already exists in the vault, its data will be updated.
|
||||
func (bridge *Bridge) newVaultUser(
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
authUID, authRef string,
|
||||
saltedKeyPass []byte,
|
||||
) (*user.User, error) {
|
||||
vaultUser, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass)
|
||||
) (*vault.User, bool, error) {
|
||||
if !bridge.vault.HasUser(apiUser.ID) {
|
||||
user, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, fmt.Errorf("failed to add user to vault: %w", err)
|
||||
}
|
||||
|
||||
user, err := user.New(ctx, vaultUser, client, apiUser)
|
||||
return user, true, nil
|
||||
}
|
||||
|
||||
user, err := bridge.vault.NewUser(apiUser.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
bridge.users.Set(apiUser.ID, user)
|
||||
|
||||
return user, nil
|
||||
if err := user.SetAuth(authUID, authRef); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
func (bridge *Bridge) addExistingUser(
|
||||
ctx context.Context,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
authUID, authRef string,
|
||||
saltedKeyPass []byte,
|
||||
) (*user.User, error) {
|
||||
vaultUser, err := bridge.vault.GetUser(apiUser.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err := user.SetKeyPass(saltedKeyPass); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if err := vaultUser.SetAuth(authUID, authRef); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := vaultUser.SetKeyPass(saltedKeyPass); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := user.New(ctx, vaultUser, client, apiUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bridge.users.Set(apiUser.ID, user)
|
||||
|
||||
return user, nil
|
||||
return user, false, nil
|
||||
}
|
||||
|
||||
// addIMAPUser connects the given user to gluon.
|
||||
|
||||
@ -262,21 +262,33 @@ func TestBridge_FailLoginRecover(t *testing.T) {
|
||||
read += uint64(len(b))
|
||||
})
|
||||
|
||||
var userID string
|
||||
|
||||
// Log the user in and record how much data was read.
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userID := must(bridge.LoginFull(ctx, username, password, nil, nil))
|
||||
userID = must(bridge.LoginFull(ctx, username, password, nil, nil))
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
})
|
||||
|
||||
// Now simulate failing to login.
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Simulate a partial read.
|
||||
netCtl.SetReadLimit(read / 2)
|
||||
netCtl.SetReadLimit(3 * read / 4)
|
||||
|
||||
// We should fail to log the user in because we can't fully read its data.
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Error(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
|
||||
|
||||
// There should be no users.
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
// The user should still be there (but disconnected).
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
|
||||
// Simulate the network recovering.
|
||||
netCtl.SetReadLimit(0)
|
||||
|
||||
// We should now be able to log the user in.
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/focus/proto"
|
||||
"github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
@ -23,7 +24,6 @@ type Service struct {
|
||||
proto.UnimplementedFocusServer
|
||||
|
||||
server *grpc.Server
|
||||
listener net.Listener
|
||||
raiseCh chan struct{}
|
||||
version *semver.Version
|
||||
}
|
||||
@ -31,25 +31,23 @@ type Service struct {
|
||||
// NewService creates a new focus service.
|
||||
// It listens on the local host and port 1042 (by default).
|
||||
func NewService(version *semver.Version) (*Service, error) {
|
||||
listener, err := net.Listen("tcp", net.JoinHostPort(Host, fmt.Sprint(Port)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
service := &Service{
|
||||
server: grpc.NewServer(),
|
||||
listener: listener,
|
||||
raiseCh: make(chan struct{}, 1),
|
||||
version: version,
|
||||
}
|
||||
|
||||
proto.RegisterFocusServer(service.server, service)
|
||||
|
||||
if listener, err := net.Listen("tcp", net.JoinHostPort(Host, fmt.Sprint(Port))); err != nil {
|
||||
logrus.WithError(err).Warn("Failed to start focus service")
|
||||
} else {
|
||||
go func() {
|
||||
if err := service.server.Serve(listener); err != nil {
|
||||
fmt.Printf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/bradenaw/juniper/xerrors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -23,7 +24,7 @@ func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, er
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
catch(handlers...)
|
||||
err = fmt.Errorf("panic: %v", r)
|
||||
err = xerrors.WithStack(fmt.Errorf("panic: %v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
@ -38,13 +39,13 @@ func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, er
|
||||
func catch(handlers ...func() error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logrus.WithField("panic", r).Error("Panic in catch")
|
||||
logrus.WithError(xerrors.WithStack(fmt.Errorf("panic: %v", r))).Error("Catch handler panicked")
|
||||
}
|
||||
}()
|
||||
|
||||
for _, handler := range handlers {
|
||||
if err := handler(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to handle error")
|
||||
logrus.WithError(xerrors.WithStack(err)).Error("Catch handler failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -341,7 +341,9 @@ func (user *User) Close() error {
|
||||
user.waitSync()
|
||||
|
||||
// Close the user's API client.
|
||||
user.client.Close()
|
||||
if err := user.client.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close API client")
|
||||
}
|
||||
|
||||
// Close the user's update channels.
|
||||
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||
@ -353,6 +355,11 @@ func (user *User) Close() error {
|
||||
// Close the user's notify channel.
|
||||
user.eventCh.Close()
|
||||
|
||||
// Close the user's vault.
|
||||
if err := user.vault.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close vault")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -147,3 +147,8 @@ func (user *User) Clear() error {
|
||||
data.KeyPass = nil
|
||||
})
|
||||
}
|
||||
|
||||
// Close closes the user. This allows it to be removed from the vault.
|
||||
func (user *User) Close() error {
|
||||
return user.vault.detachUser(user.userID)
|
||||
}
|
||||
|
||||
@ -80,7 +80,11 @@ func TestUser_Delete(t *testing.T) {
|
||||
// The user should be listed in the store.
|
||||
require.ElementsMatch(t, []string{"userID"}, s.GetUserIDs())
|
||||
|
||||
// Clear the user's auth information.
|
||||
// Try to delete the user; it should fail because it is still in use.
|
||||
require.Error(t, s.DeleteUser("userID"))
|
||||
|
||||
// Close the user; it should now be deletable.
|
||||
require.NoError(t, user.Close())
|
||||
require.NoError(t, s.DeleteUser("userID"))
|
||||
|
||||
// The store should have no users again.
|
||||
@ -122,6 +126,56 @@ func TestUser_SyncStatus(t *testing.T) {
|
||||
require.Empty(t, user.SyncStatus().LastMessageID)
|
||||
}
|
||||
|
||||
func TestUser_ForEach(t *testing.T) {
|
||||
// Create a new test vault.
|
||||
s := newVault(t)
|
||||
|
||||
// Create some new users.
|
||||
user1, err := s.AddUser("userID1", "username1", "authUID1", "authRef1", []byte("keyPass1"))
|
||||
require.NoError(t, err)
|
||||
user2, err := s.AddUser("userID2", "username2", "authUID2", "authRef2", []byte("keyPass2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Iterate through the users.
|
||||
s.ForUser(func(user *vault.User) error {
|
||||
switch user.UserID() {
|
||||
case "userID1":
|
||||
require.Equal(t, "username1", user.Username())
|
||||
require.Equal(t, "authUID1", user.AuthUID())
|
||||
require.Equal(t, "authRef1", user.AuthRef())
|
||||
require.Equal(t, "keyPass1", string(user.KeyPass()))
|
||||
|
||||
case "userID2":
|
||||
require.Equal(t, "username2", user.Username())
|
||||
require.Equal(t, "authUID2", user.AuthUID())
|
||||
require.Equal(t, "authRef2", user.AuthRef())
|
||||
require.Equal(t, "keyPass2", string(user.KeyPass()))
|
||||
|
||||
default:
|
||||
t.Fatalf("unexpected user %q", user.UserID())
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Try to delete the first user; it should fail because it is still in use.
|
||||
require.Error(t, s.DeleteUser("userID1"))
|
||||
|
||||
// Close the first user; it should now be deletable.
|
||||
require.NoError(t, user1.Close())
|
||||
require.NoError(t, s.DeleteUser("userID1"))
|
||||
|
||||
// Try to delete the second user; it should fail because it is still in use.
|
||||
require.Error(t, s.DeleteUser("userID2"))
|
||||
|
||||
// Close the second user; it should now be deletable.
|
||||
require.NoError(t, user2.Close())
|
||||
require.NoError(t, s.DeleteUser("userID2"))
|
||||
|
||||
// The store should have no users again.
|
||||
require.Empty(t, s.GetUserIDs())
|
||||
}
|
||||
|
||||
/*
|
||||
func TestUser(t *testing.T) {
|
||||
// Replace the token generator with a dummy one.
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math/rand"
|
||||
"os"
|
||||
@ -19,9 +20,13 @@ import (
|
||||
// Vault is an encrypted data vault that stores bridge and user data.
|
||||
type Vault struct {
|
||||
path string
|
||||
enc []byte
|
||||
gcm cipher.AEAD
|
||||
lock sync.RWMutex
|
||||
|
||||
enc []byte
|
||||
encLock sync.RWMutex
|
||||
|
||||
ref map[string]int
|
||||
refLock sync.Mutex
|
||||
}
|
||||
|
||||
// New constructs a new encrypted data vault at the given filepath using the given encryption key.
|
||||
@ -57,27 +62,45 @@ func (vault *Vault) GetUserIDs() []string {
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserIDs returns the user IDs and usernames of all users in the vault.
|
||||
func (vault *Vault) GetUser(userID string) (*User, error) {
|
||||
// HasUser returns true if the vault contains a user with the given ID.
|
||||
func (vault *Vault) HasUser(userID string) bool {
|
||||
return xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
}) >= 0
|
||||
}
|
||||
|
||||
// GetUser provides access to a vault user. It returns an error if the user does not exist.
|
||||
func (vault *Vault) GetUser(userID string, fn func(*User)) error {
|
||||
user, err := vault.NewUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = user.Close() }()
|
||||
|
||||
fn(user)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewUser returns a new vault user. It must be closed before it can be deleted.
|
||||
func (vault *Vault) NewUser(userID string) (*User, error) {
|
||||
if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
}); idx < 0 {
|
||||
return nil, errors.New("no such user")
|
||||
}
|
||||
|
||||
return &User{
|
||||
vault: vault,
|
||||
userID: userID,
|
||||
}, nil
|
||||
return vault.attachUser(userID), nil
|
||||
}
|
||||
|
||||
// ForUser executes a callback for each user in the vault.
|
||||
func (vault *Vault) ForUser(fn func(*User) error) error {
|
||||
for _, userID := range vault.GetUserIDs() {
|
||||
user, err := vault.GetUser(userID)
|
||||
user, err := vault.NewUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = user.Close() }()
|
||||
|
||||
if err := fn(user); err != nil {
|
||||
return err
|
||||
@ -102,11 +125,18 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return vault.GetUser(userID)
|
||||
return vault.NewUser(userID)
|
||||
}
|
||||
|
||||
// DeleteUser removes the given user from the vault.
|
||||
func (vault *Vault) DeleteUser(userID string) error {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
|
||||
if _, ok := vault.ref[userID]; ok {
|
||||
return fmt.Errorf("user %s is currently in use", userID)
|
||||
}
|
||||
|
||||
return vault.mod(func(data *Data) {
|
||||
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
@ -124,6 +154,35 @@ func (vault *Vault) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vault *Vault) attachUser(userID string) *User {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
|
||||
vault.ref[userID] += 1
|
||||
|
||||
return &User{
|
||||
vault: vault,
|
||||
userID: userID,
|
||||
}
|
||||
}
|
||||
|
||||
func (vault *Vault) detachUser(userID string) error {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
|
||||
if _, ok := vault.ref[userID]; !ok {
|
||||
return fmt.Errorf("user %s is not attached", userID)
|
||||
}
|
||||
|
||||
vault.ref[userID] -= 1
|
||||
|
||||
if vault.ref[userID] == 0 {
|
||||
delete(vault.ref, userID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
||||
if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) {
|
||||
if _, err := initVault(path, gluonDir, gcm); err != nil {
|
||||
@ -149,12 +208,17 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
||||
enc = newEnc
|
||||
}
|
||||
|
||||
return &Vault{path: path, enc: enc, gcm: gcm}, corrupt, nil
|
||||
return &Vault{
|
||||
path: path,
|
||||
enc: enc,
|
||||
gcm: gcm,
|
||||
ref: make(map[string]int),
|
||||
}, corrupt, nil
|
||||
}
|
||||
|
||||
func (vault *Vault) get() Data {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
vault.encLock.RLock()
|
||||
defer vault.encLock.RUnlock()
|
||||
|
||||
dec, err := decrypt(vault.gcm, vault.enc)
|
||||
if err != nil {
|
||||
@ -171,8 +235,8 @@ func (vault *Vault) get() Data {
|
||||
}
|
||||
|
||||
func (vault *Vault) mod(fn func(data *Data)) error {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
vault.encLock.Lock()
|
||||
defer vault.encLock.Unlock()
|
||||
|
||||
dec, err := decrypt(vault.gcm, vault.enc)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user