Other: Fix race conditions in internal/dialer

Some race conditions came from the tests themselves. But we had a race
condition reading the proxyAddress; this change protects it with a
mutex.
This commit is contained in:
James Houlahan
2022-10-24 16:41:17 +02:00
parent d6260d960c
commit 350544e801
3 changed files with 29 additions and 9 deletions

View File

@ -19,6 +19,7 @@ package dialer
import (
"context"
"sync/atomic"
"testing"
"time"
@ -109,8 +110,8 @@ func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) {
r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
}
func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) {
called := 0
func createClientWithPinningDialer(hostURL string) (*atomicUint64, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) {
called := &atomicUint64{}
reporter := NewTLSReporter(hostURL, "appVersion", useragent.New(), TrustedAPIPins)
checker := NewTLSPinChecker(TrustedAPIPins)
@ -118,11 +119,11 @@ func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TL
go func() {
for range dialer.GetTLSIssueCh() {
called++
called.add(1)
}
}()
return &called, dialer, reporter, checker, liteapi.New(
return called, dialer, reporter, checker, liteapi.New(
liteapi.WithHostURL(hostURL),
liteapi.WithTransport(CreateTransportWithDialer(dialer)),
)
@ -134,24 +135,32 @@ func copyTrustedPins(pinChecker *TLSPinChecker) {
pinChecker.trustedPins = copiedPins
}
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) {
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast uint64, called *atomicUint64) {
// TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called.
a.Eventually(
t,
func() bool {
if wantCalledAtLeast == 0 {
return *called == 0
return called.load() == 0
}
// Dialer can do more attempts resulting in more calls.
return *called >= wantCalledAtLeast
return called.load() >= wantCalledAtLeast
},
time.Second,
10*time.Millisecond,
)
// Repeated again so it generates nice message.
if wantCalledAtLeast == 0 {
r.Equal(t, 0, *called)
r.Equal(t, uint64(0), called.load())
} else {
r.GreaterOrEqual(t, *called, wantCalledAtLeast)
r.GreaterOrEqual(t, called.load(), wantCalledAtLeast)
}
}
type atomicUint64 struct {
v uint64
}
func (x *atomicUint64) load() uint64 { return atomic.LoadUint64(&x.v) }
func (x *atomicUint64) add(delta uint64) uint64 { return atomic.AddUint64(&x.v, delta) }

View File

@ -77,9 +77,11 @@ func formatAsAddress(rawURL string) string {
// DialTLSContext dials the given network/address. If it fails, it retries using a proxy.
func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
d.locker.RLock()
if address == d.directAddress {
address = d.proxyAddress
}
d.locker.RUnlock()
conn, err := d.dialer.DialTLSContext(ctx, network, address)
if err == nil || !d.allowProxy {
@ -90,6 +92,9 @@ func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address st
return nil, err
}
d.locker.RLock()
defer d.locker.RUnlock()
return d.dialer.DialTLSContext(ctx, network, d.proxyAddress)
}

View File

@ -197,10 +197,16 @@ func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
d.locker.Lock()
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
d.locker.Unlock()
time.Sleep(2 * time.Second)
d.locker.Lock()
require.Equal(t, ":443", d.proxyAddress)
d.locker.Unlock()
}
func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {