feat(GODT-2800): User Event Service

This patch adds the User Event Service which is meant to replace the
current event polling flow.

Each user interested in receiving events should register a new
subscriber using the `Service.Subscribe` function and then react on
the incoming events.

The current patch does not hook this up Bridge user as there are no
existing consumers, but it does provide extensive testing for the
expected behavior.
This commit is contained in:
Leander Beernaert
2023-07-19 12:48:09 +02:00
parent 110286b81c
commit 82efa16d65
13 changed files with 2224 additions and 1 deletions

View File

@ -281,6 +281,13 @@ mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/updater Downloader,Installer > internal/updater/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/telemetry HeartbeatManager > internal/telemetry/mocks/mocks.go
cp internal/telemetry/mocks/mocks.go internal/bridge/mocks/telemetry_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \
EventSource,EventIDStore > internal/services/userevents/mocks/mocks.go
mockgen --package userevents github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \
MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscriber,UserUsedSpaceSubscriber > tmp
mv tmp internal/services/userevents/mocks_test.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/events EventPublisher \
> internal/events/mocks/mocks.go
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog

View File

@ -17,7 +17,10 @@
package events
import "fmt"
import (
"context"
"fmt"
)
type Event interface {
fmt.Stringer
@ -28,3 +31,7 @@ type Event interface {
type eventBase struct{}
func (eventBase) _isEvent() {}
type EventPublisher interface {
PublishEvent(ctx context.Context, event Event)
}

View File

@ -0,0 +1,48 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v3/internal/events (interfaces: EventPublisher)
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
reflect "reflect"
events "github.com/ProtonMail/proton-bridge/v3/internal/events"
gomock "github.com/golang/mock/gomock"
)
// MockEventPublisher is a mock of EventPublisher interface.
type MockEventPublisher struct {
ctrl *gomock.Controller
recorder *MockEventPublisherMockRecorder
}
// MockEventPublisherMockRecorder is the mock recorder for MockEventPublisher.
type MockEventPublisherMockRecorder struct {
mock *MockEventPublisher
}
// NewMockEventPublisher creates a new mock instance.
func NewMockEventPublisher(ctrl *gomock.Controller) *MockEventPublisher {
mock := &MockEventPublisher{ctrl: ctrl}
mock.recorder = &MockEventPublisherMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockEventPublisher) EXPECT() *MockEventPublisherMockRecorder {
return m.recorder
}
// PublishEvent mocks base method.
func (m *MockEventPublisher) PublishEvent(arg0 context.Context, arg1 events.Event) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "PublishEvent", arg0, arg1)
}
// PublishEvent indicates an expected call of PublishEvent.
func (mr *MockEventPublisherMockRecorder) PublishEvent(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishEvent", reflect.TypeOf((*MockEventPublisher)(nil).PublishEvent), arg0, arg1)
}

View File

@ -0,0 +1,40 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package userevents
import (
"context"
"github.com/ProtonMail/go-proton-api"
)
// EventSource represents a type which produces proton events.
type EventSource interface {
GetLatestEventID(ctx context.Context) (string, error)
GetEvent(ctx context.Context, id string) ([]proton.Event, bool, error)
}
type NullEventSource struct{}
func (n NullEventSource) GetLatestEventID(_ context.Context) (string, error) {
return "", nil
}
func (n NullEventSource) GetEvent(_ context.Context, _ string) ([]proton.Event, bool, error) {
return nil, false, nil
}

View File

@ -0,0 +1,75 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package userevents
import (
"context"
"sync"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
)
// EventIDStore exposes behavior expected of a type which allows us to store and retrieve event Ids.
// Note: this may be accessed from multiple go-routines.
type EventIDStore interface {
// Load the last stored event, return "" for empty.
Load(ctx context.Context) (string, error)
// Store the new id.
Store(ctx context.Context, id string) error
}
type InMemoryEventIDStore struct {
lock sync.Mutex
id string
}
func NewInMemoryEventIDStore() *InMemoryEventIDStore {
return &InMemoryEventIDStore{}
}
func (i *InMemoryEventIDStore) Load(_ context.Context) (string, error) {
i.lock.Lock()
defer i.lock.Unlock()
return i.id, nil
}
func (i *InMemoryEventIDStore) Store(_ context.Context, id string) error {
i.lock.Lock()
defer i.lock.Unlock()
i.id = id
return nil
}
type VaultEventIDStore struct {
vault *vault.User
}
func NewVaultEventIDStore(vault *VaultEventIDStore) *VaultEventIDStore {
return &VaultEventIDStore{vault: vault.vault}
}
func (v VaultEventIDStore) Load(_ context.Context) (string, error) {
return v.vault.EventID(), nil
}
func (v VaultEventIDStore) Store(_ context.Context, id string) error {
return v.vault.SetEventID(id)
}

View File

@ -0,0 +1,119 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/userevents (interfaces: EventSource,EventIDStore)
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
reflect "reflect"
proton "github.com/ProtonMail/go-proton-api"
gomock "github.com/golang/mock/gomock"
)
// MockEventSource is a mock of EventSource interface.
type MockEventSource struct {
ctrl *gomock.Controller
recorder *MockEventSourceMockRecorder
}
// MockEventSourceMockRecorder is the mock recorder for MockEventSource.
type MockEventSourceMockRecorder struct {
mock *MockEventSource
}
// NewMockEventSource creates a new mock instance.
func NewMockEventSource(ctrl *gomock.Controller) *MockEventSource {
mock := &MockEventSource{ctrl: ctrl}
mock.recorder = &MockEventSourceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockEventSource) EXPECT() *MockEventSourceMockRecorder {
return m.recorder
}
// GetEvent mocks base method.
func (m *MockEventSource) GetEvent(arg0 context.Context, arg1 string) ([]proton.Event, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEvent", arg0, arg1)
ret0, _ := ret[0].([]proton.Event)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetEvent indicates an expected call of GetEvent.
func (mr *MockEventSourceMockRecorder) GetEvent(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvent", reflect.TypeOf((*MockEventSource)(nil).GetEvent), arg0, arg1)
}
// GetLatestEventID mocks base method.
func (m *MockEventSource) GetLatestEventID(arg0 context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetLatestEventID", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetLatestEventID indicates an expected call of GetLatestEventID.
func (mr *MockEventSourceMockRecorder) GetLatestEventID(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestEventID", reflect.TypeOf((*MockEventSource)(nil).GetLatestEventID), arg0)
}
// MockEventIDStore is a mock of EventIDStore interface.
type MockEventIDStore struct {
ctrl *gomock.Controller
recorder *MockEventIDStoreMockRecorder
}
// MockEventIDStoreMockRecorder is the mock recorder for MockEventIDStore.
type MockEventIDStoreMockRecorder struct {
mock *MockEventIDStore
}
// NewMockEventIDStore creates a new mock instance.
func NewMockEventIDStore(ctrl *gomock.Controller) *MockEventIDStore {
mock := &MockEventIDStore{ctrl: ctrl}
mock.recorder = &MockEventIDStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockEventIDStore) EXPECT() *MockEventIDStoreMockRecorder {
return m.recorder
}
// Load mocks base method.
func (m *MockEventIDStore) Load(arg0 context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Load", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Load indicates an expected call of Load.
func (mr *MockEventIDStoreMockRecorder) Load(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockEventIDStore)(nil).Load), arg0)
}
// Store mocks base method.
func (m *MockEventIDStore) Store(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Store", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Store indicates an expected call of Store.
func (mr *MockEventIDStoreMockRecorder) Store(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Store", reflect.TypeOf((*MockEventIDStore)(nil).Store), arg0, arg1)
}

View File

@ -0,0 +1,463 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/userevents (interfaces: MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscriber,UserUsedSpaceSubscriber)
// Package userevents is a generated GoMock package.
package userevents
import (
context "context"
reflect "reflect"
proton "github.com/ProtonMail/go-proton-api"
gomock "github.com/golang/mock/gomock"
)
// MockMessageSubscriber is a mock of MessageSubscriber interface.
type MockMessageSubscriber struct {
ctrl *gomock.Controller
recorder *MockMessageSubscriberMockRecorder
}
// MockMessageSubscriberMockRecorder is the mock recorder for MockMessageSubscriber.
type MockMessageSubscriberMockRecorder struct {
mock *MockMessageSubscriber
}
// NewMockMessageSubscriber creates a new mock instance.
func NewMockMessageSubscriber(ctrl *gomock.Controller) *MockMessageSubscriber {
mock := &MockMessageSubscriber{ctrl: ctrl}
mock.recorder = &MockMessageSubscriberMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMessageSubscriber) EXPECT() *MockMessageSubscriberMockRecorder {
return m.recorder
}
// cancel mocks base method.
func (m *MockMessageSubscriber) cancel() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "cancel")
}
// cancel indicates an expected call of cancel.
func (mr *MockMessageSubscriberMockRecorder) cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockMessageSubscriber)(nil).cancel))
}
// close mocks base method.
func (m *MockMessageSubscriber) close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "close")
}
// close indicates an expected call of close.
func (mr *MockMessageSubscriberMockRecorder) close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockMessageSubscriber)(nil).close))
}
// handle mocks base method.
func (m *MockMessageSubscriber) handle(arg0 context.Context, arg1 []proton.MessageEvent) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "handle", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// handle indicates an expected call of handle.
func (mr *MockMessageSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockMessageSubscriber)(nil).handle), arg0, arg1)
}
// name mocks base method.
func (m *MockMessageSubscriber) name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "name")
ret0, _ := ret[0].(string)
return ret0
}
// name indicates an expected call of name.
func (mr *MockMessageSubscriberMockRecorder) name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockMessageSubscriber)(nil).name))
}
// MockLabelSubscriber is a mock of LabelSubscriber interface.
type MockLabelSubscriber struct {
ctrl *gomock.Controller
recorder *MockLabelSubscriberMockRecorder
}
// MockLabelSubscriberMockRecorder is the mock recorder for MockLabelSubscriber.
type MockLabelSubscriberMockRecorder struct {
mock *MockLabelSubscriber
}
// NewMockLabelSubscriber creates a new mock instance.
func NewMockLabelSubscriber(ctrl *gomock.Controller) *MockLabelSubscriber {
mock := &MockLabelSubscriber{ctrl: ctrl}
mock.recorder = &MockLabelSubscriberMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLabelSubscriber) EXPECT() *MockLabelSubscriberMockRecorder {
return m.recorder
}
// cancel mocks base method.
func (m *MockLabelSubscriber) cancel() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "cancel")
}
// cancel indicates an expected call of cancel.
func (mr *MockLabelSubscriberMockRecorder) cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockLabelSubscriber)(nil).cancel))
}
// close mocks base method.
func (m *MockLabelSubscriber) close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "close")
}
// close indicates an expected call of close.
func (mr *MockLabelSubscriberMockRecorder) close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockLabelSubscriber)(nil).close))
}
// handle mocks base method.
func (m *MockLabelSubscriber) handle(arg0 context.Context, arg1 []proton.LabelEvent) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "handle", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// handle indicates an expected call of handle.
func (mr *MockLabelSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockLabelSubscriber)(nil).handle), arg0, arg1)
}
// name mocks base method.
func (m *MockLabelSubscriber) name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "name")
ret0, _ := ret[0].(string)
return ret0
}
// name indicates an expected call of name.
func (mr *MockLabelSubscriberMockRecorder) name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockLabelSubscriber)(nil).name))
}
// MockAddressSubscriber is a mock of AddressSubscriber interface.
type MockAddressSubscriber struct {
ctrl *gomock.Controller
recorder *MockAddressSubscriberMockRecorder
}
// MockAddressSubscriberMockRecorder is the mock recorder for MockAddressSubscriber.
type MockAddressSubscriberMockRecorder struct {
mock *MockAddressSubscriber
}
// NewMockAddressSubscriber creates a new mock instance.
func NewMockAddressSubscriber(ctrl *gomock.Controller) *MockAddressSubscriber {
mock := &MockAddressSubscriber{ctrl: ctrl}
mock.recorder = &MockAddressSubscriberMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockAddressSubscriber) EXPECT() *MockAddressSubscriberMockRecorder {
return m.recorder
}
// cancel mocks base method.
func (m *MockAddressSubscriber) cancel() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "cancel")
}
// cancel indicates an expected call of cancel.
func (mr *MockAddressSubscriberMockRecorder) cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockAddressSubscriber)(nil).cancel))
}
// close mocks base method.
func (m *MockAddressSubscriber) close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "close")
}
// close indicates an expected call of close.
func (mr *MockAddressSubscriberMockRecorder) close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockAddressSubscriber)(nil).close))
}
// handle mocks base method.
func (m *MockAddressSubscriber) handle(arg0 context.Context, arg1 []proton.AddressEvent) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "handle", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// handle indicates an expected call of handle.
func (mr *MockAddressSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockAddressSubscriber)(nil).handle), arg0, arg1)
}
// name mocks base method.
func (m *MockAddressSubscriber) name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "name")
ret0, _ := ret[0].(string)
return ret0
}
// name indicates an expected call of name.
func (mr *MockAddressSubscriberMockRecorder) name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockAddressSubscriber)(nil).name))
}
// MockRefreshSubscriber is a mock of RefreshSubscriber interface.
type MockRefreshSubscriber struct {
ctrl *gomock.Controller
recorder *MockRefreshSubscriberMockRecorder
}
// MockRefreshSubscriberMockRecorder is the mock recorder for MockRefreshSubscriber.
type MockRefreshSubscriberMockRecorder struct {
mock *MockRefreshSubscriber
}
// NewMockRefreshSubscriber creates a new mock instance.
func NewMockRefreshSubscriber(ctrl *gomock.Controller) *MockRefreshSubscriber {
mock := &MockRefreshSubscriber{ctrl: ctrl}
mock.recorder = &MockRefreshSubscriberMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRefreshSubscriber) EXPECT() *MockRefreshSubscriberMockRecorder {
return m.recorder
}
// cancel mocks base method.
func (m *MockRefreshSubscriber) cancel() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "cancel")
}
// cancel indicates an expected call of cancel.
func (mr *MockRefreshSubscriberMockRecorder) cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockRefreshSubscriber)(nil).cancel))
}
// close mocks base method.
func (m *MockRefreshSubscriber) close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "close")
}
// close indicates an expected call of close.
func (mr *MockRefreshSubscriberMockRecorder) close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockRefreshSubscriber)(nil).close))
}
// handle mocks base method.
func (m *MockRefreshSubscriber) handle(arg0 context.Context, arg1 proton.RefreshFlag) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "handle", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// handle indicates an expected call of handle.
func (mr *MockRefreshSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockRefreshSubscriber)(nil).handle), arg0, arg1)
}
// name mocks base method.
func (m *MockRefreshSubscriber) name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "name")
ret0, _ := ret[0].(string)
return ret0
}
// name indicates an expected call of name.
func (mr *MockRefreshSubscriberMockRecorder) name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockRefreshSubscriber)(nil).name))
}
// MockUserSubscriber is a mock of UserSubscriber interface.
type MockUserSubscriber struct {
ctrl *gomock.Controller
recorder *MockUserSubscriberMockRecorder
}
// MockUserSubscriberMockRecorder is the mock recorder for MockUserSubscriber.
type MockUserSubscriberMockRecorder struct {
mock *MockUserSubscriber
}
// NewMockUserSubscriber creates a new mock instance.
func NewMockUserSubscriber(ctrl *gomock.Controller) *MockUserSubscriber {
mock := &MockUserSubscriber{ctrl: ctrl}
mock.recorder = &MockUserSubscriberMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockUserSubscriber) EXPECT() *MockUserSubscriberMockRecorder {
return m.recorder
}
// cancel mocks base method.
func (m *MockUserSubscriber) cancel() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "cancel")
}
// cancel indicates an expected call of cancel.
func (mr *MockUserSubscriberMockRecorder) cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockUserSubscriber)(nil).cancel))
}
// close mocks base method.
func (m *MockUserSubscriber) close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "close")
}
// close indicates an expected call of close.
func (mr *MockUserSubscriberMockRecorder) close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockUserSubscriber)(nil).close))
}
// handle mocks base method.
func (m *MockUserSubscriber) handle(arg0 context.Context, arg1 proton.User) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "handle", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// handle indicates an expected call of handle.
func (mr *MockUserSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockUserSubscriber)(nil).handle), arg0, arg1)
}
// name mocks base method.
func (m *MockUserSubscriber) name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "name")
ret0, _ := ret[0].(string)
return ret0
}
// name indicates an expected call of name.
func (mr *MockUserSubscriberMockRecorder) name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockUserSubscriber)(nil).name))
}
// MockUserUsedSpaceSubscriber is a mock of UserUsedSpaceSubscriber interface.
type MockUserUsedSpaceSubscriber struct {
ctrl *gomock.Controller
recorder *MockUserUsedSpaceSubscriberMockRecorder
}
// MockUserUsedSpaceSubscriberMockRecorder is the mock recorder for MockUserUsedSpaceSubscriber.
type MockUserUsedSpaceSubscriberMockRecorder struct {
mock *MockUserUsedSpaceSubscriber
}
// NewMockUserUsedSpaceSubscriber creates a new mock instance.
func NewMockUserUsedSpaceSubscriber(ctrl *gomock.Controller) *MockUserUsedSpaceSubscriber {
mock := &MockUserUsedSpaceSubscriber{ctrl: ctrl}
mock.recorder = &MockUserUsedSpaceSubscriberMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockUserUsedSpaceSubscriber) EXPECT() *MockUserUsedSpaceSubscriberMockRecorder {
return m.recorder
}
// cancel mocks base method.
func (m *MockUserUsedSpaceSubscriber) cancel() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "cancel")
}
// cancel indicates an expected call of cancel.
func (mr *MockUserUsedSpaceSubscriberMockRecorder) cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cancel", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).cancel))
}
// close mocks base method.
func (m *MockUserUsedSpaceSubscriber) close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "close")
}
// close indicates an expected call of close.
func (mr *MockUserUsedSpaceSubscriberMockRecorder) close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).close))
}
// handle mocks base method.
func (m *MockUserUsedSpaceSubscriber) handle(arg0 context.Context, arg1 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "handle", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// handle indicates an expected call of handle.
func (mr *MockUserUsedSpaceSubscriberMockRecorder) handle(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).handle), arg0, arg1)
}
// name mocks base method.
func (m *MockUserUsedSpaceSubscriber) name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "name")
ret0, _ := ret[0].(string)
return ret0
}
// name indicates an expected call of name.
func (mr *MockUserUsedSpaceSubscriberMockRecorder) name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "name", reflect.TypeOf((*MockUserUsedSpaceSubscriber)(nil).name))
}

