forked from Silverfish/proton-bridge
fix(GODT-2606): Improve Vault concurrency scopes
Rewrite the vault to have one RWlock rather then two separate locks for data and reference counts. In certain circumstance, it could be possible that that different requests could end up in undefined states if a user got deleted successfully while at he same time another goroutine/thread is loading the given user. While I have not been able to reproduce this in a test, restricting the access scope to one lock rather than two, should avoid corner cases where logic code is executing outside of the lock scope.
This commit is contained in:
@ -584,29 +584,7 @@ func (bridge *Bridge) newVaultUser(
|
|||||||
authUID, authRef string,
|
authUID, authRef string,
|
||||||
saltedKeyPass []byte,
|
saltedKeyPass []byte,
|
||||||
) (*vault.User, bool, error) {
|
) (*vault.User, bool, error) {
|
||||||
if !bridge.vault.HasUser(apiUser.ID) {
|
return bridge.vault.GetOrAddUser(apiUser.ID, apiUser.Name, apiUser.Email, authUID, authRef, saltedKeyPass)
|
||||||
user, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, apiUser.Email, authUID, authRef, saltedKeyPass)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, fmt.Errorf("failed to add user to vault: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := bridge.vault.NewUser(apiUser.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := user.SetAuth(authUID, authRef); err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := user.SetKeyPass(saltedKeyPass); err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, false, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// logout logs out the given user, optionally logging them out from the API too.
|
// logout logs out the given user, optionally logging them out from the API too.
|
||||||
|
|||||||
@ -30,7 +30,12 @@ import (
|
|||||||
// If CertPEMPath is set, it will attempt to read the certificate from the file.
|
// If CertPEMPath is set, it will attempt to read the certificate from the file.
|
||||||
// Otherwise, or on read/validation failure, it will return the certificate from the vault.
|
// Otherwise, or on read/validation failure, it will return the certificate from the vault.
|
||||||
func (vault *Vault) GetBridgeTLSCert() ([]byte, []byte) {
|
func (vault *Vault) GetBridgeTLSCert() ([]byte, []byte) {
|
||||||
if certPath, keyPath := vault.get().Certs.CustomCertPath, vault.get().Certs.CustomKeyPath; certPath != "" && keyPath != "" {
|
vault.lock.RLock()
|
||||||
|
defer vault.lock.RUnlock()
|
||||||
|
|
||||||
|
certs := vault.getUnsafe().Certs
|
||||||
|
|
||||||
|
if certPath, keyPath := certs.CustomCertPath, certs.CustomKeyPath; certPath != "" && keyPath != "" {
|
||||||
if certPEM, keyPEM, err := readPEMCert(certPath, keyPath); err == nil {
|
if certPEM, keyPEM, err := readPEMCert(certPath, keyPath); err == nil {
|
||||||
return certPEM, keyPEM
|
return certPEM, keyPEM
|
||||||
}
|
}
|
||||||
@ -38,7 +43,7 @@ func (vault *Vault) GetBridgeTLSCert() ([]byte, []byte) {
|
|||||||
logrus.Error("Failed to read certificate from file, using default")
|
logrus.Error("Failed to read certificate from file, using default")
|
||||||
}
|
}
|
||||||
|
|
||||||
return vault.get().Certs.Bridge.Cert, vault.get().Certs.Bridge.Key
|
return certs.Bridge.Cert, certs.Bridge.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetBridgeTLSCertPath sets the path to PEM-encoded certificates for the bridge.
|
// SetBridgeTLSCertPath sets the path to PEM-encoded certificates for the bridge.
|
||||||
@ -47,7 +52,7 @@ func (vault *Vault) SetBridgeTLSCertPath(certPath, keyPath string) error {
|
|||||||
return fmt.Errorf("invalid certificate: %w", err)
|
return fmt.Errorf("invalid certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Certs.CustomCertPath = certPath
|
data.Certs.CustomCertPath = certPath
|
||||||
data.Certs.CustomKeyPath = keyPath
|
data.Certs.CustomKeyPath = keyPath
|
||||||
})
|
})
|
||||||
@ -55,18 +60,18 @@ func (vault *Vault) SetBridgeTLSCertPath(certPath, keyPath string) error {
|
|||||||
|
|
||||||
// SetBridgeTLSCertKey sets the path to PEM-encoded certificates for the bridge.
|
// SetBridgeTLSCertKey sets the path to PEM-encoded certificates for the bridge.
|
||||||
func (vault *Vault) SetBridgeTLSCertKey(cert, key []byte) error {
|
func (vault *Vault) SetBridgeTLSCertKey(cert, key []byte) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Certs.Bridge.Cert = cert
|
data.Certs.Bridge.Cert = cert
|
||||||
data.Certs.Bridge.Key = key
|
data.Certs.Bridge.Key = key
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) GetCertsInstalled() bool {
|
func (vault *Vault) GetCertsInstalled() bool {
|
||||||
return vault.get().Certs.Installed
|
return vault.getSafe().Certs.Installed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) SetCertsInstalled(installed bool) error {
|
func (vault *Vault) SetCertsInstalled(installed bool) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Certs.Installed = installed
|
data.Certs.Installed = installed
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,11 +18,11 @@
|
|||||||
package vault
|
package vault
|
||||||
|
|
||||||
func (vault *Vault) GetCookies() ([]byte, error) {
|
func (vault *Vault) GetCookies() ([]byte, error) {
|
||||||
return vault.get().Cookies, nil
|
return vault.getSafe().Cookies, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) SetCookies(cookies []byte) error {
|
func (vault *Vault) SetCookies(cookies []byte) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Cookies = cookies
|
data.Cookies = cookies
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -33,72 +33,72 @@ const (
|
|||||||
|
|
||||||
// GetIMAPPort sets the port that the IMAP server should listen on.
|
// GetIMAPPort sets the port that the IMAP server should listen on.
|
||||||
func (vault *Vault) GetIMAPPort() int {
|
func (vault *Vault) GetIMAPPort() int {
|
||||||
return vault.get().Settings.IMAPPort
|
return vault.getSafe().Settings.IMAPPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetIMAPPort sets the port that the IMAP server should listen on.
|
// SetIMAPPort sets the port that the IMAP server should listen on.
|
||||||
func (vault *Vault) SetIMAPPort(port int) error {
|
func (vault *Vault) SetIMAPPort(port int) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.IMAPPort = port
|
data.Settings.IMAPPort = port
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSMTPPort sets the port that the SMTP server should listen on.
|
// GetSMTPPort sets the port that the SMTP server should listen on.
|
||||||
func (vault *Vault) GetSMTPPort() int {
|
func (vault *Vault) GetSMTPPort() int {
|
||||||
return vault.get().Settings.SMTPPort
|
return vault.getSafe().Settings.SMTPPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSMTPPort sets the port that the SMTP server should listen on.
|
// SetSMTPPort sets the port that the SMTP server should listen on.
|
||||||
func (vault *Vault) SetSMTPPort(port int) error {
|
func (vault *Vault) SetSMTPPort(port int) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.SMTPPort = port
|
data.Settings.SMTPPort = port
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIMAPSSL sets whether the IMAP server should use SSL.
|
// GetIMAPSSL sets whether the IMAP server should use SSL.
|
||||||
func (vault *Vault) GetIMAPSSL() bool {
|
func (vault *Vault) GetIMAPSSL() bool {
|
||||||
return vault.get().Settings.IMAPSSL
|
return vault.getSafe().Settings.IMAPSSL
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetIMAPSSL sets whether the IMAP server should use SSL.
|
// SetIMAPSSL sets whether the IMAP server should use SSL.
|
||||||
func (vault *Vault) SetIMAPSSL(ssl bool) error {
|
func (vault *Vault) SetIMAPSSL(ssl bool) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.IMAPSSL = ssl
|
data.Settings.IMAPSSL = ssl
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSMTPSSL sets whether the SMTP server should use SSL.
|
// GetSMTPSSL sets whether the SMTP server should use SSL.
|
||||||
func (vault *Vault) GetSMTPSSL() bool {
|
func (vault *Vault) GetSMTPSSL() bool {
|
||||||
return vault.get().Settings.SMTPSSL
|
return vault.getSafe().Settings.SMTPSSL
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSMTPSSL sets whether the SMTP server should use SSL.
|
// SetSMTPSSL sets whether the SMTP server should use SSL.
|
||||||
func (vault *Vault) SetSMTPSSL(ssl bool) error {
|
func (vault *Vault) SetSMTPSSL(ssl bool) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.SMTPSSL = ssl
|
data.Settings.SMTPSSL = ssl
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGluonCacheDir sets the directory where the gluon should store its data.
|
// GetGluonCacheDir sets the directory where the gluon should store its data.
|
||||||
func (vault *Vault) GetGluonCacheDir() string {
|
func (vault *Vault) GetGluonCacheDir() string {
|
||||||
return vault.get().Settings.GluonDir
|
return vault.getSafe().Settings.GluonDir
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetGluonDir sets the directory where the gluon should store its data.
|
// SetGluonDir sets the directory where the gluon should store its data.
|
||||||
func (vault *Vault) SetGluonDir(dir string) error {
|
func (vault *Vault) SetGluonDir(dir string) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.GluonDir = dir
|
data.Settings.GluonDir = dir
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUpdateChannel sets the update channel.
|
// GetUpdateChannel sets the update channel.
|
||||||
func (vault *Vault) GetUpdateChannel() updater.Channel {
|
func (vault *Vault) GetUpdateChannel() updater.Channel {
|
||||||
return vault.get().Settings.UpdateChannel
|
return vault.getSafe().Settings.UpdateChannel
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUpdateChannel sets the update channel.
|
// SetUpdateChannel sets the update channel.
|
||||||
func (vault *Vault) SetUpdateChannel(channel updater.Channel) error {
|
func (vault *Vault) SetUpdateChannel(channel updater.Channel) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.UpdateChannel = channel
|
data.Settings.UpdateChannel = channel
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -106,7 +106,7 @@ func (vault *Vault) SetUpdateChannel(channel updater.Channel) error {
|
|||||||
// GetUpdateRollout sets the update rollout.
|
// GetUpdateRollout sets the update rollout.
|
||||||
func (vault *Vault) GetUpdateRollout() float64 {
|
func (vault *Vault) GetUpdateRollout() float64 {
|
||||||
// The rollout value 0.6046602879796196 is forbidden. The RNG was not seeded when it was picked (GODT-2319).
|
// The rollout value 0.6046602879796196 is forbidden. The RNG was not seeded when it was picked (GODT-2319).
|
||||||
rollout := vault.get().Settings.UpdateRollout
|
rollout := vault.getSafe().Settings.UpdateRollout
|
||||||
if math.Abs(rollout-ForbiddenRollout) >= 0.00000001 {
|
if math.Abs(rollout-ForbiddenRollout) >= 0.00000001 {
|
||||||
return rollout
|
return rollout
|
||||||
}
|
}
|
||||||
@ -120,110 +120,110 @@ func (vault *Vault) GetUpdateRollout() float64 {
|
|||||||
|
|
||||||
// SetUpdateRollout sets the update rollout.
|
// SetUpdateRollout sets the update rollout.
|
||||||
func (vault *Vault) SetUpdateRollout(rollout float64) error {
|
func (vault *Vault) SetUpdateRollout(rollout float64) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.UpdateRollout = rollout
|
data.Settings.UpdateRollout = rollout
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetColorScheme sets the color scheme to be used by the bridge GUI.
|
// GetColorScheme sets the color scheme to be used by the bridge GUI.
|
||||||
func (vault *Vault) GetColorScheme() string {
|
func (vault *Vault) GetColorScheme() string {
|
||||||
return vault.get().Settings.ColorScheme
|
return vault.getSafe().Settings.ColorScheme
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetColorScheme sets the color scheme to be used by the bridge GUI.
|
// SetColorScheme sets the color scheme to be used by the bridge GUI.
|
||||||
func (vault *Vault) SetColorScheme(colorScheme string) error {
|
func (vault *Vault) SetColorScheme(colorScheme string) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.ColorScheme = colorScheme
|
data.Settings.ColorScheme = colorScheme
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProxyAllowed sets whether the bridge is allowed to use alternative routing.
|
// GetProxyAllowed sets whether the bridge is allowed to use alternative routing.
|
||||||
func (vault *Vault) GetProxyAllowed() bool {
|
func (vault *Vault) GetProxyAllowed() bool {
|
||||||
return vault.get().Settings.ProxyAllowed
|
return vault.getSafe().Settings.ProxyAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetProxyAllowed sets whether the bridge is allowed to use alternative routing.
|
// SetProxyAllowed sets whether the bridge is allowed to use alternative routing.
|
||||||
func (vault *Vault) SetProxyAllowed(allowed bool) error {
|
func (vault *Vault) SetProxyAllowed(allowed bool) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.ProxyAllowed = allowed
|
data.Settings.ProxyAllowed = allowed
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetShowAllMail sets whether the bridge should show the All Mail folder.
|
// GetShowAllMail sets whether the bridge should show the All Mail folder.
|
||||||
func (vault *Vault) GetShowAllMail() bool {
|
func (vault *Vault) GetShowAllMail() bool {
|
||||||
return vault.get().Settings.ShowAllMail
|
return vault.getSafe().Settings.ShowAllMail
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetShowAllMail sets whether the bridge should show the All Mail folder.
|
// SetShowAllMail sets whether the bridge should show the All Mail folder.
|
||||||
func (vault *Vault) SetShowAllMail(showAllMail bool) error {
|
func (vault *Vault) SetShowAllMail(showAllMail bool) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.ShowAllMail = showAllMail
|
data.Settings.ShowAllMail = showAllMail
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAutostart sets whether the bridge should autostart.
|
// GetAutostart sets whether the bridge should autostart.
|
||||||
func (vault *Vault) GetAutostart() bool {
|
func (vault *Vault) GetAutostart() bool {
|
||||||
return vault.get().Settings.Autostart
|
return vault.getSafe().Settings.Autostart
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAutostart sets whether the bridge should autostart.
|
// SetAutostart sets whether the bridge should autostart.
|
||||||
func (vault *Vault) SetAutostart(autostart bool) error {
|
func (vault *Vault) SetAutostart(autostart bool) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.Autostart = autostart
|
data.Settings.Autostart = autostart
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAutoUpdate sets whether the bridge should automatically update.
|
// GetAutoUpdate sets whether the bridge should automatically update.
|
||||||
func (vault *Vault) GetAutoUpdate() bool {
|
func (vault *Vault) GetAutoUpdate() bool {
|
||||||
return vault.get().Settings.AutoUpdate
|
return vault.getSafe().Settings.AutoUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAutoUpdate sets whether the bridge should automatically update.
|
// SetAutoUpdate sets whether the bridge should automatically update.
|
||||||
func (vault *Vault) SetAutoUpdate(autoUpdate bool) error {
|
func (vault *Vault) SetAutoUpdate(autoUpdate bool) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.AutoUpdate = autoUpdate
|
data.Settings.AutoUpdate = autoUpdate
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTelemetryDisabled checks whether telemetry is disabled.
|
// GetTelemetryDisabled checks whether telemetry is disabled.
|
||||||
func (vault *Vault) GetTelemetryDisabled() bool {
|
func (vault *Vault) GetTelemetryDisabled() bool {
|
||||||
return vault.get().Settings.TelemetryDisabled
|
return vault.getSafe().Settings.TelemetryDisabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTelemetryDisabled sets whether telemetry is disabled.
|
// SetTelemetryDisabled sets whether telemetry is disabled.
|
||||||
func (vault *Vault) SetTelemetryDisabled(telemetryDisabled bool) error {
|
func (vault *Vault) SetTelemetryDisabled(telemetryDisabled bool) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.TelemetryDisabled = telemetryDisabled
|
data.Settings.TelemetryDisabled = telemetryDisabled
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLastVersion returns the last version of the bridge that was run.
|
// GetLastVersion returns the last version of the bridge that was run.
|
||||||
func (vault *Vault) GetLastVersion() *semver.Version {
|
func (vault *Vault) GetLastVersion() *semver.Version {
|
||||||
return semver.MustParse(vault.get().Settings.LastVersion)
|
return semver.MustParse(vault.getSafe().Settings.LastVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLastVersion sets the last version of the bridge that was run.
|
// SetLastVersion sets the last version of the bridge that was run.
|
||||||
func (vault *Vault) SetLastVersion(version *semver.Version) error {
|
func (vault *Vault) SetLastVersion(version *semver.Version) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.LastVersion = version.String()
|
data.Settings.LastVersion = version.String()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFirstStart returns whether this is the first time the bridge has been started.
|
// GetFirstStart returns whether this is the first time the bridge has been started.
|
||||||
func (vault *Vault) GetFirstStart() bool {
|
func (vault *Vault) GetFirstStart() bool {
|
||||||
return vault.get().Settings.FirstStart
|
return vault.getSafe().Settings.FirstStart
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFirstStart sets whether this is the first time the bridge has been started.
|
// SetFirstStart sets whether this is the first time the bridge has been started.
|
||||||
func (vault *Vault) SetFirstStart(firstStart bool) error {
|
func (vault *Vault) SetFirstStart(firstStart bool) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.FirstStart = firstStart
|
data.Settings.FirstStart = firstStart
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMaxSyncMemory returns the maximum amount of memory the sync process should use.
|
// GetMaxSyncMemory returns the maximum amount of memory the sync process should use.
|
||||||
func (vault *Vault) GetMaxSyncMemory() uint64 {
|
func (vault *Vault) GetMaxSyncMemory() uint64 {
|
||||||
v := vault.get().Settings.MaxSyncMemory
|
v := vault.getSafe().Settings.MaxSyncMemory
|
||||||
// can be zero if never written to vault before.
|
// can be zero if never written to vault before.
|
||||||
if v == 0 {
|
if v == 0 {
|
||||||
return DefaultMaxSyncMemory
|
return DefaultMaxSyncMemory
|
||||||
@ -234,14 +234,14 @@ func (vault *Vault) GetMaxSyncMemory() uint64 {
|
|||||||
|
|
||||||
// SetMaxSyncMemory sets the maximum amount of memory the sync process should use.
|
// SetMaxSyncMemory sets the maximum amount of memory the sync process should use.
|
||||||
func (vault *Vault) SetMaxSyncMemory(maxMemory uint64) error {
|
func (vault *Vault) SetMaxSyncMemory(maxMemory uint64) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.MaxSyncMemory = maxMemory
|
data.Settings.MaxSyncMemory = maxMemory
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLastUserAgent returns the last user agent recorded by bridge.
|
// GetLastUserAgent returns the last user agent recorded by bridge.
|
||||||
func (vault *Vault) GetLastUserAgent() string {
|
func (vault *Vault) GetLastUserAgent() string {
|
||||||
v := vault.get().Settings.LastUserAgent
|
v := vault.getSafe().Settings.LastUserAgent
|
||||||
|
|
||||||
// Handle case where there may be no value.
|
// Handle case where there may be no value.
|
||||||
if len(v) == 0 {
|
if len(v) == 0 {
|
||||||
@ -253,19 +253,19 @@ func (vault *Vault) GetLastUserAgent() string {
|
|||||||
|
|
||||||
// SetLastUserAgent store the last user agent recorded by bridge.
|
// SetLastUserAgent store the last user agent recorded by bridge.
|
||||||
func (vault *Vault) SetLastUserAgent(userAgent string) error {
|
func (vault *Vault) SetLastUserAgent(userAgent string) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.LastUserAgent = userAgent
|
data.Settings.LastUserAgent = userAgent
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLastHeartbeatSent returns the last time heartbeat was sent.
|
// GetLastHeartbeatSent returns the last time heartbeat was sent.
|
||||||
func (vault *Vault) GetLastHeartbeatSent() time.Time {
|
func (vault *Vault) GetLastHeartbeatSent() time.Time {
|
||||||
return vault.get().Settings.LastHeartbeatSent
|
return vault.getSafe().Settings.LastHeartbeatSent
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLastHeartbeatSent store the last time heartbeat was sent.
|
// SetLastHeartbeatSent store the last time heartbeat was sent.
|
||||||
func (vault *Vault) SetLastHeartbeatSent(timestamp time.Time) error {
|
func (vault *Vault) SetLastHeartbeatSent(timestamp time.Time) error {
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modSafe(func(data *Data) {
|
||||||
data.Settings.LastHeartbeatSent = timestamp
|
data.Settings.LastHeartbeatSent = timestamp
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -122,6 +122,14 @@ func (user *User) SetAuth(authUID, authRef string) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *User) setAuthAndKeyPassUnsafe(authUID, authRef string, keyPass []byte) error {
|
||||||
|
return user.vault.modUserUnsafe(user.userID, func(userData *UserData) {
|
||||||
|
userData.AuthRef = authRef
|
||||||
|
userData.AuthUID = authUID
|
||||||
|
userData.KeyPass = keyPass
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// KeyPass returns the user's (salted) key password.
|
// KeyPass returns the user's (salted) key password.
|
||||||
func (user *User) KeyPass() []byte {
|
func (user *User) KeyPass() []byte {
|
||||||
return user.vault.getUser(user.userID).KeyPass
|
return user.vault.getUser(user.userID).KeyPass
|
||||||
|
|||||||
@ -40,11 +40,11 @@ type Vault struct {
|
|||||||
path string
|
path string
|
||||||
gcm cipher.AEAD
|
gcm cipher.AEAD
|
||||||
|
|
||||||
enc []byte
|
enc []byte
|
||||||
encLock sync.RWMutex
|
|
||||||
|
|
||||||
ref map[string]int
|
ref map[string]int
|
||||||
refLock sync.Mutex
|
|
||||||
|
lock sync.RWMutex
|
||||||
|
|
||||||
panicHandler async.PanicHandler
|
panicHandler async.PanicHandler
|
||||||
}
|
}
|
||||||
@ -79,14 +79,46 @@ func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHan
|
|||||||
|
|
||||||
// GetUserIDs returns the user IDs and usernames of all users in the vault.
|
// GetUserIDs returns the user IDs and usernames of all users in the vault.
|
||||||
func (vault *Vault) GetUserIDs() []string {
|
func (vault *Vault) GetUserIDs() []string {
|
||||||
return xslices.Map(vault.get().Users, func(user UserData) string {
|
vault.lock.RLock()
|
||||||
|
defer vault.lock.RUnlock()
|
||||||
|
|
||||||
|
return xslices.Map(vault.getUnsafe().Users, func(user UserData) string {
|
||||||
return user.UserID
|
return user.UserID
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (vault *Vault) getUsers() ([]*User, error) {
|
||||||
|
vault.lock.Lock()
|
||||||
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
|
users := vault.getUnsafe().Users
|
||||||
|
|
||||||
|
result := make([]*User, 0, len(users))
|
||||||
|
|
||||||
|
for _, user := range users {
|
||||||
|
u, err := vault.newUserUnsafe(user.UserID)
|
||||||
|
if err != nil {
|
||||||
|
for _, v := range result {
|
||||||
|
if err := v.Close(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Fait to close user after failed get")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// HasUser returns true if the vault contains a user with the given ID.
|
// HasUser returns true if the vault contains a user with the given ID.
|
||||||
func (vault *Vault) HasUser(userID string) bool {
|
func (vault *Vault) HasUser(userID string) bool {
|
||||||
return xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
vault.lock.RLock()
|
||||||
|
defer vault.lock.RUnlock()
|
||||||
|
|
||||||
|
return xslices.IndexFunc(vault.getUnsafe().Users, func(user UserData) bool {
|
||||||
return user.UserID == userID
|
return user.UserID == userID
|
||||||
}) >= 0
|
}) >= 0
|
||||||
}
|
}
|
||||||
@ -106,41 +138,61 @@ func (vault *Vault) GetUser(userID string, fn func(*User)) error {
|
|||||||
|
|
||||||
// NewUser returns a new vault user. It must be closed before it can be deleted.
|
// NewUser returns a new vault user. It must be closed before it can be deleted.
|
||||||
func (vault *Vault) NewUser(userID string) (*User, error) {
|
func (vault *Vault) NewUser(userID string) (*User, error) {
|
||||||
if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
vault.lock.Lock()
|
||||||
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
|
return vault.newUserUnsafe(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vault *Vault) newUserUnsafe(userID string) (*User, error) {
|
||||||
|
if idx := xslices.IndexFunc(vault.getUnsafe().Users, func(user UserData) bool {
|
||||||
return user.UserID == userID
|
return user.UserID == userID
|
||||||
}); idx < 0 {
|
}); idx < 0 {
|
||||||
return nil, errors.New("no such user")
|
return nil, errors.New("no such user")
|
||||||
}
|
}
|
||||||
|
|
||||||
return vault.attachUser(userID), nil
|
return vault.attachUserUnsafe(userID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForUser executes a callback for each user in the vault.
|
// ForUser executes a callback for each user in the vault.
|
||||||
func (vault *Vault) ForUser(parallelism int, fn func(*User) error) error {
|
func (vault *Vault) ForUser(parallelism int, fn func(*User) error) error {
|
||||||
userIDs := vault.GetUserIDs()
|
users, err := vault.getUsers()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return parallel.DoContext(context.Background(), parallelism, len(userIDs), func(_ context.Context, idx int) error {
|
r := parallel.DoContext(context.Background(), parallelism, len(users), func(_ context.Context, idx int) error {
|
||||||
defer async.HandlePanic(vault.panicHandler)
|
defer async.HandlePanic(vault.panicHandler)
|
||||||
|
|
||||||
user, err := vault.NewUser(userIDs[idx])
|
user := users[idx]
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() { _ = user.Close() }()
|
|
||||||
|
|
||||||
return fn(user)
|
return fn(user)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
for _, u := range users {
|
||||||
|
if err := u.Close(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to close user after ForUser")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddUser creates a new user in the vault with the given ID, username and password.
|
// AddUser creates a new user in the vault with the given ID, username and password.
|
||||||
// A gluon key is generated using the package's token generator. If a password is found in the password archive for this user,
|
// A gluon key is generated using the package's token generator. If a password is found in the password archive for this user,
|
||||||
// it is restored, otherwise a new bridge password is generated using the package's token generator.
|
// it is restored, otherwise a new bridge password is generated using the package's token generator.
|
||||||
func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) {
|
func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) {
|
||||||
|
vault.lock.Lock()
|
||||||
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
|
return vault.addUserUnsafe(userID, username, primaryEmail, authUID, authRef, keyPass)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vault *Vault) addUserUnsafe(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) {
|
||||||
logrus.WithField("userID", userID).Info("Adding vault user")
|
logrus.WithField("userID", userID).Info("Adding vault user")
|
||||||
|
|
||||||
var exists bool
|
var exists bool
|
||||||
|
|
||||||
if err := vault.mod(func(data *Data) {
|
if err := vault.modUnsafe(func(data *Data) {
|
||||||
if idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
if idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||||
return user.UserID == userID
|
return user.UserID == userID
|
||||||
}); idx >= 0 {
|
}); idx >= 0 {
|
||||||
@ -161,13 +213,42 @@ func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef str
|
|||||||
return nil, errors.New("user already exists")
|
return nil, errors.New("user already exists")
|
||||||
}
|
}
|
||||||
|
|
||||||
return vault.NewUser(userID)
|
return vault.attachUserUnsafe(userID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrAddUser retrieves an existing user and updates the authRef and keyPass or creates a new user. Returns
|
||||||
|
// the user and whether the user did not exist before.
|
||||||
|
func (vault *Vault) GetOrAddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, bool, error) {
|
||||||
|
vault.lock.Lock()
|
||||||
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
|
{
|
||||||
|
users := vault.getUnsafe().Users
|
||||||
|
|
||||||
|
idx := xslices.IndexFunc(users, func(user UserData) bool {
|
||||||
|
return user.UserID == userID
|
||||||
|
})
|
||||||
|
|
||||||
|
if idx >= 0 {
|
||||||
|
user := vault.attachUserUnsafe(userID)
|
||||||
|
|
||||||
|
if err := user.setAuthAndKeyPassUnsafe(authUID, authRef, keyPass); err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := vault.addUserUnsafe(userID, username, primaryEmail, authUID, authRef, keyPass)
|
||||||
|
|
||||||
|
return u, true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUser removes the given user from the vault.
|
// DeleteUser removes the given user from the vault.
|
||||||
func (vault *Vault) DeleteUser(userID string) error {
|
func (vault *Vault) DeleteUser(userID string) error {
|
||||||
vault.refLock.Lock()
|
vault.lock.Lock()
|
||||||
defer vault.refLock.Unlock()
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
logrus.WithField("userID", userID).Info("Deleting vault user")
|
logrus.WithField("userID", userID).Info("Deleting vault user")
|
||||||
|
|
||||||
@ -175,7 +256,7 @@ func (vault *Vault) DeleteUser(userID string) error {
|
|||||||
return fmt.Errorf("user %s is currently in use", userID)
|
return fmt.Errorf("user %s is currently in use", userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return vault.mod(func(data *Data) {
|
return vault.modUnsafe(func(data *Data) {
|
||||||
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||||
return user.UserID == userID
|
return user.UserID == userID
|
||||||
})
|
})
|
||||||
@ -189,17 +270,26 @@ func (vault *Vault) DeleteUser(userID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) Migrated() bool {
|
func (vault *Vault) Migrated() bool {
|
||||||
return vault.get().Migrated
|
vault.lock.RLock()
|
||||||
|
defer vault.lock.RUnlock()
|
||||||
|
|
||||||
|
return vault.getUnsafe().Migrated
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) SetMigrated() error {
|
func (vault *Vault) SetMigrated() error {
|
||||||
return vault.mod(func(data *Data) {
|
vault.lock.Lock()
|
||||||
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
|
return vault.modUnsafe(func(data *Data) {
|
||||||
data.Migrated = true
|
data.Migrated = true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) Reset(gluonDir string) error {
|
func (vault *Vault) Reset(gluonDir string) error {
|
||||||
return vault.mod(func(data *Data) {
|
vault.lock.Lock()
|
||||||
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
|
return vault.modUnsafe(func(data *Data) {
|
||||||
*data = newDefaultData(gluonDir)
|
*data = newDefaultData(gluonDir)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -209,8 +299,8 @@ func (vault *Vault) Path() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) Close() error {
|
func (vault *Vault) Close() error {
|
||||||
vault.refLock.Lock()
|
vault.lock.Lock()
|
||||||
defer vault.refLock.Unlock()
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
if len(vault.ref) > 0 {
|
if len(vault.ref) > 0 {
|
||||||
return errors.New("vault is still in use")
|
return errors.New("vault is still in use")
|
||||||
@ -221,10 +311,7 @@ func (vault *Vault) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) attachUser(userID string) *User {
|
func (vault *Vault) attachUserUnsafe(userID string) *User {
|
||||||
vault.refLock.Lock()
|
|
||||||
defer vault.refLock.Unlock()
|
|
||||||
|
|
||||||
logrus.WithField("userID", userID).Trace("Attaching vault user")
|
logrus.WithField("userID", userID).Trace("Attaching vault user")
|
||||||
|
|
||||||
vault.ref[userID]++
|
vault.ref[userID]++
|
||||||
@ -236,8 +323,8 @@ func (vault *Vault) attachUser(userID string) *User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) detachUser(userID string) error {
|
func (vault *Vault) detachUser(userID string) error {
|
||||||
vault.refLock.Lock()
|
vault.lock.Lock()
|
||||||
defer vault.refLock.Unlock()
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
logrus.WithField("userID", userID).Trace("Detaching vault user")
|
logrus.WithField("userID", userID).Trace("Detaching vault user")
|
||||||
|
|
||||||
@ -289,10 +376,14 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
|||||||
}, corrupt, nil
|
}, corrupt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) get() Data {
|
func (vault *Vault) getSafe() Data {
|
||||||
vault.encLock.RLock()
|
vault.lock.RLock()
|
||||||
defer vault.encLock.RUnlock()
|
defer vault.lock.RUnlock()
|
||||||
|
|
||||||
|
return vault.getUnsafe()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vault *Vault) getUnsafe() Data {
|
||||||
var data Data
|
var data Data
|
||||||
|
|
||||||
if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil {
|
if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil {
|
||||||
@ -302,10 +393,14 @@ func (vault *Vault) get() Data {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) mod(fn func(data *Data)) error {
|
func (vault *Vault) modSafe(fn func(data *Data)) error {
|
||||||
vault.encLock.Lock()
|
vault.lock.Lock()
|
||||||
defer vault.encLock.Unlock()
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
|
return vault.modUnsafe(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vault *Vault) modUnsafe(fn func(data *Data)) error {
|
||||||
var data Data
|
var data Data
|
||||||
|
|
||||||
if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil {
|
if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil {
|
||||||
@ -325,13 +420,31 @@ func (vault *Vault) mod(fn func(data *Data)) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) getUser(userID string) UserData {
|
func (vault *Vault) getUser(userID string) UserData {
|
||||||
return vault.get().Users[xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
vault.lock.RLock()
|
||||||
|
defer vault.lock.RUnlock()
|
||||||
|
|
||||||
|
users := vault.getUnsafe().Users
|
||||||
|
|
||||||
|
idx := xslices.IndexFunc(users, func(user UserData) bool {
|
||||||
return user.UserID == userID
|
return user.UserID == userID
|
||||||
})]
|
})
|
||||||
|
|
||||||
|
if idx < 0 {
|
||||||
|
panic("Unknown user")
|
||||||
|
}
|
||||||
|
|
||||||
|
return users[idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) modUser(userID string, fn func(userData *UserData)) error {
|
func (vault *Vault) modUser(userID string, fn func(userData *UserData)) error {
|
||||||
return vault.mod(func(data *Data) {
|
vault.lock.Lock()
|
||||||
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
|
return vault.modUserUnsafe(userID, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vault *Vault) modUserUnsafe(userID string, fn func(userData *UserData)) error {
|
||||||
|
return vault.modUnsafe(func(data *Data) {
|
||||||
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||||
return user.UserID == userID
|
return user.UserID == userID
|
||||||
})
|
})
|
||||||
|
|||||||
@ -24,7 +24,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (vault *Vault) ImportJSON(dec []byte) {
|
func (vault *Vault) ImportJSON(dec []byte) {
|
||||||
vault.mod(func(data *Data) {
|
vault.modSafe(func(data *Data) {
|
||||||
if err := json.Unmarshal(dec, data); err != nil {
|
if err := json.Unmarshal(dec, data); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@ -32,7 +32,7 @@ func (vault *Vault) ImportJSON(dec []byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) ExportJSON() []byte {
|
func (vault *Vault) ExportJSON() []byte {
|
||||||
enc, err := json.MarshalIndent(vault.get(), "", " ")
|
enc, err := json.MarshalIndent(vault.getSafe(), "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user