GODT-2181(test): Refactor integration test setup a bit

This commit is contained in:
James Houlahan
2022-12-12 13:49:34 +01:00
parent 49fa451cc3
commit 1aca2cde71
17 changed files with 285 additions and 243 deletions

View File

@ -27,10 +27,6 @@ type API interface {
GetHostURL() string
AddCallWatcher(func(server.Call), ...string)
CreateUser(username, address string, password []byte) (string, string, error)
CreateAddress(userID, address string, password []byte) (string, error)
RemoveAddress(userID, addrID string) error
RemoveAddressKey(userID, addrID, keyID string) error
Close()

View File

@ -41,21 +41,26 @@ func (s *scenario) close(_ testing.TB) {
}
func TestFeatures(testingT *testing.T) {
paths := []string{"features"}
if features := os.Getenv("FEATURES"); features != "" {
paths = strings.Split(features, " ")
}
var s scenario
suite := godog.TestSuite{
ScenarioInitializer: func(ctx *godog.ScenarioContext) {
var s scenario
TestSuiteInitializer: func(ctx *godog.TestSuiteContext) {
ctx.BeforeSuite(func() {
// Global setup.
})
ctx.Before(func(ctx context.Context, sc *godog.Scenario) (context.Context, error) {
ctx.AfterSuite(func() {
// Global teardown.
})
},
ScenarioInitializer: func(ctx *godog.ScenarioContext) {
ctx.Before(func(ctx context.Context, _ *godog.Scenario) (context.Context, error) {
s.reset(testingT)
return ctx, nil
})
ctx.After(func(ctx context.Context, sc *godog.Scenario, err error) (context.Context, error) {
ctx.After(func(ctx context.Context, _ *godog.Scenario, _ error) (context.Context, error) {
s.close(testingT)
return ctx, nil
})
@ -72,7 +77,7 @@ func TestFeatures(testingT *testing.T) {
return ctx, nil
})
ctx.StepContext().After(func(ctx context.Context, st *godog.Step, status godog.StepResultStatus, err error) (context.Context, error) {
ctx.StepContext().After(func(ctx context.Context, st *godog.Step, status godog.StepResultStatus, _ error) (context.Context, error) {
logrus.Debugf("Finished step (%v): %s", status, st.Text)
return ctx, nil
})
@ -197,7 +202,7 @@ func TestFeatures(testingT *testing.T) {
},
Options: &godog.Options{
Format: "pretty",
Paths: paths,
Paths: getFeaturePaths(),
TestingT: testingT,
},
}
@ -206,3 +211,15 @@ func TestFeatures(testingT *testing.T) {
testingT.Fatal("non-zero status returned, failed to run feature tests")
}
}
func getFeaturePaths() []string {
var paths []string
if features := os.Getenv("FEATURES"); features != "" {
paths = strings.Split(features, " ")
} else {
paths = []string{"features"}
}
return paths
}

View File

@ -30,7 +30,6 @@ import (
"time"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/cookies"
@ -157,7 +156,7 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) {
persister,
useragent.New(),
t.mocks.TLSReporter,
proton.NewDialer(t.netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
t.netCtl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}),
t.mocks.ProxyCtl,
t.mocks.CrashHandler,
t.reporter,

View File

@ -21,39 +21,54 @@ import (
"context"
"fmt"
"runtime"
"sync"
"testing"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices"
"github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func (t *testCtx) withClient(ctx context.Context, username string, fn func(context.Context, *proton.Client) error) error {
c, _, err := proton.New(
// withProton executes the given function with a proton manager configured to use the test API.
func (t *testCtx) withProton(fn func(*proton.Manager) error) error {
m := proton.New(
proton.WithHostURL(t.api.GetHostURL()),
proton.WithTransport(proton.InsecureTransport()),
).NewClientWithLogin(ctx, username, []byte(t.getUserPass(t.getUserID(username))))
if err != nil {
return err
}
)
defer m.Close()
defer c.Close()
return fn(m)
}
if err := fn(ctx, c); err != nil {
return fmt.Errorf("failed to execute with client: %w", err)
}
// withClient executes the given function with a client that is logged in as the given (known) user.
func (t *testCtx) withClient(ctx context.Context, username string, fn func(context.Context, *proton.Client) error) error {
return t.withClientPass(ctx, username, t.getUserPass(t.getUserID(username)), fn)
}
if err := c.AuthDelete(ctx); err != nil {
return fmt.Errorf("failed to delete auth: %w", err)
}
// withClient executes the given function with a client that is logged in with the given username and password.
func (t *testCtx) withClientPass(ctx context.Context, username, password string, fn func(context.Context, *proton.Client) error) error {
return t.withProton(func(m *proton.Manager) error {
c, _, err := m.NewClientWithLogin(ctx, username, []byte(password))
if err != nil {
return fmt.Errorf("failed to create client: %w", err)
}
defer c.Close()
return nil
if err := fn(ctx, c); err != nil {
return fmt.Errorf("failed to execute with client: %w", err)
}
if err := c.AuthDelete(ctx); err != nil {
return fmt.Errorf("failed to delete auth: %w", err)
}
return nil
})
}
// runQuarkCmd runs the given quark command with the given arguments.
func (t *testCtx) runQuarkCmd(ctx context.Context, command string, args ...string) error {
return t.withProton(func(m *proton.Manager) error {
return m.Quark(ctx, command, args...)
})
}
func (t *testCtx) withAddrKR(
@ -107,123 +122,3 @@ func (t *testCtx) createMessages(ctx context.Context, username, addrID string, r
})
})
}
type reportRecord struct {
isException bool
message string
context reporter.Context
}
type reportRecorder struct {
assert *assert.Assertions
reports []reportRecord
lock sync.Locker
isClosed bool
}
func newReportRecorder(tb testing.TB) *reportRecorder {
return &reportRecorder{
assert: assert.New(tb),
reports: []reportRecord{},
lock: &sync.Mutex{},
isClosed: false,
}
}
func (r *reportRecorder) add(isException bool, message string, context reporter.Context) {
r.lock.Lock()
defer r.lock.Unlock()
l := logrus.WithFields(logrus.Fields{
"isException": isException,
"message": message,
"context": context,
"pkg": "test/reportRecorder",
})
if r.isClosed {
l.Warn("Reporter closed, report skipped")
return
}
r.reports = append(r.reports, reportRecord{
isException: isException,
message: message,
context: context,
})
l.Warn("Report recorded")
}
func (r *reportRecorder) close() {
r.lock.Lock()
defer r.lock.Unlock()
r.isClosed = true
}
func (r *reportRecorder) assertEmpty() {
r.assert.Empty(r.reports)
}
func (r *reportRecorder) removeMatchingRecords(isException, message, context gomock.Matcher, n int) {
if n == 0 {
n = len(r.reports)
}
r.reports = xslices.Filter(r.reports, func(rec reportRecord) bool {
if n <= 0 {
return true
}
l := logrus.WithFields(logrus.Fields{
"rec": rec,
})
if !isException.Matches(rec.isException) {
l.WithField("matcher", isException).Debug("Not matching")
return true
}
if !message.Matches(rec.message) {
l.WithField("matcher", message).Debug("Not matching")
return true
}
if !context.Matches(rec.context) {
l.WithField("matcher", context).Debug("Not matching")
return true
}
n--
return false
})
}
func (r *reportRecorder) ReportException(data any) error {
r.add(true, "exception", reporter.Context{"data": data})
return nil
}
func (r *reportRecorder) ReportMessage(message string) error {
r.add(false, message, reporter.Context{})
return nil
}
func (r *reportRecorder) ReportMessageWithContext(message string, context reporter.Context) error {
r.add(false, message, context)
return nil
}
func (r *reportRecorder) ReportExceptionWithContext(data any, context reporter.Context) error {
if context == nil {
context = reporter.Context{}
}
context["data"] = data
r.add(true, "exception", context)
return nil
}

132
tests/ctx_reporter_test.go Normal file
View File

@ -0,0 +1,132 @@
package tests
import (
"sync"
"testing"
"github.com/ProtonMail/gluon/reporter"
"github.com/bradenaw/juniper/xslices"
"github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
type reportRecord struct {
isException bool
message string
context reporter.Context
}
type reportRecorder struct {
assert *assert.Assertions
reports []reportRecord
lock sync.Locker
isClosed bool
}
func newReportRecorder(tb testing.TB) *reportRecorder {
return &reportRecorder{
assert: assert.New(tb),
reports: []reportRecord{},
lock: &sync.Mutex{},
isClosed: false,
}
}
func (r *reportRecorder) add(isException bool, message string, context reporter.Context) {
r.lock.Lock()
defer r.lock.Unlock()
l := logrus.WithFields(logrus.Fields{
"isException": isException,
"message": message,
"context": context,
"pkg": "test/reportRecorder",
})
if r.isClosed {
l.Warn("Reporter closed, report skipped")
return
}
r.reports = append(r.reports, reportRecord{
isException: isException,
message: message,
context: context,
})
l.Warn("Report recorded")
}
func (r *reportRecorder) close() {
r.lock.Lock()
defer r.lock.Unlock()
r.isClosed = true
}
func (r *reportRecorder) assertEmpty() {
r.assert.Empty(r.reports)
}
func (r *reportRecorder) removeMatchingRecords(isException, message, context gomock.Matcher, n int) {
if n == 0 {
n = len(r.reports)
}
r.reports = xslices.Filter(r.reports, func(rec reportRecord) bool {
if n <= 0 {
return true
}
l := logrus.WithFields(logrus.Fields{
"rec": rec,
})
if !isException.Matches(rec.isException) {
l.WithField("matcher", isException).Debug("Not matching")
return true
}
if !message.Matches(rec.message) {
l.WithField("matcher", message).Debug("Not matching")
return true
}
if !context.Matches(rec.context) {
l.WithField("matcher", context).Debug("Not matching")
return true
}
n--
return false
})
}
func (r *reportRecorder) ReportException(data any) error {
r.add(true, "exception", reporter.Context{"data": data})
return nil
}
func (r *reportRecorder) ReportMessage(message string) error {
r.add(false, message, reporter.Context{})
return nil
}
func (r *reportRecorder) ReportMessageWithContext(message string, context reporter.Context) error {
r.add(false, message, context)
return nil
}
func (r *reportRecorder) ReportExceptionWithContext(data any, context reporter.Context) error {
if context == nil {
context = reporter.Context{}
}
context["data"] = data
r.add(true, "exception", context)
return nil
}

View File

@ -34,7 +34,6 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/locations"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-imap/client"
"github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"google.golang.org/grpc"
@ -317,18 +316,7 @@ func (t *testCtx) close(ctx context.Context) {
}
t.api.Close()
t.events.close()
t.reporter.close()
// Closed connection can happen in the end of scenario
t.reporter.removeMatchingRecords(
gomock.Eq(false),
gomock.Eq("Failed to parse imap command"),
gomock.Any(), // mocks.NewClosedConnectionMatcher(),
0,
)
t.reporter.assertEmpty()
}

View File

@ -20,35 +20,19 @@ package tests
import (
"crypto/x509"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/certs"
)
var (
preCompPGPKey *crypto.Key
preCompCertPEM []byte
preCompKeyPEM []byte
)
func FastGenerateKey(name, email string, passphrase []byte, keyType string, bits int) (string, error) {
encKey, err := preCompPGPKey.Lock(passphrase)
if err != nil {
return "", err
}
return encKey.Armor()
}
func FastGenerateCert(template *x509.Certificate) ([]byte, []byte, error) {
return preCompCertPEM, preCompKeyPEM, nil
}
func init() {
key, err := crypto.GenerateKey("name", "email", "rsa", 1024)
if err != nil {
panic(err)
}
template, err := certs.NewTLSTemplate()
if err != nil {
panic(err)
@ -59,7 +43,6 @@ func init() {
panic(err)
}
preCompPGPKey = key
preCompCertPEM = certPEM
preCompKeyPEM = keyPEM
}

View File

@ -27,7 +27,7 @@ import (
func init() {
// Use the fast key generation for tests.
backend.GenerateKey = FastGenerateKey
backend.GenerateKey = backend.FastGenerateKey
// Use the fast cert generation for tests.
certs.GenerateCert = FastGenerateCert

View File

@ -34,42 +34,80 @@ import (
)
func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error {
// Create the user.
userID, addrID, err := s.t.api.CreateUser(username, username, []byte(password))
if err != nil {
// Create the user and generate its default address (with keys).
if err := s.t.runQuarkCmd(
context.Background(),
"user:create",
"--name", username,
"--password", password,
"--gen-keys", "RSA2048",
); err != nil {
return err
}
// Set the ID of this user.
s.t.setUserID(username, userID)
return s.t.withClientPass(context.Background(), username, password, func(ctx context.Context, c *proton.Client) error {
user, err := c.GetUser(ctx)
if err != nil {
return err
}
// Set the password of this user.
s.t.setUserPass(userID, password)
addr, err := c.GetAddresses(ctx)
if err != nil {
return err
}
// Set the address of this user (right now just the same as the username, but let's stay flexible).
s.t.setUserAddr(userID, addrID, username)
// Set the ID of the user.
s.t.setUserID(username, user.ID)
return nil
// Set the password of the user.
s.t.setUserPass(user.ID, password)
// Set the address of the user.
s.t.setUserAddr(user.ID, addr[0].ID, addr[0].Email)
return nil
})
}
func (s *scenario) theAccountHasAdditionalAddress(username, address string) error {
userID := s.t.getUserID(username)
addrID, err := s.t.api.CreateAddress(userID, address, []byte(s.t.getUserPass(userID)))
if err != nil {
// Create the user's additional address.
if err := s.t.runQuarkCmd(
context.Background(),
"user:create:address",
userID,
s.t.getUserPass(userID),
address,
"--gen-keys", "RSA2048",
); err != nil {
return err
}
s.t.setUserAddr(userID, addrID, address)
return s.t.withClient(context.Background(), username, func(ctx context.Context, c *proton.Client) error {
addr, err := c.GetAddresses(ctx)
if err != nil {
return err
}
return nil
// Set the new address of the user.
s.t.setUserAddr(userID, addr[len(addr)-1].ID, address)
return nil
})
}
func (s *scenario) theAccountNoLongerHasAdditionalAddress(username, address string) error {
userID := s.t.getUserID(username)
addrID := s.t.getUserAddrID(userID, address)
if err := s.t.api.RemoveAddress(userID, addrID); err != nil {
if err := s.t.withClient(context.Background(), username, func(ctx context.Context, c *proton.Client) error {
if err := c.DisableAddress(ctx, addrID); err != nil {
return err
}
return c.DeleteAddress(ctx, addrID)
}); err != nil {
return err
}