mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-11 05:06:51 +00:00
GODT-1524: Logout issues with macOS.
This commit is contained in:
@ -19,11 +19,11 @@ package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Address statuses.
|
||||
@ -201,7 +201,7 @@ func (c *client) unlockAddress(passphrase []byte, address *Address) error {
|
||||
|
||||
kr, err := address.Keys.UnlockAll(passphrase, c.userKeyRing)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, "cannot unlock address keys for "+address.ID)
|
||||
}
|
||||
|
||||
c.addrKeyRing[address.ID] = kr
|
||||
|
||||
@ -51,7 +51,7 @@ type TwoFAInfo struct {
|
||||
}
|
||||
|
||||
func (twoFAInfo TwoFAInfo) hasTwoFactor() bool {
|
||||
return twoFAInfo.Enabled > 0
|
||||
return twoFAInfo.Enabled > TwoFADisabled
|
||||
}
|
||||
|
||||
type TwoFAStatus int
|
||||
@ -185,7 +185,7 @@ func (c *client) authRefresh(ctx context.Context) error {
|
||||
|
||||
auth, err := c.manager.authRefresh(ctx, c.uid, c.ref)
|
||||
if err != nil {
|
||||
if err != ErrNoConnection {
|
||||
if IsFailedAuth(err) {
|
||||
c.sendAuthRefresh(nil)
|
||||
}
|
||||
return err
|
||||
|
||||
122
pkg/pmapi/auth_server_test.go
Normal file
122
pkg/pmapi/auth_server_test.go
Normal file
@ -0,0 +1,122 @@
|
||||
// Copyright (c) 2022 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testRefreshResponse struct {
|
||||
Code int
|
||||
AccessToken string
|
||||
ExpiresIn int
|
||||
TokenType string
|
||||
Scope string
|
||||
Scopes []string
|
||||
UID string
|
||||
RefreshToken string
|
||||
LocalID int
|
||||
|
||||
r *require.Assertions
|
||||
}
|
||||
|
||||
var tokenID = 0
|
||||
|
||||
func newTestRefreshToken(r *require.Assertions) testRefreshResponse {
|
||||
tokenID++
|
||||
scopes := []string{
|
||||
"full",
|
||||
"self",
|
||||
"parent",
|
||||
"user",
|
||||
"loggedin",
|
||||
"paid",
|
||||
"nondelinquent",
|
||||
"mail",
|
||||
"verified",
|
||||
}
|
||||
return testRefreshResponse{
|
||||
Code: 1000,
|
||||
AccessToken: fmt.Sprintf("acc%d", tokenID),
|
||||
ExpiresIn: 3600,
|
||||
TokenType: "Bearer",
|
||||
Scope: strings.Join(scopes, " "),
|
||||
Scopes: scopes,
|
||||
UID: fmt.Sprintf("uid%d", tokenID),
|
||||
RefreshToken: fmt.Sprintf("ref%d", tokenID),
|
||||
r: r,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *testRefreshResponse) isCorrectRefreshToken(body io.ReadCloser) int {
|
||||
request := authRefreshReq{}
|
||||
err := json.NewDecoder(body).Decode(&request)
|
||||
r.r.NoError(body.Close())
|
||||
r.r.NoError(err)
|
||||
|
||||
if r.UID != request.UID {
|
||||
return http.StatusUnprocessableEntity
|
||||
}
|
||||
if r.RefreshToken != request.RefreshToken {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (r *testRefreshResponse) handleAuthRefresh(response http.ResponseWriter, request *http.Request) {
|
||||
if code := r.isCorrectRefreshToken(request.Body); code != http.StatusOK {
|
||||
response.WriteHeader(code)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID++
|
||||
r.AccessToken = fmt.Sprintf("acc%d", tokenID)
|
||||
r.RefreshToken = fmt.Sprintf("ref%d", tokenID)
|
||||
|
||||
response.Header().Set("Content-Type", "application/json")
|
||||
response.WriteHeader(http.StatusOK)
|
||||
r.r.NoError(json.NewEncoder(response).Encode(r))
|
||||
}
|
||||
|
||||
func (r *testRefreshResponse) wantAuthRefresh() AuthRefresh {
|
||||
return AuthRefresh{
|
||||
UID: r.UID,
|
||||
AccessToken: r.AccessToken,
|
||||
RefreshToken: r.RefreshToken,
|
||||
ExpiresIn: int64(r.ExpiresIn),
|
||||
Scopes: r.Scopes,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *testRefreshResponse) isAuthorized(header http.Header) bool {
|
||||
return header.Get("x-pm-uid") == r.UID && header.Get("Authorization") == "Bearer "+r.AccessToken
|
||||
}
|
||||
|
||||
func (r *testRefreshResponse) handleAuthCheckOnly(response http.ResponseWriter, request *http.Request) {
|
||||
if r.isAuthorized(request.Header) {
|
||||
response.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
response.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
@ -25,179 +25,203 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
var wantAuthRefresh = &AuthRefresh{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
ExpiresIn: 100,
|
||||
}
|
||||
|
||||
r := require.New(t)
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
testUID := currentTokens.UID
|
||||
testAcc := currentTokens.AccessToken
|
||||
testRef := currentTokens.RefreshToken
|
||||
currentTokens.ExpiresIn = 100
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuthRefresh *AuthRefresh
|
||||
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
|
||||
NewClient(testUID, testAcc, testRef, time.Now().Add(-time.Second))
|
||||
|
||||
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||
|
||||
// Make a request with an access token that already expired one second ago.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
r.NoError(err)
|
||||
|
||||
wantAuthRefresh := currentTokens.wantAuthRefresh()
|
||||
|
||||
// The auth callback should have been called.
|
||||
a.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||
r.NotNil(gotAuthRefresh)
|
||||
r.Equal(wantAuthRefresh, *gotAuthRefresh)
|
||||
|
||||
cl := c.(*client) //nolint[forcetypeassert] we want to panic here
|
||||
a.Equal(t, wantAuthRefresh.AccessToken, cl.acc)
|
||||
a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref)
|
||||
a.WithinDuration(t, expiresIn(100), cl.exp, time.Second)
|
||||
r.Equal(wantAuthRefresh.AccessToken, cl.acc)
|
||||
r.Equal(wantAuthRefresh.RefreshToken, cl.ref)
|
||||
r.WithinDuration(expiresIn(100), cl.exp, time.Second)
|
||||
}
|
||||
|
||||
func Test401AuthRefresh(t *testing.T) {
|
||||
var wantAuthRefresh = &AuthRefresh{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
}
|
||||
r := require.New(t)
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
testUID := currentTokens.UID
|
||||
testRef := currentTokens.RefreshToken
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
var call int
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
call++
|
||||
|
||||
if call == 1 {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuthRefresh *AuthRefresh
|
||||
|
||||
// Create a new client.
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
c := m.NewClient(testUID, "oldAccToken", testRef, time.Now().Add(time.Hour))
|
||||
|
||||
// Register an auth handler.
|
||||
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||
|
||||
// The first request will fail with 401, triggering a refresh and retry.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
r.NoError(err)
|
||||
|
||||
// The auth callback should have been called.
|
||||
r.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||
r.NotNil(gotAuthRefresh)
|
||||
r.Equal(currentTokens.wantAuthRefresh(), *gotAuthRefresh)
|
||||
}
|
||||
|
||||
func Test401RevokedAuth(t *testing.T) {
|
||||
r := require.New(t)
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
NewClient("badUID", "badAcc", "badRef", time.Now().Add(time.Hour))
|
||||
|
||||
// The request will fail with 401, triggering a refresh.
|
||||
// The retry will also fail with 401, returning an error.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.EqualError(t, err, ErrUnauthorized.Error())
|
||||
r.True(IsFailedAuth(err))
|
||||
}
|
||||
|
||||
func Test401RevokedAuthTokenUpdate(t *testing.T) {
|
||||
var oldAuth = &AuthRefresh{
|
||||
UID: "UID",
|
||||
AccessToken: "oldAcc",
|
||||
RefreshToken: "oldRef",
|
||||
ExpiresIn: 3600,
|
||||
}
|
||||
|
||||
var newAuth = &AuthRefresh{
|
||||
UID: "UID",
|
||||
AccessToken: "newAcc",
|
||||
RefreshToken: "newRef",
|
||||
}
|
||||
func Test401OldRefreshToken(t *testing.T) {
|
||||
r := require.New(t)
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient(currentTokens.UID, "oldAcc", "oldRef", time.Now().Add(time.Hour))
|
||||
|
||||
// The request will fail with 401, triggering a refresh.
|
||||
// The retry will also fail with 401, returning an error.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.True(IsFailedAuth(err))
|
||||
}
|
||||
|
||||
func Test401NoAccessToken(t *testing.T) {
|
||||
r := require.New(t)
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
testUID := currentTokens.UID
|
||||
testRef := currentTokens.RefreshToken
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient(testUID, "", testRef, time.Now().Add(time.Hour))
|
||||
|
||||
// The request will fail with 401, triggering a refresh. After the refresh it should succeed.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(err)
|
||||
}
|
||||
|
||||
func Test401ExpiredAuthUpdateUser(t *testing.T) {
|
||||
r := require.New(t)
|
||||
mux := http.NewServeMux()
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
testUID := currentTokens.UID
|
||||
testRef := currentTokens.RefreshToken
|
||||
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
|
||||
mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !currentTokens.isAuthorized(r.Header) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(newAuth); err != nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
respObj := struct {
|
||||
Code int
|
||||
User *User
|
||||
}{
|
||||
Code: 1000,
|
||||
User: &User{
|
||||
ID: "MJLke8kWh1BBvG95JBIrZvzpgsZ94hNNgjNHVyhXMiv4g9cn6SgvqiIFR5cigpml2LD_iUk_3DkV29oojTt3eA==",
|
||||
Name: "jason",
|
||||
UsedSpace: &usedSpace,
|
||||
},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(respObj); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") == ("Bearer " + oldAuth.AccessToken) {
|
||||
if !currentTokens.isAuthorized(r.Header) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Header.Get("Authorization") == ("Bearer " + newAuth.AccessToken) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(testAddressList); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient(oldAuth.UID, oldAuth.AccessToken, oldAuth.RefreshToken, time.Now().Add(time.Hour))
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
c, _, err := m.NewClientWithRefresh(context.Background(), testUID, testRef)
|
||||
r.NoError(err)
|
||||
|
||||
// The request will fail with 401, triggering a refresh. After the refresh it should succeed.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
_, err = c.UpdateUser(context.Background())
|
||||
r.NoError(err)
|
||||
}
|
||||
|
||||
func TestAuth2FA(t *testing.T) {
|
||||
r := require.New(t)
|
||||
twoFACode := "code"
|
||||
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
|
||||
var twoFAreq auth2FAReq
|
||||
r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq))
|
||||
r.Equal(t, twoFAreq.TwoFactorCode, twoFACode)
|
||||
r.NoError(json.NewDecoder(req.Body).Decode(&twoFAreq))
|
||||
r.Equal(twoFAreq.TwoFactorCode, twoFACode)
|
||||
|
||||
return "/auth/2fa/post_response.json"
|
||||
},
|
||||
@ -205,31 +229,33 @@ func TestAuth2FA(t *testing.T) {
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), twoFACode)
|
||||
r.NoError(t, err)
|
||||
r.NoError(err)
|
||||
}
|
||||
|
||||
func TestAuth2FA_Fail(t *testing.T) {
|
||||
r := require.New(t)
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
return "/auth/2fa/post_401_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), "code")
|
||||
r.Equal(t, ErrBad2FACode, err)
|
||||
r.Equal(ErrBad2FACode, err)
|
||||
}
|
||||
|
||||
func TestAuth2FA_Retry(t *testing.T) {
|
||||
r := require.New(t)
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
return "/auth/2fa/post_422_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), "code")
|
||||
r.Equal(t, ErrBad2FACodeTryAgain, err)
|
||||
r.Equal(ErrBad2FACodeTryAgain, err)
|
||||
}
|
||||
|
||||
@ -19,8 +19,6 @@ package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
|
||||
@ -34,26 +32,26 @@ func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) {
|
||||
|
||||
// unlock unlocks the user's keys but without locking the keyring lock first.
|
||||
// Should only be used internally by methods that first lock the lock.
|
||||
func (c *client) unlock(ctx context.Context, passphrase []byte) (err error) {
|
||||
if _, err = c.CurrentUser(ctx); err != nil {
|
||||
return
|
||||
func (c *client) unlock(ctx context.Context, passphrase []byte) error {
|
||||
if _, err := c.CurrentUser(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.userKeyRing == nil {
|
||||
if err = c.unlockUser(passphrase); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
if err := c.unlockUser(passphrase); err != nil {
|
||||
return ErrUnlockFailed{err}
|
||||
}
|
||||
}
|
||||
|
||||
for _, address := range c.addresses {
|
||||
if c.addrKeyRing[address.ID] == nil {
|
||||
if err = c.unlockAddress(passphrase, address); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock address")
|
||||
if err := c.unlockAddress(passphrase, address); err != nil {
|
||||
return ErrUnlockFailed{err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) ReloadKeys(ctx context.Context, passphrase []byte) (err error) {
|
||||
|
||||
@ -31,10 +31,58 @@ var (
|
||||
ErrPasswordWrong = errors.New("wrong password")
|
||||
)
|
||||
|
||||
// ErrUnprocessableEntity ...
|
||||
type ErrUnprocessableEntity struct {
|
||||
OriginalError error
|
||||
}
|
||||
|
||||
func IsUnprocessableEntity(err error) bool {
|
||||
_, ok := err.(ErrUnprocessableEntity)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrUnprocessableEntity) Error() string {
|
||||
return err.OriginalError.Error()
|
||||
}
|
||||
|
||||
// ErrBadRequest ...
|
||||
type ErrBadRequest struct {
|
||||
OriginalError error
|
||||
}
|
||||
|
||||
func IsBadRequest(err error) bool {
|
||||
_, ok := err.(ErrBadRequest)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrBadRequest) Error() string {
|
||||
return err.OriginalError.Error()
|
||||
}
|
||||
|
||||
// ErrAuthFailed ...
|
||||
type ErrAuthFailed struct {
|
||||
OriginalError error
|
||||
}
|
||||
|
||||
func IsFailedAuth(err error) bool {
|
||||
_, ok := err.(ErrAuthFailed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrAuthFailed) Error() string {
|
||||
return err.OriginalError.Error()
|
||||
}
|
||||
|
||||
// ErrUnlockFailed ...
|
||||
type ErrUnlockFailed struct {
|
||||
OriginalError error
|
||||
}
|
||||
|
||||
func IsFailedUnlock(err error) bool {
|
||||
_, ok := err.(ErrUnlockFailed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrUnlockFailed) Error() string {
|
||||
return err.OriginalError.Error()
|
||||
}
|
||||
|
||||
@ -33,6 +33,7 @@ type manager struct {
|
||||
|
||||
isDown bool
|
||||
locker sync.Locker
|
||||
refreshingAuth sync.Locker
|
||||
connectionObservers []ConnectionObserver
|
||||
proxyDialer *ProxyTLSDialer
|
||||
|
||||
@ -50,6 +51,7 @@ func newManager(cfg Config) *manager {
|
||||
cfg: cfg,
|
||||
rc: resty.New().EnableTrace(),
|
||||
locker: &sync.Mutex{},
|
||||
refreshingAuth: &sync.Mutex{},
|
||||
pingMutex: &sync.RWMutex{},
|
||||
isPinging: false,
|
||||
setSentryUserIDOnce: sync.Once{},
|
||||
|
||||
@ -102,6 +102,9 @@ func (m *manager) auth(ctx context.Context, req AuthReq) (*Auth, error) {
|
||||
}
|
||||
|
||||
func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefresh, error) {
|
||||
m.refreshingAuth.Lock()
|
||||
defer m.refreshingAuth.Unlock()
|
||||
|
||||
var req = authRefreshReq{
|
||||
UID: uid,
|
||||
RefreshToken: ref,
|
||||
@ -117,6 +120,9 @@ func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefres
|
||||
|
||||
_, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh"))
|
||||
if err != nil {
|
||||
if IsBadRequest(err) || IsUnprocessableEntity(err) {
|
||||
err = ErrAuthFailed{err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@ -59,16 +59,14 @@ func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error {
|
||||
if apiErr, ok := res.Error().(*Error); ok {
|
||||
switch {
|
||||
case apiErr.Code == errCodeUpgradeApplication:
|
||||
err = ErrUpgradeApplication
|
||||
if m.cfg.UpgradeApplicationHandler != nil {
|
||||
m.cfg.UpgradeApplicationHandler()
|
||||
}
|
||||
return ErrUpgradeApplication
|
||||
case apiErr.Code == errCodePasswordWrong:
|
||||
err = ErrPasswordWrong
|
||||
return ErrPasswordWrong
|
||||
case apiErr.Code == errCodeAuthPaidPlanRequired:
|
||||
err = ErrPaidPlanRequired
|
||||
case res.StatusCode() == http.StatusUnprocessableEntity:
|
||||
err = ErrUnprocessableEntity{apiErr}
|
||||
return ErrPaidPlanRequired
|
||||
default:
|
||||
err = apiErr
|
||||
}
|
||||
@ -76,6 +74,13 @@ func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error {
|
||||
err = errors.New(res.Status())
|
||||
}
|
||||
|
||||
switch res.StatusCode() {
|
||||
case http.StatusUnprocessableEntity:
|
||||
err = ErrUnprocessableEntity{err}
|
||||
case http.StatusBadRequest:
|
||||
err = ErrBadRequest{err}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user