mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
fix(BRIDGE-406): fixed faulty certificate chain validation logic; made certificate pin checks exclusive to leaf certs;
This commit is contained in:
@ -22,6 +22,8 @@ import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -29,6 +31,11 @@ type TLSDialer interface {
|
||||
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
|
||||
}
|
||||
|
||||
type SecureTLSDialer interface {
|
||||
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
|
||||
ShouldSkipCertificateChainVerification(address string) bool
|
||||
}
|
||||
|
||||
func SetBasicTransportTimeouts(t *http.Transport) {
|
||||
t.MaxIdleConns = 100
|
||||
t.MaxIdleConnsPerHost = 100
|
||||
@ -71,6 +78,35 @@ func NewBasicTLSDialer(hostURL string) *BasicTLSDialer {
|
||||
}
|
||||
}
|
||||
|
||||
func extractDomain(hostname string) string {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
return strings.Join(parts[len(parts)-2:], ".")
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// ShouldSkipCertificateChainVerification determines whether certificate chain validation should be skipped.
|
||||
// It compares the domain of the requested address with the configured host URL domain.
|
||||
// Returns true if the domains don't match (skip verification), false if they do (perform verification).
|
||||
//
|
||||
// NOTE: This assumes single-part TLDs (.com, .me) and won't handle multi-part TLDs correctly.
|
||||
func (d *BasicTLSDialer) ShouldSkipCertificateChainVerification(address string) bool {
|
||||
parsedURL, err := url.Parse(d.hostURL)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
addressHost, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
addressHost = address
|
||||
}
|
||||
|
||||
hostDomain := extractDomain(parsedURL.Host)
|
||||
addressDomain := extractDomain(addressHost)
|
||||
return addressDomain != hostDomain
|
||||
}
|
||||
|
||||
// DialTLSContext returns a connection to the given address using the given network.
|
||||
func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||
return (&tls.Dialer{
|
||||
@ -78,7 +114,7 @@ func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address st
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
Config: &tls.Config{
|
||||
InsecureSkipVerify: address != d.hostURL, //nolint:gosec
|
||||
InsecureSkipVerify: d.ShouldSkipCertificateChainVerification(address), //nolint:gosec
|
||||
},
|
||||
}).DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
134
internal/dialer/dialer_basic_test.go
Normal file
134
internal/dialer/dialer_basic_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
// Copyright (c) 2025 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBasicTLSDialer_ShouldSkipCertificateChainVerification(t *testing.T) {
|
||||
tests := []struct {
|
||||
hostURL string
|
||||
address string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
hostURL: "https://mail-api.proton.me",
|
||||
address: "mail-api.proton.me:443",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
hostURL: "https://proton.me",
|
||||
address: "proton.me",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
hostURL: "https://api.proton.me",
|
||||
address: "mail.proton.me:443",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
hostURL: "https://proton.me",
|
||||
address: "mail-api.proton.me:443",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
hostURL: "https://mail-api.proton.me",
|
||||
address: "proton.me:443",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
hostURL: "https://mail.google.com",
|
||||
address: "mail-api.proton.me:443",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
hostURL: "https://mail-api.protonmail.com",
|
||||
address: "mail-api.proton.me:443",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
hostURL: "https://proton.me",
|
||||
address: "google.com:443",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
hostURL: "https://proton.me",
|
||||
address: "proton.com:443",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
hostURL: "https://proton.me",
|
||||
address: "example.me:443",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
hostURL: "https://proton.me",
|
||||
address: "mail.example.com:443",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
hostURL: "https://proton.me",
|
||||
address: "proton.me",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
hostURL: "https://proton.me:8080",
|
||||
address: "proton.me:443",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
hostURL: "https://proton.me/api/v1",
|
||||
address: "proton.me:443",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
hostURL: "https://proton.black",
|
||||
address: "mail-api.pascal.proton.black",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
hostURL: "https://mail-api.pascal.proton.black",
|
||||
address: "mail-api.pascal.proton.black",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
hostURL: "https://mail-api.pascal.proton.black",
|
||||
address: "proton.black:332",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
hostURL: "https://mail-api.pascal.proton.black",
|
||||
address: "proton.me",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
hostURL: "https://mail-api.pascal.proton.black",
|
||||
address: "proton.me:332",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
dialer := NewBasicTLSDialer(tt.hostURL)
|
||||
result := dialer.ShouldSkipCertificateChainVerification(tt.address)
|
||||
require.Equal(t, tt.expected, result)
|
||||
}
|
||||
}
|
||||
@ -50,12 +50,12 @@ var TrustedAPIPins = []string{ //nolint:gochecknoglobals
|
||||
}
|
||||
|
||||
// TLSReportURI is the address where TLS reports should be sent.
|
||||
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
|
||||
const TLSReportURI = "https://reports.proton.me/reports/tls"
|
||||
|
||||
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
|
||||
// to report errors if the fingerprint check fails.
|
||||
type PinningTLSDialer struct {
|
||||
dialer TLSDialer
|
||||
dialer SecureTLSDialer
|
||||
pinChecker PinChecker
|
||||
reporter Reporter
|
||||
tlsIssueCh chan struct{}
|
||||
@ -68,13 +68,13 @@ type Reporter interface {
|
||||
|
||||
// PinChecker is used to check TLS keys of connections.
|
||||
type PinChecker interface {
|
||||
CheckCertificate(conn net.Conn) error
|
||||
CheckCertificate(conn net.Conn, certificateChainVerificationSkipped bool) error
|
||||
}
|
||||
|
||||
// NewPinningTLSDialer constructs a new dialer which only returns TCP connections to servers
|
||||
// which present known certificates.
|
||||
// It checks pins using the given pinChecker and reports issues using the given reporter.
|
||||
func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
|
||||
func NewPinningTLSDialer(dialer SecureTLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
|
||||
return &PinningTLSDialer{
|
||||
dialer: dialer,
|
||||
pinChecker: pinChecker,
|
||||
@ -85,6 +85,7 @@ func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChec
|
||||
|
||||
// DialTLSContext dials the given network/address, returning an error if the certificates don't match the trusted pins.
|
||||
func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
shouldSkipCertificateChainVerification := p.dialer.ShouldSkipCertificateChainVerification(address)
|
||||
conn, err := p.dialer.DialTLSContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -95,7 +96,7 @@ func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.pinChecker.CheckCertificate(conn); err != nil {
|
||||
if err := p.pinChecker.CheckCertificate(conn, shouldSkipCertificateChainVerification); err != nil {
|
||||
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
|
||||
p.reporter.ReportCertIssue(TLSReportURI, host, port, tlsConn.ConnectionState())
|
||||
}
|
||||
|
||||
@ -41,3 +41,15 @@ func NewTLSPinChecker(trustedPins []string) *TLSPinChecker {
|
||||
func certFingerprint(cert *x509.Certificate) string {
|
||||
return fmt.Sprintf(`pin-sha256=%q`, algo.HashBase64SHA256(string(cert.RawSubjectPublicKeyInfo)))
|
||||
}
|
||||
|
||||
func (p *TLSPinChecker) isCertFoundInKnownPins(cert *x509.Certificate) bool {
|
||||
fingerprint := certFingerprint(cert)
|
||||
|
||||
for _, pin := range p.trustedPins {
|
||||
if pin == fingerprint {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@ -25,8 +25,8 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// CheckCertificate returns whether the connection presents a known TLS certificate.
|
||||
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
|
||||
// CheckCertificate verifies that the connection presents a known pinned leaf TLS certificate.
|
||||
func (p *TLSPinChecker) CheckCertificate(conn net.Conn, certificateChainVerificationSkipped bool) error {
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if !ok {
|
||||
return errors.New("connection is not a TLS connection")
|
||||
@ -34,14 +34,31 @@ func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
|
||||
|
||||
connState := tlsConn.ConnectionState()
|
||||
|
||||
for _, peerCert := range connState.PeerCertificates {
|
||||
fingerprint := certFingerprint(peerCert)
|
||||
// When certificate chain verification is enabled (e.g., for known API hosts), we expect the TLS handshake to produce verified chains.
|
||||
// We then validate that the leaf certificate of at least one verified chain matches a known pinned public key.
|
||||
if !certificateChainVerificationSkipped {
|
||||
if len(connState.VerifiedChains) == 0 {
|
||||
return errors.New("no verified certificate chains")
|
||||
}
|
||||
|
||||
for _, pin := range p.trustedPins {
|
||||
if pin == fingerprint {
|
||||
for _, chain := range connState.VerifiedChains {
|
||||
// Check if the leaf certificate is one of the trusted pins.
|
||||
if p.isCertFoundInKnownPins(chain[0]) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return ErrTLSMismatch
|
||||
}
|
||||
|
||||
// When certificate chain verification is skipped (e.g., for DoH proxies using self-signed certs),
|
||||
// we only validate the leaf certificate against known pinned public keys.
|
||||
if len(connState.PeerCertificates) == 0 {
|
||||
return errors.New("no peer certificates available")
|
||||
}
|
||||
|
||||
if p.isCertFoundInKnownPins(connState.PeerCertificates[0]) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrTLSMismatch
|
||||
|
||||
@ -23,6 +23,6 @@ import "net"
|
||||
|
||||
// CheckCertificate returns whether the connection presents a known TLS certificate.
|
||||
// The QA implementation always returns nil.
|
||||
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
|
||||
func (p *TLSPinChecker) CheckCertificate(conn net.Conn, _ bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -64,8 +64,7 @@ func TestTLSPinInvalid(t *testing.T) {
|
||||
checkTLSIssueHandler(t, 1, called)
|
||||
}
|
||||
|
||||
// Disabled for now we'll need to patch this up.
|
||||
func _TestTLSPinNoMatch(t *testing.T) { //nolint:unused
|
||||
func TestTLSPinNoMatch(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
called, _, reporter, checker, cm := createClientWithPinningDialer(getRootURL())
|
||||
|
||||
Reference in New Issue
Block a user