diff --git a/go.mod b/go.mod index b08536fe..7c2f61e0 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/stretchr/testify v1.8.0 github.com/urfave/cli/v2 v2.20.3 github.com/vmihailenco/msgpack/v5 v5.3.5 - gitlab.protontech.ch/go/liteapi v0.41.2 + gitlab.protontech.ch/go/liteapi v0.41.3-0.20221111021557-10de395a8f9f go.uber.org/goleak v1.2.0 golang.org/x/exp v0.0.0-20221023144134-a1e5550cf13e golang.org/x/net v0.1.0 diff --git a/go.sum b/go.sum index ca74a7fb..cf497abf 100644 --- a/go.sum +++ b/go.sum @@ -403,8 +403,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0= github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= -gitlab.protontech.ch/go/liteapi v0.41.2 h1:IJ/KmzJ5WzyofeME5UA+ib0sLLN3WkQctLZXLmI29xQ= -gitlab.protontech.ch/go/liteapi v0.41.2/go.mod h1:IM7ADWjgIL2hXopzx0WNamizEuMgM2QZl7QH12FNflk= +gitlab.protontech.ch/go/liteapi v0.41.3-0.20221111021557-10de395a8f9f h1:Vk8CdHAQTxYWhmvLHWbQSpTLW0Dj9SxqWdSWUr4fInA= +gitlab.protontech.ch/go/liteapi v0.41.3-0.20221111021557-10de395a8f9f/go.mod h1:IM7ADWjgIL2hXopzx0WNamizEuMgM2QZl7QH12FNflk= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 2b1754cc..bd14e05a 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -20,6 +20,7 @@ package bridge_test import ( "context" "crypto/tls" + "fmt" "net/http" "os" "sync" @@ -536,3 +537,22 @@ func getConnectedUserIDs(t *testing.T, bridge *bridge.Bridge) []string { return info.Connected }) } + +func chToType[In, Out any](inCh <-chan In, done func()) (<-chan Out, func()) { + outCh := make(chan Out) + + go func() { + defer close(outCh) + + for in := range inCh { + out, ok := any(in).(Out) + if !ok { + panic(fmt.Sprintf("unexpected type %T", in)) + } + + outCh <- out + } + }() + + return outCh, done +} diff --git a/internal/bridge/sync_test.go b/internal/bridge/sync_test.go index 5b06de25..66226557 100644 --- a/internal/bridge/sync_test.go +++ b/internal/bridge/sync_test.go @@ -203,22 +203,3 @@ func countBytesRead(ctl *liteapi.NetCtl, fn func()) uint64 { return read } - -func chToType[In, Out any](inCh <-chan In, done func()) (<-chan Out, func()) { - outCh := make(chan Out) - - go func() { - defer close(outCh) - - for in := range inCh { - out, ok := any(in).(Out) - if !ok { - panic(fmt.Sprintf("unexpected type %T", in)) - } - - outCh <- out - } - }() - - return outCh, done -} diff --git a/internal/bridge/user.go b/internal/bridge/user.go index aba223c5..71a75db7 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -257,14 +257,8 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va return fmt.Errorf("address mode is already %q", mode) } - for addrID, gluonID := range user.GetGluonIDs() { - if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil { - return fmt.Errorf("failed to remove user from IMAP server: %w", err) - } - - if err := user.RemoveGluonID(addrID, gluonID); err != nil { - return fmt.Errorf("failed to remove gluon ID from user: %w", err) - } + if err := bridge.removeIMAPUser(ctx, user, true); err != nil { + return fmt.Errorf("failed to remove IMAP user: %w", err) } if err := user.SetAddressMode(ctx, mode); err != nil { diff --git a/internal/bridge/user_events.go b/internal/bridge/user_events.go index f91a54d8..4fdf0d3a 100644 --- a/internal/bridge/user_events.go +++ b/internal/bridge/user_events.go @@ -44,11 +44,13 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even return fmt.Errorf("failed to handle user address deleted event: %w", err) } + case events.UserRefreshed: + if err := bridge.handleUserRefreshed(ctx, user); err != nil { + return fmt.Errorf("failed to handle user refreshed event: %w", err) + } + case events.UserDeauth: - safe.Lock(func() { - defer delete(bridge.users, user.ID()) - bridge.logoutUser(ctx, user, false, false) - }, bridge.usersLock) + bridge.handleUserDeauth(ctx, user) } return nil @@ -100,3 +102,24 @@ func (bridge *Bridge) handleUserAddressDeleted(ctx context.Context, user *user.U return nil } + +func (bridge *Bridge) handleUserRefreshed(ctx context.Context, user *user.User) error { + return safe.RLockRet(func() error { + if err := bridge.removeIMAPUser(ctx, user, true); err != nil { + return fmt.Errorf("failed to remove IMAP user: %w", err) + } + + if err := bridge.addIMAPUser(ctx, user); err != nil { + return fmt.Errorf("failed to add IMAP user: %w", err) + } + + return nil + }, bridge.usersLock) +} + +func (bridge *Bridge) handleUserDeauth(ctx context.Context, user *user.User) { + safe.Lock(func() { + defer delete(bridge.users, user.ID()) + bridge.logoutUser(ctx, user, false, false) + }, bridge.usersLock) +} diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index 5d712d4d..e4781e6c 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -609,6 +609,34 @@ func TestBridge_UserInfo_Alias(t *testing.T) { }) } +func TestBridge_User_Refresh(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Get a channel of sync started events. + syncStartCh, done := chToType[events.Event, events.SyncStarted](bridge.GetEvents(events.SyncStarted{})) + defer done() + + // Get a channel of sync finished events. + syncFinishCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) + defer done() + + // Log the user in. + userID := must(bridge.LoginFull(ctx, username, password, nil, nil)) + + // The sync should start and finish. + require.Equal(t, userID, (<-syncStartCh).UserID) + require.Equal(t, userID, (<-syncFinishCh).UserID) + + // Trigger a refresh. + require.NoError(t, s.RefreshUser(userID, liteapi.RefreshAll)) + + // The sync should start and finish again. + require.Equal(t, userID, (<-syncStartCh).UserID) + require.Equal(t, userID, (<-syncFinishCh).UserID) + }) + }) +} + // getErr returns the error that was passed to it. func getErr[T any](val T, err error) error { return err diff --git a/internal/events/user.go b/internal/events/user.go index 2dc0f98f..e039cd50 100644 --- a/internal/events/user.go +++ b/internal/events/user.go @@ -112,6 +112,16 @@ func (event UserChanged) String() string { return fmt.Sprintf("UserChanged: UserID: %s", event.UserID) } +type UserRefreshed struct { + eventBase + + UserID string +} + +func (event UserRefreshed) String() string { + return fmt.Sprintf("UserRefreshed: UserID: %s", event.UserID) +} + type AddressModeChanged struct { eventBase diff --git a/internal/user/events.go b/internal/user/events.go index b1da6fb6..c6c2bfa6 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -35,6 +35,10 @@ import ( // handleAPIEvent handles the given liteapi.Event. func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error { + if event.Refresh&liteapi.RefreshMail != 0 { + return user.handleRefreshEvent(ctx) + } + if event.User != nil { if err := user.handleUserEvent(ctx, *event.User); err != nil { return err @@ -62,6 +66,54 @@ func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error return nil } +func (user *User) handleRefreshEvent(ctx context.Context) error { + user.log.Info("Handling refresh event") + + // Cancel and restart ongoing syncs. + user.abortable.Abort() + defer user.goSync() + + return safe.LockRet(func() error { + // Fetch latest user info. + apiUser, err := user.client.GetUser(ctx) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) + } + + // Fetch latest address info. + apiAddrs, err := user.client.GetAddresses(ctx) + if err != nil { + return fmt.Errorf("failed to get addresses: %w", err) + } + + // Fetch latest label info. + apiLabels, err := user.client.GetLabels(ctx, liteapi.LabelTypeSystem, liteapi.LabelTypeFolder, liteapi.LabelTypeLabel) + if err != nil { + return fmt.Errorf("failed to get labels: %w", err) + } + + // Update the API info in the user. + user.apiUser = apiUser + user.apiAddrs = groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }) + user.apiLabels = groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }) + + // Reinitialize the update channels. + user.initUpdateCh(user.vault.AddressMode()) + + // Clear sync status; we want to sync everything again. + if err := user.vault.ClearSyncStatus(); err != nil { + return fmt.Errorf("failed to clear sync status: %w", err) + } + + // The user was refreshed. + user.eventCh.Enqueue(events.UserRefreshed{ + UserID: user.apiUser.ID, + }) + + return nil + }, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock) +} + // handleUserEvent handles the given user event. func (user *User) handleUserEvent(_ context.Context, userEvent liteapi.User) error { return safe.LockRet(func() error { diff --git a/internal/user/imap.go b/internal/user/imap.go index 90b8f73a..015bd329 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -23,6 +23,7 @@ import ( "sync/atomic" "time" + "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" @@ -35,6 +36,9 @@ import ( "golang.org/x/exp/slices" ) +// Verify that *imapConnector implements connector.Connector. +var _ connector.Connector = (*imapConnector)(nil) + var ( defaultFlags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged, imap.FlagDeleted) // nolint:gochecknoglobals defaultPermanentFlags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged, imap.FlagDeleted) // nolint:gochecknoglobals diff --git a/internal/user/user.go b/internal/user/user.go index 88ad8e20..cae7fa6f 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -100,11 +100,6 @@ func New( return nil, fmt.Errorf("failed to get addresses: %w", err) } - // Check we can unlock the keyrings. - if _, _, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass()); err != nil { - return nil, fmt.Errorf("failed to unlock user: %w", err) - } - // Get the user's API labels. apiLabels, err := client.GetLabels(ctx, liteapi.LabelTypeSystem, liteapi.LabelTypeFolder, liteapi.LabelTypeLabel) if err != nil { @@ -231,12 +226,10 @@ func New( user.log.Debug("Sync triggered") user.abortable.Do(ctx, func(ctx context.Context) { - if !user.vault.SyncStatus().IsComplete() { - if err := user.doSync(ctx); err != nil { - return - } - } else { + if user.vault.SyncStatus().IsComplete() { user.log.Debug("Sync is already complete, skipping") + } else if err := user.doSync(ctx); err != nil { + user.log.WithError(err).Error("Failed to sync") } }) }) @@ -300,30 +293,12 @@ func (user *User) GetAddressMode() vault.AddressMode { } // SetAddressMode sets the user's address mode. -func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error { +func (user *User) SetAddressMode(_ context.Context, mode vault.AddressMode) error { user.abortable.Abort() defer user.goSync() return safe.LockRet(func() error { - for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) { - updateCh.CloseAndDiscardQueued() - } - - user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update]) - - switch mode { - case vault.CombinedMode: - primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0) - - for addrID := range user.apiAddrs { - user.updateCh[addrID] = primaryUpdateCh - } - - case vault.SplitMode: - for addrID := range user.apiAddrs { - user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0) - } - } + user.initUpdateCh(mode) if err := user.vault.SetAddressMode(mode); err != nil { return fmt.Errorf("failed to set address mode: %w", err) @@ -620,6 +595,30 @@ func (user *User) SetShowAllMail(show bool) { atomic.StoreUint32(&user.showAllMail, b32(show)) } +// initUpdateCh initializes the user's update channels in the given address mode. +// It is assumed that user.apiAddrs and user.updateCh are already locked. +func (user *User) initUpdateCh(mode vault.AddressMode) { + for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) { + updateCh.CloseAndDiscardQueued() + } + + user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update]) + + switch mode { + case vault.CombinedMode: + primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0) + + for addrID := range user.apiAddrs { + user.updateCh[addrID] = primaryUpdateCh + } + + case vault.SplitMode: + for addrID := range user.apiAddrs { + user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0) + } + } +} + // b32 returns a uint32 0 or 1 representing b. func b32(b bool) uint32 { if b { diff --git a/internal/user/user_test.go b/internal/user/user_test.go index 9f561411..cb95e904 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -138,6 +138,23 @@ func TestUser_Deauth(t *testing.T) { }) } +func TestUser_Refresh(t *testing.T) { + withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { + withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) { + withUser(t, ctx, s, m, "username", "password", func(user *User) { + // Get the event channel. + eventCh := user.GetEventCh() + + // Revoke the user's auth token. + require.NoError(t, s.RefreshUser(user.ID(), liteapi.RefreshAll)) + + // The user should eventually be logged out. + require.Eventually(t, func() bool { _, ok := (<-eventCh).(events.UserRefreshed); return ok }, 5*time.Second, 100*time.Millisecond) + }) + }) + }) +} + func withAPI(_ testing.TB, ctx context.Context, fn func(context.Context, *server.Server, *liteapi.Manager)) { //nolint:revive server := server.New() defer server.Close()