mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 20:56:51 +00:00
fix(BRIDGE-406): fixed faulty certificate chain validation logic; made certificate pin checks exclusive to leaf certs;
This commit is contained in:
@ -22,6 +22,8 @@ import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -29,6 +31,11 @@ type TLSDialer interface {
|
||||
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
|
||||
}
|
||||
|
||||
type SecureTLSDialer interface {
|
||||
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
|
||||
ShouldSkipCertificateChainVerification(address string) bool
|
||||
}
|
||||
|
||||
func SetBasicTransportTimeouts(t *http.Transport) {
|
||||
t.MaxIdleConns = 100
|
||||
t.MaxIdleConnsPerHost = 100
|
||||
@ -71,6 +78,35 @@ func NewBasicTLSDialer(hostURL string) *BasicTLSDialer {
|
||||
}
|
||||
}
|
||||
|
||||
func extractDomain(hostname string) string {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
return strings.Join(parts[len(parts)-2:], ".")
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// ShouldSkipCertificateChainVerification determines whether certificate chain validation should be skipped.
|
||||
// It compares the domain of the requested address with the configured host URL domain.
|
||||
// Returns true if the domains don't match (skip verification), false if they do (perform verification).
|
||||
//
|
||||
// NOTE: This assumes single-part TLDs (.com, .me) and won't handle multi-part TLDs correctly.
|
||||
func (d *BasicTLSDialer) ShouldSkipCertificateChainVerification(address string) bool {
|
||||
parsedURL, err := url.Parse(d.hostURL)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
addressHost, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
addressHost = address
|
||||
}
|
||||
|
||||
hostDomain := extractDomain(parsedURL.Host)
|
||||
addressDomain := extractDomain(addressHost)
|
||||
return addressDomain != hostDomain
|
||||
}
|
||||
|
||||
// DialTLSContext returns a connection to the given address using the given network.
|
||||
func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||
return (&tls.Dialer{
|
||||
@ -78,7 +114,7 @@ func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address st
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
Config: &tls.Config{
|
||||
InsecureSkipVerify: address != d.hostURL, //nolint:gosec
|
||||
InsecureSkipVerify: d.ShouldSkipCertificateChainVerification(address), //nolint:gosec
|
||||
},
|
||||
}).DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user