From 0fd5ca3a2492f4623994762d896547cc93b76f94 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Thu, 23 Apr 2020 14:41:59 +0200 Subject: [PATCH] feat: dialer refactor to support modular dialing/checking/proxying --- Changelog.md | 1 + pkg/config/pmapi_prod.go | 17 +- pkg/pmapi/clientmanager.go | 10 +- pkg/pmapi/dialer.go | 55 +++ pkg/pmapi/dialer_pinning.go | 99 +++++ ...h_proxy_test.go => dialer_pinning_test.go} | 44 ++- pkg/pmapi/dialer_proxy.go | 40 ++ pkg/pmapi/dialer_with_proxy.go | 374 ------------------ pkg/pmapi/pin_checker.go | 77 ++++ pkg/pmapi/proxy.go | 8 +- pkg/pmapi/proxy_test.go | 230 ++++++++--- pkg/pmapi/tlsreport.go | 145 +++++++ 12 files changed, 651 insertions(+), 449 deletions(-) create mode 100644 pkg/pmapi/dialer.go create mode 100644 pkg/pmapi/dialer_pinning.go rename pkg/pmapi/{dialer_with_proxy_test.go => dialer_pinning_test.go} (72%) create mode 100644 pkg/pmapi/dialer_proxy.go delete mode 100644 pkg/pmapi/dialer_with_proxy.go create mode 100644 pkg/pmapi/pin_checker.go create mode 100644 pkg/pmapi/tlsreport.go diff --git a/Changelog.md b/Changelog.md index 053ff7f4..4a9b6861 100644 --- a/Changelog.md +++ b/Changelog.md @@ -26,6 +26,7 @@ Changelog [format](http://keepachangelog.com/en/1.0.0/) * `ClientManager` is the "one source of truth" for the host URL for all `Client`s * Alternative Routing is enabled/disabled by `ClientManager` * Logging out of `Clients` is handled/retried asynchronously by `ClientManager` +* GODT-265 Alternative Routing v2 (more resiliant to short term connection drops) ### Fixed diff --git a/pkg/config/pmapi_prod.go b/pkg/config/pmapi_prod.go index 46eefeb0..d9b27745 100644 --- a/pkg/config/pmapi_prod.go +++ b/pkg/config/pmapi_prod.go @@ -40,11 +40,18 @@ func (c *Config) GetAPIConfig() *pmapi.ClientConfig { } func (c *Config) GetRoundTripper(cm *pmapi.ClientManager, listener listener.Listener) http.RoundTripper { - pin := pmapi.NewDialerWithPinning(cm, c.GetAPIConfig().AppVersion) + // We use a TLS dialer. + basicDialer := pmapi.NewBasicTLSDialer() - pin.ReportCertIssueLocal = func() { - listener.Emit(events.TLSCertIssue, "") - } + // We wrap the TLS dialer in a layer which enforces connections to trusted servers. + pinningDialer := pmapi.NewPinningTLSDialer(basicDialer, c.GetAPIConfig().AppVersion) - return pin.TransportWithPinning() + // We want any pin mismatches to be communicated back to bridge GUI and reported. + pinningDialer.SetTLSIssueNotifier(func() { listener.Emit(events.TLSCertIssue, "") }) + pinningDialer.SetRemoteTLSIssueReporting(true) + + // We wrap the pinning dialer in a layer which adds "alternative routing" feature. + proxyDialer := pmapi.NewProxyTLSDialer(pinningDialer, cm) + + return pmapi.CreateTransportWithDialer(proxyDialer) } diff --git a/pkg/pmapi/clientmanager.go b/pkg/pmapi/clientmanager.go index 03c17eeb..708c7756 100644 --- a/pkg/pmapi/clientmanager.go +++ b/pkg/pmapi/clientmanager.go @@ -230,6 +230,14 @@ func (cm *ClientManager) switchToReachableServer() (proxy string, err error) { return } + // If the chosen proxy is the standard API, we want to use it but still show the troubleshooting screen. + if proxy == rootURL { + logrus.Info("The standard API is reachable again; connection drop was only intermittent") + err = ErrAPINotReachable + cm.host = proxy + return + } + logrus.WithField("proxy", proxy).Info("Switching to a proxy") // If the host is currently the rootURL, it's the first time we are enabling a proxy. @@ -243,7 +251,7 @@ func (cm *ClientManager) switchToReachableServer() (proxy string, err error) { cm.host = proxy - return + return proxy, err } // GetToken returns the token for the given userID. diff --git a/pkg/pmapi/dialer.go b/pkg/pmapi/dialer.go new file mode 100644 index 00000000..c3b725e1 --- /dev/null +++ b/pkg/pmapi/dialer.go @@ -0,0 +1,55 @@ +package pmapi + +import ( + "crypto/tls" + "net" + "net/http" + "time" +) + +type TLSDialer interface { + DialTLS(network, address string) (conn net.Conn, err error) +} + +// CreateTransportWithDialer creates an http.Transport that uses the given dialer to make TLS connections. +func CreateTransportWithDialer(dialer TLSDialer) *http.Transport { + return &http.Transport{ + DialTLS: dialer.DialTLS, + + Proxy: http.ProxyFromEnvironment, + MaxIdleConns: 100, + IdleConnTimeout: 5 * time.Minute, + ExpectContinueTimeout: 500 * time.Millisecond, + + // GODT-126: this was initially 10s but logs from users showed a significant number + // were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect. + // Bumping to 30s for now to avoid this problem. + ResponseHeaderTimeout: 30 * time.Second, + + // If we allow up to 30 seconds for response headers, it is reasonable to allow up + // to 30 seconds for the TLS handshake to take place. + TLSHandshakeTimeout: 30 * time.Second, + } +} + +// BasicTLSDialer implements TLSDialer. +type BasicTLSDialer struct{} + +// NewBasicTLSDialer returns a new BasicTLSDialer. +func NewBasicTLSDialer() *BasicTLSDialer { + return &BasicTLSDialer{} +} + +// DialTLS returns a connection to the given address using the given network. +func (b *BasicTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) { + dialer := &net.Dialer{Timeout: 10 * time.Second} + + var tlsConfig *tls.Config = nil + + // If we are not dialing the standard API then we should skip cert verification checks. + if address != rootURL { + tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec] + } + + return tls.DialWithDialer(dialer, network, address, tlsConfig) +} diff --git a/pkg/pmapi/dialer_pinning.go b/pkg/pmapi/dialer_pinning.go new file mode 100644 index 00000000..358bcba4 --- /dev/null +++ b/pkg/pmapi/dialer_pinning.go @@ -0,0 +1,99 @@ +// Copyright (c) 2020 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "crypto/tls" + "net" + "time" + + "github.com/sirupsen/logrus" +) + +// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and +// to report errors if the fingerprint check fails. +type PinningTLSDialer struct { + dialer TLSDialer + + // pinChecker is used to check TLS keys of connections. + pinChecker PinChecker + + // appVersion is supplied if there is a TLS mismatch. + appVersion string + + // tlsIssueNotifier is used to notify something when there is a TLS issue. + tlsIssueNotifier func() + + // enableRemoteReporting instructs the dialer to report TLS mismatches. + enableRemoteReporting bool + + // A logger for logging messages. + log logrus.FieldLogger +} + +// NewPinningTLSDialer constructs a new dialer which only returns tcp connections to servers +// which present known certificates. +// If enabled, it reports any invalid certificates it finds. +func NewPinningTLSDialer(dialer TLSDialer, appVersion string) *PinningTLSDialer { + return &PinningTLSDialer{ + dialer: dialer, + pinChecker: NewPinChecker(TrustedAPIPins), + appVersion: appVersion, + log: logrus.WithField("pkg", "pmapi/tls-pinning"), + } +} + +func (p *PinningTLSDialer) SetTLSIssueNotifier(notifier func()) { + p.tlsIssueNotifier = notifier +} + +func (p *PinningTLSDialer) SetRemoteTLSIssueReporting(enabled bool) { + p.enableRemoteReporting = enabled +} + +// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins. +func (p *PinningTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) { + if conn, err = p.dialer.DialTLS(network, address); err != nil { + return + } + + host, port, err := net.SplitHostPort(address) + if err != nil { + return + } + + if err = p.pinChecker.CheckCertificate(conn); err != nil { + if p.tlsIssueNotifier != nil { + go p.tlsIssueNotifier() + } + + if tlsConn, ok := conn.(*tls.Conn); ok && p.enableRemoteReporting { + p.pinChecker.ReportCertIssue( + host, + port, + time.Now().Format(time.RFC3339), + tlsConn.ConnectionState(), + p.appVersion, + ) + } + + return + } + + return +} diff --git a/pkg/pmapi/dialer_with_proxy_test.go b/pkg/pmapi/dialer_pinning_test.go similarity index 72% rename from pkg/pmapi/dialer_with_proxy_test.go rename to pkg/pmapi/dialer_pinning_test.go index 53f526e3..24c5c2f6 100644 --- a/pkg/pmapi/dialer_with_proxy_test.go +++ b/pkg/pmapi/dialer_pinning_test.go @@ -30,19 +30,21 @@ var testLiveConfig = &ClientConfig{ ClientID: "Bridge", } -func setTestDialerWithPinning(cm *ClientManager) (*int, *DialerWithPinning) { +func createAndSetPinningDialer(cm *ClientManager) (*int, *PinningTLSDialer) { called := 0 - p := NewDialerWithPinning(cm, testLiveConfig.AppVersion) - p.ReportCertIssueLocal = func() { called++ } - cm.SetRoundTripper(p.TransportWithPinning()) - return &called, p + + dialer := NewPinningTLSDialer(NewBasicTLSDialer(), testLiveConfig.AppVersion) + dialer.SetTLSIssueNotifier(func() { called++ }) + cm.SetRoundTripper(CreateTransportWithDialer(dialer)) + + return &called, dialer } func TestTLSPinValid(t *testing.T) { cm := newTestClientManager(testLiveConfig) cm.host = liveAPI rootScheme = "https" - called, _ := setTestDialerWithPinning(cm) + called, _ := createAndSetPinningDialer(cm) client := cm.GetClient("pmapi" + t.Name()) _, err := client.AuthInfo("this.address.is.disabled") @@ -54,9 +56,9 @@ func TestTLSPinValid(t *testing.T) { func TestTLSPinBackup(t *testing.T) { cm := newTestClientManager(testLiveConfig) cm.host = liveAPI - called, p := setTestDialerWithPinning(cm) - p.report.KnownPins[1] = p.report.KnownPins[0] - p.report.KnownPins[0] = "" + called, p := createAndSetPinningDialer(cm) + p.pinChecker.trustedPins[1] = p.pinChecker.trustedPins[0] + p.pinChecker.trustedPins[0] = "" client := cm.GetClient("pmapi" + t.Name()) @@ -70,9 +72,9 @@ func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused] cm := newTestClientManager(testLiveConfig) cm.host = liveAPI - called, p := setTestDialerWithPinning(cm) - for i := 0; i < len(p.report.KnownPins); i++ { - p.report.KnownPins[i] = "testing" + called, p := createAndSetPinningDialer(cm) + for i := 0; i < len(p.pinChecker.trustedPins); i++ { + p.pinChecker.trustedPins[i] = "testing" } client := cm.GetClient("pmapi" + t.Name()) @@ -96,7 +98,7 @@ func _TestTLSPinInvalid(t *testing.T) { // nolint[unused] })) defer ts.Close() - called, _ := setTestDialerWithPinning(cm) + called, _ := createAndSetPinningDialer(cm) client := cm.GetClient("pmapi" + t.Name()) @@ -113,23 +115,23 @@ func _TestTLSPinInvalid(t *testing.T) { // nolint[unused] func _TestTLSSignedCertWrongPublicKey(t *testing.T) { // nolint[unused] cm := newTestClientManager(testLiveConfig) - _, dialer := setTestDialerWithPinning(cm) - _, err := dialer.dialAndCheckFingerprints("tcp", "rsa4096.badssl.com:443") + _, dialer := createAndSetPinningDialer(cm) + _, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443") Assert(t, err != nil, "expected dial to fail because of wrong public key: ", err.Error()) } func _TestTLSSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused] cm := newTestClientManager(testLiveConfig) - _, dialer := setTestDialerWithPinning(cm) - dialer.report.KnownPins = append(dialer.report.KnownPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`) - _, err := dialer.dialAndCheckFingerprints("tcp", "rsa4096.badssl.com:443") + _, dialer := createAndSetPinningDialer(cm) + dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`) + _, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443") Assert(t, err == nil, "expected dial to succeed because public key is known and cert is signed by CA: ", err.Error()) } func _TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused] cm := newTestClientManager(testLiveConfig) - _, dialer := setTestDialerWithPinning(cm) - dialer.report.KnownPins = append(dialer.report.KnownPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`) - _, err := dialer.dialAndCheckFingerprints("tcp", "self-signed.badssl.com:443") + _, dialer := createAndSetPinningDialer(cm) + dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`) + _, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443") Assert(t, err == nil, "expected dial to succeed because public key is known despite cert being self-signed: ", err.Error()) } diff --git a/pkg/pmapi/dialer_proxy.go b/pkg/pmapi/dialer_proxy.go new file mode 100644 index 00000000..11af0128 --- /dev/null +++ b/pkg/pmapi/dialer_proxy.go @@ -0,0 +1,40 @@ +package pmapi + +import ( + "net" +) + +// ProxyTLSDialer wraps a TLSDialer to switch to a proxy if the initial dial fails. +type ProxyTLSDialer struct { + dialer TLSDialer + + cm *ClientManager +} + +// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer. +func NewProxyTLSDialer(dialer TLSDialer, cm *ClientManager) *ProxyTLSDialer { + return &ProxyTLSDialer{ + dialer: dialer, + cm: cm, + } +} + +// DialTLS dials the given network/address. If it fails, it retries using a proxy. +func (d *ProxyTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) { + if conn, err = d.dialer.DialTLS(network, address); err == nil { + return + } + + var proxy string + + if proxy, err = d.cm.switchToReachableServer(); err != nil { + return + } + + _, port, err := net.SplitHostPort(address) + if err != nil { + return + } + + return d.dialer.DialTLS(network, net.JoinHostPort(proxy, port)) +} diff --git a/pkg/pmapi/dialer_with_proxy.go b/pkg/pmapi/dialer_with_proxy.go deleted file mode 100644 index d669b73e..00000000 --- a/pkg/pmapi/dialer_with_proxy.go +++ /dev/null @@ -1,374 +0,0 @@ -// Copyright (c) 2020 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package pmapi - -import ( - "bytes" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "encoding/json" - "encoding/pem" - "fmt" - "io/ioutil" - "net" - "net/http" - "strconv" - "time" - - "github.com/sirupsen/logrus" -) - -// TLSReport is inspired by https://tools.ietf.org/html/rfc7469#section-3. -type TLSReport struct { - // DateTime of observed pin validation in time.RFC3339 format. - DateTime string `json:"date-time"` - - // Hostname to which the UA made original request that failed pin validation. - Hostname string `json:"hostname"` - - // Port to which the UA made original request that failed pin validation. - Port int `json:"port"` - - // EffectiveExpirationDate for noted pins in time.RFC3339 format. - EffectiveExpirationDate string `json:"effective-expiration-date"` - - // IncludeSubdomains indicates whether or not the UA has noted the - // includeSubDomains directive for the Known Pinned Host. - IncludeSubdomains bool `json:"include-subdomains"` - - // NotedHostname indicates the hostname that the UA noted when it noted - // the Known Pinned Host. This field allows operators to understand why - // Pin Validation was performed for, e.g., foo.example.com when the - // noted Known Pinned Host was example.com with includeSubDomains set. - NotedHostname string `json:"noted-hostname"` - - // ServedCertificateChain is the certificate chain, as served by - // the Known Pinned Host during TLS session setup. It is provided as an - // array of strings; each string pem1, ... pemN is the Privacy-Enhanced - // Mail (PEM) representation of each X.509 certificate as described in - // [RFC7468]. - ServedCertificateChain []string `json:"served-certificate-chain"` - - // ValidatedCertificateChain is the certificate chain, as - // constructed by the UA during certificate chain verification. (This - // may differ from the served-certificate-chain.) It is provided as an - // array of strings; each string pem1, ... pemN is the PEM - // representation of each X.509 certificate as described in [RFC7468]. - // UAs that build certificate chains in more than one way during the - // validation process SHOULD send the last chain built. In this way, - // they can avoid keeping too much state during the validation process. - ValidatedCertificateChain []string `json:"validated-certificate-chain"` - - // The known-pins are the Pins that the UA has noted for the Known - // Pinned Host. They are provided as an array of strings with the - // syntax: known-pin = token "=" quoted-string - // e.g.: - // ``` - // "known-pins": [ - // 'pin-sha256="d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="', - // "pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\"" - // ] - // ``` - KnownPins []string `json:"known-pins"` - - // AppVersion is used to set `x-pm-appversion` json format from datatheorem/TrustKit. - AppVersion string `json:"app-version"` -} - -// ErrTLSMatch indicates that no TLS fingerprint match could be found. -var ErrTLSMatch = fmt.Errorf("TLS fingerprint match not found") - -// DialerWithPinning will provide dial function which checks the fingerprints of public cert -// received from contacted server. If no match found among know pinse it will report using -// ReportCertIssueLocal. -type DialerWithPinning struct { - // isReported will stop reporting if true. - isReported bool - - // report stores known pins. - report TLSReport - - // When reportURI is not empty the tls issue report will be send to this URI. - reportURI string - - // ReportCertIssueLocal is used send signal to application about certificate issue. - // It is used only if set. - ReportCertIssueLocal func() - - // cm is used to find and switch to a proxy if necessary. - cm *ClientManager - - // A logger for logging messages. - log logrus.FieldLogger -} - -// NewDialerWithPinning constructs a new dialer with pinned certs. -func NewDialerWithPinning(cm *ClientManager, appVersion string) *DialerWithPinning { - reportURI := "https://reports.protonmail.ch/reports/tls" - - report := TLSReport{ - EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339), - IncludeSubdomains: false, - ValidatedCertificateChain: []string{}, - ServedCertificateChain: []string{}, - AppVersion: appVersion, - - // NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;) - KnownPins: []string{ - `pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current - `pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot - `pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold - `pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main - `pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1 - `pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2 - `pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3 - }, - } - - log := logrus.WithField("pkg", "pmapi/tls-pinning") - - return &DialerWithPinning{ - cm: cm, - isReported: false, - reportURI: reportURI, - report: report, - log: log, - } -} - -func (p *DialerWithPinning) reportCertIssue(connState tls.ConnectionState) { - p.isReported = true - - if p.ReportCertIssueLocal != nil { - go p.ReportCertIssueLocal() - } - - if p.reportURI != "" { - p.report.NotedHostname = connState.ServerName - p.report.ServedCertificateChain = marshalCert7468(connState.PeerCertificates) - - if len(connState.VerifiedChains) > 0 { - p.report.ServedCertificateChain = marshalCert7468( - connState.VerifiedChains[len(connState.VerifiedChains)-1], - ) - } - - go p.reportCertIssueRemote() - } -} - -func (p *DialerWithPinning) reportCertIssueRemote() { - b, err := json.Marshal(p.report) - if err != nil { - p.log.Errorf("marshal request: %v", err) - return - } - - req, err := http.NewRequest("POST", p.reportURI, bytes.NewReader(b)) - if err != nil { - p.log.Errorf("create request: %v", err) - } - - req.Header.Add("Content-Type", "application/json") - req.Header.Set("User-Agent", CurrentUserAgent) - req.Header.Set("x-pm-apiversion", strconv.Itoa(Version)) - req.Header.Set("x-pm-appversion", p.report.AppVersion) - - p.log.Debugf("report req: %+v\n", req) - - c := &http.Client{} - res, err := c.Do(req) - p.log.Debugf("res: %+v\nerr: %v", res, err) - if err != nil { - return - } - _, _ = ioutil.ReadAll(res.Body) - if res.StatusCode != http.StatusOK { - p.log.Errorf("response status: %v", res.Status) - } - _ = res.Body.Close() -} - -func certFingerprint(cert *x509.Certificate) string { - hash := sha256.Sum256(cert.RawSubjectPublicKeyInfo) - return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:])) -} - -func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) { - var buffer bytes.Buffer - for _, cert := range certs { - if err := pem.Encode(&buffer, &pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - }); err != nil { - logrus.WithField("pkg", "pmapi/tls-pinning").Errorf("encoding TLS cert: %v", err) - } - pemCerts = append(pemCerts, buffer.String()) - buffer.Reset() - } - - return pemCerts -} - -// TransportWithPinning creates an http.Transport that checks fingerprints when dialing. -func (p *DialerWithPinning) TransportWithPinning() *http.Transport { - return &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialTLS: p.dialAndCheckFingerprints, - MaxIdleConns: 100, - IdleConnTimeout: 5 * time.Minute, - ExpectContinueTimeout: 500 * time.Millisecond, - - // GODT-126: this was initially 10s but logs from users showed a significant number - // were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect. - // Bumping to 30s for now to avoid this problem. - ResponseHeaderTimeout: 30 * time.Second, - - // If we allow up to 30 seconds for response headers, it is reasonable to allow up - // to 30 seconds for the TLS handshake to take place. - TLSHandshakeTimeout: 30 * time.Second, - } -} - -// dialAndCheckFingerprint to set as http.Transport.DialTLS. -// -// * note that when DialTLS is not nil the Transport.TLSClientConfig and Transport.TLSHandshakeTimeout are ignored. -// * dialAndCheckFingerprints fails if certificate is not valid (not signed by authority or not matching hostname). -// * dialAndCheckFingerprints will pass if certificate pin does not have a match, but will send notification using -// p.ReportCertIssueLocal() and p.reportCertIssueRemote() if they are not nil. -func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (conn net.Conn, err error) { - // If DoH is enabled, we hardfail on fingerprint mismatches. - if p.cm.IsProxyAllowed() && p.isReported { - return nil, ErrTLSMatch - } - - // Try to dial the given address but use a proxy if necessary. - if conn, err = p.dialWithProxyFallback(network, address); err != nil { - return - } - - // If cert issue was already reported, we don't want to check fingerprints anymore. - if p.isReported { - return nil, ErrTLSMatch - } - - // Check the cert fingerprint to ensure it is known. - if err = p.checkFingerprints(conn); err != nil { - p.log.WithError(err).Error("Error checking cert fingerprints") - return - } - - return -} - -// dialWithProxyFallback tries to dial the given address but falls back to alternative proxies if need be. -func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn net.Conn, err error) { - p.log.Info("Dialing with proxy fallback") - - // Try to dial, and if it succeeds, then just return. - if conn, err = p.dial(network, address); err == nil { - return - } - - p.log.WithField("address", address).WithError(err).Error("Dialing failed") - - host, port, err := net.SplitHostPort(address) - if err != nil { - return - } - - // If DoH is not allowed, give up. Or, if we are dialing something other than the API - // (e.g. we dial protonmail.com/... to check for updates), there's also no point in - // continuing since a proxy won't help us reach that. - if !p.cm.IsProxyAllowed() || host != p.cm.getHost() { - p.log.WithField("address", address).Debug("Aborting dial, cannot switch to a proxy") - return - } - - // Switch to a proxy and retry the dial. - proxy, err := p.cm.switchToReachableServer() - if err != nil { - return - } - - proxyAddress := net.JoinHostPort(proxy, port) - - p.log.WithField("address", proxyAddress).Debug("Trying dial again using a proxy") - - return p.dial(network, proxyAddress) -} - -// dial returns a connection to the given address using the given network. -func (p *DialerWithPinning) dial(network, address string) (conn net.Conn, err error) { - var port string - if p.report.Hostname, port, err = net.SplitHostPort(address); err != nil { - return - } - if p.report.Port, err = strconv.Atoi(port); err != nil { - return - } - p.report.DateTime = time.Now().Format(time.RFC3339) - - dialer := &net.Dialer{Timeout: 10 * time.Second} - - // If we are not dialing the standard API then we should skip cert verification checks. - var tlsConfig *tls.Config = nil - if address != rootURL { - tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec] - } - - return tls.DialWithDialer(dialer, network, address, tlsConfig) -} - -func (p *DialerWithPinning) checkFingerprints(conn net.Conn) (err error) { - if !checkTLSCerts { - return - } - - connState := conn.(*tls.Conn).ConnectionState() - - hasFingerprintMatch := false - for _, peerCert := range connState.PeerCertificates { - fingerprint := certFingerprint(peerCert) - - for i, pin := range p.report.KnownPins { - if pin == fingerprint { - hasFingerprintMatch = true - - if i != 0 { - p.log.Warnf("Matched fingerprint (%q) was not primary pinned key (was key #%d)", fingerprint, i) - } - - break - } - } - - if hasFingerprintMatch { - break - } - } - - if !hasFingerprintMatch { - p.reportCertIssue(connState) - return ErrTLSMatch - } - - return err -} diff --git a/pkg/pmapi/pin_checker.go b/pkg/pmapi/pin_checker.go new file mode 100644 index 00000000..be8cc699 --- /dev/null +++ b/pkg/pmapi/pin_checker.go @@ -0,0 +1,77 @@ +package pmapi + +import ( + "bytes" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "net" + + "github.com/sirupsen/logrus" +) + +type PinChecker struct { + trustedPins []string +} + +func NewPinChecker(trustedPins []string) PinChecker { + return PinChecker{ + trustedPins: trustedPins, + } +} + +// CheckCertificate returns whether the connection presents a known TLS certificate. +func (p *PinChecker) CheckCertificate(conn net.Conn) error { + connState := conn.(*tls.Conn).ConnectionState() + + for _, peerCert := range connState.PeerCertificates { + fingerprint := certFingerprint(peerCert) + + for _, pin := range p.trustedPins { + if pin == fingerprint { + return nil + } + } + } + + return ErrTLSMismatch +} + +func certFingerprint(cert *x509.Certificate) string { + hash := sha256.Sum256(cert.RawSubjectPublicKeyInfo) + return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:])) +} + +// ReportCertIssue reports a TLS key mismatch. +func (p *PinChecker) ReportCertIssue(host, port, datetime string, connState tls.ConnectionState, appVersion string) { + var certChain []string + + if len(connState.VerifiedChains) > 0 { + certChain = marshalCert7468(connState.VerifiedChains[len(connState.VerifiedChains)-1]) + } else { + certChain = marshalCert7468(connState.PeerCertificates) + } + + report := NewTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion) + + go postCertIssueReport(report) +} + +func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) { + var buffer bytes.Buffer + for _, cert := range certs { + if err := pem.Encode(&buffer, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }); err != nil { + logrus.WithField("pkg", "pmapi/tls-pinning").Errorf("encoding TLS cert: %v", err) + } + pemCerts = append(pemCerts, buffer.String()) + buffer.Reset() + } + + return pemCerts +} diff --git a/pkg/pmapi/proxy.go b/pkg/pmapi/proxy.go index 5bf3cadd..ff147123 100644 --- a/pkg/pmapi/proxy.go +++ b/pkg/pmapi/proxy.go @@ -18,7 +18,6 @@ package pmapi import ( - "crypto/tls" "encoding/base64" "strings" "time" @@ -85,7 +84,8 @@ func (p *proxyProvider) findReachableServer() (proxy string, err error) { errResult := make(chan error) go func() { if err = p.refreshProxyCache(); err != nil { - logrus.WithError(err).Warn("Failed to refresh proxy cache, cache may be out of date") + errResult <- errors.Wrap(err, "failed to refresh proxy cache") + return } // We want to switch back to the rootURL if possible. @@ -144,10 +144,12 @@ func (p *proxyProvider) canReach(url string) bool { url = "https://" + url } + pinningDialer := NewPinningTLSDialer(NewBasicTLSDialer(), "") + pinger := resty.New(). SetHostURL(url). SetTimeout(p.lookupTimeout). - SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) // nolint[gosec] + SetTransport(CreateTransportWithDialer(pinningDialer)) if _, err := pinger.R().Get("/tests/ping"); err != nil { logrus.WithField("proxy", url).WithError(err).Warn("Failed to ping proxy") diff --git a/pkg/pmapi/proxy_test.go b/pkg/pmapi/proxy_test.go index a8e709af..30495b04 100644 --- a/pkg/pmapi/proxy_test.go +++ b/pkg/pmapi/proxy_test.go @@ -18,6 +18,7 @@ package pmapi import ( + "crypto/tls" "net/http" "net/http/httptest" "testing" @@ -32,12 +33,114 @@ const ( TestGoogleProvider = "https://dns.google/dns-query" ) +// getTrustedServer returns a server and sets its public key as one of the pinned ones. +func getTrustedServer() *httptest.Server { + proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + pin := certFingerprint(proxy.Certificate()) + TrustedAPIPins = append(TrustedAPIPins, pin) + + return proxy +} + +// server.crt +const servercrt = ` +-----BEGIN CERTIFICATE----- +MIIE5TCCA82gAwIBAgIJAKsmhcMFGfGcMA0GCSqGSIb3DQEBCwUAMIGsMQswCQYD +VQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzARBgNVBAcMClJhbmRvbUNp +dHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEfMB0GA1UECwwWUmFuZG9t +T3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYRaGVsbG9AZXhhbXBsZS5j +b20xEjAQBgNVBAMMCTEyNy4wLjAuMTAeFw0yMDA0MjQxMzI3MzdaFw0yMTA5MDYx +MzI3MzdaMIGsMQswCQYDVQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzAR +BgNVBAcMClJhbmRvbUNpdHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEf +MB0GA1UECwwWUmFuZG9tT3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYR +aGVsbG9AZXhhbXBsZS5jb20xEjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBANAnYyqhosWwNzGjBwSwmDUINOaPs4TSTgKt +r6CE01atxAWzWUCyYqnQ4fPe5q2tx5t/VrmnTNpzycammKJszGLlmj9DFxSiYVw2 +pTTK3DBWFkfTwxq98mM7wMnCWy1T2L2pmuYjnd7Pa6pQa9OHYoJwRzlIl2Q3YVdM +GIBDbkW728A1dcelkIdFpv3r3ayTZv01vU8JMXd4PLHwXU0x0hHlH52+kx+9Ndru +rdqqV6LqVfNlSR1jFZkwLBBqvh3XrJRD9Q01EAX6m+ufZ0yq8mK9ifMRtwQet10c +kKMnx63MwvxDFmqrBj4HMtIRUpK+LBDs1ke7DvS0eLqaojWl28ECAwEAAaOCAQYw +ggECMIHLBgNVHSMEgcMwgcChgbKkga8wgawxCzAJBgNVBAYTAlVTMRQwEgYDVQQI +DAtSYW5kb21TdGF0ZTETMBEGA1UEBwwKUmFuZG9tQ2l0eTEbMBkGA1UECgwSUmFu +ZG9tT3JnYW5pemF0aW9uMR8wHQYDVQQLDBZSYW5kb21Pcmdhbml6YXRpb25Vbml0 +MSAwHgYJKoZIhvcNAQkBFhFoZWxsb0BleGFtcGxlLmNvbTESMBAGA1UEAwwJMTI3 +LjAuMC4xggkAvCxbs152YckwCQYDVR0TBAIwADALBgNVHQ8EBAMCBPAwGgYDVR0R +BBMwEYIJMTI3LjAuMC4xhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAC7ZycZMZ5 +L+cjIpwSj0cemLkVD+kcFUCkI7ket5gbX1PmavmnpuFl9Sru0eJ5wyJ+97MQElPA +CNFgXoX7DbJWkcd/LSksvZoJnpc1sTqFKMWFmOUxmUD62lCacuhqE27ZTThQ/53P +3doLa74rKzUqlPI8OL4R34FY2deL7t5l2KSnpf7CKNeF5bkinAsn6NBqyZs2KPmg +yT1/POdlRewzGSqBTMdktNQ4vKSfdFjcfVeo8PSHBgbGXZ5KoHZ6R6DNJehEh27l +z3OteROLGoii+w3OllLq6JATif2MDIbH0s/KjGjbXSSGbM/rZu5eBZm5/vksGAzc +u53wgIhCJGuX +-----END CERTIFICATE----- +` + +const serverkey = ` +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDQJ2MqoaLFsDcx +owcEsJg1CDTmj7OE0k4Cra+ghNNWrcQFs1lAsmKp0OHz3uatrcebf1a5p0zac8nG +ppiibMxi5Zo/QxcUomFcNqU0ytwwVhZH08MavfJjO8DJwlstU9i9qZrmI53ez2uq +UGvTh2KCcEc5SJdkN2FXTBiAQ25Fu9vANXXHpZCHRab9692sk2b9Nb1PCTF3eDyx +8F1NMdIR5R+dvpMfvTXa7q3aqlei6lXzZUkdYxWZMCwQar4d16yUQ/UNNRAF+pvr +n2dMqvJivYnzEbcEHrddHJCjJ8etzML8QxZqqwY+BzLSEVKSviwQ7NZHuw70tHi6 +mqI1pdvBAgMBAAECggEAOqqPOYm63arPs462QK0hCPlaJ41i1FGNqRWYxU4KXoi1 +EcI9qo1cX24+8MPnEhZDhuD56XNsprkxqmpz5Htzk4AQ3DmlfKxTcnD4WQu/yWPJ +/c6CU7wrX6qMqJC9r+XM1Y/C15A8Q3sEZkkqSsECk67fdBawjI9LQRZyZVwb7U0F +qtvbKM7VQA6hrgdSmXWJ+spp5yymVFF22Ssz31SSbCI93bnp3mukRCKWdRmA9pmT +VXa0HzJ5p70WC+Se9nA/1riWGKt4HCmjVeEtZuiwaUTlXDSeYpu2e4QrX1OnUXBu +Z7yfviTqA8o7KfiA6urumFbAMJcibxkWJoWacc5tTQKBgQD39ZdtNz8B6XJy7f5h +bo9Ag9OrkVX+HITQyWKpcCDba9SuIX3/F++2AK4oeJ3aHKMJWiP19hQvGS1xE67X +TKejOsQxORn6nAYQpFd3AOBOtKAC+VQITBqlfq2ukGmvcQ1O31hMOFbZagFA5cpU +LYb9VVDsZzhM7CccIn/EGEZjgwKBgQDW51rUA2S9naV/iEGhw1tuhoQ5OADD/n8f +pPIkbGxmACDaX/7jt+UwlDU0EsI+aBlJUDqGiEZ5z3UPmaSJUdfRCeJEdKIe1GLm +nqF3sF6Aq+S/79v/wKYn+MHcoiWog5n3McLzZ3+0rwrhMREjE2eWPwVHz/jJIFP3 +Pp3+UZVsawKBgB4Az5PdjXgzwS968L7lW9wYl3I5Iciftsp0s8WA1dj3EUMItnA5 +ez3wkyI+hgswT+H/0D4gyoxwZXk7Qnq2wcoUgEzcdfJHEszMtfCmYH3liT8S4EIo +w0inLWjj/IXIDi4vBEYkww2HsCMkKvlIkP7yZdpVGxDjuk/DNOaLcWj1AoGAXuyK +PiPRl7/Onmp9MwqrlEJunSeTjv8W/89H9ba+mr9rw4mreMJ9xdtxNLMkgZRRtwRt +FYeUObHdLyradp1kCr2m6D3sblm55cwj3k5VL9i9jdpQ/sMFoZpLZz1oDOs0Uu/0 +ALeyvQikcZvOygOEOeVUW8gNSCmzbP6HoxI+QkkCgYBCI6oL4GPcPPqzd+2djbOD +z3rVUyHzYc1KUcBixK/uaRQKM886k4CL8/GvbHHI/yoZ7xWJGnBi59DtpqnGTZJ2 +FDJwYIlQKhZmsyVcZu/4smsaejGnHn/liksVlgesSwCtOrsd2AC8fBXSyrTWJx8o +vwRMog6lPhlRhHh/FZ43Cg== +-----END PRIVATE KEY----- +` + +// getUntrustedServer returns a server but it doesn't add its public key to the list of pinned ones. +func getUntrustedServer() *httptest.Server { + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + cert, err := tls.X509KeyPair([]byte(servercrt), []byte(serverkey)) + if err != nil { + panic(err) + } + server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} + + server.StartTLS() + return server +} + +// closeServer closes the given server. If it is a trusted server, its cert is removed from the trusted public keys. +func closeServer(server *httptest.Server) { + pin := certFingerprint(server.Certificate()) + + for i := range TrustedAPIPins { + if TrustedAPIPins[i] == pin { + TrustedAPIPins = append(TrustedAPIPins[:i], TrustedAPIPins[i:]...) + break + } + } + + server.Close() +} + func TestProxyProvider_FindProxy(t *testing.T) { blockAPI() defer unblockAPI() - proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer proxy.Close() + proxy := getTrustedServer() + defer closeServer(proxy) p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } @@ -51,34 +154,72 @@ func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) { blockAPI() defer unblockAPI() - badProxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - goodProxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + reachableProxy := getTrustedServer() + defer closeServer(reachableProxy) - // Close the bad proxy first so it isn't reachable; we should then choose the good proxy. - badProxy.Close() - defer goodProxy.Close() + // We actually close the unreachable proxy straight away rather than deferring the closure. + unreachableProxy := getTrustedServer() + closeServer(unreachableProxy) p := newProxyProvider([]string{"not used"}, "not used") - p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, goodProxy.URL}, nil } + p.dohLookup = func(q, p string) ([]string, error) { return []string{reachableProxy.URL, unreachableProxy.URL}, nil } url, err := p.findReachableServer() require.NoError(t, err) - require.Equal(t, goodProxy.URL, url) + require.Equal(t, reachableProxy.URL, url) +} + +func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) { + blockAPI() + defer unblockAPI() + + trustedProxy := getTrustedServer() + defer closeServer(trustedProxy) + + untrustedProxy := getUntrustedServer() + defer closeServer(untrustedProxy) + + p := newProxyProvider([]string{"not used"}, "not used") + p.dohLookup = func(q, p string) ([]string, error) { return []string{untrustedProxy.URL, trustedProxy.URL}, nil } + + url, err := p.findReachableServer() + require.NoError(t, err) + require.Equal(t, trustedProxy.URL, url) } func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) { blockAPI() defer unblockAPI() - badProxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - anotherBadProxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + unreachableProxy1 := getTrustedServer() + closeServer(unreachableProxy1) - // Close the proxies to simulate them not being reachable. - badProxy.Close() - anotherBadProxy.Close() + unreachableProxy2 := getTrustedServer() + closeServer(unreachableProxy2) p := newProxyProvider([]string{"not used"}, "not used") - p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, anotherBadProxy.URL}, nil } + p.dohLookup = func(q, p string) ([]string, error) { + return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil + } + + _, err := p.findReachableServer() + require.Error(t, err) +} + +func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) { + blockAPI() + defer unblockAPI() + + untrustedProxy1 := getUntrustedServer() + defer closeServer(untrustedProxy1) + + untrustedProxy2 := getUntrustedServer() + defer closeServer(untrustedProxy2) + + p := newProxyProvider([]string{"not used"}, "not used") + p.dohLookup = func(q, p string) ([]string, error) { + return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil + } _, err := p.findReachableServer() require.Error(t, err) @@ -88,9 +229,6 @@ func TestProxyProvider_FindProxy_LookupTimeout(t *testing.T) { blockAPI() defer unblockAPI() - proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer proxy.Close() - p := newProxyProvider([]string{"not used"}, "not used") p.lookupTimeout = time.Second p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil } @@ -124,17 +262,17 @@ func TestProxyProvider_UseProxy(t *testing.T) { cm := newTestClientManager(testClientConfig) - proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer proxy.Close() + trustedProxy := getTrustedServer() + defer closeServer(trustedProxy) p := newProxyProvider([]string{"not used"}, "not used") cm.proxyProvider = p - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } + p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } url, err := cm.switchToReachableServer() require.NoError(t, err) - require.Equal(t, proxy.URL, url) - require.Equal(t, proxy.URL, cm.getHost()) + require.Equal(t, trustedProxy.URL, url) + require.Equal(t, trustedProxy.URL, cm.getHost()) } func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { @@ -143,12 +281,12 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { cm := newTestClientManager(testClientConfig) - proxy1 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer proxy1.Close() - proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer proxy2.Close() - proxy3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer proxy3.Close() + proxy1 := getTrustedServer() + defer closeServer(proxy1) + proxy2 := getTrustedServer() + defer closeServer(proxy2) + proxy3 := getTrustedServer() + defer closeServer(proxy3) p := newProxyProvider([]string{"not used"}, "not used") cm.proxyProvider = p @@ -184,18 +322,18 @@ func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) { cm := newTestClientManager(testClientConfig) - proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer proxy.Close() + trustedProxy := getTrustedServer() + defer closeServer(trustedProxy) p := newProxyProvider([]string{"not used"}, "not used") cm.proxyProvider = p cm.proxyUseDuration = time.Second - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } + p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } url, err := cm.switchToReachableServer() require.NoError(t, err) - require.Equal(t, proxy.URL, url) - require.Equal(t, proxy.URL, cm.getHost()) + require.Equal(t, trustedProxy.URL, url) + require.Equal(t, trustedProxy.URL, cm.getHost()) time.Sleep(2 * time.Second) require.Equal(t, rootURL, cm.getHost()) @@ -207,26 +345,27 @@ func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachab cm := newTestClientManager(testClientConfig) - proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer proxy.Close() + trustedProxy := getTrustedServer() p := newProxyProvider([]string{"not used"}, "not used") cm.proxyProvider = p - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } + p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } url, err := cm.switchToReachableServer() require.NoError(t, err) - require.Equal(t, proxy.URL, url) - require.Equal(t, proxy.URL, cm.getHost()) + require.Equal(t, trustedProxy.URL, url) + require.Equal(t, trustedProxy.URL, cm.getHost()) // Simulate that the proxy stops working and that the standard api is reachable again. - proxy.Close() + closeServer(trustedProxy) unblockAPI() time.Sleep(proxyLookupWait) // We should now find the original API URL if it is working again. + // The error should be ErrAPINotReachable because the connection dropped intermittently but + // the original API is now reachable (see Alternative-Routing-v2 spec for details). url, err = cm.switchToReachableServer() - require.NoError(t, err) + require.EqualError(t, err, ErrAPINotReachable.Error()) require.Equal(t, rootURL, url) require.Equal(t, rootURL, cm.getHost()) } @@ -237,10 +376,11 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl cm := newTestClientManager(testClientConfig) - proxy1 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer proxy1.Close() - proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer proxy2.Close() + // proxy1 is closed later in this test so we don't defer it here. + proxy1 := getTrustedServer() + + proxy2 := getTrustedServer() + defer closeServer(proxy2) p := newProxyProvider([]string{"not used"}, "not used") cm.proxyProvider = p diff --git a/pkg/pmapi/tlsreport.go b/pkg/pmapi/tlsreport.go new file mode 100644 index 00000000..058ecac9 --- /dev/null +++ b/pkg/pmapi/tlsreport.go @@ -0,0 +1,145 @@ +package pmapi + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "strconv" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +// ErrTLSMismatch indicates that no TLS fingerprint match could be found. +var ErrTLSMismatch = errors.New("no TLS fingerprint match found") + +// TrustedAPIPins contains trusted public keys of the protonmail API and proxies. +// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;) +var TrustedAPIPins = []string{ // nolint[gochecknoglobals] + `pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current + `pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot + `pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold + `pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main + `pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1 + `pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2 + `pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3 +} + +// TLSReportURI is the address where TLS reports should be sent. +const TLSReportURI = "https://reports.protonmail.ch/reports/tls" + +// TLSReport is inspired by https://tools.ietf.org/html/rfc7469#section-3. +// When a TLS key mismatch is detected, a TLSReport is posted to TLSReportURI. +type TLSReport struct { + // DateTime of observed pin validation in time.RFC3339 format. + DateTime string `json:"date-time"` + + // Hostname to which the UA made original request that failed pin validation. + Hostname string `json:"hostname"` + + // Port to which the UA made original request that failed pin validation. + Port int `json:"port"` + + // EffectiveExpirationDate for noted pins in time.RFC3339 format. + EffectiveExpirationDate string `json:"effective-expiration-date"` + + // IncludeSubdomains indicates whether or not the UA has noted the + // includeSubDomains directive for the Known Pinned Host. + IncludeSubdomains bool `json:"include-subdomains"` + + // NotedHostname indicates the hostname that the UA noted when it noted + // the Known Pinned Host. This field allows operators to understand why + // Pin Validation was performed for, e.g., foo.example.com when the + // noted Known Pinned Host was example.com with includeSubDomains set. + NotedHostname string `json:"noted-hostname"` + + // ServedCertificateChain is the certificate chain, as served by + // the Known Pinned Host during TLS session setup. It is provided as an + // array of strings; each string pem1, ... pemN is the Privacy-Enhanced + // Mail (PEM) representation of each X.509 certificate as described in + // [RFC7468]. + ServedCertificateChain []string `json:"served-certificate-chain"` + + // ValidatedCertificateChain is the certificate chain, as + // constructed by the UA during certificate chain verification. (This + // may differ from the served-certificate-chain.) It is provided as an + // array of strings; each string pem1, ... pemN is the PEM + // representation of each X.509 certificate as described in [RFC7468]. + // UAs that build certificate chains in more than one way during the + // validation process SHOULD send the last chain built. In this way, + // they can avoid keeping too much state during the validation process. + ValidatedCertificateChain []string `json:"validated-certificate-chain"` + + // The known-pins are the Pins that the UA has noted for the Known + // Pinned Host. They are provided as an array of strings with the + // syntax: known-pin = token "=" quoted-string + // e.g.: + // ``` + // "known-pins": [ + // 'pin-sha256="d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="', + // "pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\"" + // ] + // ``` + KnownPins []string `json:"known-pins"` + + // AppVersion is used to set `x-pm-appversion` json format from datatheorem/TrustKit. + AppVersion string `json:"app-version"` +} + +// NewTLSReport constructs a new TLSreport configured with the given app version and known pinned public keys. +func NewTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report TLSReport) { + // If we can't parse the port for whatever reason, it doesn't really matter; we should report anyway. + intPort, _ := strconv.Atoi(port) + + report = TLSReport{ + Hostname: host, + Port: intPort, + EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339), + IncludeSubdomains: false, + NotedHostname: server, + ValidatedCertificateChain: []string{}, + ServedCertificateChain: certChain, + KnownPins: knownPins, + AppVersion: appVersion, + } + + return +} + +// postCertIssueReport posts the given TLS report to the standard TLS Report URI. +func postCertIssueReport(report TLSReport) { + b, err := json.Marshal(report) + if err != nil { + logrus.WithError(err).Error("Failed to marshal TLS report") + return + } + + req, err := http.NewRequest("POST", TLSReportURI, bytes.NewReader(b)) + if err != nil { + logrus.WithError(err).Error("Failed to create http request") + return + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Set("User-Agent", CurrentUserAgent) + req.Header.Set("x-pm-apiversion", strconv.Itoa(Version)) + req.Header.Set("x-pm-appversion", report.AppVersion) + + logrus.WithField("request", req).Warn("Reporting TLS mismatch") + res, err := (&http.Client{}).Do(req) + if err != nil { + logrus.WithError(err).Error("Failed to report TLS mismatch") + return + } + + logrus.WithField("response", res).Error("Reported TLS mismatch") + + if res.StatusCode != http.StatusOK { + logrus.WithField("status", http.StatusOK).Error("StatusCode was not OK") + } + + _, _ = ioutil.ReadAll(res.Body) + _ = res.Body.Close() +}