fix(GODT-2387): Ensure vault can be unlocked after factory reset

When performing a factory reset, we don't want to wipe all keychain
entries. The only keychain entry should be the vault's passphrase,
and we need this to be able to decrypt the vault at next startup
(to avoid it being reported as corrupt).
This commit is contained in:
James Houlahan
2023-02-22 13:39:59 +01:00
parent 8534da98ea
commit e9f20aee7a
7 changed files with 68 additions and 50 deletions

View File

@ -41,8 +41,6 @@ import (
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
const vaultSecretName = "bridge-vault-key"
// deleteOldGoIMAPFiles Set with `-ldflags -X app.deleteOldGoIMAPFiles=true` to enable cleanup of old imap cache data. // deleteOldGoIMAPFiles Set with `-ldflags -X app.deleteOldGoIMAPFiles=true` to enable cleanup of old imap cache data.
var deleteOldGoIMAPFiles bool //nolint:gochecknoglobals var deleteOldGoIMAPFiles bool //nolint:gochecknoglobals

View File

@ -18,18 +18,15 @@
package app package app
import ( import (
"encoding/base64"
"fmt" "fmt"
"path" "path"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/certs" "github.com/ProtonMail/proton-bridge/v3/internal/certs"
"github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/locations" "github.com/ProtonMail/proton-bridge/v3/internal/locations"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
) )
func WithVault(locations *locations.Locations, fn func(*vault.Vault, bool, bool) error) error { func WithVault(locations *locations.Locations, fn func(*vault.Vault, bool, bool) error) error {
@ -82,7 +79,7 @@ func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error)
insecure bool insecure bool
) )
if key, err := getVaultKey(vaultDir); err != nil { if key, err := loadVaultKey(vaultDir); err != nil {
insecure = true insecure = true
// We store the insecure vault in a separate directory // We store the insecure vault in a separate directory
@ -104,42 +101,25 @@ func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error)
return vault, insecure, corrupt, nil return vault, insecure, corrupt, nil
} }
func getVaultKey(vaultDir string) ([]byte, error) { func loadVaultKey(vaultDir string) ([]byte, error) {
helper, err := vault.GetHelper(vaultDir) helper, err := vault.GetHelper(vaultDir)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get keychain helper: %w", err) return nil, fmt.Errorf("could not get keychain helper: %w", err)
} }
keychain, err := keychain.NewKeychain(helper, constants.KeyChainName) kc, err := keychain.NewKeychain(helper, constants.KeyChainName)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create keychain: %w", err) return nil, fmt.Errorf("could not create keychain: %w", err)
} }
secrets, err := keychain.List() has, err := vault.HasVaultKey(kc)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not list keychain: %w", err) return nil, fmt.Errorf("could not check for vault key: %w", err)
} }
if !slices.Contains(secrets, vaultSecretName) { if has {
tok, err := crypto.RandomToken(32) return vault.GetVaultKey(kc)
if err != nil {
return nil, fmt.Errorf("could not generate random token: %w", err)
} }
if err := keychain.Put(vaultSecretName, base64.StdEncoding.EncodeToString(tok)); err != nil { return vault.NewVaultKey(kc)
return nil, fmt.Errorf("could not put keychain item: %w", err)
}
}
_, keyEnc, err := keychain.Get(vaultSecretName)
if err != nil {
return nil, fmt.Errorf("could not get keychain item: %w", err)
}
keyDec, err := base64.StdEncoding.DecodeString(keyEnc)
if err != nil {
return nil, fmt.Errorf("could not decode keychain item: %w", err)
}
return keyDec, nil
} }

View File

