GODT-1779: Remove go-imap

This commit is contained in:
James Houlahan
2022-08-26 17:00:21 +02:00
parent 3b0bc1ca15
commit 39433fe707
593 changed files with 12725 additions and 91626 deletions

View File

@ -0,0 +1,77 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
type TLSDialer interface {
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
}
// CreateTransportWithDialer creates an http.Transport that uses the given dialer to make TLS connections.
func CreateTransportWithDialer(dialer TLSDialer) *http.Transport {
return &http.Transport{
DialTLSContext: dialer.DialTLSContext,
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 5 * time.Minute,
ExpectContinueTimeout: 500 * time.Millisecond,
// GODT-126: this was initially 10s but logs from users showed a significant number
// were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect.
// Bumping to 30s for now to avoid this problem.
ResponseHeaderTimeout: 30 * time.Second,
// If we allow up to 30 seconds for response headers, it is reasonable to allow up
// to 30 seconds for the TLS handshake to take place.
TLSHandshakeTimeout: 30 * time.Second,
}
}
// BasicTLSDialer implements TLSDialer.
type BasicTLSDialer struct {
hostURL string
}
// NewBasicTLSDialer returns a new BasicTLSDialer.
func NewBasicTLSDialer(hostURL string) *BasicTLSDialer {
return &BasicTLSDialer{
hostURL: hostURL,
}
}
// DialTLS returns a connection to the given address using the given network.
func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
return (&tls.Dialer{
NetDialer: &net.Dialer{
Timeout: 30 * time.Second,
},
Config: &tls.Config{
InsecureSkipVerify: address != d.hostURL,
},
}).DialContext(ctx, network, address)
}

View File

@ -0,0 +1,114 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"crypto/tls"
"net"
)
// TrustedAPIPins contains trusted public keys of the protonmail API and proxies.
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;).
var TrustedAPIPins = []string{ //nolint:gochecknoglobals
// api.protonmail.ch
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot backup
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold backup
// protonmail.com
// \todo remove when sure no one is using it.
`pin-sha256="8joiNBdqaYiQpKskgtkJsqRxF7zN0C0aqfi8DacknnI="`, // current
`pin-sha256="JMI8yrbc6jB1FYGyyWRLFTmDNgIszrNEMGlgy972e7w="`, // hot backup
`pin-sha256="Iu44zU84EOCZ9vx/vz67/MRVrxF1IO4i4NIa8ETwiIY="`, // cold backup
// proton.me
`pin-sha256="CT56BhOTmj5ZIPgb/xD5mH8rY3BLo/MlhP7oPyJUEDo="`, // current
`pin-sha256="35Dx28/uzN3LeltkCBQ8RHK0tlNSa2kCpCRGNp34Gxc="`, // hot backup
`pin-sha256="qYIukVc63DEITct8sFT7ebIq5qsWmuscaIKeJx+5J5A="`, // col backup
// proxies
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // main
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // backup 1
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // backup 2
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // backup 3
}
// TLSReportURI is the address where TLS reports should be sent.
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
// to report errors if the fingerprint check fails.
type PinningTLSDialer struct {
dialer TLSDialer
pinChecker PinChecker
reporter Reporter
tlsIssueCh chan struct{}
}
// Reporter is used to report TLS issues.
type Reporter interface {
ReportCertIssue(reportURI, host, port string, state tls.ConnectionState)
}
// PinChecker is used to check TLS keys of connections.
type PinChecker interface {
CheckCertificate(conn net.Conn) error
}
// NewPinningTLSDialer constructs a new dialer which only returns TCP connections to servers
// which present known certificates.
// It checks pins using the given pinChecker and reports issues using the given reporter.
func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
return &PinningTLSDialer{
dialer: dialer,
pinChecker: pinChecker,
reporter: reporter,
tlsIssueCh: make(chan struct{}, 1),
}
}
// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := p.dialer.DialTLSContext(ctx, network, address)
if err != nil {
return nil, err
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
if err := p.pinChecker.CheckCertificate(conn); err != nil {
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
p.reporter.ReportCertIssue(TLSReportURI, host, port, tlsConn.ConnectionState())
}
p.tlsIssueCh <- struct{}{}
return nil, err
}
return conn, nil
}
// GetTLSIssueCh returns a channel which notifies when a TLS issue is reported.
func (p *PinningTLSDialer) GetTLSIssueCh() <-chan struct{} {
return p.tlsIssueCh
}

View File

@ -0,0 +1,67 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"github.com/ProtonMail/proton-bridge/v2/pkg/algo"
)
// ErrTLSMismatch indicates that no TLS fingerprint match could be found.
var ErrTLSMismatch = errors.New("no TLS fingerprint match found")
type TLSPinChecker struct {
trustedPins []string
}
func NewTLSPinChecker(trustedPins []string) *TLSPinChecker {
return &TLSPinChecker{
trustedPins: trustedPins,
}
}
// checkCertificate returns whether the connection presents a known TLS certificate.
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return errors.New("connection is not a TLS connection")
}
connState := tlsConn.ConnectionState()
for _, peerCert := range connState.PeerCertificates {
fingerprint := certFingerprint(peerCert)
for _, pin := range p.trustedPins {
if pin == fingerprint {
return nil
}
}
}
return ErrTLSMismatch
}
func certFingerprint(cert *x509.Certificate) string {
return fmt.Sprintf(`pin-sha256=%q`, algo.HashBase64SHA256(string(cert.RawSubjectPublicKeyInfo)))
}

