fix: correct doh timeouts

This commit is contained in:
James Houlahan
2020-05-25 17:15:35 +02:00
parent ad877431de
commit cc14b523cb
3 changed files with 133 additions and 94 deletions

View File

@ -18,8 +18,10 @@
package pmapi
import (
"context"
"encoding/base64"
"strings"
"sync"
"time"
"github.com/go-resty/resty/v2"
@ -29,11 +31,12 @@ import (
)
const (
proxyUseDuration = 24 * time.Hour
proxySearchTimeout = 20 * time.Second
proxyQueryTimeout = 20 * time.Second
proxyLookupWait = 5 * time.Second
proxyQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz"
proxyUseDuration = 24 * time.Hour
proxyLookupWait = 5 * time.Second
proxyCacheRefreshTimeout = 20 * time.Second
proxyDoHTimeout = 20 * time.Second
proxyCanReachTimeout = 20 * time.Second
proxyQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz"
)
var dohProviders = []string{ //nolint[gochecknoglobals]
@ -44,13 +47,15 @@ var dohProviders = []string{ //nolint[gochecknoglobals]
// proxyProvider manages known proxies.
type proxyProvider struct {
// dohLookup is used to look up the given query at the given DoH provider, returning the TXT records>
dohLookup func(query, provider string) (urls []string, err error)
dohLookup func(ctx context.Context, query, provider string) (urls []string, err error)
providers []string // List of known doh providers.
query string // The query string used to find proxies.
proxyCache []string // All known proxies, cached in case DoH providers are unreachable.
findTimeout, lookupTimeout time.Duration // Timeouts for DNS query and proxy search.
cacheRefreshTimeout time.Duration
dohTimeout time.Duration
canReachTimeout time.Duration
lastLookup time.Time // The time at which we last attempted to find a proxy.
}
@ -59,10 +64,11 @@ type proxyProvider struct {
// to retrieve DNS records for the given query string.
func newProxyProvider(providers []string, query string) (p *proxyProvider) { // nolint[unparam]
p = &proxyProvider{
providers: providers,
query: query,
findTimeout: proxySearchTimeout,
lookupTimeout: proxyQueryTimeout,
providers: providers,
query: query,
cacheRefreshTimeout: proxyCacheRefreshTimeout,
dohTimeout: proxyDoHTimeout,
canReachTimeout: proxyCanReachTimeout,
}
// Use the default DNS lookup method; this can be overridden if necessary.
@ -72,77 +78,89 @@ func newProxyProvider(providers []string, query string) (p *proxyProvider) { //
}
// findReachableServer returns a working API server (either proxy or standard API).
// It returns an error if the process takes longer than ProxySearchTime.
func (p *proxyProvider) findReachableServer() (proxy string, err error) {
logrus.Debug("Trying to find a reachable server")
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
return "", errors.New("not looking for a proxy, too soon")
}
p.lastLookup = time.Now()
proxyResult := make(chan string)
errResult := make(chan error)
// We use a waitgroup to wait for both
// a) the check whether the API is reachable, and
// b) the DoH queries.
// This is because the Alternative Routes v2 spec says:
// Call the GET /test/ping route on normal API domain (same time as DoH requests and wait until all have finished)
var wg sync.WaitGroup
var apiReachable bool
wg.Add(2)
go func() {
if err = p.refreshProxyCache(); err != nil {
errResult <- errors.Wrap(err, "failed to refresh proxy cache")
return
}
// We want to switch back to the rootURL if possible.
if p.canReach(rootURL) {
proxyResult <- rootURL
return
}
for _, proxy := range p.proxyCache {
if p.canReach(proxy) {
proxyResult <- proxy
return
}
}
errResult <- errors.New("no proxy available")
defer wg.Done()
apiReachable = p.canReach(rootURL)
}()
select {
case <-time.After(p.findTimeout):
logrus.Error("Timed out finding a proxy server")
return "", errors.New("timed out finding a proxy")
go func() {
defer wg.Done()
err = p.refreshProxyCache()
}()
case proxy = <-proxyResult:
logrus.WithField("proxy", proxy).Info("Found proxy server")
return
wg.Wait()
case err = <-errResult:
logrus.WithError(err).Error("Failed to find available proxy server")
if apiReachable {
proxy = rootURL
return
}
if err != nil {
return
}
for _, url := range p.proxyCache {
if p.canReach(url) {
proxy = url
return
}
}
return "", errors.New("no reachable server could be found")
}
// refreshProxyCache loads the latest proxies from the known providers.
// It includes the standard API.
// If the process takes longer than proxyCacheRefreshTimeout, an error is returned.
func (p *proxyProvider) refreshProxyCache() error {
logrus.Info("Refreshing proxy cache")
for _, provider := range p.providers {
proxies, err := p.dohLookup(p.query, provider)
ctx, cancel := context.WithTimeout(context.Background(), p.cacheRefreshTimeout)
defer cancel()
if err == nil {
p.proxyCache = proxies
resultChan := make(chan []string)
logrus.WithField("proxies", proxies).Info("Available proxies")
return nil
go func() {
for _, provider := range p.providers {
if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil {
resultChan <- proxies
return
}
}
}()
logrus.WithError(err).Warn("Lookup failed, trying another provider")
select {
case result := <-resultChan:
p.proxyCache = result
return nil
case <-ctx.Done():
return errors.New("timed out while refreshing proxy cache")
}
return errors.New("lookup failed with all DoH providers")
}
// canReach returns whether we can reach the given url.
func (p *proxyProvider) canReach(url string) bool {
logrus.WithField("url", url).Debug("Trying to ping proxy")
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
url = "https://" + url
}
@ -151,7 +169,7 @@ func (p *proxyProvider) canReach(url string) bool {
pinger := resty.New().
SetHostURL(url).
SetTimeout(p.lookupTimeout).
SetTimeout(p.canReachTimeout).
SetTransport(CreateTransportWithDialer(dialer))
if _, err := pinger.R().Get("/tests/ping"); err != nil {
@ -165,10 +183,13 @@ func (p *proxyProvider) canReach(url string) bool {
// defaultDoHLookup is the default implementation of the proxy manager's DoH lookup.
// It looks up DNS TXT records for the given query URL using the given DoH provider.
// It returns a list of all found TXT records.
// If the whole process takes more than ProxyQueryTime then an error is returned.
func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []string, err error) {
dataResult := make(chan []string)
errResult := make(chan error)
// If the whole process takes more than proxyDoHTimeout then an error is returned.
func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider string) (data []string, err error) {
ctx, cancel := context.WithTimeout(ctx, p.dohTimeout)
defer cancel()
dataChan, errChan := make(chan []string), make(chan error)
go func() {
// Build new DNS request in RFC1035 format.
dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT)
@ -176,7 +197,7 @@ func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []stri
// Pack the DNS request message into wire format.
rawRequest, err := dnsRequest.Pack()
if err != nil {
errResult <- errors.Wrap(err, "failed to pack DNS request")
errChan <- errors.Wrap(err, "failed to pack DNS request")
return
}
@ -184,16 +205,16 @@ func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []stri
encodedRequest := base64.RawURLEncoding.EncodeToString(rawRequest)
// Make DoH request to the given DoH provider.
rawResponse, err := resty.New().R().SetQueryParam("dns", encodedRequest).Get(dohProvider)
rawResponse, err := resty.New().R().SetContext(ctx).SetQueryParam("dns", encodedRequest).Get(dohProvider)
if err != nil {
errResult <- errors.Wrap(err, "failed to make DoH request")
errChan <- errors.Wrap(err, "failed to make DoH request")
return
}
// Unpack the DNS response.
dnsResponse := new(dns.Msg)
if err = dnsResponse.Unpack(rawResponse.Body()); err != nil {
errResult <- errors.Wrap(err, "failed to unpack DNS response")
errChan <- errors.Wrap(err, "failed to unpack DNS response")
return
}
@ -204,20 +225,20 @@ func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []stri
}
}
dataResult <- data
dataChan <- data
}()
select {
case <-time.After(p.lookupTimeout):
logrus.WithField("provider", dohProvider).Error("Timed out querying DNS records")
return []string{}, errors.New("timed out querying DNS records")
case data = <-dataResult:
case data = <-dataChan:
logrus.WithField("data", data).Info("Received TXT records")
return
case err = <-errResult:
case err = <-errChan:
logrus.WithField("provider", dohProvider).WithError(err).Error("Failed to query DNS records")
return
case <-ctx.Done():
logrus.WithField("provider", dohProvider).Error("Timed out querying DNS records")
return []string{}, errors.New("timed out querying DNS records")
}
}