diff --git a/go.mod b/go.mod index 893846e0..6abe7051 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/Masterminds/semver/v3 v3.1.1 github.com/ProtonMail/gluon v0.14.2-0.20230118120413-542c2bf244a0 github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a - github.com/ProtonMail/go-proton-api v0.3.0 + github.com/ProtonMail/go-proton-api v0.3.1-0.20230118091111-93ad9245e8ee github.com/ProtonMail/go-rfc5322 v0.11.0 github.com/ProtonMail/gopenpgp/v2 v2.4.10 github.com/PuerkitoBio/goquery v1.8.0 diff --git a/go.sum b/go.sum index aa67805f..d1de9bfd 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/ProtonMail/go-mime v0.0.0-20220429130430-2192574d760f h1:4IWzKjHzZxdr github.com/ProtonMail/go-mime v0.0.0-20220429130430-2192574d760f/go.mod h1:qRZgbeASl2a9OwmsV85aWwRqic0NHPh+9ewGAzb4cgM= github.com/ProtonMail/go-proton-api v0.3.0 h1:0lRWSp4bGSwWcpVFWMk++z11ZzRHhXAC9k5L6BQ4KQA= github.com/ProtonMail/go-proton-api v0.3.0/go.mod h1:JUo5IQG0hNuPRuDpOUsCOvtee6UjTEHHF1QN2i8RSos= +github.com/ProtonMail/go-proton-api v0.3.1-0.20230118091111-93ad9245e8ee h1:kXz09BKBbVLyzXgHztxIAMuvmSF0g8FgGpDaqi2IPiM= +github.com/ProtonMail/go-proton-api v0.3.1-0.20230118091111-93ad9245e8ee/go.mod h1:JUo5IQG0hNuPRuDpOUsCOvtee6UjTEHHF1QN2i8RSos= github.com/ProtonMail/go-rfc5322 v0.11.0 h1:o5Obrm4DpmQEffvgsVqG6S4BKwC1Wat+hYwjIp2YcCY= github.com/ProtonMail/go-rfc5322 v0.11.0/go.mod h1:6oOKr0jXvpoE6pwTx/HukigQpX2J9WUf6h0auplrFTw= github.com/ProtonMail/go-srp v0.0.5 h1:xhUioxZgDbCnpo9JehyFhwwsn9JLWkUGfB0oiKXgiGg= diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 1958957e..5a099279 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -586,20 +586,25 @@ func withEnv(t *testing.T, tests func(context.Context, *server.Server, *proton.N tests(ctx, server, netCtl, locations, vaultKey) } +// withMocks creates the mock objects used in the tests. +func withMocks(t *testing.T, tests func(*bridge.Mocks)) { + mocks := bridge.NewMocks(t, v2_3_0, v2_3_0) + defer mocks.Close() + + tests(mocks) +} + // withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done. -func withBridge( +func withBridgeNoMocks( ctx context.Context, t *testing.T, + mocks *bridge.Mocks, apiURL string, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte, - tests func(*bridge.Bridge, *bridge.Mocks), + tests func(*bridge.Bridge), ) { - // Create the mock objects used in the tests. - mocks := bridge.NewMocks(t, v2_3_0, v2_3_0) - defer mocks.Close() - // Bridge will disable the proxy by default at startup. mocks.ProxyCtl.EXPECT().DisallowProxy() @@ -654,7 +659,24 @@ func withBridge( defer bridge.Close(ctx) // Use the bridge. - tests(bridge, mocks) + tests(bridge) +} + +// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done. +func withBridge( + ctx context.Context, + t *testing.T, + apiURL string, + netCtl *proton.NetCtl, + locator bridge.Locator, + vaultKey []byte, + tests func(*bridge.Bridge, *bridge.Mocks), +) { + withMocks(t, func(mocks *bridge.Mocks) { + withBridgeNoMocks(ctx, t, mocks, apiURL, netCtl, locator, vaultKey, func(bridge *bridge.Bridge) { + tests(bridge, mocks) + }) + }) } func waitForEvent[T any](t *testing.T, eventCh <-chan events.Event, wantEvent T) { diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index 4c70f3a9..7fb317d3 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -20,6 +20,7 @@ package bridge_test import ( "context" "fmt" + "net/http" "testing" "time" @@ -29,6 +30,7 @@ import ( mocksPkg "github.com/ProtonMail/proton-bridge/v3/internal/bridge/mocks" "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/vault" + "github.com/bradenaw/juniper/xslices" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" ) @@ -662,6 +664,63 @@ func TestBridge_User_Refresh(t *testing.T) { }) } +func TestBridge_User_BadEvents(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + // Create a user. + userID, addrID, err := s.CreateUser("user", password) + require.NoError(t, err) + + labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder) + require.NoError(t, err) + + // Create 10 messages for the user. + withClient(ctx, t, s, "user", password, func(ctx context.Context, c *proton.Client) { + createNumMessages(ctx, t, c, addrID, labelID, 10) + }) + + // The initial user should be fully synced. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { + syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) + defer done() + + userID, err := bridge.LoginFull(ctx, "user", password, nil, nil) + require.NoError(t, err) + + require.Equal(t, userID, (<-syncCh).UserID) + }) + + var messageIDs []string + + // Create 10 more messages for the user, generating events. + withClient(ctx, t, s, "user", password, func(ctx context.Context, c *proton.Client) { + messageIDs = createNumMessages(ctx, t, c, addrID, labelID, 10) + }) + + // If bridge attempts to sync the new messages, it should get a BadRequest error. + s.AddStatusHook(func(req *http.Request) (int, bool) { + if xslices.Index(xslices.Map(messageIDs, func(messageID string) string { + return "/mail/v4/messages/" + messageID + }), req.URL.Path) < 0 { + return 0, false + } + + return http.StatusBadRequest, true + }) + + // The user will continue to process events and will receive bad request errors. + withMocks(t, func(mocks *bridge.Mocks) { + mocks.Reporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Any()).MinTimes(1) + + // The user will eventually be logged out due to the bad request errors. + withBridgeNoMocks(ctx, t, mocks, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge) { + require.Eventually(t, func() bool { + return len(bridge.GetUserIDs()) == 1 && len(getConnectedUserIDs(t, bridge)) == 0 + }, 10*time.Second, 100*time.Millisecond) + }) + }) + }) +} + // getErr returns the error that was passed to it. func getErr[T any](val T, err error) error { return err