View File

@ -0,0 +1,118 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"fmt"
"time"
"github.com/go-resty/resty/v2"
)
// tlsReport is inspired by https://tools.ietf.org/html/rfc7469#section-3.
// When a TLS key mismatch is detected, a tlsReport is posted to TLSReportURI.
type tlsReport struct {
// DateTime of observed pin validation in time.RFC3339 format.
DateTime string `json:"date-time"`
// Hostname to which the UA made original request that failed pin validation.
Hostname string `json:"hostname"`
// Port to which the UA made original request that failed pin validation.
Port string `json:"port"`
// EffectiveExpirationDate for noted pins in time.RFC3339 format.
EffectiveExpirationDate string `json:"effective-expiration-date"`
// IncludeSubdomains indicates whether or not the UA has noted the
// includeSubDomains directive for the Known Pinned Host.
IncludeSubdomains bool `json:"include-subdomains"`
// NotedHostname indicates the hostname that the UA noted when it noted
// the Known Pinned Host. This field allows operators to understand why
// Pin Validation was performed for, e.g., foo.example.com when the
// noted Known Pinned Host was example.com with includeSubDomains set.
NotedHostname string `json:"noted-hostname"`
// ServedCertificateChain is the certificate chain, as served by
// the Known Pinned Host during TLS session setup. It is provided as an
// array of strings; each string pem1, ... pemN is the Privacy-Enhanced
// Mail (PEM) representation of each X.509 certificate as described in
// [RFC7468].
ServedCertificateChain []string `json:"served-certificate-chain"`
// ValidatedCertificateChain is the certificate chain, as
// constructed by the UA during certificate chain verification. (This
// may differ from the served-certificate-chain.) It is provided as an
// array of strings; each string pem1, ... pemN is the PEM
// representation of each X.509 certificate as described in [RFC7468].
// UAs that build certificate chains in more than one way during the
// validation process SHOULD send the last chain built. In this way,
// they can avoid keeping too much state during the validation process.
ValidatedCertificateChain []string `json:"validated-certificate-chain"`
// The known-pins are the Pins that the UA has noted for the Known
// Pinned Host. They are provided as an array of strings with the
// syntax: known-pin = token "=" quoted-string
// e.g.:
// ```
// "known-pins": [
// 'pin-sha256="d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="',
// "pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\""
// ]
// ```
KnownPins []string `json:"known-pins"`
// AppVersion is used to set `x-pm-appversion` json format from datatheorem/TrustKit.
AppVersion string `json:"app-version"`
}
// newTLSReport constructs a new tlsReport configured with the given app version and known pinned public keys.
// Temporal things (current date/time) are not set yet -- they are set when sendReport is called.
func newTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report tlsReport) {
report = tlsReport{
Hostname: host,
Port: port,
NotedHostname: server,
ServedCertificateChain: certChain,
KnownPins: knownPins,
AppVersion: appVersion,
}
return
}
// sendReport posts the given TLS report to the standard TLS Report URI.
func sendReport(report tlsReport, userAgent, appVersion, hostURL, remoteURI string) error {
now := time.Now()
report.DateTime = now.Format(time.RFC3339)
report.EffectiveExpirationDate = now.Add(365 * 24 * time.Hour).Format(time.RFC3339)
if _, err := resty.New().
SetTransport(CreateTransportWithDialer(NewBasicTLSDialer(hostURL))).
SetHeader("User-Agent", userAgent).
SetHeader("x-pm-appversion", appVersion).
NewRequest().
SetBody(report).
Post(remoteURI); err != nil {
return fmt.Errorf("failed to send TLS report: %w", err)
}
return nil
}

View File

