diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 526850b1..3c8e15e4 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -584,29 +584,7 @@ func (bridge *Bridge) newVaultUser( authUID, authRef string, saltedKeyPass []byte, ) (*vault.User, bool, error) { - if !bridge.vault.HasUser(apiUser.ID) { - user, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, apiUser.Email, authUID, authRef, saltedKeyPass) - if err != nil { - return nil, false, fmt.Errorf("failed to add user to vault: %w", err) - } - - return user, true, nil - } - - user, err := bridge.vault.NewUser(apiUser.ID) - if err != nil { - return nil, false, err - } - - if err := user.SetAuth(authUID, authRef); err != nil { - return nil, false, err - } - - if err := user.SetKeyPass(saltedKeyPass); err != nil { - return nil, false, err - } - - return user, false, nil + return bridge.vault.GetOrAddUser(apiUser.ID, apiUser.Name, apiUser.Email, authUID, authRef, saltedKeyPass) } // logout logs out the given user, optionally logging them out from the API too. diff --git a/internal/vault/certs.go b/internal/vault/certs.go index 24be5f38..ee48e203 100644 --- a/internal/vault/certs.go +++ b/internal/vault/certs.go @@ -30,7 +30,12 @@ import ( // If CertPEMPath is set, it will attempt to read the certificate from the file. // Otherwise, or on read/validation failure, it will return the certificate from the vault. func (vault *Vault) GetBridgeTLSCert() ([]byte, []byte) { - if certPath, keyPath := vault.get().Certs.CustomCertPath, vault.get().Certs.CustomKeyPath; certPath != "" && keyPath != "" { + vault.lock.RLock() + defer vault.lock.RUnlock() + + certs := vault.getUnsafe().Certs + + if certPath, keyPath := certs.CustomCertPath, certs.CustomKeyPath; certPath != "" && keyPath != "" { if certPEM, keyPEM, err := readPEMCert(certPath, keyPath); err == nil { return certPEM, keyPEM } @@ -38,7 +43,7 @@ func (vault *Vault) GetBridgeTLSCert() ([]byte, []byte) { logrus.Error("Failed to read certificate from file, using default") } - return vault.get().Certs.Bridge.Cert, vault.get().Certs.Bridge.Key + return certs.Bridge.Cert, certs.Bridge.Key } // SetBridgeTLSCertPath sets the path to PEM-encoded certificates for the bridge. @@ -47,7 +52,7 @@ func (vault *Vault) SetBridgeTLSCertPath(certPath, keyPath string) error { return fmt.Errorf("invalid certificate: %w", err) } - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Certs.CustomCertPath = certPath data.Certs.CustomKeyPath = keyPath }) @@ -55,18 +60,18 @@ func (vault *Vault) SetBridgeTLSCertPath(certPath, keyPath string) error { // SetBridgeTLSCertKey sets the path to PEM-encoded certificates for the bridge. func (vault *Vault) SetBridgeTLSCertKey(cert, key []byte) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Certs.Bridge.Cert = cert data.Certs.Bridge.Key = key }) } func (vault *Vault) GetCertsInstalled() bool { - return vault.get().Certs.Installed + return vault.getSafe().Certs.Installed } func (vault *Vault) SetCertsInstalled(installed bool) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Certs.Installed = installed }) } diff --git a/internal/vault/cookies.go b/internal/vault/cookies.go index afbe6a57..96ccda66 100644 --- a/internal/vault/cookies.go +++ b/internal/vault/cookies.go @@ -18,11 +18,11 @@ package vault func (vault *Vault) GetCookies() ([]byte, error) { - return vault.get().Cookies, nil + return vault.getSafe().Cookies, nil } func (vault *Vault) SetCookies(cookies []byte) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Cookies = cookies }) } diff --git a/internal/vault/settings.go b/internal/vault/settings.go index f625825e..9a7815b3 100644 --- a/internal/vault/settings.go +++ b/internal/vault/settings.go @@ -33,72 +33,72 @@ const ( // GetIMAPPort sets the port that the IMAP server should listen on. func (vault *Vault) GetIMAPPort() int { - return vault.get().Settings.IMAPPort + return vault.getSafe().Settings.IMAPPort } // SetIMAPPort sets the port that the IMAP server should listen on. func (vault *Vault) SetIMAPPort(port int) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.IMAPPort = port }) } // GetSMTPPort sets the port that the SMTP server should listen on. func (vault *Vault) GetSMTPPort() int { - return vault.get().Settings.SMTPPort + return vault.getSafe().Settings.SMTPPort } // SetSMTPPort sets the port that the SMTP server should listen on. func (vault *Vault) SetSMTPPort(port int) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.SMTPPort = port }) } // GetIMAPSSL sets whether the IMAP server should use SSL. func (vault *Vault) GetIMAPSSL() bool { - return vault.get().Settings.IMAPSSL + return vault.getSafe().Settings.IMAPSSL } // SetIMAPSSL sets whether the IMAP server should use SSL. func (vault *Vault) SetIMAPSSL(ssl bool) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.IMAPSSL = ssl }) } // GetSMTPSSL sets whether the SMTP server should use SSL. func (vault *Vault) GetSMTPSSL() bool { - return vault.get().Settings.SMTPSSL + return vault.getSafe().Settings.SMTPSSL } // SetSMTPSSL sets whether the SMTP server should use SSL. func (vault *Vault) SetSMTPSSL(ssl bool) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.SMTPSSL = ssl }) } // GetGluonCacheDir sets the directory where the gluon should store its data. func (vault *Vault) GetGluonCacheDir() string { - return vault.get().Settings.GluonDir + return vault.getSafe().Settings.GluonDir } // SetGluonDir sets the directory where the gluon should store its data. func (vault *Vault) SetGluonDir(dir string) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.GluonDir = dir }) } // GetUpdateChannel sets the update channel. func (vault *Vault) GetUpdateChannel() updater.Channel { - return vault.get().Settings.UpdateChannel + return vault.getSafe().Settings.UpdateChannel } // SetUpdateChannel sets the update channel. func (vault *Vault) SetUpdateChannel(channel updater.Channel) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.UpdateChannel = channel }) } @@ -106,7 +106,7 @@ func (vault *Vault) SetUpdateChannel(channel updater.Channel) error { // GetUpdateRollout sets the update rollout. func (vault *Vault) GetUpdateRollout() float64 { // The rollout value 0.6046602879796196 is forbidden. The RNG was not seeded when it was picked (GODT-2319). - rollout := vault.get().Settings.UpdateRollout + rollout := vault.getSafe().Settings.UpdateRollout if math.Abs(rollout-ForbiddenRollout) >= 0.00000001 { return rollout } @@ -120,110 +120,110 @@ func (vault *Vault) GetUpdateRollout() float64 { // SetUpdateRollout sets the update rollout. func (vault *Vault) SetUpdateRollout(rollout float64) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.UpdateRollout = rollout }) } // GetColorScheme sets the color scheme to be used by the bridge GUI. func (vault *Vault) GetColorScheme() string { - return vault.get().Settings.ColorScheme + return vault.getSafe().Settings.ColorScheme } // SetColorScheme sets the color scheme to be used by the bridge GUI. func (vault *Vault) SetColorScheme(colorScheme string) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.ColorScheme = colorScheme }) } // GetProxyAllowed sets whether the bridge is allowed to use alternative routing. func (vault *Vault) GetProxyAllowed() bool { - return vault.get().Settings.ProxyAllowed + return vault.getSafe().Settings.ProxyAllowed } // SetProxyAllowed sets whether the bridge is allowed to use alternative routing. func (vault *Vault) SetProxyAllowed(allowed bool) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.ProxyAllowed = allowed }) } // GetShowAllMail sets whether the bridge should show the All Mail folder. func (vault *Vault) GetShowAllMail() bool { - return vault.get().Settings.ShowAllMail + return vault.getSafe().Settings.ShowAllMail } // SetShowAllMail sets whether the bridge should show the All Mail folder. func (vault *Vault) SetShowAllMail(showAllMail bool) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.ShowAllMail = showAllMail }) } // GetAutostart sets whether the bridge should autostart. func (vault *Vault) GetAutostart() bool { - return vault.get().Settings.Autostart + return vault.getSafe().Settings.Autostart } // SetAutostart sets whether the bridge should autostart. func (vault *Vault) SetAutostart(autostart bool) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.Autostart = autostart }) } // GetAutoUpdate sets whether the bridge should automatically update. func (vault *Vault) GetAutoUpdate() bool { - return vault.get().Settings.AutoUpdate + return vault.getSafe().Settings.AutoUpdate } // SetAutoUpdate sets whether the bridge should automatically update. func (vault *Vault) SetAutoUpdate(autoUpdate bool) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.AutoUpdate = autoUpdate }) } // GetTelemetryDisabled checks whether telemetry is disabled. func (vault *Vault) GetTelemetryDisabled() bool { - return vault.get().Settings.TelemetryDisabled + return vault.getSafe().Settings.TelemetryDisabled } // SetTelemetryDisabled sets whether telemetry is disabled. func (vault *Vault) SetTelemetryDisabled(telemetryDisabled bool) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.TelemetryDisabled = telemetryDisabled }) } // GetLastVersion returns the last version of the bridge that was run. func (vault *Vault) GetLastVersion() *semver.Version { - return semver.MustParse(vault.get().Settings.LastVersion) + return semver.MustParse(vault.getSafe().Settings.LastVersion) } // SetLastVersion sets the last version of the bridge that was run. func (vault *Vault) SetLastVersion(version *semver.Version) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.LastVersion = version.String() }) } // GetFirstStart returns whether this is the first time the bridge has been started. func (vault *Vault) GetFirstStart() bool { - return vault.get().Settings.FirstStart + return vault.getSafe().Settings.FirstStart } // SetFirstStart sets whether this is the first time the bridge has been started. func (vault *Vault) SetFirstStart(firstStart bool) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.FirstStart = firstStart }) } // GetMaxSyncMemory returns the maximum amount of memory the sync process should use. func (vault *Vault) GetMaxSyncMemory() uint64 { - v := vault.get().Settings.MaxSyncMemory + v := vault.getSafe().Settings.MaxSyncMemory // can be zero if never written to vault before. if v == 0 { return DefaultMaxSyncMemory @@ -234,14 +234,14 @@ func (vault *Vault) GetMaxSyncMemory() uint64 { // SetMaxSyncMemory sets the maximum amount of memory the sync process should use. func (vault *Vault) SetMaxSyncMemory(maxMemory uint64) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.MaxSyncMemory = maxMemory }) } // GetLastUserAgent returns the last user agent recorded by bridge. func (vault *Vault) GetLastUserAgent() string { - v := vault.get().Settings.LastUserAgent + v := vault.getSafe().Settings.LastUserAgent // Handle case where there may be no value. if len(v) == 0 { @@ -253,19 +253,19 @@ func (vault *Vault) GetLastUserAgent() string { // SetLastUserAgent store the last user agent recorded by bridge. func (vault *Vault) SetLastUserAgent(userAgent string) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.LastUserAgent = userAgent }) } // GetLastHeartbeatSent returns the last time heartbeat was sent. func (vault *Vault) GetLastHeartbeatSent() time.Time { - return vault.get().Settings.LastHeartbeatSent + return vault.getSafe().Settings.LastHeartbeatSent } // SetLastHeartbeatSent store the last time heartbeat was sent. func (vault *Vault) SetLastHeartbeatSent(timestamp time.Time) error { - return vault.mod(func(data *Data) { + return vault.modSafe(func(data *Data) { data.Settings.LastHeartbeatSent = timestamp }) } diff --git a/internal/vault/user.go b/internal/vault/user.go index f97358dc..1f997984 100644 --- a/internal/vault/user.go +++ b/internal/vault/user.go @@ -122,6 +122,14 @@ func (user *User) SetAuth(authUID, authRef string) error { }) } +func (user *User) setAuthAndKeyPassUnsafe(authUID, authRef string, keyPass []byte) error { + return user.vault.modUserUnsafe(user.userID, func(userData *UserData) { + userData.AuthRef = authRef + userData.AuthUID = authUID + userData.KeyPass = keyPass + }) +} + // KeyPass returns the user's (salted) key password. func (user *User) KeyPass() []byte { return user.vault.getUser(user.userID).KeyPass diff --git a/internal/vault/vault.go b/internal/vault/vault.go index e57c623f..e8cbc3b1 100644 --- a/internal/vault/vault.go +++ b/internal/vault/vault.go @@ -40,11 +40,11 @@ type Vault struct { path string gcm cipher.AEAD - enc []byte - encLock sync.RWMutex + enc []byte - ref map[string]int - refLock sync.Mutex + ref map[string]int + + lock sync.RWMutex panicHandler async.PanicHandler } @@ -79,14 +79,46 @@ func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHan // GetUserIDs returns the user IDs and usernames of all users in the vault. func (vault *Vault) GetUserIDs() []string { - return xslices.Map(vault.get().Users, func(user UserData) string { + vault.lock.RLock() + defer vault.lock.RUnlock() + + return xslices.Map(vault.getUnsafe().Users, func(user UserData) string { return user.UserID }) } +func (vault *Vault) getUsers() ([]*User, error) { + vault.lock.Lock() + defer vault.lock.Unlock() + + users := vault.getUnsafe().Users + + result := make([]*User, 0, len(users)) + + for _, user := range users { + u, err := vault.newUserUnsafe(user.UserID) + if err != nil { + for _, v := range result { + if err := v.Close(); err != nil { + logrus.WithError(err).Error("Fait to close user after failed get") + } + } + + return nil, err + } + + result = append(result, u) + } + + return result, nil +} + // HasUser returns true if the vault contains a user with the given ID. func (vault *Vault) HasUser(userID string) bool { - return xslices.IndexFunc(vault.get().Users, func(user UserData) bool { + vault.lock.RLock() + defer vault.lock.RUnlock() + + return xslices.IndexFunc(vault.getUnsafe().Users, func(user UserData) bool { return user.UserID == userID }) >= 0 } @@ -106,41 +138,61 @@ func (vault *Vault) GetUser(userID string, fn func(*User)) error { // NewUser returns a new vault user. It must be closed before it can be deleted. func (vault *Vault) NewUser(userID string) (*User, error) { - if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool { + vault.lock.Lock() + defer vault.lock.Unlock() + + return vault.newUserUnsafe(userID) +} + +func (vault *Vault) newUserUnsafe(userID string) (*User, error) { + if idx := xslices.IndexFunc(vault.getUnsafe().Users, func(user UserData) bool { return user.UserID == userID }); idx < 0 { return nil, errors.New("no such user") } - return vault.attachUser(userID), nil + return vault.attachUserUnsafe(userID), nil } // ForUser executes a callback for each user in the vault. func (vault *Vault) ForUser(parallelism int, fn func(*User) error) error { - userIDs := vault.GetUserIDs() + users, err := vault.getUsers() + if err != nil { + return err + } - return parallel.DoContext(context.Background(), parallelism, len(userIDs), func(_ context.Context, idx int) error { + r := parallel.DoContext(context.Background(), parallelism, len(users), func(_ context.Context, idx int) error { defer async.HandlePanic(vault.panicHandler) - user, err := vault.NewUser(userIDs[idx]) - if err != nil { - return err - } - defer func() { _ = user.Close() }() - + user := users[idx] return fn(user) }) + + for _, u := range users { + if err := u.Close(); err != nil { + logrus.WithError(err).Error("Failed to close user after ForUser") + } + } + + return r } // AddUser creates a new user in the vault with the given ID, username and password. // A gluon key is generated using the package's token generator. If a password is found in the password archive for this user, // it is restored, otherwise a new bridge password is generated using the package's token generator. func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) { + vault.lock.Lock() + defer vault.lock.Unlock() + + return vault.addUserUnsafe(userID, username, primaryEmail, authUID, authRef, keyPass) +} + +func (vault *Vault) addUserUnsafe(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) { logrus.WithField("userID", userID).Info("Adding vault user") var exists bool - if err := vault.mod(func(data *Data) { + if err := vault.modUnsafe(func(data *Data) { if idx := xslices.IndexFunc(data.Users, func(user UserData) bool { return user.UserID == userID }); idx >= 0 { @@ -161,13 +213,42 @@ func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef str return nil, errors.New("user already exists") } - return vault.NewUser(userID) + return vault.attachUserUnsafe(userID), nil +} + +// GetOrAddUser retrieves an existing user and updates the authRef and keyPass or creates a new user. Returns +// the user and whether the user did not exist before. +func (vault *Vault) GetOrAddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, bool, error) { + vault.lock.Lock() + defer vault.lock.Unlock() + + { + users := vault.getUnsafe().Users + + idx := xslices.IndexFunc(users, func(user UserData) bool { + return user.UserID == userID + }) + + if idx >= 0 { + user := vault.attachUserUnsafe(userID) + + if err := user.setAuthAndKeyPassUnsafe(authUID, authRef, keyPass); err != nil { + return nil, false, err + } + + return user, false, nil + } + } + + u, err := vault.addUserUnsafe(userID, username, primaryEmail, authUID, authRef, keyPass) + + return u, true, err } // DeleteUser removes the given user from the vault. func (vault *Vault) DeleteUser(userID string) error { - vault.refLock.Lock() - defer vault.refLock.Unlock() + vault.lock.Lock() + defer vault.lock.Unlock() logrus.WithField("userID", userID).Info("Deleting vault user") @@ -175,7 +256,7 @@ func (vault *Vault) DeleteUser(userID string) error { return fmt.Errorf("user %s is currently in use", userID) } - return vault.mod(func(data *Data) { + return vault.modUnsafe(func(data *Data) { idx := xslices.IndexFunc(data.Users, func(user UserData) bool { return user.UserID == userID }) @@ -189,17 +270,26 @@ func (vault *Vault) DeleteUser(userID string) error { } func (vault *Vault) Migrated() bool { - return vault.get().Migrated + vault.lock.RLock() + defer vault.lock.RUnlock() + + return vault.getUnsafe().Migrated } func (vault *Vault) SetMigrated() error { - return vault.mod(func(data *Data) { + vault.lock.Lock() + defer vault.lock.Unlock() + + return vault.modUnsafe(func(data *Data) { data.Migrated = true }) } func (vault *Vault) Reset(gluonDir string) error { - return vault.mod(func(data *Data) { + vault.lock.Lock() + defer vault.lock.Unlock() + + return vault.modUnsafe(func(data *Data) { *data = newDefaultData(gluonDir) }) } @@ -209,8 +299,8 @@ func (vault *Vault) Path() string { } func (vault *Vault) Close() error { - vault.refLock.Lock() - defer vault.refLock.Unlock() + vault.lock.Lock() + defer vault.lock.Unlock() if len(vault.ref) > 0 { return errors.New("vault is still in use") @@ -221,10 +311,7 @@ func (vault *Vault) Close() error { return nil } -func (vault *Vault) attachUser(userID string) *User { - vault.refLock.Lock() - defer vault.refLock.Unlock() - +func (vault *Vault) attachUserUnsafe(userID string) *User { logrus.WithField("userID", userID).Trace("Attaching vault user") vault.ref[userID]++ @@ -236,8 +323,8 @@ func (vault *Vault) attachUser(userID string) *User { } func (vault *Vault) detachUser(userID string) error { - vault.refLock.Lock() - defer vault.refLock.Unlock() + vault.lock.Lock() + defer vault.lock.Unlock() logrus.WithField("userID", userID).Trace("Detaching vault user") @@ -289,10 +376,14 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) { }, corrupt, nil } -func (vault *Vault) get() Data { - vault.encLock.RLock() - defer vault.encLock.RUnlock() +func (vault *Vault) getSafe() Data { + vault.lock.RLock() + defer vault.lock.RUnlock() + return vault.getUnsafe() +} + +func (vault *Vault) getUnsafe() Data { var data Data if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil { @@ -302,10 +393,14 @@ func (vault *Vault) get() Data { return data } -func (vault *Vault) mod(fn func(data *Data)) error { - vault.encLock.Lock() - defer vault.encLock.Unlock() +func (vault *Vault) modSafe(fn func(data *Data)) error { + vault.lock.Lock() + defer vault.lock.Unlock() + return vault.modUnsafe(fn) +} + +func (vault *Vault) modUnsafe(fn func(data *Data)) error { var data Data if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil { @@ -325,13 +420,31 @@ func (vault *Vault) mod(fn func(data *Data)) error { } func (vault *Vault) getUser(userID string) UserData { - return vault.get().Users[xslices.IndexFunc(vault.get().Users, func(user UserData) bool { + vault.lock.RLock() + defer vault.lock.RUnlock() + + users := vault.getUnsafe().Users + + idx := xslices.IndexFunc(users, func(user UserData) bool { return user.UserID == userID - })] + }) + + if idx < 0 { + panic("Unknown user") + } + + return users[idx] } func (vault *Vault) modUser(userID string, fn func(userData *UserData)) error { - return vault.mod(func(data *Data) { + vault.lock.Lock() + defer vault.lock.Unlock() + + return vault.modUserUnsafe(userID, fn) +} + +func (vault *Vault) modUserUnsafe(userID string, fn func(userData *UserData)) error { + return vault.modUnsafe(func(data *Data) { idx := xslices.IndexFunc(data.Users, func(user UserData) bool { return user.UserID == userID }) diff --git a/internal/vault/vault_debug.go b/internal/vault/vault_debug.go index 17ece7f8..aef5f5cf 100644 --- a/internal/vault/vault_debug.go +++ b/internal/vault/vault_debug.go @@ -24,7 +24,7 @@ import ( ) func (vault *Vault) ImportJSON(dec []byte) { - vault.mod(func(data *Data) { + vault.modSafe(func(data *Data) { if err := json.Unmarshal(dec, data); err != nil { panic(err) } @@ -32,7 +32,7 @@ func (vault *Vault) ImportJSON(dec []byte) { } func (vault *Vault) ExportJSON() []byte { - enc, err := json.MarshalIndent(vault.get(), "", " ") + enc, err := json.MarshalIndent(vault.getSafe(), "", " ") if err != nil { panic(err) }