fix(GODT-3102): Distinguish Vault Decryption from Serialization Errors

Rather than returning whether the vault was corrupt or not return the
error which caused the vault to be considered as corrupt.
This commit is contained in:
Leander Beernaert
2023-11-30 08:31:14 +01:00
parent 7a1c7e8743
commit 1b22c32ef9
10 changed files with 50 additions and 37 deletions

View File

@ -42,7 +42,7 @@ func TestMigratePrefsToVaultWithKeys(t *testing.T) {
// Create a new vault. // Create a new vault.
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), async.NoopPanicHandler{}) vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.NoError(t, corrupt)
// load the old prefs file. // load the old prefs file.
configDir := filepath.Join("testdata", "with_keys") configDir := filepath.Join("testdata", "with_keys")
@ -63,7 +63,7 @@ func TestMigratePrefsToVaultWithoutKeys(t *testing.T) {
// Create a new vault. // Create a new vault.
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), async.NoopPanicHandler{}) vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.NoError(t, corrupt)
// load the old prefs file. // load the old prefs file.
configDir := filepath.Join("testdata", "without_keys") configDir := filepath.Join("testdata", "without_keys")
@ -173,7 +173,7 @@ func TestUserMigration(t *testing.T) {
v, corrupt, err := vault.New(settingsFolder, settingsFolder, token, async.NoopPanicHandler{}) v, corrupt, err := vault.New(settingsFolder, settingsFolder, token, async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.NoError(t, corrupt)
require.NoError(t, migrateOldAccounts(locations, kcl, v)) require.NoError(t, migrateOldAccounts(locations, kcl, v))
require.Equal(t, []string{wantCredentials.UserID}, v.GetUserIDs()) require.Equal(t, []string{wantCredentials.UserID}, v.GetUserIDs())

View File

@ -42,21 +42,25 @@ func WithVault(locations *locations.Locations, keychains *keychain.List, panicHa
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"insecure": insecure, "insecure": insecure,
"corrupt": corrupt, "corrupt": corrupt != nil,
}).Debug("Vault created") }).Debug("Vault created")
if corrupt != nil {
logrus.WithError(corrupt).Warn("Failed to load existing vault, vault has been reset")
}
cert, _ := encVault.GetBridgeTLSCert() cert, _ := encVault.GetBridgeTLSCert()
certs.NewInstaller().LogCertInstallStatus(cert) certs.NewInstaller().LogCertInstallStatus(cert)
// GODT-1950: Add teardown actions (e.g. to close the vault). // GODT-1950: Add teardown actions (e.g. to close the vault).
return fn(encVault, insecure, corrupt) return fn(encVault, insecure, corrupt != nil)
} }
func newVault(locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) { func newVault(locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler) (*vault.Vault, bool, error, error) {
vaultDir, err := locations.ProvideSettingsPath() vaultDir, err := locations.ProvideSettingsPath()
if err != nil { if err != nil {
return nil, false, false, fmt.Errorf("could not get vault dir: %w", err) return nil, false, nil, fmt.Errorf("could not get vault dir: %w", err)
} }
logrus.WithField("vaultDir", vaultDir).Debug("Loading vault from directory") logrus.WithField("vaultDir", vaultDir).Debug("Loading vault from directory")
@ -78,12 +82,12 @@ func newVault(locations *locations.Locations, keychains *keychain.List, panicHan
gluonCacheDir, err := locations.ProvideGluonCachePath() gluonCacheDir, err := locations.ProvideGluonCachePath()
if err != nil { if err != nil {
return nil, false, false, fmt.Errorf("could not provide gluon path: %w", err) return nil, false, nil, fmt.Errorf("could not provide gluon path: %w", err)
} }
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey, panicHandler) vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey, panicHandler)
if err != nil { if err != nil {
return nil, false, false, fmt.Errorf("could not create vault: %w", err) return nil, false, corrupt, fmt.Errorf("could not create vault: %w", err)
} }
return vault, insecure, corrupt, nil return vault, insecure, corrupt, nil

View File

@ -133,7 +133,7 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"), nil) v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"), nil)
require.NoError(tb, err) require.NoError(tb, err)
require.False(tb, corrupt) require.NoError(tb, corrupt)
vaultUser, err := v.AddUser(apiUser.ID, username, username+"@pm.me", apiAuth.UID, apiAuth.RefreshToken, saltedKeyPass) vaultUser, err := v.AddUser(apiUser.ID, username, username+"@pm.me", apiAuth.UID, apiAuth.RefreshToken, saltedKeyPass)
require.NoError(tb, err) require.NoError(tb, err)