@ -0,0 +1,115 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/google/go-cmp/cmp"
"github.com/sirupsen/logrus"
)
type sentReport struct {
r tlsReport
t time.Time
}
type TLSReporter struct {
hostURL string
appVersion string
userAgent *useragent.UserAgent
trustedPins []string
sentReports []sentReport
}
func NewTLSReporter(hostURL, appVersion string, userAgent *useragent.UserAgent, trustedPins []string) *TLSReporter {
return &TLSReporter{
hostURL: hostURL,
appVersion: appVersion,
userAgent: userAgent,
trustedPins: trustedPins,
}
}
// reportCertIssue reports a TLS key mismatch.
func (r *TLSReporter) ReportCertIssue(remoteURI, host, port string, connState tls.ConnectionState) {
var certChain []string
if len(connState.VerifiedChains) > 0 {
certChain = marshalCert7468(connState.VerifiedChains[len(connState.VerifiedChains)-1])
} else {
certChain = marshalCert7468(connState.PeerCertificates)
}
report := newTLSReport(host, port, connState.ServerName, certChain, r.trustedPins, r.appVersion)
if !r.hasRecentlySentReport(report) {
r.recordReport(report)
if err := sendReport(report, r.userAgent.GetUserAgent(), r.appVersion, r.hostURL, remoteURI); err != nil {
logrus.WithError(err).Error("Failed to send TLS pinning report")
}
}
}
// hasRecentlySentReport returns whether the report was already sent within the last 24 hours.
func (r *TLSReporter) hasRecentlySentReport(report tlsReport) bool {
var validReports []sentReport
for _, r := range r.sentReports {
if time.Since(r.t) < 24*time.Hour {
validReports = append(validReports, r)
}
}
r.sentReports = validReports
for _, r := range r.sentReports {
if cmp.Equal(report, r.r) {
return true
}
}
return false
}
// recordReport records the given report and the current time so we can check whether we recently sent this report.
func (r *TLSReporter) recordReport(report tlsReport) {
r.sentReports = append(r.sentReports, sentReport{r: report, t: time.Now()})
}
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
var buffer bytes.Buffer
for _, cert := range certs {
if err := pem.Encode(&buffer, &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}); err != nil {
logrus.WithError(err).Error("Failed to encode TLS certificate")
}
pemCerts = append(pemCerts, buffer.String())
buffer.Reset()
}
return pemCerts
}

View File

@ -0,0 +1,59 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/stretchr/testify/assert"
)
func TestTLSReporter_DoubleReport(t *testing.T) {
reportCounter := 0
reportServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reportCounter++
}))
r := NewTLSReporter("hostURL", "appVersion", useragent.New(), TrustedAPIPins)
// Report the same issue many times.
for i := 0; i < 10; i++ {
r.ReportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{})
}
// We should only report once.
assert.Eventually(t, func() bool {
return reportCounter == 1
}, time.Second, time.Millisecond)
// If we then report something else many times.
for i := 0; i < 10; i++ {
r.ReportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{})
}
// We should get a second report.
assert.Eventually(t, func() bool {
return reportCounter == 2
}, time.Second, time.Millisecond)
}

View File

@ -0,0 +1,157 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
a "github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
)
func getRootURL() string {
return "https://api.protonmail.ch"
}
func TestTLSPinValid(t *testing.T) {
called, _, _, _, cm := createClientWithPinningDialer(getRootURL())
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
checkTLSIssueHandler(t, 0, called)
}
func TestTLSPinBackup(t *testing.T) {
called, _, _, checker, cm := createClientWithPinningDialer(getRootURL())
copyTrustedPins(checker)
checker.trustedPins[1] = checker.trustedPins[0]
checker.trustedPins[0] = ""
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
checkTLSIssueHandler(t, 0, called)
}
func TestTLSPinInvalid(t *testing.T) {
s := server.NewTLS()
defer s.Close()
called, _, _, _, cm := createClientWithPinningDialer(s.GetHostURL())
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
checkTLSIssueHandler(t, 1, called)
}
func TestTLSPinNoMatch(t *testing.T) {
skipIfProxyIsSet(t)
called, _, reporter, checker, cm := createClientWithPinningDialer(getRootURL())
copyTrustedPins(checker)
for i := 0; i < len(checker.trustedPins); i++ {
checker.trustedPins[i] = "testing"
}
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
// Check that it will be reported only once per session, but notified every time.
r.Equal(t, 1, len(reporter.sentReports))
checkTLSIssueHandler(t, 2, called)
}
func TestTLSSignedCertWrongPublicKey(t *testing.T) {
skipIfProxyIsSet(t)
_, dialer, _, _, _ := createClientWithPinningDialer("")
_, err := dialer.DialTLSContext(context.Background(), "tcp", "rsa4096.badssl.com:443")
r.Error(t, err, "expected dial to fail because of wrong public key")
}
func TestTLSSignedCertTrustedPublicKey(t *testing.T) {
skipIfProxyIsSet(t)
_, dialer, _, checker, _ := createClientWithPinningDialer("")
copyTrustedPins(checker)
checker.trustedPins = append(checker.trustedPins, `pin-sha256="LwnIKjNLV3z243ap8y0yXNPghsqE76J08Eq3COvUt2E="`)
_, err := dialer.DialTLSContext(context.Background(), "tcp", "rsa4096.badssl.com:443")
r.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA")
}
func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) {
skipIfProxyIsSet(t)
_, dialer, _, checker, _ := createClientWithPinningDialer("")
copyTrustedPins(checker)
checker.trustedPins = append(checker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
_, err := dialer.DialTLSContext(context.Background(), "tcp", "self-signed.badssl.com:443")
r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
}
func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) {
called := 0
reporter := NewTLSReporter(hostURL, "appVersion", useragent.New(), TrustedAPIPins)
checker := NewTLSPinChecker(TrustedAPIPins)
dialer := NewPinningTLSDialer(NewBasicTLSDialer(hostURL), reporter, checker)
go func() {
for range dialer.GetTLSIssueCh() {
called++
}
}()
return &called, dialer, reporter, checker, liteapi.New(
liteapi.WithHostURL(hostURL),
liteapi.WithTransport(CreateTransportWithDialer(dialer)),
)
}
func copyTrustedPins(pinChecker *TLSPinChecker) {
copiedPins := make([]string, len(pinChecker.trustedPins))
copy(copiedPins, pinChecker.trustedPins)
pinChecker.trustedPins = copiedPins
}
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) {
// TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called.
a.Eventually(
t,
func() bool {
if wantCalledAtLeast == 0 {
return *called == 0
}
// Dialer can do more attempts resulting in more calls.
return *called >= wantCalledAtLeast
},
time.Second,
10*time.Millisecond,
)
// Repeated again so it generates nice message.
if wantCalledAtLeast == 0 {
r.Equal(t, 0, *called)
} else {
r.GreaterOrEqual(t, *called, wantCalledAtLeast)
}
}

