From 48dfdabaf4d44b1007081e065ec16b8cfa970a00 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Mon, 21 Nov 2022 01:38:51 +0100 Subject: [PATCH] GODT-1975: Migrate keychain secrets --- internal/app/app.go | 22 +- internal/app/migration.go | 90 +++++- internal/app/migration_test.go | 2 +- internal/legacy/credentials/credentials.go | 136 ++++++++ .../legacy/credentials/credentials_test.go | 67 ++++ internal/legacy/credentials/store.go | 118 +++++++ internal/legacy/credentials/store_test.go | 298 ++++++++++++++++++ internal/vault/helper.go | 8 +- internal/vault/types_data.go | 1 + internal/vault/vault.go | 10 + 10 files changed, 742 insertions(+), 10 deletions(-) create mode 100644 internal/legacy/credentials/credentials.go create mode 100644 internal/legacy/credentials/credentials_test.go create mode 100644 internal/legacy/credentials/store.go create mode 100644 internal/legacy/credentials/store_test.go diff --git a/internal/app/app.go b/internal/app/app.go index 5558fceb..caefc980 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -183,6 +183,11 @@ func run(c *cli.Context) error { //nolint:funlen return withCrashHandler(restarter, reporter, func(crashHandler *crash.Handler) error { // Load the locations where we store our files. return WithLocations(func(locations *locations.Locations) error { + // Migrate the keychain helper. + if err := migrateKeychainHelper(locations); err != nil { + logrus.WithError(err).Error("Failed to migrate keychain helper") + } + // Initialize logging. return withLogging(c, crashHandler, locations, func() error { // If there was an error during migration, log it now. @@ -194,8 +199,21 @@ func run(c *cli.Context) error { //nolint:funlen return withSingleInstance(locations, version, func() error { // Unlock the encrypted vault. return WithVault(locations, func(vault *vault.Vault, insecure, corrupt bool) error { - if err := migrateOldSettings(vault); err != nil { - logrus.WithError(err).Error("Failed to migrate old settings") + if !vault.Migrated() { + // Migrate old settings into the vault. + if err := migrateOldSettings(vault); err != nil { + logrus.WithError(err).Error("Failed to migrate old settings") + } + + // Migrate old accounts into the vault. + if err := migrateOldAccounts(locations, vault); err != nil { + logrus.WithError(err).Error("Failed to migrate old accounts") + } + + // The vault has been migrated. + if err := vault.SetMigrated(); err != nil { + logrus.WithError(err).Error("Failed to mark vault as migrated") + } } // Load the cookies from the vault. diff --git a/internal/app/migration.go b/internal/app/migration.go index 68514925..196e2ad1 100644 --- a/internal/app/migration.go +++ b/internal/app/migration.go @@ -27,8 +27,11 @@ import ( "time" "github.com/Masterminds/semver/v3" + "github.com/ProtonMail/proton-bridge/v2/internal/legacy/credentials" + "github.com/ProtonMail/proton-bridge/v2/internal/locations" "github.com/ProtonMail/proton-bridge/v2/internal/updater" "github.com/ProtonMail/proton-bridge/v2/internal/vault" + "github.com/ProtonMail/proton-bridge/v2/pkg/keychain" "github.com/allan-simon/go-singleinstance" "github.com/hashicorp/go-multierror" "github.com/pkg/errors" @@ -36,7 +39,9 @@ import ( ) // nolint:gosec -func migrateOldSettings(vault *vault.Vault) error { +func migrateKeychainHelper(locations *locations.Locations) error { + logrus.Info("Migrating keychain helper") + configDir, err := os.UserConfigDir() if err != nil { return fmt.Errorf("failed to get user config dir: %w", err) @@ -47,7 +52,88 @@ func migrateOldSettings(vault *vault.Vault) error { return fmt.Errorf("failed to read old prefs file: %w", err) } - return migratePrefsToVault(vault, b) + var prefs struct { + Helper string `json:"preferred_keychain"` + } + + if err := json.Unmarshal(b, &prefs); err != nil { + return fmt.Errorf("failed to unmarshal old prefs file: %w", err) + } + + settings, err := locations.ProvideSettingsPath() + if err != nil { + return fmt.Errorf("failed to get settings path: %w", err) + } + + return vault.SetHelper(settings, prefs.Helper) +} + +// nolint:gosec +func migrateOldSettings(v *vault.Vault) error { + logrus.Info("Migrating settings") + + configDir, err := os.UserConfigDir() + if err != nil { + return fmt.Errorf("failed to get user config dir: %w", err) + } + + b, err := os.ReadFile(filepath.Join(configDir, "protonmail", "bridge", "prefs.json")) + if err != nil { + return fmt.Errorf("failed to read old prefs file: %w", err) + } + + return migratePrefsToVault(v, b) +} + +func migrateOldAccounts(locations *locations.Locations, v *vault.Vault) error { + logrus.Info("Migrating accounts") + + settings, err := locations.ProvideSettingsPath() + if err != nil { + return fmt.Errorf("failed to get settings path: %w", err) + } + + helper, err := vault.GetHelper(settings) + if err != nil { + return fmt.Errorf("failed to get helper: %w", err) + } + + keychain, err := keychain.NewKeychain(helper, "bridge") + if err != nil { + return fmt.Errorf("failed to create keychain: %w", err) + } + + store := credentials.NewStore(keychain) + + users, err := store.List() + if err != nil { + return fmt.Errorf("failed to create credentials store: %w", err) + } + + for _, userID := range users { + logrus.WithField("userID", userID).Info("Migrating account") + + creds, err := store.Get(userID) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) + } + + authUID, authRef, err := creds.SplitAPIToken() + if err != nil { + return fmt.Errorf("failed to split api token: %w", err) + } + + user, err := v.AddUser(creds.UserID, creds.EmailList()[0], authUID, authRef, creds.MailboxPassword) + if err != nil { + return fmt.Errorf("failed to add user: %w", err) + } + + if err := user.Close(); err != nil { + return fmt.Errorf("failed to close user: %w", err) + } + } + + return nil } // nolint:funlen diff --git a/internal/app/migration_test.go b/internal/app/migration_test.go index 74690636..dcac4cf9 100644 --- a/internal/app/migration_test.go +++ b/internal/app/migration_test.go @@ -30,7 +30,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestMigrateOldVaultFromJSON(t *testing.T) { +func TestMigratePrefsToVault(t *testing.T) { // Create a new vault. vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key")) require.NoError(t, err) diff --git a/internal/legacy/credentials/credentials.go b/internal/legacy/credentials/credentials.go new file mode 100644 index 00000000..b0ce2a69 --- /dev/null +++ b/internal/legacy/credentials/credentials.go @@ -0,0 +1,136 @@ +// Copyright (c) 2022 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +// Package credentials implements our struct stored in keychain. +// Store struct is kind of like a database client. +// Credentials struct is kind of like one record from the database. +package credentials + +import ( + "encoding/base64" + "errors" + "fmt" + "strings" + + "github.com/sirupsen/logrus" +) + +const ( + sep = "\x00" + + itemLengthBridge = 9 + itemLengthImportExport = 6 // Old format for Import-Export. +) + +var ( + log = logrus.WithField("pkg", "credentials") //nolint:gochecknoglobals + + ErrWrongFormat = errors.New("malformed credentials") +) + +type Credentials struct { + UserID, // Do not marshal; used as a key. + Name, + Emails, + APIToken string + MailboxPassword []byte + BridgePassword, + Version string + Timestamp int64 + IsHidden, // Deprecated. + IsCombinedAddressMode bool +} + +func (s *Credentials) Marshal() string { + items := []string{ + s.Name, // 0 + s.Emails, // 1 + s.APIToken, // 2 + string(s.MailboxPassword), // 3 + s.BridgePassword, // 4 + s.Version, // 5 + "", // 6 + "", // 7 + "", // 8 + } + + items[6] = fmt.Sprint(s.Timestamp) + + if s.IsHidden { + items[7] = "1" + } + + if s.IsCombinedAddressMode { + items[8] = "1" + } + + str := strings.Join(items, sep) + return base64.StdEncoding.EncodeToString([]byte(str)) +} + +func (s *Credentials) Unmarshal(secret string) error { + b, err := base64.StdEncoding.DecodeString(secret) + if err != nil { + return err + } + items := strings.Split(string(b), sep) + + if len(items) != itemLengthBridge && len(items) != itemLengthImportExport { + return ErrWrongFormat + } + + s.Name = items[0] + s.Emails = items[1] + s.APIToken = items[2] + s.MailboxPassword = []byte(items[3]) + + switch len(items) { + case itemLengthBridge: + s.BridgePassword = items[4] + s.Version = items[5] + if _, err = fmt.Sscan(items[6], &s.Timestamp); err != nil { + s.Timestamp = 0 + } + if s.IsHidden = false; items[7] == "1" { + s.IsHidden = true + } + if s.IsCombinedAddressMode = false; items[8] == "1" { + s.IsCombinedAddressMode = true + } + + case itemLengthImportExport: + s.Version = items[4] + if _, err = fmt.Sscan(items[5], &s.Timestamp); err != nil { + s.Timestamp = 0 + } + } + return nil +} + +func (s *Credentials) EmailList() []string { + return strings.Split(s.Emails, ";") +} + +func (s *Credentials) SplitAPIToken() (string, string, error) { + split := strings.Split(s.APIToken, ":") + + if len(split) != 2 { + return "", "", errors.New("malformed API token") + } + + return split[0], split[1], nil +} diff --git a/internal/legacy/credentials/credentials_test.go b/internal/legacy/credentials/credentials_test.go new file mode 100644 index 00000000..fb8460b7 --- /dev/null +++ b/internal/legacy/credentials/credentials_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2022 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package credentials + +import ( + "encoding/base64" + "fmt" + "strings" + "testing" + "time" + + r "github.com/stretchr/testify/require" +) + +var wantCredentials = Credentials{ + UserID: "1", + Name: "name", + Emails: "email1;email2", + APIToken: "token", + MailboxPassword: []byte("mailbox pass"), + BridgePassword: "bridge pass", + Version: "k11", + Timestamp: time.Now().Unix(), + IsHidden: false, + IsCombinedAddressMode: false, +} + +func TestUnmarshallBridge(t *testing.T) { + encoded := wantCredentials.Marshal() + haveCredentials := Credentials{UserID: "1"} + r.NoError(t, haveCredentials.Unmarshal(encoded)) + r.Equal(t, wantCredentials, haveCredentials) +} + +func TestUnmarshallImportExport(t *testing.T) { + items := []string{ + wantCredentials.Name, + wantCredentials.Emails, + wantCredentials.APIToken, + string(wantCredentials.MailboxPassword), + "k11", + fmt.Sprint(wantCredentials.Timestamp), + } + + str := strings.Join(items, sep) + encoded := base64.StdEncoding.EncodeToString([]byte(str)) + + haveCredentials := Credentials{UserID: "1"} + haveCredentials.BridgePassword = wantCredentials.BridgePassword // This one is not used. + r.NoError(t, haveCredentials.Unmarshal(encoded)) + r.Equal(t, wantCredentials, haveCredentials) +} diff --git a/internal/legacy/credentials/store.go b/internal/legacy/credentials/store.go new file mode 100644 index 00000000..c2ff320a --- /dev/null +++ b/internal/legacy/credentials/store.go @@ -0,0 +1,118 @@ +// Copyright (c) 2022 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package credentials + +import ( + "errors" + "fmt" + "sort" + "sync" +) + +var storeLocker = sync.RWMutex{} //nolint:gochecknoglobals + +// Store is an encrypted credentials store. +type Store struct { + secrets Keychain +} + +type Keychain interface { + List() ([]string, error) + Get(string) (string, string, error) + Put(string, string) error + Delete(string) error +} + +// NewStore creates a new encrypted credentials store. +func NewStore(keychain Keychain) *Store { + return &Store{secrets: keychain} +} + +// List returns a list of usernames that have credentials stored. +func (s *Store) List() (userIDs []string, err error) { + storeLocker.RLock() + defer storeLocker.RUnlock() + + log.Trace("Listing credentials in credentials store") + + var allUserIDs []string + if allUserIDs, err = s.secrets.List(); err != nil { + log.WithError(err).Error("Could not list credentials") + return + } + + credentialList := []*Credentials{} + for _, userID := range allUserIDs { + creds, getErr := s.get(userID) + if getErr != nil { + log.WithField("userID", userID).WithError(getErr).Warn("Failed to get credentials") + continue + } + + // Disabled credentials + if creds.Timestamp == 0 { + continue + } + + credentialList = append(credentialList, creds) + } + + sort.Slice(credentialList, func(i, j int) bool { + return credentialList[i].Timestamp < credentialList[j].Timestamp + }) + + for _, credentials := range credentialList { + userIDs = append(userIDs, credentials.UserID) + } + + return userIDs, err +} + +func (s *Store) Get(userID string) (creds *Credentials, err error) { + storeLocker.RLock() + defer storeLocker.RUnlock() + + return s.get(userID) +} + +func (s *Store) get(userID string) (*Credentials, error) { + log := log.WithField("user", userID) + + _, secret, err := s.secrets.Get(userID) + if err != nil { + return nil, err + } + + if secret == "" { + return nil, errors.New("secret is empty") + } + + credentials := &Credentials{UserID: userID} + + if err := credentials.Unmarshal(secret); err != nil { + log.WithError(fmt.Errorf("malformed secret: %w", err)).Error("Could not unmarshal secret") + + if err := s.secrets.Delete(userID); err != nil { + log.WithError(err).Error("Failed to remove malformed secret") + } + + return nil, err + } + + return credentials, nil +} diff --git a/internal/legacy/credentials/store_test.go b/internal/legacy/credentials/store_test.go new file mode 100644 index 00000000..1b96deff --- /dev/null +++ b/internal/legacy/credentials/store_test.go @@ -0,0 +1,298 @@ +// Copyright (c) 2022 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package credentials + +import ( + "bytes" + "encoding/base64" + "encoding/gob" + "encoding/json" + "fmt" + "strings" + "testing" + + r "github.com/stretchr/testify/require" +) + +const ( + testSep = "\n" + secretFormat = "%v" + testSep + // UserID, + "%v" + testSep + // Name, + "%v" + testSep + // Emails, + "%v" + testSep + // APIToken, + "%v" + testSep + // Mailbox, + "%v" + testSep + // BridgePassword, + "%v" + testSep + // Version string + "%v" + testSep + // Timestamp, + "%v" + testSep + // IsHidden, + "%v" // IsCombinedAddressMode +) + +// the best would be to run this test on mac, win, and linux separately + +type testCredentials struct { + UserID, + Name, + Emails, + APIToken, + Mailbox, + BridgePassword, + Version string + Timestamp int64 + IsHidden, + IsCombinedAddressMode bool +} + +func init() { //nolint:gochecknoinits + gob.Register(testCredentials{}) +} + +func (s *testCredentials) MarshalGob() string { + buf := bytes.Buffer{} + enc := gob.NewEncoder(&buf) + if err := enc.Encode(s); err != nil { + return "" + } + log.Infof("MarshalGob: %#v\n", buf.String()) + return base64.StdEncoding.EncodeToString(buf.Bytes()) +} + +func (s *testCredentials) Clear() { + s.UserID = "" + s.Name = "" + s.Emails = "" + s.APIToken = "" + s.Mailbox = "" + s.BridgePassword = "" + s.Version = "" + s.Timestamp = 0 + s.IsHidden = false + s.IsCombinedAddressMode = false +} + +func (s *testCredentials) UnmarshalGob(secret string) error { + s.Clear() + b, err := base64.StdEncoding.DecodeString(secret) + if err != nil { + log.Infoln("decode base64", b) + return err + } + buf := bytes.NewBuffer(b) + dec := gob.NewDecoder(buf) + if err = dec.Decode(s); err != nil { + log.Info("decode gob", b, buf.Bytes()) + return err + } + return nil +} + +func (s *testCredentials) ToJSON() string { + if b, err := json.Marshal(s); err == nil { + log.Infof("MarshalJSON: %#v\n", string(b)) + return base64.StdEncoding.EncodeToString(b) + } + return "" +} + +func (s *testCredentials) FromJSON(secret string) error { + b, err := base64.StdEncoding.DecodeString(secret) + if err != nil { + return err + } + if err = json.Unmarshal(b, s); err == nil { + return nil + } + return err +} + +func (s *testCredentials) MarshalFmt() string { + buf := bytes.Buffer{} + fmt.Fprintf( + &buf, secretFormat, + s.UserID, + s.Name, + s.Emails, + s.APIToken, + s.Mailbox, + s.BridgePassword, + s.Version, + s.Timestamp, + s.IsHidden, + s.IsCombinedAddressMode, + ) + log.Infof("MarshalFmt: %#v\n", buf.String()) + return base64.StdEncoding.EncodeToString(buf.Bytes()) +} + +func (s *testCredentials) UnmarshalFmt(secret string) error { + b, err := base64.StdEncoding.DecodeString(secret) + if err != nil { + return err + } + buf := bytes.NewBuffer(b) + log.Infoln("decode fmt", b, buf.Bytes()) + _, err = fmt.Fscanf( + buf, secretFormat, + &s.UserID, + &s.Name, + &s.Emails, + &s.APIToken, + &s.Mailbox, + &s.BridgePassword, + &s.Version, + &s.Timestamp, + &s.IsHidden, + &s.IsCombinedAddressMode, + ) + if err != nil { + return err + } + return nil +} + +func (s *testCredentials) MarshalStrings() string { // this is the most space efficient + items := []string{ + s.UserID, // 0 + s.Name, // 1 + s.Emails, // 2 + s.APIToken, // 3 + s.Mailbox, // 4 + s.BridgePassword, // 5 + s.Version, // 6 + } + items = append(items, fmt.Sprint(s.Timestamp)) // 7 + + if s.IsHidden { // 8 + items = append(items, "1") + } else { + items = append(items, "") + } + + if s.IsCombinedAddressMode { // 9 + items = append(items, "1") + } else { + items = append(items, "") + } + + str := strings.Join(items, sep) + + log.Infof("MarshalJoin: %#v\n", str) + return base64.StdEncoding.EncodeToString([]byte(str)) +} + +func (s *testCredentials) UnmarshalStrings(secret string) error { + b, err := base64.StdEncoding.DecodeString(secret) + if err != nil { + return err + } + items := strings.Split(string(b), sep) + if len(items) != 10 { + return ErrWrongFormat + } + + s.UserID = items[0] + s.Name = items[1] + s.Emails = items[2] + s.APIToken = items[3] + s.Mailbox = items[4] + s.BridgePassword = items[5] + s.Version = items[6] + if _, err = fmt.Sscanf(items[7], "%d", &s.Timestamp); err != nil { + s.Timestamp = 0 + } + if s.IsHidden = false; items[8] == "1" { + s.IsHidden = true + } + if s.IsCombinedAddressMode = false; items[9] == "1" { + s.IsCombinedAddressMode = true + } + return nil +} + +func (s *testCredentials) IsSame(rhs *testCredentials) bool { + return s.Name == rhs.Name && + s.Emails == rhs.Emails && + s.APIToken == rhs.APIToken && + s.Mailbox == rhs.Mailbox && + s.BridgePassword == rhs.BridgePassword && + s.Version == rhs.Version && + s.Timestamp == rhs.Timestamp && + s.IsHidden == rhs.IsHidden && + s.IsCombinedAddressMode == rhs.IsCombinedAddressMode +} + +func TestMarshalFormats(t *testing.T) { + input := testCredentials{UserID: "007", Emails: "ja@pm.me;jakub@cu.th", Timestamp: 152469263742, IsHidden: true} + log.Infof("input %#v\n", input) + + secretStrings := input.MarshalStrings() + log.Infof("secretStrings %#v %d\n", secretStrings, len(secretStrings)) + secretGob := input.MarshalGob() + log.Infof("secretGob %#v %d\n", secretGob, len(secretGob)) + secretJSON := input.ToJSON() + log.Infof("secretJSON %#v %d\n", secretJSON, len(secretJSON)) + secretFmt := input.MarshalFmt() + log.Infof("secretFmt %#v %d\n", secretFmt, len(secretFmt)) + + output := testCredentials{APIToken: "refresh"} + r.NoError(t, output.UnmarshalStrings(secretStrings)) + log.Infof("strings out %#v \n", output) + r.True(t, input.IsSame(&output), "strings out not same") + + output = testCredentials{APIToken: "refresh"} + r.NoError(t, output.UnmarshalGob(secretGob)) + log.Infof("gob out %#v\n \n", output) + r.Equal(t, input, output) + + output = testCredentials{APIToken: "refresh"} + r.NoError(t, output.FromJSON(secretJSON)) + log.Infof("json out %#v \n", output) + r.True(t, input.IsSame(&output), "json out not same") + + /* + // Simple Fscanf not working! + output = testCredentials{APIToken: "refresh"} + r.NoError(t, output.UnmarshalFmt(secretFmt)) + log.Infof("fmt out %#v \n", output) + r.True(t, input.IsSame(&output), "fmt out not same") + */ +} + +func TestMarshal(t *testing.T) { + input := Credentials{ + UserID: "", + Name: "007", + Emails: "ja@pm.me;aj@cus.tom", + APIToken: "sdfdsfsdfsdfsdf", + MailboxPassword: []byte("cdcdcdcd"), + BridgePassword: "wew123", + Version: "k11", + Timestamp: 152469263742, + IsHidden: true, + IsCombinedAddressMode: false, + } + log.Infof("input %#v\n", input) + + secret := input.Marshal() + log.Infof("secret %#v %d\n", secret, len(secret)) + + output := Credentials{APIToken: "refresh"} + r.NoError(t, output.Unmarshal(secret)) + log.Infof("output %#v\n", output) + r.Equal(t, input, output) +} diff --git a/internal/vault/helper.go b/internal/vault/helper.go index 92c0b3f6..a39ab01c 100644 --- a/internal/vault/helper.go +++ b/internal/vault/helper.go @@ -30,8 +30,6 @@ type Keychain struct { } func GetHelper(vaultDir string) (string, error) { - var keychain Keychain - filePath := filepath.Clean(filepath.Join(vaultDir, "keychain.json")) if _, err := os.Stat(filePath); errors.Is(err, fs.ErrNotExist) { @@ -43,6 +41,8 @@ func GetHelper(vaultDir string) (string, error) { return "", err } + var keychain Keychain + if err := json.Unmarshal(b, &keychain); err != nil { return "", err } @@ -56,7 +56,5 @@ func SetHelper(vaultDir, helper string) error { return err } - filePath := filepath.Clean(filepath.Join(vaultDir, "keychain.json")) - - return os.WriteFile(filePath, b, 0o600) + return os.WriteFile(filepath.Clean(filepath.Join(vaultDir, "keychain.json")), b, 0o600) } diff --git a/internal/vault/types_data.go b/internal/vault/types_data.go index a285a9f0..fd9dae04 100644 --- a/internal/vault/types_data.go +++ b/internal/vault/types_data.go @@ -22,6 +22,7 @@ type Data struct { Users []UserData Cookies []byte Certs Certs + Migrated bool } func newDefaultData(gluonDir string) Data { diff --git a/internal/vault/vault.go b/internal/vault/vault.go index 88b68f30..5d9469ce 100644 --- a/internal/vault/vault.go +++ b/internal/vault/vault.go @@ -175,6 +175,16 @@ func (vault *Vault) DeleteUser(userID string) error { }) } +func (vault *Vault) Migrated() bool { + return vault.get().Migrated +} + +func (vault *Vault) SetMigrated() error { + return vault.mod(func(data *Data) { + data.Migrated = true + }) +} + func (vault *Vault) Close() error { vault.refLock.Lock() defer vault.refLock.Unlock()