mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-11 13:16:53 +00:00
This patch adds the User Event Service which is meant to replace the current event polling flow. Each user interested in receiving events should register a new subscriber using the `Service.Subscribe` function and then react on the incoming events. The current patch does not hook this up Bridge user as there are no existing consumers, but it does provide extensive testing for the expected behavior.
487 lines
14 KiB
Go
487 lines
14 KiB
Go
// Copyright (c) 2023 Proton AG
|
|
//
|
|
// This file is part of Proton Mail Bridge.
|
|
//
|
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package userevents
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ProtonMail/gluon/async"
|
|
"github.com/ProtonMail/go-proton-api"
|
|
"github.com/ProtonMail/proton-bridge/v3/internal"
|
|
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
|
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// Service polls from the given event source and ensures that all the respective subscribers get notified
|
|
// before proceeding to the next event. The events are published in the following order:
|
|
// * Refresh
|
|
// * User
|
|
// * Address
|
|
// * Label
|
|
// * Message
|
|
// * UserUsedSpace
|
|
// By default this service starts paused, you need to call `Service.Resume` at least one time to begin event polling.
|
|
type Service struct {
|
|
userID string
|
|
cpc *cpc.CPC
|
|
eventSource EventSource
|
|
eventIDStore EventIDStore
|
|
log *logrus.Entry
|
|
eventPublisher events.EventPublisher
|
|
timer *time.Ticker
|
|
eventTimeout time.Duration
|
|
paused bool
|
|
panicHandler async.PanicHandler
|
|
|
|
userSubscriberList userSubscriberList
|
|
addressSubscribers addressSubscriberList
|
|
labelSubscribers labelSubscriberList
|
|
messageSubscribers messageSubscriberList
|
|
refreshSubscribers refreshSubscriberList
|
|
userUsedSpaceSubscriber userUsedSpaceSubscriberList
|
|
|
|
pendingSubscriptionsLock sync.Mutex
|
|
pendingSubscriptionsAdd []Subscription
|
|
pendingSubscriptionsRemove []Subscription
|
|
}
|
|
|
|
func NewService(
|
|
userID string,
|
|
eventSource EventSource,
|
|
store EventIDStore,
|
|
eventPublisher events.EventPublisher,
|
|
pollPeriod time.Duration,
|
|
eventTimeout time.Duration,
|
|
panicHandler async.PanicHandler,
|
|
) *Service {
|
|
return &Service{
|
|
userID: userID,
|
|
cpc: cpc.NewCPC(),
|
|
eventSource: eventSource,
|
|
eventIDStore: store,
|
|
log: logrus.WithFields(logrus.Fields{
|
|
"service": "user-events",
|
|
"user": userID,
|
|
}),
|
|
eventPublisher: eventPublisher,
|
|
timer: time.NewTicker(pollPeriod),
|
|
paused: true,
|
|
eventTimeout: eventTimeout,
|
|
panicHandler: panicHandler,
|
|
}
|
|
}
|
|
|
|
type Subscription struct {
|
|
User UserSubscriber
|
|
Refresh RefreshSubscriber
|
|
Address AddressSubscriber
|
|
Labels LabelSubscriber
|
|
Messages MessageSubscriber
|
|
UserUsedSpace UserUsedSpaceSubscriber
|
|
}
|
|
|
|
// cancel subscription subscribers if applicable, see `subscriber.cancel` for more information.
|
|
func (s Subscription) cancel() {
|
|
if s.User != nil {
|
|
s.User.cancel()
|
|
}
|
|
if s.Refresh != nil {
|
|
s.Refresh.cancel()
|
|
}
|
|
if s.Address != nil {
|
|
s.Address.cancel()
|
|
}
|
|
if s.Labels != nil {
|
|
s.Labels.cancel()
|
|
}
|
|
if s.Messages != nil {
|
|
s.Messages.cancel()
|
|
}
|
|
if s.UserUsedSpace != nil {
|
|
s.UserUsedSpace.cancel()
|
|
}
|
|
}
|
|
|
|
// Subscribe adds new subscribers to the service.
|
|
// This method can safely be called during event handling.
|
|
func (s *Service) Subscribe(subscription Subscription) {
|
|
s.pendingSubscriptionsLock.Lock()
|
|
defer s.pendingSubscriptionsLock.Unlock()
|
|
|
|
s.pendingSubscriptionsAdd = append(s.pendingSubscriptionsAdd, subscription)
|
|
}
|
|
|
|
// Unsubscribe removes subscribers from the service.
|
|
// This method can safely be called during event handling.
|
|
func (s *Service) Unsubscribe(subscription Subscription) {
|
|
subscription.cancel()
|
|
|
|
s.pendingSubscriptionsLock.Lock()
|
|
defer s.pendingSubscriptionsLock.Unlock()
|
|
|
|
s.pendingSubscriptionsRemove = append(s.pendingSubscriptionsRemove, subscription)
|
|
}
|
|
|
|
// Pause pauses the event polling.
|
|
// DO NOT CALL THIS DURING EVENT HANDLING.
|
|
func (s *Service) Pause(ctx context.Context) error {
|
|
_, err := s.cpc.Send(ctx, &pauseRequest{})
|
|
|
|
return err
|
|
}
|
|
|
|
// Resume resumes the event polling.
|
|
// DO NOT CALL THIS DURING EVENT HANDLING.
|
|
func (s *Service) Resume(ctx context.Context) error {
|
|
_, err := s.cpc.Send(ctx, &resumeRequest{})
|
|
|
|
return err
|
|
}
|
|
|
|
// IsPaused return true if the service is paused
|
|
// DO NOT CALL THIS DURING EVENT HANDLING.
|
|
func (s *Service) IsPaused(ctx context.Context) (bool, error) {
|
|
return cpc.SendTyped[bool](ctx, s.cpc, &isPausedRequest{})
|
|
}
|
|
|
|
func (s *Service) Start(ctx context.Context, group *async.Group) error {
|
|
lastEventID, err := s.eventIDStore.Load(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load last event id: %w", err)
|
|
}
|
|
|
|
if lastEventID == "" {
|
|
s.log.Debugf("No event ID present in storage, retrieving latest")
|
|
eventID, err := s.eventSource.GetLatestEventID(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get latest event id: %w", err)
|
|
}
|
|
|
|
if err := s.eventIDStore.Store(ctx, eventID); err != nil {
|
|
return fmt.Errorf("failed to store event in event id store: %v", err)
|
|
}
|
|
|
|
lastEventID = eventID
|
|
}
|
|
|
|
group.Once(func(ctx context.Context) {
|
|
s.run(ctx, lastEventID)
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) run(ctx context.Context, lastEventID string) {
|
|
s.log.Debugf("Starting service Last EventID=%v", lastEventID)
|
|
defer s.close()
|
|
defer s.log.Debug("Exiting service")
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case req, ok := <-s.cpc.ReceiveCh():
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
s.handleRequest(ctx, req)
|
|
continue
|
|
case <-s.timer.C:
|
|
if s.paused {
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Apply any pending subscription changes.
|
|
func() {
|
|
s.pendingSubscriptionsLock.Lock()
|
|
defer s.pendingSubscriptionsLock.Unlock()
|
|
|
|
for _, subscription := range s.pendingSubscriptionsRemove {
|
|
s.removeSubscription(subscription)
|
|
}
|
|
|
|
for _, subscription := range s.pendingSubscriptionsAdd {
|
|
s.addSubscription(subscription)
|
|
}
|
|
|
|
s.pendingSubscriptionsRemove = nil
|
|
s.pendingSubscriptionsAdd = nil
|
|
}()
|
|
|
|
newEvents, _, err := s.eventSource.GetEvent(ctx, lastEventID)
|
|
if err != nil {
|
|
s.log.WithError(err).Errorf("Failed to get event (caused by %T)", internal.ErrCause(err))
|
|
continue
|
|
}
|
|
|
|
// If the event ID hasn't changed, there are no new events.
|
|
if newEvents[len(newEvents)-1].EventID == lastEventID {
|
|
s.log.Debugf("No new API Events")
|
|
continue
|
|
}
|
|
|
|
if event, eventErr := func() (proton.Event, error) {
|
|
for _, event := range newEvents {
|
|
if err := s.handleEvent(ctx, lastEventID, event); err != nil {
|
|
return event, err
|
|
}
|
|
}
|
|
|
|
return proton.Event{}, nil
|
|
}(); eventErr != nil {
|
|
subscriberName, err := s.handleEventError(ctx, lastEventID, event, eventErr)
|
|
if subscriberName == "" {
|
|
subscriberName = "?"
|
|
}
|
|
s.log.WithField("subscriber", subscriberName).WithError(err).Errorf("Failed to apply event")
|
|
continue
|
|
}
|
|
|
|
newEventID := newEvents[len(newEvents)-1].EventID
|
|
if err := s.eventIDStore.Store(ctx, newEventID); err != nil {
|
|
s.log.WithError(err).Errorf("Failed to store new event ID: %v", err)
|
|
s.onBadEvent(ctx, events.UserBadEvent{
|
|
Error: fmt.Errorf("failed to store new event ID: %w", err),
|
|
UserID: s.userID,
|
|
})
|
|
continue
|
|
}
|
|
|
|
lastEventID = newEventID
|
|
}
|
|
}
|
|
|
|
func (s *Service) handleEvent(ctx context.Context, lastEventID string, event proton.Event) error {
|
|
s.log.WithFields(logrus.Fields{
|
|
"old": lastEventID,
|
|
"new": event,
|
|
}).Info("Received new API event")
|
|
|
|
if event.Refresh&proton.RefreshMail != 0 {
|
|
s.log.Info("Handling refresh event")
|
|
if err := s.refreshSubscribers.Publish(ctx, event.Refresh, s.eventTimeout); err != nil {
|
|
return fmt.Errorf("failed to apply refresh event: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Start with user events.
|
|
if event.User != nil {
|
|
if err := s.userSubscriberList.PublishParallel(ctx, *event.User, s.panicHandler, s.eventTimeout); err != nil {
|
|
return fmt.Errorf("failed to apply user event: %w", err)
|
|
}
|
|
}
|
|
|
|
// Next Address events
|
|
if err := s.addressSubscribers.PublishParallel(ctx, event.Addresses, s.panicHandler, s.eventTimeout); err != nil {
|
|
return fmt.Errorf("failed to apply address events: %w", err)
|
|
}
|
|
|
|
// Next label events
|
|
if err := s.labelSubscribers.PublishParallel(ctx, event.Labels, s.panicHandler, s.eventTimeout); err != nil {
|
|
return fmt.Errorf("failed to apply label events: %w", err)
|
|
}
|
|
|
|
// Next message events
|
|
if err := s.messageSubscribers.PublishParallel(ctx, event.Messages, s.panicHandler, s.eventTimeout); err != nil {
|
|
return fmt.Errorf("failed to apply message events: %w", err)
|
|
}
|
|
|
|
// Finally user used space events
|
|
if event.UsedSpace != nil {
|
|
if err := s.userUsedSpaceSubscriber.PublishParallel(ctx, *event.UsedSpace, s.panicHandler, s.eventTimeout); err != nil {
|
|
return fmt.Errorf("failed to apply message events: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func unpackPublisherError(err error) (string, error) {
|
|
var addressErr *addressPublishError
|
|
var labelErr *labelPublishError
|
|
var messageErr *messagePublishError
|
|
var refreshErr *refreshPublishError
|
|
var userErr *userPublishError
|
|
var usedSpaceErr *userUsedEventPublishError
|
|
|
|
switch {
|
|
case errors.As(err, &userErr):
|
|
return userErr.subscriber.name(), userErr.error
|
|
case errors.As(err, &addressErr):
|
|
return addressErr.subscriber.name(), addressErr.error
|
|
case errors.As(err, &labelErr):
|
|
return labelErr.subscriber.name(), labelErr.error
|
|
case errors.As(err, &messageErr):
|
|
return messageErr.subscriber.name(), messageErr.error
|
|
case errors.As(err, &refreshErr):
|
|
return refreshErr.subscriber.name(), refreshErr.error
|
|
case errors.As(err, &usedSpaceErr):
|
|
return usedSpaceErr.subscriber.name(), usedSpaceErr.error
|
|
default:
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
func (s *Service) handleEventError(ctx context.Context, lastEventID string, event proton.Event, err error) (string, error) {
|
|
// Unpack the error so we can proceed to handle the real issue.
|
|
subscriberName, err := unpackPublisherError(err)
|
|
|
|
// If the error is a context cancellation, return error to retry later.
|
|
if errors.Is(err, context.Canceled) {
|
|
return subscriberName, fmt.Errorf("failed to handle event due to context cancellation: %w", err)
|
|
}
|
|
|
|
// If the error is a network error, return error to retry later.
|
|
if netErr := new(proton.NetError); errors.As(err, &netErr) {
|
|
return subscriberName, fmt.Errorf("failed to handle event due to network issue: %w", err)
|
|
}
|
|
|
|
// Catch all for uncategorized net errors that may slip through.
|
|
if netErr := new(net.OpError); errors.As(err, &netErr) {
|
|
return subscriberName, fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err)
|
|
}
|
|
|
|
// In case a json decode error slips through.
|
|
if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) {
|
|
s.eventPublisher.PublishEvent(ctx, events.UncategorizedEventError{
|
|
UserID: s.userID,
|
|
Error: err,
|
|
})
|
|
|
|
return subscriberName, fmt.Errorf("failed to handle event due to JSON issue: %w", err)
|
|
}
|
|
|
|
// If the error is an unexpected EOF, return error to retry later.
|
|
if errors.Is(err, io.ErrUnexpectedEOF) {
|
|
return subscriberName, fmt.Errorf("failed to handle event due to EOF: %w", err)
|
|
}
|
|
|
|
// If the error is a server-side issue, return error to retry later.
|
|
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
|
|
return subscriberName, fmt.Errorf("failed to handle event due to server error: %w", err)
|
|
}
|
|
|
|
// Otherwise, the error is a client-side issue; notify bridge to handle it.
|
|
s.log.WithField("event", event).Warn("Failed to handle API event")
|
|
|
|
s.onBadEvent(ctx, events.UserBadEvent{
|
|
UserID: s.userID,
|
|
OldEventID: lastEventID,
|
|
NewEventID: event.EventID,
|
|
EventInfo: event.String(),
|
|
Error: err,
|
|
})
|
|
|
|
return subscriberName, fmt.Errorf("failed to handle event due to client error: %w", err)
|
|
}
|
|
|
|
func (s *Service) onBadEvent(ctx context.Context, event events.UserBadEvent) {
|
|
s.paused = true
|
|
s.eventPublisher.PublishEvent(ctx, event)
|
|
}
|
|
|
|
func (s *Service) handleRequest(ctx context.Context, request *cpc.Request) {
|
|
switch request.Value().(type) {
|
|
case *pauseRequest:
|
|
s.paused = true
|
|
request.Reply(ctx, nil, nil)
|
|
case *resumeRequest:
|
|
s.paused = false
|
|
request.Reply(ctx, nil, nil)
|
|
case *isPausedRequest:
|
|
request.Reply(ctx, s.paused, nil)
|
|
default:
|
|
s.log.Errorf("Unknown request")
|
|
}
|
|
}
|
|
|
|
func (s *Service) addSubscription(subscription Subscription) {
|
|
if subscription.User != nil {
|
|
s.userSubscriberList.Add(subscription.User)
|
|
}
|
|
|
|
if subscription.Refresh != nil {
|
|
s.refreshSubscribers.Add(subscription.Refresh)
|
|
}
|
|
|
|
if subscription.Address != nil {
|
|
s.addressSubscribers.Add(subscription.Address)
|
|
}
|
|
|
|
if subscription.Labels != nil {
|
|
s.labelSubscribers.Add(subscription.Labels)
|
|
}
|
|
|
|
if subscription.Messages != nil {
|
|
s.messageSubscribers.Add(subscription.Messages)
|
|
}
|
|
|
|
if subscription.UserUsedSpace != nil {
|
|
s.userUsedSpaceSubscriber.Add(subscription.UserUsedSpace)
|
|
}
|
|
}
|
|
|
|
func (s *Service) removeSubscription(subscription Subscription) {
|
|
if subscription.User != nil {
|
|
s.userSubscriberList.Remove(subscription.User)
|
|
}
|
|
|
|
if subscription.Refresh != nil {
|
|
s.refreshSubscribers.Remove(subscription.Refresh)
|
|
}
|
|
|
|
if subscription.Address != nil {
|
|
s.addressSubscribers.Remove(subscription.Address)
|
|
}
|
|
|
|
if subscription.Labels != nil {
|
|
s.labelSubscribers.Remove(subscription.Labels)
|
|
}
|
|
|
|
if subscription.Messages != nil {
|
|
s.messageSubscribers.Remove(subscription.Messages)
|
|
}
|
|
|
|
if subscription.UserUsedSpace != nil {
|
|
s.userUsedSpaceSubscriber.Remove(subscription.UserUsedSpace)
|
|
}
|
|
}
|
|
|
|
func (s *Service) close() {
|
|
s.cpc.Close()
|
|
s.timer.Stop()
|
|
}
|
|
|
|
type pauseRequest struct{}
|
|
|
|
type resumeRequest struct{}
|
|
|
|
type isPausedRequest struct{}
|