View File

@ -0,0 +1,486 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package userevents
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
"github.com/sirupsen/logrus"
)
// Service polls from the given event source and ensures that all the respective subscribers get notified
// before proceeding to the next event. The events are published in the following order:
// * Refresh
// * User
// * Address
// * Label
// * Message
// * UserUsedSpace
// By default this service starts paused, you need to call `Service.Resume` at least one time to begin event polling.
type Service struct {
userID string
cpc *cpc.CPC
eventSource EventSource
eventIDStore EventIDStore
log *logrus.Entry
eventPublisher events.EventPublisher
timer *time.Ticker
eventTimeout time.Duration
paused bool
panicHandler async.PanicHandler
userSubscriberList userSubscriberList
addressSubscribers addressSubscriberList
labelSubscribers labelSubscriberList
messageSubscribers messageSubscriberList
refreshSubscribers refreshSubscriberList
userUsedSpaceSubscriber userUsedSpaceSubscriberList
pendingSubscriptionsLock sync.Mutex
pendingSubscriptionsAdd []Subscription
pendingSubscriptionsRemove []Subscription
}
func NewService(
userID string,
eventSource EventSource,
store EventIDStore,
eventPublisher events.EventPublisher,
pollPeriod time.Duration,
eventTimeout time.Duration,
panicHandler async.PanicHandler,
) *Service {
return &Service{
userID: userID,
cpc: cpc.NewCPC(),
eventSource: eventSource,
eventIDStore: store,
log: logrus.WithFields(logrus.Fields{
"service": "user-events",
"user": userID,
}),
eventPublisher: eventPublisher,
timer: time.NewTicker(pollPeriod),
paused: true,
eventTimeout: eventTimeout,
panicHandler: panicHandler,
}
}
type Subscription struct {
User UserSubscriber
Refresh RefreshSubscriber
Address AddressSubscriber
Labels LabelSubscriber
Messages MessageSubscriber
UserUsedSpace UserUsedSpaceSubscriber
}
// cancel subscription subscribers if applicable, see `subscriber.cancel` for more information.
func (s Subscription) cancel() {
if s.User != nil {
s.User.cancel()
}
if s.Refresh != nil {
s.Refresh.cancel()
}
if s.Address != nil {
s.Address.cancel()
}
if s.Labels != nil {
s.Labels.cancel()
}
if s.Messages != nil {
s.Messages.cancel()
}
if s.UserUsedSpace != nil {
s.UserUsedSpace.cancel()
}
}
// Subscribe adds new subscribers to the service.
// This method can safely be called during event handling.
func (s *Service) Subscribe(subscription Subscription) {
s.pendingSubscriptionsLock.Lock()
defer s.pendingSubscriptionsLock.Unlock()
s.pendingSubscriptionsAdd = append(s.pendingSubscriptionsAdd, subscription)
}
// Unsubscribe removes subscribers from the service.
// This method can safely be called during event handling.
func (s *Service) Unsubscribe(subscription Subscription) {
subscription.cancel()
s.pendingSubscriptionsLock.Lock()
defer s.pendingSubscriptionsLock.Unlock()
s.pendingSubscriptionsRemove = append(s.pendingSubscriptionsRemove, subscription)
}
// Pause pauses the event polling.
// DO NOT CALL THIS DURING EVENT HANDLING.
func (s *Service) Pause(ctx context.Context) error {
_, err := s.cpc.Send(ctx, &pauseRequest{})
return err
}
// Resume resumes the event polling.
// DO NOT CALL THIS DURING EVENT HANDLING.
func (s *Service) Resume(ctx context.Context) error {
_, err := s.cpc.Send(ctx, &resumeRequest{})
return err
}
// IsPaused return true if the service is paused
// DO NOT CALL THIS DURING EVENT HANDLING.
func (s *Service) IsPaused(ctx context.Context) (bool, error) {
return cpc.SendTyped[bool](ctx, s.cpc, &isPausedRequest{})
}
func (s *Service) Start(ctx context.Context, group *async.Group) error {
lastEventID, err := s.eventIDStore.Load(ctx)
if err != nil {
return fmt.Errorf("failed to load last event id: %w", err)
}
if lastEventID == "" {
s.log.Debugf("No event ID present in storage, retrieving latest")
eventID, err := s.eventSource.GetLatestEventID(ctx)
if err != nil {
return fmt.Errorf("failed to get latest event id: %w", err)
}
if err := s.eventIDStore.Store(ctx, eventID); err != nil {
return fmt.Errorf("failed to store event in event id store: %v", err)
}
lastEventID = eventID
}
group.Once(func(ctx context.Context) {
s.run(ctx, lastEventID)
})
return nil
}
func (s *Service) run(ctx context.Context, lastEventID string) {
s.log.Debugf("Starting service Last EventID=%v", lastEventID)
defer s.close()
defer s.log.Debug("Exiting service")
for {
select {
case <-ctx.Done():
return
case req, ok := <-s.cpc.ReceiveCh():
if !ok {
return
}
s.handleRequest(ctx, req)
continue
case <-s.timer.C:
if s.paused {
continue
}
}
// Apply any pending subscription changes.
func() {
s.pendingSubscriptionsLock.Lock()
defer s.pendingSubscriptionsLock.Unlock()
for _, subscription := range s.pendingSubscriptionsRemove {
s.removeSubscription(subscription)
}
for _, subscription := range s.pendingSubscriptionsAdd {
s.addSubscription(subscription)
}
s.pendingSubscriptionsRemove = nil
s.pendingSubscriptionsAdd = nil
}()
newEvents, _, err := s.eventSource.GetEvent(ctx, lastEventID)
if err != nil {
s.log.WithError(err).Errorf("Failed to get event (caused by %T)", internal.ErrCause(err))
continue
}
// If the event ID hasn't changed, there are no new events.
if newEvents[len(newEvents)-1].EventID == lastEventID {
s.log.Debugf("No new API Events")
continue
}
if event, eventErr := func() (proton.Event, error) {
for _, event := range newEvents {
if err := s.handleEvent(ctx, lastEventID, event); err != nil {
return event, err
}
}
return proton.Event{}, nil
}(); eventErr != nil {
subscriberName, err := s.handleEventError(ctx, lastEventID, event, eventErr)
if subscriberName == "" {
subscriberName = "?"
}
s.log.WithField("subscriber", subscriberName).WithError(err).Errorf("Failed to apply event")
continue
}
newEventID := newEvents[len(newEvents)-1].EventID
if err := s.eventIDStore.Store(ctx, newEventID); err != nil {
s.log.WithError(err).Errorf("Failed to store new event ID: %v", err)
s.onBadEvent(ctx, events.UserBadEvent{
Error: fmt.Errorf("failed to store new event ID: %w", err),
UserID: s.userID,
})
continue
}
lastEventID = newEventID
}
}
func (s *Service) handleEvent(ctx context.Context, lastEventID string, event proton.Event) error {
s.log.WithFields(logrus.Fields{
"old": lastEventID,
"new": event,
}).Info("Received new API event")
if event.Refresh&proton.RefreshMail != 0 {
s.log.Info("Handling refresh event")
if err := s.refreshSubscribers.Publish(ctx, event.Refresh, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply refresh event: %w", err)
}
return nil
}
// Start with user events.
if event.User != nil {
if err := s.userSubscriberList.PublishParallel(ctx, *event.User, s.panicHandler, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply user event: %w", err)
}
}
// Next Address events
if err := s.addressSubscribers.PublishParallel(ctx, event.Addresses, s.panicHandler, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply address events: %w", err)
}
// Next label events
if err := s.labelSubscribers.PublishParallel(ctx, event.Labels, s.panicHandler, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply label events: %w", err)
}
// Next message events
if err := s.messageSubscribers.PublishParallel(ctx, event.Messages, s.panicHandler, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply message events: %w", err)
}
// Finally user used space events
if event.UsedSpace != nil {
if err := s.userUsedSpaceSubscriber.PublishParallel(ctx, *event.UsedSpace, s.panicHandler, s.eventTimeout); err != nil {
return fmt.Errorf("failed to apply message events: %w", err)
}
}
return nil
}
func unpackPublisherError(err error) (string, error) {
var addressErr *addressPublishError
var labelErr *labelPublishError
var messageErr *messagePublishError
var refreshErr *refreshPublishError
var userErr *userPublishError
var usedSpaceErr *userUsedEventPublishError
switch {
case errors.As(err, &userErr):
return userErr.subscriber.name(), userErr.error
case errors.As(err, &addressErr):
return addressErr.subscriber.name(), addressErr.error
case errors.As(err, &labelErr):
return labelErr.subscriber.name(), labelErr.error
case errors.As(err, &messageErr):
return messageErr.subscriber.name(), messageErr.error
case errors.As(err, &refreshErr):
return refreshErr.subscriber.name(), refreshErr.error
case errors.As(err, &usedSpaceErr):
return usedSpaceErr.subscriber.name(), usedSpaceErr.error
default:
return "", err
}
}
func (s *Service) handleEventError(ctx context.Context, lastEventID string, event proton.Event, err error) (string, error) {
// Unpack the error so we can proceed to handle the real issue.
subscriberName, err := unpackPublisherError(err)
// If the error is a context cancellation, return error to retry later.
if errors.Is(err, context.Canceled) {
return subscriberName, fmt.Errorf("failed to handle event due to context cancellation: %w", err)
}
// If the error is a network error, return error to retry later.
if netErr := new(proton.NetError); errors.As(err, &netErr) {
return subscriberName, fmt.Errorf("failed to handle event due to network issue: %w", err)
}
// Catch all for uncategorized net errors that may slip through.
if netErr := new(net.OpError); errors.As(err, &netErr) {
return subscriberName, fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err)
}
// In case a json decode error slips through.
if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) {
s.eventPublisher.PublishEvent(ctx, events.UncategorizedEventError{
UserID: s.userID,
Error: err,
})
return subscriberName, fmt.Errorf("failed to handle event due to JSON issue: %w", err)
}
// If the error is an unexpected EOF, return error to retry later.
if errors.Is(err, io.ErrUnexpectedEOF) {
return subscriberName, fmt.Errorf("failed to handle event due to EOF: %w", err)
}
// If the error is a server-side issue, return error to retry later.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
return subscriberName, fmt.Errorf("failed to handle event due to server error: %w", err)
}
// Otherwise, the error is a client-side issue; notify bridge to handle it.
s.log.WithField("event", event).Warn("Failed to handle API event")
s.onBadEvent(ctx, events.UserBadEvent{
UserID: s.userID,
OldEventID: lastEventID,
NewEventID: event.EventID,
EventInfo: event.String(),
Error: err,
})
return subscriberName, fmt.Errorf("failed to handle event due to client error: %w", err)
}
func (s *Service) onBadEvent(ctx context.Context, event events.UserBadEvent) {
s.paused = true
s.eventPublisher.PublishEvent(ctx, event)
}
func (s *Service) handleRequest(ctx context.Context, request *cpc.Request) {
switch request.Value().(type) {
case *pauseRequest:
s.paused = true
request.Reply(ctx, nil, nil)
case *resumeRequest:
s.paused = false
request.Reply(ctx, nil, nil)
case *isPausedRequest:
request.Reply(ctx, s.paused, nil)
default:
s.log.Errorf("Unknown request")
}
}
func (s *Service) addSubscription(subscription Subscription) {
if subscription.User != nil {
s.userSubscriberList.Add(subscription.User)
}
if subscription.Refresh != nil {
s.refreshSubscribers.Add(subscription.Refresh)
}
if subscription.Address != nil {
s.addressSubscribers.Add(subscription.Address)
}
if subscription.Labels != nil {
s.labelSubscribers.Add(subscription.Labels)
}
if subscription.Messages != nil {
s.messageSubscribers.Add(subscription.Messages)
}
if subscription.UserUsedSpace != nil {
s.userUsedSpaceSubscriber.Add(subscription.UserUsedSpace)
}
}
func (s *Service) removeSubscription(subscription Subscription) {
if subscription.User != nil {
s.userSubscriberList.Remove(subscription.User)
}
if subscription.Refresh != nil {
s.refreshSubscribers.Remove(subscription.Refresh)
}
if subscription.Address != nil {
s.addressSubscribers.Remove(subscription.Address)
}
if subscription.Labels != nil {
s.labelSubscribers.Remove(subscription.Labels)
}
if subscription.Messages != nil {
s.messageSubscribers.Remove(subscription.Messages)
}
if subscription.UserUsedSpace != nil {
s.userUsedSpaceSubscriber.Remove(subscription.UserUsedSpace)
}
}
func (s *Service) close() {
s.cpc.Close()
s.timer.Stop()
}
type pauseRequest struct{}
type resumeRequest struct{}
type isPausedRequest struct{}

