refactor: make cookie architecture less crazy

This commit is contained in:
James Houlahan
2020-08-13 11:27:42 +02:00
parent 9f24c666b9
commit 209af59232
4 changed files with 31 additions and 37 deletions

View File

@ -275,7 +275,7 @@ func run(context *cli.Context) (contextError error) { // nolint[funlen]
cm.SetRoundTripper(cfg.GetRoundTripper(cm, eventListener)) cm.SetRoundTripper(cfg.GetRoundTripper(cm, eventListener))
// Cookies must be persisted across restarts. // Cookies must be persisted across restarts.
jar, err := cookies.NewCookieJar(cookies.NewPersister(pref)) jar, err := cookies.NewCookieJar(pref)
if err != nil { if err != nil {
logrus.WithError(err).Warn("Could not create cookie jar") logrus.WithError(err).Warn("Could not create cookie jar")
} else { } else {

View File

@ -30,18 +30,25 @@ import (
// Jar implements http.CookieJar by wrapping the standard library's cookiejar.Jar. // Jar implements http.CookieJar by wrapping the standard library's cookiejar.Jar.
// The jar uses a Persister to load cookies at startup and save cookies when set. // The jar uses a Persister to load cookies at startup and save cookies when set.
type Jar struct { type Jar struct {
jar *cookiejar.Jar jar *cookiejar.Jar
persister *Persister pantry *pantry
locker sync.Locker locker sync.Locker
} }
func NewCookieJar(persister *Persister) (*Jar, error) { type GetterSetter interface {
jar, err := cookiejar.New(nil) Get(string) string
Set(string, string)
}
func NewCookieJar(getterSetter GetterSetter) (*Jar, error) {
pantry := &pantry{prefs: getterSetter}
cookies, err := pantry.loadCookies()
if err != nil { if err != nil {
return nil, err return nil, err
} }
cookies, err := persister.Load() jar, err := cookiejar.New(nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -56,9 +63,9 @@ func NewCookieJar(persister *Persister) (*Jar, error) {
} }
return &Jar{ return &Jar{
jar: jar, jar: jar,
persister: persister, pantry: pantry,
locker: &sync.Mutex{}, locker: &sync.Mutex{},
}, nil }, nil
} }
@ -68,7 +75,7 @@ func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
j.jar.SetCookies(u, cookies) j.jar.SetCookies(u, cookies)
if err := j.persister.Persist(u.String(), cookies); err != nil { if err := j.pantry.persistCookies(u.String(), cookies); err != nil {
logrus.WithError(err).Warn("Failed to persist cookie") logrus.WithError(err).Warn("Failed to persist cookie")
} }
} }

View File

@ -36,7 +36,7 @@ func TestJar(t *testing.T) {
ts := getTestServer(t, testCookies...) ts := getTestServer(t, testCookies...)
defer ts.Close() defer ts.Close()
jar, err := NewCookieJar(NewPersister(make(testPersister))) jar, err := NewCookieJar(make(testGetterSetter))
require.NoError(t, err) require.NoError(t, err)
client := &http.Client{Jar: jar} client := &http.Client{Jar: jar}
@ -86,12 +86,12 @@ func getTestServer(t *testing.T, wantCookies ...testCookie) *httptest.Server {
return httptest.NewServer(mux) return httptest.NewServer(mux)
} }
type testPersister map[string]string type testGetterSetter map[string]string
func (p testPersister) Set(key, value string) { func (p testGetterSetter) Set(key, value string) {
p[key] = value p[key] = value
} }
func (p testPersister) Get(key string) string { func (p testGetterSetter) Get(key string) string {
return p[key] return p[key]
} }

View File

@ -24,39 +24,30 @@ import (
"github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/internal/preferences"
) )
type Persister struct { type pantry struct {
prefs GetterSetter prefs GetterSetter
} }
type GetterSetter interface { func (p *pantry) persistCookies(url string, cookies []*http.Cookie) error {
Get(string) string
Set(string, string)
}
func NewPersister(prefs GetterSetter) *Persister {
return &Persister{prefs: prefs}
}
func (p *Persister) Persist(url string, cookies []*http.Cookie) error {
b, err := json.Marshal(cookies) b, err := json.Marshal(cookies)
if err != nil { if err != nil {
return err return err
} }
val, err := p.load() val, err := p.loadFromJSON()
if err != nil { if err != nil {
return err return err
} }
val[url] = string(b) val[url] = string(b)
return p.save(val) return p.saveToJSON(val)
} }
func (p *Persister) Load() (map[string][]*http.Cookie, error) { func (p *pantry) loadCookies() (map[string][]*http.Cookie, error) {
res := make(map[string][]*http.Cookie) res := make(map[string][]*http.Cookie)
val, err := p.load() val, err := p.loadFromJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -76,15 +67,11 @@ func (p *Persister) Load() (map[string][]*http.Cookie, error) {
type dataStructure map[string]string type dataStructure map[string]string
func (p *Persister) load() (dataStructure, error) { func (p *pantry) loadFromJSON() (dataStructure, error) {
b := p.prefs.Get(preferences.CookiesKey) b := p.prefs.Get(preferences.CookiesKey)
if b == "" { if b == "" {
if err := p.save(make(dataStructure)); err != nil { return make(dataStructure), nil
return nil, err
}
return p.load()
} }
var val dataStructure var val dataStructure
@ -96,7 +83,7 @@ func (p *Persister) load() (dataStructure, error) {
return val, nil return val, nil
} }
func (p *Persister) save(val dataStructure) error { func (p *pantry) saveToJSON(val dataStructure) error {
b, err := json.Marshal(val) b, err := json.Marshal(val)
if err != nil { if err != nil {
return err return err