From 247e676b41d56a70d595bd4b398e41344f2e28a0 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Mon, 24 Oct 2022 22:08:58 +0200 Subject: [PATCH] Other: Fix race conditions in integration tests --- tests/ctx_test.go | 58 +++++++++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/tests/ctx_test.go b/tests/ctx_test.go index 6d18fae4..89f32040 100644 --- a/tests/ctx_test.go +++ b/tests/ctx_test.go @@ -67,10 +67,12 @@ type testCtx struct { smtpClients map[string]*smtpClient // calls holds calls made to the API during each step of the test. - calls [][]server.Call + calls [][]server.Call + callsLock sync.RWMutex // errors holds test-related errors encountered while running test steps. - errors [][]error + errors [][]error + errorsLock sync.RWMutex } type imapClient struct { @@ -86,7 +88,7 @@ type smtpClient struct { func newTestCtx(tb testing.TB) *testCtx { dir := tb.TempDir() - ctx := &testCtx{ + t := &testCtx{ dir: dir, api: newFakeAPI(), netCtl: liteapi.NewNetCtl(), @@ -105,14 +107,23 @@ func newTestCtx(tb testing.TB) *testCtx { smtpClients: make(map[string]*smtpClient), } - ctx.api.AddCallWatcher(func(call server.Call) { - ctx.calls[len(ctx.calls)-1] = append(ctx.calls[len(ctx.calls)-1], call) + t.api.AddCallWatcher(func(call server.Call) { + t.callsLock.Lock() + defer t.callsLock.Unlock() + + t.calls[len(t.calls)-1] = append(t.calls[len(t.calls)-1], call) }) - return ctx + return t } func (t *testCtx) beforeStep() { + t.callsLock.Lock() + defer t.callsLock.Unlock() + + t.errorsLock.Lock() + defer t.errorsLock.Unlock() + t.calls = append(t.calls, nil) t.errors = append(t.errors, nil) } @@ -205,38 +216,35 @@ func (t *testCtx) getMBoxID(userID string, name string) string { return labelID } -func (t *testCtx) getLastCall(method, path string) (server.Call, error) { - var allCalls []server.Call +func (t *testCtx) getLastCall(method, pathExp string) (server.Call, error) { + t.callsLock.RLock() + defer t.callsLock.RUnlock() - for _, calls := range t.calls { - allCalls = append(allCalls, calls...) + if matches := xslices.Filter(xslices.Join(t.calls...), func(call server.Call) bool { + return call.Method == method && regexp.MustCompile("^"+pathExp+"$").MatchString(call.URL.Path) + }); len(matches) > 0 { + return matches[len(matches)-1], nil } - if len(allCalls) == 0 { - return server.Call{}, fmt.Errorf("no calls made") - } - - for idx := len(allCalls) - 1; idx >= 0; idx-- { - if call := allCalls[idx]; call.Method == method && regexp.MustCompile("^"+path+"$").MatchString(call.URL.Path) { - return call, nil - } - } - - return server.Call{}, fmt.Errorf("no call with method %q and path %q was made", method, path) + return server.Call{}, fmt.Errorf("no call with method %q and path %q was made", method, pathExp) } func (t *testCtx) pushError(err error) { + t.errorsLock.Lock() + defer t.errorsLock.Unlock() + t.errors[len(t.errors)-1] = append(t.errors[len(t.errors)-1], err) } func (t *testCtx) getLastError() error { - errors := t.errors[len(t.errors)-2] + t.errorsLock.RLock() + defer t.errorsLock.RUnlock() - if len(errors) == 0 { - return nil + if lastStep := t.errors[len(t.errors)-2]; len(lastStep) > 0 { + return lastStep[len(lastStep)-1] } - return errors[len(errors)-1] + return nil } func (t *testCtx) close(ctx context.Context) error {