mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
fix(BRIDGE-406): fixed faulty certificate chain validation logic; made certificate pin checks exclusive to leaf certs;
This commit is contained in:
@ -22,6 +22,8 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,6 +31,11 @@ type TLSDialer interface {
|
|||||||
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
|
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SecureTLSDialer interface {
|
||||||
|
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
|
||||||
|
ShouldSkipCertificateChainVerification(address string) bool
|
||||||
|
}
|
||||||
|
|
||||||
func SetBasicTransportTimeouts(t *http.Transport) {
|
func SetBasicTransportTimeouts(t *http.Transport) {
|
||||||
t.MaxIdleConns = 100
|
t.MaxIdleConns = 100
|
||||||
t.MaxIdleConnsPerHost = 100
|
t.MaxIdleConnsPerHost = 100
|
||||||
@ -71,6 +78,35 @@ func NewBasicTLSDialer(hostURL string) *BasicTLSDialer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractDomain(hostname string) string {
|
||||||
|
parts := strings.Split(hostname, ".")
|
||||||
|
if len(parts) >= 2 {
|
||||||
|
return strings.Join(parts[len(parts)-2:], ".")
|
||||||
|
}
|
||||||
|
return hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldSkipCertificateChainVerification determines whether certificate chain validation should be skipped.
|
||||||
|
// It compares the domain of the requested address with the configured host URL domain.
|
||||||
|
// Returns true if the domains don't match (skip verification), false if they do (perform verification).
|
||||||
|
//
|
||||||
|
// NOTE: This assumes single-part TLDs (.com, .me) and won't handle multi-part TLDs correctly.
|
||||||
|
func (d *BasicTLSDialer) ShouldSkipCertificateChainVerification(address string) bool {
|
||||||
|
parsedURL, err := url.Parse(d.hostURL)
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
addressHost, _, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
addressHost = address
|
||||||
|
}
|
||||||
|
|
||||||
|
hostDomain := extractDomain(parsedURL.Host)
|
||||||
|
addressDomain := extractDomain(addressHost)
|
||||||
|
return addressDomain != hostDomain
|
||||||
|
}
|
||||||
|
|
||||||
// DialTLSContext returns a connection to the given address using the given network.
|
// DialTLSContext returns a connection to the given address using the given network.
|
||||||
func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||||
return (&tls.Dialer{
|
return (&tls.Dialer{
|
||||||
@ -78,7 +114,7 @@ func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address st
|
|||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
},
|
},
|
||||||
Config: &tls.Config{
|
Config: &tls.Config{
|
||||||
InsecureSkipVerify: address != d.hostURL, //nolint:gosec
|
InsecureSkipVerify: d.ShouldSkipCertificateChainVerification(address), //nolint:gosec
|
||||||
},
|
},
|
||||||
}).DialContext(ctx, network, address)
|
}).DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|||||||
134
internal/dialer/dialer_basic_test.go
Normal file
134
internal/dialer/dialer_basic_test.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
// Copyright (c) 2025 Proton AG
|
||||||
|
//
|
||||||
|
// This file is part of Proton Mail Bridge.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package dialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBasicTLSDialer_ShouldSkipCertificateChainVerification(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
hostURL string
|
||||||
|
address string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
hostURL: "https://mail-api.proton.me",
|
||||||
|
address: "mail-api.proton.me:443",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://proton.me",
|
||||||
|
address: "proton.me",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://api.proton.me",
|
||||||
|
address: "mail.proton.me:443",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://proton.me",
|
||||||
|
address: "mail-api.proton.me:443",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://mail-api.proton.me",
|
||||||
|
address: "proton.me:443",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://mail.google.com",
|
||||||
|
address: "mail-api.proton.me:443",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://mail-api.protonmail.com",
|
||||||
|
address: "mail-api.proton.me:443",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://proton.me",
|
||||||
|
address: "google.com:443",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://proton.me",
|
||||||
|
address: "proton.com:443",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://proton.me",
|
||||||
|
address: "example.me:443",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://proton.me",
|
||||||
|
address: "mail.example.com:443",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://proton.me",
|
||||||
|
address: "proton.me",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://proton.me:8080",
|
||||||
|
address: "proton.me:443",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://proton.me/api/v1",
|
||||||
|
address: "proton.me:443",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://proton.black",
|
||||||
|
address: "mail-api.pascal.proton.black",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://mail-api.pascal.proton.black",
|
||||||
|
address: "mail-api.pascal.proton.black",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://mail-api.pascal.proton.black",
|
||||||
|
address: "proton.black:332",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://mail-api.pascal.proton.black",
|
||||||
|
address: "proton.me",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hostURL: "https://mail-api.pascal.proton.black",
|
||||||
|
address: "proton.me:332",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
dialer := NewBasicTLSDialer(tt.hostURL)
|
||||||
|
result := dialer.ShouldSkipCertificateChainVerification(tt.address)
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -50,12 +50,12 @@ var TrustedAPIPins = []string{ //nolint:gochecknoglobals
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TLSReportURI is the address where TLS reports should be sent.
|
// TLSReportURI is the address where TLS reports should be sent.
|
||||||
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
|
const TLSReportURI = "https://reports.proton.me/reports/tls"
|
||||||
|
|
||||||
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
|
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
|
||||||
// to report errors if the fingerprint check fails.
|
// to report errors if the fingerprint check fails.
|
||||||
type PinningTLSDialer struct {
|
type PinningTLSDialer struct {
|
||||||
dialer TLSDialer
|
dialer SecureTLSDialer
|
||||||
pinChecker PinChecker
|
pinChecker PinChecker
|
||||||
reporter Reporter
|
reporter Reporter
|
||||||
tlsIssueCh chan struct{}
|
tlsIssueCh chan struct{}
|
||||||
@ -68,13 +68,13 @@ type Reporter interface {
|
|||||||
|
|
||||||
// PinChecker is used to check TLS keys of connections.
|
// PinChecker is used to check TLS keys of connections.
|
||||||
type PinChecker interface {
|
type PinChecker interface {
|
||||||
CheckCertificate(conn net.Conn) error
|
CheckCertificate(conn net.Conn, certificateChainVerificationSkipped bool) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPinningTLSDialer constructs a new dialer which only returns TCP connections to servers
|
// NewPinningTLSDialer constructs a new dialer which only returns TCP connections to servers
|
||||||
// which present known certificates.
|
// which present known certificates.
|
||||||
// It checks pins using the given pinChecker and reports issues using the given reporter.
|
// It checks pins using the given pinChecker and reports issues using the given reporter.
|
||||||
func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
|
func NewPinningTLSDialer(dialer SecureTLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
|
||||||
return &PinningTLSDialer{
|
return &PinningTLSDialer{
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
pinChecker: pinChecker,
|
pinChecker: pinChecker,
|
||||||
@ -85,6 +85,7 @@ func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChec
|
|||||||
|
|
||||||
// DialTLSContext dials the given network/address, returning an error if the certificates don't match the trusted pins.
|
// DialTLSContext dials the given network/address, returning an error if the certificates don't match the trusted pins.
|
||||||
func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
shouldSkipCertificateChainVerification := p.dialer.ShouldSkipCertificateChainVerification(address)
|
||||||
conn, err := p.dialer.DialTLSContext(ctx, network, address)
|
conn, err := p.dialer.DialTLSContext(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -95,7 +96,7 @@ func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.pinChecker.CheckCertificate(conn); err != nil {
|
if err := p.pinChecker.CheckCertificate(conn, shouldSkipCertificateChainVerification); err != nil {
|
||||||
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
|
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
|
||||||
p.reporter.ReportCertIssue(TLSReportURI, host, port, tlsConn.ConnectionState())
|
p.reporter.ReportCertIssue(TLSReportURI, host, port, tlsConn.ConnectionState())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -41,3 +41,15 @@ func NewTLSPinChecker(trustedPins []string) *TLSPinChecker {
|
|||||||
func certFingerprint(cert *x509.Certificate) string {
|
func certFingerprint(cert *x509.Certificate) string {
|
||||||
return fmt.Sprintf(`pin-sha256=%q`, algo.HashBase64SHA256(string(cert.RawSubjectPublicKeyInfo)))
|
return fmt.Sprintf(`pin-sha256=%q`, algo.HashBase64SHA256(string(cert.RawSubjectPublicKeyInfo)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *TLSPinChecker) isCertFoundInKnownPins(cert *x509.Certificate) bool {
|
||||||
|
fingerprint := certFingerprint(cert)
|
||||||
|
|
||||||
|
for _, pin := range p.trustedPins {
|
||||||
|
if pin == fingerprint {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@ -25,8 +25,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CheckCertificate returns whether the connection presents a known TLS certificate.
|
// CheckCertificate verifies that the connection presents a known pinned leaf TLS certificate.
|
||||||
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
|
func (p *TLSPinChecker) CheckCertificate(conn net.Conn, certificateChainVerificationSkipped bool) error {
|
||||||
tlsConn, ok := conn.(*tls.Conn)
|
tlsConn, ok := conn.(*tls.Conn)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("connection is not a TLS connection")
|
return errors.New("connection is not a TLS connection")
|
||||||
@ -34,15 +34,32 @@ func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
|
|||||||
|
|
||||||
connState := tlsConn.ConnectionState()
|
connState := tlsConn.ConnectionState()
|
||||||
|
|
||||||
for _, peerCert := range connState.PeerCertificates {
|
// When certificate chain verification is enabled (e.g., for known API hosts), we expect the TLS handshake to produce verified chains.
|
||||||
fingerprint := certFingerprint(peerCert)
|
// We then validate that the leaf certificate of at least one verified chain matches a known pinned public key.
|
||||||
|
if !certificateChainVerificationSkipped {
|
||||||
for _, pin := range p.trustedPins {
|
if len(connState.VerifiedChains) == 0 {
|
||||||
if pin == fingerprint {
|
return errors.New("no verified certificate chains")
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, chain := range connState.VerifiedChains {
|
||||||
|
// Check if the leaf certificate is one of the trusted pins.
|
||||||
|
if p.isCertFoundInKnownPins(chain[0]) {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return ErrTLSMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// When certificate chain verification is skipped (e.g., for DoH proxies using self-signed certs),
|
||||||
|
// we only validate the leaf certificate against known pinned public keys.
|
||||||
|
if len(connState.PeerCertificates) == 0 {
|
||||||
|
return errors.New("no peer certificates available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.isCertFoundInKnownPins(connState.PeerCertificates[0]) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return ErrTLSMismatch
|
return ErrTLSMismatch
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,6 +23,6 @@ import "net"
|
|||||||
|
|
||||||
// CheckCertificate returns whether the connection presents a known TLS certificate.
|
// CheckCertificate returns whether the connection presents a known TLS certificate.
|
||||||
// The QA implementation always returns nil.
|
// The QA implementation always returns nil.
|
||||||
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
|
func (p *TLSPinChecker) CheckCertificate(conn net.Conn, _ bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -64,8 +64,7 @@ func TestTLSPinInvalid(t *testing.T) {
|
|||||||
checkTLSIssueHandler(t, 1, called)
|
checkTLSIssueHandler(t, 1, called)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disabled for now we'll need to patch this up.
|
func TestTLSPinNoMatch(t *testing.T) {
|
||||||
func _TestTLSPinNoMatch(t *testing.T) { //nolint:unused
|
|
||||||
skipIfProxyIsSet(t)
|
skipIfProxyIsSet(t)
|
||||||
|
|
||||||
called, _, reporter, checker, cm := createClientWithPinningDialer(getRootURL())
|
called, _, reporter, checker, cm := createClientWithPinningDialer(getRootURL())
|
||||||
|
|||||||
Reference in New Issue
Block a user