GODT-35: Finish all details and make tests pass

This commit is contained in:
Michal Horejsek
2021-03-11 14:37:15 +01:00
committed by Jakub
parent 2284e9ede1
commit 8109831c07
173 changed files with 4697 additions and 2897 deletions

View File

@ -1,22 +1,40 @@
package pmapi_test
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
a "github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
)
func TestAutomaticAuthRefresh(t *testing.T) {
var wantAuth = &pmapi.Auth{
var wantAuthRefresh = &AuthRefresh{
UID: "testUID",
AccessToken: "testAcc",
RefreshToken: "testRef",
ExpiresIn: 100,
}
mux := http.NewServeMux()
@ -24,7 +42,7 @@ func TestAutomaticAuthRefresh(t *testing.T) {
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
panic(err)
}
})
@ -35,28 +53,28 @@ func TestAutomaticAuthRefresh(t *testing.T) {
ts := httptest.NewServer(mux)
var gotAuth *pmapi.Auth
var gotAuthRefresh *AuthRefresh
// Create a new client.
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
c := New(Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
// Register an auth handler.
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
// Make a request with an access token that already expired one second ago.
if _, err := c.GetAddresses(context.Background()); err != nil {
t.Fatal("got unexpected error", err)
}
_, err := c.GetAddresses(context.Background())
r.NoError(t, err)
// The auth callback should have been called.
if *gotAuth != *wantAuth {
t.Fatal("got unexpected auth", gotAuth)
}
a.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
cl := c.(*client) //nolint[forcetypeassert] we want to panic here
a.Equal(t, wantAuthRefresh.AccessToken, cl.acc)
a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref)
a.WithinDuration(t, expiresIn(100), cl.exp, time.Second)
}
func Test401AuthRefresh(t *testing.T) {
var wantAuth = &pmapi.Auth{
var wantAuthRefresh = &AuthRefresh{
UID: "testUID",
AccessToken: "testAcc",
RefreshToken: "testRef",
@ -67,7 +85,7 @@ func Test401AuthRefresh(t *testing.T) {
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
panic(err)
}
})
@ -86,24 +104,21 @@ func Test401AuthRefresh(t *testing.T) {
ts := httptest.NewServer(mux)
var gotAuth *pmapi.Auth
var gotAuthRefresh *AuthRefresh
// Create a new client.
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
c := New(Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
// Register an auth handler.
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
// The first request will fail with 401, triggering a refresh and retry.
if _, err := c.GetAddresses(context.Background()); err != nil {
t.Fatal("got unexpected error", err)
}
_, err := c.GetAddresses(context.Background())
r.NoError(t, err)
// The auth callback should have been called.
if *gotAuth != *wantAuth {
t.Fatal("got unexpected auth", gotAuth)
}
r.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
}
func Test401RevokedAuth(t *testing.T) {
@ -119,17 +134,57 @@ func Test401RevokedAuth(t *testing.T) {
ts := httptest.NewServer(mux)
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
c := New(Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
// The request will fail with 401, triggering a refresh.
// The retry will also fail with 401, returning an error.
_, err := c.GetAddresses(context.Background())
if err == nil {
t.Fatal("expected error, instead got", err)
}
if !errors.Is(err, pmapi.ErrUnauthorized) {
t.Fatal("expected error to be ErrUnauthorized, instead got", err)
}
r.EqualError(t, err, ErrUnauthorized.Error())
}
func TestAuth2FA(t *testing.T) {
twoFACode := "code"
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
var twoFAreq auth2FAReq
r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq))
r.Equal(t, twoFAreq.TwoFactorCode, twoFACode)
return "/auth/2fa/post_response.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), twoFACode)
r.NoError(t, err)
}
func TestAuth2FA_Fail(t *testing.T) {
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
return "/auth/2fa/post_401_bad_password.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), "code")
r.Equal(t, ErrBad2FACode, err)
}
func TestAuth2FA_Retry(t *testing.T) {
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
return "/auth/2fa/post_422_bad_password.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), "code")
r.Equal(t, ErrBad2FACodeTryAgain, err)
}