View File

@ -0,0 +1,157 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package userevents
import (
"context"
"encoding/json"
"errors"
"io"
"net"
"testing"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/events/mocks"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestServiceHandleEventError_SubscriberEventUnwrapping(t *testing.T) {
mockCtrl := gomock.NewController(t)
eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
eventIDStore := NewInMemoryEventIDStore()
service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
lastEventID := "PrevEvent"
event := proton.Event{EventID: "MyEvent"}
subscriber := &noOpSubscriber[proton.AddressEvent]{}
err := &addressPublishError{
subscriber: subscriber,
error: &proton.NetError{},
}
subscriberName, unpackedErr := service.handleEventError(context.Background(), lastEventID, event, err)
require.Equal(t, subscriber.name(), subscriberName)
protonNetErr := new(proton.NetError)
require.True(t, errors.As(unpackedErr, &protonNetErr))
err2 := &proton.APIError{Status: 500}
subscriberName2, unpackedErr2 := service.handleEventError(context.Background(), lastEventID, event, err2)
require.Equal(t, "", subscriberName2)
protonAPIErr := new(proton.APIError)
require.True(t, errors.As(unpackedErr2, &protonAPIErr))
}
func TestServiceHandleEventError_BadEventPutsServiceOnPause(t *testing.T) {
mockCtrl := gomock.NewController(t)
eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
eventIDStore := NewInMemoryEventIDStore()
service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
service.paused = false
lastEventID := "PrevEvent"
event := proton.Event{EventID: "MyEvent"}
err := &proton.APIError{}
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserBadEvent{
UserID: service.userID,
OldEventID: lastEventID,
NewEventID: event.EventID,
EventInfo: event.String(),
Error: err,
})).Times(1)
_, _ = service.handleEventError(context.Background(), lastEventID, event, err)
require.True(t, service.paused)
}
func TestServiceHandleEventError_BadEventFromPublishTimeout(t *testing.T) {
mockCtrl := gomock.NewController(t)
eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
eventIDStore := NewInMemoryEventIDStore()
service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
lastEventID := "PrevEvent"
event := proton.Event{EventID: "MyEvent"}
err := ErrPublishTimeoutExceeded
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserBadEvent{
UserID: service.userID,
OldEventID: lastEventID,
NewEventID: event.EventID,
EventInfo: event.String(),
Error: err,
})).Times(1)
_, _ = service.handleEventError(context.Background(), lastEventID, event, err)
}
func TestServiceHandleEventError_NoBadEventCheck(t *testing.T) {
mockCtrl := gomock.NewController(t)
eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
eventIDStore := NewInMemoryEventIDStore()
service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
lastEventID := "PrevEvent"
event := proton.Event{EventID: "MyEvent"}
_, _ = service.handleEventError(context.Background(), lastEventID, event, context.Canceled)
_, _ = service.handleEventError(context.Background(), lastEventID, event, &proton.NetError{})
_, _ = service.handleEventError(context.Background(), lastEventID, event, &net.OpError{})
_, _ = service.handleEventError(context.Background(), lastEventID, event, io.ErrUnexpectedEOF)
_, _ = service.handleEventError(context.Background(), lastEventID, event, &proton.APIError{Status: 500})
}
func TestServiceHandleEventError_JsonUnmarshalEventProducesUncategorizedErrorEvent(t *testing.T) {
mockCtrl := gomock.NewController(t)
eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
eventIDStore := NewInMemoryEventIDStore()
service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
lastEventID := "PrevEvent"
event := proton.Event{EventID: "MyEvent"}
err := &json.UnmarshalTypeError{}
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UncategorizedEventError{
UserID: service.userID,
Error: err,
})).Times(1)
_, _ = service.handleEventError(context.Background(), lastEventID, event, err)
}
type noOpSubscriber[T any] struct{}
func (n noOpSubscriber[T]) name() string { //nolint:unused
return "NoopSubscriber"
}
func (n noOpSubscriber[T]) handle(_ context.Context, _ []T) error { //nolint:unused
return nil
}
//nolint:unused
func (n noOpSubscriber[T]) close() {} //
//nolint:unused
func (n noOpSubscriber[T]) cancel() {}

View File

@ -0,0 +1,195 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package userevents
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/events/mocks"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestServiceHandleEvent_CheckEventCategoriesHandledInOrder(t *testing.T) {
mockCtrl := gomock.NewController(t)
eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
eventIDStore := NewInMemoryEventIDStore()
refreshHandler := NewMockRefreshSubscriber(mockCtrl)
refreshHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(2).Return(nil)
userHandler := NewMockUserSubscriber(mockCtrl)
userCall := userHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(nil)
addressHandler := NewMockAddressSubscriber(mockCtrl)
addressCall := addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(userCall).Times(1).Return(nil)
labelHandler := NewMockLabelSubscriber(mockCtrl)
labelCall := labelHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(addressCall).Times(1).Return(nil)
messageHandler := NewMockMessageSubscriber(mockCtrl)
messageCall := messageHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(labelCall).Times(1).Return(nil)
userSpaceHandler := NewMockUserUsedSpaceSubscriber(mockCtrl)
userSpaceCall := userSpaceHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(messageCall).Times(1).Return(nil)
secondRefreshHandler := NewMockRefreshSubscriber(mockCtrl)
secondRefreshHandler.EXPECT().handle(gomock.Any(), gomock.Any()).After(userSpaceCall).Times(1).Return(nil)
service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
service.addSubscription(Subscription{
User: userHandler,
Refresh: refreshHandler,
Address: addressHandler,
Labels: labelHandler,
Messages: messageHandler,
UserUsedSpace: userSpaceHandler,
})
// Simulate 1st refresh.
require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{Refresh: proton.RefreshMail}))
// Simulate Regular event.
usedSpace := 20
require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{
User: new(proton.User),
Addresses: []proton.AddressEvent{},
Labels: []proton.LabelEvent{
{},
},
Messages: []proton.MessageEvent{
{},
},
UsedSpace: &usedSpace,
}))
service.addSubscription(Subscription{
Refresh: secondRefreshHandler,
})
// Simulate 2nd refresh.
require.NoError(t, service.handleEvent(context.Background(), "", proton.Event{Refresh: proton.RefreshMail}))
}
func TestServiceHandleEvent_CheckEventFailureCausesError(t *testing.T) {
mockCtrl := gomock.NewController(t)
eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
eventIDStore := NewInMemoryEventIDStore()
addressHandler := NewMockAddressSubscriber(mockCtrl)
addressHandler.EXPECT().name().MinTimes(1).Return("Hello")
addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(fmt.Errorf("failed"))
messageHandler := NewMockMessageSubscriber(mockCtrl)
service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
service.addSubscription(Subscription{
Address: addressHandler,
Messages: messageHandler,
})
err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}})
require.Error(t, err)
publisherErr := new(addressPublishError)
require.True(t, errors.As(err, &publisherErr))
require.Equal(t, publisherErr.subscriber, addressHandler)
}
func TestServiceHandleEvent_CheckEventFailureCausesErrorParallel(t *testing.T) {
mockCtrl := gomock.NewController(t)
eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
eventIDStore := NewInMemoryEventIDStore()
addressHandler := NewMockAddressSubscriber(mockCtrl)
addressHandler.EXPECT().name().MinTimes(1).Return("Hello")
addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).Times(1).Return(fmt.Errorf("failed"))
addressHandler2 := NewMockAddressSubscriber(mockCtrl)
addressHandler2.EXPECT().handle(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil)
service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, time.Second, async.NoopPanicHandler{})
service.addSubscription(Subscription{
Address: addressHandler,
})
service.addSubscription(Subscription{
Address: addressHandler2,
})
err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}})
require.Error(t, err)
publisherErr := new(addressPublishError)
require.True(t, errors.As(err, &publisherErr))
require.Equal(t, publisherErr.subscriber, addressHandler)
}
func TestServiceHandleEvent_SubscriberTimeout(t *testing.T) {
mockCtrl := gomock.NewController(t)
eventPublisher := mocks.NewMockEventPublisher(mockCtrl)
eventIDStore := NewInMemoryEventIDStore()
addressHandler := NewMockAddressSubscriber(mockCtrl)
addressHandler.EXPECT().name().AnyTimes().Return("Ok")
addressHandler.EXPECT().handle(gomock.Any(), gomock.Any()).MaxTimes(1).Return(nil)
addressHandler2 := NewMockAddressSubscriber(mockCtrl)
addressHandler2.EXPECT().name().AnyTimes().Return("Timeout")
addressHandler2.EXPECT().handle(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, _ []proton.AddressEvent) error {
timer := time.NewTimer(time.Second)
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}).MaxTimes(1)
service := NewService("foo", &NullEventSource{}, eventIDStore, eventPublisher, 100*time.Millisecond, 500*time.Millisecond, async.NoopPanicHandler{})
service.addSubscription(Subscription{
Address: addressHandler,
})
service.addSubscription(Subscription{
Address: addressHandler2,
})
// Simulate 1st refresh.
err := service.handleEvent(context.Background(), "", proton.Event{Addresses: []proton.AddressEvent{{}}})
require.Error(t, err)
if publisherErr := new(addressPublishError); errors.As(err, &publisherErr) {
require.Equal(t, publisherErr.subscriber, addressHandler)
require.True(t, errors.Is(publisherErr.error, ErrPublishTimeoutExceeded))
} else {
require.True(t, errors.Is(err, ErrPublishTimeoutExceeded))
}
}

