Compare commits

...

3 Commits

9 changed files with 221 additions and 16 deletions

View File

@ -3,6 +3,12 @@
Changelog [format](http://keepachangelog.com/en/1.0.0/)
## Kanmon Bridge 3.21.2
### Fixed
* BRIDGE-406: Fixed faulty certificate chain validation logic. Made certificate pin checks exclusive to leaf certificates.
## Kanmon Bridge 3.21.1
### Changed

View File

@ -12,7 +12,7 @@ ROOT_DIR:=$(realpath .)
.PHONY: build build-gui build-nogui build-launcher versioner hasher
# Keep version hardcoded so app build works also without Git repository.
BRIDGE_APP_VERSION?=3.21.1+git
BRIDGE_APP_VERSION?=3.21.2+git
APP_VERSION:=${BRIDGE_APP_VERSION}
APP_FULL_NAME:=Proton Mail Bridge
APP_VENDOR:=Proton AG

View File

@ -22,6 +22,8 @@ import (
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
"time"
)
@ -29,6 +31,11 @@ type TLSDialer interface {
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
}
type SecureTLSDialer interface {
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
ShouldSkipCertificateChainVerification(address string) bool
}
func SetBasicTransportTimeouts(t *http.Transport) {
t.MaxIdleConns = 100
t.MaxIdleConnsPerHost = 100
@ -71,6 +78,35 @@ func NewBasicTLSDialer(hostURL string) *BasicTLSDialer {
}
}
func extractDomain(hostname string) string {
parts := strings.Split(hostname, ".")
if len(parts) >= 2 {
return strings.Join(parts[len(parts)-2:], ".")
}
return hostname
}
// ShouldSkipCertificateChainVerification determines whether certificate chain validation should be skipped.
// It compares the domain of the requested address with the configured host URL domain.
// Returns true if the domains don't match (skip verification), false if they do (perform verification).
//
// NOTE: This assumes single-part TLDs (.com, .me) and won't handle multi-part TLDs correctly.
func (d *BasicTLSDialer) ShouldSkipCertificateChainVerification(address string) bool {
parsedURL, err := url.Parse(d.hostURL)
if err != nil {
return true
}
addressHost, _, err := net.SplitHostPort(address)
if err != nil {
addressHost = address
}
hostDomain := extractDomain(parsedURL.Host)
addressDomain := extractDomain(addressHost)
return addressDomain != hostDomain
}
// DialTLSContext returns a connection to the given address using the given network.
func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
return (&tls.Dialer{
@ -78,7 +114,7 @@ func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address st
Timeout: 30 * time.Second,
},
Config: &tls.Config{
InsecureSkipVerify: address != d.hostURL, //nolint:gosec
InsecureSkipVerify: d.ShouldSkipCertificateChainVerification(address), //nolint:gosec
},
}).DialContext(ctx, network, address)
}

View File

@ -0,0 +1,134 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestBasicTLSDialer_ShouldSkipCertificateChainVerification(t *testing.T) {
tests := []struct {
hostURL string
address string
expected bool
}{
{
hostURL: "https://mail-api.proton.me",
address: "mail-api.proton.me:443",
expected: false,
},
{
hostURL: "https://proton.me",
address: "proton.me",
expected: false,
},
{
hostURL: "https://api.proton.me",
address: "mail.proton.me:443",
expected: false,
},
{
hostURL: "https://proton.me",
address: "mail-api.proton.me:443",
expected: false,
},
{
hostURL: "https://mail-api.proton.me",
address: "proton.me:443",
expected: false,
},
{
hostURL: "https://mail.google.com",
address: "mail-api.proton.me:443",
expected: true,
},
{
hostURL: "https://mail-api.protonmail.com",
address: "mail-api.proton.me:443",
expected: true,
},
{
hostURL: "https://proton.me",
address: "google.com:443",
expected: true,
},
{
hostURL: "https://proton.me",
address: "proton.com:443",
expected: true,
},
{
hostURL: "https://proton.me",
address: "example.me:443",
expected: true,
},
{
hostURL: "https://proton.me",
address: "mail.example.com:443",
expected: true,
},
{
hostURL: "https://proton.me",
address: "proton.me",
expected: false,
},
{
hostURL: "https://proton.me:8080",
address: "proton.me:443",
expected: true,
},
{
hostURL: "https://proton.me/api/v1",
address: "proton.me:443",
expected: false,
},
{
hostURL: "https://proton.black",
address: "mail-api.pascal.proton.black",
expected: false,
},
{
hostURL: "https://mail-api.pascal.proton.black",
address: "mail-api.pascal.proton.black",
expected: false,
},
{
hostURL: "https://mail-api.pascal.proton.black",
address: "proton.black:332",
expected: false,
},
{
hostURL: "https://mail-api.pascal.proton.black",
address: "proton.me",
expected: true,
},
{
hostURL: "https://mail-api.pascal.proton.black",
address: "proton.me:332",
expected: true,
},
}
for _, tt := range tests {
dialer := NewBasicTLSDialer(tt.hostURL)
result := dialer.ShouldSkipCertificateChainVerification(tt.address)
require.Equal(t, tt.expected, result)
}
}

