diff --git a/internal/bridge/sync_test.go b/internal/bridge/sync_test.go index adca3418..24849089 100644 --- a/internal/bridge/sync_test.go +++ b/internal/bridge/sync_test.go @@ -21,9 +21,11 @@ import ( "context" "fmt" "io" + "net/http" "os" "path/filepath" "runtime" + "strings" "sync/atomic" "testing" "time" @@ -314,6 +316,194 @@ func TestBridge_SyncWithOngoingEvents(t *testing.T) { }, server.WithTLS(false)) } +func TestBridge_CanProcessEventsDuringSync(t *testing.T) { + numMsg := 1 << 8 + + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + userID, addrID, err := s.CreateUser("imap", password) + require.NoError(t, err) + + labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder) + require.NoError(t, err) + + withClient(ctx, t, s, "imap", password, func(ctx context.Context, c *proton.Client) { + createNumMessages(ctx, t, c, addrID, labelID, numMsg) + }) + + // Simulate 429 to prevent sync from progressing. + s.AddStatusHook(func(request *http.Request) (int, bool) { + if strings.Contains(request.URL.Path, "/mail/v4/messages/") { + return http.StatusTooManyRequests, true + } + + return 0, false + }) + + // The initial user should be fully synced. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { + syncStartedCh, syncStartedDone := chToType[events.Event, events.SyncStarted](bridge.GetEvents(events.SyncStarted{})) + defer syncStartedDone() + + addressCreatedCh, addressCreatedDone := chToType[events.Event, events.UserAddressCreated](bridge.GetEvents(events.UserAddressCreated{})) + defer addressCreatedDone() + + userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) + require.NoError(t, err) + + require.Equal(t, userID, (<-syncStartedCh).UserID) + + // Create a new address + newAddress := "foo@proton.ch" + addrID, err := s.CreateAddress(userID, newAddress, password) + require.NoError(t, err) + + event := <-addressCreatedCh + require.Equal(t, userID, event.UserID) + require.Equal(t, newAddress, event.Email) + require.Equal(t, addrID, event.AddressID) + }) + }, server.WithTLS(false)) +} + +func TestBridge_RefreshDuringSyncRestartSync(t *testing.T) { + numMsg := 1 << 8 + + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + userID, addrID, err := s.CreateUser("imap", password) + require.NoError(t, err) + + labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder) + require.NoError(t, err) + + withClient(ctx, t, s, "imap", password, func(ctx context.Context, c *proton.Client) { + createNumMessages(ctx, t, c, addrID, labelID, numMsg) + }) + + var refreshPerformed atomic.Bool + refreshPerformed.Store(false) + + // Simulate 429 to prevent sync from progressing. + s.AddStatusHook(func(request *http.Request) (int, bool) { + if strings.Contains(request.URL.Path, "/mail/v4/messages/") { + if !refreshPerformed.Load() { + return http.StatusTooManyRequests, true + } + } + + return 0, false + }) + + // The initial user should be fully synced. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { + syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) + defer done() + + userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) + require.NoError(t, err) + + syncStartedCh, syncStartedDone := chToType[events.Event, events.SyncStarted](bridge.GetEvents(events.SyncStarted{})) + defer syncStartedDone() + + require.Equal(t, userID, (<-syncStartedCh).UserID) + + require.NoError(t, err, s.RefreshUser(userID, proton.RefreshMail)) + require.Equal(t, userID, (<-syncStartedCh).UserID) + refreshPerformed.Store(true) + + require.Equal(t, userID, (<-syncCh).UserID) + }) + }, server.WithTLS(false)) +} + +func TestBridge_EventReplayAfterSyncHasFinished(t *testing.T) { + numMsg := 1 << 8 + + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + userID, addrID, err := s.CreateUser("imap", password) + require.NoError(t, err) + + labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder) + require.NoError(t, err) + + withClient(ctx, t, s, "imap", password, func(ctx context.Context, c *proton.Client) { + createNumMessages(ctx, t, c, addrID, labelID, numMsg) + }) + + addrID1, err := s.CreateAddress(userID, "foo@proton.ch", password) + require.NoError(t, err) + + var allowSyncToProgress atomic.Bool + allowSyncToProgress.Store(false) + + // Simulate 429 to prevent sync from progressing. + s.AddStatusHook(func(request *http.Request) (int, bool) { + if request.Method == "GET" && strings.Contains(request.URL.Path, "/mail/v4/messages/") { + if !allowSyncToProgress.Load() { + return http.StatusTooManyRequests, true + } + } + + return 0, false + }) + + // The initial user should be fully synced. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { + syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) + defer done() + + syncStartedCh, syncStartedDone := chToType[events.Event, events.SyncStarted](bridge.GetEvents(events.SyncStarted{})) + defer syncStartedDone() + + addressCreatedCh, addressCreatedDone := chToType[events.Event, events.UserAddressCreated](bridge.GetEvents(events.UserAddressCreated{})) + defer addressCreatedDone() + + userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil) + require.NoError(t, err) + + require.Equal(t, userID, (<-syncStartedCh).UserID) + + // create 20 more messages and move them to inbox + withClient(ctx, t, s, "imap", password, func(ctx context.Context, c *proton.Client) { + createNumMessages(ctx, t, c, addrID, proton.InboxLabel, 20) + }) + + // User AddrID2 event as a check point to see when the new address was created. + addrID2, err := s.CreateAddress(userID, "bar@proton.ch", password) + require.NoError(t, err) + + allowSyncToProgress.Store(true) + require.Equal(t, userID, (<-syncCh).UserID) + + // At most two events can be published, one for the first address, then for the second. + // if the second event is not `addrID2` then something went wrong. + event := <-addressCreatedCh + if event.AddressID == addrID1 { + event = <-addressCreatedCh + } + + require.Equal(t, addrID2, event.AddressID) + + info, err := bridge.GetUserInfo(userID) + require.NoError(t, err) + + client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort())) + require.NoError(t, err) + require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) + defer func() { _ = client.Logout() }() + + // Finally check if the 20 messages are in INBOX. + status, err := client.Status("INBOX", []imap.StatusItem{imap.StatusMessages}) + require.NoError(t, err) + require.Equal(t, uint32(20), status.Messages) + + // Finally check if the numMsg are in the folder. + status, err = client.Status("Folders/folder", []imap.StatusItem{imap.StatusMessages}) + require.NoError(t, err) + require.Equal(t, uint32(numMsg), status.Messages) + }) + }, server.WithTLS(false)) +} + func withClient(ctx context.Context, t *testing.T, s *server.Server, username string, password []byte, fn func(context.Context, *proton.Client)) { //nolint:unparam m := proton.New( proton.WithHostURL(s.GetHostURL()), diff --git a/internal/services/imapservice/service.go b/internal/services/imapservice/service.go index af40fcd2..18498073 100644 --- a/internal/services/imapservice/service.go +++ b/internal/services/imapservice/service.go @@ -42,8 +42,7 @@ import ( type EventProvider interface { userevents.Subscribable - PauseIMAP() - ResumeIMAP() + RewindEventID(ctx context.Context, eventID string) error } type Telemetry interface { @@ -92,7 +91,9 @@ type Service struct { syncStateProvider *SyncState syncReporter *syncReporter - syncConfigPath string + syncConfigPath string + lastHandledEventID string + isSyncing bool } func NewService( @@ -164,8 +165,9 @@ func (s *Service) Start( ctx context.Context, group *orderedtasks.OrderedCancelGroup, syncRegulator syncservice.Regulator, - + lastEventID string, ) error { + s.lastHandledEventID = lastEventID { syncStateProvider, err := NewSyncState(s.syncConfigPath) if err != nil { @@ -276,7 +278,7 @@ func (s *Service) HandleRefreshEvent(ctx context.Context, _ proton.RefreshFlag) return err } - s.syncHandler.CancelAndWait() + s.cancelSync() if err := s.removeConnectorsFromServer(ctx, s.connectors, true); err != nil { return err @@ -353,7 +355,7 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo case *resumeSyncReq: s.log.Info("Resuming sync") // Cancel previous run, if any, just in case. - s.syncHandler.CancelAndWait() + s.cancelSync() s.startSyncing() req.Reply(ctx, nil, nil) case *getLabelsReq: @@ -401,7 +403,22 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo } s.log.Info("Sync complete, starting API event stream") - s.eventProvider.ResumeIMAP() + if err := s.eventProvider.RewindEventID(ctx, s.lastHandledEventID); err != nil { + if errors.Is(err, context.Canceled) { + continue + } + + s.log.WithError(err).Error("Failed to rewind event service") + s.eventPublisher.PublishEvent(ctx, events.UserBadEvent{ + UserID: s.identityState.UserID(), + OldEventID: "", + NewEventID: "", + EventInfo: "", + Error: fmt.Errorf("failed to rewind event loop: %w", err), + }) + } + + s.isSyncing = false } case request, ok := <-s.syncUpdateApplier.requestCh: @@ -423,7 +440,26 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo continue } e.Consume(func(event proton.Event) error { - return eventHandler.OnEvent(ctx, event) + if s.isSyncing { + // We need to reset the sync if we receive a refresh event during a sync and update + // the last event id to avoid problems. + if event.Refresh&proton.RefreshMail != 0 { + if err := s.HandleRefreshEvent(ctx, 0); err != nil { + return err + } + s.lastHandledEventID = event.EventID + } + + return nil + } + + if err := eventHandler.OnEvent(ctx, event); err != nil { + return err + } + + s.lastHandledEventID = event.EventID + + return nil }) case e, ok := <-s.eventWatcher.GetChannel(): if !ok { @@ -537,7 +573,7 @@ func (s *Service) setAddressMode(ctx context.Context, mode usertypes.AddressMode s.log.Info("Setting Combined Address Mode") } - s.syncHandler.CancelAndWait() + s.cancelSync() if err := s.removeConnectorsFromServer(ctx, s.connectors, true); err != nil { return err @@ -573,9 +609,15 @@ func (s *Service) setShowAllMail(v bool) { } func (s *Service) startSyncing() { + s.isSyncing = true s.syncHandler.Execute(s.syncReporter, s.labels.GetLabelMap(), s.syncUpdateApplier, s.syncMessageBuilder, syncservice.DefaultRetryCoolDown) } +func (s *Service) cancelSync() { + s.syncHandler.CancelAndWait() + s.isSyncing = false +} + type resyncReq struct{} type cancelSyncReq struct{} diff --git a/internal/services/userevents/service.go b/internal/services/userevents/service.go index 08155bea..dc1bb9c6 100644 --- a/internal/services/userevents/service.go +++ b/internal/services/userevents/service.go @@ -34,6 +34,7 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/network" "github.com/ProtonMail/proton-bridge/v3/internal/services/orderedtasks" + "github.com/ProtonMail/proton-bridge/v3/pkg/cpc" "github.com/bradenaw/juniper/xmaps" "github.com/sirupsen/logrus" ) @@ -48,6 +49,7 @@ import ( // * UserUsedSpace // By default this service starts paused, you need to call `Service.Resume` at least one time to begin event polling. type Service struct { + cpc *cpc.CPC userID string eventSource EventSource eventIDStore EventIDStore @@ -56,7 +58,6 @@ type Service struct { timer *proton.Ticker eventTimeout time.Duration paused uint32 - pausedIMAP uint32 panicHandler async.PanicHandler subscriberList eventSubscriberList @@ -79,6 +80,7 @@ func NewService( panicHandler async.PanicHandler, ) *Service { return &Service{ + cpc: cpc.NewCPC(), userID: userID, eventSource: eventSource, eventIDStore: store, @@ -89,7 +91,6 @@ func NewService( eventPublisher: eventPublisher, timer: proton.NewTicker(pollPeriod, jitter, panicHandler), paused: 1, - pausedIMAP: 1, eventTimeout: eventTimeout, panicHandler: panicHandler, } @@ -121,12 +122,6 @@ func (s *Service) Pause() { atomic.StoreUint32(&s.paused, 1) } -// PauseIMAP special handler for the IMAP Service - Do Not Use. -func (s *Service) PauseIMAP() { - s.log.Info("Pausing from IMAP") - atomic.StoreUint32(&s.pausedIMAP, 1) -} - // PauseWithWaiter pauses the event polling and returns a waiter to notify when the last event has been published // after the pause request. func (s *Service) PauseWithWaiter() *EventPollWaiter { @@ -148,24 +143,23 @@ func (s *Service) Resume() { atomic.StoreUint32(&s.paused, 0) } -// ResumeIMAP special handler for the IMAP Service - Do Not Use. -func (s *Service) ResumeIMAP() { - s.log.Info("Resuming from IMAP") - atomic.StoreUint32(&s.pausedIMAP, 0) -} - // IsPaused return true if the service is paused. func (s *Service) IsPaused() bool { - // We need to check both IMAP and service paused conditions here to determine if the service is - // paused. There can be instances where the sync from IMAP can overwrite a previous request to pause the loop once - // it is finished. To be addressed as part of GODT-2848. - return atomic.LoadUint32(&s.paused) == 1 || atomic.LoadUint32(&s.pausedIMAP) == 1 + return atomic.LoadUint32(&s.paused) == 1 } -func (s *Service) Start(ctx context.Context, group *orderedtasks.OrderedCancelGroup) error { +// RewindEventID sets the event id as the next event to be polled. +func (s *Service) RewindEventID(ctx context.Context, id string) error { + _, err := s.cpc.Send(ctx, &rewindEventIDReq{eventID: id}) + + return err +} + +// Start the event service and return the last EventID that was processed. +func (s *Service) Start(ctx context.Context, group *orderedtasks.OrderedCancelGroup) (string, error) { lastEventID, err := s.eventIDStore.Load(ctx) if err != nil { - return fmt.Errorf("failed to load last event id: %w", err) + return "", fmt.Errorf("failed to load last event id: %w", err) } if lastEventID == "" { @@ -176,11 +170,11 @@ func (s *Service) Start(ctx context.Context, group *orderedtasks.OrderedCancelGr return eventSource.GetLatestEventID(ctx) }) if err != nil { - return fmt.Errorf("failed to get latest event id: %w", err) + return "", fmt.Errorf("failed to get latest event id: %w", err) } if err := s.eventIDStore.Store(ctx, eventID); err != nil { - return fmt.Errorf("failed to store event in event id store: %v", err) + return "", fmt.Errorf("failed to store event in event id store: %v", err) } lastEventID = eventID @@ -190,11 +184,12 @@ func (s *Service) Start(ctx context.Context, group *orderedtasks.OrderedCancelGr s.run(ctx, lastEventID) }) - return nil + return lastEventID, nil } func (s *Service) run(ctx context.Context, lastEventID string) { s.log.Infof("Starting service Last EventID=%v", lastEventID) + defer s.cpc.Close() defer s.timer.Stop() defer s.log.Info("Exiting service") defer s.Close() @@ -210,6 +205,26 @@ func (s *Service) run(ctx context.Context, lastEventID string) { s.closePollWaiters() continue } + + case r, ok := <-s.cpc.ReceiveCh(): + if !ok { + return + } + + rewind, ok := r.Value().(*rewindEventIDReq) + if !ok { + s.log.Errorf("Received unknown request") + continue + } + + err := s.rewindEventLoop(ctx, rewind.eventID) + r.Reply(ctx, nil, err) + + if err == nil { + lastEventID = rewind.eventID + } + + continue } // Apply any pending subscription changes. @@ -402,6 +417,11 @@ func (s *Service) removeSubscription(subscription EventSubscriber) { s.subscriberList.Remove(subscription) } +func (s *Service) rewindEventLoop(ctx context.Context, id string) error { + s.log.WithField("eventID", id).Info("Event loop reset") + return s.eventIDStore.Store(ctx, id) +} + type pendingOp int const ( @@ -413,3 +433,7 @@ type pendingSubscription struct { op pendingOp sub EventSubscriber } + +type rewindEventIDReq struct { + eventID string +} diff --git a/internal/services/userevents/service_test.go b/internal/services/userevents/service_test.go index 6f95b1e8..2948125c 100644 --- a/internal/services/userevents/service_test.go +++ b/internal/services/userevents/service_test.go @@ -76,9 +76,11 @@ func TestService_EventIDLoadStore(t *testing.T) { time.Second, async.NoopPanicHandler{}, ) - require.NoError(t, service.Start(context.Background(), group)) + + _, err := service.Start(context.Background(), group) + require.NoError(t, err) + service.Resume() - service.ResumeIMAP() group.Wait() } @@ -131,9 +133,10 @@ func TestService_RetryEventOnNonCatastrophicFailure(t *testing.T) { ) service.Subscribe(NewCallbackSubscriber("foo", EventHandler{MessageHandler: subscriber})) - require.NoError(t, service.Start(context.Background(), group)) + _, err := service.Start(context.Background(), group) + require.NoError(t, err) + service.Resume() - service.ResumeIMAP() group.Wait() } @@ -194,9 +197,11 @@ func TestService_OnBadEventServiceIsPaused(t *testing.T) { }) service.Subscribe(NewCallbackSubscriber("foo", EventHandler{MessageHandler: subscriber})) - require.NoError(t, service.Start(context.Background(), group)) + + _, err := service.Start(context.Background(), group) + require.NoError(t, err) + service.Resume() - service.ResumeIMAP() group.Wait() } @@ -251,9 +256,11 @@ func TestService_UnsubscribeDuringEventHandlingDoesNotCauseDeadlock(t *testing.T }) service.Subscribe(subscription) - require.NoError(t, service.Start(context.Background(), group)) + + _, err := service.Start(context.Background(), group) + require.NoError(t, err) + service.Resume() - service.ResumeIMAP() group.Wait() } @@ -310,9 +317,11 @@ func TestService_UnsubscribeBeforeHandlingEventIsNotConsideredError(t *testing.T }) service.Subscribe(subscription) - require.NoError(t, service.Start(context.Background(), group)) + + _, err := service.Start(context.Background(), group) + require.NoError(t, err) + service.Resume() - service.ResumeIMAP() group.Wait() } @@ -373,9 +382,69 @@ func TestService_WaitOnEventPublishAfterPause(t *testing.T) { }) service.Subscribe(NewCallbackSubscriber("foo", EventHandler{MessageHandler: subscriber})) - require.NoError(t, service.Start(context.Background(), group)) + + _, err := service.Start(context.Background(), group) + require.NoError(t, err) + + service.Resume() + group.Wait() +} + +func TestService_EventRewind(t *testing.T) { + group := orderedtasks.NewOrderedCancelGroup(async.NoopPanicHandler{}) + mockCtrl := gomock.NewController(t) + eventPublisher := mocks2.NewMockEventPublisher(mockCtrl) + eventIDStore := mocks.NewMockEventIDStore(mockCtrl) + eventSource := mocks.NewMockEventSource(mockCtrl) + + firstEventID := "EVENT01" + secondEventID := "EVENT02" + messageEvents := []proton.MessageEvent{ + { + EventItem: proton.EventItem{ID: "Message"}, + }, + } + secondEvent := []proton.Event{{ + EventID: secondEventID, + Messages: messageEvents, + }} + + // Return Second Event from id store, but then reset to event 1 + + // Event id store expectations. + store1 := eventIDStore.EXPECT().Load(gomock.Any()).Times(1).Return(secondEventID, nil) + eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(firstEventID)).Times(1).Return(nil).After(store1) + eventIDStore.EXPECT().Store(gomock.Any(), gomock.Eq(secondEventID)).Times(1).Return(nil) + + // Event Source expectations. + eventSource.EXPECT().GetEvent(gomock.Any(), gomock.Eq(firstEventID)).MinTimes(1).DoAndReturn( + func(_ context.Context, _ string) ([]proton.Event, bool, error) { + group.Cancel() + return secondEvent, false, nil + }, + ) + + // Subscriber expectations. + + service := NewService( + "foo", + eventSource, + eventIDStore, + eventPublisher, + time.Millisecond, + time.Millisecond, + time.Second, + async.NoopPanicHandler{}, + ) + + _, err := service.Start(context.Background(), group) + require.NoError(t, err) + + go func() { + require.NoError(t, service.RewindEventID(context.Background(), firstEventID)) + }() + service.Resume() - service.ResumeIMAP() group.Wait() } diff --git a/internal/user/user.go b/internal/user/user.go index a28783dd..c983dcce 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -295,7 +295,8 @@ func newImpl( }) // Start Event Service - if err := user.eventService.Start(ctx, user.serviceGroup); err != nil { + lastEventID, err := user.eventService.Start(ctx, user.serviceGroup) + if err != nil { return user, fmt.Errorf("failed to start event service: %w", err) } @@ -311,7 +312,7 @@ func newImpl( } // Start IMAP Service - if err := user.imapService.Start(ctx, user.serviceGroup, syncService); err != nil { + if err := user.imapService.Start(ctx, user.serviceGroup, syncService, lastEventID); err != nil { return user, fmt.Errorf("failed to start imap service: %w", err) }