From 040ddadb7a66ac56b9ea8b3414a5307124aec698 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Fri, 21 Jul 2023 16:33:35 +0200 Subject: [PATCH] feat(GODT-2801): Identity Service Identity Service contains all the information related to user state, addresses and keys. This patch also introduces the `State` type which can be used by other services to maintain their own copy of this state to avoid lock contention. Finally, there are currently no external facing methods via a CPC interface. Those will added as needed once the refactoring of the architecture is complete. --- Makefile | 4 +- internal/services/userevents/subscribable.go | 34 ++ internal/services/useridentity/mocks/mocks.go | 66 +++ internal/services/useridentity/service.go | 270 +++++++++++ .../services/useridentity/service_test.go | 447 ++++++++++++++++++ internal/services/useridentity/state.go | 187 ++++++++ 6 files changed, 1007 insertions(+), 1 deletion(-) create mode 100644 internal/services/userevents/subscribable.go create mode 100644 internal/services/useridentity/mocks/mocks.go create mode 100644 internal/services/useridentity/service.go create mode 100644 internal/services/useridentity/service_test.go create mode 100644 internal/services/useridentity/state.go diff --git a/Makefile b/Makefile index 0e01fe2a..276a0fcc 100644 --- a/Makefile +++ b/Makefile @@ -282,12 +282,14 @@ mocks: mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/telemetry HeartbeatManager > internal/telemetry/mocks/mocks.go cp internal/telemetry/mocks/mocks.go internal/bridge/mocks/telemetry_mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \ -EventSource,EventIDStore > internal/services/userevents/mocks/mocks.go +EventSource,EventIDStore > internal/services/userevents/mocks/mocks.go mockgen --package userevents github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \ MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscriber,UserUsedSpaceSubscriber > tmp mv tmp internal/services/userevents/mocks_test.go mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/events EventPublisher \ > internal/events/mocks/mocks.go + mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider \ +> internal/services/useridentity/mocks/mocks.go lint: gofiles lint-golang lint-license lint-dependencies lint-changelog diff --git a/internal/services/userevents/subscribable.go b/internal/services/userevents/subscribable.go new file mode 100644 index 00000000..4a1526c6 --- /dev/null +++ b/internal/services/userevents/subscribable.go @@ -0,0 +1,34 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package userevents + +// Subscribable represents a type that allows the registration of event subscribers. +type Subscribable interface { + Subscribe(subscription Subscription) + Unsubscribe(subscription Subscription) +} + +type NoOpSubscribable struct{} + +func (n NoOpSubscribable) Subscribe(_ Subscription) { + // Does nothing +} + +func (n NoOpSubscribable) Unsubscribe(_ Subscription) { + // Does nothing +} diff --git a/internal/services/useridentity/mocks/mocks.go b/internal/services/useridentity/mocks/mocks.go new file mode 100644 index 00000000..ce105d33 --- /dev/null +++ b/internal/services/useridentity/mocks/mocks.go @@ -0,0 +1,66 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity (interfaces: IdentityProvider) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + proton "github.com/ProtonMail/go-proton-api" + gomock "github.com/golang/mock/gomock" +) + +// MockIdentityProvider is a mock of IdentityProvider interface. +type MockIdentityProvider struct { + ctrl *gomock.Controller + recorder *MockIdentityProviderMockRecorder +} + +// MockIdentityProviderMockRecorder is the mock recorder for MockIdentityProvider. +type MockIdentityProviderMockRecorder struct { + mock *MockIdentityProvider +} + +// NewMockIdentityProvider creates a new mock instance. +func NewMockIdentityProvider(ctrl *gomock.Controller) *MockIdentityProvider { + mock := &MockIdentityProvider{ctrl: ctrl} + mock.recorder = &MockIdentityProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIdentityProvider) EXPECT() *MockIdentityProviderMockRecorder { + return m.recorder +} + +// GetAddresses mocks base method. +func (m *MockIdentityProvider) GetAddresses(arg0 context.Context) ([]proton.Address, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAddresses", arg0) + ret0, _ := ret[0].([]proton.Address) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAddresses indicates an expected call of GetAddresses. +func (mr *MockIdentityProviderMockRecorder) GetAddresses(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddresses", reflect.TypeOf((*MockIdentityProvider)(nil).GetAddresses), arg0) +} + +// GetUser mocks base method. +func (m *MockIdentityProvider) GetUser(arg0 context.Context) (proton.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUser", arg0) + ret0, _ := ret[0].(proton.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUser indicates an expected call of GetUser. +func (mr *MockIdentityProviderMockRecorder) GetUser(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUser", reflect.TypeOf((*MockIdentityProvider)(nil).GetUser), arg0) +} diff --git a/internal/services/useridentity/service.go b/internal/services/useridentity/service.go new file mode 100644 index 00000000..4d90b20f --- /dev/null +++ b/internal/services/useridentity/service.go @@ -0,0 +1,270 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package useridentity + +import ( + "context" + "fmt" + + "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal/events" + "github.com/ProtonMail/proton-bridge/v3/internal/logging" + "github.com/ProtonMail/proton-bridge/v3/internal/services/userevents" + "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" + "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" +) + +type IdentityProvider interface { + GetUser(ctx context.Context) (proton.User, error) + GetAddresses(ctx context.Context) ([]proton.Address, error) +} + +// Service contains all the data required to establish the user identity. This +// includes all the user's information as well as mail addresses and keys. +type Service struct { + eventService userevents.Subscribable + eventPublisher events.EventPublisher + log *logrus.Entry + identity State + + userSubscriber *userevents.UserChanneledSubscriber + addressSubscriber *userevents.AddressChanneledSubscriber + usedSpaceSubscriber *userevents.UserUsedSpaceChanneledSubscriber + refreshSubscriber *userevents.RefreshChanneledSubscriber +} + +func NewService( + ctx context.Context, + service userevents.Subscribable, + user proton.User, + eventPublisher events.EventPublisher, + provider IdentityProvider, +) (*Service, error) { + addresses, err := provider.GetAddresses(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get addresses: %w", err) + } + + subscriberName := fmt.Sprintf("identity-%v", user.ID) + + return &Service{ + eventService: service, + identity: NewState(user, addresses, provider), + eventPublisher: eventPublisher, + log: logrus.WithFields(logrus.Fields{ + "service": "user-identity", + "user": user.ID, + }), + + userSubscriber: userevents.NewUserSubscriber(subscriberName), + refreshSubscriber: userevents.NewRefreshSubscriber(subscriberName), + addressSubscriber: userevents.NewAddressSubscriber(subscriberName), + usedSpaceSubscriber: userevents.NewUserUsedSpaceSubscriber(subscriberName), + }, nil +} + +func (s *Service) Start(group *async.Group) { + group.Once(func(ctx context.Context) { + s.run(ctx) + }) +} + +func (s *Service) run(ctx context.Context) { + s.log.WithFields(logrus.Fields{ + "numAddr": len(s.identity.Addresses), + }).Info("Starting user identity service") + + s.registerSubscription() + defer s.unregisterSubscription() + + for { + select { + case <-ctx.Done(): + return + case evt, ok := <-s.userSubscriber.OnEventCh(): + if !ok { + continue + } + evt.Consume(func(user proton.User) error { + s.onUserEvent(ctx, user) + return nil + }) + case evt, ok := <-s.refreshSubscriber.OnEventCh(): + if !ok { + continue + } + evt.Consume(func(_ proton.RefreshFlag) error { + return s.onRefreshEvent(ctx) + }) + case evt, ok := <-s.usedSpaceSubscriber.OnEventCh(): + if !ok { + continue + } + evt.Consume(func(usedSpace int) error { + s.onUserSpaceChanged(ctx, usedSpace) + + return nil + }) + case evt, ok := <-s.addressSubscriber.OnEventCh(): + if !ok { + continue + } + evt.Consume(func(events []proton.AddressEvent) error { + return s.onAddressEvent(ctx, events) + }) + } + } +} + +func (s *Service) registerSubscription() { + s.eventService.Subscribe(userevents.Subscription{ + Refresh: s.refreshSubscriber, + User: s.userSubscriber, + Address: s.addressSubscriber, + UserUsedSpace: s.usedSpaceSubscriber, + }) +} + +func (s *Service) unregisterSubscription() { + s.eventService.Unsubscribe(userevents.Subscription{ + Refresh: s.refreshSubscriber, + User: s.userSubscriber, + Address: s.addressSubscriber, + UserUsedSpace: s.usedSpaceSubscriber, + }) +} + +func (s *Service) onUserEvent(ctx context.Context, user proton.User) { + s.log.WithField("username", logging.Sensitive(user.Name)).Info("Handling user event") + s.identity.OnUserEvent(user) + s.eventPublisher.PublishEvent(ctx, events.UserChanged{ + UserID: user.ID, + }) +} + +func (s *Service) onRefreshEvent(ctx context.Context) error { + s.log.Info("Handling refresh event") + + if err := s.identity.OnRefreshEvent(ctx); err != nil { + s.log.WithError(err).Error("Failed to handle refresh event") + return err + } + + s.eventPublisher.PublishEvent(ctx, events.UserRefreshed{ + UserID: s.identity.User.ID, + CancelEventPool: false, + }) + + return nil +} + +func (s *Service) onUserSpaceChanged(ctx context.Context, value int) { + s.log.Info("Handling User Space Changed event") + if !s.identity.OnUserSpaceChanged(value) { + return + } + + s.eventPublisher.PublishEvent(ctx, events.UsedSpaceChanged{ + UserID: s.identity.User.ID, + UsedSpace: value, + }) +} + +func (s *Service) onAddressEvent(ctx context.Context, addressEvents []proton.AddressEvent) error { + s.log.Infof("Handling Address Events (%v)", len(addressEvents)) + + for idx, event := range addressEvents { + switch event.Action { + case proton.EventCreate: + s.log.WithFields(logrus.Fields{ + "index": idx, + "addressID": event.ID, + "email": logging.Sensitive(event.Address.Email), + }).Info("Handling address created event") + if s.identity.OnAddressCreated(event) == AddressUpdateCreated { + s.eventPublisher.PublishEvent(ctx, events.UserAddressCreated{ + UserID: s.identity.User.ID, + AddressID: event.Address.ID, + Email: event.Address.Email, + }) + } + + case proton.EventUpdate, proton.EventUpdateFlags: + addr, status := s.identity.OnAddressUpdated(event) + switch status { + case AddressUpdateCreated: + s.eventPublisher.PublishEvent(ctx, events.UserAddressCreated{ + UserID: s.identity.User.ID, + AddressID: addr.ID, + Email: addr.Email, + }) + case AddressUpdateUpdated: + s.eventPublisher.PublishEvent(ctx, events.UserAddressUpdated{ + UserID: s.identity.User.ID, + AddressID: addr.ID, + Email: addr.Email, + }) + + case AddressUpdateDisabled: + s.eventPublisher.PublishEvent(ctx, events.UserAddressDisabled{ + UserID: s.identity.User.ID, + AddressID: addr.ID, + Email: addr.Email, + }) + + case AddressUpdateEnabled: + s.eventPublisher.PublishEvent(ctx, events.UserAddressEnabled{ + UserID: s.identity.User.ID, + AddressID: addr.ID, + Email: addr.Email, + }) + + case AddressUpdateNoop: + continue + + case AddressUpdateDeleted: + s.log.Warnf("Unexpected address update status after update event %v", status) + continue + } + + case proton.EventDelete: + if addr, status := s.identity.OnAddressDeleted(event); status == AddressUpdateDeleted { + s.eventPublisher.PublishEvent(ctx, events.UserAddressDeleted{ + UserID: s.identity.User.ID, + AddressID: event.ID, + Email: addr.Email, + }) + } + } + } + return nil +} + +func sortAddresses(addr []proton.Address) []proton.Address { + slices.SortFunc(addr, func(a, b proton.Address) bool { + return a.Order < b.Order + }) + + return addr +} + +func buildAddressMapFromSlice(addr []proton.Address) map[string]proton.Address { + return usertypes.GroupBy(addr, func(addr proton.Address) string { return addr.ID }) +} diff --git a/internal/services/useridentity/service_test.go b/internal/services/useridentity/service_test.go new file mode 100644 index 00000000..1d3021a5 --- /dev/null +++ b/internal/services/useridentity/service_test.go @@ -0,0 +1,447 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package useridentity + +import ( + "context" + "testing" + + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal/events" + mocks2 "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks" + "github.com/ProtonMail/proton-bridge/v3/internal/services/userevents" + "github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity/mocks" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +const TestUserID = "MyUserID" + +func TestService_OnUserEvent(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, eventPublisher, _ := newTestService(t, mockCtrl) + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserChanged{UserID: TestUserID})).Times(1) + + service.onUserEvent(context.Background(), newTestUser()) +} + +func TestService_OnUserSpaceChanged(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, eventPublisher, _ := newTestService(t, mockCtrl) + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UsedSpaceChanged{UserID: TestUserID, UsedSpace: 1024})).Times(1) + + // Original value, no changes. + service.onUserSpaceChanged(context.Background(), 0) + + // New value, event should be published. + service.onUserSpaceChanged(context.Background(), 1024) + require.Equal(t, 1024, service.identity.User.UsedSpace) +} + +func TestService_OnRefreshEvent(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, eventPublisher, provider := newTestService(t, mockCtrl) + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserRefreshed{UserID: TestUserID, CancelEventPool: false})).Times(1) + + newUser := newTestUserRefreshed() + newAddresses := newTestAddressesRefreshed() + + { + getUserCall := provider.EXPECT().GetUser(gomock.Any()).Times(1).Return(newUser, nil) + provider.EXPECT().GetAddresses(gomock.Any()).After(getUserCall).Times(1).Return(newAddresses, nil) + } + + // Original value, no changes. + require.NoError(t, service.onRefreshEvent(context.Background())) + + require.Equal(t, newUser, service.identity.User) + require.Equal(t, newAddresses, service.identity.AddressesSorted) +} + +func TestService_OnAddressCreated(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, eventPublisher, _ := newTestService(t, mockCtrl) + + newAddress := proton.Address{ + ID: "NewAddrID", + Email: "new@bar.com", + Status: proton.AddressStatusEnabled, + } + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressCreated{ + UserID: TestUserID, + AddressID: newAddress.ID, + Email: newAddress.Email, + })).Times(1) + + err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + { + EventItem: proton.EventItem{ + ID: "", + Action: proton.EventCreate, + }, + Address: newAddress, + }, + }) + require.NoError(t, err) + + require.Contains(t, service.identity.Addresses, newAddress.ID) +} + +func TestService_OnAddressCreatedDisabledDoesNotProduceEvent(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, _, _ := newTestService(t, mockCtrl) + + newAddress := proton.Address{ + ID: "Address1", + Email: "new@bar.com", + Status: proton.AddressStatusEnabled, + } + + err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + { + EventItem: proton.EventItem{ + ID: "", + Action: proton.EventCreate, + }, + Address: newAddress, + }, + }) + require.NoError(t, err) + + require.Contains(t, service.identity.Addresses, newAddress.ID) +} + +func TestService_OnAddressCreatedDuplicateDoesNotProduceEvent(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, _, _ := newTestService(t, mockCtrl) + + newAddress := proton.Address{ + ID: "NewAddrID", + Email: "new@bar.com", + Status: proton.AddressStatusDisabled, + } + + err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + { + EventItem: proton.EventItem{ + ID: "", + Action: proton.EventCreate, + }, + Address: newAddress, + }, + }) + require.NoError(t, err) + + require.Contains(t, service.identity.Addresses, newAddress.ID) +} + +func TestService_OnAddressUpdated(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, eventPublisher, _ := newTestService(t, mockCtrl) + + newAddress := proton.Address{ + ID: "Address1", + Email: "new@bar.com", + Status: proton.AddressStatusEnabled, + } + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressUpdated{ + UserID: TestUserID, + AddressID: newAddress.ID, + Email: newAddress.Email, + })).Times(1) + + err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + { + EventItem: proton.EventItem{ + ID: "", + Action: proton.EventUpdate, + }, + Address: newAddress, + }, + }) + require.NoError(t, err) + + require.Equal(t, newAddress, service.identity.Addresses[newAddress.ID]) +} + +func TestService_OnAddressUpdatedDisableFollowedByEnable(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, eventPublisher, _ := newTestService(t, mockCtrl) + + newAddressDisabled := proton.Address{ + ID: "Address1", + Email: "new@bar.com", + Status: proton.AddressStatusDisabled, + } + newAddressEnabled := proton.Address{ + ID: "Address1", + Email: "new@bar.com", + Status: proton.AddressStatusEnabled, + } + + { + disabledCall := eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressDisabled{ + UserID: TestUserID, + AddressID: newAddressDisabled.ID, + Email: newAddressDisabled.Email, + })).Times(1) + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressEnabled{ + UserID: TestUserID, + AddressID: newAddressEnabled.ID, + Email: newAddressEnabled.Email, + })).Times(1).After(disabledCall) + } + + err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + { + EventItem: proton.EventItem{ + ID: "", + Action: proton.EventUpdate, + }, + Address: newAddressDisabled, + }, + }) + require.NoError(t, err) + + require.Equal(t, newAddressDisabled, service.identity.Addresses[newAddressEnabled.ID]) + + err = service.onAddressEvent(context.Background(), []proton.AddressEvent{ + { + EventItem: proton.EventItem{ + ID: "", + Action: proton.EventUpdate, + }, + Address: newAddressEnabled, + }, + }) + require.NoError(t, err) + + require.Equal(t, newAddressEnabled, service.identity.Addresses[newAddressEnabled.ID]) +} + +func TestService_OnAddressUpdateCreatedIfNotExists(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, eventPublisher, _ := newTestService(t, mockCtrl) + + newAddress := proton.Address{ + ID: "NewAddrID", + Email: "new@bar.com", + Status: proton.AddressStatusEnabled, + } + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressCreated{ + UserID: TestUserID, + AddressID: newAddress.ID, + Email: newAddress.Email, + })).Times(1) + + err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + { + EventItem: proton.EventItem{ + ID: "", + Action: proton.EventUpdate, + }, + Address: newAddress, + }, + }) + require.NoError(t, err) + + require.Contains(t, service.identity.Addresses, newAddress.ID) +} + +func TestService_OnAddressDeleted(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, eventPublisher, _ := newTestService(t, mockCtrl) + + address := proton.Address{ + ID: "Address1", + Email: "foo@bar.com", + Status: proton.AddressStatusEnabled, + } + + eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressDeleted{ + UserID: TestUserID, + AddressID: address.ID, + Email: address.Email, + })).Times(1) + + err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + { + EventItem: proton.EventItem{ + ID: address.ID, + Action: proton.EventDelete, + }, + }, + }) + require.NoError(t, err) + + require.NotContains(t, service.identity.Addresses, address.ID) +} + +func TestService_OnAddressDeleteDisabledDoesNotProduceEvent(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, _, _ := newTestService(t, mockCtrl) + + address := proton.Address{ + ID: "Address2", + Email: "foo2@bar.com", + Status: proton.AddressStatusDisabled, + } + + err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + { + EventItem: proton.EventItem{ + ID: address.ID, + Action: proton.EventDelete, + }, + }, + }) + require.NoError(t, err) + + require.NotContains(t, service.identity.Addresses, address.ID) +} + +func TestService_OnAddressDeletedUnknownDoesNotProduceEvent(t *testing.T) { + mockCtrl := gomock.NewController(t) + + service, _, _ := newTestService(t, mockCtrl) + + address := proton.Address{ + ID: "UnknownID", + Email: "new@bar.com", + Status: proton.AddressStatusEnabled, + } + + err := service.onAddressEvent(context.Background(), []proton.AddressEvent{ + { + EventItem: proton.EventItem{ + ID: address.ID, + Action: proton.EventDelete, + }, + Address: address, + }, + }) + require.NoError(t, err) +} + +func newTestService(t *testing.T, mockCtrl *gomock.Controller) (*Service, *mocks2.MockEventPublisher, *mocks.MockIdentityProvider) { + subscribable := &userevents.NoOpSubscribable{} + eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) + provider := mocks.NewMockIdentityProvider(mockCtrl) + user := newTestUser() + + provider.EXPECT().GetAddresses(gomock.Any()).Times(1).Return(newTestAddresses(), nil) + + service, err := NewService(context.Background(), subscribable, user, eventPublisher, provider) + require.NoError(t, err) + + return service, eventPublisher, provider +} + +func newTestUser() proton.User { + return proton.User{ + ID: TestUserID, + Name: "Foo", + DisplayName: "Foo", + Email: "foo@bar", + Keys: nil, + UsedSpace: 0, + MaxSpace: 0, + MaxUpload: 0, + Credit: 0, + Currency: "", + } +} + +func newTestUserRefreshed() proton.User { + return proton.User{ + ID: TestUserID, + Name: "Alternate", + DisplayName: "Universe", + Email: "foo2@bar", + Keys: nil, + UsedSpace: 0, + MaxSpace: 0, + MaxUpload: 0, + Credit: 0, + Currency: "USD", + } +} + +func newTestAddresses() []proton.Address { + return []proton.Address{ + { + ID: "Address1", + Email: "foo@bar.com", + Status: proton.AddressStatusEnabled, + Type: 0, + Order: 0, + DisplayName: "", + Keys: nil, + }, + { + ID: "Address2", + Email: "foo2@bar.com", + Status: proton.AddressStatusDisabled, + Type: 0, + Order: 1, + DisplayName: "", + Keys: nil, + }, + } +} + +func newTestAddressesRefreshed() []proton.Address { + return []proton.Address{ + { + ID: "Address1", + Email: "foo@bar.com", + Status: proton.AddressStatusEnabled, + Type: 0, + Order: 2, + DisplayName: "FOo barish", + Keys: nil, + }, + { + ID: "Address2", + Email: "foo2@bar.com", + Status: proton.AddressStatusDisabled, + Type: 0, + Order: 4, + DisplayName: "New display name", + Keys: nil, + }, + } +} diff --git a/internal/services/useridentity/state.go b/internal/services/useridentity/state.go new file mode 100644 index 00000000..178c0d23 --- /dev/null +++ b/internal/services/useridentity/state.go @@ -0,0 +1,187 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package useridentity + +import ( + "context" + "fmt" + "strings" + + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" + "golang.org/x/exp/maps" +) + +// State holds all the required user identity state. The idea of this type is that +// it can be replicated across all services to avoid lock contention. The only +// requirement is that the service with the respective events. +type State struct { + AddressesSorted []proton.Address + Addresses map[string]proton.Address + User proton.User + + provider IdentityProvider +} + +func NewState( + user proton.User, + addresses []proton.Address, + provider IdentityProvider, +) State { + addressMap := buildAddressMapFromSlice(addresses) + return State{ + AddressesSorted: sortAddresses(maps.Values(addressMap)), + Addresses: addressMap, + User: user, + + provider: provider, + } +} + +func NewStateFromProvider(ctx context.Context, provider IdentityProvider) (State, error) { + user, err := provider.GetUser(ctx) + if err != nil { + return State{}, fmt.Errorf("failed to get user: %w", err) + } + + addresses, err := provider.GetAddresses(ctx) + if err != nil { + return State{}, fmt.Errorf("failed to get user addresses: %w", err) + } + + return NewState(user, addresses, provider), nil +} + +// GetAddr returns the address for the given email address. +func (s *State) GetAddr(email string) (proton.Address, error) { + for _, addr := range s.AddressesSorted { + if strings.EqualFold(addr.Email, usertypes.SanitizeEmail(email)) { + return addr, nil + } + } + + return proton.Address{}, fmt.Errorf("address %s not found", email) +} + +// GetPrimaryAddr returns the primary address for this user. +func (s *State) GetPrimaryAddr() (proton.Address, error) { + if len(s.AddressesSorted) == 0 { + return proton.Address{}, fmt.Errorf("no addresses available") + } + + return s.AddressesSorted[0], nil +} + +func (s *State) OnUserEvent(user proton.User) { + s.User = user +} + +func (s *State) OnRefreshEvent(ctx context.Context) error { + user, err := s.provider.GetUser(ctx) + if err != nil { + return fmt.Errorf("failed to get user:%w", err) + } + + addresses, err := s.provider.GetAddresses(ctx) + if err != nil { + return fmt.Errorf("failed to get addresses:%w", err) + } + + s.User = user + s.Addresses = buildAddressMapFromSlice(addresses) + s.AddressesSorted = sortAddresses(maps.Values(s.Addresses)) + + return nil +} + +func (s *State) OnUserSpaceChanged(value int) bool { + if s.User.UsedSpace == value { + return false + } + + s.User.UsedSpace = value + + return true +} + +type AddressUpdate int + +const ( + AddressUpdateNoop AddressUpdate = iota + AddressUpdateCreated + AddressUpdateEnabled + AddressUpdateDisabled + AddressUpdateUpdated + AddressUpdateDeleted +) + +func (s *State) OnAddressCreated(event proton.AddressEvent) AddressUpdate { + if _, ok := s.Addresses[event.Address.ID]; ok { + return AddressUpdateNoop + } + + s.Addresses[event.Address.ID] = event.Address + s.AddressesSorted = sortAddresses(maps.Values(s.Addresses)) + + if event.Address.Status != proton.AddressStatusEnabled { + return AddressUpdateNoop + } + + return AddressUpdateCreated +} + +func (s *State) OnAddressUpdated(event proton.AddressEvent) (proton.Address, AddressUpdate) { + // Address does not exist create it. + oldAddr, ok := s.Addresses[event.Address.ID] + if !ok { + return event.Address, s.OnAddressCreated(event) + } + + s.Addresses[event.Address.ID] = event.Address + s.AddressesSorted = sortAddresses(maps.Values(s.Addresses)) + + switch { + // If the address was newly enabled: + case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled: + return event.Address, AddressUpdateEnabled + + // If the address was newly disabled: + case oldAddr.Status == proton.AddressStatusEnabled && event.Address.Status != proton.AddressStatusEnabled: + return event.Address, AddressUpdateDisabled + + // Otherwise it's just an update: + default: + return event.Address, AddressUpdateUpdated + } +} + +func (s *State) OnAddressDeleted(event proton.AddressEvent) (proton.Address, AddressUpdate) { + addr, ok := s.Addresses[event.ID] + if !ok { + return proton.Address{}, AddressUpdateNoop + } + + delete(s.Addresses, event.ID) + s.AddressesSorted = sortAddresses(maps.Values(s.Addresses)) + + if addr.Status != proton.AddressStatusEnabled { + return proton.Address{}, AddressUpdateNoop + } + + return addr, AddressUpdateDeleted +}