diff --git a/internal/services/useridentity/auth.go b/internal/services/useridentity/auth.go index 7c101e9c..52b0e0aa 100644 --- a/internal/services/useridentity/auth.go +++ b/internal/services/useridentity/auth.go @@ -24,3 +24,15 @@ type KeyPassProvider interface { type BridgePassProvider interface { BridgePass() []byte } + +type FixedBridgePassProvider struct { + pass []byte +} + +func (f FixedBridgePassProvider) BridgePass() []byte { + return f.pass +} + +func NewFixedBridgePassProvider(pass []byte) *FixedBridgePassProvider { + return &FixedBridgePassProvider{pass: pass} +} diff --git a/internal/services/useridentity/mocks/mocks.go b/internal/services/useridentity/mocks/mocks.go index ce105d33..7c342639 100644 --- a/internal/services/useridentity/mocks/mocks.go +++ b/internal/services/useridentity/mocks/mocks.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity (interfaces: IdentityProvider) +// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity (interfaces: IdentityProvider,Telemetry) // Package mocks is a generated GoMock package. package mocks @@ -64,3 +64,38 @@ func (mr *MockIdentityProviderMockRecorder) GetUser(arg0 interface{}) *gomock.Ca mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUser", reflect.TypeOf((*MockIdentityProvider)(nil).GetUser), arg0) } + +// MockTelemetry is a mock of Telemetry interface. +type MockTelemetry struct { + ctrl *gomock.Controller + recorder *MockTelemetryMockRecorder +} + +// MockTelemetryMockRecorder is the mock recorder for MockTelemetry. +type MockTelemetryMockRecorder struct { + mock *MockTelemetry +} + +// NewMockTelemetry creates a new mock instance. +func NewMockTelemetry(ctrl *gomock.Controller) *MockTelemetry { + mock := &MockTelemetry{ctrl: ctrl} + mock.recorder = &MockTelemetryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTelemetry) EXPECT() *MockTelemetryMockRecorder { + return m.recorder +} + +// ReportConfigStatusFailure mocks base method. +func (m *MockTelemetry) ReportConfigStatusFailure(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReportConfigStatusFailure", arg0) +} + +// ReportConfigStatusFailure indicates an expected call of ReportConfigStatusFailure. +func (mr *MockTelemetryMockRecorder) ReportConfigStatusFailure(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportConfigStatusFailure", reflect.TypeOf((*MockTelemetry)(nil).ReportConfigStatusFailure), arg0) +} diff --git a/internal/services/useridentity/service.go b/internal/services/useridentity/service.go index fb56a003..e2a042e2 100644 --- a/internal/services/useridentity/service.go +++ b/internal/services/useridentity/service.go @@ -21,13 +21,15 @@ import ( "context" "fmt" - "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/logging" + "github.com/ProtonMail/proton-bridge/v3/internal/services/orderedtasks" "github.com/ProtonMail/proton-bridge/v3/internal/services/userevents" "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" + "github.com/ProtonMail/proton-bridge/v3/pkg/cpc" "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) @@ -39,6 +41,7 @@ type IdentityProvider interface { // Service contains all the data required to establish the user identity. This // includes all the user's information as well as mail addresses and keys. type Service struct { + cpc *cpc.CPC eventService userevents.Subscribable eventPublisher events.EventPublisher log *logrus.Entry @@ -48,16 +51,22 @@ type Service struct { addressSubscriber *userevents.AddressChanneledSubscriber usedSpaceSubscriber *userevents.UserUsedSpaceChanneledSubscriber refreshSubscriber *userevents.RefreshChanneledSubscriber + + bridgePassProvider BridgePassProvider + telemetry Telemetry } func NewService( service userevents.Subscribable, eventPublisher events.EventPublisher, state *State, + bridgePassProvider BridgePassProvider, + telemetry Telemetry, ) *Service { subscriberName := fmt.Sprintf("identity-%v", state.User.ID) return &Service{ + cpc: cpc.NewCPC(), eventService: service, identity: *state, eventPublisher: eventPublisher, @@ -69,12 +78,33 @@ func NewService( refreshSubscriber: userevents.NewRefreshSubscriber(subscriberName), addressSubscriber: userevents.NewAddressSubscriber(subscriberName), usedSpaceSubscriber: userevents.NewUserUsedSpaceSubscriber(subscriberName), + bridgePassProvider: bridgePassProvider, + telemetry: telemetry, } } -func (s *Service) Start(group *async.Group) { - group.Once(func(ctx context.Context) { - s.run(ctx) +func (s *Service) Start(ctx context.Context, group *orderedtasks.OrderedCancelGroup) { + group.Go(ctx, s.identity.User.ID, "identity-service", s.run) +} + +func (s *Service) Resync(ctx context.Context) error { + _, err := s.cpc.Send(ctx, &resyncReq{}) + + return err +} + +func (s *Service) GetAPIUser(ctx context.Context) (proton.User, error) { + return cpc.SendTyped[proton.User](ctx, s.cpc, &getUserReq{}) +} + +func (s *Service) GetAddresses(ctx context.Context) (map[string]proton.Address, error) { + return cpc.SendTyped[map[string]proton.Address](ctx, s.cpc, &getAddressesReq{}) +} + +func (s *Service) CheckAuth(ctx context.Context, email string, password []byte) (string, error) { + return cpc.SendTyped[string](ctx, s.cpc, &checkAuthReq{ + email: email, + password: password, }) } @@ -82,14 +112,40 @@ func (s *Service) run(ctx context.Context) { s.log.WithFields(logrus.Fields{ "numAddr": len(s.identity.Addresses), }).Info("Starting user identity service") + defer s.log.Info("Exiting Service") s.registerSubscription() defer s.unregisterSubscription() + defer s.cpc.Close() + for { select { case <-ctx.Done(): return + case r, ok := <-s.cpc.ReceiveCh(): + if !ok { + continue + } + switch req := r.Value().(type) { + case *resyncReq: + err := s.identity.OnRefreshEvent(ctx) + r.Reply(ctx, nil, err) + + case *getUserReq: + r.Reply(ctx, s.identity.User, nil) + + case *getAddressesReq: + r.Reply(ctx, maps.Clone(s.identity.Addresses), nil) + + case *checkAuthReq: + id, err := s.identity.CheckAuth(req.email, req.password, s.bridgePassProvider, s.telemetry) + r.Reply(ctx, id, err) + + default: + s.log.Error("Invalid request") + } + case evt, ok := <-s.userSubscriber.OnEventCh(): if !ok { continue @@ -260,3 +316,14 @@ func sortAddresses(addr []proton.Address) []proton.Address { func buildAddressMapFromSlice(addr []proton.Address) map[string]proton.Address { return usertypes.GroupBy(addr, func(addr proton.Address) string { return addr.ID }) } + +type resyncReq struct{} + +type getUserReq struct{} + +type checkAuthReq struct { + email string + password []byte +} + +type getAddressesReq struct{} diff --git a/internal/services/useridentity/service_test.go b/internal/services/useridentity/service_test.go index 33d1bdeb..7c627b1b 100644 --- a/internal/services/useridentity/service_test.go +++ b/internal/services/useridentity/service_test.go @@ -361,8 +361,10 @@ func newTestService(_ *testing.T, mockCtrl *gomock.Controller) (*Service, *mocks eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) provider := mocks.NewMockIdentityProvider(mockCtrl) user := newTestUser() + telemetry := mocks.NewMockTelemetry(mockCtrl) + bridgePassProvider := NewFixedBridgePassProvider([]byte("hello")) - service := NewService(subscribable, eventPublisher, NewState(user, newTestAddresses(), provider)) + service := NewService(subscribable, eventPublisher, NewState(user, newTestAddresses(), provider), bridgePassProvider, telemetry) return service, eventPublisher, provider } diff --git a/internal/services/useridentity/state.go b/internal/services/useridentity/state.go index 6d5bbdb4..5bb5ca44 100644 --- a/internal/services/useridentity/state.go +++ b/internal/services/useridentity/state.go @@ -21,14 +21,13 @@ import ( "context" "crypto/subtle" "fmt" - "github.com/ProtonMail/gopenpgp/v2/crypto" "strings" "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v3/internal/usertypes" "github.com/ProtonMail/proton-bridge/v3/pkg/algo" "golang.org/x/exp/maps" - "golang.org/x/exp/slices" ) // State holds all the required user identity state. The idea of this type is that @@ -82,6 +81,13 @@ func (s *State) GetAddr(email string) (proton.Address, error) { return proton.Address{}, fmt.Errorf("address %s not found", email) } +// GetAddrByID returns the address for the given addressID. +func (s *State) GetAddrByID(id string) (proton.Address, bool) { + v, ok := s.Addresses[id] + + return v, ok +} + // GetPrimaryAddr returns the primary address for this user. func (s *State) GetPrimaryAddr() (proton.Address, error) { if len(s.AddressesSorted) == 0 { @@ -204,9 +210,15 @@ func (s *State) OnAddressEvents(events []proton.AddressEvent) { } func (s *State) Clone() *State { + mapCopy := make(map[string]proton.Address, len(s.Addresses)) + sliceCopy := make([]proton.Address, len(s.AddressesSorted)) + + copy(sliceCopy, s.AddressesSorted) + maps.Copy(mapCopy, s.Addresses) + return &State{ - AddressesSorted: slices.Clone(s.AddressesSorted), - Addresses: maps.Clone(s.Addresses), + AddressesSorted: sliceCopy, + Addresses: mapCopy, User: s.User, provider: s.provider, }