diff --git a/internal/cookies/jar.go b/internal/cookies/jar.go index cd17eee6..d0acb4c4 100644 --- a/internal/cookies/jar.go +++ b/internal/cookies/jar.go @@ -40,10 +40,14 @@ type GetterSetter interface { Set(string, string) } -func NewCookieJar(getterSetter GetterSetter) (*Jar, error) { - pantry := &pantry{gs: getterSetter} +func NewCookieJar(gs GetterSetter) (*Jar, error) { + pantry := &pantry{gs: gs} - cookies, err := pantry.loadCookies() + if err := pantry.discardExpiredCookies(); err != nil { + return nil, err + } + + cookies, err := pantry.loadFromJSON() if err != nil { return nil, err } @@ -75,7 +79,7 @@ func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { j.jar.SetCookies(u, cookies) - if err := j.pantry.persistCookies(u.String(), cookies); err != nil { + if err := j.pantry.persistCookies(u.Scheme+"://"+u.Host, cookies); err != nil { logrus.WithError(err).Warn("Failed to persist cookie") } } diff --git a/internal/cookies/jar_test.go b/internal/cookies/jar_test.go index c26591f8..ed7375ba 100644 --- a/internal/cookies/jar_test.go +++ b/internal/cookies/jar_test.go @@ -21,6 +21,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -28,9 +29,9 @@ import ( func TestJarGetSet(t *testing.T) { ts := getTestServer(t, []testCookie{ - {"TestName1", "TestValue1"}, - {"TestName2", "TestValue2"}, - {"TestName3", "TestValue3"}, + {"TestName1", "TestValue1", 3600}, + {"TestName2", "TestValue2", 3600}, + {"TestName3", "TestValue3", 3600}, }) defer ts.Close() @@ -53,9 +54,9 @@ func TestJarGetSet(t *testing.T) { func TestJarLoad(t *testing.T) { ts := getTestServer(t, []testCookie{ - {"TestName1", "TestValue1"}, - {"TestName2", "TestValue2"}, - {"TestName3", "TestValue3"}, + {"TestName1", "TestValue1", 3600}, + {"TestName2", "TestValue2", 3600}, + {"TestName3", "TestValue3", 3600}, }) defer ts.Close() @@ -83,8 +84,41 @@ func TestJarLoad(t *testing.T) { require.NoError(t, getRes.Body.Close()) } +func TestJarExpiry(t *testing.T) { + ts := getTestServer(t, []testCookie{ + {"TestName1", "TestValue1", 3600}, + {"TestName2", "TestValue2", 1}, + {"TestName3", "TestValue3", 3600}, + }) + defer ts.Close() + + // This will be our "persistent storage" from which the cookie jar should load cookies. + gs := make(testGetterSetter) + + // This client saves cookies to persistent storage. + oldClient := getClientWithJar(t, gs) + + // Hit a server that sets some cookies. + setRes, err := oldClient.Get(ts.URL + "/set") + if err != nil { + t.FailNow() + } + require.NoError(t, setRes.Body.Close()) + + // Wait until the second cookie expires. + time.Sleep(2 * time.Second) + + // Load a client, which will clear out expired cookies. + _ = getClientWithJar(t, gs) + + assert.Contains(t, gs["cookies"], "TestName1") + assert.NotContains(t, gs["cookies"], "TestName2") + assert.Contains(t, gs["cookies"], "TestName3") +} + type testCookie struct { name, value string + maxAge int } func getClientWithJar(t *testing.T, gs GetterSetter) *http.Client { @@ -100,8 +134,9 @@ func getTestServer(t *testing.T, wantCookies []testCookie) *httptest.Server { mux.HandleFunc("/set", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { for _, cookie := range wantCookies { http.SetCookie(w, &http.Cookie{ - Name: cookie.name, - Value: cookie.value, + Name: cookie.name, + Value: cookie.value, + MaxAge: cookie.maxAge, }) } diff --git a/internal/cookies/pantry.go b/internal/cookies/pantry.go index e7215b5c..fbd1b1fa 100644 --- a/internal/cookies/pantry.go +++ b/internal/cookies/pantry.go @@ -30,69 +30,56 @@ type pantry struct { gs GetterSetter } -func (p *pantry) persistCookies(url string, cookies []*http.Cookie) error { +func (p *pantry) persistCookies(host string, cookies []*http.Cookie) error { for _, cookie := range cookies { if cookie.MaxAge > 0 { cookie.Expires = time.Now().Add(time.Duration(cookie.MaxAge) * time.Second) } } - b, err := json.Marshal(cookies) + cookiesByHost, err := p.loadFromJSON() if err != nil { return err } - val, err := p.loadFromJSON() + cookiesByHost[host] = cookies + + return p.saveToJSON(cookiesByHost) +} + +func (p *pantry) discardExpiredCookies() error { + cookiesByHost, err := p.loadFromJSON() if err != nil { return err } - val[url] = string(b) - - return p.saveToJSON(val) -} - -func (p *pantry) loadCookies() (map[string][]*http.Cookie, error) { - res := make(map[string][]*http.Cookie) - - val, err := p.loadFromJSON() - if err != nil { - return nil, err + for host, cookies := range cookiesByHost { + cookiesByHost[host] = discardExpiredCookies(cookies) } - for url, rawCookies := range val { - var cookies []*http.Cookie - - if err := json.Unmarshal([]byte(rawCookies), &cookies); err != nil { - return nil, err - } - - res[url] = cookies - } - - return res, nil + return p.saveToJSON(cookiesByHost) } -type dataStructure map[string]string +type cookiesByHost map[string][]*http.Cookie -func (p *pantry) loadFromJSON() (dataStructure, error) { +func (p *pantry) loadFromJSON() (cookiesByHost, error) { b := p.gs.Get(preferences.CookiesKey) if b == "" { - return make(dataStructure), nil + return make(cookiesByHost), nil } - var val dataStructure + var cookies cookiesByHost - if err := json.Unmarshal([]byte(b), &val); err != nil { + if err := json.Unmarshal([]byte(b), &cookies); err != nil { return nil, err } - return val, nil + return cookies, nil } -func (p *pantry) saveToJSON(val dataStructure) error { - b, err := json.Marshal(val) +func (p *pantry) saveToJSON(cookies cookiesByHost) error { + b, err := json.Marshal(cookies) if err != nil { return err } @@ -101,3 +88,13 @@ func (p *pantry) saveToJSON(val dataStructure) error { return nil } + +func discardExpiredCookies(cookies []*http.Cookie) (validCookies []*http.Cookie) { + for _, cookie := range cookies { + if cookie.Expires.After(time.Now()) { + validCookies = append(validCookies, cookie) + } + } + + return +}