GODT-2223: Handle bad events by logging user out

This commit is contained in:
James Houlahan
2023-01-17 15:11:13 +01:00
committed by Leander Beernaert
parent 70f0384cc3
commit 849c8bee78
4 changed files with 80 additions and 23 deletions

View File

@ -51,6 +51,9 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
case events.UserDeauth:
bridge.handleUserDeauth(ctx, user)
case events.UserBadEvent:
bridge.handleUserBadEvent(ctx, user)
}
return nil
@ -130,3 +133,9 @@ func (bridge *Bridge) handleUserDeauth(ctx context.Context, user *user.User) {
bridge.logoutUser(ctx, user, false, false)
}, bridge.usersLock)
}
func (bridge *Bridge) handleUserBadEvent(ctx context.Context, user *user.User) {
safe.Lock(func() {
bridge.logoutUser(ctx, user, true, false)
}, bridge.usersLock)
}

View File

@ -23,6 +23,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
)
// AllUsersLoaded is emitted when all users have been loaded.
type AllUsersLoaded struct {
eventBase
}
@ -31,6 +32,7 @@ func (event AllUsersLoaded) String() string {
return "AllUsersLoaded"
}
// UserLoading is emitted when a user is being loaded.
type UserLoading struct {
eventBase
@ -41,6 +43,7 @@ func (event UserLoading) String() string {
return fmt.Sprintf("UserLoading: UserID: %s", event.UserID)
}
// UserLoadSuccess is emitted when a user has been loaded successfully.
type UserLoadSuccess struct {
eventBase
@ -51,6 +54,7 @@ func (event UserLoadSuccess) String() string {
return fmt.Sprintf("UserLoadSuccess: UserID: %s", event.UserID)
}
// UserLoadFail is emitted when a user has failed to load.
type UserLoadFail struct {
eventBase
@ -62,6 +66,7 @@ func (event UserLoadFail) String() string {
return fmt.Sprintf("UserLoadFail: UserID: %s, Error: %s", event.UserID, event.Error)
}
// UserLoggedIn is emitted when a user has logged in.
type UserLoggedIn struct {
eventBase
@ -72,6 +77,7 @@ func (event UserLoggedIn) String() string {
return fmt.Sprintf("UserLoggedIn: UserID: %s", event.UserID)
}
// UserLoggedOut is emitted when a user has logged out.
type UserLoggedOut struct {
eventBase
@ -82,6 +88,7 @@ func (event UserLoggedOut) String() string {
return fmt.Sprintf("UserLoggedOut: UserID: %s", event.UserID)
}
// UserDeauth is emitted when a user has lost its API authentication.
type UserDeauth struct {
eventBase
@ -92,6 +99,19 @@ func (event UserDeauth) String() string {
return fmt.Sprintf("UserDeauth: UserID: %s", event.UserID)
}
// UserBadEvent is emitted when a user cannot apply an event.
type UserBadEvent struct {
eventBase
UserID string
Error error
}
func (event UserBadEvent) String() string {
return fmt.Sprintf("UserBadEvent: UserID: %s, Error: %s", event.UserID, event.Error)
}
// UserDeleted is emitted when a user has been deleted.
type UserDeleted struct {
eventBase
@ -102,6 +122,7 @@ func (event UserDeleted) String() string {
return fmt.Sprintf("UserDeleted: UserID: %s", event.UserID)
}
// UserChanged is emitted when a user's data has changed (name, email, etc.).
type UserChanged struct {
eventBase
@ -112,6 +133,7 @@ func (event UserChanged) String() string {
return fmt.Sprintf("UserChanged: UserID: %s", event.UserID)
}
// UserRefreshed is emitted when an API refresh was issued for a user.
type UserRefreshed struct {
eventBase
@ -122,6 +144,7 @@ func (event UserRefreshed) String() string {
return fmt.Sprintf("UserRefreshed: UserID: %s", event.UserID)
}
// AddressModeChanged is emitted when a user's address mode has changed.
type AddressModeChanged struct {
eventBase

View File

@ -670,9 +670,8 @@ func getMailboxName(label proton.Label) []string {
func waitOnIMAPUpdates(ctx context.Context, updates []imap.Update) error {
for _, update := range updates {
err, ok := update.WaitContext(ctx)
if ok && err != nil {
return fmt.Errorf("failed to apply gluon update %v :%w", update.String(), err)
if err, ok := update.WaitContext(ctx); ok && err != nil {
return fmt.Errorf("failed to apply gluon update %v: %w", update.String(), err)
}
}

View File

@ -20,6 +20,7 @@ package user
import (
"context"
"crypto/subtle"
"errors"
"fmt"
"io"
"strings"
@ -552,29 +553,54 @@ func (user *User) doEventPoll(ctx context.Context) error {
return fmt.Errorf("failed to get event: %w", err)
}
if event.EventID != user.vault.EventID() {
user.log.WithFields(logrus.Fields{
"old": user.vault.EventID(),
"new": event,
}).Info("Received new API event")
// Handle the event.
if err := user.handleAPIEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle event: %w", err)
}
user.log.WithField("event", event).Debug("Handled API event")
// Update the event ID in the vault.
if err := user.vault.SetEventID(event.EventID); err != nil {
return fmt.Errorf("failed to update event ID: %w", err)
}
user.log.WithField("eventID", event.EventID).Debug("Updated event ID in vault")
} else {
// If the event ID hasn't changed, there are no new events.
if event.EventID == user.vault.EventID() {
user.log.Debug("No new API events")
return nil
}
user.log.WithFields(logrus.Fields{
"old": user.vault.EventID(),
"new": event,
}).Info("Received new API event")
// Handle the event.
if err := user.handleAPIEvent(ctx, event); err != nil {
// If the error is a network error, return error to retry later.
if netErr := new(proton.NetError); errors.As(err, &netErr) {
return fmt.Errorf("failed to handle event due to network issue: %w", err)
}
// If the error is a server-side issue, return error to retry later.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
return fmt.Errorf("failed to handle event due to server error: %w", err)
}
// Otherwise, the error is a client-side issue; notify bridge to handle it.
user.log.WithField("event", event).Warn("Failed to handle API event")
user.eventCh.Enqueue(events.UserBadEvent{
UserID: user.ID(),
Error: err,
})
return fmt.Errorf("failed to handle event due to client error: %w", err)
}
user.log.WithField("event", event).Debug("Handled API event")
// Update the event ID in the vault. If this fails, notify bridge to handle it.
if err := user.vault.SetEventID(event.EventID); err != nil {
user.eventCh.Enqueue(events.UserBadEvent{
UserID: user.ID(),
Error: err,
})
return fmt.Errorf("failed to update event ID: %w", err)
}
user.log.WithField("eventID", event.EventID).Debug("Updated event ID in vault")
return nil
}