diff --git a/pkg/confirmer/confirmer_test.go b/pkg/confirmer/confirmer_test.go index 7e1e87f9..60b64fc9 100644 --- a/pkg/confirmer/confirmer_test.go +++ b/pkg/confirmer/confirmer_test.go @@ -66,7 +66,7 @@ func TestConfirmerTimeout(t *testing.T) { assert.Error(t, err) } -func TestConfirmerMultipleRequestCalls(t *testing.T) { +func TestConfirmerMultipleResultCalls(t *testing.T) { c := New() req := c.NewRequest(1 * time.Second) @@ -83,6 +83,25 @@ func TestConfirmerMultipleRequestCalls(t *testing.T) { assert.Error(t, errAgain) } +func TestConfirmerMultipleSimultaneousResultCalls(t *testing.T) { + c := New() + + req := c.NewRequest(1 * time.Second) + + go func() { + time.Sleep(1 * time.Second) + assert.NoError(t, c.SetResult(req.ID(), true)) + }() + + // We just check that nothing panics. We can't know which Result() will get the result though. + + go func() { _, _ = req.Result() }() + go func() { _, _ = req.Result() }() + go func() { _, _ = req.Result() }() + + _, _ = req.Result() +} + func TestConfirmerMultipleSetResultCalls(t *testing.T) { c := New() diff --git a/pkg/confirmer/request.go b/pkg/confirmer/request.go index 26217842..b7591fad 100644 --- a/pkg/confirmer/request.go +++ b/pkg/confirmer/request.go @@ -51,11 +51,14 @@ func (r *Request) ID() string { // Result returns the result or an error if it is not available within the request timeout. func (r *Request) Result() (bool, error) { - if r.hasExpired() { + r.locker.Lock() + defer r.locker.Unlock() + + if r.expired { return false, errors.New("this result has expired") } - defer r.done() + defer func() { r.expired = true }() select { case res := <-r.ch: @@ -65,17 +68,3 @@ func (r *Request) Result() (bool, error) { return false, errors.New("timed out waiting for result") } } - -func (r *Request) hasExpired() bool { - r.locker.Lock() - defer r.locker.Unlock() - - return r.expired -} - -func (r *Request) done() { - r.locker.Lock() - defer r.locker.Unlock() - - r.expired = true -}