From 0de30afba1cfaad57d871cb86c77fd1883f663b7 Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Tue, 11 Oct 2022 18:18:08 +0200 Subject: [PATCH] Other: Better cookies test --- internal/bridge/bridge_test.go | 26 ++++++++++--------- internal/safe/set.go | 46 ++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 12 deletions(-) create mode 100644 internal/safe/set.go diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index fc534d72..77fedc8b 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -15,6 +15,7 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/focus" "github.com/ProtonMail/proton-bridge/v2/internal/locations" + "github.com/ProtonMail/proton-bridge/v2/internal/safe" "github.com/ProtonMail/proton-bridge/v2/internal/updater" "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/useragent" @@ -134,31 +135,32 @@ func TestBridge_UserAgent(t *testing.T) { func TestBridge_Cookies(t *testing.T) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { - var calls []server.Call + sessionIDs := safe.NewSet[string]() + // Save any session IDs the API returns. s.AddCallWatcher(func(call server.Call) { - calls = append(calls, call) - }) + cookie, err := (&http.Request{Header: call.Header}).Cookie("Session-Id") + if err != nil { + return + } - var sessionID string + sessionIDs.Insert(cookie.Value) + }) // Start bridge and add a user so that API assigns us a session ID via cookie. withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { _, err := bridge.LoginUser(context.Background(), username, password, nil, nil) require.NoError(t, err) - - cookie, err := (&http.Request{Header: calls[len(calls)-1].Header}).Cookie("Session-Id") - require.NoError(t, err) - - sessionID = cookie.Value }) // Start bridge again and check that it uses the same session ID. withBridge(t, ctx, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - cookie, err := (&http.Request{Header: calls[len(calls)-1].Header}).Cookie("Session-Id") - require.NoError(t, err) + // ... + }) - require.Equal(t, sessionID, cookie.Value) + // We should have used just one session ID. + sessionIDs.Values(func(sessionIDs []string) { + require.Len(t, sessionIDs, 1) }) }) } diff --git a/internal/safe/set.go b/internal/safe/set.go new file mode 100644 index 00000000..7d844c95 --- /dev/null +++ b/internal/safe/set.go @@ -0,0 +1,46 @@ +package safe + +import "golang.org/x/exp/maps" + +type Set[Val comparable] Map[Val, struct{}] + +func NewSet[Val comparable](vals ...Val) *Set[Val] { + set := (*Set[Val])(NewMap[Val, struct{}](nil)) + + for _, val := range vals { + set.Insert(val) + } + + return set +} + +func (m *Set[Val]) Has(key Val) bool { + m.lock.RLock() + defer m.lock.RUnlock() + + _, ok := m.data[key] + return ok +} + +func (m *Set[Val]) Insert(key Val) { + m.lock.Lock() + defer m.lock.Unlock() + + m.data[key] = struct{}{} +} + +func (m *Set[Val]) Iter(fn func(key Val)) { + m.lock.RLock() + defer m.lock.RUnlock() + + for key := range m.data { + fn(key) + } +} + +func (m *Set[Val]) Values(fn func(vals []Val)) { + m.lock.RLock() + defer m.lock.RUnlock() + + fn(maps.Keys(m.data)) +}