fix(BRIDGE-108): fixed GetInitials when empty username is passed; we now overwrite the username in the vault if its value changed, each time we refresh user auth

This commit is contained in:
Atanas Janeshliev
2024-09-05 08:14:30 +00:00
parent 99e6f00aaa
commit bfe67f3005
4 changed files with 87 additions and 0 deletions

View File

@ -34,6 +34,11 @@ var (
// getInitials based on webapp implementation:
// https://github.com/ProtonMail/WebClients/blob/55d96a8b4afaaa4372fc5f1ef34953f2070fd7ec/packages/shared/lib/helpers/string.ts#L145
func getInitials(fullName string) string {
fullName = strings.TrimSpace(fullName)
if fullName == "" {
return "?"
}
words := strings.Split(
reMultiSpaces.ReplaceAllString(fullName, " "),
" ",

View File

@ -0,0 +1,45 @@
// Copyright (c) 2024 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package grpc
import (
"testing"
"github.com/stretchr/testify/require"
)
func Test_GetInitials(t *testing.T) {
require.Equal(t, "?", getInitials(""))
require.Equal(t, "T", getInitials(" test"))
require.Equal(t, "T", getInitials("test "))
require.Equal(t, "T", getInitials(" test "))
require.Equal(t, "JD", getInitials(" John Doe "))
require.Equal(t, "J", getInitials(" JohnDoe@proton.me "))
require.Equal(t, "JD", getInitials("\t\r\n John Doe \t\r\n "))
require.Equal(t, "T", getInitials("TestTestman"))
require.Equal(t, "TT", getInitials("Test Testman"))
require.Equal(t, "J", getInitials("JamesJoyce"))
require.Equal(t, "J", getInitials("JamesJoyceJeremy"))
require.Equal(t, "J", getInitials("james.joyce"))
require.Equal(t, "JJ", getInitials("James Joyce"))
require.Equal(t, "JM", getInitials("James Joyce Mahabharata"))
require.Equal(t, "JL", getInitials("James Joyce Jeremy Lin"))
require.Equal(t, "JM", getInitials("Jean Michel"))
require.Equal(t, "GC", getInitials("George Michael Carrie"))
}

View File

@ -19,6 +19,7 @@ package vault
import (
"fmt"
"strings"
"github.com/bradenaw/juniper/xslices"
"golang.org/x/exp/slices"
@ -37,6 +38,10 @@ func (user *User) Username() string {
return user.vault.getUser(user.userID).Username
}
func (user *User) usernameUnsafe() string {
return user.vault.getUserUnsafe(user.userID).Username
}
// PrimaryEmail returns the user's primary email address.
func (user *User) PrimaryEmail() string {
return user.vault.getUser(user.userID).PrimaryEmail
@ -242,3 +247,15 @@ func (user *User) SetShouldSync(shouldResync bool) error {
func (user *User) GetShouldResync() bool {
return user.vault.getUser(user.userID).ShouldResync
}
// updateUsernameUnsafe - updates the username of the relevant user, provided that the new username is not empty
// and differs from the previous. Writes are not performed if this case is not met.
// Should only be called from contexts where the vault mutex is already locked.
func (user *User) updateUsernameUnsafe(username string) error {
if strings.TrimSpace(username) == "" || user.usernameUnsafe() == username {
return nil
}
return user.vault.modUserUnsafe(user.userID, func(userData *UserData) {
userData.Username = username
})
}

View File

@ -240,6 +240,10 @@ func (vault *Vault) GetOrAddUser(userID, username, primaryEmail, authUID, authRe
return nil, false, err
}
if err := user.updateUsernameUnsafe(username); err != nil {
return nil, false, err
}
return user, false, nil
}
}
@ -450,6 +454,22 @@ func (vault *Vault) getUser(userID string) UserData {
return users[idx]
}
// getUserUnsafe - fetches the relevant UserData.
// Should only be called from contexts in which the vault mutex has been read locked.
func (vault *Vault) getUserUnsafe(userID string) UserData {
users := vault.getUnsafe().Users
idx := xslices.IndexFunc(users, func(user UserData) bool {
return user.UserID == userID
})
if idx < 0 {
panic("Unknown user")
}
return users[idx]
}
func (vault *Vault) modUser(userID string, fn func(userData *UserData)) error {
vault.lock.Lock()
defer vault.lock.Unlock()