From 350544e801c54e3e7e4e9a339bcacff0db46eba7 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Mon, 24 Oct 2022 16:41:17 +0200 Subject: [PATCH] Other: Fix race conditions in internal/dialer Some race conditions came from the tests themselves. But we had a race condition reading the proxyAddress; this change protects it with a mutex. --- internal/dialer/dialer_pinning_test.go | 27 +++++++++++++++++--------- internal/dialer/dialer_proxy.go | 5 +++++ internal/dialer/dialer_proxy_test.go | 6 ++++++ 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/internal/dialer/dialer_pinning_test.go b/internal/dialer/dialer_pinning_test.go index 61971f53..aee31d65 100644 --- a/internal/dialer/dialer_pinning_test.go +++ b/internal/dialer/dialer_pinning_test.go @@ -19,6 +19,7 @@ package dialer import ( "context" + "sync/atomic" "testing" "time" @@ -109,8 +110,8 @@ func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) { r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed") } -func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) { - called := 0 +func createClientWithPinningDialer(hostURL string) (*atomicUint64, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) { + called := &atomicUint64{} reporter := NewTLSReporter(hostURL, "appVersion", useragent.New(), TrustedAPIPins) checker := NewTLSPinChecker(TrustedAPIPins) @@ -118,11 +119,11 @@ func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TL go func() { for range dialer.GetTLSIssueCh() { - called++ + called.add(1) } }() - return &called, dialer, reporter, checker, liteapi.New( + return called, dialer, reporter, checker, liteapi.New( liteapi.WithHostURL(hostURL), liteapi.WithTransport(CreateTransportWithDialer(dialer)), ) @@ -134,24 +135,32 @@ func copyTrustedPins(pinChecker *TLSPinChecker) { pinChecker.trustedPins = copiedPins } -func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) { +func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast uint64, called *atomicUint64) { // TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called. a.Eventually( t, func() bool { if wantCalledAtLeast == 0 { - return *called == 0 + return called.load() == 0 } // Dialer can do more attempts resulting in more calls. - return *called >= wantCalledAtLeast + return called.load() >= wantCalledAtLeast }, time.Second, 10*time.Millisecond, ) // Repeated again so it generates nice message. if wantCalledAtLeast == 0 { - r.Equal(t, 0, *called) + r.Equal(t, uint64(0), called.load()) } else { - r.GreaterOrEqual(t, *called, wantCalledAtLeast) + r.GreaterOrEqual(t, called.load(), wantCalledAtLeast) } } + +type atomicUint64 struct { + v uint64 +} + +func (x *atomicUint64) load() uint64 { return atomic.LoadUint64(&x.v) } + +func (x *atomicUint64) add(delta uint64) uint64 { return atomic.AddUint64(&x.v, delta) } diff --git a/internal/dialer/dialer_proxy.go b/internal/dialer/dialer_proxy.go index 4ad7f4b1..fddb6091 100644 --- a/internal/dialer/dialer_proxy.go +++ b/internal/dialer/dialer_proxy.go @@ -77,9 +77,11 @@ func formatAsAddress(rawURL string) string { // DialTLSContext dials the given network/address. If it fails, it retries using a proxy. func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { + d.locker.RLock() if address == d.directAddress { address = d.proxyAddress } + d.locker.RUnlock() conn, err := d.dialer.DialTLSContext(ctx, network, address) if err == nil || !d.allowProxy { @@ -90,6 +92,9 @@ func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address st return nil, err } + d.locker.RLock() + defer d.locker.RUnlock() + return d.dialer.DialTLSContext(ctx, network, d.proxyAddress) } diff --git a/internal/dialer/dialer_proxy_test.go b/internal/dialer/dialer_proxy_test.go index f6963f9c..66d60794 100644 --- a/internal/dialer/dialer_proxy_test.go +++ b/internal/dialer/dialer_proxy_test.go @@ -197,10 +197,16 @@ func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) { provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } err := d.switchToReachableServer() require.NoError(t, err) + + d.locker.Lock() require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress) + d.locker.Unlock() time.Sleep(2 * time.Second) + + d.locker.Lock() require.Equal(t, ":443", d.proxyAddress) + d.locker.Unlock() } func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {