diff --git a/Changelog.md b/Changelog.md index 3b506ad6..f4cedb62 100644 --- a/Changelog.md +++ b/Changelog.md @@ -43,6 +43,7 @@ Changelog [format](http://keepachangelog.com/en/1.0.0/) * Logging out of `Clients` is handled/retried asynchronously by `ClientManager` * GODT-265 Alternative Routing v2 (more resiliant to short term connection drops) * GODT-310 Alternative parsing of `References` header (old parsing probably malformed message IDs) +* GODT-320 Only report the same TLS issue once every 24 hours ### Fixed * Use correct binary name when finding location of addcert.scpt diff --git a/pkg/pmapi/dialer_pinning.go b/pkg/pmapi/dialer_pinning.go index 25132f3b..f0abe42d 100644 --- a/pkg/pmapi/dialer_pinning.go +++ b/pkg/pmapi/dialer_pinning.go @@ -20,7 +20,6 @@ package pmapi import ( "crypto/tls" "net" - "time" "github.com/sirupsen/logrus" ) @@ -31,7 +30,7 @@ type PinningTLSDialer struct { dialer TLSDialer // pinChecker is used to check TLS keys of connections. - pinChecker PinChecker + pinChecker pinChecker // tlsIssueNotifier is used to notify something when there is a TLS issue. tlsIssueNotifier func() @@ -55,7 +54,7 @@ type PinningTLSDialer struct { func NewPinningTLSDialer(dialer TLSDialer) *PinningTLSDialer { return &PinningTLSDialer{ dialer: dialer, - pinChecker: NewPinChecker(TrustedAPIPins), + pinChecker: newPinChecker(TrustedAPIPins), log: logrus.WithField("pkg", "pmapi/tls-pinning"), } } @@ -81,16 +80,16 @@ func (p *PinningTLSDialer) DialTLS(network, address string) (conn net.Conn, err return } - if err = p.pinChecker.CheckCertificate(conn); err != nil { + if err = p.pinChecker.checkCertificate(conn); err != nil { if p.tlsIssueNotifier != nil { go p.tlsIssueNotifier() } if tlsConn, ok := conn.(*tls.Conn); ok && p.enableRemoteReporting { - p.pinChecker.ReportCertIssue( + p.pinChecker.reportCertIssue( + TLSReportURI, host, port, - time.Now().Format(time.RFC3339), tlsConn.ConnectionState(), p.appVersion, p.userAgent, diff --git a/pkg/pmapi/pin_checker.go b/pkg/pmapi/pin_checker.go index 8798ffe6..7818e3f5 100644 --- a/pkg/pmapi/pin_checker.go +++ b/pkg/pmapi/pin_checker.go @@ -9,22 +9,30 @@ import ( "encoding/pem" "fmt" "net" + "time" + "github.com/google/go-cmp/cmp" "github.com/sirupsen/logrus" ) -type PinChecker struct { +type pinChecker struct { trustedPins []string + sentReports []sentReport } -func NewPinChecker(trustedPins []string) PinChecker { - return PinChecker{ +type sentReport struct { + r tlsReport + t time.Time +} + +func newPinChecker(trustedPins []string) pinChecker { + return pinChecker{ trustedPins: trustedPins, } } -// CheckCertificate returns whether the connection presents a known TLS certificate. -func (p *PinChecker) CheckCertificate(conn net.Conn) error { +// checkCertificate returns whether the connection presents a known TLS certificate. +func (p *pinChecker) checkCertificate(conn net.Conn) error { connState := conn.(*tls.Conn).ConnectionState() for _, peerCert := range connState.PeerCertificates { @@ -45,8 +53,8 @@ func certFingerprint(cert *x509.Certificate) string { return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:])) } -// ReportCertIssue reports a TLS key mismatch. -func (p *PinChecker) ReportCertIssue(host, port, datetime string, connState tls.ConnectionState, appVersion, userAgent string) { +// reportCertIssue reports a TLS key mismatch. +func (p *pinChecker) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState, appVersion, userAgent string) { var certChain []string if len(connState.VerifiedChains) > 0 { @@ -55,9 +63,38 @@ func (p *PinChecker) ReportCertIssue(host, port, datetime string, connState tls. certChain = marshalCert7468(connState.PeerCertificates) } - report := NewTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion) + r := newTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion) - go postCertIssueReport(report, userAgent) + if !p.hasRecentlySentReport(r) { + p.recordReport(r) + go r.sendReport(remoteURI, userAgent) + } +} + +// hasRecentlySentReport returns whether the report was already sent within the last 24 hours. +func (p *pinChecker) hasRecentlySentReport(report tlsReport) bool { + var validReports []sentReport + + for _, r := range p.sentReports { + if time.Since(r.t) < 24*time.Hour { + validReports = append(validReports, r) + } + } + + p.sentReports = validReports + + for _, r := range p.sentReports { + if cmp.Equal(report, r.r) { + return true + } + } + + return false +} + +// recordReport records the given report and the current time so we can check whether we recently sent this report. +func (p *pinChecker) recordReport(r tlsReport) { + p.sentReports = append(p.sentReports, sentReport{r: r, t: time.Now()}) } func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) { @@ -67,7 +104,7 @@ func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) { Type: "CERTIFICATE", Bytes: cert.Raw, }); err != nil { - logrus.WithField("pkg", "pmapi/tls-pinning").Errorf("encoding TLS cert: %v", err) + logrus.WithField("pkg", "pmapi/tls-pinning").WithError(err).Error("Failed to encode TLS certificate") } pemCerts = append(pemCerts, buffer.String()) buffer.Reset() diff --git a/pkg/pmapi/pin_checker_test.go b/pkg/pmapi/pin_checker_test.go new file mode 100644 index 00000000..29d5b3ae --- /dev/null +++ b/pkg/pmapi/pin_checker_test.go @@ -0,0 +1,41 @@ +package pmapi + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestPinCheckerDoubleReport(t *testing.T) { + reportCounter := 0 + + reportServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reportCounter++ + })) + + pc := newPinChecker(TrustedAPIPins) + + // Report the same issue many times. + for i := 0; i < 10; i++ { + pc.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{}, "3", "useragent") + } + + // We should only report once. + assert.Eventually(t, func() bool { + return reportCounter == 1 + }, time.Second, time.Millisecond) + + // If we then report something else many times. + for i := 0; i < 10; i++ { + pc.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{}, "3", "useragent") + } + + // We should get a second report. + assert.Eventually(t, func() bool { + return reportCounter == 2 + }, time.Second, time.Millisecond) +} diff --git a/pkg/pmapi/proxy.go b/pkg/pmapi/proxy.go index 19ac2b8a..694cd330 100644 --- a/pkg/pmapi/proxy.go +++ b/pkg/pmapi/proxy.go @@ -142,7 +142,6 @@ func (p *proxyProvider) refreshProxyCache() error { } // canReach returns whether we can reach the given url. -// NOTE: we skip cert verification to stop it complaining that cert name doesn't match hostname. func (p *proxyProvider) canReach(url string) bool { if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") { url = "https://" + url diff --git a/pkg/pmapi/tlsreport.go b/pkg/pmapi/tlsreport.go index 2e57e0d4..9c366ac7 100644 --- a/pkg/pmapi/tlsreport.go +++ b/pkg/pmapi/tlsreport.go @@ -30,9 +30,9 @@ var TrustedAPIPins = []string{ // nolint[gochecknoglobals] // TLSReportURI is the address where TLS reports should be sent. const TLSReportURI = "https://reports.protonmail.ch/reports/tls" -// TLSReport is inspired by https://tools.ietf.org/html/rfc7469#section-3. -// When a TLS key mismatch is detected, a TLSReport is posted to TLSReportURI. -type TLSReport struct { +// tlsReport is inspired by https://tools.ietf.org/html/rfc7469#section-3. +// When a TLS key mismatch is detected, a tlsReport is posted to TLSReportURI. +type tlsReport struct { // DateTime of observed pin validation in time.RFC3339 format. DateTime string `json:"date-time"` @@ -88,35 +88,37 @@ type TLSReport struct { AppVersion string `json:"app-version"` } -// NewTLSReport constructs a new TLSreport configured with the given app version and known pinned public keys. -func NewTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report TLSReport) { +// newTLSReport constructs a new tlsReport configured with the given app version and known pinned public keys. +// Temporal things (current date/time) are not set yet -- they are set when sendReport is called. +func newTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report tlsReport) { // If we can't parse the port for whatever reason, it doesn't really matter; we should report anyway. intPort, _ := strconv.Atoi(port) - report = TLSReport{ - Hostname: host, - Port: intPort, - EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339), - IncludeSubdomains: false, - NotedHostname: server, - ValidatedCertificateChain: []string{}, - ServedCertificateChain: certChain, - KnownPins: knownPins, - AppVersion: appVersion, + report = tlsReport{ + Hostname: host, + Port: intPort, + NotedHostname: server, + ServedCertificateChain: certChain, + KnownPins: knownPins, + AppVersion: appVersion, } return } -// postCertIssueReport posts the given TLS report to the standard TLS Report URI. -func postCertIssueReport(report TLSReport, userAgent string) { - b, err := json.Marshal(report) +// sendReport posts the given TLS report to the standard TLS Report URI. +func (r tlsReport) sendReport(uri, userAgent string) { + now := time.Now() + r.DateTime = now.Format(time.RFC3339) + r.EffectiveExpirationDate = now.Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339) + + b, err := json.Marshal(r) if err != nil { logrus.WithError(err).Error("Failed to marshal TLS report") return } - req, err := http.NewRequest("POST", TLSReportURI, bytes.NewReader(b)) + req, err := http.NewRequest("POST", uri, bytes.NewReader(b)) if err != nil { logrus.WithError(err).Error("Failed to create http request") return @@ -125,10 +127,10 @@ func postCertIssueReport(report TLSReport, userAgent string) { req.Header.Add("Content-Type", "application/json") req.Header.Set("User-Agent", userAgent) req.Header.Set("x-pm-apiversion", strconv.Itoa(Version)) - req.Header.Set("x-pm-appversion", report.AppVersion) + req.Header.Set("x-pm-appversion", r.AppVersion) logrus.WithField("request", req).Warn("Reporting TLS mismatch") - res, err := (&http.Client{}).Do(req) + res, err := (&http.Client{Transport: CreateTransportWithDialer(NewBasicTLSDialer())}).Do(req) if err != nil { logrus.WithError(err).Error("Failed to report TLS mismatch") return