test: fix most integration tests (live)

This commit is contained in:
James Houlahan
2020-04-09 10:24:58 +02:00
parent bafd4e714e
commit fec5f2d3c3
18 changed files with 124 additions and 104 deletions

View File

@ -24,6 +24,7 @@ import (
"github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/internal/preferences"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
) )
// GetBridge returns bridge instance. // GetBridge returns bridge instance.
@ -59,7 +60,7 @@ func newBridgeInstance(
cfg *fakeConfig, cfg *fakeConfig,
credStore bridge.CredentialsStorer, credStore bridge.CredentialsStorer,
eventListener listener.Listener, eventListener listener.Listener,
clientManager bridge.ClientManager, clientManager *pmapi.ClientManager,
) *bridge.Bridge { ) *bridge.Bridge {
version := os.Getenv("VERSION") version := os.Getenv("VERSION")
bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "") bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "")

View File

@ -38,14 +38,17 @@ type TestContext struct {
t *bddT t *bddT
cfg *fakeConfig cfg *fakeConfig
listener listener.Listener listener listener.Listener
pmapiController PMAPIController // pmapiController is used to create pmapi clients (either real or fake) and control server state.
testAccounts *accounts.TestAccounts testAccounts *accounts.TestAccounts
// pmapiController is used to control real or fake pmapi clients.
// The clients are created by the clientManager.
pmapiController PMAPIController
clientManager *pmapi.ClientManager
// Bridge core related variables. // Bridge core related variables.
bridge *bridge.Bridge bridge *bridge.Bridge
bridgeLastError error bridgeLastError error
credStore bridge.CredentialsStorer credStore bridge.CredentialsStorer
clientManager *pmapi.ClientManager
// IMAP related variables. // IMAP related variables.
imapAddr string imapAddr string

View File

@ -46,10 +46,9 @@ func (api *FakePMAPI) CreateAttachment(attachment *pmapi.Attachment, data io.Rea
return attachment, nil return attachment, nil
} }
func (api *FakePMAPI) DeleteAttachment(attachmentID string) error { func (api *FakePMAPI) DeleteAttachment(attID string) error {
if err := api.checkAndRecordCall(GET, "/attachments/"+attachmentID, nil); err != nil { if err := api.checkAndRecordCall(DELETE, "/attachments/"+attID, nil); err != nil {
return err return err
} }
return nil return nil
} }

View File

@ -142,7 +142,7 @@ func (api *FakePMAPI) AuthRefresh(token string) (*pmapi.Auth, error) {
} }
func (api *FakePMAPI) Logout() { func (api *FakePMAPI) Logout() {
_ = api.DeleteAuth() api.DeleteAuth()
api.ClearData() api.ClearData()
} }
@ -150,14 +150,12 @@ func (api *FakePMAPI) DeleteAuth() error {
if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil { if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil {
return err return err
} }
// Logout will also emit change to auth channel // Logout will also emit change to auth channel
api.sendAuth(nil) api.sendAuth(nil)
api.controller.deleteSession(api.uid)
return nil return nil
} }
func (api *FakePMAPI) ClearData() { func (api *FakePMAPI) ClearData() {
api.controller.deleteSession(api.uid)
api.unsetUser() api.unsetUser()
} }

View File

