diff --git a/internal/dialer/dialer_basic.go b/internal/dialer/dialer_basic.go index a0dfca57..cf60419c 100644 --- a/internal/dialer/dialer_basic.go +++ b/internal/dialer/dialer_basic.go @@ -22,6 +22,8 @@ import ( "crypto/tls" "net" "net/http" + "net/url" + "strings" "time" ) @@ -29,6 +31,11 @@ type TLSDialer interface { DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) } +type SecureTLSDialer interface { + DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) + ShouldSkipCertificateChainVerification(address string) bool +} + func SetBasicTransportTimeouts(t *http.Transport) { t.MaxIdleConns = 100 t.MaxIdleConnsPerHost = 100 @@ -71,6 +78,35 @@ func NewBasicTLSDialer(hostURL string) *BasicTLSDialer { } } +func extractDomain(hostname string) string { + parts := strings.Split(hostname, ".") + if len(parts) >= 2 { + return strings.Join(parts[len(parts)-2:], ".") + } + return hostname +} + +// ShouldSkipCertificateChainVerification determines whether certificate chain validation should be skipped. +// It compares the domain of the requested address with the configured host URL domain. +// Returns true if the domains don't match (skip verification), false if they do (perform verification). +// +// NOTE: This assumes single-part TLDs (.com, .me) and won't handle multi-part TLDs correctly. +func (d *BasicTLSDialer) ShouldSkipCertificateChainVerification(address string) bool { + parsedURL, err := url.Parse(d.hostURL) + if err != nil { + return true + } + + addressHost, _, err := net.SplitHostPort(address) + if err != nil { + addressHost = address + } + + hostDomain := extractDomain(parsedURL.Host) + addressDomain := extractDomain(addressHost) + return addressDomain != hostDomain +} + // DialTLSContext returns a connection to the given address using the given network. func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) { return (&tls.Dialer{ @@ -78,7 +114,7 @@ func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address st Timeout: 30 * time.Second, }, Config: &tls.Config{ - InsecureSkipVerify: address != d.hostURL, //nolint:gosec + InsecureSkipVerify: d.ShouldSkipCertificateChainVerification(address), //nolint:gosec }, }).DialContext(ctx, network, address) } diff --git a/internal/dialer/dialer_basic_test.go b/internal/dialer/dialer_basic_test.go new file mode 100644 index 00000000..3d81cf0f --- /dev/null +++ b/internal/dialer/dialer_basic_test.go @@ -0,0 +1,134 @@ +// Copyright (c) 2025 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package dialer + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBasicTLSDialer_ShouldSkipCertificateChainVerification(t *testing.T) { + tests := []struct { + hostURL string + address string + expected bool + }{ + { + hostURL: "https://mail-api.proton.me", + address: "mail-api.proton.me:443", + expected: false, + }, + { + hostURL: "https://proton.me", + address: "proton.me", + expected: false, + }, + { + hostURL: "https://api.proton.me", + address: "mail.proton.me:443", + expected: false, + }, + { + hostURL: "https://proton.me", + address: "mail-api.proton.me:443", + expected: false, + }, + { + hostURL: "https://mail-api.proton.me", + address: "proton.me:443", + expected: false, + }, + { + hostURL: "https://mail.google.com", + address: "mail-api.proton.me:443", + expected: true, + }, + { + hostURL: "https://mail-api.protonmail.com", + address: "mail-api.proton.me:443", + expected: true, + }, + { + hostURL: "https://proton.me", + address: "google.com:443", + expected: true, + }, + { + hostURL: "https://proton.me", + address: "proton.com:443", + expected: true, + }, + { + hostURL: "https://proton.me", + address: "example.me:443", + expected: true, + }, + { + hostURL: "https://proton.me", + address: "mail.example.com:443", + expected: true, + }, + { + hostURL: "https://proton.me", + address: "proton.me", + expected: false, + }, + { + hostURL: "https://proton.me:8080", + address: "proton.me:443", + expected: true, + }, + { + hostURL: "https://proton.me/api/v1", + address: "proton.me:443", + expected: false, + }, + { + hostURL: "https://proton.black", + address: "mail-api.pascal.proton.black", + expected: false, + }, + { + hostURL: "https://mail-api.pascal.proton.black", + address: "mail-api.pascal.proton.black", + expected: false, + }, + { + hostURL: "https://mail-api.pascal.proton.black", + address: "proton.black:332", + expected: false, + }, + { + hostURL: "https://mail-api.pascal.proton.black", + address: "proton.me", + expected: true, + }, + { + hostURL: "https://mail-api.pascal.proton.black", + address: "proton.me:332", + expected: true, + }, + } + + for _, tt := range tests { + dialer := NewBasicTLSDialer(tt.hostURL) + result := dialer.ShouldSkipCertificateChainVerification(tt.address) + require.Equal(t, tt.expected, result) + } +} diff --git a/internal/dialer/dialer_pinning.go b/internal/dialer/dialer_pinning.go index a547ac60..fde85284 100644 --- a/internal/dialer/dialer_pinning.go +++ b/internal/dialer/dialer_pinning.go @@ -50,12 +50,12 @@ var TrustedAPIPins = []string{ //nolint:gochecknoglobals } // TLSReportURI is the address where TLS reports should be sent. -const TLSReportURI = "https://reports.protonmail.ch/reports/tls" +const TLSReportURI = "https://reports.proton.me/reports/tls" // PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and // to report errors if the fingerprint check fails. type PinningTLSDialer struct { - dialer TLSDialer + dialer SecureTLSDialer pinChecker PinChecker reporter Reporter tlsIssueCh chan struct{} @@ -68,13 +68,13 @@ type Reporter interface { // PinChecker is used to check TLS keys of connections. type PinChecker interface { - CheckCertificate(conn net.Conn) error + CheckCertificate(conn net.Conn, certificateChainVerificationSkipped bool) error } // NewPinningTLSDialer constructs a new dialer which only returns TCP connections to servers // which present known certificates. // It checks pins using the given pinChecker and reports issues using the given reporter. -func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer { +func NewPinningTLSDialer(dialer SecureTLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer { return &PinningTLSDialer{ dialer: dialer, pinChecker: pinChecker, @@ -85,6 +85,7 @@ func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChec // DialTLSContext dials the given network/address, returning an error if the certificates don't match the trusted pins. func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { + shouldSkipCertificateChainVerification := p.dialer.ShouldSkipCertificateChainVerification(address) conn, err := p.dialer.DialTLSContext(ctx, network, address) if err != nil { return nil, err @@ -95,7 +96,7 @@ func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address return nil, err } - if err := p.pinChecker.CheckCertificate(conn); err != nil { + if err := p.pinChecker.CheckCertificate(conn, shouldSkipCertificateChainVerification); err != nil { if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil { p.reporter.ReportCertIssue(TLSReportURI, host, port, tlsConn.ConnectionState()) } diff --git a/internal/dialer/dialer_pinning_checker.go b/internal/dialer/dialer_pinning_checker.go index 58cebe5f..38a21903 100644 --- a/internal/dialer/dialer_pinning_checker.go +++ b/internal/dialer/dialer_pinning_checker.go @@ -41,3 +41,15 @@ func NewTLSPinChecker(trustedPins []string) *TLSPinChecker { func certFingerprint(cert *x509.Certificate) string { return fmt.Sprintf(`pin-sha256=%q`, algo.HashBase64SHA256(string(cert.RawSubjectPublicKeyInfo))) } + +func (p *TLSPinChecker) isCertFoundInKnownPins(cert *x509.Certificate) bool { + fingerprint := certFingerprint(cert) + + for _, pin := range p.trustedPins { + if pin == fingerprint { + return true + } + } + + return false +} diff --git a/internal/dialer/dialer_pinning_checker_default.go b/internal/dialer/dialer_pinning_checker_default.go index 64ff1b62..e8428490 100644 --- a/internal/dialer/dialer_pinning_checker_default.go +++ b/internal/dialer/dialer_pinning_checker_default.go @@ -25,8 +25,8 @@ import ( "net" ) -// CheckCertificate returns whether the connection presents a known TLS certificate. -func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error { +// CheckCertificate verifies that the connection presents a known pinned leaf TLS certificate. +func (p *TLSPinChecker) CheckCertificate(conn net.Conn, certificateChainVerificationSkipped bool) error { tlsConn, ok := conn.(*tls.Conn) if !ok { return errors.New("connection is not a TLS connection") @@ -34,14 +34,31 @@ func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error { connState := tlsConn.ConnectionState() - for _, peerCert := range connState.PeerCertificates { - fingerprint := certFingerprint(peerCert) + // When certificate chain verification is enabled (e.g., for known API hosts), we expect the TLS handshake to produce verified chains. + // We then validate that the leaf certificate of at least one verified chain matches a known pinned public key. + if !certificateChainVerificationSkipped { + if len(connState.VerifiedChains) == 0 { + return errors.New("no verified certificate chains") + } - for _, pin := range p.trustedPins { - if pin == fingerprint { + for _, chain := range connState.VerifiedChains { + // Check if the leaf certificate is one of the trusted pins. + if p.isCertFoundInKnownPins(chain[0]) { return nil } } + + return ErrTLSMismatch + } + + // When certificate chain verification is skipped (e.g., for DoH proxies using self-signed certs), + // we only validate the leaf certificate against known pinned public keys. + if len(connState.PeerCertificates) == 0 { + return errors.New("no peer certificates available") + } + + if p.isCertFoundInKnownPins(connState.PeerCertificates[0]) { + return nil } return ErrTLSMismatch diff --git a/internal/dialer/dialer_pinning_checker_qa.go b/internal/dialer/dialer_pinning_checker_qa.go index 2b8e1bcd..87047dcf 100644 --- a/internal/dialer/dialer_pinning_checker_qa.go +++ b/internal/dialer/dialer_pinning_checker_qa.go @@ -23,6 +23,6 @@ import "net" // CheckCertificate returns whether the connection presents a known TLS certificate. // The QA implementation always returns nil. -func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error { +func (p *TLSPinChecker) CheckCertificate(conn net.Conn, _ bool) error { return nil } diff --git a/internal/dialer/dialer_pinning_test.go b/internal/dialer/dialer_pinning_test.go index e74b81ad..b8f8f9ab 100644 --- a/internal/dialer/dialer_pinning_test.go +++ b/internal/dialer/dialer_pinning_test.go @@ -64,8 +64,7 @@ func TestTLSPinInvalid(t *testing.T) { checkTLSIssueHandler(t, 1, called) } -// Disabled for now we'll need to patch this up. -func _TestTLSPinNoMatch(t *testing.T) { //nolint:unused +func TestTLSPinNoMatch(t *testing.T) { skipIfProxyIsSet(t) called, _, reporter, checker, cm := createClientWithPinningDialer(getRootURL())