View File

@ -0,0 +1,269 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package userevents
import (
"context"
"fmt"
"io"
"testing"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
mocks2 "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks"
"github.com/ProtonMail/proton-bridge/v3/internal/services/userevents/mocks"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestService_EventIDLoadStore(t *testing.T) {
// Simulate the flow of data when we start without any event id in the event id store:
// * Load event id from store
// * Get latest event id since
// * Store that latest event id
// * Start event poll loop
// * Get new event id, store it in vault
// * Try to poll new event it, but context is cancelled
group := async.NewGroup(context.Background(), &async.NoopPanicHandler{})
mockCtrl := gomock.NewController(t)
eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
eventIDStore := mocks.NewMockEventIDStore(mockCtrl)
eventSource := mocks.NewMockEventSource(mockCtrl)
firstEventID := "EVENT01"
secondEventID := "EVENT02"
secondEvent := []proton.Event{{
EventID: secondEventID,
}}
// Event id store expectations.
eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return("", nil)
eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(firstEventID)).Times(1).Return(nil)
eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error {
// Force exit, we have finished executing what we expected.
group.Cancel()
return nil
})
// Event Source expectations.
eventSource.EXPECT().GetLatestEventID(gomock.Any()).Times(1).Return(firstEventID, nil)
eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil)
service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{})
require.NoError(t, service.Start(context.Background(), group))
require.NoError(t, service.Resume(context.Background()))
group.WaitToFinish()
}
func TestService_RetryEventOnNonCatastrophicFailure(t *testing.T) {
group := async.NewGroup(context.Background(), &async.NoopPanicHandler{})
mockCtrl := gomock.NewController(t)
eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
eventIDStore := mocks.NewMockEventIDStore(mockCtrl)
eventSource := mocks.NewMockEventSource(mockCtrl)
subscriber := NewMockMessageSubscriber(mockCtrl)
firstEventID := "EVENT01"
secondEventID := "EVENT02"
messageEvents := []proton.MessageEvent{
{
EventItem: proton.EventItem{ID: "Message"},
},
}
secondEvent := []proton.Event{{
EventID: secondEventID,
Messages: messageEvents,
}}
// Event id store expectations.
eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil)
eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error {
// Force exit, we have finished executing what we expected.
group.Cancel()
return nil
})
// Event Source expectations.
eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil)
// Subscriber expectations.
subscriber.EXPECT().name().AnyTimes().Return("Foo")
{
firstCall := subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).Return(io.ErrUnexpectedEOF)
subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).After(firstCall).Times(1).Return(nil)
}
service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{})
service.Subscribe(Subscription{Messages: subscriber})
require.NoError(t, service.Start(context.Background(), group))
require.NoError(t, service.Resume(context.Background()))
group.WaitToFinish()
}
func TestService_OnBadEventServiceIsPaused(t *testing.T) {
group := async.NewGroup(context.Background(), &async.NoopPanicHandler{})
mockCtrl := gomock.NewController(t)
eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
eventIDStore := mocks.NewMockEventIDStore(mockCtrl)
eventSource := mocks.NewMockEventSource(mockCtrl)
subscriber := NewMockMessageSubscriber(mockCtrl)
firstEventID := "EVENT01"
secondEventID := "EVENT02"
messageEvents := []proton.MessageEvent{
{
EventItem: proton.EventItem{ID: "Message"},
},
}
secondEvent := []proton.Event{{
EventID: secondEventID,
Messages: messageEvents,
}}
// Event id store expectations.
eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil)
// Event Source expectations.
eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil)
// Subscriber expectations.
badEventErr := fmt.Errorf("I will cause bad event")
subscriber.EXPECT().name().AnyTimes().Return("Foo")
subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).Return(badEventErr)
service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{})
// Event publisher expectations.
eventPublisher.EXPECT().PublishEvent(gomock.Any(), events.UserBadEvent{
UserID: "foo",
OldEventID: firstEventID,
NewEventID: secondEventID,
EventInfo: secondEvent[0].String(),
Error: badEventErr,
}).Do(func(_ context.Context, event events.Event) {
group.Once(func(_ context.Context) {
// Use background context to avoid having the request cancelled
paused, err := service.IsPaused(context.Background())
require.NoError(t, err)
require.True(t, paused)
group.Cancel()
})
})
service.Subscribe(Subscription{Messages: subscriber})
require.NoError(t, service.Start(context.Background(), group))
require.NoError(t, service.Resume(context.Background()))
group.WaitToFinish()
}
func TestService_UnsubscribeDuringEventHandlingDoesNotCauseDeadlock(t *testing.T) {
group := async.NewGroup(context.Background(), &async.NoopPanicHandler{})
mockCtrl := gomock.NewController(t)
eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
eventIDStore := mocks.NewMockEventIDStore(mockCtrl)
eventSource := mocks.NewMockEventSource(mockCtrl)
subscriber := NewMockMessageSubscriber(mockCtrl)
firstEventID := "EVENT01"
secondEventID := "EVENT02"
messageEvents := []proton.MessageEvent{
{
EventItem: proton.EventItem{ID: "Message"},
},
}
secondEvent := []proton.Event{{
EventID: secondEventID,
Messages: messageEvents,
}}
// Event id store expectations.
eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil)
eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error {
// Force exit, we have finished executing what we expected.
group.Cancel()
return nil
})
// Event Source expectations.
eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil)
service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{})
// Subscriber expectations.
subscriber.EXPECT().name().AnyTimes().Return("Foo")
subscriber.EXPECT().cancel().Times(1)
subscriber.EXPECT().handle(gomock.Any(), gomock.Eq(messageEvents)).Times(1).DoAndReturn(func(_ context.Context, _ []proton.MessageEvent) error {
service.Unsubscribe(Subscription{Messages: subscriber})
return nil
})
service.Subscribe(Subscription{Messages: subscriber})
require.NoError(t, service.Start(context.Background(), group))
require.NoError(t, service.Resume(context.Background()))
group.WaitToFinish()
}
func TestService_UnsubscribeBeforeHandlingEventIsNotConsideredError(t *testing.T) {
group := async.NewGroup(context.Background(), &async.NoopPanicHandler{})
mockCtrl := gomock.NewController(t)
eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
eventIDStore := mocks.NewMockEventIDStore(mockCtrl)
eventSource := mocks.NewMockEventSource(mockCtrl)
subscriber := newChanneledSubscriber[proton.MessageEvent]("My subscriber")
firstEventID := "EVENT01"
secondEventID := "EVENT02"
messageEvents := []proton.MessageEvent{
{
EventItem: proton.EventItem{ID: "Message"},
},
}
secondEvent := []proton.Event{{
EventID: secondEventID,
Messages: messageEvents,
}}
// Event id store expectations.
eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(firstEventID, nil)
eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).DoAndReturn(func(_ context.Context, _ string) error {
// Force exit, we have finished executing what we expected.
group.Cancel()
return nil
})
// Event Source expectations.
eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).Return(secondEvent, false, nil)
eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(secondEventID)).AnyTimes().Return(secondEvent, false, nil)
service := NewService("foo", eventSource, eventIDStore, eventPublisher, 1*time.Millisecond, time.Second, async.NoopPanicHandler{})
// start subscriber
group.Once(func(_ context.Context) {
defer service.Unsubscribe(Subscription{Messages: subscriber})
// Simulate the reception of an event, but it is never handled due to unexpected exit
<-time.NewTicker(500 * time.Millisecond).C
})
service.Subscribe(Subscription{Messages: subscriber})
require.NoError(t, service.Start(context.Background(), group))
require.NoError(t, service.Resume(context.Background()))
group.WaitToFinish()
}