@ -44,8 +44,8 @@ type Controller struct {
log *logrus.Entry log *logrus.Entry
} }
func NewController(cm *pmapi.ClientManager) (cntrl *Controller) { func NewController(cm *pmapi.ClientManager) *Controller {
cntrl = &Controller{ controller := &Controller{
lock: &sync.RWMutex{}, lock: &sync.RWMutex{},
fakeAPIs: []*FakePMAPI{}, fakeAPIs: []*FakePMAPI{},
calls: []*fakeCall{}, calls: []*fakeCall{},
@ -64,10 +64,10 @@ func NewController(cm *pmapi.ClientManager) (cntrl *Controller) {
} }
cm.SetClientConstructor(func(userID string) pmapi.Client { cm.SetClientConstructor(func(userID string) pmapi.Client {
fakeAPI := New(cntrl) fakeAPI := New(controller)
cntrl.fakeAPIs = append(cntrl.fakeAPIs, fakeAPI) controller.fakeAPIs = append(controller.fakeAPIs, fakeAPI)
return fakeAPI return fakeAPI
}) })
return return controller
} }

View File

@ -39,9 +39,9 @@ type fakeCall struct {
request []byte request []byte
} }
func (cntrl *Controller) recordCall(method method, path string, req interface{}) { func (ctl *Controller) recordCall(method method, path string, req interface{}) {
cntrl.lock.Lock() ctl.lock.Lock()
defer cntrl.lock.Unlock() defer ctl.lock.Unlock()
request := []byte{} request := []byte{}
if req != nil { if req != nil {
@ -51,16 +51,16 @@ func (cntrl *Controller) recordCall(method method, path string, req interface{})
panic(err) panic(err)
} }
} }
cntrl.calls = append(cntrl.calls, &fakeCall{ ctl.calls = append(ctl.calls, &fakeCall{
method: method, method: method,
path: path, path: path,
request: request, request: request,
}) })
} }
func (cntrl *Controller) PrintCalls() { func (ctl *Controller) PrintCalls() {
fmt.Println("API calls:") fmt.Println("API calls:")
for idx, call := range cntrl.calls { for idx, call := range ctl.calls {
fmt.Printf("%02d: [%s] %s\n", idx+1, call.method, call.path) fmt.Printf("%02d: [%s] %s\n", idx+1, call.method, call.path)
if call.request != nil && string(call.request) != "null" { if call.request != nil && string(call.request) != "null" {
fmt.Printf("\t%s\n", call.request) fmt.Printf("\t%s\n", call.request)
@ -68,8 +68,8 @@ func (cntrl *Controller) PrintCalls() {
} }
} }
func (cntrl *Controller) WasCalled(method, path string, expectedRequest []byte) bool { func (ctl *Controller) WasCalled(method, path string, expectedRequest []byte) bool {
for _, call := range cntrl.calls { for _, call := range ctl.calls {
if string(call.method) != method && call.path != path { if string(call.method) != method && call.path != path {
continue continue
} }
@ -82,9 +82,9 @@ func (cntrl *Controller) WasCalled(method, path string, expectedRequest []byte)
return false return false
} }
func (cntrl *Controller) GetCalls(method, path string) [][]byte { func (ctl *Controller) GetCalls(method, path string) [][]byte {
requests := [][]byte{} requests := [][]byte{}
for _, call := range cntrl.calls { for _, call := range ctl.calls {
if string(call.method) == method && call.path == path { if string(call.method) == method && call.path == path {
requests = append(requests, call.request) requests = append(requests, call.request)
} }

View File

@ -34,33 +34,33 @@ var systemLabelNameToID = map[string]string{ //nolint[gochecknoglobals]
"Drafts": pmapi.DraftLabel, "Drafts": pmapi.DraftLabel,
} }
func (cntrl *Controller) TurnInternetConnectionOff() { func (ctl *Controller) TurnInternetConnectionOff() {
cntrl.log.Warn("Turning OFF internet") ctl.log.Warn("Turning OFF internet")
cntrl.noInternetConnection = true ctl.noInternetConnection = true
} }
func (cntrl *Controller) TurnInternetConnectionOn() { func (ctl *Controller) TurnInternetConnectionOn() {
cntrl.log.Warn("Turning ON internet") ctl.log.Warn("Turning ON internet")
cntrl.noInternetConnection = false ctl.noInternetConnection = false
} }
func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, password string, twoFAEnabled bool) error { func (ctl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, password string, twoFAEnabled bool) error {
cntrl.usersByUsername[user.Name] = &fakeUser{ ctl.usersByUsername[user.Name] = &fakeUser{
user: user, user: user,
password: password, password: password,
has2FA: twoFAEnabled, has2FA: twoFAEnabled,
} }
cntrl.addressesByUsername[user.Name] = addresses ctl.addressesByUsername[user.Name] = addresses
return nil return nil
} }
func (cntrl *Controller) AddUserLabel(username string, label *pmapi.Label) error { func (ctl *Controller) AddUserLabel(username string, label *pmapi.Label) error {
if _, ok := cntrl.labelsByUsername[username]; !ok { if _, ok := ctl.labelsByUsername[username]; !ok {
cntrl.labelsByUsername[username] = []*pmapi.Label{} ctl.labelsByUsername[username] = []*pmapi.Label{}
} }
labelName := getLabelNameWithoutPrefix(label.Name) labelName := getLabelNameWithoutPrefix(label.Name)
for _, existingLabel := range cntrl.labelsByUsername[username] { for _, existingLabel := range ctl.labelsByUsername[username] {
if existingLabel.Name == labelName { if existingLabel.Name == labelName {
return fmt.Errorf("folder or label %s already exists", label.Name) return fmt.Errorf("folder or label %s already exists", label.Name)
} }
@ -71,17 +71,17 @@ func (cntrl *Controller) AddUserLabel(username string, label *pmapi.Label) error
if label.Exclusive == 1 { if label.Exclusive == 1 {
prefix = "folder" prefix = "folder"
} }
label.ID = cntrl.labelIDGenerator.next(prefix) label.ID = ctl.labelIDGenerator.next(prefix)
label.Name = labelName label.Name = labelName
cntrl.labelsByUsername[username] = append(cntrl.labelsByUsername[username], label) ctl.labelsByUsername[username] = append(ctl.labelsByUsername[username], label)
cntrl.resetUsers() ctl.resetUsers()
return nil return nil
} }
func (cntrl *Controller) GetLabelIDs(username string, labelNames []string) ([]string, error) { func (ctl *Controller) GetLabelIDs(username string, labelNames []string) ([]string, error) {
labelIDs := []string{} labelIDs := []string{}
for _, labelName := range labelNames { for _, labelName := range labelNames {
labelID, err := cntrl.getLabelID(username, labelName) labelID, err := ctl.getLabelID(username, labelName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -90,12 +90,12 @@ func (cntrl *Controller) GetLabelIDs(username string, labelNames []string) ([]st
return labelIDs, nil return labelIDs, nil
} }
func (cntrl *Controller) getLabelID(username, labelName string) (string, error) { func (ctl *Controller) getLabelID(username, labelName string) (string, error) {
if labelID, ok := systemLabelNameToID[labelName]; ok { if labelID, ok := systemLabelNameToID[labelName]; ok {
return labelID, nil return labelID, nil
} }
labelName = getLabelNameWithoutPrefix(labelName) labelName = getLabelNameWithoutPrefix(labelName)
for _, label := range cntrl.labelsByUsername[username] { for _, label := range ctl.labelsByUsername[username] {
if label.Name == labelName { if label.Name == labelName {
return label.ID, nil return label.ID, nil
} }
@ -120,23 +120,23 @@ func getLabelExclusive(name string) int {
return 0 return 0
} }
func (cntrl *Controller) AddUserMessage(username string, message *pmapi.Message) error { func (ctl *Controller) AddUserMessage(username string, message *pmapi.Message) error {
if _, ok := cntrl.messagesByUsername[username]; !ok { if _, ok := ctl.messagesByUsername[username]; !ok {
cntrl.messagesByUsername[username] = []*pmapi.Message{} ctl.messagesByUsername[username] = []*pmapi.Message{}
} }
message.ID = cntrl.messageIDGenerator.next("") message.ID = ctl.messageIDGenerator.next("")
message.LabelIDs = append(message.LabelIDs, pmapi.AllMailLabel) message.LabelIDs = append(message.LabelIDs, pmapi.AllMailLabel)
cntrl.messagesByUsername[username] = append(cntrl.messagesByUsername[username], message) ctl.messagesByUsername[username] = append(ctl.messagesByUsername[username], message)
cntrl.resetUsers() ctl.resetUsers()
return nil return nil
} }
func (cntrl *Controller) resetUsers() { func (ctl *Controller) resetUsers() {
for _, fakeAPI := range cntrl.fakeAPIs { for _, fakeAPI := range ctl.fakeAPIs {
_ = fakeAPI.setUser(fakeAPI.username) _ = fakeAPI.setUser(fakeAPI.username)
} }
} }
func (cntrl *Controller) GetMessageID(username, messageIndex string) string { func (ctl *Controller) GetMessageID(username, messageIndex string) string {
return messageIndex return messageIndex
} }

View File

@ -29,9 +29,9 @@ type fakeSession struct {
var errWrongNameOrPassword = errors.New("Incorrect login credentials. Please try again") //nolint[stylecheck] var errWrongNameOrPassword = errors.New("Incorrect login credentials. Please try again") //nolint[stylecheck]
func (cntrl *Controller) createSessionIfAuthorized(username, password string) (*fakeSession, error) { func (ctl *Controller) createSessionIfAuthorized(username, password string) (*fakeSession, error) {
// get user // get user
user, ok := cntrl.usersByUsername[username] user, ok := ctl.usersByUsername[username]
if !ok || user.password != password { if !ok || user.password != password {
return nil, errWrongNameOrPassword return nil, errWrongNameOrPassword
} }
@ -39,19 +39,19 @@ func (cntrl *Controller) createSessionIfAuthorized(username, password string) (*
// create session // create session
session := &fakeSession{ session := &fakeSession{
username: username, username: username,
uid: cntrl.tokenGenerator.next("uid"), uid: ctl.tokenGenerator.next("uid"),
hasFullScope: !user.has2FA, hasFullScope: !user.has2FA,
} }
cntrl.refreshTheTokensForSession(session) ctl.refreshTheTokensForSession(session)
cntrl.sessionsByUID[session.uid] = session ctl.sessionsByUID[session.uid] = session
return session, nil return session, nil
} }
func (cntrl *Controller) refreshTheTokensForSession(session *fakeSession) { func (ctl *Controller) refreshTheTokensForSession(session *fakeSession) {
session.refreshToken = cntrl.tokenGenerator.next("refresh") session.refreshToken = ctl.tokenGenerator.next("refresh")
} }
func (cntrl *Controller) deleteSession(uid string) { func (ctl *Controller) deleteSession(uid string) {
delete(cntrl.sessionsByUID, uid) delete(ctl.sessionsByUID, uid)
} }

View File

@ -103,6 +103,7 @@ func (api *FakePMAPI) checkInternetAndRecordCall(method method, path string, req
return nil return nil
} }
// TODO: This should be sent back to the ClientManager properly!
func (api *FakePMAPI) sendAuth(auth *pmapi.Auth) { func (api *FakePMAPI) sendAuth(auth *pmapi.Auth) {
if auth != nil { if auth != nil {
auth.DANGEROUSLYSetUID(api.uid) auth.DANGEROUSLYSetUID(api.uid)

View File

@ -43,6 +43,6 @@ func (api *FakePMAPI) SendSimpleMetric(category, action, label string) error {
return api.checkInternetAndRecordCall(GET, "/metrics?"+v.Encode(), nil) return api.checkInternetAndRecordCall(GET, "/metrics?"+v.Encode(), nil)
} }
func (api *FakePMAPI) ReportSentryCrash(reportErr error) (err error) { func (api *FakePMAPI) ReportSentryCrash(err error) error {
return nil return nil
} }

View File

@ -51,6 +51,9 @@ func (api *FakePMAPI) UpdateUser() (*pmapi.User, error) {
} }
func (api *FakePMAPI) GetAddresses() (pmapi.AddressList, error) { func (api *FakePMAPI) GetAddresses() (pmapi.AddressList, error) {
if err := api.checkAndRecordCall(GET, "/addresses", nil); err != nil {
return nil, err
}
return *api.addresses, nil return *api.addresses, nil
} }

View File

@ -3,7 +3,7 @@ Feature: Re-login to bridge
Given there is connected user "user" Given there is connected user "user"
And there is database file for "user" And there is database file for "user"
When "user" logs in to bridge When "user" logs in to bridge
Then bridge response is "failed to finish login: user is already logged in" Then bridge response is "failed to finish login: user is already connected"
And "user" is connected And "user" is connected
And "user" has running event loop And "user" has running event loop
@ -12,7 +12,7 @@ Feature: Re-login to bridge
Given there is connected user "user" Given there is connected user "user"
And there is no database file for "user" And there is no database file for "user"
When "user" logs in to bridge When "user" logs in to bridge
Then bridge response is "failed to finish login: user is already logged in" Then bridge response is "failed to finish login: user is already connected"
And "user" is connected And "user" is connected
And "user" has database file And "user" has database file
And "user" has running event loop And "user" has running event loop

View File

@ -29,20 +29,20 @@ type fakeCall struct {
request []byte request []byte
} }
func (cntrl *Controller) recordCall(method, path string, request []byte) { func (ctl *Controller) recordCall(method, path string, request []byte) {
cntrl.lock.Lock() ctl.lock.Lock()
defer cntrl.lock.Unlock() defer ctl.lock.Unlock()
cntrl.calls = append(cntrl.calls, &fakeCall{ ctl.calls = append(ctl.calls, &fakeCall{
method: method, method: method,
path: path, path: path,
request: request, request: request,
}) })
} }
func (cntrl *Controller) PrintCalls() { func (ctl *Controller) PrintCalls() {
fmt.Println("API calls:") fmt.Println("API calls:")
for idx, call := range cntrl.calls { for idx, call := range ctl.calls {
fmt.Printf("%02d: [%s] %s\n", idx+1, call.method, call.path) fmt.Printf("%02d: [%s] %s\n", idx+1, call.method, call.path)
if call.request != nil && string(call.request) != "null" { if call.request != nil && string(call.request) != "null" {
fmt.Printf("\t%s\n", call.request) fmt.Printf("\t%s\n", call.request)
@ -50,8 +50,8 @@ func (cntrl *Controller) PrintCalls() {
} }
} }
func (cntrl *Controller) WasCalled(method, path string, expectedRequest []byte) bool { func (ctl *Controller) WasCalled(method, path string, expectedRequest []byte) bool {
for _, call := range cntrl.calls { for _, call := range ctl.calls {
if call.method != method && call.path != path { if call.method != method && call.path != path {
continue continue
} }
@ -64,9 +64,9 @@ func (cntrl *Controller) WasCalled(method, path string, expectedRequest []byte)
return false return false
} }
func (cntrl *Controller) GetCalls(method, path string) [][]byte { func (ctl *Controller) GetCalls(method, path string) [][]byte {
requests := [][]byte{} requests := [][]byte{}
for _, call := range cntrl.calls { for _, call := range ctl.calls {
if call.method == method && call.path == path { if call.method == method && call.path == path {
requests = append(requests, call.request) requests = append(requests, call.request)
} }

View File

@ -28,6 +28,7 @@ type Controller struct {
// Internal states. // Internal states.
lock *sync.RWMutex lock *sync.RWMutex
calls []*fakeCall calls []*fakeCall
pmapiByUsername map[string]pmapi.Client
messageIDsByUsername map[string][]string messageIDsByUsername map[string][]string
clientManager *pmapi.ClientManager clientManager *pmapi.ClientManager
@ -35,10 +36,11 @@ type Controller struct {
noInternetConnection bool noInternetConnection bool
} }
func NewController(cm *pmapi.ClientManager) (cntrl *Controller) { func NewController(cm *pmapi.ClientManager) *Controller {
cntrl = &Controller{ controller := &Controller{
lock: &sync.RWMutex{}, lock: &sync.RWMutex{},
calls: []*fakeCall{}, calls: []*fakeCall{},
pmapiByUsername: map[string]pmapi.Client{},
messageIDsByUsername: map[string][]string{}, messageIDsByUsername: map[string][]string{},
clientManager: cm, clientManager: cm,
@ -46,9 +48,9 @@ func NewController(cm *pmapi.ClientManager) (cntrl *Controller) {
} }
cm.SetRoundTripper(&fakeTransport{ cm.SetRoundTripper(&fakeTransport{
cntrl: cntrl, ctl: controller,
transport: http.DefaultTransport, transport: http.DefaultTransport,
}) })
return return controller
} }

View File

@ -35,8 +35,12 @@ var systemLabelNameToID = map[string]string{ //nolint[gochecknoglobals]
"Drafts": pmapi.DraftLabel, "Drafts": pmapi.DraftLabel,
} }
func (cntrl *Controller) AddUserLabel(username string, label *pmapi.Label) error { func (ctl *Controller) AddUserLabel(username string, label *pmapi.Label) error {
client := cntrl.clientManager.GetClient(username) client, ok := ctl.pmapiByUsername[username]
if !ok {
return fmt.Errorf("user %s does not exist", username)
}
label.Exclusive = getLabelExclusive(label.Name) label.Exclusive = getLabelExclusive(label.Name)
label.Name = getLabelNameWithoutPrefix(label.Name) label.Name = getLabelNameWithoutPrefix(label.Name)
label.Color = pmapi.LabelColors[0] label.Color = pmapi.LabelColors[0]
@ -46,10 +50,10 @@ func (cntrl *Controller) AddUserLabel(username string, label *pmapi.Label) error
return nil return nil
} }
func (cntrl *Controller) GetLabelIDs(username string, labelNames []string) ([]string, error) { func (ctl *Controller) GetLabelIDs(username string, labelNames []string) ([]string, error) {
labelIDs := []string{} labelIDs := []string{}
for _, labelName := range labelNames { for _, labelName := range labelNames {
labelID, err := cntrl.getLabelID(username, labelName) labelID, err := ctl.getLabelID(username, labelName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -58,12 +62,16 @@ func (cntrl *Controller) GetLabelIDs(username string, labelNames []string) ([]st
return labelIDs, nil return labelIDs, nil
} }
func (cntrl *Controller) getLabelID(username, labelName string) (string, error) { func (ctl *Controller) getLabelID(username, labelName string) (string, error) {
if labelID, ok := systemLabelNameToID[labelName]; ok { if labelID, ok := systemLabelNameToID[labelName]; ok {
return labelID, nil return labelID, nil
} }
client := cntrl.clientManager.GetClient(username) client, ok := ctl.pmapiByUsername[username]
if !ok {
return "", fmt.Errorf("user %s does not exist", username)
}
labels, err := client.ListLabels() labels, err := client.ListLabels()
if err != nil { if err != nil {
return "", errors.Wrap(err, "failed to list labels") return "", errors.Wrap(err, "failed to list labels")

View File

@ -30,8 +30,11 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func (cntrl *Controller) AddUserMessage(username string, message *pmapi.Message) error { func (ctl *Controller) AddUserMessage(username string, message *pmapi.Message) error {
client := cntrl.clientManager.GetClient(username) client, ok := ctl.pmapiByUsername[username]
if !ok {
return fmt.Errorf("user %s does not exist", username)
}
body, err := buildMessage(client, message) body, err := buildMessage(client, message)
if err != nil { if err != nil {
@ -55,7 +58,7 @@ func (cntrl *Controller) AddUserMessage(username string, message *pmapi.Message)
if result.Error != nil { if result.Error != nil {
return errors.Wrap(result.Error, "failed to import message") return errors.Wrap(result.Error, "failed to import message")
} }
cntrl.messageIDsByUsername[username] = append(cntrl.messageIDsByUsername[username], result.MessageID) ctl.messageIDsByUsername[username] = append(ctl.messageIDsByUsername[username], result.MessageID)
} }
return nil return nil
@ -122,10 +125,10 @@ func buildMessageBody(message *pmapi.Message, body *bytes.Buffer) error {
return nil return nil
} }
func (cntrl *Controller) GetMessageID(username, messageIndex string) string { func (ctl *Controller) GetMessageID(username, messageIndex string) string {
idx, err := strconv.Atoi(messageIndex) idx, err := strconv.Atoi(messageIndex)
if err != nil { if err != nil {
panic(fmt.Sprintf("message index %s not found", messageIndex)) panic(fmt.Sprintf("message index %s not found", messageIndex))
} }
return cntrl.messageIDsByUsername[username][idx-1] return ctl.messageIDsByUsername[username][idx-1]
} }

View File

@ -24,21 +24,21 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func (cntrl *Controller) TurnInternetConnectionOff() { func (ctl *Controller) TurnInternetConnectionOff() {
cntrl.noInternetConnection = true ctl.noInternetConnection = true
} }
func (cntrl *Controller) TurnInternetConnectionOn() { func (ctl *Controller) TurnInternetConnectionOn() {
cntrl.noInternetConnection = false ctl.noInternetConnection = false
} }
type fakeTransport struct { type fakeTransport struct {
cntrl *Controller ctl *Controller
transport http.RoundTripper transport http.RoundTripper
} }
func (t *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (t *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.cntrl.noInternetConnection { if t.ctl.noInternetConnection {
return nil, errors.New("no route to host") return nil, errors.New("no route to host")
} }
@ -53,7 +53,7 @@ func (t *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, errors.Wrap(err, "failed to read body") return nil, errors.Wrap(err, "failed to read body")
} }
} }
t.cntrl.recordCall(req.Method, req.URL.Path, body) t.ctl.recordCall(req.Method, req.URL.Path, body)
return t.transport.RoundTrip(req) return t.transport.RoundTrip(req)
} }

View File

@ -23,12 +23,12 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, password string, twoFAEnabled bool) error { func (ctl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, password string, twoFAEnabled bool) error {
if twoFAEnabled { if twoFAEnabled {
return godog.ErrPending return godog.ErrPending
} }
client := cntrl.clientManager.GetClient(user.ID) client := ctl.clientManager.GetClient(user.ID)
authInfo, err := client.AuthInfo(user.Name) authInfo, err := client.AuthInfo(user.Name)
if err != nil { if err != nil {
@ -54,5 +54,7 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList,
return errors.Wrap(err, "failed to clean user") return errors.Wrap(err, "failed to clean user")
} }
ctl.pmapiByUsername[user.Name] = client
return nil return nil
} }