diff --git a/go.mod b/go.mod index 9b125df8..6956de98 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/Masterminds/semver/v3 v3.1.1 github.com/ProtonMail/gluon v0.14.2-0.20221207071431-0faa318d3c9f github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a - github.com/ProtonMail/go-proton-api v0.2.2-0.20221212093343-0afe67dc1c50 + github.com/ProtonMail/go-proton-api v0.2.2-0.20221213042823-5bfe853434e7 github.com/ProtonMail/go-rfc5322 v0.11.0 github.com/ProtonMail/gopenpgp/v2 v2.4.10 github.com/PuerkitoBio/goquery v1.8.0 diff --git a/go.sum b/go.sum index 58848305..4e0d80d6 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,6 @@ github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo= github.com/ProtonMail/docker-credential-helpers v1.1.0 h1:+kvUIpwWcbtP3WFv5sSvkFn/XLzSqPOB5AAthuk9xPk= github.com/ProtonMail/docker-credential-helpers v1.1.0/go.mod h1:mK0aBveCxhnQ756AmaTfXMZDeULvheYVhF/MWMErN5g= -github.com/ProtonMail/gluon v0.14.2-0.20221206104410-725ddb9db68a h1:BwWVZcvvf9Pw353+wZGD3X433kPFT4SjQVnYKD0YBRY= -github.com/ProtonMail/gluon v0.14.2-0.20221206104410-725ddb9db68a/go.mod h1:z2AxLIiBCT1K+0OBHyaDI7AEaO5qI6/BEC2TE42vs4Q= github.com/ProtonMail/gluon v0.14.2-0.20221207071431-0faa318d3c9f h1:73b28jayIkYr1cJPHSHFMGyFgk1h6iJ127kuYX8UaRo= github.com/ProtonMail/gluon v0.14.2-0.20221207071431-0faa318d3c9f/go.mod h1:z2AxLIiBCT1K+0OBHyaDI7AEaO5qI6/BEC2TE42vs4Q= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4= @@ -45,10 +43,8 @@ github.com/ProtonMail/go-message v0.0.0-20210611055058-fabeff2ec753/go.mod h1:NB github.com/ProtonMail/go-mime v0.0.0-20220302105931-303f85f7fe0f/go.mod h1:NYt+V3/4rEeDuaev/zw1zCq8uqVEuPHzDPo3OZrlGJ4= github.com/ProtonMail/go-mime v0.0.0-20220429130430-2192574d760f h1:4IWzKjHzZxdrW9k4zl/qCwenOVHDbVDADPPHFLjs0Oc= github.com/ProtonMail/go-mime v0.0.0-20220429130430-2192574d760f/go.mod h1:qRZgbeASl2a9OwmsV85aWwRqic0NHPh+9ewGAzb4cgM= -github.com/ProtonMail/go-proton-api v0.2.1 h1:M15/zzfx6EPiskv2+gogUkmvx7Y1SmRRtLT6GiBh5T0= -github.com/ProtonMail/go-proton-api v0.2.1/go.mod h1:jqvJ2HqLHqiPJoEb+BTIB1IF7wvr6p+8ZfA6PO2NRNk= -github.com/ProtonMail/go-proton-api v0.2.2-0.20221212093343-0afe67dc1c50 h1:DXcvmx1sx20YsFmP40kCRYD+jzBy6OLjvFQFCbht4ZI= -github.com/ProtonMail/go-proton-api v0.2.2-0.20221212093343-0afe67dc1c50/go.mod h1:jqvJ2HqLHqiPJoEb+BTIB1IF7wvr6p+8ZfA6PO2NRNk= +github.com/ProtonMail/go-proton-api v0.2.2-0.20221213042823-5bfe853434e7 h1:jreVsSvIlslQpDks/OhEL1YyI1XBmhgzYDdoQ9u68UA= +github.com/ProtonMail/go-proton-api v0.2.2-0.20221213042823-5bfe853434e7/go.mod h1:O7ZTIDOhJRkfQgtW8dB0ZSCq8OZsShjMQ3ahzpDheOk= github.com/ProtonMail/go-rfc5322 v0.11.0 h1:o5Obrm4DpmQEffvgsVqG6S4BKwC1Wat+hYwjIp2YcCY= github.com/ProtonMail/go-rfc5322 v0.11.0/go.mod h1:6oOKr0jXvpoE6pwTx/HukigQpX2J9WUf6h0auplrFTw= github.com/ProtonMail/go-srp v0.0.5 h1:xhUioxZgDbCnpo9JehyFhwwsn9JLWkUGfB0oiKXgiGg= diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 2c26fecb..2ae4f9d6 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -59,7 +59,7 @@ var ( func init() { user.EventPeriod = 100 * time.Millisecond user.EventJitter = 0 - backend.GenerateKey = tests.FastGenerateKey + backend.GenerateKey = backend.FastGenerateKey certs.GenerateCert = tests.FastGenerateCert } @@ -384,7 +384,7 @@ func TestBridge_AddressWithoutKeys(t *testing.T) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Create a user which will have an address without keys. - userID, _, err := s.CreateUser("nokeys", "nokeys@pm.me", []byte("password")) + userID, _, err := s.CreateUser("nokeys", []byte("password")) require.NoError(t, err) // Create an additional address for the user; it will not have keys. @@ -501,7 +501,7 @@ func withEnv(t *testing.T, tests func(context.Context, *server.Server, *proton.N defer server.Close() // Add test user. - _, _, err := server.CreateUser(username, username+"@pm.me", password) + _, _, err := server.CreateUser(username, password) require.NoError(t, err) // Generate a random vault key. @@ -566,7 +566,7 @@ func withBridge( cookieJar, useragent.New(), mocks.TLSReporter, - proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(), + netCtl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}), mocks.ProxyCtl, mocks.CrashHandler, mocks.Reporter, diff --git a/internal/bridge/refresh_test.go b/internal/bridge/refresh_test.go index eec9a7fa..4933065c 100644 --- a/internal/bridge/refresh_test.go +++ b/internal/bridge/refresh_test.go @@ -35,7 +35,7 @@ import ( func TestBridge_Refresh(t *testing.T) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { - userID, _, err := s.CreateUser("imap", "imap@pm.me", password) + userID, _, err := s.CreateUser("imap", password) require.NoError(t, err) names := iterator.Collect(iterator.Map(iterator.Counter(10), func(i int) string { @@ -67,7 +67,7 @@ func TestBridge_Refresh(t *testing.T) { client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) require.NoError(t, err) - require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) defer func() { _ = client.Logout() }() for _, name := range names { @@ -100,7 +100,7 @@ func TestBridge_Refresh(t *testing.T) { client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) require.NoError(t, err) - require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) defer func() { _ = client.Logout() }() for _, name := range names { diff --git a/internal/bridge/send_test.go b/internal/bridge/send_test.go index e514a34c..7025a58f 100644 --- a/internal/bridge/send_test.go +++ b/internal/bridge/send_test.go @@ -39,7 +39,7 @@ import ( func TestBridge_Send(t *testing.T) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { - _, _, err := s.CreateUser("recipient", "recipient@pm.me", password) + _, _, err := s.CreateUser("recipient", password) require.NoError(t, err) withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { diff --git a/internal/bridge/sync_test.go b/internal/bridge/sync_test.go index 869de9f0..776a8919 100644 --- a/internal/bridge/sync_test.go +++ b/internal/bridge/sync_test.go @@ -46,7 +46,7 @@ func TestBridge_Sync(t *testing.T) { numMsg := 1 << 8 withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { - userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password) + userID, addrID, err := s.CreateUser("imap", password) require.NoError(t, err) labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder) @@ -59,7 +59,7 @@ func TestBridge_Sync(t *testing.T) { var total uint64 // The initial user should be fully synced. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) defer done() @@ -73,14 +73,14 @@ func TestBridge_Sync(t *testing.T) { }) // If we then connect an IMAP client, it should see all the messages. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(b *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(b *bridge.Bridge, _ *bridge.Mocks) { info, err := b.GetUserInfo(userID) require.NoError(t, err) require.True(t, info.State == bridge.Connected) client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) require.NoError(t, err) - require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) defer func() { _ = client.Logout() }() status, err := client.Select(`Folders/folder`, false) @@ -89,7 +89,7 @@ func TestBridge_Sync(t *testing.T) { }) // Now let's remove the user and simulate a network error. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) { require.NoError(t, bridge.DeleteUser(ctx, userID)) }) @@ -97,7 +97,7 @@ func TestBridge_Sync(t *testing.T) { netCtl.SetReadLimit(2 * total / 3) // Login the user; its sync should fail. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(b *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(b *bridge.Bridge, _ *bridge.Mocks) { { syncCh, done := chToType[events.Event, events.SyncFailed](b.GetEvents(events.SyncFailed{})) defer done() @@ -113,7 +113,7 @@ func TestBridge_Sync(t *testing.T) { client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) require.NoError(t, err) - require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) defer func() { _ = client.Logout() }() status, err := client.Select(`Folders/folder`, false) @@ -136,7 +136,7 @@ func TestBridge_Sync(t *testing.T) { client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) require.NoError(t, err) - require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) defer func() { _ = client.Logout() }() status, err := client.Select(`Folders/folder`, false) @@ -149,7 +149,7 @@ func TestBridge_Sync(t *testing.T) { func TestBridge_Sync_BadMessage(t *testing.T) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { - userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password) + userID, addrID, err := s.CreateUser("imap", password) require.NoError(t, err) labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder) @@ -179,14 +179,14 @@ func TestBridge_Sync_BadMessage(t *testing.T) { }) // If we then connect an IMAP client, it should see the good message but not the bad one. - withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(b *bridge.Bridge, mocks *bridge.Mocks) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(b *bridge.Bridge, _ *bridge.Mocks) { info, err := b.GetUserInfo(userID) require.NoError(t, err) require.True(t, info.State == bridge.Connected) client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort())) require.NoError(t, err) - require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass))) + require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass))) defer func() { _ = client.Logout() }() status, err := client.Select(`Folders/folder`, false) diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index b57e73c7..444e55ae 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -592,7 +592,7 @@ func TestBridge_UserInfo_Alias(t *testing.T) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { // Create a new user. - userID, _, err := s.CreateUser("primary", "primary@pm.me", []byte("password")) + userID, _, err := s.CreateUser("primary", []byte("password")) require.NoError(t, err) // Give the new user an alias. @@ -606,7 +606,7 @@ func TestBridge_UserInfo_Alias(t *testing.T) { require.NoError(t, err) // The user should have two addresses, the primary should be first. - require.Equal(t, []string{"primary@pm.me", "alias@pm.me"}, info.Addresses) + require.Equal(t, []string{"primary@" + s.GetDomain(), "alias@pm.me"}, info.Addresses) }) }) } diff --git a/internal/user/user_test.go b/internal/user/user_test.go index 66bba5e0..18b05757 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -39,7 +39,7 @@ import ( func init() { EventPeriod = 100 * time.Millisecond EventJitter = 0 - backend.GenerateKey = tests.FastGenerateKey + backend.GenerateKey = backend.FastGenerateKey certs.GenerateCert = tests.FastGenerateCert } @@ -49,7 +49,7 @@ func TestMain(m *testing.M) { func TestUser_Info(t *testing.T) { withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *proton.Manager) { - withAccount(t, s, "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(userID string, _ []string) { + withAccount(t, s, "username", "password", []string{"alias@pm.me"}, func(userID string, _ []string) { withUser(t, ctx, s, m, "username", "password", func(user *User) { // User's ID should be correct. require.Equal(t, userID, user.ID()) @@ -58,7 +58,7 @@ func TestUser_Info(t *testing.T) { require.Equal(t, "username", user.Name()) // User's email should be correct. - require.ElementsMatch(t, []string{"email@pm.me", "alias@pm.me"}, user.Emails()) + require.ElementsMatch(t, []string{"username@" + s.GetDomain(), "alias@pm.me"}, user.Emails()) // By default, user should be in combined mode. require.Equal(t, vault.CombinedMode, user.GetAddressMode()) @@ -72,7 +72,7 @@ func TestUser_Info(t *testing.T) { func TestUser_Sync(t *testing.T) { withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *proton.Manager) { - withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(string, []string) { + withAccount(t, s, "username", "password", []string{}, func(string, []string) { withUser(t, ctx, s, m, "username", "password", func(user *User) { // User starts a sync at startup. require.IsType(t, events.SyncStarted{}, <-user.GetEventCh()) @@ -89,7 +89,7 @@ func TestUser_Sync(t *testing.T) { func TestUser_AddressMode(t *testing.T) { withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *proton.Manager) { - withAccount(t, s, "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(string, []string) { + withAccount(t, s, "username", "password", []string{}, func(string, []string) { withUser(t, ctx, s, m, "username", "password", func(user *User) { // User finishes syncing at startup. require.IsType(t, events.SyncStarted{}, <-user.GetEventCh()) @@ -126,7 +126,7 @@ func TestUser_AddressMode(t *testing.T) { func TestUser_Deauth(t *testing.T) { withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *proton.Manager) { - withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(string, []string) { + withAccount(t, s, "username", "password", []string{}, func(string, []string) { withUser(t, ctx, s, m, "username", "password", func(user *User) { require.IsType(t, events.SyncStarted{}, <-user.GetEventCh()) require.IsType(t, events.SyncProgress{}, <-user.GetEventCh()) @@ -147,7 +147,7 @@ func TestUser_Refresh(t *testing.T) { mockReporter := mocks.NewMockReporter(ctl) withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *proton.Manager) { - withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(string, []string) { + withAccount(t, s, "username", "password", []string{}, func(string, []string) { withUser(t, ctx, s, m, "username", "password", func(user *User) { require.IsType(t, events.SyncStarted{}, <-user.GetEventCh()) require.IsType(t, events.SyncProgress{}, <-user.GetEventCh()) @@ -180,15 +180,13 @@ func withAPI(_ testing.TB, ctx context.Context, fn func(context.Context, *server )) } -func withAccount(tb testing.TB, s *server.Server, username, password string, emails []string, fn func(string, []string)) { //nolint:unparam - userID, addrID, err := s.CreateUser(username, emails[0], []byte(password)) +func withAccount(tb testing.TB, s *server.Server, username, password string, aliases []string, fn func(string, []string)) { //nolint:unparam + userID, addrID, err := s.CreateUser(username, []byte(password)) require.NoError(tb, err) - addrIDs := make([]string, 0, len(emails)) + addrIDs := []string{addrID} - addrIDs = append(addrIDs, addrID) - - for _, email := range emails[1:] { + for _, email := range aliases { addrID, err := s.CreateAddress(userID, email, []byte(password)) require.NoError(tb, err) diff --git a/tests/api_test.go b/tests/api_test.go index da2d26fd..1920062e 100644 --- a/tests/api_test.go +++ b/tests/api_test.go @@ -27,10 +27,6 @@ type API interface { GetHostURL() string AddCallWatcher(func(server.Call), ...string) - - CreateUser(username, address string, password []byte) (string, string, error) - CreateAddress(userID, address string, password []byte) (string, error) - RemoveAddress(userID, addrID string) error RemoveAddressKey(userID, addrID, keyID string) error Close() diff --git a/tests/bdd_test.go b/tests/bdd_test.go index 25ad2228..e1281352 100644 --- a/tests/bdd_test.go +++ b/tests/bdd_test.go @@ -41,21 +41,26 @@ func (s *scenario) close(_ testing.TB) { } func TestFeatures(testingT *testing.T) { - paths := []string{"features"} - if features := os.Getenv("FEATURES"); features != "" { - paths = strings.Split(features, " ") - } + var s scenario suite := godog.TestSuite{ - ScenarioInitializer: func(ctx *godog.ScenarioContext) { - var s scenario + TestSuiteInitializer: func(ctx *godog.TestSuiteContext) { + ctx.BeforeSuite(func() { + // Global setup. + }) - ctx.Before(func(ctx context.Context, sc *godog.Scenario) (context.Context, error) { + ctx.AfterSuite(func() { + // Global teardown. + }) + }, + + ScenarioInitializer: func(ctx *godog.ScenarioContext) { + ctx.Before(func(ctx context.Context, _ *godog.Scenario) (context.Context, error) { s.reset(testingT) return ctx, nil }) - ctx.After(func(ctx context.Context, sc *godog.Scenario, err error) (context.Context, error) { + ctx.After(func(ctx context.Context, _ *godog.Scenario, _ error) (context.Context, error) { s.close(testingT) return ctx, nil }) @@ -72,7 +77,7 @@ func TestFeatures(testingT *testing.T) { return ctx, nil }) - ctx.StepContext().After(func(ctx context.Context, st *godog.Step, status godog.StepResultStatus, err error) (context.Context, error) { + ctx.StepContext().After(func(ctx context.Context, st *godog.Step, status godog.StepResultStatus, _ error) (context.Context, error) { logrus.Debugf("Finished step (%v): %s", status, st.Text) return ctx, nil }) @@ -197,7 +202,7 @@ func TestFeatures(testingT *testing.T) { }, Options: &godog.Options{ Format: "pretty", - Paths: paths, + Paths: getFeaturePaths(), TestingT: testingT, }, } @@ -206,3 +211,15 @@ func TestFeatures(testingT *testing.T) { testingT.Fatal("non-zero status returned, failed to run feature tests") } } + +func getFeaturePaths() []string { + var paths []string + + if features := os.Getenv("FEATURES"); features != "" { + paths = strings.Split(features, " ") + } else { + paths = []string{"features"} + } + + return paths +} diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index a37ccff5..e68ac08e 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -30,7 +30,6 @@ import ( "time" "github.com/ProtonMail/gluon/queue" - "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/proton-bridge/v3/internal/bridge" "github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/cookies" @@ -157,7 +156,7 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) { persister, useragent.New(), t.mocks.TLSReporter, - proton.NewDialer(t.netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(), + t.netCtl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}), t.mocks.ProxyCtl, t.mocks.CrashHandler, t.reporter, diff --git a/tests/ctx_helper_test.go b/tests/ctx_helper_test.go index a7e63307..0e7fc6e4 100644 --- a/tests/ctx_helper_test.go +++ b/tests/ctx_helper_test.go @@ -21,39 +21,54 @@ import ( "context" "fmt" "runtime" - "sync" - "testing" - "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/stream" - "github.com/bradenaw/juniper/xslices" - "github.com/golang/mock/gomock" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" ) -func (t *testCtx) withClient(ctx context.Context, username string, fn func(context.Context, *proton.Client) error) error { - c, _, err := proton.New( +// withProton executes the given function with a proton manager configured to use the test API. +func (t *testCtx) withProton(fn func(*proton.Manager) error) error { + m := proton.New( proton.WithHostURL(t.api.GetHostURL()), proton.WithTransport(proton.InsecureTransport()), - ).NewClientWithLogin(ctx, username, []byte(t.getUserPass(t.getUserID(username)))) - if err != nil { - return err - } + ) + defer m.Close() - defer c.Close() + return fn(m) +} - if err := fn(ctx, c); err != nil { - return fmt.Errorf("failed to execute with client: %w", err) - } +// withClient executes the given function with a client that is logged in as the given (known) user. +func (t *testCtx) withClient(ctx context.Context, username string, fn func(context.Context, *proton.Client) error) error { + return t.withClientPass(ctx, username, t.getUserPass(t.getUserID(username)), fn) +} - if err := c.AuthDelete(ctx); err != nil { - return fmt.Errorf("failed to delete auth: %w", err) - } +// withClient executes the given function with a client that is logged in with the given username and password. +func (t *testCtx) withClientPass(ctx context.Context, username, password string, fn func(context.Context, *proton.Client) error) error { + return t.withProton(func(m *proton.Manager) error { + c, _, err := m.NewClientWithLogin(ctx, username, []byte(password)) + if err != nil { + return fmt.Errorf("failed to create client: %w", err) + } + defer c.Close() - return nil + if err := fn(ctx, c); err != nil { + return fmt.Errorf("failed to execute with client: %w", err) + } + + if err := c.AuthDelete(ctx); err != nil { + return fmt.Errorf("failed to delete auth: %w", err) + } + + return nil + }) +} + +// runQuarkCmd runs the given quark command with the given arguments. +func (t *testCtx) runQuarkCmd(ctx context.Context, command string, args ...string) error { + return t.withProton(func(m *proton.Manager) error { + return m.Quark(ctx, command, args...) + }) } func (t *testCtx) withAddrKR( @@ -107,123 +122,3 @@ func (t *testCtx) createMessages(ctx context.Context, username, addrID string, r }) }) } - -type reportRecord struct { - isException bool - message string - context reporter.Context -} - -type reportRecorder struct { - assert *assert.Assertions - reports []reportRecord - - lock sync.Locker - isClosed bool -} - -func newReportRecorder(tb testing.TB) *reportRecorder { - return &reportRecorder{ - assert: assert.New(tb), - reports: []reportRecord{}, - lock: &sync.Mutex{}, - isClosed: false, - } -} - -func (r *reportRecorder) add(isException bool, message string, context reporter.Context) { - r.lock.Lock() - defer r.lock.Unlock() - - l := logrus.WithFields(logrus.Fields{ - "isException": isException, - "message": message, - "context": context, - "pkg": "test/reportRecorder", - }) - - if r.isClosed { - l.Warn("Reporter closed, report skipped") - return - } - - r.reports = append(r.reports, reportRecord{ - isException: isException, - message: message, - context: context, - }) - - l.Warn("Report recorded") -} - -func (r *reportRecorder) close() { - r.lock.Lock() - defer r.lock.Unlock() - - r.isClosed = true -} - -func (r *reportRecorder) assertEmpty() { - r.assert.Empty(r.reports) -} - -func (r *reportRecorder) removeMatchingRecords(isException, message, context gomock.Matcher, n int) { - if n == 0 { - n = len(r.reports) - } - - r.reports = xslices.Filter(r.reports, func(rec reportRecord) bool { - if n <= 0 { - return true - } - - l := logrus.WithFields(logrus.Fields{ - "rec": rec, - }) - if !isException.Matches(rec.isException) { - l.WithField("matcher", isException).Debug("Not matching") - return true - } - - if !message.Matches(rec.message) { - l.WithField("matcher", message).Debug("Not matching") - return true - } - - if !context.Matches(rec.context) { - l.WithField("matcher", context).Debug("Not matching") - return true - } - - n-- - - return false - }) -} - -func (r *reportRecorder) ReportException(data any) error { - r.add(true, "exception", reporter.Context{"data": data}) - return nil -} - -func (r *reportRecorder) ReportMessage(message string) error { - r.add(false, message, reporter.Context{}) - return nil -} - -func (r *reportRecorder) ReportMessageWithContext(message string, context reporter.Context) error { - r.add(false, message, context) - return nil -} - -func (r *reportRecorder) ReportExceptionWithContext(data any, context reporter.Context) error { - if context == nil { - context = reporter.Context{} - } - - context["data"] = data - - r.add(true, "exception", context) - - return nil -} diff --git a/tests/ctx_reporter_test.go b/tests/ctx_reporter_test.go new file mode 100644 index 00000000..7a248489 --- /dev/null +++ b/tests/ctx_reporter_test.go @@ -0,0 +1,132 @@ +package tests + +import ( + "sync" + "testing" + + "github.com/ProtonMail/gluon/reporter" + "github.com/bradenaw/juniper/xslices" + "github.com/golang/mock/gomock" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +type reportRecord struct { + isException bool + message string + context reporter.Context +} + +type reportRecorder struct { + assert *assert.Assertions + reports []reportRecord + + lock sync.Locker + isClosed bool +} + +func newReportRecorder(tb testing.TB) *reportRecorder { + return &reportRecorder{ + assert: assert.New(tb), + reports: []reportRecord{}, + lock: &sync.Mutex{}, + isClosed: false, + } +} + +func (r *reportRecorder) add(isException bool, message string, context reporter.Context) { + r.lock.Lock() + defer r.lock.Unlock() + + l := logrus.WithFields(logrus.Fields{ + "isException": isException, + "message": message, + "context": context, + "pkg": "test/reportRecorder", + }) + + if r.isClosed { + l.Warn("Reporter closed, report skipped") + return + } + + r.reports = append(r.reports, reportRecord{ + isException: isException, + message: message, + context: context, + }) + + l.Warn("Report recorded") +} + +func (r *reportRecorder) close() { + r.lock.Lock() + defer r.lock.Unlock() + + r.isClosed = true +} + +func (r *reportRecorder) assertEmpty() { + r.assert.Empty(r.reports) +} + +func (r *reportRecorder) removeMatchingRecords(isException, message, context gomock.Matcher, n int) { + if n == 0 { + n = len(r.reports) + } + + r.reports = xslices.Filter(r.reports, func(rec reportRecord) bool { + if n <= 0 { + return true + } + + l := logrus.WithFields(logrus.Fields{ + "rec": rec, + }) + if !isException.Matches(rec.isException) { + l.WithField("matcher", isException).Debug("Not matching") + return true + } + + if !message.Matches(rec.message) { + l.WithField("matcher", message).Debug("Not matching") + return true + } + + if !context.Matches(rec.context) { + l.WithField("matcher", context).Debug("Not matching") + return true + } + + n-- + + return false + }) +} + +func (r *reportRecorder) ReportException(data any) error { + r.add(true, "exception", reporter.Context{"data": data}) + return nil +} + +func (r *reportRecorder) ReportMessage(message string) error { + r.add(false, message, reporter.Context{}) + return nil +} + +func (r *reportRecorder) ReportMessageWithContext(message string, context reporter.Context) error { + r.add(false, message, context) + return nil +} + +func (r *reportRecorder) ReportExceptionWithContext(data any, context reporter.Context) error { + if context == nil { + context = reporter.Context{} + } + + context["data"] = data + + r.add(true, "exception", context) + + return nil +} diff --git a/tests/ctx_test.go b/tests/ctx_test.go index e489714c..aef37f69 100644 --- a/tests/ctx_test.go +++ b/tests/ctx_test.go @@ -34,7 +34,6 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/locations" "github.com/bradenaw/juniper/xslices" "github.com/emersion/go-imap/client" - "github.com/golang/mock/gomock" "github.com/sirupsen/logrus" "golang.org/x/exp/maps" "google.golang.org/grpc" @@ -317,18 +316,7 @@ func (t *testCtx) close(ctx context.Context) { } t.api.Close() - t.events.close() - t.reporter.close() - - // Closed connection can happen in the end of scenario - t.reporter.removeMatchingRecords( - gomock.Eq(false), - gomock.Eq("Failed to parse imap command"), - gomock.Any(), // mocks.NewClosedConnectionMatcher(), - 0, - ) - t.reporter.assertEmpty() } diff --git a/tests/fast.go b/tests/fast.go index 1805b339..508d3ceb 100644 --- a/tests/fast.go +++ b/tests/fast.go @@ -20,35 +20,19 @@ package tests import ( "crypto/x509" - "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v3/internal/certs" ) var ( - preCompPGPKey *crypto.Key preCompCertPEM []byte preCompKeyPEM []byte ) -func FastGenerateKey(name, email string, passphrase []byte, keyType string, bits int) (string, error) { - encKey, err := preCompPGPKey.Lock(passphrase) - if err != nil { - return "", err - } - - return encKey.Armor() -} - func FastGenerateCert(template *x509.Certificate) ([]byte, []byte, error) { return preCompCertPEM, preCompKeyPEM, nil } func init() { - key, err := crypto.GenerateKey("name", "email", "rsa", 1024) - if err != nil { - panic(err) - } - template, err := certs.NewTLSTemplate() if err != nil { panic(err) @@ -59,7 +43,6 @@ func init() { panic(err) } - preCompPGPKey = key preCompCertPEM = certPEM preCompKeyPEM = keyPEM } diff --git a/tests/init_test.go b/tests/init_test.go index 7aca5ad5..498438cf 100644 --- a/tests/init_test.go +++ b/tests/init_test.go @@ -27,7 +27,7 @@ import ( func init() { // Use the fast key generation for tests. - backend.GenerateKey = FastGenerateKey + backend.GenerateKey = backend.FastGenerateKey // Use the fast cert generation for tests. certs.GenerateCert = FastGenerateCert diff --git a/tests/user_test.go b/tests/user_test.go index 786427ad..61c6d9ef 100644 --- a/tests/user_test.go +++ b/tests/user_test.go @@ -34,42 +34,80 @@ import ( ) func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error { - // Create the user. - userID, addrID, err := s.t.api.CreateUser(username, username, []byte(password)) - if err != nil { + // Create the user and generate its default address (with keys). + if err := s.t.runQuarkCmd( + context.Background(), + "user:create", + "--name", username, + "--password", password, + "--gen-keys", "RSA2048", + ); err != nil { return err } - // Set the ID of this user. - s.t.setUserID(username, userID) + return s.t.withClientPass(context.Background(), username, password, func(ctx context.Context, c *proton.Client) error { + user, err := c.GetUser(ctx) + if err != nil { + return err + } - // Set the password of this user. - s.t.setUserPass(userID, password) + addr, err := c.GetAddresses(ctx) + if err != nil { + return err + } - // Set the address of this user (right now just the same as the username, but let's stay flexible). - s.t.setUserAddr(userID, addrID, username) + // Set the ID of the user. + s.t.setUserID(username, user.ID) - return nil + // Set the password of the user. + s.t.setUserPass(user.ID, password) + + // Set the address of the user. + s.t.setUserAddr(user.ID, addr[0].ID, addr[0].Email) + + return nil + }) } func (s *scenario) theAccountHasAdditionalAddress(username, address string) error { userID := s.t.getUserID(username) - addrID, err := s.t.api.CreateAddress(userID, address, []byte(s.t.getUserPass(userID))) - if err != nil { + // Create the user's additional address. + if err := s.t.runQuarkCmd( + context.Background(), + "user:create:address", + userID, + s.t.getUserPass(userID), + address, + "--gen-keys", "RSA2048", + ); err != nil { return err } - s.t.setUserAddr(userID, addrID, address) + return s.t.withClient(context.Background(), username, func(ctx context.Context, c *proton.Client) error { + addr, err := c.GetAddresses(ctx) + if err != nil { + return err + } - return nil + // Set the new address of the user. + s.t.setUserAddr(userID, addr[len(addr)-1].ID, address) + + return nil + }) } func (s *scenario) theAccountNoLongerHasAdditionalAddress(username, address string) error { userID := s.t.getUserID(username) addrID := s.t.getUserAddrID(userID, address) - if err := s.t.api.RemoveAddress(userID, addrID); err != nil { + if err := s.t.withClient(context.Background(), username, func(ctx context.Context, c *proton.Client) error { + if err := c.DisableAddress(ctx, addrID); err != nil { + return err + } + + return c.DeleteAddress(ctx, addrID) + }); err != nil { return err }