mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
Other: Fix goroutine leaks in integration tests
We were closing the event QueuedChannel objects in the wrong place; they should have been closed on test teardown, not on stopBridge (which was just a test action and wasn't always called). In order to make the events more scalable, i replace all the QueuedChannel objects with a single event collector, which would create QueuedChannels on demand when it receives an event of a new type.
This commit is contained in:
@ -25,7 +25,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
)
|
||||
@ -125,146 +124,143 @@ func (s *scenario) theUserReportsABug() error {
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsAConnectionUpEvent() error {
|
||||
return try(s.t.connStatusCh, 5*time.Second, func(event events.Event) error {
|
||||
if event, ok := event.(events.ConnStatusUp); !ok {
|
||||
return fmt.Errorf("expected connection up event, got %T", event)
|
||||
}
|
||||
if event := s.t.events.await(events.ConnStatusUp{}, 5*time.Second); event == nil {
|
||||
return errors.New("expected connection up event, got none")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsAConnectionDownEvent() error {
|
||||
return try(s.t.connStatusCh, 5*time.Second, func(event events.Event) error {
|
||||
if event, ok := event.(events.ConnStatusDown); !ok {
|
||||
return fmt.Errorf("expected connection down event, got %T", event)
|
||||
}
|
||||
if event := s.t.events.await(events.ConnStatusDown{}, 5*time.Second); event == nil {
|
||||
return errors.New("expected connection down event, got none")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsADeauthEventForUser(username string) error {
|
||||
return try(s.t.deauthCh, 5*time.Second, func(event events.UserDeauth) error {
|
||||
if wantUserID := s.t.getUserID(username); wantUserID != event.UserID {
|
||||
return fmt.Errorf("expected deauth event for user with ID %s, got %s", wantUserID, event.UserID)
|
||||
}
|
||||
event, ok := awaitType(s.t.events, events.UserDeauth{}, 5*time.Second)
|
||||
if !ok {
|
||||
return errors.New("expected deauth event, got none")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if wantUserID := s.t.getUserID(username); event.UserID != wantUserID {
|
||||
return fmt.Errorf("expected deauth event for user %s, got %s", wantUserID, event.UserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsAnAddressCreatedEventForUser(username string) error {
|
||||
return try(s.t.addrCreatedCh, 60*time.Second, func(event events.UserAddressCreated) error {
|
||||
if wantUserID := s.t.getUserID(username); wantUserID != event.UserID {
|
||||
return fmt.Errorf("expected user address created event for user with ID %s, got %s", wantUserID, event.UserID)
|
||||
}
|
||||
event, ok := awaitType(s.t.events, events.UserAddressCreated{}, 5*time.Second)
|
||||
if !ok {
|
||||
return errors.New("expected address created event, got none")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if wantUserID := s.t.getUserID(username); event.UserID != wantUserID {
|
||||
return fmt.Errorf("expected address created event for user %s, got %s", wantUserID, event.UserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsAnAddressDeletedEventForUser(username string) error {
|
||||
return try(s.t.addrDeletedCh, 60*time.Second, func(event events.UserAddressDeleted) error {
|
||||
if wantUserID := s.t.getUserID(username); wantUserID != event.UserID {
|
||||
return fmt.Errorf("expected user address deleted event for user with ID %s, got %s", wantUserID, event.UserID)
|
||||
}
|
||||
event, ok := awaitType(s.t.events, events.UserAddressDeleted{}, 5*time.Second)
|
||||
if !ok {
|
||||
return errors.New("expected address deleted event, got none")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if wantUserID := s.t.getUserID(username); event.UserID != wantUserID {
|
||||
return fmt.Errorf("expected address deleted event for user %s, got %s", wantUserID, event.UserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsSyncStartedAndFinishedEventsForUser(username string) error {
|
||||
if err := get(s.t.syncStartedCh, func(event events.SyncStarted) error {
|
||||
if wantUserID := s.t.getUserID(username); wantUserID != event.UserID {
|
||||
return fmt.Errorf("expected sync started event for user with ID %s, got %s", wantUserID, event.UserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to get sync started event: %w", err)
|
||||
startEvent, ok := awaitType(s.t.events, events.SyncStarted{}, 5*time.Second)
|
||||
if !ok {
|
||||
return errors.New("expected sync started event, got none")
|
||||
}
|
||||
|
||||
if err := get(s.t.syncFinishedCh, func(event events.SyncFinished) error {
|
||||
if wantUserID := s.t.getUserID(username); wantUserID != event.UserID {
|
||||
return fmt.Errorf("expected sync finished event for user with ID %s, got %s", wantUserID, event.UserID)
|
||||
}
|
||||
if wantUserID := s.t.getUserID(username); startEvent.UserID != wantUserID {
|
||||
return fmt.Errorf("expected sync started event for user %s, got %s", wantUserID, startEvent.UserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to get sync finished event: %w", err)
|
||||
finishEvent, ok := awaitType(s.t.events, events.SyncFinished{}, 5*time.Second)
|
||||
if !ok {
|
||||
return errors.New("expected sync finished event, got none")
|
||||
}
|
||||
|
||||
if wantUserID := s.t.getUserID(username); finishEvent.UserID != wantUserID {
|
||||
return fmt.Errorf("expected sync finished event for user %s, got %s", wantUserID, finishEvent.UserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsAnUpdateNotAvailableEvent() error {
|
||||
return try(s.t.updateCh, 5*time.Second, func(event events.Event) error {
|
||||
if event, ok := event.(events.UpdateNotAvailable); !ok {
|
||||
return fmt.Errorf("expected update not available event, got %T", event)
|
||||
}
|
||||
if event := s.t.events.await(events.UpdateNotAvailable{}, 5*time.Second); event == nil {
|
||||
return errors.New("expected update not available event, got none")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsAnUpdateAvailableEventForVersion(version string) error {
|
||||
return try(s.t.updateCh, 5*time.Second, func(event events.Event) error {
|
||||
updateEvent, ok := event.(events.UpdateAvailable)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected update available event, got %T", event)
|
||||
}
|
||||
event, ok := awaitType(s.t.events, events.UpdateAvailable{}, 5*time.Second)
|
||||
if !ok {
|
||||
return errors.New("expected update available event, got none")
|
||||
}
|
||||
|
||||
if !updateEvent.CanInstall {
|
||||
return errors.New("expected update event to be installable")
|
||||
}
|
||||
if !event.CanInstall {
|
||||
return errors.New("expected update event to be installable")
|
||||
}
|
||||
|
||||
if !updateEvent.Version.Version.Equal(semver.MustParse(version)) {
|
||||
return fmt.Errorf("expected update event for version %s, got %s", version, updateEvent.Version.Version)
|
||||
}
|
||||
if !event.Version.Version.Equal(semver.MustParse(version)) {
|
||||
return fmt.Errorf("expected update event for version %s, got %s", version, event.Version.Version)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsAManualUpdateEventForVersion(version string) error {
|
||||
return try(s.t.updateCh, 5*time.Second, func(event events.Event) error {
|
||||
updateEvent, ok := event.(events.UpdateAvailable)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected manual update event, got %T", event)
|
||||
}
|
||||
event, ok := awaitType(s.t.events, events.UpdateAvailable{}, 5*time.Second)
|
||||
if !ok {
|
||||
return errors.New("expected update available event, got none")
|
||||
}
|
||||
|
||||
if updateEvent.CanInstall {
|
||||
return errors.New("expected update event to not be installable")
|
||||
}
|
||||
if event.CanInstall {
|
||||
return errors.New("expected update event to not be installable")
|
||||
}
|
||||
|
||||
if !updateEvent.Version.Version.Equal(semver.MustParse(version)) {
|
||||
return fmt.Errorf("expected update event for version %s, got %s", version, updateEvent.Version.Version)
|
||||
}
|
||||
if !event.Version.Version.Equal(semver.MustParse(version)) {
|
||||
return fmt.Errorf("expected update event for version %s, got %s", version, event.Version.Version)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsAnUpdateInstalledEventForVersion(version string) error {
|
||||
return try(s.t.updateCh, 5*time.Second, func(event events.Event) error {
|
||||
updateEvent, ok := event.(events.UpdateInstalled)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected update installed event, got %T", event)
|
||||
}
|
||||
event, ok := awaitType(s.t.events, events.UpdateInstalled{}, 5*time.Second)
|
||||
if !ok {
|
||||
return errors.New("expected update installed event, got none")
|
||||
}
|
||||
|
||||
if !updateEvent.Version.Version.Equal(semver.MustParse(version)) {
|
||||
return fmt.Errorf("expected update event for version %s, got %s", version, updateEvent.Version.Version)
|
||||
}
|
||||
if !event.Version.Version.Equal(semver.MustParse(version)) {
|
||||
return fmt.Errorf("expected update installed event for version %s, got %s", version, event.Version.Version)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeSendsAForcedUpdateEvent() error {
|
||||
return try(s.t.forcedUpdateCh, 5*time.Second, func(event events.UpdateForced) error {
|
||||
return nil
|
||||
})
|
||||
if event := s.t.events.await(events.UpdateForced{}, 5*time.Second); event == nil {
|
||||
return errors.New("expected update forced event, got none")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scenario) bridgeHidesAllMail() error {
|
||||
@ -274,17 +270,3 @@ func (s *scenario) bridgeHidesAllMail() error {
|
||||
func (s *scenario) bridgeShowsAllMail() error {
|
||||
return s.t.bridge.SetShowAllMail(true)
|
||||
}
|
||||
|
||||
func try[T any](inCh *queue.QueuedChannel[T], wait time.Duration, fn func(T) error) error {
|
||||
select {
|
||||
case event := <-inCh.GetChannel():
|
||||
return fn(event)
|
||||
|
||||
case <-time.After(wait):
|
||||
return errors.New("timeout waiting for event")
|
||||
}
|
||||
}
|
||||
|
||||
func get[T any](inCh *queue.QueuedChannel[T], fn func(T) error) error {
|
||||
return fn(<-inCh.GetChannel())
|
||||
}
|
||||
|
||||
@ -22,8 +22,8 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http/cookiejar"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/cookies"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
@ -94,60 +94,10 @@ func (t *testCtx) startBridge() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create the event channels for use in the test.
|
||||
t.loginCh = queue.NewQueuedChannel[events.UserLoggedIn](0, 0)
|
||||
t.logoutCh = queue.NewQueuedChannel[events.UserLoggedOut](0, 0)
|
||||
t.loadedCh = queue.NewQueuedChannel[events.AllUsersLoaded](0, 0)
|
||||
t.deletedCh = queue.NewQueuedChannel[events.UserDeleted](0, 0)
|
||||
t.deauthCh = queue.NewQueuedChannel[events.UserDeauth](0, 0)
|
||||
t.addrCreatedCh = queue.NewQueuedChannel[events.UserAddressCreated](0, 0)
|
||||
t.addrDeletedCh = queue.NewQueuedChannel[events.UserAddressDeleted](0, 0)
|
||||
t.syncStartedCh = queue.NewQueuedChannel[events.SyncStarted](0, 0)
|
||||
t.syncFinishedCh = queue.NewQueuedChannel[events.SyncFinished](0, 0)
|
||||
t.forcedUpdateCh = queue.NewQueuedChannel[events.UpdateForced](0, 0)
|
||||
t.connStatusCh = queue.NewQueuedChannel[events.Event](0, 0)
|
||||
t.updateCh = queue.NewQueuedChannel[events.Event](0, 0)
|
||||
|
||||
// Push the updates to the appropriate channels.
|
||||
go func() {
|
||||
for event := range eventCh {
|
||||
switch event := event.(type) {
|
||||
case events.UserLoggedIn:
|
||||
t.loginCh.Enqueue(event)
|
||||
case events.UserLoggedOut:
|
||||
t.logoutCh.Enqueue(event)
|
||||
case events.AllUsersLoaded:
|
||||
t.loadedCh.Enqueue(event)
|
||||
case events.UserDeleted:
|
||||
t.deletedCh.Enqueue(event)
|
||||
case events.UserDeauth:
|
||||
t.deauthCh.Enqueue(event)
|
||||
case events.UserAddressCreated:
|
||||
t.addrCreatedCh.Enqueue(event)
|
||||
case events.UserAddressDeleted:
|
||||
t.addrDeletedCh.Enqueue(event)
|
||||
case events.SyncStarted:
|
||||
t.syncStartedCh.Enqueue(event)
|
||||
case events.SyncFinished:
|
||||
t.syncFinishedCh.Enqueue(event)
|
||||
case events.ConnStatusUp:
|
||||
t.connStatusCh.Enqueue(event)
|
||||
case events.ConnStatusDown:
|
||||
t.connStatusCh.Enqueue(event)
|
||||
case events.UpdateAvailable:
|
||||
t.updateCh.Enqueue(event)
|
||||
case events.UpdateNotAvailable:
|
||||
t.updateCh.Enqueue(event)
|
||||
case events.UpdateInstalled:
|
||||
t.updateCh.Enqueue(event)
|
||||
case events.UpdateForced:
|
||||
t.forcedUpdateCh.Enqueue(event)
|
||||
}
|
||||
}
|
||||
}()
|
||||
t.events.collectFrom(eventCh)
|
||||
|
||||
// Wait for the users to be loaded.
|
||||
<-t.loadedCh.GetChannel()
|
||||
t.events.await(events.AllUsersLoaded{}, 10*time.Second)
|
||||
|
||||
// Save the bridge to the context.
|
||||
t.bridge = bridge
|
||||
@ -161,18 +111,6 @@ func (t *testCtx) stopBridge() error {
|
||||
}
|
||||
|
||||
t.bridge = nil
|
||||
t.loginCh.CloseAndDiscardQueued()
|
||||
t.logoutCh.CloseAndDiscardQueued()
|
||||
t.loadedCh.CloseAndDiscardQueued()
|
||||
t.deletedCh.CloseAndDiscardQueued()
|
||||
t.deauthCh.CloseAndDiscardQueued()
|
||||
t.addrCreatedCh.CloseAndDiscardQueued()
|
||||
t.addrDeletedCh.CloseAndDiscardQueued()
|
||||
t.syncStartedCh.CloseAndDiscardQueued()
|
||||
t.syncFinishedCh.CloseAndDiscardQueued()
|
||||
t.forcedUpdateCh.CloseAndDiscardQueued()
|
||||
t.connStatusCh.CloseAndDiscardQueued()
|
||||
t.updateCh.CloseAndDiscardQueued()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -21,8 +21,11 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
@ -31,6 +34,7 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/emersion/go-imap/client"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
"golang.org/x/exp/maps"
|
||||
@ -47,24 +51,11 @@ type testCtx struct {
|
||||
storeKey []byte
|
||||
version *semver.Version
|
||||
mocks *bridge.Mocks
|
||||
events *eventCollector
|
||||
|
||||
// bridge holds the bridge app under test.
|
||||
bridge *bridge.Bridge
|
||||
|
||||
// These channels hold events of various types coming from bridge.
|
||||
loginCh *queue.QueuedChannel[events.UserLoggedIn]
|
||||
logoutCh *queue.QueuedChannel[events.UserLoggedOut]
|
||||
loadedCh *queue.QueuedChannel[events.AllUsersLoaded]
|
||||
deletedCh *queue.QueuedChannel[events.UserDeleted]
|
||||
deauthCh *queue.QueuedChannel[events.UserDeauth]
|
||||
addrCreatedCh *queue.QueuedChannel[events.UserAddressCreated]
|
||||
addrDeletedCh *queue.QueuedChannel[events.UserAddressDeleted]
|
||||
syncStartedCh *queue.QueuedChannel[events.SyncStarted]
|
||||
syncFinishedCh *queue.QueuedChannel[events.SyncFinished]
|
||||
forcedUpdateCh *queue.QueuedChannel[events.UpdateForced]
|
||||
connStatusCh *queue.QueuedChannel[events.Event]
|
||||
updateCh *queue.QueuedChannel[events.Event]
|
||||
|
||||
// These maps hold expected userIDByName, their primary addresses and bridge passwords.
|
||||
userIDByName map[string]string
|
||||
userAddrByEmail map[string]map[string]string
|
||||
@ -101,8 +92,9 @@ func newTestCtx(tb testing.TB) *testCtx {
|
||||
netCtl: liteapi.NewNetCtl(),
|
||||
locator: locations.New(bridge.NewTestLocationsProvider(dir), "config-name"),
|
||||
storeKey: []byte("super-secret-store-key"),
|
||||
mocks: bridge.NewMocks(tb, defaultVersion, defaultVersion),
|
||||
version: defaultVersion,
|
||||
mocks: bridge.NewMocks(tb, defaultVersion, defaultVersion),
|
||||
events: newEventCollector(),
|
||||
|
||||
userIDByName: make(map[string]string),
|
||||
userAddrByEmail: make(map[string]map[string]string),
|
||||
@ -250,7 +242,13 @@ func (t *testCtx) getLastError() error {
|
||||
func (t *testCtx) close(ctx context.Context) error {
|
||||
for _, client := range t.imapClients {
|
||||
if err := client.client.Logout(); err != nil {
|
||||
return err
|
||||
logrus.WithError(err).Error("Failed to logout IMAP client")
|
||||
}
|
||||
}
|
||||
|
||||
for _, client := range t.smtpClients {
|
||||
if err := client.client.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close SMTP client")
|
||||
}
|
||||
}
|
||||
|
||||
@ -262,5 +260,81 @@ func (t *testCtx) close(ctx context.Context) error {
|
||||
|
||||
t.api.Close()
|
||||
|
||||
t.events.close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type eventCollector struct {
|
||||
events map[reflect.Type]*queue.QueuedChannel[events.Event]
|
||||
lock sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newEventCollector() *eventCollector {
|
||||
return &eventCollector{
|
||||
events: make(map[reflect.Type]*queue.QueuedChannel[events.Event]),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *eventCollector) collectFrom(eventCh <-chan events.Event) {
|
||||
c.wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
|
||||
for event := range eventCh {
|
||||
c.push(event)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func awaitType[T events.Event](c *eventCollector, ofType T, timeout time.Duration) (T, bool) {
|
||||
if event := c.await(ofType, timeout); event == nil {
|
||||
return *new(T), false //nolint:gocritic
|
||||
} else if event, ok := event.(T); !ok {
|
||||
panic(fmt.Errorf("unexpected event type %T", event))
|
||||
} else {
|
||||
return event, true
|
||||
}
|
||||
}
|
||||
|
||||
func (c *eventCollector) await(ofType events.Event, timeout time.Duration) events.Event {
|
||||
select {
|
||||
case event := <-c.getEventCh(ofType):
|
||||
return event
|
||||
|
||||
case <-time.After(timeout):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *eventCollector) push(event events.Event) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if _, ok := c.events[reflect.TypeOf(event)]; !ok {
|
||||
c.events[reflect.TypeOf(event)] = queue.NewQueuedChannel[events.Event](0, 0)
|
||||
}
|
||||
|
||||
c.events[reflect.TypeOf(event)].Enqueue(event)
|
||||
}
|
||||
|
||||
func (c *eventCollector) getEventCh(ofType events.Event) <-chan events.Event {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if _, ok := c.events[reflect.TypeOf(ofType)]; !ok {
|
||||
c.events[reflect.TypeOf(ofType)] = queue.NewQueuedChannel[events.Event](0, 0)
|
||||
}
|
||||
|
||||
return c.events[reflect.TypeOf(ofType)].GetChannel()
|
||||
}
|
||||
|
||||
func (c *eventCollector) close() {
|
||||
c.wg.Wait()
|
||||
|
||||
for _, eventCh := range c.events {
|
||||
eventCh.CloseAndDiscardQueued()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user