Other: Safer vault

This commit is contained in:
James Houlahan
2022-10-13 00:08:11 +02:00
parent 593d86f3a7
commit ef2dea89b4
11 changed files with 270 additions and 131 deletions

2
go.mod
View File

@ -39,7 +39,7 @@ require (
github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.8.0
github.com/urfave/cli/v2 v2.16.3
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f
golang.org/x/exp v0.0.0-20220921164117-439092de6870
golang.org/x/net v0.1.0
golang.org/x/sys v0.1.0

4
go.sum
View File

@ -399,8 +399,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455 h1:TWNT/rPSUGjYsNTwWx5Fd029LipSv+h1XuBwFSd5cAo=
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4=
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f h1:5gPPdQS+dm0A2GAE0IaGLZwgTKn2Q2dCQeMxgJUD+Nk=
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=

View File

@ -27,6 +27,7 @@ import (
"sort"
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"gitlab.protontech.ch/go/liteapi"
)
@ -41,12 +42,11 @@ func (bridge *Bridge) ReportBug(ctx context.Context, osType, osVersion, descript
if info, err := bridge.QueryUserInfo(username); err == nil {
account = info.Username
} else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 {
user, err := bridge.vault.GetUser(userIDs[0])
if err != nil {
if err := bridge.vault.GetUser(userIDs[0], func(user *vault.User) {
account = user.Username()
}); err != nil {
return err
}
account = user.Username()
}
var atts []liteapi.ReportBugAttachment

View File

@ -13,7 +13,6 @@ import (
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/slices"
)
type UserInfo struct {
@ -49,18 +48,21 @@ func (bridge *Bridge) GetUserIDs() []string {
// GetUserInfo returns info about the given user.
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
vaultUser, err := bridge.vault.GetUser(userID)
if err != nil {
return UserInfo{}, err
}
if info, ok := safe.MapGetRet(bridge.users, userID, func(user *user.User) UserInfo {
return getConnUserInfo(user)
}); ok {
return info, nil
}
return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil
var info UserInfo
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
}); err != nil {
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
}
return info, nil
}
// QueryUserInfo queries the user info by username or address.
@ -108,10 +110,6 @@ func (bridge *Bridge) LoginUser(
func() error {
return client.AuthDelete(ctx)
},
func() error {
bridge.deleteUser(ctx, auth.UserID)
return nil
},
)
if err != nil {
return "", fmt.Errorf("failed to login user: %w", err)
@ -126,6 +124,7 @@ func (bridge *Bridge) LoginUser(
// LoginUser authorizes a new bridge user with the given username and password.
// If necessary, a TOTP and mailbox password are requested via the callbacks.
// This is equivalent to doing LoginAuth and LoginUser separately.
func (bridge *Bridge) LoginFull(
ctx context.Context,
username string,
@ -252,11 +251,13 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut
func (bridge *Bridge) loadLoop() {
for {
bridge.loadWG.GoTry(func(ok bool) {
if ok {
if !ok {
return
}
if err := bridge.loadUsers(); err != nil {
logrus.WithError(err).Error("Failed to load users")
}
}
})
select {
@ -269,6 +270,7 @@ func (bridge *Bridge) loadLoop() {
}
}
// loadUsers tries to load each user in the vault that isn't already loaded.
func (bridge *Bridge) loadUsers() error {
if err := bridge.vault.ForUser(func(user *vault.User) error {
if bridge.users.Has(user.UserID()) {
@ -317,7 +319,7 @@ func (bridge *Bridge) loadUser(user *vault.User) error {
return fmt.Errorf("failed to create API client: %w", err)
}
if err := try.Catch(
return try.Catch(
func() error {
apiUser, err := client.GetUser(ctx)
if err != nil {
@ -329,14 +331,7 @@ func (bridge *Bridge) loadUser(user *vault.User) error {
func() error {
return client.AuthDelete(ctx)
},
func() error {
return bridge.logoutUser(ctx, user.UserID())
},
); err != nil {
return fmt.Errorf("failed to load user: %w", err)
}
return nil
)
}
// addUser adds a new user with an already salted mailbox password.
@ -347,23 +342,45 @@ func (bridge *Bridge) addUser(
authUID, authRef string,
saltedKeyPass []byte,
) error {
var user *user.User
if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) {
existingUser, err := bridge.addExistingUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass)
vaultUser, isNew, err := bridge.newVaultUser(client, apiUser, authUID, authRef, saltedKeyPass)
if err != nil {
return fmt.Errorf("failed to add existing user: %w", err)
return fmt.Errorf("failed to add vault user: %w", err)
}
user = existingUser
} else {
newUser, err := bridge.addNewUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass)
if err := bridge.addUserWithVault(ctx, client, apiUser, vaultUser); err != nil {
if err := vaultUser.Clear(); err != nil {
logrus.WithError(err).Error("Failed to clear vault user")
}
if err := vaultUser.Close(); err != nil {
logrus.WithError(err).Error("Failed to close vault user")
}
if isNew {
if err := bridge.vault.DeleteUser(apiUser.ID); err != nil {
logrus.WithError(err).Error("Failed to delete vault user")
}
}
return fmt.Errorf("failed to add user with vault: %w", err)
}
return nil
}
// addUserWithVault adds a new user to bridge with the given vault.
func (bridge *Bridge) addUserWithVault(
ctx context.Context,
client *liteapi.Client,
apiUser liteapi.User,
vaultUser *vault.User,
) error {
user, err := user.New(ctx, vaultUser, client, apiUser)
if err != nil {
return fmt.Errorf("failed to add new user: %w", err)
return fmt.Errorf("failed to create user: %w", err)
}
user = newUser
}
bridge.users.Set(apiUser.ID, user)
// Connect the user's address(es) to gluon.
if err := bridge.addIMAPUser(ctx, user); err != nil {
@ -403,56 +420,37 @@ func (bridge *Bridge) addUser(
return nil
}
func (bridge *Bridge) addNewUser(
ctx context.Context,
// newVaultUser creates a new vault user from the given auth information.
// If one already exists in the vault, its data will be updated.
func (bridge *Bridge) newVaultUser(
client *liteapi.Client,
apiUser liteapi.User,
authUID, authRef string,
saltedKeyPass []byte,
) (*user.User, error) {
vaultUser, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass)
) (*vault.User, bool, error) {
if !bridge.vault.HasUser(apiUser.ID) {
user, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass)
if err != nil {
return nil, err
return nil, false, fmt.Errorf("failed to add user to vault: %w", err)
}
user, err := user.New(ctx, vaultUser, client, apiUser)
return user, true, nil
}
user, err := bridge.vault.NewUser(apiUser.ID)
if err != nil {
return nil, err
return nil, false, err
}
bridge.users.Set(apiUser.ID, user)
return user, nil
if err := user.SetAuth(authUID, authRef); err != nil {
return nil, false, err
}
func (bridge *Bridge) addExistingUser(
ctx context.Context,
client *liteapi.Client,
apiUser liteapi.User,
authUID, authRef string,
saltedKeyPass []byte,
) (*user.User, error) {
vaultUser, err := bridge.vault.GetUser(apiUser.ID)
if err != nil {
return nil, err
if err := user.SetKeyPass(saltedKeyPass); err != nil {
return nil, false, err
}
if err := vaultUser.SetAuth(authUID, authRef); err != nil {
return nil, err
}
if err := vaultUser.SetKeyPass(saltedKeyPass); err != nil {
return nil, err
}
user, err := user.New(ctx, vaultUser, client, apiUser)
if err != nil {
return nil, err
}
bridge.users.Set(apiUser.ID, user)
return user, nil
return user, false, nil
}
// addIMAPUser connects the given user to gluon.

View File

@ -262,21 +262,33 @@ func TestBridge_FailLoginRecover(t *testing.T) {
read += uint64(len(b))
})
var userID string
// Log the user in and record how much data was read.
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userID := must(bridge.LoginFull(ctx, username, password, nil, nil))
userID = must(bridge.LoginFull(ctx, username, password, nil, nil))
require.NoError(t, bridge.LogoutUser(ctx, userID))
})
// Now simulate failing to login.
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Simulate a partial read.
netCtl.SetReadLimit(read / 2)
netCtl.SetReadLimit(3 * read / 4)
// We should fail to log the user in because we can't fully read its data.
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Error(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
// There should be no users.
require.Empty(t, bridge.GetUserIDs())
// The user should still be there (but disconnected).
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
// Simulate the network recovering.
netCtl.SetReadLimit(0)
// We should now be able to log the user in.
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
})
})
}

View File

@ -8,6 +8,7 @@ import (
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/focus/proto"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb"
)
@ -23,7 +24,6 @@ type Service struct {
proto.UnimplementedFocusServer
server *grpc.Server
listener net.Listener
raiseCh chan struct{}
version *semver.Version
}
@ -31,25 +31,23 @@ type Service struct {
// NewService creates a new focus service.
// It listens on the local host and port 1042 (by default).
func NewService(version *semver.Version) (*Service, error) {
listener, err := net.Listen("tcp", net.JoinHostPort(Host, fmt.Sprint(Port)))
if err != nil {
return nil, fmt.Errorf("failed to listen: %w", err)
}
service := &Service{
server: grpc.NewServer(),
listener: listener,
raiseCh: make(chan struct{}, 1),
version: version,
}
proto.RegisterFocusServer(service.server, service)
if listener, err := net.Listen("tcp", net.JoinHostPort(Host, fmt.Sprint(Port))); err != nil {
logrus.WithError(err).Warn("Failed to start focus service")
} else {
go func() {
if err := service.server.Serve(listener); err != nil {
fmt.Printf("failed to serve: %v", err)
}
}()
}
return service, nil
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"sync"
"github.com/bradenaw/juniper/xerrors"
"github.com/sirupsen/logrus"
)
@ -23,7 +24,7 @@ func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, er
defer func() {
if r := recover(); r != nil {
catch(handlers...)
err = fmt.Errorf("panic: %v", r)
err = xerrors.WithStack(fmt.Errorf("panic: %v", r))
}
}()
@ -38,13 +39,13 @@ func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, er
func catch(handlers ...func() error) {
defer func() {
if r := recover(); r != nil {
logrus.WithField("panic", r).Error("Panic in catch")
logrus.WithError(xerrors.WithStack(fmt.Errorf("panic: %v", r))).Error("Catch handler panicked")
}
}()
for _, handler := range handlers {
if err := handler(); err != nil {
logrus.WithError(err).Error("Failed to handle error")
logrus.WithError(xerrors.WithStack(err)).Error("Catch handler failed")
}
}
}

View File

@ -341,7 +341,9 @@ func (user *User) Close() error {
user.waitSync()
// Close the user's API client.
user.client.Close()
if err := user.client.Close(); err != nil {
logrus.WithError(err).Error("Failed to close API client")
}
// Close the user's update channels.
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
@ -353,6 +355,11 @@ func (user *User) Close() error {
// Close the user's notify channel.
user.eventCh.Close()
// Close the user's vault.
if err := user.vault.Close(); err != nil {
logrus.WithError(err).Error("Failed to close vault")
}
return nil
}

View File

@ -147,3 +147,8 @@ func (user *User) Clear() error {
data.KeyPass = nil
})
}
// Close closes the user. This allows it to be removed from the vault.
func (user *User) Close() error {
return user.vault.detachUser(user.userID)
}

View File

@ -80,7 +80,11 @@ func TestUser_Delete(t *testing.T) {
// The user should be listed in the store.
require.ElementsMatch(t, []string{"userID"}, s.GetUserIDs())
// Clear the user's auth information.
// Try to delete the user; it should fail because it is still in use.
require.Error(t, s.DeleteUser("userID"))
// Close the user; it should now be deletable.
require.NoError(t, user.Close())
require.NoError(t, s.DeleteUser("userID"))
// The store should have no users again.
@ -122,6 +126,56 @@ func TestUser_SyncStatus(t *testing.T) {
require.Empty(t, user.SyncStatus().LastMessageID)
}
func TestUser_ForEach(t *testing.T) {
// Create a new test vault.
s := newVault(t)
// Create some new users.
user1, err := s.AddUser("userID1", "username1", "authUID1", "authRef1", []byte("keyPass1"))
require.NoError(t, err)
user2, err := s.AddUser("userID2", "username2", "authUID2", "authRef2", []byte("keyPass2"))
require.NoError(t, err)
// Iterate through the users.
s.ForUser(func(user *vault.User) error {
switch user.UserID() {
case "userID1":
require.Equal(t, "username1", user.Username())
require.Equal(t, "authUID1", user.AuthUID())
require.Equal(t, "authRef1", user.AuthRef())
require.Equal(t, "keyPass1", string(user.KeyPass()))
case "userID2":
require.Equal(t, "username2", user.Username())
require.Equal(t, "authUID2", user.AuthUID())
require.Equal(t, "authRef2", user.AuthRef())
require.Equal(t, "keyPass2", string(user.KeyPass()))
default:
t.Fatalf("unexpected user %q", user.UserID())
}
return nil
})
// Try to delete the first user; it should fail because it is still in use.
require.Error(t, s.DeleteUser("userID1"))
// Close the first user; it should now be deletable.
require.NoError(t, user1.Close())
require.NoError(t, s.DeleteUser("userID1"))
// Try to delete the second user; it should fail because it is still in use.
require.Error(t, s.DeleteUser("userID2"))
// Close the second user; it should now be deletable.
require.NoError(t, user2.Close())
require.NoError(t, s.DeleteUser("userID2"))
// The store should have no users again.
require.Empty(t, s.GetUserIDs())
}
/*
func TestUser(t *testing.T) {
// Replace the token generator with a dummy one.

View File

@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io/fs"
"math/rand"
"os"
@ -19,9 +20,13 @@ import (
// Vault is an encrypted data vault that stores bridge and user data.
type Vault struct {
path string
enc []byte
gcm cipher.AEAD
lock sync.RWMutex
enc []byte
encLock sync.RWMutex
ref map[string]int
refLock sync.Mutex
}
// New constructs a new encrypted data vault at the given filepath using the given encryption key.
@ -57,27 +62,45 @@ func (vault *Vault) GetUserIDs() []string {
})
}
// GetUserIDs returns the user IDs and usernames of all users in the vault.
func (vault *Vault) GetUser(userID string) (*User, error) {
// HasUser returns true if the vault contains a user with the given ID.
func (vault *Vault) HasUser(userID string) bool {
return xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
return user.UserID == userID
}) >= 0
}
// GetUser provides access to a vault user. It returns an error if the user does not exist.
func (vault *Vault) GetUser(userID string, fn func(*User)) error {
user, err := vault.NewUser(userID)
if err != nil {
return err
}
defer func() { _ = user.Close() }()
fn(user)
return nil
}
// NewUser returns a new vault user. It must be closed before it can be deleted.
func (vault *Vault) NewUser(userID string) (*User, error) {
if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
return user.UserID == userID
}); idx < 0 {
return nil, errors.New("no such user")
}
return &User{
vault: vault,
userID: userID,
}, nil
return vault.attachUser(userID), nil
}
// ForUser executes a callback for each user in the vault.
func (vault *Vault) ForUser(fn func(*User) error) error {
for _, userID := range vault.GetUserIDs() {
user, err := vault.GetUser(userID)
user, err := vault.NewUser(userID)
if err != nil {
return err
}
defer func() { _ = user.Close() }()
if err := fn(user); err != nil {
return err
@ -102,11 +125,18 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
return nil, err
}
return vault.GetUser(userID)
return vault.NewUser(userID)
}
// DeleteUser removes the given user from the vault.
func (vault *Vault) DeleteUser(userID string) error {
vault.refLock.Lock()
defer vault.refLock.Unlock()
if _, ok := vault.ref[userID]; ok {
return fmt.Errorf("user %s is currently in use", userID)
}
return vault.mod(func(data *Data) {
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
return user.UserID == userID
@ -124,6 +154,35 @@ func (vault *Vault) Close() error {
return nil
}
func (vault *Vault) attachUser(userID string) *User {
vault.refLock.Lock()
defer vault.refLock.Unlock()
vault.ref[userID] += 1
return &User{
vault: vault,
userID: userID,
}
}
func (vault *Vault) detachUser(userID string) error {
vault.refLock.Lock()
defer vault.refLock.Unlock()
if _, ok := vault.ref[userID]; !ok {
return fmt.Errorf("user %s is not attached", userID)
}
vault.ref[userID] -= 1
if vault.ref[userID] == 0 {
delete(vault.ref, userID)
}
return nil
}
func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) {
if _, err := initVault(path, gluonDir, gcm); err != nil {
@ -149,12 +208,17 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
enc = newEnc
}
return &Vault{path: path, enc: enc, gcm: gcm}, corrupt, nil
return &Vault{
path: path,
enc: enc,
gcm: gcm,
ref: make(map[string]int),
}, corrupt, nil
}
func (vault *Vault) get() Data {
vault.lock.RLock()
defer vault.lock.RUnlock()
vault.encLock.RLock()
defer vault.encLock.RUnlock()
dec, err := decrypt(vault.gcm, vault.enc)
if err != nil {
@ -171,8 +235,8 @@ func (vault *Vault) get() Data {
}
func (vault *Vault) mod(fn func(data *Data)) error {
vault.lock.Lock()
defer vault.lock.Unlock()
vault.encLock.Lock()
defer vault.encLock.Unlock()
dec, err := decrypt(vault.gcm, vault.enc)
if err != nil {