mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-11 13:16:53 +00:00
GODT-1524: Logout issues with macOS.
This commit is contained in:
4
go.mod
4
go.mod
@ -48,10 +48,12 @@ require (
|
||||
github.com/google/uuid v1.1.1
|
||||
github.com/hashicorp/go-multierror v1.1.0
|
||||
github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7
|
||||
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
|
||||
github.com/keybase/go-keychain v0.0.0-20211119201326-e02f34051621
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/logrusorgru/aurora v2.0.3+incompatible
|
||||
github.com/mattn/go-runewidth v0.0.9 // indirect
|
||||
github.com/miekg/dns v1.1.41
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
|
||||
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
|
||||
github.com/olekukonko/tablewriter v0.0.4 // indirect
|
||||
github.com/pkg/errors v0.9.1
|
||||
|
||||
2
go.sum
2
go.sum
@ -265,6 +265,8 @@ github.com/kataras/pio v0.0.2/go.mod h1:hAoW0t9UmXi4R5Oyq5Z4irTbaTsOemSrDGUtaTl7
|
||||
github.com/kataras/sitemap v0.0.5/go.mod h1:KY2eugMKiPwsJgx7+U103YZehfvNGOXURubcGyk0Bz8=
|
||||
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d h1:gVjhBCfVGl32RIBooOANzfw+0UqX8HU+yPlMv8vypcg=
|
||||
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d/go.mod h1:W6EbaYmb4RldPn0N3gvVHjY1wmU59kbymhW9NATWhwY=
|
||||
github.com/keybase/go-keychain v0.0.0-20211119201326-e02f34051621 h1:aMQ7pA4f06yOVXSulygyGvy4xA94fyzjUGs0iqQdMOI=
|
||||
github.com/keybase/go-keychain v0.0.0-20211119201326-e02f34051621/go.mod h1:enrU/ug069Om7vWxuFE6nikLI2BZNwevMiGSo43Kt5w=
|
||||
github.com/keybase/go.dbus v0.0.0-20200324223359-a94be52c0b03/go.mod h1:a8clEhrrGV/d76/f9r2I41BwANMihfZYV9C223vaxqE=
|
||||
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
|
||||
@ -194,7 +194,7 @@ func (store *Store) BuildAndCacheMessage(ctx context.Context, messageID string)
|
||||
}
|
||||
|
||||
func (store *Store) checkAndRemoveDeletedMessage(err error, msgID string) {
|
||||
if _, ok := err.(pmapi.ErrUnprocessableEntity); !ok {
|
||||
if !pmapi.IsUnprocessableEntity(err) {
|
||||
return
|
||||
}
|
||||
l := store.log.WithError(err).WithField("msgID", msgID)
|
||||
|
||||
@ -477,7 +477,7 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
|
||||
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
|
||||
|
||||
if msg, err = loop.client().GetMessage(context.Background(), message.ID); err != nil {
|
||||
if _, ok := err.(pmapi.ErrUnprocessableEntity); ok {
|
||||
if pmapi.IsUnprocessableEntity(err) {
|
||||
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
|
||||
err = nil
|
||||
continue
|
||||
|
||||
@ -223,7 +223,7 @@ func (u *User) UpdateSpace(apiUser *pmapi.User) {
|
||||
// values from client.CurrentUser()
|
||||
if apiUser == nil {
|
||||
var err error
|
||||
apiUser, err = u.client.GetUser(pmapi.ContextWithoutRetry(context.Background()))
|
||||
apiUser, err = u.GetClient().GetUser(pmapi.ContextWithoutRetry(context.Background()))
|
||||
if err != nil {
|
||||
u.log.WithError(err).Warning("Cannot update user space")
|
||||
return
|
||||
@ -280,16 +280,21 @@ func (u *User) unlockIfNecessary() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch errors.Cause(err) {
|
||||
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
|
||||
u.log.WithError(err).Warn("Could not unlock user")
|
||||
return nil
|
||||
if pmapi.IsFailedAuth(err) || pmapi.IsFailedUnlock(err) {
|
||||
if logoutErr := u.logout(); logoutErr != nil {
|
||||
u.log.WithError(logoutErr).Warn("Could not logout user")
|
||||
}
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
|
||||
if logoutErr := u.logout(); logoutErr != nil {
|
||||
u.log.WithError(logoutErr).Warn("Could not logout user")
|
||||
switch errors.Cause(err) {
|
||||
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
|
||||
u.log.WithError(err).Warn("Skipping unlock for known reason")
|
||||
default:
|
||||
u.log.WithError(err).Error("Unknown unlock issue")
|
||||
}
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsCombinedAddressMode returns whether user is set in combined or split mode.
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -46,7 +47,7 @@ func TestNewUserUnlockFails(t *testing.T) {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
|
||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(errors.New("bad password")),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(pmapi.ErrUnlockFailed{OriginalError: errors.New("bad password")}),
|
||||
|
||||
// Handle of unlock error.
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||
|
||||
@ -178,8 +178,10 @@ func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *creden
|
||||
return connectErr
|
||||
}
|
||||
|
||||
if logoutErr := user.logout(); logoutErr != nil {
|
||||
logrus.WithError(logoutErr).Warn("Could not logout user")
|
||||
if pmapi.IsFailedAuth(connectErr) {
|
||||
if logoutErr := user.logout(); logoutErr != nil {
|
||||
logrus.WithError(logoutErr).Warn("Could not logout user")
|
||||
}
|
||||
}
|
||||
return errors.Wrap(err, "could not refresh token")
|
||||
}
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -80,11 +81,11 @@ func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, errors.New("bad token"))
|
||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, pmapi.ErrBadRequest{OriginalError: errors.New("bad token")})
|
||||
m.clientManager.EXPECT().NewClient("uid", "", "acc", time.Time{}).Return(m.pmapiClient)
|
||||
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
|
||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false)
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(errors.New("not authorized"))
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(pmapi.ErrBadRequest{OriginalError: errors.New("not authorized")})
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
@ -93,7 +94,6 @@ func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||
|
||||
@ -19,11 +19,11 @@ package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Address statuses.
|
||||
@ -201,7 +201,7 @@ func (c *client) unlockAddress(passphrase []byte, address *Address) error {
|
||||
|
||||
kr, err := address.Keys.UnlockAll(passphrase, c.userKeyRing)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, "cannot unlock address keys for "+address.ID)
|
||||
}
|
||||
|
||||
c.addrKeyRing[address.ID] = kr
|
||||
|
||||
@ -51,7 +51,7 @@ type TwoFAInfo struct {
|
||||
}
|
||||
|
||||
func (twoFAInfo TwoFAInfo) hasTwoFactor() bool {
|
||||
return twoFAInfo.Enabled > 0
|
||||
return twoFAInfo.Enabled > TwoFADisabled
|
||||
}
|
||||
|
||||
type TwoFAStatus int
|
||||
@ -185,7 +185,7 @@ func (c *client) authRefresh(ctx context.Context) error {
|
||||
|
||||
auth, err := c.manager.authRefresh(ctx, c.uid, c.ref)
|
||||
if err != nil {
|
||||
if err != ErrNoConnection {
|
||||
if IsFailedAuth(err) {
|
||||
c.sendAuthRefresh(nil)
|
||||
}
|
||||
return err
|
||||
|
||||
122
pkg/pmapi/auth_server_test.go
Normal file
122
pkg/pmapi/auth_server_test.go
Normal file
@ -0,0 +1,122 @@
|
||||
// Copyright (c) 2022 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testRefreshResponse struct {
|
||||
Code int
|
||||
AccessToken string
|
||||
ExpiresIn int
|
||||
TokenType string
|
||||
Scope string
|
||||
Scopes []string
|
||||
UID string
|
||||
RefreshToken string
|
||||
LocalID int
|
||||
|
||||
r *require.Assertions
|
||||
}
|
||||
|
||||
var tokenID = 0
|
||||
|
||||
func newTestRefreshToken(r *require.Assertions) testRefreshResponse {
|
||||
tokenID++
|
||||
scopes := []string{
|
||||
"full",
|
||||
"self",
|
||||
"parent",
|
||||
"user",
|
||||
"loggedin",
|
||||
"paid",
|
||||
"nondelinquent",
|
||||
"mail",
|
||||
"verified",
|
||||
}
|
||||
return testRefreshResponse{
|
||||
Code: 1000,
|
||||
AccessToken: fmt.Sprintf("acc%d", tokenID),
|
||||
ExpiresIn: 3600,
|
||||
TokenType: "Bearer",
|
||||
Scope: strings.Join(scopes, " "),
|
||||
Scopes: scopes,
|
||||
UID: fmt.Sprintf("uid%d", tokenID),
|
||||
RefreshToken: fmt.Sprintf("ref%d", tokenID),
|
||||
r: r,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *testRefreshResponse) isCorrectRefreshToken(body io.ReadCloser) int {
|
||||
request := authRefreshReq{}
|
||||
err := json.NewDecoder(body).Decode(&request)
|
||||
r.r.NoError(body.Close())
|
||||
r.r.NoError(err)
|
||||
|
||||
if r.UID != request.UID {
|
||||
return http.StatusUnprocessableEntity
|
||||
}
|
||||
if r.RefreshToken != request.RefreshToken {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (r *testRefreshResponse) handleAuthRefresh(response http.ResponseWriter, request *http.Request) {
|
||||
if code := r.isCorrectRefreshToken(request.Body); code != http.StatusOK {
|
||||
response.WriteHeader(code)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID++
|
||||
r.AccessToken = fmt.Sprintf("acc%d", tokenID)
|
||||
r.RefreshToken = fmt.Sprintf("ref%d", tokenID)
|
||||
|
||||
response.Header().Set("Content-Type", "application/json")
|
||||
response.WriteHeader(http.StatusOK)
|
||||
r.r.NoError(json.NewEncoder(response).Encode(r))
|
||||
}
|
||||
|
||||
func (r *testRefreshResponse) wantAuthRefresh() AuthRefresh {
|
||||
return AuthRefresh{
|
||||
UID: r.UID,
|
||||
AccessToken: r.AccessToken,
|
||||
RefreshToken: r.RefreshToken,
|
||||
ExpiresIn: int64(r.ExpiresIn),
|
||||
Scopes: r.Scopes,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *testRefreshResponse) isAuthorized(header http.Header) bool {
|
||||
return header.Get("x-pm-uid") == r.UID && header.Get("Authorization") == "Bearer "+r.AccessToken
|
||||
}
|
||||
|
||||
func (r *testRefreshResponse) handleAuthCheckOnly(response http.ResponseWriter, request *http.Request) {
|
||||
if r.isAuthorized(request.Header) {
|
||||
response.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
response.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
@ -25,179 +25,203 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
var wantAuthRefresh = &AuthRefresh{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
ExpiresIn: 100,
|
||||
}
|
||||
|
||||
r := require.New(t)
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
testUID := currentTokens.UID
|
||||
testAcc := currentTokens.AccessToken
|
||||
testRef := currentTokens.RefreshToken
|
||||
currentTokens.ExpiresIn = 100
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuthRefresh *AuthRefresh
|
||||
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
|
||||
NewClient(testUID, testAcc, testRef, time.Now().Add(-time.Second))
|
||||
|
||||
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||
|
||||
// Make a request with an access token that already expired one second ago.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
r.NoError(err)
|
||||
|
||||
wantAuthRefresh := currentTokens.wantAuthRefresh()
|
||||
|
||||
// The auth callback should have been called.
|
||||
a.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||
r.NotNil(gotAuthRefresh)
|
||||
r.Equal(wantAuthRefresh, *gotAuthRefresh)
|
||||
|
||||
cl := c.(*client) //nolint[forcetypeassert] we want to panic here
|
||||
a.Equal(t, wantAuthRefresh.AccessToken, cl.acc)
|
||||
a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref)
|
||||
a.WithinDuration(t, expiresIn(100), cl.exp, time.Second)
|
||||
r.Equal(wantAuthRefresh.AccessToken, cl.acc)
|
||||
r.Equal(wantAuthRefresh.RefreshToken, cl.ref)
|
||||
r.WithinDuration(expiresIn(100), cl.exp, time.Second)
|
||||
}
|
||||
|
||||
func Test401AuthRefresh(t *testing.T) {
|
||||
var wantAuthRefresh = &AuthRefresh{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
}
|
||||
r := require.New(t)
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
testUID := currentTokens.UID
|
||||
testRef := currentTokens.RefreshToken
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
var call int
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
call++
|
||||
|
||||
if call == 1 {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuthRefresh *AuthRefresh
|
||||
|
||||
// Create a new client.
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
c := m.NewClient(testUID, "oldAccToken", testRef, time.Now().Add(time.Hour))
|
||||
|
||||
// Register an auth handler.
|
||||
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||
|
||||
// The first request will fail with 401, triggering a refresh and retry.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
r.NoError(err)
|
||||
|
||||
// The auth callback should have been called.
|
||||
r.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||
r.NotNil(gotAuthRefresh)
|
||||
r.Equal(currentTokens.wantAuthRefresh(), *gotAuthRefresh)
|
||||
}
|
||||
|
||||
func Test401RevokedAuth(t *testing.T) {
|
||||
r := require.New(t)
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
NewClient("badUID", "badAcc", "badRef", time.Now().Add(time.Hour))
|
||||
|
||||
// The request will fail with 401, triggering a refresh.
|
||||
// The retry will also fail with 401, returning an error.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.EqualError(t, err, ErrUnauthorized.Error())
|
||||
r.True(IsFailedAuth(err))
|
||||
}
|
||||
|
||||
func Test401RevokedAuthTokenUpdate(t *testing.T) {
|
||||
var oldAuth = &AuthRefresh{
|
||||
UID: "UID",
|
||||
AccessToken: "oldAcc",
|
||||
RefreshToken: "oldRef",
|
||||
ExpiresIn: 3600,
|
||||
}
|
||||
|
||||
var newAuth = &AuthRefresh{
|
||||
UID: "UID",
|
||||
AccessToken: "newAcc",
|
||||
RefreshToken: "newRef",
|
||||
}
|
||||
func Test401OldRefreshToken(t *testing.T) {
|
||||
r := require.New(t)
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient(currentTokens.UID, "oldAcc", "oldRef", time.Now().Add(time.Hour))
|
||||
|
||||
// The request will fail with 401, triggering a refresh.
|
||||
// The retry will also fail with 401, returning an error.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.True(IsFailedAuth(err))
|
||||
}
|
||||
|
||||
func Test401NoAccessToken(t *testing.T) {
|
||||
r := require.New(t)
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
testUID := currentTokens.UID
|
||||
testRef := currentTokens.RefreshToken
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient(testUID, "", testRef, time.Now().Add(time.Hour))
|
||||
|
||||
// The request will fail with 401, triggering a refresh. After the refresh it should succeed.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(err)
|
||||
}
|
||||
|
||||
func Test401ExpiredAuthUpdateUser(t *testing.T) {
|
||||
r := require.New(t)
|
||||
mux := http.NewServeMux()
|
||||
currentTokens := newTestRefreshToken(r)
|
||||
testUID := currentTokens.UID
|
||||
testRef := currentTokens.RefreshToken
|
||||
|
||||
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
|
||||
|
||||
mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !currentTokens.isAuthorized(r.Header) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(newAuth); err != nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
respObj := struct {
|
||||
Code int
|
||||
User *User
|
||||
}{
|
||||
Code: 1000,
|
||||
User: &User{
|
||||
ID: "MJLke8kWh1BBvG95JBIrZvzpgsZ94hNNgjNHVyhXMiv4g9cn6SgvqiIFR5cigpml2LD_iUk_3DkV29oojTt3eA==",
|
||||
Name: "jason",
|
||||
UsedSpace: &usedSpace,
|
||||
},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(respObj); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") == ("Bearer " + oldAuth.AccessToken) {
|
||||
if !currentTokens.isAuthorized(r.Header) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Header.Get("Authorization") == ("Bearer " + newAuth.AccessToken) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(testAddressList); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient(oldAuth.UID, oldAuth.AccessToken, oldAuth.RefreshToken, time.Now().Add(time.Hour))
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
c, _, err := m.NewClientWithRefresh(context.Background(), testUID, testRef)
|
||||
r.NoError(err)
|
||||
|
||||
// The request will fail with 401, triggering a refresh. After the refresh it should succeed.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
_, err = c.UpdateUser(context.Background())
|
||||
r.NoError(err)
|
||||
}
|
||||
|
||||
func TestAuth2FA(t *testing.T) {
|
||||
r := require.New(t)
|
||||
twoFACode := "code"
|
||||
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
|
||||
var twoFAreq auth2FAReq
|
||||
r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq))
|
||||
r.Equal(t, twoFAreq.TwoFactorCode, twoFACode)
|
||||
r.NoError(json.NewDecoder(req.Body).Decode(&twoFAreq))
|
||||
r.Equal(twoFAreq.TwoFactorCode, twoFACode)
|
||||
|
||||
return "/auth/2fa/post_response.json"
|
||||
},
|
||||
@ -205,31 +229,33 @@ func TestAuth2FA(t *testing.T) {
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), twoFACode)
|
||||
r.NoError(t, err)
|
||||
r.NoError(err)
|
||||
}
|
||||
|
||||
func TestAuth2FA_Fail(t *testing.T) {
|
||||
r := require.New(t)
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
return "/auth/2fa/post_401_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), "code")
|
||||
r.Equal(t, ErrBad2FACode, err)
|
||||
r.Equal(ErrBad2FACode, err)
|
||||
}
|
||||
|
||||
func TestAuth2FA_Retry(t *testing.T) {
|
||||
r := require.New(t)
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
return "/auth/2fa/post_422_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), "code")
|
||||
r.Equal(t, ErrBad2FACodeTryAgain, err)
|
||||
r.Equal(ErrBad2FACodeTryAgain, err)
|
||||
}
|
||||
|
||||
@ -19,8 +19,6 @@ package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
|
||||
@ -34,26 +32,26 @@ func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) {
|
||||
|
||||
// unlock unlocks the user's keys but without locking the keyring lock first.
|
||||
// Should only be used internally by methods that first lock the lock.
|
||||
func (c *client) unlock(ctx context.Context, passphrase []byte) (err error) {
|
||||
if _, err = c.CurrentUser(ctx); err != nil {
|
||||
return
|
||||
func (c *client) unlock(ctx context.Context, passphrase []byte) error {
|
||||
if _, err := c.CurrentUser(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.userKeyRing == nil {
|
||||
if err = c.unlockUser(passphrase); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
if err := c.unlockUser(passphrase); err != nil {
|
||||
return ErrUnlockFailed{err}
|
||||
}
|
||||
}
|
||||
|
||||
for _, address := range c.addresses {
|
||||
if c.addrKeyRing[address.ID] == nil {
|
||||
if err = c.unlockAddress(passphrase, address); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock address")
|
||||
if err := c.unlockAddress(passphrase, address); err != nil {
|
||||
return ErrUnlockFailed{err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) ReloadKeys(ctx context.Context, passphrase []byte) (err error) {
|
||||
|
||||
@ -31,10 +31,58 @@ var (
|
||||
ErrPasswordWrong = errors.New("wrong password")
|
||||
)
|
||||
|
||||
// ErrUnprocessableEntity ...
|
||||
type ErrUnprocessableEntity struct {
|
||||
OriginalError error
|
||||
}
|
||||
|
||||
func IsUnprocessableEntity(err error) bool {
|
||||
_, ok := err.(ErrUnprocessableEntity)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrUnprocessableEntity) Error() string {
|
||||
return err.OriginalError.Error()
|
||||
}
|
||||
|
||||
// ErrBadRequest ...
|
||||
type ErrBadRequest struct {
|
||||
OriginalError error
|
||||
}
|
||||
|
||||
func IsBadRequest(err error) bool {
|
||||
_, ok := err.(ErrBadRequest)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrBadRequest) Error() string {
|
||||
return err.OriginalError.Error()
|
||||
}
|
||||
|
||||
// ErrAuthFailed ...
|
||||
type ErrAuthFailed struct {
|
||||
OriginalError error
|
||||
}
|
||||
|
||||
func IsFailedAuth(err error) bool {
|
||||
_, ok := err.(ErrAuthFailed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrAuthFailed) Error() string {
|
||||
return err.OriginalError.Error()
|
||||
}
|
||||
|
||||
// ErrUnlockFailed ...
|
||||
type ErrUnlockFailed struct {
|
||||
OriginalError error
|
||||
}
|
||||
|
||||
func IsFailedUnlock(err error) bool {
|
||||
_, ok := err.(ErrUnlockFailed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrUnlockFailed) Error() string {
|
||||
return err.OriginalError.Error()
|
||||
}
|
||||
|
||||
@ -33,6 +33,7 @@ type manager struct {
|
||||
|
||||
isDown bool
|
||||
locker sync.Locker
|
||||
refreshingAuth sync.Locker
|
||||
connectionObservers []ConnectionObserver
|
||||
proxyDialer *ProxyTLSDialer
|
||||
|
||||
@ -50,6 +51,7 @@ func newManager(cfg Config) *manager {
|
||||
cfg: cfg,
|
||||
rc: resty.New().EnableTrace(),
|
||||
locker: &sync.Mutex{},
|
||||
refreshingAuth: &sync.Mutex{},
|
||||
pingMutex: &sync.RWMutex{},
|
||||
isPinging: false,
|
||||
setSentryUserIDOnce: sync.Once{},
|
||||
|
||||
@ -102,6 +102,9 @@ func (m *manager) auth(ctx context.Context, req AuthReq) (*Auth, error) {
|
||||
}
|
||||
|
||||
func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefresh, error) {
|
||||
m.refreshingAuth.Lock()
|
||||
defer m.refreshingAuth.Unlock()
|
||||
|
||||
var req = authRefreshReq{
|
||||
UID: uid,
|
||||
RefreshToken: ref,
|
||||
@ -117,6 +120,9 @@ func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefres
|
||||
|
||||
_, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh"))
|
||||
if err != nil {
|
||||
if IsBadRequest(err) || IsUnprocessableEntity(err) {
|
||||
err = ErrAuthFailed{err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@ -59,16 +59,14 @@ func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error {
|
||||
if apiErr, ok := res.Error().(*Error); ok {
|
||||
switch {
|
||||
case apiErr.Code == errCodeUpgradeApplication:
|
||||
err = ErrUpgradeApplication
|
||||
if m.cfg.UpgradeApplicationHandler != nil {
|
||||
m.cfg.UpgradeApplicationHandler()
|
||||
}
|
||||
return ErrUpgradeApplication
|
||||
case apiErr.Code == errCodePasswordWrong:
|
||||
err = ErrPasswordWrong
|
||||
return ErrPasswordWrong
|
||||
case apiErr.Code == errCodeAuthPaidPlanRequired:
|
||||
err = ErrPaidPlanRequired
|
||||
case res.StatusCode() == http.StatusUnprocessableEntity:
|
||||
err = ErrUnprocessableEntity{apiErr}
|
||||
return ErrPaidPlanRequired
|
||||
default:
|
||||
err = apiErr
|
||||
}
|
||||
@ -76,6 +74,13 @@ func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error {
|
||||
err = errors.New(res.Status())
|
||||
}
|
||||
|
||||
switch res.StatusCode() {
|
||||
case http.StatusUnprocessableEntity:
|
||||
err = ErrUnprocessableEntity{err}
|
||||
case http.StatusBadRequest:
|
||||
err = ErrBadRequest{err}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@ -46,6 +46,7 @@ type PMAPIController interface {
|
||||
LockEvents(username string)
|
||||
UnlockEvents(username string)
|
||||
RemoveUserMessageWithoutEvent(username, messageID string) error
|
||||
RevokeSession(username string) error
|
||||
}
|
||||
|
||||
func newPMAPIController(listener listener.Listener) (PMAPIController, pmapi.Manager) {
|
||||
|
||||
@ -250,3 +250,10 @@ func (ctl *Controller) RemoveUserMessageWithoutEvent(username string, messageID
|
||||
|
||||
return errors.New("message not found")
|
||||
}
|
||||
|
||||
func (ctl *Controller) RevokeSession(username string) error {
|
||||
for _, session := range ctl.sessionsByUID {
|
||||
session.uid = "revoked"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -74,12 +74,12 @@ func (ctl *Controller) createSession(username string, hasFullScope bool) *fakeSe
|
||||
|
||||
func (ctl *Controller) refreshSessionIfAuthorized(uid, ref string) (*fakeSession, error) {
|
||||
session, ok := ctl.sessionsByUID[uid]
|
||||
if !ok {
|
||||
return nil, pmapi.ErrUnauthorized
|
||||
if !ok || session.uid != uid {
|
||||
return nil, pmapi.ErrAuthFailed{OriginalError: errors.New("bad uid")}
|
||||
}
|
||||
|
||||
if ref != session.ref {
|
||||
return nil, pmapi.ErrUnauthorized
|
||||
return nil, pmapi.ErrAuthFailed{OriginalError: errors.New("bad refresh token")}
|
||||
}
|
||||
|
||||
session.ref = ctl.tokenGenerator.next("ref")
|
||||
|
||||
@ -133,14 +133,32 @@ func (api *FakePMAPI) authRefresh() error {
|
||||
|
||||
session, err := api.controller.refreshSessionIfAuthorized(api.uid, api.ref)
|
||||
if err != nil {
|
||||
if pmapi.IsFailedAuth(err) {
|
||||
go api.handleAuth(nil)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
api.ref = session.ref
|
||||
api.acc = session.acc
|
||||
|
||||
go api.handleAuth(&pmapi.AuthRefresh{
|
||||
UID: api.uid,
|
||||
AccessToken: api.acc,
|
||||
RefreshToken: api.ref,
|
||||
ExpiresIn: 7200,
|
||||
Scopes: []string{"full", "self", "user", "mail"},
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) handleAuth(auth *pmapi.AuthRefresh) {
|
||||
for _, handle := range api.authHandlers {
|
||||
handle(auth)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) setUser(username string) error {
|
||||
api.username = username
|
||||
api.log = api.log.WithField("username", username)
|
||||
|
||||
@ -69,7 +69,7 @@ func (m *fakePMAPIManager) NewClientWithRefresh(_ context.Context, uid, ref stri
|
||||
|
||||
session, err := m.controller.refreshSessionIfAuthorized(uid, ref)
|
||||
if err != nil {
|
||||
return nil, nil, pmapi.ErrUnauthorized
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
user, ok := m.controller.usersByUsername[session.username]
|
||||
|
||||
@ -82,6 +82,10 @@ func (api *FakePMAPI) UpdateUser(context.Context) (*pmapi.User, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := api.checkAndRecordCall(GET, "/addresses", nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return api.user, nil
|
||||
}
|
||||
|
||||
|
||||
17
test/features/users/revoked_session.feature
Normal file
17
test/features/users/revoked_session.feature
Normal file
@ -0,0 +1,17 @@
|
||||
Feature: Session deleted on API
|
||||
|
||||
@ignore-live
|
||||
Scenario: Session revoked after start
|
||||
Given there is connected user "user"
|
||||
When session was revoked for "user"
|
||||
And the event loop of "user" loops once
|
||||
Then "user" is disconnected
|
||||
|
||||
|
||||
@ignore-live
|
||||
Scenario: Starting with revoked session
|
||||
Given there is user "user" which just logged in
|
||||
And session was revoked for "user"
|
||||
When bridge starts
|
||||
Then "user" is disconnected
|
||||
|
||||
@ -60,3 +60,7 @@ func (ctl *Controller) GetAuthClient(username string) pmapi.Client {
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func (ctl *Controller) RevokeSession(username string) error {
|
||||
return errors.New("revoke live session not implemented")
|
||||
}
|
||||
|
||||
@ -29,6 +29,7 @@ func UsersActionsFeatureContext(s *godog.ScenarioContext) {
|
||||
s.Step(`^user deletes "([^"]*)"$`, userDeletesUser)
|
||||
s.Step(`^user deletes "([^"]*)" with cache$`, userDeletesUserWithCache)
|
||||
s.Step(`^"([^"]*)" swaps address "([^"]*)" with address "([^"]*)"$`, swapsAddressWithAddress)
|
||||
s.Step(`^session was revoked for "([^"]*)"$`, sessionRevoked)
|
||||
}
|
||||
|
||||
func userLogsIn(bddUserID string) error {
|
||||
@ -123,3 +124,8 @@ func swapsAddressWithAddress(bddUserID, bddAddressID1, bddAddressID2 string) err
|
||||
|
||||
return ctx.GetPMAPIController().ReorderAddresses(account.User(), addressIDs)
|
||||
}
|
||||
|
||||
func sessionRevoked(bddUserID string) error {
|
||||
account := ctx.GetTestAccount(bddUserID)
|
||||
return ctx.GetPMAPIController().RevokeSession(account.Username())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user