forked from Silverfish/proton-bridge
GODT-35: Finish all details and make tests pass
This commit is contained in:
@ -1,22 +1,40 @@
|
||||
package pmapi_test
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
var wantAuth = &pmapi.Auth{
|
||||
var wantAuthRefresh = &AuthRefresh{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
ExpiresIn: 100,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
@ -24,7 +42,7 @@ func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
@ -35,28 +53,28 @@ func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuth *pmapi.Auth
|
||||
var gotAuthRefresh *AuthRefresh
|
||||
|
||||
// Create a new client.
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
|
||||
|
||||
// Register an auth handler.
|
||||
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
|
||||
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||
|
||||
// Make a request with an access token that already expired one second ago.
|
||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
|
||||
// The auth callback should have been called.
|
||||
if *gotAuth != *wantAuth {
|
||||
t.Fatal("got unexpected auth", gotAuth)
|
||||
}
|
||||
a.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||
|
||||
cl := c.(*client) //nolint[forcetypeassert] we want to panic here
|
||||
a.Equal(t, wantAuthRefresh.AccessToken, cl.acc)
|
||||
a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref)
|
||||
a.WithinDuration(t, expiresIn(100), cl.exp, time.Second)
|
||||
}
|
||||
|
||||
func Test401AuthRefresh(t *testing.T) {
|
||||
var wantAuth = &pmapi.Auth{
|
||||
var wantAuthRefresh = &AuthRefresh{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
@ -67,7 +85,7 @@ func Test401AuthRefresh(t *testing.T) {
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
@ -86,24 +104,21 @@ func Test401AuthRefresh(t *testing.T) {
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuth *pmapi.Auth
|
||||
var gotAuthRefresh *AuthRefresh
|
||||
|
||||
// Create a new client.
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
|
||||
// Register an auth handler.
|
||||
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
|
||||
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||
|
||||
// The first request will fail with 401, triggering a refresh and retry.
|
||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
|
||||
// The auth callback should have been called.
|
||||
if *gotAuth != *wantAuth {
|
||||
t.Fatal("got unexpected auth", gotAuth)
|
||||
}
|
||||
r.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||
}
|
||||
|
||||
func Test401RevokedAuth(t *testing.T) {
|
||||
@ -119,17 +134,57 @@ func Test401RevokedAuth(t *testing.T) {
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
|
||||
// The request will fail with 401, triggering a refresh.
|
||||
// The retry will also fail with 401, returning an error.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
if !errors.Is(err, pmapi.ErrUnauthorized) {
|
||||
t.Fatal("expected error to be ErrUnauthorized, instead got", err)
|
||||
}
|
||||
r.EqualError(t, err, ErrUnauthorized.Error())
|
||||
}
|
||||
|
||||
func TestAuth2FA(t *testing.T) {
|
||||
twoFACode := "code"
|
||||
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
|
||||
var twoFAreq auth2FAReq
|
||||
r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq))
|
||||
r.Equal(t, twoFAreq.TwoFactorCode, twoFACode)
|
||||
|
||||
return "/auth/2fa/post_response.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), twoFACode)
|
||||
r.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuth2FA_Fail(t *testing.T) {
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
return "/auth/2fa/post_401_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), "code")
|
||||
r.Equal(t, ErrBad2FACode, err)
|
||||
}
|
||||
|
||||
func TestAuth2FA_Retry(t *testing.T) {
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
return "/auth/2fa/post_422_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), "code")
|
||||
r.Equal(t, ErrBad2FACodeTryAgain, err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user