View File

@ -0,0 +1,152 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"net"
"net/url"
"sync"
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
var ErrNoConnection = errors.New("no connection")
// ProxyTLSDialer wraps a TLSDialer to switch to a proxy if the initial dial fails.
type ProxyTLSDialer struct {
dialer TLSDialer
locker sync.RWMutex
directAddress string
proxyAddress string
allowProxy bool
proxyProvider *proxyProvider
proxyUseDuration time.Duration
}
// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer.
func NewProxyTLSDialer(dialer TLSDialer, hostURL string) *ProxyTLSDialer {
return &ProxyTLSDialer{
dialer: dialer,
locker: sync.RWMutex{},
directAddress: formatAsAddress(hostURL),
proxyAddress: formatAsAddress(hostURL),
proxyProvider: newProxyProvider(dialer, hostURL, DoHProviders),
proxyUseDuration: proxyUseDuration,
}
}
// formatAsAddress returns URL as `host:port` for easy comparison in DialTLS.
func formatAsAddress(rawURL string) string {
url, err := url.Parse(rawURL)
if err != nil {
// This means wrong configuration.
// Developer should get feedback right away.
panic(err)
}
host := url.Host
if host == "" {
host = url.Path
}
port := "443"
if url.Scheme == "http" {
port = "80"
}
return net.JoinHostPort(host, port)
}
// DialTLS dials the given network/address. If it fails, it retries using a proxy.
func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
if address == d.directAddress {
address = d.proxyAddress
}
conn, err := d.dialer.DialTLSContext(ctx, network, address)
if err == nil || !d.allowProxy {
return conn, err
}
if err := d.switchToReachableServer(); err != nil {
return nil, err
}
return d.dialer.DialTLSContext(ctx, network, d.proxyAddress)
}
// switchToReachableServer switches to using a reachable server (either proxy or standard API).
func (d *ProxyTLSDialer) switchToReachableServer() error {
d.locker.Lock()
defer d.locker.Unlock()
logrus.Info("Attempting to switch to a proxy")
proxy, err := d.proxyProvider.findReachableServer()
if err != nil {
return errors.Wrap(err, "failed to find a usable proxy")
}
proxyAddress := formatAsAddress(proxy)
// If the chosen proxy is the standard API, we want to use it but still show the troubleshooting screen.
if proxyAddress == d.directAddress {
logrus.Info("The standard API is reachable again; connection drop was only intermittent")
d.proxyAddress = proxyAddress
return ErrNoConnection
}
logrus.WithField("proxy", proxyAddress).Info("Switching to a proxy")
// If the host is currently the rootURL, it's the first time we are enabling a proxy.
// This means we want to disable it again in 24 hours.
if d.proxyAddress == d.directAddress {
go func() {
<-time.After(d.proxyUseDuration)
d.locker.Lock()
defer d.locker.Unlock()
d.proxyAddress = d.directAddress
}()
}
d.proxyAddress = proxyAddress
return nil
}
// AllowProxy allows the dialer to switch to a proxy if need be.
func (d *ProxyTLSDialer) AllowProxy() {
d.locker.Lock()
defer d.locker.Unlock()
d.allowProxy = true
}
// DisallowProxy prevents the dialer from switching to a proxy if need be.
func (d *ProxyTLSDialer) DisallowProxy() {
d.locker.Lock()
defer d.locker.Unlock()
d.allowProxy = false
d.proxyAddress = d.directAddress
}

View File

