diff --git a/internal/services/useridentity/auth.go b/internal/services/useridentity/auth.go
new file mode 100644
index 00000000..7c101e9c
--- /dev/null
+++ b/internal/services/useridentity/auth.go
@@ -0,0 +1,26 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package useridentity
+
+type KeyPassProvider interface {
+ KeyPass() []byte
+}
+
+type BridgePassProvider interface {
+ BridgePass() []byte
+}
diff --git a/internal/services/useridentity/service.go b/internal/services/useridentity/service.go
index 4d90b20f..fb56a003 100644
--- a/internal/services/useridentity/service.go
+++ b/internal/services/useridentity/service.go
@@ -51,33 +51,25 @@ type Service struct {
}
func NewService(
- ctx context.Context,
service userevents.Subscribable,
- user proton.User,
eventPublisher events.EventPublisher,
- provider IdentityProvider,
-) (*Service, error) {
- addresses, err := provider.GetAddresses(ctx)
- if err != nil {
- return nil, fmt.Errorf("failed to get addresses: %w", err)
- }
-
- subscriberName := fmt.Sprintf("identity-%v", user.ID)
+ state *State,
+) *Service {
+ subscriberName := fmt.Sprintf("identity-%v", state.User.ID)
return &Service{
eventService: service,
- identity: NewState(user, addresses, provider),
+ identity: *state,
eventPublisher: eventPublisher,
log: logrus.WithFields(logrus.Fields{
"service": "user-identity",
- "user": user.ID,
+ "user": state.User.ID,
}),
-
userSubscriber: userevents.NewUserSubscriber(subscriberName),
refreshSubscriber: userevents.NewRefreshSubscriber(subscriberName),
addressSubscriber: userevents.NewAddressSubscriber(subscriberName),
usedSpaceSubscriber: userevents.NewUserUsedSpaceSubscriber(subscriberName),
- }, nil
+ }
}
func (s *Service) Start(group *async.Group) {
diff --git a/internal/services/useridentity/service_test.go b/internal/services/useridentity/service_test.go
index 1d3021a5..33d1bdeb 100644
--- a/internal/services/useridentity/service_test.go
+++ b/internal/services/useridentity/service_test.go
@@ -356,17 +356,13 @@ func TestService_OnAddressDeletedUnknownDoesNotProduceEvent(t *testing.T) {
require.NoError(t, err)
}
-func newTestService(t *testing.T, mockCtrl *gomock.Controller) (*Service, *mocks2.MockEventPublisher, *mocks.MockIdentityProvider) {
+func newTestService(_ *testing.T, mockCtrl *gomock.Controller) (*Service, *mocks2.MockEventPublisher, *mocks.MockIdentityProvider) {
subscribable := &userevents.NoOpSubscribable{}
eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
provider := mocks.NewMockIdentityProvider(mockCtrl)
user := newTestUser()
- provider.EXPECT().GetAddresses(gomock.Any()).Times(1).Return(newTestAddresses(), nil)
-
- service, err := NewService(context.Background(), subscribable, user, eventPublisher, provider)
- require.NoError(t, err)
-
+ service := NewService(subscribable, eventPublisher, NewState(user, newTestAddresses(), provider))
return service, eventPublisher, provider
}
diff --git a/internal/services/useridentity/state.go b/internal/services/useridentity/state.go
index 178c0d23..ac8e90d9 100644
--- a/internal/services/useridentity/state.go
+++ b/internal/services/useridentity/state.go
@@ -19,12 +19,15 @@ package useridentity
import (
"context"
+ "crypto/subtle"
"fmt"
"strings"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
+ "github.com/ProtonMail/proton-bridge/v3/pkg/algo"
"golang.org/x/exp/maps"
+ "golang.org/x/exp/slices"
)
// State holds all the required user identity state. The idea of this type is that
@@ -42,9 +45,9 @@ func NewState(
user proton.User,
addresses []proton.Address,
provider IdentityProvider,
-) State {
+) *State {
addressMap := buildAddressMapFromSlice(addresses)
- return State{
+ return &State{
AddressesSorted: sortAddresses(maps.Values(addressMap)),
Addresses: addressMap,
User: user,
@@ -53,15 +56,15 @@ func NewState(
}
}
-func NewStateFromProvider(ctx context.Context, provider IdentityProvider) (State, error) {
+func NewStateFromProvider(ctx context.Context, provider IdentityProvider) (*State, error) {
user, err := provider.GetUser(ctx)
if err != nil {
- return State{}, fmt.Errorf("failed to get user: %w", err)
+ return nil, fmt.Errorf("failed to get user: %w", err)
}
addresses, err := provider.GetAddresses(ctx)
if err != nil {
- return State{}, fmt.Errorf("failed to get user addresses: %w", err)
+ return nil, fmt.Errorf("failed to get user addresses: %w", err)
}
return NewState(user, addresses, provider), nil
@@ -185,3 +188,58 @@ func (s *State) OnAddressDeleted(event proton.AddressEvent) (proton.Address, Add
return addr, AddressUpdateDeleted
}
+
+func (s *State) OnAddressEvents(events []proton.AddressEvent) {
+ for _, evt := range events {
+ switch evt.Action {
+ case proton.EventCreate:
+ s.OnAddressCreated(evt)
+ case proton.EventUpdate, proton.EventUpdateFlags:
+ s.OnAddressUpdated(evt)
+ case proton.EventDelete:
+ s.OnAddressDeleted(evt)
+ }
+ }
+}
+
+func (s *State) Clone() *State {
+ return &State{
+ AddressesSorted: slices.Clone(s.AddressesSorted),
+ Addresses: maps.Clone(s.Addresses),
+ User: s.User,
+ provider: s.provider,
+ }
+}
+
+// CheckAuth returns whether the given email and password can be used to authenticate over IMAP or SMTP with this user.
+// It returns the address ID of the authenticated address.
+func (s *State) CheckAuth(email string, password []byte, bridgePassProvider BridgePassProvider, telemetry Telemetry) (string, error) {
+ if email == "crash@bandicoot" {
+ panic("your wish is my command.. I crash")
+ }
+
+ dec, err := algo.B64RawDecode(password)
+ if err != nil {
+ return "", fmt.Errorf("failed to decode password: %w", err)
+ }
+
+ if subtle.ConstantTimeCompare(bridgePassProvider.BridgePass(), dec) != 1 {
+ err := fmt.Errorf("invalid password")
+ if telemetry != nil {
+ telemetry.ReportConfigStatusFailure(err.Error())
+ }
+ return "", err
+ }
+
+ for _, addr := range s.AddressesSorted {
+ if addr.Status != proton.AddressStatusEnabled {
+ continue
+ }
+
+ if strings.EqualFold(addr.Email, email) {
+ return addr.ID, nil
+ }
+ }
+
+ return "", fmt.Errorf("invalid email")
+}
diff --git a/internal/services/useridentity/telemetry.go b/internal/services/useridentity/telemetry.go
new file mode 100644
index 00000000..0400e1f2
--- /dev/null
+++ b/internal/services/useridentity/telemetry.go
@@ -0,0 +1,22 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package useridentity
+
+type Telemetry interface {
+ ReportConfigStatusFailure(errDetails string)
+}