From ab4776c332906cb13dc61776d7fca0f58f88198d Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Wed, 20 Jan 2021 13:11:06 +0100 Subject: [PATCH] refactor: tidy up tls cert stuff --- internal/app/bridge/bridge.go | 48 ++++++- internal/config/tls/cert_store_darwin.go | 53 +++++++ internal/config/tls/cert_store_linux.go | 26 ++++ internal/config/tls/cert_store_windows.go | 26 ++++ internal/config/tls/tls.go | 166 +++++++++------------- internal/config/tls/tls_test.go | 63 ++++---- 6 files changed, 257 insertions(+), 125 deletions(-) create mode 100644 internal/config/tls/cert_store_darwin.go create mode 100644 internal/config/tls/cert_store_linux.go create mode 100644 internal/config/tls/cert_store_windows.go diff --git a/internal/app/bridge/bridge.go b/internal/app/bridge/bridge.go index d155354d..03982801 100644 --- a/internal/app/bridge/bridge.go +++ b/internal/app/bridge/bridge.go @@ -19,12 +19,14 @@ package bridge import ( + "crypto/tls" "time" "github.com/ProtonMail/proton-bridge/internal/api" "github.com/ProtonMail/proton-bridge/internal/app/base" "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/config/settings" + pkgTLS "github.com/ProtonMail/proton-bridge/internal/config/tls" "github.com/ProtonMail/proton-bridge/internal/constants" "github.com/ProtonMail/proton-bridge/internal/frontend" "github.com/ProtonMail/proton-bridge/internal/frontend/types" @@ -58,9 +60,9 @@ func New(base *base.Base) *cli.App { } func run(b *base.Base, c *cli.Context) error { // nolint[funlen] - tls, err := b.TLS.GetConfig() + tlsConfig, err := loadTLSConfig(b) if err != nil { - logrus.WithError(err).Fatal("Failed to create TLS config") + logrus.WithError(err).Fatal("Failed to load TLS config") } bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.CrashHandler, b.Listener, b.CM, b.Creds) @@ -78,7 +80,7 @@ func run(b *base.Base, c *cli.Context) error { // nolint[funlen] imap.NewIMAPServer( c.String("log-imap") == "client" || c.String("log-imap") == "all", c.String("log-imap") == "server" || c.String("log-imap") == "all", - imapPort, tls, imapBackend, b.Listener).ListenAndServe() + imapPort, tlsConfig, imapBackend, b.Listener).ListenAndServe() }() go func() { @@ -87,7 +89,7 @@ func run(b *base.Base, c *cli.Context) error { // nolint[funlen] useSSL := b.Settings.GetBool(settings.SMTPSSLKey) smtp.NewSMTPServer( c.Bool("log-smtp"), - smtpPort, useSSL, tls, smtpBackend, b.Listener).ListenAndServe() + smtpPort, useSSL, tlsConfig, smtpBackend, b.Listener).ListenAndServe() }() // Bridge supports no-window option which we should use for autostart. @@ -140,6 +142,44 @@ func run(b *base.Base, c *cli.Context) error { // nolint[funlen] return f.Loop() } +func loadTLSConfig(b *base.Base) (*tls.Config, error) { + if !b.TLS.HasCerts() { + if err := generateTLSCerts(b); err != nil { + return nil, err + } + } + + tlsConfig, err := b.TLS.GetConfig() + if err == nil { + return tlsConfig, nil + } + + logrus.WithError(err).Error("Failed to load TLS config, regenerating certificates") + + if err := generateTLSCerts(b); err != nil { + return nil, err + } + + return b.TLS.GetConfig() +} + +func generateTLSCerts(b *base.Base) error { + template, err := pkgTLS.NewTLSTemplate() + if err != nil { + return errors.Wrap(err, "failed to generate TLS template") + } + + if err := b.TLS.GenerateCerts(template); err != nil { + return errors.Wrap(err, "failed to generate TLS certs") + } + + if err := b.TLS.InstallCerts(); err != nil { + return errors.Wrap(err, "failed to install TLS certs") + } + + return nil +} + func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) { version, err := u.Check() if err != nil { diff --git a/internal/config/tls/cert_store_darwin.go b/internal/config/tls/cert_store_darwin.go new file mode 100644 index 00000000..ee14a42c --- /dev/null +++ b/internal/config/tls/cert_store_darwin.go @@ -0,0 +1,53 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package tls + +import "os/exec" + +func addTrustedCert(certPath string) error { + return exec.Command( // nolint[gosec] + "/usr/bin/security", + "execute-with-privileges", + "/usr/bin/security", + "add-trusted-cert", + "-d", + "-r", "trustRoot", + "-p", "ssl", + "-k", "/Library/Keychains/System.keychain", + certPath, + ).Run() +} + +func removeTrustedCert(certPath string) error { + return exec.Command( // nolint[gosec] + "/usr/bin/security", + "execute-with-privileges", + "/usr/bin/security", + "remove-trusted-cert", + "-d", + certPath, + ).Run() +} + +func (t *TLS) InstallCerts() error { + return addTrustedCert(t.getTLSCertPath()) +} + +func (t *TLS) UninstallCerts() error { + return removeTrustedCert(t.getTLSCertPath()) +} diff --git a/internal/config/tls/cert_store_linux.go b/internal/config/tls/cert_store_linux.go new file mode 100644 index 00000000..01a138b8 --- /dev/null +++ b/internal/config/tls/cert_store_linux.go @@ -0,0 +1,26 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package tls + +func (t *TLS) InstallCerts() error { + return nil // Linux doesn't have a root cert store. +} + +func (t *TLS) UninstallCerts() error { + return nil // Linux doesn't have a root cert store. +} diff --git a/internal/config/tls/cert_store_windows.go b/internal/config/tls/cert_store_windows.go new file mode 100644 index 00000000..0fed5159 --- /dev/null +++ b/internal/config/tls/cert_store_windows.go @@ -0,0 +1,26 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package tls + +func (t *TLS) InstallCerts() error { + return nil // NOTE(GODT-986): Install certs to root cert store? +} + +func (t *TLS) UninstallCerts() error { + return nil // NOTE(GODT-986): Uninstall certs from root cert store? +} diff --git a/internal/config/tls/tls.go b/internal/config/tls/tls.go index 0b9e27e9..0ce76cd4 100644 --- a/internal/config/tls/tls.go +++ b/internal/config/tls/tls.go @@ -28,12 +28,10 @@ import ( "math/big" "net" "os" - "os/exec" "path/filepath" - "runtime" "time" - "github.com/sirupsen/logrus" + "github.com/pkg/errors" ) type TLS struct { @@ -46,24 +44,32 @@ func New(settingsPath string) *TLS { } } -var tlsTemplate = x509.Certificate{ //nolint[gochecknoglobals] - SerialNumber: big.NewInt(-1), - Subject: pkix.Name{ - Country: []string{"CH"}, - Organization: []string{"Proton Technologies AG"}, - OrganizationalUnit: []string{"ProtonMail"}, - CommonName: "127.0.0.1", - }, - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, - BasicConstraintsValid: true, - IsCA: true, - IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(20 * 365 * 24 * time.Hour), +// NewTLSTemplate creates a new TLS template certificate with a random serial number. +func NewTLSTemplate() (*x509.Certificate, error) { + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, errors.Wrap(err, "failed to generate serial number") + } + + return &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Country: []string{"CH"}, + Organization: []string{"Proton Technologies AG"}, + OrganizationalUnit: []string{"ProtonMail"}, + CommonName: "127.0.0.1", + }, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(20 * 365 * 24 * time.Hour), + }, nil } -var ErrTLSCertExpireSoon = fmt.Errorf("TLS certificate will expire soon") +var ErrTLSCertExpiresSoon = fmt.Errorf("TLS certificate will expire soon") // getTLSCertPath returns path to certificate; used for TLS servers (IMAP, SMTP). func (t *TLS) getTLSCertPath() string { @@ -75,110 +81,78 @@ func (t *TLS) getTLSKeyPath() string { return filepath.Join(t.settingsPath, "key.pem") } -// GenerateConfig generates certs and keys at the given filepaths and returns a TLS Config which holds them. -// See https://golang.org/src/crypto/tls/generate_cert.go -func (t *TLS) GenerateConfig() (tlsConfig *tls.Config, err error) { +// HasCerts returns whether TLS certs have been generated. +func (t *TLS) HasCerts() bool { + if _, err := os.Stat(t.getTLSCertPath()); err != nil { + return false + } + + if _, err := os.Stat(t.getTLSKeyPath()); err != nil { + return false + } + + return true +} + +// GenerateCerts generates certs from the given template. +func (t *TLS) GenerateCerts(template *x509.Certificate) error { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - err = fmt.Errorf("failed to generate private key: %s", err) - return + return errors.Wrap(err, "failed to generate private key") } - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) if err != nil { - err = fmt.Errorf("failed to generate serial number: %s", err) - return - } - - tlsTemplate.SerialNumber = serialNumber - derBytes, err := x509.CreateCertificate(rand.Reader, &tlsTemplate, &tlsTemplate, &priv.PublicKey, priv) - if err != nil { - err = fmt.Errorf("failed to create certificate: %s", err) - return + return errors.Wrap(err, "failed to create certificate") } certOut, err := os.Create(t.getTLSCertPath()) if err != nil { - return + return err } - defer certOut.Close() //nolint[errcheck] - err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - if err != nil { - return + defer certOut.Close() // nolint[errcheck] + + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + return err } keyOut, err := os.OpenFile(t.getTLSKeyPath(), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { - return + return err } - defer keyOut.Close() //nolint[errcheck] - err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) - if err != nil { - return + defer keyOut.Close() // nolint[errcheck] + + if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil { + return err } - return loadTLSConfig(t.getTLSCertPath(), t.getTLSKeyPath()) + return nil } // GetConfig tries to load TLS config or generate new one which is then returned. -func (t *TLS) GetConfig() (tlsConfig *tls.Config, err error) { - certPath := t.getTLSCertPath() - keyPath := t.getTLSKeyPath() - tlsConfig, err = loadTLSConfig(certPath, keyPath) +func (t *TLS) GetConfig() (*tls.Config, error) { + c, err := tls.LoadX509KeyPair(t.getTLSCertPath(), t.getTLSKeyPath()) if err != nil { - logrus.WithError(err).Warn("Cannot load cert, generating a new one") - tlsConfig, err = t.GenerateConfig() - if err != nil { - return - } - - if runtime.GOOS == "darwin" { - if err := exec.Command( // nolint[gosec] - "/usr/bin/security", - "execute-with-privileges", - "/usr/bin/security", - "add-trusted-cert", - "-d", - "-r", "trustRoot", - "-p", "ssl", - "-k", "/Library/Keychains/System.keychain", - certPath, - ).Run(); err != nil { - logrus.WithError(err).Error("Failed to add cert to system keychain") - } - } - } - - tlsConfig.ServerName = "127.0.0.1" - tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven - - caCertPool := x509.NewCertPool() - caCertPool.AddCert(tlsConfig.Certificates[0].Leaf) - tlsConfig.RootCAs = caCertPool - tlsConfig.ClientCAs = caCertPool - - return tlsConfig, err -} - -func loadTLSConfig(certPath, keyPath string) (tlsConfig *tls.Config, err error) { - c, err := tls.LoadX509KeyPair(certPath, keyPath) - if err != nil { - return + return nil, errors.Wrap(err, "failed to load keypair") } c.Leaf, err = x509.ParseCertificate(c.Certificate[0]) if err != nil { - return - } - - tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{c}, + return nil, errors.Wrap(err, "failed to parse certificate") } if time.Now().Add(31 * 24 * time.Hour).After(c.Leaf.NotAfter) { - err = ErrTLSCertExpireSoon - return + return nil, ErrTLSCertExpiresSoon } - return + + caCertPool := x509.NewCertPool() + caCertPool.AddCert(c.Leaf) + + return &tls.Config{ + Certificates: []tls.Certificate{c}, + ServerName: "127.0.0.1", + ClientAuth: tls.VerifyClientCertIfGiven, + RootCAs: caCertPool, + ClientCAs: caCertPool, + }, nil } diff --git a/internal/config/tls/tls_test.go b/internal/config/tls/tls_test.go index 292682d0..5c41a463 100644 --- a/internal/config/tls/tls_test.go +++ b/internal/config/tls/tls_test.go @@ -19,46 +19,59 @@ package tls import ( "io/ioutil" - "os" - "path/filepath" - "runtime" "testing" "time" "github.com/stretchr/testify/require" ) -func TestTLSKeyRenewal(t *testing.T) { - // Remove keys. - configPath := "/tmp" - certPath := filepath.Join(configPath, "cert.pem") - keyPath := filepath.Join(configPath, "key.pem") - _ = os.Remove(certPath) - _ = os.Remove(keyPath) - +func TestGetOldConfig(t *testing.T) { dir, err := ioutil.TempDir("", "test-tls") require.NoError(t, err) + // Create new tls object. tls := New(dir) - // Put old key there. + // Create new TLS template. + tlsTemplate, err := NewTLSTemplate() + require.NoError(t, err) + + // Make the template be an old key. tlsTemplate.NotBefore = time.Now().Add(-365 * 24 * time.Hour) tlsTemplate.NotAfter = time.Now() - cert, err := tls.GenerateConfig() - require.Equal(t, err, ErrTLSCertExpireSoon) - require.Equal(t, len(cert.Certificates), 1) - time.Sleep(time.Second) - now, notValidAfter := time.Now(), cert.Certificates[0].Leaf.NotAfter - require.True(t, now.After(notValidAfter), "old certificate expected to not be valid at %v but have valid until %v", now, notValidAfter) - // Renew key. + // Generate the certs from the template. + require.NoError(t, tls.GenerateCerts(tlsTemplate)) + + // Generate the config from the certs -- it's going to expire soon so we don't want to use it. + _, err = tls.GetConfig() + require.Equal(t, err, ErrTLSCertExpiresSoon) +} + +func TestGetValidConfig(t *testing.T) { + dir, err := ioutil.TempDir("", "test-tls") + require.NoError(t, err) + + // Create new tls object. + tls := New(dir) + + // Create new TLS template. + tlsTemplate, err := NewTLSTemplate() + require.NoError(t, err) + + // Make the template be a new key. tlsTemplate.NotBefore = time.Now() tlsTemplate.NotAfter = time.Now().Add(2 * 365 * 24 * time.Hour) - cert, err = tls.GetConfig() - if runtime.GOOS != "darwin" { // Darwin is not supported. - require.NoError(t, err) - } - require.Equal(t, len(cert.Certificates), 1) - now, notValidAfter = time.Now(), cert.Certificates[0].Leaf.NotAfter + + // Generate the certs from the template. + require.NoError(t, tls.GenerateCerts(tlsTemplate)) + + // Generate the config from the certs -- it's not going to expire soon so we want to use it. + config, err := tls.GetConfig() + require.NoError(t, err) + require.Equal(t, len(config.Certificates), 1) + + // Check the cert is valid. + now, notValidAfter := time.Now(), config.Certificates[0].Leaf.NotAfter require.False(t, now.After(notValidAfter), "new certificate expected to be valid at %v but have valid until %v", now, notValidAfter) }