forked from Silverfish/proton-bridge
fix: correct doh timeouts
This commit is contained in:
@ -18,6 +18,7 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -35,7 +36,15 @@ const (
|
||||
|
||||
// getTrustedServer returns a server and sets its public key as one of the pinned ones.
|
||||
func getTrustedServer() *httptest.Server {
|
||||
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
return getTrustedServerWithHandler(
|
||||
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||
// Do nothing.
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
func getTrustedServerWithHandler(handler http.HandlerFunc) *httptest.Server {
|
||||
proxy := httptest.NewTLSServer(handler)
|
||||
|
||||
pin := certFingerprint(proxy.Certificate())
|
||||
TrustedAPIPins = append(TrustedAPIPins, pin)
|
||||
@ -143,7 +152,7 @@ func TestProxyProvider_FindProxy(t *testing.T) {
|
||||
defer closeServer(proxy)
|
||||
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
require.NoError(t, err)
|
||||
@ -162,7 +171,9 @@ func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
|
||||
closeServer(unreachableProxy)
|
||||
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{reachableProxy.URL, unreachableProxy.URL}, nil }
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{reachableProxy.URL, unreachableProxy.URL}, nil
|
||||
}
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
require.NoError(t, err)
|
||||
@ -180,7 +191,9 @@ func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) {
|
||||
defer closeServer(untrustedProxy)
|
||||
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{untrustedProxy.URL, trustedProxy.URL}, nil }
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{untrustedProxy.URL, trustedProxy.URL}, nil
|
||||
}
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
require.NoError(t, err)
|
||||
@ -198,7 +211,7 @@ func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
|
||||
closeServer(unreachableProxy2)
|
||||
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) {
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil
|
||||
}
|
||||
|
||||
@ -217,7 +230,7 @@ func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
|
||||
defer closeServer(untrustedProxy2)
|
||||
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) {
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil
|
||||
}
|
||||
|
||||
@ -225,34 +238,38 @@ func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_LookupTimeout(t *testing.T) {
|
||||
func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.lookupTimeout = time.Second
|
||||
p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
|
||||
p.cacheRefreshTimeout = 1 * time.Second
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
|
||||
|
||||
// The findReachableServer should fail because lookup takes 2 seconds but we only allow 1 second.
|
||||
// We should fail to refresh the proxy cache because the doh provider
|
||||
// takes 2 seconds to respond but we timeout after just 1 second.
|
||||
_, err := p.findReachableServer()
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_FindTimeout(t *testing.T) {
|
||||
func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
slowProxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
slowProxy := getTrustedServerWithHandler(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||
time.Sleep(2 * time.Second)
|
||||
}))
|
||||
defer slowProxy.Close()
|
||||
defer closeServer(slowProxy)
|
||||
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.findTimeout = time.Second
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
|
||||
p.canReachTimeout = 1 * time.Second
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
|
||||
|
||||
// The findReachableServer should fail because lookup takes 2 seconds but we only allow 1 second.
|
||||
// We should fail to reach the returned proxy because it takes 2 seconds
|
||||
// to reach it and we only allow 1.
|
||||
_, err := p.findReachableServer()
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@ -268,7 +285,7 @@ func TestProxyProvider_UseProxy(t *testing.T) {
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
cm.proxyProvider = p
|
||||
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
url, err := cm.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, trustedProxy.URL, url)
|
||||
@ -291,7 +308,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
cm.proxyProvider = p
|
||||
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
|
||||
url, err := cm.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, proxy1.URL, url)
|
||||
@ -300,7 +317,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
|
||||
// Have to wait so as to not get rejected.
|
||||
time.Sleep(proxyLookupWait)
|
||||
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy2.URL}, nil }
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy2.URL}, nil }
|
||||
url, err = cm.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, proxy2.URL, url)
|
||||
@ -309,7 +326,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
|
||||
// Have to wait so as to not get rejected.
|
||||
time.Sleep(proxyLookupWait)
|
||||
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy3.URL}, nil }
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy3.URL}, nil }
|
||||
url, err = cm.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, proxy3.URL, url)
|
||||
@ -329,7 +346,7 @@ func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) {
|
||||
cm.proxyProvider = p
|
||||
cm.proxyUseDuration = time.Second
|
||||
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
url, err := cm.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, trustedProxy.URL, url)
|
||||
@ -350,7 +367,7 @@ func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachab
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
cm.proxyProvider = p
|
||||
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
url, err := cm.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, trustedProxy.URL, url)
|
||||
@ -386,7 +403,7 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl
|
||||
cm.proxyProvider = p
|
||||
|
||||
// Find a proxy.
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
|
||||
url, err := cm.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, proxy1.URL, url)
|
||||
@ -408,7 +425,7 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl
|
||||
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
|
||||
p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
|
||||
records, err := p.dohLookup(TestDoHQuery, TestQuad9Provider)
|
||||
records, err := p.dohLookup(context.Background(), TestDoHQuery, TestQuad9Provider)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, records)
|
||||
}
|
||||
@ -416,7 +433,7 @@ func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
|
||||
func TestProxyProvider_DoHLookup_Google(t *testing.T) {
|
||||
p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
|
||||
records, err := p.dohLookup(TestDoHQuery, TestGoogleProvider)
|
||||
records, err := p.dohLookup(context.Background(), TestDoHQuery, TestGoogleProvider)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, records)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user