diff --git a/internal/services/userevents/service.go b/internal/services/userevents/service.go index 8640c135..8e26beb4 100644 --- a/internal/services/userevents/service.go +++ b/internal/services/userevents/service.go @@ -33,9 +33,8 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal" "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/services/orderedtasks" - "github.com/bradenaw/juniper/xslices" + "github.com/bradenaw/juniper/xmaps" "github.com/sirupsen/logrus" - "golang.org/x/exp/slices" ) // Service polls from the given event source and ensures that all the respective subscribers get notified @@ -65,9 +64,8 @@ type Service struct { refreshSubscribers refreshSubscriberList userUsedSpaceSubscriber userUsedSpaceSubscriberList - pendingSubscriptionsLock sync.Mutex - pendingSubscriptionsAdd []Subscription - pendingSubscriptionsRemove []Subscription + pendingSubscriptionsLock sync.Mutex + pendingSubscriptions []pendingSubscription } func NewService( @@ -153,7 +151,7 @@ func (s *Service) Subscribe(subscription Subscription) { s.pendingSubscriptionsLock.Lock() defer s.pendingSubscriptionsLock.Unlock() - s.pendingSubscriptionsAdd = append(s.pendingSubscriptionsAdd, subscription) + s.pendingSubscriptions = append(s.pendingSubscriptions, pendingSubscription{op: pendingOpAdd, sub: subscription}) } // Unsubscribe removes subscribers from the service. @@ -164,7 +162,7 @@ func (s *Service) Unsubscribe(subscription Subscription) { s.pendingSubscriptionsLock.Lock() defer s.pendingSubscriptionsLock.Unlock() - s.pendingSubscriptionsRemove = append(s.pendingSubscriptionsRemove, subscription) + s.pendingSubscriptions = append(s.pendingSubscriptions, pendingSubscription{op: pendingOpRemove, sub: subscription}) } // Pause pauses the event polling. @@ -230,16 +228,15 @@ func (s *Service) run(ctx context.Context, lastEventID string) { s.pendingSubscriptionsLock.Lock() defer s.pendingSubscriptionsLock.Unlock() - for _, subscription := range s.pendingSubscriptionsRemove { - s.removeSubscription(subscription) + for _, p := range s.pendingSubscriptions { + if p.op == pendingOpAdd { + s.addSubscription(p.sub) + } else { + s.removeSubscription(p.sub) + } } - for _, subscription := range s.pendingSubscriptionsAdd { - s.addSubscription(subscription) - } - - s.pendingSubscriptionsRemove = nil - s.pendingSubscriptionsAdd = nil + s.pendingSubscriptions = nil }() newEvents, _, err := s.eventSource.GetEvent(ctx, lastEventID) @@ -290,21 +287,22 @@ func (s *Service) Close() { s.pendingSubscriptionsLock.Lock() defer s.pendingSubscriptionsLock.Unlock() + processed := xmaps.Set[Subscription]{} + // Cleanup pending removes. - for _, subscription := range s.pendingSubscriptionsRemove { - subscription.close() + for _, s := range s.pendingSubscriptions { + if s.op == pendingOpRemove { + if !processed.Contains(s.sub) { + s.sub.close() + } + } else { + s.sub.cancel() + s.sub.close() + processed.Add(s.sub) + } } - // Cleanup pending adds. - for _, subscription := range xslices.Filter(s.pendingSubscriptionsAdd, func(sub Subscription) bool { - return !slices.Contains(s.pendingSubscriptionsRemove, sub) - }) { - subscription.cancel() - subscription.close() - } - - s.pendingSubscriptionsRemove = nil - s.pendingSubscriptionsAdd = nil + s.pendingSubscriptions = nil } func (s *Service) handleEvent(ctx context.Context, lastEventID string, event proton.Event) error { @@ -493,3 +491,15 @@ func (s *Service) removeSubscription(subscription Subscription) { func (s *Service) close() { s.timer.Stop() } + +type pendingOp int + +const ( + pendingOpAdd pendingOp = iota + pendingOpRemove +) + +type pendingSubscription struct { + op pendingOp + sub Subscription +}