Files
proton-bridge/internal/services/useridentity/service.go
Leander Beernaert 7be1a8ae8a feat(GODT-2801): Identity State Cloning & Auth Check
This patch moves the `user.User.CheckAuth` function into the `State`
type as it contains all the necessary information to perform this check.

A couple of more interfaces are added to abstract the retrieval of the
Bridge and the User's key passwords, as well as telemetry reports.

Finally, adds `State.OnAddressEvents` which other services should use to
update the state.
2023-07-24 17:10:05 +02:00

263 lines
7.4 KiB
Go

// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package useridentity
import (
"context"
"fmt"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/services/userevents"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
)
type IdentityProvider interface {
GetUser(ctx context.Context) (proton.User, error)
GetAddresses(ctx context.Context) ([]proton.Address, error)
}
// Service contains all the data required to establish the user identity. This
// includes all the user's information as well as mail addresses and keys.
type Service struct {
eventService userevents.Subscribable
eventPublisher events.EventPublisher
log *logrus.Entry
identity State
userSubscriber *userevents.UserChanneledSubscriber
addressSubscriber *userevents.AddressChanneledSubscriber
usedSpaceSubscriber *userevents.UserUsedSpaceChanneledSubscriber
refreshSubscriber *userevents.RefreshChanneledSubscriber
}
func NewService(
service userevents.Subscribable,
eventPublisher events.EventPublisher,
state *State,
) *Service {
subscriberName := fmt.Sprintf("identity-%v", state.User.ID)
return &Service{
eventService: service,
identity: *state,
eventPublisher: eventPublisher,
log: logrus.WithFields(logrus.Fields{
"service": "user-identity",
"user": state.User.ID,
}),
userSubscriber: userevents.NewUserSubscriber(subscriberName),
refreshSubscriber: userevents.NewRefreshSubscriber(subscriberName),
addressSubscriber: userevents.NewAddressSubscriber(subscriberName),
usedSpaceSubscriber: userevents.NewUserUsedSpaceSubscriber(subscriberName),
}
}
func (s *Service) Start(group *async.Group) {
group.Once(func(ctx context.Context) {
s.run(ctx)
})
}
func (s *Service) run(ctx context.Context) {
s.log.WithFields(logrus.Fields{
"numAddr": len(s.identity.Addresses),
}).Info("Starting user identity service")
s.registerSubscription()
defer s.unregisterSubscription()
for {
select {
case <-ctx.Done():
return
case evt, ok := <-s.userSubscriber.OnEventCh():
if !ok {
continue
}
evt.Consume(func(user proton.User) error {
s.onUserEvent(ctx, user)
return nil
})
case evt, ok := <-s.refreshSubscriber.OnEventCh():
if !ok {
continue
}
evt.Consume(func(_ proton.RefreshFlag) error {
return s.onRefreshEvent(ctx)
})
case evt, ok := <-s.usedSpaceSubscriber.OnEventCh():
if !ok {
continue
}
evt.Consume(func(usedSpace int) error {
s.onUserSpaceChanged(ctx, usedSpace)
return nil
})
case evt, ok := <-s.addressSubscriber.OnEventCh():
if !ok {
continue
}
evt.Consume(func(events []proton.AddressEvent) error {
return s.onAddressEvent(ctx, events)
})
}
}
}
func (s *Service) registerSubscription() {
s.eventService.Subscribe(userevents.Subscription{
Refresh: s.refreshSubscriber,
User: s.userSubscriber,
Address: s.addressSubscriber,
UserUsedSpace: s.usedSpaceSubscriber,
})
}
func (s *Service) unregisterSubscription() {
s.eventService.Unsubscribe(userevents.Subscription{
Refresh: s.refreshSubscriber,
User: s.userSubscriber,
Address: s.addressSubscriber,
UserUsedSpace: s.usedSpaceSubscriber,
})
}
func (s *Service) onUserEvent(ctx context.Context, user proton.User) {
s.log.WithField("username", logging.Sensitive(user.Name)).Info("Handling user event")
s.identity.OnUserEvent(user)
s.eventPublisher.PublishEvent(ctx, events.UserChanged{
UserID: user.ID,
})
}
func (s *Service) onRefreshEvent(ctx context.Context) error {
s.log.Info("Handling refresh event")
if err := s.identity.OnRefreshEvent(ctx); err != nil {
s.log.WithError(err).Error("Failed to handle refresh event")
return err
}
s.eventPublisher.PublishEvent(ctx, events.UserRefreshed{
UserID: s.identity.User.ID,
CancelEventPool: false,
})
return nil
}
func (s *Service) onUserSpaceChanged(ctx context.Context, value int) {
s.log.Info("Handling User Space Changed event")
if !s.identity.OnUserSpaceChanged(value) {
return
}
s.eventPublisher.PublishEvent(ctx, events.UsedSpaceChanged{
UserID: s.identity.User.ID,
UsedSpace: value,
})
}
func (s *Service) onAddressEvent(ctx context.Context, addressEvents []proton.AddressEvent) error {
s.log.Infof("Handling Address Events (%v)", len(addressEvents))
for idx, event := range addressEvents {
switch event.Action {
case proton.EventCreate:
s.log.WithFields(logrus.Fields{
"index": idx,
"addressID": event.ID,
"email": logging.Sensitive(event.Address.Email),
}).Info("Handling address created event")
if s.identity.OnAddressCreated(event) == AddressUpdateCreated {
s.eventPublisher.PublishEvent(ctx, events.UserAddressCreated{
UserID: s.identity.User.ID,
AddressID: event.Address.ID,
Email: event.Address.Email,
})
}
case proton.EventUpdate, proton.EventUpdateFlags:
addr, status := s.identity.OnAddressUpdated(event)
switch status {
case AddressUpdateCreated:
s.eventPublisher.PublishEvent(ctx, events.UserAddressCreated{
UserID: s.identity.User.ID,
AddressID: addr.ID,
Email: addr.Email,
})
case AddressUpdateUpdated:
s.eventPublisher.PublishEvent(ctx, events.UserAddressUpdated{
UserID: s.identity.User.ID,
AddressID: addr.ID,
Email: addr.Email,
})
case AddressUpdateDisabled:
s.eventPublisher.PublishEvent(ctx, events.UserAddressDisabled{
UserID: s.identity.User.ID,
AddressID: addr.ID,
Email: addr.Email,
})
case AddressUpdateEnabled:
s.eventPublisher.PublishEvent(ctx, events.UserAddressEnabled{
UserID: s.identity.User.ID,
AddressID: addr.ID,
Email: addr.Email,
})
case AddressUpdateNoop:
continue
case AddressUpdateDeleted:
s.log.Warnf("Unexpected address update status after update event %v", status)
continue
}
case proton.EventDelete:
if addr, status := s.identity.OnAddressDeleted(event); status == AddressUpdateDeleted {
s.eventPublisher.PublishEvent(ctx, events.UserAddressDeleted{
UserID: s.identity.User.ID,
AddressID: event.ID,
Email: addr.Email,
})
}
}
}
return nil
}
func sortAddresses(addr []proton.Address) []proton.Address {
slices.SortFunc(addr, func(a, b proton.Address) bool {
return a.Order < b.Order
})
return addr
}
func buildAddressMapFromSlice(addr []proton.Address) map[string]proton.Address {
return usertypes.GroupBy(addr, func(addr proton.Address) string { return addr.ID })
}