Other: Simple gRPC client/server under test

This commit is contained in:
James Houlahan
2022-11-03 10:43:25 +01:00
parent d093488522
commit 78fc5ec458
12 changed files with 477 additions and 121 deletions

View File

@ -248,6 +248,10 @@ func (bridge *Bridge) SetUpdateChannel(channel updater.Channel) error {
return nil return nil
} }
func (bridge *Bridge) GetCurrentVersion() *semver.Version {
return bridge.curVersion
}
func (bridge *Bridge) GetLastVersion() *semver.Version { func (bridge *Bridge) GetLastVersion() *semver.Version {
return bridge.vault.GetLastVersion() return bridge.vault.GetLastVersion()
} }

View File

@ -22,15 +22,15 @@ import (
"os" "os"
) )
// config is a structure containing the service configuration data that are exchanged by the gRPC server and client. // Config is a structure containing the service configuration data that are exchanged by the gRPC server and client.
type config struct { type Config struct {
Port int `json:"port"` Port int `json:"port"`
Cert string `json:"cert"` Cert string `json:"cert"`
Token string `json:"token"` Token string `json:"token"`
} }
// save saves a gRPC service configuration to file. // save saves a gRPC service configuration to file.
func (s *config) save(path string) error { func (s *Config) save(path string) error {
// Another process may be waiting for this file to be available. In order to prevent this process to open // Another process may be waiting for this file to be available. In order to prevent this process to open
// the file while we are writing in it, we write it with a temp file name, then rename it. // the file while we are writing in it, we write it with a temp file name, then rename it.
tempPath := path + "_" tempPath := path + "_"
@ -41,7 +41,7 @@ func (s *config) save(path string) error {
return os.Rename(tempPath, path) return os.Rename(tempPath, path)
} }
func (s *config) _save(path string) error { func (s *Config) _save(path string) error {
f, err := os.Create(path) //nolint:errcheck,gosec f, err := os.Create(path) //nolint:errcheck,gosec
if err != nil { if err != nil {
return err return err
@ -53,7 +53,7 @@ func (s *config) _save(path string) error {
} }
// load loads a gRPC service configuration from file. // load loads a gRPC service configuration from file.
func (s *config) load(path string) error { func (s *Config) load(path string) error {
f, err := os.Open(path) //nolint:errcheck,gosec f, err := os.Open(path) //nolint:errcheck,gosec
if err != nil { if err != nil {
return err return err

View File

@ -32,7 +32,7 @@ const (
) )
func TestConfig(t *testing.T) { func TestConfig(t *testing.T) {
conf1 := config{ conf1 := Config{
Port: dummyPort, Port: dummyPort,
Cert: dummyCert, Cert: dummyCert,
Token: dummyToken, Token: dummyToken,
@ -43,7 +43,7 @@ func TestConfig(t *testing.T) {
tempFilePath := filepath.Join(tempDir, tempFileName) tempFilePath := filepath.Join(tempDir, tempFileName)
require.NoError(t, conf1.save(tempFilePath)) require.NoError(t, conf1.save(tempFilePath))
conf2 := config{} conf2 := Config{}
require.NoError(t, conf2.load(tempFilePath)) require.NoError(t, conf2.load(tempFilePath))
require.Equal(t, conf1, conf2) require.Equal(t, conf1, conf2)

View File

@ -31,12 +31,9 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/certs" "github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/safe"
"github.com/ProtonMail/proton-bridge/v2/internal/updater" "github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/pkg/restarter"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
@ -64,8 +61,8 @@ type Service struct { // nolint:structcheck
eventQueue []*StreamEvent eventQueue []*StreamEvent
eventQueueMutex sync.Mutex eventQueueMutex sync.Mutex
panicHandler *crash.Handler panicHandler CrashHandler
restarter *restarter.Restarter restarter Restarter
bridge *bridge.Bridge bridge *bridge.Bridge
eventCh <-chan events.Event eventCh <-chan events.Event
@ -91,9 +88,9 @@ type Service struct { // nolint:structcheck
// //
// nolint:funlen // nolint:funlen
func NewService( func NewService(
panicHandler *crash.Handler, panicHandler CrashHandler,
restarter *restarter.Restarter, restarter Restarter,
locations *locations.Locations, locations Locator,
bridge *bridge.Bridge, bridge *bridge.Bridge,
eventCh <-chan events.Event, eventCh <-chan events.Event,
showOnStartup bool, showOnStartup bool,
@ -394,13 +391,13 @@ func newTLSConfig() (*tls.Config, []byte, error) {
}, certPEM, nil }, certPEM, nil
} }
func saveGRPCServerConfigFile(locations *locations.Locations, listener net.Listener, token string, certPEM []byte) (string, error) { func saveGRPCServerConfigFile(locations Locator, listener net.Listener, token string, certPEM []byte) (string, error) {
address, ok := listener.Addr().(*net.TCPAddr) address, ok := listener.Addr().(*net.TCPAddr)
if !ok { if !ok {
return "", fmt.Errorf("could not retrieve gRPC service listener address") return "", fmt.Errorf("could not retrieve gRPC service listener address")
} }
sc := config{ sc := Config{
Port: address.Port, Port: address.Port,
Cert: string(certPEM), Cert: string(certPEM),
Token: token, Token: token,

View File

@ -48,7 +48,7 @@ func (s *Service) CheckTokens(ctx context.Context, clientConfigPath *wrapperspb.
path := clientConfigPath.Value path := clientConfigPath.Value
logEntry := s.log.WithField("path", path) logEntry := s.log.WithField("path", path)
var clientConfig config var clientConfig Config
if err := clientConfig.load(path); err != nil { if err := clientConfig.load(path); err != nil {
logEntry.WithError(err).Error("Could not read gRPC client config file") logEntry.WithError(err).Error("Could not read gRPC client config file")
@ -234,7 +234,7 @@ func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb.
func (s *Service) Version(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { func (s *Service) Version(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("Version") s.log.Debug("Version")
return wrapperspb.String(constants.Version), nil return wrapperspb.String(s.bridge.GetCurrentVersion().Original()), nil
} }
func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {

View File

@ -0,0 +1,32 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package grpc
type CrashHandler interface {
HandlePanic()
}
type Restarter interface {
Set(restart, crash bool)
AddFlags(flags ...string)
Override(exe string)
}
type Locator interface {
ProvideSettingsPath() (string, error)
}

View File

@ -132,6 +132,9 @@ func TestFeatures(testingT *testing.T) {
ctx.Step(`^bridge sends an update not available event$`, s.bridgeSendsAnUpdateNotAvailableEvent) ctx.Step(`^bridge sends an update not available event$`, s.bridgeSendsAnUpdateNotAvailableEvent)
ctx.Step(`^bridge sends a forced update event$`, s.bridgeSendsAForcedUpdateEvent) ctx.Step(`^bridge sends a forced update event$`, s.bridgeSendsAForcedUpdateEvent)
// ==== FRONTEND ====
ctx.Step(`^frontend sees that bridge is version "([^"]*)"$`, s.frontendSeesThatBridgeIsVersion)
// ==== USER ==== // ==== USER ====
ctx.Step(`^the user logs in with username "([^"]*)" and password "([^"]*)"$`, s.userLogsInWithUsernameAndPassword) ctx.Step(`^the user logs in with username "([^"]*)" and password "([^"]*)"$`, s.userLogsInWithUsernameAndPassword)
ctx.Step(`^user "([^"]*)" logs out$`, s.userLogsOut) ctx.Step(`^user "([^"]*)" logs out$`, s.userLogsOut)

117
tests/collector_test.go Normal file
View File

@ -0,0 +1,117 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package tests
import (
"fmt"
"reflect"
"sync"
"time"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
)
type eventCollector struct {
events map[reflect.Type]*queue.QueuedChannel[events.Event]
fwdCh []*queue.QueuedChannel[events.Event]
lock sync.Mutex
wg sync.WaitGroup
}
func newEventCollector() *eventCollector {
return &eventCollector{
events: make(map[reflect.Type]*queue.QueuedChannel[events.Event]),
}
}
func (c *eventCollector) collectFrom(eventCh <-chan events.Event) <-chan events.Event {
c.lock.Lock()
defer c.lock.Unlock()
fwdCh := queue.NewQueuedChannel[events.Event](0, 0)
c.fwdCh = append(c.fwdCh, fwdCh)
c.wg.Add(1)
go func() {
defer fwdCh.CloseAndDiscardQueued()
defer c.wg.Done()
for event := range eventCh {
c.push(event)
}
}()
return fwdCh.GetChannel()
}
func awaitType[T events.Event](c *eventCollector, ofType T, timeout time.Duration) (T, bool) {
if event := c.await(ofType, timeout); event == nil {
return *new(T), false //nolint:gocritic
} else if event, ok := event.(T); !ok {
panic(fmt.Errorf("unexpected event type %T", event))
} else {
return event, true
}
}
func (c *eventCollector) await(ofType events.Event, timeout time.Duration) events.Event {
select {
case event := <-c.getEventCh(ofType):
return event
case <-time.After(timeout):
return nil
}
}
func (c *eventCollector) push(event events.Event) {
c.lock.Lock()
defer c.lock.Unlock()
if _, ok := c.events[reflect.TypeOf(event)]; !ok {
c.events[reflect.TypeOf(event)] = queue.NewQueuedChannel[events.Event](0, 0)
}
c.events[reflect.TypeOf(event)].Enqueue(event)
for _, eventCh := range c.fwdCh {
eventCh.Enqueue(event)
}
}
func (c *eventCollector) getEventCh(ofType events.Event) <-chan events.Event {
c.lock.Lock()
defer c.lock.Unlock()
if _, ok := c.events[reflect.TypeOf(ofType)]; !ok {
c.events[reflect.TypeOf(ofType)] = queue.NewQueuedChannel[events.Event](0, 0)
}
return c.events[reflect.TypeOf(ofType)].GetChannel()
}
func (c *eventCollector) close() {
c.wg.Wait()
for _, eventCh := range c.events {
eventCh.CloseAndDiscardQueued()
}
}

View File

@ -20,64 +20,129 @@ package tests
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/json"
"fmt" "fmt"
"net/http/cookiejar" "net/http/cookiejar"
"os" "os"
"path/filepath"
"runtime"
"time" "time"
"github.com/sirupsen/logrus" "github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/cookies" "github.com/ProtonMail/proton-bridge/v2/internal/cookies"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
frontend "github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/types/known/emptypb"
) )
func (t *testCtx) startBridge() error { func (t *testCtx) startBridge() error {
logrus.Info("Starting bridge")
eventCh, err := t.initBridge()
if err != nil {
return fmt.Errorf("could not create bridge: %w", err)
}
logrus.Info("Starting frontend service")
if err := t.initFrontendService(eventCh); err != nil {
return fmt.Errorf("could not create frontend service: %w", err)
}
logrus.Info("Starting frontend client")
if err := t.initFrontendClient(); err != nil {
return fmt.Errorf("could not create frontend client: %w", err)
}
t.events.await(events.AllUsersLoaded{}, 30*time.Second)
return nil
}
func (t *testCtx) stopBridge() error {
if err := t.closeFrontendService(context.Background()); err != nil {
return fmt.Errorf("could not close frontend: %w", err)
}
if err := t.closeFrontendClient(); err != nil {
return fmt.Errorf("could not close frontend client: %w", err)
}
if err := t.closeBridge(context.Background()); err != nil {
return fmt.Errorf("could not close bridge: %w", err)
}
return nil
}
func (t *testCtx) initBridge() (<-chan events.Event, error) {
if t.bridge != nil {
return nil, fmt.Errorf("bridge is already started")
}
// Bridge will enable the proxy by default at startup. // Bridge will enable the proxy by default at startup.
t.mocks.ProxyCtl.EXPECT().AllowProxy() t.mocks.ProxyCtl.EXPECT().AllowProxy()
// Get the path to the vault. // Get the path to the vault.
vaultDir, err := t.locator.ProvideSettingsPath() vaultDir, err := t.locator.ProvideSettingsPath()
if err != nil { if err != nil {
return err return nil, fmt.Errorf("could not get vault dir: %w", err)
} }
// Get the default gluon path. // Get the default gluon path.
gluonDir, err := t.locator.ProvideGluonPath() gluonDir, err := t.locator.ProvideGluonPath()
if err != nil { if err != nil {
return err return nil, fmt.Errorf("could not get gluon dir: %w", err)
} }
// Create the vault. // Create the vault.
vault, corrupt, err := vault.New(vaultDir, gluonDir, t.storeKey) vault, corrupt, err := vault.New(vaultDir, gluonDir, t.storeKey)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("could not create vault: %w", err)
} else if corrupt { } else if corrupt {
return fmt.Errorf("vault is corrupt") return nil, fmt.Errorf("vault is corrupt")
} }
// Create the underlying cookie jar. // Create the underlying cookie jar.
jar, err := cookiejar.New(nil) jar, err := cookiejar.New(nil)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("could not create cookie jar: %w", err)
} }
// Create the persisting cookie jar. // Create the persisting cookie jar.
persister, err := cookies.NewCookieJar(jar, vault) persister, err := cookies.NewCookieJar(jar, vault)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("could not create cookie persister: %w", err)
} }
var logIMAP bool var (
logIMAP bool
logSMTP bool
)
if len(os.Getenv("FEATURE_TEST_LOG_IMAP")) != 0 { if len(os.Getenv("FEATURE_TEST_LOG_IMAP")) != 0 {
logrus.SetLevel(logrus.TraceLevel)
logIMAP = true logIMAP = true
} }
if len(os.Getenv("FEATURE_TEST_LOG_SMTP")) != 0 {
logSMTP = true
}
if logIMAP || logSMTP {
logrus.SetLevel(logrus.TraceLevel)
}
// Create the bridge. // Create the bridge.
bridge, eventCh, err := bridge.New( bridge, eventCh, err := bridge.New(
// App stuff // App stuff
@ -98,31 +163,178 @@ func (t *testCtx) startBridge() error {
// Logging stuff // Logging stuff
logIMAP, logIMAP,
logIMAP, logIMAP,
false, logSMTP,
) )
if err != nil { if err != nil {
return err return nil, fmt.Errorf("could not create bridge: %w", err)
} }
t.events.collectFrom(eventCh)
// Wait for the users to be loaded.
t.events.await(events.AllUsersLoaded{}, 30*time.Second)
// Save the bridge to the context.
t.bridge = bridge t.bridge = bridge
return nil return t.events.collectFrom(eventCh), nil
} }
func (t *testCtx) stopBridge() error { func (t *testCtx) closeBridge(ctx context.Context) error {
if t.bridge == nil { if t.bridge == nil {
return fmt.Errorf("bridge is not running") return fmt.Errorf("bridge is not started")
} }
t.bridge.Close(context.Background()) t.bridge.Close(ctx)
t.bridge = nil
return nil return nil
} }
func (t *testCtx) initFrontendService(eventCh <-chan events.Event) error {
if t.service != nil {
return fmt.Errorf("frontend service is already started")
}
// When starting the frontend, we might enable autostart on bridge if it isn't already.
t.mocks.Autostarter.EXPECT().Enable().AnyTimes()
service, err := frontend.NewService(
new(mockCrashHandler),
new(mockRestarter),
t.locator,
t.bridge,
eventCh,
true,
)
if err != nil {
return fmt.Errorf("could not create service: %w", err)
}
logrus.Info("Frontend service started")
t.service = service
t.serviceWG.Add(1)
go func() {
defer t.serviceWG.Done()
if err := service.Loop(); err != nil {
panic(err)
}
}()
return nil
}
func (t *testCtx) closeFrontendService(ctx context.Context) error {
if t.service == nil {
return fmt.Errorf("frontend service is not started")
}
if _, err := t.client.Quit(ctx, &emptypb.Empty{}); err != nil {
return fmt.Errorf("could not quit frontend: %w", err)
}
t.serviceWG.Wait()
logrus.Info("Frontend service stopped")
t.service = nil
return nil
}
func (t *testCtx) initFrontendClient() error {
if t.client != nil {
return fmt.Errorf("frontend client is already started")
}
settings, err := t.locator.ProvideSettingsPath()
if err != nil {
return fmt.Errorf("could not get settings path: %w", err)
}
b, err := os.ReadFile(filepath.Join(settings, "grpcServerConfig.json"))
if err != nil {
return fmt.Errorf("could not read grpcServerConfig.json: %w", err)
}
var cfg frontend.Config
if err := json.Unmarshal(b, &cfg); err != nil {
return fmt.Errorf("could not unmarshal grpcServerConfig.json: %w", err)
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM([]byte(cfg.Cert)) {
return fmt.Errorf("failed to append certificates to pool")
}
conn, err := grpc.DialContext(
context.Background(),
fmt.Sprintf("%v:%d", constants.Host, cfg.Port),
grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{RootCAs: cp})),
grpc.WithUnaryInterceptor(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return invoker(metadata.AppendToOutgoingContext(ctx, "server-token", cfg.Token), method, req, reply, cc, opts...)
}),
)
if err != nil {
return fmt.Errorf("could not dial grpc server: %w", err)
}
client := frontend.NewBridgeClient(conn)
stream, err := client.RunEventStream(context.Background(), &frontend.EventStreamRequest{ClientPlatform: runtime.GOOS})
if err != nil {
return fmt.Errorf("could not start event stream: %w", err)
}
eventCh := queue.NewQueuedChannel[*frontend.StreamEvent](0, 0)
go func() {
defer eventCh.CloseAndDiscardQueued()
for {
event, err := stream.Recv()
if err != nil {
return
}
eventCh.Enqueue(event)
}
}()
logrus.Info("Frontend client started")
t.client = client
t.clientConn = conn
t.clientEventCh = eventCh
return nil
}
func (t *testCtx) closeFrontendClient() error {
if t.client == nil {
return fmt.Errorf("frontend client is not started")
}
if err := t.clientConn.Close(); err != nil {
return fmt.Errorf("could not close frontend client connection: %w", err)
}
logrus.Info("Frontend client stopped")
t.client = nil
t.clientConn = nil
t.clientEventCh = nil
return nil
}
type mockCrashHandler struct{}
func (m *mockCrashHandler) HandlePanic() {}
type mockRestarter struct{}
func (m *mockRestarter) Set(restart, crash bool) {}
func (m *mockRestarter) AddFlags(flags ...string) {}
func (m *mockRestarter) Override(exe string) {}

View File

@ -21,16 +21,14 @@ import (
"context" "context"
"fmt" "fmt"
"net/smtp" "net/smtp"
"reflect"
"regexp" "regexp"
"sync" "sync"
"testing" "testing"
"time"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events" frontend "github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc"
"github.com/ProtonMail/proton-bridge/v2/internal/locations" "github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-imap/client" "github.com/emersion/go-imap/client"
@ -38,6 +36,7 @@ import (
"gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"google.golang.org/grpc"
) )
var defaultVersion = semver.MustParse("1.0.0") var defaultVersion = semver.MustParse("1.0.0")
@ -56,6 +55,15 @@ type testCtx struct {
// bridge holds the bridge app under test. // bridge holds the bridge app under test.
bridge *bridge.Bridge bridge *bridge.Bridge
// service holds the gRPC frontend service under test.
service *frontend.Service
serviceWG sync.WaitGroup
// client holds the gRPC frontend client under test.
client frontend.BridgeClient
clientConn *grpc.ClientConn
clientEventCh *queue.QueuedChannel[*frontend.StreamEvent]
// These maps hold expected userIDByName, their primary addresses and bridge passwords. // These maps hold expected userIDByName, their primary addresses and bridge passwords.
userIDByName map[string]string userIDByName map[string]string
userAddrByEmail map[string]map[string]string userAddrByEmail map[string]map[string]string
@ -295,85 +303,25 @@ func (t *testCtx) close(ctx context.Context) {
} }
} }
if t.service != nil {
if err := t.closeFrontendService(ctx); err != nil {
logrus.WithError(err).Error("Failed to close frontend service")
}
}
if t.client != nil {
if err := t.closeFrontendClient(); err != nil {
logrus.WithError(err).Error("Failed to close frontend client")
}
}
if t.bridge != nil { if t.bridge != nil {
t.bridge.Close(ctx) if err := t.closeBridge(ctx); err != nil {
logrus.WithError(err).Error("Failed to close bridge")
}
} }
t.api.Close() t.api.Close()
t.events.close() t.events.close()
} }
type eventCollector struct {
events map[reflect.Type]*queue.QueuedChannel[events.Event]
lock sync.RWMutex
wg sync.WaitGroup
}
func newEventCollector() *eventCollector {
return &eventCollector{
events: make(map[reflect.Type]*queue.QueuedChannel[events.Event]),
}
}
func (c *eventCollector) collectFrom(eventCh <-chan events.Event) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for event := range eventCh {
c.push(event)
}
}()
}
func awaitType[T events.Event](c *eventCollector, ofType T, timeout time.Duration) (T, bool) {
if event := c.await(ofType, timeout); event == nil {
return *new(T), false //nolint:gocritic
} else if event, ok := event.(T); !ok {
panic(fmt.Errorf("unexpected event type %T", event))
} else {
return event, true
}
}
func (c *eventCollector) await(ofType events.Event, timeout time.Duration) events.Event {
select {
case event := <-c.getEventCh(ofType):
return event
case <-time.After(timeout):
return nil
}
}
func (c *eventCollector) push(event events.Event) {
c.lock.Lock()
defer c.lock.Unlock()
if _, ok := c.events[reflect.TypeOf(event)]; !ok {
c.events[reflect.TypeOf(event)] = queue.NewQueuedChannel[events.Event](0, 0)
}
c.events[reflect.TypeOf(event)].Enqueue(event)
}
func (c *eventCollector) getEventCh(ofType events.Event) <-chan events.Event {
c.lock.Lock()
defer c.lock.Unlock()
if _, ok := c.events[reflect.TypeOf(ofType)]; !ok {
c.events[reflect.TypeOf(ofType)] = queue.NewQueuedChannel[events.Event](0, 0)
}
return c.events[reflect.TypeOf(ofType)].GetChannel()
}
func (c *eventCollector) close() {
c.wg.Wait()
for _, eventCh := range c.events {
eventCh.CloseAndDiscardQueued()
}
}

View File

@ -0,0 +1,5 @@
Feature: Frontend events
Scenario: Frontend starts and stops
Given bridge is version "2.3.0" and the latest available version is "2.3.0" reachable from "2.3.0"
When bridge starts
Then frontend sees that bridge is version "2.3.0"

38
tests/frontend_test.go Normal file
View File

@ -0,0 +1,38 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package tests
import (
"context"
"fmt"
"google.golang.org/protobuf/types/known/emptypb"
)
func (s *scenario) frontendSeesThatBridgeIsVersion(version string) error {
res, err := s.t.client.Version(context.Background(), &emptypb.Empty{})
if err != nil {
return err
}
if version != res.GetValue() {
return fmt.Errorf("expected version %s, got %s", version, res.GetValue())
}
return nil
}