Other: Fix user sync leaks/race conditions

This fixes various race conditions and leaks related to the user's sync
and API event stream. It was possible for a sync/stream to begin after a
user was already closed; this change prevents that by managing the
goroutines related to sync/stream within cancellable groups.
This commit is contained in:
James Houlahan
2022-10-24 12:54:01 +02:00
parent 6bbaf03f1f
commit 828385b049
14 changed files with 282 additions and 253 deletions

View File

@ -27,7 +27,6 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
)
@ -92,7 +91,7 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea
}
case liteapi.EventUpdateFlags:
logrus.Warn("Not implemented yet.")
user.log.Warn("Not implemented yet.")
}
}
@ -100,7 +99,9 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea
}
func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
user.apiAddrs.Set(event.Address.ID, event.Address)
if had := user.apiAddrs.Set(event.Address.ID, event.Address); had {
return fmt.Errorf("address %q already exists", event.Address.ID)
}
switch user.vault.AddressMode() {
case vault.CombinedMode:
@ -132,7 +133,9 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad
}
func (user *User) handleUpdateAddressEvent(_ context.Context, event liteapi.AddressEvent) error { //nolint:unparam
user.apiAddrs.Set(event.Address.ID, event.Address)
if had := user.apiAddrs.Set(event.Address.ID, event.Address); !had {
return fmt.Errorf("address %q does not exist", event.Address.ID)
}
user.eventCh.Enqueue(events.UserAddressUpdated{
UserID: user.ID(),
@ -219,9 +222,6 @@ func (user *User) handleDeleteLabelEvent(_ context.Context, event liteapi.LabelE
// handleMessageEvents handles the given message events.
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []liteapi.MessageEvent) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for _, event := range messageEvents {
switch event.Action {
case liteapi.EventCreate:
@ -249,7 +249,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me
}
return user.withAddrKR(event.Message.AddressID, func(_, addrKR *crypto.KeyRing) error {
buildRes, err := buildRFC822(ctx, full, addrKR)
buildRes, err := buildRFC822(full, addrKR)
if err != nil {
return fmt.Errorf("failed to build RFC822 message: %w", err)
}

View File

@ -20,6 +20,7 @@ package user
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/ProtonMail/gluon/imap"
@ -79,9 +80,9 @@ func (conn *imapConnector) Authorize(username string, password []byte) bool {
return true
}
// GetMailbox returns information about the label with the given ID.
func (conn *imapConnector) GetMailbox(ctx context.Context, labelID imap.MailboxID) (imap.Mailbox, error) {
label, err := conn.client.GetLabel(ctx, string(labelID), liteapi.LabelTypeLabel, liteapi.LabelTypeFolder)
// GetMailbox returns information about the mailbox with the given ID.
func (conn *imapConnector) GetMailbox(ctx context.Context, mailboxID imap.MailboxID) (imap.Mailbox, error) {
label, err := conn.client.GetLabel(ctx, string(mailboxID), liteapi.LabelTypeLabel, liteapi.LabelTypeFolder)
if err != nil {
return imap.Mailbox{}, err
}
@ -362,10 +363,7 @@ func (conn *imapConnector) Close(ctx context.Context) error {
return nil
}
func (conn *imapConnector) IsMailboxVisible(_ context.Context, id imap.MailboxID) bool {
if !conn.GetShowAllMail() && id == liteapi.AllMailLabel {
return false
}
return true
// IsMailboxVisible returns whether this mailbox should be visible over IMAP.
func (conn *imapConnector) IsMailboxVisible(_ context.Context, mailboxID imap.MailboxID) bool {
return atomic.LoadUint32(&conn.showAllMail) != 0 || mailboxID != liteapi.AllMailLabel
}

View File

@ -35,7 +35,6 @@ import (
"github.com/ProtonMail/proton-bridge/v2/pkg/message/parser"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/slices"
)
@ -104,7 +103,7 @@ func (user *User) sendMail(authID string, emails []string, from string, to []str
return fmt.Errorf("failed to send message: %w", err)
}
logrus.WithField("messageID", sent.ID).Info("Message sent")
user.log.WithField("messageID", sent.ID).Info("Message sent")
return nil
})

View File

@ -33,7 +33,6 @@ import (
"github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
)
@ -42,12 +41,40 @@ const (
maxBatchSize = 1 << 8
)
// doSync begins syncing the users data.
// It sends a SyncStarted event and then either SyncFinished or SyncFailed
// depending on whether the sync was successful.
func (user *User) doSync(ctx context.Context) error {
user.log.Debug("Beginning user sync")
user.eventCh.Enqueue(events.SyncStarted{
UserID: user.ID(),
})
if err := user.sync(ctx); err != nil {
user.log.WithError(err).Debug("Failed to sync user")
user.eventCh.Enqueue(events.SyncFailed{
UserID: user.ID(),
Err: err,
})
return fmt.Errorf("failed to sync: %w", err)
}
user.log.Debug("Finished user sync")
user.eventCh.Enqueue(events.SyncFinished{
UserID: user.ID(),
})
return nil
}
func (user *User) sync(ctx context.Context) error {
return user.withAddrKRs(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
logrus.Info("Beginning sync")
if !user.vault.SyncStatus().HasLabels {
logrus.Info("Syncing labels")
user.log.Debug("Syncing labels")
if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error {
return syncLabels(ctx, user.client, xslices.Unique(updateCh)...)
@ -59,13 +86,13 @@ func (user *User) sync(ctx context.Context) error {
return fmt.Errorf("failed to set has labels: %w", err)
}
logrus.Info("Synced labels")
user.log.Debug("Synced labels")
} else {
logrus.Info("Labels are already synced, skipping")
user.log.Debug("Labels are already synced, skipping")
}
if !user.vault.SyncStatus().HasMessages {
logrus.Info("Syncing messages")
user.log.Debug("Syncing messages")
if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error {
return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh)
@ -77,9 +104,9 @@ func (user *User) sync(ctx context.Context) error {
return fmt.Errorf("failed to set has messages: %w", err)
}
logrus.Info("Synced messages")
user.log.Debug("Synced messages")
} else {
logrus.Info("Messages are already synced, skipping")
user.log.Debug("Messages are already synced, skipping")
}
return nil
@ -169,11 +196,10 @@ func syncMessages( //nolint:funlen
// Fetch and build each message.
buildCh := stream.Map(
client.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...),
func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) {
return buildRFC822(ctx, full, addrKRs[full.AddressID])
func(_ context.Context, full liteapi.FullMessage) (*buildRes, error) {
return buildRFC822(full, addrKRs[full.AddressID])
},
)
defer buildCh.Close()
// Create the flushers, one per update channel.
flushers := make(map[string]*flusher)
@ -254,6 +280,8 @@ func wantLabelID(labelID string) bool {
}
func forEach[T any](ctx context.Context, streamer stream.Stream[T], fn func(T) error) error {
defer streamer.Close()
for {
res, err := streamer.Next(ctx)
if errors.Is(err, stream.End) {

View File

@ -18,7 +18,6 @@
package user
import (
"context"
"fmt"
"time"
@ -47,7 +46,7 @@ func defaultJobOpts() message.JobOptions {
}
}
func buildRFC822(_ context.Context, full liteapi.FullMessage, addrKR *crypto.KeyRing) (*buildRes, error) {
func buildRFC822(full liteapi.FullMessage, addrKR *crypto.KeyRing) (*buildRes, error) {
literal, err := message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts())
if err != nil {
return nil, fmt.Errorf("failed to build message %s: %w", full.ID, err)

View File

@ -18,7 +18,6 @@
package user
import (
"context"
"encoding/hex"
"fmt"
"reflect"
@ -92,22 +91,3 @@ func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
return "", fmt.Errorf("address %s not found", email)
}
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
func contextWithStopCh(ctx context.Context, channels ...<-chan struct{}) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
for _, stopCh := range channels {
go func(ch <-chan struct{}) {
select {
case <-ch:
cancel()
case <-ctx.Done():
// ...
}
}(stopCh)
}
return ctx, cancel
}

View File

@ -30,11 +30,12 @@ import (
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/async"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/try"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/bradenaw/juniper/xsync"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
)
@ -45,23 +46,35 @@ var (
)
type User struct {
log *logrus.Entry
vault *vault.User
client *liteapi.Client
eventCh *queue.QueuedChannel[events.Event]
stopCh chan struct{}
apiUser *safe.Value[liteapi.User]
apiAddrs *safe.Map[string, liteapi.Address]
updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]]
syncStopCh chan struct{}
syncLock try.Group
syncWG sync.WaitGroup
tasks *xsync.Group
abortable async.Abortable
goSync func()
showAllMail int32
showAllMail uint32
}
func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User, showAllMail bool) (*User, error) { //nolint:funlen
// New returns a new user.
//
// nolint:funlen
func New(
ctx context.Context,
encVault *vault.User,
client *liteapi.Client,
apiUser liteapi.User,
showAllMail bool,
) (*User, error) { //nolint:funlen
logrus.WithField("userID", apiUser.ID).Debug("Creating new user")
// Get the user's API addresses.
apiAddrs, err := client.GetAddresses(ctx)
if err != nil {
@ -104,25 +117,26 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
}
user := &User{
log: logrus.WithField("userID", apiUser.ID),
vault: encVault,
client: client,
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
stopCh: make(chan struct{}),
apiUser: safe.NewValue(apiUser),
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
updateCh: safe.NewMapFrom(updateCh, nil),
syncStopCh: make(chan struct{}),
}
tasks: xsync.NewGroup(context.Background()),
user.SetShowAllMail(showAllMail)
showAllMail: b32(showAllMail),
}
// When we receive an auth object, we update it in the vault.
// This will be used to authorize the user on the next run.
user.client.AddAuthHandler(func(auth liteapi.Auth) {
if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil {
logrus.WithError(err).Error("Failed to update auth in vault")
user.log.WithError(err).Error("Failed to update auth in vault")
}
})
@ -134,24 +148,38 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
})
})
// GODT-1946 - Don't start the event loop until the initial sync has finished.
eventCh := user.client.NewEventStream(EventPeriod, EventJitter, user.vault.EventID())
user.syncWG.Add(1)
// If we haven't synced yet, do it first.
// If it fails, we don't start the event loop.
// Otherwise, begin processing API events, logging any errors that occur.
go func() {
defer user.syncWG.Done()
if err := <-user.startSync(); err != nil {
return
// Stream events from the API, logging any errors that occur.
// When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault.
goStream := user.tasks.Trigger(func(ctx context.Context) {
for event := range user.client.NewEventStream(ctx, EventPeriod, EventJitter, user.vault.EventID()) {
if err := user.handleAPIEvent(ctx, event); err != nil {
user.log.WithError(err).Error("Failed to handle API event")
} else if err := user.vault.SetEventID(event.EventID); err != nil {
user.log.WithError(err).Error("Failed to update event ID in vault")
}
}
})
for err := range user.streamEvents(eventCh) {
logrus.WithError(err).Error("Error while streaming events")
}
}()
// We only ever want to start one event streamer.
var once sync.Once
// When triggered, attempt to sync the user.
// If successful, we start the event streamer if we haven't already.
user.goSync = user.tasks.Trigger(func(ctx context.Context) {
user.abortable.Do(ctx, func(ctx context.Context) {
if !user.vault.SyncStatus().IsComplete() {
if err := user.doSync(ctx); err != nil {
return
}
}
once.Do(goStream)
})
})
// Trigger an initial sync (if necessary) and start the event stream.
user.goSync()
return user, nil
}
@ -199,9 +227,8 @@ func (user *User) GetAddressMode() vault.AddressMode {
// SetAddressMode sets the user's address mode.
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
user.stopSync()
user.lockSync()
defer user.unlockSync()
user.abortable.Abort()
defer user.goSync()
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
for _, updateCh := range xslices.Unique(updateCh) {
@ -235,12 +262,6 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er
return fmt.Errorf("failed to clear sync status: %w", err)
}
go func() {
if err := <-user.startSync(); err != nil {
logrus.WithError(err).Error("Failed to sync after setting address mode")
}
}()
return nil
}
@ -364,26 +385,17 @@ func (user *User) CheckAuth(email string, password []byte) (string, error) {
// OnStatusUp is called when the connection goes up.
func (user *User) OnStatusUp() {
go func() {
logrus.Info("Connection up, checking if sync is needed")
if err := <-user.startSync(); err != nil {
logrus.WithError(err).Error("Failed to sync on status up")
}
}()
user.goSync()
}
// OnStatusDown is called when the connection goes down.
func (user *User) OnStatusDown() {
logrus.Info("Connection down, aborting any ongoing syncs")
user.stopSync()
user.abortable.Abort()
}
// Logout logs the user out from the API.
func (user *User) Logout(ctx context.Context) error {
// Cancel ongoing syncs.
user.stopSync()
user.tasks.Wait()
if err := user.client.AuthDelete(ctx); err != nil {
return fmt.Errorf("failed to delete auth: %w", err)
@ -397,14 +409,9 @@ func (user *User) Logout(ctx context.Context) error {
}
// Close closes ongoing connections and cleans up resources.
func (user *User) Close() error {
defer user.syncWG.Wait()
// Close any ongoing operations.
close(user.stopCh)
// Cancel ongoing syncs.
user.stopSync()
func (user *User) Close() {
// Stop any ongoing background tasks.
user.tasks.Wait()
// Close the user's API client.
user.client.Close()
@ -421,113 +428,20 @@ func (user *User) Close() error {
// Close the user's vault.
if err := user.vault.Close(); err != nil {
logrus.WithError(err).Error("Failed to close vault")
user.log.WithError(err).Error("Failed to close vault")
}
return nil
}
// SetShowAllMail sets whether to show the All Mail mailbox.
func (user *User) SetShowAllMail(show bool) {
var value int32
atomic.StoreUint32(&user.showAllMail, b32(show))
}
if show {
value = 1
} else {
value = 0
// b32 returns a uint32 0 or 1 representing b.
func b32(b bool) uint32 {
if b {
return 1
}
atomic.StoreInt32(&user.showAllMail, value)
}
func (user *User) GetShowAllMail() bool {
return atomic.LoadInt32(&user.showAllMail) == 1
}
// streamEvents begins streaming API events for the user.
// When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault.
func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error {
errCh := make(chan error)
go func() {
defer close(errCh)
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh)
defer cancel()
for event := range eventCh {
if err := user.handleAPIEvent(ctx, event); err != nil {
errCh <- fmt.Errorf("failed to handle API event: %w", err)
} else if err := user.vault.SetEventID(event.EventID); err != nil {
errCh <- fmt.Errorf("failed to update event ID: %w", err)
}
}
}()
return errCh
}
// startSync begins a startSync for the user.
func (user *User) startSync() <-chan error {
errCh := make(chan error)
user.syncLock.GoTry(func(ok bool) {
defer close(errCh)
if user.vault.SyncStatus().IsComplete() {
logrus.Debug("Already synced, skipping")
return
}
if !ok {
logrus.Debug("Sync already in progress, skipping")
return
}
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh, user.syncStopCh)
defer cancel()
user.eventCh.Enqueue(events.SyncStarted{
UserID: user.ID(),
})
if err := user.sync(ctx); err != nil {
user.eventCh.Enqueue(events.SyncFailed{
UserID: user.ID(),
Err: err,
})
errCh <- err
} else {
user.eventCh.Enqueue(events.SyncFinished{
UserID: user.ID(),
})
}
})
return errCh
}
// AbortSync aborts any ongoing sync.
// GODT-1947: Should probably be done automatically when one of the user's IMAP connectors is closed.
func (user *User) stopSync() {
defer user.syncLock.Wait()
select {
case user.syncStopCh <- struct{}{}:
logrus.Debug("Sent sync abort signal")
default:
logrus.Debug("No sync to abort")
}
}
// lockSync prevents a new sync from starting.
func (user *User) lockSync() {
user.syncLock.Lock()
}
// unlockSync allows a new sync to start.
func (user *User) unlockSync() {
user.syncLock.Unlock()
return 0
}

View File

@ -22,6 +22,7 @@ import (
"testing"
"time"
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
@ -30,6 +31,7 @@ import (
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
"gitlab.protontech.ch/go/liteapi/server/backend"
"go.uber.org/goleak"
)
func init() {
@ -39,7 +41,11 @@ func init() {
certs.GenerateCert = tests.FastGenerateCert
}
func TestUser_Data(t *testing.T) {
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, goleak.IgnoreCurrent())
}
func TestUser_Info(t *testing.T) {
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
withAccount(t, s, "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(userID string, addrIDs []string) {
withUser(t, ctx, s, m, "username", "password", func(user *User) {
@ -66,6 +72,9 @@ func TestUser_Sync(t *testing.T) {
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) {
withUser(t, ctx, s, m, "username", "password", func(user *User) {
// Process the IMAP updates as if we were gluon.
handleUpdates(t, user)
// User starts a sync at startup.
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
@ -79,6 +88,36 @@ func TestUser_Sync(t *testing.T) {
})
}
func TestUser_AddressMode(t *testing.T) {
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
withAccount(t, s, "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(userID string, addrIDs []string) {
withUser(t, ctx, s, m, "username", "password", func(user *User) {
// Process the IMAP updates as if we were gluon.
handleUpdates(t, user)
// User finishes syncing at startup.
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
require.IsType(t, events.SyncProgress{}, <-user.GetEventCh())
require.IsType(t, events.SyncFinished{}, <-user.GetEventCh())
// By default, user should be in combined mode.
require.Equal(t, vault.CombinedMode, user.GetAddressMode())
// User should be able to switch to split mode.
require.NoError(t, user.SetAddressMode(ctx, vault.SplitMode))
// Process the IMAP updates as if we were gluon.
handleUpdates(t, user)
// User finishes syncing after switching to split mode.
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
require.IsType(t, events.SyncProgress{}, <-user.GetEventCh())
require.IsType(t, events.SyncFinished{}, <-user.GetEventCh())
})
})
})
}
func TestUser_Deauth(t *testing.T) {
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) {
@ -126,7 +165,6 @@ func withAccount(tb testing.TB, s *server.Server, username, password string, ema
func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *liteapi.Manager, username, password string, fn func(*User)) { //nolint:unparam,revive
client, apiAuth, err := m.NewClientWithLogin(ctx, username, []byte(password))
require.NoError(tb, err)
defer client.Close()
apiUser, err := client.GetUser(ctx)
require.NoError(tb, err)
@ -146,18 +184,20 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *liteapi.M
user, err := New(ctx, vaultUser, client, apiUser, true)
require.NoError(tb, err)
defer func() { require.NoError(tb, user.Close()) }()
imapConn, err := user.NewIMAPConnectors()
require.NoError(tb, err)
go func() {
for _, imapConn := range imapConn {
for update := range imapConn.GetUpdates() {
update.Done()
}
}
}()
defer user.Close()
fn(user)
}
func handleUpdates(t *testing.T, user *User) {
imapConn, err := user.NewIMAPConnectors()
require.NoError(t, err)
for _, imapConn := range imapConn {
go func(imapConn connector.Connector) {
for update := range imapConn.GetUpdates() {
update.Done()
}
}(imapConn)
}
}