forked from Silverfish/proton-bridge
feat(GODT-2277): Move Keychain helpers creation in main.
This commit is contained in:
@ -41,6 +41,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/sentry"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/restarter"
|
||||
"github.com/pkg/profile"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -234,56 +235,59 @@ func run(c *cli.Context) error {
|
||||
}
|
||||
|
||||
return withSingleInstance(settings, locations.GetLockFile(), version, func() error {
|
||||
// Unlock the encrypted vault.
|
||||
return WithVault(locations, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error {
|
||||
if !v.Migrated() {
|
||||
// Migrate old settings into the vault.
|
||||
if err := migrateOldSettings(v); err != nil {
|
||||
logrus.WithError(err).Error("Failed to migrate old settings")
|
||||
}
|
||||
|
||||
// Migrate old accounts into the vault.
|
||||
if err := migrateOldAccounts(locations, v); err != nil {
|
||||
logrus.WithError(err).Error("Failed to migrate old accounts")
|
||||
}
|
||||
|
||||
// The vault has been migrated.
|
||||
if err := v.SetMigrated(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to mark vault as migrated")
|
||||
}
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"lastVersion": v.GetLastVersion().String(),
|
||||
"showAllMail": v.GetShowAllMail(),
|
||||
"updateCh": v.GetUpdateChannel(),
|
||||
"autoUpdate": v.GetAutoUpdate(),
|
||||
"rollout": v.GetUpdateRollout(),
|
||||
"DoH": v.GetProxyAllowed(),
|
||||
}).Info("Vault loaded")
|
||||
|
||||
// Load the cookies from the vault.
|
||||
return withCookieJar(v, func(cookieJar http.CookieJar) error {
|
||||
// Create a new bridge instance.
|
||||
return withBridge(c, exe, locations, version, identifier, crashHandler, reporter, v, cookieJar, func(b *bridge.Bridge, eventCh <-chan events.Event) error {
|
||||
if insecure {
|
||||
logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted")
|
||||
b.PushError(bridge.ErrVaultInsecure)
|
||||
// Look for available keychains
|
||||
return withKeychainList(func(keychains *keychain.List) error {
|
||||
// Unlock the encrypted vault.
|
||||
return WithVault(locations, keychains, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error {
|
||||
if !v.Migrated() {
|
||||
// Migrate old settings into the vault.
|
||||
if err := migrateOldSettings(v); err != nil {
|
||||
logrus.WithError(err).Error("Failed to migrate old settings")
|
||||
}
|
||||
|
||||
if corrupt {
|
||||
logrus.Warn("The vault is corrupt and has been wiped")
|
||||
b.PushError(bridge.ErrVaultCorrupt)
|
||||
// Migrate old accounts into the vault.
|
||||
if err := migrateOldAccounts(locations, keychains, v); err != nil {
|
||||
logrus.WithError(err).Error("Failed to migrate old accounts")
|
||||
}
|
||||
|
||||
// Remove old updates files
|
||||
b.RemoveOldUpdates()
|
||||
// The vault has been migrated.
|
||||
if err := v.SetMigrated(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to mark vault as migrated")
|
||||
}
|
||||
}
|
||||
|
||||
// Start telemetry heartbeat process
|
||||
b.StartHeartbeat(b)
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"lastVersion": v.GetLastVersion().String(),
|
||||
"showAllMail": v.GetShowAllMail(),
|
||||
"updateCh": v.GetUpdateChannel(),
|
||||
"autoUpdate": v.GetAutoUpdate(),
|
||||
"rollout": v.GetUpdateRollout(),
|
||||
"DoH": v.GetProxyAllowed(),
|
||||
}).Info("Vault loaded")
|
||||
|
||||
// Run the frontend.
|
||||
return runFrontend(c, crashHandler, restarter, locations, b, eventCh, quitCh, c.Int(flagParentPID))
|
||||
// Load the cookies from the vault.
|
||||
return withCookieJar(v, func(cookieJar http.CookieJar) error {
|
||||
// Create a new bridge instance.
|
||||
return withBridge(c, exe, locations, version, identifier, crashHandler, reporter, v, cookieJar, keychains, func(b *bridge.Bridge, eventCh <-chan events.Event) error {
|
||||
if insecure {
|
||||
logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted")
|
||||
b.PushError(bridge.ErrVaultInsecure)
|
||||
}
|
||||
|
||||
if corrupt {
|
||||
logrus.Warn("The vault is corrupt and has been wiped")
|
||||
b.PushError(bridge.ErrVaultCorrupt)
|
||||
}
|
||||
|
||||
// Remove old updates files
|
||||
b.RemoveOldUpdates()
|
||||
|
||||
// Start telemetry heartbeat process
|
||||
b.StartHeartbeat(b)
|
||||
|
||||
// Run the frontend.
|
||||
return runFrontend(c, crashHandler, restarter, locations, b, eventCh, quitCh, c.Int(flagParentPID))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -480,6 +484,13 @@ func withCookieJar(vault *vault.Vault, fn func(http.CookieJar) error) error {
|
||||
return fn(persister)
|
||||
}
|
||||
|
||||
// List usable keychains.
|
||||
func withKeychainList(fn func(*keychain.List) error) error {
|
||||
logrus.Debug("Creating keychain list")
|
||||
defer logrus.Debug("Keychain list stop")
|
||||
return fn(keychain.NewList())
|
||||
}
|
||||
|
||||
func setDeviceCookies(jar *cookies.Jar) error {
|
||||
url, err := url.Parse(constants.APIHost)
|
||||
if err != nil {
|
||||
|
||||
@ -37,6 +37,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/versioner"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
@ -55,6 +56,7 @@ func withBridge(
|
||||
reporter *sentry.Reporter,
|
||||
vault *vault.Vault,
|
||||
cookieJar http.CookieJar,
|
||||
keychains *keychain.List,
|
||||
fn func(*bridge.Bridge, <-chan events.Event) error,
|
||||
) error {
|
||||
logrus.Debug("Creating bridge")
|
||||
@ -97,6 +99,7 @@ func withBridge(
|
||||
autostarter,
|
||||
updater,
|
||||
version,
|
||||
keychains,
|
||||
|
||||
// The API stuff.
|
||||
constants.APIHost,
|
||||
|
||||
@ -122,7 +122,7 @@ func migrateOldSettingsWithDir(configDir string, v *vault.Vault) error {
|
||||
return v.SetBridgeTLSCertKey(certPEM, keyPEM)
|
||||
}
|
||||
|
||||
func migrateOldAccounts(locations *locations.Locations, v *vault.Vault) error {
|
||||
func migrateOldAccounts(locations *locations.Locations, keychains *keychain.List, v *vault.Vault) error {
|
||||
logrus.Info("Migrating accounts")
|
||||
|
||||
settings, err := locations.ProvideSettingsPath()
|
||||
@ -134,8 +134,7 @@ func migrateOldAccounts(locations *locations.Locations, v *vault.Vault) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get helper: %w", err)
|
||||
}
|
||||
|
||||
keychain, err := keychain.NewKeychain(helper, "bridge")
|
||||
keychain, err := keychain.NewKeychain(helper, "bridge", keychains.GetHelpers(), keychains.GetDefaultHelper())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create keychain: %w", err)
|
||||
}
|
||||
|
||||
@ -35,7 +35,6 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
|
||||
dockerCredentials "github.com/docker/docker-credential-helpers/credentials"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -133,11 +132,9 @@ func TestKeychainMigration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUserMigration(t *testing.T) {
|
||||
keychainHelper := keychain.NewTestHelper()
|
||||
kcl := keychain.NewTestKeychainsList()
|
||||
|
||||
keychain.Helpers["mock"] = func(string) (dockerCredentials.Helper, error) { return keychainHelper, nil }
|
||||
|
||||
kc, err := keychain.NewKeychain("mock", "bridge")
|
||||
kc, err := keychain.NewKeychain("mock", "bridge", kcl.GetHelpers(), kcl.GetDefaultHelper())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, kc.Put("brokenID", "broken"))
|
||||
@ -178,7 +175,7 @@ func TestUserMigration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
|
||||
require.NoError(t, migrateOldAccounts(locations, v))
|
||||
require.NoError(t, migrateOldAccounts(locations, kcl, v))
|
||||
require.Equal(t, []string{wantCredentials.UserID}, v.GetUserIDs())
|
||||
|
||||
require.NoError(t, v.GetUser(wantCredentials.UserID, func(u *vault.User) {
|
||||
|
||||
@ -29,12 +29,12 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func WithVault(locations *locations.Locations, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error {
|
||||
func WithVault(locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error {
|
||||
logrus.Debug("Creating vault")
|
||||
defer logrus.Debug("Vault stopped")
|
||||
|
||||
// Create the encVault.
|
||||
encVault, insecure, corrupt, err := newVault(locations, panicHandler)
|
||||
encVault, insecure, corrupt, err := newVault(locations, keychains, panicHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create vault: %w", err)
|
||||
}
|
||||
@ -49,7 +49,7 @@ func WithVault(locations *locations.Locations, panicHandler async.PanicHandler,
|
||||
return fn(encVault, insecure, corrupt)
|
||||
}
|
||||
|
||||
func newVault(locations *locations.Locations, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) {
|
||||
func newVault(locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) {
|
||||
vaultDir, err := locations.ProvideSettingsPath()
|
||||
if err != nil {
|
||||
return nil, false, false, fmt.Errorf("could not get vault dir: %w", err)
|
||||
@ -62,7 +62,7 @@ func newVault(locations *locations.Locations, panicHandler async.PanicHandler) (
|
||||
insecure bool
|
||||
)
|
||||
|
||||
if key, err := loadVaultKey(vaultDir); err != nil {
|
||||
if key, err := loadVaultKey(vaultDir, keychains); err != nil {
|
||||
logrus.WithError(err).Error("Could not load/create vault key")
|
||||
insecure = true
|
||||
|
||||
@ -85,13 +85,13 @@ func newVault(locations *locations.Locations, panicHandler async.PanicHandler) (
|
||||
return vault, insecure, corrupt, nil
|
||||
}
|
||||
|
||||
func loadVaultKey(vaultDir string) ([]byte, error) {
|
||||
func loadVaultKey(vaultDir string, keychains *keychain.List) ([]byte, error) {
|
||||
helper, err := vault.GetHelper(vaultDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get keychain helper: %w", err)
|
||||
}
|
||||
|
||||
kc, err := keychain.NewKeychain(helper, constants.KeyChainName)
|
||||
kc, err := keychain.NewKeychain(helper, constants.KeyChainName, keychains.GetHelpers(), keychains.GetDefaultHelper())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create keychain: %w", err)
|
||||
}
|
||||
|
||||
@ -45,6 +45,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -82,6 +83,9 @@ type Bridge struct {
|
||||
newVersion *semver.Version
|
||||
newVersionLock safe.RWMutex
|
||||
|
||||
// keychains is the utils that own usable keychains found in the OS.
|
||||
keychains *keychain.List
|
||||
|
||||
// focusService is used to raise the bridge window when needed.
|
||||
focusService *focus.Service
|
||||
|
||||
@ -138,6 +142,7 @@ func New(
|
||||
autostarter Autostarter, // the autostarter to manage autostart settings
|
||||
updater Updater, // the updater to fetch and install updates
|
||||
curVersion *semver.Version, // the current version of the bridge
|
||||
keychains *keychain.List, // usable keychains
|
||||
|
||||
apiURL string, // the URL of the API to use
|
||||
cookieJar http.CookieJar, // the cookie jar to use
|
||||
@ -171,6 +176,7 @@ func New(
|
||||
autostarter,
|
||||
updater,
|
||||
curVersion,
|
||||
keychains,
|
||||
panicHandler,
|
||||
reporter,
|
||||
|
||||
@ -204,6 +210,7 @@ func newBridge(
|
||||
autostarter Autostarter,
|
||||
updater Updater,
|
||||
curVersion *semver.Version,
|
||||
keychains *keychain.List,
|
||||
panicHandler async.PanicHandler,
|
||||
reporter reporter.Reporter,
|
||||
|
||||
@ -256,6 +263,8 @@ func newBridge(
|
||||
newVersion: curVersion,
|
||||
newVersionLock: safe.NewRWMutex(),
|
||||
|
||||
keychains: keychains,
|
||||
|
||||
panicHandler: panicHandler,
|
||||
reporter: reporter,
|
||||
|
||||
|
||||
@ -49,6 +49,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
|
||||
"github.com/ProtonMail/proton-bridge/v3/tests"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
imapid "github.com/emersion/go-imap-id"
|
||||
@ -950,6 +951,7 @@ func withBridgeNoMocks(
|
||||
mocks.Autostarter,
|
||||
mocks.Updater,
|
||||
v2_3_0,
|
||||
keychain.NewTestKeychainsList(),
|
||||
|
||||
// The API stuff.
|
||||
apiURL,
|
||||
|
||||
@ -26,7 +26,6 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -81,7 +80,7 @@ func (bridge *Bridge) SetLastHeartbeatSent(timestamp time.Time) error {
|
||||
}
|
||||
|
||||
func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) {
|
||||
bridge.heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), keychain.DefaultHelper)
|
||||
bridge.heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), bridge.keychains.GetDefaultHelper())
|
||||
|
||||
// Check for heartbeat when triggered.
|
||||
bridge.goHeartbeat = bridge.tasks.PeriodicOrTrigger(HeartbeatCheckInterval, 0, func(ctx context.Context) {
|
||||
@ -104,7 +103,7 @@ func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) {
|
||||
if val, err := bridge.GetKeychainApp(); err != nil {
|
||||
bridge.heartbeat.SetKeyChainPref(val)
|
||||
} else {
|
||||
bridge.heartbeat.SetKeyChainPref(keychain.DefaultHelper)
|
||||
bridge.heartbeat.SetKeyChainPref(bridge.keychains.GetDefaultHelper())
|
||||
}
|
||||
bridge.heartbeat.SetPrevVersion(bridge.GetLastVersion().String())
|
||||
|
||||
|
||||
24
internal/bridge/keychain.go
Normal file
24
internal/bridge/keychain.go
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge
|
||||
|
||||
import "golang.org/x/exp/maps"
|
||||
|
||||
func (bridge *Bridge) GetHelpersNames() []string {
|
||||
return maps.Keys(bridge.keychains.GetHelpers())
|
||||
}
|
||||
@ -33,10 +33,8 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/service"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/ports"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/runtime/protoimpl"
|
||||
@ -712,7 +710,7 @@ func (s *Service) IsPortFree(_ context.Context, port *wrapperspb.Int32Value) (*w
|
||||
func (s *Service) AvailableKeychains(_ context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) {
|
||||
s.log.Debug("AvailableKeychains")
|
||||
|
||||
return &AvailableKeychainsResponse{Keychains: maps.Keys(keychain.Helpers)}, nil
|
||||
return &AvailableKeychainsResponse{Keychains: s.bridge.GetHelpersNames()}, nil
|
||||
}
|
||||
|
||||
func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
|
||||
Reference in New Issue
Block a user