feat: tls report cache

This commit is contained in:
James Houlahan
2020-05-14 13:34:48 +02:00
parent 6147c214c3
commit bbf1364e30
6 changed files with 117 additions and 38 deletions

View File

@ -43,6 +43,7 @@ Changelog [format](http://keepachangelog.com/en/1.0.0/)
* Logging out of `Clients` is handled/retried asynchronously by `ClientManager` * Logging out of `Clients` is handled/retried asynchronously by `ClientManager`
* GODT-265 Alternative Routing v2 (more resiliant to short term connection drops) * GODT-265 Alternative Routing v2 (more resiliant to short term connection drops)
* GODT-310 Alternative parsing of `References` header (old parsing probably malformed message IDs) * GODT-310 Alternative parsing of `References` header (old parsing probably malformed message IDs)
* GODT-320 Only report the same TLS issue once every 24 hours
### Fixed ### Fixed
* Use correct binary name when finding location of addcert.scpt * Use correct binary name when finding location of addcert.scpt

View File

@ -20,7 +20,6 @@ package pmapi
import ( import (
"crypto/tls" "crypto/tls"
"net" "net"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -31,7 +30,7 @@ type PinningTLSDialer struct {
dialer TLSDialer dialer TLSDialer
// pinChecker is used to check TLS keys of connections. // pinChecker is used to check TLS keys of connections.
pinChecker PinChecker pinChecker pinChecker
// tlsIssueNotifier is used to notify something when there is a TLS issue. // tlsIssueNotifier is used to notify something when there is a TLS issue.
tlsIssueNotifier func() tlsIssueNotifier func()
@ -55,7 +54,7 @@ type PinningTLSDialer struct {
func NewPinningTLSDialer(dialer TLSDialer) *PinningTLSDialer { func NewPinningTLSDialer(dialer TLSDialer) *PinningTLSDialer {
return &PinningTLSDialer{ return &PinningTLSDialer{
dialer: dialer, dialer: dialer,
pinChecker: NewPinChecker(TrustedAPIPins), pinChecker: newPinChecker(TrustedAPIPins),
log: logrus.WithField("pkg", "pmapi/tls-pinning"), log: logrus.WithField("pkg", "pmapi/tls-pinning"),
} }
} }
@ -81,16 +80,16 @@ func (p *PinningTLSDialer) DialTLS(network, address string) (conn net.Conn, err
return return
} }
if err = p.pinChecker.CheckCertificate(conn); err != nil { if err = p.pinChecker.checkCertificate(conn); err != nil {
if p.tlsIssueNotifier != nil { if p.tlsIssueNotifier != nil {
go p.tlsIssueNotifier() go p.tlsIssueNotifier()
} }
if tlsConn, ok := conn.(*tls.Conn); ok && p.enableRemoteReporting { if tlsConn, ok := conn.(*tls.Conn); ok && p.enableRemoteReporting {
p.pinChecker.ReportCertIssue( p.pinChecker.reportCertIssue(
TLSReportURI,
host, host,
port, port,
time.Now().Format(time.RFC3339),
tlsConn.ConnectionState(), tlsConn.ConnectionState(),
p.appVersion, p.appVersion,
p.userAgent, p.userAgent,

View File

@ -9,22 +9,30 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"net" "net"
"time"
"github.com/google/go-cmp/cmp"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
type PinChecker struct { type pinChecker struct {
trustedPins []string trustedPins []string
sentReports []sentReport
} }
func NewPinChecker(trustedPins []string) PinChecker { type sentReport struct {
return PinChecker{ r tlsReport
t time.Time
}
func newPinChecker(trustedPins []string) pinChecker {
return pinChecker{
trustedPins: trustedPins, trustedPins: trustedPins,
} }
} }
// CheckCertificate returns whether the connection presents a known TLS certificate. // checkCertificate returns whether the connection presents a known TLS certificate.
func (p *PinChecker) CheckCertificate(conn net.Conn) error { func (p *pinChecker) checkCertificate(conn net.Conn) error {
connState := conn.(*tls.Conn).ConnectionState() connState := conn.(*tls.Conn).ConnectionState()
for _, peerCert := range connState.PeerCertificates { for _, peerCert := range connState.PeerCertificates {
@ -45,8 +53,8 @@ func certFingerprint(cert *x509.Certificate) string {
return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:])) return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:]))
} }
// ReportCertIssue reports a TLS key mismatch. // reportCertIssue reports a TLS key mismatch.
func (p *PinChecker) ReportCertIssue(host, port, datetime string, connState tls.ConnectionState, appVersion, userAgent string) { func (p *pinChecker) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState, appVersion, userAgent string) {
var certChain []string var certChain []string
if len(connState.VerifiedChains) > 0 { if len(connState.VerifiedChains) > 0 {
@ -55,9 +63,38 @@ func (p *PinChecker) ReportCertIssue(host, port, datetime string, connState tls.
certChain = marshalCert7468(connState.PeerCertificates) certChain = marshalCert7468(connState.PeerCertificates)
} }
report := NewTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion) r := newTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion)
go postCertIssueReport(report, userAgent) if !p.hasRecentlySentReport(r) {
p.recordReport(r)
go r.sendReport(remoteURI, userAgent)
}
}
// hasRecentlySentReport returns whether the report was already sent within the last 24 hours.
func (p *pinChecker) hasRecentlySentReport(report tlsReport) bool {
var validReports []sentReport
for _, r := range p.sentReports {
if time.Since(r.t) < 24*time.Hour {
validReports = append(validReports, r)
}
}
p.sentReports = validReports
for _, r := range p.sentReports {
if cmp.Equal(report, r.r) {
return true
}
}
return false
}
// recordReport records the given report and the current time so we can check whether we recently sent this report.
func (p *pinChecker) recordReport(r tlsReport) {
p.sentReports = append(p.sentReports, sentReport{r: r, t: time.Now()})
} }
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) { func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
@ -67,7 +104,7 @@ func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: cert.Raw, Bytes: cert.Raw,
}); err != nil { }); err != nil {
logrus.WithField("pkg", "pmapi/tls-pinning").Errorf("encoding TLS cert: %v", err) logrus.WithField("pkg", "pmapi/tls-pinning").WithError(err).Error("Failed to encode TLS certificate")
} }
pemCerts = append(pemCerts, buffer.String()) pemCerts = append(pemCerts, buffer.String())
buffer.Reset() buffer.Reset()

View File

@ -0,0 +1,41 @@
package pmapi
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPinCheckerDoubleReport(t *testing.T) {
reportCounter := 0
reportServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reportCounter++
}))
pc := newPinChecker(TrustedAPIPins)
// Report the same issue many times.
for i := 0; i < 10; i++ {
pc.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{}, "3", "useragent")
}
// We should only report once.
assert.Eventually(t, func() bool {
return reportCounter == 1
}, time.Second, time.Millisecond)
// If we then report something else many times.
for i := 0; i < 10; i++ {
pc.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{}, "3", "useragent")
}
// We should get a second report.
assert.Eventually(t, func() bool {
return reportCounter == 2
}, time.Second, time.Millisecond)
}

