diff --git a/Makefile b/Makefile index 5f43fe5a..0e01fe2a 100644 --- a/Makefile +++ b/Makefile @@ -281,6 +281,13 @@ mocks: mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/telemetry HeartbeatManager > internal/telemetry/mocks/mocks.go cp internal/telemetry/mocks/mocks.go internal/bridge/mocks/telemetry_mocks.go + mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \ +EventSource,EventIDStore > internal/services/userevents/mocks/mocks.go + mockgen --package userevents github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \ +MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscriber,UserUsedSpaceSubscriber > tmp + mv tmp internal/services/userevents/mocks_test.go + mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/events EventPublisher \ +> internal/events/mocks/mocks.go lint: gofiles lint-golang lint-license lint-dependencies lint-changelog diff --git a/internal/events/events.go b/internal/events/events.go index ce27e2ea..84933cc9 100644 --- a/internal/events/events.go +++ b/internal/events/events.go @@ -17,7 +17,10 @@ package events -import "fmt" +import ( + "context" + "fmt" +) type Event interface { fmt.Stringer @@ -28,3 +31,7 @@ type Event interface { type eventBase struct{} func (eventBase) _isEvent() {} + +type EventPublisher interface { + PublishEvent(ctx context.Context, event Event) +} diff --git a/internal/events/mocks/mocks.go b/internal/events/mocks/mocks.go new file mode 100644 index 00000000..2da49cdf --- /dev/null +++ b/internal/events/mocks/mocks.go @@ -0,0 +1,48 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ProtonMail/proton-bridge/v3/internal/events (interfaces: EventPublisher) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + events "github.com/ProtonMail/proton-bridge/v3/internal/events" + gomock "github.com/golang/mock/gomock" +) + +// MockEventPublisher is a mock of EventPublisher interface. +type MockEventPublisher struct { + ctrl *gomock.Controller + recorder *MockEventPublisherMockRecorder +} + +// MockEventPublisherMockRecorder is the mock recorder for MockEventPublisher. +type MockEventPublisherMockRecorder struct { + mock *MockEventPublisher +} + +// NewMockEventPublisher creates a new mock instance. +func NewMockEventPublisher(ctrl *gomock.Controller) *MockEventPublisher { + mock := &MockEventPublisher{ctrl: ctrl} + mock.recorder = &MockEventPublisherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEventPublisher) EXPECT() *MockEventPublisherMockRecorder { + return m.recorder +} + +// PublishEvent mocks base method. +func (m *MockEventPublisher) PublishEvent(arg0 context.Context, arg1 events.Event) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "PublishEvent", arg0, arg1) +} + +// PublishEvent indicates an expected call of PublishEvent. +func (mr *MockEventPublisherMockRecorder) PublishEvent(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishEvent", reflect.TypeOf((*MockEventPublisher)(nil).PublishEvent), arg0, arg1) +} diff --git a/internal/services/userevents/event_source.go b/internal/services/userevents/event_source.go new file mode 100644 index 00000000..098c7605 --- /dev/null +++ b/internal/services/userevents/event_source.go @@ -0,0 +1,40 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package userevents + +import ( + "context" + + "github.com/ProtonMail/go-proton-api" +) + +// EventSource represents a type which produces proton events. +type EventSource interface { + GetLatestEventID(ctx context.Context) (string, error) + GetEvent(ctx context.Context, id string) ([]proton.Event, bool, error) +} + +type NullEventSource struct{} + +func (n NullEventSource) GetLatestEventID(_ context.Context) (string, error) { + return "", nil +} + +func (n NullEventSource) GetEvent(_ context.Context, _ string) ([]proton.Event, bool, error) { + return nil, false, nil +} diff --git a/internal/services/userevents/eventid_store.go b/internal/services/userevents/eventid_store.go new file mode 100644 index 00000000..ce4e3852 --- /dev/null +++ b/internal/services/userevents/eventid_store.go @@ -0,0 +1,75 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package userevents + +import ( + "context" + "sync" + + "github.com/ProtonMail/proton-bridge/v3/internal/vault" +) + +// EventIDStore exposes behavior expected of a type which allows us to store and retrieve event Ids. +// Note: this may be accessed from multiple go-routines. +type EventIDStore interface { + // Load the last stored event, return "" for empty. + Load(ctx context.Context) (string, error) + // Store the new id. + Store(ctx context.Context, id string) error +} + +type InMemoryEventIDStore struct { + lock sync.Mutex + id string +} + +func NewInMemoryEventIDStore() *InMemoryEventIDStore { + return &InMemoryEventIDStore{} +} + +func (i *InMemoryEventIDStore) Load(_ context.Context) (string, error) { + i.lock.Lock() + defer i.lock.Unlock() + + return i.id, nil +} + +func (i *InMemoryEventIDStore) Store(_ context.Context, id string) error { + i.lock.Lock() + defer i.lock.Unlock() + + i.id = id + + return nil +} + +type VaultEventIDStore struct { + vault *vault.User +} + +func NewVaultEventIDStore(vault *VaultEventIDStore) *VaultEventIDStore { + return &VaultEventIDStore{vault: vault.vault} +} + +func (v VaultEventIDStore) Load(_ context.Context) (string, error) { + return v.vault.EventID(), nil +} + +func (v VaultEventIDStore) Store(_ context.Context, id string) error { + return v.vault.SetEventID(id) +} diff --git a/internal/services/userevents/mocks/mocks.go b/internal/services/userevents/mocks/mocks.go new file mode 100644 index 00000000..b0dd345d --- /dev/null +++ b/internal/services/userevents/mocks/mocks.go @@ -0,0 +1,119 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/userevents (interfaces: EventSource,EventIDStore) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + proton "github.com/ProtonMail/go-proton-api" + gomock "github.com/golang/mock/gomock" +) + +// MockEventSource is a mock of EventSource interface. +type MockEventSource struct { + ctrl *gomock.Controller + recorder *MockEventSourceMockRecorder +} + +// MockEventSourceMockRecorder is the mock recorder for MockEventSource. +type MockEventSourceMockRecorder struct { + mock *MockEventSource +} + +// NewMockEventSource creates a new mock instance. +func NewMockEventSource(ctrl *gomock.Controller) *MockEventSource { + mock := &MockEventSource{ctrl: ctrl} + mock.recorder = &MockEventSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEventSource) EXPECT() *MockEventSourceMockRecorder { + return m.recorder +} + +// GetEvent mocks base method. +func (m *MockEventSource) GetEvent(arg0 context.Context, arg1 string) ([]proton.Event, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEvent", arg0, arg1) + ret0, _ := ret[0].([]proton.Event) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetEvent indicates an expected call of GetEvent. +func (mr *MockEventSourceMockRecorder) GetEvent(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvent", reflect.TypeOf((*MockEventSource)(nil).GetEvent), arg0, arg1) +} + +// GetLatestEventID mocks base method. +func (m *MockEventSource) GetLatestEventID(arg0 context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLatestEventID", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLatestEventID indicates an expected call of GetLatestEventID. +func (mr *MockEventSourceMockRecorder) GetLatestEventID(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestEventID", reflect.TypeOf((*MockEventSource)(nil).GetLatestEventID), arg0) +} + +// MockEventIDStore is a mock of EventIDStore interface. +type MockEventIDStore struct { + ctrl *gomock.Controller + recorder *MockEventIDStoreMockRecorder +} + +// MockEventIDStoreMockRecorder is the mock recorder for MockEventIDStore. +type MockEventIDStoreMockRecorder struct { + mock *MockEventIDStore +} + +// NewMockEventIDStore creates a new mock instance. +func NewMockEventIDStore(ctrl *gomock.Controller) *MockEventIDStore { + mock := &MockEventIDStore{ctrl: ctrl} + mock.recorder = &MockEventIDStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEventIDStore) EXPECT() *MockEventIDStoreMockRecorder { + return m.recorder +} + +// Load mocks base method. +func (m *MockEventIDStore) Load(arg0 context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Load", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Load indicates an expected call of Load. +func (mr *MockEventIDStoreMockRecorder) Load(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockEventIDStore)(nil).Load), arg0) +} + +// Store mocks base method. +func (m *MockEventIDStore) Store(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Store", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Store indicates an expected call of Store. +func (mr *MockEventIDStoreMockRecorder) Store(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Store", reflect.TypeOf((*MockEventIDStore)(nil).Store), arg0, arg1) +} diff --git a/internal/services/userevents/mocks_test.go b/internal/services/userevents/mocks_test.go new file mode 100644 index 00000000..dbf7a736 --- /dev/null +++ b/internal/services/userevents/mocks_test.go @@ -0,0 +1,463 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/userevents (interfaces: MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscriber,UserUsedSpaceSubscriber) + +// Package userevents is a generated GoMock package. +package userevents + +import ( + context "context" + reflect "reflect" + + proton "github.com/ProtonMail/go-proton-api" + gomock "github.com/golang/mock/gomock" +) + +// MockMessageSubscriber is a mock of MessageSubscriber interface. +type MockMessageSubscriber struct { + ctrl *gomock.Controller + recorder *MockMessageSubscriberMockRecorder +} + +// MockMessageSubscriberMockRecorder is the mock recorder for MockMessageSubscriber. +type MockMessageSubscriberMockRecorder struct { + mock *MockMessageSubscriber +} + +// NewMockMessageSubscriber creates a new mock instance. +func NewMockMessageSubscriber(ctrl *gomock.Controller) *MockMessageSubscriber { + mock := &MockMessageSubscriber{ctrl: ctrl} + mock.recorder = &MockMessageSubscriberMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMessageSubscriber) EXPECT() *MockMessageSubscriberMockRecorder { + return m.recorder +} + +// cancel mocks base method. +func (m *MockMessageSubscriber) cancel() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "cancel") +} + +// cancel indicates an expected call of cancel. +func (mr *MockMessageSubscriberMockRecorder) cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockMessageSubscriber)(nil).cancel)) +} + +// close mocks base method. +func (m *MockMessageSubscriber) close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "close") +} + +// close indicates an expected call of close. +func (mr *MockMessageSubscriberMockRecorder) close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockMessageSubscriber)(nil).close)) +} + +// handle mocks base method. +func (m *MockMessageSubscriber) handle(arg0 context.Context, arg1 []proton.MessageEvent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// handle indicates an expected call of handle. +func (mr *MockMessageSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockMessageSubscriber)(nil).handle), arg0, arg1) +} + +// name mocks base method. +func (m *MockMessageSubscriber) name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "name") + ret0, _ := ret[0].(string) + return ret0 +} + +// name indicates an expected call of name. +func (mr *MockMessageSubscriberMockRecorder) name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockMessageSubscriber)(nil).name)) +} + +// MockLabelSubscriber is a mock of LabelSubscriber interface. +type MockLabelSubscriber struct { + ctrl *gomock.Controller + recorder *MockLabelSubscriberMockRecorder +} + +// MockLabelSubscriberMockRecorder is the mock recorder for MockLabelSubscriber. +type MockLabelSubscriberMockRecorder struct { + mock *MockLabelSubscriber +} + +// NewMockLabelSubscriber creates a new mock instance. +func NewMockLabelSubscriber(ctrl *gomock.Controller) *MockLabelSubscriber { + mock := &MockLabelSubscriber{ctrl: ctrl} + mock.recorder = &MockLabelSubscriberMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLabelSubscriber) EXPECT() *MockLabelSubscriberMockRecorder { + return m.recorder +} + +// cancel mocks base method. +func (m *MockLabelSubscriber) cancel() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "cancel") +} + +// cancel indicates an expected call of cancel. +func (mr *MockLabelSubscriberMockRecorder) cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockLabelSubscriber)(nil).cancel)) +} + +// close mocks base method. +func (m *MockLabelSubscriber) close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "close") +} + +// close indicates an expected call of close. +func (mr *MockLabelSubscriberMockRecorder) close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockLabelSubscriber)(nil).close)) +} + +// handle mocks base method. +func (m *MockLabelSubscriber) handle(arg0 context.Context, arg1 []proton.LabelEvent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// handle indicates an expected call of handle. +func (mr *MockLabelSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockLabelSubscriber)(nil).handle), arg0, arg1) +} + +// name mocks base method. +func (m *MockLabelSubscriber) name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "name") + ret0, _ := ret[0].(string) + return ret0 +} + +// name indicates an expected call of name. +func (mr *MockLabelSubscriberMockRecorder) name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockLabelSubscriber)(nil).name)) +} + +// MockAddressSubscriber is a mock of AddressSubscriber interface. +type MockAddressSubscriber struct { + ctrl *gomock.Controller + recorder *MockAddressSubscriberMockRecorder +} + +// MockAddressSubscriberMockRecorder is the mock recorder for MockAddressSubscriber. +type MockAddressSubscriberMockRecorder struct { + mock *MockAddressSubscriber +} + +// NewMockAddressSubscriber creates a new mock instance. +func NewMockAddressSubscriber(ctrl *gomock.Controller) *MockAddressSubscriber { + mock := &MockAddressSubscriber{ctrl: ctrl} + mock.recorder = &MockAddressSubscriberMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAddressSubscriber) EXPECT() *MockAddressSubscriberMockRecorder { + return m.recorder +} + +// cancel mocks base method. +func (m *MockAddressSubscriber) cancel() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "cancel") +} + +// cancel indicates an expected call of cancel. +func (mr *MockAddressSubscriberMockRecorder) cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockAddressSubscriber)(nil).cancel)) +} + +// close mocks base method. +func (m *MockAddressSubscriber) close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "close") +} + +// close indicates an expected call of close. +func (mr *MockAddressSubscriberMockRecorder) close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockAddressSubscriber)(nil).close)) +} + +// handle mocks base method. +func (m *MockAddressSubscriber) handle(arg0 context.Context, arg1 []proton.AddressEvent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// handle indicates an expected call of handle. +func (mr *MockAddressSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockAddressSubscriber)(nil).handle), arg0, arg1) +} + +// name mocks base method. +func (m *MockAddressSubscriber) name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "name") + ret0, _ := ret[0].(string) + return ret0 +} + +// name indicates an expected call of name. +func (mr *MockAddressSubscriberMockRecorder) name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockAddressSubscriber)(nil).name)) +} + +// MockRefreshSubscriber is a mock of RefreshSubscriber interface. +type MockRefreshSubscriber struct { + ctrl *gomock.Controller + recorder *MockRefreshSubscriberMockRecorder +} + +// MockRefreshSubscriberMockRecorder is the mock recorder for MockRefreshSubscriber. +type MockRefreshSubscriberMockRecorder struct { + mock *MockRefreshSubscriber +} + +// NewMockRefreshSubscriber creates a new mock instance. +func NewMockRefreshSubscriber(ctrl *gomock.Controller) *MockRefreshSubscriber { + mock := &MockRefreshSubscriber{ctrl: ctrl} + mock.recorder = &MockRefreshSubscriberMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRefreshSubscriber) EXPECT() *MockRefreshSubscriberMockRecorder { + return m.recorder +} + +// cancel mocks base method. +func (m *MockRefreshSubscriber) cancel() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "cancel") +} + +// cancel indicates an expected call of cancel. +func (mr *MockRefreshSubscriberMockRecorder) cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockRefreshSubscriber)(nil).cancel)) +} + +// close mocks base method. +func (m *MockRefreshSubscriber) close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "close") +} + +// close indicates an expected call of close. +func (mr *MockRefreshSubscriberMockRecorder) close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockRefreshSubscriber)(nil).close)) +} + +// handle mocks base method. +func (m *MockRefreshSubscriber) handle(arg0 context.Context, arg1 proton.RefreshFlag) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// handle indicates an expected call of handle. +func (mr *MockRefreshSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockRefreshSubscriber)(nil).handle), arg0, arg1) +} + +// name mocks base method. +func (m *MockRefreshSubscriber) name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "name") + ret0, _ := ret[0].(string) + return ret0 +} + +// name indicates an expected call of name. +func (mr *MockRefreshSubscriberMockRecorder) name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockRefreshSubscriber)(nil).name)) +} + +// MockUserSubscriber is a mock of UserSubscriber interface. +type MockUserSubscriber struct { + ctrl *gomock.Controller + recorder *MockUserSubscriberMockRecorder +} + +// MockUserSubscriberMockRecorder is the mock recorder for MockUserSubscriber. +type MockUserSubscriberMockRecorder struct { + mock *MockUserSubscriber +} + +// NewMockUserSubscriber creates a new mock instance. +func NewMockUserSubscriber(ctrl *gomock.Controller) *MockUserSubscriber { + mock := &MockUserSubscriber{ctrl: ctrl} + mock.recorder = &MockUserSubscriberMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUserSubscriber) EXPECT() *MockUserSubscriberMockRecorder { + return m.recorder +} + +// cancel mocks base method. +func (m *MockUserSubscriber) cancel() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "cancel") +} + +// cancel indicates an expected call of cancel. +func (mr *MockUserSubscriberMockRecorder) cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockUserSubscriber)(nil).cancel)) +} + +// close mocks base method. +func (m *MockUserSubscriber) close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "close") +} + +// close indicates an expected call of close. +func (mr *MockUserSubscriberMockRecorder) close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockUserSubscriber)(nil).close)) +} + +// handle mocks base method. +func (m *MockUserSubscriber) handle(arg0 context.Context, arg1 proton.User) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// handle indicates an expected call of handle. +func (mr *MockUserSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockUserSubscriber)(nil).handle), arg0, arg1) +} + +// name mocks base method. +func (m *MockUserSubscriber) name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "name") + ret0, _ := ret[0].(string) + return ret0 +} + +// name indicates an expected call of name. +func (mr *MockUserSubscriberMockRecorder) name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockUserSubscriber)(nil).name)) +} + +// MockUserUsedSpaceSubscriber is a mock of UserUsedSpaceSubscriber interface. +type MockUserUsedSpaceSubscriber struct { + ctrl *gomock.Controller + recorder *MockUserUsedSpaceSubscriberMockRecorder +} + +// MockUserUsedSpaceSubscriberMockRecorder is the mock recorder for MockUserUsedSpaceSubscriber. +type MockUserUsedSpaceSubscriberMockRecorder struct { + mock *MockUserUsedSpaceSubscriber +} + +// NewMockUserUsedSpaceSubscriber creates a new mock instance. +func NewMockUserUsedSpaceSubscriber(ctrl *gomock.Controller) *MockUserUsedSpaceSubscriber { + mock := &MockUserUsedSpaceSubscriber{ctrl: ctrl} + mock.recorder = &MockUserUsedSpaceSubscriberMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUserUsedSpaceSubscriber) EXPECT() *MockUserUsedSpaceSubscriberMockRecorder { + return m.recorder +} + +// cancel mocks base method. +func (m *MockUserUsedSpaceSubscriber) cancel() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "cancel") +} + +// cancel indicates an expected call of cancel. +func (mr *MockUserUsedSpaceSubscriberMockRecorder) cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).cancel)) +} + +// close mocks base method. +func (m *MockUserUsedSpaceSubscriber) close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "close") +} + +// close indicates an expected call of close. +func (mr *MockUserUsedSpaceSubscriberMockRecorder) close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).close)) +} + +// handle mocks base method. +func (m *MockUserUsedSpaceSubscriber) handle(arg0 context.Context, arg1 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handle", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// handle indicates an expected call of handle. +func (mr *MockUserUsedSpaceSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).handle), arg0, arg1) +} + +// name mocks base method. +func (m *MockUserUsedSpaceSubscriber) name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "name") + ret0, _ := ret[0].(string) + return ret0 +} + +// name indicates an expected call of name. +func (mr *MockUserUsedSpaceSubscriberMockRecorder) name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).name)) +} diff --git a/internal/services/userevents/service.go b/internal/services/userevents/service.go new file mode 100644 index 00000000..560f3575 --- /dev/null +++ b/internal/services/userevents/service.go @@ -0,0 +1,486 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package userevents + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal" + "github.com/ProtonMail/proton-bridge/v3/internal/events" + "github.com/ProtonMail/proton-bridge/v3/pkg/cpc" + "github.com/sirupsen/logrus" +) + +// Service polls from the given event source and ensures that all the respective subscribers get notified +// before proceeding to the next event. The events are published in the following order: +// * Refresh +// * User +// * Address +// * Label +// * Message +// * UserUsedSpace +// By default this service starts paused, you need to call `Service.Resume` at least one time to begin event polling. +type Service struct { + userID string + cpc *cpc.CPC + eventSource EventSource + eventIDStore EventIDStore + log *logrus.Entry + eventPublisher events.EventPublisher + timer *time.Ticker + eventTimeout time.Duration + paused bool + panicHandler async.PanicHandler + + userSubscriberList userSubscriberList + addressSubscribers addressSubscriberList + labelSubscribers labelSubscriberList + messageSubscribers messageSubscriberList + refreshSubscribers refreshSubscriberList + userUsedSpaceSubscriber userUsedSpaceSubscriberList + + pendingSubscriptionsLock sync.Mutex + pendingSubscriptionsAdd []Subscription + pendingSubscriptionsRemove []Subscription +} + +func NewService( + userID string, + eventSource EventSource, + store EventIDStore, + eventPublisher events.EventPublisher, + pollPeriod time.Duration, + eventTimeout time.Duration, + panicHandler async.PanicHandler, +) *Service { + return &Service{ + userID: userID, + cpc: cpc.NewCPC(), + eventSource: eventSource, + eventIDStore: store, + log: logrus.WithFields(logrus.Fields{ + "service": "user-events", + "user": userID, + }), + eventPublisher: eventPublisher, + timer: time.NewTicker(pollPeriod), + paused: true, + eventTimeout: eventTimeout, + panicHandler: panicHandler, + } +} + +type Subscription struct { + User UserSubscriber + Refresh RefreshSubscriber + Address AddressSubscriber + Labels LabelSubscriber + Messages MessageSubscriber + UserUsedSpace UserUsedSpaceSubscriber +} + +// cancel subscription subscribers if applicable, see `subscriber.cancel` for more information. +func (s Subscription) cancel() { + if s.User != nil { + s.User.cancel() + } + if s.Refresh != nil { + s.Refresh.cancel() + } + if s.Address != nil { + s.Address.cancel() + } + if s.Labels != nil { + s.Labels.cancel() + } + if s.Messages != nil { + s.Messages.cancel() + } + if s.UserUsedSpace != nil { + s.UserUsedSpace.cancel() + } +} + +// Subscribe adds new subscribers to the service. +// This method can safely be called during event handling. +func (s *Service) Subscribe(subscription Subscription) { + s.pendingSubscriptionsLock.Lock() + defer s.pendingSubscriptionsLock.Unlock() + + s.pendingSubscriptionsAdd = append(s.pendingSubscriptionsAdd, subscription) +} + +// Unsubscribe removes subscribers from the service. +// This method can safely be called during event handling. +func (s *Service) Unsubscribe(subscription Subscription) { + subscription.cancel() + + s.pendingSubscriptionsLock.Lock() + defer s.pendingSubscriptionsLock.Unlock() + + s.pendingSubscriptionsRemove = append(s.pendingSubscriptionsRemove, subscription) +} + +// Pause pauses the event polling. +// DO NOT CALL THIS DURING EVENT HANDLING. +func (s *Service) Pause(ctx context.Context) error { + _, err := s.cpc.Send(ctx, &pauseRequest{}) + + return err +} + +// Resume resumes the event polling. +// DO NOT CALL THIS DURING EVENT HANDLING. +func (s *Service) Resume(ctx context.Context) error { + _, err := s.cpc.Send(ctx, &resumeRequest{}) + + return err +} + +// IsPaused return true if the service is paused +// DO NOT CALL THIS DURING EVENT HANDLING. +func (s *Service) IsPaused(ctx context.Context) (bool, error) { + return cpc.SendTyped[bool](ctx, s.cpc, &isPausedRequest{}) +} + +func (s *Service) Start(ctx context.Context, group *async.Group) error { + lastEventID, err := s.eventIDStore.Load(ctx) + if err != nil { + return fmt.Errorf("failed to load last event id: %w", err) + } + + if lastEventID == "" { + s.log.Debugf("No event ID present in storage, retrieving latest") + eventID, err := s.eventSource.GetLatestEventID(ctx) + if err != nil { + return fmt.Errorf("failed to get latest event id: %w", err) + } + + if err := s.eventIDStore.Store(ctx, eventID); err != nil { + return fmt.Errorf("failed to store event in event id store: %v", err) + } + + lastEventID = eventID + } + + group.Once(func(ctx context.Context) { + s.run(ctx, lastEventID) + }) + + return nil +} + +func (s *Service) run(ctx context.Context, lastEventID string) { + s.log.Debugf("Starting service Last EventID=%v", lastEventID) + defer s.close() + defer s.log.Debug("Exiting service") + + for { + select { + case <-ctx.Done(): + return + case req, ok := <-s.cpc.ReceiveCh(): + if !ok { + return + } + + s.handleRequest(ctx, req) + continue + case <-s.timer.C: + if s.paused { + continue + } + } + + // Apply any pending subscription changes. + func() { + s.pendingSubscriptionsLock.Lock() + defer s.pendingSubscriptionsLock.Unlock() + + for _, subscription := range s.pendingSubscriptionsRemove { + s.removeSubscription(subscription) + } + + for _, subscription := range s.pendingSubscriptionsAdd { + s.addSubscription(subscription) + } + + s.pendingSubscriptionsRemove = nil + s.pendingSubscriptionsAdd = nil + }() + + newEvents, _, err := s.eventSource.GetEvent(ctx, lastEventID) + if err != nil { + s.log.WithError(err).Errorf("Failed to get event (caused by %T)", internal.ErrCause(err)) + continue + } + + // If the event ID hasn't changed, there are no new events. + if newEvents[len(newEvents)-1].EventID == lastEventID { + s.log.Debugf("No new API Events") + continue + } + + if event, eventErr := func() (proton.Event, error) { + for _, event := range newEvents { + if err := s.handleEvent(ctx, lastEventID, event); err != nil { + return event, err + } + } + + return proton.Event{}, nil + }(); eventErr != nil { + subscriberName, err := s.handleEventError(ctx, lastEventID, event, eventErr) + if subscriberName == "" { + subscriberName = "?" + } + s.log.WithField("subscriber", subscriberName).WithError(err).Errorf("Failed to apply event") + continue + } + + newEventID := newEvents[len(newEvents)-1].EventID + if err := s.eventIDStore.Store(ctx, newEventID); err != nil { + s.log.WithError(err).Errorf("Failed to store new event ID: %v", err) + s.onBadEvent(ctx, events.UserBadEvent{ + Error: fmt.Errorf("failed to store new event ID: %w", err), + UserID: s.userID, + }) + continue + } + + lastEventID = newEventID + } +} + +func (s *Service) handleEvent(ctx context.Context, lastEventID string, event proton.Event) error { + s.log.WithFields(logrus.Fields{ + "old": lastEventID, + "new": event, + }).Info("Received new API event") + + if event.Refresh&proton.RefreshMail != 0 { + s.log.Info("Handling refresh event") + if err := s.refreshSubscribers.Publish(ctx, event.Refresh, s.eventTimeout); err != nil { + return fmt.Errorf("failed to apply refresh event: %w", err) + } + + return nil + } + + // Start with user events. + if event.User != nil { + if err := s.userSubscriberList.PublishParallel(ctx, *event.User, s.panicHandler, s.eventTimeout); err != nil { + return fmt.Errorf("failed to apply user event: %w", err) + } + } + + // Next Address events + if err := s.addressSubscribers.PublishParallel(ctx, event.Addresses, s.panicHandler, s.eventTimeout); err != nil { + return fmt.Errorf("failed to apply address events: %w", err) + } + + // Next label events + if err := s.labelSubscribers.PublishParallel(ctx, event.Labels, s.panicHandler, s.eventTimeout); err != nil { + return fmt.Errorf("failed to apply label events: %w", err) + } + + // Next message events + if err := s.messageSubscribers.PublishParallel(ctx, event.Messages, s.panicHandler, s.eventTimeout); err != nil { + return fmt.Errorf("failed to apply message events: %w", err) + } + + // Finally user used space events + if event.UsedSpace != nil { + if err := s.userUsedSpaceSubscriber.PublishParallel(ctx, *event.UsedSpace, s.panicHandler, s.eventTimeout); err != nil { + return fmt.Errorf("failed to apply message events: %w", err) + } + } + + return nil +} + +func unpackPublisherError(err error) (string, error) { + var addressErr *addressPublishError + var labelErr *labelPublishError + var messageErr *messagePublishError + var refreshErr *refreshPublishError + var userErr *userPublishError + var usedSpaceErr *userUsedEventPublishError + + switch { + case errors.As(err, &userErr): + return userErr.subscriber.name(), userErr.error + case errors.As(err, &addressErr): + return addressErr.subscriber.name(), addressErr.error + case errors.As(err, &labelErr): + return labelErr.subscriber.name(), labelErr.error + case errors.As(err, &messageErr): + return messageErr.subscriber.name(), messageErr.error + case errors.As(err, &refreshErr): + return refreshErr.subscriber.name(), refreshErr.error + case errors.As(err, &usedSpaceErr): + return usedSpaceErr.subscriber.name(), usedSpaceErr.error + default: + return "", err + } +} + +func (s *Service) handleEventError(ctx context.Context, lastEventID string, event proton.Event, err error) (string, error) { + // Unpack the error so we can proceed to handle the real issue. + subscriberName, err := unpackPublisherError(err) + + // If the error is a context cancellation, return error to retry later. + if errors.Is(err, context.Canceled) { + return subscriberName, fmt.Errorf("failed to handle event due to context cancellation: %w", err) + } + + // If the error is a network error, return error to retry later. + if netErr := new(proton.NetError); errors.As(err, &netErr) { + return subscriberName, fmt.Errorf("failed to handle event due to network issue: %w", err) + } + + // Catch all for uncategorized net errors that may slip through. + if netErr := new(net.OpError); errors.As(err, &netErr) { + return subscriberName, fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err) + } + + // In case a json decode error slips through. + if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) { + s.eventPublisher.PublishEvent(ctx, events.UncategorizedEventError{ + UserID: s.userID, + Error: err, + }) + + return subscriberName, fmt.Errorf("failed to handle event due to JSON issue: %w", err) + } + + // If the error is an unexpected EOF, return error to retry later. + if errors.Is(err, io.ErrUnexpectedEOF) { + return subscriberName, fmt.Errorf("failed to handle event due to EOF: %w", err) + } + + // If the error is a server-side issue, return error to retry later. + if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 { + return subscriberName, fmt.Errorf("failed to handle event due to server error: %w", err) + } + + // Otherwise, the error is a client-side issue; notify bridge to handle it. + s.log.WithField("event", event).Warn("Failed to handle API event") + + s.onBadEvent(ctx, events.UserBadEvent{ + UserID: s.userID, + OldEventID: lastEventID, + NewEventID: event.EventID, + EventInfo: event.String(), + Error: err, + }) + + return subscriberName, fmt.Errorf("failed to handle event due to client error: %w", err) +} + +func (s *Service) onBadEvent(ctx context.Context, event events.UserBadEvent) { + s.paused = true + s.eventPublisher.PublishEvent(ctx, event) +} + +func (s *Service) handleRequest(ctx context.Context, request *cpc.Request) { + switch request.Value().(type) { + case *pauseRequest: + s.paused = true + request.Reply(ctx, nil, nil) + case *resumeRequest: + s.paused = false + request.Reply(ctx, nil, nil) + case *isPausedRequest: + request.Reply(ctx, s.paused, nil) + default: + s.log.Errorf("Unknown request") + } +} + +func (s *Service) addSubscription(subscription Subscription) { + if subscription.User != nil { + s.userSubscriberList.Add(subscription.User) + } + + if subscription.Refresh != nil { + s.refreshSubscribers.Add(subscription.Refresh) + } + + if subscription.Address != nil { + s.addressSubscribers.Add(subscription.Address) + } + + if subscription.Labels != nil { + s.labelSubscribers.Add(subscription.Labels) + } + + if subscription.Messages != nil { + s.messageSubscribers.Add(subscription.Messages) + } + + if subscription.UserUsedSpace != nil { + s.userUsedSpaceSubscriber.Add(subscription.UserUsedSpace) + } +} + +func (s *Service) removeSubscription(subscription Subscription) { + if subscription.User != nil { + s.userSubscriberList.Remove(subscription.User) + } + + if subscription.Refresh != nil { + s.refreshSubscribers.Remove(subscription.Refresh) + } + + if subscription.Address != nil { + s.addressSubscribers.Remove(subscription.Address) + } + + if subscription.Labels != nil { + s.labelSubscribers.Remove(subscription.Labels) + } + + if subscription.Messages != nil { + s.messageSubscribers.Remove(subscription.Messages) + } + + if subscription.UserUsedSpace != nil { + s.userUsedSpaceSubscriber.Remove(subscription.UserUsedSpace) + } +} + +func (s *Service) close() { + s.cpc.Close() + s.timer.Stop() +} + +type pauseRequest struct{} + +type resumeRequest struct{} + +type isPausedRequest struct{} diff --git a/internal/services/userevents/service_handle_event_error_test.go b/internal/services/userevents/service_handle_event_error_test.go new file mode 100644 index 00000000..0ea206f4 --- /dev/null +++ b/internal/services/userevents/service_handle_event_error_test.go @@ -0,0 +1,157 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package userevents + +import ( + "context" + "encoding/json" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal/events" + "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +func TestServiceHandleEventError_SubscriberEventUnwrapping(t *testing.T) { + mockCtrl := gomock.NewController(t) + eventPublisher := mocks.NewMockEventPublisher(mockCtrl) + eventIDStore := NewInMemoryEventIDStore() + + service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{}) + + lastEventID := "PrevEvent" + event := proton.Event{EventID: "MyEvent"} + subscriber := &noOpSubscriber[proton.AddressEvent]{} + + err := &addressPublishError{ + subscriber: subscriber, + error: &proton.NetError{}, + } + + subscriberName, unpackedErr := service.handleEventError(context.Background(), lastEventID, event, err) + require.Equal(t, subscriber.name(), subscriberName) + protonNetErr := new(proton.NetError) + require.True(t, errors.As(unpackedErr, &protonNetErr)) + + err2 := &proton.APIError{Status: 500} + subscriberName2, unpackedErr2 := service.handleEventError(context.Background(), lastEventID, event, err2) + require.Equal(t, "", subscriberName2) + protonAPIErr := new(proton.APIError) + require.True(t, errors.As(unpackedErr2, &protonAPIErr)) +} + +func TestServiceHandleEventError_BadEventPutsServiceOnPause(t *testing.T) { + mockCtrl := gomock.NewController(t) + eventPublisher := mocks.NewMockEventPublisher(mockCtrl) + eventIDStore := NewInMemoryEventIDStore() + + service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{}) + service.paused = false + lastEventID := "PrevEvent" + event := proton.Event{EventID: "MyEvent"} + + err := &proton.APIError{} + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserBadEvent{ + UserID: service.userID, + OldEventID: lastEventID, + NewEventID: event.EventID, + EventInfo: event.String(), + Error: err, + })).Times(1) + + _, _ = service.handleEventError(context.Background(), lastEventID, event, err) + require.True(t, service.paused) +} + +func TestServiceHandleEventError_BadEventFromPublishTimeout(t *testing.T) { + mockCtrl := gomock.NewController(t) + eventPublisher := mocks.NewMockEventPublisher(mockCtrl) + eventIDStore := NewInMemoryEventIDStore() + + service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{}) + lastEventID := "PrevEvent" + event := proton.Event{EventID: "MyEvent"} + err := ErrPublishTimeoutExceeded + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserBadEvent{ + UserID: service.userID, + OldEventID: lastEventID, + NewEventID: event.EventID, + EventInfo: event.String(), + Error: err, + })).Times(1) + + _, _ = service.handleEventError(context.Background(), lastEventID, event, err) +} + +func TestServiceHandleEventError_NoBadEventCheck(t *testing.T) { + mockCtrl := gomock.NewController(t) + eventPublisher := mocks.NewMockEventPublisher(mockCtrl) + eventIDStore := NewInMemoryEventIDStore() + + service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{}) + lastEventID := "PrevEvent" + event := proton.Event{EventID: "MyEvent"} + _, _ = service.handleEventError(context.Background(), lastEventID, event, context.Canceled) + _, _ = service.handleEventError(context.Background(), lastEventID, event, &proton.NetError{}) + _, _ = service.handleEventError(context.Background(), lastEventID, event, &net.OpError{}) + _, _ = service.handleEventError(context.Background(), lastEventID, event, io.ErrUnexpectedEOF) + _, _ = service.handleEventError(context.Background(), lastEventID, event, &proton.APIError{Status: 500}) +} + +func TestServiceHandleEventError_JsonUnmarshalEventProducesUncategorizedErrorEvent(t *testing.T) { + mockCtrl := gomock.NewController(t) + eventPublisher := mocks.NewMockEventPublisher(mockCtrl) + eventIDStore := NewInMemoryEventIDStore() + + service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{}) + lastEventID := "PrevEvent" + event := proton.Event{EventID: "MyEvent"} + err := &json.UnmarshalTypeError{} + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UncategorizedEventError{ + UserID: service.userID, + Error: err, + })).Times(1) + + _, _ = service.handleEventError(context.Background(), lastEventID, event, err) +} + +type noOpSubscriber[T any] struct{} + +func (n noOpSubscriber[T]) name() string { //nolint:unused + return "NoopSubscriber" +} + +func (n noOpSubscriber[T]) handle(_ context.Context, _ []T) error { //nolint:unused + return nil +} + +//nolint:unused +func (n noOpSubscriber[T]) close() {} // + +//nolint:unused +func (n noOpSubscriber[T]) cancel() {} diff --git a/internal/services/userevents/service_handle_event_test.go b/internal/services/userevents/service_handle_event_test.go new file mode 100644 index 00000000..d64d9da2 --- /dev/null +++ b/internal/services/userevents/service_handle_event_test.go @@ -0,0 +1,195 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package userevents + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +func TestServiceHandleEvent_CheckEventCategoriesHandledInOrder(t *testing.T) { + mockCtrl := gomock.NewController(t) + + eventPublisher := mocks.NewMockEventPublisher(mockCtrl) + eventIDStore := NewInMemoryEventIDStore() + + refreshHandler := NewMockRefreshSubscriber(mockCtrl) + refreshHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(2).Return(nil) + + userHandler := NewMockUserSubscriber(mockCtrl) + userCall := userHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(nil) + + addressHandler := NewMockAddressSubscriber(mockCtrl) + addressCall := addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(userCall).Times(1).Return(nil) + + labelHandler := NewMockLabelSubscriber(mockCtrl) + labelCall := labelHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(addressCall).Times(1).Return(nil) + + messageHandler := NewMockMessageSubscriber(mockCtrl) + messageCall := messageHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(labelCall).Times(1).Return(nil) + + userSpaceHandler := NewMockUserUsedSpaceSubscriber(mockCtrl) + userSpaceCall := userSpaceHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(messageCall).Times(1).Return(nil) + + secondRefreshHandler := NewMockRefreshSubscriber(mockCtrl) + secondRefreshHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(userSpaceCall).Times(1).Return(nil) + + service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{}) + + service.addSubscription(Subscription{ + User: userHandler, + Refresh: refreshHandler, + Address: addressHandler, + Labels: labelHandler, + Messages: messageHandler, + UserUsedSpace: userSpaceHandler, + }) + + // Simulate 1st refresh. + require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{Refresh: proton.RefreshMail})) + + // Simulate Regular event. + usedSpace := 20 + require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{ + User: new(proton.User), + Addresses: []proton.AddressEvent{}, + Labels: []proton.LabelEvent{ + {}, + }, + Messages: []proton.MessageEvent{ + {}, + }, + UsedSpace: &usedSpace, + })) + + service.addSubscription(Subscription{ + Refresh: secondRefreshHandler, + }) + + // Simulate 2nd refresh. + require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{Refresh: proton.RefreshMail})) +} + +func TestServiceHandleEvent_CheckEventFailureCausesError(t *testing.T) { + mockCtrl := gomock.NewController(t) + + eventPublisher := mocks.NewMockEventPublisher(mockCtrl) + eventIDStore := NewInMemoryEventIDStore() + + addressHandler := NewMockAddressSubscriber(mockCtrl) + addressHandler.EXPECT().name().MinTimes(1).Return("Hello") + addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(fmt.Errorf("failed")) + + messageHandler := NewMockMessageSubscriber(mockCtrl) + + service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{}) + + service.addSubscription(Subscription{ + Address: addressHandler, + Messages: messageHandler, + }) + + err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}}) + require.Error(t, err) + publisherErr := new(addressPublishError) + require.True(t, errors.As(err, &publisherErr)) + require.Equal(t, publisherErr.subscriber, addressHandler) +} + +func TestServiceHandleEvent_CheckEventFailureCausesErrorParallel(t *testing.T) { + mockCtrl := gomock.NewController(t) + + eventPublisher := mocks.NewMockEventPublisher(mockCtrl) + eventIDStore := NewInMemoryEventIDStore() + + addressHandler := NewMockAddressSubscriber(mockCtrl) + addressHandler.EXPECT().name().MinTimes(1).Return("Hello") + addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(fmt.Errorf("failed")) + + addressHandler2 := NewMockAddressSubscriber(mockCtrl) + addressHandler2.EXPECT().handle(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil) + + service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{}) + + service.addSubscription(Subscription{ + Address: addressHandler, + }) + + service.addSubscription(Subscription{ + Address: addressHandler2, + }) + + err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}}) + require.Error(t, err) + publisherErr := new(addressPublishError) + require.True(t, errors.As(err, &publisherErr)) + require.Equal(t, publisherErr.subscriber, addressHandler) +} + +func TestServiceHandleEvent_SubscriberTimeout(t *testing.T) { + mockCtrl := gomock.NewController(t) + + eventPublisher := mocks.NewMockEventPublisher(mockCtrl) + eventIDStore := NewInMemoryEventIDStore() + + addressHandler := NewMockAddressSubscriber(mockCtrl) + addressHandler.EXPECT().name().AnyTimes().Return("Ok") + addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil) + + addressHandler2 := NewMockAddressSubscriber(mockCtrl) + addressHandler2.EXPECT().name().AnyTimes().Return("Timeout") + addressHandler2.EXPECT().handle(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, _ []proton.AddressEvent) error { + timer := time.NewTimer(time.Second) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } + }).MaxTimes(1) + + service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, 500*time.Millisecond, async.NoopPanicHandler{}) + + service.addSubscription(Subscription{ + Address: addressHandler, + }) + + service.addSubscription(Subscription{ + Address: addressHandler2, + }) + + // Simulate 1st refresh. + err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}}) + require.Error(t, err) + if publisherErr := new(addressPublishError); errors.As(err, &publisherErr) { + require.Equal(t, publisherErr.subscriber, addressHandler) + require.True(t, errors.Is(publisherErr.error, ErrPublishTimeoutExceeded)) + } else { + require.True(t, errors.Is(err, ErrPublishTimeoutExceeded)) + } +} diff --git a/internal/services/userevents/service_test.go b/internal/services/userevents/service_test.go new file mode 100644 index 00000000..fc4afc25 --- /dev/null +++ b/internal/services/userevents/service_test.go @@ -0,0 +1,269 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . +package userevents + +import ( + "context" + "fmt" + "io" + "testing" + "time" + + "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal/events" + mocks2 "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks" + "github.com/ProtonMail/proton-bridge/v3/internal/services/userevents/mocks" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +func TestService_EventIDLoadStore(t *testing.T) { + // Simulate the flow of data when we start without any event id in the event id store: + // * Load event id from store + // * Get latest event id since + // * Store that latest event id + // * Start event poll loop + // * Get new event id, store it in vault + // * Try to poll new event it, but context is cancelled + group := async.NewGroup(context.Background(), &async.NoopPanicHandler{}) + mockCtrl := gomock.NewController(t) + eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) + eventIDStore := mocks.NewMockEventIDStore(mockCtrl) + eventSource := mocks.NewMockEventSource(mockCtrl) + + firstEventID := "EVENT01" + secondEventID := "EVENT02" + secondEvent := []proton.Event{{ + EventID: secondEventID, + }} + + // Event id store expectations. + eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return("", nil) + eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(firstEventID)).Times(1).Return(nil) + eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error { + // Force exit, we have finished executing what we expected. + group.Cancel() + return nil + }) + + // Event Source expectations. + eventSource.EXPECT().GetLatestEventID(gomock.Any()).Times(1).Return(firstEventID, nil) + eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil) + + service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{}) + require.NoError(t, service.Start(context.Background(), group)) + require.NoError(t, service.Resume(context.Background())) + group.WaitToFinish() +} + +func TestService_RetryEventOnNonCatastrophicFailure(t *testing.T) { + group := async.NewGroup(context.Background(), &async.NoopPanicHandler{}) + mockCtrl := gomock.NewController(t) + eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) + eventIDStore := mocks.NewMockEventIDStore(mockCtrl) + eventSource := mocks.NewMockEventSource(mockCtrl) + subscriber := NewMockMessageSubscriber(mockCtrl) + + firstEventID := "EVENT01" + secondEventID := "EVENT02" + messageEvents := []proton.MessageEvent{ + { + EventItem: proton.EventItem{ID: "Message"}, + }, + } + secondEvent := []proton.Event{{ + EventID: secondEventID, + Messages: messageEvents, + }} + + // Event id store expectations. + eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil) + eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error { + // Force exit, we have finished executing what we expected. + group.Cancel() + return nil + }) + + // Event Source expectations. + eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil) + + // Subscriber expectations. + subscriber.EXPECT().name().AnyTimes().Return("Foo") + { + firstCall := subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).Return(io.ErrUnexpectedEOF) + subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).After(firstCall).Times(1).Return(nil) + } + + service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{}) + service.Subscribe(Subscription{Messages: subscriber}) + + require.NoError(t, service.Start(context.Background(), group)) + require.NoError(t, service.Resume(context.Background())) + group.WaitToFinish() +} + +func TestService_OnBadEventServiceIsPaused(t *testing.T) { + group := async.NewGroup(context.Background(), &async.NoopPanicHandler{}) + mockCtrl := gomock.NewController(t) + eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) + eventIDStore := mocks.NewMockEventIDStore(mockCtrl) + eventSource := mocks.NewMockEventSource(mockCtrl) + subscriber := NewMockMessageSubscriber(mockCtrl) + + firstEventID := "EVENT01" + secondEventID := "EVENT02" + messageEvents := []proton.MessageEvent{ + { + EventItem: proton.EventItem{ID: "Message"}, + }, + } + secondEvent := []proton.Event{{ + EventID: secondEventID, + Messages: messageEvents, + }} + + // Event id store expectations. + eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil) + + // Event Source expectations. + eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil) + + // Subscriber expectations. + badEventErr := fmt.Errorf("I will cause bad event") + subscriber.EXPECT().name().AnyTimes().Return("Foo") + subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).Return(badEventErr) + + service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{}) + + // Event publisher expectations. + eventPublisher.EXPECT().PublishEvent(gomock.Any(), events.UserBadEvent{ + UserID: "foo", + OldEventID: firstEventID, + NewEventID: secondEventID, + EventInfo: secondEvent[0].String(), + Error: badEventErr, + }).Do(func(_ context.Context, event events.Event) { + group.Once(func(_ context.Context) { + // Use background context to avoid having the request cancelled + paused, err := service.IsPaused(context.Background()) + require.NoError(t, err) + require.True(t, paused) + group.Cancel() + }) + }) + + service.Subscribe(Subscription{Messages: subscriber}) + require.NoError(t, service.Start(context.Background(), group)) + require.NoError(t, service.Resume(context.Background())) + group.WaitToFinish() +} + +func TestService_UnsubscribeDuringEventHandlingDoesNotCauseDeadlock(t *testing.T) { + group := async.NewGroup(context.Background(), &async.NoopPanicHandler{}) + mockCtrl := gomock.NewController(t) + eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) + eventIDStore := mocks.NewMockEventIDStore(mockCtrl) + eventSource := mocks.NewMockEventSource(mockCtrl) + subscriber := NewMockMessageSubscriber(mockCtrl) + + firstEventID := "EVENT01" + secondEventID := "EVENT02" + messageEvents := []proton.MessageEvent{ + { + EventItem: proton.EventItem{ID: "Message"}, + }, + } + secondEvent := []proton.Event{{ + EventID: secondEventID, + Messages: messageEvents, + }} + + // Event id store expectations. + eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil) + eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error { + // Force exit, we have finished executing what we expected. + group.Cancel() + return nil + }) + + // Event Source expectations. + eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil) + + service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{}) + + // Subscriber expectations. + subscriber.EXPECT().name().AnyTimes().Return("Foo") + subscriber.EXPECT().cancel().Times(1) + subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).DoAndReturn(func(_ context.Context, _ []proton.MessageEvent) error { + service.Unsubscribe(Subscription{Messages: subscriber}) + return nil + }) + + service.Subscribe(Subscription{Messages: subscriber}) + require.NoError(t, service.Start(context.Background(), group)) + require.NoError(t, service.Resume(context.Background())) + group.WaitToFinish() +} + +func TestService_UnsubscribeBeforeHandlingEventIsNotConsideredError(t *testing.T) { + group := async.NewGroup(context.Background(), &async.NoopPanicHandler{}) + mockCtrl := gomock.NewController(t) + eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) + eventIDStore := mocks.NewMockEventIDStore(mockCtrl) + eventSource := mocks.NewMockEventSource(mockCtrl) + subscriber := newChanneledSubscriber[proton.MessageEvent]("My subscriber") + + firstEventID := "EVENT01" + secondEventID := "EVENT02" + messageEvents := []proton.MessageEvent{ + { + EventItem: proton.EventItem{ID: "Message"}, + }, + } + secondEvent := []proton.Event{{ + EventID: secondEventID, + Messages: messageEvents, + }} + + // Event id store expectations. + eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil) + eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error { + // Force exit, we have finished executing what we expected. + group.Cancel() + return nil + }) + + // Event Source expectations. + eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil) + eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(secondEventID)).AnyTimes().Return(secondEvent, false, nil) + + service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{}) + + // start subscriber + group.Once(func(_ context.Context) { + defer service.Unsubscribe(Subscription{Messages: subscriber}) + + // Simulate the reception of an event, but it is never handled due to unexpected exit + <-time.NewTicker(500 * time.Millisecond).C + }) + + service.Subscribe(Subscription{Messages: subscriber}) + require.NoError(t, service.Start(context.Background(), group)) + require.NoError(t, service.Resume(context.Background())) + group.WaitToFinish() +} diff --git a/internal/services/userevents/subscriber.go b/internal/services/userevents/subscriber.go new file mode 100644 index 00000000..16e92f92 --- /dev/null +++ b/internal/services/userevents/subscriber.go @@ -0,0 +1,252 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package userevents + +import ( + "context" + "errors" + "fmt" + "runtime" + "time" + + "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/go-proton-api" + "github.com/bradenaw/juniper/parallel" + "github.com/bradenaw/juniper/xslices" + "golang.org/x/exp/slices" +) + +func NewMessageSubscriber(name string) *ChanneledSubscriber[proton.MessageEvent] { + return newChanneledSubscriber[proton.MessageEvent](name) +} + +func NewAddressSubscriber(name string) *ChanneledSubscriber[proton.AddressEvent] { + return newChanneledSubscriber[proton.AddressEvent](name) +} + +func NewLabelSubscriber(name string) *ChanneledSubscriber[proton.LabelEvent] { + return newChanneledSubscriber[proton.LabelEvent](name) +} + +func NewRefreshSubscriber(name string) *ChanneledSubscriber[struct{}] { + return newChanneledSubscriber[struct{}](name) +} + +func NewUserSubscriber(name string) *ChanneledSubscriber[proton.User] { + return newChanneledSubscriber[proton.User](name) +} + +func NewUserUsedSpaceSubscriber(name string) *ChanneledSubscriber[int] { + return newChanneledSubscriber[int](name) +} + +type AddressSubscriber = subscriber[[]proton.AddressEvent] +type LabelSubscriber = subscriber[[]proton.LabelEvent] +type MessageSubscriber = subscriber[[]proton.MessageEvent] +type RefreshSubscriber = subscriber[proton.RefreshFlag] +type UserSubscriber = subscriber[proton.User] +type UserUsedSpaceSubscriber = subscriber[int] + +// Subscriber is the main entry point of interacting with user generated events. +type subscriber[T any] interface { + // Name returns the identifier for this subscriber + name() string + // Handle the event list. + handle(context.Context, T) error + // cancel is behavior extension for channel based subscribers so that they can ensure that + // if a subscriber unsubscribes, it doesn't cause pending events on the channel to time-out as there is no one to handle + // them. + cancel() + // close release all associated resources + close() +} + +type subscriberList[T any] struct { + subscribers []subscriber[T] +} + +type addressSubscriberList = subscriberList[[]proton.AddressEvent] +type labelSubscriberList = subscriberList[[]proton.LabelEvent] +type messageSubscriberList = subscriberList[[]proton.MessageEvent] +type refreshSubscriberList = subscriberList[proton.RefreshFlag] +type userSubscriberList = subscriberList[proton.User] +type userUsedSpaceSubscriberList = subscriberList[int] + +func (s *subscriberList[T]) Add(subscriber subscriber[T]) { + if !slices.Contains(s.subscribers, subscriber) { + s.subscribers = append(s.subscribers, subscriber) + } +} + +func (s *subscriberList[T]) Remove(subscriber subscriber[T]) { + index := slices.Index(s.subscribers, subscriber) + if index < 0 { + return + } + + s.subscribers[index].close() + s.subscribers = xslices.Remove(s.subscribers, index, 1) +} + +type publishError[T any] struct { + subscriber subscriber[T] + error error +} + +var ErrPublishTimeoutExceeded = errors.New("event publish timed out") + +type addressPublishError = publishError[[]proton.AddressEvent] +type labelPublishError = publishError[[]proton.LabelEvent] +type messagePublishError = publishError[[]proton.MessageEvent] +type refreshPublishError = publishError[proton.RefreshFlag] +type userPublishError = publishError[proton.User] +type userUsedEventPublishError = publishError[int] + +func (p publishError[T]) Error() string { + return fmt.Sprintf("Event publish failed on (%v): %v", p.subscriber.name(), p.error.Error()) +} + +func (s *subscriberList[T]) Publish(ctx context.Context, event T, timeout time.Duration) error { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + + for _, subscriber := range s.subscribers { + if err := subscriber.handle(ctx, event); err != nil { + return &publishError[T]{ + subscriber: subscriber, + error: mapContextTimeoutError(err), + } + } + + if err := ctx.Err(); err != nil { + return &publishError[T]{ + subscriber: subscriber, + error: mapContextTimeoutError(err), + } + } + } + + return nil +} + +func mapContextTimeoutError(err error) error { + if errors.Is(err, context.DeadlineExceeded) { + return ErrPublishTimeoutExceeded + } + + return err +} + +func (s *subscriberList[T]) PublishParallel( + ctx context.Context, + event T, + panicHandler async.PanicHandler, + timeout time.Duration, +) error { + if len(s.subscribers) <= 1 { + return s.Publish(ctx, event, timeout) + } + + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + + err := parallel.DoContext(ctx, runtime.NumCPU()/2, len(s.subscribers), func(ctx context.Context, index int) error { + defer async.HandlePanic(panicHandler) + if err := s.subscribers[index].handle(ctx, event); err != nil { + return &publishError[T]{ + subscriber: s.subscribers[index], + error: mapContextTimeoutError(err), + } + } + + return nil + }) + + return mapContextTimeoutError(err) +} + +type ChanneledSubscriber[T any] struct { + id string + sender chan *ChanneledSubscriberEvent[T] +} + +func newChanneledSubscriber[T any](name string) *ChanneledSubscriber[T] { + return &ChanneledSubscriber[T]{ + id: name, + sender: make(chan *ChanneledSubscriberEvent[T]), + } +} + +type ChanneledSubscriberEvent[T any] struct { + data []T + response chan error +} + +func (c ChanneledSubscriberEvent[T]) Consume(f func([]T) error) { + if err := f(c.data); err != nil { + c.response <- err + } + close(c.response) +} + +func (c *ChanneledSubscriber[T]) name() string { //nolint:unused + return c.id +} + +func (c *ChanneledSubscriber[T]) handle(ctx context.Context, event []T) error { //nolint:unused + data := &ChanneledSubscriberEvent[T]{ + data: event, + response: make(chan error), + } + // Send Event + select { + case <-ctx.Done(): + return fmt.Errorf("failed to send event: %w", ctx.Err()) + case c.sender <- data: + // + } + + // Wait on Reply + select { + case <-ctx.Done(): + return fmt.Errorf("failed to receive event reply: %w", ctx.Err()) + case reply := <-data.response: + return reply + } +} + +func (c *ChanneledSubscriber[T]) OnEventCh() <-chan *ChanneledSubscriberEvent[T] { + return c.sender +} + +func (c *ChanneledSubscriber[T]) close() { //nolint:unused + close(c.sender) +} + +func (c *ChanneledSubscriber[T]) cancel() { //nolint:unused + go func() { + for { + e, ok := <-c.sender + if !ok { + return + } + + e.Consume(func(_ []T) error { return nil }) + } + }() +} diff --git a/internal/services/userevents/subscriber_test.go b/internal/services/userevents/subscriber_test.go new file mode 100644 index 00000000..2bb0b806 --- /dev/null +++ b/internal/services/userevents/subscriber_test.go @@ -0,0 +1,105 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package userevents + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestChanneledSubscriber_CtxTimeoutDoesNotBlockFutureEvents(t *testing.T) { + wg := sync.WaitGroup{} + wg.Add(1) + + subscriber := newChanneledSubscriber[int]("test") + defer subscriber.close() + + go func() { + defer wg.Done() + + // Send one event, that succeeds. + require.NoError(t, subscriber.handle(context.Background(), []int{30})) + + // Add an impossible deadline that fails immediately to simulate on event taking too long to send. + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Microsecond)) + defer cancel() + + err := subscriber.handle(ctx, []int{20}) + require.Error(t, err) + require.True(t, errors.Is(err, context.DeadlineExceeded)) + }() + + // Receive first event. Notify success. + event, ok := <-subscriber.OnEventCh() + require.True(t, ok) + event.Consume(func(event []int) error { + require.Equal(t, []int{30}, event) + return nil + }) + wg.Wait() + + // Simulate reception of another event + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(t, subscriber.handle(context.Background(), []int{40})) + }() + + event, ok = <-subscriber.OnEventCh() + require.True(t, ok) + event.Consume(func(event []int) error { + require.Equal(t, []int{40}, event) + return nil + }) + + wg.Wait() +} + +func TestChanneledSubscriber_ErrorReported(t *testing.T) { + wg := sync.WaitGroup{} + wg.Add(1) + + subscriber := newChanneledSubscriber[int]("test") + defer subscriber.close() + reportedErr := fmt.Errorf("request failed") + + go func() { + defer wg.Done() + + // Send one event, that succeeds. + err := subscriber.handle(context.Background(), []int{30}) + require.Error(t, err) + require.Equal(t, reportedErr, err) + }() + + // Receive first event. Notify success. + event, ok := <-subscriber.OnEventCh() + require.True(t, ok) + event.Consume(func(event []int) error { + require.Equal(t, []int{30}, event) + return reportedErr + }) + + wg.Wait() +}