1
0

Other: Fix user sync leaks/race conditions

This fixes various race conditions and leaks related to the user's sync
and API event stream. It was possible for a sync/stream to begin after a
user was already closed; this change prevents that by managing the
goroutines related to sync/stream within cancellable groups.
This commit is contained in:
James Houlahan
2022-10-24 12:54:01 +02:00
parent 6bbaf03f1f
commit 828385b049
14 changed files with 282 additions and 253 deletions

View File

@ -30,11 +30,12 @@ import (
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/async"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/try"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/bradenaw/juniper/xsync"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
)
@ -45,23 +46,35 @@ var (
)
type User struct {
log *logrus.Entry
vault *vault.User
client *liteapi.Client
eventCh *queue.QueuedChannel[events.Event]
stopCh chan struct{}
apiUser *safe.Value[liteapi.User]
apiAddrs *safe.Map[string, liteapi.Address]
updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]]
syncStopCh chan struct{}
syncLock try.Group
syncWG sync.WaitGroup
tasks *xsync.Group
abortable async.Abortable
goSync func()
showAllMail int32
showAllMail uint32
}
func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User, showAllMail bool) (*User, error) { //nolint:funlen
// New returns a new user.
//
// nolint:funlen
func New(
ctx context.Context,
encVault *vault.User,
client *liteapi.Client,
apiUser liteapi.User,
showAllMail bool,
) (*User, error) { //nolint:funlen
logrus.WithField("userID", apiUser.ID).Debug("Creating new user")
// Get the user's API addresses.
apiAddrs, err := client.GetAddresses(ctx)
if err != nil {
@ -104,25 +117,26 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
}
user := &User{
log: logrus.WithField("userID", apiUser.ID),
vault: encVault,
client: client,
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
stopCh: make(chan struct{}),
apiUser: safe.NewValue(apiUser),
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
updateCh: safe.NewMapFrom(updateCh, nil),
syncStopCh: make(chan struct{}),
}
tasks: xsync.NewGroup(context.Background()),
user.SetShowAllMail(showAllMail)
showAllMail: b32(showAllMail),
}
// When we receive an auth object, we update it in the vault.
// This will be used to authorize the user on the next run.
user.client.AddAuthHandler(func(auth liteapi.Auth) {
if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil {
logrus.WithError(err).Error("Failed to update auth in vault")
user.log.WithError(err).Error("Failed to update auth in vault")
}
})
@ -134,24 +148,38 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
})
})
// GODT-1946 - Don't start the event loop until the initial sync has finished.
eventCh := user.client.NewEventStream(EventPeriod, EventJitter, user.vault.EventID())
user.syncWG.Add(1)
// If we haven't synced yet, do it first.
// If it fails, we don't start the event loop.
// Otherwise, begin processing API events, logging any errors that occur.
go func() {
defer user.syncWG.Done()
if err := <-user.startSync(); err != nil {
return
// Stream events from the API, logging any errors that occur.
// When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault.
goStream := user.tasks.Trigger(func(ctx context.Context) {
for event := range user.client.NewEventStream(ctx, EventPeriod, EventJitter, user.vault.EventID()) {
if err := user.handleAPIEvent(ctx, event); err != nil {
user.log.WithError(err).Error("Failed to handle API event")
} else if err := user.vault.SetEventID(event.EventID); err != nil {
user.log.WithError(err).Error("Failed to update event ID in vault")
}
}
})
for err := range user.streamEvents(eventCh) {
logrus.WithError(err).Error("Error while streaming events")
}
}()
// We only ever want to start one event streamer.
var once sync.Once
// When triggered, attempt to sync the user.
// If successful, we start the event streamer if we haven't already.
user.goSync = user.tasks.Trigger(func(ctx context.Context) {
user.abortable.Do(ctx, func(ctx context.Context) {
if !user.vault.SyncStatus().IsComplete() {
if err := user.doSync(ctx); err != nil {
return
}
}
once.Do(goStream)
})
})
// Trigger an initial sync (if necessary) and start the event stream.
user.goSync()
return user, nil
}
@ -199,9 +227,8 @@ func (user *User) GetAddressMode() vault.AddressMode {
// SetAddressMode sets the user's address mode.
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
user.stopSync()
user.lockSync()
defer user.unlockSync()
user.abortable.Abort()
defer user.goSync()
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
for _, updateCh := range xslices.Unique(updateCh) {
@ -235,12 +262,6 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er
return fmt.Errorf("failed to clear sync status: %w", err)
}
go func() {
if err := <-user.startSync(); err != nil {
logrus.WithError(err).Error("Failed to sync after setting address mode")
}
}()
return nil
}
@ -364,26 +385,17 @@ func (user *User) CheckAuth(email string, password []byte) (string, error) {
// OnStatusUp is called when the connection goes up.
func (user *User) OnStatusUp() {
go func() {
logrus.Info("Connection up, checking if sync is needed")
if err := <-user.startSync(); err != nil {
logrus.WithError(err).Error("Failed to sync on status up")
}
}()
user.goSync()
}
// OnStatusDown is called when the connection goes down.
func (user *User) OnStatusDown() {
logrus.Info("Connection down, aborting any ongoing syncs")
user.stopSync()
user.abortable.Abort()
}
// Logout logs the user out from the API.
func (user *User) Logout(ctx context.Context) error {
// Cancel ongoing syncs.
user.stopSync()
user.tasks.Wait()
if err := user.client.AuthDelete(ctx); err != nil {
return fmt.Errorf("failed to delete auth: %w", err)
@ -397,14 +409,9 @@ func (user *User) Logout(ctx context.Context) error {
}
// Close closes ongoing connections and cleans up resources.
func (user *User) Close() error {
defer user.syncWG.Wait()
// Close any ongoing operations.
close(user.stopCh)
// Cancel ongoing syncs.
user.stopSync()
func (user *User) Close() {
// Stop any ongoing background tasks.
user.tasks.Wait()
// Close the user's API client.
user.client.Close()
@ -421,113 +428,20 @@ func (user *User) Close() error {
// Close the user's vault.
if err := user.vault.Close(); err != nil {
logrus.WithError(err).Error("Failed to close vault")
user.log.WithError(err).Error("Failed to close vault")
}
return nil
}
// SetShowAllMail sets whether to show the All Mail mailbox.
func (user *User) SetShowAllMail(show bool) {
var value int32
atomic.StoreUint32(&user.showAllMail, b32(show))
}
if show {
value = 1
} else {
value = 0
// b32 returns a uint32 0 or 1 representing b.
func b32(b bool) uint32 {
if b {
return 1
}
atomic.StoreInt32(&user.showAllMail, value)
}
func (user *User) GetShowAllMail() bool {
return atomic.LoadInt32(&user.showAllMail) == 1
}
// streamEvents begins streaming API events for the user.
// When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault.
func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error {
errCh := make(chan error)
go func() {
defer close(errCh)
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh)
defer cancel()
for event := range eventCh {
if err := user.handleAPIEvent(ctx, event); err != nil {
errCh <- fmt.Errorf("failed to handle API event: %w", err)
} else if err := user.vault.SetEventID(event.EventID); err != nil {
errCh <- fmt.Errorf("failed to update event ID: %w", err)
}
}
}()
return errCh
}
// startSync begins a startSync for the user.
func (user *User) startSync() <-chan error {
errCh := make(chan error)
user.syncLock.GoTry(func(ok bool) {
defer close(errCh)
if user.vault.SyncStatus().IsComplete() {
logrus.Debug("Already synced, skipping")
return
}
if !ok {
logrus.Debug("Sync already in progress, skipping")
return
}
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh, user.syncStopCh)
defer cancel()
user.eventCh.Enqueue(events.SyncStarted{
UserID: user.ID(),
})
if err := user.sync(ctx); err != nil {
user.eventCh.Enqueue(events.SyncFailed{
UserID: user.ID(),
Err: err,
})
errCh <- err
} else {
user.eventCh.Enqueue(events.SyncFinished{
UserID: user.ID(),
})
}
})
return errCh
}
// AbortSync aborts any ongoing sync.
// GODT-1947: Should probably be done automatically when one of the user's IMAP connectors is closed.
func (user *User) stopSync() {
defer user.syncLock.Wait()
select {
case user.syncStopCh <- struct{}{}:
logrus.Debug("Sent sync abort signal")
default:
logrus.Debug("No sync to abort")
}
}
// lockSync prevents a new sync from starting.
func (user *User) lockSync() {
user.syncLock.Lock()
}
// unlockSync allows a new sync to start.
func (user *User) unlockSync() {
user.syncLock.Unlock()
return 0
}