GODT-1524: Logout issues with macOS.

This commit is contained in:
Jakub
2022-03-14 16:37:42 +01:00
parent 2d5ea669a5
commit 6671dd38ea
26 changed files with 411 additions and 135 deletions

4
go.mod
View File

@ -48,10 +48,12 @@ require (
github.com/google/uuid v1.1.1
github.com/hashicorp/go-multierror v1.1.0
github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
github.com/keybase/go-keychain v0.0.0-20211119201326-e02f34051621
github.com/kr/text v0.2.0 // indirect
github.com/logrusorgru/aurora v2.0.3+incompatible
github.com/mattn/go-runewidth v0.0.9 // indirect
github.com/miekg/dns v1.1.41
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
github.com/olekukonko/tablewriter v0.0.4 // indirect
github.com/pkg/errors v0.9.1

2
go.sum
View File

@ -265,6 +265,8 @@ github.com/kataras/pio v0.0.2/go.mod h1:hAoW0t9UmXi4R5Oyq5Z4irTbaTsOemSrDGUtaTl7
github.com/kataras/sitemap v0.0.5/go.mod h1:KY2eugMKiPwsJgx7+U103YZehfvNGOXURubcGyk0Bz8=
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d h1:gVjhBCfVGl32RIBooOANzfw+0UqX8HU+yPlMv8vypcg=
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d/go.mod h1:W6EbaYmb4RldPn0N3gvVHjY1wmU59kbymhW9NATWhwY=
github.com/keybase/go-keychain v0.0.0-20211119201326-e02f34051621 h1:aMQ7pA4f06yOVXSulygyGvy4xA94fyzjUGs0iqQdMOI=
github.com/keybase/go-keychain v0.0.0-20211119201326-e02f34051621/go.mod h1:enrU/ug069Om7vWxuFE6nikLI2BZNwevMiGSo43Kt5w=
github.com/keybase/go.dbus v0.0.0-20200324223359-a94be52c0b03/go.mod h1:a8clEhrrGV/d76/f9r2I41BwANMihfZYV9C223vaxqE=
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=

View File

@ -194,7 +194,7 @@ func (store *Store) BuildAndCacheMessage(ctx context.Context, messageID string)
}
func (store *Store) checkAndRemoveDeletedMessage(err error, msgID string) {
if _, ok := err.(pmapi.ErrUnprocessableEntity); !ok {
if !pmapi.IsUnprocessableEntity(err) {
return
}
l := store.log.WithError(err).WithField("msgID", msgID)

View File

@ -477,7 +477,7 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
if msg, err = loop.client().GetMessage(context.Background(), message.ID); err != nil {
if _, ok := err.(pmapi.ErrUnprocessableEntity); ok {
if pmapi.IsUnprocessableEntity(err) {
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
err = nil
continue

View File

@ -223,7 +223,7 @@ func (u *User) UpdateSpace(apiUser *pmapi.User) {
// values from client.CurrentUser()
if apiUser == nil {
var err error
apiUser, err = u.client.GetUser(pmapi.ContextWithoutRetry(context.Background()))
apiUser, err = u.GetClient().GetUser(pmapi.ContextWithoutRetry(context.Background()))
if err != nil {
u.log.WithError(err).Warning("Cannot update user space")
return
@ -280,16 +280,21 @@ func (u *User) unlockIfNecessary() error {
return nil
}
switch errors.Cause(err) {
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
u.log.WithError(err).Warn("Could not unlock user")
return nil
if pmapi.IsFailedAuth(err) || pmapi.IsFailedUnlock(err) {
if logoutErr := u.logout(); logoutErr != nil {
u.log.WithError(logoutErr).Warn("Could not logout user")
}
return errors.Wrap(err, "failed to unlock user")
}
if logoutErr := u.logout(); logoutErr != nil {
u.log.WithError(logoutErr).Warn("Could not logout user")
switch errors.Cause(err) {
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
u.log.WithError(err).Warn("Skipping unlock for known reason")
default:
u.log.WithError(err).Error("Unknown unlock issue")
}
return errors.Wrap(err, "failed to unlock user")
return nil
}
// IsCombinedAddressMode returns whether user is set in combined or split mode.

View File

@ -23,6 +23,7 @@ import (
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
r "github.com/stretchr/testify/require"
)
@ -46,7 +47,7 @@ func TestNewUserUnlockFails(t *testing.T) {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(errors.New("bad password")),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(pmapi.ErrUnlockFailed{OriginalError: errors.New("bad password")}),
// Handle of unlock error.
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),

View File

@ -178,8 +178,10 @@ func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *creden
return connectErr
}
if logoutErr := user.logout(); logoutErr != nil {
logrus.WithError(logoutErr).Warn("Could not logout user")
if pmapi.IsFailedAuth(connectErr) {
if logoutErr := user.logout(); logoutErr != nil {
logrus.WithError(logoutErr).Warn("Could not logout user")
}
}
return errors.Wrap(err, "could not refresh token")
}

View File

@ -24,6 +24,7 @@ import (
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
r "github.com/stretchr/testify/require"
)
@ -80,11 +81,11 @@ func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, errors.New("bad token"))
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, pmapi.ErrBadRequest{OriginalError: errors.New("bad token")})
m.clientManager.EXPECT().NewClient("uid", "", "acc", time.Time{}).Return(m.pmapiClient)
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
m.pmapiClient.EXPECT().IsUnlocked().Return(false)
m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(errors.New("not authorized"))
m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(pmapi.ErrBadRequest{OriginalError: errors.New("not authorized")})
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
@ -93,7 +94,6 @@ func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})

View File

@ -19,11 +19,11 @@ package pmapi
import (
"context"
"errors"
"strings"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
"github.com/pkg/errors"
)
// Address statuses.
@ -201,7 +201,7 @@ func (c *client) unlockAddress(passphrase []byte, address *Address) error {
kr, err := address.Keys.UnlockAll(passphrase, c.userKeyRing)
if err != nil {
return err
return errors.Wrap(err, "cannot unlock address keys for "+address.ID)
}
c.addrKeyRing[address.ID] = kr

View File

@ -51,7 +51,7 @@ type TwoFAInfo struct {
}
func (twoFAInfo TwoFAInfo) hasTwoFactor() bool {
return twoFAInfo.Enabled > 0
return twoFAInfo.Enabled > TwoFADisabled
}
type TwoFAStatus int
@ -185,7 +185,7 @@ func (c *client) authRefresh(ctx context.Context) error {
auth, err := c.manager.authRefresh(ctx, c.uid, c.ref)
if err != nil {
if err != ErrNoConnection {
if IsFailedAuth(err) {
c.sendAuthRefresh(nil)
}
return err

View File

@ -0,0 +1,122 @@
// Copyright (c) 2022 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/stretchr/testify/require"
)
type testRefreshResponse struct {
Code int
AccessToken string
ExpiresIn int
TokenType string
Scope string
Scopes []string
UID string
RefreshToken string
LocalID int
r *require.Assertions
}
var tokenID = 0
func newTestRefreshToken(r *require.Assertions) testRefreshResponse {
tokenID++
scopes := []string{
"full",
"self",
"parent",
"user",
"loggedin",
"paid",
"nondelinquent",
"mail",
"verified",
}
return testRefreshResponse{
Code: 1000,
AccessToken: fmt.Sprintf("acc%d", tokenID),
ExpiresIn: 3600,
TokenType: "Bearer",
Scope: strings.Join(scopes, " "),
Scopes: scopes,
UID: fmt.Sprintf("uid%d", tokenID),
RefreshToken: fmt.Sprintf("ref%d", tokenID),
r: r,
}
}
func (r *testRefreshResponse) isCorrectRefreshToken(body io.ReadCloser) int {
request := authRefreshReq{}
err := json.NewDecoder(body).Decode(&request)
r.r.NoError(body.Close())
r.r.NoError(err)
if r.UID != request.UID {
return http.StatusUnprocessableEntity
}
if r.RefreshToken != request.RefreshToken {
return http.StatusBadRequest
}
return http.StatusOK
}
func (r *testRefreshResponse) handleAuthRefresh(response http.ResponseWriter, request *http.Request) {
if code := r.isCorrectRefreshToken(request.Body); code != http.StatusOK {
response.WriteHeader(code)
return
}
tokenID++
r.AccessToken = fmt.Sprintf("acc%d", tokenID)
r.RefreshToken = fmt.Sprintf("ref%d", tokenID)
response.Header().Set("Content-Type", "application/json")
response.WriteHeader(http.StatusOK)
r.r.NoError(json.NewEncoder(response).Encode(r))
}
func (r *testRefreshResponse) wantAuthRefresh() AuthRefresh {
return AuthRefresh{
UID: r.UID,
AccessToken: r.AccessToken,
RefreshToken: r.RefreshToken,
ExpiresIn: int64(r.ExpiresIn),
Scopes: r.Scopes,
}
}
func (r *testRefreshResponse) isAuthorized(header http.Header) bool {
return header.Get("x-pm-uid") == r.UID && header.Get("Authorization") == "Bearer "+r.AccessToken
}
func (r *testRefreshResponse) handleAuthCheckOnly(response http.ResponseWriter, request *http.Request) {
if r.isAuthorized(request.Header) {
response.WriteHeader(http.StatusOK)
} else {
response.WriteHeader(http.StatusUnauthorized)
}
}

View File

@ -25,179 +25,203 @@ import (
"testing"
"time"
a "github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
func TestAutomaticAuthRefresh(t *testing.T) {
var wantAuthRefresh = &AuthRefresh{
UID: "testUID",
AccessToken: "testAcc",
RefreshToken: "testRef",
ExpiresIn: 100,
}
r := require.New(t)
mux := http.NewServeMux()
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
currentTokens := newTestRefreshToken(r)
testUID := currentTokens.UID
testAcc := currentTokens.AccessToken
testRef := currentTokens.RefreshToken
currentTokens.ExpiresIn = 100
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
panic(err)
}
})
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
ts := httptest.NewServer(mux)
var gotAuthRefresh *AuthRefresh
c := New(Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
NewClient(testUID, testAcc, testRef, time.Now().Add(-time.Second))
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
// Make a request with an access token that already expired one second ago.
_, err := c.GetAddresses(context.Background())
r.NoError(t, err)
r.NoError(err)
wantAuthRefresh := currentTokens.wantAuthRefresh()
// The auth callback should have been called.
a.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
r.NotNil(gotAuthRefresh)
r.Equal(wantAuthRefresh, *gotAuthRefresh)
cl := c.(*client) //nolint[forcetypeassert] we want to panic here
a.Equal(t, wantAuthRefresh.AccessToken, cl.acc)
a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref)
a.WithinDuration(t, expiresIn(100), cl.exp, time.Second)
r.Equal(wantAuthRefresh.AccessToken, cl.acc)
r.Equal(wantAuthRefresh.RefreshToken, cl.ref)
r.WithinDuration(expiresIn(100), cl.exp, time.Second)
}
func Test401AuthRefresh(t *testing.T) {
var wantAuthRefresh = &AuthRefresh{
UID: "testUID",
AccessToken: "testAcc",
RefreshToken: "testRef",
}
r := require.New(t)
currentTokens := newTestRefreshToken(r)
testUID := currentTokens.UID
testRef := currentTokens.RefreshToken
mux := http.NewServeMux()
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
panic(err)
}
})
var call int
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
call++
if call == 1 {
w.WriteHeader(http.StatusUnauthorized)
} else {
w.WriteHeader(http.StatusOK)
}
})
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
ts := httptest.NewServer(mux)
var gotAuthRefresh *AuthRefresh
// Create a new client.
c := New(Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
m := New(Config{HostURL: ts.URL})
c := m.NewClient(testUID, "oldAccToken", testRef, time.Now().Add(time.Hour))
// Register an auth handler.
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
// The first request will fail with 401, triggering a refresh and retry.
_, err := c.GetAddresses(context.Background())
r.NoError(t, err)
r.NoError(err)
// The auth callback should have been called.
r.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
r.NotNil(gotAuthRefresh)
r.Equal(currentTokens.wantAuthRefresh(), *gotAuthRefresh)
}
func Test401RevokedAuth(t *testing.T) {
r := require.New(t)
currentTokens := newTestRefreshToken(r)
mux := http.NewServeMux()
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
})
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
})
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
ts := httptest.NewServer(mux)
c := New(Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
NewClient("badUID", "badAcc", "badRef", time.Now().Add(time.Hour))
// The request will fail with 401, triggering a refresh.
// The retry will also fail with 401, returning an error.
_, err := c.GetAddresses(context.Background())
r.EqualError(t, err, ErrUnauthorized.Error())
r.True(IsFailedAuth(err))
}
func Test401RevokedAuthTokenUpdate(t *testing.T) {
var oldAuth = &AuthRefresh{
UID: "UID",
AccessToken: "oldAcc",
RefreshToken: "oldRef",
ExpiresIn: 3600,
}
var newAuth = &AuthRefresh{
UID: "UID",
AccessToken: "newAcc",
RefreshToken: "newRef",
}
func Test401OldRefreshToken(t *testing.T) {
r := require.New(t)
currentTokens := newTestRefreshToken(r)
mux := http.NewServeMux()
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
ts := httptest.NewServer(mux)
c := New(Config{HostURL: ts.URL}).
NewClient(currentTokens.UID, "oldAcc", "oldRef", time.Now().Add(time.Hour))
// The request will fail with 401, triggering a refresh.
// The retry will also fail with 401, returning an error.
_, err := c.GetAddresses(context.Background())
r.True(IsFailedAuth(err))
}
func Test401NoAccessToken(t *testing.T) {
r := require.New(t)
currentTokens := newTestRefreshToken(r)
testUID := currentTokens.UID
testRef := currentTokens.RefreshToken
mux := http.NewServeMux()
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
ts := httptest.NewServer(mux)
c := New(Config{HostURL: ts.URL}).
NewClient(testUID, "", testRef, time.Now().Add(time.Hour))
// The request will fail with 401, triggering a refresh. After the refresh it should succeed.
_, err := c.GetAddresses(context.Background())
r.NoError(err)
}
func Test401ExpiredAuthUpdateUser(t *testing.T) {
r := require.New(t)
mux := http.NewServeMux()
currentTokens := newTestRefreshToken(r)
testUID := currentTokens.UID
testRef := currentTokens.RefreshToken
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) {
if !currentTokens.isAuthorized(r.Header) {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(newAuth); err != nil {
w.WriteHeader(http.StatusOK)
respObj := struct {
Code int
User *User
}{
Code: 1000,
User: &User{
ID: "MJLke8kWh1BBvG95JBIrZvzpgsZ94hNNgjNHVyhXMiv4g9cn6SgvqiIFR5cigpml2LD_iUk_3DkV29oojTt3eA==",
Name: "jason",
UsedSpace: &usedSpace,
},
}
if err := json.NewEncoder(w).Encode(respObj); err != nil {
panic(err)
}
})
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") == ("Bearer " + oldAuth.AccessToken) {
if !currentTokens.isAuthorized(r.Header) {
w.WriteHeader(http.StatusUnauthorized)
return
}
if r.Header.Get("Authorization") == ("Bearer " + newAuth.AccessToken) {
w.WriteHeader(http.StatusOK)
return
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(testAddressList); err != nil {
panic(err)
}
})
ts := httptest.NewServer(mux)
c := New(Config{HostURL: ts.URL}).
NewClient(oldAuth.UID, oldAuth.AccessToken, oldAuth.RefreshToken, time.Now().Add(time.Hour))
m := New(Config{HostURL: ts.URL})
c, _, err := m.NewClientWithRefresh(context.Background(), testUID, testRef)
r.NoError(err)
// The request will fail with 401, triggering a refresh. After the refresh it should succeed.
_, err := c.GetAddresses(context.Background())
r.NoError(t, err)
_, err = c.UpdateUser(context.Background())
r.NoError(err)
}
func TestAuth2FA(t *testing.T) {
r := require.New(t)
twoFACode := "code"
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
var twoFAreq auth2FAReq
r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq))
r.Equal(t, twoFAreq.TwoFactorCode, twoFACode)
r.NoError(json.NewDecoder(req.Body).Decode(&twoFAreq))
r.Equal(twoFAreq.TwoFactorCode, twoFACode)
return "/auth/2fa/post_response.json"
},
@ -205,31 +229,33 @@ func TestAuth2FA(t *testing.T) {
defer finish()
err := c.Auth2FA(context.Background(), twoFACode)
r.NoError(t, err)
r.NoError(err)
}
func TestAuth2FA_Fail(t *testing.T) {
r := require.New(t)
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
return "/auth/2fa/post_401_bad_password.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), "code")
r.Equal(t, ErrBad2FACode, err)
r.Equal(ErrBad2FACode, err)
}
func TestAuth2FA_Retry(t *testing.T) {
r := require.New(t)
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
return "/auth/2fa/post_422_bad_password.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), "code")
r.Equal(t, ErrBad2FACodeTryAgain, err)
r.Equal(ErrBad2FACodeTryAgain, err)
}

