diff --git a/internal/config/tls/tls.go b/internal/config/tls/tls.go index 8cee6118..ea10f11e 100644 --- a/internal/config/tls/tls.go +++ b/internal/config/tls/tls.go @@ -69,6 +69,30 @@ func NewTLSTemplate() (*x509.Certificate, error) { }, nil } +// NewPEMKeyPair return a new TLS private key and certificate in PEM encoded format. +func NewPEMKeyPair() (pemCert, pemKey []byte, err error) { + template, err := NewTLSTemplate() + if err != nil { + return nil, nil, errors.Wrap(err, "failed to generate TLS template") + } + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, errors.Wrap(err, "failed to generate private key") + } + + pemKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, errors.Wrap(err, "failed to create certificate") + } + + pemCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + + return pemCert, pemKey, nil +} + var ErrTLSCertExpiresSoon = fmt.Errorf("TLS certificate will expire soon") // getTLSCertPath returns path to certificate; used for TLS servers (IMAP, SMTP). @@ -132,6 +156,21 @@ func (t *TLS) GetConfig() (*tls.Config, error) { return nil, errors.Wrap(err, "failed to load keypair") } + return getConfigFromKeyPair(c) +} + +// GetConfigFromPEMKeyPair load a TLS config from PEM encoded certificate and key. +func GetConfigFromPEMKeyPair(permCert, pemKey []byte) (*tls.Config, error) { + c, err := tls.X509KeyPair(permCert, pemKey) + if err != nil { + return nil, errors.Wrap(err, "failed to load keypair") + } + + return getConfigFromKeyPair(c) +} + +func getConfigFromKeyPair(c tls.Certificate) (*tls.Config, error) { + var err error c.Leaf, err = x509.ParseCertificate(c.Certificate[0]) if err != nil { return nil, errors.Wrap(err, "failed to parse certificate") diff --git a/internal/config/tls/tls_test.go b/internal/config/tls/tls_test.go index 0248995c..866c9682 100644 --- a/internal/config/tls/tls_test.go +++ b/internal/config/tls/tls_test.go @@ -75,3 +75,11 @@ func TestGetValidConfig(t *testing.T) { now, notValidAfter := time.Now(), config.Certificates[0].Leaf.NotAfter require.False(t, now.After(notValidAfter), "new certificate expected to be valid at %v but have valid until %v", now, notValidAfter) } + +func TestNewConfig(t *testing.T) { + pemCert, pemKey, err := NewPEMKeyPair() + require.NoError(t, err) + + _, err = GetConfigFromPEMKeyPair(pemCert, pemKey) + require.NoError(t, err) +} diff --git a/internal/frontend/bridge-gui/bridgepp/bridgepp/GRPC/GRPCClient.cpp b/internal/frontend/bridge-gui/bridgepp/bridgepp/GRPC/GRPCClient.cpp index ec592db3..4693aa2b 100644 --- a/internal/frontend/bridge-gui/bridgepp/bridgepp/GRPC/GRPCClient.cpp +++ b/internal/frontend/bridge-gui/bridgepp/bridgepp/GRPC/GRPCClient.cpp @@ -37,7 +37,6 @@ Empty empty; // re-used across client calls. int const maxConnectionTimeSecs = 60; ///< Amount of time after which we consider connection attempts to the server have failed. -int const maxCertificateWaitMsecs = 60 * 1000; ///< Amount of time we wait for he server to generate the certificate. } @@ -91,60 +90,6 @@ GRPCConfig GRPCClient::waitAndRetrieveServiceConfig(qint64 timeoutMs) } -//**************************************************************************************************************************************************** -/// \brief wait for certificate generation by Bridge -/// \return server certificate generated by Bridge -//**************************************************************************************************************************************************** -std::string GRPCClient::getServerCertificate() -{ - QString const certPath = serverCertificatePath(); - QString const certFolder = QFileInfo(certPath).absolutePath(); - QFile file(certPath); - // TODO : the certificate can exist but still be invalid. - // If the certificate is close to its limit, the bridge will generate a new one. - // If we read the certificate before the bridge rewrites it the certificate will be invalid. - if (!file.exists()) - { - // wait for file creation - QFileSystemWatcher watcher(this); - if (!watcher.addPath(certFolder)) - throw Exception("Failed to watch User Config Directory"); - connect(&watcher, &QFileSystemWatcher::directoryChanged, this, &GRPCClient::configFolderChanged); - - // set up an eventLoop to wait for the certIsReady signal or timeout. - QTimer timer; - timer.setSingleShot(true); - QEventLoop loop; - connect(this, &GRPCClient::certIsReady, &loop, &QEventLoop::quit); - connect(&timer, &QTimer::timeout, &loop, &QEventLoop::quit); - timer.start(maxCertificateWaitMsecs); - loop.exec(); - - // timeout case. - if (!timer.isActive()) - throw Exception("Server failed to generate certificate on time"); - //else certIsReadySignal. - } - - if (!file.open(QFile::ReadOnly)) - throw Exception("Failed to read the server certificate"); - QByteArray qbaCert = file.readAll(); - std::string cert(qbaCert.constData(), qbaCert.length()); - file.close(); - return cert; -} - - -//**************************************************************************************************************************************************** -/// \brief Action on UserConfig directory changes, looking for the certificate creation -//**************************************************************************************************************************************************** -void GRPCClient::configFolderChanged() -{ - if (QFileInfo::exists(serverCertificatePath())) - emit certIsReady(); -} - - //**************************************************************************************************************************************************** /// \param[in] log The log //**************************************************************************************************************************************************** @@ -163,7 +108,7 @@ bool GRPCClient::connectToServer(GRPCConfig const &config, QString &outError) try { SslCredentialsOptions opts; - opts.pem_root_certs += this->getServerCertificate(); + opts.pem_root_certs += config.cert.toStdString(); QString const address = QString("127.0.0.1:%1").arg(config.port); channel_ = CreateChannel(address.toStdString(), grpc::SslCredentials(opts)); diff --git a/internal/frontend/bridge-gui/bridgepp/bridgepp/GRPC/GRPCClient.h b/internal/frontend/bridge-gui/bridgepp/bridgepp/GRPC/GRPCClient.h index ca67e8a1..d60f8055 100644 --- a/internal/frontend/bridge-gui/bridgepp/bridgepp/GRPC/GRPCClient.h +++ b/internal/frontend/bridge-gui/bridgepp/bridgepp/GRPC/GRPCClient.h @@ -204,9 +204,6 @@ public: grpc::Status runEventStreamReader(); ///< Retrieve and signal the events in the event stream. grpc::Status stopEventStreamReader(); ///< Stop the event stream. -private slots: - void configFolderChanged(); - private: void logTrace(QString const &message); ///< Log an event. void logError(QString const &message); ///< Log an event. @@ -223,8 +220,6 @@ private: grpc::Status methodWithStringParam(StringParamMethod method, QString const &str); ///< Perform a gRPC call that takes a string as a parameter and returns an Empty. SPUser parseGRPCUser(grpc::User const &grpcUser); ///< Parse a gRPC user struct and return a User. - - std::string getServerCertificate(); ///< Wait until server certificates is generated and retrieve it. void processAppEvent(grpc::AppEvent const &event); ///< Process an 'App' event. void processLoginEvent(grpc::LoginEvent const &event); ///< Process a 'Login' event. void processUpdateEvent(grpc::UpdateEvent const &event); ///< Process an 'Update' event. diff --git a/internal/frontend/grpc/service.go b/internal/frontend/grpc/service.go index 6438b02c..a6945ba2 100644 --- a/internal/frontend/grpc/service.go +++ b/internal/frontend/grpc/service.go @@ -31,6 +31,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/config/settings" + "github.com/ProtonMail/proton-bridge/v2/internal/config/tls" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/types" "github.com/ProtonMail/proton-bridge/v2/internal/locations" @@ -76,6 +77,7 @@ type Service struct { // nolint:structcheck initializationDone sync.Once firstTimeAutostart sync.Once locations *locations.Locations + pemCert string } // NewService returns a new instance of the service. @@ -108,23 +110,21 @@ func NewService( // set to 1 s.initializing.Add(1) - config, err := bridge.GetTLSConfig() - config.ClientAuth = cryptotls.NoClientCert // skip client auth if the certificate allow it. + tlsConfig, pemCert, err := s.generateTLSConfig() if err != nil { - s.log.WithError(err).Error("could not get TLS config") - panic(err) + s.log.WithError(err).Panic("could not generate gRPC TLS config") } - s.initAutostart() + s.pemCert = string(pemCert) - s.grpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(config))) + s.initAutostart() + s.grpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) RegisterBridgeServer(s.grpcServer, &s) s.listener, err = net.Listen("tcp", "127.0.0.1:0") // Port 0 means that the port is randomly picked by the system. - if err != nil { - s.log.WithError(err).Panic("could not create listener") + s.log.WithError(err).Panic("could not create gRPC listener") } if err := s.saveGRPCServerConfigFile(); err != nil { @@ -404,13 +404,32 @@ func (s *Service) installUpdate() { _ = s.SendEvent(NewUpdateSilentRestartNeededEvent()) } +func (s *Service) generateTLSConfig() (tlsConfig *cryptotls.Config, pemCert []byte, err error) { + pemCert, pemKey, err := tls.NewPEMKeyPair() + if err != nil { + return nil, nil, errors.New("Could not get TLS config") + } + + tlsConfig, err = tls.GetConfigFromPEMKeyPair(pemCert, pemKey) + if err != nil { + return nil, nil, errors.New("Could not get TLS config") + } + + tlsConfig.ClientAuth = cryptotls.NoClientCert // skip client auth if the certificate allow it. + + return tlsConfig, pemCert, nil +} + func (s *Service) saveGRPCServerConfigFile() error { address, ok := s.listener.Addr().(*net.TCPAddr) if !ok { return fmt.Errorf("could not retrieve gRPC service listener address") } - sc := config{Port: address.Port} + sc := config{ + Port: address.Port, + Cert: s.pemCert, + } settingsPath, err := s.locations.ProvideSettingsPath() if err != nil {