mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
Other: Fix race conditions in internal/dialer
Some race conditions came from the tests themselves. But we had a race condition reading the proxyAddress; this change protects it with a mutex.
This commit is contained in:
@ -19,6 +19,7 @@ package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -109,8 +110,8 @@ func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) {
|
||||
r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
|
||||
}
|
||||
|
||||
func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) {
|
||||
called := 0
|
||||
func createClientWithPinningDialer(hostURL string) (*atomicUint64, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) {
|
||||
called := &atomicUint64{}
|
||||
|
||||
reporter := NewTLSReporter(hostURL, "appVersion", useragent.New(), TrustedAPIPins)
|
||||
checker := NewTLSPinChecker(TrustedAPIPins)
|
||||
@ -118,11 +119,11 @@ func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TL
|
||||
|
||||
go func() {
|
||||
for range dialer.GetTLSIssueCh() {
|
||||
called++
|
||||
called.add(1)
|
||||
}
|
||||
}()
|
||||
|
||||
return &called, dialer, reporter, checker, liteapi.New(
|
||||
return called, dialer, reporter, checker, liteapi.New(
|
||||
liteapi.WithHostURL(hostURL),
|
||||
liteapi.WithTransport(CreateTransportWithDialer(dialer)),
|
||||
)
|
||||
@ -134,24 +135,32 @@ func copyTrustedPins(pinChecker *TLSPinChecker) {
|
||||
pinChecker.trustedPins = copiedPins
|
||||
}
|
||||
|
||||
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) {
|
||||
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast uint64, called *atomicUint64) {
|
||||
// TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called.
|
||||
a.Eventually(
|
||||
t,
|
||||
func() bool {
|
||||
if wantCalledAtLeast == 0 {
|
||||
return *called == 0
|
||||
return called.load() == 0
|
||||
}
|
||||
// Dialer can do more attempts resulting in more calls.
|
||||
return *called >= wantCalledAtLeast
|
||||
return called.load() >= wantCalledAtLeast
|
||||
},
|
||||
time.Second,
|
||||
10*time.Millisecond,
|
||||
)
|
||||
// Repeated again so it generates nice message.
|
||||
if wantCalledAtLeast == 0 {
|
||||
r.Equal(t, 0, *called)
|
||||
r.Equal(t, uint64(0), called.load())
|
||||
} else {
|
||||
r.GreaterOrEqual(t, *called, wantCalledAtLeast)
|
||||
r.GreaterOrEqual(t, called.load(), wantCalledAtLeast)
|
||||
}
|
||||
}
|
||||
|
||||
type atomicUint64 struct {
|
||||
v uint64
|
||||
}
|
||||
|
||||
func (x *atomicUint64) load() uint64 { return atomic.LoadUint64(&x.v) }
|
||||
|
||||
func (x *atomicUint64) add(delta uint64) uint64 { return atomic.AddUint64(&x.v, delta) }
|
||||
|
||||
@ -77,9 +77,11 @@ func formatAsAddress(rawURL string) string {
|
||||
|
||||
// DialTLSContext dials the given network/address. If it fails, it retries using a proxy.
|
||||
func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d.locker.RLock()
|
||||
if address == d.directAddress {
|
||||
address = d.proxyAddress
|
||||
}
|
||||
d.locker.RUnlock()
|
||||
|
||||
conn, err := d.dialer.DialTLSContext(ctx, network, address)
|
||||
if err == nil || !d.allowProxy {
|
||||
@ -90,6 +92,9 @@ func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address st
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d.locker.RLock()
|
||||
defer d.locker.RUnlock()
|
||||
|
||||
return d.dialer.DialTLSContext(ctx, network, d.proxyAddress)
|
||||
}
|
||||
|
||||
|
||||
@ -197,10 +197,16 @@ func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
|
||||
d.locker.Lock()
|
||||
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
|
||||
d.locker.Unlock()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
d.locker.Lock()
|
||||
require.Equal(t, ":443", d.proxyAddress)
|
||||
d.locker.Unlock()
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user