From 54c013012e2222b2e09eaafb003a3eaffcd4171f Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Fri, 17 Feb 2023 14:02:39 +0100 Subject: [PATCH] feat(GODT-2374): Import TLS certs via shell --- internal/app/vault.go | 4 ++- internal/bridge/bridge.go | 2 +- internal/bridge/tls.go | 6 +++- internal/frontend/cli/frontend.go | 12 ++++--- internal/frontend/cli/system.go | 30 +++++++++++++++++ internal/vault/certs.go | 54 ++++++++++++++++++++++++++++--- internal/vault/certs_test.go | 5 +-- internal/vault/types_certs.go | 4 +++ 8 files changed, 104 insertions(+), 13 deletions(-) diff --git a/internal/app/vault.go b/internal/app/vault.go index d49be879..895e320d 100644 --- a/internal/app/vault.go +++ b/internal/app/vault.go @@ -51,7 +51,9 @@ func WithVault(locations *locations.Locations, fn func(*vault.Vault, bool, bool) if installed := encVault.GetCertsInstalled(); !installed { logrus.Debug("Installing certificates") - if err := certs.NewInstaller().InstallCert(encVault.GetBridgeTLSCert()); err != nil { + certPEM, _ := encVault.GetBridgeTLSCert() + + if err := certs.NewInstaller().InstallCert(certPEM); err != nil { return fmt.Errorf("failed to install certs: %w", err) } diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index e08d4308..27cde2a4 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -556,7 +556,7 @@ func (bridge *Bridge) onStatusDown(ctx context.Context) { } func loadTLSConfig(vault *vault.Vault) (*tls.Config, error) { - cert, err := tls.X509KeyPair(vault.GetBridgeTLSCert(), vault.GetBridgeTLSKey()) + cert, err := tls.X509KeyPair(vault.GetBridgeTLSCert()) if err != nil { return nil, err } diff --git a/internal/bridge/tls.go b/internal/bridge/tls.go index b62e3470..c5e50fce 100644 --- a/internal/bridge/tls.go +++ b/internal/bridge/tls.go @@ -18,5 +18,9 @@ package bridge func (bridge *Bridge) GetBridgeTLSCert() ([]byte, []byte) { - return bridge.vault.GetBridgeTLSCert(), bridge.vault.GetBridgeTLSKey() + return bridge.vault.GetBridgeTLSCert() +} + +func (bridge *Bridge) SetBridgeTLSCertPath(certPath, keyPath string) error { + return bridge.vault.SetBridgeTLSCertPath(certPath, keyPath) } diff --git a/internal/frontend/cli/frontend.go b/internal/frontend/cli/frontend.go index e22856ca..07719c4f 100644 --- a/internal/frontend/cli/frontend.go +++ b/internal/frontend/cli/frontend.go @@ -135,12 +135,16 @@ func New(bridge *bridge.Bridge, restarter *restarter.Restarter, eventCh <-chan e fe.AddCmd(configureCmd) // TLS commands. - exportTLSCmd := &ishell.Cmd{ - Name: "export-tls", + fe.AddCmd(&ishell.Cmd{ + Name: "export-tls-cert", Help: "Export the TLS certificate used by the Bridge", Func: fe.exportTLSCerts, - } - fe.AddCmd(exportTLSCmd) + }) + fe.AddCmd(&ishell.Cmd{ + Name: "import-tls-cert", + Help: "Import a TLS certificate to be used by the Bridge", + Func: fe.importTLSCerts, + }) // All mail visibility commands. allMailCmd := &ishell.Cmd{ diff --git a/internal/frontend/cli/system.go b/internal/frontend/cli/system.go index c92df83f..dc1b12ca 100644 --- a/internal/frontend/cli/system.go +++ b/internal/frontend/cli/system.go @@ -226,6 +226,27 @@ func (f *frontendCLI) exportTLSCerts(c *ishell.Context) { } } +func (f *frontendCLI) importTLSCerts(c *ishell.Context) { + certPath := f.readStringInAttempts("Enter the path to the cert.pem file", c.ReadLine, f.isFile) + if certPath == "" { + f.printAndLogError(errors.New("failed to get cert path")) + return + } + + keyPath := f.readStringInAttempts("Enter the path to the key.pem file", c.ReadLine, f.isFile) + if keyPath == "" { + f.printAndLogError(errors.New("failed to get key path")) + return + } + + if err := f.bridge.SetBridgeTLSCertPath(certPath, keyPath); err != nil { + f.printAndLogError(err) + return + } + + f.Println("TLS certificate imported. Restart Bridge to use it.") +} + func (f *frontendCLI) isPortFree(port string) bool { port = strings.ReplaceAll(port, ":", "") if port == "" { @@ -252,3 +273,12 @@ func (f *frontendCLI) isCacheLocationUsable(location string) bool { return stat.IsDir() } + +func (f *frontendCLI) isFile(location string) bool { + stat, err := os.Stat(location) + if err != nil { + return false + } + + return !stat.IsDir() +} diff --git a/internal/vault/certs.go b/internal/vault/certs.go index 77086f9b..75294293 100644 --- a/internal/vault/certs.go +++ b/internal/vault/certs.go @@ -17,12 +17,40 @@ package vault -func (vault *Vault) GetBridgeTLSCert() []byte { - return vault.get().Certs.Bridge.Cert +import ( + "crypto/tls" + "fmt" + "os" + "path/filepath" + + "github.com/sirupsen/logrus" +) + +// GetBridgeTLSCert returns the PEM-encoded certificate for the bridge. +// If CertPEMPath is set, it will attempt to read the certificate from the file. +// Otherwise, or on read/validation failure, it will return the certificate from the vault. +func (vault *Vault) GetBridgeTLSCert() ([]byte, []byte) { + if certPath, keyPath := vault.get().Certs.CustomCertPath, vault.get().Certs.CustomKeyPath; certPath != "" && keyPath != "" { + if certPEM, keyPEM, err := readPEMCert(certPath, keyPath); err == nil { + return certPEM, keyPEM + } + + logrus.Error("Failed to read certificate from file, using default") + } + + return vault.get().Certs.Bridge.Cert, vault.get().Certs.Bridge.Key } -func (vault *Vault) GetBridgeTLSKey() []byte { - return vault.get().Certs.Bridge.Key +// SetBridgeTLSCertPath sets the path to PEM-encoded certificates for the bridge. +func (vault *Vault) SetBridgeTLSCertPath(certPath, keyPath string) error { + if _, _, err := readPEMCert(certPath, keyPath); err != nil { + return fmt.Errorf("invalid certificate: %w", err) + } + + return vault.mod(func(data *Data) { + data.Certs.CustomCertPath = certPath + data.Certs.CustomKeyPath = keyPath + }) } func (vault *Vault) GetCertsInstalled() bool { @@ -34,3 +62,21 @@ func (vault *Vault) SetCertsInstalled(installed bool) error { data.Certs.Installed = installed }) } + +func readPEMCert(certPEMPath, keyPEMPath string) ([]byte, []byte, error) { + certPEM, err := os.ReadFile(filepath.Clean(certPEMPath)) + if err != nil { + return nil, nil, err + } + + keyPEM, err := os.ReadFile(filepath.Clean(keyPEMPath)) + if err != nil { + return nil, nil, err + } + + if _, err := tls.X509KeyPair(certPEM, keyPEM); err != nil { + return nil, nil, err + } + + return certPEM, keyPEM, nil +} diff --git a/internal/vault/certs_test.go b/internal/vault/certs_test.go index 243b575f..0a3d7fde 100644 --- a/internal/vault/certs_test.go +++ b/internal/vault/certs_test.go @@ -28,8 +28,9 @@ func TestVault_TLSCerts(t *testing.T) { s := newVault(t) // Check the default bridge TLS certs. - require.NotEmpty(t, s.GetBridgeTLSCert()) - require.NotEmpty(t, s.GetBridgeTLSKey()) + cert, key := s.GetBridgeTLSCert() + require.NotEmpty(t, cert) + require.NotEmpty(t, key) // Check the certificates are not installed. require.False(t, s.GetCertsInstalled()) diff --git a/internal/vault/types_certs.go b/internal/vault/types_certs.go index b478199d..195a43a1 100644 --- a/internal/vault/types_certs.go +++ b/internal/vault/types_certs.go @@ -22,6 +22,10 @@ import "github.com/ProtonMail/proton-bridge/v3/internal/certs" type Certs struct { Bridge Cert Installed bool + + // If non-empty, the path to the PEM-encoded certificate file. + CustomCertPath string + CustomKeyPath string } type Cert struct {