View File

@ -19,8 +19,6 @@ package pmapi
import (
"context"
"github.com/pkg/errors"
)
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
@ -34,26 +32,26 @@ func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) {
// unlock unlocks the user's keys but without locking the keyring lock first.
// Should only be used internally by methods that first lock the lock.
func (c *client) unlock(ctx context.Context, passphrase []byte) (err error) {
if _, err = c.CurrentUser(ctx); err != nil {
return
func (c *client) unlock(ctx context.Context, passphrase []byte) error {
if _, err := c.CurrentUser(ctx); err != nil {
return err
}
if c.userKeyRing == nil {
if err = c.unlockUser(passphrase); err != nil {
return errors.Wrap(err, "failed to unlock user")
if err := c.unlockUser(passphrase); err != nil {
return ErrUnlockFailed{err}
}
}
for _, address := range c.addresses {
if c.addrKeyRing[address.ID] == nil {
if err = c.unlockAddress(passphrase, address); err != nil {
return errors.Wrap(err, "failed to unlock address")
if err := c.unlockAddress(passphrase, address); err != nil {
return ErrUnlockFailed{err}
}
}
}
return
return nil
}
func (c *client) ReloadKeys(ctx context.Context, passphrase []byte) (err error) {

View File

@ -31,10 +31,58 @@ var (
ErrPasswordWrong = errors.New("wrong password")
)
// ErrUnprocessableEntity ...
type ErrUnprocessableEntity struct {
OriginalError error
}
func IsUnprocessableEntity(err error) bool {
_, ok := err.(ErrUnprocessableEntity)
return ok
}
func (err ErrUnprocessableEntity) Error() string {
return err.OriginalError.Error()
}
// ErrBadRequest ...
type ErrBadRequest struct {
OriginalError error
}
func IsBadRequest(err error) bool {
_, ok := err.(ErrBadRequest)
return ok
}
func (err ErrBadRequest) Error() string {
return err.OriginalError.Error()
}
// ErrAuthFailed ...
type ErrAuthFailed struct {
OriginalError error
}
func IsFailedAuth(err error) bool {
_, ok := err.(ErrAuthFailed)
return ok
}
func (err ErrAuthFailed) Error() string {
return err.OriginalError.Error()
}
// ErrUnlockFailed ...
type ErrUnlockFailed struct {
OriginalError error
}
func IsFailedUnlock(err error) bool {
_, ok := err.(ErrUnlockFailed)
return ok
}
func (err ErrUnlockFailed) Error() string {
return err.OriginalError.Error()
}

View File

@ -33,6 +33,7 @@ type manager struct {
isDown bool
locker sync.Locker
refreshingAuth sync.Locker
connectionObservers []ConnectionObserver
proxyDialer *ProxyTLSDialer
@ -50,6 +51,7 @@ func newManager(cfg Config) *manager {
cfg: cfg,
rc: resty.New().EnableTrace(),
locker: &sync.Mutex{},
refreshingAuth: &sync.Mutex{},
pingMutex: &sync.RWMutex{},
isPinging: false,
setSentryUserIDOnce: sync.Once{},

View File

@ -102,6 +102,9 @@ func (m *manager) auth(ctx context.Context, req AuthReq) (*Auth, error) {
}
func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefresh, error) {
m.refreshingAuth.Lock()
defer m.refreshingAuth.Unlock()
var req = authRefreshReq{
UID: uid,
RefreshToken: ref,
@ -117,6 +120,9 @@ func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefres
_, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh"))
if err != nil {
if IsBadRequest(err) || IsUnprocessableEntity(err) {
err = ErrAuthFailed{err}
}
return nil, err
}

View File

@ -59,16 +59,14 @@ func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error {
if apiErr, ok := res.Error().(*Error); ok {
switch {
case apiErr.Code == errCodeUpgradeApplication:
err = ErrUpgradeApplication
if m.cfg.UpgradeApplicationHandler != nil {
m.cfg.UpgradeApplicationHandler()
}
return ErrUpgradeApplication
case apiErr.Code == errCodePasswordWrong:
err = ErrPasswordWrong
return ErrPasswordWrong
case apiErr.Code == errCodeAuthPaidPlanRequired:
err = ErrPaidPlanRequired
case res.StatusCode() == http.StatusUnprocessableEntity:
err = ErrUnprocessableEntity{apiErr}
return ErrPaidPlanRequired
default:
err = apiErr
}
@ -76,6 +74,13 @@ func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error {
err = errors.New(res.Status())
}
switch res.StatusCode() {
case http.StatusUnprocessableEntity:
err = ErrUnprocessableEntity{err}
case http.StatusBadRequest:
err = ErrBadRequest{err}
}
return err
}

View File

@ -46,6 +46,7 @@ type PMAPIController interface {
LockEvents(username string)
UnlockEvents(username string)
RemoveUserMessageWithoutEvent(username, messageID string) error
RevokeSession(username string) error
}
func newPMAPIController(listener listener.Listener) (PMAPIController, pmapi.Manager) {

View File

@ -250,3 +250,10 @@ func (ctl *Controller) RemoveUserMessageWithoutEvent(username string, messageID
return errors.New("message not found")
}
func (ctl *Controller) RevokeSession(username string) error {
for _, session := range ctl.sessionsByUID {
session.uid = "revoked"
}
return nil
}

View File

@ -74,12 +74,12 @@ func (ctl *Controller) createSession(username string, hasFullScope bool) *fakeSe
func (ctl *Controller) refreshSessionIfAuthorized(uid, ref string) (*fakeSession, error) {
session, ok := ctl.sessionsByUID[uid]
if !ok {
return nil, pmapi.ErrUnauthorized
if !ok || session.uid != uid {
return nil, pmapi.ErrAuthFailed{OriginalError: errors.New("bad uid")}
}
if ref != session.ref {
return nil, pmapi.ErrUnauthorized
return nil, pmapi.ErrAuthFailed{OriginalError: errors.New("bad refresh token")}
}
session.ref = ctl.tokenGenerator.next("ref")

View File

@ -133,14 +133,32 @@ func (api *FakePMAPI) authRefresh() error {
session, err := api.controller.refreshSessionIfAuthorized(api.uid, api.ref)
if err != nil {
if pmapi.IsFailedAuth(err) {
go api.handleAuth(nil)
}
return err
}
api.ref = session.ref
api.acc = session.acc
go api.handleAuth(&pmapi.AuthRefresh{
UID: api.uid,
AccessToken: api.acc,
RefreshToken: api.ref,
ExpiresIn: 7200,
Scopes: []string{"full", "self", "user", "mail"},
})
return nil
}
func (api *FakePMAPI) handleAuth(auth *pmapi.AuthRefresh) {
for _, handle := range api.authHandlers {
handle(auth)
}
}
func (api *FakePMAPI) setUser(username string) error {
api.username = username
api.log = api.log.WithField("username", username)

View File

@ -69,7 +69,7 @@ func (m *fakePMAPIManager) NewClientWithRefresh(_ context.Context, uid, ref stri
session, err := m.controller.refreshSessionIfAuthorized(uid, ref)
if err != nil {
return nil, nil, pmapi.ErrUnauthorized
return nil, nil, err
}
user, ok := m.controller.usersByUsername[session.username]

View File

@ -82,6 +82,10 @@ func (api *FakePMAPI) UpdateUser(context.Context) (*pmapi.User, error) {
return nil, err
}
if err := api.checkAndRecordCall(GET, "/addresses", nil); err != nil {
return nil, err
}
return api.user, nil
}

View File

@ -0,0 +1,17 @@
Feature: Session deleted on API
@ignore-live
Scenario: Session revoked after start
Given there is connected user "user"
When session was revoked for "user"
And the event loop of "user" loops once
Then "user" is disconnected
@ignore-live
Scenario: Starting with revoked session
Given there is user "user" which just logged in
And session was revoked for "user"
When bridge starts
Then "user" is disconnected

View File

@ -60,3 +60,7 @@ func (ctl *Controller) GetAuthClient(username string) pmapi.Client {
}
return client
}
func (ctl *Controller) RevokeSession(username string) error {
return errors.New("revoke live session not implemented")
}

View File

@ -29,6 +29,7 @@ func UsersActionsFeatureContext(s *godog.ScenarioContext) {
s.Step(`^user deletes "([^"]*)"$`, userDeletesUser)
s.Step(`^user deletes "([^"]*)" with cache$`, userDeletesUserWithCache)
s.Step(`^"([^"]*)" swaps address "([^"]*)" with address "([^"]*)"$`, swapsAddressWithAddress)
s.Step(`^session was revoked for "([^"]*)"$`, sessionRevoked)
}
func userLogsIn(bddUserID string) error {
@ -123,3 +124,8 @@ func swapsAddressWithAddress(bddUserID, bddAddressID1, bddAddressID2 string) err
return ctx.GetPMAPIController().ReorderAddresses(account.User(), addressIDs)
}
func sessionRevoked(bddUserID string) error {
account := ctx.GetTestAccount(bddUserID)
return ctx.GetPMAPIController().RevokeSession(account.Username())
}