Other: Fix race conditions in integration tests

This commit is contained in:
James Houlahan
2022-10-24 22:08:58 +02:00
parent cb04dabea8
commit 247e676b41

View File

@ -67,10 +67,12 @@ type testCtx struct {
smtpClients map[string]*smtpClient smtpClients map[string]*smtpClient
// calls holds calls made to the API during each step of the test. // calls holds calls made to the API during each step of the test.
calls [][]server.Call calls [][]server.Call
callsLock sync.RWMutex
// errors holds test-related errors encountered while running test steps. // errors holds test-related errors encountered while running test steps.
errors [][]error errors [][]error
errorsLock sync.RWMutex
} }
type imapClient struct { type imapClient struct {
@ -86,7 +88,7 @@ type smtpClient struct {
func newTestCtx(tb testing.TB) *testCtx { func newTestCtx(tb testing.TB) *testCtx {
dir := tb.TempDir() dir := tb.TempDir()
ctx := &testCtx{ t := &testCtx{
dir: dir, dir: dir,
api: newFakeAPI(), api: newFakeAPI(),
netCtl: liteapi.NewNetCtl(), netCtl: liteapi.NewNetCtl(),
@ -105,14 +107,23 @@ func newTestCtx(tb testing.TB) *testCtx {
smtpClients: make(map[string]*smtpClient), smtpClients: make(map[string]*smtpClient),
} }
ctx.api.AddCallWatcher(func(call server.Call) { t.api.AddCallWatcher(func(call server.Call) {
ctx.calls[len(ctx.calls)-1] = append(ctx.calls[len(ctx.calls)-1], call) t.callsLock.Lock()
defer t.callsLock.Unlock()
t.calls[len(t.calls)-1] = append(t.calls[len(t.calls)-1], call)
}) })
return ctx return t
} }
func (t *testCtx) beforeStep() { func (t *testCtx) beforeStep() {
t.callsLock.Lock()
defer t.callsLock.Unlock()
t.errorsLock.Lock()
defer t.errorsLock.Unlock()
t.calls = append(t.calls, nil) t.calls = append(t.calls, nil)
t.errors = append(t.errors, nil) t.errors = append(t.errors, nil)
} }
@ -205,38 +216,35 @@ func (t *testCtx) getMBoxID(userID string, name string) string {
return labelID return labelID
} }
func (t *testCtx) getLastCall(method, path string) (server.Call, error) { func (t *testCtx) getLastCall(method, pathExp string) (server.Call, error) {
var allCalls []server.Call t.callsLock.RLock()
defer t.callsLock.RUnlock()
for _, calls := range t.calls { if matches := xslices.Filter(xslices.Join(t.calls...), func(call server.Call) bool {
allCalls = append(allCalls, calls...) return call.Method == method && regexp.MustCompile("^"+pathExp+"$").MatchString(call.URL.Path)
}); len(matches) > 0 {
return matches[len(matches)-1], nil
} }
if len(allCalls) == 0 { return server.Call{}, fmt.Errorf("no call with method %q and path %q was made", method, pathExp)
return server.Call{}, fmt.Errorf("no calls made")
}
for idx := len(allCalls) - 1; idx >= 0; idx-- {
if call := allCalls[idx]; call.Method == method && regexp.MustCompile("^"+path+"$").MatchString(call.URL.Path) {
return call, nil
}
}
return server.Call{}, fmt.Errorf("no call with method %q and path %q was made", method, path)
} }
func (t *testCtx) pushError(err error) { func (t *testCtx) pushError(err error) {
t.errorsLock.Lock()
defer t.errorsLock.Unlock()
t.errors[len(t.errors)-1] = append(t.errors[len(t.errors)-1], err) t.errors[len(t.errors)-1] = append(t.errors[len(t.errors)-1], err)
} }
func (t *testCtx) getLastError() error { func (t *testCtx) getLastError() error {
errors := t.errors[len(t.errors)-2] t.errorsLock.RLock()
defer t.errorsLock.RUnlock()
if len(errors) == 0 { if lastStep := t.errors[len(t.errors)-2]; len(lastStep) > 0 {
return nil return lastStep[len(lastStep)-1]
} }
return errors[len(errors)-1] return nil
} }
func (t *testCtx) close(ctx context.Context) error { func (t *testCtx) close(ctx context.Context) error {