forked from Silverfish/proton-bridge
GODT-2100: Load users in parallel at startup
This commit is contained in:
@ -25,6 +25,7 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon"
|
||||
@ -192,7 +193,7 @@ func getGluonDir(encVault *vault.Vault) (string, error) {
|
||||
}
|
||||
|
||||
if empty {
|
||||
if err := encVault.ForUser(func(user *vault.User) error {
|
||||
if err := encVault.ForUser(runtime.NumCPU(), func(user *vault.User) error {
|
||||
return user.ClearSyncStatus()
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("failed to reset user sync status: %w", err)
|
||||
|
||||
@ -20,6 +20,7 @@ package bridge
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/async"
|
||||
@ -319,7 +320,7 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut
|
||||
|
||||
// loadUsers tries to load each user in the vault that isn't already loaded.
|
||||
func (bridge *Bridge) loadUsers(ctx context.Context) error {
|
||||
return bridge.vault.ForUser(func(user *vault.User) error {
|
||||
return bridge.vault.ForUser(runtime.NumCPU(), func(user *vault.User) error {
|
||||
if user.AuthUID() == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package vault_test
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
@ -154,7 +155,7 @@ func TestUser_ForEach(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Iterate through the users.
|
||||
err = s.ForUser(func(user *vault.User) error {
|
||||
err = s.ForUser(runtime.NumCPU(), func(user *vault.User) error {
|
||||
switch user.UserID() {
|
||||
case "userID1":
|
||||
require.Equal(t, "username1", user.Username())
|
||||
@ -194,99 +195,3 @@ func TestUser_ForEach(t *testing.T) {
|
||||
// The store should have no users again.
|
||||
require.Empty(t, s.GetUserIDs())
|
||||
}
|
||||
|
||||
/*
|
||||
func TestUser(t *testing.T) {
|
||||
// Replace the token generator with a dummy one.
|
||||
vault.RandomToken = func(size int) ([]byte, error) {
|
||||
return []byte("token"), nil
|
||||
}
|
||||
|
||||
// create a new test vault.
|
||||
s := newVault(t)
|
||||
|
||||
// Set auth information for user 1 and 2.
|
||||
user1, err := s.AddUser("userID1", "user1", "authUID1", "authRef1", []byte("keyPass1"))
|
||||
require.NoError(t, err)
|
||||
user2, err := s.AddUser("userID2", "user2", "authUID2", "authRef2", []byte("keyPass2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set event IDs for user 1 and 2.
|
||||
require.NoError(t, user1.SetEventID("eventID1"))
|
||||
require.NoError(t, user2.SetEventID("eventID2"))
|
||||
|
||||
// Set sync state for user 1 and 2.
|
||||
require.NoError(t, user1.SetSync(true))
|
||||
require.NoError(t, user2.SetSync(false))
|
||||
|
||||
// Set gluon data for user 1 and 2.
|
||||
require.NoError(t, user1.SetGluonID("addrID1", "gluonID1"))
|
||||
require.NoError(t, user2.SetGluonID("addrID2", "gluonID2"))
|
||||
require.NoError(t, user1.SetUIDValidity("addrID1", imap.UID(1)))
|
||||
require.NoError(t, user2.SetUIDValidity("addrID2", imap.UID(2)))
|
||||
|
||||
// List available users.
|
||||
require.ElementsMatch(t, []string{"userID1", "userID2"}, s.GetUserIDs())
|
||||
|
||||
// Check gluon information for user 1.
|
||||
gluonID1, ok := user1.GetGluonIDs()["addrID1"]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "gluonID1", gluonID1)
|
||||
uidValidity1, ok := user1.GetUIDValidity("addrID1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, imap.UID(1), uidValidity1)
|
||||
require.NotEmpty(t, user1.GluonKey())
|
||||
|
||||
// Get auth information for user 1.
|
||||
require.Equal(t, "userID1", user1.UserID())
|
||||
require.Equal(t, "user1", user1.Username())
|
||||
require.Equal(t, hex.EncodeToString([]byte("token")), string(user1.BridgePass()))
|
||||
require.Equal(t, vault.CombinedMode, user1.AddressMode())
|
||||
require.Equal(t, "authUID1", user1.AuthUID())
|
||||
require.Equal(t, "authRef1", user1.AuthRef())
|
||||
require.Equal(t, []byte("keyPass1"), user1.KeyPass())
|
||||
require.Equal(t, "eventID1", user1.EventID())
|
||||
require.Equal(t, true, user1.HasSync())
|
||||
|
||||
// Check gluon information for user 1.
|
||||
gluonID2, ok := user2.GetGluonIDs()["addrID2"]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "gluonID2", gluonID2)
|
||||
uidValidity2, ok := user2.GetUIDValidity("addrID2")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, imap.UID(2), uidValidity2)
|
||||
require.NotEmpty(t, user2.GluonKey())
|
||||
|
||||
// Get auth information for user 2.
|
||||
require.Equal(t, "userID2", user2.UserID())
|
||||
require.Equal(t, "user2", user2.Username())
|
||||
require.Equal(t, hex.EncodeToString([]byte("token")), string(user2.BridgePass()))
|
||||
require.Equal(t, vault.CombinedMode, user2.AddressMode())
|
||||
require.Equal(t, "authUID2", user2.AuthUID())
|
||||
require.Equal(t, "authRef2", user2.AuthRef())
|
||||
require.Equal(t, []byte("keyPass2"), user2.KeyPass())
|
||||
require.Equal(t, "eventID2", user2.EventID())
|
||||
require.Equal(t, false, user2.HasSync())
|
||||
|
||||
// Clear the users.
|
||||
require.NoError(t, s.ClearUser("userID1"))
|
||||
require.NoError(t, s.ClearUser("userID2"))
|
||||
|
||||
// Their secrets should now be cleared.
|
||||
require.Equal(t, "", user1.AuthUID())
|
||||
require.Equal(t, "", user1.AuthRef())
|
||||
require.Empty(t, user1.KeyPass())
|
||||
|
||||
// Get auth information for user 2.
|
||||
require.Equal(t, "", user2.AuthUID())
|
||||
require.Equal(t, "", user2.AuthRef())
|
||||
require.Empty(t, user2.KeyPass())
|
||||
|
||||
// Delete auth information for user 1.
|
||||
require.NoError(t, s.DeleteUser("userID1"))
|
||||
|
||||
// List available userIDs. User 1 should be gone.
|
||||
require.ElementsMatch(t, []string{"userID2"}, s.GetUserIDs())
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha256"
|
||||
@ -28,6 +29,7 @@ import (
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
@ -109,20 +111,18 @@ func (vault *Vault) NewUser(userID string) (*User, error) {
|
||||
}
|
||||
|
||||
// ForUser executes a callback for each user in the vault.
|
||||
func (vault *Vault) ForUser(fn func(*User) error) error {
|
||||
for _, userID := range vault.GetUserIDs() {
|
||||
user, err := vault.NewUser(userID)
|
||||
func (vault *Vault) ForUser(parallelism int, fn func(*User) error) error {
|
||||
userIDs := vault.GetUserIDs()
|
||||
|
||||
return parallel.DoContext(context.Background(), parallelism, len(userIDs), func(_ context.Context, idx int) error {
|
||||
user, err := vault.NewUser(userIDs[idx])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = user.Close() }()
|
||||
|
||||
if err := fn(user); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return fn(user)
|
||||
})
|
||||
}
|
||||
|
||||
// AddUser creates a new user in the vault with the given ID and username.
|
||||
|
||||
@ -19,6 +19,7 @@ package vault_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
@ -49,7 +50,7 @@ func BenchmarkVault(b *testing.B) {
|
||||
|
||||
// Time how quickly we can iterate through the users and get their key pass and bridge pass.
|
||||
for i := 0; i < b.N; i++ {
|
||||
require.NoError(b, s.ForUser(func(user *vault.User) error {
|
||||
require.NoError(b, s.ForUser(runtime.NumCPU(), func(user *vault.User) error {
|
||||
require.NotEmpty(b, user.KeyPass())
|
||||
require.NotEmpty(b, user.BridgePass())
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user