mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-26 11:46:44 +00:00
GODT-2037: Handle and log API refresh event
This commit is contained in:
@ -35,6 +35,10 @@ import (
|
||||
|
||||
// handleAPIEvent handles the given liteapi.Event.
|
||||
func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error {
|
||||
if event.Refresh&liteapi.RefreshMail != 0 {
|
||||
return user.handleRefreshEvent(ctx)
|
||||
}
|
||||
|
||||
if event.User != nil {
|
||||
if err := user.handleUserEvent(ctx, *event.User); err != nil {
|
||||
return err
|
||||
@ -62,6 +66,54 @@ func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleRefreshEvent(ctx context.Context) error {
|
||||
user.log.Info("Handling refresh event")
|
||||
|
||||
// Cancel and restart ongoing syncs.
|
||||
user.abortable.Abort()
|
||||
defer user.goSync()
|
||||
|
||||
return safe.LockRet(func() error {
|
||||
// Fetch latest user info.
|
||||
apiUser, err := user.client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Fetch latest address info.
|
||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
// Fetch latest label info.
|
||||
apiLabels, err := user.client.GetLabels(ctx, liteapi.LabelTypeSystem, liteapi.LabelTypeFolder, liteapi.LabelTypeLabel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get labels: %w", err)
|
||||
}
|
||||
|
||||
// Update the API info in the user.
|
||||
user.apiUser = apiUser
|
||||
user.apiAddrs = groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID })
|
||||
user.apiLabels = groupBy(apiLabels, func(label liteapi.Label) string { return label.ID })
|
||||
|
||||
// Reinitialize the update channels.
|
||||
user.initUpdateCh(user.vault.AddressMode())
|
||||
|
||||
// Clear sync status; we want to sync everything again.
|
||||
if err := user.vault.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
// The user was refreshed.
|
||||
user.eventCh.Enqueue(events.UserRefreshed{
|
||||
UserID: user.apiUser.ID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
// handleUserEvent handles the given user event.
|
||||
func (user *User) handleUserEvent(_ context.Context, userEvent liteapi.User) error {
|
||||
return safe.LockRet(func() error {
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/connector"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
@ -35,6 +36,9 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Verify that *imapConnector implements connector.Connector.
|
||||
var _ connector.Connector = (*imapConnector)(nil)
|
||||
|
||||
var (
|
||||
defaultFlags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged, imap.FlagDeleted) // nolint:gochecknoglobals
|
||||
defaultPermanentFlags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged, imap.FlagDeleted) // nolint:gochecknoglobals
|
||||
|
||||
@ -100,11 +100,6 @@ func New(
|
||||
return nil, fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
// Check we can unlock the keyrings.
|
||||
if _, _, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass()); err != nil {
|
||||
return nil, fmt.Errorf("failed to unlock user: %w", err)
|
||||
}
|
||||
|
||||
// Get the user's API labels.
|
||||
apiLabels, err := client.GetLabels(ctx, liteapi.LabelTypeSystem, liteapi.LabelTypeFolder, liteapi.LabelTypeLabel)
|
||||
if err != nil {
|
||||
@ -231,12 +226,10 @@ func New(
|
||||
user.log.Debug("Sync triggered")
|
||||
|
||||
user.abortable.Do(ctx, func(ctx context.Context) {
|
||||
if !user.vault.SyncStatus().IsComplete() {
|
||||
if err := user.doSync(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if user.vault.SyncStatus().IsComplete() {
|
||||
user.log.Debug("Sync is already complete, skipping")
|
||||
} else if err := user.doSync(ctx); err != nil {
|
||||
user.log.WithError(err).Error("Failed to sync")
|
||||
}
|
||||
})
|
||||
})
|
||||
@ -300,30 +293,12 @@ func (user *User) GetAddressMode() vault.AddressMode {
|
||||
}
|
||||
|
||||
// SetAddressMode sets the user's address mode.
|
||||
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
|
||||
func (user *User) SetAddressMode(_ context.Context, mode vault.AddressMode) error {
|
||||
user.abortable.Abort()
|
||||
defer user.goSync()
|
||||
|
||||
return safe.LockRet(func() error {
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
}
|
||||
|
||||
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
|
||||
|
||||
switch mode {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh[addrID] = primaryUpdateCh
|
||||
}
|
||||
|
||||
case vault.SplitMode:
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
}
|
||||
}
|
||||
user.initUpdateCh(mode)
|
||||
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
@ -620,6 +595,30 @@ func (user *User) SetShowAllMail(show bool) {
|
||||
atomic.StoreUint32(&user.showAllMail, b32(show))
|
||||
}
|
||||
|
||||
// initUpdateCh initializes the user's update channels in the given address mode.
|
||||
// It is assumed that user.apiAddrs and user.updateCh are already locked.
|
||||
func (user *User) initUpdateCh(mode vault.AddressMode) {
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
}
|
||||
|
||||
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
|
||||
|
||||
switch mode {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh[addrID] = primaryUpdateCh
|
||||
}
|
||||
|
||||
case vault.SplitMode:
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// b32 returns a uint32 0 or 1 representing b.
|
||||
func b32(b bool) uint32 {
|
||||
if b {
|
||||
|
||||
@ -138,6 +138,23 @@ func TestUser_Deauth(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_Refresh(t *testing.T) {
|
||||
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
|
||||
withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s, m, "username", "password", func(user *User) {
|
||||
// Get the event channel.
|
||||
eventCh := user.GetEventCh()
|
||||
|
||||
// Revoke the user's auth token.
|
||||
require.NoError(t, s.RefreshUser(user.ID(), liteapi.RefreshAll))
|
||||
|
||||
// The user should eventually be logged out.
|
||||
require.Eventually(t, func() bool { _, ok := (<-eventCh).(events.UserRefreshed); return ok }, 5*time.Second, 100*time.Millisecond)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func withAPI(_ testing.TB, ctx context.Context, fn func(context.Context, *server.Server, *liteapi.Manager)) { //nolint:revive
|
||||
server := server.New()
|
||||
defer server.Close()
|
||||
|
||||
Reference in New Issue
Block a user