diff --git a/tests/bridge_test.go b/tests/bridge_test.go index 80c7392f..a523e8b0 100644 --- a/tests/bridge_test.go +++ b/tests/bridge_test.go @@ -25,7 +25,6 @@ import ( "time" "github.com/Masterminds/semver/v3" - "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/vault" ) @@ -125,146 +124,143 @@ func (s *scenario) theUserReportsABug() error { } func (s *scenario) bridgeSendsAConnectionUpEvent() error { - return try(s.t.connStatusCh, 5*time.Second, func(event events.Event) error { - if event, ok := event.(events.ConnStatusUp); !ok { - return fmt.Errorf("expected connection up event, got %T", event) - } + if event := s.t.events.await(events.ConnStatusUp{}, 5*time.Second); event == nil { + return errors.New("expected connection up event, got none") + } - return nil - }) + return nil } func (s *scenario) bridgeSendsAConnectionDownEvent() error { - return try(s.t.connStatusCh, 5*time.Second, func(event events.Event) error { - if event, ok := event.(events.ConnStatusDown); !ok { - return fmt.Errorf("expected connection down event, got %T", event) - } + if event := s.t.events.await(events.ConnStatusDown{}, 5*time.Second); event == nil { + return errors.New("expected connection down event, got none") + } - return nil - }) + return nil } func (s *scenario) bridgeSendsADeauthEventForUser(username string) error { - return try(s.t.deauthCh, 5*time.Second, func(event events.UserDeauth) error { - if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { - return fmt.Errorf("expected deauth event for user with ID %s, got %s", wantUserID, event.UserID) - } + event, ok := awaitType(s.t.events, events.UserDeauth{}, 5*time.Second) + if !ok { + return errors.New("expected deauth event, got none") + } - return nil - }) + if wantUserID := s.t.getUserID(username); event.UserID != wantUserID { + return fmt.Errorf("expected deauth event for user %s, got %s", wantUserID, event.UserID) + } + + return nil } func (s *scenario) bridgeSendsAnAddressCreatedEventForUser(username string) error { - return try(s.t.addrCreatedCh, 60*time.Second, func(event events.UserAddressCreated) error { - if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { - return fmt.Errorf("expected user address created event for user with ID %s, got %s", wantUserID, event.UserID) - } + event, ok := awaitType(s.t.events, events.UserAddressCreated{}, 5*time.Second) + if !ok { + return errors.New("expected address created event, got none") + } - return nil - }) + if wantUserID := s.t.getUserID(username); event.UserID != wantUserID { + return fmt.Errorf("expected address created event for user %s, got %s", wantUserID, event.UserID) + } + + return nil } func (s *scenario) bridgeSendsAnAddressDeletedEventForUser(username string) error { - return try(s.t.addrDeletedCh, 60*time.Second, func(event events.UserAddressDeleted) error { - if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { - return fmt.Errorf("expected user address deleted event for user with ID %s, got %s", wantUserID, event.UserID) - } + event, ok := awaitType(s.t.events, events.UserAddressDeleted{}, 5*time.Second) + if !ok { + return errors.New("expected address deleted event, got none") + } - return nil - }) + if wantUserID := s.t.getUserID(username); event.UserID != wantUserID { + return fmt.Errorf("expected address deleted event for user %s, got %s", wantUserID, event.UserID) + } + + return nil } func (s *scenario) bridgeSendsSyncStartedAndFinishedEventsForUser(username string) error { - if err := get(s.t.syncStartedCh, func(event events.SyncStarted) error { - if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { - return fmt.Errorf("expected sync started event for user with ID %s, got %s", wantUserID, event.UserID) - } - - return nil - }); err != nil { - return fmt.Errorf("failed to get sync started event: %w", err) + startEvent, ok := awaitType(s.t.events, events.SyncStarted{}, 5*time.Second) + if !ok { + return errors.New("expected sync started event, got none") } - if err := get(s.t.syncFinishedCh, func(event events.SyncFinished) error { - if wantUserID := s.t.getUserID(username); wantUserID != event.UserID { - return fmt.Errorf("expected sync finished event for user with ID %s, got %s", wantUserID, event.UserID) - } + if wantUserID := s.t.getUserID(username); startEvent.UserID != wantUserID { + return fmt.Errorf("expected sync started event for user %s, got %s", wantUserID, startEvent.UserID) + } - return nil - }); err != nil { - return fmt.Errorf("failed to get sync finished event: %w", err) + finishEvent, ok := awaitType(s.t.events, events.SyncFinished{}, 5*time.Second) + if !ok { + return errors.New("expected sync finished event, got none") + } + + if wantUserID := s.t.getUserID(username); finishEvent.UserID != wantUserID { + return fmt.Errorf("expected sync finished event for user %s, got %s", wantUserID, finishEvent.UserID) } return nil } func (s *scenario) bridgeSendsAnUpdateNotAvailableEvent() error { - return try(s.t.updateCh, 5*time.Second, func(event events.Event) error { - if event, ok := event.(events.UpdateNotAvailable); !ok { - return fmt.Errorf("expected update not available event, got %T", event) - } + if event := s.t.events.await(events.UpdateNotAvailable{}, 5*time.Second); event == nil { + return errors.New("expected update not available event, got none") + } - return nil - }) + return nil } func (s *scenario) bridgeSendsAnUpdateAvailableEventForVersion(version string) error { - return try(s.t.updateCh, 5*time.Second, func(event events.Event) error { - updateEvent, ok := event.(events.UpdateAvailable) - if !ok { - return fmt.Errorf("expected update available event, got %T", event) - } + event, ok := awaitType(s.t.events, events.UpdateAvailable{}, 5*time.Second) + if !ok { + return errors.New("expected update available event, got none") + } - if !updateEvent.CanInstall { - return errors.New("expected update event to be installable") - } + if !event.CanInstall { + return errors.New("expected update event to be installable") + } - if !updateEvent.Version.Version.Equal(semver.MustParse(version)) { - return fmt.Errorf("expected update event for version %s, got %s", version, updateEvent.Version.Version) - } + if !event.Version.Version.Equal(semver.MustParse(version)) { + return fmt.Errorf("expected update event for version %s, got %s", version, event.Version.Version) + } - return nil - }) + return nil } func (s *scenario) bridgeSendsAManualUpdateEventForVersion(version string) error { - return try(s.t.updateCh, 5*time.Second, func(event events.Event) error { - updateEvent, ok := event.(events.UpdateAvailable) - if !ok { - return fmt.Errorf("expected manual update event, got %T", event) - } + event, ok := awaitType(s.t.events, events.UpdateAvailable{}, 5*time.Second) + if !ok { + return errors.New("expected update available event, got none") + } - if updateEvent.CanInstall { - return errors.New("expected update event to not be installable") - } + if event.CanInstall { + return errors.New("expected update event to not be installable") + } - if !updateEvent.Version.Version.Equal(semver.MustParse(version)) { - return fmt.Errorf("expected update event for version %s, got %s", version, updateEvent.Version.Version) - } + if !event.Version.Version.Equal(semver.MustParse(version)) { + return fmt.Errorf("expected update event for version %s, got %s", version, event.Version.Version) + } - return nil - }) + return nil } func (s *scenario) bridgeSendsAnUpdateInstalledEventForVersion(version string) error { - return try(s.t.updateCh, 5*time.Second, func(event events.Event) error { - updateEvent, ok := event.(events.UpdateInstalled) - if !ok { - return fmt.Errorf("expected update installed event, got %T", event) - } + event, ok := awaitType(s.t.events, events.UpdateInstalled{}, 5*time.Second) + if !ok { + return errors.New("expected update installed event, got none") + } - if !updateEvent.Version.Version.Equal(semver.MustParse(version)) { - return fmt.Errorf("expected update event for version %s, got %s", version, updateEvent.Version.Version) - } + if !event.Version.Version.Equal(semver.MustParse(version)) { + return fmt.Errorf("expected update installed event for version %s, got %s", version, event.Version.Version) + } - return nil - }) + return nil } func (s *scenario) bridgeSendsAForcedUpdateEvent() error { - return try(s.t.forcedUpdateCh, 5*time.Second, func(event events.UpdateForced) error { - return nil - }) + if event := s.t.events.await(events.UpdateForced{}, 5*time.Second); event == nil { + return errors.New("expected update forced event, got none") + } + + return nil } func (s *scenario) bridgeHidesAllMail() error { @@ -274,17 +270,3 @@ func (s *scenario) bridgeHidesAllMail() error { func (s *scenario) bridgeShowsAllMail() error { return s.t.bridge.SetShowAllMail(true) } - -func try[T any](inCh *queue.QueuedChannel[T], wait time.Duration, fn func(T) error) error { - select { - case event := <-inCh.GetChannel(): - return fn(event) - - case <-time.After(wait): - return errors.New("timeout waiting for event") - } -} - -func get[T any](inCh *queue.QueuedChannel[T], fn func(T) error) error { - return fn(<-inCh.GetChannel()) -} diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index f74b3ad4..1f1594ae 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -22,8 +22,8 @@ import ( "crypto/tls" "fmt" "net/http/cookiejar" + "time" - "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/cookies" "github.com/ProtonMail/proton-bridge/v2/internal/events" @@ -94,60 +94,10 @@ func (t *testCtx) startBridge() error { return err } - // Create the event channels for use in the test. - t.loginCh = queue.NewQueuedChannel[events.UserLoggedIn](0, 0) - t.logoutCh = queue.NewQueuedChannel[events.UserLoggedOut](0, 0) - t.loadedCh = queue.NewQueuedChannel[events.AllUsersLoaded](0, 0) - t.deletedCh = queue.NewQueuedChannel[events.UserDeleted](0, 0) - t.deauthCh = queue.NewQueuedChannel[events.UserDeauth](0, 0) - t.addrCreatedCh = queue.NewQueuedChannel[events.UserAddressCreated](0, 0) - t.addrDeletedCh = queue.NewQueuedChannel[events.UserAddressDeleted](0, 0) - t.syncStartedCh = queue.NewQueuedChannel[events.SyncStarted](0, 0) - t.syncFinishedCh = queue.NewQueuedChannel[events.SyncFinished](0, 0) - t.forcedUpdateCh = queue.NewQueuedChannel[events.UpdateForced](0, 0) - t.connStatusCh = queue.NewQueuedChannel[events.Event](0, 0) - t.updateCh = queue.NewQueuedChannel[events.Event](0, 0) - - // Push the updates to the appropriate channels. - go func() { - for event := range eventCh { - switch event := event.(type) { - case events.UserLoggedIn: - t.loginCh.Enqueue(event) - case events.UserLoggedOut: - t.logoutCh.Enqueue(event) - case events.AllUsersLoaded: - t.loadedCh.Enqueue(event) - case events.UserDeleted: - t.deletedCh.Enqueue(event) - case events.UserDeauth: - t.deauthCh.Enqueue(event) - case events.UserAddressCreated: - t.addrCreatedCh.Enqueue(event) - case events.UserAddressDeleted: - t.addrDeletedCh.Enqueue(event) - case events.SyncStarted: - t.syncStartedCh.Enqueue(event) - case events.SyncFinished: - t.syncFinishedCh.Enqueue(event) - case events.ConnStatusUp: - t.connStatusCh.Enqueue(event) - case events.ConnStatusDown: - t.connStatusCh.Enqueue(event) - case events.UpdateAvailable: - t.updateCh.Enqueue(event) - case events.UpdateNotAvailable: - t.updateCh.Enqueue(event) - case events.UpdateInstalled: - t.updateCh.Enqueue(event) - case events.UpdateForced: - t.forcedUpdateCh.Enqueue(event) - } - } - }() + t.events.collectFrom(eventCh) // Wait for the users to be loaded. - <-t.loadedCh.GetChannel() + t.events.await(events.AllUsersLoaded{}, 10*time.Second) // Save the bridge to the context. t.bridge = bridge @@ -161,18 +111,6 @@ func (t *testCtx) stopBridge() error { } t.bridge = nil - t.loginCh.CloseAndDiscardQueued() - t.logoutCh.CloseAndDiscardQueued() - t.loadedCh.CloseAndDiscardQueued() - t.deletedCh.CloseAndDiscardQueued() - t.deauthCh.CloseAndDiscardQueued() - t.addrCreatedCh.CloseAndDiscardQueued() - t.addrDeletedCh.CloseAndDiscardQueued() - t.syncStartedCh.CloseAndDiscardQueued() - t.syncFinishedCh.CloseAndDiscardQueued() - t.forcedUpdateCh.CloseAndDiscardQueued() - t.connStatusCh.CloseAndDiscardQueued() - t.updateCh.CloseAndDiscardQueued() return nil } diff --git a/tests/ctx_test.go b/tests/ctx_test.go index 79c52c30..6d18fae4 100644 --- a/tests/ctx_test.go +++ b/tests/ctx_test.go @@ -21,8 +21,11 @@ import ( "context" "fmt" "net/smtp" + "reflect" "regexp" + "sync" "testing" + "time" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gluon/queue" @@ -31,6 +34,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/locations" "github.com/bradenaw/juniper/xslices" "github.com/emersion/go-imap/client" + "github.com/sirupsen/logrus" "gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi/server" "golang.org/x/exp/maps" @@ -47,24 +51,11 @@ type testCtx struct { storeKey []byte version *semver.Version mocks *bridge.Mocks + events *eventCollector // bridge holds the bridge app under test. bridge *bridge.Bridge - // These channels hold events of various types coming from bridge. - loginCh *queue.QueuedChannel[events.UserLoggedIn] - logoutCh *queue.QueuedChannel[events.UserLoggedOut] - loadedCh *queue.QueuedChannel[events.AllUsersLoaded] - deletedCh *queue.QueuedChannel[events.UserDeleted] - deauthCh *queue.QueuedChannel[events.UserDeauth] - addrCreatedCh *queue.QueuedChannel[events.UserAddressCreated] - addrDeletedCh *queue.QueuedChannel[events.UserAddressDeleted] - syncStartedCh *queue.QueuedChannel[events.SyncStarted] - syncFinishedCh *queue.QueuedChannel[events.SyncFinished] - forcedUpdateCh *queue.QueuedChannel[events.UpdateForced] - connStatusCh *queue.QueuedChannel[events.Event] - updateCh *queue.QueuedChannel[events.Event] - // These maps hold expected userIDByName, their primary addresses and bridge passwords. userIDByName map[string]string userAddrByEmail map[string]map[string]string @@ -101,8 +92,9 @@ func newTestCtx(tb testing.TB) *testCtx { netCtl: liteapi.NewNetCtl(), locator: locations.New(bridge.NewTestLocationsProvider(dir), "config-name"), storeKey: []byte("super-secret-store-key"), - mocks: bridge.NewMocks(tb, defaultVersion, defaultVersion), version: defaultVersion, + mocks: bridge.NewMocks(tb, defaultVersion, defaultVersion), + events: newEventCollector(), userIDByName: make(map[string]string), userAddrByEmail: make(map[string]map[string]string), @@ -250,7 +242,13 @@ func (t *testCtx) getLastError() error { func (t *testCtx) close(ctx context.Context) error { for _, client := range t.imapClients { if err := client.client.Logout(); err != nil { - return err + logrus.WithError(err).Error("Failed to logout IMAP client") + } + } + + for _, client := range t.smtpClients { + if err := client.client.Close(); err != nil { + logrus.WithError(err).Error("Failed to close SMTP client") } } @@ -262,5 +260,81 @@ func (t *testCtx) close(ctx context.Context) error { t.api.Close() + t.events.close() + return nil } + +type eventCollector struct { + events map[reflect.Type]*queue.QueuedChannel[events.Event] + lock sync.RWMutex + wg sync.WaitGroup +} + +func newEventCollector() *eventCollector { + return &eventCollector{ + events: make(map[reflect.Type]*queue.QueuedChannel[events.Event]), + } +} + +func (c *eventCollector) collectFrom(eventCh <-chan events.Event) { + c.wg.Add(1) + + go func() { + defer c.wg.Done() + + for event := range eventCh { + c.push(event) + } + }() +} + +func awaitType[T events.Event](c *eventCollector, ofType T, timeout time.Duration) (T, bool) { + if event := c.await(ofType, timeout); event == nil { + return *new(T), false //nolint:gocritic + } else if event, ok := event.(T); !ok { + panic(fmt.Errorf("unexpected event type %T", event)) + } else { + return event, true + } +} + +func (c *eventCollector) await(ofType events.Event, timeout time.Duration) events.Event { + select { + case event := <-c.getEventCh(ofType): + return event + + case <-time.After(timeout): + return nil + } +} + +func (c *eventCollector) push(event events.Event) { + c.lock.Lock() + defer c.lock.Unlock() + + if _, ok := c.events[reflect.TypeOf(event)]; !ok { + c.events[reflect.TypeOf(event)] = queue.NewQueuedChannel[events.Event](0, 0) + } + + c.events[reflect.TypeOf(event)].Enqueue(event) +} + +func (c *eventCollector) getEventCh(ofType events.Event) <-chan events.Event { + c.lock.Lock() + defer c.lock.Unlock() + + if _, ok := c.events[reflect.TypeOf(ofType)]; !ok { + c.events[reflect.TypeOf(ofType)] = queue.NewQueuedChannel[events.Event](0, 0) + } + + return c.events[reflect.TypeOf(ofType)].GetChannel() +} + +func (c *eventCollector) close() { + c.wg.Wait() + + for _, eventCh := range c.events { + eventCh.CloseAndDiscardQueued() + } +}