feat(BRIDGE-356): Added retry logic for unavailable preferred keychain on Linux; Feature flag support before bridge initialization; Refactored some bits of the code;
This commit is contained in:
@ -39,7 +39,9 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/frontend/theme"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/sentry"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
|
||||
@ -285,11 +287,13 @@ func run(c *cli.Context) error {
|
||||
logrus.WithError(err).Error("Failed to get settings path")
|
||||
}
|
||||
|
||||
featureFlags := unleash.GetStartupFeatureFlagsAndStore(constants.APIHost, version, locations.ProvideUnleashStartupCachePath)
|
||||
|
||||
return withSingleInstance(settings, locations.GetLockFile(), version, func() error {
|
||||
// Look for available keychains
|
||||
return WithKeychainList(crashHandler, func(keychains *keychain.List) error {
|
||||
// Unlock the encrypted vault.
|
||||
return WithVault(reporter, locations, keychains, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error {
|
||||
return WithVault(reporter, locations, keychains, featureFlags, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error {
|
||||
if !v.Migrated() {
|
||||
// Migrate old settings into the vault.
|
||||
if err := migrateOldSettings(v); err != nil {
|
||||
@ -577,5 +581,5 @@ func setDeviceCookies(jar *cookies.Jar) error {
|
||||
}
|
||||
|
||||
func onMacOS() bool {
|
||||
return runtime.GOOS == "darwin"
|
||||
return runtime.GOOS == platform.MACOS
|
||||
}
|
||||
|
||||
@ -138,7 +138,14 @@ func migrateOldAccounts(locations *locations.Locations, keychains *keychain.List
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get helper: %w", err)
|
||||
}
|
||||
keychain, _, err := keychain.NewKeychain(helper, "bridge", keychains.GetHelpers(), keychains.GetDefaultHelper())
|
||||
|
||||
keychain, _, err := keychain.NewKeychain(
|
||||
helper, "bridge",
|
||||
keychains.GetHelpers(),
|
||||
keychains.GetDefaultHelper(),
|
||||
0,
|
||||
make(map[string]bool),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create keychain: %w", err)
|
||||
}
|
||||
|
||||
@ -31,6 +31,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/cookies"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/legacy/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
|
||||
@ -85,7 +86,7 @@ func TestMigratePrefsToVaultWithoutKeys(t *testing.T) {
|
||||
|
||||
func TestKeychainMigration(t *testing.T) {
|
||||
// Migration tested only for linux.
|
||||
if runtime.GOOS != "linux" {
|
||||
if runtime.GOOS != platform.LINUX {
|
||||
return
|
||||
}
|
||||
|
||||
@ -134,7 +135,13 @@ func TestKeychainMigration(t *testing.T) {
|
||||
func TestUserMigration(t *testing.T) {
|
||||
kcl := keychain.NewTestKeychainsList()
|
||||
|
||||
kc, _, err := keychain.NewKeychain("mock", "bridge", kcl.GetHelpers(), kcl.GetDefaultHelper())
|
||||
kc, _, err := keychain.NewKeychain(
|
||||
"mock", "bridge",
|
||||
kcl.GetHelpers(),
|
||||
kcl.GetDefaultHelper(),
|
||||
0,
|
||||
make(map[string]bool),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, kc.Put("brokenID", "broken"))
|
||||
|
||||
@ -20,6 +20,7 @@ package app
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"runtime"
|
||||
@ -28,18 +29,20 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/certs"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/sentry"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func WithVault(reporter *sentry.Reporter, locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error {
|
||||
func WithVault(reporter *sentry.Reporter, locations *locations.Locations, keychains *keychain.List, featureFlags unleash.FeatureFlagStartupStore, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error {
|
||||
logrus.Debug("Creating vault")
|
||||
defer logrus.Debug("Vault stopped")
|
||||
|
||||
// Create the encVault.
|
||||
encVault, insecure, corrupt, err := newVault(reporter, locations, keychains, panicHandler)
|
||||
encVault, insecure, corrupt, err := newVault(reporter, locations, keychains, featureFlags, panicHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create vault: %w", err)
|
||||
}
|
||||
@ -61,7 +64,7 @@ func WithVault(reporter *sentry.Reporter, locations *locations.Locations, keycha
|
||||
return fn(encVault, insecure, corrupt != nil)
|
||||
}
|
||||
|
||||
func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler) (*vault.Vault, bool, error, error) {
|
||||
func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychains *keychain.List, featureFlags unleash.FeatureFlagStartupStore, panicHandler async.PanicHandler) (*vault.Vault, bool, error, error) {
|
||||
vaultDir, err := locations.ProvideSettingsPath()
|
||||
if err != nil {
|
||||
return nil, false, nil, fmt.Errorf("could not get vault dir: %w", err)
|
||||
@ -75,7 +78,14 @@ func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychai
|
||||
lastUsedHelper string
|
||||
)
|
||||
|
||||
if key, helper, err := loadVaultKey(vaultDir, keychains); err != nil {
|
||||
if key, helper, err := loadVaultKey(vaultDir, keychains, featureFlags); err != nil {
|
||||
if errors.Is(err, keychain.ErrPreferredKeychainNotAvailable) {
|
||||
if err := vault.IncrementKeychainFailedAttemptCount(vaultDir); err != nil {
|
||||
logrus.WithError(err).Error("Failed to increment failed keychain attempt count")
|
||||
}
|
||||
return &vault.Vault{}, false, nil, err
|
||||
}
|
||||
|
||||
if reporter != nil {
|
||||
if rerr := reporter.ReportMessageWithContext("Could not load/create vault key", map[string]any{
|
||||
"keychainDefaultHelper": keychains.GetDefaultHelper(),
|
||||
@ -108,23 +118,38 @@ func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychai
|
||||
}
|
||||
|
||||
// Remember the last successfully used keychain on Linux and store that as the user preference.
|
||||
if runtime.GOOS == "linux" {
|
||||
if runtime.GOOS == platform.LINUX {
|
||||
if err := vault.SetHelper(vaultDir, lastUsedHelper); err != nil {
|
||||
logrus.WithError(err).Error("Could not store last used keychain helper")
|
||||
}
|
||||
|
||||
if err := vault.ResetFailedKeychainAttemptCount(vaultDir); err != nil {
|
||||
logrus.WithError(err).Error("Could not reset and save failed keychain attempt count")
|
||||
}
|
||||
}
|
||||
|
||||
return userVault, insecure, corrupt, nil
|
||||
}
|
||||
|
||||
// loadVaultKey - loads the key used to encrypt the vault alongside the keychain helper used to access it.
|
||||
func loadVaultKey(vaultDir string, keychains *keychain.List) (key []byte, keychainHelper string, err error) {
|
||||
func loadVaultKey(vaultDir string, keychains *keychain.List, featureFlags unleash.FeatureFlagStartupStore) (key []byte, keychainHelper string, err error) {
|
||||
keychainHelper, err = vault.GetHelper(vaultDir)
|
||||
if err != nil {
|
||||
return nil, keychainHelper, fmt.Errorf("could not get keychain helper: %w", err)
|
||||
}
|
||||
|
||||
kc, keychainHelper, err := keychain.NewKeychain(keychainHelper, constants.KeyChainName, keychains.GetHelpers(), keychains.GetDefaultHelper())
|
||||
keychainFailedAttemptCount, err := vault.GetKeychainFailedAttemptCount(vaultDir)
|
||||
if err != nil {
|
||||
return nil, keychainHelper, fmt.Errorf("could not get keychain failed attempt count: %w", err)
|
||||
}
|
||||
|
||||
kc, keychainHelper, err := keychain.NewKeychain(
|
||||
keychainHelper, constants.KeyChainName,
|
||||
keychains.GetHelpers(),
|
||||
keychains.GetDefaultHelper(),
|
||||
keychainFailedAttemptCount,
|
||||
featureFlags,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, keychainHelper, fmt.Errorf("could not create keychain: %w", err)
|
||||
}
|
||||
@ -139,6 +164,12 @@ func loadVaultKey(vaultDir string, keychains *keychain.List) (key []byte, keycha
|
||||
return key, keychainHelper, err
|
||||
}
|
||||
|
||||
if keychain.ShouldRetryPreferredKeychain(featureFlags, keychainHelper) {
|
||||
if keychainFailedAttemptCount < keychain.MaxFailedKeychainAttemptsLinux {
|
||||
return nil, keychainHelper, keychain.PreferredKeychainRetryError(keychainFailedAttemptCount)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, keychainHelper, fmt.Errorf("could not check for vault key: %w", err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user