From cc14b523cb10f66b7097775da54285aab6cd2ca5 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Mon, 25 May 2020 17:15:35 +0200 Subject: [PATCH] fix: correct doh timeouts --- Changelog.md | 1 + pkg/pmapi/proxy.go | 159 +++++++++++++++++++++++----------------- pkg/pmapi/proxy_test.go | 67 ++++++++++------- 3 files changed, 133 insertions(+), 94 deletions(-) diff --git a/Changelog.md b/Changelog.md index df1dd619..767fd60f 100644 --- a/Changelog.md +++ b/Changelog.md @@ -14,6 +14,7 @@ Changelog [format](http://keepachangelog.com/en/1.0.0/) ### Fixed * GODT-356 Fix crash when removing account while mail client is fetching messages (regression from GODT-204) * GODT-390 Don't logout user if AuthRefresh fails because internet was off. +* GODT-358 Bad timeouts with Alternative Routing ## [v1.2.7] Donghai-hotfix - beta (2020-05-07) diff --git a/pkg/pmapi/proxy.go b/pkg/pmapi/proxy.go index 694cd330..aa9df688 100644 --- a/pkg/pmapi/proxy.go +++ b/pkg/pmapi/proxy.go @@ -18,8 +18,10 @@ package pmapi import ( + "context" "encoding/base64" "strings" + "sync" "time" "github.com/go-resty/resty/v2" @@ -29,11 +31,12 @@ import ( ) const ( - proxyUseDuration = 24 * time.Hour - proxySearchTimeout = 20 * time.Second - proxyQueryTimeout = 20 * time.Second - proxyLookupWait = 5 * time.Second - proxyQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz" + proxyUseDuration = 24 * time.Hour + proxyLookupWait = 5 * time.Second + proxyCacheRefreshTimeout = 20 * time.Second + proxyDoHTimeout = 20 * time.Second + proxyCanReachTimeout = 20 * time.Second + proxyQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz" ) var dohProviders = []string{ //nolint[gochecknoglobals] @@ -44,13 +47,15 @@ var dohProviders = []string{ //nolint[gochecknoglobals] // proxyProvider manages known proxies. type proxyProvider struct { // dohLookup is used to look up the given query at the given DoH provider, returning the TXT records> - dohLookup func(query, provider string) (urls []string, err error) + dohLookup func(ctx context.Context, query, provider string) (urls []string, err error) providers []string // List of known doh providers. query string // The query string used to find proxies. proxyCache []string // All known proxies, cached in case DoH providers are unreachable. - findTimeout, lookupTimeout time.Duration // Timeouts for DNS query and proxy search. + cacheRefreshTimeout time.Duration + dohTimeout time.Duration + canReachTimeout time.Duration lastLookup time.Time // The time at which we last attempted to find a proxy. } @@ -59,10 +64,11 @@ type proxyProvider struct { // to retrieve DNS records for the given query string. func newProxyProvider(providers []string, query string) (p *proxyProvider) { // nolint[unparam] p = &proxyProvider{ - providers: providers, - query: query, - findTimeout: proxySearchTimeout, - lookupTimeout: proxyQueryTimeout, + providers: providers, + query: query, + cacheRefreshTimeout: proxyCacheRefreshTimeout, + dohTimeout: proxyDoHTimeout, + canReachTimeout: proxyCanReachTimeout, } // Use the default DNS lookup method; this can be overridden if necessary. @@ -72,77 +78,89 @@ func newProxyProvider(providers []string, query string) (p *proxyProvider) { // } // findReachableServer returns a working API server (either proxy or standard API). -// It returns an error if the process takes longer than ProxySearchTime. func (p *proxyProvider) findReachableServer() (proxy string, err error) { + logrus.Debug("Trying to find a reachable server") + if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) { return "", errors.New("not looking for a proxy, too soon") } p.lastLookup = time.Now() - proxyResult := make(chan string) - errResult := make(chan error) + // We use a waitgroup to wait for both + // a) the check whether the API is reachable, and + // b) the DoH queries. + // This is because the Alternative Routes v2 spec says: + // Call the GET /test/ping route on normal API domain (same time as DoH requests and wait until all have finished) + var wg sync.WaitGroup + var apiReachable bool + + wg.Add(2) + go func() { - if err = p.refreshProxyCache(); err != nil { - errResult <- errors.Wrap(err, "failed to refresh proxy cache") - return - } - - // We want to switch back to the rootURL if possible. - if p.canReach(rootURL) { - proxyResult <- rootURL - return - } - - for _, proxy := range p.proxyCache { - if p.canReach(proxy) { - proxyResult <- proxy - return - } - } - - errResult <- errors.New("no proxy available") + defer wg.Done() + apiReachable = p.canReach(rootURL) }() - select { - case <-time.After(p.findTimeout): - logrus.Error("Timed out finding a proxy server") - return "", errors.New("timed out finding a proxy") + go func() { + defer wg.Done() + err = p.refreshProxyCache() + }() - case proxy = <-proxyResult: - logrus.WithField("proxy", proxy).Info("Found proxy server") - return + wg.Wait() - case err = <-errResult: - logrus.WithError(err).Error("Failed to find available proxy server") + if apiReachable { + proxy = rootURL return } + + if err != nil { + return + } + + for _, url := range p.proxyCache { + if p.canReach(url) { + proxy = url + return + } + } + + return "", errors.New("no reachable server could be found") } // refreshProxyCache loads the latest proxies from the known providers. -// It includes the standard API. +// If the process takes longer than proxyCacheRefreshTimeout, an error is returned. func (p *proxyProvider) refreshProxyCache() error { logrus.Info("Refreshing proxy cache") - for _, provider := range p.providers { - proxies, err := p.dohLookup(p.query, provider) + ctx, cancel := context.WithTimeout(context.Background(), p.cacheRefreshTimeout) + defer cancel() - if err == nil { - p.proxyCache = proxies + resultChan := make(chan []string) - logrus.WithField("proxies", proxies).Info("Available proxies") - - return nil + go func() { + for _, provider := range p.providers { + if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil { + resultChan <- proxies + return + } } + }() - logrus.WithError(err).Warn("Lookup failed, trying another provider") + select { + case result := <-resultChan: + p.proxyCache = result + return nil + + case <-ctx.Done(): + return errors.New("timed out while refreshing proxy cache") } - - return errors.New("lookup failed with all DoH providers") } // canReach returns whether we can reach the given url. func (p *proxyProvider) canReach(url string) bool { + logrus.WithField("url", url).Debug("Trying to ping proxy") + if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") { url = "https://" + url } @@ -151,7 +169,7 @@ func (p *proxyProvider) canReach(url string) bool { pinger := resty.New(). SetHostURL(url). - SetTimeout(p.lookupTimeout). + SetTimeout(p.canReachTimeout). SetTransport(CreateTransportWithDialer(dialer)) if _, err := pinger.R().Get("/tests/ping"); err != nil { @@ -165,10 +183,13 @@ func (p *proxyProvider) canReach(url string) bool { // defaultDoHLookup is the default implementation of the proxy manager's DoH lookup. // It looks up DNS TXT records for the given query URL using the given DoH provider. // It returns a list of all found TXT records. -// If the whole process takes more than ProxyQueryTime then an error is returned. -func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []string, err error) { - dataResult := make(chan []string) - errResult := make(chan error) +// If the whole process takes more than proxyDoHTimeout then an error is returned. +func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider string) (data []string, err error) { + ctx, cancel := context.WithTimeout(ctx, p.dohTimeout) + defer cancel() + + dataChan, errChan := make(chan []string), make(chan error) + go func() { // Build new DNS request in RFC1035 format. dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT) @@ -176,7 +197,7 @@ func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []stri // Pack the DNS request message into wire format. rawRequest, err := dnsRequest.Pack() if err != nil { - errResult <- errors.Wrap(err, "failed to pack DNS request") + errChan <- errors.Wrap(err, "failed to pack DNS request") return } @@ -184,16 +205,16 @@ func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []stri encodedRequest := base64.RawURLEncoding.EncodeToString(rawRequest) // Make DoH request to the given DoH provider. - rawResponse, err := resty.New().R().SetQueryParam("dns", encodedRequest).Get(dohProvider) + rawResponse, err := resty.New().R().SetContext(ctx).SetQueryParam("dns", encodedRequest).Get(dohProvider) if err != nil { - errResult <- errors.Wrap(err, "failed to make DoH request") + errChan <- errors.Wrap(err, "failed to make DoH request") return } // Unpack the DNS response. dnsResponse := new(dns.Msg) if err = dnsResponse.Unpack(rawResponse.Body()); err != nil { - errResult <- errors.Wrap(err, "failed to unpack DNS response") + errChan <- errors.Wrap(err, "failed to unpack DNS response") return } @@ -204,20 +225,20 @@ func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []stri } } - dataResult <- data + dataChan <- data }() select { - case <-time.After(p.lookupTimeout): - logrus.WithField("provider", dohProvider).Error("Timed out querying DNS records") - return []string{}, errors.New("timed out querying DNS records") - - case data = <-dataResult: + case data = <-dataChan: logrus.WithField("data", data).Info("Received TXT records") return - case err = <-errResult: + case err = <-errChan: logrus.WithField("provider", dohProvider).WithError(err).Error("Failed to query DNS records") return + + case <-ctx.Done(): + logrus.WithField("provider", dohProvider).Error("Timed out querying DNS records") + return []string{}, errors.New("timed out querying DNS records") } } diff --git a/pkg/pmapi/proxy_test.go b/pkg/pmapi/proxy_test.go index 30495b04..ababdf4d 100644 --- a/pkg/pmapi/proxy_test.go +++ b/pkg/pmapi/proxy_test.go @@ -18,6 +18,7 @@ package pmapi import ( + "context" "crypto/tls" "net/http" "net/http/httptest" @@ -35,7 +36,15 @@ const ( // getTrustedServer returns a server and sets its public key as one of the pinned ones. func getTrustedServer() *httptest.Server { - proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + return getTrustedServerWithHandler( + http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + // Do nothing. + }), + ) +} + +func getTrustedServerWithHandler(handler http.HandlerFunc) *httptest.Server { + proxy := httptest.NewTLSServer(handler) pin := certFingerprint(proxy.Certificate()) TrustedAPIPins = append(TrustedAPIPins, pin) @@ -143,7 +152,7 @@ func TestProxyProvider_FindProxy(t *testing.T) { defer closeServer(proxy) p := newProxyProvider([]string{"not used"}, "not used") - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil } url, err := p.findReachableServer() require.NoError(t, err) @@ -162,7 +171,9 @@ func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) { closeServer(unreachableProxy) p := newProxyProvider([]string{"not used"}, "not used") - p.dohLookup = func(q, p string) ([]string, error) { return []string{reachableProxy.URL, unreachableProxy.URL}, nil } + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { + return []string{reachableProxy.URL, unreachableProxy.URL}, nil + } url, err := p.findReachableServer() require.NoError(t, err) @@ -180,7 +191,9 @@ func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) { defer closeServer(untrustedProxy) p := newProxyProvider([]string{"not used"}, "not used") - p.dohLookup = func(q, p string) ([]string, error) { return []string{untrustedProxy.URL, trustedProxy.URL}, nil } + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { + return []string{untrustedProxy.URL, trustedProxy.URL}, nil + } url, err := p.findReachableServer() require.NoError(t, err) @@ -198,7 +211,7 @@ func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) { closeServer(unreachableProxy2) p := newProxyProvider([]string{"not used"}, "not used") - p.dohLookup = func(q, p string) ([]string, error) { + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil } @@ -217,7 +230,7 @@ func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) { defer closeServer(untrustedProxy2) p := newProxyProvider([]string{"not used"}, "not used") - p.dohLookup = func(q, p string) ([]string, error) { + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil } @@ -225,34 +238,38 @@ func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) { require.Error(t, err) } -func TestProxyProvider_FindProxy_LookupTimeout(t *testing.T) { +func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) { blockAPI() defer unblockAPI() p := newProxyProvider([]string{"not used"}, "not used") - p.lookupTimeout = time.Second - p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil } + p.cacheRefreshTimeout = 1 * time.Second + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil } - // The findReachableServer should fail because lookup takes 2 seconds but we only allow 1 second. + // We should fail to refresh the proxy cache because the doh provider + // takes 2 seconds to respond but we timeout after just 1 second. _, err := p.findReachableServer() + require.Error(t, err) } -func TestProxyProvider_FindProxy_FindTimeout(t *testing.T) { +func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) { blockAPI() defer unblockAPI() - slowProxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + slowProxy := getTrustedServerWithHandler(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { time.Sleep(2 * time.Second) })) - defer slowProxy.Close() + defer closeServer(slowProxy) p := newProxyProvider([]string{"not used"}, "not used") - p.findTimeout = time.Second - p.dohLookup = func(q, p string) ([]string, error) { return []string{slowProxy.URL}, nil } + p.canReachTimeout = 1 * time.Second + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil } - // The findReachableServer should fail because lookup takes 2 seconds but we only allow 1 second. + // We should fail to reach the returned proxy because it takes 2 seconds + // to reach it and we only allow 1. _, err := p.findReachableServer() + require.Error(t, err) } @@ -268,7 +285,7 @@ func TestProxyProvider_UseProxy(t *testing.T) { p := newProxyProvider([]string{"not used"}, "not used") cm.proxyProvider = p - p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } url, err := cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, trustedProxy.URL, url) @@ -291,7 +308,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { p := newProxyProvider([]string{"not used"}, "not used") cm.proxyProvider = p - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil } + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil } url, err := cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy1.URL, url) @@ -300,7 +317,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { // Have to wait so as to not get rejected. time.Sleep(proxyLookupWait) - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy2.URL}, nil } + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy2.URL}, nil } url, err = cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy2.URL, url) @@ -309,7 +326,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { // Have to wait so as to not get rejected. time.Sleep(proxyLookupWait) - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy3.URL}, nil } + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy3.URL}, nil } url, err = cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy3.URL, url) @@ -329,7 +346,7 @@ func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) { cm.proxyProvider = p cm.proxyUseDuration = time.Second - p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } url, err := cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, trustedProxy.URL, url) @@ -350,7 +367,7 @@ func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachab p := newProxyProvider([]string{"not used"}, "not used") cm.proxyProvider = p - p.dohLookup = func(q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } url, err := cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, trustedProxy.URL, url) @@ -386,7 +403,7 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl cm.proxyProvider = p // Find a proxy. - p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } url, err := cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy1.URL, url) @@ -408,7 +425,7 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl func TestProxyProvider_DoHLookup_Quad9(t *testing.T) { p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) - records, err := p.dohLookup(TestDoHQuery, TestQuad9Provider) + records, err := p.dohLookup(context.Background(), TestDoHQuery, TestQuad9Provider) require.NoError(t, err) require.NotEmpty(t, records) } @@ -416,7 +433,7 @@ func TestProxyProvider_DoHLookup_Quad9(t *testing.T) { func TestProxyProvider_DoHLookup_Google(t *testing.T) { p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) - records, err := p.dohLookup(TestDoHQuery, TestGoogleProvider) + records, err := p.dohLookup(context.Background(), TestDoHQuery, TestGoogleProvider) require.NoError(t, err) require.NotEmpty(t, records) }