mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-18 08:06:59 +00:00
GODT-35: Finish all details and make tests pass
This commit is contained in:
@ -1,16 +1,39 @@
|
||||
package pmapi_test
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testForceUpgradeBody = `{
|
||||
"Code":5003,
|
||||
"Error":"Upgrade!"
|
||||
}`
|
||||
|
||||
func TestHandleTooManyRequests(t *testing.T) {
|
||||
var numCalls int
|
||||
|
||||
@ -24,21 +47,17 @@ func TestHandleTooManyRequests(t *testing.T) {
|
||||
}
|
||||
}))
|
||||
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
|
||||
// Set the retry count to 5.
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// The call should succeed because the 5th retry should succeed (429s are retried).
|
||||
if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
|
||||
// The server should be called 5 times.
|
||||
// The first four calls should return 429 and the last call should return 200.
|
||||
if numCalls != 5 {
|
||||
t.Fatal("expected numCalls to be 5, instead got", numCalls)
|
||||
}
|
||||
r.Equal(t, 5, numCalls)
|
||||
}
|
||||
|
||||
func TestHandleUnprocessableEntity(t *testing.T) {
|
||||
@ -49,27 +68,16 @@ func TestHandleUnprocessableEntity(t *testing.T) {
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
}))
|
||||
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
|
||||
// Set the retry count to 5.
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// The call should fail because the first call should fail (422s are not retried).
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
// API-side errors get ErrAPIFailure
|
||||
if !errors.Is(err, pmapi.ErrAPIFailure) {
|
||||
t.Fatal("expected error to be ErrAPIFailure, instead got", err)
|
||||
}
|
||||
|
||||
r.EqualError(t, err, "422 Unprocessable Entity")
|
||||
// The server should be called 1 time.
|
||||
// The first call should return 422.
|
||||
if numCalls != 1 {
|
||||
t.Fatal("expected numCalls to be 1, instead got", numCalls)
|
||||
}
|
||||
r.Equal(t, 1, numCalls)
|
||||
}
|
||||
|
||||
func TestHandleDialFailure(t *testing.T) {
|
||||
@ -81,24 +89,17 @@ func TestHandleDialFailure(t *testing.T) {
|
||||
}))
|
||||
|
||||
// The failingRoundTripper will fail the first 5 times it is used.
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
|
||||
// Set a custom transport.
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
m.SetTransport(newFailingRoundTripper(5))
|
||||
|
||||
// Set the retry count to 5.
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// The call should succeed because the last retry should succeed (dial errors are retried).
|
||||
if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
|
||||
// The server should be called 1 time.
|
||||
// The first 4 attempts don't reach the server.
|
||||
if numCalls != 1 {
|
||||
t.Fatal("expected numCalls to be 1, instead got", numCalls)
|
||||
}
|
||||
r.Equal(t, 1, numCalls)
|
||||
}
|
||||
|
||||
func TestHandleTooManyDialFailures(t *testing.T) {
|
||||
@ -112,28 +113,15 @@ func TestHandleTooManyDialFailures(t *testing.T) {
|
||||
// The failingRoundTripper will fail the first 10 times it is used.
|
||||
// This is more than the number of retries we permit.
|
||||
// Thus, dials will fail.
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
|
||||
// Set a custom transport.
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
m.SetTransport(newFailingRoundTripper(10))
|
||||
|
||||
// Set the retry count to 5.
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// The call should fail because every dial will fail and we'll run out of retries.
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
if !errors.Is(err, pmapi.ErrNoConnection) {
|
||||
t.Fatal("expected error to be ErrNoConnection, instead got", err)
|
||||
}
|
||||
|
||||
r.EqualError(t, err, "no internet connection")
|
||||
// The server should never be called.
|
||||
if numCalls != 0 {
|
||||
t.Fatal("expected numCalls to be 0, instead got", numCalls)
|
||||
}
|
||||
r.Equal(t, 0, numCalls)
|
||||
}
|
||||
|
||||
func TestRetriesWithContextTimeout(t *testing.T) {
|
||||
@ -150,24 +138,16 @@ func TestRetriesWithContextTimeout(t *testing.T) {
|
||||
}))
|
||||
|
||||
// Theoretically, this should succeed; on the fifth retry, we'll get StatusOK.
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
|
||||
// Set the retry count to 5.
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// However, that will take ~5s, and we only allow 1s in the context.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
// However, that will take ~0.5s, and we only allow 10ms in the context.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Thus, it will fail.
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatal("expected error to be DeadlineExceeded, instead got", err)
|
||||
}
|
||||
r.EqualError(t, err, context.DeadlineExceeded.Error())
|
||||
}
|
||||
|
||||
func TestObserveConnectionStatus(t *testing.T) {
|
||||
@ -177,36 +157,24 @@ func TestObserveConnectionStatus(t *testing.T) {
|
||||
|
||||
var onDown, onUp bool
|
||||
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
|
||||
// Set a custom transport.
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
m.SetTransport(newFailingRoundTripper(10))
|
||||
|
||||
// Set the retry count to 5.
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// Add a connection observer.
|
||||
m.AddConnectionObserver(pmapi.NewConnectionObserver(func() { onDown = true }, func() { onUp = true }))
|
||||
m.AddConnectionObserver(NewConnectionObserver(func() { onDown = true }, func() { onUp = true }))
|
||||
|
||||
// The call should fail because every dial will fail and we'll run out of retries.
|
||||
if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
if onDown != true || onUp == true {
|
||||
t.Fatal("expected onDown to have been called and onUp to not have been called")
|
||||
}
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
r.Error(t, err)
|
||||
r.False(t, onUp)
|
||||
r.True(t, onDown)
|
||||
|
||||
onDown, onUp = false, false
|
||||
|
||||
// The call should succeed because the last dial attempt will succeed.
|
||||
if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
|
||||
if onDown == true || onUp != true {
|
||||
t.Fatal("expected onUp to have been called and onDown to not have been called")
|
||||
}
|
||||
_, err = m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
r.True(t, onUp)
|
||||
r.False(t, onDown)
|
||||
}
|
||||
|
||||
func TestReturnErrNoConnection(t *testing.T) {
|
||||
@ -215,19 +183,27 @@ func TestReturnErrNoConnection(t *testing.T) {
|
||||
}))
|
||||
|
||||
// We will fail more times than we retry, so requests should fail with ErrNoConnection.
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
m.SetTransport(newFailingRoundTripper(10))
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// The call should fail because every dial will fail and we'll run out of retries.
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
r.EqualError(t, err, "no internet connection")
|
||||
}
|
||||
|
||||
if !errors.Is(err, pmapi.ErrNoConnection) {
|
||||
t.Fatal("expected error to be ErrNoConnection, instead got", err)
|
||||
}
|
||||
func TestReturnErrUpgradeApplication(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("content-type", "application/json")
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
fmt.Fprint(w, testForceUpgradeBody)
|
||||
}))
|
||||
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
|
||||
// The call should fail because every call return force upgrade error.
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
r.EqualError(t, err, ErrUpgradeApplication.Error())
|
||||
}
|
||||
|
||||
type failingRoundTripper struct {
|
||||
|
||||
Reference in New Issue
Block a user