diff --git a/Changelog.md b/Changelog.md
index 053ff7f4..4a9b6861 100644
--- a/Changelog.md
+++ b/Changelog.md
@@ -26,6 +26,7 @@ Changelog [format](http://keepachangelog.com/en/1.0.0/)
* `ClientManager` is the "one source of truth" for the host URL for all `Client`s
* Alternative Routing is enabled/disabled by `ClientManager`
* Logging out of `Clients` is handled/retried asynchronously by `ClientManager`
+* GODT-265 Alternative Routing v2 (more resiliant to short term connection drops)
### Fixed
diff --git a/pkg/config/pmapi_prod.go b/pkg/config/pmapi_prod.go
index 46eefeb0..d9b27745 100644
--- a/pkg/config/pmapi_prod.go
+++ b/pkg/config/pmapi_prod.go
@@ -40,11 +40,18 @@ func (c *Config) GetAPIConfig() *pmapi.ClientConfig {
}
func (c *Config) GetRoundTripper(cm *pmapi.ClientManager, listener listener.Listener) http.RoundTripper {
- pin := pmapi.NewDialerWithPinning(cm, c.GetAPIConfig().AppVersion)
+ // We use a TLS dialer.
+ basicDialer := pmapi.NewBasicTLSDialer()
- pin.ReportCertIssueLocal = func() {
- listener.Emit(events.TLSCertIssue, "")
- }
+ // We wrap the TLS dialer in a layer which enforces connections to trusted servers.
+ pinningDialer := pmapi.NewPinningTLSDialer(basicDialer, c.GetAPIConfig().AppVersion)
- return pin.TransportWithPinning()
+ // We want any pin mismatches to be communicated back to bridge GUI and reported.
+ pinningDialer.SetTLSIssueNotifier(func() { listener.Emit(events.TLSCertIssue, "") })
+ pinningDialer.SetRemoteTLSIssueReporting(true)
+
+ // We wrap the pinning dialer in a layer which adds "alternative routing" feature.
+ proxyDialer := pmapi.NewProxyTLSDialer(pinningDialer, cm)
+
+ return pmapi.CreateTransportWithDialer(proxyDialer)
}
diff --git a/pkg/pmapi/clientmanager.go b/pkg/pmapi/clientmanager.go
index 03c17eeb..708c7756 100644
--- a/pkg/pmapi/clientmanager.go
+++ b/pkg/pmapi/clientmanager.go
@@ -230,6 +230,14 @@ func (cm *ClientManager) switchToReachableServer() (proxy string, err error) {
return
}
+ // If the chosen proxy is the standard API, we want to use it but still show the troubleshooting screen.
+ if proxy == rootURL {
+ logrus.Info("The standard API is reachable again; connection drop was only intermittent")
+ err = ErrAPINotReachable
+ cm.host = proxy
+ return
+ }
+
logrus.WithField("proxy", proxy).Info("Switching to a proxy")
// If the host is currently the rootURL, it's the first time we are enabling a proxy.
@@ -243,7 +251,7 @@ func (cm *ClientManager) switchToReachableServer() (proxy string, err error) {
cm.host = proxy
- return
+ return proxy, err
}
// GetToken returns the token for the given userID.
diff --git a/pkg/pmapi/dialer.go b/pkg/pmapi/dialer.go
new file mode 100644
index 00000000..c3b725e1
--- /dev/null
+++ b/pkg/pmapi/dialer.go
@@ -0,0 +1,55 @@
+package pmapi
+
+import (
+ "crypto/tls"
+ "net"
+ "net/http"
+ "time"
+)
+
+type TLSDialer interface {
+ DialTLS(network, address string) (conn net.Conn, err error)
+}
+
+// CreateTransportWithDialer creates an http.Transport that uses the given dialer to make TLS connections.
+func CreateTransportWithDialer(dialer TLSDialer) *http.Transport {
+ return &http.Transport{
+ DialTLS: dialer.DialTLS,
+
+ Proxy: http.ProxyFromEnvironment,
+ MaxIdleConns: 100,
+ IdleConnTimeout: 5 * time.Minute,
+ ExpectContinueTimeout: 500 * time.Millisecond,
+
+ // GODT-126: this was initially 10s but logs from users showed a significant number
+ // were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect.
+ // Bumping to 30s for now to avoid this problem.
+ ResponseHeaderTimeout: 30 * time.Second,
+
+ // If we allow up to 30 seconds for response headers, it is reasonable to allow up
+ // to 30 seconds for the TLS handshake to take place.
+ TLSHandshakeTimeout: 30 * time.Second,
+ }
+}
+
+// BasicTLSDialer implements TLSDialer.
+type BasicTLSDialer struct{}
+
+// NewBasicTLSDialer returns a new BasicTLSDialer.
+func NewBasicTLSDialer() *BasicTLSDialer {
+ return &BasicTLSDialer{}
+}
+
+// DialTLS returns a connection to the given address using the given network.
+func (b *BasicTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
+ dialer := &net.Dialer{Timeout: 10 * time.Second}
+
+ var tlsConfig *tls.Config = nil
+
+ // If we are not dialing the standard API then we should skip cert verification checks.
+ if address != rootURL {
+ tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec]
+ }
+
+ return tls.DialWithDialer(dialer, network, address, tlsConfig)
+}
diff --git a/pkg/pmapi/dialer_pinning.go b/pkg/pmapi/dialer_pinning.go
new file mode 100644
index 00000000..358bcba4
--- /dev/null
+++ b/pkg/pmapi/dialer_pinning.go
@@ -0,0 +1,99 @@
+// Copyright (c) 2020 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package pmapi
+
+import (
+ "crypto/tls"
+ "net"
+ "time"
+
+ "github.com/sirupsen/logrus"
+)
+
+// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
+// to report errors if the fingerprint check fails.
+type PinningTLSDialer struct {
+ dialer TLSDialer
+
+ // pinChecker is used to check TLS keys of connections.
+ pinChecker PinChecker
+
+ // appVersion is supplied if there is a TLS mismatch.
+ appVersion string
+
+ // tlsIssueNotifier is used to notify something when there is a TLS issue.
+ tlsIssueNotifier func()
+
+ // enableRemoteReporting instructs the dialer to report TLS mismatches.
+ enableRemoteReporting bool
+
+ // A logger for logging messages.
+ log logrus.FieldLogger
+}
+
+// NewPinningTLSDialer constructs a new dialer which only returns tcp connections to servers
+// which present known certificates.
+// If enabled, it reports any invalid certificates it finds.
+func NewPinningTLSDialer(dialer TLSDialer, appVersion string) *PinningTLSDialer {
+ return &PinningTLSDialer{
+ dialer: dialer,
+ pinChecker: NewPinChecker(TrustedAPIPins),
+ appVersion: appVersion,
+ log: logrus.WithField("pkg", "pmapi/tls-pinning"),
+ }
+}
+
+func (p *PinningTLSDialer) SetTLSIssueNotifier(notifier func()) {
+ p.tlsIssueNotifier = notifier
+}
+
+func (p *PinningTLSDialer) SetRemoteTLSIssueReporting(enabled bool) {
+ p.enableRemoteReporting = enabled
+}
+
+// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
+func (p *PinningTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
+ if conn, err = p.dialer.DialTLS(network, address); err != nil {
+ return
+ }
+
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return
+ }
+
+ if err = p.pinChecker.CheckCertificate(conn); err != nil {
+ if p.tlsIssueNotifier != nil {
+ go p.tlsIssueNotifier()
+ }
+
+ if tlsConn, ok := conn.(*tls.Conn); ok && p.enableRemoteReporting {
+ p.pinChecker.ReportCertIssue(
+ host,
+ port,
+ time.Now().Format(time.RFC3339),
+ tlsConn.ConnectionState(),
+ p.appVersion,
+ )
+ }
+
+ return
+ }
+
+ return
+}
diff --git a/pkg/pmapi/dialer_with_proxy_test.go b/pkg/pmapi/dialer_pinning_test.go
similarity index 72%
rename from pkg/pmapi/dialer_with_proxy_test.go
rename to pkg/pmapi/dialer_pinning_test.go
index 53f526e3..24c5c2f6 100644
--- a/pkg/pmapi/dialer_with_proxy_test.go
+++ b/pkg/pmapi/dialer_pinning_test.go
@@ -30,19 +30,21 @@ var testLiveConfig = &ClientConfig{
ClientID: "Bridge",
}
-func setTestDialerWithPinning(cm *ClientManager) (*int, *DialerWithPinning) {
+func createAndSetPinningDialer(cm *ClientManager) (*int, *PinningTLSDialer) {
called := 0
- p := NewDialerWithPinning(cm, testLiveConfig.AppVersion)
- p.ReportCertIssueLocal = func() { called++ }
- cm.SetRoundTripper(p.TransportWithPinning())
- return &called, p
+
+ dialer := NewPinningTLSDialer(NewBasicTLSDialer(), testLiveConfig.AppVersion)
+ dialer.SetTLSIssueNotifier(func() { called++ })
+ cm.SetRoundTripper(CreateTransportWithDialer(dialer))
+
+ return &called, dialer
}
func TestTLSPinValid(t *testing.T) {
cm := newTestClientManager(testLiveConfig)
cm.host = liveAPI
rootScheme = "https"
- called, _ := setTestDialerWithPinning(cm)
+ called, _ := createAndSetPinningDialer(cm)
client := cm.GetClient("pmapi" + t.Name())
_, err := client.AuthInfo("this.address.is.disabled")
@@ -54,9 +56,9 @@ func TestTLSPinValid(t *testing.T) {
func TestTLSPinBackup(t *testing.T) {
cm := newTestClientManager(testLiveConfig)
cm.host = liveAPI
- called, p := setTestDialerWithPinning(cm)
- p.report.KnownPins[1] = p.report.KnownPins[0]
- p.report.KnownPins[0] = ""
+ called, p := createAndSetPinningDialer(cm)
+ p.pinChecker.trustedPins[1] = p.pinChecker.trustedPins[0]
+ p.pinChecker.trustedPins[0] = ""
client := cm.GetClient("pmapi" + t.Name())
@@ -70,9 +72,9 @@ func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused]
cm := newTestClientManager(testLiveConfig)
cm.host = liveAPI
- called, p := setTestDialerWithPinning(cm)
- for i := 0; i < len(p.report.KnownPins); i++ {
- p.report.KnownPins[i] = "testing"
+ called, p := createAndSetPinningDialer(cm)
+ for i := 0; i < len(p.pinChecker.trustedPins); i++ {
+ p.pinChecker.trustedPins[i] = "testing"
}
client := cm.GetClient("pmapi" + t.Name())
@@ -96,7 +98,7 @@ func _TestTLSPinInvalid(t *testing.T) { // nolint[unused]
}))
defer ts.Close()
- called, _ := setTestDialerWithPinning(cm)
+ called, _ := createAndSetPinningDialer(cm)
client := cm.GetClient("pmapi" + t.Name())
@@ -113,23 +115,23 @@ func _TestTLSPinInvalid(t *testing.T) { // nolint[unused]
func _TestTLSSignedCertWrongPublicKey(t *testing.T) { // nolint[unused]
cm := newTestClientManager(testLiveConfig)
- _, dialer := setTestDialerWithPinning(cm)
- _, err := dialer.dialAndCheckFingerprints("tcp", "rsa4096.badssl.com:443")
+ _, dialer := createAndSetPinningDialer(cm)
+ _, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
Assert(t, err != nil, "expected dial to fail because of wrong public key: ", err.Error())
}
func _TestTLSSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused]
cm := newTestClientManager(testLiveConfig)
- _, dialer := setTestDialerWithPinning(cm)
- dialer.report.KnownPins = append(dialer.report.KnownPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`)
- _, err := dialer.dialAndCheckFingerprints("tcp", "rsa4096.badssl.com:443")
+ _, dialer := createAndSetPinningDialer(cm)
+ dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`)
+ _, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
Assert(t, err == nil, "expected dial to succeed because public key is known and cert is signed by CA: ", err.Error())
}
func _TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused]
cm := newTestClientManager(testLiveConfig)
- _, dialer := setTestDialerWithPinning(cm)
- dialer.report.KnownPins = append(dialer.report.KnownPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
- _, err := dialer.dialAndCheckFingerprints("tcp", "self-signed.badssl.com:443")
+ _, dialer := createAndSetPinningDialer(cm)
+ dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
+ _, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443")
Assert(t, err == nil, "expected dial to succeed because public key is known despite cert being self-signed: ", err.Error())
}
diff --git a/pkg/pmapi/dialer_proxy.go b/pkg/pmapi/dialer_proxy.go
new file mode 100644
index 00000000..11af0128
--- /dev/null
+++ b/pkg/pmapi/dialer_proxy.go
@@ -0,0 +1,40 @@
+package pmapi
+
+import (
+ "net"
+)
+
+// ProxyTLSDialer wraps a TLSDialer to switch to a proxy if the initial dial fails.
+type ProxyTLSDialer struct {
+ dialer TLSDialer
+
+ cm *ClientManager
+}
+
+// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer.
+func NewProxyTLSDialer(dialer TLSDialer, cm *ClientManager) *ProxyTLSDialer {
+ return &ProxyTLSDialer{
+ dialer: dialer,
+ cm: cm,
+ }
+}
+
+// DialTLS dials the given network/address. If it fails, it retries using a proxy.
+func (d *ProxyTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
+ if conn, err = d.dialer.DialTLS(network, address); err == nil {
+ return
+ }
+
+ var proxy string
+
+ if proxy, err = d.cm.switchToReachableServer(); err != nil {
+ return
+ }
+
+ _, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return
+ }
+
+ return d.dialer.DialTLS(network, net.JoinHostPort(proxy, port))
+}
diff --git a/pkg/pmapi/dialer_with_proxy.go b/pkg/pmapi/dialer_with_proxy.go
deleted file mode 100644
index d669b73e..00000000
--- a/pkg/pmapi/dialer_with_proxy.go
+++ /dev/null
@@ -1,374 +0,0 @@
-// Copyright (c) 2020 Proton Technologies AG
-//
-// This file is part of ProtonMail Bridge.
-//
-// ProtonMail Bridge is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// ProtonMail Bridge is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with ProtonMail Bridge. If not, see .
-
-package pmapi
-
-import (
- "bytes"
- "crypto/sha256"
- "crypto/tls"
- "crypto/x509"
- "encoding/base64"
- "encoding/json"
- "encoding/pem"
- "fmt"
- "io/ioutil"
- "net"
- "net/http"
- "strconv"
- "time"
-
- "github.com/sirupsen/logrus"
-)
-
-// TLSReport is inspired by https://tools.ietf.org/html/rfc7469#section-3.
-type TLSReport struct {
- // DateTime of observed pin validation in time.RFC3339 format.
- DateTime string `json:"date-time"`
-
- // Hostname to which the UA made original request that failed pin validation.
- Hostname string `json:"hostname"`
-
- // Port to which the UA made original request that failed pin validation.
- Port int `json:"port"`
-
- // EffectiveExpirationDate for noted pins in time.RFC3339 format.
- EffectiveExpirationDate string `json:"effective-expiration-date"`
-
- // IncludeSubdomains indicates whether or not the UA has noted the
- // includeSubDomains directive for the Known Pinned Host.
- IncludeSubdomains bool `json:"include-subdomains"`
-
- // NotedHostname indicates the hostname that the UA noted when it noted
- // the Known Pinned Host. This field allows operators to understand why
- // Pin Validation was performed for, e.g., foo.example.com when the
- // noted Known Pinned Host was example.com with includeSubDomains set.
- NotedHostname string `json:"noted-hostname"`
-
- // ServedCertificateChain is the certificate chain, as served by
- // the Known Pinned Host during TLS session setup. It is provided as an
- // array of strings; each string pem1, ... pemN is the Privacy-Enhanced
- // Mail (PEM) representation of each X.509 certificate as described in
- // [RFC7468].
- ServedCertificateChain []string `json:"served-certificate-chain"`
-
- // ValidatedCertificateChain is the certificate chain, as
- // constructed by the UA during certificate chain verification. (This
- // may differ from the served-certificate-chain.) It is provided as an
- // array of strings; each string pem1, ... pemN is the PEM
- // representation of each X.509 certificate as described in [RFC7468].
- // UAs that build certificate chains in more than one way during the
- // validation process SHOULD send the last chain built. In this way,
- // they can avoid keeping too much state during the validation process.
- ValidatedCertificateChain []string `json:"validated-certificate-chain"`
-
- // The known-pins are the Pins that the UA has noted for the Known
- // Pinned Host. They are provided as an array of strings with the
- // syntax: known-pin = token "=" quoted-string
- // e.g.:
- // ```
- // "known-pins": [
- // 'pin-sha256="d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="',
- // "pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\""
- // ]
- // ```
- KnownPins []string `json:"known-pins"`
-
- // AppVersion is used to set `x-pm-appversion` json format from datatheorem/TrustKit.
- AppVersion string `json:"app-version"`
-}
-
-// ErrTLSMatch indicates that no TLS fingerprint match could be found.
-var ErrTLSMatch = fmt.Errorf("TLS fingerprint match not found")
-
-// DialerWithPinning will provide dial function which checks the fingerprints of public cert
-// received from contacted server. If no match found among know pinse it will report using
-// ReportCertIssueLocal.
-type DialerWithPinning struct {
- // isReported will stop reporting if true.
- isReported bool
-
- // report stores known pins.
- report TLSReport
-
- // When reportURI is not empty the tls issue report will be send to this URI.
- reportURI string
-
- // ReportCertIssueLocal is used send signal to application about certificate issue.
- // It is used only if set.
- ReportCertIssueLocal func()
-
- // cm is used to find and switch to a proxy if necessary.
- cm *ClientManager
-
- // A logger for logging messages.
- log logrus.FieldLogger
-}
-
-// NewDialerWithPinning constructs a new dialer with pinned certs.
-func NewDialerWithPinning(cm *ClientManager, appVersion string) *DialerWithPinning {
- reportURI := "https://reports.protonmail.ch/reports/tls"
-
- report := TLSReport{
- EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339),
- IncludeSubdomains: false,
- ValidatedCertificateChain: []string{},
- ServedCertificateChain: []string{},
- AppVersion: appVersion,
-
- // NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;)
- KnownPins: []string{
- `pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
- `pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot
- `pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold
- `pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main
- `pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1
- `pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2
- `pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3
- },
- }
-
- log := logrus.WithField("pkg", "pmapi/tls-pinning")
-
- return &DialerWithPinning{
- cm: cm,
- isReported: false,
- reportURI: reportURI,
- report: report,
- log: log,
- }
-}
-
-func (p *DialerWithPinning) reportCertIssue(connState tls.ConnectionState) {
- p.isReported = true
-
- if p.ReportCertIssueLocal != nil {
- go p.ReportCertIssueLocal()
- }
-
- if p.reportURI != "" {
- p.report.NotedHostname = connState.ServerName
- p.report.ServedCertificateChain = marshalCert7468(connState.PeerCertificates)
-
- if len(connState.VerifiedChains) > 0 {
- p.report.ServedCertificateChain = marshalCert7468(
- connState.VerifiedChains[len(connState.VerifiedChains)-1],
- )
- }
-
- go p.reportCertIssueRemote()
- }
-}
-
-func (p *DialerWithPinning) reportCertIssueRemote() {
- b, err := json.Marshal(p.report)
- if err != nil {
- p.log.Errorf("marshal request: %v", err)
- return
- }
-
- req, err := http.NewRequest("POST", p.reportURI, bytes.NewReader(b))
- if err != nil {
- p.log.Errorf("create request: %v", err)
- }
-
- req.Header.Add("Content-Type", "application/json")
- req.Header.Set("User-Agent", CurrentUserAgent)
- req.Header.Set("x-pm-apiversion", strconv.Itoa(Version))
- req.Header.Set("x-pm-appversion", p.report.AppVersion)
-
- p.log.Debugf("report req: %+v\n", req)
-
- c := &http.Client{}
- res, err := c.Do(req)
- p.log.Debugf("res: %+v\nerr: %v", res, err)
- if err != nil {
- return
- }
- _, _ = ioutil.ReadAll(res.Body)
- if res.StatusCode != http.StatusOK {
- p.log.Errorf("response status: %v", res.Status)
- }
- _ = res.Body.Close()
-}
-
-func certFingerprint(cert *x509.Certificate) string {
- hash := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
- return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:]))
-}
-
-func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
- var buffer bytes.Buffer
- for _, cert := range certs {
- if err := pem.Encode(&buffer, &pem.Block{
- Type: "CERTIFICATE",
- Bytes: cert.Raw,
- }); err != nil {
- logrus.WithField("pkg", "pmapi/tls-pinning").Errorf("encoding TLS cert: %v", err)
- }
- pemCerts = append(pemCerts, buffer.String())
- buffer.Reset()
- }
-
- return pemCerts
-}
-
-// TransportWithPinning creates an http.Transport that checks fingerprints when dialing.
-func (p *DialerWithPinning) TransportWithPinning() *http.Transport {
- return &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- DialTLS: p.dialAndCheckFingerprints,
- MaxIdleConns: 100,
- IdleConnTimeout: 5 * time.Minute,
- ExpectContinueTimeout: 500 * time.Millisecond,
-
- // GODT-126: this was initially 10s but logs from users showed a significant number
- // were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect.
- // Bumping to 30s for now to avoid this problem.
- ResponseHeaderTimeout: 30 * time.Second,
-
- // If we allow up to 30 seconds for response headers, it is reasonable to allow up
- // to 30 seconds for the TLS handshake to take place.
- TLSHandshakeTimeout: 30 * time.Second,
- }
-}
-
-// dialAndCheckFingerprint to set as http.Transport.DialTLS.
-//
-// * note that when DialTLS is not nil the Transport.TLSClientConfig and Transport.TLSHandshakeTimeout are ignored.
-// * dialAndCheckFingerprints fails if certificate is not valid (not signed by authority or not matching hostname).
-// * dialAndCheckFingerprints will pass if certificate pin does not have a match, but will send notification using
-// p.ReportCertIssueLocal() and p.reportCertIssueRemote() if they are not nil.
-func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (conn net.Conn, err error) {
- // If DoH is enabled, we hardfail on fingerprint mismatches.
- if p.cm.IsProxyAllowed() && p.isReported {
- return nil, ErrTLSMatch
- }
-
- // Try to dial the given address but use a proxy if necessary.
- if conn, err = p.dialWithProxyFallback(network, address); err != nil {
- return
- }
-
- // If cert issue was already reported, we don't want to check fingerprints anymore.
- if p.isReported {
- return nil, ErrTLSMatch
- }
-
- // Check the cert fingerprint to ensure it is known.
- if err = p.checkFingerprints(conn); err != nil {
- p.log.WithError(err).Error("Error checking cert fingerprints")
- return
- }
-
- return
-}
-
-// dialWithProxyFallback tries to dial the given address but falls back to alternative proxies if need be.
-func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn net.Conn, err error) {
- p.log.Info("Dialing with proxy fallback")
-
- // Try to dial, and if it succeeds, then just return.
- if conn, err = p.dial(network, address); err == nil {
- return
- }
-
- p.log.WithField("address", address).WithError(err).Error("Dialing failed")
-
- host, port, err := net.SplitHostPort(address)
- if err != nil {
- return
- }
-
- // If DoH is not allowed, give up. Or, if we are dialing something other than the API
- // (e.g. we dial protonmail.com/... to check for updates), there's also no point in
- // continuing since a proxy won't help us reach that.
- if !p.cm.IsProxyAllowed() || host != p.cm.getHost() {
- p.log.WithField("address", address).Debug("Aborting dial, cannot switch to a proxy")
- return
- }
-
- // Switch to a proxy and retry the dial.
- proxy, err := p.cm.switchToReachableServer()
- if err != nil {
- return
- }
-
- proxyAddress := net.JoinHostPort(proxy, port)
-
- p.log.WithField("address", proxyAddress).Debug("Trying dial again using a proxy")
-
- return p.dial(network, proxyAddress)
-}
-
-// dial returns a connection to the given address using the given network.
-func (p *DialerWithPinning) dial(network, address string) (conn net.Conn, err error) {
- var port string
- if p.report.Hostname, port, err = net.SplitHostPort(address); err != nil {
- return
- }
- if p.report.Port, err = strconv.Atoi(port); err != nil {
- return
- }
- p.report.DateTime = time.Now().Format(time.RFC3339)
-
- dialer := &net.Dialer{Timeout: 10 * time.Second}
-
- // If we are not dialing the standard API then we should skip cert verification checks.
- var tlsConfig *tls.Config = nil
- if address != rootURL {
- tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec]
- }
-
- return tls.DialWithDialer(dialer, network, address, tlsConfig)
-}
-
-func (p *DialerWithPinning) checkFingerprints(conn net.Conn) (err error) {
- if !checkTLSCerts {
- return
- }
-
- connState := conn.(*tls.Conn).ConnectionState()
-
- hasFingerprintMatch := false
- for _, peerCert := range connState.PeerCertificates {
- fingerprint := certFingerprint(peerCert)
-
- for i, pin := range p.report.KnownPins {
- if pin == fingerprint {
- hasFingerprintMatch = true
-
- if i != 0 {
- p.log.Warnf("Matched fingerprint (%q) was not primary pinned key (was key #%d)", fingerprint, i)
- }
-
- break
- }
- }
-
- if hasFingerprintMatch {
- break
- }
- }
-
- if !hasFingerprintMatch {
- p.reportCertIssue(connState)
- return ErrTLSMatch
- }
-
- return err
-}
diff --git a/pkg/pmapi/pin_checker.go b/pkg/pmapi/pin_checker.go
new file mode 100644
index 00000000..be8cc699
--- /dev/null
+++ b/pkg/pmapi/pin_checker.go
@@ -0,0 +1,77 @@
+package pmapi
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "encoding/pem"
+ "fmt"
+ "net"
+
+ "github.com/sirupsen/logrus"
+)
+
+type PinChecker struct {
+ trustedPins []string
+}
+
+func NewPinChecker(trustedPins []string) PinChecker {
+ return PinChecker{
+ trustedPins: trustedPins,
+ }
+}
+
+// CheckCertificate returns whether the connection presents a known TLS certificate.
+func (p *PinChecker) CheckCertificate(conn net.Conn) error {
+ connState := conn.(*tls.Conn).ConnectionState()
+
+ for _, peerCert := range connState.PeerCertificates {
+ fingerprint := certFingerprint(peerCert)
+
+ for _, pin := range p.trustedPins {
+ if pin == fingerprint {
+ return nil
+ }
+ }
+ }
+
+ return ErrTLSMismatch
+}
+
+func certFingerprint(cert *x509.Certificate) string {
+ hash := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
+ return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:]))
+}
+
+// ReportCertIssue reports a TLS key mismatch.
+func (p *PinChecker) ReportCertIssue(host, port, datetime string, connState tls.ConnectionState, appVersion string) {
+ var certChain []string
+
+ if len(connState.VerifiedChains) > 0 {
+ certChain = marshalCert7468(connState.VerifiedChains[len(connState.VerifiedChains)-1])
+ } else {
+ certChain = marshalCert7468(connState.PeerCertificates)
+ }
+
+ report := NewTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion)
+
+ go postCertIssueReport(report)
+}
+
+func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
+ var buffer bytes.Buffer
+ for _, cert := range certs {
+ if err := pem.Encode(&buffer, &pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: cert.Raw,
+ }); err != nil {
+ logrus.WithField("pkg", "pmapi/tls-pinning").Errorf("encoding TLS cert: %v", err)
+ }
+ pemCerts = append(pemCerts, buffer.String())
+ buffer.Reset()
+ }
+
+ return pemCerts
+}
diff --git a/pkg/pmapi/proxy.go b/pkg/pmapi/proxy.go
index 5bf3cadd..ff147123 100644
--- a/pkg/pmapi/proxy.go
+++ b/pkg/pmapi/proxy.go
@@ -18,7 +18,6 @@
package pmapi
import (
- "crypto/tls"
"encoding/base64"
"strings"
"time"
@@ -85,7 +84,8 @@ func (p *proxyProvider) findReachableServer() (proxy string, err error) {
errResult := make(chan error)
go func() {
if err = p.refreshProxyCache(); err != nil {
- logrus.WithError(err).Warn("Failed to refresh proxy cache, cache may be out of date")
+ errResult <- errors.Wrap(err, "failed to refresh proxy cache")
+ return
}
// We want to switch back to the rootURL if possible.
@@ -144,10 +144,12 @@ func (p *proxyProvider) canReach(url string) bool {
url = "https://" + url
}
+ pinningDialer := NewPinningTLSDialer(NewBasicTLSDialer(), "")
+
pinger := resty.New().
SetHostURL(url).
SetTimeout(p.lookupTimeout).
- SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) // nolint[gosec]
+ SetTransport(CreateTransportWithDialer(pinningDialer))
if _, err := pinger.R().Get("/tests/ping"); err != nil {
logrus.WithField("proxy", url).WithError(err).Warn("Failed to ping proxy")
diff --git a/pkg/pmapi/proxy_test.go b/pkg/pmapi/proxy_test.go
index a8e709af..30495b04 100644
--- a/pkg/pmapi/proxy_test.go
+++ b/pkg/pmapi/proxy_test.go
@@ -18,6 +18,7 @@
package pmapi
import (
+ "crypto/tls"
"net/http"
"net/http/httptest"
"testing"
@@ -32,12 +33,114 @@ const (
TestGoogleProvider = "https://dns.google/dns-query"
)
+// getTrustedServer returns a server and sets its public key as one of the pinned ones.
+func getTrustedServer() *httptest.Server {
+ proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
+
+ pin := certFingerprint(proxy.Certificate())
+ TrustedAPIPins = append(TrustedAPIPins, pin)
+
+ return proxy
+}
+
+// server.crt
+const servercrt = `
+-----BEGIN CERTIFICATE-----
+MIIE5TCCA82gAwIBAgIJAKsmhcMFGfGcMA0GCSqGSIb3DQEBCwUAMIGsMQswCQYD
+VQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzARBgNVBAcMClJhbmRvbUNp
+dHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEfMB0GA1UECwwWUmFuZG9t
+T3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYRaGVsbG9AZXhhbXBsZS5j
+b20xEjAQBgNVBAMMCTEyNy4wLjAuMTAeFw0yMDA0MjQxMzI3MzdaFw0yMTA5MDYx
+MzI3MzdaMIGsMQswCQYDVQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzAR
+BgNVBAcMClJhbmRvbUNpdHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEf
+MB0GA1UECwwWUmFuZG9tT3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYR
+aGVsbG9AZXhhbXBsZS5jb20xEjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZI
+hvcNAQEBBQADggEPADCCAQoCggEBANAnYyqhosWwNzGjBwSwmDUINOaPs4TSTgKt
+r6CE01atxAWzWUCyYqnQ4fPe5q2tx5t/VrmnTNpzycammKJszGLlmj9DFxSiYVw2
+pTTK3DBWFkfTwxq98mM7wMnCWy1T2L2pmuYjnd7Pa6pQa9OHYoJwRzlIl2Q3YVdM
+GIBDbkW728A1dcelkIdFpv3r3ayTZv01vU8JMXd4PLHwXU0x0hHlH52+kx+9Ndru
+rdqqV6LqVfNlSR1jFZkwLBBqvh3XrJRD9Q01EAX6m+ufZ0yq8mK9ifMRtwQet10c
+kKMnx63MwvxDFmqrBj4HMtIRUpK+LBDs1ke7DvS0eLqaojWl28ECAwEAAaOCAQYw
+ggECMIHLBgNVHSMEgcMwgcChgbKkga8wgawxCzAJBgNVBAYTAlVTMRQwEgYDVQQI
+DAtSYW5kb21TdGF0ZTETMBEGA1UEBwwKUmFuZG9tQ2l0eTEbMBkGA1UECgwSUmFu
+ZG9tT3JnYW5pemF0aW9uMR8wHQYDVQQLDBZSYW5kb21Pcmdhbml6YXRpb25Vbml0
+MSAwHgYJKoZIhvcNAQkBFhFoZWxsb0BleGFtcGxlLmNvbTESMBAGA1UEAwwJMTI3
+LjAuMC4xggkAvCxbs152YckwCQYDVR0TBAIwADALBgNVHQ8EBAMCBPAwGgYDVR0R
+BBMwEYIJMTI3LjAuMC4xhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAC7ZycZMZ5
+L+cjIpwSj0cemLkVD+kcFUCkI7ket5gbX1PmavmnpuFl9Sru0eJ5wyJ+97MQElPA
+CNFgXoX7DbJWkcd/LSksvZoJnpc1sTqFKMWFmOUxmUD62lCacuhqE27ZTThQ/53P
+3doLa74rKzUqlPI8OL4R34FY2deL7t5l2KSnpf7CKNeF5bkinAsn6NBqyZs2KPmg
+yT1/POdlRewzGSqBTMdktNQ4vKSfdFjcfVeo8PSHBgbGXZ5KoHZ6R6DNJehEh27l
+z3OteROLGoii+w3OllLq6JATif2MDIbH0s/KjGjbXSSGbM/rZu5eBZm5/vksGAzc
+u53wgIhCJGuX
+-----END CERTIFICATE-----
+`
+
+const serverkey = `
+-----BEGIN PRIVATE KEY-----
+MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDQJ2MqoaLFsDcx
+owcEsJg1CDTmj7OE0k4Cra+ghNNWrcQFs1lAsmKp0OHz3uatrcebf1a5p0zac8nG
+ppiibMxi5Zo/QxcUomFcNqU0ytwwVhZH08MavfJjO8DJwlstU9i9qZrmI53ez2uq
+UGvTh2KCcEc5SJdkN2FXTBiAQ25Fu9vANXXHpZCHRab9692sk2b9Nb1PCTF3eDyx
+8F1NMdIR5R+dvpMfvTXa7q3aqlei6lXzZUkdYxWZMCwQar4d16yUQ/UNNRAF+pvr
+n2dMqvJivYnzEbcEHrddHJCjJ8etzML8QxZqqwY+BzLSEVKSviwQ7NZHuw70tHi6
+mqI1pdvBAgMBAAECggEAOqqPOYm63arPs462QK0hCPlaJ41i1FGNqRWYxU4KXoi1
+EcI9qo1cX24+8MPnEhZDhuD56XNsprkxqmpz5Htzk4AQ3DmlfKxTcnD4WQu/yWPJ
+/c6CU7wrX6qMqJC9r+XM1Y/C15A8Q3sEZkkqSsECk67fdBawjI9LQRZyZVwb7U0F
+qtvbKM7VQA6hrgdSmXWJ+spp5yymVFF22Ssz31SSbCI93bnp3mukRCKWdRmA9pmT
+VXa0HzJ5p70WC+Se9nA/1riWGKt4HCmjVeEtZuiwaUTlXDSeYpu2e4QrX1OnUXBu
+Z7yfviTqA8o7KfiA6urumFbAMJcibxkWJoWacc5tTQKBgQD39ZdtNz8B6XJy7f5h
+bo9Ag9OrkVX+HITQyWKpcCDba9SuIX3/F++2AK4oeJ3aHKMJWiP19hQvGS1xE67X
+TKejOsQxORn6nAYQpFd3AOBOtKAC+VQITBqlfq2ukGmvcQ1O31hMOFbZagFA5cpU
+LYb9VVDsZzhM7CccIn/EGEZjgwKBgQDW51rUA2S9naV/iEGhw1tuhoQ5OADD/n8f
+pPIkbGxmACDaX/7jt+UwlDU0EsI+aBlJUDqGiEZ5z3UPmaSJUdfRCeJEdKIe1GLm
+nqF3sF6Aq+S/79v/wKYn+MHcoiWog5n3McLzZ3+0rwrhMREjE2eWPwVHz/jJIFP3
+Pp3+UZVsawKBgB4Az5PdjXgzwS968L7lW9wYl3I5Iciftsp0s8WA1dj3EUMItnA5
+ez3wkyI+hgswT+H/0D4gyoxwZXk7Qnq2wcoUgEzcdfJHEszMtfCmYH3liT8S4EIo
+w0inLWjj/IXIDi4vBEYkww2HsCMkKvlIkP7yZdpVGxDjuk/DNOaLcWj1AoGAXuyK
+PiPRl7/Onmp9MwqrlEJunSeTjv8W/89H9ba+mr9rw4mreMJ9xdtxNLMkgZRRtwRt
+FYeUObHdLyradp1kCr2m6D3sblm55cwj3k5VL9i9jdpQ/sMFoZpLZz1oDOs0Uu/0
+ALeyvQikcZvOygOEOeVUW8gNSCmzbP6HoxI+QkkCgYBCI6oL4GPcPPqzd+2djbOD
+z3rVUyHzYc1KUcBixK/uaRQKM886k4CL8/GvbHHI/yoZ7xWJGnBi59DtpqnGTZJ2
+FDJwYIlQKhZmsyVcZu/4smsaejGnHn/liksVlgesSwCtOrsd2AC8fBXSyrTWJx8o
+vwRMog6lPhlRhHh/FZ43Cg==
+-----END PRIVATE KEY-----
+`
+
+// getUntrustedServer returns a server but it doesn't add its public key to the list of pinned ones.
+func getUntrustedServer() *httptest.Server {
+ server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
+
+ cert, err := tls.X509KeyPair([]byte(servercrt), []byte(serverkey))
+ if err != nil {
+ panic(err)
+ }
+ server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
+
+ server.StartTLS()
+ return server
+}
+
+// closeServer closes the given server. If it is a trusted server, its cert is removed from the trusted public keys.
+func closeServer(server *httptest.Server) {
+ pin := certFingerprint(server.Certificate())
+
+ for i := range TrustedAPIPins {
+ if TrustedAPIPins[i] == pin {
+ TrustedAPIPins = append(TrustedAPIPins[:i], TrustedAPIPins[i:]...)
+ break
+ }
+ }
+
+ server.Close()
+}
+
func TestProxyProvider_FindProxy(t *testing.T) {
blockAPI()
defer unblockAPI()
- proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- defer proxy.Close()
+ proxy := getTrustedServer()
+ defer closeServer(proxy)
p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
@@ -51,34 +154,72 @@ func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
blockAPI()
defer unblockAPI()
- badProxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- goodProxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
+ reachableProxy := getTrustedServer()
+ defer closeServer(reachableProxy)
- // Close the bad proxy first so it isn't reachable; we should then choose the good proxy.
- badProxy.Close()
- defer goodProxy.Close()
+ // We actually close the unreachable proxy straight away rather than deferring the closure.
+ unreachableProxy := getTrustedServer()
+ closeServer(unreachableProxy)
p := newProxyProvider([]string{"not used"}, "not used")
- p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, goodProxy.URL}, nil }
+ p.dohLookup = func(q, p string) ([]string, error) { return []string{reachableProxy.URL, unreachableProxy.URL}, nil }
url, err := p.findReachableServer()
require.NoError(t, err)
- require.Equal(t, goodProxy.URL, url)
+ require.Equal(t, reachableProxy.URL, url)
+}
+
+func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) {
+ blockAPI()
+ defer unblockAPI()
+
+ trustedProxy := getTrustedServer()
+ defer closeServer(trustedProxy)
+
+ untrustedProxy := getUntrustedServer()
+ defer closeServer(untrustedProxy)
+
+ p := newProxyProvider([]string{"not used"}, "not used")
+ p.dohLookup = func(q, p string) ([]string, error) { return []string{untrustedProxy.URL, trustedProxy.URL}, nil }
+
+ url, err := p.findReachableServer()
+ require.NoError(t, err)
+ require.Equal(t, trustedProxy.URL, url)
}
func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
blockAPI()
defer unblockAPI()
- badProxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- anotherBadProxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
+ unreachableProxy1 := getTrustedServer()
+ closeServer(unreachableProxy1)
- // Close the proxies to simulate them not being reachable.
- badProxy.Close()
- anotherBadProxy.Close()
+ unreachableProxy2 := getTrustedServer()
+ closeServer(unreachableProxy2)
p := newProxyProvider([]string{"not used"}, "not used")
- p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, anotherBadProxy.URL}, nil }
+ p.dohLookup = func(q, p string) ([]string, error) {
+ return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil
+ }
+
+ _, err := p.findReachableServer()
+ require.Error(t, err)
+}
+
+func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
+ blockAPI()
+ defer unblockAPI()
+
+ untrustedProxy1 := getUntrustedServer()
+ defer closeServer(untrustedProxy1)
+
+ untrustedProxy2 := getUntrustedServer()
+ defer closeServer(untrustedProxy2)
+
+ p := newProxyProvider([]string{"not used"}, "not used")
+ p.dohLookup = func(q, p string) ([]string, error) {
+ return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil
+ }
_, err := p.findReachableServer()
require.Error(t, err)
@@ -88,9 +229,6 @@ func TestProxyProvider_FindProxy_LookupTimeout(t *testing.T) {
blockAPI()
defer unblockAPI()
- proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- defer proxy.Close()
-
p := newProxyProvider([]string{"not used"}, "not used")
p.lookupTimeout = time.Second
p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
@@ -124,17 +262,17 @@ func TestProxyProvider_UseProxy(t *testing.T) {
cm := newTestClientManager(testClientConfig)
- proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- defer proxy.Close()
+ trustedProxy := getTrustedServer()
+ defer closeServer(trustedProxy)
p := newProxyProvider([]string{"not used"}, "not used")
cm.proxyProvider = p
- p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
+ p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
url, err := cm.switchToReachableServer()
require.NoError(t, err)
- require.Equal(t, proxy.URL, url)
- require.Equal(t, proxy.URL, cm.getHost())
+ require.Equal(t, trustedProxy.URL, url)
+ require.Equal(t, trustedProxy.URL, cm.getHost())
}
func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
@@ -143,12 +281,12 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
cm := newTestClientManager(testClientConfig)
- proxy1 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- defer proxy1.Close()
- proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- defer proxy2.Close()
- proxy3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- defer proxy3.Close()
+ proxy1 := getTrustedServer()
+ defer closeServer(proxy1)
+ proxy2 := getTrustedServer()
+ defer closeServer(proxy2)
+ proxy3 := getTrustedServer()
+ defer closeServer(proxy3)
p := newProxyProvider([]string{"not used"}, "not used")
cm.proxyProvider = p
@@ -184,18 +322,18 @@ func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) {
cm := newTestClientManager(testClientConfig)
- proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- defer proxy.Close()
+ trustedProxy := getTrustedServer()
+ defer closeServer(trustedProxy)
p := newProxyProvider([]string{"not used"}, "not used")
cm.proxyProvider = p
cm.proxyUseDuration = time.Second
- p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
+ p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
url, err := cm.switchToReachableServer()
require.NoError(t, err)
- require.Equal(t, proxy.URL, url)
- require.Equal(t, proxy.URL, cm.getHost())
+ require.Equal(t, trustedProxy.URL, url)
+ require.Equal(t, trustedProxy.URL, cm.getHost())
time.Sleep(2 * time.Second)
require.Equal(t, rootURL, cm.getHost())
@@ -207,26 +345,27 @@ func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachab
cm := newTestClientManager(testClientConfig)
- proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- defer proxy.Close()
+ trustedProxy := getTrustedServer()
p := newProxyProvider([]string{"not used"}, "not used")
cm.proxyProvider = p
- p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
+ p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
url, err := cm.switchToReachableServer()
require.NoError(t, err)
- require.Equal(t, proxy.URL, url)
- require.Equal(t, proxy.URL, cm.getHost())
+ require.Equal(t, trustedProxy.URL, url)
+ require.Equal(t, trustedProxy.URL, cm.getHost())
// Simulate that the proxy stops working and that the standard api is reachable again.
- proxy.Close()
+ closeServer(trustedProxy)
unblockAPI()
time.Sleep(proxyLookupWait)
// We should now find the original API URL if it is working again.
+ // The error should be ErrAPINotReachable because the connection dropped intermittently but
+ // the original API is now reachable (see Alternative-Routing-v2 spec for details).
url, err = cm.switchToReachableServer()
- require.NoError(t, err)
+ require.EqualError(t, err, ErrAPINotReachable.Error())
require.Equal(t, rootURL, url)
require.Equal(t, rootURL, cm.getHost())
}
@@ -237,10 +376,11 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl
cm := newTestClientManager(testClientConfig)
- proxy1 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- defer proxy1.Close()
- proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- defer proxy2.Close()
+ // proxy1 is closed later in this test so we don't defer it here.
+ proxy1 := getTrustedServer()
+
+ proxy2 := getTrustedServer()
+ defer closeServer(proxy2)
p := newProxyProvider([]string{"not used"}, "not used")
cm.proxyProvider = p
diff --git a/pkg/pmapi/tlsreport.go b/pkg/pmapi/tlsreport.go
new file mode 100644
index 00000000..058ecac9
--- /dev/null
+++ b/pkg/pmapi/tlsreport.go
@@ -0,0 +1,145 @@
+package pmapi
+
+import (
+ "bytes"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "strconv"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/sirupsen/logrus"
+)
+
+// ErrTLSMismatch indicates that no TLS fingerprint match could be found.
+var ErrTLSMismatch = errors.New("no TLS fingerprint match found")
+
+// TrustedAPIPins contains trusted public keys of the protonmail API and proxies.
+// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;)
+var TrustedAPIPins = []string{ // nolint[gochecknoglobals]
+ `pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
+ `pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot
+ `pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold
+ `pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main
+ `pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1
+ `pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2
+ `pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3
+}
+
+// TLSReportURI is the address where TLS reports should be sent.
+const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
+
+// TLSReport is inspired by https://tools.ietf.org/html/rfc7469#section-3.
+// When a TLS key mismatch is detected, a TLSReport is posted to TLSReportURI.
+type TLSReport struct {
+ // DateTime of observed pin validation in time.RFC3339 format.
+ DateTime string `json:"date-time"`
+
+ // Hostname to which the UA made original request that failed pin validation.
+ Hostname string `json:"hostname"`
+
+ // Port to which the UA made original request that failed pin validation.
+ Port int `json:"port"`
+
+ // EffectiveExpirationDate for noted pins in time.RFC3339 format.
+ EffectiveExpirationDate string `json:"effective-expiration-date"`
+
+ // IncludeSubdomains indicates whether or not the UA has noted the
+ // includeSubDomains directive for the Known Pinned Host.
+ IncludeSubdomains bool `json:"include-subdomains"`
+
+ // NotedHostname indicates the hostname that the UA noted when it noted
+ // the Known Pinned Host. This field allows operators to understand why
+ // Pin Validation was performed for, e.g., foo.example.com when the
+ // noted Known Pinned Host was example.com with includeSubDomains set.
+ NotedHostname string `json:"noted-hostname"`
+
+ // ServedCertificateChain is the certificate chain, as served by
+ // the Known Pinned Host during TLS session setup. It is provided as an
+ // array of strings; each string pem1, ... pemN is the Privacy-Enhanced
+ // Mail (PEM) representation of each X.509 certificate as described in
+ // [RFC7468].
+ ServedCertificateChain []string `json:"served-certificate-chain"`
+
+ // ValidatedCertificateChain is the certificate chain, as
+ // constructed by the UA during certificate chain verification. (This
+ // may differ from the served-certificate-chain.) It is provided as an
+ // array of strings; each string pem1, ... pemN is the PEM
+ // representation of each X.509 certificate as described in [RFC7468].
+ // UAs that build certificate chains in more than one way during the
+ // validation process SHOULD send the last chain built. In this way,
+ // they can avoid keeping too much state during the validation process.
+ ValidatedCertificateChain []string `json:"validated-certificate-chain"`
+
+ // The known-pins are the Pins that the UA has noted for the Known
+ // Pinned Host. They are provided as an array of strings with the
+ // syntax: known-pin = token "=" quoted-string
+ // e.g.:
+ // ```
+ // "known-pins": [
+ // 'pin-sha256="d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="',
+ // "pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\""
+ // ]
+ // ```
+ KnownPins []string `json:"known-pins"`
+
+ // AppVersion is used to set `x-pm-appversion` json format from datatheorem/TrustKit.
+ AppVersion string `json:"app-version"`
+}
+
+// NewTLSReport constructs a new TLSreport configured with the given app version and known pinned public keys.
+func NewTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report TLSReport) {
+ // If we can't parse the port for whatever reason, it doesn't really matter; we should report anyway.
+ intPort, _ := strconv.Atoi(port)
+
+ report = TLSReport{
+ Hostname: host,
+ Port: intPort,
+ EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339),
+ IncludeSubdomains: false,
+ NotedHostname: server,
+ ValidatedCertificateChain: []string{},
+ ServedCertificateChain: certChain,
+ KnownPins: knownPins,
+ AppVersion: appVersion,
+ }
+
+ return
+}
+
+// postCertIssueReport posts the given TLS report to the standard TLS Report URI.
+func postCertIssueReport(report TLSReport) {
+ b, err := json.Marshal(report)
+ if err != nil {
+ logrus.WithError(err).Error("Failed to marshal TLS report")
+ return
+ }
+
+ req, err := http.NewRequest("POST", TLSReportURI, bytes.NewReader(b))
+ if err != nil {
+ logrus.WithError(err).Error("Failed to create http request")
+ return
+ }
+
+ req.Header.Add("Content-Type", "application/json")
+ req.Header.Set("User-Agent", CurrentUserAgent)
+ req.Header.Set("x-pm-apiversion", strconv.Itoa(Version))
+ req.Header.Set("x-pm-appversion", report.AppVersion)
+
+ logrus.WithField("request", req).Warn("Reporting TLS mismatch")
+ res, err := (&http.Client{}).Do(req)
+ if err != nil {
+ logrus.WithError(err).Error("Failed to report TLS mismatch")
+ return
+ }
+
+ logrus.WithField("response", res).Error("Reported TLS mismatch")
+
+ if res.StatusCode != http.StatusOK {
+ logrus.WithField("status", http.StatusOK).Error("StatusCode was not OK")
+ }
+
+ _, _ = ioutil.ReadAll(res.Body)
+ _ = res.Body.Close()
+}