View File

@ -0,0 +1,252 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package userevents
import (
"context"
"errors"
"fmt"
"runtime"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/xslices"
"golang.org/x/exp/slices"
)
func NewMessageSubscriber(name string) *ChanneledSubscriber[proton.MessageEvent] {
return newChanneledSubscriber[proton.MessageEvent](name)
}
func NewAddressSubscriber(name string) *ChanneledSubscriber[proton.AddressEvent] {
return newChanneledSubscriber[proton.AddressEvent](name)
}
func NewLabelSubscriber(name string) *ChanneledSubscriber[proton.LabelEvent] {
return newChanneledSubscriber[proton.LabelEvent](name)
}
func NewRefreshSubscriber(name string) *ChanneledSubscriber[struct{}] {
return newChanneledSubscriber[struct{}](name)
}
func NewUserSubscriber(name string) *ChanneledSubscriber[proton.User] {
return newChanneledSubscriber[proton.User](name)
}
func NewUserUsedSpaceSubscriber(name string) *ChanneledSubscriber[int] {
return newChanneledSubscriber[int](name)
}
type AddressSubscriber = subscriber[[]proton.AddressEvent]
type LabelSubscriber = subscriber[[]proton.LabelEvent]
type MessageSubscriber = subscriber[[]proton.MessageEvent]
type RefreshSubscriber = subscriber[proton.RefreshFlag]
type UserSubscriber = subscriber[proton.User]
type UserUsedSpaceSubscriber = subscriber[int]
// Subscriber is the main entry point of interacting with user generated events.
type subscriber[T any] interface {
// Name returns the identifier for this subscriber
name() string
// Handle the event list.
handle(context.Context, T) error
// cancel is behavior extension for channel based subscribers so that they can ensure that
// if a subscriber unsubscribes, it doesn't cause pending events on the channel to time-out as there is no one to handle
// them.
cancel()
// close release all associated resources
close()
}
type subscriberList[T any] struct {
subscribers []subscriber[T]
}
type addressSubscriberList = subscriberList[[]proton.AddressEvent]
type labelSubscriberList = subscriberList[[]proton.LabelEvent]
type messageSubscriberList = subscriberList[[]proton.MessageEvent]
type refreshSubscriberList = subscriberList[proton.RefreshFlag]
type userSubscriberList = subscriberList[proton.User]
type userUsedSpaceSubscriberList = subscriberList[int]
func (s *subscriberList[T]) Add(subscriber subscriber[T]) {
if !slices.Contains(s.subscribers, subscriber) {
s.subscribers = append(s.subscribers, subscriber)
}
}
func (s *subscriberList[T]) Remove(subscriber subscriber[T]) {
index := slices.Index(s.subscribers, subscriber)
if index < 0 {
return
}
s.subscribers[index].close()
s.subscribers = xslices.Remove(s.subscribers, index, 1)
}
type publishError[T any] struct {
subscriber subscriber[T]
error error
}
var ErrPublishTimeoutExceeded = errors.New("event publish timed out")
type addressPublishError = publishError[[]proton.AddressEvent]
type labelPublishError = publishError[[]proton.LabelEvent]
type messagePublishError = publishError[[]proton.MessageEvent]
type refreshPublishError = publishError[proton.RefreshFlag]
type userPublishError = publishError[proton.User]
type userUsedEventPublishError = publishError[int]
func (p publishError[T]) Error() string {
return fmt.Sprintf("Event publish failed on (%v): %v", p.subscriber.name(), p.error.Error())
}
func (s *subscriberList[T]) Publish(ctx context.Context, event T, timeout time.Duration) error {
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
for _, subscriber := range s.subscribers {
if err := subscriber.handle(ctx, event); err != nil {
return &publishError[T]{
subscriber: subscriber,
error: mapContextTimeoutError(err),
}
}
if err := ctx.Err(); err != nil {
return &publishError[T]{
subscriber: subscriber,
error: mapContextTimeoutError(err),
}
}
}
return nil
}
func mapContextTimeoutError(err error) error {
if errors.Is(err, context.DeadlineExceeded) {
return ErrPublishTimeoutExceeded
}
return err
}
func (s *subscriberList[T]) PublishParallel(
ctx context.Context,
event T,
panicHandler async.PanicHandler,
timeout time.Duration,
) error {
if len(s.subscribers) <= 1 {
return s.Publish(ctx, event, timeout)
}
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
err := parallel.DoContext(ctx, runtime.NumCPU()/2, len(s.subscribers), func(ctx context.Context, index int) error {
defer async.HandlePanic(panicHandler)
if err := s.subscribers[index].handle(ctx, event); err != nil {
return &publishError[T]{
subscriber: s.subscribers[index],
error: mapContextTimeoutError(err),
}
}
return nil
})
return mapContextTimeoutError(err)
}
type ChanneledSubscriber[T any] struct {
id string
sender chan *ChanneledSubscriberEvent[T]
}
func newChanneledSubscriber[T any](name string) *ChanneledSubscriber[T] {
return &ChanneledSubscriber[T]{
id: name,
sender: make(chan *ChanneledSubscriberEvent[T]),
}
}
type ChanneledSubscriberEvent[T any] struct {
data []T
response chan error
}
func (c ChanneledSubscriberEvent[T]) Consume(f func([]T) error) {
if err := f(c.data); err != nil {
c.response <- err
}
close(c.response)
}
func (c *ChanneledSubscriber[T]) name() string { //nolint:unused
return c.id
}
func (c *ChanneledSubscriber[T]) handle(ctx context.Context, event []T) error { //nolint:unused
data := &ChanneledSubscriberEvent[T]{
data: event,
response: make(chan error),
}
// Send Event
select {
case <-ctx.Done():
return fmt.Errorf("failed to send event: %w", ctx.Err())
case c.sender <- data:
//
}
// Wait on Reply
select {
case <-ctx.Done():
return fmt.Errorf("failed to receive event reply: %w", ctx.Err())
case reply := <-data.response:
return reply
}
}
func (c *ChanneledSubscriber[T]) OnEventCh() <-chan *ChanneledSubscriberEvent[T] {
return c.sender
}
func (c *ChanneledSubscriber[T]) close() { //nolint:unused
close(c.sender)
}
func (c *ChanneledSubscriber[T]) cancel() { //nolint:unused
go func() {
for {
e, ok := <-c.sender
if !ok {
return
}
e.Consume(func(_ []T) error { return nil })
}
}()
}

