feat: add expiry

This commit is contained in:
James Houlahan
2020-07-23 14:55:55 +02:00
parent 369c6ebf85
commit 5ad307868e
4 changed files with 48 additions and 6 deletions

View File

@ -105,7 +105,7 @@ func (sb *smtpBackend) shouldReportOutgoingNoEnc() bool {
} }
func (sb *smtpBackend) ConfirmNoEncryption(messageID string, shouldSend bool) { func (sb *smtpBackend) ConfirmNoEncryption(messageID string, shouldSend bool) {
if err := sb.confirmer.SetResponse(messageID, shouldSend); err != nil { if err := sb.confirmer.SetResult(messageID, shouldSend); err != nil {
logrus.WithError(err).Error("Failed to set confirmation value") logrus.WithError(err).Error("Failed to set confirmation value")
} }
} }

View File

@ -61,9 +61,9 @@ func (c *Confirmer) SetResult(id string, value bool) error {
return errors.New("no such request") return errors.New("no such request")
} }
req.value <- value req.ch <- value
close(req.value) close(req.ch)
delete(c.requests, id) delete(c.requests, id)
return nil return nil

View File

@ -65,3 +65,20 @@ func TestConfirmerTimeout(t *testing.T) {
_, err := req.Result() _, err := req.Result()
assert.Error(t, err) assert.Error(t, err)
} }
func TestConfirmerMultipleRequestCalls(t *testing.T) {
c := New()
req := c.NewRequest(1 * time.Second)
go func() {
assert.NoError(t, c.SetResult(req.ID(), true))
}()
res, err := req.Result()
assert.NoError(t, err)
assert.True(t, res)
_, errAgain := req.Result()
assert.Error(t, errAgain)
}

View File

@ -19,6 +19,7 @@ package confirmer
import ( import (
"errors" "errors"
"sync"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@ -27,15 +28,19 @@ import (
// Request provides a result when it becomes available. // Request provides a result when it becomes available.
type Request struct { type Request struct {
uuid string uuid string
value chan bool ch chan bool
timeout time.Duration timeout time.Duration
expired bool
locker sync.Locker
} }
func newRequest(timeout time.Duration) *Request { func newRequest(timeout time.Duration) *Request {
return &Request{ return &Request{
uuid: uuid.New().String(), uuid: uuid.New().String(),
value: make(chan bool), ch: make(chan bool),
timeout: timeout, timeout: timeout,
locker: &sync.Mutex{},
} }
} }
@ -46,11 +51,31 @@ func (r *Request) ID() string {
// Result returns the result or an error if it is not available within the request timeout. // Result returns the result or an error if it is not available within the request timeout.
func (r *Request) Result() (bool, error) { func (r *Request) Result() (bool, error) {
if r.hasExpired() {
return false, errors.New("this result has expired")
}
defer r.done()
select { select {
case res := <-r.value: case res := <-r.ch:
return res, nil return res, nil
case <-time.After(r.timeout): case <-time.After(r.timeout):
return false, errors.New("timed out waiting for result") return false, errors.New("timed out waiting for result")
} }
} }
func (r *Request) hasExpired() bool {
r.locker.Lock()
defer r.locker.Unlock()
return r.expired
}
func (r *Request) done() {
r.locker.Lock()
defer r.locker.Unlock()
r.expired = true
}