From 6671dd38ea50cc7d3623d06496bc0577d6de5f16 Mon Sep 17 00:00:00 2001 From: Jakub Date: Mon, 14 Mar 2022 16:37:42 +0100 Subject: [PATCH] GODT-1524: Logout issues with macOS. --- go.mod | 4 +- go.sum | 2 + internal/store/cache.go | 2 +- internal/store/event_loop.go | 2 +- internal/users/user.go | 21 +- internal/users/user_new_test.go | 3 +- internal/users/users.go | 6 +- internal/users/users_new_test.go | 6 +- pkg/pmapi/addresses.go | 4 +- pkg/pmapi/auth.go | 4 +- pkg/pmapi/auth_server_test.go | 122 +++++++++++ pkg/pmapi/auth_test.go | 216 +++++++++++--------- pkg/pmapi/client_keys.go | 18 +- pkg/pmapi/errors.go | 48 +++++ pkg/pmapi/manager.go | 2 + pkg/pmapi/manager_auth.go | 6 + pkg/pmapi/response.go | 15 +- test/context/pmapi_controller.go | 1 + test/fakeapi/controller_control.go | 7 + test/fakeapi/controller_session.go | 6 +- test/fakeapi/fakeapi.go | 18 ++ test/fakeapi/manager.go | 2 +- test/fakeapi/user.go | 4 + test/features/users/revoked_session.feature | 17 ++ test/liveapi/users.go | 4 + test/users_actions_test.go | 6 + 26 files changed, 411 insertions(+), 135 deletions(-) create mode 100644 pkg/pmapi/auth_server_test.go create mode 100644 test/features/users/revoked_session.feature diff --git a/go.mod b/go.mod index 688cb79d..19aca587 100644 --- a/go.mod +++ b/go.mod @@ -48,10 +48,12 @@ require ( github.com/google/uuid v1.1.1 github.com/hashicorp/go-multierror v1.1.0 github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7 - github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d + github.com/keybase/go-keychain v0.0.0-20211119201326-e02f34051621 + github.com/kr/text v0.2.0 // indirect github.com/logrusorgru/aurora v2.0.3+incompatible github.com/mattn/go-runewidth v0.0.9 // indirect github.com/miekg/dns v1.1.41 + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce github.com/olekukonko/tablewriter v0.0.4 // indirect github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index f1170518..e12c82f5 100644 --- a/go.sum +++ b/go.sum @@ -265,6 +265,8 @@ github.com/kataras/pio v0.0.2/go.mod h1:hAoW0t9UmXi4R5Oyq5Z4irTbaTsOemSrDGUtaTl7 github.com/kataras/sitemap v0.0.5/go.mod h1:KY2eugMKiPwsJgx7+U103YZehfvNGOXURubcGyk0Bz8= github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d h1:gVjhBCfVGl32RIBooOANzfw+0UqX8HU+yPlMv8vypcg= github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d/go.mod h1:W6EbaYmb4RldPn0N3gvVHjY1wmU59kbymhW9NATWhwY= +github.com/keybase/go-keychain v0.0.0-20211119201326-e02f34051621 h1:aMQ7pA4f06yOVXSulygyGvy4xA94fyzjUGs0iqQdMOI= +github.com/keybase/go-keychain v0.0.0-20211119201326-e02f34051621/go.mod h1:enrU/ug069Om7vWxuFE6nikLI2BZNwevMiGSo43Kt5w= github.com/keybase/go.dbus v0.0.0-20200324223359-a94be52c0b03/go.mod h1:a8clEhrrGV/d76/f9r2I41BwANMihfZYV9C223vaxqE= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= diff --git a/internal/store/cache.go b/internal/store/cache.go index 37da020b..baa08cd8 100644 --- a/internal/store/cache.go +++ b/internal/store/cache.go @@ -194,7 +194,7 @@ func (store *Store) BuildAndCacheMessage(ctx context.Context, messageID string) } func (store *Store) checkAndRemoveDeletedMessage(err error, msgID string) { - if _, ok := err.(pmapi.ErrUnprocessableEntity); !ok { + if !pmapi.IsUnprocessableEntity(err) { return } l := store.log.WithError(err).WithField("msgID", msgID) diff --git a/internal/store/event_loop.go b/internal/store/event_loop.go index 503f9ad1..5c3cb093 100644 --- a/internal/store/event_loop.go +++ b/internal/store/event_loop.go @@ -477,7 +477,7 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...") if msg, err = loop.client().GetMessage(context.Background(), message.ID); err != nil { - if _, ok := err.(pmapi.ErrUnprocessableEntity); ok { + if pmapi.IsUnprocessableEntity(err) { msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API") err = nil continue diff --git a/internal/users/user.go b/internal/users/user.go index a5c946bf..31aefb00 100644 --- a/internal/users/user.go +++ b/internal/users/user.go @@ -223,7 +223,7 @@ func (u *User) UpdateSpace(apiUser *pmapi.User) { // values from client.CurrentUser() if apiUser == nil { var err error - apiUser, err = u.client.GetUser(pmapi.ContextWithoutRetry(context.Background())) + apiUser, err = u.GetClient().GetUser(pmapi.ContextWithoutRetry(context.Background())) if err != nil { u.log.WithError(err).Warning("Cannot update user space") return @@ -280,16 +280,21 @@ func (u *User) unlockIfNecessary() error { return nil } - switch errors.Cause(err) { - case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication: - u.log.WithError(err).Warn("Could not unlock user") - return nil + if pmapi.IsFailedAuth(err) || pmapi.IsFailedUnlock(err) { + if logoutErr := u.logout(); logoutErr != nil { + u.log.WithError(logoutErr).Warn("Could not logout user") + } + return errors.Wrap(err, "failed to unlock user") } - if logoutErr := u.logout(); logoutErr != nil { - u.log.WithError(logoutErr).Warn("Could not logout user") + switch errors.Cause(err) { + case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication: + u.log.WithError(err).Warn("Skipping unlock for known reason") + default: + u.log.WithError(err).Error("Unknown unlock issue") } - return errors.Wrap(err, "failed to unlock user") + + return nil } // IsCombinedAddressMode returns whether user is set in combined or split mode. diff --git a/internal/users/user_new_test.go b/internal/users/user_new_test.go index b742335f..5ca56f65 100644 --- a/internal/users/user_new_test.go +++ b/internal/users/user_new_test.go @@ -23,6 +23,7 @@ import ( "github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/users/credentials" + "github.com/ProtonMail/proton-bridge/pkg/pmapi" gomock "github.com/golang/mock/gomock" r "github.com/stretchr/testify/require" ) @@ -46,7 +47,7 @@ func TestNewUserUnlockFails(t *testing.T) { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()), m.pmapiClient.EXPECT().IsUnlocked().Return(false), - m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(errors.New("bad password")), + m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(pmapi.ErrUnlockFailed{OriginalError: errors.New("bad password")}), // Handle of unlock error. m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil), diff --git a/internal/users/users.go b/internal/users/users.go index 3763bf04..c0eda61b 100644 --- a/internal/users/users.go +++ b/internal/users/users.go @@ -178,8 +178,10 @@ func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *creden return connectErr } - if logoutErr := user.logout(); logoutErr != nil { - logrus.WithError(logoutErr).Warn("Could not logout user") + if pmapi.IsFailedAuth(connectErr) { + if logoutErr := user.logout(); logoutErr != nil { + logrus.WithError(logoutErr).Warn("Could not logout user") + } } return errors.Wrap(err, "could not refresh token") } diff --git a/internal/users/users_new_test.go b/internal/users/users_new_test.go index 5120c270..a7833a89 100644 --- a/internal/users/users_new_test.go +++ b/internal/users/users_new_test.go @@ -24,6 +24,7 @@ import ( "github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/users/credentials" + "github.com/ProtonMail/proton-bridge/pkg/pmapi" gomock "github.com/golang/mock/gomock" r "github.com/stretchr/testify/require" ) @@ -80,11 +81,11 @@ func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, errors.New("bad token")) + m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, pmapi.ErrBadRequest{OriginalError: errors.New("bad token")}) m.clientManager.EXPECT().NewClient("uid", "", "acc", time.Time{}).Return(m.pmapiClient) m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()) m.pmapiClient.EXPECT().IsUnlocked().Return(false) - m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(errors.New("not authorized")) + m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(pmapi.ErrBadRequest{OriginalError: errors.New("not authorized")}) m.pmapiClient.EXPECT().AuthDelete(gomock.Any()) m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) @@ -93,7 +94,6 @@ func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) { m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") - m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) diff --git a/pkg/pmapi/addresses.go b/pkg/pmapi/addresses.go index 4540f66a..4f0f2ab9 100644 --- a/pkg/pmapi/addresses.go +++ b/pkg/pmapi/addresses.go @@ -19,11 +19,11 @@ package pmapi import ( "context" - "errors" "strings" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/go-resty/resty/v2" + "github.com/pkg/errors" ) // Address statuses. @@ -201,7 +201,7 @@ func (c *client) unlockAddress(passphrase []byte, address *Address) error { kr, err := address.Keys.UnlockAll(passphrase, c.userKeyRing) if err != nil { - return err + return errors.Wrap(err, "cannot unlock address keys for "+address.ID) } c.addrKeyRing[address.ID] = kr diff --git a/pkg/pmapi/auth.go b/pkg/pmapi/auth.go index ded6efc7..160f3c7a 100644 --- a/pkg/pmapi/auth.go +++ b/pkg/pmapi/auth.go @@ -51,7 +51,7 @@ type TwoFAInfo struct { } func (twoFAInfo TwoFAInfo) hasTwoFactor() bool { - return twoFAInfo.Enabled > 0 + return twoFAInfo.Enabled > TwoFADisabled } type TwoFAStatus int @@ -185,7 +185,7 @@ func (c *client) authRefresh(ctx context.Context) error { auth, err := c.manager.authRefresh(ctx, c.uid, c.ref) if err != nil { - if err != ErrNoConnection { + if IsFailedAuth(err) { c.sendAuthRefresh(nil) } return err diff --git a/pkg/pmapi/auth_server_test.go b/pkg/pmapi/auth_server_test.go new file mode 100644 index 00000000..129b6df6 --- /dev/null +++ b/pkg/pmapi/auth_server_test.go @@ -0,0 +1,122 @@ +// Copyright (c) 2022 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/stretchr/testify/require" +) + +type testRefreshResponse struct { + Code int + AccessToken string + ExpiresIn int + TokenType string + Scope string + Scopes []string + UID string + RefreshToken string + LocalID int + + r *require.Assertions +} + +var tokenID = 0 + +func newTestRefreshToken(r *require.Assertions) testRefreshResponse { + tokenID++ + scopes := []string{ + "full", + "self", + "parent", + "user", + "loggedin", + "paid", + "nondelinquent", + "mail", + "verified", + } + return testRefreshResponse{ + Code: 1000, + AccessToken: fmt.Sprintf("acc%d", tokenID), + ExpiresIn: 3600, + TokenType: "Bearer", + Scope: strings.Join(scopes, " "), + Scopes: scopes, + UID: fmt.Sprintf("uid%d", tokenID), + RefreshToken: fmt.Sprintf("ref%d", tokenID), + r: r, + } +} + +func (r *testRefreshResponse) isCorrectRefreshToken(body io.ReadCloser) int { + request := authRefreshReq{} + err := json.NewDecoder(body).Decode(&request) + r.r.NoError(body.Close()) + r.r.NoError(err) + + if r.UID != request.UID { + return http.StatusUnprocessableEntity + } + if r.RefreshToken != request.RefreshToken { + return http.StatusBadRequest + } + return http.StatusOK +} + +func (r *testRefreshResponse) handleAuthRefresh(response http.ResponseWriter, request *http.Request) { + if code := r.isCorrectRefreshToken(request.Body); code != http.StatusOK { + response.WriteHeader(code) + return + } + + tokenID++ + r.AccessToken = fmt.Sprintf("acc%d", tokenID) + r.RefreshToken = fmt.Sprintf("ref%d", tokenID) + + response.Header().Set("Content-Type", "application/json") + response.WriteHeader(http.StatusOK) + r.r.NoError(json.NewEncoder(response).Encode(r)) +} + +func (r *testRefreshResponse) wantAuthRefresh() AuthRefresh { + return AuthRefresh{ + UID: r.UID, + AccessToken: r.AccessToken, + RefreshToken: r.RefreshToken, + ExpiresIn: int64(r.ExpiresIn), + Scopes: r.Scopes, + } +} + +func (r *testRefreshResponse) isAuthorized(header http.Header) bool { + return header.Get("x-pm-uid") == r.UID && header.Get("Authorization") == "Bearer "+r.AccessToken +} + +func (r *testRefreshResponse) handleAuthCheckOnly(response http.ResponseWriter, request *http.Request) { + if r.isAuthorized(request.Header) { + response.WriteHeader(http.StatusOK) + } else { + response.WriteHeader(http.StatusUnauthorized) + } +} diff --git a/pkg/pmapi/auth_test.go b/pkg/pmapi/auth_test.go index 9ce123fb..fbcad362 100644 --- a/pkg/pmapi/auth_test.go +++ b/pkg/pmapi/auth_test.go @@ -25,179 +25,203 @@ import ( "testing" "time" - a "github.com/stretchr/testify/assert" - r "github.com/stretchr/testify/require" + "github.com/stretchr/testify/require" ) func TestAutomaticAuthRefresh(t *testing.T) { - var wantAuthRefresh = &AuthRefresh{ - UID: "testUID", - AccessToken: "testAcc", - RefreshToken: "testRef", - ExpiresIn: 100, - } - + r := require.New(t) mux := http.NewServeMux() - mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") + currentTokens := newTestRefreshToken(r) + testUID := currentTokens.UID + testAcc := currentTokens.AccessToken + testRef := currentTokens.RefreshToken + currentTokens.ExpiresIn = 100 - if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil { - panic(err) - } - }) - - mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) + mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh) + mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly) ts := httptest.NewServer(mux) var gotAuthRefresh *AuthRefresh c := New(Config{HostURL: ts.URL}). - NewClient("uid", "acc", "ref", time.Now().Add(-time.Second)) + NewClient(testUID, testAcc, testRef, time.Now().Add(-time.Second)) c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth }) // Make a request with an access token that already expired one second ago. _, err := c.GetAddresses(context.Background()) - r.NoError(t, err) + r.NoError(err) + + wantAuthRefresh := currentTokens.wantAuthRefresh() // The auth callback should have been called. - a.Equal(t, *wantAuthRefresh, *gotAuthRefresh) + r.NotNil(gotAuthRefresh) + r.Equal(wantAuthRefresh, *gotAuthRefresh) cl := c.(*client) //nolint[forcetypeassert] we want to panic here - a.Equal(t, wantAuthRefresh.AccessToken, cl.acc) - a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref) - a.WithinDuration(t, expiresIn(100), cl.exp, time.Second) + r.Equal(wantAuthRefresh.AccessToken, cl.acc) + r.Equal(wantAuthRefresh.RefreshToken, cl.ref) + r.WithinDuration(expiresIn(100), cl.exp, time.Second) } func Test401AuthRefresh(t *testing.T) { - var wantAuthRefresh = &AuthRefresh{ - UID: "testUID", - AccessToken: "testAcc", - RefreshToken: "testRef", - } + r := require.New(t) + currentTokens := newTestRefreshToken(r) + testUID := currentTokens.UID + testRef := currentTokens.RefreshToken mux := http.NewServeMux() - mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil { - panic(err) - } - }) - - var call int - - mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) { - call++ - - if call == 1 { - w.WriteHeader(http.StatusUnauthorized) - } else { - w.WriteHeader(http.StatusOK) - } - }) + mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh) + mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly) ts := httptest.NewServer(mux) - var gotAuthRefresh *AuthRefresh // Create a new client. - c := New(Config{HostURL: ts.URL}). - NewClient("uid", "acc", "ref", time.Now().Add(time.Hour)) + m := New(Config{HostURL: ts.URL}) + c := m.NewClient(testUID, "oldAccToken", testRef, time.Now().Add(time.Hour)) // Register an auth handler. c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth }) // The first request will fail with 401, triggering a refresh and retry. _, err := c.GetAddresses(context.Background()) - r.NoError(t, err) + r.NoError(err) // The auth callback should have been called. - r.Equal(t, *wantAuthRefresh, *gotAuthRefresh) + r.NotNil(gotAuthRefresh) + r.Equal(currentTokens.wantAuthRefresh(), *gotAuthRefresh) } func Test401RevokedAuth(t *testing.T) { + r := require.New(t) + currentTokens := newTestRefreshToken(r) + mux := http.NewServeMux() - mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - }) - - mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - }) + mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh) + mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly) ts := httptest.NewServer(mux) c := New(Config{HostURL: ts.URL}). - NewClient("uid", "acc", "ref", time.Now().Add(time.Hour)) + NewClient("badUID", "badAcc", "badRef", time.Now().Add(time.Hour)) // The request will fail with 401, triggering a refresh. // The retry will also fail with 401, returning an error. _, err := c.GetAddresses(context.Background()) - r.EqualError(t, err, ErrUnauthorized.Error()) + r.True(IsFailedAuth(err)) } -func Test401RevokedAuthTokenUpdate(t *testing.T) { - var oldAuth = &AuthRefresh{ - UID: "UID", - AccessToken: "oldAcc", - RefreshToken: "oldRef", - ExpiresIn: 3600, - } - - var newAuth = &AuthRefresh{ - UID: "UID", - AccessToken: "newAcc", - RefreshToken: "newRef", - } +func Test401OldRefreshToken(t *testing.T) { + r := require.New(t) + currentTokens := newTestRefreshToken(r) mux := http.NewServeMux() - mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh) + mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly) + + ts := httptest.NewServer(mux) + + c := New(Config{HostURL: ts.URL}). + NewClient(currentTokens.UID, "oldAcc", "oldRef", time.Now().Add(time.Hour)) + + // The request will fail with 401, triggering a refresh. + // The retry will also fail with 401, returning an error. + _, err := c.GetAddresses(context.Background()) + r.True(IsFailedAuth(err)) +} + +func Test401NoAccessToken(t *testing.T) { + r := require.New(t) + currentTokens := newTestRefreshToken(r) + testUID := currentTokens.UID + testRef := currentTokens.RefreshToken + mux := http.NewServeMux() + mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh) + mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly) + + ts := httptest.NewServer(mux) + + c := New(Config{HostURL: ts.URL}). + NewClient(testUID, "", testRef, time.Now().Add(time.Hour)) + + // The request will fail with 401, triggering a refresh. After the refresh it should succeed. + _, err := c.GetAddresses(context.Background()) + r.NoError(err) +} + +func Test401ExpiredAuthUpdateUser(t *testing.T) { + r := require.New(t) + mux := http.NewServeMux() + currentTokens := newTestRefreshToken(r) + testUID := currentTokens.UID + testRef := currentTokens.RefreshToken + + mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh) + + mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) { + if !currentTokens.isAuthorized(r.Header) { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(newAuth); err != nil { + w.WriteHeader(http.StatusOK) + respObj := struct { + Code int + User *User + }{ + Code: 1000, + User: &User{ + ID: "MJLke8kWh1BBvG95JBIrZvzpgsZ94hNNgjNHVyhXMiv4g9cn6SgvqiIFR5cigpml2LD_iUk_3DkV29oojTt3eA==", + Name: "jason", + UsedSpace: &usedSpace, + }, + } + if err := json.NewEncoder(w).Encode(respObj); err != nil { panic(err) } }) mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") == ("Bearer " + oldAuth.AccessToken) { + if !currentTokens.isAuthorized(r.Header) { w.WriteHeader(http.StatusUnauthorized) return } - if r.Header.Get("Authorization") == ("Bearer " + newAuth.AccessToken) { - w.WriteHeader(http.StatusOK) - return + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(testAddressList); err != nil { + panic(err) } }) ts := httptest.NewServer(mux) - - c := New(Config{HostURL: ts.URL}). - NewClient(oldAuth.UID, oldAuth.AccessToken, oldAuth.RefreshToken, time.Now().Add(time.Hour)) + m := New(Config{HostURL: ts.URL}) + c, _, err := m.NewClientWithRefresh(context.Background(), testUID, testRef) + r.NoError(err) // The request will fail with 401, triggering a refresh. After the refresh it should succeed. - _, err := c.GetAddresses(context.Background()) - r.NoError(t, err) + _, err = c.UpdateUser(context.Background()) + r.NoError(err) } func TestAuth2FA(t *testing.T) { + r := require.New(t) twoFACode := "code" finish, c := newTestClientCallbacks(t, func(tb testing.TB, w http.ResponseWriter, req *http.Request) string { - r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa")) + r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa")) var twoFAreq auth2FAReq - r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq)) - r.Equal(t, twoFAreq.TwoFactorCode, twoFACode) + r.NoError(json.NewDecoder(req.Body).Decode(&twoFAreq)) + r.Equal(twoFAreq.TwoFactorCode, twoFACode) return "/auth/2fa/post_response.json" }, @@ -205,31 +229,33 @@ func TestAuth2FA(t *testing.T) { defer finish() err := c.Auth2FA(context.Background(), twoFACode) - r.NoError(t, err) + r.NoError(err) } func TestAuth2FA_Fail(t *testing.T) { + r := require.New(t) finish, c := newTestClientCallbacks(t, func(tb testing.TB, w http.ResponseWriter, req *http.Request) string { - r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa")) + r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa")) return "/auth/2fa/post_401_bad_password.json" }, ) defer finish() err := c.Auth2FA(context.Background(), "code") - r.Equal(t, ErrBad2FACode, err) + r.Equal(ErrBad2FACode, err) } func TestAuth2FA_Retry(t *testing.T) { + r := require.New(t) finish, c := newTestClientCallbacks(t, func(tb testing.TB, w http.ResponseWriter, req *http.Request) string { - r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa")) + r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa")) return "/auth/2fa/post_422_bad_password.json" }, ) defer finish() err := c.Auth2FA(context.Background(), "code") - r.Equal(t, ErrBad2FACodeTryAgain, err) + r.Equal(ErrBad2FACodeTryAgain, err) } diff --git a/pkg/pmapi/client_keys.go b/pkg/pmapi/client_keys.go index 64a3a625..b6844d0c 100644 --- a/pkg/pmapi/client_keys.go +++ b/pkg/pmapi/client_keys.go @@ -19,8 +19,6 @@ package pmapi import ( "context" - - "github.com/pkg/errors" ) // Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings. @@ -34,26 +32,26 @@ func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) { // unlock unlocks the user's keys but without locking the keyring lock first. // Should only be used internally by methods that first lock the lock. -func (c *client) unlock(ctx context.Context, passphrase []byte) (err error) { - if _, err = c.CurrentUser(ctx); err != nil { - return +func (c *client) unlock(ctx context.Context, passphrase []byte) error { + if _, err := c.CurrentUser(ctx); err != nil { + return err } if c.userKeyRing == nil { - if err = c.unlockUser(passphrase); err != nil { - return errors.Wrap(err, "failed to unlock user") + if err := c.unlockUser(passphrase); err != nil { + return ErrUnlockFailed{err} } } for _, address := range c.addresses { if c.addrKeyRing[address.ID] == nil { - if err = c.unlockAddress(passphrase, address); err != nil { - return errors.Wrap(err, "failed to unlock address") + if err := c.unlockAddress(passphrase, address); err != nil { + return ErrUnlockFailed{err} } } } - return + return nil } func (c *client) ReloadKeys(ctx context.Context, passphrase []byte) (err error) { diff --git a/pkg/pmapi/errors.go b/pkg/pmapi/errors.go index beb6873d..2819b20c 100644 --- a/pkg/pmapi/errors.go +++ b/pkg/pmapi/errors.go @@ -31,10 +31,58 @@ var ( ErrPasswordWrong = errors.New("wrong password") ) +// ErrUnprocessableEntity ... type ErrUnprocessableEntity struct { OriginalError error } +func IsUnprocessableEntity(err error) bool { + _, ok := err.(ErrUnprocessableEntity) + return ok +} + func (err ErrUnprocessableEntity) Error() string { return err.OriginalError.Error() } + +// ErrBadRequest ... +type ErrBadRequest struct { + OriginalError error +} + +func IsBadRequest(err error) bool { + _, ok := err.(ErrBadRequest) + return ok +} + +func (err ErrBadRequest) Error() string { + return err.OriginalError.Error() +} + +// ErrAuthFailed ... +type ErrAuthFailed struct { + OriginalError error +} + +func IsFailedAuth(err error) bool { + _, ok := err.(ErrAuthFailed) + return ok +} + +func (err ErrAuthFailed) Error() string { + return err.OriginalError.Error() +} + +// ErrUnlockFailed ... +type ErrUnlockFailed struct { + OriginalError error +} + +func IsFailedUnlock(err error) bool { + _, ok := err.(ErrUnlockFailed) + return ok +} + +func (err ErrUnlockFailed) Error() string { + return err.OriginalError.Error() +} diff --git a/pkg/pmapi/manager.go b/pkg/pmapi/manager.go index 6a94ec3c..5aa1a559 100644 --- a/pkg/pmapi/manager.go +++ b/pkg/pmapi/manager.go @@ -33,6 +33,7 @@ type manager struct { isDown bool locker sync.Locker + refreshingAuth sync.Locker connectionObservers []ConnectionObserver proxyDialer *ProxyTLSDialer @@ -50,6 +51,7 @@ func newManager(cfg Config) *manager { cfg: cfg, rc: resty.New().EnableTrace(), locker: &sync.Mutex{}, + refreshingAuth: &sync.Mutex{}, pingMutex: &sync.RWMutex{}, isPinging: false, setSentryUserIDOnce: sync.Once{}, diff --git a/pkg/pmapi/manager_auth.go b/pkg/pmapi/manager_auth.go index 4ebf15f2..67c1d42d 100644 --- a/pkg/pmapi/manager_auth.go +++ b/pkg/pmapi/manager_auth.go @@ -102,6 +102,9 @@ func (m *manager) auth(ctx context.Context, req AuthReq) (*Auth, error) { } func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefresh, error) { + m.refreshingAuth.Lock() + defer m.refreshingAuth.Unlock() + var req = authRefreshReq{ UID: uid, RefreshToken: ref, @@ -117,6 +120,9 @@ func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefres _, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh")) if err != nil { + if IsBadRequest(err) || IsUnprocessableEntity(err) { + err = ErrAuthFailed{err} + } return nil, err } diff --git a/pkg/pmapi/response.go b/pkg/pmapi/response.go index ae683e1d..bc399569 100644 --- a/pkg/pmapi/response.go +++ b/pkg/pmapi/response.go @@ -59,16 +59,14 @@ func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error { if apiErr, ok := res.Error().(*Error); ok { switch { case apiErr.Code == errCodeUpgradeApplication: - err = ErrUpgradeApplication if m.cfg.UpgradeApplicationHandler != nil { m.cfg.UpgradeApplicationHandler() } + return ErrUpgradeApplication case apiErr.Code == errCodePasswordWrong: - err = ErrPasswordWrong + return ErrPasswordWrong case apiErr.Code == errCodeAuthPaidPlanRequired: - err = ErrPaidPlanRequired - case res.StatusCode() == http.StatusUnprocessableEntity: - err = ErrUnprocessableEntity{apiErr} + return ErrPaidPlanRequired default: err = apiErr } @@ -76,6 +74,13 @@ func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error { err = errors.New(res.Status()) } + switch res.StatusCode() { + case http.StatusUnprocessableEntity: + err = ErrUnprocessableEntity{err} + case http.StatusBadRequest: + err = ErrBadRequest{err} + } + return err } diff --git a/test/context/pmapi_controller.go b/test/context/pmapi_controller.go index e3ff1dfe..d941bdf7 100644 --- a/test/context/pmapi_controller.go +++ b/test/context/pmapi_controller.go @@ -46,6 +46,7 @@ type PMAPIController interface { LockEvents(username string) UnlockEvents(username string) RemoveUserMessageWithoutEvent(username, messageID string) error + RevokeSession(username string) error } func newPMAPIController(listener listener.Listener) (PMAPIController, pmapi.Manager) { diff --git a/test/fakeapi/controller_control.go b/test/fakeapi/controller_control.go index 1b7e35e5..c3345d60 100644 --- a/test/fakeapi/controller_control.go +++ b/test/fakeapi/controller_control.go @@ -250,3 +250,10 @@ func (ctl *Controller) RemoveUserMessageWithoutEvent(username string, messageID return errors.New("message not found") } + +func (ctl *Controller) RevokeSession(username string) error { + for _, session := range ctl.sessionsByUID { + session.uid = "revoked" + } + return nil +} diff --git a/test/fakeapi/controller_session.go b/test/fakeapi/controller_session.go index 7c0cb1aa..a4a0a5f3 100644 --- a/test/fakeapi/controller_session.go +++ b/test/fakeapi/controller_session.go @@ -74,12 +74,12 @@ func (ctl *Controller) createSession(username string, hasFullScope bool) *fakeSe func (ctl *Controller) refreshSessionIfAuthorized(uid, ref string) (*fakeSession, error) { session, ok := ctl.sessionsByUID[uid] - if !ok { - return nil, pmapi.ErrUnauthorized + if !ok || session.uid != uid { + return nil, pmapi.ErrAuthFailed{OriginalError: errors.New("bad uid")} } if ref != session.ref { - return nil, pmapi.ErrUnauthorized + return nil, pmapi.ErrAuthFailed{OriginalError: errors.New("bad refresh token")} } session.ref = ctl.tokenGenerator.next("ref") diff --git a/test/fakeapi/fakeapi.go b/test/fakeapi/fakeapi.go index 78fe06d6..771085d7 100644 --- a/test/fakeapi/fakeapi.go +++ b/test/fakeapi/fakeapi.go @@ -133,14 +133,32 @@ func (api *FakePMAPI) authRefresh() error { session, err := api.controller.refreshSessionIfAuthorized(api.uid, api.ref) if err != nil { + if pmapi.IsFailedAuth(err) { + go api.handleAuth(nil) + } return err } api.ref = session.ref api.acc = session.acc + + go api.handleAuth(&pmapi.AuthRefresh{ + UID: api.uid, + AccessToken: api.acc, + RefreshToken: api.ref, + ExpiresIn: 7200, + Scopes: []string{"full", "self", "user", "mail"}, + }) + return nil } +func (api *FakePMAPI) handleAuth(auth *pmapi.AuthRefresh) { + for _, handle := range api.authHandlers { + handle(auth) + } +} + func (api *FakePMAPI) setUser(username string) error { api.username = username api.log = api.log.WithField("username", username) diff --git a/test/fakeapi/manager.go b/test/fakeapi/manager.go index 37493564..cd0a8692 100644 --- a/test/fakeapi/manager.go +++ b/test/fakeapi/manager.go @@ -69,7 +69,7 @@ func (m *fakePMAPIManager) NewClientWithRefresh(_ context.Context, uid, ref stri session, err := m.controller.refreshSessionIfAuthorized(uid, ref) if err != nil { - return nil, nil, pmapi.ErrUnauthorized + return nil, nil, err } user, ok := m.controller.usersByUsername[session.username] diff --git a/test/fakeapi/user.go b/test/fakeapi/user.go index 992ebbd8..e7b71e9f 100644 --- a/test/fakeapi/user.go +++ b/test/fakeapi/user.go @@ -82,6 +82,10 @@ func (api *FakePMAPI) UpdateUser(context.Context) (*pmapi.User, error) { return nil, err } + if err := api.checkAndRecordCall(GET, "/addresses", nil); err != nil { + return nil, err + } + return api.user, nil } diff --git a/test/features/users/revoked_session.feature b/test/features/users/revoked_session.feature new file mode 100644 index 00000000..fa46c8c3 --- /dev/null +++ b/test/features/users/revoked_session.feature @@ -0,0 +1,17 @@ +Feature: Session deleted on API + + @ignore-live + Scenario: Session revoked after start + Given there is connected user "user" + When session was revoked for "user" + And the event loop of "user" loops once + Then "user" is disconnected + + + @ignore-live + Scenario: Starting with revoked session + Given there is user "user" which just logged in + And session was revoked for "user" + When bridge starts + Then "user" is disconnected + diff --git a/test/liveapi/users.go b/test/liveapi/users.go index 63bba89a..79fc03e5 100644 --- a/test/liveapi/users.go +++ b/test/liveapi/users.go @@ -60,3 +60,7 @@ func (ctl *Controller) GetAuthClient(username string) pmapi.Client { } return client } + +func (ctl *Controller) RevokeSession(username string) error { + return errors.New("revoke live session not implemented") +} diff --git a/test/users_actions_test.go b/test/users_actions_test.go index 638fa8d4..b0d7905d 100644 --- a/test/users_actions_test.go +++ b/test/users_actions_test.go @@ -29,6 +29,7 @@ func UsersActionsFeatureContext(s *godog.ScenarioContext) { s.Step(`^user deletes "([^"]*)"$`, userDeletesUser) s.Step(`^user deletes "([^"]*)" with cache$`, userDeletesUserWithCache) s.Step(`^"([^"]*)" swaps address "([^"]*)" with address "([^"]*)"$`, swapsAddressWithAddress) + s.Step(`^session was revoked for "([^"]*)"$`, sessionRevoked) } func userLogsIn(bddUserID string) error { @@ -123,3 +124,8 @@ func swapsAddressWithAddress(bddUserID, bddAddressID1, bddAddressID2 string) err return ctx.GetPMAPIController().ReorderAddresses(account.User(), addressIDs) } + +func sessionRevoked(bddUserID string) error { + account := ctx.GetTestAccount(bddUserID) + return ctx.GetPMAPIController().RevokeSession(account.Username()) +}