View File

@ -142,7 +142,6 @@ func (p *proxyProvider) refreshProxyCache() error {
} }
// canReach returns whether we can reach the given url. // canReach returns whether we can reach the given url.
// NOTE: we skip cert verification to stop it complaining that cert name doesn't match hostname.
func (p *proxyProvider) canReach(url string) bool { func (p *proxyProvider) canReach(url string) bool {
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") { if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
url = "https://" + url url = "https://" + url

View File

@ -30,9 +30,9 @@ var TrustedAPIPins = []string{ // nolint[gochecknoglobals]
// TLSReportURI is the address where TLS reports should be sent. // TLSReportURI is the address where TLS reports should be sent.
const TLSReportURI = "https://reports.protonmail.ch/reports/tls" const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
// TLSReport is inspired by https://tools.ietf.org/html/rfc7469#section-3. // tlsReport is inspired by https://tools.ietf.org/html/rfc7469#section-3.
// When a TLS key mismatch is detected, a TLSReport is posted to TLSReportURI. // When a TLS key mismatch is detected, a tlsReport is posted to TLSReportURI.
type TLSReport struct { type tlsReport struct {
// DateTime of observed pin validation in time.RFC3339 format. // DateTime of observed pin validation in time.RFC3339 format.
DateTime string `json:"date-time"` DateTime string `json:"date-time"`
@ -88,35 +88,37 @@ type TLSReport struct {
AppVersion string `json:"app-version"` AppVersion string `json:"app-version"`
} }
// NewTLSReport constructs a new TLSreport configured with the given app version and known pinned public keys. // newTLSReport constructs a new tlsReport configured with the given app version and known pinned public keys.
func NewTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report TLSReport) { // Temporal things (current date/time) are not set yet -- they are set when sendReport is called.
func newTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report tlsReport) {
// If we can't parse the port for whatever reason, it doesn't really matter; we should report anyway. // If we can't parse the port for whatever reason, it doesn't really matter; we should report anyway.
intPort, _ := strconv.Atoi(port) intPort, _ := strconv.Atoi(port)
report = TLSReport{ report = tlsReport{
Hostname: host, Hostname: host,
Port: intPort, Port: intPort,
EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339), NotedHostname: server,
IncludeSubdomains: false, ServedCertificateChain: certChain,
NotedHostname: server, KnownPins: knownPins,
ValidatedCertificateChain: []string{}, AppVersion: appVersion,
ServedCertificateChain: certChain,
KnownPins: knownPins,
AppVersion: appVersion,
} }
return return
} }
// postCertIssueReport posts the given TLS report to the standard TLS Report URI. // sendReport posts the given TLS report to the standard TLS Report URI.
func postCertIssueReport(report TLSReport, userAgent string) { func (r tlsReport) sendReport(uri, userAgent string) {
b, err := json.Marshal(report) now := time.Now()
r.DateTime = now.Format(time.RFC3339)
r.EffectiveExpirationDate = now.Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339)
b, err := json.Marshal(r)
if err != nil { if err != nil {
logrus.WithError(err).Error("Failed to marshal TLS report") logrus.WithError(err).Error("Failed to marshal TLS report")
return return
} }
req, err := http.NewRequest("POST", TLSReportURI, bytes.NewReader(b)) req, err := http.NewRequest("POST", uri, bytes.NewReader(b))
if err != nil { if err != nil {
logrus.WithError(err).Error("Failed to create http request") logrus.WithError(err).Error("Failed to create http request")
return return
@ -125,10 +127,10 @@ func postCertIssueReport(report TLSReport, userAgent string) {
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent) req.Header.Set("User-Agent", userAgent)
req.Header.Set("x-pm-apiversion", strconv.Itoa(Version)) req.Header.Set("x-pm-apiversion", strconv.Itoa(Version))
req.Header.Set("x-pm-appversion", report.AppVersion) req.Header.Set("x-pm-appversion", r.AppVersion)
logrus.WithField("request", req).Warn("Reporting TLS mismatch") logrus.WithField("request", req).Warn("Reporting TLS mismatch")
res, err := (&http.Client{}).Do(req) res, err := (&http.Client{Transport: CreateTransportWithDialer(NewBasicTLSDialer())}).Do(req)
if err != nil { if err != nil {
logrus.WithError(err).Error("Failed to report TLS mismatch") logrus.WithError(err).Error("Failed to report TLS mismatch")
return return