mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-16 07:06:45 +00:00
feat: tls report cache
This commit is contained in:
@ -43,6 +43,7 @@ Changelog [format](http://keepachangelog.com/en/1.0.0/)
|
|||||||
* Logging out of `Clients` is handled/retried asynchronously by `ClientManager`
|
* Logging out of `Clients` is handled/retried asynchronously by `ClientManager`
|
||||||
* GODT-265 Alternative Routing v2 (more resiliant to short term connection drops)
|
* GODT-265 Alternative Routing v2 (more resiliant to short term connection drops)
|
||||||
* GODT-310 Alternative parsing of `References` header (old parsing probably malformed message IDs)
|
* GODT-310 Alternative parsing of `References` header (old parsing probably malformed message IDs)
|
||||||
|
* GODT-320 Only report the same TLS issue once every 24 hours
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
* Use correct binary name when finding location of addcert.scpt
|
* Use correct binary name when finding location of addcert.scpt
|
||||||
|
|||||||
@ -20,7 +20,6 @@ package pmapi
|
|||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@ -31,7 +30,7 @@ type PinningTLSDialer struct {
|
|||||||
dialer TLSDialer
|
dialer TLSDialer
|
||||||
|
|
||||||
// pinChecker is used to check TLS keys of connections.
|
// pinChecker is used to check TLS keys of connections.
|
||||||
pinChecker PinChecker
|
pinChecker pinChecker
|
||||||
|
|
||||||
// tlsIssueNotifier is used to notify something when there is a TLS issue.
|
// tlsIssueNotifier is used to notify something when there is a TLS issue.
|
||||||
tlsIssueNotifier func()
|
tlsIssueNotifier func()
|
||||||
@ -55,7 +54,7 @@ type PinningTLSDialer struct {
|
|||||||
func NewPinningTLSDialer(dialer TLSDialer) *PinningTLSDialer {
|
func NewPinningTLSDialer(dialer TLSDialer) *PinningTLSDialer {
|
||||||
return &PinningTLSDialer{
|
return &PinningTLSDialer{
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
pinChecker: NewPinChecker(TrustedAPIPins),
|
pinChecker: newPinChecker(TrustedAPIPins),
|
||||||
log: logrus.WithField("pkg", "pmapi/tls-pinning"),
|
log: logrus.WithField("pkg", "pmapi/tls-pinning"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -81,16 +80,16 @@ func (p *PinningTLSDialer) DialTLS(network, address string) (conn net.Conn, err
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = p.pinChecker.CheckCertificate(conn); err != nil {
|
if err = p.pinChecker.checkCertificate(conn); err != nil {
|
||||||
if p.tlsIssueNotifier != nil {
|
if p.tlsIssueNotifier != nil {
|
||||||
go p.tlsIssueNotifier()
|
go p.tlsIssueNotifier()
|
||||||
}
|
}
|
||||||
|
|
||||||
if tlsConn, ok := conn.(*tls.Conn); ok && p.enableRemoteReporting {
|
if tlsConn, ok := conn.(*tls.Conn); ok && p.enableRemoteReporting {
|
||||||
p.pinChecker.ReportCertIssue(
|
p.pinChecker.reportCertIssue(
|
||||||
|
TLSReportURI,
|
||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
time.Now().Format(time.RFC3339),
|
|
||||||
tlsConn.ConnectionState(),
|
tlsConn.ConnectionState(),
|
||||||
p.appVersion,
|
p.appVersion,
|
||||||
p.userAgent,
|
p.userAgent,
|
||||||
|
|||||||
@ -9,22 +9,30 @@ import (
|
|||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PinChecker struct {
|
type pinChecker struct {
|
||||||
trustedPins []string
|
trustedPins []string
|
||||||
|
sentReports []sentReport
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPinChecker(trustedPins []string) PinChecker {
|
type sentReport struct {
|
||||||
return PinChecker{
|
r tlsReport
|
||||||
|
t time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPinChecker(trustedPins []string) pinChecker {
|
||||||
|
return pinChecker{
|
||||||
trustedPins: trustedPins,
|
trustedPins: trustedPins,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckCertificate returns whether the connection presents a known TLS certificate.
|
// checkCertificate returns whether the connection presents a known TLS certificate.
|
||||||
func (p *PinChecker) CheckCertificate(conn net.Conn) error {
|
func (p *pinChecker) checkCertificate(conn net.Conn) error {
|
||||||
connState := conn.(*tls.Conn).ConnectionState()
|
connState := conn.(*tls.Conn).ConnectionState()
|
||||||
|
|
||||||
for _, peerCert := range connState.PeerCertificates {
|
for _, peerCert := range connState.PeerCertificates {
|
||||||
@ -45,8 +53,8 @@ func certFingerprint(cert *x509.Certificate) string {
|
|||||||
return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:]))
|
return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:]))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReportCertIssue reports a TLS key mismatch.
|
// reportCertIssue reports a TLS key mismatch.
|
||||||
func (p *PinChecker) ReportCertIssue(host, port, datetime string, connState tls.ConnectionState, appVersion, userAgent string) {
|
func (p *pinChecker) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState, appVersion, userAgent string) {
|
||||||
var certChain []string
|
var certChain []string
|
||||||
|
|
||||||
if len(connState.VerifiedChains) > 0 {
|
if len(connState.VerifiedChains) > 0 {
|
||||||
@ -55,9 +63,38 @@ func (p *PinChecker) ReportCertIssue(host, port, datetime string, connState tls.
|
|||||||
certChain = marshalCert7468(connState.PeerCertificates)
|
certChain = marshalCert7468(connState.PeerCertificates)
|
||||||
}
|
}
|
||||||
|
|
||||||
report := NewTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion)
|
r := newTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion)
|
||||||
|
|
||||||
go postCertIssueReport(report, userAgent)
|
if !p.hasRecentlySentReport(r) {
|
||||||
|
p.recordReport(r)
|
||||||
|
go r.sendReport(remoteURI, userAgent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasRecentlySentReport returns whether the report was already sent within the last 24 hours.
|
||||||
|
func (p *pinChecker) hasRecentlySentReport(report tlsReport) bool {
|
||||||
|
var validReports []sentReport
|
||||||
|
|
||||||
|
for _, r := range p.sentReports {
|
||||||
|
if time.Since(r.t) < 24*time.Hour {
|
||||||
|
validReports = append(validReports, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.sentReports = validReports
|
||||||
|
|
||||||
|
for _, r := range p.sentReports {
|
||||||
|
if cmp.Equal(report, r.r) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordReport records the given report and the current time so we can check whether we recently sent this report.
|
||||||
|
func (p *pinChecker) recordReport(r tlsReport) {
|
||||||
|
p.sentReports = append(p.sentReports, sentReport{r: r, t: time.Now()})
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
|
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
|
||||||
@ -67,7 +104,7 @@ func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
|
|||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: cert.Raw,
|
Bytes: cert.Raw,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logrus.WithField("pkg", "pmapi/tls-pinning").Errorf("encoding TLS cert: %v", err)
|
logrus.WithField("pkg", "pmapi/tls-pinning").WithError(err).Error("Failed to encode TLS certificate")
|
||||||
}
|
}
|
||||||
pemCerts = append(pemCerts, buffer.String())
|
pemCerts = append(pemCerts, buffer.String())
|
||||||
buffer.Reset()
|
buffer.Reset()
|
||||||
|
|||||||
41
pkg/pmapi/pin_checker_test.go
Normal file
41
pkg/pmapi/pin_checker_test.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package pmapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPinCheckerDoubleReport(t *testing.T) {
|
||||||
|
reportCounter := 0
|
||||||
|
|
||||||
|
reportServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
reportCounter++
|
||||||
|
}))
|
||||||
|
|
||||||
|
pc := newPinChecker(TrustedAPIPins)
|
||||||
|
|
||||||
|
// Report the same issue many times.
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
pc.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{}, "3", "useragent")
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should only report once.
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
return reportCounter == 1
|
||||||
|
}, time.Second, time.Millisecond)
|
||||||
|
|
||||||
|
// If we then report something else many times.
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
pc.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{}, "3", "useragent")
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should get a second report.
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
return reportCounter == 2
|
||||||
|
}, time.Second, time.Millisecond)
|
||||||
|
}
|
||||||
@ -142,7 +142,6 @@ func (p *proxyProvider) refreshProxyCache() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// canReach returns whether we can reach the given url.
|
// canReach returns whether we can reach the given url.
|
||||||
// NOTE: we skip cert verification to stop it complaining that cert name doesn't match hostname.
|
|
||||||
func (p *proxyProvider) canReach(url string) bool {
|
func (p *proxyProvider) canReach(url string) bool {
|
||||||
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
|
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
|
||||||
url = "https://" + url
|
url = "https://" + url
|
||||||
|
|||||||
@ -30,9 +30,9 @@ var TrustedAPIPins = []string{ // nolint[gochecknoglobals]
|
|||||||
// TLSReportURI is the address where TLS reports should be sent.
|
// TLSReportURI is the address where TLS reports should be sent.
|
||||||
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
|
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
|
||||||
|
|
||||||
// TLSReport is inspired by https://tools.ietf.org/html/rfc7469#section-3.
|
// tlsReport is inspired by https://tools.ietf.org/html/rfc7469#section-3.
|
||||||
// When a TLS key mismatch is detected, a TLSReport is posted to TLSReportURI.
|
// When a TLS key mismatch is detected, a tlsReport is posted to TLSReportURI.
|
||||||
type TLSReport struct {
|
type tlsReport struct {
|
||||||
// DateTime of observed pin validation in time.RFC3339 format.
|
// DateTime of observed pin validation in time.RFC3339 format.
|
||||||
DateTime string `json:"date-time"`
|
DateTime string `json:"date-time"`
|
||||||
|
|
||||||
@ -88,35 +88,37 @@ type TLSReport struct {
|
|||||||
AppVersion string `json:"app-version"`
|
AppVersion string `json:"app-version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTLSReport constructs a new TLSreport configured with the given app version and known pinned public keys.
|
// newTLSReport constructs a new tlsReport configured with the given app version and known pinned public keys.
|
||||||
func NewTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report TLSReport) {
|
// Temporal things (current date/time) are not set yet -- they are set when sendReport is called.
|
||||||
|
func newTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report tlsReport) {
|
||||||
// If we can't parse the port for whatever reason, it doesn't really matter; we should report anyway.
|
// If we can't parse the port for whatever reason, it doesn't really matter; we should report anyway.
|
||||||
intPort, _ := strconv.Atoi(port)
|
intPort, _ := strconv.Atoi(port)
|
||||||
|
|
||||||
report = TLSReport{
|
report = tlsReport{
|
||||||
Hostname: host,
|
Hostname: host,
|
||||||
Port: intPort,
|
Port: intPort,
|
||||||
EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339),
|
NotedHostname: server,
|
||||||
IncludeSubdomains: false,
|
ServedCertificateChain: certChain,
|
||||||
NotedHostname: server,
|
KnownPins: knownPins,
|
||||||
ValidatedCertificateChain: []string{},
|
AppVersion: appVersion,
|
||||||
ServedCertificateChain: certChain,
|
|
||||||
KnownPins: knownPins,
|
|
||||||
AppVersion: appVersion,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// postCertIssueReport posts the given TLS report to the standard TLS Report URI.
|
// sendReport posts the given TLS report to the standard TLS Report URI.
|
||||||
func postCertIssueReport(report TLSReport, userAgent string) {
|
func (r tlsReport) sendReport(uri, userAgent string) {
|
||||||
b, err := json.Marshal(report)
|
now := time.Now()
|
||||||
|
r.DateTime = now.Format(time.RFC3339)
|
||||||
|
r.EffectiveExpirationDate = now.Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339)
|
||||||
|
|
||||||
|
b, err := json.Marshal(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("Failed to marshal TLS report")
|
logrus.WithError(err).Error("Failed to marshal TLS report")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", TLSReportURI, bytes.NewReader(b))
|
req, err := http.NewRequest("POST", uri, bytes.NewReader(b))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("Failed to create http request")
|
logrus.WithError(err).Error("Failed to create http request")
|
||||||
return
|
return
|
||||||
@ -125,10 +127,10 @@ func postCertIssueReport(report TLSReport, userAgent string) {
|
|||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", userAgent)
|
req.Header.Set("User-Agent", userAgent)
|
||||||
req.Header.Set("x-pm-apiversion", strconv.Itoa(Version))
|
req.Header.Set("x-pm-apiversion", strconv.Itoa(Version))
|
||||||
req.Header.Set("x-pm-appversion", report.AppVersion)
|
req.Header.Set("x-pm-appversion", r.AppVersion)
|
||||||
|
|
||||||
logrus.WithField("request", req).Warn("Reporting TLS mismatch")
|
logrus.WithField("request", req).Warn("Reporting TLS mismatch")
|
||||||
res, err := (&http.Client{}).Do(req)
|
res, err := (&http.Client{Transport: CreateTransportWithDialer(NewBasicTLSDialer())}).Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("Failed to report TLS mismatch")
|
logrus.WithError(err).Error("Failed to report TLS mismatch")
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user