From c9d496956ce1250d4475d99f8f9ec5f0432b32ee Mon Sep 17 00:00:00 2001 From: James Houlahan Date: Tue, 7 Feb 2023 17:32:44 +0100 Subject: [PATCH] test: Refactor account management, fix map-random-order race condition Some SMTP tests made use of disabled addresses. We stored addresses in a map, meaning the order was randomized. This lead to tests sometimes attempting to authenticate over SMTP using a disabled address, failing. --- tests/bridge_test.go | 14 +-- tests/ctx_helper_test.go | 4 +- tests/ctx_test.go | 161 ++++++++++++++++++------------ tests/features/user/login.feature | 4 +- tests/imap_test.go | 26 ++--- tests/smtp_test.go | 18 ++-- tests/user_test.go | 55 +++++----- 7 files changed, 154 insertions(+), 128 deletions(-) diff --git a/tests/bridge_test.go b/tests/bridge_test.go index 9e57773a..45f6e67a 100644 --- a/tests/bridge_test.go +++ b/tests/bridge_test.go @@ -68,10 +68,10 @@ func (s *scenario) theUserChangesTheSMTPPortTo(port int) error { func (s *scenario) theUserSetsTheAddressModeOfUserTo(user, mode string) error { switch mode { case "split": - return s.t.bridge.SetAddressMode(context.Background(), s.t.getUserID(user), vault.SplitMode) + return s.t.bridge.SetAddressMode(context.Background(), s.t.getUserByName(user).getUserID(), vault.SplitMode) case "combined": - return s.t.bridge.SetAddressMode(context.Background(), s.t.getUserID(user), vault.CombinedMode) + return s.t.bridge.SetAddressMode(context.Background(), s.t.getUserByName(user).getUserID(), vault.CombinedMode) default: return fmt.Errorf("unknown address mode %q", mode) @@ -156,7 +156,7 @@ func (s *scenario) bridgeSendsADeauthEventForUser(username string) error { return errors.New("expected deauth event, got none") } - if wantUserID := s.t.getUserID(username); event.UserID != wantUserID { + if wantUserID := s.t.getUserByName(username).getUserID(); event.UserID != wantUserID { return fmt.Errorf("expected deauth event for user %s, got %s", wantUserID, event.UserID) } @@ -169,7 +169,7 @@ func (s *scenario) bridgeSendsAnAddressCreatedEventForUser(username string) erro return errors.New("expected address created event, got none") } - if wantUserID := s.t.getUserID(username); event.UserID != wantUserID { + if wantUserID := s.t.getUserByName(username).getUserID(); event.UserID != wantUserID { return fmt.Errorf("expected address created event for user %s, got %s", wantUserID, event.UserID) } @@ -182,7 +182,7 @@ func (s *scenario) bridgeSendsAnAddressDeletedEventForUser(username string) erro return errors.New("expected address deleted event, got none") } - if wantUserID := s.t.getUserID(username); event.UserID != wantUserID { + if wantUserID := s.t.getUserByName(username).getUserID(); event.UserID != wantUserID { return fmt.Errorf("expected address deleted event for user %s, got %s", wantUserID, event.UserID) } @@ -195,7 +195,7 @@ func (s *scenario) bridgeSendsSyncStartedAndFinishedEventsForUser(username strin return errors.New("expected sync started event, got none") } - if wantUserID := s.t.getUserID(username); startEvent.UserID != wantUserID { + if wantUserID := s.t.getUserByName(username).getUserID(); startEvent.UserID != wantUserID { return fmt.Errorf("expected sync started event for user %s, got %s", wantUserID, startEvent.UserID) } @@ -204,7 +204,7 @@ func (s *scenario) bridgeSendsSyncStartedAndFinishedEventsForUser(username strin return errors.New("expected sync finished event, got none") } - if wantUserID := s.t.getUserID(username); finishEvent.UserID != wantUserID { + if wantUserID := s.t.getUserByName(username).getUserID(); finishEvent.UserID != wantUserID { return fmt.Errorf("expected sync finished event for user %s, got %s", wantUserID, finishEvent.UserID) } diff --git a/tests/ctx_helper_test.go b/tests/ctx_helper_test.go index 4bf0103c..3c803b5e 100644 --- a/tests/ctx_helper_test.go +++ b/tests/ctx_helper_test.go @@ -43,7 +43,7 @@ func (t *testCtx) withProton(fn func(*proton.Manager) error) error { // withClient executes the given function with a client that is logged in as the given (known) user. func (t *testCtx) withClient(ctx context.Context, username string, fn func(context.Context, *proton.Client) error) error { - return t.withClientPass(ctx, username, t.getUserPass(t.getUserID(username)), fn) + return t.withClientPass(ctx, username, t.getUserByName(username).getUserPass(), fn) } // withClient executes the given function with a client that is logged in with the given username and password. @@ -108,7 +108,7 @@ func (t *testCtx) withAddrKR( return err } - keyPass, err := salt.SaltForKey([]byte(t.getUserPass(t.getUserID(username))), user.Keys.Primary().ID) + keyPass, err := salt.SaltForKey([]byte(t.getUserByName(username).getUserPass()), user.Keys.Primary().ID) if err != nil { return err } diff --git a/tests/ctx_test.go b/tests/ctx_test.go index 0c9db06c..2089ea89 100644 --- a/tests/ctx_test.go +++ b/tests/ctx_test.go @@ -40,12 +40,85 @@ import ( "github.com/emersion/go-imap/client" "github.com/google/uuid" "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "google.golang.org/grpc" ) var defaultVersion = semver.MustParse("3.0.6") +type testUser struct { + name string // the test user name + userID string // the user's account ID + addresses []*testAddr // the user's addresses + userPass string // the user's account password + bridgePass string // the user's bridge password +} + +func newTestUser(userID, name, userPass string) *testUser { + return &testUser{ + userID: userID, + name: name, + userPass: userPass, + } +} + +func (user *testUser) getName() string { + return user.name +} + +func (user *testUser) getUserID() string { + return user.userID +} + +func (user *testUser) getEmails() []string { + return xslices.Map(user.addresses, func(addr *testAddr) string { + return addr.email + }) +} + +func (user *testUser) getAddrID(email string) string { + for _, addr := range user.addresses { + if addr.email == email { + return addr.addrID + } + } + + panic(fmt.Sprintf("unknown email %q", email)) +} + +func (user *testUser) addAddress(addrID, email string) { + user.addresses = append(user.addresses, newTestAddr(addrID, email)) +} + +func (user *testUser) remAddress(addrID string) { + user.addresses = xslices.Filter(user.addresses, func(addr *testAddr) bool { + return addr.addrID != addrID + }) +} + +func (user *testUser) getUserPass() string { + return user.userPass +} + +func (user *testUser) getBridgePass() string { + return user.bridgePass +} + +func (user *testUser) setBridgePass(pass string) { + user.bridgePass = pass +} + +type testAddr struct { + addrID string // the remote address ID + email string // the test address email +} + +func newTestAddr(addrID, email string) *testAddr { + return &testAddr{ + addrID: addrID, + email: email, + } +} + type testCtx struct { // These are the objects supporting the test. dir string @@ -70,13 +143,11 @@ type testCtx struct { clientConn *grpc.ClientConn clientEventCh *queue.QueuedChannel[*frontend.StreamEvent] - // These maps hold expected userIDByName, their primary addresses and bridge passwords. - userUUIDByName map[string]string - addrUUIDByName map[string]string - userIDByName map[string]string - userAddrByEmail map[string]map[string]string - userPassByID map[string]string - userBridgePassByID map[string][]byte + // These maps hold test objects created during the test. + userByID map[string]*testUser + userUUIDByName map[string]string + addrByID map[string]*testAddr + addrUUIDByName map[string]string // These are the IMAP and SMTP clients used to connect to bridge. imapClients map[string]*imapClient @@ -115,12 +186,10 @@ func newTestCtx(tb testing.TB) *testCtx { events: newEventCollector(), reporter: newReportRecorder(tb), - userUUIDByName: make(map[string]string), - addrUUIDByName: make(map[string]string), - userIDByName: make(map[string]string), - userAddrByEmail: make(map[string]map[string]string), - userPassByID: make(map[string]string), - userBridgePassByID: make(map[string][]byte), + userByID: make(map[string]*testUser), + userUUIDByName: make(map[string]string), + addrByID: make(map[string]*testAddr), + addrUUIDByName: make(map[string]string), imapClients: make(map[string]*imapClient), smtpClients: make(map[string]*smtpClient), @@ -192,62 +261,22 @@ func (t *testCtx) afterStep(st *godog.Step, status godog.StepResultStatus) { logrus.Debugf("Finished step (%v): %s", status, st.Text) } -func (t *testCtx) getName(wantUserID string) string { - for name, userID := range t.userIDByName { - if userID == wantUserID { - return name +func (t *testCtx) addUser(userID, name, userPass string) { + t.userByID[userID] = newTestUser(userID, name, userPass) +} + +func (t *testCtx) getUserByName(name string) *testUser { + for _, user := range t.userByID { + if user.name == name { + return user } } - panic(fmt.Sprintf("unknown user ID %q", wantUserID)) + panic(fmt.Sprintf("user %q not found", name)) } -func (t *testCtx) getUserID(username string) string { - return t.userIDByName[username] -} - -func (t *testCtx) setUserID(username, userID string) { - t.userIDByName[username] = userID -} - -func (t *testCtx) getUserAddrID(userID, email string) string { - return t.userAddrByEmail[userID][email] -} - -func (t *testCtx) getUserAddrs(userID string) []string { - return maps.Keys(t.userAddrByEmail[userID]) -} - -func (t *testCtx) setUserAddr(userID, addrID, email string) { - if _, ok := t.userAddrByEmail[userID]; !ok { - t.userAddrByEmail[userID] = make(map[string]string) - } - - t.userAddrByEmail[userID][email] = addrID -} - -func (t *testCtx) unsetUserAddr(userID, wantAddrID string) { - for email, addrID := range t.userAddrByEmail[userID] { - if addrID == wantAddrID { - delete(t.userAddrByEmail[userID], email) - } - } -} - -func (t *testCtx) getUserPass(userID string) string { - return t.userPassByID[userID] -} - -func (t *testCtx) setUserPass(userID, pass string) { - t.userPassByID[userID] = pass -} - -func (t *testCtx) getUserBridgePass(userID string) string { - return string(t.userBridgePassByID[userID]) -} - -func (t *testCtx) setUserBridgePass(userID string, pass []byte) { - t.userBridgePassByID[userID] = pass +func (t *testCtx) getUserByID(userID string) *testUser { + return t.userByID[userID] } func (t *testCtx) getMBoxID(userID string, name string) string { @@ -256,7 +285,7 @@ func (t *testCtx) getMBoxID(userID string, name string) string { var labelID string - if err := t.withClient(ctx, t.getName(userID), func(ctx context.Context, client *proton.Client) error { + if err := t.withClient(ctx, t.getUserByID(userID).getName(), func(ctx context.Context, client *proton.Client) error { labels, err := client.GetLabels(ctx, proton.LabelTypeLabel, proton.LabelTypeFolder, proton.LabelTypeSystem) if err != nil { panic(err) diff --git a/tests/features/user/login.feature b/tests/features/user/login.feature index 385a7984..3edfdb79 100644 --- a/tests/features/user/login.feature +++ b/tests/features/user/login.feature @@ -14,8 +14,8 @@ Feature: A user can login Then user "[user:user]" is not listed Scenario: Login to nonexistent account - When the user logs in with username "[user:other]" and password "unknown" - Then user "[user:other]" is not listed + When the user logs in with username "nonexistent" and password "unknown" + Then user "nonexistent" is not listed Scenario: Login to account without internet Given the internet is turned off diff --git a/tests/imap_test.go b/tests/imap_test.go index 2f2cf448..8cae5707 100644 --- a/tests/imap_test.go +++ b/tests/imap_test.go @@ -38,35 +38,35 @@ import ( ) func (s *scenario) userConnectsIMAPClient(username, clientID string) error { - return s.t.newIMAPClient(s.t.getUserID(username), clientID) + return s.t.newIMAPClient(s.t.getUserByName(username).getUserID(), clientID) } func (s *scenario) userConnectsIMAPClientOnPort(username, clientID string, port int) error { - return s.t.newIMAPClientOnPort(s.t.getUserID(username), clientID, port) + return s.t.newIMAPClientOnPort(s.t.getUserByName(username).getUserID(), clientID, port) } func (s *scenario) userConnectsAndAuthenticatesIMAPClient(username, clientID string) error { - return s.userConnectsAndAuthenticatesIMAPClientWithAddress(username, clientID, s.t.getUserAddrs(s.t.getUserID(username))[0]) + return s.userConnectsAndAuthenticatesIMAPClientWithAddress(username, clientID, s.t.getUserByName(username).getEmails()[0]) } func (s *scenario) userConnectsAndAuthenticatesIMAPClientWithAddress(username, clientID, address string) error { - if err := s.t.newIMAPClient(s.t.getUserID(username), clientID); err != nil { + if err := s.t.newIMAPClient(s.t.getUserByName(username).getUserID(), clientID); err != nil { return err } userID, client := s.t.getIMAPClient(clientID) - return client.Login(address, s.t.getUserBridgePass(userID)) + return client.Login(address, s.t.getUserByID(userID).getBridgePass()) } func (s *scenario) userConnectsAndCanNotAuthenticateIMAPClientWithAddress(username, clientID, address string) error { - if err := s.t.newIMAPClient(s.t.getUserID(username), clientID); err != nil { + if err := s.t.newIMAPClient(s.t.getUserByName(username).getUserID(), clientID); err != nil { return err } userID, client := s.t.getIMAPClient(clientID) - if err := client.Login(address, s.t.getUserBridgePass(userID)); err == nil { + if err := client.Login(address, s.t.getUserByID(userID).getBridgePass()); err == nil { return fmt.Errorf("expected error, got nil") } @@ -76,19 +76,19 @@ func (s *scenario) userConnectsAndCanNotAuthenticateIMAPClientWithAddress(userna func (s *scenario) imapClientCanAuthenticate(clientID string) error { userID, client := s.t.getIMAPClient(clientID) - return client.Login(s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)) + return client.Login(s.t.getUserByID(userID).getEmails()[0], s.t.getUserByID(userID).getBridgePass()) } func (s *scenario) imapClientCanAuthenticateWithAddress(clientID string, address string) error { userID, client := s.t.getIMAPClient(clientID) - return client.Login(address, s.t.getUserBridgePass(userID)) + return client.Login(address, s.t.getUserByID(userID).getBridgePass()) } func (s *scenario) imapClientCannotAuthenticate(clientID string) error { userID, client := s.t.getIMAPClient(clientID) - if err := client.Login(s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)); err == nil { + if err := client.Login(s.t.getUserByID(userID).getEmails()[0], s.t.getUserByID(userID).getBridgePass()); err == nil { return fmt.Errorf("expected error, got nil") } @@ -98,7 +98,7 @@ func (s *scenario) imapClientCannotAuthenticate(clientID string) error { func (s *scenario) imapClientCannotAuthenticateWithAddress(clientID, address string) error { userID, client := s.t.getIMAPClient(clientID) - if err := client.Login(address, s.t.getUserBridgePass(userID)); err == nil { + if err := client.Login(address, s.t.getUserByID(userID).getBridgePass()); err == nil { return fmt.Errorf("expected error, got nil") } @@ -108,7 +108,7 @@ func (s *scenario) imapClientCannotAuthenticateWithAddress(clientID, address str func (s *scenario) imapClientCannotAuthenticateWithIncorrectUsername(clientID string) error { userID, client := s.t.getIMAPClient(clientID) - if err := client.Login(s.t.getUserAddrs(userID)[0]+"bad", s.t.getUserBridgePass(userID)); err == nil { + if err := client.Login(s.t.getUserByID(userID).getEmails()[0]+"bad", s.t.getUserByID(userID).getBridgePass()); err == nil { return fmt.Errorf("expected error, got nil") } @@ -118,7 +118,7 @@ func (s *scenario) imapClientCannotAuthenticateWithIncorrectUsername(clientID st func (s *scenario) imapClientCannotAuthenticateWithIncorrectPassword(clientID string) error { userID, client := s.t.getIMAPClient(clientID) - if err := client.Login(s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)+"bad"); err == nil { + if err := client.Login(s.t.getUserByID(userID).getEmails()[0], s.t.getUserByID(userID).getBridgePass()+"bad"); err == nil { return fmt.Errorf("expected error, got nil") } diff --git a/tests/smtp_test.go b/tests/smtp_test.go index 26c1fe29..430b83af 100644 --- a/tests/smtp_test.go +++ b/tests/smtp_test.go @@ -27,25 +27,25 @@ import ( ) func (s *scenario) userConnectsSMTPClient(username, clientID string) error { - return s.t.newSMTPClient(s.t.getUserID(username), clientID) + return s.t.newSMTPClient(s.t.getUserByName(username).getUserID(), clientID) } func (s *scenario) userConnectsSMTPClientOnPort(username, clientID string, port int) error { - return s.t.newSMTPClientOnPort(s.t.getUserID(username), clientID, port) + return s.t.newSMTPClientOnPort(s.t.getUserByName(username).getUserID(), clientID, port) } func (s *scenario) userConnectsAndAuthenticatesSMTPClient(username, clientID string) error { - return s.userConnectsAndAuthenticatesSMTPClientWithAddress(username, clientID, s.t.getUserAddrs(s.t.getUserID(username))[0]) + return s.userConnectsAndAuthenticatesSMTPClientWithAddress(username, clientID, s.t.getUserByName(username).getEmails()[0]) } func (s *scenario) userConnectsAndAuthenticatesSMTPClientWithAddress(username, clientID, address string) error { - if err := s.t.newSMTPClient(s.t.getUserID(username), clientID); err != nil { + if err := s.t.newSMTPClient(s.t.getUserByName(username).getUserID(), clientID); err != nil { return err } userID, client := s.t.getSMTPClient(clientID) - s.t.pushError(client.Auth(smtp.PlainAuth("", address, s.t.getUserBridgePass(userID), constants.Host))) + s.t.pushError(client.Auth(smtp.PlainAuth("", address, s.t.getUserByID(userID).getBridgePass(), constants.Host))) return nil } @@ -53,7 +53,7 @@ func (s *scenario) userConnectsAndAuthenticatesSMTPClientWithAddress(username, c func (s *scenario) smtpClientCanAuthenticate(clientID string) error { userID, client := s.t.getSMTPClient(clientID) - if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID), constants.Host)); err != nil { + if err := client.Auth(smtp.PlainAuth("", s.t.getUserByID(userID).getEmails()[0], s.t.getUserByID(userID).getBridgePass(), constants.Host)); err != nil { return fmt.Errorf("expected no error, got %v", err) } @@ -63,7 +63,7 @@ func (s *scenario) smtpClientCanAuthenticate(clientID string) error { func (s *scenario) smtpClientCannotAuthenticate(clientID string) error { userID, client := s.t.getSMTPClient(clientID) - if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID), constants.Host)); err == nil { + if err := client.Auth(smtp.PlainAuth("", s.t.getUserByID(userID).getEmails()[0], s.t.getUserByID(userID).getBridgePass(), constants.Host)); err == nil { return fmt.Errorf("expected error, got nil") } @@ -73,7 +73,7 @@ func (s *scenario) smtpClientCannotAuthenticate(clientID string) error { func (s *scenario) smtpClientCannotAuthenticateWithIncorrectUsername(clientID string) error { userID, client := s.t.getSMTPClient(clientID) - if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0]+"bad", s.t.getUserBridgePass(userID), constants.Host)); err == nil { + if err := client.Auth(smtp.PlainAuth("", s.t.getUserByID(userID).getEmails()[0]+"bad", s.t.getUserByID(userID).getBridgePass(), constants.Host)); err == nil { return fmt.Errorf("expected error, got nil") } @@ -83,7 +83,7 @@ func (s *scenario) smtpClientCannotAuthenticateWithIncorrectUsername(clientID st func (s *scenario) smtpClientCannotAuthenticateWithIncorrectPassword(clientID string) error { userID, client := s.t.getSMTPClient(clientID) - if err := client.Auth(smtp.PlainAuth("", s.t.getUserAddrs(userID)[0], s.t.getUserBridgePass(userID)+"bad", constants.Host)); err == nil { + if err := client.Auth(smtp.PlainAuth("", s.t.getUserByID(userID).getEmails()[0], s.t.getUserByID(userID).getBridgePass()+"bad", constants.Host)); err == nil { return fmt.Errorf("expected error, got nil") } diff --git a/tests/user_test.go b/tests/user_test.go index 044d52f4..989a914c 100644 --- a/tests/user_test.go +++ b/tests/user_test.go @@ -32,7 +32,6 @@ import ( "github.com/bradenaw/juniper/xslices" "github.com/cucumber/godog" "github.com/google/uuid" - "golang.org/x/exp/slices" ) func (s *scenario) thereExistsAnAccountWithUsernameAndPassword(username, password string) error { @@ -52,7 +51,7 @@ func (s *scenario) theAccountHasAdditionalDisabledAddress(username, address stri } func (s *scenario) theAccountHasAdditionalAddressWithoutKeys(username, address string) error { - userID := s.t.getUserID(username) + userID := s.t.getUserByName(username).getUserID() // Decrypt the user's encrypted ID for use with quark. userDecID, err := s.t.runQuarkCmd(context.Background(), "encryption:id", "--decrypt", userID) @@ -65,7 +64,8 @@ func (s *scenario) theAccountHasAdditionalAddressWithoutKeys(username, address s context.Background(), "user:create:address", string(userDecID), - s.t.getUserPass(userID), + s.t.getUserByID(userID).getUserPass(), + address, ); err != nil { return err @@ -78,15 +78,15 @@ func (s *scenario) theAccountHasAdditionalAddressWithoutKeys(username, address s } // Set the new address of the user. - s.t.setUserAddr(userID, addr[len(addr)-1].ID, address) + s.t.getUserByID(userID).addAddress(addr[len(addr)-1].ID, address) return nil }) } func (s *scenario) theAccountNoLongerHasAdditionalAddress(username, address string) error { - userID := s.t.getUserID(username) - addrID := s.t.getUserAddrID(userID, address) + userID := s.t.getUserByName(username).getUserID() + addrID := s.t.getUserByName(username).getAddrID(address) if err := s.t.withClient(context.Background(), username, func(ctx context.Context, c *proton.Client) error { if err := c.DisableAddress(ctx, addrID); err != nil { @@ -98,7 +98,7 @@ func (s *scenario) theAccountNoLongerHasAdditionalAddress(username, address stri return err } - s.t.unsetUserAddr(userID, addrID) + s.t.getUserByID(userID).remAddress(addrID) return nil } @@ -184,8 +184,8 @@ func (s *scenario) theAddressOfAccountHasTheFollowingMessagesInMailbox(address, ctx, cancel := context.WithCancel(context.Background()) defer cancel() - userID := s.t.getUserID(username) - addrID := s.t.getUserAddrID(userID, address) + userID := s.t.getUserByName(username).getUserID() + addrID := s.t.getUserByName(username).getAddrID(address) mboxID := s.t.getMBoxID(userID, mailbox) wantMessages, err := unmarshalTable[Message](table) @@ -219,8 +219,8 @@ func (s *scenario) theAddressOfAccountHasMessagesInMailbox(address, username str ctx, cancel := context.WithCancel(context.Background()) defer cancel() - userID := s.t.getUserID(username) - addrID := s.t.getUserAddrID(userID, address) + userID := s.t.getUserByName(username).getUserID() + addrID := s.t.getUserByName(username).getAddrID(address) mboxID := s.t.getMBoxID(userID, mailbox) return s.t.createMessages(ctx, username, addrID, iterator.Collect(iterator.Map(iterator.Counter(count), func(idx int) proton.ImportReq { @@ -262,7 +262,7 @@ func (s *scenario) theFollowingFieldsWereChangedInDraftForAddressOfAccount(draft defer cancel() return s.t.withClient(ctx, username, func(ctx context.Context, c *proton.Client) error { - return s.t.withAddrKR(ctx, c, username, s.t.getUserAddrID(s.t.getUserID(username), address), func(_ context.Context, addrKR *crypto.KeyRing) error { + return s.t.withAddrKR(ctx, c, username, s.t.getUserByName(username).getAddrID(address), func(_ context.Context, addrKR *crypto.KeyRing) error { var changes proton.DraftTemplate if wantMessages[0].From != "" { @@ -311,7 +311,7 @@ func (s *scenario) drafAtIndexWasMovedToTrashForAddressOfAccount(draftIndex int, defer cancel() return s.t.withClient(ctx, username, func(ctx context.Context, c *proton.Client) error { - return s.t.withAddrKR(ctx, c, username, s.t.getUserAddrID(s.t.getUserID(username), address), func(_ context.Context, addrKR *crypto.KeyRing) error { + return s.t.withAddrKR(ctx, c, username, s.t.getUserByName(username).getAddrID(address), func(_ context.Context, addrKR *crypto.KeyRing) error { if err := c.UnlabelMessages(ctx, []string{draftID}, proton.DraftsLabel); err != nil { return fmt.Errorf("failed to unlabel draft") } @@ -329,7 +329,7 @@ func (s *scenario) userLogsInWithUsernameAndPassword(username, password string) if err != nil { s.t.pushError(err) } else { - if userID != s.t.getUserID(username) { + if userID != s.t.getUserByName(username).getUserID() { return errors.New("user ID mismatch") } @@ -338,18 +338,18 @@ func (s *scenario) userLogsInWithUsernameAndPassword(username, password string) return err } - s.t.setUserBridgePass(userID, info.BridgePass) + s.t.getUserByID(userID).setBridgePass(string(info.BridgePass)) } return nil } func (s *scenario) userLogsOut(username string) error { - return s.t.bridge.LogoutUser(context.Background(), s.t.getUserID(username)) + return s.t.bridge.LogoutUser(context.Background(), s.t.getUserByName(username).getUserID()) } func (s *scenario) userIsDeleted(username string) error { - return s.t.bridge.DeleteUser(context.Background(), s.t.getUserID(username)) + return s.t.bridge.DeleteUser(context.Background(), s.t.getUserByName(username).getUserID()) } func (s *scenario) theAuthOfUserIsRevoked(username string) error { @@ -359,7 +359,7 @@ func (s *scenario) theAuthOfUserIsRevoked(username string) error { } func (s *scenario) userIsListedAndConnected(username string) error { - user, err := s.t.bridge.GetUserInfo(s.t.getUserID(username)) + user, err := s.t.bridge.GetUserInfo(s.t.getUserByName(username).getUserID()) if err != nil { return err } @@ -382,7 +382,7 @@ func (s *scenario) userIsEventuallyListedAndConnected(username string) error { } func (s *scenario) userIsListedButNotConnected(username string) error { - user, err := s.t.bridge.GetUserInfo(s.t.getUserID(username)) + user, err := s.t.bridge.GetUserInfo(s.t.getUserByName(username).getUserID()) if err != nil { return err } @@ -399,7 +399,7 @@ func (s *scenario) userIsListedButNotConnected(username string) error { } func (s *scenario) userIsNotListed(username string) error { - if slices.Contains(s.t.bridge.GetUserIDs(), s.t.getUserID(username)) { + if _, err := s.t.bridge.QueryUserInfo(username); !errors.Is(err, bridge.ErrNoSuchUser) { return errors.New("user listed") } @@ -411,7 +411,7 @@ func (s *scenario) userFinishesSyncing(username string) error { } func (s *scenario) addAdditionalAddressToAccount(username, address string, disabled bool) error { - userID := s.t.getUserID(username) + userID := s.t.getUserByName(username).getUserID() // Decrypt the user's encrypted ID for use with quark. userDecID, err := s.t.runQuarkCmd(context.Background(), "encryption:id", "--decrypt", userID) @@ -429,7 +429,7 @@ func (s *scenario) addAdditionalAddressToAccount(username, address string, disab args = append(args, string(userDecID), - s.t.getUserPass(userID), + s.t.getUserByID(userID).getUserPass(), address, ) @@ -449,7 +449,7 @@ func (s *scenario) addAdditionalAddressToAccount(username, address string, disab } // Set the new address of the user. - s.t.setUserAddr(userID, addr[len(addr)-1].ID, address) + s.t.getUserByID(userID).addAddress(addr[len(addr)-1].ID, address) return nil }) @@ -503,14 +503,11 @@ func (s *scenario) createUserAccount(username, password string, disabled bool) e return err } - // Set the ID of the user. - s.t.setUserID(username, user.ID) - - // Set the password of the user. - s.t.setUserPass(user.ID, password) + // Add the test user. + s.t.addUser(user.ID, username, password) // Set the address of the user. - s.t.setUserAddr(user.ID, addr[0].ID, addr[0].Email) + s.t.getUserByID(user.ID).addAddress(addr[0].ID, addr[0].Email) return nil })