diff --git a/internal/services/userevents/service.go b/internal/services/userevents/service.go index dc1bb9c6..52e6d83d 100644 --- a/internal/services/userevents/service.go +++ b/internal/services/userevents/service.go @@ -338,7 +338,7 @@ func (s *Service) handleEvent(ctx context.Context, lastEventID string, event pro s.log.Info("Received refresh event") } - return s.subscriberList.PublishParallel(ctx, event, s.panicHandler, s.eventTimeout) + return s.subscriberList.PublishParallel(ctx, event, s.panicHandler) } func unpackPublisherError(err error) (string, error) { diff --git a/internal/services/userevents/service_handle_event_test.go b/internal/services/userevents/service_handle_event_test.go index 4e023d26..318fbb2d 100644 --- a/internal/services/userevents/service_handle_event_test.go +++ b/internal/services/userevents/service_handle_event_test.go @@ -182,56 +182,3 @@ func TestServiceHandleEvent_CheckEventFailureCausesErrorParallel(t *testing.T) { require.True(t, errors.As(err, &publisherErr)) require.Equal(t, publisherErr.subscriber, subscription) } - -func TestServiceHandleEvent_SubscriberTimeout(t *testing.T) { - mockCtrl := gomock.NewController(t) - - eventPublisher := mocks.NewMockEventPublisher(mockCtrl) - eventIDStore := NewInMemoryEventIDStore() - - addressHandler := NewMockAddressEventHandler(mockCtrl) - addressHandler.EXPECT().HandleAddressEvents(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil) - - addressHandler2 := NewMockAddressEventHandler(mockCtrl) - addressHandler2.EXPECT().HandleAddressEvents(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, _ []proton.AddressEvent) error { - timer := time.NewTimer(time.Second) - - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return nil - } - }).MaxTimes(1) - - service := NewService( - "foo", - &NullEventSource{}, - eventIDStore, - eventPublisher, - 100*time.Millisecond, - time.Millisecond, - 500*time.Millisecond, - async.NoopPanicHandler{}, - ) - - subscription := NewCallbackSubscriber("test", EventHandler{ - AddressHandler: addressHandler2, - }) - - service.addSubscription(subscription) - - service.addSubscription(NewCallbackSubscriber("test2", EventHandler{ - AddressHandler: addressHandler, - })) - - // Simulate 1st refresh. - err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}}) - require.Error(t, err) - if publisherErr := new(eventPublishError); errors.As(err, &publisherErr) { - require.Equal(t, publisherErr.subscriber, subscription) - require.True(t, errors.Is(publisherErr.error, ErrPublishTimeoutExceeded)) - } else { - require.True(t, errors.Is(err, ErrPublishTimeoutExceeded)) - } -} diff --git a/internal/services/userevents/subscriber.go b/internal/services/userevents/subscriber.go index 535435ac..3e66573d 100644 --- a/internal/services/userevents/subscriber.go +++ b/internal/services/userevents/subscriber.go @@ -22,7 +22,6 @@ import ( "errors" "fmt" "runtime" - "time" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/go-proton-api" @@ -88,22 +87,19 @@ func (p publishError[T]) Error() string { return fmt.Sprintf("Event publish failed on (%v): %v", p.subscriber.name(), p.error.Error()) } -func (s *subscriberList[T]) Publish(ctx context.Context, event T, timeout time.Duration) error { - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) - defer cancel() - +func (s *subscriberList[T]) Publish(ctx context.Context, event T) error { for _, subscriber := range s.subscribers { if err := subscriber.handle(ctx, event); err != nil { return &publishError[T]{ subscriber: subscriber, - error: mapContextTimeoutError(err), + error: err, } } if err := ctx.Err(); err != nil { return &publishError[T]{ subscriber: subscriber, - error: mapContextTimeoutError(err), + error: err, } } } @@ -111,40 +107,28 @@ func (s *subscriberList[T]) Publish(ctx context.Context, event T, timeout time.D return nil } -func mapContextTimeoutError(err error) error { - if errors.Is(err, context.DeadlineExceeded) { - return ErrPublishTimeoutExceeded - } - - return err -} - func (s *subscriberList[T]) PublishParallel( ctx context.Context, event T, panicHandler async.PanicHandler, - timeout time.Duration, ) error { if len(s.subscribers) <= 1 { - return s.Publish(ctx, event, timeout) + return s.Publish(ctx, event) } - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) - defer cancel() - err := parallel.DoContext(ctx, runtime.NumCPU()/2, len(s.subscribers), func(ctx context.Context, index int) error { defer async.HandlePanic(panicHandler) if err := s.subscribers[index].handle(ctx, event); err != nil { return &publishError[T]{ subscriber: s.subscribers[index], - error: mapContextTimeoutError(err), + error: err, } } return nil }) - return mapContextTimeoutError(err) + return err } type ChanneledSubscriber[T any] struct {