From 2450511555f2faffe64811ec9400ea7f738c4b54 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Wed, 12 Oct 2022 12:44:06 +0200 Subject: [PATCH] Other: Put back split login process in backend --- internal/bridge/bridge_test.go | 14 +-- internal/bridge/settings_test.go | 2 +- internal/bridge/sync_test.go | 4 +- internal/bridge/user.go | 100 ++++++++++++++-------- internal/bridge/user_test.go | 38 ++++---- internal/frontend/cli/accounts.go | 44 +++++++--- internal/frontend/grpc/service_methods.go | 2 +- tests/user_test.go | 2 +- 8 files changed, 126 insertions(+), 80 deletions(-) diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 68b96adc..f8ad5517 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -55,7 +55,7 @@ func TestBridge_ConnStatus(t *testing.T) { netCtl.Disable() // Trigger some operation that will fail due to the network disconnect. - _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) + _, err := bridge.LoginFull(context.Background(), username, password, nil, nil) require.Error(t, err) // Wait for the event. @@ -65,7 +65,7 @@ func TestBridge_ConnStatus(t *testing.T) { netCtl.Enable() // Trigger some operation that will succeed due to the network reconnect. - userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil) + userID, err := bridge.LoginFull(context.Background(), username, password, nil, nil) require.NoError(t, err) require.NotEmpty(t, userID) @@ -125,7 +125,7 @@ func TestBridge_UserAgent(t *testing.T) { require.Contains(t, bridge.GetCurrentUserAgent(), "platform") // Login the user. - _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) + _, err := bridge.LoginFull(context.Background(), username, password, nil, nil) require.NoError(t, err) // Assert that the user agent was sent to the API. @@ -150,7 +150,7 @@ func TestBridge_Cookies(t *testing.T) { // Start bridge and add a user so that API assigns us a session ID via cookie. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) + _, err := bridge.LoginFull(context.Background(), username, password, nil, nil) require.NoError(t, err) }) @@ -273,7 +273,7 @@ func TestBridge_ForceUpdate(t *testing.T) { s.SetMinAppVersion(v2_4_0) // Try to login the user. It will fail because the bridge is too old. - _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) + _, err := bridge.LoginFull(context.Background(), username, password, nil, nil) require.Error(t, err) // We should get an update required event. @@ -288,7 +288,7 @@ func TestBridge_BadVaultKey(t *testing.T) { // Login a user. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil) + newUserID, err := bridge.LoginFull(context.Background(), username, password, nil, nil) require.NoError(t, err) userID = newUserID @@ -316,7 +316,7 @@ func TestBridge_MissingGluonDir(t *testing.T) { var gluonDir string withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) + _, err := bridge.LoginFull(context.Background(), username, password, nil, nil) require.NoError(t, err) // Move the gluon dir. diff --git a/internal/bridge/settings_test.go b/internal/bridge/settings_test.go index 28f06858..315dc13a 100644 --- a/internal/bridge/settings_test.go +++ b/internal/bridge/settings_test.go @@ -15,7 +15,7 @@ func TestBridge_Settings_GluonDir(t *testing.T) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Create a user. - _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) + _, err := bridge.LoginFull(context.Background(), username, password, nil, nil) require.NoError(t, err) // Create a new location for the Gluon data. diff --git a/internal/bridge/sync_test.go b/internal/bridge/sync_test.go index 3f02a48f..9c764b50 100644 --- a/internal/bridge/sync_test.go +++ b/internal/bridge/sync_test.go @@ -48,7 +48,7 @@ func TestBridge_Sync(t *testing.T) { syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) defer done() - userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil) + userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) require.NoError(t, err) require.Equal(t, userID, (<-syncCh).UserID) @@ -83,7 +83,7 @@ func TestBridge_Sync(t *testing.T) { syncCh, done := chToType[events.Event, events.SyncFailed](bridge.GetEvents(events.SyncFailed{})) defer done() - userID, err := bridge.LoginUser(ctx, "imap", password, nil, nil) + userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) require.NoError(t, err) require.Equal(t, userID, (<-syncCh).UserID) diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 4481364d..19431a3c 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -72,50 +72,33 @@ func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) { return UserInfo{}, ErrNoSuchUser } -// LoginUser authorizes a new bridge user with the given username and password. -// If necessary, a TOTP and mailbox password are requested via the callbacks. -func (bridge *Bridge) LoginUser( - ctx context.Context, - username string, - password []byte, - getTOTP func() (string, error), - getKeyPass func() ([]byte, error), -) (string, error) { +// LoginAuth begins the login process. It returns an authorized client that might need 2FA. +func (bridge *Bridge) LoginAuth(ctx context.Context, username string, password []byte) (*liteapi.Client, liteapi.Auth, error) { client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password) if err != nil { - return "", fmt.Errorf("failed to create new API client: %w", err) + return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err) } + if _, ok := bridge.users[auth.UserID]; ok { + if err := client.AuthDelete(ctx); err != nil { + logrus.WithError(err).Warn("Failed to delete auth") + } + + return nil, liteapi.Auth{}, ErrUserAlreadyLoggedIn + } + + return client, auth, nil +} + +// LoginUser finishes the user login process using the client and auth received from LoginAuth. +func (bridge *Bridge) LoginUser( + ctx context.Context, + client *liteapi.Client, + auth liteapi.Auth, + keyPass []byte, +) (string, error) { userID, err := try.CatchVal( func() (string, error) { - if _, ok := bridge.users[auth.UserID]; ok { - return "", ErrUserAlreadyLoggedIn - } - - if auth.TwoFA.Enabled == liteapi.TOTPEnabled { - totp, err := getTOTP() - if err != nil { - return "", fmt.Errorf("failed to get TOTP: %w", err) - } - - if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil { - return "", fmt.Errorf("failed to authorize 2FA: %w", err) - } - } - - var keyPass []byte - - if auth.PasswordMode == liteapi.TwoPasswordMode { - userKeyPass, err := getKeyPass() - if err != nil { - return "", fmt.Errorf("failed to get key password: %w", err) - } - - keyPass = userKeyPass - } else { - keyPass = password - } - return bridge.loginUser(ctx, client, auth.UID, auth.RefreshToken, keyPass) }, func() error { @@ -137,6 +120,47 @@ func (bridge *Bridge) LoginUser( return userID, nil } +// LoginUser authorizes a new bridge user with the given username and password. +// If necessary, a TOTP and mailbox password are requested via the callbacks. +func (bridge *Bridge) LoginFull( + ctx context.Context, + username string, + password []byte, + getTOTP func() (string, error), + getKeyPass func() ([]byte, error), +) (string, error) { + client, auth, err := bridge.LoginAuth(ctx, username, password) + if err != nil { + return "", fmt.Errorf("failed to begin login process: %w", err) + } + + if auth.TwoFA.Enabled == liteapi.TOTPEnabled { + totp, err := getTOTP() + if err != nil { + return "", fmt.Errorf("failed to get TOTP: %w", err) + } + + if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil { + return "", fmt.Errorf("failed to authorize 2FA: %w", err) + } + } + + var keyPass []byte + + if auth.PasswordMode == liteapi.TwoPasswordMode { + userKeyPass, err := getKeyPass() + if err != nil { + return "", fmt.Errorf("failed to get key password: %w", err) + } + + keyPass = userKeyPass + } else { + keyPass = password + } + + return bridge.LoginUser(ctx, client, auth, keyPass) +} + // LogoutUser logs out the given user. func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error { if err := bridge.logoutUser(ctx, userID); err != nil { diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index e9bb79ce..46e8709c 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -31,7 +31,7 @@ func TestBridge_Login(t *testing.T) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. - userID, err := bridge.LoginUser(ctx, username, password, nil, nil) + userID, err := bridge.LoginFull(ctx, username, password, nil, nil) require.NoError(t, err) // The user is now connected. @@ -45,7 +45,7 @@ func TestBridge_LoginLogoutLogin(t *testing.T) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. - userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID := must(bridge.LoginFull(ctx, username, password, nil, nil)) // The user is now connected. require.Equal(t, []string{userID}, bridge.GetUserIDs()) @@ -59,7 +59,7 @@ func TestBridge_LoginLogoutLogin(t *testing.T) { require.Empty(t, getConnectedUserIDs(t, bridge)) // Login the user again. - newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil)) + newUserID := must(bridge.LoginFull(ctx, username, password, nil, nil)) require.Equal(t, userID, newUserID) // The user is connected again. @@ -73,7 +73,7 @@ func TestBridge_LoginDeleteLogin(t *testing.T) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. - userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID := must(bridge.LoginFull(ctx, username, password, nil, nil)) // The user is now connected. require.Equal(t, []string{userID}, bridge.GetUserIDs()) @@ -87,7 +87,7 @@ func TestBridge_LoginDeleteLogin(t *testing.T) { require.Empty(t, getConnectedUserIDs(t, bridge)) // Login the user again. - newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil)) + newUserID := must(bridge.LoginFull(ctx, username, password, nil, nil)) require.Equal(t, userID, newUserID) // The user is connected again. @@ -101,7 +101,7 @@ func TestBridge_LoginDeauthLogin(t *testing.T) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. - userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID := must(bridge.LoginFull(ctx, username, password, nil, nil)) // Get a channel to receive the deauth event. eventCh, done := bridge.GetEvents(events.UserDeauth{}) @@ -119,7 +119,7 @@ func TestBridge_LoginDeauthLogin(t *testing.T) { require.IsType(t, events.UserDeauth{}, <-eventCh) // Login the user after the disconnection. - newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil)) + newUserID := must(bridge.LoginFull(ctx, username, password, nil, nil)) require.Equal(t, userID, newUserID) // The user is connected again. @@ -137,7 +137,7 @@ func TestBridge_LoginExpireLogin(t *testing.T) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. Its auth will only be valid for a short time. - userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID := must(bridge.LoginFull(ctx, username, password, nil, nil)) // Wait until the auth expires. time.Sleep(authLife) @@ -154,7 +154,7 @@ func TestBridge_FailToLoad(t *testing.T) { // Login the user. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) }) // Deauth the user while bridge is stopped. @@ -174,7 +174,7 @@ func TestBridge_LoadWithoutInternet(t *testing.T) { // Login the user. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) }) // Simulate loss of internet connection. @@ -204,7 +204,7 @@ func TestBridge_LoginRestart(t *testing.T) { var userID string withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) }) withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { @@ -220,7 +220,7 @@ func TestBridge_LoginLogoutRestart(t *testing.T) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. - userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) // Logout the user. require.NoError(t, bridge.LogoutUser(ctx, userID)) @@ -240,7 +240,7 @@ func TestBridge_LoginDeleteRestart(t *testing.T) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. - userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) // Delete the user. require.NoError(t, bridge.DeleteUser(ctx, userID)) @@ -264,7 +264,7 @@ func TestBridge_FailLoginRecover(t *testing.T) { // Log the user in and record how much data was read. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - userID := must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID := must(bridge.LoginFull(ctx, username, password, nil, nil)) require.NoError(t, bridge.LogoutUser(ctx, userID)) }) @@ -273,7 +273,7 @@ func TestBridge_FailLoginRecover(t *testing.T) { // We should fail to log the user in because we can't fully read its data. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - require.Error(t, getErr(bridge.LoginUser(ctx, username, password, nil, nil))) + require.Error(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) // There should be no users. require.Empty(t, bridge.GetUserIDs()) @@ -284,7 +284,7 @@ func TestBridge_FailLoginRecover(t *testing.T) { func TestBridge_FailLoadRecover(t *testing.T) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - must(bridge.LoginUser(ctx, username, password, nil, nil)) + must(bridge.LoginFull(ctx, username, password, nil, nil)) }) var read uint64 @@ -318,7 +318,7 @@ func TestBridge_BridgePass(t *testing.T) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. - userID = must(bridge.LoginUser(ctx, username, password, nil, nil)) + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) // Retrieve the bridge pass. pass = must(bridge.GetUserInfo(userID)).BridgePass @@ -327,7 +327,7 @@ func TestBridge_BridgePass(t *testing.T) { require.NoError(t, bridge.LogoutUser(ctx, userID)) // Log the user back in. - must(bridge.LoginUser(ctx, username, password, nil, nil)) + must(bridge.LoginFull(ctx, username, password, nil, nil)) // The bridge pass should be the same. require.Equal(t, pass, pass) @@ -348,7 +348,7 @@ func TestBridge_AddressMode(t *testing.T) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Login the user. - userID, err := bridge.LoginUser(ctx, username, password, nil, nil) + userID, err := bridge.LoginFull(ctx, username, password, nil, nil) require.NoError(t, err) // Get the user's info. diff --git a/internal/frontend/cli/accounts.go b/internal/frontend/cli/accounts.go index 43b22f32..08728a44 100644 --- a/internal/frontend/cli/accounts.go +++ b/internal/frontend/cli/accounts.go @@ -25,6 +25,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/abiosoft/ishell" + "gitlab.protontech.ch/go/liteapi" ) func (f *frontendCLI) listAccounts(c *ishell.Context) { @@ -126,17 +127,38 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { //nolint:funlen f.Println("Authenticating ... ") - userID, err := f.bridge.LoginUser( - context.Background(), - loginName, - []byte(password), - func() (string, error) { - return f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty), nil - }, - func() ([]byte, error) { - return []byte(f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)), nil - }, - ) + client, auth, err := f.bridge.LoginAuth(context.Background(), loginName, []byte(password)) + if err != nil { + f.printAndLogError("Cannot login: ", err) + return + } + + if auth.TwoFA.Enabled == liteapi.TOTPEnabled { + code := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty) + if code == "" { + f.printAndLogError("Cannot login: need two factor code") + return + } + + if err := client.Auth2FA(context.Background(), liteapi.Auth2FAReq{TwoFactorCode: code}); err != nil { + f.printAndLogError("Cannot login: ", err) + return + } + } + + var keyPass []byte + + if auth.PasswordMode == liteapi.TwoPasswordMode { + keyPass = []byte(f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)) + if len(keyPass) == 0 { + f.printAndLogError("Cannot login: need mailbox password") + return + } + } else { + keyPass = []byte(password) + } + + userID, err := f.bridge.LoginUser(context.Background(), client, auth, keyPass) if err != nil { f.processAPIError(err) return diff --git a/internal/frontend/grpc/service_methods.go b/internal/frontend/grpc/service_methods.go index 6de3c4b4..ddb0d450 100644 --- a/internal/frontend/grpc/service_methods.go +++ b/internal/frontend/grpc/service_methods.go @@ -368,7 +368,7 @@ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empt // - bad credentials // - bad proton plan // - user already exists - userID, err := s.bridge.LoginUser(context.Background(), login.Username, password, nil, nil) + userID, err := s.bridge.LoginFull(context.Background(), login.Username, password, nil, nil) if err != nil { s.log.WithError(err).Error("Cannot login user") _ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, "Cannot login user")) diff --git a/tests/user_test.go b/tests/user_test.go index cc80f414..b5bcdfcb 100644 --- a/tests/user_test.go +++ b/tests/user_test.go @@ -156,7 +156,7 @@ func (s *scenario) theAddressOfAccountHasMessagesInMailbox(address, username str } func (s *scenario) userLogsInWithUsernameAndPassword(username, password string) error { - userID, err := s.t.bridge.LoginUser(context.Background(), username, []byte(password), nil, nil) + userID, err := s.t.bridge.LoginFull(context.Background(), username, []byte(password), nil, nil) if err != nil { s.t.pushError(err) } else {