View File

@ -0,0 +1,105 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package userevents
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestChanneledSubscriber_CtxTimeoutDoesNotBlockFutureEvents(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(1)
subscriber := newChanneledSubscriber[int]("test")
defer subscriber.close()
go func() {
defer wg.Done()
// Send one event, that succeeds.
require.NoError(t, subscriber.handle(context.Background(), []int{30}))
// Add an impossible deadline that fails immediately to simulate on event taking too long to send.
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Microsecond))
defer cancel()
err := subscriber.handle(ctx, []int{20})
require.Error(t, err)
require.True(t, errors.Is(err, context.DeadlineExceeded))
}()
// Receive first event. Notify success.
event, ok := <-subscriber.OnEventCh()
require.True(t, ok)
event.Consume(func(event []int) error {
require.Equal(t, []int{30}, event)
return nil
})
wg.Wait()
// Simulate reception of another event
wg.Add(1)
go func() {
defer wg.Done()
require.NoError(t, subscriber.handle(context.Background(), []int{40}))
}()
event, ok = <-subscriber.OnEventCh()
require.True(t, ok)
event.Consume(func(event []int) error {
require.Equal(t, []int{40}, event)
return nil
})
wg.Wait()
}
func TestChanneledSubscriber_ErrorReported(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(1)
subscriber := newChanneledSubscriber[int]("test")
defer subscriber.close()
reportedErr := fmt.Errorf("request failed")
go func() {
defer wg.Done()
// Send one event, that succeeds.
err := subscriber.handle(context.Background(), []int{30})
require.Error(t, err)
require.Equal(t, reportedErr, err)
}()
// Receive first event. Notify success.
event, ok := <-subscriber.OnEventCh()
require.True(t, ok)
event.Consume(func(event []int) error {
require.Equal(t, []int{30}, event)
return reportedErr
})
wg.Wait()
}