@ -25,11 +25,9 @@ import (
"path/filepath" "path/filepath"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/updater" "github.com/ProtonMail/proton-bridge/v3/internal/updater"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -310,6 +308,9 @@ func (bridge *Bridge) SetColorScheme(colorScheme string) error {
return bridge.vault.SetColorScheme(colorScheme) return bridge.vault.SetColorScheme(colorScheme)
} }
// FactoryReset deletes all users, wipes the vault, and deletes all files.
// Note: it does not clear the keychain. The only entry in the keychain is the vault password,
// which we need at next startup to decrypt the vault.
func (bridge *Bridge) FactoryReset(ctx context.Context) { func (bridge *Bridge) FactoryReset(ctx context.Context) {
// Delete all the users. // Delete all the users.
safe.Lock(func() { safe.Lock(func() {
@ -326,22 +327,10 @@ func (bridge *Bridge) FactoryReset(ctx context.Context) {
logrus.WithError(err).Error("Failed to reset vault") logrus.WithError(err).Error("Failed to reset vault")
} }
// Then delete all files. // Lastly, delete all files except the vault.
if err := bridge.locator.Clear(); err != nil { if err := bridge.locator.Clear(bridge.vault.Path()); err != nil {
logrus.WithError(err).Error("Failed to clear data paths") logrus.WithError(err).Error("Failed to clear data paths")
} }
// Lastly clear the keychain.
vaultDir, err := bridge.locator.ProvideSettingsPath()
if err != nil {
logrus.WithError(err).Error("Failed to get vault dir")
} else if helper, err := vault.GetHelper(vaultDir); err != nil {
logrus.WithError(err).Error("Failed to get keychain helper")
} else if keychain, err := keychain.NewKeychain(helper, constants.KeyChainName); err != nil {
logrus.WithError(err).Error("Failed to get keychain")
} else if err := keychain.Clear(); err != nil {
logrus.WithError(err).Error("Failed to clear keychain")
}
} }
func getPort(addr net.Addr) int { func getPort(addr net.Addr) int {

View File

@ -30,7 +30,7 @@ type Locator interface {
ProvideGluonDataPath() (string, error) ProvideGluonDataPath() (string, error)
GetLicenseFilePath() string GetLicenseFilePath() string
GetDependencyLicensesLink() string GetDependencyLicensesLink() string
Clear() error Clear(...string) error
} }
type Identifier interface { type Identifier interface {

View File

@ -217,14 +217,13 @@ func (l *Locations) getUpdatesPath() string {
} }
// Clear removes everything except the lock and update files. // Clear removes everything except the lock and update files.
func (l *Locations) Clear() error { func (l *Locations) Clear(except ...string) error {
return files.Remove( return files.Remove(
l.userConfig, l.userConfig,
l.userData, l.userData,
l.userCache, l.userCache,
).Except( ).Except(
l.GetGuiLockFile(), append(except, l.GetGuiLockFile(), l.getUpdatesPath())...,
l.getUpdatesPath(),
).Do() ).Do()
} }

View File

@ -18,13 +18,21 @@
package vault package vault
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io/fs" "io/fs"
"os" "os"
"path/filepath" "path/filepath"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"golang.org/x/exp/slices"
) )
const vaultSecretName = "bridge-vault-key"
type Keychain struct { type Keychain struct {
Helper string Helper string
} }
@ -60,3 +68,43 @@ func SetHelper(vaultDir, helper string) error {
return os.WriteFile(getKeychainPrefPath(vaultDir), b, 0o600) return os.WriteFile(getKeychainPrefPath(vaultDir), b, 0o600)
} }
func HasVaultKey(kc *keychain.Keychain) (bool, error) {
secrets, err := kc.List()
if err != nil {
return false, fmt.Errorf("could not list keychain: %w", err)
}
return slices.Contains(secrets, vaultSecretName), nil
}
func GetVaultKey(kc *keychain.Keychain) ([]byte, error) {
_, keyEnc, err := kc.Get(vaultSecretName)
if err != nil {
return nil, fmt.Errorf("could not get keychain item: %w", err)
}
keyDec, err := base64.StdEncoding.DecodeString(keyEnc)
if err != nil {
return nil, fmt.Errorf("could not decode keychain item: %w", err)
}
return keyDec, nil
}
func SetVaultKey(kc *keychain.Keychain, key []byte) error {
return kc.Put(vaultSecretName, base64.StdEncoding.EncodeToString(key))
}
func NewVaultKey(kc *keychain.Keychain) ([]byte, error) {
tok, err := crypto.RandomToken(32)
if err != nil {
return nil, fmt.Errorf("could not generate random token: %w", err)
}
if err := kc.Put(vaultSecretName, base64.StdEncoding.EncodeToString(tok)); err != nil {
return nil, fmt.Errorf("could not put keychain item: %w", err)
}
return tok, nil
}

View File

@ -191,6 +191,10 @@ func (vault *Vault) Reset(gluonDir string) error {
}) })
} }
func (vault *Vault) Path() string {
return vault.path
}
func (vault *Vault) Close() error { func (vault *Vault) Close() error {
vault.refLock.Lock() vault.refLock.Lock()
defer vault.refLock.Unlock() defer vault.refLock.Unlock()