feat(GODT-2277): Move Keychain helpers creation in main.

This commit is contained in:
Romain Le Jeune
2023-11-08 13:05:57 +00:00
parent 96904b160f
commit e8d9534b9c
17 changed files with 243 additions and 134 deletions

View File

@ -41,6 +41,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/sentry" "github.com/ProtonMail/proton-bridge/v3/internal/sentry"
"github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/useragent"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/ProtonMail/proton-bridge/v3/pkg/restarter" "github.com/ProtonMail/proton-bridge/v3/pkg/restarter"
"github.com/pkg/profile" "github.com/pkg/profile"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -234,56 +235,59 @@ func run(c *cli.Context) error {
} }
return withSingleInstance(settings, locations.GetLockFile(), version, func() error { return withSingleInstance(settings, locations.GetLockFile(), version, func() error {
// Unlock the encrypted vault. // Look for available keychains
return WithVault(locations, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error { return withKeychainList(func(keychains *keychain.List) error {
if !v.Migrated() { // Unlock the encrypted vault.
// Migrate old settings into the vault. return WithVault(locations, keychains, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error {
if err := migrateOldSettings(v); err != nil { if !v.Migrated() {
logrus.WithError(err).Error("Failed to migrate old settings") // Migrate old settings into the vault.
} if err := migrateOldSettings(v); err != nil {
logrus.WithError(err).Error("Failed to migrate old settings")
// Migrate old accounts into the vault.
if err := migrateOldAccounts(locations, v); err != nil {
logrus.WithError(err).Error("Failed to migrate old accounts")
}
// The vault has been migrated.
if err := v.SetMigrated(); err != nil {
logrus.WithError(err).Error("Failed to mark vault as migrated")
}
}
logrus.WithFields(logrus.Fields{
"lastVersion": v.GetLastVersion().String(),
"showAllMail": v.GetShowAllMail(),
"updateCh": v.GetUpdateChannel(),
"autoUpdate": v.GetAutoUpdate(),
"rollout": v.GetUpdateRollout(),
"DoH": v.GetProxyAllowed(),
}).Info("Vault loaded")
// Load the cookies from the vault.
return withCookieJar(v, func(cookieJar http.CookieJar) error {
// Create a new bridge instance.
return withBridge(c, exe, locations, version, identifier, crashHandler, reporter, v, cookieJar, func(b *bridge.Bridge, eventCh <-chan events.Event) error {
if insecure {
logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted")
b.PushError(bridge.ErrVaultInsecure)
} }
if corrupt { // Migrate old accounts into the vault.
logrus.Warn("The vault is corrupt and has been wiped") if err := migrateOldAccounts(locations, keychains, v); err != nil {
b.PushError(bridge.ErrVaultCorrupt) logrus.WithError(err).Error("Failed to migrate old accounts")
} }
// Remove old updates files // The vault has been migrated.
b.RemoveOldUpdates() if err := v.SetMigrated(); err != nil {
logrus.WithError(err).Error("Failed to mark vault as migrated")
}
}
// Start telemetry heartbeat process logrus.WithFields(logrus.Fields{
b.StartHeartbeat(b) "lastVersion": v.GetLastVersion().String(),
"showAllMail": v.GetShowAllMail(),
"updateCh": v.GetUpdateChannel(),
"autoUpdate": v.GetAutoUpdate(),
"rollout": v.GetUpdateRollout(),
"DoH": v.GetProxyAllowed(),
}).Info("Vault loaded")
// Run the frontend. // Load the cookies from the vault.
return runFrontend(c, crashHandler, restarter, locations, b, eventCh, quitCh, c.Int(flagParentPID)) return withCookieJar(v, func(cookieJar http.CookieJar) error {
// Create a new bridge instance.
return withBridge(c, exe, locations, version, identifier, crashHandler, reporter, v, cookieJar, keychains, func(b *bridge.Bridge, eventCh <-chan events.Event) error {
if insecure {
logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted")
b.PushError(bridge.ErrVaultInsecure)
}
if corrupt {
logrus.Warn("The vault is corrupt and has been wiped")
b.PushError(bridge.ErrVaultCorrupt)
}
// Remove old updates files
b.RemoveOldUpdates()
// Start telemetry heartbeat process
b.StartHeartbeat(b)
// Run the frontend.
return runFrontend(c, crashHandler, restarter, locations, b, eventCh, quitCh, c.Int(flagParentPID))
})
}) })
}) })
}) })
@ -480,6 +484,13 @@ func withCookieJar(vault *vault.Vault, fn func(http.CookieJar) error) error {
return fn(persister) return fn(persister)
} }
// List usable keychains.
func withKeychainList(fn func(*keychain.List) error) error {
logrus.Debug("Creating keychain list")
defer logrus.Debug("Keychain list stop")
return fn(keychain.NewList())
}
func setDeviceCookies(jar *cookies.Jar) error { func setDeviceCookies(jar *cookies.Jar) error {
url, err := url.Parse(constants.APIHost) url, err := url.Parse(constants.APIHost)
if err != nil { if err != nil {

View File

@ -37,6 +37,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/useragent"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/internal/versioner" "github.com/ProtonMail/proton-bridge/v3/internal/versioner"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
@ -55,6 +56,7 @@ func withBridge(
reporter *sentry.Reporter, reporter *sentry.Reporter,
vault *vault.Vault, vault *vault.Vault,
cookieJar http.CookieJar, cookieJar http.CookieJar,
keychains *keychain.List,
fn func(*bridge.Bridge, <-chan events.Event) error, fn func(*bridge.Bridge, <-chan events.Event) error,
) error { ) error {
logrus.Debug("Creating bridge") logrus.Debug("Creating bridge")
@ -97,6 +99,7 @@ func withBridge(
autostarter, autostarter,
updater, updater,
version, version,
keychains,
// The API stuff. // The API stuff.
constants.APIHost, constants.APIHost,

View File

@ -122,7 +122,7 @@ func migrateOldSettingsWithDir(configDir string, v *vault.Vault) error {
return v.SetBridgeTLSCertKey(certPEM, keyPEM) return v.SetBridgeTLSCertKey(certPEM, keyPEM)
} }
func migrateOldAccounts(locations *locations.Locations, v *vault.Vault) error { func migrateOldAccounts(locations *locations.Locations, keychains *keychain.List, v *vault.Vault) error {
logrus.Info("Migrating accounts") logrus.Info("Migrating accounts")
settings, err := locations.ProvideSettingsPath() settings, err := locations.ProvideSettingsPath()
@ -134,8 +134,7 @@ func migrateOldAccounts(locations *locations.Locations, v *vault.Vault) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to get helper: %w", err) return fmt.Errorf("failed to get helper: %w", err)
} }
keychain, err := keychain.NewKeychain(helper, "bridge", keychains.GetHelpers(), keychains.GetDefaultHelper())
keychain, err := keychain.NewKeychain(helper, "bridge")
if err != nil { if err != nil {
return fmt.Errorf("failed to create keychain: %w", err) return fmt.Errorf("failed to create keychain: %w", err)
} }

View File

@ -35,7 +35,6 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/algo" "github.com/ProtonMail/proton-bridge/v3/pkg/algo"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
dockerCredentials "github.com/docker/docker-credential-helpers/credentials"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -133,11 +132,9 @@ func TestKeychainMigration(t *testing.T) {
} }
func TestUserMigration(t *testing.T) { func TestUserMigration(t *testing.T) {
keychainHelper := keychain.NewTestHelper() kcl := keychain.NewTestKeychainsList()
keychain.Helpers["mock"] = func(string) (dockerCredentials.Helper, error) { return keychainHelper, nil } kc, err := keychain.NewKeychain("mock", "bridge", kcl.GetHelpers(), kcl.GetDefaultHelper())
kc, err := keychain.NewKeychain("mock", "bridge")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, kc.Put("brokenID", "broken")) require.NoError(t, kc.Put("brokenID", "broken"))
@ -178,7 +175,7 @@ func TestUserMigration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)
require.NoError(t, migrateOldAccounts(locations, v)) require.NoError(t, migrateOldAccounts(locations, kcl, v))
require.Equal(t, []string{wantCredentials.UserID}, v.GetUserIDs()) require.Equal(t, []string{wantCredentials.UserID}, v.GetUserIDs())
require.NoError(t, v.GetUser(wantCredentials.UserID, func(u *vault.User) { require.NoError(t, v.GetUser(wantCredentials.UserID, func(u *vault.User) {

View File

@ -29,12 +29,12 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func WithVault(locations *locations.Locations, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error { func WithVault(locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error {
logrus.Debug("Creating vault") logrus.Debug("Creating vault")
defer logrus.Debug("Vault stopped") defer logrus.Debug("Vault stopped")
// Create the encVault. // Create the encVault.
encVault, insecure, corrupt, err := newVault(locations, panicHandler) encVault, insecure, corrupt, err := newVault(locations, keychains, panicHandler)
if err != nil { if err != nil {
return fmt.Errorf("could not create vault: %w", err) return fmt.Errorf("could not create vault: %w", err)
} }
@ -49,7 +49,7 @@ func WithVault(locations *locations.Locations, panicHandler async.PanicHandler,
return fn(encVault, insecure, corrupt) return fn(encVault, insecure, corrupt)
} }
func newVault(locations *locations.Locations, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) { func newVault(locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) {
vaultDir, err := locations.ProvideSettingsPath() vaultDir, err := locations.ProvideSettingsPath()
if err != nil { if err != nil {
return nil, false, false, fmt.Errorf("could not get vault dir: %w", err) return nil, false, false, fmt.Errorf("could not get vault dir: %w", err)
@ -62,7 +62,7 @@ func newVault(locations *locations.Locations, panicHandler async.PanicHandler) (
insecure bool insecure bool
) )
if key, err := loadVaultKey(vaultDir); err != nil { if key, err := loadVaultKey(vaultDir, keychains); err != nil {
logrus.WithError(err).Error("Could not load/create vault key") logrus.WithError(err).Error("Could not load/create vault key")
insecure = true insecure = true
@ -85,13 +85,13 @@ func newVault(locations *locations.Locations, panicHandler async.PanicHandler) (
return vault, insecure, corrupt, nil return vault, insecure, corrupt, nil
} }
func loadVaultKey(vaultDir string) ([]byte, error) { func loadVaultKey(vaultDir string, keychains *keychain.List) ([]byte, error) {
helper, err := vault.GetHelper(vaultDir) helper, err := vault.GetHelper(vaultDir)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get keychain helper: %w", err) return nil, fmt.Errorf("could not get keychain helper: %w", err)
} }
kc, err := keychain.NewKeychain(helper, constants.KeyChainName) kc, err := keychain.NewKeychain(helper, constants.KeyChainName, keychains.GetHelpers(), keychains.GetDefaultHelper())
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create keychain: %w", err) return nil, fmt.Errorf("could not create keychain: %w", err)
} }

View File

@ -45,6 +45,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry" "github.com/ProtonMail/proton-bridge/v3/internal/telemetry"
"github.com/ProtonMail/proton-bridge/v3/internal/user" "github.com/ProtonMail/proton-bridge/v3/internal/user"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -82,6 +83,9 @@ type Bridge struct {
newVersion *semver.Version newVersion *semver.Version
newVersionLock safe.RWMutex newVersionLock safe.RWMutex
// keychains is the utils that own usable keychains found in the OS.
keychains *keychain.List
// focusService is used to raise the bridge window when needed. // focusService is used to raise the bridge window when needed.
focusService *focus.Service focusService *focus.Service
@ -138,6 +142,7 @@ func New(
autostarter Autostarter, // the autostarter to manage autostart settings autostarter Autostarter, // the autostarter to manage autostart settings
updater Updater, // the updater to fetch and install updates updater Updater, // the updater to fetch and install updates
curVersion *semver.Version, // the current version of the bridge curVersion *semver.Version, // the current version of the bridge
keychains *keychain.List, // usable keychains
apiURL string, // the URL of the API to use apiURL string, // the URL of the API to use
cookieJar http.CookieJar, // the cookie jar to use cookieJar http.CookieJar, // the cookie jar to use
@ -171,6 +176,7 @@ func New(
autostarter, autostarter,
updater, updater,
curVersion, curVersion,
keychains,
panicHandler, panicHandler,
reporter, reporter,
@ -204,6 +210,7 @@ func newBridge(
autostarter Autostarter, autostarter Autostarter,
updater Updater, updater Updater,
curVersion *semver.Version, curVersion *semver.Version,
keychains *keychain.List,
panicHandler async.PanicHandler, panicHandler async.PanicHandler,
reporter reporter.Reporter, reporter reporter.Reporter,
@ -256,6 +263,8 @@ func newBridge(
newVersion: curVersion, newVersion: curVersion,
newVersionLock: safe.NewRWMutex(), newVersionLock: safe.NewRWMutex(),
keychains: keychains,
panicHandler: panicHandler, panicHandler: panicHandler,
reporter: reporter, reporter: reporter,

View File

@ -49,6 +49,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/user" "github.com/ProtonMail/proton-bridge/v3/internal/user"
"github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/useragent"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/ProtonMail/proton-bridge/v3/tests" "github.com/ProtonMail/proton-bridge/v3/tests"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
imapid "github.com/emersion/go-imap-id" imapid "github.com/emersion/go-imap-id"
@ -950,6 +951,7 @@ func withBridgeNoMocks(
mocks.Autostarter, mocks.Autostarter,
mocks.Updater, mocks.Updater,
v2_3_0, v2_3_0,
keychain.NewTestKeychainsList(),
// The API stuff. // The API stuff.
apiURL, apiURL,

View File

@ -26,7 +26,6 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry" "github.com/ProtonMail/proton-bridge/v3/internal/telemetry"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -81,7 +80,7 @@ func (bridge *Bridge) SetLastHeartbeatSent(timestamp time.Time) error {
} }
func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) { func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) {
bridge.heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), keychain.DefaultHelper) bridge.heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), bridge.keychains.GetDefaultHelper())
// Check for heartbeat when triggered. // Check for heartbeat when triggered.
bridge.goHeartbeat = bridge.tasks.PeriodicOrTrigger(HeartbeatCheckInterval, 0, func(ctx context.Context) { bridge.goHeartbeat = bridge.tasks.PeriodicOrTrigger(HeartbeatCheckInterval, 0, func(ctx context.Context) {
@ -104,7 +103,7 @@ func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) {
if val, err := bridge.GetKeychainApp(); err != nil { if val, err := bridge.GetKeychainApp(); err != nil {
bridge.heartbeat.SetKeyChainPref(val) bridge.heartbeat.SetKeyChainPref(val)
} else { } else {
bridge.heartbeat.SetKeyChainPref(keychain.DefaultHelper) bridge.heartbeat.SetKeyChainPref(bridge.keychains.GetDefaultHelper())
} }
bridge.heartbeat.SetPrevVersion(bridge.GetLastVersion().String()) bridge.heartbeat.SetPrevVersion(bridge.GetLastVersion().String())

View File

@ -0,0 +1,24 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import "golang.org/x/exp/maps"
func (bridge *Bridge) GetHelpersNames() []string {
return maps.Keys(bridge.keychains.GetHelpers())
}

View File

@ -33,10 +33,8 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/service" "github.com/ProtonMail/proton-bridge/v3/internal/service"
"github.com/ProtonMail/proton-bridge/v3/internal/updater" "github.com/ProtonMail/proton-bridge/v3/internal/updater"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/ProtonMail/proton-bridge/v3/pkg/ports" "github.com/ProtonMail/proton-bridge/v3/pkg/ports"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/runtime/protoimpl" "google.golang.org/protobuf/runtime/protoimpl"
@ -712,7 +710,7 @@ func (s *Service) IsPortFree(_ context.Context, port *wrapperspb.Int32Value) (*w
func (s *Service) AvailableKeychains(_ context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) { func (s *Service) AvailableKeychains(_ context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) {
s.log.Debug("AvailableKeychains") s.log.Debug("AvailableKeychains")
return &AvailableKeychainsResponse{Keychains: maps.Keys(keychain.Helpers)}, nil return &AvailableKeychainsResponse{Keychains: s.bridge.GetHelpersNames()}, nil
} }
func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.StringValue) (*emptypb.Empty, error) { func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.StringValue) (*emptypb.Empty, error) {

View File

@ -31,14 +31,18 @@ const (
MacOSKeychain = "macos-keychain" MacOSKeychain = "macos-keychain"
) )
func init() { //nolint:gochecknoinits func listHelpers() (Helpers, string) {
Helpers = make(map[string]helperConstructor) helpers := make(Helpers)
// MacOS always provides a keychain. // MacOS always provides a keychain.
Helpers[MacOSKeychain] = newMacOSHelper if isUsable(newMacOSHelper("")) {
helpers[MacOSKeychain] = newMacOSHelper
} else {
logrus.WithField("keychain", "MacOSKeychain").Warn("Keychain is not available.")
}
// Use MacOSKeychain by default. // Use MacOSKeychain by default.
DefaultHelper = MacOSKeychain return helpers, MacOSKeychain
} }
func parseError(original error) error { func parseError(original error) error {

View File

@ -18,8 +18,6 @@
package keychain package keychain
import ( import (
"reflect"
"github.com/docker/docker-credential-helpers/credentials" "github.com/docker/docker-credential-helpers/credentials"
"github.com/docker/docker-credential-helpers/pass" "github.com/docker/docker-credential-helpers/pass"
"github.com/docker/docker-credential-helpers/secretservice" "github.com/docker/docker-credential-helpers/secretservice"
@ -33,30 +31,37 @@ const (
SecretServiceDBus = "secret-service-dbus" SecretServiceDBus = "secret-service-dbus"
) )
func init() { //nolint:gochecknoinits func listHelpers() (Helpers, string) {
Helpers = make(map[string]helperConstructor) helpers := make(Helpers)
if isUsable(newDBusHelper("")) { if isUsable(newDBusHelper("")) {
Helpers[SecretServiceDBus] = newDBusHelper helpers[SecretServiceDBus] = newDBusHelper
} else {
logrus.WithField("keychain", "SecretServiceDBus").Warn("Keychain is not available.")
} }
if _, err := execabs.LookPath("gnome-keyring"); err == nil && isUsable(newSecretServiceHelper("")) { if _, err := execabs.LookPath("gnome-keyring"); err == nil && isUsable(newSecretServiceHelper("")) {
Helpers[SecretService] = newSecretServiceHelper helpers[SecretService] = newSecretServiceHelper
} else {
logrus.WithField("keychain", "SecretService").Warn("Keychain is not available.")
} }
if _, err := execabs.LookPath("pass"); err == nil && isUsable(newPassHelper("")) { if _, err := execabs.LookPath("pass"); err == nil && isUsable(newPassHelper("")) {
Helpers[Pass] = newPassHelper helpers[Pass] = newPassHelper
} else {
logrus.WithField("keychain", "Pass").Warn("Keychain is not available.")
} }
DefaultHelper = SecretServiceDBus defaultHelper := SecretServiceDBus
// If Pass is available, use it by default. // If Pass is available, use it by default.
// Otherwise, if SecretService is available, use it by default. // Otherwise, if SecretService is available, use it by default.
if _, ok := Helpers[Pass]; ok { if _, ok := helpers[Pass]; ok {
DefaultHelper = Pass defaultHelper = Pass
} else if _, ok := Helpers[SecretService]; ok { } else if _, ok := helpers[SecretService]; ok {
DefaultHelper = SecretService defaultHelper = SecretService
} }
return helpers, defaultHelper
} }
func newDBusHelper(string) (credentials.Helper, error) { func newDBusHelper(string) (credentials.Helper, error) {
@ -70,36 +75,3 @@ func newPassHelper(string) (credentials.Helper, error) {
func newSecretServiceHelper(string) (credentials.Helper, error) { func newSecretServiceHelper(string) (credentials.Helper, error) {
return &secretservice.Secretservice{}, nil return &secretservice.Secretservice{}, nil
} }
// isUsable returns whether the credentials helper is usable.
func isUsable(helper credentials.Helper, err error) bool {
l := logrus.WithField("helper", reflect.TypeOf(helper))
if err != nil {
l.WithError(err).Warn("Keychain helper couldn't be created")
return false
}
creds := &credentials.Credentials{
ServerURL: "bridge/check",
Username: "check",
Secret: "check",
}
if err := helper.Add(creds); err != nil {
l.WithError(err).Warn("Failed to add test credentials to keychain")
return false
}
if _, _, err := helper.Get(creds.ServerURL); err != nil {
l.WithError(err).Warn("Failed to get test credentials from keychain")
return false
}
if err := helper.Delete(creds.ServerURL); err != nil {
l.WithError(err).Warn("Failed to delete test credentials from keychain")
return false
}
return true
}

View File

@ -20,18 +20,21 @@ package keychain
import ( import (
"github.com/docker/docker-credential-helpers/credentials" "github.com/docker/docker-credential-helpers/credentials"
"github.com/docker/docker-credential-helpers/wincred" "github.com/docker/docker-credential-helpers/wincred"
"github.com/sirupsen/logrus"
) )
const WindowsCredentials = "windows-credentials" const WindowsCredentials = "windows-credentials"
func init() { //nolint:gochecknoinits func listHelpers() (Helpers, string) {
Helpers = make(map[string]helperConstructor) helpers := make(Helpers)
// Windows always provides a keychain. // Windows always provides a keychain.
Helpers[WindowsCredentials] = newWinCredHelper if isUsable(newWinCredHelper("")) {
helpers[WindowsCredentials] = newWinCredHelper
} else {
logrus.WithField("keychain", "WindowsCredentials").Warn("Keychain is not available.")
}
// Use WindowsCredentials by default. // Use WindowsCredentials by default.
DefaultHelper = WindowsCredentials return helpers, WindowsCredentials
} }
func newWinCredHelper(string) (credentials.Helper, error) { func newWinCredHelper(string) (credentials.Helper, error) {

View File

@ -21,9 +21,12 @@ package keychain
import ( import (
"errors" "errors"
"fmt" "fmt"
"reflect"
"sync" "sync"
"time"
"github.com/docker/docker-credential-helpers/credentials" "github.com/docker/docker-credential-helpers/credentials"
"github.com/sirupsen/logrus"
) )
// helperConstructor constructs a keychain helperConstructor. // helperConstructor constructs a keychain helperConstructor.
@ -38,28 +41,53 @@ var (
// ErrMacKeychainRebuild is returned on macOS with blocked or corrupted keychain. // ErrMacKeychainRebuild is returned on macOS with blocked or corrupted keychain.
ErrMacKeychainRebuild = errors.New("keychain error -25293") ErrMacKeychainRebuild = errors.New("keychain error -25293")
// Helpers holds all discovered keychain helpers. It is populated in init().
Helpers map[string]helperConstructor //nolint:gochecknoglobals
// DefaultHelper is the default helper to use if the user hasn't yet set a preference.
DefaultHelper string //nolint:gochecknoglobals
) )
type Helpers map[string]helperConstructor
type List struct {
helpers Helpers
defaultHelper string
locker sync.Locker
}
// NewList checks availability of every keychains detected on the User Operating System
// This will ask the user to unlock keychain(s) to check their usability.
// This should only be called once.
func NewList() *List {
var list = List{locker: &sync.Mutex{}}
list.helpers, list.defaultHelper = listHelpers()
return &list
}
func (kcl *List) GetHelpers() Helpers {
kcl.locker.Lock()
defer kcl.locker.Unlock()
return kcl.helpers
}
func (kcl *List) GetDefaultHelper() string {
kcl.locker.Lock()
defer kcl.locker.Unlock()
return kcl.defaultHelper
}
// NewKeychain creates a new native keychain. // NewKeychain creates a new native keychain.
func NewKeychain(preferred, keychainName string) (*Keychain, error) { func NewKeychain(preferred, keychainName string, helpers Helpers, defaultHelper string) (*Keychain, error) {
// There must be at least one keychain helper available. // There must be at least one keychain helper available.
if len(Helpers) < 1 { if len(helpers) < 1 {
return nil, ErrNoKeychain return nil, ErrNoKeychain
} }
// If the preferred keychain is unsupported, fallback to the default one. // If the preferred keychain is unsupported, fallback to the default one.
if _, ok := Helpers[preferred]; !ok { if _, ok := helpers[preferred]; !ok {
preferred = DefaultHelper preferred = defaultHelper
} }
// Load the user's preferred keychain helper. // Load the user's preferred keychain helper.
helperConstructor, ok := Helpers[preferred] helperConstructor, ok := helpers[preferred]
if !ok { if !ok {
return nil, ErrNoKeychain return nil, ErrNoKeychain
} }
@ -163,3 +191,49 @@ func (kc *Keychain) Put(userID, secret string) error {
func (kc *Keychain) secretURL(userID string) string { func (kc *Keychain) secretURL(userID string) string {
return fmt.Sprintf("%v/%v", kc.url, userID) return fmt.Sprintf("%v/%v", kc.url, userID)
} }
// isUsable returns whether the credentials helper is usable.
func isUsable(helper credentials.Helper, err error) bool {
l := logrus.WithField("helper", reflect.TypeOf(helper))
if err != nil {
l.WithError(err).Warn("Keychain helper couldn't be created")
return false
}
creds := &credentials.Credentials{
ServerURL: "bridge/check",
Username: "check",
Secret: "check",
}
if err := retry(func() error {
return helper.Add(creds)
}); err != nil {
l.WithError(err).Warn("Failed to add test credentials to keychain")
return false
}
if _, _, err := helper.Get(creds.ServerURL); err != nil {
l.WithError(err).Warn("Failed to get test credentials from keychain")
return false
}
if err := helper.Delete(creds.ServerURL); err != nil {
l.WithError(err).Warn("Failed to delete test credentials from keychain")
return false
}
return true
}
func retry(condition func() error) error {
var maxRetry = 5
for r := 0; ; r++ {
err := condition()
if err == nil || r >= maxRetry {
return err
}
time.Sleep(200 * time.Millisecond)
}
}

View File

@ -17,10 +17,22 @@
package keychain package keychain
import "github.com/docker/docker-credential-helpers/credentials" import (
"sync"
"github.com/docker/docker-credential-helpers/credentials"
)
type TestHelper map[string]*credentials.Credentials type TestHelper map[string]*credentials.Credentials
func NewTestKeychainsList() *List {
keychainHelper := NewTestHelper()
helpers := make(Helpers)
helpers["mock"] = func(string) (credentials.Helper, error) { return keychainHelper, nil }
var list = List{helpers: helpers, defaultHelper: "mock", locker: &sync.Mutex{}}
return &list
}
func NewTestHelper() TestHelper { func NewTestHelper() TestHelper {
return make(TestHelper) return make(TestHelper)
} }

View File

@ -39,6 +39,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/service" "github.com/ProtonMail/proton-bridge/v3/internal/service"
"github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/useragent"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
@ -153,6 +154,7 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) {
t.mocks.Autostarter, t.mocks.Autostarter,
t.mocks.Updater, t.mocks.Updater,
t.version, t.version,
keychain.NewTestKeychainsList(),
// API stuff // API stuff
t.api.GetHostURL(), t.api.GetHostURL(),

View File

@ -50,7 +50,7 @@ func main() {
func readAction(c *cli.Context) error { func readAction(c *cli.Context) error {
return app.WithLocations(func(locations *locations.Locations) error { return app.WithLocations(func(locations *locations.Locations) error {
return app.WithVault(locations, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error { return app.WithVault(locations, nil, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error {
if _, err := os.Stdout.Write(vault.ExportJSON()); err != nil { if _, err := os.Stdout.Write(vault.ExportJSON()); err != nil {
return fmt.Errorf("failed to write vault: %w", err) return fmt.Errorf("failed to write vault: %w", err)
} }
@ -62,7 +62,7 @@ func readAction(c *cli.Context) error {
func writeAction(c *cli.Context) error { func writeAction(c *cli.Context) error {
return app.WithLocations(func(locations *locations.Locations) error { return app.WithLocations(func(locations *locations.Locations) error {
return app.WithVault(locations, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error { return app.WithVault(locations, nil, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error {
b, err := io.ReadAll(os.Stdin) b, err := io.ReadAll(os.Stdin)
if err != nil { if err != nil {
return fmt.Errorf("failed to read vault: %w", err) return fmt.Errorf("failed to read vault: %w", err)