diff --git a/Makefile b/Makefile
index 5f43fe5a..0e01fe2a 100644
--- a/Makefile
+++ b/Makefile
@@ -281,6 +281,13 @@ mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/telemetry HeartbeatManager > internal/telemetry/mocks/mocks.go
cp internal/telemetry/mocks/mocks.go internal/bridge/mocks/telemetry_mocks.go
+ mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \
+EventSource,EventIDStore > internal/services/userevents/mocks/mocks.go
+ mockgen --package userevents github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \
+MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscriber,UserUsedSpaceSubscriber > tmp
+ mv tmp internal/services/userevents/mocks_test.go
+ mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/events EventPublisher \
+> internal/events/mocks/mocks.go
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog
diff --git a/internal/events/events.go b/internal/events/events.go
index ce27e2ea..84933cc9 100644
--- a/internal/events/events.go
+++ b/internal/events/events.go
@@ -17,7 +17,10 @@
package events
-import "fmt"
+import (
+ "context"
+ "fmt"
+)
type Event interface {
fmt.Stringer
@@ -28,3 +31,7 @@ type Event interface {
type eventBase struct{}
func (eventBase) _isEvent() {}
+
+type EventPublisher interface {
+ PublishEvent(ctx context.Context, event Event)
+}
diff --git a/internal/events/mocks/mocks.go b/internal/events/mocks/mocks.go
new file mode 100644
index 00000000..2da49cdf
--- /dev/null
+++ b/internal/events/mocks/mocks.go
@@ -0,0 +1,48 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/ProtonMail/proton-bridge/v3/internal/events (interfaces: EventPublisher)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+ context "context"
+ reflect "reflect"
+
+ events "github.com/ProtonMail/proton-bridge/v3/internal/events"
+ gomock "github.com/golang/mock/gomock"
+)
+
+// MockEventPublisher is a mock of EventPublisher interface.
+type MockEventPublisher struct {
+ ctrl *gomock.Controller
+ recorder *MockEventPublisherMockRecorder
+}
+
+// MockEventPublisherMockRecorder is the mock recorder for MockEventPublisher.
+type MockEventPublisherMockRecorder struct {
+ mock *MockEventPublisher
+}
+
+// NewMockEventPublisher creates a new mock instance.
+func NewMockEventPublisher(ctrl *gomock.Controller) *MockEventPublisher {
+ mock := &MockEventPublisher{ctrl: ctrl}
+ mock.recorder = &MockEventPublisherMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockEventPublisher) EXPECT() *MockEventPublisherMockRecorder {
+ return m.recorder
+}
+
+// PublishEvent mocks base method.
+func (m *MockEventPublisher) PublishEvent(arg0 context.Context, arg1 events.Event) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "PublishEvent", arg0, arg1)
+}
+
+// PublishEvent indicates an expected call of PublishEvent.
+func (mr *MockEventPublisherMockRecorder) PublishEvent(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishEvent", reflect.TypeOf((*MockEventPublisher)(nil).PublishEvent), arg0, arg1)
+}
diff --git a/internal/services/userevents/event_source.go b/internal/services/userevents/event_source.go
new file mode 100644
index 00000000..098c7605
--- /dev/null
+++ b/internal/services/userevents/event_source.go
@@ -0,0 +1,40 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package userevents
+
+import (
+ "context"
+
+ "github.com/ProtonMail/go-proton-api"
+)
+
+// EventSource represents a type which produces proton events.
+type EventSource interface {
+ GetLatestEventID(ctx context.Context) (string, error)
+ GetEvent(ctx context.Context, id string) ([]proton.Event, bool, error)
+}
+
+type NullEventSource struct{}
+
+func (n NullEventSource) GetLatestEventID(_ context.Context) (string, error) {
+ return "", nil
+}
+
+func (n NullEventSource) GetEvent(_ context.Context, _ string) ([]proton.Event, bool, error) {
+ return nil, false, nil
+}
diff --git a/internal/services/userevents/eventid_store.go b/internal/services/userevents/eventid_store.go
new file mode 100644
index 00000000..ce4e3852
--- /dev/null
+++ b/internal/services/userevents/eventid_store.go
@@ -0,0 +1,75 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package userevents
+
+import (
+ "context"
+ "sync"
+
+ "github.com/ProtonMail/proton-bridge/v3/internal/vault"
+)
+
+// EventIDStore exposes behavior expected of a type which allows us to store and retrieve event Ids.
+// Note: this may be accessed from multiple go-routines.
+type EventIDStore interface {
+ // Load the last stored event, return "" for empty.
+ Load(ctx context.Context) (string, error)
+ // Store the new id.
+ Store(ctx context.Context, id string) error
+}
+
+type InMemoryEventIDStore struct {
+ lock sync.Mutex
+ id string
+}
+
+func NewInMemoryEventIDStore() *InMemoryEventIDStore {
+ return &InMemoryEventIDStore{}
+}
+
+func (i *InMemoryEventIDStore) Load(_ context.Context) (string, error) {
+ i.lock.Lock()
+ defer i.lock.Unlock()
+
+ return i.id, nil
+}
+
+func (i *InMemoryEventIDStore) Store(_ context.Context, id string) error {
+ i.lock.Lock()
+ defer i.lock.Unlock()
+
+ i.id = id
+
+ return nil
+}
+
+type VaultEventIDStore struct {
+ vault *vault.User
+}
+
+func NewVaultEventIDStore(vault *VaultEventIDStore) *VaultEventIDStore {
+ return &VaultEventIDStore{vault: vault.vault}
+}
+
+func (v VaultEventIDStore) Load(_ context.Context) (string, error) {
+ return v.vault.EventID(), nil
+}
+
+func (v VaultEventIDStore) Store(_ context.Context, id string) error {
+ return v.vault.SetEventID(id)
+}
diff --git a/internal/services/userevents/mocks/mocks.go b/internal/services/userevents/mocks/mocks.go
new file mode 100644
index 00000000..b0dd345d
--- /dev/null
+++ b/internal/services/userevents/mocks/mocks.go
@@ -0,0 +1,119 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/userevents (interfaces: EventSource,EventIDStore)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+ context "context"
+ reflect "reflect"
+
+ proton "github.com/ProtonMail/go-proton-api"
+ gomock "github.com/golang/mock/gomock"
+)
+
+// MockEventSource is a mock of EventSource interface.
+type MockEventSource struct {
+ ctrl *gomock.Controller
+ recorder *MockEventSourceMockRecorder
+}
+
+// MockEventSourceMockRecorder is the mock recorder for MockEventSource.
+type MockEventSourceMockRecorder struct {
+ mock *MockEventSource
+}
+
+// NewMockEventSource creates a new mock instance.
+func NewMockEventSource(ctrl *gomock.Controller) *MockEventSource {
+ mock := &MockEventSource{ctrl: ctrl}
+ mock.recorder = &MockEventSourceMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockEventSource) EXPECT() *MockEventSourceMockRecorder {
+ return m.recorder
+}
+
+// GetEvent mocks base method.
+func (m *MockEventSource) GetEvent(arg0 context.Context, arg1 string) ([]proton.Event, bool, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetEvent", arg0, arg1)
+ ret0, _ := ret[0].([]proton.Event)
+ ret1, _ := ret[1].(bool)
+ ret2, _ := ret[2].(error)
+ return ret0, ret1, ret2
+}
+
+// GetEvent indicates an expected call of GetEvent.
+func (mr *MockEventSourceMockRecorder) GetEvent(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvent", reflect.TypeOf((*MockEventSource)(nil).GetEvent), arg0, arg1)
+}
+
+// GetLatestEventID mocks base method.
+func (m *MockEventSource) GetLatestEventID(arg0 context.Context) (string, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetLatestEventID", arg0)
+ ret0, _ := ret[0].(string)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetLatestEventID indicates an expected call of GetLatestEventID.
+func (mr *MockEventSourceMockRecorder) GetLatestEventID(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestEventID", reflect.TypeOf((*MockEventSource)(nil).GetLatestEventID), arg0)
+}
+
+// MockEventIDStore is a mock of EventIDStore interface.
+type MockEventIDStore struct {
+ ctrl *gomock.Controller
+ recorder *MockEventIDStoreMockRecorder
+}
+
+// MockEventIDStoreMockRecorder is the mock recorder for MockEventIDStore.
+type MockEventIDStoreMockRecorder struct {
+ mock *MockEventIDStore
+}
+
+// NewMockEventIDStore creates a new mock instance.
+func NewMockEventIDStore(ctrl *gomock.Controller) *MockEventIDStore {
+ mock := &MockEventIDStore{ctrl: ctrl}
+ mock.recorder = &MockEventIDStoreMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockEventIDStore) EXPECT() *MockEventIDStoreMockRecorder {
+ return m.recorder
+}
+
+// Load mocks base method.
+func (m *MockEventIDStore) Load(arg0 context.Context) (string, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Load", arg0)
+ ret0, _ := ret[0].(string)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Load indicates an expected call of Load.
+func (mr *MockEventIDStoreMockRecorder) Load(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockEventIDStore)(nil).Load), arg0)
+}
+
+// Store mocks base method.
+func (m *MockEventIDStore) Store(arg0 context.Context, arg1 string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Store", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Store indicates an expected call of Store.
+func (mr *MockEventIDStoreMockRecorder) Store(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Store", reflect.TypeOf((*MockEventIDStore)(nil).Store), arg0, arg1)
+}
diff --git a/internal/services/userevents/mocks_test.go b/internal/services/userevents/mocks_test.go
new file mode 100644
index 00000000..dbf7a736
--- /dev/null
+++ b/internal/services/userevents/mocks_test.go
@@ -0,0 +1,463 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/userevents (interfaces: MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscriber,UserUsedSpaceSubscriber)
+
+// Package userevents is a generated GoMock package.
+package userevents
+
+import (
+ context "context"
+ reflect "reflect"
+
+ proton "github.com/ProtonMail/go-proton-api"
+ gomock "github.com/golang/mock/gomock"
+)
+
+// MockMessageSubscriber is a mock of MessageSubscriber interface.
+type MockMessageSubscriber struct {
+ ctrl *gomock.Controller
+ recorder *MockMessageSubscriberMockRecorder
+}
+
+// MockMessageSubscriberMockRecorder is the mock recorder for MockMessageSubscriber.
+type MockMessageSubscriberMockRecorder struct {
+ mock *MockMessageSubscriber
+}
+
+// NewMockMessageSubscriber creates a new mock instance.
+func NewMockMessageSubscriber(ctrl *gomock.Controller) *MockMessageSubscriber {
+ mock := &MockMessageSubscriber{ctrl: ctrl}
+ mock.recorder = &MockMessageSubscriberMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockMessageSubscriber) EXPECT() *MockMessageSubscriberMockRecorder {
+ return m.recorder
+}
+
+// cancel mocks base method.
+func (m *MockMessageSubscriber) cancel() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "cancel")
+}
+
+// cancel indicates an expected call of cancel.
+func (mr *MockMessageSubscriberMockRecorder) cancel() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockMessageSubscriber)(nil).cancel))
+}
+
+// close mocks base method.
+func (m *MockMessageSubscriber) close() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "close")
+}
+
+// close indicates an expected call of close.
+func (mr *MockMessageSubscriberMockRecorder) close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockMessageSubscriber)(nil).close))
+}
+
+// handle mocks base method.
+func (m *MockMessageSubscriber) handle(arg0 context.Context, arg1 []proton.MessageEvent) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "handle", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// handle indicates an expected call of handle.
+func (mr *MockMessageSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockMessageSubscriber)(nil).handle), arg0, arg1)
+}
+
+// name mocks base method.
+func (m *MockMessageSubscriber) name() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "name")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// name indicates an expected call of name.
+func (mr *MockMessageSubscriberMockRecorder) name() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockMessageSubscriber)(nil).name))
+}
+
+// MockLabelSubscriber is a mock of LabelSubscriber interface.
+type MockLabelSubscriber struct {
+ ctrl *gomock.Controller
+ recorder *MockLabelSubscriberMockRecorder
+}
+
+// MockLabelSubscriberMockRecorder is the mock recorder for MockLabelSubscriber.
+type MockLabelSubscriberMockRecorder struct {
+ mock *MockLabelSubscriber
+}
+
+// NewMockLabelSubscriber creates a new mock instance.
+func NewMockLabelSubscriber(ctrl *gomock.Controller) *MockLabelSubscriber {
+ mock := &MockLabelSubscriber{ctrl: ctrl}
+ mock.recorder = &MockLabelSubscriberMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockLabelSubscriber) EXPECT() *MockLabelSubscriberMockRecorder {
+ return m.recorder
+}
+
+// cancel mocks base method.
+func (m *MockLabelSubscriber) cancel() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "cancel")
+}
+
+// cancel indicates an expected call of cancel.
+func (mr *MockLabelSubscriberMockRecorder) cancel() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockLabelSubscriber)(nil).cancel))
+}
+
+// close mocks base method.
+func (m *MockLabelSubscriber) close() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "close")
+}
+
+// close indicates an expected call of close.
+func (mr *MockLabelSubscriberMockRecorder) close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockLabelSubscriber)(nil).close))
+}
+
+// handle mocks base method.
+func (m *MockLabelSubscriber) handle(arg0 context.Context, arg1 []proton.LabelEvent) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "handle", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// handle indicates an expected call of handle.
+func (mr *MockLabelSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockLabelSubscriber)(nil).handle), arg0, arg1)
+}
+
+// name mocks base method.
+func (m *MockLabelSubscriber) name() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "name")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// name indicates an expected call of name.
+func (mr *MockLabelSubscriberMockRecorder) name() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockLabelSubscriber)(nil).name))
+}
+
+// MockAddressSubscriber is a mock of AddressSubscriber interface.
+type MockAddressSubscriber struct {
+ ctrl *gomock.Controller
+ recorder *MockAddressSubscriberMockRecorder
+}
+
+// MockAddressSubscriberMockRecorder is the mock recorder for MockAddressSubscriber.
+type MockAddressSubscriberMockRecorder struct {
+ mock *MockAddressSubscriber
+}
+
+// NewMockAddressSubscriber creates a new mock instance.
+func NewMockAddressSubscriber(ctrl *gomock.Controller) *MockAddressSubscriber {
+ mock := &MockAddressSubscriber{ctrl: ctrl}
+ mock.recorder = &MockAddressSubscriberMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockAddressSubscriber) EXPECT() *MockAddressSubscriberMockRecorder {
+ return m.recorder
+}
+
+// cancel mocks base method.
+func (m *MockAddressSubscriber) cancel() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "cancel")
+}
+
+// cancel indicates an expected call of cancel.
+func (mr *MockAddressSubscriberMockRecorder) cancel() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockAddressSubscriber)(nil).cancel))
+}
+
+// close mocks base method.
+func (m *MockAddressSubscriber) close() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "close")
+}
+
+// close indicates an expected call of close.
+func (mr *MockAddressSubscriberMockRecorder) close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockAddressSubscriber)(nil).close))
+}
+
+// handle mocks base method.
+func (m *MockAddressSubscriber) handle(arg0 context.Context, arg1 []proton.AddressEvent) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "handle", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// handle indicates an expected call of handle.
+func (mr *MockAddressSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockAddressSubscriber)(nil).handle), arg0, arg1)
+}
+
+// name mocks base method.
+func (m *MockAddressSubscriber) name() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "name")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// name indicates an expected call of name.
+func (mr *MockAddressSubscriberMockRecorder) name() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockAddressSubscriber)(nil).name))
+}
+
+// MockRefreshSubscriber is a mock of RefreshSubscriber interface.
+type MockRefreshSubscriber struct {
+ ctrl *gomock.Controller
+ recorder *MockRefreshSubscriberMockRecorder
+}
+
+// MockRefreshSubscriberMockRecorder is the mock recorder for MockRefreshSubscriber.
+type MockRefreshSubscriberMockRecorder struct {
+ mock *MockRefreshSubscriber
+}
+
+// NewMockRefreshSubscriber creates a new mock instance.
+func NewMockRefreshSubscriber(ctrl *gomock.Controller) *MockRefreshSubscriber {
+ mock := &MockRefreshSubscriber{ctrl: ctrl}
+ mock.recorder = &MockRefreshSubscriberMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockRefreshSubscriber) EXPECT() *MockRefreshSubscriberMockRecorder {
+ return m.recorder
+}
+
+// cancel mocks base method.
+func (m *MockRefreshSubscriber) cancel() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "cancel")
+}
+
+// cancel indicates an expected call of cancel.
+func (mr *MockRefreshSubscriberMockRecorder) cancel() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockRefreshSubscriber)(nil).cancel))
+}
+
+// close mocks base method.
+func (m *MockRefreshSubscriber) close() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "close")
+}
+
+// close indicates an expected call of close.
+func (mr *MockRefreshSubscriberMockRecorder) close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockRefreshSubscriber)(nil).close))
+}
+
+// handle mocks base method.
+func (m *MockRefreshSubscriber) handle(arg0 context.Context, arg1 proton.RefreshFlag) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "handle", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// handle indicates an expected call of handle.
+func (mr *MockRefreshSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockRefreshSubscriber)(nil).handle), arg0, arg1)
+}
+
+// name mocks base method.
+func (m *MockRefreshSubscriber) name() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "name")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// name indicates an expected call of name.
+func (mr *MockRefreshSubscriberMockRecorder) name() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockRefreshSubscriber)(nil).name))
+}
+
+// MockUserSubscriber is a mock of UserSubscriber interface.
+type MockUserSubscriber struct {
+ ctrl *gomock.Controller
+ recorder *MockUserSubscriberMockRecorder
+}
+
+// MockUserSubscriberMockRecorder is the mock recorder for MockUserSubscriber.
+type MockUserSubscriberMockRecorder struct {
+ mock *MockUserSubscriber
+}
+
+// NewMockUserSubscriber creates a new mock instance.
+func NewMockUserSubscriber(ctrl *gomock.Controller) *MockUserSubscriber {
+ mock := &MockUserSubscriber{ctrl: ctrl}
+ mock.recorder = &MockUserSubscriberMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockUserSubscriber) EXPECT() *MockUserSubscriberMockRecorder {
+ return m.recorder
+}
+
+// cancel mocks base method.
+func (m *MockUserSubscriber) cancel() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "cancel")
+}
+
+// cancel indicates an expected call of cancel.
+func (mr *MockUserSubscriberMockRecorder) cancel() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockUserSubscriber)(nil).cancel))
+}
+
+// close mocks base method.
+func (m *MockUserSubscriber) close() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "close")
+}
+
+// close indicates an expected call of close.
+func (mr *MockUserSubscriberMockRecorder) close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockUserSubscriber)(nil).close))
+}
+
+// handle mocks base method.
+func (m *MockUserSubscriber) handle(arg0 context.Context, arg1 proton.User) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "handle", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// handle indicates an expected call of handle.
+func (mr *MockUserSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockUserSubscriber)(nil).handle), arg0, arg1)
+}
+
+// name mocks base method.
+func (m *MockUserSubscriber) name() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "name")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// name indicates an expected call of name.
+func (mr *MockUserSubscriberMockRecorder) name() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockUserSubscriber)(nil).name))
+}
+
+// MockUserUsedSpaceSubscriber is a mock of UserUsedSpaceSubscriber interface.
+type MockUserUsedSpaceSubscriber struct {
+ ctrl *gomock.Controller
+ recorder *MockUserUsedSpaceSubscriberMockRecorder
+}
+
+// MockUserUsedSpaceSubscriberMockRecorder is the mock recorder for MockUserUsedSpaceSubscriber.
+type MockUserUsedSpaceSubscriberMockRecorder struct {
+ mock *MockUserUsedSpaceSubscriber
+}
+
+// NewMockUserUsedSpaceSubscriber creates a new mock instance.
+func NewMockUserUsedSpaceSubscriber(ctrl *gomock.Controller) *MockUserUsedSpaceSubscriber {
+ mock := &MockUserUsedSpaceSubscriber{ctrl: ctrl}
+ mock.recorder = &MockUserUsedSpaceSubscriberMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockUserUsedSpaceSubscriber) EXPECT() *MockUserUsedSpaceSubscriberMockRecorder {
+ return m.recorder
+}
+
+// cancel mocks base method.
+func (m *MockUserUsedSpaceSubscriber) cancel() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "cancel")
+}
+
+// cancel indicates an expected call of cancel.
+func (mr *MockUserUsedSpaceSubscriberMockRecorder) cancel() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).cancel))
+}
+
+// close mocks base method.
+func (m *MockUserUsedSpaceSubscriber) close() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "close")
+}
+
+// close indicates an expected call of close.
+func (mr *MockUserUsedSpaceSubscriberMockRecorder) close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).close))
+}
+
+// handle mocks base method.
+func (m *MockUserUsedSpaceSubscriber) handle(arg0 context.Context, arg1 int) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "handle", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// handle indicates an expected call of handle.
+func (mr *MockUserUsedSpaceSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).handle), arg0, arg1)
+}
+
+// name mocks base method.
+func (m *MockUserUsedSpaceSubscriber) name() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "name")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// name indicates an expected call of name.
+func (mr *MockUserUsedSpaceSubscriberMockRecorder) name() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).name))
+}
diff --git a/internal/services/userevents/service.go b/internal/services/userevents/service.go
new file mode 100644
index 00000000..560f3575
--- /dev/null
+++ b/internal/services/userevents/service.go
@@ -0,0 +1,486 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package userevents
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/ProtonMail/gluon/async"
+ "github.com/ProtonMail/go-proton-api"
+ "github.com/ProtonMail/proton-bridge/v3/internal"
+ "github.com/ProtonMail/proton-bridge/v3/internal/events"
+ "github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
+ "github.com/sirupsen/logrus"
+)
+
+// Service polls from the given event source and ensures that all the respective subscribers get notified
+// before proceeding to the next event. The events are published in the following order:
+// * Refresh
+// * User
+// * Address
+// * Label
+// * Message
+// * UserUsedSpace
+// By default this service starts paused, you need to call `Service.Resume` at least one time to begin event polling.
+type Service struct {
+ userID string
+ cpc *cpc.CPC
+ eventSource EventSource
+ eventIDStore EventIDStore
+ log *logrus.Entry
+ eventPublisher events.EventPublisher
+ timer *time.Ticker
+ eventTimeout time.Duration
+ paused bool
+ panicHandler async.PanicHandler
+
+ userSubscriberList userSubscriberList
+ addressSubscribers addressSubscriberList
+ labelSubscribers labelSubscriberList
+ messageSubscribers messageSubscriberList
+ refreshSubscribers refreshSubscriberList
+ userUsedSpaceSubscriber userUsedSpaceSubscriberList
+
+ pendingSubscriptionsLock sync.Mutex
+ pendingSubscriptionsAdd []Subscription
+ pendingSubscriptionsRemove []Subscription
+}
+
+func NewService(
+ userID string,
+ eventSource EventSource,
+ store EventIDStore,
+ eventPublisher events.EventPublisher,
+ pollPeriod time.Duration,
+ eventTimeout time.Duration,
+ panicHandler async.PanicHandler,
+) *Service {
+ return &Service{
+ userID: userID,
+ cpc: cpc.NewCPC(),
+ eventSource: eventSource,
+ eventIDStore: store,
+ log: logrus.WithFields(logrus.Fields{
+ "service": "user-events",
+ "user": userID,
+ }),
+ eventPublisher: eventPublisher,
+ timer: time.NewTicker(pollPeriod),
+ paused: true,
+ eventTimeout: eventTimeout,
+ panicHandler: panicHandler,
+ }
+}
+
+type Subscription struct {
+ User UserSubscriber
+ Refresh RefreshSubscriber
+ Address AddressSubscriber
+ Labels LabelSubscriber
+ Messages MessageSubscriber
+ UserUsedSpace UserUsedSpaceSubscriber
+}
+
+// cancel subscription subscribers if applicable, see `subscriber.cancel` for more information.
+func (s Subscription) cancel() {
+ if s.User != nil {
+ s.User.cancel()
+ }
+ if s.Refresh != nil {
+ s.Refresh.cancel()
+ }
+ if s.Address != nil {
+ s.Address.cancel()
+ }
+ if s.Labels != nil {
+ s.Labels.cancel()
+ }
+ if s.Messages != nil {
+ s.Messages.cancel()
+ }
+ if s.UserUsedSpace != nil {
+ s.UserUsedSpace.cancel()
+ }
+}
+
+// Subscribe adds new subscribers to the service.
+// This method can safely be called during event handling.
+func (s *Service) Subscribe(subscription Subscription) {
+ s.pendingSubscriptionsLock.Lock()
+ defer s.pendingSubscriptionsLock.Unlock()
+
+ s.pendingSubscriptionsAdd = append(s.pendingSubscriptionsAdd, subscription)
+}
+
+// Unsubscribe removes subscribers from the service.
+// This method can safely be called during event handling.
+func (s *Service) Unsubscribe(subscription Subscription) {
+ subscription.cancel()
+
+ s.pendingSubscriptionsLock.Lock()
+ defer s.pendingSubscriptionsLock.Unlock()
+
+ s.pendingSubscriptionsRemove = append(s.pendingSubscriptionsRemove, subscription)
+}
+
+// Pause pauses the event polling.
+// DO NOT CALL THIS DURING EVENT HANDLING.
+func (s *Service) Pause(ctx context.Context) error {
+ _, err := s.cpc.Send(ctx, &pauseRequest{})
+
+ return err
+}
+
+// Resume resumes the event polling.
+// DO NOT CALL THIS DURING EVENT HANDLING.
+func (s *Service) Resume(ctx context.Context) error {
+ _, err := s.cpc.Send(ctx, &resumeRequest{})
+
+ return err
+}
+
+// IsPaused return true if the service is paused
+// DO NOT CALL THIS DURING EVENT HANDLING.
+func (s *Service) IsPaused(ctx context.Context) (bool, error) {
+ return cpc.SendTyped[bool](ctx, s.cpc, &isPausedRequest{})
+}
+
+func (s *Service) Start(ctx context.Context, group *async.Group) error {
+ lastEventID, err := s.eventIDStore.Load(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to load last event id: %w", err)
+ }
+
+ if lastEventID == "" {
+ s.log.Debugf("No event ID present in storage, retrieving latest")
+ eventID, err := s.eventSource.GetLatestEventID(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get latest event id: %w", err)
+ }
+
+ if err := s.eventIDStore.Store(ctx, eventID); err != nil {
+ return fmt.Errorf("failed to store event in event id store: %v", err)
+ }
+
+ lastEventID = eventID
+ }
+
+ group.Once(func(ctx context.Context) {
+ s.run(ctx, lastEventID)
+ })
+
+ return nil
+}
+
+func (s *Service) run(ctx context.Context, lastEventID string) {
+ s.log.Debugf("Starting service Last EventID=%v", lastEventID)
+ defer s.close()
+ defer s.log.Debug("Exiting service")
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case req, ok := <-s.cpc.ReceiveCh():
+ if !ok {
+ return
+ }
+
+ s.handleRequest(ctx, req)
+ continue
+ case <-s.timer.C:
+ if s.paused {
+ continue
+ }
+ }
+
+ // Apply any pending subscription changes.
+ func() {
+ s.pendingSubscriptionsLock.Lock()
+ defer s.pendingSubscriptionsLock.Unlock()
+
+ for _, subscription := range s.pendingSubscriptionsRemove {
+ s.removeSubscription(subscription)
+ }
+
+ for _, subscription := range s.pendingSubscriptionsAdd {
+ s.addSubscription(subscription)
+ }
+
+ s.pendingSubscriptionsRemove = nil
+ s.pendingSubscriptionsAdd = nil
+ }()
+
+ newEvents, _, err := s.eventSource.GetEvent(ctx, lastEventID)
+ if err != nil {
+ s.log.WithError(err).Errorf("Failed to get event (caused by %T)", internal.ErrCause(err))
+ continue
+ }
+
+ // If the event ID hasn't changed, there are no new events.
+ if newEvents[len(newEvents)-1].EventID == lastEventID {
+ s.log.Debugf("No new API Events")
+ continue
+ }
+
+ if event, eventErr := func() (proton.Event, error) {
+ for _, event := range newEvents {
+ if err := s.handleEvent(ctx, lastEventID, event); err != nil {
+ return event, err
+ }
+ }
+
+ return proton.Event{}, nil
+ }(); eventErr != nil {
+ subscriberName, err := s.handleEventError(ctx, lastEventID, event, eventErr)
+ if subscriberName == "" {
+ subscriberName = "?"
+ }
+ s.log.WithField("subscriber", subscriberName).WithError(err).Errorf("Failed to apply event")
+ continue
+ }
+
+ newEventID := newEvents[len(newEvents)-1].EventID
+ if err := s.eventIDStore.Store(ctx, newEventID); err != nil {
+ s.log.WithError(err).Errorf("Failed to store new event ID: %v", err)
+ s.onBadEvent(ctx, events.UserBadEvent{
+ Error: fmt.Errorf("failed to store new event ID: %w", err),
+ UserID: s.userID,
+ })
+ continue
+ }
+
+ lastEventID = newEventID
+ }
+}
+
+func (s *Service) handleEvent(ctx context.Context, lastEventID string, event proton.Event) error {
+ s.log.WithFields(logrus.Fields{
+ "old": lastEventID,
+ "new": event,
+ }).Info("Received new API event")
+
+ if event.Refresh&proton.RefreshMail != 0 {
+ s.log.Info("Handling refresh event")
+ if err := s.refreshSubscribers.Publish(ctx, event.Refresh, s.eventTimeout); err != nil {
+ return fmt.Errorf("failed to apply refresh event: %w", err)
+ }
+
+ return nil
+ }
+
+ // Start with user events.
+ if event.User != nil {
+ if err := s.userSubscriberList.PublishParallel(ctx, *event.User, s.panicHandler, s.eventTimeout); err != nil {
+ return fmt.Errorf("failed to apply user event: %w", err)
+ }
+ }
+
+ // Next Address events
+ if err := s.addressSubscribers.PublishParallel(ctx, event.Addresses, s.panicHandler, s.eventTimeout); err != nil {
+ return fmt.Errorf("failed to apply address events: %w", err)
+ }
+
+ // Next label events
+ if err := s.labelSubscribers.PublishParallel(ctx, event.Labels, s.panicHandler, s.eventTimeout); err != nil {
+ return fmt.Errorf("failed to apply label events: %w", err)
+ }
+
+ // Next message events
+ if err := s.messageSubscribers.PublishParallel(ctx, event.Messages, s.panicHandler, s.eventTimeout); err != nil {
+ return fmt.Errorf("failed to apply message events: %w", err)
+ }
+
+ // Finally user used space events
+ if event.UsedSpace != nil {
+ if err := s.userUsedSpaceSubscriber.PublishParallel(ctx, *event.UsedSpace, s.panicHandler, s.eventTimeout); err != nil {
+ return fmt.Errorf("failed to apply message events: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func unpackPublisherError(err error) (string, error) {
+ var addressErr *addressPublishError
+ var labelErr *labelPublishError
+ var messageErr *messagePublishError
+ var refreshErr *refreshPublishError
+ var userErr *userPublishError
+ var usedSpaceErr *userUsedEventPublishError
+
+ switch {
+ case errors.As(err, &userErr):
+ return userErr.subscriber.name(), userErr.error
+ case errors.As(err, &addressErr):
+ return addressErr.subscriber.name(), addressErr.error
+ case errors.As(err, &labelErr):
+ return labelErr.subscriber.name(), labelErr.error
+ case errors.As(err, &messageErr):
+ return messageErr.subscriber.name(), messageErr.error
+ case errors.As(err, &refreshErr):
+ return refreshErr.subscriber.name(), refreshErr.error
+ case errors.As(err, &usedSpaceErr):
+ return usedSpaceErr.subscriber.name(), usedSpaceErr.error
+ default:
+ return "", err
+ }
+}
+
+func (s *Service) handleEventError(ctx context.Context, lastEventID string, event proton.Event, err error) (string, error) {
+ // Unpack the error so we can proceed to handle the real issue.
+ subscriberName, err := unpackPublisherError(err)
+
+ // If the error is a context cancellation, return error to retry later.
+ if errors.Is(err, context.Canceled) {
+ return subscriberName, fmt.Errorf("failed to handle event due to context cancellation: %w", err)
+ }
+
+ // If the error is a network error, return error to retry later.
+ if netErr := new(proton.NetError); errors.As(err, &netErr) {
+ return subscriberName, fmt.Errorf("failed to handle event due to network issue: %w", err)
+ }
+
+ // Catch all for uncategorized net errors that may slip through.
+ if netErr := new(net.OpError); errors.As(err, &netErr) {
+ return subscriberName, fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err)
+ }
+
+ // In case a json decode error slips through.
+ if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) {
+ s.eventPublisher.PublishEvent(ctx, events.UncategorizedEventError{
+ UserID: s.userID,
+ Error: err,
+ })
+
+ return subscriberName, fmt.Errorf("failed to handle event due to JSON issue: %w", err)
+ }
+
+ // If the error is an unexpected EOF, return error to retry later.
+ if errors.Is(err, io.ErrUnexpectedEOF) {
+ return subscriberName, fmt.Errorf("failed to handle event due to EOF: %w", err)
+ }
+
+ // If the error is a server-side issue, return error to retry later.
+ if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
+ return subscriberName, fmt.Errorf("failed to handle event due to server error: %w", err)
+ }
+
+ // Otherwise, the error is a client-side issue; notify bridge to handle it.
+ s.log.WithField("event", event).Warn("Failed to handle API event")
+
+ s.onBadEvent(ctx, events.UserBadEvent{
+ UserID: s.userID,
+ OldEventID: lastEventID,
+ NewEventID: event.EventID,
+ EventInfo: event.String(),
+ Error: err,
+ })
+
+ return subscriberName, fmt.Errorf("failed to handle event due to client error: %w", err)
+}
+
+func (s *Service) onBadEvent(ctx context.Context, event events.UserBadEvent) {
+ s.paused = true
+ s.eventPublisher.PublishEvent(ctx, event)
+}
+
+func (s *Service) handleRequest(ctx context.Context, request *cpc.Request) {
+ switch request.Value().(type) {
+ case *pauseRequest:
+ s.paused = true
+ request.Reply(ctx, nil, nil)
+ case *resumeRequest:
+ s.paused = false
+ request.Reply(ctx, nil, nil)
+ case *isPausedRequest:
+ request.Reply(ctx, s.paused, nil)
+ default:
+ s.log.Errorf("Unknown request")
+ }
+}
+
+func (s *Service) addSubscription(subscription Subscription) {
+ if subscription.User != nil {
+ s.userSubscriberList.Add(subscription.User)
+ }
+
+ if subscription.Refresh != nil {
+ s.refreshSubscribers.Add(subscription.Refresh)
+ }
+
+ if subscription.Address != nil {
+ s.addressSubscribers.Add(subscription.Address)
+ }
+
+ if subscription.Labels != nil {
+ s.labelSubscribers.Add(subscription.Labels)
+ }
+
+ if subscription.Messages != nil {
+ s.messageSubscribers.Add(subscription.Messages)
+ }
+
+ if subscription.UserUsedSpace != nil {
+ s.userUsedSpaceSubscriber.Add(subscription.UserUsedSpace)
+ }
+}
+
+func (s *Service) removeSubscription(subscription Subscription) {
+ if subscription.User != nil {
+ s.userSubscriberList.Remove(subscription.User)
+ }
+
+ if subscription.Refresh != nil {
+ s.refreshSubscribers.Remove(subscription.Refresh)
+ }
+
+ if subscription.Address != nil {
+ s.addressSubscribers.Remove(subscription.Address)
+ }
+
+ if subscription.Labels != nil {
+ s.labelSubscribers.Remove(subscription.Labels)
+ }
+
+ if subscription.Messages != nil {
+ s.messageSubscribers.Remove(subscription.Messages)
+ }
+
+ if subscription.UserUsedSpace != nil {
+ s.userUsedSpaceSubscriber.Remove(subscription.UserUsedSpace)
+ }
+}
+
+func (s *Service) close() {
+ s.cpc.Close()
+ s.timer.Stop()
+}
+
+type pauseRequest struct{}
+
+type resumeRequest struct{}
+
+type isPausedRequest struct{}
diff --git a/internal/services/userevents/service_handle_event_error_test.go b/internal/services/userevents/service_handle_event_error_test.go
new file mode 100644
index 00000000..0ea206f4
--- /dev/null
+++ b/internal/services/userevents/service_handle_event_error_test.go
@@ -0,0 +1,157 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package userevents
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/ProtonMail/gluon/async"
+ "github.com/ProtonMail/go-proton-api"
+ "github.com/ProtonMail/proton-bridge/v3/internal/events"
+ "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks"
+ "github.com/golang/mock/gomock"
+ "github.com/stretchr/testify/require"
+)
+
+func TestServiceHandleEventError_SubscriberEventUnwrapping(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
+ eventIDStore := NewInMemoryEventIDStore()
+
+ service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
+
+ lastEventID := "PrevEvent"
+ event := proton.Event{EventID: "MyEvent"}
+ subscriber := &noOpSubscriber[proton.AddressEvent]{}
+
+ err := &addressPublishError{
+ subscriber: subscriber,
+ error: &proton.NetError{},
+ }
+
+ subscriberName, unpackedErr := service.handleEventError(context.Background(), lastEventID, event, err)
+ require.Equal(t, subscriber.name(), subscriberName)
+ protonNetErr := new(proton.NetError)
+ require.True(t, errors.As(unpackedErr, &protonNetErr))
+
+ err2 := &proton.APIError{Status: 500}
+ subscriberName2, unpackedErr2 := service.handleEventError(context.Background(), lastEventID, event, err2)
+ require.Equal(t, "", subscriberName2)
+ protonAPIErr := new(proton.APIError)
+ require.True(t, errors.As(unpackedErr2, &protonAPIErr))
+}
+
+func TestServiceHandleEventError_BadEventPutsServiceOnPause(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
+ eventIDStore := NewInMemoryEventIDStore()
+
+ service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
+ service.paused = false
+ lastEventID := "PrevEvent"
+ event := proton.Event{EventID: "MyEvent"}
+
+ err := &proton.APIError{}
+
+ eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserBadEvent{
+ UserID: service.userID,
+ OldEventID: lastEventID,
+ NewEventID: event.EventID,
+ EventInfo: event.String(),
+ Error: err,
+ })).Times(1)
+
+ _, _ = service.handleEventError(context.Background(), lastEventID, event, err)
+ require.True(t, service.paused)
+}
+
+func TestServiceHandleEventError_BadEventFromPublishTimeout(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
+ eventIDStore := NewInMemoryEventIDStore()
+
+ service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
+ lastEventID := "PrevEvent"
+ event := proton.Event{EventID: "MyEvent"}
+ err := ErrPublishTimeoutExceeded
+
+ eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserBadEvent{
+ UserID: service.userID,
+ OldEventID: lastEventID,
+ NewEventID: event.EventID,
+ EventInfo: event.String(),
+ Error: err,
+ })).Times(1)
+
+ _, _ = service.handleEventError(context.Background(), lastEventID, event, err)
+}
+
+func TestServiceHandleEventError_NoBadEventCheck(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
+ eventIDStore := NewInMemoryEventIDStore()
+
+ service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
+ lastEventID := "PrevEvent"
+ event := proton.Event{EventID: "MyEvent"}
+ _, _ = service.handleEventError(context.Background(), lastEventID, event, context.Canceled)
+ _, _ = service.handleEventError(context.Background(), lastEventID, event, &proton.NetError{})
+ _, _ = service.handleEventError(context.Background(), lastEventID, event, &net.OpError{})
+ _, _ = service.handleEventError(context.Background(), lastEventID, event, io.ErrUnexpectedEOF)
+ _, _ = service.handleEventError(context.Background(), lastEventID, event, &proton.APIError{Status: 500})
+}
+
+func TestServiceHandleEventError_JsonUnmarshalEventProducesUncategorizedErrorEvent(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
+ eventIDStore := NewInMemoryEventIDStore()
+
+ service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
+ lastEventID := "PrevEvent"
+ event := proton.Event{EventID: "MyEvent"}
+ err := &json.UnmarshalTypeError{}
+
+ eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UncategorizedEventError{
+ UserID: service.userID,
+ Error: err,
+ })).Times(1)
+
+ _, _ = service.handleEventError(context.Background(), lastEventID, event, err)
+}
+
+type noOpSubscriber[T any] struct{}
+
+func (n noOpSubscriber[T]) name() string { //nolint:unused
+ return "NoopSubscriber"
+}
+
+func (n noOpSubscriber[T]) handle(_ context.Context, _ []T) error { //nolint:unused
+ return nil
+}
+
+//nolint:unused
+func (n noOpSubscriber[T]) close() {} //
+
+//nolint:unused
+func (n noOpSubscriber[T]) cancel() {}
diff --git a/internal/services/userevents/service_handle_event_test.go b/internal/services/userevents/service_handle_event_test.go
new file mode 100644
index 00000000..d64d9da2
--- /dev/null
+++ b/internal/services/userevents/service_handle_event_test.go
@@ -0,0 +1,195 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package userevents
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/ProtonMail/gluon/async"
+ "github.com/ProtonMail/go-proton-api"
+ "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks"
+ "github.com/golang/mock/gomock"
+ "github.com/stretchr/testify/require"
+)
+
+func TestServiceHandleEvent_CheckEventCategoriesHandledInOrder(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+
+ eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
+ eventIDStore := NewInMemoryEventIDStore()
+
+ refreshHandler := NewMockRefreshSubscriber(mockCtrl)
+ refreshHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(2).Return(nil)
+
+ userHandler := NewMockUserSubscriber(mockCtrl)
+ userCall := userHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(nil)
+
+ addressHandler := NewMockAddressSubscriber(mockCtrl)
+ addressCall := addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(userCall).Times(1).Return(nil)
+
+ labelHandler := NewMockLabelSubscriber(mockCtrl)
+ labelCall := labelHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(addressCall).Times(1).Return(nil)
+
+ messageHandler := NewMockMessageSubscriber(mockCtrl)
+ messageCall := messageHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(labelCall).Times(1).Return(nil)
+
+ userSpaceHandler := NewMockUserUsedSpaceSubscriber(mockCtrl)
+ userSpaceCall := userSpaceHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(messageCall).Times(1).Return(nil)
+
+ secondRefreshHandler := NewMockRefreshSubscriber(mockCtrl)
+ secondRefreshHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(userSpaceCall).Times(1).Return(nil)
+
+ service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
+
+ service.addSubscription(Subscription{
+ User: userHandler,
+ Refresh: refreshHandler,
+ Address: addressHandler,
+ Labels: labelHandler,
+ Messages: messageHandler,
+ UserUsedSpace: userSpaceHandler,
+ })
+
+ // Simulate 1st refresh.
+ require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{Refresh: proton.RefreshMail}))
+
+ // Simulate Regular event.
+ usedSpace := 20
+ require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{
+ User: new(proton.User),
+ Addresses: []proton.AddressEvent{},
+ Labels: []proton.LabelEvent{
+ {},
+ },
+ Messages: []proton.MessageEvent{
+ {},
+ },
+ UsedSpace: &usedSpace,
+ }))
+
+ service.addSubscription(Subscription{
+ Refresh: secondRefreshHandler,
+ })
+
+ // Simulate 2nd refresh.
+ require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{Refresh: proton.RefreshMail}))
+}
+
+func TestServiceHandleEvent_CheckEventFailureCausesError(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+
+ eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
+ eventIDStore := NewInMemoryEventIDStore()
+
+ addressHandler := NewMockAddressSubscriber(mockCtrl)
+ addressHandler.EXPECT().name().MinTimes(1).Return("Hello")
+ addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(fmt.Errorf("failed"))
+
+ messageHandler := NewMockMessageSubscriber(mockCtrl)
+
+ service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
+
+ service.addSubscription(Subscription{
+ Address: addressHandler,
+ Messages: messageHandler,
+ })
+
+ err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}})
+ require.Error(t, err)
+ publisherErr := new(addressPublishError)
+ require.True(t, errors.As(err, &publisherErr))
+ require.Equal(t, publisherErr.subscriber, addressHandler)
+}
+
+func TestServiceHandleEvent_CheckEventFailureCausesErrorParallel(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+
+ eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
+ eventIDStore := NewInMemoryEventIDStore()
+
+ addressHandler := NewMockAddressSubscriber(mockCtrl)
+ addressHandler.EXPECT().name().MinTimes(1).Return("Hello")
+ addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(fmt.Errorf("failed"))
+
+ addressHandler2 := NewMockAddressSubscriber(mockCtrl)
+ addressHandler2.EXPECT().handle(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil)
+
+ service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
+
+ service.addSubscription(Subscription{
+ Address: addressHandler,
+ })
+
+ service.addSubscription(Subscription{
+ Address: addressHandler2,
+ })
+
+ err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}})
+ require.Error(t, err)
+ publisherErr := new(addressPublishError)
+ require.True(t, errors.As(err, &publisherErr))
+ require.Equal(t, publisherErr.subscriber, addressHandler)
+}
+
+func TestServiceHandleEvent_SubscriberTimeout(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+
+ eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
+ eventIDStore := NewInMemoryEventIDStore()
+
+ addressHandler := NewMockAddressSubscriber(mockCtrl)
+ addressHandler.EXPECT().name().AnyTimes().Return("Ok")
+ addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil)
+
+ addressHandler2 := NewMockAddressSubscriber(mockCtrl)
+ addressHandler2.EXPECT().name().AnyTimes().Return("Timeout")
+ addressHandler2.EXPECT().handle(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, _ []proton.AddressEvent) error {
+ timer := time.NewTimer(time.Second)
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer.C:
+ return nil
+ }
+ }).MaxTimes(1)
+
+ service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, 500*time.Millisecond, async.NoopPanicHandler{})
+
+ service.addSubscription(Subscription{
+ Address: addressHandler,
+ })
+
+ service.addSubscription(Subscription{
+ Address: addressHandler2,
+ })
+
+ // Simulate 1st refresh.
+ err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}})
+ require.Error(t, err)
+ if publisherErr := new(addressPublishError); errors.As(err, &publisherErr) {
+ require.Equal(t, publisherErr.subscriber, addressHandler)
+ require.True(t, errors.Is(publisherErr.error, ErrPublishTimeoutExceeded))
+ } else {
+ require.True(t, errors.Is(err, ErrPublishTimeoutExceeded))
+ }
+}
diff --git a/internal/services/userevents/service_test.go b/internal/services/userevents/service_test.go
new file mode 100644
index 00000000..fc4afc25
--- /dev/null
+++ b/internal/services/userevents/service_test.go
@@ -0,0 +1,269 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+package userevents
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "testing"
+ "time"
+
+ "github.com/ProtonMail/gluon/async"
+ "github.com/ProtonMail/go-proton-api"
+ "github.com/ProtonMail/proton-bridge/v3/internal/events"
+ mocks2 "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks"
+ "github.com/ProtonMail/proton-bridge/v3/internal/services/userevents/mocks"
+ "github.com/golang/mock/gomock"
+ "github.com/stretchr/testify/require"
+)
+
+func TestService_EventIDLoadStore(t *testing.T) {
+ // Simulate the flow of data when we start without any event id in the event id store:
+ // * Load event id from store
+ // * Get latest event id since
+ // * Store that latest event id
+ // * Start event poll loop
+ // * Get new event id, store it in vault
+ // * Try to poll new event it, but context is cancelled
+ group := async.NewGroup(context.Background(), &async.NoopPanicHandler{})
+ mockCtrl := gomock.NewController(t)
+ eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
+ eventIDStore := mocks.NewMockEventIDStore(mockCtrl)
+ eventSource := mocks.NewMockEventSource(mockCtrl)
+
+ firstEventID := "EVENT01"
+ secondEventID := "EVENT02"
+ secondEvent := []proton.Event{{
+ EventID: secondEventID,
+ }}
+
+ // Event id store expectations.
+ eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return("", nil)
+ eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(firstEventID)).Times(1).Return(nil)
+ eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error {
+ // Force exit, we have finished executing what we expected.
+ group.Cancel()
+ return nil
+ })
+
+ // Event Source expectations.
+ eventSource.EXPECT().GetLatestEventID(gomock.Any()).Times(1).Return(firstEventID, nil)
+ eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil)
+
+ service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{})
+ require.NoError(t, service.Start(context.Background(), group))
+ require.NoError(t, service.Resume(context.Background()))
+ group.WaitToFinish()
+}
+
+func TestService_RetryEventOnNonCatastrophicFailure(t *testing.T) {
+ group := async.NewGroup(context.Background(), &async.NoopPanicHandler{})
+ mockCtrl := gomock.NewController(t)
+ eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
+ eventIDStore := mocks.NewMockEventIDStore(mockCtrl)
+ eventSource := mocks.NewMockEventSource(mockCtrl)
+ subscriber := NewMockMessageSubscriber(mockCtrl)
+
+ firstEventID := "EVENT01"
+ secondEventID := "EVENT02"
+ messageEvents := []proton.MessageEvent{
+ {
+ EventItem: proton.EventItem{ID: "Message"},
+ },
+ }
+ secondEvent := []proton.Event{{
+ EventID: secondEventID,
+ Messages: messageEvents,
+ }}
+
+ // Event id store expectations.
+ eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil)
+ eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error {
+ // Force exit, we have finished executing what we expected.
+ group.Cancel()
+ return nil
+ })
+
+ // Event Source expectations.
+ eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil)
+
+ // Subscriber expectations.
+ subscriber.EXPECT().name().AnyTimes().Return("Foo")
+ {
+ firstCall := subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).Return(io.ErrUnexpectedEOF)
+ subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).After(firstCall).Times(1).Return(nil)
+ }
+
+ service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{})
+ service.Subscribe(Subscription{Messages: subscriber})
+
+ require.NoError(t, service.Start(context.Background(), group))
+ require.NoError(t, service.Resume(context.Background()))
+ group.WaitToFinish()
+}
+
+func TestService_OnBadEventServiceIsPaused(t *testing.T) {
+ group := async.NewGroup(context.Background(), &async.NoopPanicHandler{})
+ mockCtrl := gomock.NewController(t)
+ eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
+ eventIDStore := mocks.NewMockEventIDStore(mockCtrl)
+ eventSource := mocks.NewMockEventSource(mockCtrl)
+ subscriber := NewMockMessageSubscriber(mockCtrl)
+
+ firstEventID := "EVENT01"
+ secondEventID := "EVENT02"
+ messageEvents := []proton.MessageEvent{
+ {
+ EventItem: proton.EventItem{ID: "Message"},
+ },
+ }
+ secondEvent := []proton.Event{{
+ EventID: secondEventID,
+ Messages: messageEvents,
+ }}
+
+ // Event id store expectations.
+ eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil)
+
+ // Event Source expectations.
+ eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil)
+
+ // Subscriber expectations.
+ badEventErr := fmt.Errorf("I will cause bad event")
+ subscriber.EXPECT().name().AnyTimes().Return("Foo")
+ subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).Return(badEventErr)
+
+ service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{})
+
+ // Event publisher expectations.
+ eventPublisher.EXPECT().PublishEvent(gomock.Any(), events.UserBadEvent{
+ UserID: "foo",
+ OldEventID: firstEventID,
+ NewEventID: secondEventID,
+ EventInfo: secondEvent[0].String(),
+ Error: badEventErr,
+ }).Do(func(_ context.Context, event events.Event) {
+ group.Once(func(_ context.Context) {
+ // Use background context to avoid having the request cancelled
+ paused, err := service.IsPaused(context.Background())
+ require.NoError(t, err)
+ require.True(t, paused)
+ group.Cancel()
+ })
+ })
+
+ service.Subscribe(Subscription{Messages: subscriber})
+ require.NoError(t, service.Start(context.Background(), group))
+ require.NoError(t, service.Resume(context.Background()))
+ group.WaitToFinish()
+}
+
+func TestService_UnsubscribeDuringEventHandlingDoesNotCauseDeadlock(t *testing.T) {
+ group := async.NewGroup(context.Background(), &async.NoopPanicHandler{})
+ mockCtrl := gomock.NewController(t)
+ eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
+ eventIDStore := mocks.NewMockEventIDStore(mockCtrl)
+ eventSource := mocks.NewMockEventSource(mockCtrl)
+ subscriber := NewMockMessageSubscriber(mockCtrl)
+
+ firstEventID := "EVENT01"
+ secondEventID := "EVENT02"
+ messageEvents := []proton.MessageEvent{
+ {
+ EventItem: proton.EventItem{ID: "Message"},
+ },
+ }
+ secondEvent := []proton.Event{{
+ EventID: secondEventID,
+ Messages: messageEvents,
+ }}
+
+ // Event id store expectations.
+ eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil)
+ eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error {
+ // Force exit, we have finished executing what we expected.
+ group.Cancel()
+ return nil
+ })
+
+ // Event Source expectations.
+ eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil)
+
+ service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{})
+
+ // Subscriber expectations.
+ subscriber.EXPECT().name().AnyTimes().Return("Foo")
+ subscriber.EXPECT().cancel().Times(1)
+ subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).DoAndReturn(func(_ context.Context, _ []proton.MessageEvent) error {
+ service.Unsubscribe(Subscription{Messages: subscriber})
+ return nil
+ })
+
+ service.Subscribe(Subscription{Messages: subscriber})
+ require.NoError(t, service.Start(context.Background(), group))
+ require.NoError(t, service.Resume(context.Background()))
+ group.WaitToFinish()
+}
+
+func TestService_UnsubscribeBeforeHandlingEventIsNotConsideredError(t *testing.T) {
+ group := async.NewGroup(context.Background(), &async.NoopPanicHandler{})
+ mockCtrl := gomock.NewController(t)
+ eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
+ eventIDStore := mocks.NewMockEventIDStore(mockCtrl)
+ eventSource := mocks.NewMockEventSource(mockCtrl)
+ subscriber := newChanneledSubscriber[proton.MessageEvent]("My subscriber")
+
+ firstEventID := "EVENT01"
+ secondEventID := "EVENT02"
+ messageEvents := []proton.MessageEvent{
+ {
+ EventItem: proton.EventItem{ID: "Message"},
+ },
+ }
+ secondEvent := []proton.Event{{
+ EventID: secondEventID,
+ Messages: messageEvents,
+ }}
+
+ // Event id store expectations.
+ eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil)
+ eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error {
+ // Force exit, we have finished executing what we expected.
+ group.Cancel()
+ return nil
+ })
+
+ // Event Source expectations.
+ eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil)
+ eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(secondEventID)).AnyTimes().Return(secondEvent, false, nil)
+
+ service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{})
+
+ // start subscriber
+ group.Once(func(_ context.Context) {
+ defer service.Unsubscribe(Subscription{Messages: subscriber})
+
+ // Simulate the reception of an event, but it is never handled due to unexpected exit
+ <-time.NewTicker(500 * time.Millisecond).C
+ })
+
+ service.Subscribe(Subscription{Messages: subscriber})
+ require.NoError(t, service.Start(context.Background(), group))
+ require.NoError(t, service.Resume(context.Background()))
+ group.WaitToFinish()
+}
diff --git a/internal/services/userevents/subscriber.go b/internal/services/userevents/subscriber.go
new file mode 100644
index 00000000..16e92f92
--- /dev/null
+++ b/internal/services/userevents/subscriber.go
@@ -0,0 +1,252 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package userevents
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "runtime"
+ "time"
+
+ "github.com/ProtonMail/gluon/async"
+ "github.com/ProtonMail/go-proton-api"
+ "github.com/bradenaw/juniper/parallel"
+ "github.com/bradenaw/juniper/xslices"
+ "golang.org/x/exp/slices"
+)
+
+func NewMessageSubscriber(name string) *ChanneledSubscriber[proton.MessageEvent] {
+ return newChanneledSubscriber[proton.MessageEvent](name)
+}
+
+func NewAddressSubscriber(name string) *ChanneledSubscriber[proton.AddressEvent] {
+ return newChanneledSubscriber[proton.AddressEvent](name)
+}
+
+func NewLabelSubscriber(name string) *ChanneledSubscriber[proton.LabelEvent] {
+ return newChanneledSubscriber[proton.LabelEvent](name)
+}
+
+func NewRefreshSubscriber(name string) *ChanneledSubscriber[struct{}] {
+ return newChanneledSubscriber[struct{}](name)
+}
+
+func NewUserSubscriber(name string) *ChanneledSubscriber[proton.User] {
+ return newChanneledSubscriber[proton.User](name)
+}
+
+func NewUserUsedSpaceSubscriber(name string) *ChanneledSubscriber[int] {
+ return newChanneledSubscriber[int](name)
+}
+
+type AddressSubscriber = subscriber[[]proton.AddressEvent]
+type LabelSubscriber = subscriber[[]proton.LabelEvent]
+type MessageSubscriber = subscriber[[]proton.MessageEvent]
+type RefreshSubscriber = subscriber[proton.RefreshFlag]
+type UserSubscriber = subscriber[proton.User]
+type UserUsedSpaceSubscriber = subscriber[int]
+
+// Subscriber is the main entry point of interacting with user generated events.
+type subscriber[T any] interface {
+ // Name returns the identifier for this subscriber
+ name() string
+ // Handle the event list.
+ handle(context.Context, T) error
+ // cancel is behavior extension for channel based subscribers so that they can ensure that
+ // if a subscriber unsubscribes, it doesn't cause pending events on the channel to time-out as there is no one to handle
+ // them.
+ cancel()
+ // close release all associated resources
+ close()
+}
+
+type subscriberList[T any] struct {
+ subscribers []subscriber[T]
+}
+
+type addressSubscriberList = subscriberList[[]proton.AddressEvent]
+type labelSubscriberList = subscriberList[[]proton.LabelEvent]
+type messageSubscriberList = subscriberList[[]proton.MessageEvent]
+type refreshSubscriberList = subscriberList[proton.RefreshFlag]
+type userSubscriberList = subscriberList[proton.User]
+type userUsedSpaceSubscriberList = subscriberList[int]
+
+func (s *subscriberList[T]) Add(subscriber subscriber[T]) {
+ if !slices.Contains(s.subscribers, subscriber) {
+ s.subscribers = append(s.subscribers, subscriber)
+ }
+}
+
+func (s *subscriberList[T]) Remove(subscriber subscriber[T]) {
+ index := slices.Index(s.subscribers, subscriber)
+ if index < 0 {
+ return
+ }
+
+ s.subscribers[index].close()
+ s.subscribers = xslices.Remove(s.subscribers, index, 1)
+}
+
+type publishError[T any] struct {
+ subscriber subscriber[T]
+ error error
+}
+
+var ErrPublishTimeoutExceeded = errors.New("event publish timed out")
+
+type addressPublishError = publishError[[]proton.AddressEvent]
+type labelPublishError = publishError[[]proton.LabelEvent]
+type messagePublishError = publishError[[]proton.MessageEvent]
+type refreshPublishError = publishError[proton.RefreshFlag]
+type userPublishError = publishError[proton.User]
+type userUsedEventPublishError = publishError[int]
+
+func (p publishError[T]) Error() string {
+ return fmt.Sprintf("Event publish failed on (%v): %v", p.subscriber.name(), p.error.Error())
+}
+
+func (s *subscriberList[T]) Publish(ctx context.Context, event T, timeout time.Duration) error {
+ ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
+ defer cancel()
+
+ for _, subscriber := range s.subscribers {
+ if err := subscriber.handle(ctx, event); err != nil {
+ return &publishError[T]{
+ subscriber: subscriber,
+ error: mapContextTimeoutError(err),
+ }
+ }
+
+ if err := ctx.Err(); err != nil {
+ return &publishError[T]{
+ subscriber: subscriber,
+ error: mapContextTimeoutError(err),
+ }
+ }
+ }
+
+ return nil
+}
+
+func mapContextTimeoutError(err error) error {
+ if errors.Is(err, context.DeadlineExceeded) {
+ return ErrPublishTimeoutExceeded
+ }
+
+ return err
+}
+
+func (s *subscriberList[T]) PublishParallel(
+ ctx context.Context,
+ event T,
+ panicHandler async.PanicHandler,
+ timeout time.Duration,
+) error {
+ if len(s.subscribers) <= 1 {
+ return s.Publish(ctx, event, timeout)
+ }
+
+ ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
+ defer cancel()
+
+ err := parallel.DoContext(ctx, runtime.NumCPU()/2, len(s.subscribers), func(ctx context.Context, index int) error {
+ defer async.HandlePanic(panicHandler)
+ if err := s.subscribers[index].handle(ctx, event); err != nil {
+ return &publishError[T]{
+ subscriber: s.subscribers[index],
+ error: mapContextTimeoutError(err),
+ }
+ }
+
+ return nil
+ })
+
+ return mapContextTimeoutError(err)
+}
+
+type ChanneledSubscriber[T any] struct {
+ id string
+ sender chan *ChanneledSubscriberEvent[T]
+}
+
+func newChanneledSubscriber[T any](name string) *ChanneledSubscriber[T] {
+ return &ChanneledSubscriber[T]{
+ id: name,
+ sender: make(chan *ChanneledSubscriberEvent[T]),
+ }
+}
+
+type ChanneledSubscriberEvent[T any] struct {
+ data []T
+ response chan error
+}
+
+func (c ChanneledSubscriberEvent[T]) Consume(f func([]T) error) {
+ if err := f(c.data); err != nil {
+ c.response <- err
+ }
+ close(c.response)
+}
+
+func (c *ChanneledSubscriber[T]) name() string { //nolint:unused
+ return c.id
+}
+
+func (c *ChanneledSubscriber[T]) handle(ctx context.Context, event []T) error { //nolint:unused
+ data := &ChanneledSubscriberEvent[T]{
+ data: event,
+ response: make(chan error),
+ }
+ // Send Event
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("failed to send event: %w", ctx.Err())
+ case c.sender <- data:
+ //
+ }
+
+ // Wait on Reply
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("failed to receive event reply: %w", ctx.Err())
+ case reply := <-data.response:
+ return reply
+ }
+}
+
+func (c *ChanneledSubscriber[T]) OnEventCh() <-chan *ChanneledSubscriberEvent[T] {
+ return c.sender
+}
+
+func (c *ChanneledSubscriber[T]) close() { //nolint:unused
+ close(c.sender)
+}
+
+func (c *ChanneledSubscriber[T]) cancel() { //nolint:unused
+ go func() {
+ for {
+ e, ok := <-c.sender
+ if !ok {
+ return
+ }
+
+ e.Consume(func(_ []T) error { return nil })
+ }
+ }()
+}
diff --git a/internal/services/userevents/subscriber_test.go b/internal/services/userevents/subscriber_test.go
new file mode 100644
index 00000000..2bb0b806
--- /dev/null
+++ b/internal/services/userevents/subscriber_test.go
@@ -0,0 +1,105 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package userevents
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestChanneledSubscriber_CtxTimeoutDoesNotBlockFutureEvents(t *testing.T) {
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+
+ subscriber := newChanneledSubscriber[int]("test")
+ defer subscriber.close()
+
+ go func() {
+ defer wg.Done()
+
+ // Send one event, that succeeds.
+ require.NoError(t, subscriber.handle(context.Background(), []int{30}))
+
+ // Add an impossible deadline that fails immediately to simulate on event taking too long to send.
+ ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Microsecond))
+ defer cancel()
+
+ err := subscriber.handle(ctx, []int{20})
+ require.Error(t, err)
+ require.True(t, errors.Is(err, context.DeadlineExceeded))
+ }()
+
+ // Receive first event. Notify success.
+ event, ok := <-subscriber.OnEventCh()
+ require.True(t, ok)
+ event.Consume(func(event []int) error {
+ require.Equal(t, []int{30}, event)
+ return nil
+ })
+ wg.Wait()
+
+ // Simulate reception of another event
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ require.NoError(t, subscriber.handle(context.Background(), []int{40}))
+ }()
+
+ event, ok = <-subscriber.OnEventCh()
+ require.True(t, ok)
+ event.Consume(func(event []int) error {
+ require.Equal(t, []int{40}, event)
+ return nil
+ })
+
+ wg.Wait()
+}
+
+func TestChanneledSubscriber_ErrorReported(t *testing.T) {
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+
+ subscriber := newChanneledSubscriber[int]("test")
+ defer subscriber.close()
+ reportedErr := fmt.Errorf("request failed")
+
+ go func() {
+ defer wg.Done()
+
+ // Send one event, that succeeds.
+ err := subscriber.handle(context.Background(), []int{30})
+ require.Error(t, err)
+ require.Equal(t, reportedErr, err)
+ }()
+
+ // Receive first event. Notify success.
+ event, ok := <-subscriber.OnEventCh()
+ require.True(t, ok)
+ event.Consume(func(event []int) error {
+ require.Equal(t, []int{30}, event)
+ return reportedErr
+ })
+
+ wg.Wait()
+}