@ -0,0 +1,256 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"encoding/base64"
"strings"
"sync"
"time"
"github.com/go-resty/resty/v2"
"github.com/miekg/dns"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
proxyUseDuration = 24 * time.Hour
proxyLookupWait = 5 * time.Second
proxyCacheRefreshTimeout = 20 * time.Second
proxyDoHTimeout = 20 * time.Second
proxyCanReachTimeout = 20 * time.Second
proxyQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz"
Quad9Provider = "https://dns11.quad9.net/dns-query"
Quad9PortProvider = "https://dns11.quad9.net:5053/dns-query"
GoogleProvider = "https://dns.google/dns-query"
)
var DoHProviders = []string{ //nolint:gochecknoglobals
Quad9Provider,
Quad9PortProvider,
GoogleProvider,
}
// proxyProvider manages known proxies.
type proxyProvider struct {
dialer TLSDialer
hostURL string
// dohLookup is used to look up the given query at the given DoH provider, returning the TXT records>
dohLookup func(ctx context.Context, query, provider string) (urls []string, err error)
providers []string // List of known doh providers.
query string // The query string used to find proxies.
proxyCache []string // All known proxies, cached in case DoH providers are unreachable.
cacheRefreshTimeout time.Duration
dohTimeout time.Duration
canReachTimeout time.Duration
lastLookup time.Time // The time at which we last attempted to find a proxy.
}
// newProxyProvider creates a new proxyProvider that queries the given DoH providers
// to retrieve DNS records for the given query string.
func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p *proxyProvider) {
p = &proxyProvider{
dialer: dialer,
hostURL: hostURL,
providers: providers,
query: proxyQuery,
cacheRefreshTimeout: proxyCacheRefreshTimeout,
dohTimeout: proxyDoHTimeout,
canReachTimeout: proxyCanReachTimeout,
}
// Use the default DNS lookup method; this can be overridden if necessary.
p.dohLookup = p.defaultDoHLookup
return
}
// findReachableServer returns a working API server (either proxy or standard API).
func (p *proxyProvider) findReachableServer() (proxy string, err error) {
logrus.Debug("Trying to find a reachable server")
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
return "", errors.New("not looking for a proxy, too soon")
}
p.lastLookup = time.Now()
// We use a waitgroup to wait for both
// a) the check whether the API is reachable, and
// b) the DoH queries.
// This is because the Alternative Routes v2 spec says:
// Call the GET /test/ping route on normal API domain (same time as DoH requests and wait until all have finished)
var wg sync.WaitGroup
var apiReachable bool
wg.Add(2)
go func() {
defer wg.Done()
apiReachable = p.canReach(p.hostURL)
}()
go func() {
defer wg.Done()
err = p.refreshProxyCache()
}()
wg.Wait()
if apiReachable {
proxy = p.hostURL
return
}
if err != nil {
return
}
for _, url := range p.proxyCache {
if p.canReach(url) {
proxy = url
return
}
}
return "", errors.New("no reachable server could be found")
}
// refreshProxyCache loads the latest proxies from the known providers.
// If the process takes longer than proxyCacheRefreshTimeout, an error is returned.
func (p *proxyProvider) refreshProxyCache() error {
logrus.Info("Refreshing proxy cache")
ctx, cancel := context.WithTimeout(context.Background(), p.cacheRefreshTimeout)
defer cancel()
resultChan := make(chan []string)
go func() {
for _, provider := range p.providers {
if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil {
resultChan <- proxies
return
}
}
// If no dohLoopkup worked, cancel right after it's done to not
// block refreshing for the whole cacheRefreshTimeout.
cancel()
}()
select {
case result := <-resultChan:
p.proxyCache = result
return nil
case <-ctx.Done():
return errors.New("timed out while refreshing proxy cache")
}
}
// canReach returns whether we can reach the given url.
func (p *proxyProvider) canReach(url string) bool {
logrus.WithField("url", url).Debug("Trying to ping proxy")
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
url = "https://" + url
}
pinger := resty.New().
SetBaseURL(url).
SetTimeout(p.canReachTimeout).
SetTransport(CreateTransportWithDialer(p.dialer))
if _, err := pinger.R().Get("/tests/ping"); err != nil {
logrus.WithField("proxy", url).WithError(err).Warn("Failed to ping proxy")
return false
}
return true
}
// defaultDoHLookup is the default implementation of the proxy manager's DoH lookup.
// It looks up DNS TXT records for the given query URL using the given DoH provider.
// It returns a list of all found TXT records.
// If the whole process takes more than proxyDoHTimeout then an error is returned.
func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider string) (data []string, err error) {
ctx, cancel := context.WithTimeout(ctx, p.dohTimeout)
defer cancel()
dataChan, errChan := make(chan []string), make(chan error)
go func() {
// Build new DNS request in RFC1035 format.
dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT)
// Pack the DNS request message into wire format.
rawRequest, err := dnsRequest.Pack()
if err != nil {
errChan <- errors.Wrap(err, "failed to pack DNS request")
return
}
// Encode wire-format DNS request message as base64url (RFC4648) without padding chars.
encodedRequest := base64.RawURLEncoding.EncodeToString(rawRequest)
// Make DoH request to the given DoH provider.
rawResponse, err := resty.New().R().SetContext(ctx).SetQueryParam("dns", encodedRequest).Get(dohProvider)
if err != nil {
errChan <- errors.Wrap(err, "failed to make DoH request")
return
}
// Unpack the DNS response.
dnsResponse := new(dns.Msg)
if err = dnsResponse.Unpack(rawResponse.Body()); err != nil {
errChan <- errors.Wrap(err, "failed to unpack DNS response")
return
}
// Pick out the TXT answers.
for _, answer := range dnsResponse.Answer {
if t, ok := answer.(*dns.TXT); ok {
data = append(data, t.Txt...)
}
}
dataChan <- data
}()
select {
case data = <-dataChan:
logrus.WithField("data", data).Info("Received TXT records")
return
case err = <-errChan:
logrus.WithField("provider", dohProvider).WithError(err).Error("Failed to query DNS records")
return
case <-ctx.Done():
logrus.WithField("provider", dohProvider).Error("Timed out querying DNS records")
return []string{}, errors.New("timed out querying DNS records")
}
}

