Other: Fix user sync leaks/race conditions
This fixes various race conditions and leaks related to the user's sync and API event stream. It was possible for a sync/stream to begin after a user was already closed; this change prevents that by managing the goroutines related to sync/stream within cancellable groups.
This commit is contained in:
@ -30,11 +30,12 @@ import (
|
||||
"github.com/ProtonMail/gluon/connector"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/async"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/try"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/bradenaw/juniper/xsync"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
@ -45,23 +46,35 @@ var (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
log *logrus.Entry
|
||||
|
||||
vault *vault.User
|
||||
client *liteapi.Client
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
stopCh chan struct{}
|
||||
|
||||
apiUser *safe.Value[liteapi.User]
|
||||
apiAddrs *safe.Map[string, liteapi.Address]
|
||||
updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]]
|
||||
|
||||
syncStopCh chan struct{}
|
||||
syncLock try.Group
|
||||
syncWG sync.WaitGroup
|
||||
tasks *xsync.Group
|
||||
abortable async.Abortable
|
||||
goSync func()
|
||||
|
||||
showAllMail int32
|
||||
showAllMail uint32
|
||||
}
|
||||
|
||||
func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User, showAllMail bool) (*User, error) { //nolint:funlen
|
||||
// New returns a new user.
|
||||
//
|
||||
// nolint:funlen
|
||||
func New(
|
||||
ctx context.Context,
|
||||
encVault *vault.User,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
showAllMail bool,
|
||||
) (*User, error) { //nolint:funlen
|
||||
logrus.WithField("userID", apiUser.ID).Debug("Creating new user")
|
||||
|
||||
// Get the user's API addresses.
|
||||
apiAddrs, err := client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
@ -104,25 +117,26 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
}
|
||||
|
||||
user := &User{
|
||||
log: logrus.WithField("userID", apiUser.ID),
|
||||
|
||||
vault: encVault,
|
||||
client: client,
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
stopCh: make(chan struct{}),
|
||||
|
||||
apiUser: safe.NewValue(apiUser),
|
||||
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
|
||||
updateCh: safe.NewMapFrom(updateCh, nil),
|
||||
|
||||
syncStopCh: make(chan struct{}),
|
||||
}
|
||||
tasks: xsync.NewGroup(context.Background()),
|
||||
|
||||
user.SetShowAllMail(showAllMail)
|
||||
showAllMail: b32(showAllMail),
|
||||
}
|
||||
|
||||
// When we receive an auth object, we update it in the vault.
|
||||
// This will be used to authorize the user on the next run.
|
||||
user.client.AddAuthHandler(func(auth liteapi.Auth) {
|
||||
if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil {
|
||||
logrus.WithError(err).Error("Failed to update auth in vault")
|
||||
user.log.WithError(err).Error("Failed to update auth in vault")
|
||||
}
|
||||
})
|
||||
|
||||
@ -134,24 +148,38 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
})
|
||||
})
|
||||
|
||||
// GODT-1946 - Don't start the event loop until the initial sync has finished.
|
||||
eventCh := user.client.NewEventStream(EventPeriod, EventJitter, user.vault.EventID())
|
||||
|
||||
user.syncWG.Add(1)
|
||||
// If we haven't synced yet, do it first.
|
||||
// If it fails, we don't start the event loop.
|
||||
// Otherwise, begin processing API events, logging any errors that occur.
|
||||
go func() {
|
||||
defer user.syncWG.Done()
|
||||
|
||||
if err := <-user.startSync(); err != nil {
|
||||
return
|
||||
// Stream events from the API, logging any errors that occur.
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
goStream := user.tasks.Trigger(func(ctx context.Context) {
|
||||
for event := range user.client.NewEventStream(ctx, EventPeriod, EventJitter, user.vault.EventID()) {
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
user.log.WithError(err).Error("Failed to handle API event")
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
user.log.WithError(err).Error("Failed to update event ID in vault")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for err := range user.streamEvents(eventCh) {
|
||||
logrus.WithError(err).Error("Error while streaming events")
|
||||
}
|
||||
}()
|
||||
// We only ever want to start one event streamer.
|
||||
var once sync.Once
|
||||
|
||||
// When triggered, attempt to sync the user.
|
||||
// If successful, we start the event streamer if we haven't already.
|
||||
user.goSync = user.tasks.Trigger(func(ctx context.Context) {
|
||||
user.abortable.Do(ctx, func(ctx context.Context) {
|
||||
if !user.vault.SyncStatus().IsComplete() {
|
||||
if err := user.doSync(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
once.Do(goStream)
|
||||
})
|
||||
})
|
||||
|
||||
// Trigger an initial sync (if necessary) and start the event stream.
|
||||
user.goSync()
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@ -199,9 +227,8 @@ func (user *User) GetAddressMode() vault.AddressMode {
|
||||
|
||||
// SetAddressMode sets the user's address mode.
|
||||
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
|
||||
user.stopSync()
|
||||
user.lockSync()
|
||||
defer user.unlockSync()
|
||||
user.abortable.Abort()
|
||||
defer user.goSync()
|
||||
|
||||
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||
for _, updateCh := range xslices.Unique(updateCh) {
|
||||
@ -235,12 +262,6 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := <-user.startSync(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to sync after setting address mode")
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -364,26 +385,17 @@ func (user *User) CheckAuth(email string, password []byte) (string, error) {
|
||||
|
||||
// OnStatusUp is called when the connection goes up.
|
||||
func (user *User) OnStatusUp() {
|
||||
go func() {
|
||||
logrus.Info("Connection up, checking if sync is needed")
|
||||
|
||||
if err := <-user.startSync(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to sync on status up")
|
||||
}
|
||||
}()
|
||||
user.goSync()
|
||||
}
|
||||
|
||||
// OnStatusDown is called when the connection goes down.
|
||||
func (user *User) OnStatusDown() {
|
||||
logrus.Info("Connection down, aborting any ongoing syncs")
|
||||
|
||||
user.stopSync()
|
||||
user.abortable.Abort()
|
||||
}
|
||||
|
||||
// Logout logs the user out from the API.
|
||||
func (user *User) Logout(ctx context.Context) error {
|
||||
// Cancel ongoing syncs.
|
||||
user.stopSync()
|
||||
user.tasks.Wait()
|
||||
|
||||
if err := user.client.AuthDelete(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete auth: %w", err)
|
||||
@ -397,14 +409,9 @@ func (user *User) Logout(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Close closes ongoing connections and cleans up resources.
|
||||
func (user *User) Close() error {
|
||||
defer user.syncWG.Wait()
|
||||
|
||||
// Close any ongoing operations.
|
||||
close(user.stopCh)
|
||||
|
||||
// Cancel ongoing syncs.
|
||||
user.stopSync()
|
||||
func (user *User) Close() {
|
||||
// Stop any ongoing background tasks.
|
||||
user.tasks.Wait()
|
||||
|
||||
// Close the user's API client.
|
||||
user.client.Close()
|
||||
@ -421,113 +428,20 @@ func (user *User) Close() error {
|
||||
|
||||
// Close the user's vault.
|
||||
if err := user.vault.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close vault")
|
||||
user.log.WithError(err).Error("Failed to close vault")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetShowAllMail sets whether to show the All Mail mailbox.
|
||||
func (user *User) SetShowAllMail(show bool) {
|
||||
var value int32
|
||||
atomic.StoreUint32(&user.showAllMail, b32(show))
|
||||
}
|
||||
|
||||
if show {
|
||||
value = 1
|
||||
} else {
|
||||
value = 0
|
||||
// b32 returns a uint32 0 or 1 representing b.
|
||||
func b32(b bool) uint32 {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
|
||||
atomic.StoreInt32(&user.showAllMail, value)
|
||||
}
|
||||
|
||||
func (user *User) GetShowAllMail() bool {
|
||||
return atomic.LoadInt32(&user.showAllMail) == 1
|
||||
}
|
||||
|
||||
// streamEvents begins streaming API events for the user.
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error {
|
||||
errCh := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
|
||||
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh)
|
||||
defer cancel()
|
||||
|
||||
for event := range eventCh {
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
errCh <- fmt.Errorf("failed to handle API event: %w", err)
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
errCh <- fmt.Errorf("failed to update event ID: %w", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
// startSync begins a startSync for the user.
|
||||
func (user *User) startSync() <-chan error {
|
||||
errCh := make(chan error)
|
||||
|
||||
user.syncLock.GoTry(func(ok bool) {
|
||||
defer close(errCh)
|
||||
|
||||
if user.vault.SyncStatus().IsComplete() {
|
||||
logrus.Debug("Already synced, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
logrus.Debug("Sync already in progress, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh, user.syncStopCh)
|
||||
defer cancel()
|
||||
|
||||
user.eventCh.Enqueue(events.SyncStarted{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
if err := user.sync(ctx); err != nil {
|
||||
user.eventCh.Enqueue(events.SyncFailed{
|
||||
UserID: user.ID(),
|
||||
Err: err,
|
||||
})
|
||||
|
||||
errCh <- err
|
||||
} else {
|
||||
user.eventCh.Enqueue(events.SyncFinished{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
// AbortSync aborts any ongoing sync.
|
||||
// GODT-1947: Should probably be done automatically when one of the user's IMAP connectors is closed.
|
||||
func (user *User) stopSync() {
|
||||
defer user.syncLock.Wait()
|
||||
|
||||
select {
|
||||
case user.syncStopCh <- struct{}{}:
|
||||
logrus.Debug("Sent sync abort signal")
|
||||
|
||||
default:
|
||||
logrus.Debug("No sync to abort")
|
||||
}
|
||||
}
|
||||
|
||||
// lockSync prevents a new sync from starting.
|
||||
func (user *User) lockSync() {
|
||||
user.syncLock.Lock()
|
||||
}
|
||||
|
||||
// unlockSync allows a new sync to start.
|
||||
func (user *User) unlockSync() {
|
||||
user.syncLock.Unlock()
|
||||
return 0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user