diff --git a/Makefile b/Makefile index 15b69e4a..5addbd4e 100644 --- a/Makefile +++ b/Makefile @@ -284,7 +284,7 @@ mocks: mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \ EventSource,EventIDStore > internal/services/userevents/mocks/mocks.go mockgen --package userevents github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \ -MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscriber,UserUsedSpaceSubscriber > tmp +EventSubscriber,MessageEventHandler,LabelEventHandler,AddressEventHandler,RefreshEventHandler,UserEventHandler,UserUsedSpaceEventHandler > tmp mv tmp internal/services/userevents/mocks_test.go mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/events EventPublisher \ > internal/events/mocks/mocks.go diff --git a/internal/services/imapservice/service.go b/internal/services/imapservice/service.go index 43fd368e..0b14f132 100644 --- a/internal/services/imapservice/service.go +++ b/internal/services/imapservice/service.go @@ -73,11 +73,7 @@ type Service struct { labels *rwLabels addressMode usertypes.AddressMode - refreshSubscriber *userevents.RefreshChanneledSubscriber - addressSubscriber *userevents.AddressChanneledSubscriber - userSubscriber *userevents.UserChanneledSubscriber - messageSubscriber *userevents.MessageChanneledSubscriber - labelSubscriber *userevents.LabelChanneledSubscriber + subscription *userevents.EventChanneledSubscriber gluonIDProvider GluonIDProvider syncStateProvider SyncStateProvider @@ -95,6 +91,7 @@ type Service struct { connectors map[string]*Connector maxSyncMemory uint64 showAllMail bool + syncHandler *syncHandler } func NewService( @@ -135,11 +132,7 @@ func NewService( eventProvider: eventProvider, eventPublisher: eventPublisher, - refreshSubscriber: userevents.NewRefreshSubscriber(subscriberName), - addressSubscriber: userevents.NewAddressSubscriber(subscriberName), - userSubscriber: userevents.NewUserSubscriber(subscriberName), - messageSubscriber: userevents.NewMessageSubscriber(subscriberName), - labelSubscriber: userevents.NewLabelSubscriber(subscriberName), + subscription: userevents.NewEventSubscriber(subscriberName), panicHandler: panicHandler, sendRecorder: sendRecorder, @@ -241,6 +234,44 @@ func (s *Service) Close() { s.connectors = make(map[string]*Connector) } +func (s *Service) HandleRefreshEvent(ctx context.Context, _ proton.RefreshFlag) error { + s.log.Debug("handling refresh event") + + if err := s.identityState.Write(func(identity *useridentity.State) error { + return identity.OnRefreshEvent(ctx) + }); err != nil { + s.log.WithError(err).Error("Failed to apply refresh event to identity state") + return err + } + + s.syncHandler.CancelAndWait() + + if err := s.removeConnectorsFromServer(ctx, s.connectors, true); err != nil { + return err + } + + if err := s.syncStateProvider.ClearSyncStatus(); err != nil { + return fmt.Errorf("failed to clear sync status:%w", err) + } + + if err := s.addConnectorsToServer(ctx, s.connectors); err != nil { + return err + } + + s.syncHandler.launch(s) + + return nil +} + +func (s *Service) HandleUserEvent(_ context.Context, user *proton.User) error { + s.log.Debug("handling user event") + + return s.identityState.Write(func(identity *useridentity.State) error { + identity.OnUserEvent(*user) + return nil + }) +} + func (s *Service) run(ctx context.Context) { //nolint gocyclo s.log.Info("Starting IMAP Service") defer s.log.Info("Exiting IMAP Service") @@ -248,21 +279,21 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo defer s.cpc.Close() defer s.eventSubscription.Remove(s.eventWatcher) - syncHandler := newSyncHandler(ctx, s.panicHandler) - defer syncHandler.Close() + s.syncHandler = newSyncHandler(ctx, s.panicHandler) + defer s.syncHandler.Close() - syncHandler.launch(s) + s.syncHandler.launch(s) - subscription := userevents.Subscription{ - User: s.userSubscriber, - Refresh: s.refreshSubscriber, - Address: s.addressSubscriber, - Labels: s.labelSubscriber, - Messages: s.messageSubscriber, + eventHandler := userevents.EventHandler{ + UserHandler: s, + AddressHandler: s, + RefreshHandler: s, + LabelHandler: s, + MessageHandler: s, } - s.eventProvider.Subscribe(subscription) - defer s.eventProvider.Unsubscribe(subscription) + s.eventProvider.Subscribe(s.subscription) + defer s.eventProvider.Unsubscribe(s.subscription) for { select { @@ -275,25 +306,25 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo } switch r := req.Value().(type) { case *setAddressModeReq: - err := s.setAddressMode(ctx, syncHandler, r.mode) + err := s.setAddressMode(ctx, s.syncHandler, r.mode) req.Reply(ctx, nil, err) case *resyncReq: s.log.Info("Received resync request, handling as refresh event") - err := s.onRefreshEvent(ctx, syncHandler) + err := s.HandleRefreshEvent(ctx, 0) req.Reply(ctx, nil, err) s.log.Info("Resync reply sent, handling as refresh event") case *cancelSyncReq: s.log.Info("Cancelling sync") - syncHandler.Cancel() + s.syncHandler.Cancel() req.Reply(ctx, nil, nil) case *resumeSyncReq: s.log.Info("Resuming sync") // Cancel previous run, if any, just in case. - syncHandler.CancelAndWait() - syncHandler.launch(s) + s.syncHandler.CancelAndWait() + s.syncHandler.launch(s) req.Reply(ctx, nil, nil) case *getLabelsReq: labels := s.labels.GetLabelMap() @@ -319,7 +350,7 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo s.log.Error("Received unknown request") } - case err, ok := <-syncHandler.OnSyncFinishedCH(): + case err, ok := <-s.syncHandler.OnSyncFinishedCH(): { if !ok { continue @@ -334,46 +365,18 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo s.eventProvider.Resume() } - case update, ok := <-syncHandler.updater.ch: + case update, ok := <-s.syncHandler.updater.ch: if !ok { continue } s.onSyncUpdate(ctx, update) - case e, ok := <-s.userSubscriber.OnEventCh(): + case e, ok := <-s.subscription.OnEventCh(): if !ok { continue } - e.Consume(func(user proton.User) error { - return s.onUserEvent(user) - }) - case e, ok := <-s.addressSubscriber.OnEventCh(): - if !ok { - continue - } - e.Consume(func(events []proton.AddressEvent) error { - return s.onAddressEvent(ctx, events) - }) - case e, ok := <-s.labelSubscriber.OnEventCh(): - if !ok { - continue - } - e.Consume(func(events []proton.LabelEvent) error { - return s.onLabelEvent(ctx, events) - }) - case e, ok := <-s.messageSubscriber.OnEventCh(): - if !ok { - continue - } - e.Consume(func(events []proton.MessageEvent) error { - return s.onMessageEvent(ctx, events) - }) - case e, ok := <-s.refreshSubscriber.OnEventCh(): - if !ok { - continue - } - e.Consume(func(_ proton.RefreshFlag) error { - return s.onRefreshEvent(ctx, syncHandler) + e.Consume(func(event proton.Event) error { + return eventHandler.OnEvent(ctx, event) }) case e, ok := <-s.eventWatcher.GetChannel(): if !ok { @@ -389,43 +392,6 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo } } -func (s *Service) onRefreshEvent(ctx context.Context, handler *syncHandler) error { - s.log.Debug("handling refresh event") - - if err := s.identityState.Write(func(identity *useridentity.State) error { - return identity.OnRefreshEvent(ctx) - }); err != nil { - s.log.WithError(err).Error("Failed to apply refresh event to identity state") - return err - } - - handler.CancelAndWait() - - if err := s.removeConnectorsFromServer(ctx, s.connectors, true); err != nil { - return err - } - - if err := s.syncStateProvider.ClearSyncStatus(); err != nil { - return fmt.Errorf("failed to clear sync status:%w", err) - } - - if err := s.addConnectorsToServer(ctx, s.connectors); err != nil { - return err - } - - handler.launch(s) - - return nil -} - -func (s *Service) onUserEvent(user proton.User) error { - s.log.Debug("handling user event") - return s.identityState.Write(func(identity *useridentity.State) error { - identity.OnUserEvent(user) - return nil - }) -} - func (s *Service) buildConnectors() (map[string]*Connector, error) { connectors := make(map[string]*Connector) diff --git a/internal/services/imapservice/service_address_events.go b/internal/services/imapservice/service_address_events.go index f919fbbf..5f673243 100644 --- a/internal/services/imapservice/service_address_events.go +++ b/internal/services/imapservice/service_address_events.go @@ -26,7 +26,7 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" ) -func (s *Service) onAddressEvent(ctx context.Context, events []proton.AddressEvent) error { +func (s *Service) HandleAddressEvents(ctx context.Context, events []proton.AddressEvent) error { s.log.Debug("handling address event") if s.addressMode == usertypes.AddressModeCombined { diff --git a/internal/services/imapservice/service_label_events.go b/internal/services/imapservice/service_label_events.go index b9ed2889..472a7a29 100644 --- a/internal/services/imapservice/service_label_events.go +++ b/internal/services/imapservice/service_label_events.go @@ -32,7 +32,7 @@ import ( "golang.org/x/exp/maps" ) -func (s *Service) onLabelEvent(ctx context.Context, events []proton.LabelEvent) error { +func (s *Service) HandleLabelEvents(ctx context.Context, events []proton.LabelEvent) error { s.log.Debug("handling label event") for _, event := range events { diff --git a/internal/services/imapservice/service_message_events.go b/internal/services/imapservice/service_message_events.go index 9257ffe5..32391462 100644 --- a/internal/services/imapservice/service_message_events.go +++ b/internal/services/imapservice/service_message_events.go @@ -36,7 +36,7 @@ import ( "golang.org/x/exp/maps" ) -func (s *Service) onMessageEvent(ctx context.Context, events []proton.MessageEvent) error { +func (s *Service) HandleMessageEvents(ctx context.Context, events []proton.MessageEvent) error { s.log.Debug("handling message event") for _, event := range events { diff --git a/internal/services/smtp/service.go b/internal/services/smtp/service.go index 10ba8a8a..ede91e57 100644 --- a/internal/services/smtp/service.go +++ b/internal/services/smtp/service.go @@ -57,10 +57,8 @@ type Service struct { identityState *useridentity.State telemetry Telemetry - eventService userevents.Subscribable - refreshSubscriber *userevents.RefreshChanneledSubscriber - addressSubscriber *userevents.AddressChanneledSubscriber - userSubscriber *userevents.UserChanneledSubscriber + eventService userevents.Subscribable + subscription *userevents.EventChanneledSubscriber addressMode usertypes.AddressMode serverManager ServerManager @@ -100,9 +98,7 @@ func NewService( identityState: identityState, eventService: eventService, - refreshSubscriber: userevents.NewRefreshSubscriber(subscriberName), - userSubscriber: userevents.NewUserSubscriber(subscriberName), - addressSubscriber: userevents.NewAddressSubscriber(subscriberName), + subscription: userevents.NewEventSubscriber(subscriberName), addressMode: mode, serverManager: serverManager, @@ -168,19 +164,38 @@ func (s *Service) UserID() string { return s.userID } +func (s *Service) HandleRefreshEvent(ctx context.Context, _ proton.RefreshFlag) error { + s.log.Debug("Handling refresh event") + return s.identityState.OnRefreshEvent(ctx) +} + +func (s *Service) HandleAddressEvents(_ context.Context, events []proton.AddressEvent) error { + s.log.Debug("Handling Address Event") + s.identityState.OnAddressEvents(events) + + return nil +} + +func (s *Service) HandleUserEvent(_ context.Context, user *proton.User) error { + s.log.Debug("Handling user event") + s.identityState.OnUserEvent(*user) + + return nil +} + func (s *Service) run(ctx context.Context) { s.log.Info("Starting service main loop") defer s.log.Info("Exiting service main loop") defer s.cpc.Close() - subscription := userevents.Subscription{ - User: s.userSubscriber, - Refresh: s.refreshSubscriber, - Address: s.addressSubscriber, + eventHandler := userevents.EventHandler{ + AddressHandler: s, + RefreshHandler: s, + UserHandler: s, } - s.eventService.Subscribe(subscription) - defer s.eventService.Unsubscribe(subscription) + s.eventService.Subscribe(s.subscription) + defer s.eventService.Unsubscribe(s.subscription) for { select { @@ -219,34 +234,12 @@ func (s *Service) run(ctx context.Context) { default: s.log.Error("Received unknown request") } - case e, ok := <-s.userSubscriber.OnEventCh(): + case e, ok := <-s.subscription.OnEventCh(): if !ok { continue } - - s.log.Debug("Handling user event") - e.Consume(func(user proton.User) error { - s.identityState.OnUserEvent(user) - return nil - }) - case e, ok := <-s.refreshSubscriber.OnEventCh(): - if !ok { - continue - } - - s.log.Debug("Handling refresh event") - e.Consume(func(_ proton.RefreshFlag) error { - return s.identityState.OnRefreshEvent(ctx) - }) - case e, ok := <-s.addressSubscriber.OnEventCh(): - if !ok { - continue - } - - s.log.Debug("Handling Address Event") - e.Consume(func(evt []proton.AddressEvent) error { - s.identityState.OnAddressEvents(evt) - return nil + e.Consume(func(event proton.Event) error { + return eventHandler.OnEvent(ctx, event) }) } } diff --git a/internal/services/userevents/mocks_test.go b/internal/services/userevents/mocks_test.go index dbf7a736..15852e0b 100644 --- a/internal/services/userevents/mocks_test.go +++ b/internal/services/userevents/mocks_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/userevents (interfaces: MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscriber,UserUsedSpaceSubscriber) +// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/userevents (interfaces: EventSubscriber,MessageEventHandler,LabelEventHandler,AddressEventHandler,RefreshEventHandler,UserEventHandler,UserUsedSpaceEventHandler) // Package userevents is a generated GoMock package. package userevents @@ -12,55 +12,55 @@ import ( gomock "github.com/golang/mock/gomock" ) -// MockMessageSubscriber is a mock of MessageSubscriber interface. -type MockMessageSubscriber struct { +// MockEventSubscriber is a mock of EventSubscriber interface. +type MockEventSubscriber struct { ctrl *gomock.Controller - recorder *MockMessageSubscriberMockRecorder + recorder *MockEventSubscriberMockRecorder } -// MockMessageSubscriberMockRecorder is the mock recorder for MockMessageSubscriber. -type MockMessageSubscriberMockRecorder struct { - mock *MockMessageSubscriber +// MockEventSubscriberMockRecorder is the mock recorder for MockEventSubscriber. +type MockEventSubscriberMockRecorder struct { + mock *MockEventSubscriber } -// NewMockMessageSubscriber creates a new mock instance. -func NewMockMessageSubscriber(ctrl *gomock.Controller) *MockMessageSubscriber { - mock := &MockMessageSubscriber{ctrl: ctrl} - mock.recorder = &MockMessageSubscriberMockRecorder{mock} +// NewMockEventSubscriber creates a new mock instance. +func NewMockEventSubscriber(ctrl *gomock.Controller) *MockEventSubscriber { + mock := &MockEventSubscriber{ctrl: ctrl} + mock.recorder = &MockEventSubscriberMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMessageSubscriber) EXPECT() *MockMessageSubscriberMockRecorder { +func (m *MockEventSubscriber) EXPECT() *MockEventSubscriberMockRecorder { return m.recorder } // cancel mocks base method. -func (m *MockMessageSubscriber) cancel() { +func (m *MockEventSubscriber) cancel() { m.ctrl.T.Helper() m.ctrl.Call(m, "cancel") } // cancel indicates an expected call of cancel. -func (mr *MockMessageSubscriberMockRecorder) cancel() *gomock.Call { +func (mr *MockEventSubscriberMockRecorder) cancel() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockMessageSubscriber)(nil).cancel)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockEventSubscriber)(nil).cancel)) } // close mocks base method. -func (m *MockMessageSubscriber) close() { +func (m *MockEventSubscriber) close() { m.ctrl.T.Helper() m.ctrl.Call(m, "close") } // close indicates an expected call of close. -func (mr *MockMessageSubscriberMockRecorder) close() *gomock.Call { +func (mr *MockEventSubscriberMockRecorder) close() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockMessageSubscriber)(nil).close)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockEventSubscriber)(nil).close)) } // handle mocks base method. -func (m *MockMessageSubscriber) handle(arg0 context.Context, arg1 []proton.MessageEvent) error { +func (m *MockEventSubscriber) handle(arg0 context.Context, arg1 proton.Event) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "handle", arg0, arg1) ret0, _ := ret[0].(error) @@ -68,13 +68,13 @@ func (m *MockMessageSubscriber) handle(arg0 context.Context, arg1 []proton.Messa } // handle indicates an expected call of handle. -func (mr *MockMessageSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockEventSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockMessageSubscriber)(nil).handle), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockEventSubscriber)(nil).handle), arg0, arg1) } // name mocks base method. -func (m *MockMessageSubscriber) name() string { +func (m *MockEventSubscriber) name() string { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "name") ret0, _ := ret[0].(string) @@ -82,382 +82,229 @@ func (m *MockMessageSubscriber) name() string { } // name indicates an expected call of name. -func (mr *MockMessageSubscriberMockRecorder) name() *gomock.Call { +func (mr *MockEventSubscriberMockRecorder) name() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockMessageSubscriber)(nil).name)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockEventSubscriber)(nil).name)) } -// MockLabelSubscriber is a mock of LabelSubscriber interface. -type MockLabelSubscriber struct { +// MockMessageEventHandler is a mock of MessageEventHandler interface. +type MockMessageEventHandler struct { ctrl *gomock.Controller - recorder *MockLabelSubscriberMockRecorder + recorder *MockMessageEventHandlerMockRecorder } -// MockLabelSubscriberMockRecorder is the mock recorder for MockLabelSubscriber. -type MockLabelSubscriberMockRecorder struct { - mock *MockLabelSubscriber +// MockMessageEventHandlerMockRecorder is the mock recorder for MockMessageEventHandler. +type MockMessageEventHandlerMockRecorder struct { + mock *MockMessageEventHandler } -// NewMockLabelSubscriber creates a new mock instance. -func NewMockLabelSubscriber(ctrl *gomock.Controller) *MockLabelSubscriber { - mock := &MockLabelSubscriber{ctrl: ctrl} - mock.recorder = &MockLabelSubscriberMockRecorder{mock} +// NewMockMessageEventHandler creates a new mock instance. +func NewMockMessageEventHandler(ctrl *gomock.Controller) *MockMessageEventHandler { + mock := &MockMessageEventHandler{ctrl: ctrl} + mock.recorder = &MockMessageEventHandlerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockLabelSubscriber) EXPECT() *MockLabelSubscriberMockRecorder { +func (m *MockMessageEventHandler) EXPECT() *MockMessageEventHandlerMockRecorder { return m.recorder } -// cancel mocks base method. -func (m *MockLabelSubscriber) cancel() { +// HandleMessageEvents mocks base method. +func (m *MockMessageEventHandler) HandleMessageEvents(arg0 context.Context, arg1 []proton.MessageEvent) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "cancel") -} - -// cancel indicates an expected call of cancel. -func (mr *MockLabelSubscriberMockRecorder) cancel() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockLabelSubscriber)(nil).cancel)) -} - -// close mocks base method. -func (m *MockLabelSubscriber) close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "close") -} - -// close indicates an expected call of close. -func (mr *MockLabelSubscriberMockRecorder) close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockLabelSubscriber)(nil).close)) -} - -// handle mocks base method. -func (m *MockLabelSubscriber) handle(arg0 context.Context, arg1 []proton.LabelEvent) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret := m.ctrl.Call(m, "HandleMessageEvents", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } -// handle indicates an expected call of handle. -func (mr *MockLabelSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { +// HandleMessageEvents indicates an expected call of HandleMessageEvents. +func (mr *MockMessageEventHandlerMockRecorder) HandleMessageEvents(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockLabelSubscriber)(nil).handle), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessageEvents", reflect.TypeOf((*MockMessageEventHandler)(nil).HandleMessageEvents), arg0, arg1) } -// name mocks base method. -func (m *MockLabelSubscriber) name() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "name") - ret0, _ := ret[0].(string) - return ret0 -} - -// name indicates an expected call of name. -func (mr *MockLabelSubscriberMockRecorder) name() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockLabelSubscriber)(nil).name)) -} - -// MockAddressSubscriber is a mock of AddressSubscriber interface. -type MockAddressSubscriber struct { +// MockLabelEventHandler is a mock of LabelEventHandler interface. +type MockLabelEventHandler struct { ctrl *gomock.Controller - recorder *MockAddressSubscriberMockRecorder + recorder *MockLabelEventHandlerMockRecorder } -// MockAddressSubscriberMockRecorder is the mock recorder for MockAddressSubscriber. -type MockAddressSubscriberMockRecorder struct { - mock *MockAddressSubscriber +// MockLabelEventHandlerMockRecorder is the mock recorder for MockLabelEventHandler. +type MockLabelEventHandlerMockRecorder struct { + mock *MockLabelEventHandler } -// NewMockAddressSubscriber creates a new mock instance. -func NewMockAddressSubscriber(ctrl *gomock.Controller) *MockAddressSubscriber { - mock := &MockAddressSubscriber{ctrl: ctrl} - mock.recorder = &MockAddressSubscriberMockRecorder{mock} +// NewMockLabelEventHandler creates a new mock instance. +func NewMockLabelEventHandler(ctrl *gomock.Controller) *MockLabelEventHandler { + mock := &MockLabelEventHandler{ctrl: ctrl} + mock.recorder = &MockLabelEventHandlerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockAddressSubscriber) EXPECT() *MockAddressSubscriberMockRecorder { +func (m *MockLabelEventHandler) EXPECT() *MockLabelEventHandlerMockRecorder { return m.recorder } -// cancel mocks base method. -func (m *MockAddressSubscriber) cancel() { +// HandleLabelEvents mocks base method. +func (m *MockLabelEventHandler) HandleLabelEvents(arg0 context.Context, arg1 []proton.LabelEvent) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "cancel") -} - -// cancel indicates an expected call of cancel. -func (mr *MockAddressSubscriberMockRecorder) cancel() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockAddressSubscriber)(nil).cancel)) -} - -// close mocks base method. -func (m *MockAddressSubscriber) close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "close") -} - -// close indicates an expected call of close. -func (mr *MockAddressSubscriberMockRecorder) close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockAddressSubscriber)(nil).close)) -} - -// handle mocks base method. -func (m *MockAddressSubscriber) handle(arg0 context.Context, arg1 []proton.AddressEvent) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret := m.ctrl.Call(m, "HandleLabelEvents", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } -// handle indicates an expected call of handle. -func (mr *MockAddressSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { +// HandleLabelEvents indicates an expected call of HandleLabelEvents. +func (mr *MockLabelEventHandlerMockRecorder) HandleLabelEvents(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockAddressSubscriber)(nil).handle), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleLabelEvents", reflect.TypeOf((*MockLabelEventHandler)(nil).HandleLabelEvents), arg0, arg1) } -// name mocks base method. -func (m *MockAddressSubscriber) name() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "name") - ret0, _ := ret[0].(string) - return ret0 -} - -// name indicates an expected call of name. -func (mr *MockAddressSubscriberMockRecorder) name() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockAddressSubscriber)(nil).name)) -} - -// MockRefreshSubscriber is a mock of RefreshSubscriber interface. -type MockRefreshSubscriber struct { +// MockAddressEventHandler is a mock of AddressEventHandler interface. +type MockAddressEventHandler struct { ctrl *gomock.Controller - recorder *MockRefreshSubscriberMockRecorder + recorder *MockAddressEventHandlerMockRecorder } -// MockRefreshSubscriberMockRecorder is the mock recorder for MockRefreshSubscriber. -type MockRefreshSubscriberMockRecorder struct { - mock *MockRefreshSubscriber +// MockAddressEventHandlerMockRecorder is the mock recorder for MockAddressEventHandler. +type MockAddressEventHandlerMockRecorder struct { + mock *MockAddressEventHandler } -// NewMockRefreshSubscriber creates a new mock instance. -func NewMockRefreshSubscriber(ctrl *gomock.Controller) *MockRefreshSubscriber { - mock := &MockRefreshSubscriber{ctrl: ctrl} - mock.recorder = &MockRefreshSubscriberMockRecorder{mock} +// NewMockAddressEventHandler creates a new mock instance. +func NewMockAddressEventHandler(ctrl *gomock.Controller) *MockAddressEventHandler { + mock := &MockAddressEventHandler{ctrl: ctrl} + mock.recorder = &MockAddressEventHandlerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockRefreshSubscriber) EXPECT() *MockRefreshSubscriberMockRecorder { +func (m *MockAddressEventHandler) EXPECT() *MockAddressEventHandlerMockRecorder { return m.recorder } -// cancel mocks base method. -func (m *MockRefreshSubscriber) cancel() { +// HandleAddressEvents mocks base method. +func (m *MockAddressEventHandler) HandleAddressEvents(arg0 context.Context, arg1 []proton.AddressEvent) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "cancel") -} - -// cancel indicates an expected call of cancel. -func (mr *MockRefreshSubscriberMockRecorder) cancel() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockRefreshSubscriber)(nil).cancel)) -} - -// close mocks base method. -func (m *MockRefreshSubscriber) close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "close") -} - -// close indicates an expected call of close. -func (mr *MockRefreshSubscriberMockRecorder) close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockRefreshSubscriber)(nil).close)) -} - -// handle mocks base method. -func (m *MockRefreshSubscriber) handle(arg0 context.Context, arg1 proton.RefreshFlag) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret := m.ctrl.Call(m, "HandleAddressEvents", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } -// handle indicates an expected call of handle. -func (mr *MockRefreshSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { +// HandleAddressEvents indicates an expected call of HandleAddressEvents. +func (mr *MockAddressEventHandlerMockRecorder) HandleAddressEvents(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockRefreshSubscriber)(nil).handle), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleAddressEvents", reflect.TypeOf((*MockAddressEventHandler)(nil).HandleAddressEvents), arg0, arg1) } -// name mocks base method. -func (m *MockRefreshSubscriber) name() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "name") - ret0, _ := ret[0].(string) - return ret0 -} - -// name indicates an expected call of name. -func (mr *MockRefreshSubscriberMockRecorder) name() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockRefreshSubscriber)(nil).name)) -} - -// MockUserSubscriber is a mock of UserSubscriber interface. -type MockUserSubscriber struct { +// MockRefreshEventHandler is a mock of RefreshEventHandler interface. +type MockRefreshEventHandler struct { ctrl *gomock.Controller - recorder *MockUserSubscriberMockRecorder + recorder *MockRefreshEventHandlerMockRecorder } -// MockUserSubscriberMockRecorder is the mock recorder for MockUserSubscriber. -type MockUserSubscriberMockRecorder struct { - mock *MockUserSubscriber +// MockRefreshEventHandlerMockRecorder is the mock recorder for MockRefreshEventHandler. +type MockRefreshEventHandlerMockRecorder struct { + mock *MockRefreshEventHandler } -// NewMockUserSubscriber creates a new mock instance. -func NewMockUserSubscriber(ctrl *gomock.Controller) *MockUserSubscriber { - mock := &MockUserSubscriber{ctrl: ctrl} - mock.recorder = &MockUserSubscriberMockRecorder{mock} +// NewMockRefreshEventHandler creates a new mock instance. +func NewMockRefreshEventHandler(ctrl *gomock.Controller) *MockRefreshEventHandler { + mock := &MockRefreshEventHandler{ctrl: ctrl} + mock.recorder = &MockRefreshEventHandlerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockUserSubscriber) EXPECT() *MockUserSubscriberMockRecorder { +func (m *MockRefreshEventHandler) EXPECT() *MockRefreshEventHandlerMockRecorder { return m.recorder } -// cancel mocks base method. -func (m *MockUserSubscriber) cancel() { +// HandleRefreshEvent mocks base method. +func (m *MockRefreshEventHandler) HandleRefreshEvent(arg0 context.Context, arg1 proton.RefreshFlag) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "cancel") -} - -// cancel indicates an expected call of cancel. -func (mr *MockUserSubscriberMockRecorder) cancel() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockUserSubscriber)(nil).cancel)) -} - -// close mocks base method. -func (m *MockUserSubscriber) close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "close") -} - -// close indicates an expected call of close. -func (mr *MockUserSubscriberMockRecorder) close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockUserSubscriber)(nil).close)) -} - -// handle mocks base method. -func (m *MockUserSubscriber) handle(arg0 context.Context, arg1 proton.User) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret := m.ctrl.Call(m, "HandleRefreshEvent", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } -// handle indicates an expected call of handle. -func (mr *MockUserSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { +// HandleRefreshEvent indicates an expected call of HandleRefreshEvent. +func (mr *MockRefreshEventHandlerMockRecorder) HandleRefreshEvent(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockUserSubscriber)(nil).handle), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleRefreshEvent", reflect.TypeOf((*MockRefreshEventHandler)(nil).HandleRefreshEvent), arg0, arg1) } -// name mocks base method. -func (m *MockUserSubscriber) name() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "name") - ret0, _ := ret[0].(string) - return ret0 -} - -// name indicates an expected call of name. -func (mr *MockUserSubscriberMockRecorder) name() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockUserSubscriber)(nil).name)) -} - -// MockUserUsedSpaceSubscriber is a mock of UserUsedSpaceSubscriber interface. -type MockUserUsedSpaceSubscriber struct { +// MockUserEventHandler is a mock of UserEventHandler interface. +type MockUserEventHandler struct { ctrl *gomock.Controller - recorder *MockUserUsedSpaceSubscriberMockRecorder + recorder *MockUserEventHandlerMockRecorder } -// MockUserUsedSpaceSubscriberMockRecorder is the mock recorder for MockUserUsedSpaceSubscriber. -type MockUserUsedSpaceSubscriberMockRecorder struct { - mock *MockUserUsedSpaceSubscriber +// MockUserEventHandlerMockRecorder is the mock recorder for MockUserEventHandler. +type MockUserEventHandlerMockRecorder struct { + mock *MockUserEventHandler } -// NewMockUserUsedSpaceSubscriber creates a new mock instance. -func NewMockUserUsedSpaceSubscriber(ctrl *gomock.Controller) *MockUserUsedSpaceSubscriber { - mock := &MockUserUsedSpaceSubscriber{ctrl: ctrl} - mock.recorder = &MockUserUsedSpaceSubscriberMockRecorder{mock} +// NewMockUserEventHandler creates a new mock instance. +func NewMockUserEventHandler(ctrl *gomock.Controller) *MockUserEventHandler { + mock := &MockUserEventHandler{ctrl: ctrl} + mock.recorder = &MockUserEventHandlerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockUserUsedSpaceSubscriber) EXPECT() *MockUserUsedSpaceSubscriberMockRecorder { +func (m *MockUserEventHandler) EXPECT() *MockUserEventHandlerMockRecorder { return m.recorder } -// cancel mocks base method. -func (m *MockUserUsedSpaceSubscriber) cancel() { +// HandleUserEvent mocks base method. +func (m *MockUserEventHandler) HandleUserEvent(arg0 context.Context, arg1 *proton.User) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "cancel") -} - -// cancel indicates an expected call of cancel. -func (mr *MockUserUsedSpaceSubscriberMockRecorder) cancel() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).cancel)) -} - -// close mocks base method. -func (m *MockUserUsedSpaceSubscriber) close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "close") -} - -// close indicates an expected call of close. -func (mr *MockUserUsedSpaceSubscriberMockRecorder) close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).close)) -} - -// handle mocks base method. -func (m *MockUserUsedSpaceSubscriber) handle(arg0 context.Context, arg1 int) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret := m.ctrl.Call(m, "HandleUserEvent", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } -// handle indicates an expected call of handle. -func (mr *MockUserUsedSpaceSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { +// HandleUserEvent indicates an expected call of HandleUserEvent. +func (mr *MockUserEventHandlerMockRecorder) HandleUserEvent(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).handle), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleUserEvent", reflect.TypeOf((*MockUserEventHandler)(nil).HandleUserEvent), arg0, arg1) } -// name mocks base method. -func (m *MockUserUsedSpaceSubscriber) name() string { +// MockUserUsedSpaceEventHandler is a mock of UserUsedSpaceEventHandler interface. +type MockUserUsedSpaceEventHandler struct { + ctrl *gomock.Controller + recorder *MockUserUsedSpaceEventHandlerMockRecorder +} + +// MockUserUsedSpaceEventHandlerMockRecorder is the mock recorder for MockUserUsedSpaceEventHandler. +type MockUserUsedSpaceEventHandlerMockRecorder struct { + mock *MockUserUsedSpaceEventHandler +} + +// NewMockUserUsedSpaceEventHandler creates a new mock instance. +func NewMockUserUsedSpaceEventHandler(ctrl *gomock.Controller) *MockUserUsedSpaceEventHandler { + mock := &MockUserUsedSpaceEventHandler{ctrl: ctrl} + mock.recorder = &MockUserUsedSpaceEventHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUserUsedSpaceEventHandler) EXPECT() *MockUserUsedSpaceEventHandlerMockRecorder { + return m.recorder +} + +// HandleUsedSpaceEvent mocks base method. +func (m *MockUserUsedSpaceEventHandler) HandleUsedSpaceEvent(arg0 context.Context, arg1 int) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "name") - ret0, _ := ret[0].(string) + ret := m.ctrl.Call(m, "HandleUsedSpaceEvent", arg0, arg1) + ret0, _ := ret[0].(error) return ret0 } -// name indicates an expected call of name. -func (mr *MockUserUsedSpaceSubscriberMockRecorder) name() *gomock.Call { +// HandleUsedSpaceEvent indicates an expected call of HandleUsedSpaceEvent. +func (mr *MockUserUsedSpaceEventHandlerMockRecorder) HandleUsedSpaceEvent(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).name)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleUsedSpaceEvent", reflect.TypeOf((*MockUserUsedSpaceEventHandler)(nil).HandleUsedSpaceEvent), arg0, arg1) } diff --git a/internal/services/userevents/service.go b/internal/services/userevents/service.go index 72a7612f..471a4130 100644 --- a/internal/services/userevents/service.go +++ b/internal/services/userevents/service.go @@ -57,12 +57,7 @@ type Service struct { paused uint32 panicHandler async.PanicHandler - userSubscriberList userSubscriberList - addressSubscribers addressSubscriberList - labelSubscribers labelSubscriberList - messageSubscribers messageSubscriberList - refreshSubscribers refreshSubscriberList - userUsedSpaceSubscriber userUsedSpaceSubscriberList + subscriberList eventSubscriberList pendingSubscriptionsLock sync.Mutex pendingSubscriptions []pendingSubscription @@ -94,61 +89,9 @@ func NewService( } } -type Subscription struct { - User UserSubscriber - Refresh RefreshSubscriber - Address AddressSubscriber - Labels LabelSubscriber - Messages MessageSubscriber - UserUsedSpace UserUsedSpaceSubscriber -} - -// cancel subscription subscribers if applicable, see `subscriber.cancel` for more information. -func (s Subscription) cancel() { - if s.User != nil { - s.User.cancel() - } - if s.Refresh != nil { - s.Refresh.cancel() - } - if s.Address != nil { - s.Address.cancel() - } - if s.Labels != nil { - s.Labels.cancel() - } - if s.Messages != nil { - s.Messages.cancel() - } - if s.UserUsedSpace != nil { - s.UserUsedSpace.cancel() - } -} - -func (s Subscription) close() { - if s.User != nil { - s.User.close() - } - if s.Refresh != nil { - s.Refresh.close() - } - if s.Address != nil { - s.Address.close() - } - if s.Labels != nil { - s.Labels.close() - } - if s.Messages != nil { - s.Messages.close() - } - if s.UserUsedSpace != nil { - s.UserUsedSpace.close() - } -} - // Subscribe adds new subscribers to the service. // This method can safely be called during event handling. -func (s *Service) Subscribe(subscription Subscription) { +func (s *Service) Subscribe(subscription EventSubscriber) { s.pendingSubscriptionsLock.Lock() defer s.pendingSubscriptionsLock.Unlock() @@ -157,7 +100,7 @@ func (s *Service) Subscribe(subscription Subscription) { // Unsubscribe removes subscribers from the service. // This method can safely be called during event handling. -func (s *Service) Unsubscribe(subscription Subscription) { +func (s *Service) Unsubscribe(subscription EventSubscriber) { subscription.cancel() s.pendingSubscriptionsLock.Lock() @@ -288,7 +231,7 @@ func (s *Service) Close() { s.pendingSubscriptionsLock.Lock() defer s.pendingSubscriptionsLock.Unlock() - processed := xmaps.Set[Subscription]{} + processed := xmaps.Set[EventSubscriber]{} // Cleanup pending removes. for _, s := range s.pendingSubscriptions { @@ -313,70 +256,20 @@ func (s *Service) handleEvent(ctx context.Context, lastEventID string, event pro }).Info("Received new API event") if event.Refresh&proton.RefreshMail != 0 { - s.log.Info("Handling refresh event") - if err := s.refreshSubscribers.Publish(ctx, event.Refresh, s.eventTimeout); err != nil { - return fmt.Errorf("failed to apply refresh event: %w", err) - } - - return nil + s.log.Info("Received refresh event") } - // Start with user events. - if event.User != nil { - if err := s.userSubscriberList.PublishParallel(ctx, *event.User, s.panicHandler, s.eventTimeout); err != nil { - return fmt.Errorf("failed to apply user event: %w", err) - } - } - - // Next Address events - if err := s.addressSubscribers.PublishParallel(ctx, event.Addresses, s.panicHandler, s.eventTimeout); err != nil { - return fmt.Errorf("failed to apply address events: %w", err) - } - - // Next label events - if err := s.labelSubscribers.PublishParallel(ctx, event.Labels, s.panicHandler, s.eventTimeout); err != nil { - return fmt.Errorf("failed to apply label events: %w", err) - } - - // Next message events - if err := s.messageSubscribers.PublishParallel(ctx, event.Messages, s.panicHandler, s.eventTimeout); err != nil { - return fmt.Errorf("failed to apply message events: %w", err) - } - - // Finally user used space events - if event.UsedSpace != nil { - if err := s.userUsedSpaceSubscriber.PublishParallel(ctx, *event.UsedSpace, s.panicHandler, s.eventTimeout); err != nil { - return fmt.Errorf("failed to apply message events: %w", err) - } - } - - return nil + return s.subscriberList.PublishParallel(ctx, event, s.panicHandler, s.eventTimeout) } func unpackPublisherError(err error) (string, error) { - var addressErr *addressPublishError - var labelErr *labelPublishError - var messageErr *messagePublishError - var refreshErr *refreshPublishError - var userErr *userPublishError - var usedSpaceErr *userUsedEventPublishError + var publishErr *eventPublishError - switch { - case errors.As(err, &userErr): - return userErr.subscriber.name(), userErr.error - case errors.As(err, &addressErr): - return addressErr.subscriber.name(), addressErr.error - case errors.As(err, &labelErr): - return labelErr.subscriber.name(), labelErr.error - case errors.As(err, &messageErr): - return messageErr.subscriber.name(), messageErr.error - case errors.As(err, &refreshErr): - return refreshErr.subscriber.name(), refreshErr.error - case errors.As(err, &usedSpaceErr): - return usedSpaceErr.subscriber.name(), usedSpaceErr.error - default: - return "", err + if errors.As(err, &publishErr) { + return publishErr.subscriber.name(), publishErr.error } + + return "", err } func (s *Service) handleEventError(ctx context.Context, lastEventID string, event proton.Event, err error) (string, error) { @@ -437,56 +330,12 @@ func (s *Service) onBadEvent(ctx context.Context, event events.UserBadEvent) { s.eventPublisher.PublishEvent(ctx, event) } -func (s *Service) addSubscription(subscription Subscription) { - if subscription.User != nil { - s.userSubscriberList.Add(subscription.User) - } - - if subscription.Refresh != nil { - s.refreshSubscribers.Add(subscription.Refresh) - } - - if subscription.Address != nil { - s.addressSubscribers.Add(subscription.Address) - } - - if subscription.Labels != nil { - s.labelSubscribers.Add(subscription.Labels) - } - - if subscription.Messages != nil { - s.messageSubscribers.Add(subscription.Messages) - } - - if subscription.UserUsedSpace != nil { - s.userUsedSpaceSubscriber.Add(subscription.UserUsedSpace) - } +func (s *Service) addSubscription(subscription EventSubscriber) { + s.subscriberList.Add(subscription) } -func (s *Service) removeSubscription(subscription Subscription) { - if subscription.User != nil { - s.userSubscriberList.Remove(subscription.User) - } - - if subscription.Refresh != nil { - s.refreshSubscribers.Remove(subscription.Refresh) - } - - if subscription.Address != nil { - s.addressSubscribers.Remove(subscription.Address) - } - - if subscription.Labels != nil { - s.labelSubscribers.Remove(subscription.Labels) - } - - if subscription.Messages != nil { - s.messageSubscribers.Remove(subscription.Messages) - } - - if subscription.UserUsedSpace != nil { - s.userUsedSpaceSubscriber.Remove(subscription.UserUsedSpace) - } +func (s *Service) removeSubscription(subscription EventSubscriber) { + s.subscriberList.Remove(subscription) } type pendingOp int @@ -498,5 +347,5 @@ const ( type pendingSubscription struct { op pendingOp - sub Subscription + sub EventSubscriber } diff --git a/internal/services/userevents/service_handle_event_error_test.go b/internal/services/userevents/service_handle_event_error_test.go index 563cefdf..0f3ddefa 100644 --- a/internal/services/userevents/service_handle_event_error_test.go +++ b/internal/services/userevents/service_handle_event_error_test.go @@ -52,9 +52,9 @@ func TestServiceHandleEventError_SubscriberEventUnwrapping(t *testing.T) { lastEventID := "PrevEvent" event := proton.Event{EventID: "MyEvent"} - subscriber := &noOpSubscriber[proton.AddressEvent]{} + subscriber := &noOpSubscriber[proton.Event]{} - err := &addressPublishError{ + err := &eventPublishError{ subscriber: subscriber, error: &proton.NetError{}, } @@ -192,7 +192,7 @@ func (n noOpSubscriber[T]) name() string { //nolint:unused return "NoopSubscriber" } -func (n noOpSubscriber[T]) handle(_ context.Context, _ []T) error { //nolint:unused +func (n noOpSubscriber[T]) handle(_ context.Context, _ T) error { //nolint:unused return nil } diff --git a/internal/services/userevents/service_handle_event_test.go b/internal/services/userevents/service_handle_event_test.go index be966817..4e023d26 100644 --- a/internal/services/userevents/service_handle_event_test.go +++ b/internal/services/userevents/service_handle_event_test.go @@ -37,26 +37,26 @@ func TestServiceHandleEvent_CheckEventCategoriesHandledInOrder(t *testing.T) { eventPublisher := mocks.NewMockEventPublisher(mockCtrl) eventIDStore := NewInMemoryEventIDStore() - refreshHandler := NewMockRefreshSubscriber(mockCtrl) - refreshHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(2).Return(nil) + refreshHandler := NewMockRefreshEventHandler(mockCtrl) + refreshHandler.EXPECT().HandleRefreshEvent(gomock.Any(), gomock.Any()).Times(2).Return(nil) - userHandler := NewMockUserSubscriber(mockCtrl) - userCall := userHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(nil) + userHandler := NewMockUserEventHandler(mockCtrl) + userCall := userHandler.EXPECT().HandleUserEvent(gomock.Any(), gomock.Any()).Times(1).Return(nil) - addressHandler := NewMockAddressSubscriber(mockCtrl) - addressCall := addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(userCall).Times(1).Return(nil) + addressHandler := NewMockAddressEventHandler(mockCtrl) + addressCall := addressHandler.EXPECT().HandleAddressEvents(gomock.Any(), gomock.Any()).After(userCall).Times(1).Return(nil) - labelHandler := NewMockLabelSubscriber(mockCtrl) - labelCall := labelHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(addressCall).Times(1).Return(nil) + labelHandler := NewMockLabelEventHandler(mockCtrl) + labelCall := labelHandler.EXPECT().HandleLabelEvents(gomock.Any(), gomock.Any()).After(addressCall).Times(1).Return(nil) - messageHandler := NewMockMessageSubscriber(mockCtrl) - messageCall := messageHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(labelCall).Times(1).Return(nil) + messageHandler := NewMockMessageEventHandler(mockCtrl) + messageCall := messageHandler.EXPECT().HandleMessageEvents(gomock.Any(), gomock.Any()).After(labelCall).Times(1).Return(nil) - userSpaceHandler := NewMockUserUsedSpaceSubscriber(mockCtrl) - userSpaceCall := userSpaceHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(messageCall).Times(1).Return(nil) + userSpaceHandler := NewMockUserUsedSpaceEventHandler(mockCtrl) + userSpaceCall := userSpaceHandler.EXPECT().HandleUsedSpaceEvent(gomock.Any(), gomock.Any()).After(messageCall).Times(1).Return(nil) - secondRefreshHandler := NewMockRefreshSubscriber(mockCtrl) - secondRefreshHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(userSpaceCall).Times(1).Return(nil) + secondRefreshHandler := NewMockRefreshEventHandler(mockCtrl) + secondRefreshHandler.EXPECT().HandleRefreshEvent(gomock.Any(), gomock.Any()).After(userSpaceCall).Times(1).Return(nil) service := NewService( "foo", @@ -65,27 +65,31 @@ func TestServiceHandleEvent_CheckEventCategoriesHandledInOrder(t *testing.T) { eventPublisher, 100*time.Millisecond, time.Millisecond, - time.Second, + 10*time.Second, async.NoopPanicHandler{}, ) - service.addSubscription(Subscription{ - User: userHandler, - Refresh: refreshHandler, - Address: addressHandler, - Labels: labelHandler, - Messages: messageHandler, - UserUsedSpace: userSpaceHandler, + subscription := NewCallbackSubscriber("test", EventHandler{ + UserHandler: userHandler, + RefreshHandler: refreshHandler, + AddressHandler: addressHandler, + LabelHandler: labelHandler, + MessageHandler: messageHandler, + UsedSpaceHandler: userSpaceHandler, }) + service.addSubscription(subscription) + // Simulate 1st refresh. require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{Refresh: proton.RefreshMail})) // Simulate Regular event. usedSpace := 20 require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{ - User: new(proton.User), - Addresses: []proton.AddressEvent{}, + User: new(proton.User), + Addresses: []proton.AddressEvent{ + {}, + }, Labels: []proton.LabelEvent{ {}, }, @@ -95,9 +99,9 @@ func TestServiceHandleEvent_CheckEventCategoriesHandledInOrder(t *testing.T) { UsedSpace: &usedSpace, })) - service.addSubscription(Subscription{ - Refresh: secondRefreshHandler, - }) + service.addSubscription(NewCallbackSubscriber("test", EventHandler{ + RefreshHandler: secondRefreshHandler, + })) // Simulate 2nd refresh. require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{Refresh: proton.RefreshMail})) @@ -109,11 +113,10 @@ func TestServiceHandleEvent_CheckEventFailureCausesError(t *testing.T) { eventPublisher := mocks.NewMockEventPublisher(mockCtrl) eventIDStore := NewInMemoryEventIDStore() - addressHandler := NewMockAddressSubscriber(mockCtrl) - addressHandler.EXPECT().name().MinTimes(1).Return("Hello") - addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(fmt.Errorf("failed")) + addressHandler := NewMockAddressEventHandler(mockCtrl) + addressHandler.EXPECT().HandleAddressEvents(gomock.Any(), gomock.Any()).Times(1).Return(fmt.Errorf("failed")) - messageHandler := NewMockMessageSubscriber(mockCtrl) + messageHandler := NewMockMessageEventHandler(mockCtrl) service := NewService( "foo", @@ -126,16 +129,18 @@ func TestServiceHandleEvent_CheckEventFailureCausesError(t *testing.T) { async.NoopPanicHandler{}, ) - service.addSubscription(Subscription{ - Address: addressHandler, - Messages: messageHandler, + subscription := NewCallbackSubscriber("test", EventHandler{ + AddressHandler: addressHandler, + MessageHandler: messageHandler, }) + service.addSubscription(subscription) + err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}}) require.Error(t, err) - publisherErr := new(addressPublishError) + publisherErr := new(eventPublishError) require.True(t, errors.As(err, &publisherErr)) - require.Equal(t, publisherErr.subscriber, addressHandler) + require.Equal(t, publisherErr.subscriber, subscription) } func TestServiceHandleEvent_CheckEventFailureCausesErrorParallel(t *testing.T) { @@ -144,12 +149,11 @@ func TestServiceHandleEvent_CheckEventFailureCausesErrorParallel(t *testing.T) { eventPublisher := mocks.NewMockEventPublisher(mockCtrl) eventIDStore := NewInMemoryEventIDStore() - addressHandler := NewMockAddressSubscriber(mockCtrl) - addressHandler.EXPECT().name().MinTimes(1).Return("Hello") - addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(fmt.Errorf("failed")) + addressHandler := NewMockAddressEventHandler(mockCtrl) + addressHandler.EXPECT().HandleAddressEvents(gomock.Any(), gomock.Any()).Times(1).Return(fmt.Errorf("failed")) - addressHandler2 := NewMockAddressSubscriber(mockCtrl) - addressHandler2.EXPECT().handle(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil) + addressHandler2 := NewMockAddressEventHandler(mockCtrl) + addressHandler2.EXPECT().HandleAddressEvents(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil) service := NewService( "foo", @@ -162,19 +166,21 @@ func TestServiceHandleEvent_CheckEventFailureCausesErrorParallel(t *testing.T) { async.NoopPanicHandler{}, ) - service.addSubscription(Subscription{ - Address: addressHandler, + subscription := NewCallbackSubscriber("test", EventHandler{ + AddressHandler: addressHandler, }) - service.addSubscription(Subscription{ - Address: addressHandler2, - }) + service.addSubscription(subscription) + + service.addSubscription(NewCallbackSubscriber("test2", EventHandler{ + AddressHandler: addressHandler2, + })) err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}}) require.Error(t, err) - publisherErr := new(addressPublishError) + publisherErr := new(eventPublishError) require.True(t, errors.As(err, &publisherErr)) - require.Equal(t, publisherErr.subscriber, addressHandler) + require.Equal(t, publisherErr.subscriber, subscription) } func TestServiceHandleEvent_SubscriberTimeout(t *testing.T) { @@ -183,13 +189,11 @@ func TestServiceHandleEvent_SubscriberTimeout(t *testing.T) { eventPublisher := mocks.NewMockEventPublisher(mockCtrl) eventIDStore := NewInMemoryEventIDStore() - addressHandler := NewMockAddressSubscriber(mockCtrl) - addressHandler.EXPECT().name().AnyTimes().Return("Ok") - addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil) + addressHandler := NewMockAddressEventHandler(mockCtrl) + addressHandler.EXPECT().HandleAddressEvents(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil) - addressHandler2 := NewMockAddressSubscriber(mockCtrl) - addressHandler2.EXPECT().name().AnyTimes().Return("Timeout") - addressHandler2.EXPECT().handle(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, _ []proton.AddressEvent) error { + addressHandler2 := NewMockAddressEventHandler(mockCtrl) + addressHandler2.EXPECT().HandleAddressEvents(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, _ []proton.AddressEvent) error { timer := time.NewTimer(time.Second) select { @@ -211,19 +215,21 @@ func TestServiceHandleEvent_SubscriberTimeout(t *testing.T) { async.NoopPanicHandler{}, ) - service.addSubscription(Subscription{ - Address: addressHandler, + subscription := NewCallbackSubscriber("test", EventHandler{ + AddressHandler: addressHandler2, }) - service.addSubscription(Subscription{ - Address: addressHandler2, - }) + service.addSubscription(subscription) + + service.addSubscription(NewCallbackSubscriber("test2", EventHandler{ + AddressHandler: addressHandler, + })) // Simulate 1st refresh. err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}}) require.Error(t, err) - if publisherErr := new(addressPublishError); errors.As(err, &publisherErr) { - require.Equal(t, publisherErr.subscriber, addressHandler) + if publisherErr := new(eventPublishError); errors.As(err, &publisherErr) { + require.Equal(t, publisherErr.subscriber, subscription) require.True(t, errors.Is(publisherErr.error, ErrPublishTimeoutExceeded)) } else { require.True(t, errors.Is(err, ErrPublishTimeoutExceeded)) diff --git a/internal/services/userevents/service_test.go b/internal/services/userevents/service_test.go index f9702714..802e7281 100644 --- a/internal/services/userevents/service_test.go +++ b/internal/services/userevents/service_test.go @@ -87,7 +87,7 @@ func TestService_RetryEventOnNonCatastrophicFailure(t *testing.T) { eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) eventIDStore := mocks.NewMockEventIDStore(mockCtrl) eventSource := mocks.NewMockEventSource(mockCtrl) - subscriber := NewMockMessageSubscriber(mockCtrl) + subscriber := NewMockMessageEventHandler(mockCtrl) firstEventID := "EVENT01" secondEventID := "EVENT02" @@ -113,10 +113,9 @@ func TestService_RetryEventOnNonCatastrophicFailure(t *testing.T) { eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil) // Subscriber expectations. - subscriber.EXPECT().name().AnyTimes().Return("Foo") { - firstCall := subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).Return(io.ErrUnexpectedEOF) - subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).After(firstCall).Times(1).Return(nil) + firstCall := subscriber.EXPECT().HandleMessageEvents(gomock.Any(), gomock.Eq(messageEvents)).Times(1).Return(io.ErrUnexpectedEOF) + subscriber.EXPECT().HandleMessageEvents(gomock.Any(), gomock.Eq(messageEvents)).After(firstCall).Times(1).Return(nil) } service := NewService( @@ -129,7 +128,7 @@ func TestService_RetryEventOnNonCatastrophicFailure(t *testing.T) { time.Second, async.NoopPanicHandler{}, ) - service.Subscribe(Subscription{Messages: subscriber}) + service.Subscribe(NewCallbackSubscriber("foo", EventHandler{MessageHandler: subscriber})) require.NoError(t, service.Start(context.Background(), group)) service.Resume() @@ -142,7 +141,7 @@ func TestService_OnBadEventServiceIsPaused(t *testing.T) { eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) eventIDStore := mocks.NewMockEventIDStore(mockCtrl) eventSource := mocks.NewMockEventSource(mockCtrl) - subscriber := NewMockMessageSubscriber(mockCtrl) + subscriber := NewMockMessageEventHandler(mockCtrl) firstEventID := "EVENT01" secondEventID := "EVENT02" @@ -164,8 +163,7 @@ func TestService_OnBadEventServiceIsPaused(t *testing.T) { // Subscriber expectations. badEventErr := fmt.Errorf("I will cause bad event") - subscriber.EXPECT().name().AnyTimes().Return("Foo") - subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).Return(badEventErr) + subscriber.EXPECT().HandleMessageEvents(gomock.Any(), gomock.Eq(messageEvents)).Times(1).Return(badEventErr) service := NewService( "foo", @@ -184,7 +182,7 @@ func TestService_OnBadEventServiceIsPaused(t *testing.T) { OldEventID: firstEventID, NewEventID: secondEventID, EventInfo: secondEvent[0].String(), - Error: badEventErr, + Error: fmt.Errorf("failed to apply message events: %w", badEventErr), }).Do(func(_ context.Context, event events.Event) { group.Go(context.Background(), "", "", func(_ context.Context) { // Use background context to avoid having the request cancelled @@ -193,7 +191,7 @@ func TestService_OnBadEventServiceIsPaused(t *testing.T) { }) }) - service.Subscribe(Subscription{Messages: subscriber}) + service.Subscribe(NewCallbackSubscriber("foo", EventHandler{MessageHandler: subscriber})) require.NoError(t, service.Start(context.Background(), group)) service.Resume() group.Wait() @@ -205,7 +203,7 @@ func TestService_UnsubscribeDuringEventHandlingDoesNotCauseDeadlock(t *testing.T eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) eventIDStore := mocks.NewMockEventIDStore(mockCtrl) eventSource := mocks.NewMockEventSource(mockCtrl) - subscriber := NewMockMessageSubscriber(mockCtrl) + subscriber := NewMockMessageEventHandler(mockCtrl) firstEventID := "EVENT01" secondEventID := "EVENT02" @@ -241,16 +239,15 @@ func TestService_UnsubscribeDuringEventHandlingDoesNotCauseDeadlock(t *testing.T async.NoopPanicHandler{}, ) + subscription := NewCallbackSubscriber("foo", EventHandler{MessageHandler: subscriber}) + // Subscriber expectations. - subscriber.EXPECT().name().AnyTimes().Return("Foo") - subscriber.EXPECT().cancel().Times(1) - subscriber.EXPECT().close().Times(1) - subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).DoAndReturn(func(_ context.Context, _ []proton.MessageEvent) error { - service.Unsubscribe(Subscription{Messages: subscriber}) + subscriber.EXPECT().HandleMessageEvents(gomock.Any(), gomock.Eq(messageEvents)).Times(1).DoAndReturn(func(_ context.Context, _ []proton.MessageEvent) error { + service.Unsubscribe(subscription) return nil }) - service.Subscribe(Subscription{Messages: subscriber}) + service.Subscribe(subscription) require.NoError(t, service.Start(context.Background(), group)) service.Resume() group.Wait() @@ -262,7 +259,6 @@ func TestService_UnsubscribeBeforeHandlingEventIsNotConsideredError(t *testing.T eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) eventIDStore := mocks.NewMockEventIDStore(mockCtrl) eventSource := mocks.NewMockEventSource(mockCtrl) - subscriber := NewMessageSubscriber("My subscriber") firstEventID := "EVENT01" secondEventID := "EVENT02" @@ -299,16 +295,42 @@ func TestService_UnsubscribeBeforeHandlingEventIsNotConsideredError(t *testing.T async.NoopPanicHandler{}, ) + subscription := NewEventSubscriber("Foo") + // start subscriber group.Go(context.Background(), "", "", func(_ context.Context) { - defer service.Unsubscribe(Subscription{Messages: subscriber}) + defer service.Unsubscribe(subscription) // Simulate the reception of an event, but it is never handled due to unexpected exit <-time.NewTicker(500 * time.Millisecond).C }) - service.Subscribe(Subscription{Messages: subscriber}) + service.Subscribe(subscription) require.NoError(t, service.Start(context.Background(), group)) service.Resume() group.Wait() } + +type CallbackSubscriber struct { + handler EventHandler + n string +} + +func NewCallbackSubscriber(name string, handler EventHandler) *CallbackSubscriber { + return &CallbackSubscriber{handler: handler, n: name} +} + +func (c CallbackSubscriber) name() string { //nolint: unused + return c.n +} +func (c CallbackSubscriber) handle(ctx context.Context, t proton.Event) error { //nolint: unused + return c.handler.OnEvent(ctx, t) +} + +func (c CallbackSubscriber) cancel() { //nolint: unused + // Nothing to do. +} + +func (c CallbackSubscriber) close() { //nolint: unused + // Nothing to do. +} diff --git a/internal/services/userevents/subscribable.go b/internal/services/userevents/subscribable.go index 4a1526c6..aecf8265 100644 --- a/internal/services/userevents/subscribable.go +++ b/internal/services/userevents/subscribable.go @@ -19,16 +19,16 @@ package userevents // Subscribable represents a type that allows the registration of event subscribers. type Subscribable interface { - Subscribe(subscription Subscription) - Unsubscribe(subscription Subscription) + Subscribe(subscription EventSubscriber) + Unsubscribe(subscription EventSubscriber) } type NoOpSubscribable struct{} -func (n NoOpSubscribable) Subscribe(_ Subscription) { +func (n NoOpSubscribable) Subscribe(_ EventSubscriber) { // Does nothing } -func (n NoOpSubscribable) Unsubscribe(_ Subscription) { +func (n NoOpSubscribable) Unsubscribe(_ EventSubscriber) { // Does nothing } diff --git a/internal/services/userevents/subscriber.go b/internal/services/userevents/subscriber.go index 2dc3aea5..535435ac 100644 --- a/internal/services/userevents/subscriber.go +++ b/internal/services/userevents/subscriber.go @@ -31,43 +31,13 @@ import ( "golang.org/x/exp/slices" ) -type AddressChanneledSubscriber = ChanneledSubscriber[[]proton.AddressEvent] -type LabelChanneledSubscriber = ChanneledSubscriber[[]proton.LabelEvent] -type MessageChanneledSubscriber = ChanneledSubscriber[[]proton.MessageEvent] -type UserChanneledSubscriber = ChanneledSubscriber[proton.User] -type RefreshChanneledSubscriber = ChanneledSubscriber[proton.RefreshFlag] -type UserUsedSpaceChanneledSubscriber = ChanneledSubscriber[int] +type EventChanneledSubscriber = ChanneledSubscriber[proton.Event] -func NewMessageSubscriber(name string) *MessageChanneledSubscriber { - return newChanneledSubscriber[[]proton.MessageEvent](name) +func newSubscriber(name string) *EventChanneledSubscriber { + return newChanneledSubscriber[proton.Event](name) } -func NewAddressSubscriber(name string) *AddressChanneledSubscriber { - return newChanneledSubscriber[[]proton.AddressEvent](name) -} - -func NewLabelSubscriber(name string) *LabelChanneledSubscriber { - return newChanneledSubscriber[[]proton.LabelEvent](name) -} - -func NewRefreshSubscriber(name string) *RefreshChanneledSubscriber { - return newChanneledSubscriber[proton.RefreshFlag](name) -} - -func NewUserSubscriber(name string) *UserChanneledSubscriber { - return newChanneledSubscriber[proton.User](name) -} - -func NewUserUsedSpaceSubscriber(name string) *UserUsedSpaceChanneledSubscriber { - return newChanneledSubscriber[int](name) -} - -type AddressSubscriber = subscriber[[]proton.AddressEvent] -type LabelSubscriber = subscriber[[]proton.LabelEvent] -type MessageSubscriber = subscriber[[]proton.MessageEvent] -type RefreshSubscriber = subscriber[proton.RefreshFlag] -type UserSubscriber = subscriber[proton.User] -type UserUsedSpaceSubscriber = subscriber[int] +type EventSubscriber = subscriber[proton.Event] // Subscriber is the main entry point of interacting with user generated events. type subscriber[T any] interface { @@ -87,12 +57,7 @@ type subscriberList[T any] struct { subscribers []subscriber[T] } -type addressSubscriberList = subscriberList[[]proton.AddressEvent] -type labelSubscriberList = subscriberList[[]proton.LabelEvent] -type messageSubscriberList = subscriberList[[]proton.MessageEvent] -type refreshSubscriberList = subscriberList[proton.RefreshFlag] -type userSubscriberList = subscriberList[proton.User] -type userUsedSpaceSubscriberList = subscriberList[int] +type eventSubscriberList = subscriberList[proton.Event] func (s *subscriberList[T]) Add(subscriber subscriber[T]) { if !slices.Contains(s.subscribers, subscriber) { @@ -117,12 +82,7 @@ type publishError[T any] struct { var ErrPublishTimeoutExceeded = errors.New("event publish timed out") -type addressPublishError = publishError[[]proton.AddressEvent] -type labelPublishError = publishError[[]proton.LabelEvent] -type messagePublishError = publishError[[]proton.MessageEvent] -type refreshPublishError = publishError[proton.RefreshFlag] -type userPublishError = publishError[proton.User] -type userUsedEventPublishError = publishError[int] +type eventPublishError = publishError[proton.Event] func (p publishError[T]) Error() string { return fmt.Sprintf("Event publish failed on (%v): %v", p.subscriber.name(), p.error.Error()) diff --git a/internal/services/userevents/subscription.go b/internal/services/userevents/subscription.go new file mode 100644 index 00000000..b5f4ef85 --- /dev/null +++ b/internal/services/userevents/subscription.go @@ -0,0 +1,106 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package userevents + +import ( + "context" + "fmt" + + "github.com/ProtonMail/go-proton-api" +) + +type Subscription = EventSubscriber + +func NewEventSubscriber(name string) *EventChanneledSubscriber { + return newSubscriber(name) +} + +type EventHandler struct { + RefreshHandler RefreshEventHandler + AddressHandler AddressEventHandler + UserHandler UserEventHandler + LabelHandler LabelEventHandler + MessageHandler MessageEventHandler + UsedSpaceHandler UserUsedSpaceEventHandler +} + +func (e EventHandler) OnEvent(ctx context.Context, event proton.Event) error { + if event.Refresh&proton.RefreshMail != 0 && e.RefreshHandler != nil { + return e.RefreshHandler.HandleRefreshEvent(ctx, event.Refresh) + } + + // Start with user events. + if event.User != nil && e.UserHandler != nil { + if err := e.UserHandler.HandleUserEvent(ctx, event.User); err != nil { + return fmt.Errorf("failed to apply user event: %w", err) + } + } + + // Next Address events + if len(event.Addresses) != 0 && e.AddressHandler != nil { + if err := e.AddressHandler.HandleAddressEvents(ctx, event.Addresses); err != nil { + return fmt.Errorf("failed to apply address events: %w", err) + } + } + + // Next label events + if len(event.Labels) != 0 && e.LabelHandler != nil { + if err := e.LabelHandler.HandleLabelEvents(ctx, event.Labels); err != nil { + return fmt.Errorf("failed to apply label events: %w", err) + } + } + + // Next message events + if len(event.Messages) != 0 && e.MessageHandler != nil { + if err := e.MessageHandler.HandleMessageEvents(ctx, event.Messages); err != nil { + return fmt.Errorf("failed to apply message events: %w", err) + } + } + + // Finally user used space events + if event.UsedSpace != nil && e.UsedSpaceHandler != nil { + if err := e.UsedSpaceHandler.HandleUsedSpaceEvent(ctx, *event.UsedSpace); err != nil { + return fmt.Errorf("failed to apply message events: %w", err) + } + } + + return nil +} + +type RefreshEventHandler interface { + HandleRefreshEvent(ctx context.Context, flag proton.RefreshFlag) error +} +type UserEventHandler interface { + HandleUserEvent(ctx context.Context, user *proton.User) error +} + +type UserUsedSpaceEventHandler interface { + HandleUsedSpaceEvent(ctx context.Context, newSpace int) error +} + +type AddressEventHandler interface { + HandleAddressEvents(ctx context.Context, events []proton.AddressEvent) error +} + +type LabelEventHandler interface { + HandleLabelEvents(ctx context.Context, events []proton.LabelEvent) error +} + +type MessageEventHandler interface { + HandleMessageEvents(ctx context.Context, events []proton.MessageEvent) error +} diff --git a/internal/services/useridentity/service.go b/internal/services/useridentity/service.go index e2a042e2..a57c0a9a 100644 --- a/internal/services/useridentity/service.go +++ b/internal/services/useridentity/service.go @@ -47,10 +47,7 @@ type Service struct { log *logrus.Entry identity State - userSubscriber *userevents.UserChanneledSubscriber - addressSubscriber *userevents.AddressChanneledSubscriber - usedSpaceSubscriber *userevents.UserUsedSpaceChanneledSubscriber - refreshSubscriber *userevents.RefreshChanneledSubscriber + subscription *userevents.EventChanneledSubscriber bridgePassProvider BridgePassProvider telemetry Telemetry @@ -74,12 +71,9 @@ func NewService( "service": "user-identity", "user": state.User.ID, }), - userSubscriber: userevents.NewUserSubscriber(subscriberName), - refreshSubscriber: userevents.NewRefreshSubscriber(subscriberName), - addressSubscriber: userevents.NewAddressSubscriber(subscriberName), - usedSpaceSubscriber: userevents.NewUserUsedSpaceSubscriber(subscriberName), - bridgePassProvider: bridgePassProvider, - telemetry: telemetry, + subscription: userevents.NewEventSubscriber(subscriberName), + bridgePassProvider: bridgePassProvider, + telemetry: telemetry, } } @@ -108,134 +102,30 @@ func (s *Service) CheckAuth(ctx context.Context, email string, password []byte) }) } -func (s *Service) run(ctx context.Context) { - s.log.WithFields(logrus.Fields{ - "numAddr": len(s.identity.Addresses), - }).Info("Starting user identity service") - defer s.log.Info("Exiting Service") +func (s *Service) HandleUsedSpaceEvent(ctx context.Context, newSpace int) error { + s.log.Info("Handling User Space Changed event") - s.registerSubscription() - defer s.unregisterSubscription() - - defer s.cpc.Close() - - for { - select { - case <-ctx.Done(): - return - case r, ok := <-s.cpc.ReceiveCh(): - if !ok { - continue - } - switch req := r.Value().(type) { - case *resyncReq: - err := s.identity.OnRefreshEvent(ctx) - r.Reply(ctx, nil, err) - - case *getUserReq: - r.Reply(ctx, s.identity.User, nil) - - case *getAddressesReq: - r.Reply(ctx, maps.Clone(s.identity.Addresses), nil) - - case *checkAuthReq: - id, err := s.identity.CheckAuth(req.email, req.password, s.bridgePassProvider, s.telemetry) - r.Reply(ctx, id, err) - - default: - s.log.Error("Invalid request") - } - - case evt, ok := <-s.userSubscriber.OnEventCh(): - if !ok { - continue - } - evt.Consume(func(user proton.User) error { - s.onUserEvent(ctx, user) - return nil - }) - case evt, ok := <-s.refreshSubscriber.OnEventCh(): - if !ok { - continue - } - evt.Consume(func(_ proton.RefreshFlag) error { - return s.onRefreshEvent(ctx) - }) - case evt, ok := <-s.usedSpaceSubscriber.OnEventCh(): - if !ok { - continue - } - evt.Consume(func(usedSpace int) error { - s.onUserSpaceChanged(ctx, usedSpace) - - return nil - }) - case evt, ok := <-s.addressSubscriber.OnEventCh(): - if !ok { - continue - } - evt.Consume(func(events []proton.AddressEvent) error { - return s.onAddressEvent(ctx, events) - }) - } + if s.identity.OnUserSpaceChanged(newSpace) { + s.eventPublisher.PublishEvent(ctx, events.UsedSpaceChanged{ + UserID: s.identity.User.ID, + UsedSpace: newSpace, + }) } + + return nil } -func (s *Service) registerSubscription() { - s.eventService.Subscribe(userevents.Subscription{ - Refresh: s.refreshSubscriber, - User: s.userSubscriber, - Address: s.addressSubscriber, - UserUsedSpace: s.usedSpaceSubscriber, - }) -} - -func (s *Service) unregisterSubscription() { - s.eventService.Unsubscribe(userevents.Subscription{ - Refresh: s.refreshSubscriber, - User: s.userSubscriber, - Address: s.addressSubscriber, - UserUsedSpace: s.usedSpaceSubscriber, - }) -} - -func (s *Service) onUserEvent(ctx context.Context, user proton.User) { +func (s *Service) HandleUserEvent(ctx context.Context, user *proton.User) error { s.log.WithField("username", logging.Sensitive(user.Name)).Info("Handling user event") - s.identity.OnUserEvent(user) + s.identity.OnUserEvent(*user) s.eventPublisher.PublishEvent(ctx, events.UserChanged{ UserID: user.ID, }) -} - -func (s *Service) onRefreshEvent(ctx context.Context) error { - s.log.Info("Handling refresh event") - - if err := s.identity.OnRefreshEvent(ctx); err != nil { - s.log.WithError(err).Error("Failed to handle refresh event") - return err - } - - s.eventPublisher.PublishEvent(ctx, events.UserRefreshed{ - UserID: s.identity.User.ID, - CancelEventPool: false, - }) return nil } -func (s *Service) onUserSpaceChanged(ctx context.Context, value int) { - s.log.Info("Handling User Space Changed event") - if !s.identity.OnUserSpaceChanged(value) { - return - } - - s.eventPublisher.PublishEvent(ctx, events.UsedSpaceChanged{ - UserID: s.identity.User.ID, - UsedSpace: value, - }) -} - -func (s *Service) onAddressEvent(ctx context.Context, addressEvents []proton.AddressEvent) error { +func (s *Service) HandleAddressEvents(ctx context.Context, addressEvents []proton.AddressEvent) error { s.log.Infof("Handling Address Events (%v)", len(addressEvents)) for idx, event := range addressEvents { @@ -305,6 +195,86 @@ func (s *Service) onAddressEvent(ctx context.Context, addressEvents []proton.Add return nil } +func (s *Service) HandleRefreshEvent(ctx context.Context, _ proton.RefreshFlag) error { + s.log.Info("Handling refresh event") + + if err := s.identity.OnRefreshEvent(ctx); err != nil { + s.log.WithError(err).Error("Failed to handle refresh event") + return err + } + + s.eventPublisher.PublishEvent(ctx, events.UserRefreshed{ + UserID: s.identity.User.ID, + CancelEventPool: false, + }) + + return nil +} + +func (s *Service) run(ctx context.Context) { + s.log.WithFields(logrus.Fields{ + "numAddr": len(s.identity.Addresses), + }).Info("Starting user identity service") + defer s.log.Info("Exiting Service") + + eventHandler := userevents.EventHandler{ + UserHandler: s, + AddressHandler: s, + UsedSpaceHandler: s, + RefreshHandler: s, + } + + s.registerSubscription() + defer s.unregisterSubscription() + + defer s.cpc.Close() + + for { + select { + case <-ctx.Done(): + return + case r, ok := <-s.cpc.ReceiveCh(): + if !ok { + continue + } + switch req := r.Value().(type) { + case *resyncReq: + err := s.identity.OnRefreshEvent(ctx) + r.Reply(ctx, nil, err) + + case *getUserReq: + r.Reply(ctx, s.identity.User, nil) + + case *getAddressesReq: + r.Reply(ctx, maps.Clone(s.identity.Addresses), nil) + + case *checkAuthReq: + id, err := s.identity.CheckAuth(req.email, req.password, s.bridgePassProvider, s.telemetry) + r.Reply(ctx, id, err) + + default: + s.log.Error("Invalid request") + } + + case evt, ok := <-s.subscription.OnEventCh(): + if !ok { + continue + } + evt.Consume(func(event proton.Event) error { + return eventHandler.OnEvent(ctx, event) + }) + } + } +} + +func (s *Service) registerSubscription() { + s.eventService.Subscribe(s.subscription) +} + +func (s *Service) unregisterSubscription() { + s.eventService.Unsubscribe(s.subscription) +} + func sortAddresses(addr []proton.Address) []proton.Address { slices.SortFunc(addr, func(a, b proton.Address) bool { return a.Order < b.Order diff --git a/internal/services/useridentity/service_test.go b/internal/services/useridentity/service_test.go index 7c627b1b..8e9a5a92 100644 --- a/internal/services/useridentity/service_test.go +++ b/internal/services/useridentity/service_test.go @@ -39,7 +39,7 @@ func TestService_OnUserEvent(t *testing.T) { eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserChanged{UserID: TestUserID})).Times(1) - service.onUserEvent(context.Background(), newTestUser()) + require.NoError(t, service.HandleUserEvent(context.Background(), newTestUser())) } func TestService_OnUserSpaceChanged(t *testing.T) { @@ -50,10 +50,10 @@ func TestService_OnUserSpaceChanged(t *testing.T) { eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UsedSpaceChanged{UserID: TestUserID, UsedSpace: 1024})).Times(1) // Original value, no changes. - service.onUserSpaceChanged(context.Background(), 0) + require.NoError(t, service.HandleUsedSpaceEvent(context.Background(), 0)) // New value, event should be published. - service.onUserSpaceChanged(context.Background(), 1024) + require.NoError(t, service.HandleUsedSpaceEvent(context.Background(), 1024)) require.Equal(t, 1024, service.identity.User.UsedSpace) } @@ -68,14 +68,14 @@ func TestService_OnRefreshEvent(t *testing.T) { newAddresses := newTestAddressesRefreshed() { - getUserCall := provider.EXPECT().GetUser(gomock.Any()).Times(1).Return(newUser, nil) + getUserCall := provider.EXPECT().GetUser(gomock.Any()).Times(1).Return(*newUser, nil) provider.EXPECT().GetAddresses(gomock.Any()).After(getUserCall).Times(1).Return(newAddresses, nil) } // Original value, no changes. - require.NoError(t, service.onRefreshEvent(context.Background())) + require.NoError(t, service.HandleRefreshEvent(context.Background(), 0)) - require.Equal(t, newUser, service.identity.User) + require.Equal(t, *newUser, service.identity.User) require.Equal(t, newAddresses, service.identity.AddressesSorted) } @@ -96,7 +96,7 @@ func TestService_OnAddressCreated(t *testing.T) { Email: newAddress.Email, })).Times(1) - err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + err := service.HandleAddressEvents(context.Background(), []proton.AddressEvent{ { EventItem: proton.EventItem{ ID: "", @@ -121,7 +121,7 @@ func TestService_OnAddressCreatedDisabledDoesNotProduceEvent(t *testing.T) { Status: proton.AddressStatusEnabled, } - err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + err := service.HandleAddressEvents(context.Background(), []proton.AddressEvent{ { EventItem: proton.EventItem{ ID: "", @@ -146,7 +146,7 @@ func TestService_OnAddressCreatedDuplicateDoesNotProduceEvent(t *testing.T) { Status: proton.AddressStatusDisabled, } - err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + err := service.HandleAddressEvents(context.Background(), []proton.AddressEvent{ { EventItem: proton.EventItem{ ID: "", @@ -177,7 +177,7 @@ func TestService_OnAddressUpdated(t *testing.T) { Email: newAddress.Email, })).Times(1) - err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + err := service.HandleAddressEvents(context.Background(), []proton.AddressEvent{ { EventItem: proton.EventItem{ ID: "", @@ -221,7 +221,7 @@ func TestService_OnAddressUpdatedDisableFollowedByEnable(t *testing.T) { })).Times(1).After(disabledCall) } - err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + err := service.HandleAddressEvents(context.Background(), []proton.AddressEvent{ { EventItem: proton.EventItem{ ID: "", @@ -234,7 +234,7 @@ func TestService_OnAddressUpdatedDisableFollowedByEnable(t *testing.T) { require.Equal(t, newAddressDisabled, service.identity.Addresses[newAddressEnabled.ID]) - err = service.onAddressEvent(context.Background(), []proton.AddressEvent{ + err = service.HandleAddressEvents(context.Background(), []proton.AddressEvent{ { EventItem: proton.EventItem{ ID: "", @@ -265,7 +265,7 @@ func TestService_OnAddressUpdateCreatedIfNotExists(t *testing.T) { Email: newAddress.Email, })).Times(1) - err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + err := service.HandleAddressEvents(context.Background(), []proton.AddressEvent{ { EventItem: proton.EventItem{ ID: "", @@ -296,7 +296,7 @@ func TestService_OnAddressDeleted(t *testing.T) { Email: address.Email, })).Times(1) - err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + err := service.HandleAddressEvents(context.Background(), []proton.AddressEvent{ { EventItem: proton.EventItem{ ID: address.ID, @@ -320,7 +320,7 @@ func TestService_OnAddressDeleteDisabledDoesNotProduceEvent(t *testing.T) { Status: proton.AddressStatusDisabled, } - err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + err := service.HandleAddressEvents(context.Background(), []proton.AddressEvent{ { EventItem: proton.EventItem{ ID: address.ID, @@ -344,7 +344,7 @@ func TestService_OnAddressDeletedUnknownDoesNotProduceEvent(t *testing.T) { Status: proton.AddressStatusEnabled, } - err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + err := service.HandleAddressEvents(context.Background(), []proton.AddressEvent{ { EventItem: proton.EventItem{ ID: address.ID, @@ -364,12 +364,12 @@ func newTestService(_ *testing.T, mockCtrl *gomock.Controller) (*Service, *mocks telemetry := mocks.NewMockTelemetry(mockCtrl) bridgePassProvider := NewFixedBridgePassProvider([]byte("hello")) - service := NewService(subscribable, eventPublisher, NewState(user, newTestAddresses(), provider), bridgePassProvider, telemetry) + service := NewService(subscribable, eventPublisher, NewState(*user, newTestAddresses(), provider), bridgePassProvider, telemetry) return service, eventPublisher, provider } -func newTestUser() proton.User { - return proton.User{ +func newTestUser() *proton.User { + return &proton.User{ ID: TestUserID, Name: "Foo", DisplayName: "Foo", @@ -383,8 +383,8 @@ func newTestUser() proton.User { } } -func newTestUserRefreshed() proton.User { - return proton.User{ +func newTestUserRefreshed() *proton.User { + return &proton.User{ ID: TestUserID, Name: "Alternate", DisplayName: "Universe",