feat(BRIDGE-356): Added retry logic for unavailable preferred keychain on Linux; Feature flag support before bridge initialization; Refactored some bits of the code;

This commit is contained in:
Atanas Janeshliev
2025-07-02 16:34:32 +02:00
parent 20183bf984
commit de3fd34998
33 changed files with 716 additions and 87 deletions

View File

@ -31,6 +31,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/crash" "github.com/ProtonMail/proton-bridge/v3/internal/crash"
"github.com/ProtonMail/proton-bridge/v3/internal/locations" "github.com/ProtonMail/proton-bridge/v3/internal/locations"
"github.com/ProtonMail/proton-bridge/v3/internal/logging" "github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/sentry" "github.com/ProtonMail/proton-bridge/v3/internal/sentry"
"github.com/ProtonMail/proton-bridge/v3/internal/updater" "github.com/ProtonMail/proton-bridge/v3/internal/updater"
"github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/useragent"
@ -164,7 +165,7 @@ func main() { //nolint:funlen
// On windows, if you use Run(), a terminal stays open; we don't want that. // On windows, if you use Run(), a terminal stays open; we don't want that.
if //goland:noinspection GoBoolExpressions if //goland:noinspection GoBoolExpressions
runtime.GOOS == "windows" { runtime.GOOS == platform.WINDOWS {
err = cmd.Start() err = cmd.Start()
} else { } else {
err = cmd.Run() err = cmd.Run()

View File

@ -39,7 +39,9 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/frontend/theme" "github.com/ProtonMail/proton-bridge/v3/internal/frontend/theme"
"github.com/ProtonMail/proton-bridge/v3/internal/locations" "github.com/ProtonMail/proton-bridge/v3/internal/locations"
"github.com/ProtonMail/proton-bridge/v3/internal/logging" "github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/sentry" "github.com/ProtonMail/proton-bridge/v3/internal/sentry"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/useragent"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
@ -285,11 +287,13 @@ func run(c *cli.Context) error {
logrus.WithError(err).Error("Failed to get settings path") logrus.WithError(err).Error("Failed to get settings path")
} }
featureFlags := unleash.GetStartupFeatureFlagsAndStore(constants.APIHost, version, locations.ProvideUnleashStartupCachePath)
return withSingleInstance(settings, locations.GetLockFile(), version, func() error { return withSingleInstance(settings, locations.GetLockFile(), version, func() error {
// Look for available keychains // Look for available keychains
return WithKeychainList(crashHandler, func(keychains *keychain.List) error { return WithKeychainList(crashHandler, func(keychains *keychain.List) error {
// Unlock the encrypted vault. // Unlock the encrypted vault.
return WithVault(reporter, locations, keychains, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error { return WithVault(reporter, locations, keychains, featureFlags, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error {
if !v.Migrated() { if !v.Migrated() {
// Migrate old settings into the vault. // Migrate old settings into the vault.
if err := migrateOldSettings(v); err != nil { if err := migrateOldSettings(v); err != nil {
@ -577,5 +581,5 @@ func setDeviceCookies(jar *cookies.Jar) error {
} }
func onMacOS() bool { func onMacOS() bool {
return runtime.GOOS == "darwin" return runtime.GOOS == platform.MACOS
} }

View File

@ -138,7 +138,14 @@ func migrateOldAccounts(locations *locations.Locations, keychains *keychain.List
if err != nil { if err != nil {
return fmt.Errorf("failed to get helper: %w", err) return fmt.Errorf("failed to get helper: %w", err)
} }
keychain, _, err := keychain.NewKeychain(helper, "bridge", keychains.GetHelpers(), keychains.GetDefaultHelper())
keychain, _, err := keychain.NewKeychain(
helper, "bridge",
keychains.GetHelpers(),
keychains.GetDefaultHelper(),
0,
make(map[string]bool),
)
if err != nil { if err != nil {
return fmt.Errorf("failed to create keychain: %w", err) return fmt.Errorf("failed to create keychain: %w", err)
} }

View File

@ -31,6 +31,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/cookies" "github.com/ProtonMail/proton-bridge/v3/internal/cookies"
"github.com/ProtonMail/proton-bridge/v3/internal/legacy/credentials" "github.com/ProtonMail/proton-bridge/v3/internal/legacy/credentials"
"github.com/ProtonMail/proton-bridge/v3/internal/locations" "github.com/ProtonMail/proton-bridge/v3/internal/locations"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/updater" "github.com/ProtonMail/proton-bridge/v3/internal/updater"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/algo" "github.com/ProtonMail/proton-bridge/v3/pkg/algo"
@ -85,7 +86,7 @@ func TestMigratePrefsToVaultWithoutKeys(t *testing.T) {
func TestKeychainMigration(t *testing.T) { func TestKeychainMigration(t *testing.T) {
// Migration tested only for linux. // Migration tested only for linux.
if runtime.GOOS != "linux" { if runtime.GOOS != platform.LINUX {
return return
} }
@ -134,7 +135,13 @@ func TestKeychainMigration(t *testing.T) {
func TestUserMigration(t *testing.T) { func TestUserMigration(t *testing.T) {
kcl := keychain.NewTestKeychainsList() kcl := keychain.NewTestKeychainsList()
kc, _, err := keychain.NewKeychain("mock", "bridge", kcl.GetHelpers(), kcl.GetDefaultHelper()) kc, _, err := keychain.NewKeychain(
"mock", "bridge",
kcl.GetHelpers(),
kcl.GetDefaultHelper(),
0,
make(map[string]bool),
)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, kc.Put("brokenID", "broken")) require.NoError(t, kc.Put("brokenID", "broken"))

View File

@ -20,6 +20,7 @@ package app
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"path" "path"
"runtime" "runtime"
@ -28,18 +29,20 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/certs" "github.com/ProtonMail/proton-bridge/v3/internal/certs"
"github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/locations" "github.com/ProtonMail/proton-bridge/v3/internal/locations"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/sentry" "github.com/ProtonMail/proton-bridge/v3/internal/sentry"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func WithVault(reporter *sentry.Reporter, locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error { func WithVault(reporter *sentry.Reporter, locations *locations.Locations, keychains *keychain.List, featureFlags unleash.FeatureFlagStartupStore, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error {
logrus.Debug("Creating vault") logrus.Debug("Creating vault")
defer logrus.Debug("Vault stopped") defer logrus.Debug("Vault stopped")
// Create the encVault. // Create the encVault.
encVault, insecure, corrupt, err := newVault(reporter, locations, keychains, panicHandler) encVault, insecure, corrupt, err := newVault(reporter, locations, keychains, featureFlags, panicHandler)
if err != nil { if err != nil {
return fmt.Errorf("could not create vault: %w", err) return fmt.Errorf("could not create vault: %w", err)
} }
@ -61,7 +64,7 @@ func WithVault(reporter *sentry.Reporter, locations *locations.Locations, keycha
return fn(encVault, insecure, corrupt != nil) return fn(encVault, insecure, corrupt != nil)
} }
func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler) (*vault.Vault, bool, error, error) { func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychains *keychain.List, featureFlags unleash.FeatureFlagStartupStore, panicHandler async.PanicHandler) (*vault.Vault, bool, error, error) {
vaultDir, err := locations.ProvideSettingsPath() vaultDir, err := locations.ProvideSettingsPath()
if err != nil { if err != nil {
return nil, false, nil, fmt.Errorf("could not get vault dir: %w", err) return nil, false, nil, fmt.Errorf("could not get vault dir: %w", err)
@ -75,7 +78,14 @@ func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychai
lastUsedHelper string lastUsedHelper string
) )
if key, helper, err := loadVaultKey(vaultDir, keychains); err != nil { if key, helper, err := loadVaultKey(vaultDir, keychains, featureFlags); err != nil {
if errors.Is(err, keychain.ErrPreferredKeychainNotAvailable) {
if err := vault.IncrementKeychainFailedAttemptCount(vaultDir); err != nil {
logrus.WithError(err).Error("Failed to increment failed keychain attempt count")
}
return &vault.Vault{}, false, nil, err
}
if reporter != nil { if reporter != nil {
if rerr := reporter.ReportMessageWithContext("Could not load/create vault key", map[string]any{ if rerr := reporter.ReportMessageWithContext("Could not load/create vault key", map[string]any{
"keychainDefaultHelper": keychains.GetDefaultHelper(), "keychainDefaultHelper": keychains.GetDefaultHelper(),
@ -108,23 +118,38 @@ func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychai
} }
// Remember the last successfully used keychain on Linux and store that as the user preference. // Remember the last successfully used keychain on Linux and store that as the user preference.
if runtime.GOOS == "linux" { if runtime.GOOS == platform.LINUX {
if err := vault.SetHelper(vaultDir, lastUsedHelper); err != nil { if err := vault.SetHelper(vaultDir, lastUsedHelper); err != nil {
logrus.WithError(err).Error("Could not store last used keychain helper") logrus.WithError(err).Error("Could not store last used keychain helper")
} }
if err := vault.ResetFailedKeychainAttemptCount(vaultDir); err != nil {
logrus.WithError(err).Error("Could not reset and save failed keychain attempt count")
}
} }
return userVault, insecure, corrupt, nil return userVault, insecure, corrupt, nil
} }
// loadVaultKey - loads the key used to encrypt the vault alongside the keychain helper used to access it. // loadVaultKey - loads the key used to encrypt the vault alongside the keychain helper used to access it.
func loadVaultKey(vaultDir string, keychains *keychain.List) (key []byte, keychainHelper string, err error) { func loadVaultKey(vaultDir string, keychains *keychain.List, featureFlags unleash.FeatureFlagStartupStore) (key []byte, keychainHelper string, err error) {
keychainHelper, err = vault.GetHelper(vaultDir) keychainHelper, err = vault.GetHelper(vaultDir)
if err != nil { if err != nil {
return nil, keychainHelper, fmt.Errorf("could not get keychain helper: %w", err) return nil, keychainHelper, fmt.Errorf("could not get keychain helper: %w", err)
} }
kc, keychainHelper, err := keychain.NewKeychain(keychainHelper, constants.KeyChainName, keychains.GetHelpers(), keychains.GetDefaultHelper()) keychainFailedAttemptCount, err := vault.GetKeychainFailedAttemptCount(vaultDir)
if err != nil {
return nil, keychainHelper, fmt.Errorf("could not get keychain failed attempt count: %w", err)
}
kc, keychainHelper, err := keychain.NewKeychain(
keychainHelper, constants.KeyChainName,
keychains.GetHelpers(),
keychains.GetDefaultHelper(),
keychainFailedAttemptCount,
featureFlags,
)
if err != nil { if err != nil {
return nil, keychainHelper, fmt.Errorf("could not create keychain: %w", err) return nil, keychainHelper, fmt.Errorf("could not create keychain: %w", err)
} }
@ -139,6 +164,12 @@ func loadVaultKey(vaultDir string, keychains *keychain.List) (key []byte, keycha
return key, keychainHelper, err return key, keychainHelper, err
} }
if keychain.ShouldRetryPreferredKeychain(featureFlags, keychainHelper) {
if keychainFailedAttemptCount < keychain.MaxFailedKeychainAttemptsLinux {
return nil, keychainHelper, keychain.PreferredKeychainRetryError(keychainFailedAttemptCount)
}
}
return nil, keychainHelper, fmt.Errorf("could not check for vault key: %w", err) return nil, keychainHelper, fmt.Errorf("could not check for vault key: %w", err)
} }

View File

@ -42,6 +42,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/focus" "github.com/ProtonMail/proton-bridge/v3/internal/focus"
"github.com/ProtonMail/proton-bridge/v3/internal/identifier" "github.com/ProtonMail/proton-bridge/v3/internal/identifier"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/sentry" "github.com/ProtonMail/proton-bridge/v3/internal/sentry"
"github.com/ProtonMail/proton-bridge/v3/internal/services/imapsmtpserver" "github.com/ProtonMail/proton-bridge/v3/internal/services/imapsmtpserver"
@ -687,7 +688,7 @@ func (bridge *Bridge) HasAPIConnection() bool {
// then we verify whether the gluon cache exists using the "new" username (provided by the DB path in this case) // then we verify whether the gluon cache exists using the "new" username (provided by the DB path in this case)
// if so we modify the cache directory in the user vault. // if so we modify the cache directory in the user vault.
func (bridge *Bridge) verifyUsernameChange() { func (bridge *Bridge) verifyUsernameChange() {
if runtime.GOOS != "darwin" { if runtime.GOOS != platform.MACOS {
return return
} }

View File

@ -28,6 +28,7 @@ import (
"github.com/ProtonMail/go-proton-api/server" "github.com/ProtonMail/go-proton-api/server"
bridgePkg "github.com/ProtonMail/proton-bridge/v3/internal/bridge" bridgePkg "github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/updater" "github.com/ProtonMail/proton-bridge/v3/internal/updater"
"github.com/ProtonMail/proton-bridge/v3/internal/updater/versioncompare" "github.com/ProtonMail/proton-bridge/v3/internal/updater/versioncompare"
"github.com/elastic/go-sysinfo/types" "github.com/elastic/go-sysinfo/types"
@ -331,7 +332,7 @@ func Test_Update_CheckOSVersion_NoUpdate(t *testing.T) {
bridge.CheckForUpdates() bridge.CheckForUpdates()
if runtime.GOOS == "darwin" { if runtime.GOOS == platform.MACOS {
require.Equal(t, events.UpdateNotAvailable{}, <-updateNotAvailableCh) require.Equal(t, events.UpdateNotAvailable{}, <-updateNotAvailableCh)
} else { } else {
require.Equal(t, events.UpdateInstalled{ require.Equal(t, events.UpdateInstalled{
@ -442,7 +443,7 @@ func Test_Update_CheckOSVersion_HasUpdate(t *testing.T) {
bridge.CheckForUpdates() bridge.CheckForUpdates()
if runtime.GOOS == "darwin" { if runtime.GOOS == platform.MACOS {
require.Equal(t, events.UpdateInstalled{ require.Equal(t, events.UpdateInstalled{
Release: expectedUpdateRelease, Release: expectedUpdateRelease,
Silent: true, Silent: true,

View File

@ -37,6 +37,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/bridge" "github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/user" "github.com/ProtonMail/proton-bridge/v3/internal/user"
"github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
@ -77,7 +78,7 @@ func TestBridge_User_RefreshEvent(t *testing.T) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
syncCh, closeCh := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) syncCh, closeCh := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{}))
if runtime.GOOS != "windows" { if runtime.GOOS != platform.WINDOWS {
require.Equal(t, userID, (<-syncCh).UserID) require.Equal(t, userID, (<-syncCh).UserID)
} }
require.Equal(t, userID, (<-syncCh).UserID) require.Equal(t, userID, (<-syncCh).UserID)

View File

@ -21,6 +21,8 @@ package constants
import ( import (
"fmt" "fmt"
"runtime" "runtime"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
) )
const VendorName = "protonmail" const VendorName = "protonmail"
@ -72,13 +74,13 @@ const (
// nolint:goconst // nolint:goconst
func getAPIOS() string { func getAPIOS() string {
switch runtime.GOOS { switch runtime.GOOS {
case "darwin": case platform.MACOS:
return "macos" return "macos"
case "linux": case platform.LINUX:
return "linux" return "linux"
case "windows": case platform.WINDOWS:
return "windows" return "windows"
default: default:

View File

@ -28,6 +28,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/bridge" "github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/pkg/restarter" "github.com/ProtonMail/proton-bridge/v3/pkg/restarter"
"github.com/abiosoft/ishell" "github.com/abiosoft/ishell"
@ -148,7 +149,7 @@ func New(
fe.AddCmd(dohCmd) fe.AddCmd(dohCmd)
//goland:noinspection GoBoolExpressions //goland:noinspection GoBoolExpressions
if runtime.GOOS == "darwin" { if runtime.GOOS == platform.MACOS {
// Apple Mail commands. // Apple Mail commands.
configureCmd := &ishell.Cmd{ configureCmd := &ishell.Cmd{
Name: "configure-apple-mail", Name: "configure-apple-mail",
@ -165,7 +166,7 @@ func New(
} }
//goland:noinspection GoBoolExpressions //goland:noinspection GoBoolExpressions
if runtime.GOOS == "darwin" { if runtime.GOOS == platform.MACOS {
certCmd.AddCmd(&ishell.Cmd{ certCmd.AddCmd(&ishell.Cmd{
Name: "status", Name: "status",
Help: "Check if the TLS certificate used by Bridge is installed in the OS keychain", Help: "Check if the TLS certificate used by Bridge is installed in the OS keychain",

View File

@ -39,6 +39,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/certs" "github.com/ProtonMail/proton-bridge/v3/internal/certs"
"github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/hv" "github.com/ProtonMail/proton-bridge/v3/internal/hv"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/service" "github.com/ProtonMail/proton-bridge/v3/internal/service"
"github.com/ProtonMail/proton-bridge/v3/internal/updater" "github.com/ProtonMail/proton-bridge/v3/internal/updater"
@ -685,5 +686,5 @@ func computeFileSocketPath() (string, error) {
// useFileSocket return true iff file socket should be used for the gRPC service. // useFileSocket return true iff file socket should be used for the gRPC service.
func useFileSocket() bool { func useFileSocket() bool {
//goland:noinspection GoBoolExpressions //goland:noinspection GoBoolExpressions
return runtime.GOOS != "windows" return runtime.GOOS != platform.WINDOWS
} }

View File

@ -32,6 +32,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/frontend/theme" "github.com/ProtonMail/proton-bridge/v3/internal/frontend/theme"
"github.com/ProtonMail/proton-bridge/v3/internal/hv" "github.com/ProtonMail/proton-bridge/v3/internal/hv"
"github.com/ProtonMail/proton-bridge/v3/internal/kb" "github.com/ProtonMail/proton-bridge/v3/internal/kb"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/service" "github.com/ProtonMail/proton-bridge/v3/internal/service"
"github.com/ProtonMail/proton-bridge/v3/internal/updater" "github.com/ProtonMail/proton-bridge/v3/internal/updater"
@ -688,7 +689,7 @@ func (s *Service) SetDiskCachePath(_ context.Context, newPath *wrapperspb.String
path := newPath.Value path := newPath.Value
//goland:noinspection GoBoolExpressions //goland:noinspection GoBoolExpressions
if (runtime.GOOS == "windows") && (path[0] == '/') { if (runtime.GOOS == platform.WINDOWS) && (path[0] == '/') {
path = path[1:] path = path[1:]
} }

View File

@ -19,6 +19,8 @@ package theme
import ( import (
"runtime" "runtime"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
) )
type Theme string type Theme string
@ -34,7 +36,7 @@ func IsAvailable(have Theme) bool {
func DefaultTheme() Theme { func DefaultTheme() Theme {
switch runtime.GOOS { switch runtime.GOOS {
case "darwin", "windows": case platform.MACOS, platform.WINDOWS:
return detectSystemTheme() return detectSystemTheme()
default: default:
return Light return Light

View File

@ -24,6 +24,7 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/pkg/files" "github.com/ProtonMail/proton-bridge/v3/pkg/files"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -90,7 +91,7 @@ func (l *Locations) getLicenseFilePath() string {
} }
switch runtime.GOOS { switch runtime.GOOS {
case "linux": case platform.LINUX:
// Most Linux distributions. // Most Linux distributions.
path := "/usr/share/doc/protonmail/" + l.configName + "/LICENSE" path := "/usr/share/doc/protonmail/" + l.configName + "/LICENSE"
if _, err := os.Stat(path); err == nil { if _, err := os.Stat(path); err == nil {
@ -98,7 +99,7 @@ func (l *Locations) getLicenseFilePath() string {
} }
// Arch distributions. // Arch distributions.
return "/usr/share/licenses/protonmail-" + l.configName + "/LICENSE" return "/usr/share/licenses/protonmail-" + l.configName + "/LICENSE"
case "darwin": //nolint:goconst case platform.MACOS: //nolint:goconst
path := filepath.Join(filepath.Dir(os.Args[0]), "..", "Resources", "LICENSE") path := filepath.Join(filepath.Dir(os.Args[0]), "..", "Resources", "LICENSE")
if _, err := os.Stat(path); err == nil { if _, err := os.Stat(path); err == nil {
return path return path
@ -109,7 +110,7 @@ func (l *Locations) getLicenseFilePath() string {
// or may not work, depends where user installed the app and how // or may not work, depends where user installed the app and how
// user started the app. // user started the app.
return "/Applications/Proton Mail Bridge.app/Contents/Resources/LICENSE" return "/Applications/Proton Mail Bridge.app/Contents/Resources/LICENSE"
case "windows": case platform.WINDOWS:
path := filepath.Join(filepath.Dir(os.Args[0]), "LICENSE.txt") path := filepath.Join(filepath.Dir(os.Args[0]), "LICENSE.txt")
if _, err := os.Stat(path); err == nil { if _, err := os.Stat(path); err == nil {
return path return path
@ -206,6 +207,14 @@ func (l *Locations) ProvideUnleashCachePath() (string, error) {
return l.getUnleashCachePath(), nil return l.getUnleashCachePath(), nil
} }
func (l *Locations) ProvideUnleashStartupCachePath() (string, error) {
if err := os.MkdirAll(l.getUnleashStartupCachePath(), 0o700); err != nil {
return "", err
}
return l.getUnleashStartupCachePath(), nil
}
func (l *Locations) getGluonCachePath() string { func (l *Locations) getGluonCachePath() string {
return filepath.Join(l.userData, "gluon") return filepath.Join(l.userData, "gluon")
} }
@ -244,6 +253,10 @@ func (l *Locations) getNotificationsCachePath() string {
func (l *Locations) getUnleashCachePath() string { return filepath.Join(l.userCache, "unleash_cache") } func (l *Locations) getUnleashCachePath() string { return filepath.Join(l.userCache, "unleash_cache") }
func (l *Locations) getUnleashStartupCachePath() string {
return filepath.Join(l.userCache, "unleash_startup_cache")
}
// Clear removes everything except the lock and update files. // Clear removes everything except the lock and update files.
func (l *Locations) Clear(except ...string) error { func (l *Locations) Clear(except ...string) error {
return files.Remove( return files.Remove(

View File

@ -22,6 +22,8 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
) )
// Provider provides standard locations. // Provider provides standard locations.
@ -95,7 +97,7 @@ func (p *DefaultProvider) UserCache() string {
// This is necessary because os.UserDataDir() is not implemented by the Go standard library, sadly. // This is necessary because os.UserDataDir() is not implemented by the Go standard library, sadly.
// On non-linux systems, it is the same as os.UserConfigDir(). // On non-linux systems, it is the same as os.UserConfigDir().
func userDataDir() (string, error) { func userDataDir() (string, error) {
if runtime.GOOS != "linux" { if runtime.GOOS != platform.LINUX {
return os.UserConfigDir() return os.UserConfigDir()
} }

View File

@ -0,0 +1,24 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package platform
const (
MACOS = "darwin"
LINUX = "linux"
WINDOWS = "windows"
)

View File

@ -286,7 +286,7 @@ func (m *LabelConflictManager) NewInternalLabelConflictResolver(connectors []*Co
mailboxFetch: m.generateMailboxFetcher(connectors), mailboxFetch: m.generateMailboxFetcher(connectors),
mailboxMessageCountFetch: m.generateMailboxMessageCountFetcher(connectors), mailboxMessageCountFetch: m.generateMailboxMessageCountFetcher(connectors),
userLabelConflictResolver: m.NewConflictResolver(connectors), userLabelConflictResolver: m.NewConflictResolver(connectors),
allowNonEmptyMailboxDeletion: m.featureFlagProvider.GetFlagValue(unleash.ItnternalLabelConflictNonEmptyMailboxDeletion), allowNonEmptyMailboxDeletion: m.featureFlagProvider.GetFlagValue(unleash.InternalLabelConflictNonEmptyMailboxDeletion),
client: m.client, client: m.client,
reporter: m.reporter, reporter: m.reporter,
log: logrus.WithFields(logrus.Fields{ log: logrus.WithFields(logrus.Fields{

View File

@ -38,14 +38,15 @@ var pollJitter = 2 * time.Minute //nolint:gochecknoglobals
const filename = "unleash_flags" const filename = "unleash_flags"
const ( const (
EventLoopNotificationDisabled = "InboxBridgeEventLoopNotificationDisabled" EventLoopNotificationDisabled = "InboxBridgeEventLoopNotificationDisabled"
IMAPAuthenticateCommandDisabled = "InboxBridgeImapAuthenticateCommandDisabled" IMAPAuthenticateCommandDisabled = "InboxBridgeImapAuthenticateCommandDisabled"
UserRemovalGluonDataCleanupDisabled = "InboxBridgeUserRemovalGluonDataCleanupDisabled" UserRemovalGluonDataCleanupDisabled = "InboxBridgeUserRemovalGluonDataCleanupDisabled"
UpdateUseNewVersionFileStructureDisabled = "InboxBridgeUpdateWithOsFilterDisabled" UpdateUseNewVersionFileStructureDisabled = "InboxBridgeUpdateWithOsFilterDisabled"
LabelConflictResolverDisabled = "InboxBridgeLabelConflictResolverDisabled" LabelConflictResolverDisabled = "InboxBridgeLabelConflictResolverDisabled"
SMTPSubmissionRequestSentryReportDisabled = "InboxBridgeSmtpSubmissionRequestSentryReportDisabled" SMTPSubmissionRequestSentryReportDisabled = "InboxBridgeSmtpSubmissionRequestSentryReportDisabled"
InternalLabelConflictResolverDisabled = "InboxBridgeUnexpectedFoldersLabelsStartupFixupDisabled" InternalLabelConflictResolverDisabled = "InboxBridgeUnexpectedFoldersLabelsStartupFixupDisabled"
ItnternalLabelConflictNonEmptyMailboxDeletion = "InboxBridgeUnknownNonEmptyMailboxDeletion" InternalLabelConflictNonEmptyMailboxDeletion = "InboxBridgeUnknownNonEmptyMailboxDeletion"
LinuxVaultPreferredKeychainNotAvailableRetryDisabled = "InboxBridgeLinuxVaultPreferredKeychainNotAvailableRetryDisabled"
) )
type FeatureFlagValueProvider interface { type FeatureFlagValueProvider interface {

132
internal/unleash/startup.go Normal file
View File

@ -0,0 +1,132 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package unleash
import (
"context"
"encoding/json"
"os"
"path/filepath"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
const startupCacheFilename = "unleash_startup_flags.json"
var logger = logrus.WithField("pkg", "unleash-startup") //nolint:gochecknoglobals
type FeatureFlagStartupStore map[string]bool
func (f FeatureFlagStartupStore) GetFlagValue(key string) bool {
val, ok := f[key]
if !ok {
return false
}
return val
}
func newAPIOptions(
apiURL string,
version *semver.Version,
) []proton.Option {
return []proton.Option{
proton.WithHostURL(apiURL),
proton.WithAppVersion(constants.AppVersion(version.Original())),
proton.WithLogger(logrus.WithField("pkg", "gpa/unleash-startup")),
proton.WithRetryCount(0),
}
}
func readStartupCacheFile(filepath string) (map[string]bool, error) {
ffStore := make(map[string]bool)
if filepath == "" {
return ffStore, nil
}
file, err := os.Open(filepath) //nolint:gosec
if err != nil {
return ffStore, err
}
defer func(file *os.File) {
if err := file.Close(); err != nil {
logger.WithError(err).Error("Unable to close cache file after read")
}
}(file)
if err := json.NewDecoder(file).Decode(&ffStore); err != nil {
return ffStore, err
}
return ffStore, nil
}
func saveStartupCacheFile(ffStore map[string]bool, filepath string) error {
if filepath == "" {
return nil
}
file, err := os.Create(filepath) //nolint:gosec
if err != nil {
return err
}
defer func(file *os.File) {
if err := file.Close(); err != nil {
logger.WithError(err).Error("Unable to close cache file after write")
}
}(file)
if err := json.NewEncoder(file).Encode(ffStore); err != nil {
return err
}
return nil
}
func GetStartupFeatureFlagsAndStore(apiURL string, curVersion *semver.Version, unleashCachePathProvider func() (string, error)) map[string]bool {
var cacheFilepath string
cacheDir, err := unleashCachePathProvider()
if err != nil {
logger.WithError(err).Warn("Unable to obtain feature flag cache filepath")
} else {
cacheFilepath = filepath.Clean(filepath.Join(cacheDir, startupCacheFilename))
}
ffStore, err := readStartupCacheFile(cacheFilepath)
if err != nil {
logger.WithError(err).Warn("An issue occurred when reading the cache file")
}
manager := proton.New(newAPIOptions(apiURL, curVersion)...)
featureFlagResult, err := manager.GetFeatures(context.Background(), uuid.New())
if err == nil {
ffStore = readResponseData(featureFlagResult)
} else {
logger.WithError(err).Warn("Failed to obtain feature flags from API")
}
if err := saveStartupCacheFile(ffStore, cacheFilepath); err != nil {
logger.WithError(err).Warn("An issue occurred when saving the cache file")
}
return ffStore
}

View File

@ -0,0 +1,156 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package unleash
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/Masterminds/semver/v3"
"github.com/stretchr/testify/require"
)
func TestReadStartupCacheFile_Success(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "valid_cache")
file, err := os.Create(filePath)
require.NoError(t, err)
testData := map[string]bool{
"feature1": true,
"feature2": false,
}
err = json.NewEncoder(file).Encode(testData)
require.NoError(t, err)
err = file.Close()
require.NoError(t, err)
startupCache, err := readStartupCacheFile(filePath)
require.NoError(t, err)
require.Equal(t, testData, startupCache)
}
func TestReadStartupCacheFile_InvalidFilePath(t *testing.T) {
filePath := "badFilepath/hello"
startupCache, err := readStartupCacheFile(filePath)
require.Error(t, err)
require.Empty(t, startupCache)
}
func TestSaveStartupCacheFile_Success(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "test_cache")
testData := map[string]bool{
"feature1": true,
"feature2": false,
"feature3": true,
}
err := saveStartupCacheFile(testData, filePath)
require.NoError(t, err)
savedData, err := readStartupCacheFile(filePath)
require.NoError(t, err)
require.Equal(t, testData, savedData)
}
func TestSaveStartupCacheFile_InvalidFilePath(t *testing.T) {
badFilePath := "/some_random_dir/hey/hello"
testData := map[string]bool{
"feature1": true,
"feature2": false,
}
err := saveStartupCacheFile(testData, badFilePath)
require.Error(t, err)
}
func TestGetStartupFeatureFlagsAndStore_FakeAPIURL(t *testing.T) {
apiURL := "https://example.com"
cacheProvider := func() (string, error) {
return t.TempDir(), nil
}
version, err := semver.NewVersion("3.99.99+test")
require.NoError(t, err)
featureFlags := GetStartupFeatureFlagsAndStore(apiURL, version, cacheProvider)
require.Empty(t, featureFlags)
}
func TestGetStartupFeatureFlagsAndStore_RealAPIURL(t *testing.T) {
apiURL := "https://mail-api.proton.me"
cacheProvider := func() (string, error) {
return t.TempDir(), nil
}
version, err := semver.NewVersion("3.99.99+test")
require.NoError(t, err)
featureFlags := GetStartupFeatureFlagsAndStore(apiURL, version, cacheProvider)
require.NotEmpty(t, featureFlags)
}
func TestGetStartupFeatureFlagsAndStore_FeatureFlagCacheRetention(t *testing.T) {
fakeAPIURL := "https://example.com"
realAPIURL := "https://mail-api.proton.me"
cacheDir := t.TempDir()
cacheProvider := func() (string, error) {
return cacheDir, nil
}
version, err := semver.NewVersion("3.99.99+test")
require.NoError(t, err)
featureFlags := GetStartupFeatureFlagsAndStore(realAPIURL, version, cacheProvider)
require.NotEmpty(t, featureFlags)
featureFlagsFromCache := GetStartupFeatureFlagsAndStore(fakeAPIURL, version, cacheProvider)
require.NotEmpty(t, featureFlagsFromCache)
require.Equal(t, featureFlags, featureFlagsFromCache)
}
func Test(t *testing.T) {
fakeAPIURL := "https://example.com"
tmpDir := t.TempDir()
cacheProvider := func() (string, error) {
return tmpDir, nil
}
filePath := filepath.Join(tmpDir, startupCacheFilename)
testData := map[string]bool{
"feature1": true,
"feature2": false,
"feature3": true,
}
err := saveStartupCacheFile(testData, filePath)
require.NoError(t, err)
version, err := semver.NewVersion("3.99.99+git")
require.NoError(t, err)
featureFlagsFromCache := GetStartupFeatureFlagsAndStore(fakeAPIURL, version, cacheProvider)
require.NotEmpty(t, featureFlagsFromCache)
require.Equal(t, testData, featureFlagsFromCache)
}

View File

@ -26,6 +26,7 @@ import (
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/versioner" "github.com/ProtonMail/proton-bridge/v3/internal/versioner"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -194,7 +195,7 @@ func (u *Updater) getVersionFileURLLegacy() string {
// - https://protonmail.com/download/darwin/universal/v1/version.json // - https://protonmail.com/download/darwin/universal/v1/version.json
func (u *Updater) getVersionFileURL() string { func (u *Updater) getVersionFileURL() string {
switch u.platform { switch u.platform {
case "darwin": case platform.MACOS:
return fmt.Sprintf("%v/%v/%v/universal/v%v/version.json", Host, u.product, u.platform, u.version) return fmt.Sprintf("%v/%v/%v/universal/v%v/version.json", Host, u.product, u.platform, u.version)
default: default:
return fmt.Sprintf("%v/%v/%v/x86/v%v/version.json", Host, u.product, u.platform, u.version) return fmt.Sprintf("%v/%v/%v/x86/v%v/version.json", Host, u.product, u.platform, u.version)

View File

@ -22,6 +22,7 @@ import (
"strings" "strings"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
) )
// IsCatalinaOrNewer checks whether the host is macOS Catalina 10.15.x or higher. // IsCatalinaOrNewer checks whether the host is macOS Catalina 10.15.x or higher.
@ -44,7 +45,7 @@ func getMinBigSur() *semver.Version { return semver.MustParse("20.0.0") }
func getMinVentura() *semver.Version { return semver.MustParse("22.0.0") } func getMinVentura() *semver.Version { return semver.MustParse("22.0.0") }
func isThisDarwinNewerOrEqual(minVersion *semver.Version) bool { func isThisDarwinNewerOrEqual(minVersion *semver.Version) bool {
if runtime.GOOS != "darwin" { if runtime.GOOS != platform.MACOS {
return false return false
} }

View File

@ -76,6 +76,34 @@ func SetHelper(vaultDir, helper string) error {
return settings.Save(vaultDir) return settings.Save(vaultDir)
} }
func GetKeychainFailedAttemptCount(vaultDir string) (int, error) {
keychainState, err := LoadKeychainState(vaultDir)
if err != nil {
return 0, err
}
return keychainState.FailedAttempts, nil
}
func IncrementKeychainFailedAttemptCount(vaultDir string) error {
keychainState, err := LoadKeychainState(vaultDir)
if err != nil {
return err
}
keychainState.FailedAttempts++
return keychainState.Save(vaultDir)
}
// ResetFailedKeychainAttemptCount - resets the failed keychain attempt count, and stores the data in the appropriate helper file.
func ResetFailedKeychainAttemptCount(vaultDir string) error {
keychainState, err := LoadKeychainState(vaultDir)
if err != nil {
return err
}
return keychainState.ResetAndSave(vaultDir)
}
func GetVaultKey(kc *keychain.Keychain) ([]byte, error) { func GetVaultKey(kc *keychain.Keychain) ([]byte, error) {
_, keyEnc, err := kc.Get(vaultSecretName) _, keyEnc, err := kc.Get(vaultSecretName)
if err != nil { if err != nil {

View File

@ -17,15 +17,7 @@
package vault package vault
import ( import "github.com/ProtonMail/proton-bridge/v3/internal/vault/storage"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"github.com/sirupsen/logrus"
)
const keychainSettingsFileName = "keychain.json" const keychainSettingsFileName = "keychain.json"
@ -35,40 +27,15 @@ type KeychainSettings struct {
DisableTest bool // Is the keychain test on startup disabled? DisableTest bool // Is the keychain test on startup disabled?
} }
var keychainSettingsFile = storage.NewJSONStorageFile[KeychainSettings](keychainSettingsFileName, "keychain settings") //nolint:gochecknoglobals
// LoadKeychainSettings load keychain settings from the vaultDir folder, or returns a default one if the file // LoadKeychainSettings load keychain settings from the vaultDir folder, or returns a default one if the file
// does not exists or is invalid. // does not exists or is invalid.
func LoadKeychainSettings(vaultDir string) (KeychainSettings, error) { func LoadKeychainSettings(vaultDir string) (KeychainSettings, error) {
path := filepath.Join(vaultDir, keychainSettingsFileName) return keychainSettingsFile.Load(vaultDir)
bytes, err := os.ReadFile(path) //nolint:gosec
if err != nil {
if errors.Is(err, os.ErrNotExist) {
logrus.
WithFields(logrus.Fields{"pkg": "vault", "path": path}).
Trace("Keychain settings file does not exists, default values will be used")
return KeychainSettings{}, nil
}
return KeychainSettings{}, err
}
var result KeychainSettings
if err := json.Unmarshal(bytes, &result); err != nil {
return KeychainSettings{}, fmt.Errorf("keychain settings file is invalid settings: %w", err)
}
return result, nil
} }
// Save saves the keychain settings in a file in the vaultDir folder. // Save saves the keychain settings in a file in the vaultDir folder.
func (k KeychainSettings) Save(vaultDir string) error { func (k KeychainSettings) Save(vaultDir string) error {
bytes, err := json.MarshalIndent(k, "", " ") return keychainSettingsFile.Save(vaultDir, k)
if err != nil {
return err
}
if err = os.MkdirAll(vaultDir, 0o700); err != nil {
return err
}
path := filepath.Join(vaultDir, keychainSettingsFileName)
return os.WriteFile(path, bytes, 0o600)
} }

View File

@ -0,0 +1,53 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package vault
import (
"runtime"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/vault/storage"
)
const keychainStateFileName = "keychain_state.json"
type KeychainState struct {
FailedAttempts int
}
var keychainStateFile = storage.NewJSONStorageFile[KeychainState](keychainStateFileName, "keychain state") //nolint:gochecknoglobals
func LoadKeychainState(vaultDir string) (KeychainState, error) {
if runtime.GOOS != platform.LINUX {
return KeychainState{}, nil
}
return keychainStateFile.Load(vaultDir)
}
func (k KeychainState) Save(vaultDir string) error {
if runtime.GOOS != platform.LINUX {
return nil
}
return keychainStateFile.Save(vaultDir, k)
}
func (k KeychainState) ResetAndSave(vaultDir string) error {
k.FailedAttempts = 0
return k.Save(vaultDir)
}

View File

@ -0,0 +1,75 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package vault
import (
"runtime"
"testing"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/stretchr/testify/require"
)
func TestKeychainState(t *testing.T) {
dir := t.TempDir()
// Load a non-existing keychain state file. It should return the defaults if it does not exist and no error will be thrown.
keychainState, err := LoadKeychainState(dir)
require.NoError(t, err)
require.Equal(t, KeychainState{}, keychainState)
// Increment the failed attempt count. The function call will save the data to the file.
err = IncrementKeychainFailedAttemptCount(dir)
require.NoError(t, err)
// Load the state from the now existing file. We isolate the behaviour of the helper to Linux.
// Thus, a nil state is expected on other OS'.
keychainState, err = LoadKeychainState(dir)
require.NoError(t, err)
if runtime.GOOS == platform.LINUX {
require.Equal(t, KeychainState{
FailedAttempts: 1,
}, keychainState)
} else {
require.Equal(t, KeychainState{}, keychainState)
}
// Increment again.
err = IncrementKeychainFailedAttemptCount(dir)
require.NoError(t, err)
// Same thing, we only expect linux to have data.
keychainState, err = LoadKeychainState(dir)
require.NoError(t, err)
if runtime.GOOS == platform.LINUX {
require.Equal(t, KeychainState{
FailedAttempts: 2,
}, keychainState)
} else {
require.Equal(t, KeychainState{}, keychainState)
}
// Reset the failed attempt count.
err = ResetFailedKeychainAttemptCount(dir)
require.NoError(t, err)
// All OS' states should match in this case.
keychainState, err = LoadKeychainState(dir)
require.NoError(t, err)
require.Equal(t, KeychainState{}, keychainState)
}

View File

@ -0,0 +1,75 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package storage
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"github.com/sirupsen/logrus"
)
type JSONFile[T any] struct {
fileName string
fileType string
}
func NewJSONStorageFile[T any](fileName, fileType string) *JSONFile[T] {
return &JSONFile[T]{
fileName: fileName,
fileType: fileType,
}
}
func (jf *JSONFile[T]) Load(vaultDir string) (T, error) {
var result T
path := filepath.Join(vaultDir, jf.fileName)
bytes, err := os.ReadFile(path) //nolint:gosec
if err != nil {
if errors.Is(err, os.ErrNotExist) {
logrus.
WithFields(logrus.Fields{"pkg": "vault", "path": path}).
Tracef("%s file does not exists, default values will be used", jf.fileType)
return result, nil
}
return result, err
}
if err := json.Unmarshal(bytes, &result); err != nil {
return result, fmt.Errorf("%s file has invalid data: %w", jf.fileType, err)
}
return result, nil
}
func (jf *JSONFile[T]) Save(vaultDir string, data T) error {
bytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return err
}
if err = os.MkdirAll(vaultDir, 0o700); err != nil {
return err
}
path := filepath.Join(vaultDir, jf.fileName)
return os.WriteFile(path, bytes, 0o600)
}

View File

@ -20,6 +20,8 @@ package versioner
import ( import (
"os" "os"
"runtime" "runtime"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
) )
// fileExists returns whether the given file exists. // fileExists returns whether the given file exists.
@ -30,7 +32,7 @@ func fileExists(path string) bool {
// fileIsExecutable returns the given filepath and true if it exists. // fileIsExecutable returns the given filepath and true if it exists.
func fileIsExecutable(path string) bool { func fileIsExecutable(path string) bool {
if runtime.GOOS == "windows" { if runtime.GOOS == platform.WINDOWS {
return true return true
} }

View File

@ -27,6 +27,8 @@ import (
"time" "time"
"github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/docker/docker-credential-helpers/credentials" "github.com/docker/docker-credential-helpers/credentials"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -37,6 +39,10 @@ type helperConstructor func(string) (credentials.Helper, error)
// Version is the keychain data version. // Version is the keychain data version.
const Version = "k11" const Version = "k11"
// MaxFailedKeychainAttemptsLinux defines the number of failed attempts allowed for the preferred keychain on Linux.
// Since counting starts at 0, a value of 2 allows for 3 total attempts.
const MaxFailedKeychainAttemptsLinux = 2
var ( var (
// ErrNoKeychain indicates that no suitable keychain implementation could be loaded. // ErrNoKeychain indicates that no suitable keychain implementation could be loaded.
ErrNoKeychain = errors.New("no keychain") //nolint:gochecknoglobals ErrNoKeychain = errors.New("no keychain") //nolint:gochecknoglobals
@ -45,6 +51,8 @@ var (
ErrMacKeychainRebuild = errors.New("keychain error -25293") ErrMacKeychainRebuild = errors.New("keychain error -25293")
ErrKeychainNoItem = errors.New("no such keychain item") ErrKeychainNoItem = errors.New("no such keychain item")
ErrPreferredKeychainNotAvailable = errors.New("preferred keychain is not available or usable")
) )
func IsErrKeychainNoItem(err error) bool { func IsErrKeychainNoItem(err error) bool {
@ -82,15 +90,39 @@ func (kcl *List) GetDefaultHelper() string {
return kcl.defaultHelper return kcl.defaultHelper
} }
func PreferredKeychainRetryError(attemptCount int) error {
return fmt.Errorf("%w, %d attempts remaining till vault reset", ErrPreferredKeychainNotAvailable, MaxFailedKeychainAttemptsLinux-attemptCount)
}
func ShouldRetryPreferredKeychain(featureFlags unleash.FeatureFlagStartupStore, preferredKeychain string) bool {
return !featureFlags.GetFlagValue(unleash.LinuxVaultPreferredKeychainNotAvailableRetryDisabled) &&
runtime.GOOS == platform.LINUX && preferredKeychain != ""
}
// NewKeychain creates a new native keychain. It also returns the keychain helper used to access the keychain. // NewKeychain creates a new native keychain. It also returns the keychain helper used to access the keychain.
func NewKeychain(preferred, keychainName string, helpers Helpers, defaultHelper string) (kc *Keychain, usedKeychainHelper string, err error) { func NewKeychain(
preferred, keychainName string,
helpers Helpers,
defaultHelper string,
keychainFailedAttemptCount int,
featureFlags unleash.FeatureFlagStartupStore,
) (kc *Keychain, usedKeychainHelper string, err error) {
// There must be at least one keychain helper available. // There must be at least one keychain helper available.
if len(helpers) < 1 { if len(helpers) < 1 {
return nil, "", ErrNoKeychain return nil, "", ErrNoKeychain
} }
// If the preferred keychain is unsupported, fallback to the default one. // If the preferred keychain is unsupported, fallback to the default one.
// For linux, keep on exiting early before wiping the vault until we've exceeded the allowed retry count.
if _, ok := helpers[preferred]; !ok { if _, ok := helpers[preferred]; !ok {
if ShouldRetryPreferredKeychain(featureFlags, preferred) {
if keychainFailedAttemptCount < MaxFailedKeychainAttemptsLinux {
return nil, "", PreferredKeychainRetryError(keychainFailedAttemptCount)
}
logrus.Errorf("%s, max attempts have been exceeded, resetting vault", ErrPreferredKeychainNotAvailable)
}
preferred = defaultHelper preferred = defaultHelper
} }
@ -242,7 +274,7 @@ func isUsable(helper credentials.Helper, err error) bool { //nolint:unused
func getTestCredentials() *credentials.Credentials { //nolint:unused func getTestCredentials() *credentials.Credentials { //nolint:unused
// On macOS, a handful of users experience failures of the test credentials. // On macOS, a handful of users experience failures of the test credentials.
if runtime.GOOS == "darwin" { if runtime.GOOS == platform.MACOS {
return &credentials.Credentials{ return &credentials.Credentials{
ServerURL: hostURL(constants.KeyChainName) + fmt.Sprintf("/check_%v", time.Now().UTC().UnixMicro()), ServerURL: hostURL(constants.KeyChainName) + fmt.Sprintf("/check_%v", time.Now().UTC().UnixMicro()),
Username: "", // username is ignored on macOS, it's extracted from splitting the server URL Username: "", // username is ignored on macOS, it's extracted from splitting the server URL

View File

@ -120,7 +120,12 @@ func TestIsErrKeychainNoItem(t *testing.T) {
helpers := NewList().GetHelpers() helpers := NewList().GetHelpers()
for helperName := range helpers { for helperName := range helpers {
kc, _, err := NewKeychain(helperName, "bridge-test", helpers, helperName) kc, _, err := NewKeychain(
helperName, "bridge-test",
helpers, helperName,
0,
make(map[string]bool),
)
r.NoError(err) r.NoError(err)
_, _, err = kc.Get("non-existing") _, _, err = kc.Get("non-existing")

View File

@ -25,6 +25,7 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"github.com/ProtonMail/proton-bridge/v3/internal/platform"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -91,7 +92,7 @@ func UntarToDir(r io.Reader, dir string) error {
if _, err := io.Copy(f, lr); err != nil { if _, err := io.Copy(f, lr); err != nil {
return err return err
} }
if runtime.GOOS != "windows" { if runtime.GOOS != platform.WINDOWS {
if err := f.Chmod(header.FileInfo().Mode()); err != nil { if err := f.Chmod(header.FileInfo().Mode()); err != nil {
return err return err
} }

View File

@ -62,7 +62,7 @@ func main() {
func getRollout(_ *cli.Context) error { func getRollout(_ *cli.Context) error {
return app.WithLocations(func(locations *locations.Locations) error { return app.WithLocations(func(locations *locations.Locations) error {
return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error { return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error {
return app.WithVault(nil, locations, keychains, async.NoopPanicHandler{}, func(vault *vault.Vault, _, _ bool) error { return app.WithVault(nil, locations, keychains, make(map[string]bool), async.NoopPanicHandler{}, func(vault *vault.Vault, _, _ bool) error {
fmt.Println(vault.GetUpdateRollout()) fmt.Println(vault.GetUpdateRollout())
return nil return nil
}) })
@ -73,7 +73,7 @@ func getRollout(_ *cli.Context) error {
func setRollout(c *cli.Context) error { func setRollout(c *cli.Context) error {
return app.WithLocations(func(locations *locations.Locations) error { return app.WithLocations(func(locations *locations.Locations) error {
return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error { return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error {
return app.WithVault(nil, locations, keychains, async.NoopPanicHandler{}, func(vault *vault.Vault, _, _ bool) error { return app.WithVault(nil, locations, keychains, make(map[string]bool), async.NoopPanicHandler{}, func(vault *vault.Vault, _, _ bool) error {
clamped := max(0.0, min(1.0, c.Float64("value"))) clamped := max(0.0, min(1.0, c.Float64("value")))
if err := vault.SetUpdateRollout(clamped); err != nil { if err := vault.SetUpdateRollout(clamped); err != nil {
return err return err

View File

@ -27,6 +27,7 @@ import (
"github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/proton-bridge/v3/internal/app" "github.com/ProtonMail/proton-bridge/v3/internal/app"
"github.com/ProtonMail/proton-bridge/v3/internal/locations" "github.com/ProtonMail/proton-bridge/v3/internal/locations"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
@ -52,7 +53,7 @@ func main() {
func readAction(c *cli.Context) error { func readAction(c *cli.Context) error {
return app.WithLocations(func(locations *locations.Locations) error { return app.WithLocations(func(locations *locations.Locations) error {
return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error { return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error {
return app.WithVault(nil, locations, keychains, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error { return app.WithVault(nil, locations, keychains, make(unleash.FeatureFlagStartupStore), async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error {
if _, err := os.Stdout.Write(vault.ExportJSON()); err != nil { if _, err := os.Stdout.Write(vault.ExportJSON()); err != nil {
return fmt.Errorf("failed to write vault: %w", err) return fmt.Errorf("failed to write vault: %w", err)
} }
@ -66,7 +67,7 @@ func readAction(c *cli.Context) error {
func writeAction(c *cli.Context) error { func writeAction(c *cli.Context) error {
return app.WithLocations(func(locations *locations.Locations) error { return app.WithLocations(func(locations *locations.Locations) error {
return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error { return app.WithKeychainList(async.NoopPanicHandler{}, func(keychains *keychain.List) error {
return app.WithVault(nil, locations, keychains, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error { return app.WithVault(nil, locations, keychains, make(unleash.FeatureFlagStartupStore), async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error {
b, err := io.ReadAll(os.Stdin) b, err := io.ReadAll(os.Stdin)
if err != nil { if err != nil {
return fmt.Errorf("failed to read vault: %w", err) return fmt.Errorf("failed to read vault: %w", err)