diff --git a/internal/app/base/base.go b/internal/app/base/base.go index 3d91191e..8357d5c7 100644 --- a/internal/app/base/base.go +++ b/internal/app/base/base.go @@ -68,6 +68,7 @@ type Base struct { Listener listener.Listener Creds *credentials.Store CM *pmapi.ClientManager + CookieJar *cookies.Jar Updater *updater.Updater Versioner *versioner.Versioner TLS *tls.TLS @@ -209,6 +210,7 @@ func New( // nolint[funlen] Listener: listener, Creds: creds, CM: cm, + CookieJar: jar, Updater: updater, Versioner: versioner, TLS: tls, diff --git a/internal/app/bridge/bridge.go b/internal/app/bridge/bridge.go index 00145724..ec4e9d06 100644 --- a/internal/app/bridge/bridge.go +++ b/internal/app/bridge/bridge.go @@ -111,6 +111,9 @@ func run(b *base.Base, c *cli.Context) error { // nolint[funlen] // We want to remove old versions if the app exits successfully. b.AddTeardownAction(b.Versioner.RemoveOldVersions) + // We want cookies to be saved to disk so they are loaded the next time. + b.AddTeardownAction(b.CookieJar.PersistCookies) + f := frontend.New( constants.Version, constants.BuildVersion, diff --git a/internal/app/ie/ie.go b/internal/app/ie/ie.go index 7f13e576..67e70348 100644 --- a/internal/app/ie/ie.go +++ b/internal/app/ie/ie.go @@ -57,6 +57,9 @@ func run(b *base.Base, c *cli.Context) error { // We want to remove old versions if the app exits successfully. b.AddTeardownAction(b.Versioner.RemoveOldVersions) + // We want cookies to be saved to disk so they are loaded the next time. + b.AddTeardownAction(b.CookieJar.PersistCookies) + f := frontend.NewImportExport( constants.Version, constants.BuildVersion, diff --git a/internal/cookies/jar.go b/internal/cookies/jar.go index 0e14d5ac..2633132d 100644 --- a/internal/cookies/jar.go +++ b/internal/cookies/jar.go @@ -19,46 +19,41 @@ package cookies import ( + "encoding/json" + "fmt" "net/http" "net/http/cookiejar" "net/url" "sync" + "time" - "github.com/sirupsen/logrus" + "github.com/ProtonMail/proton-bridge/internal/config/settings" ) +type cookiesByHost map[string][]*http.Cookie + // Jar implements http.CookieJar by wrapping the standard library's cookiejar.Jar. // The jar uses a pantry to load cookies at startup and save cookies when set. type Jar struct { - jar *cookiejar.Jar - pantry *pantry - locker sync.Locker + jar *cookiejar.Jar + settings *settings.Settings + cookies cookiesByHost + locker sync.Locker } -type GetterSetter interface { - Get(string) string - Set(string, string) -} - -func NewCookieJar(gs GetterSetter) (*Jar, error) { - pantry := &pantry{gs: gs} - - if err := pantry.discardExpiredCookies(); err != nil { - return nil, err - } - - cookies, err := pantry.loadFromJSON() - if err != nil { - return nil, err - } - +func NewCookieJar(s *settings.Settings) (*Jar, error) { jar, err := cookiejar.New(nil) if err != nil { return nil, err } - for rawURL, cookies := range cookies { - url, err := url.Parse(rawURL) + cookiesByHost, err := loadCookies(s) + if err != nil { + return nil, err + } + + for host, cookies := range cookiesByHost { + url, err := url.Parse(host) if err != nil { continue } @@ -67,9 +62,10 @@ func NewCookieJar(gs GetterSetter) (*Jar, error) { } return &Jar{ - jar: jar, - pantry: pantry, - locker: &sync.Mutex{}, + jar: jar, + settings: s, + cookies: cookiesByHost, + locker: &sync.Mutex{}, }, nil } @@ -79,9 +75,13 @@ func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { j.jar.SetCookies(u, cookies) - if err := j.pantry.persistCookies(u.Scheme+"://"+u.Host, cookies); err != nil { - logrus.WithError(err).Warn("Failed to persist cookie") + for _, cookie := range cookies { + if cookie.MaxAge > 0 { + cookie.Expires = time.Now().Add(time.Duration(cookie.MaxAge) * time.Second) + } } + + j.cookies[fmt.Sprintf("%v://%v", u.Scheme, u.Host)] = cookies } func (j *Jar) Cookies(u *url.URL) []*http.Cookie { @@ -90,3 +90,54 @@ func (j *Jar) Cookies(u *url.URL) []*http.Cookie { return j.jar.Cookies(u) } + +// PersistCookies persists the cookies to disk. +func (j *Jar) PersistCookies() error { + j.locker.Lock() + defer j.locker.Unlock() + + rawCookies, err := json.Marshal(j.cookies) + if err != nil { + return err + } + + j.settings.Set(settings.CookiesKey, string(rawCookies)) + + return nil +} + +// loadCookies loads all non-expired cookies from disk. +func loadCookies(s *settings.Settings) (cookiesByHost, error) { + rawCookies := s.Get(settings.CookiesKey) + + if rawCookies == "" { + return make(cookiesByHost), nil + } + + var cookiesByHost cookiesByHost + + if err := json.Unmarshal([]byte(rawCookies), &cookiesByHost); err != nil { + return nil, err + } + + for host, cookies := range cookiesByHost { + if validCookies := discardExpiredCookies(cookies); len(validCookies) > 0 { + cookiesByHost[host] = validCookies + } + } + + return cookiesByHost, nil +} + +// discardExpiredCookies returns all the given cookies which aren't expired. +func discardExpiredCookies(cookies []*http.Cookie) []*http.Cookie { + var validCookies []*http.Cookie + + for _, cookie := range cookies { + if cookie.Expires.After(time.Now()) { + validCookies = append(validCookies, cookie) + } + } + + return validCookies +} diff --git a/internal/cookies/jar_test.go b/internal/cookies/jar_test.go index 00ebb808..79aeca7b 100644 --- a/internal/cookies/jar_test.go +++ b/internal/cookies/jar_test.go @@ -18,11 +18,13 @@ package cookies import ( + "io/ioutil" "net/http" "net/http/httptest" "testing" "time" + "github.com/ProtonMail/proton-bridge/internal/config/settings" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -35,7 +37,7 @@ func TestJarGetSet(t *testing.T) { }) defer ts.Close() - client := getClientWithJar(t, make(testGetterSetter)) + client, _ := getClientWithJar(t, newFakeSettings()) // Hit a server that sets some cookies. setRes, err := client.Get(ts.URL + "/set") @@ -61,10 +63,10 @@ func TestJarLoad(t *testing.T) { defer ts.Close() // This will be our "persistent storage" from which the cookie jar should load cookies. - gs := make(testGetterSetter) + s := newFakeSettings() // This client saves cookies to persistent storage. - oldClient := getClientWithJar(t, gs) + oldClient, jar := getClientWithJar(t, s) // Hit a server that sets some cookies. setRes, err := oldClient.Get(ts.URL + "/set") @@ -73,8 +75,11 @@ func TestJarLoad(t *testing.T) { } require.NoError(t, setRes.Body.Close()) + // Save the cookies. + require.NoError(t, jar.PersistCookies()) + // This client loads cookies from persistent storage. - newClient := getClientWithJar(t, gs) + newClient, _ := getClientWithJar(t, s) // Hit a server that checks the cookies are there. getRes, err := newClient.Get(ts.URL + "/get") @@ -93,10 +98,10 @@ func TestJarExpiry(t *testing.T) { defer ts.Close() // This will be our "persistent storage" from which the cookie jar should load cookies. - gs := make(testGetterSetter) + s := newFakeSettings() // This client saves cookies to persistent storage. - oldClient := getClientWithJar(t, gs) + oldClient, jar1 := getClientWithJar(t, s) // Hit a server that sets some cookies. setRes, err := oldClient.Get(ts.URL + "/set") @@ -105,15 +110,21 @@ func TestJarExpiry(t *testing.T) { } require.NoError(t, setRes.Body.Close()) + // Save the cookies. + require.NoError(t, jar1.PersistCookies()) + // Wait until the second cookie expires. time.Sleep(2 * time.Second) // Load a client, which will clear out expired cookies. - _ = getClientWithJar(t, gs) + _, jar2 := getClientWithJar(t, s) - assert.Contains(t, gs["cookies"], "TestName1") - assert.NotContains(t, gs["cookies"], "TestName2") - assert.Contains(t, gs["cookies"], "TestName3") + // Save the cookies (expired ones were cleared out). + require.NoError(t, jar2.PersistCookies()) + + assert.Contains(t, s.Get(settings.CookiesKey), "TestName1") + assert.NotContains(t, s.Get(settings.CookiesKey), "TestName2") + assert.Contains(t, s.Get(settings.CookiesKey), "TestName3") } type testCookie struct { @@ -121,11 +132,11 @@ type testCookie struct { maxAge int } -func getClientWithJar(t *testing.T, gs GetterSetter) *http.Client { - jar, err := NewCookieJar(gs) +func getClientWithJar(t *testing.T, s *settings.Settings) (*http.Client, *Jar) { + jar, err := NewCookieJar(s) require.NoError(t, err) - return &http.Client{Jar: jar} + return &http.Client{Jar: jar}, jar } func getTestServer(t *testing.T, wantCookies []testCookie) *httptest.Server { @@ -157,12 +168,12 @@ func getTestServer(t *testing.T, wantCookies []testCookie) *httptest.Server { return httptest.NewServer(mux) } -type testGetterSetter map[string]string +// newFakeSettings creates a temporary folder for files. +func newFakeSettings() *settings.Settings { + dir, err := ioutil.TempDir("", "test-settings") + if err != nil { + panic(err) + } -func (p testGetterSetter) Set(key, value string) { - p[key] = value -} - -func (p testGetterSetter) Get(key string) string { - return p[key] + return settings.New(dir) } diff --git a/internal/cookies/pantry.go b/internal/cookies/pantry.go deleted file mode 100644 index 9ede968e..00000000 --- a/internal/cookies/pantry.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package cookies - -import ( - "encoding/json" - "net/http" - "time" - - "github.com/ProtonMail/proton-bridge/internal/config/settings" -) - -// pantry persists and loads cookies to some persistent storage location. -type pantry struct { - gs GetterSetter -} - -func (p *pantry) persistCookies(host string, cookies []*http.Cookie) error { - for _, cookie := range cookies { - if cookie.MaxAge > 0 { - cookie.Expires = time.Now().Add(time.Duration(cookie.MaxAge) * time.Second) - } - } - - cookiesByHost, err := p.loadFromJSON() - if err != nil { - return err - } - - cookiesByHost[host] = cookies - - return p.saveToJSON(cookiesByHost) -} - -func (p *pantry) discardExpiredCookies() error { - cookiesByHost, err := p.loadFromJSON() - if err != nil { - return err - } - - for host, cookies := range cookiesByHost { - cookiesByHost[host] = discardExpiredCookies(cookies) - } - - return p.saveToJSON(cookiesByHost) -} - -type cookiesByHost map[string][]*http.Cookie - -func (p *pantry) loadFromJSON() (cookiesByHost, error) { - b := p.gs.Get(settings.CookiesKey) - - if b == "" { - return make(cookiesByHost), nil - } - - var cookies cookiesByHost - - if err := json.Unmarshal([]byte(b), &cookies); err != nil { - return nil, err - } - - return cookies, nil -} - -func (p *pantry) saveToJSON(cookies cookiesByHost) error { - b, err := json.Marshal(cookies) - if err != nil { - return err - } - - p.gs.Set(settings.CookiesKey, string(b)) - - return nil -} - -func discardExpiredCookies(cookies []*http.Cookie) (validCookies []*http.Cookie) { - for _, cookie := range cookies { - if cookie.Expires.After(time.Now()) { - validCookies = append(validCookies, cookie) - } - } - - return -} diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index 66513092..d66ac48f 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -331,7 +331,6 @@ func mustMarshal(t *testing.T, v interface{}) []byte { type fakeSettings struct { *settings.Settings - dir string } // newFakeSettings creates a temporary folder for files. @@ -341,10 +340,7 @@ func newFakeSettings(rollout float64, earlyAccess bool) *fakeSettings { panic(err) } - s := &fakeSettings{ - Settings: settings.New(dir), - dir: dir, - } + s := &fakeSettings{Settings: settings.New(dir)} s.SetFloat64(settings.RolloutKey, rollout)