GODT-1975: Migrate keychain secrets

This commit is contained in:
James Houlahan
2022-11-21 01:38:51 +01:00
parent 7ed8d76d84
commit 48dfdabaf4
10 changed files with 742 additions and 10 deletions

View File

@ -183,6 +183,11 @@ func run(c *cli.Context) error { //nolint:funlen
return withCrashHandler(restarter, reporter, func(crashHandler *crash.Handler) error {
// Load the locations where we store our files.
return WithLocations(func(locations *locations.Locations) error {
// Migrate the keychain helper.
if err := migrateKeychainHelper(locations); err != nil {
logrus.WithError(err).Error("Failed to migrate keychain helper")
}
// Initialize logging.
return withLogging(c, crashHandler, locations, func() error {
// If there was an error during migration, log it now.
@ -194,10 +199,23 @@ func run(c *cli.Context) error { //nolint:funlen
return withSingleInstance(locations, version, func() error {
// Unlock the encrypted vault.
return WithVault(locations, func(vault *vault.Vault, insecure, corrupt bool) error {
if !vault.Migrated() {
// Migrate old settings into the vault.
if err := migrateOldSettings(vault); err != nil {
logrus.WithError(err).Error("Failed to migrate old settings")
}
// Migrate old accounts into the vault.
if err := migrateOldAccounts(locations, vault); err != nil {
logrus.WithError(err).Error("Failed to migrate old accounts")
}
// The vault has been migrated.
if err := vault.SetMigrated(); err != nil {
logrus.WithError(err).Error("Failed to mark vault as migrated")
}
}
// Load the cookies from the vault.
return withCookieJar(vault, func(cookieJar http.CookieJar) error {
// Create a new bridge instance.

View File

@ -27,8 +27,11 @@ import (
"time"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/legacy/credentials"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
"github.com/allan-simon/go-singleinstance"
"github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
@ -36,7 +39,9 @@ import (
)
// nolint:gosec
func migrateOldSettings(vault *vault.Vault) error {
func migrateKeychainHelper(locations *locations.Locations) error {
logrus.Info("Migrating keychain helper")
configDir, err := os.UserConfigDir()
if err != nil {
return fmt.Errorf("failed to get user config dir: %w", err)
@ -47,7 +52,88 @@ func migrateOldSettings(vault *vault.Vault) error {
return fmt.Errorf("failed to read old prefs file: %w", err)
}
return migratePrefsToVault(vault, b)
var prefs struct {
Helper string `json:"preferred_keychain"`
}
if err := json.Unmarshal(b, &prefs); err != nil {
return fmt.Errorf("failed to unmarshal old prefs file: %w", err)
}
settings, err := locations.ProvideSettingsPath()
if err != nil {
return fmt.Errorf("failed to get settings path: %w", err)
}
return vault.SetHelper(settings, prefs.Helper)
}
// nolint:gosec
func migrateOldSettings(v *vault.Vault) error {
logrus.Info("Migrating settings")
configDir, err := os.UserConfigDir()
if err != nil {
return fmt.Errorf("failed to get user config dir: %w", err)
}
b, err := os.ReadFile(filepath.Join(configDir, "protonmail", "bridge", "prefs.json"))
if err != nil {
return fmt.Errorf("failed to read old prefs file: %w", err)
}
return migratePrefsToVault(v, b)
}
func migrateOldAccounts(locations *locations.Locations, v *vault.Vault) error {
logrus.Info("Migrating accounts")
settings, err := locations.ProvideSettingsPath()
if err != nil {
return fmt.Errorf("failed to get settings path: %w", err)
}
helper, err := vault.GetHelper(settings)
if err != nil {
return fmt.Errorf("failed to get helper: %w", err)
}
keychain, err := keychain.NewKeychain(helper, "bridge")
if err != nil {
return fmt.Errorf("failed to create keychain: %w", err)
}
store := credentials.NewStore(keychain)
users, err := store.List()
if err != nil {
return fmt.Errorf("failed to create credentials store: %w", err)
}
for _, userID := range users {
logrus.WithField("userID", userID).Info("Migrating account")
creds, err := store.Get(userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
authUID, authRef, err := creds.SplitAPIToken()
if err != nil {
return fmt.Errorf("failed to split api token: %w", err)
}
user, err := v.AddUser(creds.UserID, creds.EmailList()[0], authUID, authRef, creds.MailboxPassword)
if err != nil {
return fmt.Errorf("failed to add user: %w", err)
}
if err := user.Close(); err != nil {
return fmt.Errorf("failed to close user: %w", err)
}
}
return nil
}
// nolint:funlen

View File

@ -30,7 +30,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestMigrateOldVaultFromJSON(t *testing.T) {
func TestMigratePrefsToVault(t *testing.T) {
// Create a new vault.
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"))
require.NoError(t, err)

View File

@ -0,0 +1,136 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package credentials implements our struct stored in keychain.
// Store struct is kind of like a database client.
// Credentials struct is kind of like one record from the database.
package credentials
import (
"encoding/base64"
"errors"
"fmt"
"strings"
"github.com/sirupsen/logrus"
)
const (
sep = "\x00"
itemLengthBridge = 9
itemLengthImportExport = 6 // Old format for Import-Export.
)
var (
log = logrus.WithField("pkg", "credentials") //nolint:gochecknoglobals
ErrWrongFormat = errors.New("malformed credentials")
)
type Credentials struct {
UserID, // Do not marshal; used as a key.
Name,
Emails,
APIToken string
MailboxPassword []byte
BridgePassword,
Version string
Timestamp int64
IsHidden, // Deprecated.
IsCombinedAddressMode bool
}
func (s *Credentials) Marshal() string {
items := []string{
s.Name, // 0
s.Emails, // 1
s.APIToken, // 2
string(s.MailboxPassword), // 3
s.BridgePassword, // 4
s.Version, // 5
"", // 6
"", // 7
"", // 8
}
items[6] = fmt.Sprint(s.Timestamp)
if s.IsHidden {
items[7] = "1"
}
if s.IsCombinedAddressMode {
items[8] = "1"
}
str := strings.Join(items, sep)
return base64.StdEncoding.EncodeToString([]byte(str))
}
func (s *Credentials) Unmarshal(secret string) error {
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
return err
}
items := strings.Split(string(b), sep)
if len(items) != itemLengthBridge && len(items) != itemLengthImportExport {
return ErrWrongFormat
}
s.Name = items[0]
s.Emails = items[1]
s.APIToken = items[2]
s.MailboxPassword = []byte(items[3])
switch len(items) {
case itemLengthBridge:
s.BridgePassword = items[4]
s.Version = items[5]
if _, err = fmt.Sscan(items[6], &s.Timestamp); err != nil {
s.Timestamp = 0
}
if s.IsHidden = false; items[7] == "1" {
s.IsHidden = true
}
if s.IsCombinedAddressMode = false; items[8] == "1" {
s.IsCombinedAddressMode = true
}
case itemLengthImportExport:
s.Version = items[4]
if _, err = fmt.Sscan(items[5], &s.Timestamp); err != nil {
s.Timestamp = 0
}
}
return nil
}
func (s *Credentials) EmailList() []string {
return strings.Split(s.Emails, ";")
}
func (s *Credentials) SplitAPIToken() (string, string, error) {
split := strings.Split(s.APIToken, ":")
if len(split) != 2 {
return "", "", errors.New("malformed API token")
}
return split[0], split[1], nil
}

View File

@ -0,0 +1,67 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package credentials
import (
"encoding/base64"
"fmt"
"strings"
"testing"
"time"
r "github.com/stretchr/testify/require"
)
var wantCredentials = Credentials{
UserID: "1",
Name: "name",
Emails: "email1;email2",
APIToken: "token",
MailboxPassword: []byte("mailbox pass"),
BridgePassword: "bridge pass",
Version: "k11",
Timestamp: time.Now().Unix(),
IsHidden: false,
IsCombinedAddressMode: false,
}
func TestUnmarshallBridge(t *testing.T) {
encoded := wantCredentials.Marshal()
haveCredentials := Credentials{UserID: "1"}
r.NoError(t, haveCredentials.Unmarshal(encoded))
r.Equal(t, wantCredentials, haveCredentials)
}
func TestUnmarshallImportExport(t *testing.T) {
items := []string{
wantCredentials.Name,
wantCredentials.Emails,
wantCredentials.APIToken,
string(wantCredentials.MailboxPassword),
"k11",
fmt.Sprint(wantCredentials.Timestamp),
}
str := strings.Join(items, sep)
encoded := base64.StdEncoding.EncodeToString([]byte(str))
haveCredentials := Credentials{UserID: "1"}
haveCredentials.BridgePassword = wantCredentials.BridgePassword // This one is not used.
r.NoError(t, haveCredentials.Unmarshal(encoded))
r.Equal(t, wantCredentials, haveCredentials)
}

View File

@ -0,0 +1,118 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package credentials
import (
"errors"
"fmt"
"sort"
"sync"
)
var storeLocker = sync.RWMutex{} //nolint:gochecknoglobals
// Store is an encrypted credentials store.
type Store struct {
secrets Keychain
}
type Keychain interface {
List() ([]string, error)
Get(string) (string, string, error)
Put(string, string) error
Delete(string) error
}
// NewStore creates a new encrypted credentials store.
func NewStore(keychain Keychain) *Store {
return &Store{secrets: keychain}
}
// List returns a list of usernames that have credentials stored.
func (s *Store) List() (userIDs []string, err error) {
storeLocker.RLock()
defer storeLocker.RUnlock()
log.Trace("Listing credentials in credentials store")
var allUserIDs []string
if allUserIDs, err = s.secrets.List(); err != nil {
log.WithError(err).Error("Could not list credentials")
return
}
credentialList := []*Credentials{}
for _, userID := range allUserIDs {
creds, getErr := s.get(userID)
if getErr != nil {
log.WithField("userID", userID).WithError(getErr).Warn("Failed to get credentials")
continue
}
// Disabled credentials
if creds.Timestamp == 0 {
continue
}
credentialList = append(credentialList, creds)
}
sort.Slice(credentialList, func(i, j int) bool {
return credentialList[i].Timestamp < credentialList[j].Timestamp
})
for _, credentials := range credentialList {
userIDs = append(userIDs, credentials.UserID)
}
return userIDs, err
}
func (s *Store) Get(userID string) (creds *Credentials, err error) {
storeLocker.RLock()
defer storeLocker.RUnlock()
return s.get(userID)
}
func (s *Store) get(userID string) (*Credentials, error) {
log := log.WithField("user", userID)
_, secret, err := s.secrets.Get(userID)
if err != nil {
return nil, err
}
if secret == "" {
return nil, errors.New("secret is empty")
}
credentials := &Credentials{UserID: userID}
if err := credentials.Unmarshal(secret); err != nil {
log.WithError(fmt.Errorf("malformed secret: %w", err)).Error("Could not unmarshal secret")
if err := s.secrets.Delete(userID); err != nil {
log.WithError(err).Error("Failed to remove malformed secret")
}
return nil, err
}
return credentials, nil
}

View File

@ -0,0 +1,298 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package credentials
import (
"bytes"
"encoding/base64"
"encoding/gob"
"encoding/json"
"fmt"
"strings"
"testing"
r "github.com/stretchr/testify/require"
)
const (
testSep = "\n"
secretFormat = "%v" + testSep + // UserID,
"%v" + testSep + // Name,
"%v" + testSep + // Emails,
"%v" + testSep + // APIToken,
"%v" + testSep + // Mailbox,
"%v" + testSep + // BridgePassword,
"%v" + testSep + // Version string
"%v" + testSep + // Timestamp,
"%v" + testSep + // IsHidden,
"%v" // IsCombinedAddressMode
)
// the best would be to run this test on mac, win, and linux separately
type testCredentials struct {
UserID,
Name,
Emails,
APIToken,
Mailbox,
BridgePassword,
Version string
Timestamp int64
IsHidden,
IsCombinedAddressMode bool
}
func init() { //nolint:gochecknoinits
gob.Register(testCredentials{})
}
func (s *testCredentials) MarshalGob() string {
buf := bytes.Buffer{}
enc := gob.NewEncoder(&buf)
if err := enc.Encode(s); err != nil {
return ""
}
log.Infof("MarshalGob: %#v\n", buf.String())
return base64.StdEncoding.EncodeToString(buf.Bytes())
}
func (s *testCredentials) Clear() {
s.UserID = ""
s.Name = ""
s.Emails = ""
s.APIToken = ""
s.Mailbox = ""
s.BridgePassword = ""
s.Version = ""
s.Timestamp = 0
s.IsHidden = false
s.IsCombinedAddressMode = false
}
func (s *testCredentials) UnmarshalGob(secret string) error {
s.Clear()
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
log.Infoln("decode base64", b)
return err
}
buf := bytes.NewBuffer(b)
dec := gob.NewDecoder(buf)
if err = dec.Decode(s); err != nil {
log.Info("decode gob", b, buf.Bytes())
return err
}
return nil
}
func (s *testCredentials) ToJSON() string {
if b, err := json.Marshal(s); err == nil {
log.Infof("MarshalJSON: %#v\n", string(b))
return base64.StdEncoding.EncodeToString(b)
}
return ""
}
func (s *testCredentials) FromJSON(secret string) error {
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
return err
}
if err = json.Unmarshal(b, s); err == nil {
return nil
}
return err
}
func (s *testCredentials) MarshalFmt() string {
buf := bytes.Buffer{}
fmt.Fprintf(
&buf, secretFormat,
s.UserID,
s.Name,
s.Emails,
s.APIToken,
s.Mailbox,
s.BridgePassword,
s.Version,
s.Timestamp,
s.IsHidden,
s.IsCombinedAddressMode,
)
log.Infof("MarshalFmt: %#v\n", buf.String())
return base64.StdEncoding.EncodeToString(buf.Bytes())
}
func (s *testCredentials) UnmarshalFmt(secret string) error {
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
return err
}
buf := bytes.NewBuffer(b)
log.Infoln("decode fmt", b, buf.Bytes())
_, err = fmt.Fscanf(
buf, secretFormat,
&s.UserID,
&s.Name,
&s.Emails,
&s.APIToken,
&s.Mailbox,
&s.BridgePassword,
&s.Version,
&s.Timestamp,
&s.IsHidden,
&s.IsCombinedAddressMode,
)
if err != nil {
return err
}
return nil
}
func (s *testCredentials) MarshalStrings() string { // this is the most space efficient
items := []string{
s.UserID, // 0
s.Name, // 1
s.Emails, // 2
s.APIToken, // 3
s.Mailbox, // 4
s.BridgePassword, // 5
s.Version, // 6
}
items = append(items, fmt.Sprint(s.Timestamp)) // 7
if s.IsHidden { // 8
items = append(items, "1")
} else {
items = append(items, "")
}
if s.IsCombinedAddressMode { // 9
items = append(items, "1")
} else {
items = append(items, "")
}
str := strings.Join(items, sep)
log.Infof("MarshalJoin: %#v\n", str)
return base64.StdEncoding.EncodeToString([]byte(str))
}
func (s *testCredentials) UnmarshalStrings(secret string) error {
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
return err
}
items := strings.Split(string(b), sep)
if len(items) != 10 {
return ErrWrongFormat
}
s.UserID = items[0]
s.Name = items[1]
s.Emails = items[2]
s.APIToken = items[3]
s.Mailbox = items[4]
s.BridgePassword = items[5]
s.Version = items[6]
if _, err = fmt.Sscanf(items[7], "%d", &s.Timestamp); err != nil {
s.Timestamp = 0
}
if s.IsHidden = false; items[8] == "1" {
s.IsHidden = true
}
if s.IsCombinedAddressMode = false; items[9] == "1" {
s.IsCombinedAddressMode = true
}
return nil
}
func (s *testCredentials) IsSame(rhs *testCredentials) bool {
return s.Name == rhs.Name &&
s.Emails == rhs.Emails &&
s.APIToken == rhs.APIToken &&
s.Mailbox == rhs.Mailbox &&
s.BridgePassword == rhs.BridgePassword &&
s.Version == rhs.Version &&
s.Timestamp == rhs.Timestamp &&
s.IsHidden == rhs.IsHidden &&
s.IsCombinedAddressMode == rhs.IsCombinedAddressMode
}
func TestMarshalFormats(t *testing.T) {
input := testCredentials{UserID: "007", Emails: "ja@pm.me;jakub@cu.th", Timestamp: 152469263742, IsHidden: true}
log.Infof("input %#v\n", input)
secretStrings := input.MarshalStrings()
log.Infof("secretStrings %#v %d\n", secretStrings, len(secretStrings))
secretGob := input.MarshalGob()
log.Infof("secretGob %#v %d\n", secretGob, len(secretGob))
secretJSON := input.ToJSON()
log.Infof("secretJSON %#v %d\n", secretJSON, len(secretJSON))
secretFmt := input.MarshalFmt()
log.Infof("secretFmt %#v %d\n", secretFmt, len(secretFmt))
output := testCredentials{APIToken: "refresh"}
r.NoError(t, output.UnmarshalStrings(secretStrings))
log.Infof("strings out %#v \n", output)
r.True(t, input.IsSame(&output), "strings out not same")
output = testCredentials{APIToken: "refresh"}
r.NoError(t, output.UnmarshalGob(secretGob))
log.Infof("gob out %#v\n \n", output)
r.Equal(t, input, output)
output = testCredentials{APIToken: "refresh"}
r.NoError(t, output.FromJSON(secretJSON))
log.Infof("json out %#v \n", output)
r.True(t, input.IsSame(&output), "json out not same")
/*
// Simple Fscanf not working!
output = testCredentials{APIToken: "refresh"}
r.NoError(t, output.UnmarshalFmt(secretFmt))
log.Infof("fmt out %#v \n", output)
r.True(t, input.IsSame(&output), "fmt out not same")
*/
}
func TestMarshal(t *testing.T) {
input := Credentials{
UserID: "",
Name: "007",
Emails: "ja@pm.me;aj@cus.tom",
APIToken: "sdfdsfsdfsdfsdf",
MailboxPassword: []byte("cdcdcdcd"),
BridgePassword: "wew123",
Version: "k11",
Timestamp: 152469263742,
IsHidden: true,
IsCombinedAddressMode: false,
}
log.Infof("input %#v\n", input)
secret := input.Marshal()
log.Infof("secret %#v %d\n", secret, len(secret))
output := Credentials{APIToken: "refresh"}
r.NoError(t, output.Unmarshal(secret))
log.Infof("output %#v\n", output)
r.Equal(t, input, output)
}

View File

@ -30,8 +30,6 @@ type Keychain struct {
}
func GetHelper(vaultDir string) (string, error) {
var keychain Keychain
filePath := filepath.Clean(filepath.Join(vaultDir, "keychain.json"))
if _, err := os.Stat(filePath); errors.Is(err, fs.ErrNotExist) {
@ -43,6 +41,8 @@ func GetHelper(vaultDir string) (string, error) {
return "", err
}
var keychain Keychain
if err := json.Unmarshal(b, &keychain); err != nil {
return "", err
}
@ -56,7 +56,5 @@ func SetHelper(vaultDir, helper string) error {
return err
}
filePath := filepath.Clean(filepath.Join(vaultDir, "keychain.json"))
return os.WriteFile(filePath, b, 0o600)
return os.WriteFile(filepath.Clean(filepath.Join(vaultDir, "keychain.json")), b, 0o600)
}

View File

@ -22,6 +22,7 @@ type Data struct {
Users []UserData
Cookies []byte
Certs Certs
Migrated bool
}
func newDefaultData(gluonDir string) Data {

View File

@ -175,6 +175,16 @@ func (vault *Vault) DeleteUser(userID string) error {
})
}
func (vault *Vault) Migrated() bool {
return vault.get().Migrated
}
func (vault *Vault) SetMigrated() error {
return vault.mod(func(data *Data) {
data.Migrated = true
})
}
func (vault *Vault) Close() error {
vault.refLock.Lock()
defer vault.refLock.Unlock()