diff --git a/internal/services/userevents/event_controller.go b/internal/services/userevents/event_controller.go new file mode 100644 index 00000000..b7fc01f2 --- /dev/null +++ b/internal/services/userevents/event_controller.go @@ -0,0 +1,23 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package userevents + +type EventController interface { + Pause() + Resume() +} diff --git a/internal/services/userevents/service.go b/internal/services/userevents/service.go index cd52030e..8f048136 100644 --- a/internal/services/userevents/service.go +++ b/internal/services/userevents/service.go @@ -25,13 +25,13 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/proton-bridge/v3/internal" "github.com/ProtonMail/proton-bridge/v3/internal/events" - "github.com/ProtonMail/proton-bridge/v3/pkg/cpc" "github.com/bradenaw/juniper/xslices" "github.com/sirupsen/logrus" "golang.org/x/exp/slices" @@ -48,14 +48,13 @@ import ( // By default this service starts paused, you need to call `Service.Resume` at least one time to begin event polling. type Service struct { userID string - cpc *cpc.CPC eventSource EventSource eventIDStore EventIDStore log *logrus.Entry eventPublisher events.EventPublisher timer *time.Ticker eventTimeout time.Duration - paused bool + paused uint32 panicHandler async.PanicHandler userSubscriberList userSubscriberList @@ -81,7 +80,6 @@ func NewService( ) *Service { return &Service{ userID: userID, - cpc: cpc.NewCPC(), eventSource: eventSource, eventIDStore: store, log: logrus.WithFields(logrus.Fields{ @@ -90,7 +88,7 @@ func NewService( }), eventPublisher: eventPublisher, timer: time.NewTicker(pollPeriod), - paused: true, + paused: 1, eventTimeout: eventTimeout, panicHandler: panicHandler, } @@ -169,25 +167,18 @@ func (s *Service) Unsubscribe(subscription Subscription) { } // Pause pauses the event polling. -// DO NOT CALL THIS DURING EVENT HANDLING. -func (s *Service) Pause(ctx context.Context) error { - _, err := s.cpc.Send(ctx, &pauseRequest{}) - - return err +func (s *Service) Pause() { + atomic.StoreUint32(&s.paused, 1) } // Resume resumes the event polling. -// DO NOT CALL THIS DURING EVENT HANDLING. -func (s *Service) Resume(ctx context.Context) error { - _, err := s.cpc.Send(ctx, &resumeRequest{}) - - return err +func (s *Service) Resume() { + atomic.StoreUint32(&s.paused, 0) } // IsPaused return true if the service is paused -// DO NOT CALL THIS DURING EVENT HANDLING. -func (s *Service) IsPaused(ctx context.Context) (bool, error) { - return cpc.SendTyped[bool](ctx, s.cpc, &isPausedRequest{}) +func (s *Service) IsPaused() bool { + return atomic.LoadUint32(&s.paused) == 1 } func (s *Service) Start(ctx context.Context, group *async.Group) error { @@ -226,15 +217,8 @@ func (s *Service) run(ctx context.Context, lastEventID string) { select { case <-ctx.Done(): return - case req, ok := <-s.cpc.ReceiveCh(): - if !ok { - return - } - - s.handleRequest(ctx, req) - continue case <-s.timer.C: - if s.paused { + if s.IsPaused() { continue } } @@ -448,25 +432,10 @@ func (s *Service) handleEventError(ctx context.Context, lastEventID string, even } func (s *Service) onBadEvent(ctx context.Context, event events.UserBadEvent) { - s.paused = true + s.Pause() s.eventPublisher.PublishEvent(ctx, event) } -func (s *Service) handleRequest(ctx context.Context, request *cpc.Request) { - switch request.Value().(type) { - case *pauseRequest: - s.paused = true - request.Reply(ctx, nil, nil) - case *resumeRequest: - s.paused = false - request.Reply(ctx, nil, nil) - case *isPausedRequest: - request.Reply(ctx, s.paused, nil) - default: - s.log.Errorf("Unknown request") - } -} - func (s *Service) addSubscription(subscription Subscription) { if subscription.User != nil { s.userSubscriberList.Add(subscription.User) @@ -520,7 +489,6 @@ func (s *Service) removeSubscription(subscription Subscription) { } func (s *Service) close() { - s.cpc.Close() s.timer.Stop() } diff --git a/internal/services/userevents/service_handle_event_error_test.go b/internal/services/userevents/service_handle_event_error_test.go index 0ea206f4..28372dc3 100644 --- a/internal/services/userevents/service_handle_event_error_test.go +++ b/internal/services/userevents/service_handle_event_error_test.go @@ -68,7 +68,7 @@ func TestServiceHandleEventError_BadEventPutsServiceOnPause(t *testing.T) { eventIDStore := NewInMemoryEventIDStore() service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{}) - service.paused = false + service.Resume() lastEventID := "PrevEvent" event := proton.Event{EventID: "MyEvent"} @@ -83,7 +83,7 @@ func TestServiceHandleEventError_BadEventPutsServiceOnPause(t *testing.T) { })).Times(1) _, _ = service.handleEventError(context.Background(), lastEventID, event, err) - require.True(t, service.paused) + require.True(t, service.IsPaused()) } func TestServiceHandleEventError_BadEventFromPublishTimeout(t *testing.T) { diff --git a/internal/services/userevents/service_test.go b/internal/services/userevents/service_test.go index c858321e..32d67131 100644 --- a/internal/services/userevents/service_test.go +++ b/internal/services/userevents/service_test.go @@ -67,7 +67,7 @@ func TestService_EventIDLoadStore(t *testing.T) { service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{}) require.NoError(t, service.Start(context.Background(), group)) - require.NoError(t, service.Resume(context.Background())) + service.Resume() group.WaitToFinish() } @@ -113,7 +113,7 @@ func TestService_RetryEventOnNonCatastrophicFailure(t *testing.T) { service.Subscribe(Subscription{Messages: subscriber}) require.NoError(t, service.Start(context.Background(), group)) - require.NoError(t, service.Resume(context.Background())) + service.Resume() group.WaitToFinish() } @@ -160,16 +160,14 @@ func TestService_OnBadEventServiceIsPaused(t *testing.T) { }).Do(func(_ context.Context, event events.Event) { group.Once(func(_ context.Context) { // Use background context to avoid having the request cancelled - paused, err := service.IsPaused(context.Background()) - require.NoError(t, err) - require.True(t, paused) + require.True(t, service.IsPaused()) group.Cancel() }) }) service.Subscribe(Subscription{Messages: subscriber}) require.NoError(t, service.Start(context.Background(), group)) - require.NoError(t, service.Resume(context.Background())) + service.Resume() group.WaitToFinish() } @@ -216,7 +214,7 @@ func TestService_UnsubscribeDuringEventHandlingDoesNotCauseDeadlock(t *testing.T service.Subscribe(Subscription{Messages: subscriber}) require.NoError(t, service.Start(context.Background(), group)) - require.NoError(t, service.Resume(context.Background())) + service.Resume() group.WaitToFinish() } @@ -264,6 +262,6 @@ func TestService_UnsubscribeBeforeHandlingEventIsNotConsideredError(t *testing.T service.Subscribe(Subscription{Messages: subscriber}) require.NoError(t, service.Start(context.Background(), group)) - require.NoError(t, service.Resume(context.Background())) + service.Resume() group.WaitToFinish() }