View File

@ -0,0 +1,191 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"net/http"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
r "github.com/stretchr/testify/require"
)
func TestProxyProvider_FindProxy(t *testing.T) {
proxy := getTrustedServer()
defer closeServer(proxy)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findReachableServer()
r.NoError(t, err)
r.Equal(t, proxy.URL, url)
}
func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
reachableProxy := getTrustedServer()
defer closeServer(reachableProxy)
// We actually close the unreachable proxy straight away rather than deferring the closure.
unreachableProxy := getTrustedServer()
closeServer(unreachableProxy)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{reachableProxy.URL, unreachableProxy.URL}, nil
}
url, err := p.findReachableServer()
r.NoError(t, err)
r.Equal(t, reachableProxy.URL, url)
}
func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) {
trustedProxy := getTrustedServer()
defer closeServer(trustedProxy)
untrustedProxy := getUntrustedServer()
defer closeServer(untrustedProxy)
reporter := NewTLSReporter("", "appVersion", useragent.New(), TrustedAPIPins)
checker := NewTLSPinChecker(TrustedAPIPins)
dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker)
p := newProxyProvider(dialer, "", []string{"not used"})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{untrustedProxy.URL, trustedProxy.URL}, nil
}
url, err := p.findReachableServer()
r.NoError(t, err)
r.Equal(t, trustedProxy.URL, url)
}
func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
unreachableProxy1 := getTrustedServer()
closeServer(unreachableProxy1)
unreachableProxy2 := getTrustedServer()
closeServer(unreachableProxy2)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil
}
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
untrustedProxy1 := getUntrustedServer()
defer closeServer(untrustedProxy1)
untrustedProxy2 := getUntrustedServer()
defer closeServer(untrustedProxy2)
reporter := NewTLSReporter("", "appVersion", useragent.New(), TrustedAPIPins)
checker := NewTLSPinChecker(TrustedAPIPins)
dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker)
p := newProxyProvider(dialer, "", []string{"not used"})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil
}
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
p.cacheRefreshTimeout = 1 * time.Second
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
// We should fail to refresh the proxy cache because the doh provider
// takes 2 seconds to respond but we timeout after just 1 second.
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) {
slowProxy := getTrustedServerWithHandler(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
time.Sleep(2 * time.Second)
}))
defer closeServer(slowProxy)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
p.canReachTimeout = 1 * time.Second
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
// We should fail to reach the returned proxy because it takes 2 seconds
// to reach it and we only allow 1.
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9Provider)
r.NoError(t, err)
r.NotEmpty(t, records)
}
// DISABLEDTestProxyProvider_DoHLookup_Quad9Port cannot run on CI due to custom
// port filter. Basic functionality should be covered by other tests. Keeping
// code here to be able to run it locally if needed.
func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9PortProvider)
r.NoError(t, err)
r.NotEmpty(t, records)
}
func TestProxyProvider_DoHLookup_Google(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
records, err := p.dohLookup(context.Background(), proxyQuery, GoogleProvider)
r.NoError(t, err)
r.NotEmpty(t, records)
}
func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
skipIfProxyIsSet(t)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
url, err := p.findReachableServer()
r.NoError(t, err)
r.NotEmpty(t, url)
}
func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
skipIfProxyIsSet(t)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"https://unreachable", Quad9Provider, GoogleProvider})
url, err := p.findReachableServer()
r.NoError(t, err)
r.NotEmpty(t, url)
}

View File

