From 8109831c07f4fa5a42c67ab4a7f47d432d460698 Mon Sep 17 00:00:00 2001 From: Michal Horejsek Date: Thu, 11 Mar 2021 14:37:15 +0100 Subject: [PATCH] GODT-35: Finish all details and make tests pass --- go.mod | 5 +- go.sum | 18 +- internal/app/base/base.go | 16 +- internal/bridge/bridge.go | 5 +- internal/frontend/cli-ie/accounts.go | 7 +- internal/frontend/cli-ie/frontend.go | 26 +- internal/frontend/cli-ie/system.go | 8 - internal/frontend/cli-ie/utils.go | 12 +- internal/frontend/cli/accounts.go | 7 +- internal/frontend/cli/frontend.go | 32 +-- internal/frontend/cli/system.go | 8 - internal/frontend/cli/utils.go | 12 +- .../qml/ImportExportUI/DialogExport.qml | 1 - .../qml/ImportExportUI/DialogImport.qml | 3 - .../frontend/qml/ProtonUI/InformationBar.qml | 55 +--- internal/frontend/qml/tst_GuiIE.qml | 4 - internal/frontend/qt-common/accounts.go | 3 +- internal/frontend/qt-common/common.go | 7 - internal/frontend/qt-ie/frontend.go | 25 +- internal/frontend/qt-ie/ui.go | 5 +- internal/frontend/qt/accounts.go | 3 +- internal/frontend/qt/frontend.go | 39 +-- internal/frontend/qt/ui.go | 5 +- internal/frontend/types/types.go | 1 - internal/imap/bridge.go | 3 +- internal/imap/mailbox_messages.go | 4 +- internal/imap/user.go | 2 +- internal/importexport/importexport.go | 27 +- internal/smtp/bridge.go | 2 +- internal/smtp/user.go | 2 +- internal/store/address.go | 2 +- internal/store/change_test.go | 12 +- internal/store/event_loop.go | 40 ++- internal/store/event_loop_test.go | 10 +- internal/store/mailbox_counts.go | 11 +- internal/store/mailbox_counts_test.go | 2 +- internal/store/mailbox_ids_test.go | 16 +- internal/store/mailbox_message.go | 28 +- internal/store/mocks/mocks.go | 12 +- internal/store/mocks/utils_mocks.go | 17 +- internal/store/store.go | 14 +- internal/store/store_test.go | 4 +- internal/store/sync.go | 4 +- internal/store/types.go | 8 +- internal/store/user.go | 6 +- internal/store/user_address_info.go | 2 +- internal/store/user_mailbox.go | 21 +- internal/store/user_message.go | 7 +- internal/store/user_message_test.go | 29 +- internal/store/user_sync.go | 2 +- internal/store/user_sync_test.go | 4 +- internal/transfer/mocks/mocks.go | 3 +- internal/transfer/provider_imap_utils.go | 44 ++- internal/transfer/provider_pmapi_target.go | 21 +- internal/transfer/provider_pmapi_test.go | 8 +- internal/transfer/provider_pmapi_utils.go | 36 +-- internal/updater/updater_test.go | 3 +- internal/users/_user_credentials_test.go | 251 ----------------- internal/users/_user_new_test.go | 112 -------- internal/users/_user_test.go | 89 ------ internal/users/_users_actions_test.go | 143 ---------- internal/users/_users_login_test.go | 219 --------------- internal/users/credentials/store.go | 2 +- internal/users/credentials/store_test.go | 23 +- internal/users/mock_listener.go | 107 ------- internal/users/mocks/listener_mocks.go | 120 ++++++++ internal/users/mocks/mocks.go | 3 +- internal/users/user.go | 96 ++++--- internal/users/user_credentials_test.go | 195 +++++++++++++ internal/users/user_new_test.go | 88 ++++++ internal/users/user_store_test.go | 51 ++++ internal/users/user_test.go | 41 +++ internal/users/users.go | 132 +++++---- internal/users/users_clear_test.go | 49 ++++ internal/users/users_delete_test.go | 69 +++++ internal/users/users_get_test.go | 76 +++++ internal/users/users_login_test.go | 132 +++++++++ internal/users/users_new_test.go | 103 +++---- internal/users/users_test.go | 260 +++++++++--------- pkg/keychain/helper_darwin.go | 4 + pkg/listener/listener.go | 10 + pkg/message/build_fetch.go | 4 +- pkg/message/build_framework_test.go | 4 +- pkg/message/build_test.go | 8 +- pkg/message/flags.go | 6 +- pkg/message/mocks/mocks.go | 17 +- pkg/mime/encoding.go | 2 +- pkg/mime/encoding_test.go | 1 - pkg/pmapi/addresses.go | 39 +-- pkg/pmapi/addresses_test.go | 28 +- pkg/pmapi/attachments.go | 39 --- pkg/pmapi/attachments_test.go | 110 +++----- pkg/pmapi/auth.go | 157 ++++++++++- pkg/pmapi/auth_test.go | 125 ++++++--- pkg/pmapi/auth_types.go | 72 ----- pkg/pmapi/boolean.go | 41 +++ pkg/pmapi/client.go | 38 +-- pkg/pmapi/client_keys.go | 32 ++- pkg/pmapi/client_types.go | 10 +- pkg/pmapi/config.go | 69 ++++- pkg/pmapi/config_default.go | 35 +++ pkg/pmapi/config_qa.go | 48 ++++ pkg/pmapi/contacts.go | 13 +- pkg/pmapi/contacts_test.go | 27 +- pkg/pmapi/context.go | 54 ++++ pkg/pmapi/data_test.go | 20 +- pkg/pmapi/dialer_basic.go | 76 +++++ pkg/pmapi/dialer_pinning.go | 110 ++++++++ pkg/pmapi/dialer_pinning_checker.go | 68 +++++ pkg/pmapi/dialer_pinning_report.go | 144 ++++++++++ pkg/pmapi/dialer_pinning_reporter.go | 107 +++++++ pkg/pmapi/dialer_pinning_reporter_test.go | 62 +++++ pkg/pmapi/dialer_pinning_test.go | 149 ++++++++++ pkg/pmapi/dialer_proxy.go | 144 ++++++++++ pkg/pmapi/dialer_proxy_provider.go | 249 +++++++++++++++++ pkg/pmapi/dialer_proxy_provider_test.go | 187 +++++++++++++ pkg/pmapi/dialer_proxy_test.go | 253 +++++++++++++++++ pkg/pmapi/errors.go | 34 ++- pkg/pmapi/events.go | 26 +- pkg/pmapi/events_test.go | 63 ++--- pkg/pmapi/import.go | 8 +- pkg/pmapi/import_test.go | 100 ++----- pkg/pmapi/key.go | 3 - pkg/pmapi/labels.go | 29 +- pkg/pmapi/labels_test.go | 81 ++---- pkg/pmapi/manager.go | 112 ++++---- pkg/pmapi/manager_auth.go | 54 ++-- pkg/pmapi/manager_download.go | 17 ++ pkg/pmapi/manager_log.go | 71 +++++ pkg/pmapi/manager_metrics.go | 31 ++- ...etrics_test.go => manager_metrics_test.go} | 20 +- pkg/pmapi/manager_ping.go | 53 +++- pkg/pmapi/manager_proxy.go | 32 +++ pkg/pmapi/manager_report.go | 39 ++- pkg/pmapi/manager_report_test.go | 73 ++--- pkg/pmapi/manager_report_types.go | 69 ++--- pkg/pmapi/manager_test.go | 160 +++++------ pkg/pmapi/manager_types.go | 26 +- pkg/pmapi/messages.go | 19 +- pkg/pmapi/messages_test.go | 46 ++-- pkg/pmapi/mocks/mocks.go | 56 ++-- pkg/pmapi/observer.go | 17 ++ pkg/pmapi/out | 25 -- pkg/pmapi/paging.go | 19 +- pkg/pmapi/pmapi.go | 24 ++ pkg/pmapi/response.go | 92 +++++-- pkg/pmapi/server_test.go | 46 +--- pkg/pmapi/types.go | 8 - pkg/pmapi/users_test.go | 20 +- test/context/context.go | 17 +- test/context/credentials.go | 3 - test/context/pmapi_controller.go | 19 +- test/context/pmapi_manager.go | 65 ----- test/context/users.go | 5 +- test/fakeapi/auth.go | 6 +- test/fakeapi/controller.go | 2 +- test/fakeapi/controller_calls.go | 4 +- test/fakeapi/controller_control.go | 17 +- test/fakeapi/counts.go | 5 +- test/fakeapi/fakeapi.go | 76 ++--- test/fakeapi/labels.go | 4 +- test/fakeapi/manager.go | 118 ++++---- test/fakeapi/messages.go | 18 +- test/fakeapi/user.go | 6 + test/features/bridge/start.feature | 14 +- test/liveapi/cleanup.go | 16 +- test/liveapi/controller.go | 15 +- test/liveapi/labels.go | 11 +- test/liveapi/messages.go | 4 +- test/liveapi/transport.go | 8 +- test/liveapi/users.go | 8 +- test/store_checks_test.go | 9 +- test/users_checks_test.go | 35 +-- 173 files changed, 4697 insertions(+), 2897 deletions(-) delete mode 100644 internal/users/_user_credentials_test.go delete mode 100644 internal/users/_user_new_test.go delete mode 100644 internal/users/_user_test.go delete mode 100644 internal/users/_users_actions_test.go delete mode 100644 internal/users/_users_login_test.go delete mode 100644 internal/users/mock_listener.go create mode 100644 internal/users/mocks/listener_mocks.go create mode 100644 internal/users/user_credentials_test.go create mode 100644 internal/users/user_new_test.go create mode 100644 internal/users/user_store_test.go create mode 100644 internal/users/user_test.go create mode 100644 internal/users/users_clear_test.go create mode 100644 internal/users/users_delete_test.go create mode 100644 internal/users/users_get_test.go create mode 100644 internal/users/users_login_test.go delete mode 100644 pkg/pmapi/auth_types.go create mode 100644 pkg/pmapi/boolean.go create mode 100644 pkg/pmapi/config_default.go create mode 100644 pkg/pmapi/config_qa.go create mode 100644 pkg/pmapi/context.go create mode 100644 pkg/pmapi/dialer_basic.go create mode 100644 pkg/pmapi/dialer_pinning.go create mode 100644 pkg/pmapi/dialer_pinning_checker.go create mode 100644 pkg/pmapi/dialer_pinning_report.go create mode 100644 pkg/pmapi/dialer_pinning_reporter.go create mode 100644 pkg/pmapi/dialer_pinning_reporter_test.go create mode 100644 pkg/pmapi/dialer_pinning_test.go create mode 100644 pkg/pmapi/dialer_proxy.go create mode 100644 pkg/pmapi/dialer_proxy_provider.go create mode 100644 pkg/pmapi/dialer_proxy_provider_test.go create mode 100644 pkg/pmapi/dialer_proxy_test.go create mode 100644 pkg/pmapi/manager_log.go rename pkg/pmapi/{metrics_test.go => manager_metrics_test.go} (70%) create mode 100644 pkg/pmapi/manager_proxy.go delete mode 100644 pkg/pmapi/out create mode 100644 pkg/pmapi/pmapi.go delete mode 100644 pkg/pmapi/types.go delete mode 100644 test/context/pmapi_manager.go diff --git a/go.mod b/go.mod index 6413fb98..7834e12e 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/fatih/color v1.9.0 github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect github.com/getsentry/sentry-go v0.8.0 - github.com/go-resty/resty/v2 v2.4.0 + github.com/go-resty/resty/v2 v2.6.0 github.com/golang/mock v1.4.4 github.com/google/go-cmp v0.5.1 github.com/google/uuid v1.1.1 @@ -50,6 +50,7 @@ require ( github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d github.com/logrusorgru/aurora v2.0.3+incompatible github.com/mattn/go-runewidth v0.0.9 // indirect + github.com/miekg/dns v1.1.41 github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce github.com/olekukonko/tablewriter v0.0.4 // indirect github.com/pkg/errors v0.9.1 @@ -63,7 +64,7 @@ require ( github.com/urfave/cli/v2 v2.2.0 github.com/vmihailenco/msgpack/v5 v5.1.3 go.etcd.io/bbolt v1.3.5 - golang.org/x/net v0.0.0-20201224014010-6772e930b67b + golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec ) diff --git a/go.sum b/go.sum index 9b9c7ab6..3d4d4a44 100644 --- a/go.sum +++ b/go.sum @@ -113,8 +113,8 @@ github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclK github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8= -github.com/go-resty/resty/v2 v2.4.0 h1:s6TItTLejEI+2mn98oijC5w/Rk2YU+OA6x0mnZN6r6k= -github.com/go-resty/resty/v2 v2.4.0/go.mod h1:B88+xCTEwvfD94NOuE6GS1wMlnoKNY8eEiNizfNwOwA= +github.com/go-resty/resty/v2 v2.6.0 h1:joIR5PNLM2EFqqESUjCMGXrWmXNHEU9CEiK813oKYS4= +github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w/BIH7cC3Q= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= @@ -195,6 +195,8 @@ github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= +github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY= +github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -310,12 +312,16 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20201224014010-6772e930b67b h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw= -golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -330,6 +336,10 @@ golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210303074136-134d130e1a04 h1:cEhElsAv9LUt9ZUUocxzWe05oFLVd+AA2nstydTeI8g= +golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 h1:Bli41pIlzTzf3KEY06n+xnzK/BESIg2ze4Pgfh/aI8c= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/app/base/base.go b/internal/app/base/base.go index 9f31701a..355ff05c 100644 --- a/internal/app/base/base.go +++ b/internal/app/base/base.go @@ -181,21 +181,18 @@ func New( // nolint[funlen] kc = keychain.NewMissingKeychain() } - // FIXME(conman): Customize config depending on build type (app version, host URL). - cm := pmapi.New(pmapi.DefaultConfig) + cfg := pmapi.NewConfig(configName, constants.Version) + cfg.GetUserAgent = userAgent.String + cfg.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") } + cfg.TLSIssueHandler = func() { listener.Emit(events.TLSCertIssue, "") } + + cm := pmapi.New(cfg) - // FIXME(conman): Should this be a real object, not just created via callbacks? cm.AddConnectionObserver(pmapi.NewConnectionObserver( func() { listener.Emit(events.InternetOffEvent, "") }, func() { listener.Emit(events.InternetOnEvent, "") }, )) - // FIXME(conman): Implement force upgrade observer. - // apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") } - - // FIXME(conman): Set up fancy round tripper with DoH/TLS checks etc. - // cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener)) - jar, err := cookies.NewCookieJar(settingsObj) if err != nil { return nil, err @@ -341,6 +338,7 @@ func (b *Base) run(appMainLoop func(*Base, *cli.Context) error) cli.ActionFunc { } logging.SetLevel(c.String(flagLogLevel)) + b.CM.SetLogging(logrus.WithField("pkg", "pmapi"), logrus.GetLevel() == logrus.TraceLevel) logrus. WithField("appName", b.Name). diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 95d32b64..b8c7bf29 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -65,8 +65,7 @@ func New( // Allow DoH before starting the app if the user has previously set this setting. // This allows us to start even if protonmail is blocked. if s.GetBool(settings.AllowProxyKey) { - // FIXME(conman): Support enable/disable of DoH. - // clientManager.AllowProxy() + clientManager.AllowProxy() } storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener) @@ -120,7 +119,7 @@ func (b *Bridge) heartbeat() { // ReportBug reports a new bug from the user. func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error { - return b.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{ + return b.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{ OS: osType, OSVersion: osVersion, Browser: emailClient, diff --git a/internal/frontend/cli-ie/accounts.go b/internal/frontend/cli-ie/accounts.go index 63b1b88b..8a877142 100644 --- a/internal/frontend/cli-ie/accounts.go +++ b/internal/frontend/cli-ie/accounts.go @@ -21,7 +21,6 @@ import ( "context" "strings" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/abiosoft/ishell" ) @@ -75,13 +74,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen] return } - if auth.TwoFA.Enabled == pmapi.TOTPEnabled { + if auth.HasTwoFactor() { twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty) if twoFactor == "" { return } - err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor}) + err = client.Auth2FA(context.Background(), twoFactor) if err != nil { f.processAPIError(err) return @@ -89,7 +88,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen] } mailboxPassword := password - if auth.PasswordMode == pmapi.TwoPasswordMode { + if auth.HasMailboxPassword() { mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty) } if mailboxPassword == "" { diff --git a/internal/frontend/cli-ie/frontend.go b/internal/frontend/cli-ie/frontend.go index 90e13c37..d945605e 100644 --- a/internal/frontend/cli-ie/frontend.go +++ b/internal/frontend/cli-ie/frontend.go @@ -84,11 +84,6 @@ func New( //nolint[funlen] Aliases: []string{"u", "version", "v"}, Func: fe.checkUpdates, }) - checkCmd.AddCmd(&ishell.Cmd{Name: "internet", - Help: "check internet connection. (aliases: i, conn, connection)", - Aliases: []string{"i", "con", "connection"}, - Func: fe.checkInternetConnection, - }) fe.AddCmd(checkCmd) // Print info commands. @@ -177,13 +172,13 @@ func New( //nolint[funlen] } func (f *frontendCLI) watchEvents() { - errorCh := f.getEventChannel(events.ErrorEvent) - credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent) - internetOffCh := f.getEventChannel(events.InternetOffEvent) - internetOnCh := f.getEventChannel(events.InternetOnEvent) - addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent) - logoutCh := f.getEventChannel(events.LogoutEvent) - certIssue := f.getEventChannel(events.TLSCertIssue) + errorCh := f.eventListener.ProvideChannel(events.ErrorEvent) + credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent) + internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent) + internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent) + addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent) + logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent) + certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue) for { select { case errorDetails := <-errorCh: @@ -208,13 +203,6 @@ func (f *frontendCLI) watchEvents() { } } -func (f *frontendCLI) getEventChannel(event string) <-chan string { - ch := make(chan string) - f.eventListener.Add(event, ch) - f.eventListener.RetryEmit(event) - return ch -} - // Loop starts the frontend loop with an interactive shell. func (f *frontendCLI) Loop() error { f.Print(` diff --git a/internal/frontend/cli-ie/system.go b/internal/frontend/cli-ie/system.go index 45f04efe..bcc1c69d 100644 --- a/internal/frontend/cli-ie/system.go +++ b/internal/frontend/cli-ie/system.go @@ -29,14 +29,6 @@ func (f *frontendCLI) restart(c *ishell.Context) { } } -func (f *frontendCLI) checkInternetConnection(c *ishell.Context) { - if f.ie.CheckConnection() == nil { - f.Println("Internet connection is available.") - } else { - f.Println("Can not contact the server, please check your internet connection.") - } -} - func (f *frontendCLI) printLogDir(c *ishell.Context) { if path, err := f.locations.ProvideLogsPath(); err != nil { f.Println("Failed to determine location of log files") diff --git a/internal/frontend/cli-ie/utils.go b/internal/frontend/cli-ie/utils.go index 225cd0b1..361303bd 100644 --- a/internal/frontend/cli-ie/utils.go +++ b/internal/frontend/cli-ie/utils.go @@ -20,6 +20,7 @@ package cliie import ( "strings" + "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/fatih/color" ) @@ -70,13 +71,10 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) { func (f *frontendCLI) processAPIError(err error) { log.Warn("API error: ", err) switch err { - // FIXME(conman): How to handle various API errors? - /* - case pmapi.ErrNoConnection: - f.notifyInternetOff() - case pmapi.ErrUpgradeApplication: - f.notifyNeedUpgrade() - */ + case pmapi.ErrNoConnection: + f.notifyInternetOff() + case pmapi.ErrUpgradeApplication: + f.notifyNeedUpgrade() default: f.Println("Server error:", err.Error()) } diff --git a/internal/frontend/cli/accounts.go b/internal/frontend/cli/accounts.go index d3fa40b6..cc71a386 100644 --- a/internal/frontend/cli/accounts.go +++ b/internal/frontend/cli/accounts.go @@ -24,7 +24,6 @@ import ( "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/config/settings" "github.com/ProtonMail/proton-bridge/internal/frontend/types" - pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/abiosoft/ishell" ) @@ -122,13 +121,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen] return } - if auth.TwoFA.Enabled == pmapi.TOTPEnabled { + if auth.HasTwoFactor() { twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty) if twoFactor == "" { return } - err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor}) + err = client.Auth2FA(context.Background(), twoFactor) if err != nil { f.processAPIError(err) return @@ -136,7 +135,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen] } mailboxPassword := password - if auth.PasswordMode == pmapi.TwoPasswordMode { + if auth.HasMailboxPassword() { mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty) } if mailboxPassword == "" { diff --git a/internal/frontend/cli/frontend.go b/internal/frontend/cli/frontend.go index 9118b5b1..a3e7ff01 100644 --- a/internal/frontend/cli/frontend.go +++ b/internal/frontend/cli/frontend.go @@ -157,15 +157,6 @@ func New( //nolint[funlen] }) fe.AddCmd(updatesCmd) - // Check commands. - checkCmd := &ishell.Cmd{Name: "check", Help: "check internet connection or new version."} - checkCmd.AddCmd(&ishell.Cmd{Name: "internet", - Help: "check internet connection. (aliases: i, conn, connection)", - Aliases: []string{"i", "con", "connection"}, - Func: fe.checkInternetConnection, - }) - fe.AddCmd(checkCmd) - // Print info commands. fe.AddCmd(&ishell.Cmd{Name: "log-dir", Help: "print path to directory with logs. (aliases: log, logs)", @@ -228,14 +219,14 @@ func New( //nolint[funlen] } func (f *frontendCLI) watchEvents() { - errorCh := f.getEventChannel(events.ErrorEvent) - credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent) - internetOffCh := f.getEventChannel(events.InternetOffEvent) - internetOnCh := f.getEventChannel(events.InternetOnEvent) - addressChangedCh := f.getEventChannel(events.AddressChangedEvent) - addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent) - logoutCh := f.getEventChannel(events.LogoutEvent) - certIssue := f.getEventChannel(events.TLSCertIssue) + errorCh := f.eventListener.ProvideChannel(events.ErrorEvent) + credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent) + internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent) + internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent) + addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent) + addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent) + logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent) + certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue) for { select { case errorDetails := <-errorCh: @@ -262,13 +253,6 @@ func (f *frontendCLI) watchEvents() { } } -func (f *frontendCLI) getEventChannel(event string) <-chan string { - ch := make(chan string) - f.eventListener.Add(event, ch) - f.eventListener.RetryEmit(event) - return ch -} - // Loop starts the frontend loop with an interactive shell. func (f *frontendCLI) Loop() error { f.Print(` diff --git a/internal/frontend/cli/system.go b/internal/frontend/cli/system.go index 74c64765..f601d1a5 100644 --- a/internal/frontend/cli/system.go +++ b/internal/frontend/cli/system.go @@ -39,14 +39,6 @@ func (f *frontendCLI) restart(c *ishell.Context) { } } -func (f *frontendCLI) checkInternetConnection(c *ishell.Context) { - if f.bridge.CheckConnection() == nil { - f.Println("Internet connection is available.") - } else { - f.Println("Can not contact the server, please check your internet connection.") - } -} - func (f *frontendCLI) printLogDir(c *ishell.Context) { if path, err := f.locations.ProvideLogsPath(); err != nil { f.Println("Failed to determine location of log files") diff --git a/internal/frontend/cli/utils.go b/internal/frontend/cli/utils.go index 2d8eb195..0ce5183f 100644 --- a/internal/frontend/cli/utils.go +++ b/internal/frontend/cli/utils.go @@ -20,6 +20,7 @@ package cli import ( "strings" + pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/fatih/color" ) @@ -70,13 +71,10 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) { func (f *frontendCLI) processAPIError(err error) { log.Warn("API error: ", err) switch err { - // FIXME(conman): How to handle various API errors? - /* - case pmapi.ErrNoConnection: - f.notifyInternetOff() - case pmapi.ErrUpgradeApplication: - f.notifyNeedUpgrade() - */ + case pmapi.ErrNoConnection: + f.notifyInternetOff() + case pmapi.ErrUpgradeApplication: + f.notifyNeedUpgrade() default: f.Println("Server error:", err.Error()) } diff --git a/internal/frontend/qml/ImportExportUI/DialogExport.qml b/internal/frontend/qml/ImportExportUI/DialogExport.qml index cf457245..16a48805 100644 --- a/internal/frontend/qml/ImportExportUI/DialogExport.qml +++ b/internal/frontend/qml/ImportExportUI/DialogExport.qml @@ -409,7 +409,6 @@ Dialog { onShow: { if (winMain.updateState==gui.enums.statusNoInternet) { - go.checkInternet() if (winMain.updateState==gui.enums.statusNoInternet) { go.notifyError(gui.enums.errNoInternet) root.hide() diff --git a/internal/frontend/qml/ImportExportUI/DialogImport.qml b/internal/frontend/qml/ImportExportUI/DialogImport.qml index 92b3a0b1..0a862039 100644 --- a/internal/frontend/qml/ImportExportUI/DialogImport.qml +++ b/internal/frontend/qml/ImportExportUI/DialogImport.qml @@ -857,14 +857,12 @@ Dialog { inputPort . checkIsANumber() //emailProvider . currentIndex!=0 )) isOK = false - go.checkInternet() if (winMain.updateState == gui.enums.statusNoInternet) { // todo: use main error dialog for this errorPopup.show(qsTr("Please check your internet connection.")) return false } break case 2: // loading structure - go.checkInternet() if (winMain.updateState == gui.enums.statusNoInternet) { errorPopup.show(qsTr("Please check your internet connection.")) return false @@ -949,7 +947,6 @@ Dialog { onShow : { root.clear() if (winMain.updateState==gui.enums.statusNoInternet) { - go.checkInternet() if (winMain.updateState==gui.enums.statusNoInternet) { winMain.popupMessage.show(go.canNotReachAPI) root.hide() diff --git a/internal/frontend/qml/ProtonUI/InformationBar.qml b/internal/frontend/qml/ProtonUI/InformationBar.qml index dc012c4f..ef2fa00a 100644 --- a/internal/frontend/qml/ProtonUI/InformationBar.qml +++ b/internal/frontend/qml/ProtonUI/InformationBar.qml @@ -25,33 +25,12 @@ import ProtonUI 1.0 Rectangle { id: root property var iTry: 0 - property var secLeft: 0 property var second: 1000 // convert millisecond to second - property var checkInterval: [ 5, 10, 30, 60, 120, 300, 600 ] // seconds property bool isVisible: true property var fontSize : 1.2 * Style.main.fontSize color : "black" state: "upToDate" - Timer { - id: retryInternet - interval: second - triggeredOnStart: false - repeat: true - onTriggered : { - secLeft-- - if (secLeft <= 0) { - retryInternet.stop() - go.checkInternet() - if (iTry < checkInterval.length-1) { - iTry++ - } - secLeft=checkInterval[iTry] - retryInternet.start() - } - } - } - Row { id: messageRow anchors.centerIn: root @@ -110,16 +89,12 @@ Rectangle { case "internetCheck": break; case "noInternet" : - retryInternet.start() - secLeft=checkInterval[iTry] break; case "oldVersion": break; case "forceUpdate": break; case "upToDate": - iTry = 0 - secLeft=checkInterval[iTry] break; case "updateRestart": break; @@ -128,24 +103,6 @@ Rectangle { default : break; } - - if (root.state!="noInternet") { - retryInternet.stop() - } - } - - function timeToRetry() { - if (secLeft==1){ - return qsTr("a second", "time to wait till internet connection is retried") - } else if (secLeft<60){ - return secLeft + " " + qsTr("seconds", "time to wait till internet connection is retried") - } else { - var leading = ""+secLeft%60 - if (leading.length < 2) { - leading = "0" + leading - } - return Math.floor(secLeft/60) + ":" + leading - } } states: [ @@ -194,23 +151,15 @@ Rectangle { PropertyChanges { target: message color: Style.main.line - text: qsTr("Cannot contact server. Retrying in ", "displayed when the app is disconnected from the internet or server has problems")+timeToRetry()+"." + text: qsTr("Cannot contact server. Please wait...", "displayed when the app is disconnected from the internet or server has problems") } PropertyChanges { target: linkText visible: false } - PropertyChanges { - target: actionText - visible: true - text: qsTr("Retry now", "click to try to connect to the internet when the app is disconnected from the internet") - onClicked: { - go.checkInternet() - } - } PropertyChanges { target: separatorText - visible: true + visible: false text: "|" } PropertyChanges { diff --git a/internal/frontend/qml/tst_GuiIE.qml b/internal/frontend/qml/tst_GuiIE.qml index 5988fd8a..cd58ae84 100644 --- a/internal/frontend/qml/tst_GuiIE.qml +++ b/internal/frontend/qml/tst_GuiIE.qml @@ -1331,10 +1331,6 @@ Window { return (fname!="fail") } - function checkInternet() { - // nothing to do - } - function loadImportReports(fname) { console.log("load import reports for ", fname) } diff --git a/internal/frontend/qt-common/accounts.go b/internal/frontend/qt-common/accounts.go index f372aee7..264aa946 100644 --- a/internal/frontend/qt-common/accounts.go +++ b/internal/frontend/qt-common/accounts.go @@ -20,6 +20,7 @@ package qtcommon import ( + "context" "fmt" "strings" "sync" @@ -207,7 +208,7 @@ func (a *Accounts) Auth2FA(twoFacAuth string) int { if a.auth == nil || a.authClient == nil { err = fmt.Errorf("missing authentication in auth2FA %p %p", a.auth, a.authClient) } else { - err = a.authClient.Auth2FA(twoFacAuth, a.auth) + err = a.authClient.Auth2FA(context.Background(), twoFacAuth) } if a.showLoginError(err, "auth2FA") { diff --git a/internal/frontend/qt-common/common.go b/internal/frontend/qt-common/common.go index 37eb57b1..f6281d4b 100644 --- a/internal/frontend/qt-common/common.go +++ b/internal/frontend/qt-common/common.go @@ -113,10 +113,3 @@ type Listener interface { Add(string, chan<- string) RetryEmit(string) } - -func MakeAndRegisterEvent(eventListener Listener, event string) <-chan string { - ch := make(chan string) - eventListener.Add(event, ch) - eventListener.RetryEmit(event) - return ch -} diff --git a/internal/frontend/qt-ie/frontend.go b/internal/frontend/qt-ie/frontend.go index 9cd12261..f63753b8 100644 --- a/internal/frontend/qt-ie/frontend.go +++ b/internal/frontend/qt-ie/frontend.go @@ -143,16 +143,16 @@ func (f *FrontendQt) NotifySilentUpdateError(err error) { } func (f *FrontendQt) watchEvents() { - credentialsErrorCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.CredentialsErrorEvent) - internetOffCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOffEvent) - internetOnCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOnEvent) - secondInstanceCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.SecondInstanceEvent) - restartBridgeCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.RestartBridgeEvent) - addressChangedCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedEvent) - addressChangedLogoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedLogoutEvent) - logoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.LogoutEvent) - updateApplicationCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UpgradeApplicationEvent) - newUserCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UserRefreshEvent) + credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent) + internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent) + internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent) + secondInstanceCh := f.eventListener.ProvideChannel(events.SecondInstanceEvent) + restartBridgeCh := f.eventListener.ProvideChannel(events.RestartBridgeEvent) + addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent) + addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent) + logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent) + updateApplicationCh := f.eventListener.ProvideChannel(events.UpgradeApplicationEvent) + newUserCh := f.eventListener.ProvideChannel(events.UserRefreshEvent) for { select { case <-credentialsErrorCh: @@ -351,11 +351,6 @@ func (f *FrontendQt) sendBug(description, emailClient, address string) bool { // } //} -// checkInternet is almost idetical to bridge -func (f *FrontendQt) checkInternet() { - f.Qml.SetConnectionStatus(f.ie.CheckConnection() == nil) -} - func (f *FrontendQt) showError(code int, err error) { f.Qml.SetErrorDescription(err.Error()) log.WithField("code", code).Errorln(err.Error()) diff --git a/internal/frontend/qt-ie/ui.go b/internal/frontend/qt-ie/ui.go index b7426691..79f19c98 100644 --- a/internal/frontend/qt-ie/ui.go +++ b/internal/frontend/qt-ie/ui.go @@ -77,8 +77,7 @@ type GoQMLInterface struct { _ string `property:"credentialsNotRemoved"` _ string `property:"versionCheckFailed"` // - _ func(isAvailable bool) `signal:"setConnectionStatus"` - _ func() `slot:"checkInternet"` + _ func(isAvailable bool) `signal:"setConnectionStatus"` _ func() `slot:"setToRestart"` @@ -189,8 +188,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) { return f.programVersion }) - s.ConnectCheckInternet(f.checkInternet) - s.ConnectSetToRestart(f.restarter.SetToRestart) s.ConnectLoadStructureForExport(f.LoadStructureForExport) diff --git a/internal/frontend/qt/accounts.go b/internal/frontend/qt/accounts.go index ec869f66..0244f6ff 100644 --- a/internal/frontend/qt/accounts.go +++ b/internal/frontend/qt/accounts.go @@ -20,6 +20,7 @@ package qt import ( + "context" "fmt" "strings" @@ -173,7 +174,7 @@ func (s *FrontendQt) auth2FA(twoFacAuth string) int { if s.auth == nil || s.authClient == nil { err = fmt.Errorf("missing authentication in auth2FA %p %p", s.auth, s.authClient) } else { - err = s.authClient.Auth2FA(twoFacAuth, s.auth) + err = s.authClient.Auth2FA(context.Background(), twoFacAuth) } if s.showLoginError(err, "auth2FA") { diff --git a/internal/frontend/qt/frontend.go b/internal/frontend/qt/frontend.go index c40fbcd2..c6b4f0f5 100644 --- a/internal/frontend/qt/frontend.go +++ b/internal/frontend/qt/frontend.go @@ -191,20 +191,20 @@ func (s *FrontendQt) NotifySilentUpdateError(err error) { func (s *FrontendQt) watchEvents() { s.WaitUntilFrontendIsReady() - errorCh := s.getEventChannel(events.ErrorEvent) - credentialsErrorCh := s.getEventChannel(events.CredentialsErrorEvent) - outgoingNoEncCh := s.getEventChannel(events.OutgoingNoEncEvent) - noActiveKeyForRecipientCh := s.getEventChannel(events.NoActiveKeyForRecipientEvent) - internetOffCh := s.getEventChannel(events.InternetOffEvent) - internetOnCh := s.getEventChannel(events.InternetOnEvent) - secondInstanceCh := s.getEventChannel(events.SecondInstanceEvent) - restartBridgeCh := s.getEventChannel(events.RestartBridgeEvent) - addressChangedCh := s.getEventChannel(events.AddressChangedEvent) - addressChangedLogoutCh := s.getEventChannel(events.AddressChangedLogoutEvent) - logoutCh := s.getEventChannel(events.LogoutEvent) - updateApplicationCh := s.getEventChannel(events.UpgradeApplicationEvent) - newUserCh := s.getEventChannel(events.UserRefreshEvent) - certIssue := s.getEventChannel(events.TLSCertIssue) + errorCh := s.eventListener.ProvideChannel(events.ErrorEvent) + credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent) + outgoingNoEncCh := s.eventListener.ProvideChannel(events.OutgoingNoEncEvent) + noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent) + internetOffCh := s.eventListener.ProvideChannel(events.InternetOffEvent) + internetOnCh := s.eventListener.ProvideChannel(events.InternetOnEvent) + secondInstanceCh := s.eventListener.ProvideChannel(events.SecondInstanceEvent) + restartBridgeCh := s.eventListener.ProvideChannel(events.RestartBridgeEvent) + addressChangedCh := s.eventListener.ProvideChannel(events.AddressChangedEvent) + addressChangedLogoutCh := s.eventListener.ProvideChannel(events.AddressChangedLogoutEvent) + logoutCh := s.eventListener.ProvideChannel(events.LogoutEvent) + updateApplicationCh := s.eventListener.ProvideChannel(events.UpgradeApplicationEvent) + newUserCh := s.eventListener.ProvideChannel(events.UserRefreshEvent) + certIssue := s.eventListener.ProvideChannel(events.TLSCertIssue) for { select { case errorDetails := <-errorCh: @@ -254,13 +254,6 @@ func (s *FrontendQt) watchEvents() { } } -func (s *FrontendQt) getEventChannel(event string) <-chan string { - ch := make(chan string) - s.eventListener.Add(event, ch) - s.eventListener.RetryEmit(event) - return ch -} - // Loop function for tests. // // It runs QtExecute in new thread with function returning itself after setup. @@ -653,10 +646,6 @@ func (s *FrontendQt) isSMTPSTARTTLS() bool { return !s.settings.GetBool(settings.SMTPSSLKey) } -func (s *FrontendQt) checkInternet() { - s.Qml.SetConnectionStatus(s.bridge.CheckConnection() == nil) -} - func (s *FrontendQt) switchAddressModeUser(iAccount int) { defer s.Qml.ProcessFinished() userID := s.Accounts.get(iAccount).UserID() diff --git a/internal/frontend/qt/ui.go b/internal/frontend/qt/ui.go index 00363cb1..c8e2eda1 100644 --- a/internal/frontend/qt/ui.go +++ b/internal/frontend/qt/ui.go @@ -83,8 +83,7 @@ type GoQMLInterface struct { _ float32 `property:"progress"` _ string `property:"progressDescription"` - _ func(isAvailable bool) `signal:"setConnectionStatus"` - _ func() `slot:"checkInternet"` + _ func(isAvailable bool) `signal:"setConnectionStatus"` _ func() `slot:"setToRestart"` @@ -205,8 +204,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) { return f.programVer }) - s.ConnectCheckInternet(f.checkInternet) - s.ConnectSetToRestart(f.restarter.SetToRestart) s.ConnectToggleIsReportingOutgoingNoEnc(f.toggleIsReportingOutgoingNoEnc) diff --git a/internal/frontend/types/types.go b/internal/frontend/types/types.go index 8f6556fc..131e8cfa 100644 --- a/internal/frontend/types/types.go +++ b/internal/frontend/types/types.go @@ -55,7 +55,6 @@ type UserManager interface { GetUser(query string) (User, error) DeleteUser(userID string, clearCache bool) error ClearData() error - CheckConnection() error } // User is an interface of user needed by frontend. diff --git a/internal/imap/bridge.go b/internal/imap/bridge.go index 7f2ef36b..bd4580f5 100644 --- a/internal/imap/bridge.go +++ b/internal/imap/bridge.go @@ -38,11 +38,10 @@ type bridgeUser interface { IsCombinedAddressMode() bool GetAddressID(address string) (string, error) GetPrimaryAddress() string - UpdateUser() error Logout() error CloseConnection(address string) GetStore() storeUserProvider - GetTemporaryPMAPIClient() pmapi.Client + GetClient() pmapi.Client } type bridgeWrap struct { diff --git a/internal/imap/mailbox_messages.go b/internal/imap/mailbox_messages.go index 00e3f2a9..55ca7229 100644 --- a/internal/imap/mailbox_messages.go +++ b/internal/imap/mailbox_messages.go @@ -422,7 +422,7 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria) if isStringInList(m.LabelIDs, pmapi.StarredLabel) { messageFlagsMap[imap.FlaggedFlag] = true } - if m.Unread == 0 { + if !m.Unread { messageFlagsMap[imap.SeenFlag] = true } if m.Has(pmapi.FlagReplied) || m.Has(pmapi.FlagRepliedAll) { @@ -560,7 +560,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima return nil, err } - if storeMessage.Message().Unread == 1 { + if storeMessage.Message().Unread { for section := range msg.Body { // Peek means get messages without marking them as read. // If client does not only ask for peek, we have to mark them as read. diff --git a/internal/imap/user.go b/internal/imap/user.go index 35c5518c..8c4059c7 100644 --- a/internal/imap/user.go +++ b/internal/imap/user.go @@ -93,7 +93,7 @@ func newIMAPUser( // This method should eventually no longer be necessary. Everything should go via store. func (iu *imapUser) client() pmapi.Client { - return iu.user.GetTemporaryPMAPIClient() + return iu.user.GetClient() } func (iu *imapUser) isSubscribed(labelID string) bool { diff --git a/internal/importexport/importexport.go b/internal/importexport/importexport.go index 21ee1ac1..c7e38141 100644 --- a/internal/importexport/importexport.go +++ b/internal/importexport/importexport.go @@ -22,6 +22,7 @@ import ( "bytes" "context" + "github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/transfer" "github.com/ProtonMail/proton-bridge/internal/users" "github.com/ProtonMail/proton-bridge/pkg/pmapi" @@ -40,6 +41,7 @@ type ImportExport struct { locations Locator cache Cacher panicHandler users.PanicHandler + eventListener listener.Listener clientManager pmapi.Manager } @@ -59,13 +61,14 @@ func New( locations: locations, cache: cache, panicHandler: panicHandler, + eventListener: eventListener, clientManager: clientManager, } } // ReportBug reports a new bug from the user. func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error { - return ie.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{ + return ie.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{ OS: osType, OSVersion: osVersion, Browser: emailClient, @@ -89,7 +92,7 @@ func (ie *ImportExport) ReportFile(osType, osVersion, accountName, address strin report.AddAttachment("log", "report.log", bytes.NewReader(logdata)) - return ie.clientManager.ReportBug(context.TODO(), report) + return ie.clientManager.ReportBug(context.Background(), report) } // GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account. @@ -162,5 +165,23 @@ func (ie *ImportExport) getPMAPIProvider(username, address string) (*transfer.PM log.WithError(err).Info("Address does not exist, using all addresses") } - return transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID) + provider, err := transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID) + if err != nil { + return nil, err + } + + go func() { + internetOffCh := ie.eventListener.ProvideChannel(events.InternetOffEvent) + internetOnCh := ie.eventListener.ProvideChannel(events.InternetOnEvent) + for { + select { + case <-internetOffCh: + provider.SetConnectionDown() + case <-internetOnCh: + provider.SetConnectionUp() + } + } + }() + + return provider, nil } diff --git a/internal/smtp/bridge.go b/internal/smtp/bridge.go index af7e8c4d..aa133c9c 100644 --- a/internal/smtp/bridge.go +++ b/internal/smtp/bridge.go @@ -31,7 +31,7 @@ type bridgeUser interface { CheckBridgeLogin(password string) error IsCombinedAddressMode() bool GetAddressID(address string) (string, error) - GetTemporaryPMAPIClient() pmapi.Client + GetClient() pmapi.Client GetStore() storeUserProvider } diff --git a/internal/smtp/user.go b/internal/smtp/user.go index ecd7dae4..dadc5522 100644 --- a/internal/smtp/user.go +++ b/internal/smtp/user.go @@ -81,7 +81,7 @@ func newSMTPUser( // This method should eventually no longer be necessary. Everything should go via store. func (su *smtpUser) client() pmapi.Client { - return su.user.GetTemporaryPMAPIClient() + return su.user.GetClient() } // Send sends an email from the given address to the given addresses with the given body. diff --git a/internal/store/address.go b/internal/store/address.go index adea9324..3f281a96 100644 --- a/internal/store/address.go +++ b/internal/store/address.go @@ -90,7 +90,7 @@ func getLabelPrefix(l *pmapi.Label) string { switch { case pmapi.IsSystemLabel(l.ID): return "" - case l.Exclusive == 1: + case bool(l.Exclusive): return UserFoldersPrefix default: return UserLabelsPrefix diff --git a/internal/store/change_test.go b/internal/store/change_test.go index 574ca8be..c4866d58 100644 --- a/internal/store/change_test.go +++ b/internal/store/change_test.go @@ -37,8 +37,8 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) { m.newStoreNoEvents(true) m.store.SetChangeNotifier(m.changeNotifier) - insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) - insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel}) + insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) + insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel}) } func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) { @@ -52,8 +52,8 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) { m.newStoreNoEvents(true) m.store.SetChangeNotifier(m.changeNotifier) - msg1 := getTestMessage("msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) - msg2 := getTestMessage("msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel}) + msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) + msg2 := getTestMessage("msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel}) require.Nil(t, m.store.createOrUpdateMessagesEvent([]*pmapi.Message{msg1, msg2})) } @@ -63,8 +63,8 @@ func TestNotifyChangeDeleteMessage(t *testing.T) { m.newStoreNoEvents(true) - insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) - insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel}) + insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) + insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel}) m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(2)) m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(1)) diff --git a/internal/store/event_loop.go b/internal/store/event_loop.go index b3d032b7..e17deaf5 100644 --- a/internal/store/event_loop.go +++ b/internal/store/event_loop.go @@ -81,7 +81,7 @@ func (loop *eventLoop) client() pmapi.Client { func (loop *eventLoop) setFirstEventID() (err error) { loop.log.Info("Setting first event ID") - event, err := loop.client().GetEvent(context.TODO(), "") + event, err := loop.client().GetEvent(context.Background(), "") if err != nil { loop.log.WithError(err).Error("Could not get latest event ID") return @@ -222,8 +222,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun // We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually // (e.g. no internet, ulimit reached etc.) defer func() { - // FIXME(conman): How to handle errors of different types? - if errors.Is(err, pmapi.ErrNoConnection) { + if errors.Cause(err) == pmapi.ErrNoConnection { l.Warn("Internet unavailable") err = nil } @@ -234,20 +233,17 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun err = nil } - // FIXME(conman): Handle force upgrade. - /* - if errors.Cause(err) == pmapi.ErrUpgradeApplication { - l.Warn("Need to upgrade application") - err = nil - } - */ + if errors.Cause(err) == pmapi.ErrUpgradeApplication { + l.Warn("Need to upgrade application") + err = nil + } if err == nil { loop.errCounter = 0 } // All errors except ErrUnauthorized (which is not possible to recover from) are ignored. - if !errors.Is(err, pmapi.ErrUnauthorized) { + if err != nil && errors.Cause(err) != pmapi.ErrUnauthorized { l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped") loop.errCounter++ if loop.errCounter == errMaxSentry { @@ -268,7 +264,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun loop.pollCounter++ var event *pmapi.Event - if event, err = loop.client().GetEvent(context.TODO(), loop.currentEventID); err != nil { + if event, err = loop.client().GetEvent(context.Background(), loop.currentEventID); err != nil { return false, errors.Wrap(err, "failed to get event") } @@ -295,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun } } - return event.More == 1, err + return bool(event.More), err } func (loop *eventLoop) processEvent(event *pmapi.Event) (err error) { @@ -354,7 +350,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap // Get old addresses for comparisons before updating user. oldList := loop.client().Addresses() - if err = loop.user.UpdateUser(); err != nil { + if err = loop.user.UpdateUser(context.Background()); err != nil { if logoutErr := loop.user.Logout(); logoutErr != nil { log.WithError(logoutErr).Error("Failed to logout user after failed update") } @@ -465,16 +461,12 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...") - if msg, err = loop.client().GetMessage(context.TODO(), message.ID); err != nil { - // FIXME(conman): How to handle error of this particular type? - - /* - if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok { - msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API") - err = nil - continue - } - */ + if msg, err = loop.client().GetMessage(context.Background(), message.ID); err != nil { + if _, ok := err.(pmapi.ErrUnprocessableEntity); ok { + msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API") + err = nil + continue + } return errors.Wrap(err, "failed to get message from API for updating") } diff --git a/internal/store/event_loop_test.go b/internal/store/event_loop_test.go index a51566a2..845a03a6 100644 --- a/internal/store/event_loop_test.go +++ b/internal/store/event_loop_test.go @@ -42,15 +42,15 @@ func TestEventLoopProcessMoreEvents(t *testing.T) { // next event if there is `More` of them. m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{ EventID: "event50", - More: 1, + More: true, }, nil), m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{ EventID: "event70", - More: 0, + More: false, }, nil), m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{ EventID: "event71", - More: 0, + More: false, }, nil), ) m.newStoreNoEvents(true) @@ -188,7 +188,7 @@ func TestEventLoopUpdateMessage(t *testing.T) { msg := &pmapi.Message{ ID: "msg1", Subject: "old", - Unread: 0, + Unread: false, Flags: 10, Sender: address1, ToList: []*mail.Address{address2}, @@ -200,7 +200,7 @@ func TestEventLoopUpdateMessage(t *testing.T) { newMsg := &pmapi.Message{ ID: "msg1", Subject: "new", - Unread: 1, + Unread: true, Flags: 11, Sender: address2, ToList: []*mail.Address{address1}, diff --git a/internal/store/mailbox_counts.go b/internal/store/mailbox_counts.go index ca48a7c7..80d98ed8 100644 --- a/internal/store/mailbox_counts.go +++ b/internal/store/mailbox_counts.go @@ -129,17 +129,10 @@ func (mc *mailboxCounts) getPMLabel() *pmapi.Label { Color: mc.Color, Order: mc.Order, Type: pmapi.LabelTypeMailbox, - Exclusive: mc.isExclusive(), + Exclusive: pmapi.Boolean(mc.IsFolder), } } -func (mc *mailboxCounts) isExclusive() int { - if mc.IsFolder { - return 1 - } - return 0 -} - // createOrUpdateMailboxCountsBuckets will not change the on-API-counts. func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) error { // Don't forget about system folders. @@ -162,7 +155,7 @@ func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) er mailbox.LabelName = label.Path mailbox.Color = label.Color mailbox.Order = label.Order - mailbox.IsFolder = label.Exclusive == 1 + mailbox.IsFolder = bool(label.Exclusive) // Write. if err = mailbox.txWriteToBucket(countsBkt); err != nil { diff --git a/internal/store/mailbox_counts_test.go b/internal/store/mailbox_counts_test.go index 6f9ce816..21f3c374 100644 --- a/internal/store/mailbox_counts_test.go +++ b/internal/store/mailbox_counts_test.go @@ -75,7 +75,7 @@ func TestMailboxNames(t *testing.T) { newLabel(100, "labelID1", "Label1"), newLabel(1000, "folderID1", "Folder1"), } - foldersAndLabels[1].Exclusive = 1 + foldersAndLabels[1].Exclusive = true for _, counts := range getSystemFolders() { foldersAndLabels = append(foldersAndLabels, counts.getPMLabel()) diff --git a/internal/store/mailbox_ids_test.go b/internal/store/mailbox_ids_test.go index da1ca4bb..e0ecc3ff 100644 --- a/internal/store/mailbox_ids_test.go +++ b/internal/store/mailbox_ids_test.go @@ -37,10 +37,10 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) { m.newStoreNoEvents(true) - insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) - insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel}) - insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) - insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{pmapi.AllMailLabel}) + insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) + insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel}) + insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) + insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{pmapi.AllMailLabel}) checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"}) @@ -82,20 +82,20 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen] m.newStoreNoEvents(true) - tstMsg := getTestMessage("msg1", "Without external ID", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel}) + tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel}) require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg)) - tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel}) + tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel}) tstMsg.ExternalID = " externalID-non-pm-com " require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg)) - tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel}) + tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel}) tstMsg.ExternalID = "" tstMsg.Header = mail.Header{"References": []string{"wrongID", "externalID-non-pm-com", "msg2"}} require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg)) // Not sure if this is a real-world scenario but we should be able to address this properly. - tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel}) + tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel}) tstMsg.ExternalID = " < external.()+*[]ID@another.pm.me > " require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg)) diff --git a/internal/store/mailbox_message.go b/internal/store/mailbox_message.go index af3d2162..57d129c0 100644 --- a/internal/store/mailbox_message.go +++ b/internal/store/mailbox_message.go @@ -18,8 +18,6 @@ package store import ( - "context" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -43,7 +41,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) { // FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message // wrapping it. func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) { - msg, err := storeMailbox.client().GetMessage(context.TODO(), apiID) + msg, err := storeMailbox.client().GetMessage(exposeContextForIMAP(), apiID) if err != nil { return nil, err } @@ -70,7 +68,7 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe Message: body, } - res, err := storeMailbox.client().Import(context.TODO(), pmapi.ImportMsgReqs{importReqs}) + res, err := storeMailbox.client().Import(exposeContextForIMAP(), pmapi.ImportMsgReqs{importReqs}) if err != nil { return err } @@ -99,7 +97,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error { return ErrAllMailOpNotAllowed } defer storeMailbox.pollNow() - return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, storeMailbox.labelID) + return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID) } // UnlabelMessages removes the label by calling an API. @@ -112,7 +110,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error { return ErrAllMailOpNotAllowed } defer storeMailbox.pollNow() - return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID) + return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID) } // MarkMessagesRead marks the message read by calling an API. @@ -132,14 +130,14 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error { // Therefore we do not issue API update if the message is already read. ids := []string{} for _, apiID := range apiIDs { - if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread == 1 { + if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread { ids = append(ids, apiID) } } if len(ids) == 0 { return nil } - return storeMailbox.client().MarkMessagesRead(context.TODO(), ids) + return storeMailbox.client().MarkMessagesRead(exposeContextForIMAP(), ids) } // MarkMessagesUnread marks the message unread by calling an API. @@ -151,7 +149,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error { "mailbox": storeMailbox.Name, }).Trace("Marking messages as unread") defer storeMailbox.pollNow() - return storeMailbox.client().MarkMessagesUnread(context.TODO(), apiIDs) + return storeMailbox.client().MarkMessagesUnread(exposeContextForIMAP(), apiIDs) } // MarkMessagesStarred adds the Starred label by calling an API. @@ -164,7 +162,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error { "mailbox": storeMailbox.Name, }).Trace("Marking messages as starred") defer storeMailbox.pollNow() - return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel) + return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel) } // MarkMessagesUnstarred removes the Starred label by calling an API. @@ -177,7 +175,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error { "mailbox": storeMailbox.Name, }).Trace("Marking messages as unstarred") defer storeMailbox.pollNow() - return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel) + return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel) } // MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API @@ -261,11 +259,11 @@ func (storeMailbox *Mailbox) RemoveDeleted(apiIDs []string) error { } case pmapi.DraftLabel: storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts") - if err := storeMailbox.client().DeleteMessages(context.TODO(), apiIDs); err != nil { + if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), apiIDs); err != nil { return err } default: - if err := storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID); err != nil { + if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID); err != nil { return err } } @@ -303,13 +301,13 @@ func (storeMailbox *Mailbox) deleteFromTrashOrSpam(apiIDs []string) error { } } if len(messageIDsToUnlabel) > 0 { - if err := storeMailbox.client().UnlabelMessages(context.TODO(), messageIDsToUnlabel, storeMailbox.labelID); err != nil { + if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), messageIDsToUnlabel, storeMailbox.labelID); err != nil { l.WithError(err).Warning("Cannot unlabel before deleting") } } if len(messageIDsToDelete) > 0 { storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages") - if err := storeMailbox.client().DeleteMessages(context.TODO(), messageIDsToDelete); err != nil { + if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), messageIDsToDelete); err != nil { return err } } diff --git a/internal/store/mocks/mocks.go b/internal/store/mocks/mocks.go index f93d8900..80813820 100644 --- a/internal/store/mocks/mocks.go +++ b/internal/store/mocks/mocks.go @@ -5,10 +5,10 @@ package mocks import ( - reflect "reflect" - + context "context" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" gomock "github.com/golang/mock/gomock" + reflect "reflect" ) // MockPanicHandler is a mock of PanicHandler interface @@ -207,17 +207,17 @@ func (mr *MockBridgeUserMockRecorder) Logout() *gomock.Call { } // UpdateUser mocks base method -func (m *MockBridgeUser) UpdateUser() error { +func (m *MockBridgeUser) UpdateUser(arg0 context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateUser") + ret := m.ctrl.Call(m, "UpdateUser", arg0) ret0, _ := ret[0].(error) return ret0 } // UpdateUser indicates an expected call of UpdateUser -func (mr *MockBridgeUserMockRecorder) UpdateUser() *gomock.Call { +func (mr *MockBridgeUserMockRecorder) UpdateUser(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser), arg0) } // MockChangeNotifier is a mock of ChangeNotifier interface diff --git a/internal/store/mocks/utils_mocks.go b/internal/store/mocks/utils_mocks.go index 940bf172..e7f1cbb3 100644 --- a/internal/store/mocks/utils_mocks.go +++ b/internal/store/mocks/utils_mocks.go @@ -5,10 +5,9 @@ package mocks import ( + gomock "github.com/golang/mock/gomock" reflect "reflect" time "time" - - gomock "github.com/golang/mock/gomock" ) // MockListener is a mock of Listener interface @@ -58,6 +57,20 @@ func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1) } +// ProvideChannel mocks base method +func (m *MockListener) ProvideChannel(arg0 string) <-chan string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ProvideChannel", arg0) + ret0, _ := ret[0].(<-chan string) + return ret0 +} + +// ProvideChannel indicates an expected call of ProvideChannel +func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0) +} + // Remove mocks base method func (m *MockListener) Remove(arg0 string, arg1 chan<- string) { m.ctrl.T.Helper() diff --git a/internal/store/store.go b/internal/store/store.go index d280a907..ac960e1f 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -101,6 +101,18 @@ var ( ErrNoSuchSeqNum = errors.New("no such sequence number") //nolint[gochecknoglobals] ) +// exposeContextForIMAP should be replaced once with context passed +// as an argument from IMAP package and IMAP library should cancel +// context when IMAP client cancels the request. +func exposeContextForIMAP() context.Context { + return context.TODO() +} + +// exposeContextForSMTP is the same as above but for SMTP. +func exposeContextForSMTP() context.Context { + return context.TODO() +} + // Store is local user storage, which handles the synchronization between IMAP and PM API. type Store struct { sentryReporter *sentry.Reporter @@ -278,7 +290,7 @@ func (store *Store) client() pmapi.Client { // initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if // the API is unavailable for whatever reason it tries to fetch the labels locally. func (store *Store) initCounts() (labels []*pmapi.Label, err error) { - if labels, err = store.client().ListLabels(context.TODO()); err != nil { + if labels, err = store.client().ListLabels(context.Background()); err != nil { store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.") if labels, err = store.getLabelsFromLocalStorage(); err != nil { store.log.WithError(err).Error("Cannot list local labels") diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 8f65be46..029e26a3 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -184,8 +184,8 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client) mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{ - {ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive}, - {ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive}, + {ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true}, + {ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true}, }) mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes() mocks.client.EXPECT().CountMessages(gomock.Any(), "") diff --git a/internal/store/sync.go b/internal/store/sync.go index 03e2ed62..ede34a13 100644 --- a/internal/store/sync.go +++ b/internal/store/sync.go @@ -148,7 +148,7 @@ func getSplitIDAndCount(labelID string, api messageLister, page int) (string, in Limit: 1, } // If the page does not exist, an empty page instead of an error is returned. - messages, total, err := api.ListMessages(context.TODO(), filter) + messages, total, err := api.ListMessages(context.Background(), filter) if err != nil { return "", 0, errors.Wrap(err, "failed to list messages") } @@ -190,7 +190,7 @@ func syncBatch( //nolint[funlen] log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page") - messages, _, err := api.ListMessages(context.TODO(), filter) + messages, _, err := api.ListMessages(context.Background(), filter) if err != nil { return errors.Wrap(err, "failed to list messages") } diff --git a/internal/store/types.go b/internal/store/types.go index 033add45..881999b8 100644 --- a/internal/store/types.go +++ b/internal/store/types.go @@ -17,7 +17,11 @@ package store -import "github.com/ProtonMail/proton-bridge/pkg/pmapi" +import ( + "context" + + "github.com/ProtonMail/proton-bridge/pkg/pmapi" +) type PanicHandler interface { HandlePanic() @@ -32,7 +36,7 @@ type BridgeUser interface { GetPrimaryAddress() string GetStoreAddresses() []string GetClient() pmapi.Client - UpdateUser() error + UpdateUser(context.Context) error CloseAllConnections() CloseConnection(string) Logout() error diff --git a/internal/store/user.go b/internal/store/user.go index 9edd3cf5..c7f5f3e7 100644 --- a/internal/store/user.go +++ b/internal/store/user.go @@ -17,8 +17,6 @@ package store -import "context" - // UserID returns user ID. func (store *Store) UserID() string { return store.user.ID() @@ -26,7 +24,7 @@ func (store *Store) UserID() string { // GetSpace returns used and total space in bytes. func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) { - apiUser, err := store.client().CurrentUser(context.TODO()) + apiUser, err := store.client().CurrentUser(exposeContextForIMAP()) if err != nil { return 0, 0, err } @@ -35,7 +33,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) { // GetMaxUpload returns max size of message + all attachments in bytes. func (store *Store) GetMaxUpload() (int64, error) { - apiUser, err := store.client().CurrentUser(context.TODO()) + apiUser, err := store.client().CurrentUser(exposeContextForIMAP()) if err != nil { return 0, err } diff --git a/internal/store/user_address_info.go b/internal/store/user_address_info.go index f034e3c9..c55f71a4 100644 --- a/internal/store/user_address_info.go +++ b/internal/store/user_address_info.go @@ -147,7 +147,7 @@ func (store *Store) createOrUpdateAddressInfo(addressList pmapi.AddressList) (er // filterAddresses filters out inactive addresses and ensures the original address is listed first. func filterAddresses(addressList pmapi.AddressList) (filteredList pmapi.AddressList) { for _, address := range addressList { - if address.Receive != pmapi.CanReceive { + if !address.Receive { continue } diff --git a/internal/store/user_mailbox.go b/internal/store/user_mailbox.go index beca8283..ad1dc161 100644 --- a/internal/store/user_mailbox.go +++ b/internal/store/user_mailbox.go @@ -18,7 +18,6 @@ package store import ( - "context" "fmt" "strings" @@ -39,14 +38,14 @@ func (store *Store) createMailbox(name string) error { color := store.leastUsedColor() - var exclusive int + var exclusive bool switch { case strings.HasPrefix(name, UserLabelsPrefix): name = strings.TrimPrefix(name, UserLabelsPrefix) - exclusive = 0 + exclusive = false case strings.HasPrefix(name, UserFoldersPrefix): name = strings.TrimPrefix(name, UserFoldersPrefix) - exclusive = 1 + exclusive = true default: // Ideally we would throw an error here, but then Outlook for // macOS keeps trying to make an IMAP Drafts folder and popping @@ -56,10 +55,10 @@ func (store *Store) createMailbox(name string) error { return nil } - _, err := store.client().CreateLabel(context.TODO(), &pmapi.Label{ + _, err := store.client().CreateLabel(exposeContextForIMAP(), &pmapi.Label{ Name: name, Color: color, - Exclusive: exclusive, + Exclusive: pmapi.Boolean(exclusive), Type: pmapi.LabelTypeMailbox, }) return err @@ -126,7 +125,7 @@ func (store *Store) leastUsedColor() string { func (store *Store) updateMailbox(labelID, newName, color string) error { defer store.eventLoop.pollNow() - _, err := store.client().UpdateLabel(context.TODO(), &pmapi.Label{ + _, err := store.client().UpdateLabel(exposeContextForIMAP(), &pmapi.Label{ ID: labelID, Name: newName, Color: color, @@ -143,15 +142,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error { var err error switch labelID { case pmapi.SpamLabel: - err = store.client().EmptyFolder(context.TODO(), pmapi.SpamLabel, addressID) + err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.SpamLabel, addressID) case pmapi.TrashLabel: - err = store.client().EmptyFolder(context.TODO(), pmapi.TrashLabel, addressID) + err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.TrashLabel, addressID) default: err = fmt.Errorf("cannot empty mailbox %v", labelID) } return err } - return store.client().DeleteLabel(context.TODO(), labelID) + return store.client().DeleteLabel(exposeContextForIMAP(), labelID) } func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error { @@ -166,7 +165,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro return nil } - labels, err := store.client().ListLabels(context.TODO()) + labels, err := store.client().ListLabels(exposeContextForIMAP()) if err != nil { return err } diff --git a/internal/store/user_message.go b/internal/store/user_message.go index a6155433..47470109 100644 --- a/internal/store/user_message.go +++ b/internal/store/user_message.go @@ -19,7 +19,6 @@ package store import ( "bytes" - "context" "encoding/json" "io" "io/ioutil" @@ -58,7 +57,7 @@ func (store *Store) CreateDraft( } draftAction := store.getDraftAction(message) - draft, err := store.client().CreateDraft(context.TODO(), message, parentID, draftAction) + draft, err := store.client().CreateDraft(exposeContextForSMTP(), message, parentID, draftAction) if err != nil { return nil, nil, errors.Wrap(err, "failed to create draft") } @@ -70,7 +69,7 @@ func (store *Store) CreateDraft( for _, att := range attachments { att.attachment.MessageID = draft.ID - createdAttachment, err := store.client().CreateAttachment(context.TODO(), att.attachment, att.encReader, att.sigReader) + createdAttachment, err := store.client().CreateAttachment(exposeContextForSMTP(), att.attachment, att.encReader, att.sigReader) if err != nil { return nil, nil, errors.Wrap(err, "failed to create attachment") } @@ -184,7 +183,7 @@ func (store *Store) getDraftAction(message *pmapi.Message) int { // SendMessage sends the message. func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error { defer store.eventLoop.pollNow() - _, _, err := store.client().SendMessage(context.TODO(), messageID, req) + _, _, err := store.client().SendMessage(exposeContextForSMTP(), messageID, req) return err } diff --git a/internal/store/user_message_test.go b/internal/store/user_message_test.go index 68bb50ec..bc606723 100644 --- a/internal/store/user_message_test.go +++ b/internal/store/user_message_test.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/ProtonMail/proton-bridge/pkg/pmapi" + "github.com/golang/mock/gomock" a "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -34,10 +35,10 @@ func TestGetAllMessageIDs(t *testing.T) { m.newStoreNoEvents(true) - insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) - insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel}) - insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) - insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{}) + insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) + insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel}) + insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) + insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{}) checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"}) } @@ -47,7 +48,7 @@ func TestGetMessageFromDB(t *testing.T) { defer clear() m.newStoreNoEvents(true) - insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) + insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) tests := []struct{ msgID, wantErr string }{ {"msg1", ""}, @@ -72,7 +73,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) { defer clear() m.newStoreNoEvents(true) - insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) + insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) msg, err := m.store.getMessageFromDB("msg1") require.Nil(t, err) @@ -104,7 +105,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) { a.Equal(t, wantHeader, msg.Header) // Check calculated data are not overridden by reinsert. - insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) + insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) msg, err = m.store.getMessageFromDB("msg1") require.Nil(t, err) @@ -118,8 +119,8 @@ func TestDeleteMessage(t *testing.T) { defer clear() m.newStoreNoEvents(true) - insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) - insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel}) + insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) + insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel}) require.Nil(t, m.store.deleteMessageEvent("msg1")) @@ -127,17 +128,17 @@ func TestDeleteMessage(t *testing.T) { checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}}) } -func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread pmapi.Boolean, labelIDs []string) { //nolint[unparam] +func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam] msg := getTestMessage(id, subject, sender, unread, labelIDs) require.Nil(t, m.store.createOrUpdateMessageEvent(msg)) } -func getTestMessage(id, subject, sender string, unread pmapi.Boolean, labelIDs []string) *pmapi.Message { +func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message { address := &mail.Address{Address: sender} return &pmapi.Message{ ID: id, Subject: subject, - Unread: unread, + Unread: pmapi.Boolean(unread), Sender: address, ToList: []*mail.Address{address}, LabelIDs: labelIDs, @@ -162,7 +163,7 @@ func TestCreateDraftCheckMessageSize(t *testing.T) { defer clear() m.newStoreNoEvents(false) - m.client.EXPECT().CurrentUser().Return(&pmapi.User{ + m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{ MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+. }, nil) @@ -181,7 +182,7 @@ func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) { defer clear() m.newStoreNoEvents(false) - m.client.EXPECT().CurrentUser().Return(&pmapi.User{ + m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{ MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+. }, nil) diff --git a/internal/store/user_sync.go b/internal/store/user_sync.go index 780d7ca3..0ea02a45 100644 --- a/internal/store/user_sync.go +++ b/internal/store/user_sync.go @@ -35,7 +35,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted" // updateCountsFromServer will download and set the counts. func (store *Store) updateCountsFromServer() error { - counts, err := store.client().CountMessages(context.TODO(), "") + counts, err := store.client().CountMessages(context.Background(), "") if err != nil { return errors.Wrap(err, "cannot update counts from server") } diff --git a/internal/store/user_sync_test.go b/internal/store/user_sync_test.go index eb822e1e..688f7a9a 100644 --- a/internal/store/user_sync_test.go +++ b/internal/store/user_sync_test.go @@ -31,8 +31,8 @@ func TestLoadSaveSyncState(t *testing.T) { defer clear() m.newStoreNoEvents(true) - insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) - insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) + insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) + insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) // Clear everything. diff --git a/internal/transfer/mocks/mocks.go b/internal/transfer/mocks/mocks.go index 1c8040e0..afac025d 100644 --- a/internal/transfer/mocks/mocks.go +++ b/internal/transfer/mocks/mocks.go @@ -5,11 +5,10 @@ package mocks import ( - reflect "reflect" - imap "github.com/emersion/go-imap" sasl "github.com/emersion/go-sasl" gomock "github.com/golang/mock/gomock" + reflect "reflect" ) // MockPanicHandler is a mock of PanicHandler interface diff --git a/internal/transfer/provider_imap_utils.go b/internal/transfer/provider_imap_utils.go index 74a120ca..0ac8a8f0 100644 --- a/internal/transfer/provider_imap_utils.go +++ b/internal/transfer/provider_imap_utils.go @@ -19,7 +19,9 @@ package transfer import ( "crypto/tls" + "fmt" "net" + "net/http" "strings" "time" @@ -37,6 +39,8 @@ const ( imapRetries = 10 imapReconnectTimeout = 30 * time.Minute imapReconnectSleep = time.Minute + + protonStatusURL = "http://protonstatus.com/vpn_status" ) type imapErrorLogger struct { @@ -117,19 +121,15 @@ func (p *IMAPProvider) tryReconnect(ensureSelectedIn string) error { return previousErr } - // FIXME(conman): This should register as connection observer. + err := checkConnection() + log.WithError(err).Debug("Connection check") + if err != nil { + time.Sleep(imapReconnectSleep) + previousErr = err + continue + } - /* - err := pmapi.CheckConnection() - log.WithError(err).Debug("Connection check") - if err != nil { - time.Sleep(imapReconnectSleep) - previousErr = err - continue - } - */ - - err := p.reauth() + err = p.reauth() log.WithError(err).Debug("Reauth") if err != nil { time.Sleep(imapReconnectSleep) @@ -289,3 +289,23 @@ func (p *IMAPProvider) fetchHelper(uid bool, ensureSelectedIn string, seqSet *im return err }, ensureSelectedIn) } + +// checkConnection returns an error if there is no internet connection. +// Note we don't want to use client manager because it only reports connection +// issues with API; we are only interested here whether we can reach +// third-party IMAP servers. +func checkConnection() error { + client := &http.Client{Timeout: time.Second * 10} + + resp, err := client.Get(protonStatusURL) + if err != nil { + return err + } + + _ = resp.Body.Close() + if resp.StatusCode != 200 { + return fmt.Errorf("HTTP status code %d", resp.StatusCode) + } + + return nil +} diff --git a/internal/transfer/provider_pmapi_target.go b/internal/transfer/provider_pmapi_target.go index 714f1b73..4b27530d 100644 --- a/internal/transfer/provider_pmapi_target.go +++ b/internal/transfer/provider_pmapi_target.go @@ -52,15 +52,10 @@ func (p *PMAPIProvider) CreateMailbox(mailbox Mailbox) (Mailbox, error) { return Mailbox{}, errors.New("mailbox is already created") } - exclusive := 0 - if mailbox.IsExclusive { - exclusive = 1 - } - - label, err := p.client.CreateLabel(context.TODO(), &pmapi.Label{ + label, err := p.client.CreateLabel(context.Background(), &pmapi.Label{ Name: mailbox.Name, Color: mailbox.Color, - Exclusive: exclusive, + Exclusive: pmapi.Boolean(mailbox.IsExclusive), Type: pmapi.LabelTypeMailbox, }) if err != nil { @@ -126,7 +121,7 @@ func (p *PMAPIProvider) importDraft(msg Message, globalMailbox *Mailbox) (string } if message.Sender == nil { - mainAddress := p.client().Addresses().Main() + mainAddress := p.client.Addresses().Main() message.Sender = &mail.Address{ Name: mainAddress.DisplayName, Address: mainAddress.Email, @@ -227,14 +222,6 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog } } - var unread pmapi.Boolean - - if msg.Unread { - unread = pmapi.True - } else { - unread = pmapi.False - } - labelIDs := []string{} for _, target := range msg.Targets { // Frontend should not set All Mail to Rules, but to be sure... @@ -249,7 +236,7 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog return &pmapi.ImportMsgReq{ Metadata: &pmapi.ImportMetadata{ AddressID: p.addressID, - Unread: unread, + Unread: pmapi.Boolean(msg.Unread), Time: message.Time, Flags: computeMessageFlags(message.Header), LabelIDs: labelIDs, diff --git a/internal/transfer/provider_pmapi_test.go b/internal/transfer/provider_pmapi_test.go index acb36cb9..e15d3ad1 100644 --- a/internal/transfer/provider_pmapi_test.go +++ b/internal/transfer/provider_pmapi_test.go @@ -153,10 +153,10 @@ func setupPMAPIRules(rules transferRules) { func setupPMAPIClientExpectationForExport(m *mocks) { m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes() m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{ - {ID: "label1", Name: "Foo", Color: "blue", Exclusive: 0, Order: 2}, - {ID: "label2", Name: "Bar", Color: "green", Exclusive: 0, Order: 1}, - {ID: "folder1", Name: "One", Color: "red", Exclusive: 1, Order: 1}, - {ID: "folder2", Name: "Two", Color: "orange", Exclusive: 1, Order: 2}, + {ID: "label1", Name: "Foo", Color: "blue", Exclusive: false, Order: 2}, + {ID: "label2", Name: "Bar", Color: "green", Exclusive: false, Order: 1}, + {ID: "folder1", Name: "One", Color: "red", Exclusive: true, Order: 1}, + {ID: "folder2", Name: "Two", Color: "orange", Exclusive: true, Order: 2}, }, nil).AnyTimes() m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{ {LabelID: "label1", Total: 10}, diff --git a/internal/transfer/provider_pmapi_utils.go b/internal/transfer/provider_pmapi_utils.go index 5d1d1512..2fefcacd 100644 --- a/internal/transfer/provider_pmapi_utils.go +++ b/internal/transfer/provider_pmapi_utils.go @@ -30,9 +30,17 @@ import ( const ( pmapiRetries = 10 pmapiReconnectTimeout = 30 * time.Minute - pmapiReconnectSleep = time.Minute + pmapiReconnectSleep = 10 * time.Second ) +func (p *PMAPIProvider) SetConnectionUp() { + p.connection = true +} + +func (p *PMAPIProvider) SetConnectionDown() { + p.connection = false +} + func (p *PMAPIProvider) ensureConnection(callback func() error) error { var callErr error for i := 1; i <= pmapiRetries; i++ { @@ -58,18 +66,10 @@ func (p *PMAPIProvider) tryReconnect() error { return previousErr } - // FIXME(conman): This should register as a connection observer somehow... - // Maybe the entire "provider" could register as an observer and pause if it is notified of dropped connection? - - /* - err := p.clientManager.CheckConnection() - log.WithError(err).Debug("Connection check") - if err != nil { - time.Sleep(pmapiReconnectSleep) - previousErr = err - continue - } - */ + if !p.connection { + time.Sleep(pmapiReconnectSleep) + continue + } break } @@ -83,7 +83,7 @@ func (p *PMAPIProvider) listMessages(filter *pmapi.MessagesFilter) (messages []* p.timeIt.start("listing", key) defer p.timeIt.stop("listing", key) - messages, count, err = p.client.ListMessages(context.TODO(), filter) + messages, count, err = p.client.ListMessages(context.Background(), filter) return err }) return @@ -94,7 +94,7 @@ func (p *PMAPIProvider) getMessage(msgID string) (message *pmapi.Message, err er p.timeIt.start("download", msgID) defer p.timeIt.stop("download", msgID) - message, err = p.client.GetMessage(context.TODO(), msgID) + message, err = p.client.GetMessage(context.Background(), msgID) return err }) return @@ -105,7 +105,7 @@ func (p *PMAPIProvider) importRequest(msgSourceID string, req pmapi.ImportMsgReq p.timeIt.start("upload", msgSourceID) defer p.timeIt.stop("upload", msgSourceID) - res, err = p.client.Import(context.TODO(), req) + res, err = p.client.Import(context.Background(), req) return err }) return @@ -116,7 +116,7 @@ func (p *PMAPIProvider) createDraft(msgSourceID string, message *pmapi.Message, p.timeIt.start("upload", msgSourceID) defer p.timeIt.stop("upload", msgSourceID) - draft, err = p.client.CreateDraft(context.TODO(), message, parent, action) + draft, err = p.client.CreateDraft(context.Background(), message, parent, action) return err }) return @@ -129,7 +129,7 @@ func (p *PMAPIProvider) createAttachment(msgSourceID string, att *pmapi.Attachme p.timeIt.start("upload", key) defer p.timeIt.stop("upload", key) - created, err = p.client.CreateAttachment(context.TODO(), att, r, sig) + created, err = p.client.CreateAttachment(context.Background(), att, r, sig) return err }) return diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index c32a7d04..c9ab092e 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -28,6 +28,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/internal/config/settings" + "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -274,7 +275,7 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) { wg.Wait() } -func newTestUpdater(manager *mocks.MockManager, curVer string, earlyAccess bool) *Updater { +func newTestUpdater(manager pmapi.Manager, curVer string, earlyAccess bool) *Updater { return New( manager, &fakeInstaller{}, diff --git a/internal/users/_user_credentials_test.go b/internal/users/_user_credentials_test.go deleted file mode 100644 index ab0698ac..00000000 --- a/internal/users/_user_credentials_test.go +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package users - -import ( - "testing" - - "github.com/ProtonMail/proton-bridge/internal/events" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" - gomock "github.com/golang/mock/gomock" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" -) - -func TestUpdateUser(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - user := testNewUser(m) - defer cleanUpUserData(user) - - gomock.InOrder( - m.pmapiClient.EXPECT().IsUnlocked().Return(false), - m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil), - - m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil), - m.pmapiClient.EXPECT().ReloadKeys([]byte(testCredentials.MailboxPassword)).Return(nil), - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - - m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}), - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), - ) - - gomock.InOrder( - m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1), - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1), - ) - - assert.NoError(t, user.UpdateUser()) - - waitForEvents() -} - -func TestUserSwitchAddressMode(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - user := testNewUser(m) - defer cleanUpUserData(user) - - assert.True(t, user.store.IsCombinedMode()) - assert.True(t, user.creds.IsCombinedAddressMode) - assert.True(t, user.IsCombinedAddressMode()) - waitForEvents() - - gomock.InOrder( - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"), - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - - m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil), - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsSplit, nil), - ) - - assert.NoError(t, user.SwitchAddressMode()) - assert.False(t, user.store.IsCombinedMode()) - assert.False(t, user.creds.IsCombinedAddressMode) - assert.False(t, user.IsCombinedAddressMode()) - - gomock.InOrder( - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"), - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"), - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"), - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - - m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil), - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), - ) - m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() - - assert.NoError(t, user.SwitchAddressMode()) - assert.True(t, user.store.IsCombinedMode()) - assert.True(t, user.creds.IsCombinedAddressMode) - assert.True(t, user.IsCombinedAddressMode()) - - waitForEvents() -} - -func TestLogoutUser(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - user := testNewUserForLogout(m) - defer cleanUpUserData(user) - - gomock.InOrder( - m.pmapiClient.EXPECT().Logout().Return(), - m.credentialsStore.EXPECT().Logout("user").Return(nil), - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), - ) - - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") - - err := user.Logout() - - waitForEvents() - - assert.NoError(t, err) -} - -func TestLogoutUserFailsLogout(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - user := testNewUserForLogout(m) - defer cleanUpUserData(user) - - gomock.InOrder( - m.pmapiClient.EXPECT().Logout().Return(), - m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")), - m.credentialsStore.EXPECT().Delete("user").Return(nil), - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), - ) - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") - - err := user.Logout() - waitForEvents() - assert.NoError(t, err) -} - -func TestCheckBridgeLoginOK(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - user := testNewUser(m) - defer cleanUpUserData(user) - - gomock.InOrder( - m.pmapiClient.EXPECT().IsUnlocked().Return(false), - m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil), - ) - - err := user.CheckBridgeLogin(testCredentials.BridgePassword) - - waitForEvents() - - assert.NoError(t, err) -} - -func TestCheckBridgeLoginTwiceOK(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - user := testNewUser(m) - defer cleanUpUserData(user) - - gomock.InOrder( - m.pmapiClient.EXPECT().IsUnlocked().Return(false), - m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil), - m.pmapiClient.EXPECT().IsUnlocked().Return(true), - ) - - err := user.CheckBridgeLogin(testCredentials.BridgePassword) - waitForEvents() - assert.NoError(t, err) - - err = user.CheckBridgeLogin(testCredentials.BridgePassword) - waitForEvents() - assert.NoError(t, err) -} - -func TestCheckBridgeLoginUpgradeApplication(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - user := testNewUser(m) - defer cleanUpUserData(user) - - m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "") - - isApplicationOutdated = true - - err := user.CheckBridgeLogin("any-pass") - waitForEvents() - assert.Equal(t, pmapi.ErrUpgradeApplication, err) - - isApplicationOutdated = false -} - -func TestCheckBridgeLoginLoggedOut(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) - - user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker) - assert.NoError(t, err) - - m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient).MinTimes(1) - gomock.InOrder( - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), - m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")), - m.pmapiClient.EXPECT().Addresses().Return(nil), - ) - - err = user.init() - assert.Error(t, err) - - defer cleanUpUserData(user) - - m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") - - err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword) - waitForEvents() - assert.Equal(t, ErrLoggedOutUser, err) -} - -func TestCheckBridgeLoginBadPassword(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - user := testNewUser(m) - defer cleanUpUserData(user) - - gomock.InOrder( - m.pmapiClient.EXPECT().IsUnlocked().Return(false), - m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil), - ) - - err := user.CheckBridgeLogin("wrong!") - waitForEvents() - assert.Equal(t, "backend/credentials: incorrect password", err.Error()) -} diff --git a/internal/users/_user_new_test.go b/internal/users/_user_new_test.go deleted file mode 100644 index 1c15b4d8..00000000 --- a/internal/users/_user_new_test.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package users - -import ( - "errors" - "testing" - - "github.com/ProtonMail/proton-bridge/internal/events" - "github.com/ProtonMail/proton-bridge/internal/users/credentials" - gomock "github.com/golang/mock/gomock" - a "github.com/stretchr/testify/assert" -) - -func TestNewUserNoCredentialsStore(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail")) - - _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker) - a.Error(t, err) -} - -func TestNewUserAuthRefreshFails(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") - m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") - - gomock.InOrder( - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), - m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")), - m.credentialsStore.EXPECT().Logout("user").Return(nil), - - m.pmapiClient.EXPECT().Logout(), - m.credentialsStore.EXPECT().Logout("user").Return(nil), - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), - ) - - checkNewUserHasCredentials(testCredentialsDisconnected, m) -} - -func TestNewUserUnlockFails(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - - m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") - m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") - - gomock.InOrder( - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), - m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), - - m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(errors.New("bad password")), - m.credentialsStore.EXPECT().Logout("user").Return(nil), - m.pmapiClient.EXPECT().Logout(), - m.credentialsStore.EXPECT().Logout("user").Return(nil), - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), - ) - - checkNewUserHasCredentials(testCredentialsDisconnected, m) -} - -func TestNewUser(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - mockConnectedUser(m) - mockEventLoopNoAction(m) - - checkNewUserHasCredentials(testCredentials, m) -} - -func checkNewUserHasCredentials(creds *credentials.Credentials, m mocks) { - user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker) - defer cleanUpUserData(user) - - _ = user.init() - - waitForEvents() - - a.Equal(m.t, creds, user.creds) -} - -func _TestUserEventRefreshUpdatesAddresses(t *testing.T) { // nolint[funlen] - a.Fail(t, "not implemented") -} diff --git a/internal/users/_user_test.go b/internal/users/_user_test.go deleted file mode 100644 index 9c3aff6e..00000000 --- a/internal/users/_user_test.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package users - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// testNewUser sets up a new, authorised user. -func testNewUser(m mocks) *User { - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - - mockConnectedUser(m) - mockEventLoopNoAction(m) - - user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker) - assert.NoError(m.t, err) - - err = user.init() - assert.NoError(m.t, err) - - mockAuthUpdate(user, "reftok", m) - - return user -} - -func testNewUserForLogout(m mocks) *User { - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - - mockConnectedUser(m) - mockEventLoopNoAction(m) - - user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker) - assert.NoError(m.t, err) - - err = user.init() - assert.NoError(m.t, err) - - return user -} - -func cleanUpUserData(u *User) { - _ = u.clearStore() -} - -func _TestNeverLongStorePath(t *testing.T) { // nolint[unused] - assert.Fail(t, "not implemented") -} - -func TestClearStoreWithStore(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - user := testNewUserForLogout(m) - defer cleanUpUserData(user) - - require.Nil(t, user.store.Close()) - user.store = nil - assert.Nil(t, user.clearStore()) -} - -func TestClearStoreWithoutStore(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - user := testNewUserForLogout(m) - defer cleanUpUserData(user) - - assert.NotNil(t, user.store) - assert.Nil(t, user.clearStore()) -} diff --git a/internal/users/_users_actions_test.go b/internal/users/_users_actions_test.go deleted file mode 100644 index 74cf0563..00000000 --- a/internal/users/_users_actions_test.go +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package users - -import ( - "errors" - "testing" - - "github.com/ProtonMail/proton-bridge/internal/events" - gomock "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" -) - -func TestGetNoUser(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) - - checkUsersGetUser(t, m, "nouser", -1, "user nouser not found") -} - -func TestGetUserByID(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) - - checkUsersGetUser(t, m, "user", 0, "") - checkUsersGetUser(t, m, "users", 1, "") -} - -func TestGetUserByName(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) - - checkUsersGetUser(t, m, "username", 0, "") - checkUsersGetUser(t, m, "usersname", 1, "") -} - -func TestGetUserByEmail(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) - - checkUsersGetUser(t, m, "user@pm.me", 0, "") - checkUsersGetUser(t, m, "users@pm.me", 1, "") - checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "") - checkUsersGetUser(t, m, "alsouser@pm.me", 1, "") -} - -func TestDeleteUser(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) - - users := testNewUsersWithUsers(t, m) - defer cleanUpUsersData(users) - - gomock.InOrder( - m.pmapiClient.EXPECT().Logout().Return(), - m.credentialsStore.EXPECT().Logout("user").Return(nil), - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), - m.credentialsStore.EXPECT().Delete("user").Return(nil), - ) - - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") - - err := users.DeleteUser("user", true) - assert.NoError(t, err) - assert.Equal(t, 1, len(users.users)) -} - -// Even when logout fails, delete is done. -func TestDeleteUserWithFailingLogout(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) - - users := testNewUsersWithUsers(t, m) - defer cleanUpUsersData(users) - - gomock.InOrder( - m.pmapiClient.EXPECT().Logout().Return(), - m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")), - m.credentialsStore.EXPECT().Delete("user").Return(nil), - m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("no such user")), - m.credentialsStore.EXPECT().Delete("user").Return(nil), - ) - - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") - - err := users.DeleteUser("user", true) - assert.NoError(t, err) - assert.Equal(t, 1, len(users.users)) -} - -func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) { - users := testNewUsersWithUsers(t, m) - defer cleanUpUsersData(users) - - user, err := users.GetUser(query) - waitForEvents() - - if expectedError != "" { - assert.Equal(m.t, expectedError, err.Error()) - } else { - assert.NoError(m.t, err) - } - - var expectedUser *User - if index >= 0 { - expectedUser = users.users[index] - } - - assert.Equal(m.t, expectedUser, user) -} diff --git a/internal/users/_users_login_test.go b/internal/users/_users_login_test.go deleted file mode 100644 index d23bb39e..00000000 --- a/internal/users/_users_login_test.go +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package users - -import ( - "testing" - - "github.com/ProtonMail/proton-bridge/internal/events" - "github.com/ProtonMail/proton-bridge/internal/metrics" - "github.com/ProtonMail/proton-bridge/internal/users/credentials" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" - gomock "github.com/golang/mock/gomock" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" -) - -func TestUsersFinishLoginBadMailboxPassword(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - gomock.InOrder( - // Init users with no user from keychain. - m.credentialsStore.EXPECT().List().Return([]string{}, nil), - - // Set up mocks for FinishLogin. - m.pmapiClient.EXPECT().AuthSalt().Return("", nil), - m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(errors.New("no keys could be unlocked")), - m.pmapiClient.EXPECT().DeleteAuth(), - m.pmapiClient.EXPECT().Logout(), - ) - - checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", ErrWrongMailboxPassword) -} - -func refreshWithToken(token string) *pmapi.Auth { - return &pmapi.Auth{ - RefreshToken: token, - } -} - -func credentialsWithToken(token string) *credentials.Credentials { - tmp := &credentials.Credentials{} - *tmp = *testCredentials - tmp.APIToken = token - return tmp -} - -func TestUsersFinishLoginNewUser(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - // Basically every call client has get client manager - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - - gomock.InOrder( - // users.New() finds no users in keychain. - m.credentialsStore.EXPECT().List().Return([]string{}, nil), - - // getAPIUser() loads user info from API (e.g. userID). - m.pmapiClient.EXPECT().AuthSalt().Return("", nil), - m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil), - m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), - - // addNewUser() - m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil), - m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}), - m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil), - - // user.init() in addNewUser - m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil), - m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil), - m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil), - - // store.New() in user.init - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - - // Emit event for new user and send metrics. - m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient), - m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)), - m.pmapiClient.EXPECT().Logout(), - - // Reload account list in GUI. - m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"), - - // defer logout anonymous - m.pmapiClient.EXPECT().Logout(), - ) - - mockEventLoopNoAction(m) - - user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil) - - mockAuthUpdate(user, "afterCredentials", m) -} - -func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - loggedOutCreds := *testCredentials - loggedOutCreds.APIToken = "" - loggedOutCreds.MailboxPassword = "" - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - - gomock.InOrder( - // users.New() finds one existing user in keychain. - m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil), - - // newUser() - m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil), - - // user.init() - m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil), - - // store.New() in user.init - m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUnauthorized), - m.pmapiClient.EXPECT().Addresses().Return(nil), - - // getAPIUser() loads user info from API (e.g. userID). - m.pmapiClient.EXPECT().AuthSalt().Return("", nil), - m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil), - m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), - - // connectExistingUser() - m.credentialsStore.EXPECT().UpdatePassword("user", testCredentials.MailboxPassword).Return(nil), - m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil), - m.credentialsStore.EXPECT().UpdateToken("user", ":afterLogin").Return(nil), - - // user.init() in connectExistingUser - m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil), - m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil), - m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil), - - // store.New() in user.init - m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), - m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - - // Reload account list in GUI. - m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"), - - // defer logout anonymous - m.pmapiClient.EXPECT().Logout(), - ) - - mockEventLoopNoAction(m) - - user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil) - - mockAuthUpdate(user, "afterCredentials", m) -} - -func TestUsersFinishLoginConnectedUser(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) - - mockConnectedUser(m) - mockEventLoopNoAction(m) - - users := testNewUsers(t, m) - defer cleanUpUsersData(users) - - // Then, try to log in again... - gomock.InOrder( - m.pmapiClient.EXPECT().AuthSalt().Return("", nil), - m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil), - m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil), - m.pmapiClient.EXPECT().DeleteAuth(), - m.pmapiClient.EXPECT().Logout(), - ) - - _, err := users.FinishLogin(m.pmapiClient, testAuth, testCredentials.MailboxPassword) - assert.Equal(t, "user is already connected", err.Error()) -} - -func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) *User { - users := testNewUsers(t, m) - defer cleanUpUsersData(users) - - user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword) - - waitForEvents() - - assert.Equal(t, expectedErr, err) - - if expectedUserID != "" { - assert.Equal(t, expectedUserID, user.ID()) - assert.Equal(t, 1, len(users.users)) - assert.Equal(t, expectedUserID, users.users[0].ID()) - } else { - assert.Equal(t, (*User)(nil), user) - assert.Equal(t, 0, len(users.users)) - } - - return user -} diff --git a/internal/users/credentials/store.go b/internal/users/credentials/store.go index ad131873..b1599f8d 100644 --- a/internal/users/credentials/store.go +++ b/internal/users/credentials/store.go @@ -233,7 +233,7 @@ func (s *Store) get(userID string) (creds *Credentials, err error) { _, secret, err := s.secrets.Get(userID) if err != nil { - log.WithError(err).Error("Could not get credentials from native keychain") + log.WithError(err).Warn("Could not get credentials from native keychain") return } diff --git a/internal/users/credentials/store_test.go b/internal/users/credentials/store_test.go index 538e7609..bc94ccc1 100644 --- a/internal/users/credentials/store_test.go +++ b/internal/users/credentials/store_test.go @@ -26,8 +26,7 @@ import ( "strings" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + r "github.com/stretchr/testify/require" ) const testSep = "\n" @@ -249,26 +248,26 @@ func TestMarshalFormats(t *testing.T) { log.Infof("secretFmt %#v %d\n", secretFmt, len(secretFmt)) output := testCredentials{APIToken: "refresh"} - require.NoError(t, output.UnmarshalStrings(secretStrings)) + r.NoError(t, output.UnmarshalStrings(secretStrings)) log.Infof("strings out %#v \n", output) - require.True(t, input.IsSame(&output), "strings out not same") + r.True(t, input.IsSame(&output), "strings out not same") output = testCredentials{APIToken: "refresh"} - require.NoError(t, output.UnmarshalGob(secretGob)) + r.NoError(t, output.UnmarshalGob(secretGob)) log.Infof("gob out %#v\n \n", output) - assert.Equal(t, input, output) + r.Equal(t, input, output) output = testCredentials{APIToken: "refresh"} - require.NoError(t, output.FromJSON(secretJSON)) + r.NoError(t, output.FromJSON(secretJSON)) log.Infof("json out %#v \n", output) - require.True(t, input.IsSame(&output), "json out not same") + r.True(t, input.IsSame(&output), "json out not same") /* // Simple Fscanf not working! output = testCredentials{APIToken: "refresh"} - require.NoError(t, output.UnmarshalFmt(secretFmt)) + r.NoError(t, output.UnmarshalFmt(secretFmt)) log.Infof("fmt out %#v \n", output) - require.True(t, input.IsSame(&output), "fmt out not same") + r.True(t, input.IsSame(&output), "fmt out not same") */ } @@ -291,7 +290,7 @@ func TestMarshal(t *testing.T) { log.Infof("secret %#v %d\n", secret, len(secret)) output := Credentials{APIToken: "refresh"} - require.NoError(t, output.Unmarshal(secret)) + r.NoError(t, output.Unmarshal(secret)) log.Infof("output %#v\n", output) - assert.Equal(t, input, output) + r.Equal(t, input, output) } diff --git a/internal/users/mock_listener.go b/internal/users/mock_listener.go deleted file mode 100644 index ce5d6145..00000000 --- a/internal/users/mock_listener.go +++ /dev/null @@ -1,107 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ./listener/listener.go - -// Package users is a generated GoMock package. -package users - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" -) - -// MockListener is a mock of Listener interface -type MockListener struct { - ctrl *gomock.Controller - recorder *MockListenerMockRecorder -} - -// MockListenerMockRecorder is the mock recorder for MockListener -type MockListenerMockRecorder struct { - mock *MockListener -} - -// NewMockListener creates a new mock instance -func NewMockListener(ctrl *gomock.Controller) *MockListener { - mock := &MockListener{ctrl: ctrl} - mock.recorder = &MockListenerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockListener) EXPECT() *MockListenerMockRecorder { - return m.recorder -} - -// SetLimit mocks base method -func (m *MockListener) SetLimit(eventName string, limit time.Duration) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLimit", eventName, limit) -} - -// SetLimit indicates an expected call of SetLimit -func (mr *MockListenerMockRecorder) SetLimit(eventName, limit interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), eventName, limit) -} - -// Add mocks base method -func (m *MockListener) Add(eventName string, channel chan<- string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Add", eventName, channel) -} - -// Add indicates an expected call of Add -func (mr *MockListenerMockRecorder) Add(eventName, channel interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), eventName, channel) -} - -// Remove mocks base method -func (m *MockListener) Remove(eventName string, channel chan<- string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Remove", eventName, channel) -} - -// Remove indicates an expected call of Remove -func (mr *MockListenerMockRecorder) Remove(eventName, channel interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), eventName, channel) -} - -// Emit mocks base method -func (m *MockListener) Emit(eventName, data string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Emit", eventName, data) -} - -// Emit indicates an expected call of Emit -func (mr *MockListenerMockRecorder) Emit(eventName, data interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), eventName, data) -} - -// SetBuffer mocks base method -func (m *MockListener) SetBuffer(eventName string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetBuffer", eventName) -} - -// SetBuffer indicates an expected call of SetBuffer -func (mr *MockListenerMockRecorder) SetBuffer(eventName interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), eventName) -} - -// RetryEmit mocks base method -func (m *MockListener) RetryEmit(eventName string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RetryEmit", eventName) -} - -// RetryEmit indicates an expected call of RetryEmit -func (mr *MockListenerMockRecorder) RetryEmit(eventName interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), eventName) -} diff --git a/internal/users/mocks/listener_mocks.go b/internal/users/mocks/listener_mocks.go new file mode 100644 index 00000000..e7f1cbb3 --- /dev/null +++ b/internal/users/mocks/listener_mocks.go @@ -0,0 +1,120 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ProtonMail/proton-bridge/pkg/listener (interfaces: Listener) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" + time "time" +) + +// MockListener is a mock of Listener interface +type MockListener struct { + ctrl *gomock.Controller + recorder *MockListenerMockRecorder +} + +// MockListenerMockRecorder is the mock recorder for MockListener +type MockListenerMockRecorder struct { + mock *MockListener +} + +// NewMockListener creates a new mock instance +func NewMockListener(ctrl *gomock.Controller) *MockListener { + mock := &MockListener{ctrl: ctrl} + mock.recorder = &MockListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockListener) EXPECT() *MockListenerMockRecorder { + return m.recorder +} + +// Add mocks base method +func (m *MockListener) Add(arg0 string, arg1 chan<- string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Add", arg0, arg1) +} + +// Add indicates an expected call of Add +func (mr *MockListenerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), arg0, arg1) +} + +// Emit mocks base method +func (m *MockListener) Emit(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Emit", arg0, arg1) +} + +// Emit indicates an expected call of Emit +func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1) +} + +// ProvideChannel mocks base method +func (m *MockListener) ProvideChannel(arg0 string) <-chan string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ProvideChannel", arg0) + ret0, _ := ret[0].(<-chan string) + return ret0 +} + +// ProvideChannel indicates an expected call of ProvideChannel +func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0) +} + +// Remove mocks base method +func (m *MockListener) Remove(arg0 string, arg1 chan<- string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Remove", arg0, arg1) +} + +// Remove indicates an expected call of Remove +func (mr *MockListenerMockRecorder) Remove(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), arg0, arg1) +} + +// RetryEmit mocks base method +func (m *MockListener) RetryEmit(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RetryEmit", arg0) +} + +// RetryEmit indicates an expected call of RetryEmit +func (mr *MockListenerMockRecorder) RetryEmit(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), arg0) +} + +// SetBuffer mocks base method +func (m *MockListener) SetBuffer(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetBuffer", arg0) +} + +// SetBuffer indicates an expected call of SetBuffer +func (mr *MockListenerMockRecorder) SetBuffer(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), arg0) +} + +// SetLimit mocks base method +func (m *MockListener) SetLimit(arg0 string, arg1 time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetLimit", arg0, arg1) +} + +// SetLimit indicates an expected call of SetLimit +func (mr *MockListenerMockRecorder) SetLimit(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), arg0, arg1) +} diff --git a/internal/users/mocks/mocks.go b/internal/users/mocks/mocks.go index 3304b3af..a4a8c616 100644 --- a/internal/users/mocks/mocks.go +++ b/internal/users/mocks/mocks.go @@ -5,11 +5,10 @@ package mocks import ( - reflect "reflect" - store "github.com/ProtonMail/proton-bridge/internal/store" credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials" gomock "github.com/golang/mock/gomock" + reflect "reflect" ) // MockLocator is a mock of Locator interface diff --git a/internal/users/user.go b/internal/users/user.go index 7df89416..5641526e 100644 --- a/internal/users/user.go +++ b/internal/users/user.go @@ -89,29 +89,25 @@ func newUser( // - providing it with an authorised API client // - loading its credentials from the credentials store // - loading and unlocking its PGP keys -// - loading its store -func (u *User) connect(ctx context.Context, client pmapi.Client, creds *credentials.Credentials) error { +// - loading its store. +func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) error { u.log.Info("Connecting user") // Connected users have an API client. u.client = client - // FIXME(conman): How to remove this auth handler when user is disconnected? - u.client.AddAuthHandler(u.handleAuth) + u.client.AddAuthRefreshHandler(u.handleAuthRefresh) // Save the latest credentials for the user. u.creds = creds // Connected users have unlocked keys. - // FIXME(conman): clients should always be authorized! This is a workaround to avoid a major refactor :( - if u.creds.IsConnected() { - if err := u.client.Unlock(ctx, []byte(u.creds.MailboxPassword)); err != nil { - return err - } + if err := u.unlockIfNecessary(); err != nil { + return err } // Connected users have a store. - if err := u.loadStore(); err != nil { + if err := u.loadStore(); err != nil { //nolint[revive] easier to read return err } @@ -138,17 +134,25 @@ func (u *User) loadStore() error { return nil } -func (u *User) handleAuth(auth *pmapi.Auth) error { - u.log.Debug("User received auth") +func (u *User) handleAuthRefresh(auth *pmapi.AuthRefresh) { + u.log.Debug("User received auth refresh update") + + if auth == nil { + if err := u.logout(); err != nil { + log.WithError(err). + WithField("userID", u.userID). + Error("User logout failed while watching API auths") + } + return + } creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken) if err != nil { - return errors.Wrap(err, "failed to update refresh token in credentials store") + u.log.WithError(err).Error("Failed to update refresh token in credentials store") + return } u.creds = creds - - return nil } // clearStore removes the database. @@ -181,13 +185,6 @@ func (u *User) closeStore() error { return nil } -// GetTemporaryPMAPIClient returns an authorised PMAPI client. -// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations. -// After proper refactor of SMTP and IMAP remove this method. -func (u *User) GetTemporaryPMAPIClient() pmapi.Client { - return u.client -} - // ID returns the user's userID. func (u *User) ID() string { return u.userID @@ -210,9 +207,43 @@ func (u *User) IsConnected() bool { } func (u *User) GetClient() pmapi.Client { + if err := u.unlockIfNecessary(); err != nil { + u.log.WithError(err).Error("Failed to unlock user") + } return u.client } +// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked. +func (u *User) unlockIfNecessary() error { + if !u.creds.IsConnected() { + return nil + } + + if u.client.IsUnlocked() { + return nil + } + + // unlockIfNecessary is called with every access to underlying pmapi + // client. Unlock should only finish unlocking when connection is back up. + // That means it should try it fast enough and not retry if connection + // is still down. + err := u.client.Unlock(pmapi.ContextWithoutRetry(context.Background()), []byte(u.creds.MailboxPassword)) + if err == nil { + return nil + } + + switch errors.Cause(err) { + case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication: + u.log.WithError(err).Warn("Could not unlock user") + return nil + } + + if logoutErr := u.logout(); logoutErr != nil { + u.log.WithError(logoutErr).Warn("Could not logout user") + } + return errors.Wrap(err, "failed to unlock user") +} + // IsCombinedAddressMode returns whether user is set in combined or split mode. // Combined mode is the default mode and is what users typically need. // Split mode is mostly for outlook as it cannot handle sending e-mails from an @@ -307,14 +338,10 @@ func (u *User) GetBridgePassword() string { // CheckBridgeLogin checks whether the user is logged in and the bridge // IMAP/SMTP password is correct. func (u *User) CheckBridgeLogin(password string) error { - // FIXME(conman): Handle force upgrade? - - /* - if isApplicationOutdated { - u.listener.Emit(events.UpgradeApplicationEvent, "") - return pmapi.ErrUpgradeApplication - } - */ + if isApplicationOutdated { + u.listener.Emit(events.UpgradeApplicationEvent, "") + return pmapi.ErrUpgradeApplication + } u.lock.RLock() defer u.lock.RUnlock() @@ -328,16 +355,16 @@ func (u *User) CheckBridgeLogin(password string) error { } // UpdateUser updates user details from API and saves to the credentials. -func (u *User) UpdateUser() error { +func (u *User) UpdateUser(ctx context.Context) error { u.lock.Lock() defer u.lock.Unlock() - _, err := u.client.UpdateUser(context.TODO()) + _, err := u.client.UpdateUser(ctx) if err != nil { return err } - if err := u.client.ReloadKeys(context.TODO(), []byte(u.creds.MailboxPassword)); err != nil { + if err := u.client.ReloadKeys(ctx, []byte(u.creds.MailboxPassword)); err != nil { return errors.Wrap(err, "failed to reload keys") } @@ -414,8 +441,7 @@ func (u *User) Logout() error { return nil } - // FIXME(conman): Do we delete API client now? Who cleans up? What about registered handlers? - if err := u.client.AuthDelete(context.TODO()); err != nil { + if err := u.client.AuthDelete(context.Background()); err != nil { u.log.WithError(err).Warn("Failed to delete auth") } diff --git a/internal/users/user_credentials_test.go b/internal/users/user_credentials_test.go new file mode 100644 index 00000000..b3a13946 --- /dev/null +++ b/internal/users/user_credentials_test.go @@ -0,0 +1,195 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package users + +import ( + "context" + "testing" + + "github.com/ProtonMail/proton-bridge/internal/events" + "github.com/ProtonMail/proton-bridge/pkg/pmapi" + gomock "github.com/golang/mock/gomock" + "github.com/pkg/errors" + r "github.com/stretchr/testify/require" +) + +func TestUpdateUser(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + user := testNewUser(m) + defer cleanUpUserData(user) + + gomock.InOrder( + m.pmapiClient.EXPECT().UpdateUser(gomock.Any()).Return(nil, nil), + m.pmapiClient.EXPECT().ReloadKeys(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), + + m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}).Return(testCredentials, nil), + ) + + r.NoError(t, user.UpdateUser(context.Background())) +} + +func TestUserSwitchAddressMode(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + user := testNewUser(m) + defer cleanUpUserData(user) + + // Ignore any sync on background. + m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() + + // Check initial state. + r.True(t, user.store.IsCombinedMode()) + r.True(t, user.creds.IsCombinedAddressMode) + r.True(t, user.IsCombinedAddressMode()) + + // Mock change to split mode. + gomock.InOrder( + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"), + m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil), + m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), + m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentialsSplit, nil), + ) + + // Check switch to split mode. + r.NoError(t, user.SwitchAddressMode()) + r.False(t, user.store.IsCombinedMode()) + r.False(t, user.creds.IsCombinedAddressMode) + r.False(t, user.IsCombinedAddressMode()) + + // MOck change to combined mode. + gomock.InOrder( + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"), + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"), + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"), + m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil), + m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), + m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentials, nil), + ) + + // Check switch to combined mode. + r.NoError(t, user.SwitchAddressMode()) + r.True(t, user.store.IsCombinedMode()) + r.True(t, user.creds.IsCombinedAddressMode) + r.True(t, user.IsCombinedAddressMode()) +} + +func TestLogoutUser(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + user := testNewUser(m) + defer cleanUpUserData(user) + + gomock.InOrder( + m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil), + m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil), + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"), + ) + + err := user.Logout() + r.NoError(t, err) +} + +func TestLogoutUserFailsLogout(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + user := testNewUser(m) + defer cleanUpUserData(user) + + gomock.InOrder( + m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil), + m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")), + m.credentialsStore.EXPECT().Delete("user").Return(nil), + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"), + ) + + err := user.Logout() + r.NoError(t, err) +} + +func TestCheckBridgeLogin(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + user := testNewUser(m) + defer cleanUpUserData(user) + + err := user.CheckBridgeLogin(testCredentials.BridgePassword) + r.NoError(t, err) +} + +func TestCheckBridgeLoginUpgradeApplication(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + user := testNewUser(m) + defer cleanUpUserData(user) + + m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "") + + isApplicationOutdated = true + + err := user.CheckBridgeLogin("any-pass") + r.Equal(t, pmapi.ErrUpgradeApplication, err) + + isApplicationOutdated = false +} + +func TestCheckBridgeLoginLoggedOut(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + gomock.InOrder( + // Mock init of user. + m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), + m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()), + m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")), + m.pmapiClient.EXPECT().Addresses().Return(nil), + + // Mock CheckBridgeLogin. + m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"), + ) + + user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false) + r.NoError(t, err) + + err = user.connect(m.pmapiClient, testCredentialsDisconnected) + r.Error(t, err) + defer cleanUpUserData(user) + + err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword) + r.Equal(t, ErrLoggedOutUser, err) +} + +func TestCheckBridgeLoginBadPassword(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + user := testNewUser(m) + defer cleanUpUserData(user) + + err := user.CheckBridgeLogin("wrong!") + r.EqualError(t, err, "backend/credentials: incorrect password") +} diff --git a/internal/users/user_new_test.go b/internal/users/user_new_test.go new file mode 100644 index 00000000..50e42c8d --- /dev/null +++ b/internal/users/user_new_test.go @@ -0,0 +1,88 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package users + +import ( + "errors" + "testing" + + "github.com/ProtonMail/proton-bridge/internal/events" + "github.com/ProtonMail/proton-bridge/internal/users/credentials" + gomock "github.com/golang/mock/gomock" + r "github.com/stretchr/testify/require" +) + +func TestNewUserNoCredentialsStore(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail")) + + _, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false) + r.Error(t, err) +} + +func TestNewUserUnlockFails(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + gomock.InOrder( + // Init of user. + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()), + m.pmapiClient.EXPECT().IsUnlocked().Return(false), + m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("bad password")), + + // Handle of unlock error. + m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil), + m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil), + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"), + m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"), + m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"), + ) + + checkNewUserHasCredentials(m, "failed to unlock user: bad password", testCredentialsDisconnected) +} + +func TestNewUser(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) + mockInitConnectedUser(m) + mockEventLoopNoAction(m) + + checkNewUserHasCredentials(m, "", testCredentials) +} + +func checkNewUserHasCredentials(m mocks, wantErr string, wantCreds *credentials.Credentials) { + user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false) + r.NoError(m.t, err) + defer cleanUpUserData(user) + + err = user.connect(m.pmapiClient, testCredentials) + if wantErr == "" { + r.NoError(m.t, err) + } else { + r.EqualError(m.t, err, wantErr) + } + + r.Equal(m.t, wantCreds, user.creds) + + waitForEvents() +} diff --git a/internal/users/user_store_test.go b/internal/users/user_store_test.go new file mode 100644 index 00000000..87ee0c8e --- /dev/null +++ b/internal/users/user_store_test.go @@ -0,0 +1,51 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package users + +import ( + "testing" + + r "github.com/stretchr/testify/require" +) + +func _TestNeverLongStorePath(t *testing.T) { // nolint[unused] + r.Fail(t, "not implemented") +} + +func TestClearStoreWithStore(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + user := testNewUser(m) + defer cleanUpUserData(user) + + r.Nil(t, user.store.Close()) + user.store = nil + r.Nil(t, user.clearStore()) +} + +func TestClearStoreWithoutStore(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + user := testNewUser(m) + defer cleanUpUserData(user) + + r.NotNil(t, user.store) + r.Nil(t, user.clearStore()) +} diff --git a/internal/users/user_test.go b/internal/users/user_test.go new file mode 100644 index 00000000..01d08de8 --- /dev/null +++ b/internal/users/user_test.go @@ -0,0 +1,41 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package users + +import ( + r "github.com/stretchr/testify/require" +) + +// testNewUser sets up a new, authorised user. +func testNewUser(m mocks) *User { + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) + mockInitConnectedUser(m) + mockEventLoopNoAction(m) + + user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false) + r.NoError(m.t, err) + + err = user.connect(m.pmapiClient, creds) + r.NoError(m.t, err) + + return user +} + +func cleanUpUserData(u *User) { + _ = u.clearStore() +} diff --git a/internal/users/users.go b/internal/users/users.go index 01d18149..e819d53f 100644 --- a/internal/users/users.go +++ b/internal/users/users.go @@ -89,24 +89,42 @@ func New( lock: sync.RWMutex{}, } - // FIXME(conman): Handle force upgrade events. - /* - go func() { - defer panicHandler.HandlePanic() - u.watchAppOutdated() - }() - */ + go func() { + defer panicHandler.HandlePanic() + u.watchEvents() + }() if u.credStorer == nil { log.Error("No credentials store is available") - } else if err := u.loadUsersFromCredentialsStore(context.TODO()); err != nil { + } else if err := u.loadUsersFromCredentialsStore(); err != nil { log.WithError(err).Error("Could not load all users from credentials store") } return u } -func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error { +func (u *Users) watchEvents() { + upgradeCh := u.events.ProvideChannel(events.UpgradeApplicationEvent) + internetOnCh := u.events.ProvideChannel(events.InternetOnEvent) + + for { + select { + case <-upgradeCh: + isApplicationOutdated = true + u.closeAllConnections() + case <-internetOnCh: + for _, user := range u.users { + if user.store == nil { + if err := user.loadStore(); err != nil { + log.WithError(err).Error("Failed to load store after reconnecting") + } + } + } + } + } +} + +func (u *Users) loadUsersFromCredentialsStore() error { u.lock.Lock() defer u.lock.Unlock() @@ -116,23 +134,26 @@ func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error { } for _, userID := range userIDs { + l := log.WithField("user", userID) user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses) if err != nil { - logrus.WithError(err).Warn("Could not create user, skipping") + l.WithError(err).Warn("Could not create user, skipping") continue } u.users = append(u.users, user) if creds.IsConnected() { - if err := u.loadConnectedUser(ctx, user, creds); err != nil { - logrus.WithError(err).Warn("Could not load connected user") + // If there is no connection, we don't want to retry. Load should + // happen fast enough to not block GUI. When connection is back up, + // watchEvents and unlockIfNecessary will finish user init later. + if err := u.loadConnectedUser(pmapi.ContextWithoutRetry(context.Background()), user, creds); err != nil { + l.WithError(err).Warn("Could not load connected user") } } else { - logrus.Warn("User is disconnected and must be connected manually") - - if err := u.loadDisconnectedUser(ctx, user, creds); err != nil { - logrus.WithError(err).Warn("Could not load disconnected user") + l.Warn("User is disconnected and must be connected manually") + if err := user.connect(u.clientManager.NewClient("", "", "", time.Time{}), creds); err != nil { + l.WithError(err).Warn("Could not load disconnected user") } } } @@ -140,11 +161,6 @@ func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error { return err } -func (u *Users) loadDisconnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error { - // FIXME(conman): We shouldn't be creating unauthorized clients... this is hacky, just to avoid huge refactor! - return user.connect(ctx, u.clientManager.NewClient("", "", "", time.Time{}), creds) -} - func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error { uid, ref, err := creds.SplitAPIToken() if err != nil { @@ -153,38 +169,27 @@ func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *creden client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref) if err != nil { - // FIXME(conman): This is a problem... if we weren't able to create a new client due to internet, - // we need to be able to retry later, but I deleted all the hacky "retry auth if necessary" stuff... - return user.connect(ctx, u.clientManager.NewClient(uid, "", ref, time.Time{}), creds) + // When client cannot be refreshed right away due to no connection, + // we create client which will refresh automatically when possible. + connectErr := user.connect(u.clientManager.NewClient(uid, "", ref, time.Time{}), creds) + + switch errors.Cause(err) { + case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication: + return connectErr + } + + if logoutErr := user.logout(); logoutErr != nil { + logrus.WithError(logoutErr).Warn("Could not logout user") + } + return errors.Wrap(err, "could not refresh token") } // Update the user's credentials with the latest auth used to connect this user. - if creds, err = u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil { + if creds, err = u.credStorer.UpdateToken(creds.UserID, auth.UID, auth.RefreshToken); err != nil { return errors.Wrap(err, "could not create get user's refresh token") } - return user.connect(ctx, client, creds) -} - -func (u *Users) watchAppOutdated() { - // FIXME(conman): handle force upgrade events. - - /* - ch := make(chan string) - - u.events.Add(events.UpgradeApplicationEvent, ch) - - for { - select { - case <-ch: - isApplicationOutdated = true - u.closeAllConnections() - - case <-u.stopAll: - return - } - } - */ + return user.connect(client, creds) } func (u *Users) closeAllConnections() { @@ -198,19 +203,19 @@ func (u *Users) closeAllConnections() { func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) { u.crashBandicoot(username) - return u.clientManager.NewClientWithLogin(context.TODO(), username, password) + return u.clientManager.NewClientWithLogin(context.Background(), username, password) } // FinishLogin finishes the login procedure and adds the user into the credentials store. func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen] - apiUser, passphrase, err := getAPIUser(context.TODO(), client, password) + apiUser, passphrase, err := getAPIUser(context.Background(), client, password) if err != nil { - return nil, errors.Wrap(err, "failed to get API user") + return nil, err } if user, ok := u.hasUser(apiUser.ID); ok { if user.IsConnected() { - if err := client.AuthDelete(context.TODO()); err != nil { + if err := client.AuthDelete(context.Background()); err != nil { logrus.WithError(err).Warn("Failed to delete new auth session") } @@ -228,14 +233,16 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password stri return nil, errors.Wrap(err, "failed to update password of user in credentials store") } - if err := user.connect(context.TODO(), client, creds); err != nil { + if err := user.connect(client, creds); err != nil { return nil, errors.Wrap(err, "failed to reconnect existing user") } + u.events.Emit(events.UserRefreshEvent, apiUser.ID) + return user, nil } - if err := u.addNewUser(context.TODO(), client, apiUser, auth, passphrase); err != nil { + if err := u.addNewUser(client, apiUser, auth, passphrase); err != nil { return nil, errors.Wrap(err, "failed to add new user") } @@ -245,7 +252,7 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password stri } // addNewUser adds a new user. -func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error { +func (u *Users) addNewUser(client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error { u.lock.Lock() defer u.lock.Unlock() @@ -266,7 +273,7 @@ func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pm return errors.Wrap(err, "failed to create new user") } - if err := user.connect(ctx, client, creds); err != nil { + if err := user.connect(client, creds); err != nil { return errors.Wrap(err, "failed to connect new user") } @@ -292,7 +299,7 @@ func getAPIUser(ctx context.Context, client pmapi.Client, password string) (*pma // We unlock the user's PGP key here to detect if the user's mailbox password is wrong. if err := client.Unlock(ctx, passphrase); err != nil { - return nil, nil, errors.Wrap(err, "failed to unlock client") + return nil, nil, ErrWrongMailboxPassword } user, err := client.CurrentUser(ctx) @@ -414,22 +421,13 @@ func (u *Users) SendMetric(m metrics.Metric) error { // AllowProxy instructs the app to use DoH to access an API proxy if necessary. // It also needs to work before the app is initialised (because we may need to use the proxy at startup). func (u *Users) AllowProxy() { - // FIXME(conman): Support DoH. - // u.apiManager.AllowProxy() + u.clientManager.AllowProxy() } // DisallowProxy instructs the app to not use DoH to access an API proxy if necessary. // It also needs to work before the app is initialised (because we may need to use the proxy at startup). func (u *Users) DisallowProxy() { - // FIXME(conman): Support DoH. - // u.apiManager.DisallowProxy() -} - -// CheckConnection returns whether there is an internet connection. -// This should use the connection manager when it is eventually implemented. -func (u *Users) CheckConnection() error { - // FIXME(conman): Other parts of bridge that rely on this method should register as a connection observer. - panic("TODO: register as a connection observer to get this information") + u.clientManager.DisallowProxy() } // hasUser returns whether the struct currently has a user with ID `id`. diff --git a/internal/users/users_clear_test.go b/internal/users/users_clear_test.go new file mode 100644 index 00000000..91800e85 --- /dev/null +++ b/internal/users/users_clear_test.go @@ -0,0 +1,49 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package users + +import ( + "testing" + + "github.com/ProtonMail/proton-bridge/internal/events" + gomock "github.com/golang/mock/gomock" + r "github.com/stretchr/testify/require" +) + +func TestClearData(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + users := testNewUsersWithUsers(t, m) + defer cleanUpUsersData(users) + + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me") + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me") + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me") + + m.pmapiClient.EXPECT().AuthDelete(gomock.Any()) + m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil) + + m.pmapiClient.EXPECT().AuthDelete(gomock.Any()) + m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil) + + m.locator.EXPECT().Clear() + + r.NoError(t, users.ClearData()) +} diff --git a/internal/users/users_delete_test.go b/internal/users/users_delete_test.go new file mode 100644 index 00000000..0aee1052 --- /dev/null +++ b/internal/users/users_delete_test.go @@ -0,0 +1,69 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package users + +import ( + "errors" + "testing" + + "github.com/ProtonMail/proton-bridge/internal/events" + gomock "github.com/golang/mock/gomock" + r "github.com/stretchr/testify/require" +) + +func TestDeleteUser(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + users := testNewUsersWithUsers(t, m) + defer cleanUpUsersData(users) + + gomock.InOrder( + m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil), + m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil), + m.credentialsStore.EXPECT().Delete("user").Return(nil), + ) + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") + + err := users.DeleteUser("user", true) + r.NoError(t, err) + r.Equal(t, 1, len(users.users)) +} + +// Even when logout fails, delete is done. +func TestDeleteUserWithFailingLogout(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + users := testNewUsersWithUsers(t, m) + defer cleanUpUsersData(users) + + gomock.InOrder( + m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil), + m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")), + // Once called from user.Logout after failed creds.Logout as fallback, and once at the end of users.Logout. + m.credentialsStore.EXPECT().Delete("user").Return(nil), + m.credentialsStore.EXPECT().Delete("user").Return(nil), + ) + + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") + + err := users.DeleteUser("user", true) + r.NoError(t, err) + r.Equal(t, 1, len(users.users)) +} diff --git a/internal/users/users_get_test.go b/internal/users/users_get_test.go new file mode 100644 index 00000000..f79e1088 --- /dev/null +++ b/internal/users/users_get_test.go @@ -0,0 +1,76 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package users + +import ( + "testing" + + r "github.com/stretchr/testify/require" +) + +func TestGetNoUser(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + checkUsersGetUser(t, m, "nouser", -1, "user nouser not found") +} + +func TestGetUserByID(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + checkUsersGetUser(t, m, "user", 0, "") + checkUsersGetUser(t, m, "users", 1, "") +} + +func TestGetUserByName(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + checkUsersGetUser(t, m, "username", 0, "") + checkUsersGetUser(t, m, "usersname", 1, "") +} + +func TestGetUserByEmail(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + checkUsersGetUser(t, m, "user@pm.me", 0, "") + checkUsersGetUser(t, m, "users@pm.me", 1, "") + checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "") + checkUsersGetUser(t, m, "alsouser@pm.me", 1, "") +} + +func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) { + users := testNewUsersWithUsers(t, m) + defer cleanUpUsersData(users) + + user, err := users.GetUser(query) + + if expectedError != "" { + r.EqualError(m.t, err, expectedError) + } else { + r.NoError(m.t, err) + } + + var expectedUser *User + if index >= 0 { + expectedUser = users.users[index] + } + r.Equal(m.t, expectedUser, user) +} diff --git a/internal/users/users_login_test.go b/internal/users/users_login_test.go new file mode 100644 index 00000000..9fc6015c --- /dev/null +++ b/internal/users/users_login_test.go @@ -0,0 +1,132 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package users + +import ( + "testing" + + "github.com/ProtonMail/proton-bridge/internal/events" + "github.com/ProtonMail/proton-bridge/internal/metrics" + "github.com/ProtonMail/proton-bridge/pkg/pmapi" + gomock "github.com/golang/mock/gomock" + "github.com/pkg/errors" + r "github.com/stretchr/testify/require" +) + +func TestUsersFinishLoginBadMailboxPassword(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + // Init users with no user from keychain. + m.credentialsStore.EXPECT().List().Return([]string{}, nil) + + // Set up mocks for FinishLogin. + m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil) + m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("no keys could be unlocked")) + + checkUsersFinishLogin(t, m, testAuthRefresh, testCredentials.MailboxPassword, "", ErrWrongMailboxPassword) +} + +func TestUsersFinishLoginNewUser(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + // Init users with no user from keychain. + m.credentialsStore.EXPECT().List().Return([]string{}, nil) + + mockAddingConnectedUser(m) + mockEventLoopNoAction(m) + + m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)) + m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentials.UserID) + + checkUsersFinishLogin(t, m, testAuthRefresh, testCredentials.MailboxPassword, testCredentials.UserID, nil) +} + +func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + // Mock loading disconnected user. + m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID}, nil) + mockLoadingDisconnectedUser(m, testCredentialsDisconnected) + + // Mock process of FinishLogin of already added user. + gomock.InOrder( + m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil), + m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil), + m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUserDisconnected, nil), + m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil), + m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil), + ) + mockInitConnectedUser(m) + mockEventLoopNoAction(m) + m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID) + + authRefresh := &pmapi.Auth{ + UserID: testCredentialsDisconnected.UserID, + AuthRefresh: pmapi.AuthRefresh{ + UID: "uid", + AccessToken: "acc", + RefreshToken: "ref", + }, + } + checkUsersFinishLogin(t, m, authRefresh, testCredentials.MailboxPassword, testCredentialsDisconnected.UserID, nil) +} + +func TestUsersFinishLoginConnectedUser(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + // Mock loading connected user. + m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil) + mockLoadingConnectedUser(m, testCredentials) + mockEventLoopNoAction(m) + + // Mock process of FinishLogin of already connected user. + gomock.InOrder( + m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil), + m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil), + m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil), + m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil), + ) + + users := testNewUsers(t, m) + defer cleanUpUsersData(users) + + _, err := users.FinishLogin(m.pmapiClient, testAuthRefresh, testCredentials.MailboxPassword) + r.EqualError(t, err, "user is already connected") +} + +func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) { + users := testNewUsers(t, m) + defer cleanUpUsersData(users) + + user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword) + + r.Equal(t, expectedErr, err) + + if expectedUserID != "" { + r.Equal(t, expectedUserID, user.ID()) + r.Equal(t, 1, len(users.users)) + r.Equal(t, expectedUserID, users.users[0].ID()) + } else { + r.Equal(t, (*User)(nil), user) + r.Equal(t, 0, len(users.users)) + } +} diff --git a/internal/users/users_new_test.go b/internal/users/users_new_test.go index ed3a8ee2..ae3e4ff7 100644 --- a/internal/users/users_new_test.go +++ b/internal/users/users_new_test.go @@ -22,9 +22,10 @@ import ( "testing" time "time" + "github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/users/credentials" gomock "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" + r "github.com/stretchr/testify/require" ) func TestNewUsersNoKeychain(t *testing.T) { @@ -32,7 +33,6 @@ func TestNewUsersNoKeychain(t *testing.T) { defer m.ctrl.Finish() m.credentialsStore.EXPECT().List().Return([]string{}, errors.New("no keychain")) - checkUsersNew(t, m, []*credentials.Credentials{}) } @@ -41,108 +41,73 @@ func TestNewUsersWithoutUsersInCredentialsStore(t *testing.T) { defer m.ctrl.Finish() m.credentialsStore.EXPECT().List().Return([]string{}, nil) - checkUsersNew(t, m, []*credentials.Credentials{}) } -func TestNewUsersWithDisconnectedUser(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - gomock.InOrder( - m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil), - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), - m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient), - m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()), - m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")), - m.pmapiClient.EXPECT().Addresses().Return(nil), - ) - - checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) -} - -/* -func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() - - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - - m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2) - - m.credentialsStore.EXPECT().Logout("user").Return(nil) - m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")) - - m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") - m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") - m.pmapiClient.EXPECT().Logout() - m.credentialsStore.EXPECT().Logout("user").Return(nil) - m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil) - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") - - checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) -} - func TestNewUsersWithConnectedUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) - - mockConnectedUser(m) + m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil) + mockLoadingConnectedUser(m, testCredentials) mockEventLoopNoAction(m) - checkUsersNew(t, m, []*credentials.Credentials{testCredentials}) } +func TestNewUsersWithDisconnectedUser(t *testing.T) { + m := initMocks(t) + defer m.ctrl.Finish() + + m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID}, nil) + mockLoadingDisconnectedUser(m, testCredentialsDisconnected) + checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) +} + // Tests two users with different states and checks also the order from // credentials store is kept also in array of users. func TestNewUsersWithUsers(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - m.credentialsStore.EXPECT().List().Return([]string{"userDisconnected", "user"}, nil) - - gomock.InOrder( - m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil), - m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil), - // Set up mocks for store initialisation for the unauth user. - m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient), - m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")), - m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient), - m.pmapiClient.EXPECT().Addresses().Return(nil), - ) - - mockConnectedUser(m) - + m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil) + mockLoadingDisconnectedUser(m, testCredentialsDisconnected) + mockLoadingConnectedUser(m, testCredentials) mockEventLoopNoAction(m) - checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials}) } -func TestNewUsersFirstStart(t *testing.T) { +func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - m.credentialsStore.EXPECT().List().Return([]string{}, nil) + m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, errors.New("bad token")) + m.clientManager.EXPECT().NewClient("uid", "", "acc", time.Time{}).Return(m.pmapiClient) + m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()) + m.pmapiClient.EXPECT().IsUnlocked().Return(false) + m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("not authorized")) + m.pmapiClient.EXPECT().AuthDelete(gomock.Any()) - testNewUsers(t, m) + m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) + m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil) + + m.eventListener.EXPECT().Emit(events.LogoutEvent, "user") + m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user") + m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") + + checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) } -*/ func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) { users := testNewUsers(t, m) defer cleanUpUsersData(users) - assert.Equal(m.t, len(expectedCredentials), len(users.GetUsers())) + r.Equal(m.t, len(expectedCredentials), len(users.GetUsers())) credentials := []*credentials.Credentials{} for _, user := range users.users { credentials = append(credentials, user.creds) } - assert.Equal(m.t, expectedCredentials, credentials) + r.Equal(m.t, expectedCredentials, credentials) } diff --git a/internal/users/users_test.go b/internal/users/users_test.go index 37fa84f0..08971d95 100644 --- a/internal/users/users_test.go +++ b/internal/users/users_test.go @@ -33,8 +33,9 @@ import ( "github.com/ProtonMail/proton-bridge/pkg/pmapi" pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" gomock "github.com/golang/mock/gomock" + "github.com/pkg/errors" "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" + r "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -49,9 +50,12 @@ func TestMain(m *testing.M) { var ( testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals] - UID: "uid", - AccessToken: "acc", - RefreshToken: "ref", + UserID: "user", + AuthRefresh: pmapi.AuthRefresh{ + UID: "uid", + AccessToken: "acc", + RefreshToken: "ref", + }, } testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals] @@ -81,7 +85,7 @@ var ( } testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals] - UserID: "user", + UserID: "userDisconnected", Name: "username", Emails: "user@pm.me", APIToken: "", @@ -94,7 +98,7 @@ var ( } testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals] - UserID: "users", + UserID: "usersDisconnected", Name: "usersname", Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me", APIToken: "", @@ -111,17 +115,22 @@ var ( Name: "username", } + testPMAPIUserDisconnected = &pmapi.User{ //nolint[gochecknoglobals] + ID: "userDisconnected", + Name: "username", + } + testPMAPIAddress = &pmapi.Address{ //nolint[gochecknoglobals] ID: "testAddressID", Type: pmapi.OriginalAddress, Email: "user@pm.me", - Receive: pmapi.CanReceive, + Receive: true, } testPMAPIAddresses = []*pmapi.Address{ //nolint[gochecknoglobals] - {ID: "usersAddress1ID", Email: "users@pm.me", Receive: pmapi.CanReceive, Type: pmapi.OriginalAddress}, - {ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress}, - {ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress}, + {ID: "usersAddress1ID", Email: "users@pm.me", Receive: true, Type: pmapi.OriginalAddress}, + {ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: true, Type: pmapi.AliasAddress}, + {ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: true, Type: pmapi.AliasAddress}, } testPMAPIEvent = &pmapi.Event{ // nolint[gochecknoglobals] @@ -129,15 +138,6 @@ var ( } ) -func waitForEvents() { - // Wait for goroutine to add listener. - // E.g. calling login to invoke firstsync event. Functions can end sooner than - // goroutines call the listener mock. We need to wait a little bit before the end of - // the test to capture all event calls. This allows us to detect whether there were - // missing calls, or perhaps whether something was called too many times. - time.Sleep(100 * time.Millisecond) -} - type mocks struct { t *testing.T @@ -146,7 +146,7 @@ type mocks struct { PanicHandler *usersmocks.MockPanicHandler credentialsStore *usersmocks.MockCredentialsStorer storeMaker *usersmocks.MockStoreMaker - eventListener *MockListener + eventListener *usersmocks.MockListener clientManager *pmapimocks.MockManager pmapiClient *pmapimocks.MockClient @@ -154,6 +154,48 @@ type mocks struct { storeCache *store.Cache } +func initMocks(t *testing.T) mocks { + var mockCtrl *gomock.Controller + if os.Getenv("VERBOSITY") == "trace" { + mockCtrl = gomock.NewController(&fullStackReporter{t}) + } else { + mockCtrl = gomock.NewController(t) + } + + cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db") + r.NoError(t, err, "could not get temporary file for store cache") + + m := mocks{ + t: t, + + ctrl: mockCtrl, + locator: usersmocks.NewMockLocator(mockCtrl), + PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl), + credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl), + storeMaker: usersmocks.NewMockStoreMaker(mockCtrl), + eventListener: usersmocks.NewMockListener(mockCtrl), + + clientManager: pmapimocks.NewMockManager(mockCtrl), + pmapiClient: pmapimocks.NewMockClient(mockCtrl), + + storeCache: store.NewCache(cacheFile.Name()), + } + + // Called during clean-up. + m.PanicHandler.EXPECT().HandlePanic().AnyTimes() + + // Set up store factory. + m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) { + var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests. + dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db") + r.NoError(t, err, "could not get temporary file for store db") + return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache) + }).AnyTimes() + m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes() + + return m +} + type fullStackReporter struct { T testing.TB } @@ -168,86 +210,18 @@ func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) { fr.T.FailNow() } -func initMocks(t *testing.T) mocks { - var mockCtrl *gomock.Controller - if os.Getenv("VERBOSITY") == "trace" { - mockCtrl = gomock.NewController(&fullStackReporter{t}) - } else { - mockCtrl = gomock.NewController(t) - } - - cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db") - require.NoError(t, err, "could not get temporary file for store cache") - - m := mocks{ - t: t, - - ctrl: mockCtrl, - locator: usersmocks.NewMockLocator(mockCtrl), - PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl), - credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl), - storeMaker: usersmocks.NewMockStoreMaker(mockCtrl), - eventListener: NewMockListener(mockCtrl), - - clientManager: pmapimocks.NewMockManager(mockCtrl), - pmapiClient: pmapimocks.NewMockClient(mockCtrl), - - storeCache: store.NewCache(cacheFile.Name()), - } - - // Called during clean-up. - m.PanicHandler.EXPECT().HandlePanic().AnyTimes() - - // Set up store factory. - m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) { - var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests. - dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db") - require.NoError(t, err, "could not get temporary file for store db") - return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache) - }).AnyTimes() - m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes() - - return m -} - func testNewUsersWithUsers(t *testing.T, m mocks) *Users { - // Events are asynchronous - m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).Times(2) - m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2) - m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2) - - gomock.InOrder( - m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil), - - // Init for user. - m.credentialsStore.EXPECT().Get(testCredentials.UserID).Return(testCredentials, nil), - m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil), - m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()), - m.credentialsStore.EXPECT().UpdateToken(testCredentials.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil), - m.credentialsStore.EXPECT().UpdatePassword(testCredentials.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil), - m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil), - m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil), - m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil), - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - - // Init for users. - m.credentialsStore.EXPECT().Get(testCredentialsSplit.UserID).Return(testCredentialsSplit, nil), - m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil), - m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()), - m.credentialsStore.EXPECT().UpdateToken(testCredentialsSplit.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentialsSplit, nil), - m.credentialsStore.EXPECT().UpdatePassword(testCredentialsSplit.UserID, testCredentialsSplit.MailboxPassword).Return(testCredentialsSplit, nil), - m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil), - m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil), - m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil), - m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses), - ) + m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil) + mockLoadingConnectedUser(m, testCredentials) + mockLoadingConnectedUser(m, testCredentialsSplit) + mockEventLoopNoAction(m) return testNewUsers(t, m) } func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam] - // FIXME(conman): How to handle force upgrade? - // m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any()) + m.eventListener.EXPECT().ProvideChannel(events.UpgradeApplicationEvent) + m.eventListener.EXPECT().ProvideChannel(events.InternetOnEvent) users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true) @@ -256,38 +230,84 @@ func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam] return users } +func waitForEvents() { + // Wait for goroutine to add listener. + // E.g. calling login to invoke firstsync event. Functions can end sooner than + // goroutines call the listener mock. We need to wait a little bit before the end of + // the test to capture all event calls. This allows us to detect whether there were + // missing calls, or perhaps whether something was called too many times. + time.Sleep(100 * time.Millisecond) +} + func cleanUpUsersData(b *Users) { for _, user := range b.users { _ = user.clearStore() } } -func TestClearData(t *testing.T) { - m := initMocks(t) - defer m.ctrl.Finish() +func mockAddingConnectedUser(m mocks) { + gomock.InOrder( + // Mock of users.FinishLogin. + m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil), + m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil), + m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), + m.credentialsStore.EXPECT().Add("user", "username", testAuthRefresh.UID, testAuthRefresh.RefreshToken, testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}).Return(testCredentials, nil), + m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), + ) - // m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) - // m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) + mockInitConnectedUser(m) +} - users := testNewUsersWithUsers(t, m) - defer cleanUpUsersData(users) +func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) { + authRefresh := &pmapi.AuthRefresh{ + UID: "uid", + AccessToken: "acc", + RefreshToken: "ref", + } - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me") - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me") - m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me") + gomock.InOrder( + // Mock of users.loadUsersFromCredentialsStore. + m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil), + m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, authRefresh, nil), + m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil), + ) - m.pmapiClient.EXPECT().AuthDelete(gomock.Any()) - m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil) + mockInitConnectedUser(m) +} - m.pmapiClient.EXPECT().AuthDelete(gomock.Any()) - m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil) +func mockInitConnectedUser(m mocks) { + // Mock of user initialisation. + m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()) + m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes() - m.locator.EXPECT().Clear() + // Mock of store initialisation. + gomock.InOrder( + m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil), + m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil), + m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), + ) +} - require.NoError(t, users.ClearData()) +func mockLoadingDisconnectedUser(m mocks, creds *credentials.Credentials) { + gomock.InOrder( + // Mock of users.loadUsersFromCredentialsStore. + m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil), + m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient), + ) - waitForEvents() + mockInitDisconnectedUser(m) +} + +func mockInitDisconnectedUser(m mocks) { + gomock.InOrder( + // Mock of user initialisation. + m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()), + + // Mock of store initialisation for the unauthorized user. + m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")), + m.pmapiClient.EXPECT().Addresses().Return(nil), + ) } func mockEventLoopNoAction(m mocks) { @@ -297,19 +317,3 @@ func mockEventLoopNoAction(m mocks) { m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes() m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() } - -func mockConnectedUser(m mocks) { - gomock.InOrder( - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), - - m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), - // m.pmapiClient.EXPECT().AuthRefresh("uid:acc").Return(testAuthRefresh, nil), - - m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil), - - // Set up mocks for store initialisation for the authorized user. - m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil), - m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil), - m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), - ) -} diff --git a/pkg/keychain/helper_darwin.go b/pkg/keychain/helper_darwin.go index fb39725a..6fada788 100644 --- a/pkg/keychain/helper_darwin.go +++ b/pkg/keychain/helper_darwin.go @@ -105,6 +105,10 @@ func (h *macOSHelper) Get(secretURL string) (string, string, error) { return "", "", err } + if len(results) == 0 { + return "", "", errors.New("no result") + } + if len(results) != 1 { return "", "", errors.New("ambiguous results") } diff --git a/pkg/listener/listener.go b/pkg/listener/listener.go index 17121b9c..10cb9dde 100644 --- a/pkg/listener/listener.go +++ b/pkg/listener/listener.go @@ -29,6 +29,7 @@ var log = logrus.WithField("pkg", "bridgeUtils/listener") //nolint[gochecknoglob // Listener has a list of channels watching for updates. type Listener interface { SetLimit(eventName string, limit time.Duration) + ProvideChannel(eventName string) <-chan string Add(eventName string, channel chan<- string) Remove(eventName string, channel chan<- string) Emit(eventName string, data string) @@ -69,6 +70,15 @@ func (l *listener) SetLimit(eventName string, limit time.Duration) { l.limits[eventName] = limit } +// ProvideChannel creates new channel, adds it to listener and sends to it +// bufferent events. +func (l *listener) ProvideChannel(eventName string) <-chan string { + ch := make(chan string) + l.Add(eventName, ch) + l.RetryEmit(eventName) + return ch +} + // Add adds an event listener. func (l *listener) Add(eventName string, channel chan<- string) { l.lock.Lock() diff --git a/pkg/message/build_fetch.go b/pkg/message/build_fetch.go index 8076b621..81a99829 100644 --- a/pkg/message/build_fetch.go +++ b/pkg/message/build_fetch.go @@ -97,7 +97,7 @@ func fetchWorker(fetchReqCh <-chan fetchReq, fetchResCh chan<- fetchRes, attachW } func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, error) { - msg, err := req.api.GetMessage(req.messageID) + msg, err := req.api.GetMessage(req.ctx, req.messageID) if err != nil { return nil, nil, err } @@ -109,7 +109,7 @@ func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, er } process := func(value interface{}) (interface{}, error) { - rc, err := req.api.GetAttachment(value.(string)) + rc, err := req.api.GetAttachment(req.ctx, value.(string)) if err != nil { return nil, err } diff --git a/pkg/message/build_framework_test.go b/pkg/message/build_framework_test.go index 49e76bef..02cbc3e8 100644 --- a/pkg/message/build_framework_test.go +++ b/pkg/message/build_framework_test.go @@ -43,10 +43,10 @@ func newTestFetcher( ) Fetcher { f := mocks.NewMockFetcher(m) - f.EXPECT().GetMessage(msg.ID).Return(msg, nil) + f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil) for i, att := range msg.Attachments { - f.EXPECT().GetAttachment(att.ID).Return(newTestReadCloser(attData[i]), nil) + f.EXPECT().GetAttachment(gomock.Any(), att.ID).Return(newTestReadCloser(attData[i]), nil) } f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(kr, nil) diff --git a/pkg/message/build_test.go b/pkg/message/build_test.go index 2dfe1600..bbf38365 100644 --- a/pkg/message/build_test.go +++ b/pkg/message/build_test.go @@ -1230,7 +1230,7 @@ func TestBuildFetchMessageFail(t *testing.T) { // Pretend the message cannot be fetched. f := mocks.NewMockFetcher(m) - f.EXPECT().GetMessage(msg.ID).Return(nil, errors.New("oops")) + f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(nil, errors.New("oops")) // The job should fail, returning an error and a nil result. res, err := b.NewJob(context.Background(), f, msg.ID).GetResult() @@ -1251,8 +1251,8 @@ func TestBuildFetchAttachmentFail(t *testing.T) { // Pretend the attachment cannot be fetched. f := mocks.NewMockFetcher(m) - f.EXPECT().GetMessage(msg.ID).Return(msg, nil) - f.EXPECT().GetAttachment(msg.Attachments[0].ID).Return(nil, errors.New("oops")) + f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil) + f.EXPECT().GetAttachment(gomock.Any(), msg.Attachments[0].ID).Return(nil, errors.New("oops")) // The job should fail, returning an error and a nil result. res, err := b.NewJob(context.Background(), f, msg.ID).GetResult() @@ -1272,7 +1272,7 @@ func TestBuildNoSuchKeyRing(t *testing.T) { // Pretend there is no available keyring. f := mocks.NewMockFetcher(m) - f.EXPECT().GetMessage(msg.ID).Return(msg, nil) + f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil) f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(nil, errors.New("oops")) res, err := b.NewJob(context.Background(), f, msg.ID).GetResult() diff --git a/pkg/message/flags.go b/pkg/message/flags.go index b895c61a..5ed2c2c6 100644 --- a/pkg/message/flags.go +++ b/pkg/message/flags.go @@ -31,7 +31,7 @@ const ( // GetFlags returns imap flags from pmapi message attributes. func GetFlags(m *pmapi.Message) (flags []string) { - if m.Unread == 0 { + if !m.Unread { flags = append(flags, imap.SeenFlag) } if !m.Has(pmapi.FlagSent) && !m.Has(pmapi.FlagReceived) { @@ -68,11 +68,11 @@ func ParseFlags(m *pmapi.Message, flags []string) { m.Flags = pmapi.FlagReceived } - m.Unread = 1 + m.Unread = true for _, f := range flags { switch f { case imap.SeenFlag: - m.Unread = 0 + m.Unread = false case imap.DraftFlag: m.Flags &= ^pmapi.FlagSent m.Flags &= ^pmapi.FlagReceived diff --git a/pkg/message/mocks/mocks.go b/pkg/message/mocks/mocks.go index d55de7f2..0aa4b7ed 100644 --- a/pkg/message/mocks/mocks.go +++ b/pkg/message/mocks/mocks.go @@ -5,6 +5,7 @@ package mocks import ( + context "context" io "io" reflect "reflect" @@ -37,33 +38,33 @@ func (m *MockFetcher) EXPECT() *MockFetcherMockRecorder { } // GetAttachment mocks base method -func (m *MockFetcher) GetAttachment(arg0 string) (io.ReadCloser, error) { +func (m *MockFetcher) GetAttachment(arg0 context.Context, arg1 string) (io.ReadCloser, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAttachment", arg0) + ret := m.ctrl.Call(m, "GetAttachment", arg0, arg1) ret0, _ := ret[0].(io.ReadCloser) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAttachment indicates an expected call of GetAttachment -func (mr *MockFetcherMockRecorder) GetAttachment(arg0 interface{}) *gomock.Call { +func (mr *MockFetcherMockRecorder) GetAttachment(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockFetcher)(nil).GetAttachment), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockFetcher)(nil).GetAttachment), arg0, arg1) } // GetMessage mocks base method -func (m *MockFetcher) GetMessage(arg0 string) (*pmapi.Message, error) { +func (m *MockFetcher) GetMessage(arg0 context.Context, arg1 string) (*pmapi.Message, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMessage", arg0) + ret := m.ctrl.Call(m, "GetMessage", arg0, arg1) ret0, _ := ret[0].(*pmapi.Message) ret1, _ := ret[1].(error) return ret0, ret1 } // GetMessage indicates an expected call of GetMessage -func (mr *MockFetcherMockRecorder) GetMessage(arg0 interface{}) *gomock.Call { +func (mr *MockFetcherMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockFetcher)(nil).GetMessage), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockFetcher)(nil).GetMessage), arg0, arg1) } // KeyRingForAddressID mocks base method diff --git a/pkg/mime/encoding.go b/pkg/mime/encoding.go index 5e0c7a9f..5caac7af 100644 --- a/pkg/mime/encoding.go +++ b/pkg/mime/encoding.go @@ -191,7 +191,7 @@ func DecodeHeader(raw string) (decoded string, err error) { return } -// EncodeHeader using quoted printable and utf8 +// EncodeHeader using quoted printable and utf8. func EncodeHeader(s string) string { return mime.QEncoding.Encode("utf-8", s) } diff --git a/pkg/mime/encoding_test.go b/pkg/mime/encoding_test.go index cc69177b..8bf93f89 100644 --- a/pkg/mime/encoding_test.go +++ b/pkg/mime/encoding_test.go @@ -19,7 +19,6 @@ package pmmime import ( "bytes" - //"fmt" "strings" "testing" diff --git a/pkg/pmapi/addresses.go b/pkg/pmapi/addresses.go index fff80821..28dbff69 100644 --- a/pkg/pmapi/addresses.go +++ b/pkg/pmapi/addresses.go @@ -32,12 +32,6 @@ const ( EnabledAddress ) -// Address receive values. -const ( - CannotReceive = iota - CanReceive -) - // Address HasKeys values. const ( MissingKeys = iota @@ -66,7 +60,7 @@ type Address struct { DomainID string Email string Send int - Receive int + Receive Boolean Status int Order int `json:",omitempty"` Type int @@ -103,7 +97,7 @@ func (l AddressList) AllEmails() (addresses []string) { // ActiveEmails returns only active emails. func (l AddressList) ActiveEmails() (addresses []string) { for _, a := range l { - if a.Receive == CanReceive { + if a.Receive { addresses = append(addresses, a.Email) } } @@ -175,8 +169,19 @@ func (c *client) GetAddresses(ctx context.Context) (addresses AddressList, err e return res.Addresses, nil } -func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) (err error) { - panic("TODO") +func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) error { + if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(&struct { + AddressIDs []string + }{ + AddressIDs: addressIDs, + }).Put("/addresses/order") + }); err != nil { + return err + } + + _, err := c.UpdateUser(ctx) + return err } // Addresses returns the addresses stored in the client object itself rather than fetching from the API. @@ -185,24 +190,22 @@ func (c *client) Addresses() AddressList { } // unlockAddresses unlocks all keys for all addresses of current user. -func (c *client) unlockAddress(passphrase []byte, address *Address) (err error) { +func (c *client) unlockAddress(passphrase []byte, address *Address) error { if address == nil { return errors.New("address data is missing") } if address.HasKeys == MissingKeys { - return + return nil } - var kr *crypto.KeyRing - - if kr, err = address.Keys.UnlockAll(passphrase, c.userKeyRing); err != nil { - return + kr, err := address.Keys.UnlockAll(passphrase, c.userKeyRing) + if err != nil { + return err } c.addrKeyRing[address.ID] = kr - - return + return nil } func (c *client) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) { diff --git a/pkg/pmapi/addresses_test.go b/pkg/pmapi/addresses_test.go index e8724642..4b8bfc91 100644 --- a/pkg/pmapi/addresses_test.go +++ b/pkg/pmapi/addresses_test.go @@ -20,6 +20,8 @@ package pmapi import ( "net/http" "testing" + + r "github.com/stretchr/testify/require" ) var testAddressList = AddressList{ @@ -46,39 +48,29 @@ var testAddressList = AddressList{ }, } -func routeGetAddresses(tb testing.TB, w http.ResponseWriter, r *http.Request) string { - Ok(tb, checkMethodAndPath(r, "GET", "/addresses")) - Ok(tb, isAuthReq(r, testUID, testAccessToken)) +func routeGetAddresses(tb testing.TB, w http.ResponseWriter, req *http.Request) string { + r.NoError(tb, checkMethodAndPath(req, "GET", "/addresses")) + r.NoError(tb, isAuthReq(req, testUID, testAccessToken)) return "addresses/get_response.json" } func TestAddressList(t *testing.T) { input := "1" addr := testAddressList.ByID(input) - if addr != testAddressList[0] { - t.Errorf("ById(%s) expected:\n%v\n but have:\n%v\n", input, testAddressList[0], addr) - } + r.Equal(t, testAddressList[0], addr) input = "42" addr = testAddressList.ByID(input) - if addr != nil { - t.Errorf("ById expected nil for %s but have : %v\n", input, addr) - } + r.Nil(t, addr) input = "root@protonmail.com" addr = testAddressList.ByEmail(input) - if addr != testAddressList[2] { - t.Errorf("ByEmail(%s) expected:\n%v\n but have:\n%v\n", input, testAddressList[2], addr) - } + r.Equal(t, testAddressList[2], addr) input = "idontexist@protonmail.com" addr = testAddressList.ByEmail(input) - if addr != nil { - t.Errorf("ByEmail expected nil for %s but have : %v\n", input, addr) - } + r.Nil(t, addr) addr = testAddressList.Main() - if addr != testAddressList[1] { - t.Errorf("Main() expected:\n%v\n but have:\n%v\n", testAddressList[1], addr) - } + r.Equal(t, testAddressList[1], addr) } diff --git a/pkg/pmapi/attachments.go b/pkg/pmapi/attachments.go index 5eb0a17a..adb0d08d 100644 --- a/pkg/pmapi/attachments.go +++ b/pkg/pmapi/attachments.go @@ -23,7 +23,6 @@ import ( "encoding/json" "fmt" "io" - "mime/multipart" "net/textproto" "github.com/ProtonMail/gopenpgp/v2/crypto" @@ -138,44 +137,6 @@ func (a *Attachment) DetachedSign(kr *crypto.KeyRing, att io.Reader) (signed io. return signAttachment(kr, att) } -func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.Reader) (err error) { - // Create metadata fields. - if err = w.WriteField("Filename", att.Name); err != nil { - return - } - if err = w.WriteField("MessageID", att.MessageID); err != nil { - return - } - if err = w.WriteField("MIMEType", att.MIMEType); err != nil { - return - } - - if err = w.WriteField("ContentID", att.ContentID); err != nil { - return - } - - // And send attachment data. - ff, err := w.CreateFormFile("DataPacket", "DataPacket.pgp") - if err != nil { - return - } - if _, err = io.Copy(ff, r); err != nil { - return - } - - // And send attachment data. - sigff, err := w.CreateFormFile("Signature", "Signature.pgp") - if err != nil { - return - } - - if _, err = io.Copy(sigff, sig); err != nil { - return - } - - return err -} - // CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID. // // The returned created attachment contains the new attachment ID and its size. diff --git a/pkg/pmapi/attachments_test.go b/pkg/pmapi/attachments_test.go index 95ec24cc..0791d3a4 100644 --- a/pkg/pmapi/attachments_test.go +++ b/pkg/pmapi/attachments_test.go @@ -28,13 +28,13 @@ import ( "mime/multipart" "net/http" "net/textproto" - "reflect" "strings" "testing" pmmime "github.com/ProtonMail/proton-bridge/pkg/mime" - "github.com/stretchr/testify/assert" + a "github.com/stretchr/testify/assert" + r "github.com/stretchr/testify/require" ) var testAttachment = &Attachment{ @@ -77,65 +77,40 @@ const testCreateAttachmentBody = `{ "Attachment": {"ID": "y6uKIlc2HdoHPAwPSrvf7dXoZNMYvBgxshYUN67cY5DJjL2O8NYewuvGHcYvCfd8LpEoAI_GdymO0Jr0mHlsEw=="} }` -const testDeleteAttachmentBody = `{ - "Code": 1000 -}` - func TestAttachment_UnmarshalJSON(t *testing.T) { att := new(Attachment) - if err := json.Unmarshal([]byte(testAttachmentJSON), att); err != nil { - t.Fatal("Expected no error while unmarshaling JSON, got:", err) - } + err := json.Unmarshal([]byte(testAttachmentJSON), att) + r.NoError(t, err) att.MessageID = testAttachment.MessageID // This isn't in the JSON object - if !reflect.DeepEqual(testAttachment, att) { - t.Errorf("Invalid attachment: expected %+v but got %+v", testAttachment, att) - } + r.Equal(t, testAttachment, att) } func TestClient_CreateAttachment(t *testing.T) { - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/attachments")) + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "POST", "/mail/v4/attachments")) - contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type")) - if err != nil { - t.Error("Expected no error while parsing request content type, got:", err) - } - if contentType != "multipart/form-data" { - t.Errorf("Invalid request content type: expected %v but got %v", "multipart/form-data", contentType) - } + contentType, params, err := pmmime.ParseMediaType(req.Header.Get("Content-Type")) + r.NoError(t, err) + r.Equal(t, "multipart/form-data", contentType) - mr := multipart.NewReader(r.Body, params["boundary"]) + mr := multipart.NewReader(req.Body, params["boundary"]) form, err := mr.ReadForm(10 * 1024) - if err != nil { - t.Error("Expected no error while parsing request form, got:", err) - } - defer Ok(t, form.RemoveAll()) + r.NoError(t, err) + defer r.NoError(t, form.RemoveAll()) - if form.Value["Filename"][0] != testAttachment.Name { - t.Errorf("Invalid attachment filename: expected %v but got %v", testAttachment.Name, form.Value["Filename"][0]) - } - if form.Value["MessageID"][0] != testAttachment.MessageID { - t.Errorf("Invalid attachment message id: expected %v but got %v", testAttachment.MessageID, form.Value["MessageID"][0]) - } - if form.Value["MIMEType"][0] != testAttachment.MIMEType { - t.Errorf("Invalid attachment message id: expected %v but got %v", testAttachment.MIMEType, form.Value["MIMEType"][0]) - } + r.Equal(t, testAttachment.Name, form.Value["Filename"][0]) + r.Equal(t, testAttachment.MessageID, form.Value["MessageID"][0]) + r.Equal(t, testAttachment.MIMEType, form.Value["MIMEType"][0]) dataFile, err := form.File["DataPacket"][0].Open() - if err != nil { - t.Error("Expected no error while opening packets file, got:", err) - } - defer Ok(t, dataFile.Close()) + r.NoError(t, err) + defer r.NoError(t, dataFile.Close()) b, err := ioutil.ReadAll(dataFile) - if err != nil { - t.Error("Expected no error while reading packets file, got:", err) - } - if string(b) != testAttachmentCleartext { - t.Errorf("Invalid attachment packets: expected %v but got %v", testAttachment.KeyPackets, string(b)) - } + r.NoError(t, err) + r.Equal(t, testAttachmentCleartext, string(b)) w.Header().Set("Content-Type", "application/json") @@ -143,50 +118,39 @@ func TestClient_CreateAttachment(t *testing.T) { })) defer s.Close() - r := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted - created, err := c.CreateAttachment(context.TODO(), testAttachment, r, strings.NewReader("")) - if err != nil { - t.Fatal("Expected no error while creating attachment, got:", err) - } + reader := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted + created, err := c.CreateAttachment(context.Background(), testAttachment, reader, strings.NewReader("")) + r.NoError(t, err) - if created.ID != testAttachment.ID { - t.Errorf("Invalid attachment id: expected %v but got %v", testAttachment.ID, created.ID) - } + r.Equal(t, testAttachment.ID, created.ID) } func TestClient_GetAttachment(t *testing.T) { - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "GET", "/mail/v4/attachments/"+testAttachment.ID)) + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "GET", "/mail/v4/attachments/"+testAttachment.ID)) w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, testAttachmentCleartext) })) defer s.Close() - r, err := c.GetAttachment(context.TODO(), testAttachment.ID) - if err != nil { - t.Fatal("Expected no error while getting attachment, got:", err) - } - defer r.Close() //nolint[errcheck] + att, err := c.GetAttachment(context.Background(), testAttachment.ID) + r.NoError(t, err) + defer att.Close() //nolint[errcheck] // In reality, r contains encrypted data - b, err := ioutil.ReadAll(r) - if err != nil { - t.Fatal("Expected no error while reading attachment, got:", err) - } + b, err := ioutil.ReadAll(att) + r.NoError(t, err) - if string(b) != testAttachmentCleartext { - t.Errorf("Invalid attachment data: expected %q but got %q", testAttachmentCleartext, string(b)) - } + r.Equal(t, testAttachmentCleartext, string(b)) } func TestAttachment_Encrypt(t *testing.T) { data := bytes.NewBufferString(testAttachmentCleartext) r, err := testAttachment.Encrypt(testPublicKeyRing, data) - assert.Nil(t, err) + a.Nil(t, err) b, err := ioutil.ReadAll(r) - assert.Nil(t, err) + a.Nil(t, err) // Result is always different, so the best way is to test it by decrypting again. // Another test for decrypting will help us to be sure it's working. @@ -202,8 +166,8 @@ func TestAttachment_Decrypt(t *testing.T) { func decryptAndCheck(t *testing.T, data io.Reader) { r, err := testAttachment.Decrypt(data, testPrivateKeyRing) - assert.Nil(t, err) + a.Nil(t, err) b, err := ioutil.ReadAll(r) - assert.Nil(t, err) - assert.Equal(t, testAttachmentCleartext, string(b)) + a.Nil(t, err) + a.Equal(t, testAttachmentCleartext, string(b)) } diff --git a/pkg/pmapi/auth.go b/pkg/pmapi/auth.go index 011481eb..8773278d 100644 --- a/pkg/pmapi/auth.go +++ b/pkg/pmapi/auth.go @@ -1,3 +1,20 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import ( @@ -6,15 +23,117 @@ import ( "encoding/base64" "errors" "io" + "net/http" "time" "github.com/go-resty/resty/v2" ) -func (c *client) Auth2FA(ctx context.Context, req Auth2FAReq) error { - if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { - return r.SetBody(req).Post("/auth/2fa") +type AuthModulus struct { + Modulus string + ModulusID string +} + +type GetAuthInfoReq struct { + Username string +} + +type AuthInfo struct { + Version int + Modulus string + ServerEphemeral string + Salt string + SRPSession string +} + +type TwoFAInfo struct { + Enabled TwoFAStatus +} + +func (twoFAInfo TwoFAInfo) hasTwoFactor() bool { + return twoFAInfo.Enabled > 0 +} + +type TwoFAStatus int + +const ( + TwoFADisabled TwoFAStatus = iota + TOTPEnabled + U2FEnabled + TOTPAndU2FEnabled +) + +type PasswordMode int + +const ( + OnePasswordMode PasswordMode = iota + 1 + TwoPasswordMode +) + +type AuthReq struct { + Username string + ClientProof string + ClientEphemeral string + SRPSession string +} + +type AuthRefresh struct { + UID string + AccessToken string + RefreshToken string + ExpiresIn int64 + Scopes []string +} + +type Auth struct { + AuthRefresh + + UserID string + ServerProof string + PasswordMode PasswordMode + TwoFA *TwoFAInfo `json:"2FA,omitempty"` +} + +func (a Auth) HasTwoFactor() bool { + if a.TwoFA == nil { + return false + } + return a.TwoFA.hasTwoFactor() +} + +func (a Auth) HasMailboxPassword() bool { + return a.PasswordMode == TwoPasswordMode +} + +type auth2FAReq struct { + TwoFactorCode string +} + +type authRefreshReq struct { + UID string + RefreshToken string + ResponseType string + GrantType string + RedirectURI string + State string +} + +func (c *client) Auth2FA(ctx context.Context, twoFactorCode string) error { + // 2FA is called during login procedure during which refresh token should + // be valid, therefore, no refresh is needed if there is an error. + ctx = ContextWithoutAuthRefresh(ctx) + + if res, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(auth2FAReq{TwoFactorCode: twoFactorCode}).Post("/auth/2fa") }); err != nil { + if res != nil { + switch res.StatusCode() { + case http.StatusUnauthorized: + return ErrBad2FACode + case http.StatusUnprocessableEntity: + return ErrBad2FACodeTryAgain + } + } return err } @@ -29,9 +148,7 @@ func (c *client) AuthDelete(ctx context.Context) error { } c.uid, c.acc, c.ref, c.exp = "", "", "", time.Time{} - - // FIXME(conman): should we perhaps signal via AuthHandler that the auth was deleted? - + c.sendAuthRefresh(nil) return nil } @@ -54,7 +171,7 @@ func (c *client) AuthSalt(ctx context.Context) (string, error) { return "", errors.New("no matching salt found") } -func (c *client) AddAuthHandler(handler AuthHandler) { +func (c *client) AddAuthRefreshHandler(handler AuthRefreshHandler) { c.authHandlers = append(c.authHandlers, handler) } @@ -62,23 +179,35 @@ func (c *client) authRefresh(ctx context.Context) error { c.authLocker.Lock() defer c.authLocker.Unlock() - auth, err := c.req.authRefresh(ctx, c.uid, c.ref) + if c.ref == "" { + return ErrUnauthorized + } + + auth, err := c.manager.authRefresh(ctx, c.uid, c.ref) if err != nil { + if err != ErrNoConnection { + c.sendAuthRefresh(nil) + } return err } c.acc = auth.AccessToken c.ref = auth.RefreshToken + c.exp = expiresIn(auth.ExpiresIn) - for _, handler := range c.authHandlers { - if err := handler(auth); err != nil { - return err - } - } - + c.sendAuthRefresh(auth) return nil } +func (c *client) sendAuthRefresh(auth *AuthRefresh) { + for _, handler := range c.authHandlers { + go handler(auth) + } + if auth == nil { + c.authHandlers = []AuthRefreshHandler{} + } +} + func randomString(length int) string { noise := make([]byte, length) diff --git a/pkg/pmapi/auth_test.go b/pkg/pmapi/auth_test.go index b989488d..5ccb254b 100644 --- a/pkg/pmapi/auth_test.go +++ b/pkg/pmapi/auth_test.go @@ -1,22 +1,40 @@ -package pmapi_test +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi import ( "context" "encoding/json" - "errors" "net/http" "net/http/httptest" "testing" "time" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" + a "github.com/stretchr/testify/assert" + r "github.com/stretchr/testify/require" ) func TestAutomaticAuthRefresh(t *testing.T) { - var wantAuth = &pmapi.Auth{ + var wantAuthRefresh = &AuthRefresh{ UID: "testUID", AccessToken: "testAcc", RefreshToken: "testRef", + ExpiresIn: 100, } mux := http.NewServeMux() @@ -24,7 +42,7 @@ func TestAutomaticAuthRefresh(t *testing.T) { mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(wantAuth); err != nil { + if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil { panic(err) } }) @@ -35,28 +53,28 @@ func TestAutomaticAuthRefresh(t *testing.T) { ts := httptest.NewServer(mux) - var gotAuth *pmapi.Auth + var gotAuthRefresh *AuthRefresh - // Create a new client. - c := pmapi.New(pmapi.Config{HostURL: ts.URL}). + c := New(Config{HostURL: ts.URL}). NewClient("uid", "acc", "ref", time.Now().Add(-time.Second)) - // Register an auth handler. - c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil }) + c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth }) // Make a request with an access token that already expired one second ago. - if _, err := c.GetAddresses(context.Background()); err != nil { - t.Fatal("got unexpected error", err) - } + _, err := c.GetAddresses(context.Background()) + r.NoError(t, err) // The auth callback should have been called. - if *gotAuth != *wantAuth { - t.Fatal("got unexpected auth", gotAuth) - } + a.Equal(t, *wantAuthRefresh, *gotAuthRefresh) + + cl := c.(*client) //nolint[forcetypeassert] we want to panic here + a.Equal(t, wantAuthRefresh.AccessToken, cl.acc) + a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref) + a.WithinDuration(t, expiresIn(100), cl.exp, time.Second) } func Test401AuthRefresh(t *testing.T) { - var wantAuth = &pmapi.Auth{ + var wantAuthRefresh = &AuthRefresh{ UID: "testUID", AccessToken: "testAcc", RefreshToken: "testRef", @@ -67,7 +85,7 @@ func Test401AuthRefresh(t *testing.T) { mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(wantAuth); err != nil { + if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil { panic(err) } }) @@ -86,24 +104,21 @@ func Test401AuthRefresh(t *testing.T) { ts := httptest.NewServer(mux) - var gotAuth *pmapi.Auth + var gotAuthRefresh *AuthRefresh // Create a new client. - c := pmapi.New(pmapi.Config{HostURL: ts.URL}). + c := New(Config{HostURL: ts.URL}). NewClient("uid", "acc", "ref", time.Now().Add(time.Hour)) // Register an auth handler. - c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil }) + c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth }) // The first request will fail with 401, triggering a refresh and retry. - if _, err := c.GetAddresses(context.Background()); err != nil { - t.Fatal("got unexpected error", err) - } + _, err := c.GetAddresses(context.Background()) + r.NoError(t, err) // The auth callback should have been called. - if *gotAuth != *wantAuth { - t.Fatal("got unexpected auth", gotAuth) - } + r.Equal(t, *wantAuthRefresh, *gotAuthRefresh) } func Test401RevokedAuth(t *testing.T) { @@ -119,17 +134,57 @@ func Test401RevokedAuth(t *testing.T) { ts := httptest.NewServer(mux) - c := pmapi.New(pmapi.Config{HostURL: ts.URL}). + c := New(Config{HostURL: ts.URL}). NewClient("uid", "acc", "ref", time.Now().Add(time.Hour)) // The request will fail with 401, triggering a refresh. // The retry will also fail with 401, returning an error. _, err := c.GetAddresses(context.Background()) - if err == nil { - t.Fatal("expected error, instead got", err) - } - - if !errors.Is(err, pmapi.ErrUnauthorized) { - t.Fatal("expected error to be ErrUnauthorized, instead got", err) - } + r.EqualError(t, err, ErrUnauthorized.Error()) +} + +func TestAuth2FA(t *testing.T) { + twoFACode := "code" + + finish, c := newTestClientCallbacks(t, + func(tb testing.TB, w http.ResponseWriter, req *http.Request) string { + r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa")) + + var twoFAreq auth2FAReq + r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq)) + r.Equal(t, twoFAreq.TwoFactorCode, twoFACode) + + return "/auth/2fa/post_response.json" + }, + ) + defer finish() + + err := c.Auth2FA(context.Background(), twoFACode) + r.NoError(t, err) +} + +func TestAuth2FA_Fail(t *testing.T) { + finish, c := newTestClientCallbacks(t, + func(tb testing.TB, w http.ResponseWriter, req *http.Request) string { + r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa")) + return "/auth/2fa/post_401_bad_password.json" + }, + ) + defer finish() + + err := c.Auth2FA(context.Background(), "code") + r.Equal(t, ErrBad2FACode, err) +} + +func TestAuth2FA_Retry(t *testing.T) { + finish, c := newTestClientCallbacks(t, + func(tb testing.TB, w http.ResponseWriter, req *http.Request) string { + r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa")) + return "/auth/2fa/post_422_bad_password.json" + }, + ) + defer finish() + + err := c.Auth2FA(context.Background(), "code") + r.Equal(t, ErrBad2FACodeTryAgain, err) } diff --git a/pkg/pmapi/auth_types.go b/pkg/pmapi/auth_types.go deleted file mode 100644 index 18faa058..00000000 --- a/pkg/pmapi/auth_types.go +++ /dev/null @@ -1,72 +0,0 @@ -package pmapi - -type AuthModulus struct { - Modulus string - ModulusID string -} - -type GetAuthInfoReq struct { - Username string -} - -type AuthInfo struct { - Version int - Modulus string - ServerEphemeral string - Salt string - SRPSession string -} - -type TwoFAInfo struct { - Enabled TwoFAStatus -} - -type TwoFAStatus int - -const ( - TwoFADisabled TwoFAStatus = iota - TOTPEnabled - // TODO: Support UTF -) - -type PasswordMode int - -const ( - OnePasswordMode PasswordMode = iota + 1 - TwoPasswordMode -) - -type AuthReq struct { - Username string - ClientProof string - ClientEphemeral string - SRPSession string -} - -type Auth struct { - UserID string - - UID string - AccessToken string - RefreshToken string - ExpiresIn int64 - - Scope string - ServerProof string - - TwoFA TwoFAInfo `json:"2FA"` - PasswordMode PasswordMode -} - -type Auth2FAReq struct { - TwoFactorCode string -} - -type AuthRefreshReq struct { - UID string - RefreshToken string - ResponseType string - GrantType string - RedirectURI string - State string -} diff --git a/pkg/pmapi/boolean.go b/pkg/pmapi/boolean.go new file mode 100644 index 00000000..d4156d3e --- /dev/null +++ b/pkg/pmapi/boolean.go @@ -0,0 +1,41 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import "encoding/json" + +type Boolean bool + +func (boolean *Boolean) UnmarshalJSON(b []byte) error { + var value int + err := json.Unmarshal(b, &value) + if err != nil { + return err + } + + *boolean = Boolean(value == 1) + return nil +} + +func (boolean Boolean) MarshalJSON() ([]byte, error) { + var value int + if boolean { + value = 1 + } + return json.Marshal(value) +} diff --git a/pkg/pmapi/client.go b/pkg/pmapi/client.go index b73dca7b..edbae830 100644 --- a/pkg/pmapi/client.go +++ b/pkg/pmapi/client.go @@ -25,15 +25,14 @@ import ( "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/go-resty/resty/v2" - "github.com/pkg/errors" ) // client is a client of the protonmail API. It implements the Client interface. type client struct { - req requester + manager clientManager uid, acc, ref string - authHandlers []AuthHandler + authHandlers []AuthRefreshHandler authLocker sync.RWMutex user *User @@ -45,9 +44,9 @@ type client struct { exp time.Time } -func newClient(req requester, uid string) *client { +func newClient(manager clientManager, uid string) *client { return &client{ - req: req, + manager: manager, uid: uid, addrKeyRing: make(map[string]*crypto.KeyRing), keyRingLock: &sync.RWMutex{}, @@ -63,7 +62,7 @@ func (c *client) withAuth(acc, ref string, exp time.Time) *client { } func (c *client) r(ctx context.Context) (*resty.Request, error) { - r := c.req.r(ctx) + r := c.manager.r(ctx) if c.uid != "" { r.SetHeader("x-pm-uid", c.uid) @@ -91,30 +90,23 @@ func (c *client) do(ctx context.Context, fn func(*resty.Request) (*resty.Respons return nil, err } - res, err := wrapRestyError(fn(r)) + res, err := wrapNoConnection(fn(r)) if err != nil { if res.StatusCode() != http.StatusUnauthorized { - return nil, err + // Return also response so caller has more options to decide what to do. + return res, err } - if err := c.authRefresh(ctx); err != nil { - return nil, err + if !isAuthRefreshDisabled(ctx) { + if err := c.authRefresh(ctx); err != nil { + return nil, err + } + + return wrapNoConnection(fn(r)) } - return wrapRestyError(fn(r)) + return res, err } return res, nil } - -func wrapRestyError(res *resty.Response, err error) (*resty.Response, error) { - if err, ok := err.(*resty.ResponseError); ok { - return res, err - } - - if res.RawResponse != nil { - return res, err - } - - return res, errors.Wrap(ErrNoConnection, err.Error()) -} diff --git a/pkg/pmapi/client_keys.go b/pkg/pmapi/client_keys.go index 56f7ee5c..59d36b15 100644 --- a/pkg/pmapi/client_keys.go +++ b/pkg/pmapi/client_keys.go @@ -1,3 +1,20 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import ( @@ -12,8 +29,6 @@ func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) { c.keyRingLock.Lock() defer c.keyRingLock.Unlock() - // FIXME(conman): Should this be done as part of NewClient somehow? - return c.unlock(ctx, passphrase) } @@ -65,6 +80,15 @@ func (c *client) clearKeys() { } func (c *client) IsUnlocked() bool { - // FIXME(conman): Better way to check? we don't currently check address keys. - return c.userKeyRing != nil + if c.userKeyRing == nil { + return false + } + + for _, address := range c.addresses { + if address.HasKeys != MissingKeys && c.addrKeyRing[address.ID] == nil { + return false + } + } + + return true } diff --git a/pkg/pmapi/client_types.go b/pkg/pmapi/client_types.go index 611e99d6..bf87bd60 100644 --- a/pkg/pmapi/client_types.go +++ b/pkg/pmapi/client_types.go @@ -27,10 +27,10 @@ import ( // Client defines the interface of a PMAPI client. type Client interface { - Auth2FA(context.Context, Auth2FAReq) error + Auth2FA(context.Context, string) error AuthSalt(ctx context.Context) (string, error) AuthDelete(context.Context) error - AddAuthHandler(AuthHandler) + AddAuthRefreshHandler(AuthRefreshHandler) CurrentUser(ctx context.Context) (*User, error) UpdateUser(ctx context.Context) (*User, error) @@ -75,9 +75,9 @@ type Client interface { GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error) } -type AuthHandler func(*Auth) error +type AuthRefreshHandler func(*AuthRefresh) -type requester interface { +type clientManager interface { r(context.Context) *resty.Request - authRefresh(context.Context, string, string) (*Auth, error) + authRefresh(context.Context, string, string) (*AuthRefresh, error) } diff --git a/pkg/pmapi/config.go b/pkg/pmapi/config.go index e87b4d00..65da20f1 100644 --- a/pkg/pmapi/config.go +++ b/pkg/pmapi/config.go @@ -1,11 +1,72 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi +import ( + "runtime" + "strings" +) + type Config struct { - HostURL string + // HostURL is the base URL of API. + HostURL string + + // AppVersion sets version to headers of each request. AppVersion string + + // UserAgent sets user agent to headers of each request. + // Used only if GetUserAgent is not set. + UserAgent string + + // GetUserAgent is dynamic version of UserAgent. + // Overrides UserAgent. + GetUserAgent func() string + + // UpgradeApplicationHandler is used to notify when there is a force upgrade. + UpgradeApplicationHandler func() + + // TLSIssueHandler is used to notify when there is a TLS issue. + TLSIssueHandler func() } -var DefaultConfig = Config{ - HostURL: "https://api.protonmail.ch", - AppVersion: "Other", +func NewConfig(appVersionName, appVersion string) Config { + return Config{ + HostURL: getRootURL(), + AppVersion: getAPIOS() + strings.Title(appVersionName) + "_" + appVersion, + } +} + +func (c *Config) getUserAgent() string { + if c.GetUserAgent == nil { + return c.UserAgent + } + return c.GetUserAgent() +} + +// getAPIOS returns actual operating system. +func getAPIOS() string { + switch os := runtime.GOOS; os { + case "darwin": // nolint: goconst + return "macOS" + case "linux": + return "Linux" + case "windows": + return "Windows" + } + return "Linux" } diff --git a/pkg/pmapi/config_default.go b/pkg/pmapi/config_default.go new file mode 100644 index 00000000..5e722c03 --- /dev/null +++ b/pkg/pmapi/config_default.go @@ -0,0 +1,35 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +// +build !build_qa + +package pmapi + +import ( + "net/http" +) + +func getRootURL() string { + return "https://api.protonmail.ch" +} + +func newProxyDialerAndTransport(cfg Config) (*ProxyTLSDialer, http.RoundTripper) { + basicDialer := NewBasicTLSDialer(cfg) + pinningDialer := NewPinningTLSDialer(cfg, basicDialer) + proxyDialer := NewProxyTLSDialer(cfg, pinningDialer) + return proxyDialer, CreateTransportWithDialer(proxyDialer) +} diff --git a/pkg/pmapi/config_qa.go b/pkg/pmapi/config_qa.go new file mode 100644 index 00000000..618fa690 --- /dev/null +++ b/pkg/pmapi/config_qa.go @@ -0,0 +1,48 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +// +build build_qa + +package pmapi + +import ( + "crypto/tls" + "net/http" + "os" + "strings" +) + +func getRootURL() string { + // This config allows to dynamically change ROOT URL. + url := os.Getenv("PMAPI_ROOT_URL") + if strings.HasPrefix(url, "http") { + return url + } + if url != "" { + return "https://" + url + } + return "https://api.protonmail.ch" +} + +func newProxyDialerAndTransport(cfg Config) (*ProxyTLSDialer, http.RoundTripper) { + transport := CreateTransportWithDialer(NewBasicTLSDialer(cfg)) + + // TLS certificate of testing environment might be self-signed. + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + + return nil, transport +} diff --git a/pkg/pmapi/contacts.go b/pkg/pmapi/contacts.go index de1de029..0bfd7584 100644 --- a/pkg/pmapi/contacts.go +++ b/pkg/pmapi/contacts.go @@ -129,11 +129,14 @@ func (c *client) GetContactEmailByEmail(ctx context.Context, email string, page } if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { - return r.SetQueryParams(map[string]string{ - "Email": email, - "Page": strconv.Itoa(page), - "PageSize": strconv.Itoa(pageSize), - }).SetResult(&res).Get("/contacts/v4") + r = r.SetQueryParams(map[string]string{ + "Email": email, + "Page": strconv.Itoa(page), + }) + if pageSize != 0 { + r.SetQueryParam("PageSize", strconv.Itoa(pageSize)) + } + return r.SetResult(&res).Get("/contacts/v4") }); err != nil { return nil, err } diff --git a/pkg/pmapi/contacts_test.go b/pkg/pmapi/contacts_test.go index 5ac4f05f..4d158edf 100644 --- a/pkg/pmapi/contacts_test.go +++ b/pkg/pmapi/contacts_test.go @@ -24,7 +24,7 @@ import ( "reflect" "testing" - "github.com/stretchr/testify/assert" + r "github.com/stretchr/testify/require" ) var ( @@ -106,19 +106,16 @@ var testGetContactByID = Contact{ } func TestContact_GetContactById(t *testing.T) { - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "GET", "/contacts/v4/s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==")) + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "GET", "/contacts/v4/s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==")) w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, testGetContactByIDResponseBody) })) defer s.Close() - contact, err := c.GetContactByID(context.TODO(), "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==") - if err != nil { - t.Fatal("Expected no error while getting contacts, got:", err) - } + contact, err := c.GetContactByID(context.Background(), "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==") + r.NoError(t, err) if !reflect.DeepEqual(contact, testGetContactByID) { t.Fatalf("Invalid got contact: expected %+v, got %+v", testGetContactByID, contact) @@ -160,24 +157,24 @@ var testCardsCleartext = []Card{ } func TestClient_Encrypt(t *testing.T) { - c := newClient(newManager(DefaultConfig), "") + c := newClient(newManager(Config{}), "") c.userKeyRing = testPrivateKeyRing cardEncrypted, err := c.EncryptAndSignCards(testCardsCleartext) - assert.Nil(t, err) + r.Nil(t, err) // Result is always different, so the best way is to test it by decrypting again. // Another test for decrypting will help us to be sure it's working. cardCleartext, err := c.DecryptAndVerifyCards(cardEncrypted) - assert.Nil(t, err) - assert.Equal(t, testCardsCleartext[0].Data, cardCleartext[0].Data) + r.Nil(t, err) + r.Equal(t, testCardsCleartext[0].Data, cardCleartext[0].Data) } func TestClient_Decrypt(t *testing.T) { - c := newClient(newManager(DefaultConfig), "") + c := newClient(newManager(Config{}), "") c.userKeyRing = testPrivateKeyRing cardCleartext, err := c.DecryptAndVerifyCards(testCardsEncrypted) - assert.Nil(t, err) - assert.Equal(t, testCardsCleartext[0].Data, cardCleartext[0].Data) + r.Nil(t, err) + r.Equal(t, testCardsCleartext[0].Data, cardCleartext[0].Data) } diff --git a/pkg/pmapi/context.go b/pkg/pmapi/context.go new file mode 100644 index 00000000..bba58940 --- /dev/null +++ b/pkg/pmapi/context.go @@ -0,0 +1,54 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "context" +) + +type pmapiContextKey string + +const ( + retryContextKey = pmapiContextKey("retry") + retryDisabled = "disabled" + + authRefreshContextKey = pmapiContextKey("authRefresh") + authRefreshDisabled = "disabled" +) + +func ContextWithoutRetry(parent context.Context) context.Context { + return context.WithValue(parent, retryContextKey, retryDisabled) +} + +func isRetryDisabled(ctx context.Context) bool { + if v := ctx.Value(retryContextKey); v != nil { + return v == retryDisabled + } + return false +} + +func ContextWithoutAuthRefresh(parent context.Context) context.Context { + return context.WithValue(parent, authRefreshContextKey, authRefreshDisabled) +} + +func isAuthRefreshDisabled(ctx context.Context) bool { + if v := ctx.Value(authRefreshContextKey); v != nil { + return v == authRefreshDisabled + } + return false +} diff --git a/pkg/pmapi/data_test.go b/pkg/pmapi/data_test.go index 29a1ccd0..7e7f780a 100644 --- a/pkg/pmapi/data_test.go +++ b/pkg/pmapi/data_test.go @@ -1,3 +1,20 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import "github.com/ProtonMail/gopenpgp/v2/crypto" @@ -8,9 +25,6 @@ var testIdentity = &crypto.Identity{ } const ( - testUsername = "jason" - testAPIPassword = "apple" - testUID = "729ad6012421d67ad26950dc898bebe3a6e3caa2" //nolint[gosec] testAccessToken = "de0423049b44243afeec7d9c1d99be7b46da1e8a" //nolint[gosec] testAccessTokenOld = "feb3159ac63fb05119bcf4480d939278aa746926" //nolint[gosec] diff --git a/pkg/pmapi/dialer_basic.go b/pkg/pmapi/dialer_basic.go new file mode 100644 index 00000000..8a58c9e5 --- /dev/null +++ b/pkg/pmapi/dialer_basic.go @@ -0,0 +1,76 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge.Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "crypto/tls" + "net" + "net/http" + "time" +) + +type TLSDialer interface { + DialTLS(network, address string) (conn net.Conn, err error) +} + +// CreateTransportWithDialer creates an http.Transport that uses the given dialer to make TLS connections. +func CreateTransportWithDialer(dialer TLSDialer) *http.Transport { + return &http.Transport{ + DialTLS: dialer.DialTLS, + + Proxy: http.ProxyFromEnvironment, + MaxIdleConns: 100, + IdleConnTimeout: 5 * time.Minute, + ExpectContinueTimeout: 500 * time.Millisecond, + + // GODT-126: this was initially 10s but logs from users showed a significant number + // were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect. + // Bumping to 30s for now to avoid this problem. + ResponseHeaderTimeout: 30 * time.Second, + + // If we allow up to 30 seconds for response headers, it is reasonable to allow up + // to 30 seconds for the TLS handshake to take place. + TLSHandshakeTimeout: 30 * time.Second, + } +} + +// BasicTLSDialer implements TLSDialer. +type BasicTLSDialer struct { + cfg Config +} + +// NewBasicTLSDialer returns a new BasicTLSDialer. +func NewBasicTLSDialer(cfg Config) *BasicTLSDialer { + return &BasicTLSDialer{ + cfg: cfg, + } +} + +// DialTLS returns a connection to the given address using the given network. +func (d *BasicTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) { + dialer := &net.Dialer{Timeout: 30 * time.Second} // Alternative Routes spec says this should be a 30s timeout. + + var tlsConfig *tls.Config + + // If we are not dialing the standard API then we should skip cert verification checks. + if address != d.cfg.HostURL { + tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec] + } + + return tls.DialWithDialer(dialer, network, address, tlsConfig) +} diff --git a/pkg/pmapi/dialer_pinning.go b/pkg/pmapi/dialer_pinning.go new file mode 100644 index 00000000..26fd8876 --- /dev/null +++ b/pkg/pmapi/dialer_pinning.go @@ -0,0 +1,110 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "crypto/tls" + "net" + + "github.com/sirupsen/logrus" +) + +// TrustedAPIPins contains trusted public keys of the protonmail API and proxies. +// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;). +var TrustedAPIPins = []string{ // nolint[gochecknoglobals] + // api.protonmail.ch + `pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current + `pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot backup + `pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold backup + + // protonmail.com + `pin-sha256="8joiNBdqaYiQpKskgtkJsqRxF7zN0C0aqfi8DacknnI="`, // current + `pin-sha256="JMI8yrbc6jB1FYGyyWRLFTmDNgIszrNEMGlgy972e7w="`, // hot backup + `pin-sha256="Iu44zU84EOCZ9vx/vz67/MRVrxF1IO4i4NIa8ETwiIY="`, // cold backup + + // proxies + `pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // main + `pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // backup 1 + `pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // backup 2 + `pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // backup 3 +} + +// TLSReportURI is the address where TLS reports should be sent. +const TLSReportURI = "https://reports.protonmail.ch/reports/tls" + +// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and +// to report errors if the fingerprint check fails. +type PinningTLSDialer struct { + dialer TLSDialer + + // pinChecker is used to check TLS keys of connections. + pinChecker *pinChecker + + reporter *tlsReporter + + // tlsIssueNotifier is used to notify something when there is a TLS issue. + tlsIssueNotifier func() + + // A logger for logging messages. + log logrus.FieldLogger +} + +// NewPinningTLSDialer constructs a new dialer which only returns tcp connections to servers +// which present known certificates. +// If enabled, it reports any invalid certificates it finds. +func NewPinningTLSDialer(cfg Config, dialer TLSDialer) *PinningTLSDialer { + return &PinningTLSDialer{ + dialer: dialer, + pinChecker: newPinChecker(TrustedAPIPins), + reporter: newTLSReporter(cfg, TrustedAPIPins), + tlsIssueNotifier: cfg.TLSIssueHandler, + log: logrus.WithField("pkg", "pmapi/tls-pinning"), + } +} + +// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins. +func (p *PinningTLSDialer) DialTLS(network, address string) (net.Conn, error) { + conn, err := p.dialer.DialTLS(network, address) + if err != nil { + return nil, err + } + + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + if err := p.pinChecker.checkCertificate(conn); err != nil { + if p.tlsIssueNotifier != nil { + go p.tlsIssueNotifier() + } + + if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil { + p.reporter.reportCertIssue( + TLSReportURI, + host, + port, + tlsConn.ConnectionState(), + ) + } + + return nil, err + } + + return conn, nil +} diff --git a/pkg/pmapi/dialer_pinning_checker.go b/pkg/pmapi/dialer_pinning_checker.go new file mode 100644 index 00000000..f8cee2c5 --- /dev/null +++ b/pkg/pmapi/dialer_pinning_checker.go @@ -0,0 +1,68 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge.Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "net" +) + +// ErrTLSMismatch indicates that no TLS fingerprint match could be found. +var ErrTLSMismatch = errors.New("no TLS fingerprint match found") + +type pinChecker struct { + trustedPins []string +} + +func newPinChecker(trustedPins []string) *pinChecker { + return &pinChecker{ + trustedPins: trustedPins, + } +} + +// checkCertificate returns whether the connection presents a known TLS certificate. +func (p *pinChecker) checkCertificate(conn net.Conn) error { + tlsConn, ok := conn.(*tls.Conn) + if !ok { + return errors.New("connection is not a TLS connection") + } + + connState := tlsConn.ConnectionState() + + for _, peerCert := range connState.PeerCertificates { + fingerprint := certFingerprint(peerCert) + + for _, pin := range p.trustedPins { + if pin == fingerprint { + return nil + } + } + } + + return ErrTLSMismatch +} + +func certFingerprint(cert *x509.Certificate) string { + hash := sha256.Sum256(cert.RawSubjectPublicKeyInfo) + return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:])) +} diff --git a/pkg/pmapi/dialer_pinning_report.go b/pkg/pmapi/dialer_pinning_report.go new file mode 100644 index 00000000..4885474c --- /dev/null +++ b/pkg/pmapi/dialer_pinning_report.go @@ -0,0 +1,144 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge.Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "strconv" + "time" + + "github.com/sirupsen/logrus" +) + +// tlsReport is inspired by https://tools.ietf.org/html/rfc7469#section-3. +// When a TLS key mismatch is detected, a tlsReport is posted to TLSReportURI. +type tlsReport struct { + // DateTime of observed pin validation in time.RFC3339 format. + DateTime string `json:"date-time"` + + // Hostname to which the UA made original request that failed pin validation. + Hostname string `json:"hostname"` + + // Port to which the UA made original request that failed pin validation. + Port int `json:"port"` + + // EffectiveExpirationDate for noted pins in time.RFC3339 format. + EffectiveExpirationDate string `json:"effective-expiration-date"` + + // IncludeSubdomains indicates whether or not the UA has noted the + // includeSubDomains directive for the Known Pinned Host. + IncludeSubdomains bool `json:"include-subdomains"` + + // NotedHostname indicates the hostname that the UA noted when it noted + // the Known Pinned Host. This field allows operators to understand why + // Pin Validation was performed for, e.g., foo.example.com when the + // noted Known Pinned Host was example.com with includeSubDomains set. + NotedHostname string `json:"noted-hostname"` + + // ServedCertificateChain is the certificate chain, as served by + // the Known Pinned Host during TLS session setup. It is provided as an + // array of strings; each string pem1, ... pemN is the Privacy-Enhanced + // Mail (PEM) representation of each X.509 certificate as described in + // [RFC7468]. + ServedCertificateChain []string `json:"served-certificate-chain"` + + // ValidatedCertificateChain is the certificate chain, as + // constructed by the UA during certificate chain verification. (This + // may differ from the served-certificate-chain.) It is provided as an + // array of strings; each string pem1, ... pemN is the PEM + // representation of each X.509 certificate as described in [RFC7468]. + // UAs that build certificate chains in more than one way during the + // validation process SHOULD send the last chain built. In this way, + // they can avoid keeping too much state during the validation process. + ValidatedCertificateChain []string `json:"validated-certificate-chain"` + + // The known-pins are the Pins that the UA has noted for the Known + // Pinned Host. They are provided as an array of strings with the + // syntax: known-pin = token "=" quoted-string + // e.g.: + // ``` + // "known-pins": [ + // 'pin-sha256="d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="', + // "pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\"" + // ] + // ``` + KnownPins []string `json:"known-pins"` + + // AppVersion is used to set `x-pm-appversion` json format from datatheorem/TrustKit. + AppVersion string `json:"app-version"` +} + +// newTLSReport constructs a new tlsReport configured with the given app version and known pinned public keys. +// Temporal things (current date/time) are not set yet -- they are set when sendReport is called. +func newTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report tlsReport) { + // If we can't parse the port for whatever reason, it doesn't really matter; we should report anyway. + intPort, _ := strconv.Atoi(port) + + report = tlsReport{ + Hostname: host, + Port: intPort, + NotedHostname: server, + ServedCertificateChain: certChain, + KnownPins: knownPins, + AppVersion: appVersion, + } + + return +} + +// sendReport posts the given TLS report to the standard TLS Report URI. +func (r tlsReport) sendReport(cfg Config, uri string) { + now := time.Now() + r.DateTime = now.Format(time.RFC3339) + r.EffectiveExpirationDate = now.Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339) + + b, err := json.Marshal(r) + if err != nil { + logrus.WithError(err).Error("Failed to marshal TLS report") + return + } + + req, err := http.NewRequest("POST", uri, bytes.NewReader(b)) + if err != nil { + logrus.WithError(err).Error("Failed to create http request") + return + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Set("User-Agent", cfg.getUserAgent()) + req.Header.Set("x-pm-appversion", r.AppVersion) + + logrus.WithField("request", req).Warn("Reporting TLS mismatch") + res, err := (&http.Client{Transport: CreateTransportWithDialer(NewBasicTLSDialer(cfg))}).Do(req) + if err != nil { + logrus.WithError(err).Error("Failed to report TLS mismatch") + return + } + + logrus.WithField("response", res).Error("Reported TLS mismatch") + + if res.StatusCode != http.StatusOK { + logrus.WithField("status", http.StatusOK).Error("StatusCode was not OK") + } + + _, _ = ioutil.ReadAll(res.Body) + _ = res.Body.Close() +} diff --git a/pkg/pmapi/dialer_pinning_reporter.go b/pkg/pmapi/dialer_pinning_reporter.go new file mode 100644 index 00000000..9ed0de69 --- /dev/null +++ b/pkg/pmapi/dialer_pinning_reporter.go @@ -0,0 +1,107 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge.Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/sirupsen/logrus" +) + +type sentReport struct { + r tlsReport + t time.Time +} + +type tlsReporter struct { + cfg Config + trustedPins []string + sentReports []sentReport +} + +func newTLSReporter(cfg Config, trustedPins []string) *tlsReporter { + return &tlsReporter{ + cfg: cfg, + trustedPins: trustedPins, + } +} + +// reportCertIssue reports a TLS key mismatch. +func (r *tlsReporter) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState) { + var certChain []string + + if len(connState.VerifiedChains) > 0 { + certChain = marshalCert7468(connState.VerifiedChains[len(connState.VerifiedChains)-1]) + } else { + certChain = marshalCert7468(connState.PeerCertificates) + } + + report := newTLSReport(host, port, connState.ServerName, certChain, r.trustedPins, r.cfg.AppVersion) + + if !r.hasRecentlySentReport(report) { + r.recordReport(report) + go report.sendReport(r.cfg, remoteURI) + } +} + +// hasRecentlySentReport returns whether the report was already sent within the last 24 hours. +func (r *tlsReporter) hasRecentlySentReport(report tlsReport) bool { + var validReports []sentReport + + for _, r := range r.sentReports { + if time.Since(r.t) < 24*time.Hour { + validReports = append(validReports, r) + } + } + + r.sentReports = validReports + + for _, r := range r.sentReports { + if cmp.Equal(report, r.r) { + return true + } + } + + return false +} + +// recordReport records the given report and the current time so we can check whether we recently sent this report. +func (r *tlsReporter) recordReport(report tlsReport) { + r.sentReports = append(r.sentReports, sentReport{r: report, t: time.Now()}) +} + +func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) { + var buffer bytes.Buffer + for _, cert := range certs { + if err := pem.Encode(&buffer, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }); err != nil { + logrus.WithField("pkg", "pmapi/tls-pinning").WithError(err).Error("Failed to encode TLS certificate") + } + pemCerts = append(pemCerts, buffer.String()) + buffer.Reset() + } + + return pemCerts +} diff --git a/pkg/pmapi/dialer_pinning_reporter_test.go b/pkg/pmapi/dialer_pinning_reporter_test.go new file mode 100644 index 00000000..90a7cd3f --- /dev/null +++ b/pkg/pmapi/dialer_pinning_reporter_test.go @@ -0,0 +1,62 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge.Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTLSReporter_DoubleReport(t *testing.T) { + reportCounter := 0 + + reportServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reportCounter++ + })) + + cfg := Config{ + AppVersion: "3", + UserAgent: "useragent", + } + r := newTLSReporter(cfg, TrustedAPIPins) + + // Report the same issue many times. + for i := 0; i < 10; i++ { + r.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{}) + } + + // We should only report once. + assert.Eventually(t, func() bool { + return reportCounter == 1 + }, time.Second, time.Millisecond) + + // If we then report something else many times. + for i := 0; i < 10; i++ { + r.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{}) + } + + // We should get a second report. + assert.Eventually(t, func() bool { + return reportCounter == 2 + }, time.Second, time.Millisecond) +} diff --git a/pkg/pmapi/dialer_pinning_test.go b/pkg/pmapi/dialer_pinning_test.go new file mode 100644 index 00000000..08be764f --- /dev/null +++ b/pkg/pmapi/dialer_pinning_test.go @@ -0,0 +1,149 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + a "github.com/stretchr/testify/assert" + r "github.com/stretchr/testify/require" +) + +func TestTLSPinValid(t *testing.T) { + called, _, cm := createClientWithPinningDialer(getRootURL()) + + _, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint + checkTLSIssueHandler(t, 0, called) +} + +func TestTLSPinBackup(t *testing.T) { + called, dialer, cm := createClientWithPinningDialer(getRootURL()) + copyTrustedPins(dialer.pinChecker) + dialer.pinChecker.trustedPins[1] = dialer.pinChecker.trustedPins[0] + dialer.pinChecker.trustedPins[0] = "" + + _, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint + checkTLSIssueHandler(t, 0, called) +} + +func TestTLSPinInvalid(t *testing.T) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeJSONResponsefromFile(t, w, "/auth/info/post_response.json", 0) + })) + defer ts.Close() + + called, _, cm := createClientWithPinningDialer(ts.URL) + + _, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint + checkTLSIssueHandler(t, 1, called) +} + +func TestTLSPinNoMatch(t *testing.T) { + skipIfProxyIsSet(t) + + called, dialer, cm := createClientWithPinningDialer(getRootURL()) + + copyTrustedPins(dialer.pinChecker) + for i := 0; i < len(dialer.pinChecker.trustedPins); i++ { + dialer.pinChecker.trustedPins[i] = "testing" + } + + _, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint + _, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint + + // Check that it will be reported only once per session, but notified every time. + r.Equal(t, 1, len(dialer.reporter.sentReports)) + checkTLSIssueHandler(t, 2, called) +} + +func TestTLSSignedCertWrongPublicKey(t *testing.T) { + skipIfProxyIsSet(t) + + _, dialer, _ := createClientWithPinningDialer("") + _, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443") + r.Error(t, err, "expected dial to fail because of wrong public key") +} + +func TestTLSSignedCertTrustedPublicKey(t *testing.T) { + skipIfProxyIsSet(t) + + _, dialer, _ := createClientWithPinningDialer("") + copyTrustedPins(dialer.pinChecker) + dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`) + _, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443") + r.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA") +} + +func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) { + skipIfProxyIsSet(t) + + _, dialer, _ := createClientWithPinningDialer("") + copyTrustedPins(dialer.pinChecker) + dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`) + _, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443") + r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed") +} + +func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *manager) { + called := 0 + + cfg := Config{ + AppVersion: "Bridge_1.2.4-test", + HostURL: hostURL, + TLSIssueHandler: func() { called++ }, + } + + dialer := NewPinningTLSDialer(cfg, NewBasicTLSDialer(cfg)) + + cm := newManager(cfg) + cm.SetTransport(CreateTransportWithDialer(dialer)) + + return &called, dialer, cm +} + +func copyTrustedPins(pinChecker *pinChecker) { + copiedPins := make([]string, len(pinChecker.trustedPins)) + copy(copiedPins, pinChecker.trustedPins) + pinChecker.trustedPins = copiedPins +} + +func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) { + // TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called. + a.Eventually( + t, + func() bool { + if wantCalledAtLeast == 0 { + return *called == 0 + } + // Dialer can do more attempts resulting in more calls. + return *called >= wantCalledAtLeast + }, + time.Second, + 10*time.Millisecond, + ) + // Repeated again so it generates nice message. + if wantCalledAtLeast == 0 { + r.Equal(t, 0, *called) + } else { + r.GreaterOrEqual(t, *called, wantCalledAtLeast) + } +} diff --git a/pkg/pmapi/dialer_proxy.go b/pkg/pmapi/dialer_proxy.go new file mode 100644 index 00000000..f984cce4 --- /dev/null +++ b/pkg/pmapi/dialer_proxy.go @@ -0,0 +1,144 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge.Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "net" + "net/url" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +// ProxyTLSDialer wraps a TLSDialer to switch to a proxy if the initial dial fails. +type ProxyTLSDialer struct { + dialer TLSDialer + + locker sync.RWMutex + directAddress string + proxyAddress string + allowProxy bool + proxyProvider *proxyProvider + proxyUseDuration time.Duration +} + +// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer. +func NewProxyTLSDialer(cfg Config, dialer TLSDialer) *ProxyTLSDialer { + return &ProxyTLSDialer{ + dialer: dialer, + locker: sync.RWMutex{}, + directAddress: formatAsAddress(cfg.HostURL), + proxyAddress: formatAsAddress(cfg.HostURL), + proxyProvider: newProxyProvider(cfg, dohProviders, proxyQuery), + proxyUseDuration: proxyUseDuration, + } +} + +// formatAsAddress returns URL as `host:port` for easy comparison in DialTLS. +func formatAsAddress(rawURL string) string { + url, err := url.Parse(rawURL) + if err != nil { + // This means wrong configuration. + // Developer should get feedback right away. + panic(err) + } + + port := "443" + if url.Scheme == "http" { + port = "80" + } + return net.JoinHostPort(url.Host, port) +} + +// DialTLS dials the given network/address. If it fails, it retries using a proxy. +func (d *ProxyTLSDialer) DialTLS(network, address string) (net.Conn, error) { + if address == d.directAddress { + address = d.proxyAddress + } + + conn, err := d.dialer.DialTLS(network, address) + if err == nil || !d.allowProxy { + return conn, err + } + + err = d.switchToReachableServer() + if err != nil { + return nil, err + } + + return d.dialer.DialTLS(network, d.proxyAddress) +} + +// switchToReachableServer switches to using a reachable server (either proxy or standard API). +func (d *ProxyTLSDialer) switchToReachableServer() error { + d.locker.Lock() + defer d.locker.Unlock() + + logrus.Info("Attempting to switch to a proxy") + + proxy, err := d.proxyProvider.findReachableServer() + if err != nil { + return errors.Wrap(err, "failed to find a usable proxy") + } + + proxyAddress := formatAsAddress(proxy) + + // If the chosen proxy is the standard API, we want to use it but still show the troubleshooting screen. + if proxyAddress == d.directAddress { + logrus.Info("The standard API is reachable again; connection drop was only intermittent") + d.proxyAddress = proxyAddress + return ErrNoConnection + } + + logrus.WithField("proxy", proxyAddress).Info("Switching to a proxy") + + // If the host is currently the rootURL, it's the first time we are enabling a proxy. + // This means we want to disable it again in 24 hours. + if d.proxyAddress == d.directAddress { + go func() { + <-time.After(d.proxyUseDuration) + + d.locker.Lock() + defer d.locker.Unlock() + + d.proxyAddress = d.directAddress + }() + } + + d.proxyAddress = proxyAddress + return nil +} + +// AllowProxy allows the dialer to switch to a proxy if need be. +func (d *ProxyTLSDialer) AllowProxy() { + d.locker.Lock() + defer d.locker.Unlock() + + d.allowProxy = true +} + +// DisallowProxy prevents the dialer from switching to a proxy if need be. +func (d *ProxyTLSDialer) DisallowProxy() { + d.locker.Lock() + defer d.locker.Unlock() + + d.allowProxy = false + d.proxyAddress = d.directAddress +} diff --git a/pkg/pmapi/dialer_proxy_provider.go b/pkg/pmapi/dialer_proxy_provider.go new file mode 100644 index 00000000..c22a7dd3 --- /dev/null +++ b/pkg/pmapi/dialer_proxy_provider.go @@ -0,0 +1,249 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "context" + "encoding/base64" + "strings" + "sync" + "time" + + "github.com/go-resty/resty/v2" + "github.com/miekg/dns" + "github.com/pkg/errors" +) + +const ( + proxyUseDuration = 24 * time.Hour + proxyLookupWait = 5 * time.Second + proxyCacheRefreshTimeout = 20 * time.Second + proxyDoHTimeout = 20 * time.Second + proxyCanReachTimeout = 20 * time.Second + proxyQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz" +) + +var dohProviders = []string{ //nolint[gochecknoglobals] + "https://dns11.quad9.net/dns-query", + "https://dns.google/dns-query", +} + +// proxyProvider manages known proxies. +type proxyProvider struct { + cfg Config + + // dohLookup is used to look up the given query at the given DoH provider, returning the TXT records> + dohLookup func(ctx context.Context, query, provider string) (urls []string, err error) + + providers []string // List of known doh providers. + query string // The query string used to find proxies. + proxyCache []string // All known proxies, cached in case DoH providers are unreachable. + + cacheRefreshTimeout time.Duration + dohTimeout time.Duration + canReachTimeout time.Duration + + lastLookup time.Time // The time at which we last attempted to find a proxy. +} + +// newProxyProvider creates a new proxyProvider that queries the given DoH providers +// to retrieve DNS records for the given query string. +func newProxyProvider(cfg Config, providers []string, query string) (p *proxyProvider) { // nolint[unparam] + p = &proxyProvider{ + cfg: cfg, + providers: providers, + query: query, + cacheRefreshTimeout: proxyCacheRefreshTimeout, + dohTimeout: proxyDoHTimeout, + canReachTimeout: proxyCanReachTimeout, + } + + // Use the default DNS lookup method; this can be overridden if necessary. + p.dohLookup = p.defaultDoHLookup + + return +} + +// findReachableServer returns a working API server (either proxy or standard API). +func (p *proxyProvider) findReachableServer() (proxy string, err error) { + log.Debug("Trying to find a reachable server") + + if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) { + return "", errors.New("not looking for a proxy, too soon") + } + + p.lastLookup = time.Now() + + // We use a waitgroup to wait for both + // a) the check whether the API is reachable, and + // b) the DoH queries. + // This is because the Alternative Routes v2 spec says: + // Call the GET /test/ping route on normal API domain (same time as DoH requests and wait until all have finished) + var wg sync.WaitGroup + var apiReachable bool + + wg.Add(2) + + go func() { + defer wg.Done() + apiReachable = p.canReach(p.cfg.HostURL) + }() + + go func() { + defer wg.Done() + err = p.refreshProxyCache() + }() + + wg.Wait() + + if apiReachable { + proxy = p.cfg.HostURL + return + } + + if err != nil { + return + } + + for _, url := range p.proxyCache { + if p.canReach(url) { + proxy = url + return + } + } + + return "", errors.New("no reachable server could be found") +} + +// refreshProxyCache loads the latest proxies from the known providers. +// If the process takes longer than proxyCacheRefreshTimeout, an error is returned. +func (p *proxyProvider) refreshProxyCache() error { + log.Info("Refreshing proxy cache") + + ctx, cancel := context.WithTimeout(context.Background(), p.cacheRefreshTimeout) + defer cancel() + + resultChan := make(chan []string) + + go func() { + for _, provider := range p.providers { + if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil { + resultChan <- proxies + return + } + } + // If no dohLoopkup worked, cancel right after it's done to not + // block refreshing for the whole cacheRefreshTimeout. + cancel() + }() + + select { + case result := <-resultChan: + p.proxyCache = result + return nil + + case <-ctx.Done(): + return errors.New("timed out while refreshing proxy cache") + } +} + +// canReach returns whether we can reach the given url. +func (p *proxyProvider) canReach(url string) bool { + log.WithField("url", url).Debug("Trying to ping proxy") + + if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") { + url = "https://" + url + } + + dialer := NewPinningTLSDialer(p.cfg, NewBasicTLSDialer(p.cfg)) + + pinger := resty.New(). + SetHostURL(url). + SetTimeout(p.canReachTimeout). + SetTransport(CreateTransportWithDialer(dialer)) + + if _, err := pinger.R().Get("/tests/ping"); err != nil { + log.WithField("proxy", url).WithError(err).Warn("Failed to ping proxy") + return false + } + + return true +} + +// defaultDoHLookup is the default implementation of the proxy manager's DoH lookup. +// It looks up DNS TXT records for the given query URL using the given DoH provider. +// It returns a list of all found TXT records. +// If the whole process takes more than proxyDoHTimeout then an error is returned. +func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider string) (data []string, err error) { + ctx, cancel := context.WithTimeout(ctx, p.dohTimeout) + defer cancel() + + dataChan, errChan := make(chan []string), make(chan error) + + go func() { + // Build new DNS request in RFC1035 format. + dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT) + + // Pack the DNS request message into wire format. + rawRequest, err := dnsRequest.Pack() + if err != nil { + errChan <- errors.Wrap(err, "failed to pack DNS request") + return + } + + // Encode wire-format DNS request message as base64url (RFC4648) without padding chars. + encodedRequest := base64.RawURLEncoding.EncodeToString(rawRequest) + + // Make DoH request to the given DoH provider. + rawResponse, err := resty.New().R().SetContext(ctx).SetQueryParam("dns", encodedRequest).Get(dohProvider) + if err != nil { + errChan <- errors.Wrap(err, "failed to make DoH request") + return + } + + // Unpack the DNS response. + dnsResponse := new(dns.Msg) + if err = dnsResponse.Unpack(rawResponse.Body()); err != nil { + errChan <- errors.Wrap(err, "failed to unpack DNS response") + return + } + + // Pick out the TXT answers. + for _, answer := range dnsResponse.Answer { + if t, ok := answer.(*dns.TXT); ok { + data = append(data, t.Txt...) + } + } + + dataChan <- data + }() + + select { + case data = <-dataChan: + log.WithField("data", data).Info("Received TXT records") + return + + case err = <-errChan: + log.WithField("provider", dohProvider).WithError(err).Error("Failed to query DNS records") + return + + case <-ctx.Done(): + log.WithField("provider", dohProvider).Error("Timed out querying DNS records") + return []string{}, errors.New("timed out querying DNS records") + } +} diff --git a/pkg/pmapi/dialer_proxy_provider_test.go b/pkg/pmapi/dialer_proxy_provider_test.go new file mode 100644 index 00000000..df751367 --- /dev/null +++ b/pkg/pmapi/dialer_proxy_provider_test.go @@ -0,0 +1,187 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "context" + "net/http" + "testing" + "time" + + r "github.com/stretchr/testify/require" + "golang.org/x/net/http/httpproxy" +) + +const ( + TestDoHQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz" + TestQuad9Provider = "https://dns11.quad9.net/dns-query" + TestGoogleProvider = "https://dns.google/dns-query" +) + +func TestProxyProvider_FindProxy(t *testing.T) { + proxy := getTrustedServer() + defer closeServer(proxy) + + p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used") + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil } + + url, err := p.findReachableServer() + r.NoError(t, err) + r.Equal(t, proxy.URL, url) +} + +func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) { + reachableProxy := getTrustedServer() + defer closeServer(reachableProxy) + + // We actually close the unreachable proxy straight away rather than deferring the closure. + unreachableProxy := getTrustedServer() + closeServer(unreachableProxy) + + p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used") + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { + return []string{reachableProxy.URL, unreachableProxy.URL}, nil + } + + url, err := p.findReachableServer() + r.NoError(t, err) + r.Equal(t, reachableProxy.URL, url) +} + +func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) { + trustedProxy := getTrustedServer() + defer closeServer(trustedProxy) + + untrustedProxy := getUntrustedServer() + defer closeServer(untrustedProxy) + + p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used") + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { + return []string{untrustedProxy.URL, trustedProxy.URL}, nil + } + + url, err := p.findReachableServer() + r.NoError(t, err) + r.Equal(t, trustedProxy.URL, url) +} + +func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) { + unreachableProxy1 := getTrustedServer() + closeServer(unreachableProxy1) + + unreachableProxy2 := getTrustedServer() + closeServer(unreachableProxy2) + + p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used") + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { + return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil + } + + _, err := p.findReachableServer() + r.Error(t, err) +} + +func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) { + untrustedProxy1 := getUntrustedServer() + defer closeServer(untrustedProxy1) + + untrustedProxy2 := getUntrustedServer() + defer closeServer(untrustedProxy2) + + p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used") + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { + return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil + } + + _, err := p.findReachableServer() + r.Error(t, err) +} + +func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) { + p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used") + p.cacheRefreshTimeout = 1 * time.Second + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil } + + // We should fail to refresh the proxy cache because the doh provider + // takes 2 seconds to respond but we timeout after just 1 second. + _, err := p.findReachableServer() + + r.Error(t, err) +} + +func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) { + slowProxy := getTrustedServerWithHandler(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + time.Sleep(2 * time.Second) + })) + defer closeServer(slowProxy) + + p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used") + p.canReachTimeout = 1 * time.Second + p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil } + + // We should fail to reach the returned proxy because it takes 2 seconds + // to reach it and we only allow 1. + _, err := p.findReachableServer() + + r.Error(t, err) +} + +func TestProxyProvider_DoHLookup_Quad9(t *testing.T) { + p := newProxyProvider(Config{}, []string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) + + records, err := p.dohLookup(context.Background(), TestDoHQuery, TestQuad9Provider) + r.NoError(t, err) + r.NotEmpty(t, records) +} + +func TestProxyProvider_DoHLookup_Google(t *testing.T) { + p := newProxyProvider(Config{}, []string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) + + records, err := p.dohLookup(context.Background(), TestDoHQuery, TestGoogleProvider) + r.NoError(t, err) + r.NotEmpty(t, records) +} + +func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) { + skipIfProxyIsSet(t) + + p := newProxyProvider(Config{}, []string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) + + url, err := p.findReachableServer() + r.NoError(t, err) + r.NotEmpty(t, url) +} + +func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) { + skipIfProxyIsSet(t) + + p := newProxyProvider(Config{}, []string{"https://unreachable", TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) + + url, err := p.findReachableServer() + r.NoError(t, err) + r.NotEmpty(t, url) +} + +// skipIfProxyIsSet skips the tests if HTTPS proxy is set. +// Should be used for tests depending on proper certificate checks which +// is not possible under our CI setup. +func skipIfProxyIsSet(t *testing.T) { + if httpproxy.FromEnvironment().HTTPSProxy != "" { + t.SkipNow() + } +} diff --git a/pkg/pmapi/dialer_proxy_test.go b/pkg/pmapi/dialer_proxy_test.go new file mode 100644 index 00000000..29f00f28 --- /dev/null +++ b/pkg/pmapi/dialer_proxy_test.go @@ -0,0 +1,253 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "context" + "crypto/tls" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// getTrustedServer returns a server and sets its public key as one of the pinned ones. +func getTrustedServer() *httptest.Server { + return getTrustedServerWithHandler( + http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + // Do nothing. + }), + ) +} + +func getTrustedServerWithHandler(handler http.HandlerFunc) *httptest.Server { + proxy := httptest.NewTLSServer(handler) + + pin := certFingerprint(proxy.Certificate()) + TrustedAPIPins = append(TrustedAPIPins, pin) + + return proxy +} + +const servercrt = ` +-----BEGIN CERTIFICATE----- +MIIE5TCCA82gAwIBAgIJAKsmhcMFGfGcMA0GCSqGSIb3DQEBCwUAMIGsMQswCQYD +VQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzARBgNVBAcMClJhbmRvbUNp +dHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEfMB0GA1UECwwWUmFuZG9t +T3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYRaGVsbG9AZXhhbXBsZS5j +b20xEjAQBgNVBAMMCTEyNy4wLjAuMTAeFw0yMDA0MjQxMzI3MzdaFw0yMTA5MDYx +MzI3MzdaMIGsMQswCQYDVQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzAR +BgNVBAcMClJhbmRvbUNpdHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEf +MB0GA1UECwwWUmFuZG9tT3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYR +aGVsbG9AZXhhbXBsZS5jb20xEjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBANAnYyqhosWwNzGjBwSwmDUINOaPs4TSTgKt +r6CE01atxAWzWUCyYqnQ4fPe5q2tx5t/VrmnTNpzycammKJszGLlmj9DFxSiYVw2 +pTTK3DBWFkfTwxq98mM7wMnCWy1T2L2pmuYjnd7Pa6pQa9OHYoJwRzlIl2Q3YVdM +GIBDbkW728A1dcelkIdFpv3r3ayTZv01vU8JMXd4PLHwXU0x0hHlH52+kx+9Ndru +rdqqV6LqVfNlSR1jFZkwLBBqvh3XrJRD9Q01EAX6m+ufZ0yq8mK9ifMRtwQet10c +kKMnx63MwvxDFmqrBj4HMtIRUpK+LBDs1ke7DvS0eLqaojWl28ECAwEAAaOCAQYw +ggECMIHLBgNVHSMEgcMwgcChgbKkga8wgawxCzAJBgNVBAYTAlVTMRQwEgYDVQQI +DAtSYW5kb21TdGF0ZTETMBEGA1UEBwwKUmFuZG9tQ2l0eTEbMBkGA1UECgwSUmFu +ZG9tT3JnYW5pemF0aW9uMR8wHQYDVQQLDBZSYW5kb21Pcmdhbml6YXRpb25Vbml0 +MSAwHgYJKoZIhvcNAQkBFhFoZWxsb0BleGFtcGxlLmNvbTESMBAGA1UEAwwJMTI3 +LjAuMC4xggkAvCxbs152YckwCQYDVR0TBAIwADALBgNVHQ8EBAMCBPAwGgYDVR0R +BBMwEYIJMTI3LjAuMC4xhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAC7ZycZMZ5 +L+cjIpwSj0cemLkVD+kcFUCkI7ket5gbX1PmavmnpuFl9Sru0eJ5wyJ+97MQElPA +CNFgXoX7DbJWkcd/LSksvZoJnpc1sTqFKMWFmOUxmUD62lCacuhqE27ZTThQ/53P +3doLa74rKzUqlPI8OL4R34FY2deL7t5l2KSnpf7CKNeF5bkinAsn6NBqyZs2KPmg +yT1/POdlRewzGSqBTMdktNQ4vKSfdFjcfVeo8PSHBgbGXZ5KoHZ6R6DNJehEh27l +z3OteROLGoii+w3OllLq6JATif2MDIbH0s/KjGjbXSSGbM/rZu5eBZm5/vksGAzc +u53wgIhCJGuX +-----END CERTIFICATE----- +` + +const serverkey = ` +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDQJ2MqoaLFsDcx +owcEsJg1CDTmj7OE0k4Cra+ghNNWrcQFs1lAsmKp0OHz3uatrcebf1a5p0zac8nG +ppiibMxi5Zo/QxcUomFcNqU0ytwwVhZH08MavfJjO8DJwlstU9i9qZrmI53ez2uq +UGvTh2KCcEc5SJdkN2FXTBiAQ25Fu9vANXXHpZCHRab9692sk2b9Nb1PCTF3eDyx +8F1NMdIR5R+dvpMfvTXa7q3aqlei6lXzZUkdYxWZMCwQar4d16yUQ/UNNRAF+pvr +n2dMqvJivYnzEbcEHrddHJCjJ8etzML8QxZqqwY+BzLSEVKSviwQ7NZHuw70tHi6 +mqI1pdvBAgMBAAECggEAOqqPOYm63arPs462QK0hCPlaJ41i1FGNqRWYxU4KXoi1 +EcI9qo1cX24+8MPnEhZDhuD56XNsprkxqmpz5Htzk4AQ3DmlfKxTcnD4WQu/yWPJ +/c6CU7wrX6qMqJC9r+XM1Y/C15A8Q3sEZkkqSsECk67fdBawjI9LQRZyZVwb7U0F +qtvbKM7VQA6hrgdSmXWJ+spp5yymVFF22Ssz31SSbCI93bnp3mukRCKWdRmA9pmT +VXa0HzJ5p70WC+Se9nA/1riWGKt4HCmjVeEtZuiwaUTlXDSeYpu2e4QrX1OnUXBu +Z7yfviTqA8o7KfiA6urumFbAMJcibxkWJoWacc5tTQKBgQD39ZdtNz8B6XJy7f5h +bo9Ag9OrkVX+HITQyWKpcCDba9SuIX3/F++2AK4oeJ3aHKMJWiP19hQvGS1xE67X +TKejOsQxORn6nAYQpFd3AOBOtKAC+VQITBqlfq2ukGmvcQ1O31hMOFbZagFA5cpU +LYb9VVDsZzhM7CccIn/EGEZjgwKBgQDW51rUA2S9naV/iEGhw1tuhoQ5OADD/n8f +pPIkbGxmACDaX/7jt+UwlDU0EsI+aBlJUDqGiEZ5z3UPmaSJUdfRCeJEdKIe1GLm +nqF3sF6Aq+S/79v/wKYn+MHcoiWog5n3McLzZ3+0rwrhMREjE2eWPwVHz/jJIFP3 +Pp3+UZVsawKBgB4Az5PdjXgzwS968L7lW9wYl3I5Iciftsp0s8WA1dj3EUMItnA5 +ez3wkyI+hgswT+H/0D4gyoxwZXk7Qnq2wcoUgEzcdfJHEszMtfCmYH3liT8S4EIo +w0inLWjj/IXIDi4vBEYkww2HsCMkKvlIkP7yZdpVGxDjuk/DNOaLcWj1AoGAXuyK +PiPRl7/Onmp9MwqrlEJunSeTjv8W/89H9ba+mr9rw4mreMJ9xdtxNLMkgZRRtwRt +FYeUObHdLyradp1kCr2m6D3sblm55cwj3k5VL9i9jdpQ/sMFoZpLZz1oDOs0Uu/0 +ALeyvQikcZvOygOEOeVUW8gNSCmzbP6HoxI+QkkCgYBCI6oL4GPcPPqzd+2djbOD +z3rVUyHzYc1KUcBixK/uaRQKM886k4CL8/GvbHHI/yoZ7xWJGnBi59DtpqnGTZJ2 +FDJwYIlQKhZmsyVcZu/4smsaejGnHn/liksVlgesSwCtOrsd2AC8fBXSyrTWJx8o +vwRMog6lPhlRhHh/FZ43Cg== +-----END PRIVATE KEY----- +` + +// getUntrustedServer returns a server but it doesn't add its public key to the list of pinned ones. +func getUntrustedServer() *httptest.Server { + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + cert, err := tls.X509KeyPair([]byte(servercrt), []byte(serverkey)) + if err != nil { + panic(err) + } + server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} + + server.StartTLS() + return server +} + +// closeServer closes the given server. If it is a trusted server, its cert is removed from the trusted public keys. +func closeServer(server *httptest.Server) { + pin := certFingerprint(server.Certificate()) + + for i := range TrustedAPIPins { + if TrustedAPIPins[i] == pin { + TrustedAPIPins = append(TrustedAPIPins[:i], TrustedAPIPins[i:]...) + break + } + } + + server.Close() +} + +func TestProxyDialer_UseProxy(t *testing.T) { + trustedProxy := getTrustedServer() + defer closeServer(trustedProxy) + + cfg := Config{HostURL: ""} + d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg)) + d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } + + err := d.switchToReachableServer() + require.NoError(t, err) + require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress) +} + +func TestProxyDialer_UseProxy_MultipleTimes(t *testing.T) { + proxy1 := getTrustedServer() + defer closeServer(proxy1) + proxy2 := getTrustedServer() + defer closeServer(proxy2) + proxy3 := getTrustedServer() + defer closeServer(proxy3) + + cfg := Config{HostURL: ""} + d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg)) + d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil } + + err := d.switchToReachableServer() + require.NoError(t, err) + require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress) + + // Have to wait so as to not get rejected. + time.Sleep(proxyLookupWait) + + d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy2.URL}, nil } + err = d.switchToReachableServer() + require.NoError(t, err) + require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress) + + // Have to wait so as to not get rejected. + time.Sleep(proxyLookupWait) + + d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy3.URL}, nil } + err = d.switchToReachableServer() + require.NoError(t, err) + require.Equal(t, formatAsAddress(proxy3.URL), d.proxyAddress) +} + +func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) { + trustedProxy := getTrustedServer() + defer closeServer(trustedProxy) + + cfg := Config{HostURL: ""} + d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg)) + d.proxyUseDuration = time.Second + + d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } + err := d.switchToReachableServer() + require.NoError(t, err) + require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress) + + time.Sleep(2 * time.Second) + require.Equal(t, ":443", d.proxyAddress) +} + +func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) { + trustedProxy := getTrustedServer() + + cfg := Config{HostURL: ""} + d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg)) + d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } + + err := d.switchToReachableServer() + require.NoError(t, err) + require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress) + + // Simulate that the proxy stops working and that the standard api is reachable again. + closeServer(trustedProxy) + d.directAddress = formatAsAddress(getRootURL()) + d.proxyProvider.cfg.HostURL = getRootURL() + time.Sleep(proxyLookupWait) + + // We should now find the original API URL if it is working again. + // The error should be ErrAPINotReachable because the connection dropped intermittently but + // the original API is now reachable (see Alternative-Routing-v2 spec for details). + err = d.switchToReachableServer() + require.Error(t, err) + require.Equal(t, formatAsAddress(getRootURL()), d.proxyAddress) +} + +func TestProxyDialer_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) { + // proxy1 is closed later in this test so we don't defer it here. + proxy1 := getTrustedServer() + + proxy2 := getTrustedServer() + defer closeServer(proxy2) + + cfg := Config{HostURL: ""} + d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg)) + d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } + + err := d.switchToReachableServer() + require.NoError(t, err) + require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress) + + // Have to wait so as to not get rejected. + time.Sleep(proxyLookupWait) + + // The proxy stops working and the protonmail API is still blocked. + closeServer(proxy1) + + // Should switch to the second proxy because both the first proxy and the protonmail API are blocked. + err = d.switchToReachableServer() + require.NoError(t, err) + require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress) +} diff --git a/pkg/pmapi/errors.go b/pkg/pmapi/errors.go index 7e48bad0..47776170 100644 --- a/pkg/pmapi/errors.go +++ b/pkg/pmapi/errors.go @@ -1,9 +1,37 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import "errors" var ( - ErrNoConnection = errors.New("no internet connection") - ErrAPIFailure = errors.New("API returned an error") - ErrUnauthorized = errors.New("API client is unauthorized") + ErrNoConnection = errors.New("no internet connection") + ErrUnauthorized = errors.New("API client is unauthorized") + ErrUpgradeApplication = errors.New("application upgrade required") + + ErrBad2FACode = errors.New("incorrect 2FA code") + ErrBad2FACodeTryAgain = errors.New("incorrect 2FA code: please try again") ) + +type ErrUnprocessableEntity struct { + originalError error +} + +func (err ErrUnprocessableEntity) Error() string { + return err.originalError.Error() +} diff --git a/pkg/pmapi/events.go b/pkg/pmapi/events.go index e8b1c49f..fb29b0a5 100644 --- a/pkg/pmapi/events.go +++ b/pkg/pmapi/events.go @@ -32,7 +32,7 @@ type Event struct { // If set to one, all cached data must be fetched again. Refresh int // If set to one, fetch more events. - More int + More Boolean // Changes applied to messages. Messages []*EventMessage // Counts of messages per labels. @@ -167,26 +167,32 @@ type EventAddress struct { // GetEvent returns a summary of events that occurred since last. To get the latest event, // provide an empty last value. The latest event is always empty. -func (c *client) GetEvent(ctx context.Context, eventID string) (event *Event, err error) { +func (c *client) GetEvent(ctx context.Context, eventID string) (*Event, error) { + return c.getEvent(ctx, eventID, 1) +} + +func (c *client) getEvent(ctx context.Context, eventID string, numberOfMergedEvents int) (*Event, error) { if eventID == "" { eventID = "latest" } - var res struct { - *Event - - More int - } + var event *Event if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { - return r.SetResult(&res).Get("/events/" + eventID) + return r.SetResult(&event).Get("/events/" + eventID) }); err != nil { return nil, err } - // FIXME(conman): use mergeEvents() function. + if event.More && numberOfMergedEvents < maxNumberOfMergedEvents { + nextEvent, err := c.getEvent(ctx, event.EventID, numberOfMergedEvents+1) + if err != nil { + return nil, err + } + event = mergeEvents(event, nextEvent) + } - return res.Event, nil + return event, nil } // mergeEvents combines an old events and a new events object. diff --git a/pkg/pmapi/events_test.go b/pkg/pmapi/events_test.go index 902b26e4..2cf7870a 100644 --- a/pkg/pmapi/events_test.go +++ b/pkg/pmapi/events_test.go @@ -27,13 +27,12 @@ import ( "strings" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + r "github.com/stretchr/testify/require" ) func TestClient_GetEvent(t *testing.T) { - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.NoError(t, checkMethodAndPath(r, "GET", "/events/latest")) + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "GET", "/events/latest")) w.Header().Set("Content-Type", "application/json") @@ -41,14 +40,14 @@ func TestClient_GetEvent(t *testing.T) { })) defer s.Close() - event, err := c.GetEvent(context.TODO(), "") - require.NoError(t, err) - require.Equal(t, testEvent, event) + event, err := c.GetEvent(context.Background(), "") + r.NoError(t, err) + r.Equal(t, testEvent, event) } func TestClient_GetEvent_withID(t *testing.T) { - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.NoError(t, checkMethodAndPath(r, "GET", "/events/"+testEvent.EventID)) + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "GET", "/events/"+testEvent.EventID)) w.Header().Set("Content-Type", "application/json") @@ -56,23 +55,22 @@ func TestClient_GetEvent_withID(t *testing.T) { })) defer s.Close() - event, err := c.GetEvent(context.TODO(), testEvent.EventID) - require.NoError(t, err) - require.Equal(t, testEvent, event) + event, err := c.GetEvent(context.Background(), testEvent.EventID) + r.NoError(t, err) + r.Equal(t, testEvent, event) } // We first call GetEvent with id of eventID1, which returns More=1 so we fetch with id eventID2. -// FIXME(conman): Merging is currently not supported. Implement it and then enable this test again! -func _TestClient_GetEvent_mergeEvents(t *testing.T) { - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func TestClient_GetEvent_mergeEvents(t *testing.T) { + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Header().Set("Content-Type", "application/json") - switch r.URL.RequestURI() { + switch req.URL.RequestURI() { case "/events/eventID1": - assert.NoError(t, checkMethodAndPath(r, "GET", "/events/eventID1")) + r.NoError(t, checkMethodAndPath(req, "GET", "/events/eventID1")) fmt.Fprint(w, testEventBodyMore1) case "/events/eventID2": - assert.NoError(t, checkMethodAndPath(r, "GET", "/events/eventID2")) + r.NoError(t, checkMethodAndPath(req, "GET", "/events/eventID2")) fmt.Fprint(w, testEventBodyMore2) default: t.Fail() @@ -80,29 +78,26 @@ func _TestClient_GetEvent_mergeEvents(t *testing.T) { })) defer s.Close() - event, err := c.GetEvent(context.TODO(), "eventID1") - require.NoError(t, err) - require.Equal(t, testEventMerged, event) + event, err := c.GetEvent(context.Background(), "eventID1") + r.NoError(t, err) + r.Equal(t, testEventMerged, event) } -// FIXME(conman): Merging is currently not supported. Implement it and then enable this test again! -func _TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) { +func TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) { numberOfCalls := 0 - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { numberOfCalls++ re := regexp.MustCompile(`/eventID([0-9]+)`) - eventIDString := re.FindStringSubmatch(r.URL.RequestURI())[1] + eventIDString := re.FindStringSubmatch(req.URL.RequestURI())[1] eventID, err := strconv.Atoi(eventIDString) - require.NoError(t, err) + r.NoError(t, err) if numberOfCalls > maxNumberOfMergedEvents*2 { - require.Fail(t, "Too many calls!") + r.Fail(t, "Too many calls!") } - fmt.Println("") - body := strings.ReplaceAll(testEventBodyMore1, "eventID2", "eventID"+strconv.Itoa(eventID+1)) w.Header().Set("Content-Type", "application/json") @@ -110,14 +105,14 @@ func _TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) { })) defer s.Close() - event, err := c.GetEvent(context.TODO(), "eventID1") - require.NoError(t, err) - require.Equal(t, maxNumberOfMergedEvents, numberOfCalls) - require.Equal(t, 1, event.More) + event, err := c.GetEvent(context.Background(), "eventID1") + r.NoError(t, err) + r.Equal(t, maxNumberOfMergedEvents, numberOfCalls) + r.True(t, bool(event.More)) } var ( - testEventMessageUpdateUnread = False + testEventMessageUpdateUnread = Boolean(false) testEvent = &Event{ EventID: "eventID1", diff --git a/pkg/pmapi/import.go b/pkg/pmapi/import.go index d2ab3678..d95d7028 100644 --- a/pkg/pmapi/import.go +++ b/pkg/pmapi/import.go @@ -37,9 +37,8 @@ type ImportMsgReq struct { type ImportMsgReqs []*ImportMsgReq func (reqs ImportMsgReqs) buildMultipartFormData() ([]*resty.MultipartField, error) { - var fields []*resty.MultipartField - - metadata := make(map[string]*ImportMetadata) + metadata := make(map[string]*ImportMetadata, len(reqs)) + fields := make([]*resty.MultipartField, 0, len(reqs)) for i, req := range reqs { name := strconv.Itoa(i) @@ -68,7 +67,6 @@ func (reqs ImportMsgReqs) buildMultipartFormData() ([]*resty.MultipartField, err return fields, nil } -// TODO: Add other metadata. type ImportMetadata struct { AddressID string Unread Boolean // 0: read, 1: unread. @@ -114,7 +112,7 @@ func (c *client) Import(ctx context.Context, reqs ImportMsgReqs) ([]*ImportMsgRe return nil, err } - var resps []*ImportMsgRes + resps := make([]*ImportMsgRes, 0, len(res.Responses)) for _, resp := range res.Responses { var err error diff --git a/pkg/pmapi/import_test.go b/pkg/pmapi/import_test.go index 62f3f9c0..e0e9681b 100644 --- a/pkg/pmapi/import_test.go +++ b/pkg/pmapi/import_test.go @@ -25,17 +25,17 @@ import ( "io/ioutil" "mime/multipart" "net/http" - "reflect" "testing" pmmime "github.com/ProtonMail/proton-bridge/pkg/mime" + r "github.com/stretchr/testify/require" ) var testImportReqs = []*ImportMsgReq{ { Metadata: &ImportMetadata{ AddressID: "QMJs2dzTx7uqpH5PNgIzjULywU4gO9uMBhEMVFOAVJOoUml54gC0CCHtW9qYwzH-zYbZwMv3MFYncPjW1Usq7Q==", - Unread: 0, + Unread: Boolean(false), Flags: FlagReceived | FlagImported, LabelIDs: []string{ArchiveLabel}, }, @@ -57,86 +57,52 @@ var testImportRes = &ImportMsgRes{ } func TestClient_Import(t *testing.T) { // nolint[funlen] - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/messages/import")) + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "POST", "/mail/v4/messages/import")) - contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type")) - if err != nil { - t.Error("Expected no error while parsing request content type, got:", err) - } - if contentType != "multipart/form-data" { - t.Errorf("Invalid request content type: expected %v but got %v", "multipart/form-data", contentType) - } + contentType, params, err := pmmime.ParseMediaType(req.Header.Get("Content-Type")) + r.NoError(t, err) + r.Equal(t, "multipart/form-data", contentType) - mr := multipart.NewReader(r.Body, params["boundary"]) + mr := multipart.NewReader(req.Body, params["boundary"]) // First part is message body. p, err := mr.NextPart() - if err != nil { - t.Error("Expected no error while reading second part of request body, got:", err) - } + r.NoError(t, err) contentDisp, params, err := pmmime.ParseMediaType(p.Header.Get("Content-Disposition")) - if err != nil { - t.Error("Expected no error while parsing part content disposition, got:", err) - } - if contentDisp != "form-data" { - t.Errorf("Invalid part content disposition: expected %v but got %v", "form-data", contentType) - } - if params["name"] != "0" { - t.Errorf("Invalid part name: expected %v but got %v", "0", params["name"]) - } + r.NoError(t, err) + r.Equal(t, "form-data", contentDisp) + r.Equal(t, "0", params["name"]) b, err := ioutil.ReadAll(p) - if err != nil { - t.Error("Expected no error while reading second part body, got:", err) - } - - if string(b) != string(testImportReqs[0].Message) { - t.Errorf("Invalid message body: expected %v but got %v", string(testImportReqs[0].Message), string(b)) - } + r.NoError(t, err) + r.Equal(t, string(testImportReqs[0].Message), string(b)) // Second part is metadata. p, err = mr.NextPart() - if err != nil { - t.Error("Expected no error while reading first part of request body, got:", err) - } + r.NoError(t, err) contentDisp, params, err = pmmime.ParseMediaType(p.Header.Get("Content-Disposition")) - if err != nil { - t.Error("Expected no error while parsing part content disposition, got:", err) - } - if contentDisp != "form-data" { - t.Errorf("Invalid part content disposition: expected %v but got %v", "form-data", contentType) - } - if params["name"] != "Metadata" { - t.Errorf("Invalid part name: expected %v but got %v", "Metadata", params["name"]) - } + r.NoError(t, err) + r.Equal(t, "form-data", contentDisp) + r.Equal(t, "Metadata", params["name"]) metadata := map[string]*ImportMetadata{} - if err := json.NewDecoder(p).Decode(&metadata); err != nil { - t.Error("Expected no error while parsing metadata json, got:", err) - } + err = json.NewDecoder(p).Decode(&metadata) + r.NoError(t, err) - if len(metadata) != 1 { - t.Errorf("Expected metadata to contain exactly one item, got %v", metadata) - } + r.Equal(t, 1, len(metadata)) - req := metadata["0"] - if metadata["0"] == nil { - t.Errorf("Expected metadata to contain one item indexed by 0, got %v", metadata) - } + importReq := metadata["0"] + r.NotNil(t, req) expected := *testImportReqs[0].Metadata - if !reflect.DeepEqual(&expected, req) { - t.Errorf("Invalid message metadata: expected %v, got %v", &expected, req) - } + r.Equal(t, &expected, importReq) // No more parts. _, err = mr.NextPart() - if err != io.EOF { - t.Error("Expected no more parts but error was not EOF, got:", err) - } + r.EqualError(t, err, io.EOF.Error()) w.Header().Set("Content-Type", "application/json") @@ -144,16 +110,8 @@ func TestClient_Import(t *testing.T) { // nolint[funlen] })) defer s.Close() - imported, err := c.Import(context.TODO(), testImportReqs) - if err != nil { - t.Fatal("Expected no error while importing, got:", err) - } - - if len(imported) != 1 { - t.Fatalf("Expected exactly one imported message, got %v", len(imported)) - } - - if !reflect.DeepEqual(testImportRes, imported[0]) { - t.Errorf("Invalid response for imported message: expected %+v but got %+v", testImportRes, imported[0]) - } + imported, err := c.Import(context.Background(), testImportReqs) + r.NoError(t, err) + r.Equal(t, 1, len(imported)) + r.Equal(t, testImportRes, imported[0]) } diff --git a/pkg/pmapi/key.go b/pkg/pmapi/key.go index 28e99546..acc5f6dc 100644 --- a/pkg/pmapi/key.go +++ b/pkg/pmapi/key.go @@ -19,7 +19,6 @@ package pmapi import ( "context" - "net/url" "github.com/go-resty/resty/v2" ) @@ -44,8 +43,6 @@ const ( // GetPublicKeysForEmail returns all sending public keys for the given email address. func (c *client) GetPublicKeysForEmail(ctx context.Context, email string) (keys []PublicKey, internal bool, err error) { - email = url.QueryEscape(email) - var res struct { Keys []PublicKey RecipientType RecipientType diff --git a/pkg/pmapi/labels.go b/pkg/pmapi/labels.go index 7db85de3..1329c7c5 100644 --- a/pkg/pmapi/labels.go +++ b/pkg/pmapi/labels.go @@ -75,42 +75,35 @@ var LabelColors = []string{ //nolint[gochecknoglobals] "#dfb286", } -type LabelAction int - -const ( - RemoveLabel LabelAction = iota - AddLabel -) - // Label for message. -type Label struct { +type Label struct { //nolint[maligned] ID string Name string Path string Color string Order int `json:",omitempty"` Display int // Not used for now, leave it empty. - Exclusive int + Exclusive Boolean Type int - Notify int + Notify Boolean } func (c *client) ListLabels(ctx context.Context) (labels []*Label, err error) { - return c.ListLabelType(ctx, LabelTypeMailbox) + return c.listLabelType(ctx, LabelTypeMailbox) } func (c *client) ListContactGroups(ctx context.Context) (labels []*Label, err error) { - return c.ListLabelType(ctx, LabelTypeContactGroup) + return c.listLabelType(ctx, LabelTypeContactGroup) } -// ListLabelType lists all labels created by the user. -func (c *client) ListLabelType(ctx context.Context, labelType int) (labels []*Label, err error) { +// listLabelType lists all labels created by the user. +func (c *client) listLabelType(ctx context.Context, labelType int) (labels []*Label, err error) { var res struct { Labels []*Label } if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { - return r.SetQueryParam("Type", strconv.Itoa(labelType)).SetResult(&res).Get("/v4/labels") + return r.SetQueryParam("Type", strconv.Itoa(labelType)).SetResult(&res).Get("/labels") }); err != nil { return nil, err } @@ -135,7 +128,7 @@ func (c *client) CreateLabel(ctx context.Context, label *Label) (created *Label, if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(&LabelReq{ Label: label, - }).SetResult(&res).Post("/v4/labels") + }).SetResult(&res).Post("/labels") }); err != nil { return nil, err } @@ -156,7 +149,7 @@ func (c *client) UpdateLabel(ctx context.Context, label *Label) (updated *Label, if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(&LabelReq{ Label: label, - }).SetResult(&res).Put("/v4/labels/" + label.ID) + }).SetResult(&res).Put("/labels/" + label.ID) }); err != nil { return nil, err } @@ -167,7 +160,7 @@ func (c *client) UpdateLabel(ctx context.Context, label *Label) (updated *Label, // DeleteLabel deletes a label. func (c *client) DeleteLabel(ctx context.Context, labelID string) error { if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { - return r.Delete("/v4/labels/" + labelID) + return r.Delete("/labels/" + labelID) }); err != nil { return err } diff --git a/pkg/pmapi/labels_test.go b/pkg/pmapi/labels_test.go index d1610bfd..dcf4c5f0 100644 --- a/pkg/pmapi/labels_test.go +++ b/pkg/pmapi/labels_test.go @@ -91,61 +91,43 @@ const testDeleteLabelBody = `{ ` func TestClient_ListLabels(t *testing.T) { - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "GET", "/v4/labels?Type=1")) + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "GET", "/labels?Type=1")) w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, testLabelsBody) })) defer s.Close() - labels, err := c.ListLabels(context.TODO()) - if err != nil { - t.Fatal("Expected no error while listing labels, got:", err) - } - - if !reflect.DeepEqual(labels, testLabels) { - for i, l := range testLabels { - t.Errorf("expected %d: %#v\n", i, l) - } - for i, l := range labels { - t.Errorf("got %d: %#v\n", i, l) - } - t.Fatalf("Not same") - } + labels, err := c.ListLabels(context.Background()) + r.NoError(t, err) + r.Equal(t, testLabels, labels) } func TestClient_CreateLabel(t *testing.T) { - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "POST", "/v4/labels")) + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "POST", "/labels")) body := &bytes.Buffer{} - _, err := body.ReadFrom(r.Body) - Ok(t, err) + _, err := body.ReadFrom(req.Body) + r.NoError(t, err) if bytes.Contains(body.Bytes(), []byte("Order")) { t.Fatal("Body contains `Order`: ", body.String()) } var labelReq LabelReq - if err := json.NewDecoder(body).Decode(&labelReq); err != nil { - t.Error("Expecting no error while reading request body, got:", err) - } - if !reflect.DeepEqual(testLabelReq.Label, labelReq.Label) { - t.Errorf("Invalid label request: expected %+v but got %+v", testLabelReq.Label, labelReq.Label) - } + err = json.NewDecoder(body).Decode(&labelReq) + r.NoError(t, err) + r.Equal(t, testLabelReq.Label, labelReq.Label) w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, testCreateLabelBody) })) defer s.Close() - created, err := c.CreateLabel(context.TODO(), testLabelReq.Label) - if err != nil { - t.Fatal("Expected no error while creating label, got:", err) - } + created, err := c.CreateLabel(context.Background(), testLabelReq.Label) + r.NoError(t, err) if !reflect.DeepEqual(created, testLabelCreated) { t.Fatalf("Invalid created label: expected %+v, got %+v", testLabelCreated, created) @@ -158,32 +140,26 @@ func TestClient_CreateEmptyLabel(t *testing.T) { })) defer s.Close() - _, err := c.CreateLabel(context.TODO(), &Label{}) + _, err := c.CreateLabel(context.Background(), &Label{}) r.EqualError(t, err, "name is required") } func TestClient_UpdateLabel(t *testing.T) { - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "PUT", "/v4/labels/"+testLabelCreated.ID)) + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "PUT", "/labels/"+testLabelCreated.ID)) var labelReq LabelReq - if err := json.NewDecoder(r.Body).Decode(&labelReq); err != nil { - t.Error("Expecting no error while reading request body, got:", err) - } - if !reflect.DeepEqual(testLabelCreated, labelReq.Label) { - t.Errorf("Invalid label request: expected %+v but got %+v", testLabelCreated, labelReq.Label) - } + err := json.NewDecoder(req.Body).Decode(&labelReq) + r.NoError(t, err) + r.Equal(t, testLabelCreated, labelReq.Label) w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, testCreateLabelBody) })) defer s.Close() - updated, err := c.UpdateLabel(context.TODO(), testLabelCreated) - if err != nil { - t.Fatal("Expected no error while updating label, got:", err) - } + updated, err := c.UpdateLabel(context.Background(), testLabelCreated) + r.NoError(t, err) if !reflect.DeepEqual(updated, testLabelCreated) { t.Fatalf("Invalid updated label: expected %+v, got %+v", testLabelCreated, updated) @@ -196,24 +172,21 @@ func TestClient_UpdateLabelToEmptyName(t *testing.T) { })) defer s.Close() - _, err := c.UpdateLabel(context.TODO(), &Label{ID: "label"}) + _, err := c.UpdateLabel(context.Background(), &Label{ID: "label"}) r.EqualError(t, err, "name is required") } func TestClient_DeleteLabel(t *testing.T) { - s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "DELETE", "/v4/labels/"+testLabelCreated.ID)) + s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "DELETE", "/labels/"+testLabelCreated.ID)) w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, testDeleteLabelBody) })) defer s.Close() - err := c.DeleteLabel(context.TODO(), testLabelCreated.ID) - if err != nil { - t.Fatal("Expected no error while deleting label, got:", err) - } + err := c.DeleteLabel(context.Background(), testLabelCreated.ID) + r.NoError(t, err) } func TestLeastUsedColor(t *testing.T) { diff --git a/pkg/pmapi/manager.go b/pkg/pmapi/manager.go index bb73efe3..2314704b 100644 --- a/pkg/pmapi/manager.go +++ b/pkg/pmapi/manager.go @@ -1,3 +1,20 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import ( @@ -10,56 +27,53 @@ import ( ) type manager struct { - rc *resty.Client + cfg Config + rc *resty.Client - isDown bool - locker sync.Locker - observers []ConnectionObserver -} - -func newManager(cfg Config) *manager { - m := &manager{ - rc: resty.New(), - locker: &sync.Mutex{}, - } - - // Set the API host. - m.rc.SetHostURL(cfg.HostURL) - - // Set static header values. - m.rc.SetHeader("x-pm-appversion", cfg.AppVersion) - - // Set middleware. - m.rc.OnAfterResponse(catchAPIError) - - // Configure retry mechanism. - m.rc.SetRetryMaxWaitTime(time.Minute) - m.rc.SetRetryAfter(catchRetryAfter) - m.rc.AddRetryCondition(catchTooManyRequests) - m.rc.AddRetryCondition(catchNoResponse) - m.rc.AddRetryCondition(catchProxyAvailable) - - // Determine what happens when requests succeed/fail. - m.rc.OnAfterResponse(m.handleRequestSuccess) - m.rc.OnError(m.handleRequestFailure) - - // Set the data type of API errors. - m.rc.SetError(&Error{}) - - return m + isDown bool + locker sync.Locker + connectionObservers []ConnectionObserver + proxyDialer *ProxyTLSDialer } func New(cfg Config) Manager { return newManager(cfg) } -func (m *manager) SetLogger(logger resty.Logger) { - m.rc.SetLogger(logger) - m.rc.SetDebug(true) +func newManager(cfg Config) *manager { + m := &manager{ + cfg: cfg, + rc: resty.New(), + locker: &sync.Mutex{}, + } + + proxyDialer, transport := newProxyDialerAndTransport(cfg) + m.proxyDialer = proxyDialer + m.rc.SetTransport(transport) + + m.rc.SetHostURL(cfg.HostURL) + m.rc.OnBeforeRequest(m.setHeaderValues) + + // Any HTTP status code higher than 399 with JSON inside (and proper header) + // is converted to Error. `catchAPIError` then processes API custom errors + // wrapped in JSON. If error is returned, `handleRequestFailure` is called, + // otherwise `handleRequestSuccess` is called. + m.rc.SetError(&Error{}) + m.rc.OnAfterResponse(m.catchAPIError) + m.rc.OnAfterResponse(m.handleRequestSuccess) + m.rc.OnError(m.handleRequestFailure) + + // Configure retry mechanism. + m.rc.SetRetryMaxWaitTime(time.Minute) + m.rc.SetRetryAfter(catchRetryAfter) + m.rc.AddRetryCondition(shouldRetry) + + return m } func (m *manager) SetTransport(transport http.RoundTripper) { m.rc.SetTransport(transport) + m.proxyDialer = nil } func (m *manager) SetCookieJar(jar http.CookieJar) { @@ -71,7 +85,15 @@ func (m *manager) SetRetryCount(count int) { } func (m *manager) AddConnectionObserver(observer ConnectionObserver) { - m.observers = append(m.observers, observer) + m.connectionObservers = append(m.connectionObservers, observer) +} + +func (m *manager) setHeaderValues(_ *resty.Client, req *resty.Request) error { + req.SetHeaders(map[string]string{ + "x-pm-appversion": m.cfg.AppVersion, + "User-Agent": m.cfg.getUserAgent(), + }) + return nil } func (m *manager) r(ctx context.Context) *resty.Request { @@ -90,7 +112,7 @@ func (m *manager) handleRequestSuccess(_ *resty.Client, res *resty.Response) err m.isDown = false - for _, observer := range m.observers { + for _, observer := range m.connectionObservers { observer.OnUp() } @@ -113,15 +135,9 @@ func (m *manager) handleRequestFailure(req *resty.Request, err error) { m.isDown = true - for _, observer := range m.observers { + for _, observer := range m.connectionObservers { observer.OnDown() } go m.pingUntilSuccess() } - -func (m *manager) pingUntilSuccess() { - for m.testPing(context.Background()) != nil { - time.Sleep(time.Second) // TODO: How long to sleep here? - } -} diff --git a/pkg/pmapi/manager_auth.go b/pkg/pmapi/manager_auth.go index ace86320..7f47e34f 100644 --- a/pkg/pmapi/manager_auth.go +++ b/pkg/pmapi/manager_auth.go @@ -1,3 +1,20 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import ( @@ -9,10 +26,14 @@ import ( ) func (m *manager) NewClient(uid, acc, ref string, exp time.Time) Client { + log.Trace("New client") + return newClient(m, uid).withAuth(acc, ref, exp) } -func (m *manager) NewClientWithRefresh(ctx context.Context, uid, ref string) (Client, *Auth, error) { +func (m *manager) NewClientWithRefresh(ctx context.Context, uid, ref string) (Client, *AuthRefresh, error) { + log.Trace("New client with refresh") + c := newClient(m, uid) auth, err := m.authRefresh(ctx, uid, ref) @@ -24,6 +45,8 @@ func (m *manager) NewClientWithRefresh(ctx context.Context, uid, ref string) (Cl } func (m *manager) NewClientWithLogin(ctx context.Context, username, password string) (Client, *Auth, error) { + log.Trace("New client with login") + info, err := m.getAuthInfo(ctx, GetAuthInfoReq{Username: username}) if err != nil { return nil, nil, err @@ -52,24 +75,13 @@ func (m *manager) NewClientWithLogin(ctx context.Context, username, password str return newClient(m, auth.UID).withAuth(auth.AccessToken, auth.RefreshToken, expiresIn(auth.ExpiresIn)), auth, nil } -func (m *manager) getAuthModulus(ctx context.Context) (AuthModulus, error) { - var res struct { - AuthModulus - } - - if _, err := m.r(ctx).SetResult(&res).Get("/auth/modulus"); err != nil { - return AuthModulus{}, err - } - - return res.AuthModulus, nil -} - func (m *manager) getAuthInfo(ctx context.Context, req GetAuthInfoReq) (*AuthInfo, error) { var res struct { *AuthInfo } - if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/info"); err != nil { + _, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/info")) + if err != nil { return nil, err } @@ -81,15 +93,16 @@ func (m *manager) auth(ctx context.Context, req AuthReq) (*Auth, error) { *Auth } - if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth"); err != nil { + _, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth")) + if err != nil { return nil, err } return res.Auth, nil } -func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*Auth, error) { - var req = AuthRefreshReq{ +func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefresh, error) { + var req = authRefreshReq{ UID: uid, RefreshToken: ref, ResponseType: "token", @@ -99,14 +112,15 @@ func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*Auth, erro } var res struct { - *Auth + *AuthRefresh } - if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh"); err != nil { + _, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh")) + if err != nil { return nil, err } - return res.Auth, nil + return res.AuthRefresh, nil } func expiresIn(seconds int64) time.Time { diff --git a/pkg/pmapi/manager_download.go b/pkg/pmapi/manager_download.go index 04e43458..15f75a37 100644 --- a/pkg/pmapi/manager_download.go +++ b/pkg/pmapi/manager_download.go @@ -1,3 +1,20 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import ( diff --git a/pkg/pmapi/manager_log.go b/pkg/pmapi/manager_log.go new file mode 100644 index 00000000..a8749590 --- /dev/null +++ b/pkg/pmapi/manager_log.go @@ -0,0 +1,71 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "github.com/go-resty/resty/v2" + "github.com/sirupsen/logrus" +) + +// restyLogger decreases debug level to trace level so resty logs +// are not logged as debug but trace instead. Resty logging is too +// verbose which we don't want to have in debug level. +type restyLogger struct { + logrus *logrus.Entry +} + +func (l *restyLogger) Errorf(format string, v ...interface{}) { + l.logrus.Errorf(format, v...) +} + +func (l *restyLogger) Warnf(format string, v ...interface{}) { + l.logrus.Warnf(format, v...) +} + +func (l *restyLogger) Debugf(format string, v ...interface{}) { + l.logrus.Tracef(format, v...) +} + +func (m *manager) SetLogging(logger *logrus.Entry, verbose bool) { + if verbose { + m.rc.SetLogger(&restyLogger{logrus: logger}) + m.rc.SetDebug(true) + return + } + + m.rc.OnBeforeRequest(func(_ *resty.Client, req *resty.Request) error { + logger.Infof("Requesting %s %s", req.Method, req.URL) + return nil + }) + m.rc.OnAfterResponse(func(_ *resty.Client, res *resty.Response) error { + log := logger.WithFields(logrus.Fields{ + "error": res.Error(), + "status": res.StatusCode(), + "duration": res.Time(), + }) + if res.Request == nil { + log.Warn("Requested unknown request") + return nil + } + log.Debugf("Requested %s %s", res.Request.Method, res.Request.URL) + return nil + }) + m.rc.OnError(func(req *resty.Request, err error) { + logger.WithError(err).Warnf("Failed request %s %s", req.Method, req.URL) + }) +} diff --git a/pkg/pmapi/manager_metrics.go b/pkg/pmapi/manager_metrics.go index d4e05a94..7bd50091 100644 --- a/pkg/pmapi/manager_metrics.go +++ b/pkg/pmapi/manager_metrics.go @@ -1,11 +1,34 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import ( "context" - "errors" ) -func (m *manager) SendSimpleMetric(context.Context, string, string, string) error { - // FIXME(conman): Implement. - return errors.New("not implemented") +func (m *manager) SendSimpleMetric(ctx context.Context, category, action, label string) error { + r := m.r(ctx).SetQueryParams(map[string]string{ + "Category": category, + "Action": action, + "Label": label, + }) + if _, err := wrapNoConnection(r.Get("/metrics")); err != nil { + return err + } + return nil } diff --git a/pkg/pmapi/metrics_test.go b/pkg/pmapi/manager_metrics_test.go similarity index 70% rename from pkg/pmapi/metrics_test.go rename to pkg/pmapi/manager_metrics_test.go index 43541087..c4e77114 100644 --- a/pkg/pmapi/metrics_test.go +++ b/pkg/pmapi/manager_metrics_test.go @@ -23,6 +23,8 @@ import ( "net/http" "net/http/httptest" "testing" + + r "github.com/stretchr/testify/require" ) const testSendSimpleMetricsBody = `{ @@ -30,21 +32,17 @@ const testSendSimpleMetricsBody = `{ } ` -// FIXME(conman): Implement metrics then enable this test. -func _TestClient_SendSimpleMetric(t *testing.T) { - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "GET", "/metrics?Action=some_action&Category=some_category&Label=some_label")) - +func TestClient_SendSimpleMetric(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "GET", "/metrics?Action=some_action&Category=some_category&Label=some_label")) w.Header().Set("Content-Type", "application/json") - + w.WriteHeader(http.StatusOK) fmt.Fprint(w, testSendSimpleMetricsBody) })) defer s.Close() - m := newManager(Config{HostURL: s.URL}) + m := newManager(newTestConfig(s.URL)) - err := m.SendSimpleMetric(context.TODO(), "some_category", "some_action", "some_label") - if err != nil { - t.Fatal("Expected no error while sending simple metric, got:", err) - } + err := m.SendSimpleMetric(context.Background(), "some_category", "some_action", "some_label") + r.NoError(t, err) } diff --git a/pkg/pmapi/manager_ping.go b/pkg/pmapi/manager_ping.go index 6589854f..a47bac76 100644 --- a/pkg/pmapi/manager_ping.go +++ b/pkg/pmapi/manager_ping.go @@ -1,11 +1,60 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi -import "context" +import ( + "context" + "time" + + "github.com/sirupsen/logrus" +) + +var ( + // retryConnectionSleeps defines a smooth cool down in seconds. + retryConnectionSleeps = []int{2, 5, 10, 30, 60} // nolint[gochecknoglobals] +) + +func (m *manager) pingUntilSuccess() { + attempt := 0 + for { + err := m.testPing(context.Background()) + if err == nil { + return + } + + waitTime := getRetryConnectionSleep(attempt) + attempt++ + logrus.WithError(err).WithField("attempt", attempt).WithField("wait", waitTime).Debug("Connection not available") + time.Sleep(waitTime) + } +} + +func getRetryConnectionSleep(idx int) time.Duration { + if idx >= len(retryConnectionSleeps) { + idx = len(retryConnectionSleeps) - 1 + } + sec := retryConnectionSleeps[idx] + return time.Duration(sec) * time.Second +} func (m *manager) testPing(ctx context.Context) error { if _, err := m.r(ctx).Get("/tests/ping"); err != nil { return err } - return nil } diff --git a/pkg/pmapi/manager_proxy.go b/pkg/pmapi/manager_proxy.go new file mode 100644 index 00000000..7dcf80ab --- /dev/null +++ b/pkg/pmapi/manager_proxy.go @@ -0,0 +1,32 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +// AllowProxy allows the client manager to switch clients over to a proxy if need be. +func (m *manager) AllowProxy() { + if m.proxyDialer != nil { + m.proxyDialer.AllowProxy() + } +} + +// DisallowProxy prevents the client manager from switching clients over to a proxy if need be. +func (m *manager) DisallowProxy() { + if m.proxyDialer != nil { + m.proxyDialer.DisallowProxy() + } +} diff --git a/pkg/pmapi/manager_report.go b/pkg/pmapi/manager_report.go index 3e7c05f8..7981891d 100644 --- a/pkg/pmapi/manager_report.go +++ b/pkg/pmapi/manager_report.go @@ -1,12 +1,43 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import ( "context" - "errors" ) // Report sends request as json or multipart (if has attachment). -func (m *manager) ReportBug(context.Context, ReportBugReq) error { - // FIXME(conman): Implement. - return errors.New("not implemented") +func (m *manager) ReportBug(ctx context.Context, rep ReportBugReq) error { + if rep.ClientType == 0 { + rep.ClientType = EmailClientType + } + + r := m.r(ctx) + if len(rep.Attachments) == 0 { + r = r.SetBody(rep) + } else { + r = r.SetMultipartFormData(rep.GetMultipartFormData()) + for _, att := range rep.Attachments { + r = r.SetMultipartField(att.name, att.filename, "application/octet-stream", att.body) + } + } + if _, err := wrapNoConnection(r.Post("/reports/bug")); err != nil { + return err + } + return nil } diff --git a/pkg/pmapi/manager_report_test.go b/pkg/pmapi/manager_report_test.go index 4941fd8f..53020e0a 100644 --- a/pkg/pmapi/manager_report_test.go +++ b/pkg/pmapi/manager_report_test.go @@ -24,9 +24,10 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "runtime" "strings" "testing" + + r "github.com/stretchr/testify/require" ) var testBugReportReq = ReportBugReq{ @@ -42,28 +43,15 @@ var testBugReportReq = ReportBugReq{ Email: "apple@gmail.com", } -var testBugsCrashReq = ReportBugReq{ - OS: runtime.GOOS, - Client: "demoapp", - ClientVersion: "GoPMAPI_1.0.14", - ClientType: 1, - Debug: "main.func·001()\n/Users/sunny/Code/Go/src/scratch/stack.go:21 +0xabruntime.panic(0x80b80, 0x2101fb150)\n/usr/local/Cellar/go/1.2/libexec/src/pkg/runtime/panic.c:248 +0x106\nmain.inner()/Users/sunny/Code/Go/src/scratch/stack.go:27 +0x68\nmain.outer()\n/Users/sunny/Code/Go/src/scratch/stack.go:13 +0x1a\nmain.main()\n/Users/sunny/Code/Go/src/scratch/stack.go:9 +0x1a", -} - const testBugsBody = `{ "Code": 1000 } ` -const testAttachmentJSONZipped = "PK\x03\x04\x14\x00\b\x00\b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\b\x00\x00\x00last.log\\Rَ\xaaH\x00}ﯨ\xf8r\x1f\xeeܖED;\xe9\ap\x03\x11\x11\x97\x0e8\x99L\xb0(\xa1\xa0\x16\x85b\x91I\xff\xfbD{\x99\xc9}\xab:K\x9d\xa4\xce\xf9\xe7\t\x00\x00z\xf6\xb4\xf7\x02z\xb7a\xe5\xd8\x04*V̭\x8d\xd1lvE}\xd6\xe3\x80\x1f\xd7nX\x9bI[\xa6\xe1a=\xd4a\xa8M\x97\xd9J\xf1F\xeb\x105U\xbd\xb0`XO\xce\xf1hu\x99q\xc3\xfe{\x11ߨ'-\v\x89Z\xa4\x9c5\xaf\xaf\xbd?>R\xd6\x11E\xf7\x1cX\xf0JpF#L\x9eE+\xbe\xe8\x1d\xee\ued2e\u007f\xde]\u06dd\xedo\x97\x87E\xa0V\xf4/$\xc2\xecK\xed\xa0\xdb&\x829\x12\xe5\x9do\xa0\xe9\x1a\xd2\x19\x1e\xf5`\x95гb\xf8\x89\x81\xb7\xa5G\x18\x95\xf3\x9d9\xe8\x93B\x17!\x1a^\xccr\xbb`\xb2\xb4\xb86\x87\xb4h\x0e\xda\xc6u<+\x9e$̓\x95\xccSo\xea\xa4\xdbH!\xe9g\x8b\xd4\b\xb3hܬ\xa6Wk\x14He\xae\x8aPU\xaa\xc1\xee$\xfbH\xb3\xab.I\f<\x89\x06q\xe3-3-\x99\xcdݽ\xe5v\x99\xedn\xac\xadn\xe8Rp=\xb4nJ\xed\xd5\r\x8d\xde\x06Ζ\xf6\xb3\x01\x94\xcb\xf6\xd4\x19r\xe1\xaa$4+\xeaW\xa6F\xfa0\x97\x9cD\f\x8e\xd7\xd6z\v,G\xf3e2\xd4\xe6V\xba\v\xb6\xd9\xe8\xca*\x16\x95V\xa4J\xfbp\xddmF\x8c\x9a\xc6\xc8Č-\xdb\v\xf6\xf5\xf9\x02*\x15e\x874\xc9\xe7\"\xa3\x1an\xabq}ˊq\x957\xd3\xfd\xa91\x82\xe0Lß\\\x17\x8e\x9e_\xed`\t\xe9~5̕\x03\x9a\f\xddN6\xa2\xc4\x17\xdb\xc9V\x1c~\x9e\xea\xbe\xda-xv\xed\x8b\xe2\xc8DŽS\x95E6\xf2\xc3H\x1d:HPx\xc9\x14\xbfɒ\xff\xea\xb4P\x14\xa3\xe2\xfe\xfd\x1f+z\x80\x903\x81\x98\xf8\x15\xa3\x12\x16\xf8\"0g\xf7~B^\xfd \x040T\xa3\x02\x9c\x10\xc1\xa8F\xa0I#\xf1\xa3\x04\x98\x01\x91\xe2\x12\xdc;\x06gL\xd0g\xc0\xe3\xbd\xf6\xd7}&\xa8轀?\xbfяy`X\xf0\x92\x9f\x05\xf0*A8ρ\xac=K\xff\xf3\xfe\xa6Z\xe1\x1a\x017\xc2\x04\f\x94g\xa9\xf7-\xfb\xebqz\u007fz\u007f\xfa7\x00\x00\xff\xffPK\a\b\xf5\\\v\xe5I\x02\x00\x00\r\x03\x00\x00PK\x01\x02\x14\x00\x14\x00\b\x00\b\x00\x00\x00\x00\x00\xf5\\\v\xe5I\x02\x00\x00\r\x03\x00\x00\b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00last.logPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x006\x00\x00\x00\u007f\x02\x00\x00\x00\x00" //nolint[misspell] - -// FIXME(conman): Implement bug reports then enable this test. -func _TestClient_BugReportWithAttachment(t *testing.T) { - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "POST", "/reports/bug")) - Ok(t, isAuthReq(r, testUID, testAccessToken)) - - Ok(t, r.ParseMultipartForm(10*1024)) +func TestClient_BugReportWithAttachment(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "POST", "/reports/bug")) + r.NoError(t, req.ParseMultipartForm(10*1024)) for field, expected := range map[string]string{ "OS": testBugReportReq.OS, @@ -76,60 +64,43 @@ func _TestClient_BugReportWithAttachment(t *testing.T) { "Username": testBugReportReq.Username, "Email": testBugReportReq.Email, } { - if r.PostFormValue(field) != expected { - t.Errorf("Field %q has %q but expected %q", field, r.PostFormValue(field), expected) - } + r.Equal(t, expected, req.PostFormValue(field)) } - attReader, err := r.MultipartForm.File["log"][0].Open() - Ok(t, err) - - log, err := ioutil.ReadAll(attReader) - Ok(t, err) - - Equals(t, []byte(testAttachmentJSONZipped), log) + attReader, err := req.MultipartForm.File["log"][0].Open() + r.NoError(t, err) + _, err = ioutil.ReadAll(attReader) + r.NoError(t, err) w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, testBugsBody) })) defer s.Close() - cm := newManager(Config{HostURL: s.URL}) + cm := newManager(newTestConfig(s.URL)) rep := testBugReportReq rep.AddAttachment("log", "last.log", strings.NewReader(testAttachmentJSON)) - Ok(t, cm.ReportBug(context.TODO(), rep)) + err := cm.ReportBug(context.Background(), rep) + r.NoError(t, err) } -// FIXME(conman): Implement bug reports then enable this test. -func _TestClient_BugReport(t *testing.T) { - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - Ok(t, checkMethodAndPath(r, "POST", "/reports/bug")) - Ok(t, isAuthReq(r, testUID, testAccessToken)) +func TestClient_BugReport(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + r.NoError(t, checkMethodAndPath(req, "POST", "/reports/bug")) var bugsReportReq ReportBugReq - Ok(t, json.NewDecoder(r.Body).Decode(&bugsReportReq)) - Equals(t, testBugReportReq, bugsReportReq) + r.NoError(t, json.NewDecoder(req.Body).Decode(&bugsReportReq)) + r.Equal(t, testBugReportReq, bugsReportReq) w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, testBugsBody) })) defer s.Close() - cm := newManager(Config{HostURL: s.URL}) + cm := newManager(newTestConfig(s.URL)) - r := ReportBugReq{ - OS: testBugReportReq.OS, - OSVersion: testBugReportReq.OSVersion, - Browser: testBugReportReq.Browser, - Title: testBugReportReq.Title, - Description: testBugReportReq.Description, - Username: testBugReportReq.Username, - Email: testBugReportReq.Email, - } - - Ok(t, cm.ReportBug(context.TODO(), r)) + err := cm.ReportBug(context.Background(), testBugReportReq) + r.NoError(t, err) } diff --git a/pkg/pmapi/manager_report_types.go b/pkg/pmapi/manager_report_types.go index 02bb56cd..e00d2a23 100644 --- a/pkg/pmapi/manager_report_types.go +++ b/pkg/pmapi/manager_report_types.go @@ -1,12 +1,25 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import ( - "archive/zip" "fmt" "io" - "mime/multipart" - "net/textproto" - "strings" ) // ClientType is required by API. @@ -47,8 +60,8 @@ func (rep *ReportBugReq) AddAttachment(name, filename string, r io.Reader) { rep.Attachments = append(rep.Attachments, reportAtt{name: name, filename: filename, body: r}) } -func writeMultipartReport(w *multipart.Writer, rep *ReportBugReq) error { // nolint[funlen] - fieldData := map[string]string{ +func (rep *ReportBugReq) GetMultipartFormData() map[string]string { + return map[string]string{ "OS": rep.OS, "OSVersion": rep.OSVersion, "Browser": rep.Browser, @@ -58,7 +71,7 @@ func writeMultipartReport(w *multipart.Writer, rep *ReportBugReq) error { // nol "DisplayMode": rep.DisplayMode, "Client": rep.Client, "ClientVersion": rep.ClientVersion, - "ClientType": "1", + "ClientType": fmt.Sprintf("%d", rep.ClientType), "Title": rep.Title, "Description": rep.Description, "Username": rep.Username, @@ -67,46 +80,4 @@ func writeMultipartReport(w *multipart.Writer, rep *ReportBugReq) error { // nol "ISP": rep.ISP, "Debug": rep.Debug, } - - for field, data := range fieldData { - if data == "" { - continue - } - if err := w.WriteField(field, data); err != nil { - return err - } - } - - quoteEscaper := strings.NewReplacer("\\", "\\\\", `"`, "\\\"") - - for _, att := range rep.Attachments { - h := make(textproto.MIMEHeader) - h.Set("Content-Disposition", - fmt.Sprintf(`form-data; name="%s"; filename="%s"`, - quoteEscaper.Replace(att.name), quoteEscaper.Replace(att.filename+".zip"))) - h.Set("Content-Type", "application/octet-stream") - // h.Set("Content-Transfer-Encoding", "base64") - attWr, err := w.CreatePart(h) - if err != nil { - return err - } - - zipArch := zip.NewWriter(attWr) - zipWr, err := zipArch.Create(att.filename) - // b64 := base64.NewEncoder(base64.StdEncoding, zipWr) - if err != nil { - return err - } - _, err = io.Copy(zipWr, att.body) - if err != nil { - return err - } - err = zipArch.Close() - // err = b64.Close() - if err != nil { - return err - } - } - - return nil } diff --git a/pkg/pmapi/manager_test.go b/pkg/pmapi/manager_test.go index 0657960f..bf592687 100644 --- a/pkg/pmapi/manager_test.go +++ b/pkg/pmapi/manager_test.go @@ -1,16 +1,39 @@ -package pmapi_test +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi import ( "context" "errors" + "fmt" "net/http" "net/http/httptest" "testing" "time" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" + r "github.com/stretchr/testify/require" ) +const testForceUpgradeBody = `{ + "Code":5003, + "Error":"Upgrade!" +}` + func TestHandleTooManyRequests(t *testing.T) { var numCalls int @@ -24,21 +47,17 @@ func TestHandleTooManyRequests(t *testing.T) { } })) - m := pmapi.New(pmapi.Config{HostURL: ts.URL}) + m := New(Config{HostURL: ts.URL}) - // Set the retry count to 5. m.SetRetryCount(5) // The call should succeed because the 5th retry should succeed (429s are retried). - if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err != nil { - t.Fatal("got unexpected error", err) - } + _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()) + r.NoError(t, err) // The server should be called 5 times. // The first four calls should return 429 and the last call should return 200. - if numCalls != 5 { - t.Fatal("expected numCalls to be 5, instead got", numCalls) - } + r.Equal(t, 5, numCalls) } func TestHandleUnprocessableEntity(t *testing.T) { @@ -49,27 +68,16 @@ func TestHandleUnprocessableEntity(t *testing.T) { w.WriteHeader(http.StatusUnprocessableEntity) })) - m := pmapi.New(pmapi.Config{HostURL: ts.URL}) + m := New(Config{HostURL: ts.URL}) - // Set the retry count to 5. m.SetRetryCount(5) // The call should fail because the first call should fail (422s are not retried). _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()) - if err == nil { - t.Fatal("expected error, instead got", err) - } - - // API-side errors get ErrAPIFailure - if !errors.Is(err, pmapi.ErrAPIFailure) { - t.Fatal("expected error to be ErrAPIFailure, instead got", err) - } - + r.EqualError(t, err, "422 Unprocessable Entity") // The server should be called 1 time. // The first call should return 422. - if numCalls != 1 { - t.Fatal("expected numCalls to be 1, instead got", numCalls) - } + r.Equal(t, 1, numCalls) } func TestHandleDialFailure(t *testing.T) { @@ -81,24 +89,17 @@ func TestHandleDialFailure(t *testing.T) { })) // The failingRoundTripper will fail the first 5 times it is used. - m := pmapi.New(pmapi.Config{HostURL: ts.URL}) - - // Set a custom transport. + m := New(Config{HostURL: ts.URL}) m.SetTransport(newFailingRoundTripper(5)) - - // Set the retry count to 5. m.SetRetryCount(5) // The call should succeed because the last retry should succeed (dial errors are retried). - if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err != nil { - t.Fatal("got unexpected error", err) - } + _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()) + r.NoError(t, err) // The server should be called 1 time. // The first 4 attempts don't reach the server. - if numCalls != 1 { - t.Fatal("expected numCalls to be 1, instead got", numCalls) - } + r.Equal(t, 1, numCalls) } func TestHandleTooManyDialFailures(t *testing.T) { @@ -112,28 +113,15 @@ func TestHandleTooManyDialFailures(t *testing.T) { // The failingRoundTripper will fail the first 10 times it is used. // This is more than the number of retries we permit. // Thus, dials will fail. - m := pmapi.New(pmapi.Config{HostURL: ts.URL}) - - // Set a custom transport. + m := New(Config{HostURL: ts.URL}) m.SetTransport(newFailingRoundTripper(10)) - - // Set the retry count to 5. m.SetRetryCount(5) // The call should fail because every dial will fail and we'll run out of retries. _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()) - if err == nil { - t.Fatal("expected error, instead got", err) - } - - if !errors.Is(err, pmapi.ErrNoConnection) { - t.Fatal("expected error to be ErrNoConnection, instead got", err) - } - + r.EqualError(t, err, "no internet connection") // The server should never be called. - if numCalls != 0 { - t.Fatal("expected numCalls to be 0, instead got", numCalls) - } + r.Equal(t, 0, numCalls) } func TestRetriesWithContextTimeout(t *testing.T) { @@ -150,24 +138,16 @@ func TestRetriesWithContextTimeout(t *testing.T) { })) // Theoretically, this should succeed; on the fifth retry, we'll get StatusOK. - m := pmapi.New(pmapi.Config{HostURL: ts.URL}) - - // Set the retry count to 5. + m := New(Config{HostURL: ts.URL}) m.SetRetryCount(5) - // However, that will take ~5s, and we only allow 1s in the context. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + // However, that will take ~0.5s, and we only allow 10ms in the context. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() // Thus, it will fail. _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(ctx) - if err == nil { - t.Fatal("expected error, instead got", err) - } - - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatal("expected error to be DeadlineExceeded, instead got", err) - } + r.EqualError(t, err, context.DeadlineExceeded.Error()) } func TestObserveConnectionStatus(t *testing.T) { @@ -177,36 +157,24 @@ func TestObserveConnectionStatus(t *testing.T) { var onDown, onUp bool - m := pmapi.New(pmapi.Config{HostURL: ts.URL}) - - // Set a custom transport. + m := New(Config{HostURL: ts.URL}) m.SetTransport(newFailingRoundTripper(10)) - - // Set the retry count to 5. m.SetRetryCount(5) - - // Add a connection observer. - m.AddConnectionObserver(pmapi.NewConnectionObserver(func() { onDown = true }, func() { onUp = true })) + m.AddConnectionObserver(NewConnectionObserver(func() { onDown = true }, func() { onUp = true })) // The call should fail because every dial will fail and we'll run out of retries. - if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err == nil { - t.Fatal("expected error, instead got", err) - } - - if onDown != true || onUp == true { - t.Fatal("expected onDown to have been called and onUp to not have been called") - } + _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()) + r.Error(t, err) + r.False(t, onUp) + r.True(t, onDown) onDown, onUp = false, false // The call should succeed because the last dial attempt will succeed. - if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err != nil { - t.Fatal("got unexpected error", err) - } - - if onDown == true || onUp != true { - t.Fatal("expected onUp to have been called and onDown to not have been called") - } + _, err = m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()) + r.NoError(t, err) + r.True(t, onUp) + r.False(t, onDown) } func TestReturnErrNoConnection(t *testing.T) { @@ -215,19 +183,27 @@ func TestReturnErrNoConnection(t *testing.T) { })) // We will fail more times than we retry, so requests should fail with ErrNoConnection. - m := pmapi.New(pmapi.Config{HostURL: ts.URL}) + m := New(Config{HostURL: ts.URL}) m.SetTransport(newFailingRoundTripper(10)) m.SetRetryCount(5) // The call should fail because every dial will fail and we'll run out of retries. _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()) - if err == nil { - t.Fatal("expected error, instead got", err) - } + r.EqualError(t, err, "no internet connection") +} - if !errors.Is(err, pmapi.ErrNoConnection) { - t.Fatal("expected error to be ErrNoConnection, instead got", err) - } +func TestReturnErrUpgradeApplication(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusUnprocessableEntity) + fmt.Fprint(w, testForceUpgradeBody) + })) + + m := New(Config{HostURL: ts.URL}) + + // The call should fail because every call return force upgrade error. + _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()) + r.EqualError(t, err, ErrUpgradeApplication.Error()) } type failingRoundTripper struct { diff --git a/pkg/pmapi/manager_types.go b/pkg/pmapi/manager_types.go index 1eed5829..0a9e5f0d 100644 --- a/pkg/pmapi/manager_types.go +++ b/pkg/pmapi/manager_types.go @@ -1,3 +1,20 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import ( @@ -6,21 +23,24 @@ import ( "time" "github.com/ProtonMail/gopenpgp/v2/crypto" - "github.com/go-resty/resty/v2" + "github.com/sirupsen/logrus" ) type Manager interface { NewClient(string, string, string, time.Time) Client - NewClientWithRefresh(context.Context, string, string) (Client, *Auth, error) + NewClientWithRefresh(context.Context, string, string) (Client, *AuthRefresh, error) NewClientWithLogin(context.Context, string, string) (Client, *Auth, error) DownloadAndVerify(kr *crypto.KeyRing, url, sig string) ([]byte, error) ReportBug(context.Context, ReportBugReq) error SendSimpleMetric(context.Context, string, string, string) error - SetLogger(resty.Logger) + SetLogging(logger *logrus.Entry, verbose bool) SetTransport(http.RoundTripper) SetCookieJar(http.CookieJar) SetRetryCount(int) AddConnectionObserver(ConnectionObserver) + + AllowProxy() + DisallowProxy() } diff --git a/pkg/pmapi/messages.go b/pkg/pmapi/messages.go index cc5909d0..2f8087da 100644 --- a/pkg/pmapi/messages.go +++ b/pkg/pmapi/messages.go @@ -518,7 +518,20 @@ func (c *client) ListMessages(ctx context.Context, filter *MessagesFilter) ([]*M // CountMessages counts messages by label. func (c *client) CountMessages(ctx context.Context, addressID string) (counts []*MessagesCount, err error) { - panic("TODO") + var res struct { + Counts []*MessagesCount + } + + if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + if addressID != "" { + r = r.SetQueryParam("AddressID", addressID) + } + return r.SetResult(&res).Get("/mail/v4/messages/count") + }); err != nil { + return nil, err + } + + return res.Counts, nil } // GetMessage retrieves a message. @@ -640,6 +653,10 @@ func (c *client) UnlabelMessages(ctx context.Context, messageIDs []string, label } func (c *client) EmptyFolder(ctx context.Context, labelID, addressID string) error { + if labelID == "" { + return errors.New("labelID parameter is empty string") + } + if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { if addressID != "" { r.SetQueryParam("AddressID", addressID) diff --git a/pkg/pmapi/messages_test.go b/pkg/pmapi/messages_test.go index a15cdee2..20f6d8e0 100644 --- a/pkg/pmapi/messages_test.go +++ b/pkg/pmapi/messages_test.go @@ -24,7 +24,8 @@ import ( "testing" "github.com/ProtonMail/gopenpgp/v2/crypto" - "github.com/stretchr/testify/assert" + a "github.com/stretchr/testify/assert" + r "github.com/stretchr/testify/require" ) const testMessageCleartext = `
jeej saas

Sent from ProtonMail, encrypted email based in Switzerland.

` @@ -127,66 +128,65 @@ ClW54lp9eeOfYTsdTSbn9VaSO0E6m2/Q4Tk= func TestMessage_IsBodyEncrypted(t *testing.T) { msg := &Message{Body: testMessageEncrypted} - Assert(t, msg.IsBodyEncrypted(), "the body should be encrypted") + r.True(t, msg.IsBodyEncrypted(), "the body should be encrypted") msg.Body = testMessageCleartext - Assert(t, !msg.IsBodyEncrypted(), "the body should not be encrypted") + r.True(t, !msg.IsBodyEncrypted(), "the body should not be encrypted") } func TestMessage_Decrypt(t *testing.T) { msg := &Message{Body: testMessageEncrypted} dec, err := msg.Decrypt(testPrivateKeyRing) - Ok(t, err) - Equals(t, testMessageCleartext, string(dec)) + r.NoError(t, err) + r.Equal(t, testMessageCleartext, string(dec)) } func TestMessage_Decrypt_Legacy(t *testing.T) { testPrivateKeyLegacy := readTestFile("testPrivateKeyLegacy", false) key, err := crypto.NewKeyFromArmored(testPrivateKeyLegacy) - Ok(t, err) + r.NoError(t, err) unlockedKey, err := key.Unlock([]byte(testMailboxPasswordLegacy)) - Ok(t, err) + r.NoError(t, err) testPrivateKeyRingLegacy, err := crypto.NewKeyRing(unlockedKey) - Ok(t, err) + r.NoError(t, err) msg := &Message{Body: testMessageEncryptedLegacy} dec, err := msg.Decrypt(testPrivateKeyRingLegacy) - Ok(t, err) + r.NoError(t, err) - Equals(t, testMessageCleartextLegacy, string(dec)) + r.Equal(t, testMessageCleartextLegacy, string(dec)) } func TestMessage_Decrypt_signed(t *testing.T) { msg := &Message{Body: testMessageSigned} dec, err := msg.Decrypt(testPrivateKeyRing) - Ok(t, err) - Equals(t, testMessageCleartext, string(dec)) + r.NoError(t, err) + r.Equal(t, testMessageCleartext, string(dec)) } func TestMessage_Encrypt(t *testing.T) { key, err := crypto.NewKeyFromArmored(testMessageSigner) - Ok(t, err) + r.NoError(t, err) signer, err := crypto.NewKeyRing(key) - Ok(t, err) + r.NoError(t, err) msg := &Message{Body: testMessageCleartext} - Ok(t, msg.Encrypt(testPrivateKeyRing, testPrivateKeyRing)) + r.NoError(t, msg.Encrypt(testPrivateKeyRing, testPrivateKeyRing)) dec, err := msg.Decrypt(testPrivateKeyRing) - Ok(t, err) + r.NoError(t, err) - Equals(t, testMessageCleartext, string(dec)) - Equals(t, testIdentity, signer.GetIdentities()[0]) + r.Equal(t, testMessageCleartext, string(dec)) + r.Equal(t, testIdentity, signer.GetIdentities()[0]) } -func routeLabelMessages(tb testing.TB, w http.ResponseWriter, r *http.Request) string { - Ok(tb, checkMethodAndPath(r, "PUT", "/mail/v4/messages/label")) - +func routeLabelMessages(tb testing.TB, w http.ResponseWriter, req *http.Request) string { + r.NoError(tb, checkMethodAndPath(req, "PUT", "/mail/v4/messages/label")) return "messages/label/put_response.json" } @@ -203,7 +203,7 @@ func TestMessage_LabelMessages_NoPaging(t *testing.T) { ) defer finish() - assert.NoError(t, c.LabelMessages(context.TODO(), testIDs, "mylabel")) + a.NoError(t, c.LabelMessages(context.Background(), testIDs, "mylabel")) } func TestMessage_LabelMessages_Paging(t *testing.T) { @@ -221,5 +221,5 @@ func TestMessage_LabelMessages_Paging(t *testing.T) { ) defer finish() - assert.NoError(t, c.LabelMessages(context.TODO(), testIDs, "mylabel")) + a.NoError(t, c.LabelMessages(context.Background(), testIDs, "mylabel")) } diff --git a/pkg/pmapi/mocks/mocks.go b/pkg/pmapi/mocks/mocks.go index ecf86a3a..ac6cdaaf 100644 --- a/pkg/pmapi/mocks/mocks.go +++ b/pkg/pmapi/mocks/mocks.go @@ -13,8 +13,8 @@ import ( crypto "github.com/ProtonMail/gopenpgp/v2/crypto" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" - resty "github.com/go-resty/resty/v2" gomock "github.com/golang/mock/gomock" + logrus "github.com/sirupsen/logrus" ) // MockClient is a mock of Client interface @@ -40,16 +40,16 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { return m.recorder } -// AddAuthHandler mocks base method -func (m *MockClient) AddAuthHandler(arg0 pmapi.AuthHandler) { +// AddAuthRefreshHandler mocks base method +func (m *MockClient) AddAuthRefreshHandler(arg0 pmapi.AuthRefreshHandler) { m.ctrl.T.Helper() - m.ctrl.Call(m, "AddAuthHandler", arg0) + m.ctrl.Call(m, "AddAuthRefreshHandler", arg0) } -// AddAuthHandler indicates an expected call of AddAuthHandler -func (mr *MockClientMockRecorder) AddAuthHandler(arg0 interface{}) *gomock.Call { +// AddAuthRefreshHandler indicates an expected call of AddAuthRefreshHandler +func (mr *MockClientMockRecorder) AddAuthRefreshHandler(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddAuthHandler", reflect.TypeOf((*MockClient)(nil).AddAuthHandler), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddAuthRefreshHandler", reflect.TypeOf((*MockClient)(nil).AddAuthRefreshHandler), arg0) } // Addresses mocks base method @@ -67,7 +67,7 @@ func (mr *MockClientMockRecorder) Addresses() *gomock.Call { } // Auth2FA mocks base method -func (m *MockClient) Auth2FA(arg0 context.Context, arg1 pmapi.Auth2FAReq) error { +func (m *MockClient) Auth2FA(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Auth2FA", arg0, arg1) ret0, _ := ret[0].(error) @@ -616,6 +616,30 @@ func (mr *MockManagerMockRecorder) AddConnectionObserver(arg0 interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConnectionObserver", reflect.TypeOf((*MockManager)(nil).AddConnectionObserver), arg0) } +// AllowProxy mocks base method +func (m *MockManager) AllowProxy() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AllowProxy") +} + +// AllowProxy indicates an expected call of AllowProxy +func (mr *MockManagerMockRecorder) AllowProxy() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockManager)(nil).AllowProxy)) +} + +// DisallowProxy mocks base method +func (m *MockManager) DisallowProxy() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisallowProxy") +} + +// DisallowProxy indicates an expected call of DisallowProxy +func (mr *MockManagerMockRecorder) DisallowProxy() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockManager)(nil).DisallowProxy)) +} + // DownloadAndVerify mocks base method func (m *MockManager) DownloadAndVerify(arg0 *crypto.KeyRing, arg1, arg2 string) ([]byte, error) { m.ctrl.T.Helper() @@ -662,11 +686,11 @@ func (mr *MockManagerMockRecorder) NewClientWithLogin(arg0, arg1, arg2 interface } // NewClientWithRefresh mocks base method -func (m *MockManager) NewClientWithRefresh(arg0 context.Context, arg1, arg2 string) (pmapi.Client, *pmapi.Auth, error) { +func (m *MockManager) NewClientWithRefresh(arg0 context.Context, arg1, arg2 string) (pmapi.Client, *pmapi.AuthRefresh, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "NewClientWithRefresh", arg0, arg1, arg2) ret0, _ := ret[0].(pmapi.Client) - ret1, _ := ret[1].(*pmapi.Auth) + ret1, _ := ret[1].(*pmapi.AuthRefresh) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } @@ -717,16 +741,16 @@ func (mr *MockManagerMockRecorder) SetCookieJar(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCookieJar", reflect.TypeOf((*MockManager)(nil).SetCookieJar), arg0) } -// SetLogger mocks base method -func (m *MockManager) SetLogger(arg0 resty.Logger) { +// SetLogging mocks base method +func (m *MockManager) SetLogging(arg0 *logrus.Entry, arg1 bool) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLogger", arg0) + m.ctrl.Call(m, "SetLogging", arg0, arg1) } -// SetLogger indicates an expected call of SetLogger -func (mr *MockManagerMockRecorder) SetLogger(arg0 interface{}) *gomock.Call { +// SetLogging indicates an expected call of SetLogging +func (mr *MockManagerMockRecorder) SetLogging(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLogger", reflect.TypeOf((*MockManager)(nil).SetLogger), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLogging", reflect.TypeOf((*MockManager)(nil).SetLogging), arg0, arg1) } // SetRetryCount mocks base method diff --git a/pkg/pmapi/observer.go b/pkg/pmapi/observer.go index c8f18502..fa44e9d5 100644 --- a/pkg/pmapi/observer.go +++ b/pkg/pmapi/observer.go @@ -1,3 +1,20 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi type ConnectionObserver interface { diff --git a/pkg/pmapi/out b/pkg/pmapi/out deleted file mode 100644 index c56a5d51..00000000 --- a/pkg/pmapi/out +++ /dev/null @@ -1,25 +0,0 @@ --- addresses.go --- attachments.go --- auth.go --- contacts.go --- events.go --- import.go --- key.go --- keyring.go --- labels.go --- manager_auth.go --- manager_download.go --- manager.go --- manager_metrics.go --- manager_ping.go --- manager_report.go --- manager_report_types.go --- manager_types.go --- message_send.go --- messages.go --- metrics.go --- observer.go --- passwords.go --- settings.go --- users.go --- utils.go diff --git a/pkg/pmapi/paging.go b/pkg/pmapi/paging.go index de18cc97..a2dcc0a5 100644 --- a/pkg/pmapi/paging.go +++ b/pkg/pmapi/paging.go @@ -1,8 +1,25 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi const defaultPageSize = 100 -func doPaged(elements []string, pageSize int, fn func([]string) error) error { +func doPaged(elements []string, pageSize int, fn func([]string) error) error { //nolint[unparam] for len(elements) > pageSize { if err := fn(elements[:pageSize]); err != nil { return err diff --git a/pkg/pmapi/pmapi.go b/pkg/pmapi/pmapi.go new file mode 100644 index 00000000..5a8fb135 --- /dev/null +++ b/pkg/pmapi/pmapi.go @@ -0,0 +1,24 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pmapi + +import ( + "github.com/sirupsen/logrus" +) + +var log = logrus.WithField("pkg", "pmapi") //nolint[gochecknoglobals] diff --git a/pkg/pmapi/response.go b/pkg/pmapi/response.go index 55da1a92..13458015 100644 --- a/pkg/pmapi/response.go +++ b/pkg/pmapi/response.go @@ -1,12 +1,35 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package pmapi import ( + "math/rand" "net/http" "strconv" "time" "github.com/go-resty/resty/v2" "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +const ( + errCodeUpgradeApplication = 5003 ) type Error struct { @@ -18,26 +41,34 @@ func (err Error) Error() string { return err.Message } -func catchAPIError(_ *resty.Client, res *resty.Response) error { +func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error { if !res.IsError() { return nil } + if res.StatusCode() == http.StatusUnauthorized { + return ErrUnauthorized + } + var err error if apiErr, ok := res.Error().(*Error); ok { - err = apiErr + switch { + case apiErr.Code == errCodeUpgradeApplication: + err = ErrUpgradeApplication + if m.cfg.UpgradeApplicationHandler != nil { + m.cfg.UpgradeApplicationHandler() + } + case res.StatusCode() == http.StatusUnprocessableEntity: + err = ErrUnprocessableEntity{apiErr} + default: + err = apiErr + } } else { err = errors.New(res.Status()) } - switch res.StatusCode() { - case http.StatusUnauthorized: - return errors.Wrap(ErrUnauthorized, err.Error()) - - default: - return errors.Wrap(ErrAPIFailure, err.Error()) - } + return err } func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error) { @@ -45,38 +76,47 @@ func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error if after := res.Header().Get("Retry-After"); after != "" { seconds, err := strconv.Atoi(after) if err != nil { - return 0, err + log.WithError(err).Warning("Cannot convert Retry-After to number") + seconds = 10 } + // To avoid spikes when all clients retry at the same time, we add some random wait. + seconds += rand.Intn(10) //nolint[gosec] It is OK to use weak random number generator here. + + log.Warningf("Retrying %s after %ds induced by http code %d", res.Request.URL, seconds, res.StatusCode()) return time.Duration(seconds) * time.Second, nil } } + // 0 and no error means default behaviour which is exponential backoff with jitter. return 0, nil } -func catchTooManyRequests(res *resty.Response, _ error) bool { +func shouldRetry(res *resty.Response, err error) bool { + if isRetryDisabled(res.Request.Context()) { + return false + } + return isTooManyRequest(res) || isNoResponse(res, err) +} + +func isTooManyRequest(res *resty.Response) bool { return res.StatusCode() == http.StatusTooManyRequests } -func catchNoResponse(res *resty.Response, err error) bool { +func isNoResponse(res *resty.Response, err error) bool { return res.RawResponse == nil && err != nil } -func catchProxyAvailable(res *resty.Response, err error) bool { - /* - if res.Request.Attempt < ... { - return false - } +func wrapNoConnection(res *resty.Response, err error) (*resty.Response, error) { + if err, ok := err.(*resty.ResponseError); ok { + return res, err + } - if response is not empty { - return false - } + if res.RawResponse != nil { + return res, err + } - if proxy is available { - return true - } - */ - - return false + // Log useful information and return back nicer and clear error message. + logrus.WithError(err).WithField("url", res.Request.URL).Warn("No internet connection") + return res, ErrNoConnection } diff --git a/pkg/pmapi/server_test.go b/pkg/pmapi/server_test.go index afbd54e6..8e912d61 100644 --- a/pkg/pmapi/server_test.go +++ b/pkg/pmapi/server_test.go @@ -24,7 +24,6 @@ import ( "net/http/httptest" "os" "path/filepath" - "reflect" "regexp" "runtime" "strconv" @@ -32,6 +31,7 @@ import ( "time" "github.com/hashicorp/go-multierror" + r "github.com/stretchr/testify/require" ) var ( @@ -40,36 +40,6 @@ var ( reHTTPCode = regexp.MustCompile(`(HTTP|get|post|put|delete)_(\d{3}).*.json`) ) -// Assert fails the test if the condition is false. -func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) { - if !condition { - _, file, line, _ := runtime.Caller(1) - vv := []interface{}{filepath.Base(file), line, colRed} - vv = append(vv, v...) - vv = append(vv, colNon) - fmt.Printf("%s:%d: %s"+msg+"%s\n\n", vv...) - tb.FailNow() - } -} - -// Ok fails the test if an err is not nil. -func Ok(tb testing.TB, err error) { - if err != nil { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("%s:%d: %sunexpected error: %s%s\n\n", filepath.Base(file), line, colRed, err.Error(), colNon) - tb.FailNow() - } -} - -// Equals fails the test if exp is not equal to act. -func Equals(tb testing.TB, exp, act interface{}) { - if !reflect.DeepEqual(exp, act) { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("%s:%d:\n\n%s\texp: %#v\n\n\tgot: %#v%s\n\n", filepath.Base(file), line, colRed, exp, act, colNon) - tb.FailNow() - } -} - func newTestConfig(url string) Config { return Config{ HostURL: url, @@ -77,7 +47,7 @@ func newTestConfig(url string) Config { } } -// newTestClient is old function and should be replaced everywhere by newTestServerCallbacks. +// newTestClient is old function and should be replaced everywhere by newTestClientCallbacks. func newTestClient(h http.Handler) (*httptest.Server, Client) { s := httptest.NewServer(h) @@ -93,7 +63,7 @@ func newTestClientCallbacks(tb testing.TB, callbacks ...func(testing.TB, http.Re reqNum++ if reqNum > len(callbacks) { fmt.Printf( - "%s:%d: %sServer was requeted %d times which is more requests than expected %d%s\n\n", + "%s:%d: %sServer was requested %d times which is more requests than expected %d times%s\n\n", file, line, colRed, reqNum, len(callbacks), colNon, ) tb.FailNow() @@ -134,22 +104,18 @@ func checkMethodAndPath(r *http.Request, method, path string) error { return result.ErrorOrNil() } -func httpResponse(code int) string { - return fmt.Sprintf("HTTP_%d.json", code) -} - func writeJSONResponsefromFile(tb testing.TB, w http.ResponseWriter, response string, reqNum int) { if match := reHTTPCode.FindAllSubmatch([]byte(response), -1); len(match) != 0 { httpCode, err := strconv.Atoi(string(match[0][len(match[0])-1])) - Ok(tb, err) + r.NoError(tb, err) w.WriteHeader(httpCode) } f, err := os.Open("./testdata/routes/" + response) - Ok(tb, err) + r.NoError(tb, err) w.Header().Set("content-type", "application/json;charset=utf-8") w.Header().Set("x-test-pmapi-response", fmt.Sprintf("%s:%d", tb.Name(), reqNum)) _, err = io.Copy(w, f) - Ok(tb, err) + r.NoError(tb, err) } func checkHeader(h http.Header, field, exp string) error { diff --git a/pkg/pmapi/types.go b/pkg/pmapi/types.go deleted file mode 100644 index 43bc7399..00000000 --- a/pkg/pmapi/types.go +++ /dev/null @@ -1,8 +0,0 @@ -package pmapi - -type Boolean int - -const ( - False Boolean = iota - True -) diff --git a/pkg/pmapi/users_test.go b/pkg/pmapi/users_test.go index 7d75381b..1a2445ce 100644 --- a/pkg/pmapi/users_test.go +++ b/pkg/pmapi/users_test.go @@ -42,22 +42,12 @@ var testCurrentUser = &User{ Keys: *loadPMKeys(readTestFile("keyring_userKey_JSON", false)), } -func routeGetUsers(tb testing.TB, w http.ResponseWriter, r *http.Request) string { - Ok(tb, checkMethodAndPath(r, "GET", "/users")) - Ok(tb, isAuthReq(r, testUID, testAccessToken)) - +func routeGetUsers(tb testing.TB, w http.ResponseWriter, req *http.Request) string { + r.NoError(tb, checkMethodAndPath(req, "GET", "/users")) + r.NoError(tb, isAuthReq(req, testUID, testAccessToken)) return "users/get_response.json" } -const testPublicKeysBody = `{ - "Code": 1000, - "RecipientType": 1, - "MIMEType": "text/html", - "Keys": [ - { "Flags": 3, "PublicKey": "-----BEGIN PGP PUBLIC KEY BLOCK-----\nVersion: OpenPGP.js v0.7.1\nComment: http://openpgpjs.org\n\nxsBNBFSI0BMBB/9td6B5RDzVSFTlFzYOS4JxIb5agtNW1rbA4FeLoC47bGLR\n8E42IA6aKcO4H0vOZ1lFms0URiKk1DjCMXn3AUErbxqiV5IATRZLwliH6vwy\nPI6j5rtGF8dyxYfwmLtoUNkDcPdcFEb4NCdowsN7e8tKU0bcpouZcQhAqawC\n9nEdaG/gS5w+2k4hZX2lOKS1EF5SvP48UadlspEK2PLAIp5wB9XsFS9ey2wu\nelzkSfDh7KUAlteqFGSMqIgYH62/gaKm+TcckfZeyiMHWFw6sfrcFQ3QOZPq\nahWt0Rn9XM5xBAxx5vW0oceuQ1vpvdfFlM5ix4gn/9w6MhmStaCee8/fABEB\nAAHNBlVzZXJJRMLAcgQQAQgAJgUCVIjQHQYLCQgHAwIJEASDR1Fk7GNTBBUI\nAgoDFgIBAhsDAh4BAADmhAf/Yt0mCfWqQ25NNGUN14pKKgnPm68zwj1SmMGa\npU7+7ItRpoFNaDwV5QYiQSLC1SvSb1ZeKoY928GPKfqYyJlBpTPL9zC1OHQj\n9+2yYauHjYW9JWQM7hst2S2LBcdiQPOs3ybWPaO9yaccV4thxKOCPvyClaS5\nb9T4Iv9GEVZQIUvArkwI8hyzIi6skRgxflGheq1O+S1W4Gzt2VtYvo8g8r6W\nGzAGMw2nrs2h0+vUr+dLDgIbFCTc5QU99d5jE/e5Hw8iqBxv9tqB1hVATf8T\nwC8aU5MTtxtabOiBgG0PsBs6oIwjFqEjpOIza2/AflPZfo7stp6IiwbwvTHo\n1NlHoM7ATQRUiNAdAQf/eOLJYxX4lUQUzrNQgASDNE8gJPj7ywcGzySyqr0Y\n5rbG57EjtKMIgZrpzJRpSCuRbBjfsltqJ5Q9TBAbPO+oR3rue0LqPKMnmr/q\nKsHswBJRfsb/dbktUNmv/f7R9IVyOuvyP6RgdGeloxdGNeWiZSA6AZYI+WGc\nxaOvVDPz8thtnML4G4MUhXxxNZ7JzQ0Lfz6mN8CCkblIP5xpcJsyRU7lUsGD\nEJGZX0JH/I8bRVN1Xu08uFinIkZyiXRJ5ZGgF3Dns6VbIWmbttY54tBELtk+\n5g9pNSl9qiYwiCdwuZrA//NmD3xlZIN8sG4eM7ZUibZ23vEq+bUt1++6Mpba\nGQARAQABwsBfBBgBCAATBQJUiNAfCRAEg0dRZOxjUwIbDAAAlpMH/085qZdO\nmGRAlbvViUNhF2rtHvCletC48WHGO1ueSh9VTxalkP21YAYLJ4JgJzArJ7tH\nlEeiKiHm8YU9KhLe11Yv/o3AiKIAQjJiQluvk+mWdMcddB4fBjL6ttMTRAXe\ngHnjtMoamHbSZdeUTUadv05Fl6ivWtpXlODG4V02YvDiGBUbDosdGXEqDtpT\ng6MYlj3QMvUiUNQvt7YGMJS8A9iQ9qBNzErgRW8L6CON2RmpQ/wgwP5nwUHz\nJjY51d82Vj8bZeI8LdsX41SPoUhyC7kmNYpw9ZRy7NlrCt8dBIOB4/BKEJ2G\nClW54lp9eeOfYTsdTSbn9VaSO0E6m2/Q4Tk=\n=WFtr\n-----END PGP PUBLIC KEY BLOCK-----"}, - { "Flags": 1, "PublicKey": "-----BEGIN PGP PUBLIC KEY BLOCK-----\nVersion: OpenPGP.js v0.7.1\nComment: http://openpgpjs.org\n\nxsBNBFSI0BMBB/9td6B5RDzVSFTlFzYOS4JxIb5agtNW1rbA4FeLoC47bGLR\n8E42IA6aKcO4H0vOZ1lFms0URiKk1DjCMXn3AUErbxqiV5IATRZLwliH6vwy\nPI6j5rtGF8dyxYfwmLtoUNkDcPdcFEb4NCdowsN7e8tKU0bcpouZcQhAqawC\n9nEdaG/gS5w+2k4hZX2lOKS1EF5SvP48UadlspEK2PLAIp5wB9XsFS9ey2wu\nelzkSfDh7KUAlteqFGSMqIgYH62/gaKm+TcckfZeyiMHWFw6sfrcFQ3QOZPq\nahWt0Rn9XM5xBAxx5vW0oceuQ1vpvdfFlM5ix4gn/9w6MhmStaCee8/fABEB\nAAHNBlVzZXJJRMLAcgQQAQgAJgUCVIjQHQYLCQgHAwIJEASDR1Fk7GNTBBUI\nAgoDFgIBAhsDAh4BAADmhAf/Yt0mCfWqQ25NNGUN14pKKgnPm68zwj1SmMGa\npU7+7ItRpoFNaDwV5QYiQSLC1SvSb1ZeKoY928GPKfqYyJlBpTPL9zC1OHQj\n9+2yYauHjYW9JWQM7hst2S2LBcdiQPOs3ybWPaO9yaccV4thxKOCPvyClaS5\nb9T4Iv9GEVZQIUvArkwI8hyzIi6skRgxflGheq1O+S1W4Gzt2VtYvo8g8r6W\nGzAGMw2nrs2h0+vUr+dLDgIbFCTc5QU99d5jE/e5Hw8iqBxv9tqB1hVATf8T\nwC8aU5MTtxtabOiBgG0PsBs6oIwjFqEjpOIza2/AflPZfo7stp6IiwbwvTHo\n1NlHoM7ATQRUiNAdAQf/eOLJYxX4lUQUzrNQgASDNE8gJPj7ywcGzySyqr0Y\n5rbG57EjtKMIgZrpzJRpSCuRbBjfsltqJ5Q9TBAbPO+oR3rue0LqPKMnmr/q\nKsHswBJRfsb/dbktUNmv/f7R9IVyOuvyP6RgdGeloxdGNeWiZSA6AZYI+WGc\nxaOvVDPz8thtnML4G4MUhXxxNZ7JzQ0Lfz6mN8CCkblIP5xpcJsyRU7lUsGD\nEJGZX0JH/I8bRVN1Xu08uFinIkZyiXRJ5ZGgF3Dns6VbIWmbttY54tBELtk+\n5g9pNSl9qiYwiCdwuZrA//NmD3xlZIN8sG4eM7ZUibZ23vEq+bUt1++6Mpba\nGQARAQABwsBfBBgBCAATBQJUiNAfCRAEg0dRZOxjUwIbDAAAlpMH/085qZdO\nmGRAlbvViUNhF2rtHvCletC48WHGO1ueSh9VTxalkP21YAYLJ4JgJzArJ7tH\nlEeiKiHm8YU9KhLe11Yv/o3AiKIAQjJiQluvk+mWdMcddB4fBjL6ttMTRAXe\ngHnjtMoamHbSZdeUTUadv05Fl6ivWtpXlODG4V02YvDiGBUbDosdGXEqDtpT\ng6MYlj3QMvUiUNQvt7YGMJS8A9iQ9qBNzErgRW8L6CON2RmpQ/wgwP5nwUHz\nJjY51d82Vj8bZeI8LdsX41SPoUhyC7kmNYpw9ZRy7NlrCt8dBIOB4/BKEJ2G\nClW54lp9eeOfYTsdTSbn9VaSO0E6m2/Q4Tk=\n=WFtr\n-----END PGP PUBLIC KEY BLOCK-----"} - ]}` - func TestClient_CurrentUser(t *testing.T) { finish, c := newTestClientCallbacks(t, routeGetUsers, @@ -65,11 +55,11 @@ func TestClient_CurrentUser(t *testing.T) { ) defer finish() - user, err := c.CurrentUser(context.TODO()) + user, err := c.CurrentUser(context.Background()) r.Nil(t, err) // Ignore KeyRings during the check because they have unexported fields and cannot be compared r.True(t, cmp.Equal(user, testCurrentUser, cmpopts.IgnoreTypes(&crypto.Key{}))) - r.Nil(t, c.Unlock(context.TODO(), []byte(testMailboxPassword))) + r.Nil(t, c.Unlock(context.Background(), []byte(testMailboxPassword))) } diff --git a/test/context/context.go b/test/context/context.go index 239996b3..4d2a1dc0 100644 --- a/test/context/context.go +++ b/test/context/context.go @@ -96,17 +96,16 @@ type TestContext struct { func New(app string) *TestContext { setLogrusVerbosityFromEnv() - userAgent := useragent.New() - - pmapiController, clientManager := newPMAPIController() + listener := listener.New() + pmapiController, clientManager := newPMAPIController(app, listener) ctx := &TestContext{ t: &bddT{}, cache: newFakeCache(), locations: newFakeLocations(), settings: newFakeSettings(), - listener: listener.New(), - userAgent: userAgent, + listener: listener, + userAgent: useragent.New(), pmapiController: pmapiController, clientManager: clientManager, testAccounts: newTestAccounts(), @@ -137,14 +136,6 @@ func New(app string) *TestContext { return ctx } -func getConfigName(app string) string { - if app == "ie" { - return "importExport" - } - - return app -} - // Cleanup runs through all cleanup steps. // This can be a deferred call so that it is run even if the test steps failed the test. func (ctx *TestContext) Cleanup() *TestContext { diff --git a/test/context/credentials.go b/test/context/credentials.go index d20eecd3..0fc409cc 100644 --- a/test/context/credentials.go +++ b/test/context/credentials.go @@ -65,7 +65,6 @@ func (c *fakeCredStore) Add(userID, userName, uid, ref, mailboxPassword string, BridgePassword: bridgePassword, IsCombinedAddressMode: true, // otherwise by default starts in split mode } - return c.Get(userID) } @@ -74,12 +73,10 @@ func (c *fakeCredStore) Get(userID string) (*credentials.Credentials, error) { } func (c *fakeCredStore) SwitchAddressMode(userID string) (*credentials.Credentials, error) { - // FIXME(conman): Why is this empty? return c.credentials[userID], nil } func (c *fakeCredStore) UpdateEmails(userID string, emails []string) (*credentials.Credentials, error) { - // FIXME(conman): Why is this empty? return c.credentials[userID], nil } diff --git a/test/context/pmapi_controller.go b/test/context/pmapi_controller.go index 4c1265d7..bc768496 100644 --- a/test/context/pmapi_controller.go +++ b/test/context/pmapi_controller.go @@ -20,6 +20,8 @@ package context import ( "os" + "github.com/ProtonMail/proton-bridge/internal/events" + "github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/test/fakeapi" "github.com/ProtonMail/proton-bridge/test/liveapi" @@ -39,15 +41,26 @@ type PMAPIController interface { GetCalls(method, path string) [][]byte } -func newPMAPIController() (PMAPIController, pmapi.Manager) { +func newPMAPIController(app string, listener listener.Listener) (PMAPIController, pmapi.Manager) { switch os.Getenv(EnvName) { case EnvFake: - return fakeapi.NewController() + cntl, cm := fakeapi.NewController() + addConnectionObserver(cm, listener) + return cntl, cm case EnvLive: - return liveapi.NewController() + cntl, cm := liveapi.NewController(app) + addConnectionObserver(cm, listener) + return cntl, cm default: panic("unknown env") } } + +func addConnectionObserver(cm pmapi.Manager, listener listener.Listener) { + cm.AddConnectionObserver(pmapi.NewConnectionObserver( + func() { listener.Emit(events.InternetOffEvent, "") }, + func() { listener.Emit(events.InternetOnEvent, "") }, + )) +} diff --git a/test/context/pmapi_manager.go b/test/context/pmapi_manager.go deleted file mode 100644 index a4d6f0ed..00000000 --- a/test/context/pmapi_manager.go +++ /dev/null @@ -1,65 +0,0 @@ -package context - -import ( - "context" - "net/http" - "time" - - "github.com/ProtonMail/gopenpgp/v2/crypto" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" - "github.com/go-resty/resty/v2" -) - -func newLivePMAPIManager() pmapi.Manager { - return pmapi.New(pmapi.DefaultConfig) -} - -func newFakePMAPIManager() pmapi.Manager { - return &fakePMAPIManager{} -} - -type fakePMAPIManager struct{} - -func (*fakePMAPIManager) NewClient(string, string, string, time.Time) pmapi.Client { - panic("TODO") -} - -func (*fakePMAPIManager) NewClientWithRefresh(context.Context, string, string) (pmapi.Client, *pmapi.Auth, error) { - panic("TODO") -} - -func (*fakePMAPIManager) NewClientWithLogin(context.Context, string, string) (pmapi.Client, *pmapi.Auth, error) { - panic("TODO") -} - -func (*fakePMAPIManager) DownloadAndVerify(kr *crypto.KeyRing, url, sig string) ([]byte, error) { - panic("TODO") -} - -func (*fakePMAPIManager) ReportBug(context.Context, pmapi.ReportBugReq) error { - panic("TODO") -} - -func (*fakePMAPIManager) SendSimpleMetric(context.Context, string, string, string) error { - panic("TODO") -} - -func (*fakePMAPIManager) SetLogger(resty.Logger) { - panic("TODO") -} - -func (*fakePMAPIManager) SetTransport(http.RoundTripper) { - panic("TODO") -} - -func (*fakePMAPIManager) SetCookieJar(http.CookieJar) { - panic("TODO") -} - -func (*fakePMAPIManager) SetRetryCount(int) { - panic("TODO") -} - -func (*fakePMAPIManager) AddConnectionObserver(pmapi.ConnectionObserver) { - panic("TODO") -} diff --git a/test/context/users.go b/test/context/users.go index 5b21d839..8beaafbe 100644 --- a/test/context/users.go +++ b/test/context/users.go @@ -26,7 +26,6 @@ import ( "github.com/ProtonMail/proton-bridge/internal/store" "github.com/ProtonMail/proton-bridge/internal/users" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/srp" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -46,8 +45,8 @@ func (ctx *TestContext) LoginUser(username, password, mailboxPassword string) er return errors.Wrap(err, "failed to login") } - if auth.TwoFA.Enabled == pmapi.TOTPEnabled { - if err := client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: "2fa code"}); err != nil { + if auth.HasTwoFactor() { + if err := client.Auth2FA(context.Background(), "2fa code"); err != nil { return errors.Wrap(err, "failed to login with 2FA") } } diff --git a/test/fakeapi/auth.go b/test/fakeapi/auth.go index c1aeb98c..8f45a8df 100644 --- a/test/fakeapi/auth.go +++ b/test/fakeapi/auth.go @@ -23,8 +23,8 @@ import ( "github.com/ProtonMail/proton-bridge/pkg/pmapi" ) -func (api *FakePMAPI) Auth2FA(_ context.Context, req pmapi.Auth2FAReq) error { - if err := api.checkAndRecordCall(POST, "/auth/2fa", req); err != nil { +func (api *FakePMAPI) Auth2FA(_ context.Context, twoFactorCode string) error { + if err := api.checkAndRecordCall(POST, "/auth/2fa", twoFactorCode); err != nil { return err } @@ -50,7 +50,7 @@ func (api *FakePMAPI) AuthSalt(_ context.Context) (string, error) { return "", nil } -func (api *FakePMAPI) AddAuthHandler(handler pmapi.AuthHandler) { +func (api *FakePMAPI) AddAuthRefreshHandler(handler pmapi.AuthRefreshHandler) { api.authHandlers = append(api.authHandlers, handler) } diff --git a/test/fakeapi/controller.go b/test/fakeapi/controller.go index 8d011a6b..a01f0064 100644 --- a/test/fakeapi/controller.go +++ b/test/fakeapi/controller.go @@ -32,7 +32,7 @@ type Controller struct { labelIDGenerator idGenerator messageIDGenerator idGenerator tokenGenerator idGenerator - clientManager pmapi.Manager + clientManager *fakePMAPIManager // State controlled by test. noInternetConnection bool diff --git a/test/fakeapi/controller_calls.go b/test/fakeapi/controller_calls.go index 8f9bccf7..da7a7351 100644 --- a/test/fakeapi/controller_calls.go +++ b/test/fakeapi/controller_calls.go @@ -40,7 +40,7 @@ type fakeCall struct { request []byte } -func (ctl *Controller) recordCall(method method, path string, req interface{}) error { +func (ctl *Controller) checkAndRecordCall(method method, path string, req interface{}) error { ctl.lock.Lock() defer ctl.lock.Unlock() @@ -50,7 +50,7 @@ func (ctl *Controller) recordCall(method method, path string, req interface{}) e var err error if request, err = json.Marshal(req); err != nil { - return err + panic(err) } } diff --git a/test/fakeapi/controller_control.go b/test/fakeapi/controller_control.go index 08ca5c86..4d2f92cf 100644 --- a/test/fakeapi/controller_control.go +++ b/test/fakeapi/controller_control.go @@ -39,11 +39,17 @@ var systemLabelNameToID = map[string]string{ //nolint[gochecknoglobals] func (ctl *Controller) TurnInternetConnectionOff() { ctl.log.Warn("Turning OFF internet") ctl.noInternetConnection = true + for _, observer := range ctl.clientManager.connectionObservers { + observer.OnDown() + } } func (ctl *Controller) TurnInternetConnectionOn() { ctl.log.Warn("Turning ON internet") ctl.noInternetConnection = false + for _, observer := range ctl.clientManager.connectionObservers { + observer.OnUp() + } } func (ctl *Controller) ReorderAddresses(user *pmapi.User, addressIDs []string) error { @@ -52,7 +58,7 @@ func (ctl *Controller) ReorderAddresses(user *pmapi.User, addressIDs []string) e return errors.New("no such user") } - return api.ReorderAddresses(context.TODO(), addressIDs) + return api.ReorderAddresses(context.Background(), addressIDs) } func (ctl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, password string, twoFAEnabled bool) error { @@ -79,7 +85,7 @@ func (ctl *Controller) AddUserLabel(username string, label *pmapi.Label) error { label.Exclusive = getLabelExclusive(label.Name) prefix := "label" - if label.Exclusive == 1 { + if label.Exclusive { prefix = "folder" } label.ID = ctl.labelIDGenerator.next(prefix) @@ -127,11 +133,8 @@ func getLabelNameWithoutPrefix(name string) string { return name } -func getLabelExclusive(name string) int { - if strings.HasPrefix(name, "Folders/") { - return 1 - } - return 0 +func getLabelExclusive(name string) pmapi.Boolean { + return pmapi.Boolean(strings.HasPrefix(name, "Folders/")) } func (ctl *Controller) AddUserMessage(username string, message *pmapi.Message) (string, error) { diff --git a/test/fakeapi/counts.go b/test/fakeapi/counts.go index 9b2b1b64..5fcc888c 100644 --- a/test/fakeapi/counts.go +++ b/test/fakeapi/counts.go @@ -43,13 +43,12 @@ func (api *FakePMAPI) getCounts(addressID string) []*pmapi.MessagesCount { for _, labelID := range message.LabelIDs { if counts, ok := allCounts[labelID]; ok { counts.Total++ - if message.Unread == 1 { + if message.Unread { counts.Unread++ } } else { var unread int - - if message.Unread == pmapi.True { + if message.Unread { unread = 1 } diff --git a/test/fakeapi/fakeapi.go b/test/fakeapi/fakeapi.go index 3ab888f6..8d0d14c8 100644 --- a/test/fakeapi/fakeapi.go +++ b/test/fakeapi/fakeapi.go @@ -34,7 +34,7 @@ type FakePMAPI struct { controller *Controller eventIDGenerator idGenerator - authHandlers []pmapi.AuthHandler + authHandlers []pmapi.AuthRefreshHandler user *pmapi.User userKeyRing *crypto.KeyRing addresses *pmapi.AddressList @@ -45,25 +45,13 @@ type FakePMAPI struct { // uid represents the API UID. It is the unique session ID. uid string - acc string // FIXME(conman): Check this is correct! - ref string // FIXME(conman): Check this is correct! + acc string + ref string log *logrus.Entry } -func newFakePMAPI(controller *Controller, userID, uid, acc, ref string) *FakePMAPI { - return &FakePMAPI{ - controller: controller, - log: logrus.WithField("pkg", "fakeapi").WithField("uid", uid), - uid: uid, - acc: acc, // FIXME(conman): This should be checked! - ref: ref, // FIXME(conman): This should be checked! - userID: userID, - addrKeyRing: make(map[string]*crypto.KeyRing), - } -} - -func NewFakePMAPI(controller *Controller, username, userID, uid, acc, ref string) (*FakePMAPI, error) { +func newFakePMAPI(controller *Controller, username, userID, uid, acc, ref string) (*FakePMAPI, error) { user, ok := controller.usersByUsername[username] if !ok { return nil, fmt.Errorf("user %s does not exist", username) @@ -84,19 +72,28 @@ func NewFakePMAPI(controller *Controller, username, userID, uid, acc, ref string messages = []*pmapi.Message{} } - fakePMAPI := newFakePMAPI(controller, userID, uid, acc, ref) + fakePMAPI := &FakePMAPI{ + username: username, + userID: userID, + controller: controller, - fakePMAPI.log = fakePMAPI.log.WithField("username", username) - fakePMAPI.username = username - fakePMAPI.user = user.user - fakePMAPI.addresses = addresses - fakePMAPI.labels = labels - fakePMAPI.messages = messages + user: user.user, + addresses: addresses, + labels: labels, + messages: messages, + + uid: uid, + acc: acc, + ref: ref, + addrKeyRing: make(map[string]*crypto.KeyRing), + + log: logrus.WithField("pkg", "fakeapi").WithField("uid", uid).WithField("username", username), + } fakePMAPI.addEvent(&pmapi.Event{ EventID: fakePMAPI.eventIDGenerator.last("event"), Refresh: 0, - More: 0, + More: false, }) return fakePMAPI, nil @@ -112,13 +109,14 @@ func (api *FakePMAPI) checkAndRecordCall(method method, path string, request int api.log.WithField(string(method), path).Trace("CALL") - if err := api.controller.recordCall(method, path, request); err != nil { + if err := api.controller.checkAndRecordCall(method, path, request); err != nil { return err } - // FIXME(conman): This needs to match conman behaviour. Should try auth refresh somehow. if !api.controller.checkAccessToken(api.uid, api.acc) { - return pmapi.ErrUnauthorized + if err := api.authRefresh(); err != nil { + return err + } } if path != "/auth/2fa" && !api.controller.checkScope(api.uid) { @@ -128,6 +126,21 @@ func (api *FakePMAPI) checkAndRecordCall(method method, path string, request int return nil } +func (api *FakePMAPI) authRefresh() error { + if err := api.controller.checkAndRecordCall(POST, "/auth/refresh", []string{api.uid, api.ref}); err != nil { + return err + } + + session, err := api.controller.refreshSessionIfAuthorized(api.uid, api.ref) + if err != nil { + return err + } + + api.ref = session.ref + api.acc = session.acc + return nil +} + func (api *FakePMAPI) setUser(username string) error { api.username = username api.log = api.log.WithField("username", username) @@ -158,12 +171,3 @@ func (api *FakePMAPI) setUser(username string) error { return nil } - -func (api *FakePMAPI) unsetUser() { - api.uid = "" - api.acc = "" // FIXME(conman): This should be checked! - api.user = nil - api.labels = nil - api.messages = nil - api.events = nil -} diff --git a/test/fakeapi/labels.go b/test/fakeapi/labels.go index 0e81150d..9660f0b2 100644 --- a/test/fakeapi/labels.go +++ b/test/fakeapi/labels.go @@ -27,7 +27,7 @@ import ( func (api *FakePMAPI) isLabelFolder(labelID string) bool { for _, label := range api.labels { if label.ID == labelID { - return label.Exclusive == 1 + return bool(label.Exclusive) } } return labelID == pmapi.InboxLabel || labelID == pmapi.ArchiveLabel || labelID == pmapi.SentLabel @@ -50,7 +50,7 @@ func (api *FakePMAPI) CreateLabel(_ context.Context, label *pmapi.Label) (*pmapi } } prefix := "label" - if label.Exclusive == 1 { + if label.Exclusive { prefix = "folder" } label.ID = api.controller.labelIDGenerator.next(prefix) diff --git a/test/fakeapi/manager.go b/test/fakeapi/manager.go index 9cde6418..71da9d05 100644 --- a/test/fakeapi/manager.go +++ b/test/fakeapi/manager.go @@ -1,3 +1,20 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + package fakeapi import ( @@ -8,27 +25,36 @@ import ( "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/pkg/pmapi" - "github.com/go-resty/resty/v2" + "github.com/sirupsen/logrus" ) type fakePMAPIManager struct { - controller *Controller + controller *Controller + connectionObservers []pmapi.ConnectionObserver } func (m *fakePMAPIManager) NewClient(uid string, acc string, ref string, _ time.Time) pmapi.Client { + if uid == "" { + return &FakePMAPI{ + controller: m.controller, + log: logrus.WithField("pkg", "fakeapi"), + addrKeyRing: make(map[string]*crypto.KeyRing), + } + } + session, ok := m.controller.sessionsByUID[uid] if !ok { - return newFakePMAPI(m.controller, "", "", "", "") + panic("session " + uid + " is missing") } user, ok := m.controller.usersByUsername[session.username] if !ok { - return newFakePMAPI(m.controller, "", "", "", "") + panic("user " + session.username + " from session " + uid + " is missing") } - client, err := NewFakePMAPI(m.controller, session.username, user.user.ID, session.uid, session.acc, session.ref) + client, err := newFakePMAPI(m.controller, session.username, user.user.ID, session.uid, session.acc, session.ref) if err != nil { - return newFakePMAPI(m.controller, "", "", "", "") + panic(err) } m.controller.fakeAPIs = append(m.controller.fakeAPIs, client) @@ -36,15 +62,8 @@ func (m *fakePMAPIManager) NewClient(uid string, acc string, ref string, _ time. return client } -func (m *fakePMAPIManager) NewClientWithRefresh(_ context.Context, uid, ref string) (pmapi.Client, *pmapi.Auth, error) { - if err := m.controller.recordCall(POST, "/auth/refresh", &pmapi.AuthRefreshReq{ - UID: uid, - RefreshToken: ref, - ResponseType: "token", - GrantType: "refresh_token", - RedirectURI: "https://protonmail.ch", - State: "random_string", - }); err != nil { +func (m *fakePMAPIManager) NewClientWithRefresh(_ context.Context, uid, ref string) (pmapi.Client, *pmapi.AuthRefresh, error) { + if err := m.controller.checkAndRecordCall(POST, "/auth/refresh", []string{uid, ref}); err != nil { return nil, nil, err } @@ -58,31 +77,25 @@ func (m *fakePMAPIManager) NewClientWithRefresh(_ context.Context, uid, ref stri return nil, nil, errWrongNameOrPassword } - client, err := NewFakePMAPI(m.controller, session.username, user.user.ID, session.uid, session.acc, session.ref) + client, err := newFakePMAPI(m.controller, session.username, user.user.ID, session.uid, session.acc, session.ref) if err != nil { return nil, nil, err } m.controller.fakeAPIs = append(m.controller.fakeAPIs, client) - auth := &pmapi.Auth{ + auth := &pmapi.AuthRefresh{ UID: session.uid, AccessToken: session.acc, RefreshToken: session.ref, ExpiresIn: 86400, // seconds, } - if user.has2FA { - auth.TwoFA = pmapi.TwoFAInfo{ - Enabled: pmapi.TOTPEnabled, - } - } - return client, auth, nil } func (m *fakePMAPIManager) NewClientWithLogin(_ context.Context, username string, password string) (pmapi.Client, *pmapi.Auth, error) { - if err := m.controller.recordCall(POST, "/auth/info", &pmapi.GetAuthInfoReq{Username: username}); err != nil { + if err := m.controller.checkAndRecordCall(POST, "/auth/info", &pmapi.GetAuthInfoReq{Username: username}); err != nil { return nil, nil, err } @@ -93,7 +106,7 @@ func (m *fakePMAPIManager) NewClientWithLogin(_ context.Context, username string return nil, nil, errWrongNameOrPassword } - if err := m.controller.recordCall(POST, "/auth", &pmapi.AuthReq{Username: username}); err != nil { + if err := m.controller.checkAndRecordCall(POST, "/auth", &pmapi.AuthReq{Username: username}); err != nil { return nil, nil, err } @@ -102,7 +115,7 @@ func (m *fakePMAPIManager) NewClientWithLogin(_ context.Context, username string return nil, nil, err } - client, err := NewFakePMAPI(m.controller, username, user.user.ID, session.uid, session.acc, session.ref) + client, err := newFakePMAPI(m.controller, username, user.user.ID, session.uid, session.acc, session.ref) if err != nil { return nil, nil, err } @@ -110,14 +123,17 @@ func (m *fakePMAPIManager) NewClientWithLogin(_ context.Context, username string m.controller.fakeAPIs = append(m.controller.fakeAPIs, client) auth := &pmapi.Auth{ - UID: session.uid, - AccessToken: session.acc, - RefreshToken: session.ref, - ExpiresIn: 86400, // seconds, + UserID: user.user.ID, + AuthRefresh: pmapi.AuthRefresh{ + UID: session.uid, + AccessToken: session.acc, + RefreshToken: session.ref, + ExpiresIn: 86400, // seconds, + }, } if user.has2FA { - auth.TwoFA = pmapi.TwoFAInfo{ + auth.TwoFA = &pmapi.TwoFAInfo{ Enabled: pmapi.TOTPEnabled, } } @@ -125,40 +141,46 @@ func (m *fakePMAPIManager) NewClientWithLogin(_ context.Context, username string return client, auth, nil } -func (*fakePMAPIManager) DownloadAndVerify(kr *crypto.KeyRing, url, sig string) ([]byte, error) { - panic("TODO") +func (m *fakePMAPIManager) DownloadAndVerify(kr *crypto.KeyRing, url, sig string) ([]byte, error) { + panic("Not implemented: not used by tests") } -func (*fakePMAPIManager) ReportBug(context.Context, pmapi.ReportBugReq) error { - panic("TODO") +func (m *fakePMAPIManager) ReportBug(_ context.Context, bugReport pmapi.ReportBugReq) error { + return m.controller.checkAndRecordCall(POST, "/reports/bug", bugReport) } func (m *fakePMAPIManager) SendSimpleMetric(_ context.Context, cat string, act string, lab string) error { v := url.Values{} - v.Set("Category", cat) v.Set("Action", act) v.Set("Label", lab) - - return m.controller.recordCall(GET, "/metrics?"+v.Encode(), nil) + return m.controller.checkAndRecordCall(GET, "/metrics?"+v.Encode(), nil) } -func (*fakePMAPIManager) SetLogger(resty.Logger) { - panic("TODO") +func (m *fakePMAPIManager) SetLogging(*logrus.Entry, bool) { + // NOOP } -func (*fakePMAPIManager) SetTransport(http.RoundTripper) { - panic("TODO") +func (m *fakePMAPIManager) SetTransport(http.RoundTripper) { + // NOOP } -func (*fakePMAPIManager) SetCookieJar(http.CookieJar) { - panic("TODO") +func (m *fakePMAPIManager) SetCookieJar(http.CookieJar) { + // NOOP } -func (*fakePMAPIManager) SetRetryCount(int) { - panic("TODO") +func (m *fakePMAPIManager) SetRetryCount(int) { + // NOOP } -func (*fakePMAPIManager) AddConnectionObserver(pmapi.ConnectionObserver) { - panic("TODO") +func (m *fakePMAPIManager) AddConnectionObserver(connectionObserver pmapi.ConnectionObserver) { + m.connectionObservers = append(m.connectionObservers, connectionObserver) +} + +func (m *fakePMAPIManager) AllowProxy() { + // NOOP +} + +func (m *fakePMAPIManager) DisallowProxy() { + // NOOP } diff --git a/test/fakeapi/messages.go b/test/fakeapi/messages.go index c9713ec9..f3e9b5c8 100644 --- a/test/fakeapi/messages.go +++ b/test/fakeapi/messages.go @@ -132,15 +132,7 @@ func isMessageMatchingFilter(filter *pmapi.MessagesFilter, message *pmapi.Messag return false } if filter.Unread != nil { - var wantUnread pmapi.Boolean - - if *filter.Unread { - wantUnread = pmapi.True - } else { - wantUnread = pmapi.False - } - - if message.Unread != wantUnread { + if bool(message.Unread) != *filter.Unread { return false } } @@ -393,10 +385,10 @@ func (api *FakePMAPI) MarkMessagesRead(_ context.Context, apiIDs []string) error return api.updateMessages(PUT, "/mail/v4/messages/read", &pmapi.MessagesActionReq{ IDs: apiIDs, }, apiIDs, func(message *pmapi.Message) error { - if message.Unread == 0 { + if !message.Unread { return errWasNotUpdated } - message.Unread = 0 + message.Unread = false return nil }) } @@ -405,10 +397,10 @@ func (api *FakePMAPI) MarkMessagesUnread(_ context.Context, apiIDs []string) err err := api.updateMessages(PUT, "/mail/v4/messages/unread", &pmapi.MessagesActionReq{ IDs: apiIDs, }, apiIDs, func(message *pmapi.Message) error { - if message.Unread == 1 { + if message.Unread { return errWasNotUpdated } - message.Unread = 1 + message.Unread = true return nil }) if err != nil { diff --git a/test/fakeapi/user.go b/test/fakeapi/user.go index a1a739dd..7f42dda4 100644 --- a/test/fakeapi/user.go +++ b/test/fakeapi/user.go @@ -116,6 +116,12 @@ func (api *FakePMAPI) ReorderAddresses(_ context.Context, addressIDs []string) e } func (api *FakePMAPI) Addresses() pmapi.AddressList { + if api.controller.noInternetConnection { + return nil + } + if api.addresses == nil { + return pmapi.AddressList{} + } return *api.addresses } diff --git a/test/features/bridge/start.feature b/test/features/bridge/start.feature index 618f660b..2d0b82e7 100644 --- a/test/features/bridge/start.feature +++ b/test/features/bridge/start.feature @@ -16,7 +16,6 @@ Feature: Start bridge And "user" has loaded store And "user" has running event loop - @ignore Scenario: Start with connected user, no database file and internet connection Given there is connected user "user" And there is no database file for "user" @@ -24,15 +23,20 @@ Feature: Start bridge Then "user" is connected And "user" has loaded store And "user" has running event loop - And "user" is connected - @ignore Scenario: Start with connected user, no database file and no internet connection Given there is connected user "user" And there is no database file for "user" And there is no internet connection When bridge starts - Then "user" is disconnected + Then "user" is connected + And "user" does not have loaded store + And "user" does not have running event loop + And the internet connection is restored + And 5 seconds pass + Then "user" is connected + And "user" has loaded store + And "user" has running event loop Scenario: Start with disconnected user, database file and internet connection Given there is disconnected user "user" @@ -51,7 +55,6 @@ Feature: Start bridge And "user" has loaded store And "user" does not have running event loop - @ignore Scenario: Start with disconnected user, no database file and internet connection Given there is disconnected user "user" And there is no database file for "user" @@ -60,7 +63,6 @@ Feature: Start bridge And "user" does not have loaded store And "user" does not have running event loop - @ignore Scenario: Start with disconnected user, no database file and no internet connection Given there is disconnected user "user" And there is no database file for "user" diff --git a/test/liveapi/cleanup.go b/test/liveapi/cleanup.go index 3f5797c1..05638941 100644 --- a/test/liveapi/cleanup.go +++ b/test/liveapi/cleanup.go @@ -44,7 +44,7 @@ func cleanup(client pmapi.Client, addresses *pmapi.AddressList) error { func cleanSystemFolders(client pmapi.Client) error { for _, labelID := range []string{pmapi.InboxLabel, pmapi.SentLabel, pmapi.ArchiveLabel, pmapi.AllMailLabel, pmapi.DraftLabel} { for { - messages, total, err := client.ListMessages(context.TODO(), &pmapi.MessagesFilter{ + messages, total, err := client.ListMessages(context.Background(), &pmapi.MessagesFilter{ PageSize: 150, LabelID: labelID, }) @@ -61,7 +61,7 @@ func cleanSystemFolders(client pmapi.Client) error { messageIDs = append(messageIDs, message.ID) } - if err := client.DeleteMessages(context.TODO(), messageIDs); err != nil { + if err := client.DeleteMessages(context.Background(), messageIDs); err != nil { return errors.Wrap(err, "failed to delete messages") } @@ -74,7 +74,7 @@ func cleanSystemFolders(client pmapi.Client) error { } func cleanCustomLables(client pmapi.Client) error { - labels, err := client.ListLabels(context.TODO()) + labels, err := client.ListLabels(context.Background()) if err != nil { return errors.Wrap(err, "failed to list labels") } @@ -83,7 +83,7 @@ func cleanCustomLables(client pmapi.Client) error { if err := emptyFolder(client, label.ID); err != nil { return errors.Wrap(err, "failed to empty label") } - if err := client.DeleteLabel(context.TODO(), label.ID); err != nil { + if err := client.DeleteLabel(context.Background(), label.ID); err != nil { return errors.Wrap(err, "failed to delete label") } } @@ -93,7 +93,7 @@ func cleanCustomLables(client pmapi.Client) error { func cleanTrash(client pmapi.Client) error { for { - _, total, err := client.ListMessages(context.TODO(), &pmapi.MessagesFilter{ + _, total, err := client.ListMessages(context.Background(), &pmapi.MessagesFilter{ PageSize: 1, LabelID: pmapi.TrashLabel, }) @@ -115,12 +115,12 @@ func cleanTrash(client pmapi.Client) error { } func emptyFolder(client pmapi.Client, labelID string) error { - err := client.EmptyFolder(context.TODO(), labelID, "") + err := client.EmptyFolder(context.Background(), labelID, "") if err != nil { return err } for { - _, total, err := client.ListMessages(context.TODO(), &pmapi.MessagesFilter{ + _, total, err := client.ListMessages(context.Background(), &pmapi.MessagesFilter{ PageSize: 1, LabelID: labelID, }) @@ -142,5 +142,5 @@ func reorderAddresses(client pmapi.Client, addresses *pmapi.AddressList) error { addressIDs = append(addressIDs, address.ID) } - return client.ReorderAddresses(context.TODO(), addressIDs) + return client.ReorderAddresses(context.Background(), addressIDs) } diff --git a/test/liveapi/controller.go b/test/liveapi/controller.go index 811aeb8b..d4664007 100644 --- a/test/liveapi/controller.go +++ b/test/liveapi/controller.go @@ -21,6 +21,7 @@ import ( "net/http" "sync" + "github.com/ProtonMail/proton-bridge/internal/constants" "github.com/ProtonMail/proton-bridge/pkg/pmapi" ) @@ -36,19 +37,18 @@ type Controller struct { noInternetConnection bool } -func NewController() (*Controller, pmapi.Manager) { +func NewController(app string) (*Controller, pmapi.Manager) { + cm := pmapi.New(pmapi.NewConfig(getAppVersionName(app), constants.Version)) controller := &Controller{ lock: &sync.RWMutex{}, calls: []*fakeCall{}, pmapiByUsername: map[string]pmapi.Client{}, messageIDsByUsername: map[string][]string{}, + clientManager: cm, noInternetConnection: false, } - // FIXME(conman): Set connect values here? - cm := pmapi.New(pmapi.DefaultConfig) - cm.SetTransport(&fakeTransport{ ctl: controller, transport: http.DefaultTransport, @@ -56,3 +56,10 @@ func NewController() (*Controller, pmapi.Manager) { return controller, cm } + +func getAppVersionName(app string) string { + if app == "ie" { + return "importExport" + } + return app +} diff --git a/test/liveapi/labels.go b/test/liveapi/labels.go index 642177ff..105b74e9 100644 --- a/test/liveapi/labels.go +++ b/test/liveapi/labels.go @@ -45,7 +45,7 @@ func (ctl *Controller) AddUserLabel(username string, label *pmapi.Label) error { label.Exclusive = getLabelExclusive(label.Name) label.Name = getLabelNameWithoutPrefix(label.Name) label.Color = pmapi.LabelColors[0] - if _, err := client.CreateLabel(context.TODO(), label); err != nil { + if _, err := client.CreateLabel(context.Background(), label); err != nil { return errors.Wrap(err, "failed to create label") } return nil @@ -73,7 +73,7 @@ func (ctl *Controller) getLabelID(username, labelName string) (string, error) { return "", fmt.Errorf("user %s does not exist", username) } - labels, err := client.ListLabels(context.TODO()) + labels, err := client.ListLabels(context.Background()) if err != nil { return "", errors.Wrap(err, "failed to list labels") } @@ -98,9 +98,6 @@ func getLabelNameWithoutPrefix(name string) string { return name } -func getLabelExclusive(name string) int { - if strings.HasPrefix(name, "Folders/") { - return 1 - } - return 0 +func getLabelExclusive(name string) pmapi.Boolean { + return pmapi.Boolean(strings.HasPrefix(name, "Folders/")) } diff --git a/test/liveapi/messages.go b/test/liveapi/messages.go index 590dd308..ed696f81 100644 --- a/test/liveapi/messages.go +++ b/test/liveapi/messages.go @@ -61,7 +61,7 @@ func (ctl *Controller) AddUserMessage(username string, message *pmapi.Message) ( Message: body, } - results, err := client.Import(context.TODO(), pmapi.ImportMsgReqs{req}) + results, err := client.Import(context.Background(), pmapi.ImportMsgReqs{req}) if err != nil { return "", errors.Wrap(err, "failed to make an import") } @@ -85,7 +85,7 @@ func (ctl *Controller) GetMessages(username, labelID string) ([]*pmapi.Message, for { // ListMessages returns empty result, not error, asking for page out of range. - pageMessages, _, err := client.ListMessages(context.TODO(), &pmapi.MessagesFilter{ + pageMessages, _, err := client.ListMessages(context.Background(), &pmapi.MessagesFilter{ Page: page, PageSize: 150, LabelID: labelID, diff --git a/test/liveapi/transport.go b/test/liveapi/transport.go index 32c8d712..68ca4869 100644 --- a/test/liveapi/transport.go +++ b/test/liveapi/transport.go @@ -48,9 +48,11 @@ func (t *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { return nil, errors.Wrap(err, "failed to get body") } - body, err = ioutil.ReadAll(bodyReader) - if err != nil { - return nil, errors.Wrap(err, "failed to read body") + if bodyReader != nil { + body, err = ioutil.ReadAll(bodyReader) + if err != nil { + return nil, errors.Wrap(err, "failed to read body") + } } } t.ctl.recordCall(req.Method, req.URL.Path, body) diff --git a/test/liveapi/users.go b/test/liveapi/users.go index 660b7c78..7f328329 100644 --- a/test/liveapi/users.go +++ b/test/liveapi/users.go @@ -30,12 +30,12 @@ func (ctl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, p return godog.ErrPending } - client, _, err := ctl.clientManager.NewClientWithLogin(context.TODO(), user.Name, password) + client, _, err := ctl.clientManager.NewClientWithLogin(context.Background(), user.Name, password) if err != nil { return errors.Wrap(err, "failed to create new client") } - salt, err := client.AuthSalt(context.TODO()) + salt, err := client.AuthSalt(context.Background()) if err != nil { return errors.Wrap(err, "failed to get salt") } @@ -45,7 +45,7 @@ func (ctl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, p return errors.Wrap(err, "failed to hash mailbox password") } - if err := client.Unlock(context.TODO(), mailboxPassword); err != nil { + if err := client.Unlock(context.Background(), mailboxPassword); err != nil { return errors.Wrap(err, "failed to unlock user") } @@ -59,5 +59,5 @@ func (ctl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, p } func (ctl *Controller) ReorderAddresses(user *pmapi.User, addressIDs []string) error { - return ctl.pmapiByUsername[user.Name].ReorderAddresses(context.TODO(), addressIDs) + return ctl.pmapiByUsername[user.Name].ReorderAddresses(context.Background(), addressIDs) } diff --git a/test/store_checks_test.go b/test/store_checks_test.go index fdd8e96d..2b25153b 100644 --- a/test/store_checks_test.go +++ b/test/store_checks_test.go @@ -256,11 +256,10 @@ func messagesContainsMessageRow(account *accounts.TestAccount, allMessages []int } case "read": var unread pmapi.Boolean - if cell.Value == "true" { //nolint[goconst] - unread = pmapi.False + unread = false } else { - unread = pmapi.True + unread = true } if message.Unread != unread { @@ -299,7 +298,7 @@ func areAddressesSame(first, second string) bool { func messagesInMailboxForUserIsMarkedAsRead(bddMessageIDs, mailboxName, bddUserID string) error { return checkMessages(bddUserID, mailboxName, bddMessageIDs, func(message *store.Message) error { - if message.Message().Unread == 0 { + if !message.Message().Unread { return nil } return fmt.Errorf("message %s \"%s\" is expected to be read but is not", message.ID(), message.Message().Subject) @@ -308,7 +307,7 @@ func messagesInMailboxForUserIsMarkedAsRead(bddMessageIDs, mailboxName, bddUserI func messagesInMailboxForUserIsMarkedAsUnread(bddMessageIDs, mailboxName, bddUserID string) error { return checkMessages(bddUserID, mailboxName, bddMessageIDs, func(message *store.Message) error { - if message.Message().Unread == 1 { + if message.Message().Unread { return nil } return fmt.Errorf("message %s \"%s\" is expected to not be read but is", message.ID(), message.Message().Subject) diff --git a/test/users_checks_test.go b/test/users_checks_test.go index d28906be..05f4fa61 100644 --- a/test/users_checks_test.go +++ b/test/users_checks_test.go @@ -34,10 +34,6 @@ func UsersChecksFeatureContext(s *godog.Suite) { s.Step(`^"([^"]*)" does not have loaded store$`, userDoesNotHaveLoadedStore) s.Step(`^"([^"]*)" has running event loop$`, userHasRunningEventLoop) s.Step(`^"([^"]*)" does not have running event loop$`, userDoesNotHaveRunningEventLoop) - - // FIXME(conman): Write tests for new "auth" system. - // s.Step(`^"([^"]*)" does not have API auth$`, isNotAuthorized) - // s.Step(`^"([^"]*)" has API auth$`, isAuthorized) } func userHasAddressModeInMode(bddUserID, wantAddressMode string) error { @@ -158,36 +154,11 @@ func userDoesNotHaveRunningEventLoop(bddUserID string) error { if err != nil { return internalError(err, "getting store of %s", account.Username()) } + if store == nil { + return nil + } a.Eventually(ctx.GetTestingT(), func() bool { return store.TestGetEventLoop() == nil || !store.TestGetEventLoop().IsRunning() }, 5*time.Second, 10*time.Millisecond) return ctx.GetTestingError() } - -/* -func isAuthorized(bddUserID string) error { - account := ctx.GetTestAccount(bddUserID) - if account == nil { - return godog.ErrPending - } - user, err := ctx.GetUser(account.Username()) - if err != nil { - return internalError(err, "getting user %s", account.Username()) - } - a.Eventually(ctx.GetTestingT(), user.IsAuthorized, 5*time.Second, 10*time.Millisecond) - return ctx.GetTestingError() -} - -func isNotAuthorized(bddUserID string) error { - account := ctx.GetTestAccount(bddUserID) - if account == nil { - return godog.ErrPending - } - user, err := ctx.GetUser(account.Username()) - if err != nil { - return internalError(err, "getting user %s", account.Username()) - } - a.Eventually(ctx.GetTestingT(), func() bool { return !user.IsAuthorized() }, 5*time.Second, 10*time.Millisecond) - return ctx.GetTestingError() -} -*/