From 3d53bf747750f4fe22d3b0f0c0647d1f85b0c312 Mon Sep 17 00:00:00 2001 From: Xavier Michelon Date: Thu, 8 Aug 2024 10:54:11 +0200 Subject: [PATCH] feat(BRIDGE-116): add command-line switches to enable/disable keychain check on macOS. --- internal/app/app.go | 68 +++++++++++++++++++--- internal/app/app_test.go | 65 +++++++++++++++++++++ internal/vault/helper.go | 64 ++++++++++---------- internal/vault/helper_test.go | 46 +++++++++++++++ internal/vault/keychain_settings.go | 74 ++++++++++++++++++++++++ internal/vault/keychain_settings_test.go | 58 +++++++++++++++++++ pkg/keychain/helper_darwin.go | 13 +++-- pkg/keychain/helper_linux.go | 2 +- pkg/keychain/helper_windows.go | 2 +- pkg/keychain/keychain.go | 4 +- pkg/keychain/keychain_test.go | 2 +- utils/bridge-rollout/bridge-rollout.go | 4 +- utils/vault-editor/main.go | 4 +- 13 files changed, 355 insertions(+), 51 deletions(-) create mode 100644 internal/app/app_test.go create mode 100644 internal/vault/helper_test.go create mode 100644 internal/vault/keychain_settings.go create mode 100644 internal/vault/keychain_settings_test.go diff --git a/internal/app/app.go b/internal/app/app.go index 54ffd1be..a0cb6c37 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -79,11 +79,13 @@ const ( // Hidden flags. const ( - flagLauncher = "launcher" - flagNoWindow = "no-window" - flagParentPID = "parent-pid" - flagSoftwareRenderer = "software-renderer" - FlagSessionID = "session-id" + flagLauncher = "launcher" + flagNoWindow = "no-window" + flagParentPID = "parent-pid" + flagSoftwareRenderer = "software-renderer" + flagEnableKeychainTest = "enable-keychain-test" + flagDisableKeychainTest = "disable-keychain-test" + FlagSessionID = "session-id" ) const ( @@ -91,6 +93,20 @@ const ( appShortName = "bridge" ) +var cliFlagEnableKeychainTest = &cli.BoolFlag{ //nolint:gochecknoglobals + Name: flagEnableKeychainTest, + Usage: "Enable the keychain test", + Hidden: true, + Value: false, +} //nolint:gochecknoglobals + +var cliFlagDisableKeychainTest = &cli.BoolFlag{ //nolint:gochecknoglobals + Name: flagDisableKeychainTest, + Usage: "Disable the keychain test", + Hidden: true, + Value: false, +} + func New() *cli.App { app := cli.NewApp() @@ -168,6 +184,9 @@ func New() *cli.App { Name: FlagSessionID, Hidden: true, }, + // the two flags below were introduced by BRIDGE-116 + cliFlagEnableKeychainTest, + cliFlagDisableKeychainTest, } app.Action = run @@ -238,7 +257,8 @@ func run(c *cli.Context) error { return withSingleInstance(settings, locations.GetLockFile(), version, func() error { // Look for available keychains - return WithKeychainList(crashHandler, func(keychains *keychain.List) error { + skipKeychainTest := checkSkipKeychainTest(c, settings) + return WithKeychainList(crashHandler, skipKeychainTest, func(keychains *keychain.List) error { // Unlock the encrypted vault. return WithVault(locations, keychains, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error { if !v.Migrated() { @@ -502,11 +522,11 @@ func withCookieJar(vault *vault.Vault, fn func(http.CookieJar) error) error { } // WithKeychainList init the list of usable keychains. -func WithKeychainList(panicHandler async.PanicHandler, fn func(*keychain.List) error) error { +func WithKeychainList(panicHandler async.PanicHandler, skipKeychainTest bool, fn func(*keychain.List) error) error { logrus.Debug("Creating keychain list") defer logrus.Debug("Keychain list stop") defer async.HandlePanic(panicHandler) - return fn(keychain.NewList()) + return fn(keychain.NewList(skipKeychainTest)) } func setDeviceCookies(jar *cookies.Jar) error { @@ -526,3 +546,35 @@ func setDeviceCookies(jar *cookies.Jar) error { return nil } + +func checkSkipKeychainTest(c *cli.Context, settingsDir string) bool { + if runtime.GOOS != "darwin" { + return false + } + + enable := c.Bool(flagEnableKeychainTest) + disable := c.Bool(flagDisableKeychainTest) + + skip, err := vault.GetShouldSkipKeychainTest(settingsDir) + if err != nil { + logrus.WithError(err).Error("Could not load keychain settings.") + } + + if (!enable) && (!disable) { + return skip + } + + // if both switches are passed, 'enable' has priority + if disable { + skip = true + } + if enable { + skip = false + } + + if err := vault.SetShouldSkipKeychainTest(settingsDir, skip); err != nil { + logrus.WithError(err).Error("Could not save keychain settings.") + } + + return skip +} diff --git a/internal/app/app_test.go b/internal/app/app_test.go new file mode 100644 index 00000000..116d12ac --- /dev/null +++ b/internal/app/app_test.go @@ -0,0 +1,65 @@ +// Copyright (c) 2024 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package app + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" +) + +func TestCheckSkipKeychainTest(t *testing.T) { + var expectedResult bool + dir := t.TempDir() + app := cli.App{ + Flags: []cli.Flag{ + cliFlagEnableKeychainTest, + cliFlagDisableKeychainTest, + }, + Action: func(c *cli.Context) error { + require.Equal(t, expectedResult, checkSkipKeychainTest(c, dir)) + return nil + }, + } + + noArgs := []string{"appName"} + enableArgs := []string{"appName", "-" + flagEnableKeychainTest} + disableArgs := []string{"appName", "-" + flagDisableKeychainTest} + bothArgs := []string{"appName", "-" + flagDisableKeychainTest, "-" + flagEnableKeychainTest} + + const trueOnlyOnMac = runtime.GOOS == "darwin" + + expectedResult = false + require.NoError(t, app.Run(noArgs)) + + expectedResult = trueOnlyOnMac + require.NoError(t, app.Run(disableArgs)) + require.NoError(t, app.Run(noArgs)) + + expectedResult = false + require.NoError(t, app.Run(enableArgs)) + require.NoError(t, app.Run(noArgs)) + + expectedResult = trueOnlyOnMac + require.NoError(t, app.Run(disableArgs)) + + expectedResult = false + require.NoError(t, app.Run(bothArgs)) +} diff --git a/internal/vault/helper.go b/internal/vault/helper.go index b0a965eb..9baff2c4 100644 --- a/internal/vault/helper.go +++ b/internal/vault/helper.go @@ -19,53 +19,57 @@ package vault import ( "encoding/base64" - "encoding/json" - "errors" "fmt" - "io/fs" - "os" - "path/filepath" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v3/pkg/keychain" + "github.com/sirupsen/logrus" ) const vaultSecretName = "bridge-vault-key" -type Keychain struct { - Helper string -} - -func getKeychainPrefPath(vaultDir string) string { - return filepath.Clean(filepath.Join(vaultDir, "keychain.json")) -} - -func GetHelper(vaultDir string) (string, error) { - if _, err := os.Stat(getKeychainPrefPath(vaultDir)); errors.Is(err, fs.ErrNotExist) { - return "", nil - } - - b, err := os.ReadFile(getKeychainPrefPath(vaultDir)) +func GetShouldSkipKeychainTest(vaultDir string) (bool, error) { + settings, err := LoadKeychainSettings(vaultDir) if err != nil { - return "", err + return false, err } - var keychain Keychain - - if err := json.Unmarshal(b, &keychain); err != nil { - return "", err - } - - return keychain.Helper, nil + return settings.DisableTest, nil } -func SetHelper(vaultDir, helper string) error { - b, err := json.MarshalIndent(Keychain{Helper: helper}, "", " ") +func SetShouldSkipKeychainTest(vaultDir string, skip bool) error { + settings, err := LoadKeychainSettings(vaultDir) if err != nil { return err } - return os.WriteFile(getKeychainPrefPath(vaultDir), b, 0o600) + log := logrus.WithFields(logrus.Fields{"pkg": "vault", "skipKeychainTest": skip}) + if skip == settings.DisableTest { + log.Info("Skipping change of keychain test setting as value is not modified") + return nil + } + + logrus.WithFields(logrus.Fields{"pkg": "vault", "skipKeychainTest": skip}).Info("Setting keychain test skip option") + settings.DisableTest = skip + return settings.Save(vaultDir) +} + +func GetHelper(vaultDir string) (string, error) { + settings, err := LoadKeychainSettings(vaultDir) + if err != nil { + return "", err + } + return settings.Helper, nil +} + +func SetHelper(vaultDir, helper string) error { + settings, err := LoadKeychainSettings(vaultDir) + if err != nil { + return err + } + + settings.Helper = helper + return settings.Save(vaultDir) } func GetVaultKey(kc *keychain.Keychain) ([]byte, error) { diff --git a/internal/vault/helper_test.go b/internal/vault/helper_test.go new file mode 100644 index 00000000..d2a520d6 --- /dev/null +++ b/internal/vault/helper_test.go @@ -0,0 +1,46 @@ +// Copyright (c) 2024 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package vault + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestShouldSkipKeychainTestAccessors(t *testing.T) { + dir := t.TempDir() + skip, err := GetShouldSkipKeychainTest(dir) + require.NoError(t, err) + require.False(t, skip) + require.NoError(t, SetShouldSkipKeychainTest(dir, true)) + skip, err = GetShouldSkipKeychainTest(dir) + require.NoError(t, err) + require.True(t, skip) +} + +func TestHelperAccessors(t *testing.T) { + dir := t.TempDir() + helper, err := GetHelper(dir) + require.NoError(t, err) + require.Zero(t, len(helper)) + require.NoError(t, SetHelper(dir, "keychain")) + helper, err = GetHelper(dir) + require.NoError(t, err) + require.Equal(t, "keychain", helper) +} diff --git a/internal/vault/keychain_settings.go b/internal/vault/keychain_settings.go new file mode 100644 index 00000000..ac6cb357 --- /dev/null +++ b/internal/vault/keychain_settings.go @@ -0,0 +1,74 @@ +// Copyright (c) 2024 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package vault + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/sirupsen/logrus" +) + +const keychainSettingsFileName = "keychain.json" + +// KeychainSettings holds settings related to the keychain. It is serialized in the vault directory. +type KeychainSettings struct { + Helper string // The helper used for keychain. + DisableTest bool // Is the keychain test on startup disabled? +} + +// LoadKeychainSettings load keychain settings from the vaultDir folder, or returns a default one if the file +// does not exists or is invalid. +func LoadKeychainSettings(vaultDir string) (KeychainSettings, error) { + path := filepath.Join(vaultDir, keychainSettingsFileName) + bytes, err := os.ReadFile(path) //nolint:gosec + if err != nil { + if errors.Is(err, os.ErrNotExist) { + logrus. + WithFields(logrus.Fields{"pkg": "vault", "path": path}). + Trace("Keychain settings file does not exists, default values will be used") + return KeychainSettings{}, nil + } + return KeychainSettings{}, err + } + + var result KeychainSettings + if err := json.Unmarshal(bytes, &result); err != nil { + return KeychainSettings{}, fmt.Errorf("keychain settings file is invalid settings: %w", err) + } + + return result, nil +} + +// Save saves the keychain settings in a file in the vaultDir folder. +func (k KeychainSettings) Save(vaultDir string) error { + bytes, err := json.MarshalIndent(k, "", " ") + if err != nil { + return err + } + + if err = os.MkdirAll(vaultDir, 0o700); err != nil { + return err + } + + path := filepath.Join(vaultDir, keychainSettingsFileName) + return os.WriteFile(path, bytes, 0o600) +} diff --git a/internal/vault/keychain_settings_test.go b/internal/vault/keychain_settings_test.go new file mode 100644 index 00000000..bbf7596f --- /dev/null +++ b/internal/vault/keychain_settings_test.go @@ -0,0 +1,58 @@ +// Copyright (c) 2024 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package vault + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestKeychainSettingsIO(t *testing.T) { + dir := t.TempDir() + + // test loading non existing file. no error but loads defaults. + settings, err := LoadKeychainSettings(dir) + require.NoError(t, err) + require.Equal(t, settings, KeychainSettings{}) + + // test file creation + settings.Helper = "dummy1" + settings.DisableTest = true + require.NoError(t, settings.Save(dir)) + + // test reading existing file + readBack, err := LoadKeychainSettings(dir) + require.NoError(t, err) + require.Equal(t, settings, readBack) + + // test file overwrite and read back + settings.Helper = "dummy2" + require.NoError(t, settings.Save(dir)) + readBack, err = LoadKeychainSettings(dir) + require.NoError(t, err) + require.Equal(t, settings, readBack) + + // test error on invalid content + settingsFilePath := filepath.Join(dir, keychainSettingsFileName) + require.NoError(t, os.WriteFile(settingsFilePath, []byte("][INVALID"), 0o600)) + _, err = LoadKeychainSettings(dir) + require.Error(t, err) +} diff --git a/pkg/keychain/helper_darwin.go b/pkg/keychain/helper_darwin.go index 5d9147fd..867d2fec 100644 --- a/pkg/keychain/helper_darwin.go +++ b/pkg/keychain/helper_darwin.go @@ -31,15 +31,20 @@ const ( MacOSKeychain = "macos-keychain" ) -func listHelpers() (Helpers, string) { +func listHelpers(skipKeychainTest bool) (Helpers, string) { helpers := make(Helpers) // MacOS always provides a keychain. - if isUsable(newMacOSHelper("")) { + if skipKeychainTest { + logrus.WithField("pkg", "keychain").Info("Skipping macOS keychain test") helpers[MacOSKeychain] = newMacOSHelper - logrus.WithField("keychain", "MacOSKeychain").Info("Keychain is usable.") } else { - logrus.WithField("keychain", "MacOSKeychain").Debug("Keychain is not available.") + if isUsable(newMacOSHelper("")) { + helpers[MacOSKeychain] = newMacOSHelper + logrus.WithField("keychain", "MacOSKeychain").Info("Keychain is usable.") + } else { + logrus.WithField("keychain", "MacOSKeychain").Debug("Keychain is not available.") + } } // Use MacOSKeychain by default. diff --git a/pkg/keychain/helper_linux.go b/pkg/keychain/helper_linux.go index ce531faa..7a395bc6 100644 --- a/pkg/keychain/helper_linux.go +++ b/pkg/keychain/helper_linux.go @@ -31,7 +31,7 @@ const ( SecretServiceDBus = "secret-service-dbus" ) -func listHelpers() (Helpers, string) { +func listHelpers(_ bool) (Helpers, string) { helpers := make(Helpers) if isUsable(newDBusHelper("")) { diff --git a/pkg/keychain/helper_windows.go b/pkg/keychain/helper_windows.go index 6ac141ff..506601d9 100644 --- a/pkg/keychain/helper_windows.go +++ b/pkg/keychain/helper_windows.go @@ -25,7 +25,7 @@ import ( const WindowsCredentials = "windows-credentials" -func listHelpers() (Helpers, string) { +func listHelpers(_ bool) (Helpers, string) { helpers := make(Helpers) // Windows always provides a keychain. if isUsable(newWinCredHelper("")) { diff --git a/pkg/keychain/keychain.go b/pkg/keychain/keychain.go index 95a53efd..d734f398 100644 --- a/pkg/keychain/keychain.go +++ b/pkg/keychain/keychain.go @@ -62,9 +62,9 @@ type List struct { // NewList checks availability of every keychains detected on the User Operating System // This will ask the user to unlock keychain(s) to check their usability. // This should only be called once. -func NewList() *List { +func NewList(skipKeychainTest bool) *List { var list = List{locker: &sync.Mutex{}} - list.helpers, list.defaultHelper = listHelpers() + list.helpers, list.defaultHelper = listHelpers(skipKeychainTest) return &list } diff --git a/pkg/keychain/keychain_test.go b/pkg/keychain/keychain_test.go index becbb4d3..62ee659d 100644 --- a/pkg/keychain/keychain_test.go +++ b/pkg/keychain/keychain_test.go @@ -117,7 +117,7 @@ func TestInsertReadRemove(t *testing.T) { func TestIsErrKeychainNoItem(t *testing.T) { r := require.New(t) - helpers := NewList().GetHelpers() + helpers := NewList(false).GetHelpers() for helperName := range helpers { kc, err := NewKeychain(helperName, "bridge-test", helpers, helperName) diff --git a/utils/bridge-rollout/bridge-rollout.go b/utils/bridge-rollout/bridge-rollout.go index deb468c1..4df9d0df 100644 --- a/utils/bridge-rollout/bridge-rollout.go +++ b/utils/bridge-rollout/bridge-rollout.go @@ -61,7 +61,7 @@ func main() { func getRollout(_ *cli.Context) error { return app.WithLocations(func(locations *locations.Locations) error { - return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error { + return app.WithKeychainList(async.NoopPanicHandler{}, false, func(keychains *keychain.List) error { return app.WithVault(locations, keychains, async.NoopPanicHandler{}, func(vault *vault.Vault, _, _ bool) error { fmt.Println(vault.GetUpdateRollout()) return nil @@ -72,7 +72,7 @@ func getRollout(_ *cli.Context) error { func setRollout(c *cli.Context) error { return app.WithLocations(func(locations *locations.Locations) error { - return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error { + return app.WithKeychainList(async.NoopPanicHandler{}, false, func(keychains *keychain.List) error { return app.WithVault(locations, keychains, async.NoopPanicHandler{}, func(vault *vault.Vault, _, _ bool) error { clamped := max(0.0, min(1.0, c.Float64("value"))) if err := vault.SetUpdateRollout(clamped); err != nil { diff --git a/utils/vault-editor/main.go b/utils/vault-editor/main.go index 20c63ba0..a922e1ea 100644 --- a/utils/vault-editor/main.go +++ b/utils/vault-editor/main.go @@ -51,7 +51,7 @@ func main() { func readAction(c *cli.Context) error { return app.WithLocations(func(locations *locations.Locations) error { - return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error { + return app.WithKeychainList(async.NoopPanicHandler{}, false, func(keychains *keychain.List) error { return app.WithVault(locations, keychains, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error { if _, err := os.Stdout.Write(vault.ExportJSON()); err != nil { return fmt.Errorf("failed to write vault: %w", err) @@ -65,7 +65,7 @@ func readAction(c *cli.Context) error { func writeAction(c *cli.Context) error { return app.WithLocations(func(locations *locations.Locations) error { - return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error { + return app.WithKeychainList(async.NoopPanicHandler{}, false, func(keychains *keychain.List) error { return app.WithVault(locations, keychains, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error { b, err := io.ReadAll(os.Stdin) if err != nil {