diff --git a/internal/user/imap.go b/internal/user/imap.go index 146073fb..8a695cd4 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -86,7 +86,7 @@ func (conn *imapConnector) Authorize(username string, password []byte) bool { // CreateMailbox creates a label with the given name. func (conn *imapConnector) CreateMailbox(ctx context.Context, name []string) (imap.Mailbox, error) { - defer conn.goPoll() + defer conn.goPoll(false) if len(name) < 2 { return imap.Mailbox{}, fmt.Errorf("invalid mailbox name %q", name) @@ -157,7 +157,7 @@ func (conn *imapConnector) createFolder(ctx context.Context, name []string) (ima // UpdateMailboxName sets the name of the label with the given ID. func (conn *imapConnector) UpdateMailboxName(ctx context.Context, labelID imap.MailboxID, name []string) error { - defer conn.goPoll() + defer conn.goPoll(false) if len(name) < 2 { return fmt.Errorf("invalid mailbox name %q", name) @@ -234,7 +234,7 @@ func (conn *imapConnector) updateFolder(ctx context.Context, labelID imap.Mailbo // DeleteMailbox deletes the label with the given ID. func (conn *imapConnector) DeleteMailbox(ctx context.Context, labelID imap.MailboxID) error { - defer conn.goPoll() + defer conn.goPoll(false) return conn.client.DeleteLabel(ctx, string(labelID)) } @@ -249,7 +249,7 @@ func (conn *imapConnector) CreateMessage( flags imap.FlagSet, date time.Time, ) (imap.Message, []byte, error) { - defer conn.goPoll() + defer conn.goPoll(false) // Compute the hash of the message (to match it against SMTP messages). hash, err := getMessageHash(literal) @@ -313,14 +313,14 @@ func (conn *imapConnector) CreateMessage( // AddMessagesToMailbox labels the given messages with the given label ID. func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs []imap.MessageID, mailboxID imap.MailboxID) error { - defer conn.goPoll() + defer conn.goPoll(false) return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(mailboxID)) } // RemoveMessagesFromMailbox unlabels the given messages with the given label ID. func (conn *imapConnector) RemoveMessagesFromMailbox(ctx context.Context, messageIDs []imap.MessageID, mailboxID imap.MailboxID) error { - defer conn.goPoll() + defer conn.goPoll(false) if err := conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(mailboxID)); err != nil { return err @@ -363,7 +363,7 @@ func (conn *imapConnector) RemoveMessagesFromMailbox(ctx context.Context, messag // MoveMessages removes the given messages from one label and adds them to the other label. func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.MessageID, labelFromID imap.MailboxID, labelToID imap.MailboxID) error { - defer conn.goPoll() + defer conn.goPoll(false) if err := conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil { return fmt.Errorf("labeling messages: %w", err) @@ -378,7 +378,7 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M // MarkMessagesSeen sets the seen value of the given messages. func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.MessageID, seen bool) error { - defer conn.goPoll() + defer conn.goPoll(false) if seen { return conn.client.MarkMessagesRead(ctx, mapTo[imap.MessageID, string](messageIDs)...) @@ -389,7 +389,7 @@ func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []im // MarkMessagesFlagged sets the flagged value of the given messages. func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs []imap.MessageID, flagged bool) error { - defer conn.goPoll() + defer conn.goPoll(false) if flagged { return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel) diff --git a/internal/user/user.go b/internal/user/user.go index 56643438..f0e98b9c 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -72,8 +72,9 @@ type User struct { tasks *async.Group abortable async.Abortable + pollCh chan chan struct{} + goPoll func(bool) goSync func() - goPoll func() syncWorkers int syncBuffer int @@ -130,7 +131,8 @@ func New( reporter: reporter, - tasks: async.NewGroup(context.Background(), crashHandler), + tasks: async.NewGroup(context.Background(), crashHandler), + pollCh: make(chan chan struct{}), syncWorkers: syncWorkers, syncBuffer: syncBuffer, @@ -166,16 +168,49 @@ func New( // This does nothing until the sync has been marked as complete. // When we receive an API event, we attempt to handle it. // If successful, we update the event ID in the vault. - user.goPoll = user.tasks.PeriodicOrTrigger(EventPeriod, EventJitter, func(ctx context.Context) { - user.log.Debug("Event poll triggered") + user.tasks.Once(func(ctx context.Context) { + ticker := liteapi.NewTicker(EventPeriod, EventJitter) + defer ticker.Stop() - if !user.vault.SyncStatus().IsComplete() { - user.log.Debug("Sync is incomplete, skipping event poll") - } else if err := user.doEventPoll(ctx); err != nil { - user.log.WithError(err).Error("Failed to poll events") + for { + var doneCh chan struct{} + + select { + case <-ctx.Done(): + return + + case doneCh = <-user.pollCh: + // ... + + case <-ticker.C: + // ... + } + + user.log.Debug("Event poll triggered") + + if !user.vault.SyncStatus().IsComplete() { + user.log.Debug("Sync is incomplete, skipping event poll") + } else if err := user.doEventPoll(ctx); err != nil { + user.log.WithError(err).Error("Failed to poll events") + } + + if doneCh != nil { + close(doneCh) + } } }) + // When triggered, poll the API for events, optionally blocking until the poll is complete. + user.goPoll = func(wait bool) { + doneCh := make(chan struct{}) + + go func() { user.pollCh <- doneCh }() + + if wait { + <-doneCh + } + } + // When triggered, attempt to sync the user. user.goSync = user.tasks.Trigger(func(ctx context.Context) { user.log.Debug("Sync triggered") @@ -376,7 +411,7 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) { // // nolint:funlen func (user *User) SendMail(authID string, from string, to []string, r io.Reader) error { - defer user.goPoll() + defer user.goPoll(true) if len(to) == 0 { return ErrInvalidRecipient