Other: Safer user types

This commit is contained in:
James Houlahan
2022-10-12 00:20:04 +02:00
parent 4dc32dc7f2
commit fd63611b41
35 changed files with 1253 additions and 771 deletions

View File

@ -1,19 +1,18 @@
package user
import (
"bytes"
"context"
"encoding/hex"
"crypto/subtle"
"fmt"
"strings"
"time"
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gluon/wait"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/try"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-smtp"
@ -32,15 +31,11 @@ type User struct {
eventCh *queue.QueuedChannel[events.Event]
apiUser *safe.Value[liteapi.User]
apiAddrs *safe.Slice[liteapi.Address]
settings *safe.Value[liteapi.MailSettings]
apiAddrs *safe.Map[string, liteapi.Address]
updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]]
userKR *safe.Value[*crypto.KeyRing]
addrKRs *safe.Map[string, *crypto.KeyRing]
updateCh map[string]*queue.QueuedChannel[imap.Update]
syncStopCh chan struct{}
syncWG wait.Group
syncLock try.Group
}
func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User) (*User, error) {
@ -50,9 +45,8 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
return nil, fmt.Errorf("failed to get addresses: %w", err)
}
// Unlock the user's keyrings.
userKR, addrKRs, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass())
if err != nil {
// Check we can unlock the keyrings.
if _, _, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass()); err != nil {
return nil, fmt.Errorf("failed to unlock user: %w", err)
}
@ -68,20 +62,21 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
}
}
// Get the user's mail settings.
settings, err := client.GetMailSettings(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get mail settings: %w", err)
}
// Create update channels for each of the user's addresses (if in combined mode, just the primary).
// Create update channels for each of the user's addresses.
// In combined mode, the addresses all share the same update channel.
updateCh := make(map[string]*queue.QueuedChannel[imap.Update])
for _, addr := range apiAddrs {
updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
switch encVault.AddressMode() {
case vault.CombinedMode:
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
if encVault.AddressMode() == vault.CombinedMode {
break
for _, addr := range apiAddrs {
updateCh[addr.ID] = primaryUpdateCh
}
case vault.SplitMode:
for _, addr := range apiAddrs {
updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
}
}
@ -91,19 +86,15 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
apiUser: safe.NewValue(apiUser),
apiAddrs: safe.NewSlice(apiAddrs),
settings: safe.NewValue(settings),
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
updateCh: safe.NewMapFrom(updateCh, nil),
userKR: safe.NewValue(userKR),
addrKRs: safe.NewMap(addrKRs),
updateCh: updateCh,
syncStopCh: make(chan struct{}),
}
// When we receive an auth object, we update it in the vault.
// This will be used to authorize the user on the next run.
client.AddAuthHandler(func(auth liteapi.Auth) {
user.client.AddAuthHandler(func(auth liteapi.Auth) {
if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil {
logrus.WithError(err).Error("Failed to update auth in vault")
}
@ -111,23 +102,24 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
// When we are deauthorized, we send a deauth event to the event channel.
// Bridge will react to this event by logging out the user.
client.AddDeauthHandler(func() {
user.client.AddDeauthHandler(func() {
user.eventCh.Enqueue(events.UserDeauth{
UserID: user.ID(),
})
})
// TODO: Don't start the event loop until the initial sync has finished!
eventCh := user.client.NewEventStream(EventPeriod, EventJitter, user.vault.EventID())
// If we haven't synced yet, do it first.
// If it fails, we don't start the event loop.
// Otherwise, begin processing API events, logging any errors that occur.
go func() {
if status := user.vault.SyncStatus(); !status.HasMessages {
if err := <-user.startSync(); err != nil {
return
}
if err := <-user.startSync(); err != nil {
return
}
for err := range user.streamEvents() {
for err := range user.streamEvents(eventCh) {
logrus.WithError(err).Error("Error while streaming events")
}
}()
@ -137,40 +129,34 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
// ID returns the user's ID.
func (user *User) ID() string {
return safe.GetType(user.apiUser, func(apiUser liteapi.User) string {
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string {
return apiUser.ID
})
}
// Name returns the user's username.
func (user *User) Name() string {
return safe.GetType(user.apiUser, func(apiUser liteapi.User) string {
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string {
return apiUser.Name
})
}
// Match matches the given query against the user's username and email addresses.
func (user *User) Match(query string) bool {
return safe.GetType(user.apiUser, func(apiUser liteapi.User) bool {
return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) bool {
if query == apiUser.Name {
return true
}
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) bool {
if query == apiUser.Name {
return true
}
for _, addr := range apiAddrs {
if addr.Email == query {
return true
}
}
return false
return user.apiAddrs.HasFunc(func(_ string, addr liteapi.Address) bool {
return addr.Email == query
})
})
}
// Emails returns all the user's email addresses.
// Emails returns all the user's email addresses via the callback.
func (user *User) Emails() []string {
return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) []string {
return safe.MapValuesRet(user.apiAddrs, func(apiAddrs []liteapi.Address) []string {
return xslices.Map(apiAddrs, func(addr liteapi.Address) string {
return addr.Email
})
@ -184,28 +170,38 @@ func (user *User) GetAddressMode() vault.AddressMode {
// SetAddressMode sets the user's address mode.
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
for _, updateCh := range user.updateCh {
updateCh.Close()
}
user.stopSync()
user.lockSync()
defer user.unlockSync()
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
for _, addr := range apiAddrs {
user.updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
if mode == vault.CombinedMode {
break
}
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
for _, updateCh := range xslices.Unique(updateCh) {
updateCh.Close()
}
})
updateCh := make(map[string]*queue.QueuedChannel[imap.Update])
switch mode {
case vault.CombinedMode:
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
user.apiAddrs.IterKeys(func(addrID string) {
updateCh[addrID] = primaryUpdateCh
})
case vault.SplitMode:
user.apiAddrs.IterKeys(func(addrID string) {
updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
})
}
user.updateCh = safe.NewMapFrom(updateCh, nil)
if err := user.vault.SetAddressMode(mode); err != nil {
return fmt.Errorf("failed to set address mode: %w", err)
}
user.stopSync()
if err := user.vault.ClearSyncStatus(); err != nil {
return fmt.Errorf("failed to clear sync status: %w", err)
}
@ -246,25 +242,19 @@ func (user *User) GluonKey() []byte {
// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP.
func (user *User) BridgePass() []byte {
buf := new(bytes.Buffer)
if _, err := hex.NewEncoder(buf).Write(user.vault.BridgePass()); err != nil {
panic(err)
}
return buf.Bytes()
return hexEncode(user.vault.BridgePass())
}
// UsedSpace returns the total space used by the user on the API.
func (user *User) UsedSpace() int {
return safe.GetType(user.apiUser, func(apiUser liteapi.User) int {
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int {
return apiUser.UsedSpace
})
}
// MaxSpace returns the amount of space the user can use on the API.
func (user *User) MaxSpace() int {
return safe.GetType(user.apiUser, func(apiUser liteapi.User) int {
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int {
return apiUser.MaxSpace
})
}
@ -275,37 +265,9 @@ func (user *User) GetEventCh() <-chan events.Event {
}
// NewIMAPConnector returns an IMAP connector for the given address.
// If not in split mode, this function returns an error.
func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
return safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (connector.Connector, error) {
var emails []string
switch user.vault.AddressMode() {
case vault.CombinedMode:
if addrID != apiAddrs[0].ID {
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
}
emails = xslices.Map(apiAddrs, func(addr liteapi.Address) string {
return addr.Email
})
case vault.SplitMode:
email, err := getAddrEmail(apiAddrs, addrID)
if err != nil {
return nil, err
}
emails = []string{email}
}
return newIMAPConnector(
user.client,
user.updateCh[addrID].GetChannel(),
user.BridgePass(),
emails...,
), nil
})
// If not in split mode, this must be the primary address.
func (user *User) NewIMAPConnector(addrID string) connector.Connector {
return newIMAPConnector(user, addrID)
}
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
@ -314,23 +276,48 @@ func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
imapConn := make(map[string]connector.Connector)
for addrID := range user.updateCh {
conn, err := user.NewIMAPConnector(addrID)
if err != nil {
return nil, fmt.Errorf("failed to create IMAP connector: %w", err)
}
switch user.vault.AddressMode() {
case vault.CombinedMode:
user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) {
imapConn[addrID] = newIMAPConnector(user, addrID)
})
imapConn[addrID] = conn
case vault.SplitMode:
user.apiAddrs.IterKeys(func(addrID string) {
imapConn[addrID] = newIMAPConnector(user, addrID)
})
}
return imapConn, nil
}
// NewSMTPSession returns an SMTP session for the user.
func (user *User) NewSMTPSession(email string) (smtp.Session, error) {
func (user *User) NewSMTPSession(email string, password []byte) (smtp.Session, error) {
if _, err := user.checkAuth(email, password); err != nil {
return nil, err
}
return newSMTPSession(user, email)
}
// OnStatusUp is called when the connection goes up.
func (user *User) OnStatusUp() {
go func() {
logrus.Info("Connection up, checking if sync is needed")
if err := <-user.startSync(); err != nil {
logrus.WithError(err).Error("Failed to sync on status up")
}
}()
}
// OnStatusDown is called when the connection goes down.
func (user *User) OnStatusDown() {
logrus.Info("Connection down, aborting any ongoing syncs")
user.stopSync()
}
// Logout logs the user out from the API.
// If withVault is true, the user's vault is also cleared.
func (user *User) Logout(ctx context.Context) error {
@ -350,13 +337,18 @@ func (user *User) Close() error {
// Cancel ongoing syncs.
user.stopSync()
// Wait for ongoing syncs to stop.
user.waitSync()
// Close the user's API client.
user.client.Close()
// Close the user's update channels.
for _, updateCh := range user.updateCh {
updateCh.Close()
}
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
for _, updateCh := range xslices.Unique(updateCh) {
updateCh.Close()
}
})
// Close the user's notify channel.
user.eventCh.Close()
@ -364,16 +356,37 @@ func (user *User) Close() error {
return nil
}
func (user *User) checkAuth(email string, password []byte) (string, error) {
dec, err := hexDecode(password)
if err != nil {
return "", fmt.Errorf("failed to decode password: %w", err)
}
if subtle.ConstantTimeCompare(user.vault.BridgePass(), dec) != 1 {
return "", fmt.Errorf("invalid password")
}
return safe.MapValuesRetErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
for _, addr := range apiAddrs {
if addr.Email == strings.ToLower(email) {
return addr.ID, nil
}
}
return "", fmt.Errorf("invalid email")
})
}
// streamEvents begins streaming API events for the user.
// When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault.
func (user *User) streamEvents() <-chan error {
func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error {
errCh := make(chan error)
go func() {
defer close(errCh)
for event := range user.client.NewEventStreamer(EventPeriod, EventJitter, user.vault.EventID()).Subscribe() {
for event := range eventCh {
if err := user.handleAPIEvent(context.Background(), event); err != nil {
errCh <- fmt.Errorf("failed to handle API event: %w", err)
} else if err := user.vault.SetEventID(event.EventID); err != nil {
@ -387,11 +400,21 @@ func (user *User) streamEvents() <-chan error {
// startSync begins a startSync for the user.
func (user *User) startSync() <-chan error {
if user.vault.SyncStatus().IsComplete() {
logrus.Debug("Already synced, skipping")
return nil
}
errCh := make(chan error)
user.syncWG.Go(func() {
user.syncLock.GoTry(func(ok bool) {
defer close(errCh)
if !ok {
logrus.Debug("Sync already in progress, skipping")
return
}
ctx, cancel := contextWithStopCh(context.Background(), user.syncStopCh)
defer cancel()
@ -421,46 +444,24 @@ func (user *User) startSync() <-chan error {
func (user *User) stopSync() {
select {
case user.syncStopCh <- struct{}{}:
user.syncWG.Wait()
logrus.Debug("Sent sync abort signal")
default:
// ...
logrus.Debug("No sync to abort")
}
}
func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
for _, addr := range apiAddrs {
if addr.Email == email {
return addr.ID, nil
}
}
return "", fmt.Errorf("address %s not found", email)
// lockSync prevents a new sync from starting.
func (user *User) lockSync() {
user.syncLock.Lock()
}
func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) {
for _, addr := range apiAddrs {
if addr.ID == addrID {
return addr.Email, nil
}
}
return "", fmt.Errorf("address %s not found", addrID)
// unlockSync allows a new sync to start.
func (user *User) unlockSync() {
user.syncLock.Unlock()
}
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
go func() {
select {
case <-stopCh:
cancel()
case <-ctx.Done():
// ...
}
}()
return ctx, cancel
// waitSync waits for any ongoing sync to finish.
func (user *User) waitSync() {
user.syncLock.Wait()
}