Files
proton-bridge/internal/services/userevents/service.go
Leander Beernaert 82efa16d65 feat(GODT-2800): User Event Service
This patch adds the User Event Service which is meant to replace the
current event polling flow.

Each user interested in receiving events should register a new
subscriber using the `Service.Subscribe` function and then react on
the incoming events.

The current patch does not hook this up Bridge user as there are no
existing consumers, but it does provide extensive testing for the
expected behavior.
2023-07-21 11:47:43 +02:00

487 lines
14 KiB
Go

// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package userevents
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
"github.com/sirupsen/logrus"
)
// Service polls from the given event source and ensures that all the respective subscribers get notified
// before proceeding to the next event. The events are published in the following order:
// * Refresh
// * User
// * Address
// * Label
// * Message
// * UserUsedSpace
// By default this service starts paused, you need to call `Service.Resume` at least one time to begin event polling.
type Service struct {
userID string
cpc *cpc.CPC
eventSource EventSource
eventIDStore EventIDStore
log *logrus.Entry
eventPublisher events.EventPublisher
timer *time.Ticker
eventTimeout time.Duration
paused bool
panicHandler async.PanicHandler
userSubscriberList userSubscriberList
addressSubscribers addressSubscriberList
labelSubscribers labelSubscriberList
messageSubscribers messageSubscriberList
refreshSubscribers refreshSubscriberList
userUsedSpaceSubscriber userUsedSpaceSubscriberList
pendingSubscriptionsLock sync.Mutex
pendingSubscriptionsAdd []Subscription
pendingSubscriptionsRemove []Subscription
}
func NewService(
userID string,
eventSource EventSource,
store EventIDStore,
eventPublisher events.EventPublisher,
pollPeriod time.Duration,
eventTimeout time.Duration,
panicHandler async.PanicHandler,
) *Service {
return &Service{
userID: userID,
cpc: cpc.NewCPC(),
eventSource: eventSource,
eventIDStore: store,
log: logrus.WithFields(logrus.Fields{
"service": "user-events",
"user": userID,
}),
eventPublisher: eventPublisher,
timer: time.NewTicker(pollPeriod),
paused: true,
eventTimeout: eventTimeout,
panicHandler: panicHandler,
}
}
type Subscription struct {
User UserSubscriber
Refresh RefreshSubscriber
Address AddressSubscriber
Labels LabelSubscriber
Messages MessageSubscriber
UserUsedSpace UserUsedSpaceSubscriber
}
// cancel subscription subscribers if applicable, see `subscriber.cancel` for more information.
func (s Subscription) cancel() {
if s.User != nil {
s.User.cancel()
}
if s.Refresh != nil {
s.Refresh.cancel()
}
if s.Address != nil {
s.Address.cancel()
}
if s.Labels != nil {
s.Labels.cancel()
}
if s.Messages != nil {
s.Messages.cancel()
}
if s.UserUsedSpace != nil {
s.UserUsedSpace.cancel()
}
}
// Subscribe adds new subscribers to the service.
// This method can safely be called during event handling.
func (s *Service) Subscribe(subscription Subscription) {
s.pendingSubscriptionsLock.Lock()
defer s.pendingSubscriptionsLock.Unlock()
s.pendingSubscriptionsAdd = append(s.pendingSubscriptionsAdd, subscription)
}
// Unsubscribe removes subscribers from the service.
// This method can safely be called during event handling.
func (s *Service) Unsubscribe(subscription Subscription) {
subscription.cancel()
s.pendingSubscriptionsLock.Lock()
defer s.pendingSubscriptionsLock.Unlock()
s.pendingSubscriptionsRemove = append(s.pendingSubscriptionsRemove, subscription)
}
// Pause pauses the event polling.
// DO NOT CALL THIS DURING EVENT HANDLING.
func (s *Service) Pause(ctx context.Context) error {
_, err := s.cpc.Send(ctx, &pauseRequest{})
return err
}
// Resume resumes the event polling.
// DO NOT CALL THIS DURING EVENT HANDLING.
func (s *Service) Resume(ctx context.Context) error {
_, err := s.cpc.Send(ctx, &resumeRequest{})
return err
}
// IsPaused return true if the service is paused
// DO NOT CALL THIS DURING EVENT HANDLING.
func (s *Service) IsPaused(ctx context.Context) (bool, error) {
return cpc.SendTyped[bool](ctx, s.cpc, &isPausedRequest{})
}
func (s *Service) Start(ctx context.Context, group *async.Group) error {
lastEventID, err := s.eventIDStore.Load(ctx)
if err != nil {
return fmt.Errorf("failed to load last event id: %w", err)
}
if lastEventID == "" {
s.log.Debugf("No event ID present in storage, retrieving latest")
eventID, err := s.eventSource.GetLatestEventID(ctx)
if err != nil {
return fmt.Errorf("failed to get latest event id: %w", err)
}
if err := s.eventIDStore.Store(ctx, eventID); err != nil {
return fmt.Errorf("failed to store event in event id store: %v", err)
}
lastEventID = eventID
}
group.Once(func(ctx context.Context) {
s.run(ctx, lastEventID)
})
return nil
}
func (s *Service) run(ctx context.Context, lastEventID string) {
s.log.Debugf("Starting service Last EventID=%v", lastEventID)
defer s.close()
defer s.log.Debug("Exiting service")
for {
select {
case <-ctx.Done():
return
case req, ok := <-s.cpc.ReceiveCh():
if !ok {
return
}
s.handleRequest(ctx, req)
continue
case <-s.timer.C:
if s.paused {
continue
}
}
// Apply any pending subscription changes.
func() {
s.pendingSubscriptionsLock.Lock()
defer s.pendingSubscriptionsLock.Unlock()
for _, subscription := range s.pendingSubscriptionsRemove {
s.removeSubscription(subscription)
}
for _, subscription := range s.pendingSubscriptionsAdd {
s.addSubscription(subscription)
}
s.pendingSubscriptionsRemove = nil
s.pendingSubscriptionsAdd = nil
}()
newEvents, _, err := s.eventSource.GetEvent(ctx, lastEventID)
if err != nil {
s.log.WithError(err).Errorf("Failed to get event (caused by %T)", internal.ErrCause(err))
continue
}
// If the event ID hasn't changed, there are no new events.
if newEvents[len(newEvents)-1].EventID == lastEventID {
s.log.Debugf("No new API Events")
continue
}
if event, eventErr := func() (proton.Event, error) {
for _, event := range newEvents {
if err := s.handleEvent(ctx, lastEventID, event); err != nil {
return event, err
}
}
return proton.Event{}, nil
}(); eventErr != nil {
subscriberName, err := s.handleEventError(ctx, lastEventID, event, eventErr)
if subscriberName == "" {
subscriberName = "?"
}
s.log.WithField("subscriber", subscriberName).WithError(err).Errorf("Failed to apply event")
continue
}
newEventID := newEvents[len(newEvents)-1].EventID
if err := s.eventIDStore.Store(ctx, newEventID); err != nil {
s.log.WithError(err).Errorf("Failed to store new event ID: %v", err)
s.onBadEvent(ctx, events.UserBadEvent{
Error: fmt.Errorf("failed to store new event ID: %w", err),
UserID: s.userID,
})
continue
}
lastEventID = newEventID
}
}
func (s *Service) handleEvent(ctx context.Context, lastEventID string, event proton.Event) error {
s.log.WithFields(logrus.Fields{
"old": lastEventID,
"new": event,
}).Info("Received new API event")
if event.Refresh&proton.RefreshMail != 0 {
s.log.Info("Handling refresh event")
if err := s.refreshSubscribers.Publish(ctx, event.Refresh, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply refresh event: %w", err)
}
return nil
}
// Start with user events.
if event.User != nil {
if err := s.userSubscriberList.PublishParallel(ctx, *event.User, s.panicHandler, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply user event: %w", err)
}
}
// Next Address events
if err := s.addressSubscribers.PublishParallel(ctx, event.Addresses, s.panicHandler, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply address events: %w", err)
}
// Next label events
if err := s.labelSubscribers.PublishParallel(ctx, event.Labels, s.panicHandler, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply label events: %w", err)
}
// Next message events
if err := s.messageSubscribers.PublishParallel(ctx, event.Messages, s.panicHandler, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply message events: %w", err)
}
// Finally user used space events
if event.UsedSpace != nil {
if err := s.userUsedSpaceSubscriber.PublishParallel(ctx, *event.UsedSpace, s.panicHandler, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply message events: %w", err)
}
}
return nil
}
func unpackPublisherError(err error) (string, error) {
var addressErr *addressPublishError
var labelErr *labelPublishError
var messageErr *messagePublishError
var refreshErr *refreshPublishError
var userErr *userPublishError
var usedSpaceErr *userUsedEventPublishError
switch {
case errors.As(err, &userErr):
return userErr.subscriber.name(), userErr.error
case errors.As(err, &addressErr):
return addressErr.subscriber.name(), addressErr.error
case errors.As(err, &labelErr):
return labelErr.subscriber.name(), labelErr.error
case errors.As(err, &messageErr):
return messageErr.subscriber.name(), messageErr.error
case errors.As(err, &refreshErr):
return refreshErr.subscriber.name(), refreshErr.error
case errors.As(err, &usedSpaceErr):
return usedSpaceErr.subscriber.name(), usedSpaceErr.error
default:
return "", err
}
}
func (s *Service) handleEventError(ctx context.Context, lastEventID string, event proton.Event, err error) (string, error) {
// Unpack the error so we can proceed to handle the real issue.
subscriberName, err := unpackPublisherError(err)
// If the error is a context cancellation, return error to retry later.
if errors.Is(err, context.Canceled) {
return subscriberName, fmt.Errorf("failed to handle event due to context cancellation: %w", err)
}
// If the error is a network error, return error to retry later.
if netErr := new(proton.NetError); errors.As(err, &netErr) {
return subscriberName, fmt.Errorf("failed to handle event due to network issue: %w", err)
}
// Catch all for uncategorized net errors that may slip through.
if netErr := new(net.OpError); errors.As(err, &netErr) {
return subscriberName, fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err)
}
// In case a json decode error slips through.
if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) {
s.eventPublisher.PublishEvent(ctx, events.UncategorizedEventError{
UserID: s.userID,
Error: err,
})
return subscriberName, fmt.Errorf("failed to handle event due to JSON issue: %w", err)
}
// If the error is an unexpected EOF, return error to retry later.
if errors.Is(err, io.ErrUnexpectedEOF) {
return subscriberName, fmt.Errorf("failed to handle event due to EOF: %w", err)
}
// If the error is a server-side issue, return error to retry later.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
return subscriberName, fmt.Errorf("failed to handle event due to server error: %w", err)
}
// Otherwise, the error is a client-side issue; notify bridge to handle it.
s.log.WithField("event", event).Warn("Failed to handle API event")
s.onBadEvent(ctx, events.UserBadEvent{
UserID: s.userID,
OldEventID: lastEventID,
NewEventID: event.EventID,
EventInfo: event.String(),
Error: err,
})
return subscriberName, fmt.Errorf("failed to handle event due to client error: %w", err)
}
func (s *Service) onBadEvent(ctx context.Context, event events.UserBadEvent) {
s.paused = true
s.eventPublisher.PublishEvent(ctx, event)
}
func (s *Service) handleRequest(ctx context.Context, request *cpc.Request) {
switch request.Value().(type) {
case *pauseRequest:
s.paused = true
request.Reply(ctx, nil, nil)
case *resumeRequest:
s.paused = false
request.Reply(ctx, nil, nil)
case *isPausedRequest:
request.Reply(ctx, s.paused, nil)
default:
s.log.Errorf("Unknown request")
}
}
func (s *Service) addSubscription(subscription Subscription) {
if subscription.User != nil {
s.userSubscriberList.Add(subscription.User)
}
if subscription.Refresh != nil {
s.refreshSubscribers.Add(subscription.Refresh)
}
if subscription.Address != nil {
s.addressSubscribers.Add(subscription.Address)
}
if subscription.Labels != nil {
s.labelSubscribers.Add(subscription.Labels)
}
if subscription.Messages != nil {
s.messageSubscribers.Add(subscription.Messages)
}
if subscription.UserUsedSpace != nil {
s.userUsedSpaceSubscriber.Add(subscription.UserUsedSpace)
}
}
func (s *Service) removeSubscription(subscription Subscription) {
if subscription.User != nil {
s.userSubscriberList.Remove(subscription.User)
}
if subscription.Refresh != nil {
s.refreshSubscribers.Remove(subscription.Refresh)
}
if subscription.Address != nil {
s.addressSubscribers.Remove(subscription.Address)
}
if subscription.Labels != nil {
s.labelSubscribers.Remove(subscription.Labels)
}
if subscription.Messages != nil {
s.messageSubscribers.Remove(subscription.Messages)
}
if subscription.UserUsedSpace != nil {
s.userUsedSpaceSubscriber.Remove(subscription.UserUsedSpace)
}
}
func (s *Service) close() {
s.cpc.Close()
s.timer.Stop()
}
type pauseRequest struct{}
type resumeRequest struct{}
type isPausedRequest struct{}