View File

@ -55,7 +55,7 @@ func TestMigrate(t *testing.T) {
// Migrate the vault. // Migrate the vault.
s, corrupt, err := New(dir, "default-gluon-dir", []byte("my secret key"), async.NoopPanicHandler{}) s, corrupt, err := New(dir, "default-gluon-dir", []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.NoError(t, corrupt)
// Check the migrated vault. // Check the migrated vault.
require.Equal(t, "v2.3.x-gluon-dir", s.GetGluonCacheDir()) require.Equal(t, "v2.3.x-gluon-dir", s.GetGluonCacheDir())

View File

@ -68,7 +68,7 @@ func TestVault_Settings_GluonDir(t *testing.T) {
// create a new test vault. // create a new test vault.
s, corrupt, err := vault.New(t.TempDir(), "/path/to/gluon", []byte("my secret key"), async.NoopPanicHandler{}) s, corrupt, err := vault.New(t.TempDir(), "/path/to/gluon", []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.NoError(t, corrupt)
// Check the default gluon dir. // Check the default gluon dir.
require.Equal(t, "/path/to/gluon", s.GetGluonCacheDir()) require.Equal(t, "/path/to/gluon", s.GetGluonCacheDir())

View File

@ -19,6 +19,7 @@ package vault
import ( import (
"crypto/cipher" "crypto/cipher"
"fmt"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
@ -34,12 +35,12 @@ func unmarshalFile[T any](gcm cipher.AEAD, b []byte, data *T) error {
var f File var f File
if err := msgpack.Unmarshal(b, &f); err != nil { if err := msgpack.Unmarshal(b, &f); err != nil {
return err return fmt.Errorf("%w: %v", ErrUnmarshal, err)
} }
dec, err := gcm.Open(nil, f.Data[:gcm.NonceSize()], f.Data[gcm.NonceSize():], nil) dec, err := gcm.Open(nil, f.Data[:gcm.NonceSize()], f.Data[gcm.NonceSize():], nil)
if err != nil { if err != nil {
return err return fmt.Errorf("%w: %v", ErrDecryptFailed, err)
} }
for v := f.Version; v < Current; v++ { for v := f.Version; v < Current; v++ {
@ -48,7 +49,11 @@ func unmarshalFile[T any](gcm cipher.AEAD, b []byte, data *T) error {
} }
} }
return msgpack.Unmarshal(dec, data) if err := msgpack.Unmarshal(dec, data); err != nil {
return fmt.Errorf("%w: %v", ErrUnmarshal, err)
}
return nil
} }
func marshalFile[T any](gcm cipher.AEAD, t T) ([]byte, error) { func marshalFile[T any](gcm cipher.AEAD, t T) ([]byte, error) {

View File

@ -49,27 +49,31 @@ type Vault struct {
panicHandler async.PanicHandler panicHandler async.PanicHandler
} }
var ErrDecryptFailed = errors.New("failed to decrypt vault")
var ErrUnmarshal = errors.New("vault contents are corrupt")
// New constructs a new encrypted data vault at the given filepath using the given encryption key. // New constructs a new encrypted data vault at the given filepath using the given encryption key.
func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHandler) (*Vault, bool, error) { // The first error is a corruption error for an existing vault, the second errors refrain to all other errors.
func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHandler) (*Vault, error, error) {
if err := os.MkdirAll(vaultDir, 0o700); err != nil { if err := os.MkdirAll(vaultDir, 0o700); err != nil {
return nil, false, err return nil, nil, err
} }
hash256 := sha256.Sum256(key) hash256 := sha256.Sum256(key)
aes, err := aes.NewCipher(hash256[:]) aes, err := aes.NewCipher(hash256[:])
if err != nil { if err != nil {
return nil, false, err return nil, nil, err
} }
gcm, err := cipher.NewGCM(aes) gcm, err := cipher.NewGCM(aes)
if err != nil { if err != nil {
return nil, false, err return nil, nil, err
} }
vault, corrupt, err := newVault(filepath.Join(vaultDir, "vault.enc"), gluonCacheDir, gcm) vault, corrupt, err := newVault(filepath.Join(vaultDir, "vault.enc"), gluonCacheDir, gcm)
if err != nil { if err != nil {
return nil, false, err return nil, corrupt, err
} }
vault.panicHandler = panicHandler vault.panicHandler = panicHandler
@ -341,28 +345,28 @@ func (vault *Vault) detachUser(userID string) error {
return nil return nil
} }
func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) { func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, error, error) {
if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) { if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) {
if _, err := initVault(path, gluonDir, gcm); err != nil { if _, err := initVault(path, gluonDir, gcm); err != nil {
return nil, false, err return nil, nil, err
} }
} }
enc, err := os.ReadFile(filepath.Clean(path)) enc, err := os.ReadFile(filepath.Clean(path))
if err != nil { if err != nil {
return nil, false, err return nil, nil, err
} }
var corrupt bool var corrupt error
if err := unmarshalFile(gcm, enc, new(Data)); err != nil { if err := unmarshalFile(gcm, enc, new(Data)); err != nil {
corrupt = true corrupt = err
} }
if corrupt { if corrupt != nil {
newEnc, err := initVault(path, gluonDir, gcm) newEnc, err := initVault(path, gluonDir, gcm)
if err != nil { if err != nil {
return nil, false, err return nil, corrupt, err
} }
enc = newEnc enc = newEnc

View File

@ -34,7 +34,7 @@ func BenchmarkVault(b *testing.B) {
// Create a new vault. // Create a new vault.
s, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{}) s, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(b, err) require.NoError(b, err)
require.False(b, corrupt) require.NoError(b, corrupt)
// Add 10kB of cookies to the vault. // Add 10kB of cookies to the vault.
require.NoError(b, s.SetCookies(bytes.Repeat([]byte("a"), 10_000))) require.NoError(b, s.SetCookies(bytes.Repeat([]byte("a"), 10_000)))

View File

@ -34,19 +34,19 @@ func TestVault_Corrupt(t *testing.T) {
{ {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{}) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.NoError(t, corrupt)
} }
{ {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{}) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.NoError(t, corrupt)
} }
{ {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("bad key"), async.NoopPanicHandler{}) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("bad key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.True(t, corrupt) require.ErrorIs(t, corrupt, vault.ErrDecryptFailed)
} }
} }
@ -56,13 +56,13 @@ func TestVault_Corrupt_JunkData(t *testing.T) {
{ {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{}) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.NoError(t, corrupt)
} }
{ {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{}) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.NoError(t, corrupt)
} }
{ {
@ -75,7 +75,7 @@ func TestVault_Corrupt_JunkData(t *testing.T) {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{}) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.True(t, corrupt) require.ErrorIs(t, corrupt, vault.ErrUnmarshal)
} }
} }
@ -103,7 +103,7 @@ func newVault(t *testing.T) *vault.Vault {
s, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), async.NoopPanicHandler{}) s, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.NoError(t, corrupt)
return s return s
} }

View File

@ -112,8 +112,8 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) {
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, t.storeKey, async.NoopPanicHandler{}) vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, t.storeKey, async.NoopPanicHandler{})
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create vault: %w", err) return nil, fmt.Errorf("could not create vault: %w", err)
} else if corrupt { } else if corrupt != nil {
return nil, fmt.Errorf("vault is corrupt") return nil, fmt.Errorf("vault is corrupt: %w", corrupt)
} }
t.vault = vault t.vault = vault