diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 30ce9633..8f44ee9a 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -487,27 +487,15 @@ func (bridge *Bridge) remWatcher(watcher *watcher.Watcher[events.Event]) { watcher.Close() } -func (bridge *Bridge) onStatusUp(ctx context.Context) { +func (bridge *Bridge) onStatusUp(_ context.Context) { logrus.Info("Handling API status up") - safe.RLock(func() { - for _, user := range bridge.users { - user.OnStatusUp(ctx) - } - }, bridge.usersLock) - bridge.goLoad() } func (bridge *Bridge) onStatusDown(ctx context.Context) { logrus.Info("Handling API status down") - safe.RLock(func() { - for _, user := range bridge.users { - user.OnStatusDown(ctx) - } - }, bridge.usersLock) - for backoff := time.Second; ; backoff = min(backoff*2, 30*time.Second) { select { case <-ctx.Done(): diff --git a/internal/bridge/config_status.go b/internal/bridge/config_status.go index 755e1627..a51fada2 100644 --- a/internal/bridge/config_status.go +++ b/internal/bridge/config_status.go @@ -22,7 +22,7 @@ import ( ) func (bridge *Bridge) ReportBugClicked() { - safe.Lock(func() { + safe.RLock(func() { for _, user := range bridge.users { user.ReportBugClicked() } @@ -30,7 +30,7 @@ func (bridge *Bridge) ReportBugClicked() { } func (bridge *Bridge) AutoconfigUsed(client string) { - safe.Lock(func() { + safe.RLock(func() { for _, user := range bridge.users { user.AutoconfigUsed(client) } @@ -38,7 +38,7 @@ func (bridge *Bridge) AutoconfigUsed(client string) { } func (bridge *Bridge) KBArticleOpened(article string) { - safe.Lock(func() { + safe.RLock(func() { for _, user := range bridge.users { user.KBArticleOpened(article) } diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 05497f39..ff4ccbbc 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -329,7 +329,7 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va func (bridge *Bridge) SendBadEventUserFeedback(_ context.Context, userID string, doResync bool) error { logrus.WithField("userID", userID).WithField("doResync", doResync).Info("Passing bad event feedback to user") - return safe.LockRet(func() error { + return safe.RLockRet(func() error { ctx := context.Background() user, ok := bridge.users[userID] diff --git a/internal/bridge/user_events.go b/internal/bridge/user_events.go index 8c288a02..0dda3c34 100644 --- a/internal/bridge/user_events.go +++ b/internal/bridge/user_events.go @@ -49,7 +49,7 @@ func (bridge *Bridge) handleUserDeauth(ctx context.Context, user *user.User) { } func (bridge *Bridge) handleUserBadEvent(ctx context.Context, user *user.User, event events.UserBadEvent) { - safe.Lock(func() { + safe.RLock(func() { if rerr := bridge.reporter.ReportMessageWithContext("Failed to handle event", reporter.Context{ "user_id": user.ID(), "old_event_id": event.OldEventID, diff --git a/internal/services/imapservice/service.go b/internal/services/imapservice/service.go index 79e4b9f0..fc178d95 100644 --- a/internal/services/imapservice/service.go +++ b/internal/services/imapservice/service.go @@ -151,7 +151,7 @@ func NewService( connectors: make(map[string]*Connector), maxSyncMemory: maxSyncMemory, - eventWatcher: subscription.Add(events.IMAPServerCreated{}), + eventWatcher: subscription.Add(events.IMAPServerCreated{}, events.ConnStatusUp{}, events.ConnStatusDown{}), eventSubscription: subscription, showAllMail: showAllMail, @@ -217,18 +217,6 @@ func (s *Service) Resync(ctx context.Context) error { return err } -func (s *Service) CancelSync(ctx context.Context) error { - _, err := s.cpc.Send(ctx, &cancelSyncReq{}) - - return err -} - -func (s *Service) ResumeSync(ctx context.Context) error { - _, err := s.cpc.Send(ctx, &resumeSyncReq{}) - - return err -} - func (s *Service) OnBadEvent(ctx context.Context) error { _, err := s.cpc.Send(ctx, &onBadEventReq{}) @@ -341,6 +329,7 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo } switch r := req.Value().(type) { case *setAddressModeReq: + s.log.Debug("Set Address Mode Request") err := s.setAddressMode(ctx, r.mode) req.Reply(ctx, nil, err) @@ -350,38 +339,33 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo req.Reply(ctx, nil, err) s.log.Info("Resync reply sent, handling as refresh event") - case *cancelSyncReq: - s.log.Info("Cancelling sync") - s.syncHandler.Cancel() - req.Reply(ctx, nil, nil) - - case *resumeSyncReq: - s.log.Info("Resuming sync") - // Cancel previous run, if any, just in case. - s.cancelSync() - s.startSyncing() - req.Reply(ctx, nil, nil) case *getLabelsReq: + s.log.Debug("Get labels Request") labels := s.labels.GetLabelMap() req.Reply(ctx, labels, nil) case *onBadEventReq: + s.log.Debug("Bad Event Request") err := s.removeConnectorsFromServer(ctx, s.connectors, false) req.Reply(ctx, nil, err) case *onBadEventResyncReq: + s.log.Debug("Bad Event Resync Request") err := s.addConnectorsToServer(ctx, s.connectors) req.Reply(ctx, nil, err) case *onLogoutReq: + s.log.Debug("Logout Request") err := s.removeConnectorsFromServer(ctx, s.connectors, false) req.Reply(ctx, nil, err) case *showAllMailReq: + s.log.Debug("Show all mail request") req.Reply(ctx, nil, nil) s.setShowAllMail(r.v) case *getSyncFailedMessagesReq: + s.log.Debug("Get sync failed messages Request") status, err := s.syncStateProvider.GetSyncStatus(ctx) if err != nil { req.Reply(ctx, nil, fmt.Errorf("failed to get sync status: %w", err)) @@ -470,10 +454,21 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo continue } - if _, ok := e.(events.IMAPServerCreated); ok { + switch e.(type) { + case events.IMAPServerCreated: + s.log.Debug("On IMAPServerCreated") if err := s.addConnectorsToServer(ctx, s.connectors); err != nil { s.log.WithError(err).Error("Failed to add connector to server after created") } + case events.ConnStatusUp: + s.log.Info("Connection Restored Resuming Sync (if any)") + // Cancel previous run, if any, just in case. + s.cancelSync() + s.startSyncing() + + case events.ConnStatusDown: + s.log.Info("Connection Lost cancelling sync") + s.cancelSync() } } } @@ -626,10 +621,6 @@ func (s *Service) cancelSync() { type resyncReq struct{} -type cancelSyncReq struct{} - -type resumeSyncReq struct{} - type getLabelsReq struct{} type onBadEventReq struct{} diff --git a/internal/services/userevents/service.go b/internal/services/userevents/service.go index 104dfbe8..bd81f7d0 100644 --- a/internal/services/userevents/service.go +++ b/internal/services/userevents/service.go @@ -29,6 +29,7 @@ import ( "time" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/watcher" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/proton-bridge/v3/internal" "github.com/ProtonMail/proton-bridge/v3/internal/events" @@ -67,6 +68,8 @@ type Service struct { eventPollWaiters []*EventPollWaiter eventPollWaitersLock sync.Mutex + eventSubscription events.Subscription + eventWatcher *watcher.Watcher[events.Event] } func NewService( @@ -78,6 +81,7 @@ func NewService( jitter time.Duration, eventTimeout time.Duration, panicHandler async.PanicHandler, + eventSubscription events.Subscription, ) *Service { return &Service{ cpc: cpc.NewCPC(), @@ -88,11 +92,13 @@ func NewService( "service": "user-events", "user": userID, }), - eventPublisher: eventPublisher, - timer: proton.NewTicker(pollPeriod, jitter, panicHandler), - paused: 1, - eventTimeout: eventTimeout, - panicHandler: panicHandler, + eventPublisher: eventPublisher, + timer: proton.NewTicker(pollPeriod, jitter, panicHandler), + paused: 1, + eventTimeout: eventTimeout, + panicHandler: panicHandler, + eventSubscription: eventSubscription, + eventWatcher: eventSubscription.Add(events.ConnStatusDown{}, events.ConnStatusUp{}), } } @@ -224,6 +230,19 @@ func (s *Service) run(ctx context.Context, lastEventID string) { } continue + case e, ok := <-s.eventWatcher.GetChannel(): + if !ok { + continue + } + + switch e.(type) { + case events.ConnStatusDown: + s.log.Info("Connection Lost, pausing") + s.Pause() + case events.ConnStatusUp: + s.log.Info("Connection Restored, resuming") + s.Resume() + } } // Apply any pending subscription changes. @@ -295,6 +314,11 @@ func (s *Service) run(ctx context.Context, lastEventID string) { // Close should be called after the service has been cancelled to clean up any remaining pending operations. func (s *Service) Close() { + if s.eventSubscription != nil { + s.eventSubscription.Remove(s.eventWatcher) + s.eventSubscription = nil + } + s.pendingSubscriptionsLock.Lock() defer s.pendingSubscriptionsLock.Unlock() diff --git a/internal/services/userevents/service_handle_event_error_test.go b/internal/services/userevents/service_handle_event_error_test.go index 0f3ddefa..209f48ad 100644 --- a/internal/services/userevents/service_handle_event_error_test.go +++ b/internal/services/userevents/service_handle_event_error_test.go @@ -48,6 +48,7 @@ func TestServiceHandleEventError_SubscriberEventUnwrapping(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) lastEventID := "PrevEvent" @@ -85,6 +86,7 @@ func TestServiceHandleEventError_BadEventPutsServiceOnPause(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) service.Resume() lastEventID := "PrevEvent" @@ -118,6 +120,7 @@ func TestServiceHandleEventError_BadEventFromPublishTimeout(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) lastEventID := "PrevEvent" event := proton.Event{EventID: "MyEvent"} @@ -148,6 +151,7 @@ func TestServiceHandleEventError_NoBadEventCheck(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) lastEventID := "PrevEvent" event := proton.Event{EventID: "MyEvent"} @@ -173,6 +177,7 @@ func TestServiceHandleEventError_JsonUnmarshalEventProducesUncategorizedErrorEve time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) lastEventID := "PrevEvent" event := proton.Event{EventID: "MyEvent"} diff --git a/internal/services/userevents/service_handle_event_test.go b/internal/services/userevents/service_handle_event_test.go index 902ba031..35b2be68 100644 --- a/internal/services/userevents/service_handle_event_test.go +++ b/internal/services/userevents/service_handle_event_test.go @@ -26,6 +26,7 @@ import ( "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -67,6 +68,7 @@ func TestServiceHandleEvent_CheckEventCategoriesHandledInOrder(t *testing.T) { time.Millisecond, 10*time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) subscription := NewCallbackSubscriber("test", EventHandler{ @@ -127,6 +129,7 @@ func TestServiceHandleEvent_CheckEventFailureCausesError(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) subscription := NewCallbackSubscriber("test", EventHandler{ @@ -164,6 +167,7 @@ func TestServiceHandleEvent_CheckEventFailureCausesErrorParallel(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) subscription := NewCallbackSubscriber("test", EventHandler{ diff --git a/internal/services/userevents/service_test.go b/internal/services/userevents/service_test.go index 2948125c..e7a2cbcf 100644 --- a/internal/services/userevents/service_test.go +++ b/internal/services/userevents/service_test.go @@ -75,6 +75,7 @@ func TestService_EventIDLoadStore(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) _, err := service.Start(context.Background(), group) @@ -130,6 +131,7 @@ func TestService_RetryEventOnNonCatastrophicFailure(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) service.Subscribe(NewCallbackSubscriber("foo", EventHandler{MessageHandler: subscriber})) @@ -179,6 +181,7 @@ func TestService_OnBadEventServiceIsPaused(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) // Event publisher expectations. @@ -245,6 +248,7 @@ func TestService_UnsubscribeDuringEventHandlingDoesNotCauseDeadlock(t *testing.T time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) subscription := NewCallbackSubscriber("foo", EventHandler{MessageHandler: subscriber}) @@ -304,6 +308,7 @@ func TestService_UnsubscribeBeforeHandlingEventIsNotConsideredError(t *testing.T time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) subscription := NewEventSubscriber("Foo") @@ -363,6 +368,7 @@ func TestService_WaitOnEventPublishAfterPause(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) subscriber.EXPECT().HandleMessageEvents(gomock.Any(), gomock.Eq(messageEvents)).Times(1).DoAndReturn(func(_ context.Context, _ []proton.MessageEvent) error { @@ -435,6 +441,7 @@ func TestService_EventRewind(t *testing.T) { time.Millisecond, time.Second, async.NoopPanicHandler{}, + events.NewNullSubscription(), ) _, err := service.Start(context.Background(), group) diff --git a/internal/user/user.go b/internal/user/user.go index 37f29486..bb140ec4 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -223,6 +223,7 @@ func newImpl( EventJitter, 5*time.Minute, crashHandler, + eventSubscription, ) addressMode := usertypes.VaultToAddressMode(encVault.AddressMode()) @@ -554,27 +555,6 @@ func (user *User) CheckAuth(email string, password []byte) (string, error) { return user.identityService.CheckAuth(ctx, email, password) } -// OnStatusUp is called when the connection goes up. -func (user *User) OnStatusUp(ctx context.Context) { - user.log.Info("Connection is up") - - user.eventService.Resume() - - if err := user.imapService.ResumeSync(ctx); err != nil { - user.log.WithError(err).Error("Failed to resume sync") - } -} - -// OnStatusDown is called when the connection goes down. -func (user *User) OnStatusDown(ctx context.Context) { - user.log.Info("Connection is down") - - user.eventService.Pause() - if err := user.imapService.CancelSync(ctx); err != nil { - user.log.WithError(err).Error("Failed to cancel sync") - } -} - // Logout logs the user out from the API. func (user *User) Logout(ctx context.Context, withAPI bool) error { user.log.WithField("withAPI", withAPI).Info("Logging out user")