@ -0,0 +1,273 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// getTrustedServer returns a server and sets its public key as one of the pinned ones.
func getTrustedServer() *httptest.Server {
return getTrustedServerWithHandler(
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
// Do nothing.
}),
)
}
func getTrustedServerWithHandler(handler http.HandlerFunc) *httptest.Server {
proxy := httptest.NewTLSServer(handler)
pin := certFingerprint(proxy.Certificate())
TrustedAPIPins = append(TrustedAPIPins, pin)
return proxy
}
const servercrt = `
-----BEGIN CERTIFICATE-----
MIIE5TCCA82gAwIBAgIJAKsmhcMFGfGcMA0GCSqGSIb3DQEBCwUAMIGsMQswCQYD
VQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzARBgNVBAcMClJhbmRvbUNp
dHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEfMB0GA1UECwwWUmFuZG9t
T3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYRaGVsbG9AZXhhbXBsZS5j
b20xEjAQBgNVBAMMCTEyNy4wLjAuMTAeFw0yMDA0MjQxMzI3MzdaFw0yMTA5MDYx
MzI3MzdaMIGsMQswCQYDVQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzAR
BgNVBAcMClJhbmRvbUNpdHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEf
MB0GA1UECwwWUmFuZG9tT3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYR
aGVsbG9AZXhhbXBsZS5jb20xEjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZI
hvcNAQEBBQADggEPADCCAQoCggEBANAnYyqhosWwNzGjBwSwmDUINOaPs4TSTgKt
r6CE01atxAWzWUCyYqnQ4fPe5q2tx5t/VrmnTNpzycammKJszGLlmj9DFxSiYVw2
pTTK3DBWFkfTwxq98mM7wMnCWy1T2L2pmuYjnd7Pa6pQa9OHYoJwRzlIl2Q3YVdM
GIBDbkW728A1dcelkIdFpv3r3ayTZv01vU8JMXd4PLHwXU0x0hHlH52+kx+9Ndru
rdqqV6LqVfNlSR1jFZkwLBBqvh3XrJRD9Q01EAX6m+ufZ0yq8mK9ifMRtwQet10c
kKMnx63MwvxDFmqrBj4HMtIRUpK+LBDs1ke7DvS0eLqaojWl28ECAwEAAaOCAQYw
ggECMIHLBgNVHSMEgcMwgcChgbKkga8wgawxCzAJBgNVBAYTAlVTMRQwEgYDVQQI
DAtSYW5kb21TdGF0ZTETMBEGA1UEBwwKUmFuZG9tQ2l0eTEbMBkGA1UECgwSUmFu
ZG9tT3JnYW5pemF0aW9uMR8wHQYDVQQLDBZSYW5kb21Pcmdhbml6YXRpb25Vbml0
MSAwHgYJKoZIhvcNAQkBFhFoZWxsb0BleGFtcGxlLmNvbTESMBAGA1UEAwwJMTI3
LjAuMC4xggkAvCxbs152YckwCQYDVR0TBAIwADALBgNVHQ8EBAMCBPAwGgYDVR0R
BBMwEYIJMTI3LjAuMC4xhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAC7ZycZMZ5
L+cjIpwSj0cemLkVD+kcFUCkI7ket5gbX1PmavmnpuFl9Sru0eJ5wyJ+97MQElPA
CNFgXoX7DbJWkcd/LSksvZoJnpc1sTqFKMWFmOUxmUD62lCacuhqE27ZTThQ/53P
3doLa74rKzUqlPI8OL4R34FY2deL7t5l2KSnpf7CKNeF5bkinAsn6NBqyZs2KPmg
yT1/POdlRewzGSqBTMdktNQ4vKSfdFjcfVeo8PSHBgbGXZ5KoHZ6R6DNJehEh27l
z3OteROLGoii+w3OllLq6JATif2MDIbH0s/KjGjbXSSGbM/rZu5eBZm5/vksGAzc
u53wgIhCJGuX
-----END CERTIFICATE-----
`
const serverkey = `
-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDQJ2MqoaLFsDcx
owcEsJg1CDTmj7OE0k4Cra+ghNNWrcQFs1lAsmKp0OHz3uatrcebf1a5p0zac8nG
ppiibMxi5Zo/QxcUomFcNqU0ytwwVhZH08MavfJjO8DJwlstU9i9qZrmI53ez2uq
UGvTh2KCcEc5SJdkN2FXTBiAQ25Fu9vANXXHpZCHRab9692sk2b9Nb1PCTF3eDyx
8F1NMdIR5R+dvpMfvTXa7q3aqlei6lXzZUkdYxWZMCwQar4d16yUQ/UNNRAF+pvr
n2dMqvJivYnzEbcEHrddHJCjJ8etzML8QxZqqwY+BzLSEVKSviwQ7NZHuw70tHi6
mqI1pdvBAgMBAAECggEAOqqPOYm63arPs462QK0hCPlaJ41i1FGNqRWYxU4KXoi1
EcI9qo1cX24+8MPnEhZDhuD56XNsprkxqmpz5Htzk4AQ3DmlfKxTcnD4WQu/yWPJ
/c6CU7wrX6qMqJC9r+XM1Y/C15A8Q3sEZkkqSsECk67fdBawjI9LQRZyZVwb7U0F
qtvbKM7VQA6hrgdSmXWJ+spp5yymVFF22Ssz31SSbCI93bnp3mukRCKWdRmA9pmT
VXa0HzJ5p70WC+Se9nA/1riWGKt4HCmjVeEtZuiwaUTlXDSeYpu2e4QrX1OnUXBu
Z7yfviTqA8o7KfiA6urumFbAMJcibxkWJoWacc5tTQKBgQD39ZdtNz8B6XJy7f5h
bo9Ag9OrkVX+HITQyWKpcCDba9SuIX3/F++2AK4oeJ3aHKMJWiP19hQvGS1xE67X
TKejOsQxORn6nAYQpFd3AOBOtKAC+VQITBqlfq2ukGmvcQ1O31hMOFbZagFA5cpU
LYb9VVDsZzhM7CccIn/EGEZjgwKBgQDW51rUA2S9naV/iEGhw1tuhoQ5OADD/n8f
pPIkbGxmACDaX/7jt+UwlDU0EsI+aBlJUDqGiEZ5z3UPmaSJUdfRCeJEdKIe1GLm
nqF3sF6Aq+S/79v/wKYn+MHcoiWog5n3McLzZ3+0rwrhMREjE2eWPwVHz/jJIFP3
Pp3+UZVsawKBgB4Az5PdjXgzwS968L7lW9wYl3I5Iciftsp0s8WA1dj3EUMItnA5
ez3wkyI+hgswT+H/0D4gyoxwZXk7Qnq2wcoUgEzcdfJHEszMtfCmYH3liT8S4EIo
w0inLWjj/IXIDi4vBEYkww2HsCMkKvlIkP7yZdpVGxDjuk/DNOaLcWj1AoGAXuyK
PiPRl7/Onmp9MwqrlEJunSeTjv8W/89H9ba+mr9rw4mreMJ9xdtxNLMkgZRRtwRt
FYeUObHdLyradp1kCr2m6D3sblm55cwj3k5VL9i9jdpQ/sMFoZpLZz1oDOs0Uu/0
ALeyvQikcZvOygOEOeVUW8gNSCmzbP6HoxI+QkkCgYBCI6oL4GPcPPqzd+2djbOD
z3rVUyHzYc1KUcBixK/uaRQKM886k4CL8/GvbHHI/yoZ7xWJGnBi59DtpqnGTZJ2
FDJwYIlQKhZmsyVcZu/4smsaejGnHn/liksVlgesSwCtOrsd2AC8fBXSyrTWJx8o
vwRMog6lPhlRhHh/FZ43Cg==
-----END PRIVATE KEY-----
`
// getUntrustedServer returns a server but it doesn't add its public key to the list of pinned ones.
func getUntrustedServer() *httptest.Server {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
cert, err := tls.X509KeyPair([]byte(servercrt), []byte(serverkey))
if err != nil {
panic(err)
}
server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
server.StartTLS()
return server
}
// closeServer closes the given server. If it is a trusted server, its cert is removed from the trusted public keys.
func closeServer(server *httptest.Server) {
pin := certFingerprint(server.Certificate())
for i := range TrustedAPIPins {
if TrustedAPIPins[i] == pin {
TrustedAPIPins = append(TrustedAPIPins[:i], TrustedAPIPins[i:]...)
break
}
}
server.Close()
}
func TestProxyDialer_UseProxy(t *testing.T) {
trustedProxy := getTrustedServer()
defer closeServer(trustedProxy)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
}
func TestProxyDialer_UseProxy_MultipleTimes(t *testing.T) {
proxy1 := getTrustedServer()
defer closeServer(proxy1)
proxy2 := getTrustedServer()
defer closeServer(proxy2)
proxy3 := getTrustedServer()
defer closeServer(proxy3)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress)
// Have to wait so as to not get rejected.
time.Sleep(proxyLookupWait)
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy2.URL}, nil }
err = d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress)
// Have to wait so as to not get rejected.
time.Sleep(proxyLookupWait)
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy3.URL}, nil }
err = d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy3.URL), d.proxyAddress)
}
func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
trustedProxy := getTrustedServer()
defer closeServer(trustedProxy)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
d.proxyProvider = provider
d.proxyUseDuration = time.Second
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
time.Sleep(2 * time.Second)
require.Equal(t, ":443", d.proxyAddress)
}
func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
trustedProxy := getTrustedServer()
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
// Simulate that the proxy stops working and that the standard api is reachable again.
closeServer(trustedProxy)
d.directAddress = formatAsAddress(getRootURL())
provider.hostURL = getRootURL()
time.Sleep(proxyLookupWait)
// We should now find the original API URL if it is working again.
// The error should be ErrAPINotReachable because the connection dropped intermittently but
// the original API is now reachable (see Alternative-Routing-v2 spec for details).
err = d.switchToReachableServer()
require.Error(t, err)
require.Equal(t, formatAsAddress(getRootURL()), d.proxyAddress)
}
func TestProxyDialer_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
// proxy1 is closed later in this test so we don't defer it here.
proxy1 := getTrustedServer()
proxy2 := getTrustedServer()
defer closeServer(proxy2)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress)
// Have to wait so as to not get rejected.
time.Sleep(proxyLookupWait)
// The proxy stops working and the protonmail API is still blocked.
closeServer(proxy1)
// Should switch to the second proxy because both the first proxy and the protonmail API are blocked.
err = d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress)
}
func TestFormatAsAddress(t *testing.T) {
r := require.New(t)
testData := map[string]string{
"sub.domain.tld": "sub.domain.tld:443",
"http://sub.domain.tld": "sub.domain.tld:80",
"https://sub.domain.tld": "sub.domain.tld:443",
"ftp://sub.domain.tld": "sub.domain.tld:443",
"//sub.domain.tld": "sub.domain.tld:443",
}
for rawURL, wantURL := range testData {
r.Equal(wantURL, formatAsAddress(rawURL))
}
}

View File

@ -0,0 +1,16 @@
package dialer
import (
"testing"
"golang.org/x/net/http/httpproxy"
)
// skipIfProxyIsSet skips the tests if HTTPS proxy is set.
// Should be used for tests depending on proper certificate checks which
// is not possible under our CI setup.
func skipIfProxyIsSet(t *testing.T) {
if httpproxy.FromEnvironment().HTTPSProxy != "" {
t.SkipNow()
}
}