Files
proton-bridge/pkg/pmapi/auth_server_test.go
2022-03-23 09:50:17 +01:00

123 lines
3.1 KiB
Go

// Copyright (c) 2022 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/stretchr/testify/require"
)
type testRefreshResponse struct {
Code int
AccessToken string
ExpiresIn int
TokenType string
Scope string
Scopes []string
UID string
RefreshToken string
LocalID int
r *require.Assertions
}
var tokenID = 0
func newTestRefreshToken(r *require.Assertions) testRefreshResponse {
tokenID++
scopes := []string{
"full",
"self",
"parent",
"user",
"loggedin",
"paid",
"nondelinquent",
"mail",
"verified",
}
return testRefreshResponse{
Code: 1000,
AccessToken: fmt.Sprintf("acc%d", tokenID),
ExpiresIn: 3600,
TokenType: "Bearer",
Scope: strings.Join(scopes, " "),
Scopes: scopes,
UID: fmt.Sprintf("uid%d", tokenID),
RefreshToken: fmt.Sprintf("ref%d", tokenID),
r: r,
}
}
func (r *testRefreshResponse) isCorrectRefreshToken(body io.ReadCloser) int {
request := authRefreshReq{}
err := json.NewDecoder(body).Decode(&request)
r.r.NoError(body.Close())
r.r.NoError(err)
if r.UID != request.UID {
return http.StatusUnprocessableEntity
}
if r.RefreshToken != request.RefreshToken {
return http.StatusBadRequest
}
return http.StatusOK
}
func (r *testRefreshResponse) handleAuthRefresh(response http.ResponseWriter, request *http.Request) {
if code := r.isCorrectRefreshToken(request.Body); code != http.StatusOK {
response.WriteHeader(code)
return
}
tokenID++
r.AccessToken = fmt.Sprintf("acc%d", tokenID)
r.RefreshToken = fmt.Sprintf("ref%d", tokenID)
response.Header().Set("Content-Type", "application/json")
response.WriteHeader(http.StatusOK)
r.r.NoError(json.NewEncoder(response).Encode(r))
}
func (r *testRefreshResponse) wantAuthRefresh() AuthRefresh {
return AuthRefresh{
UID: r.UID,
AccessToken: r.AccessToken,
RefreshToken: r.RefreshToken,
ExpiresIn: int64(r.ExpiresIn),
Scopes: r.Scopes,
}
}
func (r *testRefreshResponse) isAuthorized(header http.Header) bool {
return header.Get("x-pm-uid") == r.UID && header.Get("Authorization") == "Bearer "+r.AccessToken
}
func (r *testRefreshResponse) handleAuthCheckOnly(response http.ResponseWriter, request *http.Request) {
if r.isAuthorized(request.Header) {
response.WriteHeader(http.StatusOK)
} else {
response.WriteHeader(http.StatusUnauthorized)
}
}