Other: Fix race conditions in internal/dialer

Some race conditions came from the tests themselves. But we had a race
condition reading the proxyAddress; this change protects it with a
mutex.
This commit is contained in:
James Houlahan
2022-10-24 16:41:17 +02:00
parent d6260d960c
commit 350544e801
3 changed files with 29 additions and 9 deletions

View File

@ -19,6 +19,7 @@ package dialer
import (
"context"
"sync/atomic"
"testing"
"time"
@ -109,8 +110,8 @@ func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) {
r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
}
func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) {
called := 0
func createClientWithPinningDialer(hostURL string) (*atomicUint64, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) {
called := &atomicUint64{}
reporter := NewTLSReporter(hostURL, "appVersion", useragent.New(), TrustedAPIPins)
checker := NewTLSPinChecker(TrustedAPIPins)
@ -118,11 +119,11 @@ func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TL
go func() {
for range dialer.GetTLSIssueCh() {
called++
called.add(1)
}
}()
return &called, dialer, reporter, checker, liteapi.New(
return called, dialer, reporter, checker, liteapi.New(
liteapi.WithHostURL(hostURL),
liteapi.WithTransport(CreateTransportWithDialer(dialer)),
)
@ -134,24 +135,32 @@ func copyTrustedPins(pinChecker *TLSPinChecker) {
pinChecker.trustedPins = copiedPins
}
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) {
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast uint64, called *atomicUint64) {
// TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called.
a.Eventually(
t,
func() bool {
if wantCalledAtLeast == 0 {
return *called == 0
return called.load() == 0
}
// Dialer can do more attempts resulting in more calls.
return *called >= wantCalledAtLeast
return called.load() >= wantCalledAtLeast
},
time.Second,
10*time.Millisecond,
)
// Repeated again so it generates nice message.
if wantCalledAtLeast == 0 {
r.Equal(t, 0, *called)
r.Equal(t, uint64(0), called.load())
} else {
r.GreaterOrEqual(t, *called, wantCalledAtLeast)
r.GreaterOrEqual(t, called.load(), wantCalledAtLeast)
}
}
type atomicUint64 struct {
v uint64
}
func (x *atomicUint64) load() uint64 { return atomic.LoadUint64(&x.v) }
func (x *atomicUint64) add(delta uint64) uint64 { return atomic.AddUint64(&x.v, delta) }