View File

@ -50,12 +50,12 @@ var TrustedAPIPins = []string{ //nolint:gochecknoglobals
}
// TLSReportURI is the address where TLS reports should be sent.
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
const TLSReportURI = "https://reports.proton.me/reports/tls"
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
// to report errors if the fingerprint check fails.
type PinningTLSDialer struct {
dialer TLSDialer
dialer SecureTLSDialer
pinChecker PinChecker
reporter Reporter
tlsIssueCh chan struct{}
@ -68,13 +68,13 @@ type Reporter interface {
// PinChecker is used to check TLS keys of connections.
type PinChecker interface {
CheckCertificate(conn net.Conn) error
CheckCertificate(conn net.Conn, certificateChainVerificationSkipped bool) error
}
// NewPinningTLSDialer constructs a new dialer which only returns TCP connections to servers
// which present known certificates.
// It checks pins using the given pinChecker and reports issues using the given reporter.
func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
func NewPinningTLSDialer(dialer SecureTLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
return &PinningTLSDialer{
dialer: dialer,
pinChecker: pinChecker,
@ -85,6 +85,7 @@ func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChec
// DialTLSContext dials the given network/address, returning an error if the certificates don't match the trusted pins.
func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
shouldSkipCertificateChainVerification := p.dialer.ShouldSkipCertificateChainVerification(address)
conn, err := p.dialer.DialTLSContext(ctx, network, address)
if err != nil {
return nil, err
@ -95,7 +96,7 @@ func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address
return nil, err
}
if err := p.pinChecker.CheckCertificate(conn); err != nil {
if err := p.pinChecker.CheckCertificate(conn, shouldSkipCertificateChainVerification); err != nil {
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
p.reporter.ReportCertIssue(TLSReportURI, host, port, tlsConn.ConnectionState())
}

View File

@ -41,3 +41,15 @@ func NewTLSPinChecker(trustedPins []string) *TLSPinChecker {
func certFingerprint(cert *x509.Certificate) string {
return fmt.Sprintf(`pin-sha256=%q`, algo.HashBase64SHA256(string(cert.RawSubjectPublicKeyInfo)))
}
func (p *TLSPinChecker) isCertFoundInKnownPins(cert *x509.Certificate) bool {
fingerprint := certFingerprint(cert)
for _, pin := range p.trustedPins {
if pin == fingerprint {
return true
}
}
return false
}

View File

@ -25,8 +25,8 @@ import (
"net"
)
// CheckCertificate returns whether the connection presents a known TLS certificate.
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
// CheckCertificate verifies that the connection presents a known pinned leaf TLS certificate.
func (p *TLSPinChecker) CheckCertificate(conn net.Conn, certificateChainVerificationSkipped bool) error {
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return errors.New("connection is not a TLS connection")
@ -34,14 +34,31 @@ func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
connState := tlsConn.ConnectionState()
for _, peerCert := range connState.PeerCertificates {
fingerprint := certFingerprint(peerCert)
// When certificate chain verification is enabled (e.g., for known API hosts), we expect the TLS handshake to produce verified chains.
// We then validate that the leaf certificate of at least one verified chain matches a known pinned public key.
if !certificateChainVerificationSkipped {
if len(connState.VerifiedChains) == 0 {
return errors.New("no verified certificate chains")
}
for _, pin := range p.trustedPins {
if pin == fingerprint {
for _, chain := range connState.VerifiedChains {
// Check if the leaf certificate is one of the trusted pins.
if p.isCertFoundInKnownPins(chain[0]) {
return nil
}
}
return ErrTLSMismatch
}
// When certificate chain verification is skipped (e.g., for DoH proxies using self-signed certs),
// we only validate the leaf certificate against known pinned public keys.
if len(connState.PeerCertificates) == 0 {
return errors.New("no peer certificates available")
}
if p.isCertFoundInKnownPins(connState.PeerCertificates[0]) {
return nil
}
return ErrTLSMismatch

View File

@ -23,6 +23,6 @@ import "net"
// CheckCertificate returns whether the connection presents a known TLS certificate.
// The QA implementation always returns nil.
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
func (p *TLSPinChecker) CheckCertificate(conn net.Conn, _ bool) error {
return nil
}

View File

@ -64,8 +64,7 @@ func TestTLSPinInvalid(t *testing.T) {
checkTLSIssueHandler(t, 1, called)
}
// Disabled for now we'll need to patch this up.
func _TestTLSPinNoMatch(t *testing.T) { //nolint:unused
func TestTLSPinNoMatch(t *testing.T) {
skipIfProxyIsSet(t)
called, _, reporter, checker, cm := createClientWithPinningDialer(getRootURL())