GODT-35: New pmapi client and manager using resty

This commit is contained in:
James Houlahan
2021-02-22 18:23:51 +01:00
committed by Jakub
parent 1d538e8540
commit 2284e9ede1
163 changed files with 3333 additions and 8124 deletions

View File

@ -254,12 +254,13 @@ coverage: test
go tool cover -html=/tmp/coverage.out -o=coverage.html go tool cover -html=/tmp/coverage.out -o=coverage.html
mocks: mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,ClientManager,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/transfer PanicHandler,ClientManager,IMAPClientProvider > internal/transfer/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/users/mocks/listener_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,ClientManager,BridgeUser,ChangeNotifier > internal/store/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/transfer PanicHandler,IMAPClientProvider > internal/transfer/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier > internal/store/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client,Manager > pkg/pmapi/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/message Fetcher > pkg/message/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/message Fetcher > pkg/message/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client > pkg/pmapi/mocks/mocks.go
lint: gofiles lint-golang lint-license lint-changelog lint: gofiles lint-golang lint-license lint-changelog

5
go.mod
View File

@ -40,7 +40,7 @@ require (
github.com/fatih/color v1.9.0 github.com/fatih/color v1.9.0
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/getsentry/sentry-go v0.8.0 github.com/getsentry/sentry-go v0.8.0
github.com/go-resty/resty/v2 v2.3.0 github.com/go-resty/resty/v2 v2.4.0
github.com/golang/mock v1.4.4 github.com/golang/mock v1.4.4
github.com/google/go-cmp v0.5.1 github.com/google/go-cmp v0.5.1
github.com/google/uuid v1.1.1 github.com/google/uuid v1.1.1
@ -50,7 +50,6 @@ require (
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
github.com/logrusorgru/aurora v2.0.3+incompatible github.com/logrusorgru/aurora v2.0.3+incompatible
github.com/mattn/go-runewidth v0.0.9 // indirect github.com/mattn/go-runewidth v0.0.9 // indirect
github.com/miekg/dns v1.1.30
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
github.com/olekukonko/tablewriter v0.0.4 // indirect github.com/olekukonko/tablewriter v0.0.4 // indirect
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
@ -64,7 +63,7 @@ require (
github.com/urfave/cli/v2 v2.2.0 github.com/urfave/cli/v2 v2.2.0
github.com/vmihailenco/msgpack/v5 v5.1.3 github.com/vmihailenco/msgpack/v5 v5.1.3
go.etcd.io/bbolt v1.3.5 go.etcd.io/bbolt v1.3.5
golang.org/x/net v0.0.0-20200707034311-ab3426394381 golang.org/x/net v0.0.0-20201224014010-6772e930b67b
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec
) )

20
go.sum
View File

@ -113,8 +113,8 @@ github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclK
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w= github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8= github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8=
github.com/go-resty/resty/v2 v2.3.0 h1:JOOeAvjSlapTT92p8xiS19Zxev1neGikoHsXJeOq8So= github.com/go-resty/resty/v2 v2.4.0 h1:s6TItTLejEI+2mn98oijC5w/Rk2YU+OA6x0mnZN6r6k=
github.com/go-resty/resty/v2 v2.3.0/go.mod h1:UpN9CgLZNsv4e9XG50UU8xdI0F43UQ4HmxLBDwaroHU= github.com/go-resty/resty/v2 v2.4.0/go.mod h1:B88+xCTEwvfD94NOuE6GS1wMlnoKNY8eEiNizfNwOwA=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
@ -195,8 +195,6 @@ github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/miekg/dns v1.1.30 h1:Qww6FseFn8PRfw07jueqIXqodm0JKiiKuK0DeXSqfyo=
github.com/miekg/dns v1.1.30/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -310,12 +308,10 @@ golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201224014010-6772e930b67b h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw=
golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU= golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
@ -330,14 +326,15 @@ golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec h1:A1qYjneJuzBZZ2gIB8rd6zrfq6l7SoEMJ8EsSilNK/U= golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec h1:A1qYjneJuzBZZ2gIB8rd6zrfq6l7SoEMJ8EsSilNK/U=
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@ -348,7 +345,6 @@ golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3
golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69 h1:yBHHx+XZqXJBm6Exke3N7V9gnlsyXxoCPEb1yVenjfk= golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69 h1:yBHHx+XZqXJBm6Exke3N7V9gnlsyXxoCPEb1yVenjfk=
golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -23,7 +23,7 @@
// - persistent settings // - persistent settings
// - event listener // - event listener
// - credentials store // - credentials store
// - pmapi ClientManager // - pmapi Manager
// In addition, the base initialises logging and reacts to command line arguments // In addition, the base initialises logging and reacts to command line arguments
// which control the log verbosity and enable cpu/memory profiling. // which control the log verbosity and enable cpu/memory profiling.
package base package base
@ -85,7 +85,7 @@ type Base struct {
Cache *cache.Cache Cache *cache.Cache
Listener listener.Listener Listener listener.Listener
Creds *credentials.Store Creds *credentials.Store
CM *pmapi.ClientManager CM pmapi.Manager
CookieJar *cookies.Jar CookieJar *cookies.Jar
UserAgent *useragent.UserAgent UserAgent *useragent.UserAgent
Updater *updater.Updater Updater *updater.Updater
@ -181,13 +181,26 @@ func New( // nolint[funlen]
kc = keychain.NewMissingKeychain() kc = keychain.NewMissingKeychain()
} }
// FIXME(conman): Customize config depending on build type (app version, host URL).
cm := pmapi.New(pmapi.DefaultConfig)
// FIXME(conman): Should this be a real object, not just created via callbacks?
cm.AddConnectionObserver(pmapi.NewConnectionObserver(
func() { listener.Emit(events.InternetOffEvent, "") },
func() { listener.Emit(events.InternetOnEvent, "") },
))
// FIXME(conman): Implement force upgrade observer.
// apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
// FIXME(conman): Set up fancy round tripper with DoH/TLS checks etc.
// cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
jar, err := cookies.NewCookieJar(settingsObj) jar, err := cookies.NewCookieJar(settingsObj)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cm := pmapi.NewClientManager(getAPIConfig(configName, listener), userAgent)
cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
cm.SetCookieJar(jar) cm.SetCookieJar(jar)
key, err := crypto.NewKeyFromArmored(updater.DefaultPublicKey) key, err := crypto.NewKeyFromArmored(updater.DefaultPublicKey)
@ -375,13 +388,3 @@ func (b *Base) doTeardown() error {
return nil return nil
} }
func getAPIConfig(configName string, listener listener.Listener) *pmapi.ClientConfig {
apiConfig := pmapi.GetAPIConfig(configName, constants.Version)
apiConfig.ConnectionOffHandler = func() { listener.Emit(events.InternetOffEvent, "") }
apiConfig.ConnectionOnHandler = func() { listener.Emit(events.InternetOnEvent, "") }
apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
return apiConfig
}

View File

@ -19,6 +19,7 @@
package bridge package bridge
import ( import (
"context"
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
@ -44,7 +45,7 @@ type Bridge struct {
locations Locator locations Locator
settings SettingsProvider settings SettingsProvider
clientManager users.ClientManager clientManager pmapi.Manager
updater Updater updater Updater
versioner Versioner versioner Versioner
} }
@ -56,7 +57,7 @@ func New(
sentryReporter *sentry.Reporter, sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler, panicHandler users.PanicHandler,
eventListener listener.Listener, eventListener listener.Listener,
clientManager users.ClientManager, clientManager pmapi.Manager,
credStorer users.CredentialsStorer, credStorer users.CredentialsStorer,
updater Updater, updater Updater,
versioner Versioner, versioner Versioner,
@ -64,10 +65,11 @@ func New(
// Allow DoH before starting the app if the user has previously set this setting. // Allow DoH before starting the app if the user has previously set this setting.
// This allows us to start even if protonmail is blocked. // This allows us to start even if protonmail is blocked.
if s.GetBool(settings.AllowProxyKey) { if s.GetBool(settings.AllowProxyKey) {
clientManager.AllowProxy() // FIXME(conman): Support enable/disable of DoH.
// clientManager.AllowProxy()
} }
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, clientManager, eventListener) storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener)
u := users.New(locations, panicHandler, eventListener, clientManager, credStorer, storeFactory, true) u := users.New(locations, panicHandler, eventListener, clientManager, credStorer, storeFactory, true)
b := &Bridge{ b := &Bridge{
Users: u, Users: u,
@ -118,28 +120,15 @@ func (b *Bridge) heartbeat() {
// ReportBug reports a new bug from the user. // ReportBug reports a new bug from the user.
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error { func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
c := b.clientManager.GetAnonymousClient() return b.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
defer c.Logout()
title := "[Bridge] Bug"
report := pmapi.ReportReq{
OS: osType, OS: osType,
OSVersion: osVersion, OSVersion: osVersion,
Browser: emailClient, Browser: emailClient,
Title: title, Title: "[Bridge] Bug",
Description: description, Description: description,
Username: accountName, Username: accountName,
Email: address, Email: address,
} })
if err := c.Report(report); err != nil {
log.Error("Reporting bug failed: ", err)
return err
}
log.Info("Bug successfully reported")
return nil
} }
// GetUpdateChannel returns currently set update channel. // GetUpdateChannel returns currently set update channel.

View File

@ -31,7 +31,6 @@ type storeFactory struct {
cache Cacher cache Cacher
sentryReporter *sentry.Reporter sentryReporter *sentry.Reporter
panicHandler users.PanicHandler panicHandler users.PanicHandler
clientManager users.ClientManager
eventListener listener.Listener eventListener listener.Listener
storeCache *store.Cache storeCache *store.Cache
} }
@ -40,14 +39,12 @@ func newStoreFactory(
cache Cacher, cache Cacher,
sentryReporter *sentry.Reporter, sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler, panicHandler users.PanicHandler,
clientManager users.ClientManager,
eventListener listener.Listener, eventListener listener.Listener,
) *storeFactory { ) *storeFactory {
return &storeFactory{ return &storeFactory{
cache: cache, cache: cache,
sentryReporter: sentryReporter, sentryReporter: sentryReporter,
panicHandler: panicHandler, panicHandler: panicHandler,
clientManager: clientManager,
eventListener: eventListener, eventListener: eventListener,
storeCache: store.NewCache(cache.GetIMAPCachePath()), storeCache: store.NewCache(cache.GetIMAPCachePath()),
} }
@ -56,7 +53,7 @@ func newStoreFactory(
// New creates new store for given user. // New creates new store for given user.
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) { func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
storePath := getUserStorePath(f.cache.GetDBDir(), user.ID()) storePath := getUserStorePath(f.cache.GetDBDir(), user.ID())
return store.New(f.sentryReporter, f.panicHandler, user, f.clientManager, f.eventListener, storePath, f.storeCache) return store.New(f.sentryReporter, f.panicHandler, user, f.eventListener, storePath, f.storeCache)
} }
// Remove removes all store files for given user. // Remove removes all store files for given user.

View File

@ -18,8 +18,10 @@
package cliie package cliie
import ( import (
"context"
"strings" "strings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/abiosoft/ishell" "github.com/abiosoft/ishell"
) )
@ -73,13 +75,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
return return
} }
if auth.HasTwoFactor() { if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty) twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
if twoFactor == "" { if twoFactor == "" {
return return
} }
err = client.Auth2FA(twoFactor, auth) err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
if err != nil { if err != nil {
f.processAPIError(err) f.processAPIError(err)
return return
@ -87,7 +89,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
} }
mailboxPassword := password mailboxPassword := password
if auth.HasMailboxPassword() { if auth.PasswordMode == pmapi.TwoPasswordMode {
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty) mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
} }
if mailboxPassword == "" { if mailboxPassword == "" {

View File

@ -20,7 +20,6 @@ package cliie
import ( import (
"strings" "strings"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/fatih/color" "github.com/fatih/color"
) )
@ -71,10 +70,13 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
func (f *frontendCLI) processAPIError(err error) { func (f *frontendCLI) processAPIError(err error) {
log.Warn("API error: ", err) log.Warn("API error: ", err)
switch err { switch err {
case pmapi.ErrAPINotReachable: // FIXME(conman): How to handle various API errors?
/*
case pmapi.ErrNoConnection:
f.notifyInternetOff() f.notifyInternetOff()
case pmapi.ErrUpgradeApplication: case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade() f.notifyNeedUpgrade()
*/
default: default:
f.Println("Server error:", err.Error()) f.Println("Server error:", err.Error())
} }

View File

@ -18,11 +18,13 @@
package cli package cli
import ( import (
"context"
"strings" "strings"
"github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/settings" "github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/internal/frontend/types" "github.com/ProtonMail/proton-bridge/internal/frontend/types"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/abiosoft/ishell" "github.com/abiosoft/ishell"
) )
@ -120,13 +122,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
return return
} }
if auth.HasTwoFactor() { if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty) twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
if twoFactor == "" { if twoFactor == "" {
return return
} }
err = client.Auth2FA(twoFactor, auth) err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
if err != nil { if err != nil {
f.processAPIError(err) f.processAPIError(err)
return return
@ -134,7 +136,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
} }
mailboxPassword := password mailboxPassword := password
if auth.HasMailboxPassword() { if auth.PasswordMode == pmapi.TwoPasswordMode {
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty) mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
} }
if mailboxPassword == "" { if mailboxPassword == "" {

View File

@ -20,7 +20,6 @@ package cli
import ( import (
"strings" "strings"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/fatih/color" "github.com/fatih/color"
) )
@ -71,10 +70,13 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
func (f *frontendCLI) processAPIError(err error) { func (f *frontendCLI) processAPIError(err error) {
log.Warn("API error: ", err) log.Warn("API error: ", err)
switch err { switch err {
case pmapi.ErrAPINotReachable: // FIXME(conman): How to handle various API errors?
/*
case pmapi.ErrNoConnection:
f.notifyInternetOff() f.notifyInternetOff()
case pmapi.ErrUpgradeApplication: case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade() f.notifyNeedUpgrade()
*/
default: default:
f.Println("Server error:", err.Error()) f.Println("Server error:", err.Error())
} }

View File

@ -164,7 +164,7 @@ func (a *Accounts) showLoginError(err error, scope string) bool {
return false return false
} }
log.Warnf("%s: %v", scope, err) log.Warnf("%s: %v", scope, err)
if err == pmapi.ErrAPINotReachable { if err == pmapi.ErrNoConnection {
a.qml.SetConnectionStatus(false) a.qml.SetConnectionStatus(false)
SendNotification(a.qml, TabAccount, a.qml.CanNotReachAPI()) SendNotification(a.qml, TabAccount, a.qml.CanNotReachAPI())
a.qml.ProcessFinished() a.qml.ProcessFinished()

View File

@ -130,7 +130,7 @@ func (s *FrontendQt) showLoginError(err error, scope string) bool {
return false return false
} }
log.Warnf("%s: %v", scope, err) log.Warnf("%s: %v", scope, err)
if err == pmapi.ErrAPINotReachable { if err == pmapi.ErrNoConnection {
s.Qml.SetConnectionStatus(false) s.Qml.SetConnectionStatus(false)
s.SendNotification(TabAccount, s.Qml.CanNotReachAPI()) s.SendNotification(TabAccount, s.Qml.CanNotReachAPI())
s.Qml.ProcessFinished() s.Qml.ProcessFinished()

View File

@ -20,6 +20,7 @@ package importexport
import ( import (
"bytes" "bytes"
"context"
"github.com/ProtonMail/proton-bridge/internal/transfer" "github.com/ProtonMail/proton-bridge/internal/transfer"
"github.com/ProtonMail/proton-bridge/internal/users" "github.com/ProtonMail/proton-bridge/internal/users"
@ -39,7 +40,7 @@ type ImportExport struct {
locations Locator locations Locator
cache Cacher cache Cacher
panicHandler users.PanicHandler panicHandler users.PanicHandler
clientManager users.ClientManager clientManager pmapi.Manager
} }
func New( func New(
@ -47,7 +48,7 @@ func New(
cache Cacher, cache Cacher,
panicHandler users.PanicHandler, panicHandler users.PanicHandler,
eventListener listener.Listener, eventListener listener.Listener,
clientManager users.ClientManager, clientManager pmapi.Manager,
credStorer users.CredentialsStorer, credStorer users.CredentialsStorer,
) *ImportExport { ) *ImportExport {
u := users.New(locations, panicHandler, eventListener, clientManager, credStorer, &storeFactory{}, false) u := users.New(locations, panicHandler, eventListener, clientManager, credStorer, &storeFactory{}, false)
@ -64,57 +65,31 @@ func New(
// ReportBug reports a new bug from the user. // ReportBug reports a new bug from the user.
func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error { func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
c := ie.clientManager.GetAnonymousClient() return ie.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
defer c.Logout()
title := "[Import-Export] Bug"
report := pmapi.ReportReq{
OS: osType, OS: osType,
OSVersion: osVersion, OSVersion: osVersion,
Browser: emailClient, Browser: emailClient,
Title: title, Title: "[Import-Export] Bug",
Description: description, Description: description,
Username: accountName, Username: accountName,
Email: address, Email: address,
} })
if err := c.Report(report); err != nil {
log.Error("Reporting bug failed: ", err)
return err
}
log.Info("Bug successfully reported")
return nil
} }
// ReportFile submits import report file. // ReportFile submits import report file.
func (ie *ImportExport) ReportFile(osType, osVersion, accountName, address string, logdata []byte) error { func (ie *ImportExport) ReportFile(osType, osVersion, accountName, address string, logdata []byte) error {
c := ie.clientManager.GetAnonymousClient() report := pmapi.ReportBugReq{
defer c.Logout()
title := "[Import-Export] report file"
description := "An Import-Export report from the user swam down the river."
report := pmapi.ReportReq{
OS: osType, OS: osType,
OSVersion: osVersion, OSVersion: osVersion,
Description: description, Description: "An Import-Export report from the user swam down the river.",
Title: title, Title: "[Import-Export] report file",
Username: accountName, Username: accountName,
Email: address, Email: address,
} }
report.AddAttachment("log", "report.log", bytes.NewReader(logdata)) report.AddAttachment("log", "report.log", bytes.NewReader(logdata))
if err := c.Report(report); err != nil { return ie.clientManager.ReportBug(context.TODO(), report)
log.Error("Sending report failed: ", err)
return err
}
log.Info("Report successfully sent")
return nil
} }
// GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account. // GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account.
@ -187,5 +162,5 @@ func (ie *ImportExport) getPMAPIProvider(username, address string) (*transfer.PM
log.WithError(err).Info("Address does not exist, using all addresses") log.WithError(err).Info("Address does not exist, using all addresses")
} }
return transfer.NewPMAPIProvider(ie.clientManager, user.ID(), addressID) return transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
} }

View File

@ -18,6 +18,7 @@
package smtp package smtp
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"strings" "strings"
@ -28,7 +29,7 @@ import (
) )
type messageGetter interface { type messageGetter interface {
GetMessage(string) (*pmapi.Message, error) GetMessage(context.Context, string) (*pmapi.Message, error)
} }
type sendRecorderValue struct { type sendRecorderValue struct {
@ -126,7 +127,7 @@ func (q *sendRecorder) isSendingOrSent(client messageGetter, hash string) (isSen
return true, false return true, false
} }
message, err := client.GetMessage(value.messageID) message, err := client.GetMessage(context.TODO(), value.messageID)
// Message could be deleted or there could be an internet issue or whatever, // Message could be deleted or there could be an internet issue or whatever,
// so let's assume the message was not sent. // so let's assume the message was not sent.
if err != nil { if err != nil {

View File

@ -18,6 +18,7 @@
package smtp package smtp
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/mail" "net/mail"
@ -33,7 +34,7 @@ type testSendRecorderGetMessageMock struct {
err error err error
} }
func (m *testSendRecorderGetMessageMock) GetMessage(messageID string) (*pmapi.Message, error) { func (m *testSendRecorderGetMessageMock) GetMessage(_ context.Context, messageID string) (*pmapi.Message, error) {
return m.message, m.err return m.message, m.err
} }

View File

@ -21,6 +21,7 @@ package smtp
import ( import (
"bytes" "bytes"
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io" "io"
@ -122,7 +123,7 @@ func (su *smtpUser) getSendPreferences(
} }
func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata, err error) { func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata, err error) {
emails, err := su.client().GetContactEmailByEmail(recipient, 0, 1000) emails, err := su.client().GetContactEmailByEmail(context.TODO(), recipient, 0, 1000)
if err != nil { if err != nil {
return return
} }
@ -134,7 +135,7 @@ func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata
} }
var contact pmapi.Contact var contact pmapi.Contact
if contact, err = su.client().GetContactByID(email.ContactID); err != nil { if contact, err = su.client().GetContactByID(context.TODO(), email.ContactID); err != nil {
return return
} }
@ -150,7 +151,7 @@ func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata
} }
func (su *smtpUser) getAPIKeyData(recipient string) (apiKeys []pmapi.PublicKey, isInternal bool, err error) { func (su *smtpUser) getAPIKeyData(recipient string) (apiKeys []pmapi.PublicKey, isInternal bool, err error) {
return su.client().GetPublicKeysForEmail(recipient) return su.client().GetPublicKeysForEmail(context.TODO(), recipient)
} }
// Discard currently processed message. // Discard currently processed message.
@ -218,7 +219,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
messageReader = io.TeeReader(messageReader, b) messageReader = io.TeeReader(messageReader, b)
mailSettings, err := su.client().GetMailSettings() mailSettings, err := su.client().GetMailSettings(context.TODO())
if err != nil { if err != nil {
return err return err
} }
@ -339,7 +340,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
// can lead to sending the wrong message. Also clients do not necessarily // can lead to sending the wrong message. Also clients do not necessarily
// delete the old draft. // delete the old draft.
if draftID != "" { if draftID != "" {
if err := su.client().DeleteMessages([]string{draftID}); err != nil { if err := su.client().DeleteMessages(context.TODO(), []string{draftID}); err != nil {
log.WithError(err).WithField("draftID", draftID).Warn("Original draft cannot be deleted") log.WithError(err).WithField("draftID", draftID).Warn("Original draft cannot be deleted")
} }
} }
@ -393,7 +394,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
return errors.New("error decoding subject message " + message.Header.Get("Subject")) return errors.New("error decoding subject message " + message.Header.Get("Subject"))
} }
if !su.continueSendingUnencryptedMail(subject) { if !su.continueSendingUnencryptedMail(subject) {
if err := su.client().DeleteMessages([]string{message.ID}); err != nil { if err := su.client().DeleteMessages(context.TODO(), []string{message.ID}); err != nil {
log.WithError(err).Warn("Failed to delete canceled messages") log.WithError(err).Warn("Failed to delete canceled messages")
} }
return errors.New("sending was canceled by user") return errors.New("sending was canceled by user")
@ -422,7 +423,7 @@ func (su *smtpUser) handleReferencesHeader(m *pmapi.Message) (draftID, parentID
if su.addressID != "" { if su.addressID != "" {
filter.AddressID = su.addressID filter.AddressID = su.addressID
} }
metadata, _, _ := su.client().ListMessages(filter) metadata, _, _ := su.client().ListMessages(context.TODO(), filter)
for _, m := range metadata { for _, m := range metadata {
if m.IsDraft() { if m.IsDraft() {
draftID = m.ID draftID = m.ID
@ -442,7 +443,7 @@ func (su *smtpUser) handleReferencesHeader(m *pmapi.Message) (draftID, parentID
if su.addressID != "" { if su.addressID != "" {
filter.AddressID = su.addressID filter.AddressID = su.addressID
} }
metadata, _, _ := su.client().ListMessages(filter) metadata, _, _ := su.client().ListMessages(context.TODO(), filter)
// There can be two or messages with the same external ID and then we cannot // There can be two or messages with the same external ID and then we cannot
// be sure which message should be parent. Better to not choose any. // be sure which message should be parent. Better to not choose any.
if len(metadata) == 1 { if len(metadata) == 1 {

View File

@ -18,6 +18,7 @@
package store package store
import ( import (
"context"
"math/rand" "math/rand"
"time" "time"
@ -80,7 +81,7 @@ func (loop *eventLoop) client() pmapi.Client {
func (loop *eventLoop) setFirstEventID() (err error) { func (loop *eventLoop) setFirstEventID() (err error) {
loop.log.Info("Setting first event ID") loop.log.Info("Setting first event ID")
event, err := loop.client().GetEvent("") event, err := loop.client().GetEvent(context.TODO(), "")
if err != nil { if err != nil {
loop.log.WithError(err).Error("Could not get latest event ID") loop.log.WithError(err).Error("Could not get latest event ID")
return return
@ -221,7 +222,8 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually // We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
// (e.g. no internet, ulimit reached etc.) // (e.g. no internet, ulimit reached etc.)
defer func() { defer func() {
if errors.Cause(err) == pmapi.ErrAPINotReachable { // FIXME(conman): How to handle errors of different types?
if errors.Is(err, pmapi.ErrNoConnection) {
l.Warn("Internet unavailable") l.Warn("Internet unavailable")
err = nil err = nil
} }
@ -232,18 +234,20 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
err = nil err = nil
} }
// FIXME(conman): Handle force upgrade.
/*
if errors.Cause(err) == pmapi.ErrUpgradeApplication { if errors.Cause(err) == pmapi.ErrUpgradeApplication {
l.Warn("Need to upgrade application") l.Warn("Need to upgrade application")
err = nil err = nil
} }
*/
_, errUnauthorized := errors.Cause(err).(*pmapi.ErrUnauthorized)
if err == nil { if err == nil {
loop.errCounter = 0 loop.errCounter = 0
} }
// All errors except Invalid Token (which is not possible to recover from) are ignored.
if err != nil && !errUnauthorized && errors.Cause(err) != pmapi.ErrInvalidToken { // All errors except ErrUnauthorized (which is not possible to recover from) are ignored.
if !errors.Is(err, pmapi.ErrUnauthorized) {
l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped") l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped")
loop.errCounter++ loop.errCounter++
if loop.errCounter == errMaxSentry { if loop.errCounter == errMaxSentry {
@ -264,7 +268,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
loop.pollCounter++ loop.pollCounter++
var event *pmapi.Event var event *pmapi.Event
if event, err = loop.client().GetEvent(loop.currentEventID); err != nil { if event, err = loop.client().GetEvent(context.TODO(), loop.currentEventID); err != nil {
return false, errors.Wrap(err, "failed to get event") return false, errors.Wrap(err, "failed to get event")
} }
@ -461,12 +465,16 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...") msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
if msg, err = loop.client().GetMessage(message.ID); err != nil { if msg, err = loop.client().GetMessage(context.TODO(), message.ID); err != nil {
// FIXME(conman): How to handle error of this particular type?
/*
if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok { if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok {
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API") msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
err = nil err = nil
continue continue
} }
*/
return errors.Wrap(err, "failed to get message from API for updating") return errors.Wrap(err, "failed to get message from API for updating")
} }

View File

@ -18,6 +18,7 @@
package store package store
import ( import (
"context"
"net/mail" "net/mail"
"testing" "testing"
"time" "time"
@ -39,15 +40,15 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
// Doesn't matter which IDs are used. // Doesn't matter which IDs are used.
// This test is trying to see whether event loop will immediately process // This test is trying to see whether event loop will immediately process
// next event if there is `More` of them. // next event if there is `More` of them.
m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{ m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{
EventID: "event50", EventID: "event50",
More: 1, More: 1,
}, nil), }, nil),
m.client.EXPECT().GetEvent("event50").Return(&pmapi.Event{ m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{
EventID: "event70", EventID: "event70",
More: 0, More: 0,
}, nil), }, nil),
m.client.EXPECT().GetEvent("event70").Return(&pmapi.Event{ m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{
EventID: "event71", EventID: "event71",
More: 0, More: 0,
}, nil), }, nil),
@ -165,7 +166,7 @@ func TestEventLoopDeletionPaused(t *testing.T) {
func testEvent(t *testing.T, m *mocksForStore, event *pmapi.Event) { func testEvent(t *testing.T, m *mocksForStore, event *pmapi.Event) {
eventReceived := make(chan struct{}) eventReceived := make(chan struct{})
m.client.EXPECT().GetEvent("latestEventID").DoAndReturn(func(eventID string) (*pmapi.Event, error) { m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").DoAndReturn(func(_ context.Context, eventID string) (*pmapi.Event, error) {
defer close(eventReceived) defer close(eventReceived)
return event, nil return event, nil
}) })

View File

@ -18,6 +18,8 @@
package store package store
import ( import (
"context"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -41,7 +43,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message // FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
// wrapping it. // wrapping it.
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) { func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
msg, err := storeMailbox.client().GetMessage(apiID) msg, err := storeMailbox.client().GetMessage(context.TODO(), apiID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -58,15 +60,17 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
} }
importReqs := &pmapi.ImportMsgReq{ importReqs := &pmapi.ImportMsgReq{
Metadata: &pmapi.ImportMetadata{
AddressID: msg.AddressID, AddressID: msg.AddressID,
Body: body,
Unread: msg.Unread, Unread: msg.Unread,
Flags: msg.Flags, Flags: msg.Flags,
Time: msg.Time, Time: msg.Time,
LabelIDs: labelIDs, LabelIDs: labelIDs,
},
Message: body,
} }
res, err := storeMailbox.client().Import([]*pmapi.ImportMsgReq{importReqs}) res, err := storeMailbox.client().Import(context.TODO(), pmapi.ImportMsgReqs{importReqs})
if err != nil { if err != nil {
return err return err
} }
@ -95,7 +99,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
return ErrAllMailOpNotAllowed return ErrAllMailOpNotAllowed
} }
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.client().LabelMessages(apiIDs, storeMailbox.labelID) return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
} }
// UnlabelMessages removes the label by calling an API. // UnlabelMessages removes the label by calling an API.
@ -108,7 +112,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
return ErrAllMailOpNotAllowed return ErrAllMailOpNotAllowed
} }
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID) return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
} }
// MarkMessagesRead marks the message read by calling an API. // MarkMessagesRead marks the message read by calling an API.
@ -135,7 +139,7 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
if len(ids) == 0 { if len(ids) == 0 {
return nil return nil
} }
return storeMailbox.client().MarkMessagesRead(ids) return storeMailbox.client().MarkMessagesRead(context.TODO(), ids)
} }
// MarkMessagesUnread marks the message unread by calling an API. // MarkMessagesUnread marks the message unread by calling an API.
@ -147,7 +151,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Marking messages as unread") }).Trace("Marking messages as unread")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.client().MarkMessagesUnread(apiIDs) return storeMailbox.client().MarkMessagesUnread(context.TODO(), apiIDs)
} }
// MarkMessagesStarred adds the Starred label by calling an API. // MarkMessagesStarred adds the Starred label by calling an API.
@ -160,7 +164,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Marking messages as starred") }).Trace("Marking messages as starred")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.client().LabelMessages(apiIDs, pmapi.StarredLabel) return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
} }
// MarkMessagesUnstarred removes the Starred label by calling an API. // MarkMessagesUnstarred removes the Starred label by calling an API.
@ -173,7 +177,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Marking messages as unstarred") }).Trace("Marking messages as unstarred")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.client().UnlabelMessages(apiIDs, pmapi.StarredLabel) return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
} }
// MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API // MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API
@ -257,11 +261,11 @@ func (storeMailbox *Mailbox) RemoveDeleted(apiIDs []string) error {
} }
case pmapi.DraftLabel: case pmapi.DraftLabel:
storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts") storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts")
if err := storeMailbox.client().DeleteMessages(apiIDs); err != nil { if err := storeMailbox.client().DeleteMessages(context.TODO(), apiIDs); err != nil {
return err return err
} }
default: default:
if err := storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil { if err := storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID); err != nil {
return err return err
} }
} }
@ -299,13 +303,13 @@ func (storeMailbox *Mailbox) deleteFromTrashOrSpam(apiIDs []string) error {
} }
} }
if len(messageIDsToUnlabel) > 0 { if len(messageIDsToUnlabel) > 0 {
if err := storeMailbox.client().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil { if err := storeMailbox.client().UnlabelMessages(context.TODO(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
l.WithError(err).Warning("Cannot unlabel before deleting") l.WithError(err).Warning("Cannot unlabel before deleting")
} }
} }
if len(messageIDsToDelete) > 0 { if len(messageIDsToDelete) > 0 {
storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages") storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages")
if err := storeMailbox.client().DeleteMessages(messageIDsToDelete); err != nil { if err := storeMailbox.client().DeleteMessages(context.TODO(), messageIDsToDelete); err != nil {
return err return err
} }
} }

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,ClientManager,BridgeUser,ChangeNotifier) // Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier)
// Package mocks is a generated GoMock package. // Package mocks is a generated GoMock package.
package mocks package mocks
@ -46,43 +46,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
} }
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// MockBridgeUser is a mock of BridgeUser interface // MockBridgeUser is a mock of BridgeUser interface
type MockBridgeUser struct { type MockBridgeUser struct {
ctrl *gomock.Controller ctrl *gomock.Controller
@ -145,6 +108,20 @@ func (mr *MockBridgeUserMockRecorder) GetAddressID(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddressID", reflect.TypeOf((*MockBridgeUser)(nil).GetAddressID), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddressID", reflect.TypeOf((*MockBridgeUser)(nil).GetAddressID), arg0)
} }
// GetClient mocks base method
func (m *MockBridgeUser) GetClient() pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient")
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockBridgeUserMockRecorder) GetClient() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockBridgeUser)(nil).GetClient))
}
// GetPrimaryAddress mocks base method // GetPrimaryAddress mocks base method
func (m *MockBridgeUser) GetPrimaryAddress() string { func (m *MockBridgeUser) GetPrimaryAddress() string {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -19,6 +19,7 @@
package store package store
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@ -106,7 +107,6 @@ type Store struct {
panicHandler PanicHandler panicHandler PanicHandler
eventLoop *eventLoop eventLoop *eventLoop
user BridgeUser user BridgeUser
clientManager ClientManager
log *logrus.Entry log *logrus.Entry
@ -127,13 +127,12 @@ func New( // nolint[funlen]
sentryReporter *sentry.Reporter, sentryReporter *sentry.Reporter,
panicHandler PanicHandler, panicHandler PanicHandler,
user BridgeUser, user BridgeUser,
clientManager ClientManager,
events listener.Listener, events listener.Listener,
path string, path string,
cache *Cache, cache *Cache,
) (store *Store, err error) { ) (store *Store, err error) {
if user == nil || clientManager == nil || events == nil || cache == nil { if user == nil || events == nil || cache == nil {
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, clientManager, events, cache) return nil, fmt.Errorf("missing parameters - user: %v, events: %v, cache: %v", user, events, cache)
} }
l := log.WithField("user", user.ID()) l := log.WithField("user", user.ID())
@ -156,7 +155,6 @@ func New( // nolint[funlen]
store = &Store{ store = &Store{
sentryReporter: sentryReporter, sentryReporter: sentryReporter,
panicHandler: panicHandler, panicHandler: panicHandler,
clientManager: clientManager,
user: user, user: user,
cache: cache, cache: cache,
filePath: path, filePath: path,
@ -274,13 +272,13 @@ func (store *Store) init(firstInit bool) (err error) {
} }
func (store *Store) client() pmapi.Client { func (store *Store) client() pmapi.Client {
return store.clientManager.GetClient(store.UserID()) return store.user.GetClient()
} }
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if // initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
// the API is unavailable for whatever reason it tries to fetch the labels locally. // the API is unavailable for whatever reason it tries to fetch the labels locally.
func (store *Store) initCounts() (labels []*pmapi.Label, err error) { func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
if labels, err = store.client().ListLabels(); err != nil { if labels, err = store.client().ListLabels(context.TODO()); err != nil {
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.") store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
if labels, err = store.getLabelsFromLocalStorage(); err != nil { if labels, err = store.getLabelsFromLocalStorage(); err != nil {
store.log.WithError(err).Error("Cannot list local labels") store.log.WithError(err).Error("Cannot list local labels")

View File

@ -18,6 +18,7 @@
package store package store
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -133,7 +134,6 @@ type mocksForStore struct {
events *storemocks.MockListener events *storemocks.MockListener
user *storemocks.MockBridgeUser user *storemocks.MockBridgeUser
client *pmapimocks.MockClient client *pmapimocks.MockClient
clientManager *storemocks.MockClientManager
panicHandler *storemocks.MockPanicHandler panicHandler *storemocks.MockPanicHandler
changeNotifier *storemocks.MockChangeNotifier changeNotifier *storemocks.MockChangeNotifier
store *Store store *Store
@ -150,7 +150,6 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
events: storemocks.NewMockListener(ctrl), events: storemocks.NewMockListener(ctrl),
user: storemocks.NewMockBridgeUser(ctrl), user: storemocks.NewMockBridgeUser(ctrl),
client: pmapimocks.NewMockClient(ctrl), client: pmapimocks.NewMockClient(ctrl),
clientManager: storemocks.NewMockClientManager(ctrl),
panicHandler: storemocks.NewMockPanicHandler(ctrl), panicHandler: storemocks.NewMockPanicHandler(ctrl),
changeNotifier: storemocks.NewMockChangeNotifier(ctrl), changeNotifier: storemocks.NewMockChangeNotifier(ctrl),
} }
@ -182,30 +181,30 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
mocks.user.EXPECT().IsConnected().Return(true) mocks.user.EXPECT().IsConnected().Return(true)
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode) mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
mocks.clientManager.EXPECT().GetClient("userID").AnyTimes().Return(mocks.client) mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{ mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive}, {ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive}, {ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
}) })
mocks.client.EXPECT().ListLabels().AnyTimes() mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes()
mocks.client.EXPECT().CountMessages("") mocks.client.EXPECT().CountMessages(gomock.Any(), "")
// Call to get latest event ID and then to process first event. // Call to get latest event ID and then to process first event.
eventAfterSyncRequested := make(chan struct{}) eventAfterSyncRequested := make(chan struct{})
mocks.client.EXPECT().GetEvent("").Return(&pmapi.Event{ mocks.client.EXPECT().GetEvent(gomock.Any(), "").Return(&pmapi.Event{
EventID: "firstEventID", EventID: "firstEventID",
}, nil) }, nil)
mocks.client.EXPECT().GetEvent("firstEventID").DoAndReturn(func(_ string) (*pmapi.Event, error) { mocks.client.EXPECT().GetEvent(gomock.Any(), "firstEventID").DoAndReturn(func(_ context.Context, _ string) (*pmapi.Event, error) {
close(eventAfterSyncRequested) close(eventAfterSyncRequested)
return &pmapi.Event{ return &pmapi.Event{
EventID: "latestEventID", EventID: "latestEventID",
}, nil }, nil
}) })
mocks.client.EXPECT().ListMessages(gomock.Any()).Return(msgs, len(msgs), nil).AnyTimes() mocks.client.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return(msgs, len(msgs), nil).AnyTimes()
for _, msg := range msgs { for _, msg := range msgs {
mocks.client.EXPECT().GetMessage(msg.ID).Return(msg, nil).AnyTimes() mocks.client.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
} }
var err error var err error
@ -213,7 +212,6 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
nil, // Sentry reporter is not used under unit tests. nil, // Sentry reporter is not used under unit tests.
mocks.panicHandler, mocks.panicHandler,
mocks.user, mocks.user,
mocks.clientManager,
mocks.events, mocks.events,
filepath.Join(mocks.tmpDir, "mailbox-test.db"), filepath.Join(mocks.tmpDir, "mailbox-test.db"),
mocks.cache, mocks.cache,

View File

@ -18,6 +18,7 @@
package store package store
import ( import (
"context"
"math" "math"
"sync" "sync"
@ -39,10 +40,10 @@ type storeSynchronizer interface {
} }
type messageLister interface { type messageLister interface {
ListMessages(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error) ListMessages(context.Context, *pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
} }
func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func() messageLister, syncState *syncState) error { func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api messageLister, syncState *syncState) error {
labelID := pmapi.AllMailLabel labelID := pmapi.AllMailLabel
// When the full sync starts (i.e. is not already in progress), we need to load // When the full sync starts (i.e. is not already in progress), we need to load
@ -53,7 +54,7 @@ func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func()
return errors.Wrap(err, "failed to load message IDs") return errors.Wrap(err, "failed to load message IDs")
} }
if err := findIDRanges(labelID, api(), syncState); err != nil { if err := findIDRanges(labelID, api, syncState); err != nil {
return errors.Wrap(err, "failed to load IDs ranges") return errors.Wrap(err, "failed to load IDs ranges")
} }
syncState.save() syncState.save()
@ -71,7 +72,7 @@ func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func()
defer panicHandler.HandlePanic() defer panicHandler.HandlePanic()
defer wg.Done() defer wg.Done()
err := syncBatch(labelID, store, api(), syncState, idRange, &shouldStop) err := syncBatch(labelID, store, api, syncState, idRange, &shouldStop)
if err != nil { if err != nil {
shouldStop = 1 shouldStop = 1
resultError = errors.Wrap(err, "failed to sync group") resultError = errors.Wrap(err, "failed to sync group")
@ -147,7 +148,7 @@ func getSplitIDAndCount(labelID string, api messageLister, page int) (string, in
Limit: 1, Limit: 1,
} }
// If the page does not exist, an empty page instead of an error is returned. // If the page does not exist, an empty page instead of an error is returned.
messages, total, err := api.ListMessages(filter) messages, total, err := api.ListMessages(context.TODO(), filter)
if err != nil { if err != nil {
return "", 0, errors.Wrap(err, "failed to list messages") return "", 0, errors.Wrap(err, "failed to list messages")
} }
@ -189,7 +190,7 @@ func syncBatch( //nolint[funlen]
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page") log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
messages, _, err := api.ListMessages(filter) messages, _, err := api.ListMessages(context.TODO(), filter)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to list messages") return errors.Wrap(err, "failed to list messages")
} }

View File

@ -18,6 +18,7 @@
package store package store
import ( import (
"context"
"sort" "sort"
"strconv" "strconv"
"sync" "sync"
@ -34,7 +35,7 @@ type mockLister struct {
messageIDs []string messageIDs []string
} }
func (m *mockLister) ListMessages(filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) { func (m *mockLister) ListMessages(_ context.Context, filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) {
if m.err != nil { if m.err != nil {
return nil, 0, m.err return nil, 0, m.err
} }
@ -197,7 +198,7 @@ func TestSyncAllMail(t *testing.T) { //nolint[funlen]
syncState := newSyncState(store, 0, tc.idRanges, tc.idsToBeDeleted) syncState := newSyncState(store, 0, tc.idRanges, tc.idsToBeDeleted)
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState) err := syncAllMail(m.panicHandler, store, api, syncState)
require.Nil(t, err) require.Nil(t, err)
// Check all messages were created or updated. // Check all messages were created or updated.
@ -245,7 +246,7 @@ func TestSyncAllMail_FailedListing(t *testing.T) {
} }
syncState := newTestSyncState(store) syncState := newTestSyncState(store)
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState) err := syncAllMail(m.panicHandler, store, api, syncState)
require.EqualError(t, err, "failed to sync group: failed to list messages: error") require.EqualError(t, err, "failed to sync group: failed to list messages: error")
} }
@ -264,7 +265,7 @@ func TestSyncAllMail_FailedCreateOrUpdateMessage(t *testing.T) {
} }
syncState := newTestSyncState(store) syncState := newTestSyncState(store)
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState) err := syncAllMail(m.panicHandler, store, api, syncState)
require.EqualError(t, err, "failed to sync group: failed to create or update messages: error") require.EqualError(t, err, "failed to sync group: failed to create or update messages: error")
} }

View File

@ -23,10 +23,6 @@ type PanicHandler interface {
HandlePanic() HandlePanic()
} }
type ClientManager interface {
GetClient(userID string) pmapi.Client
}
// BridgeUser is subset of bridge.User for use by the Store. // BridgeUser is subset of bridge.User for use by the Store.
type BridgeUser interface { type BridgeUser interface {
ID() string ID() string
@ -35,6 +31,7 @@ type BridgeUser interface {
IsCombinedAddressMode() bool IsCombinedAddressMode() bool
GetPrimaryAddress() string GetPrimaryAddress() string
GetStoreAddresses() []string GetStoreAddresses() []string
GetClient() pmapi.Client
UpdateUser() error UpdateUser() error
CloseAllConnections() CloseAllConnections()
CloseConnection(string) CloseConnection(string)

View File

@ -17,6 +17,8 @@
package store package store
import "context"
// UserID returns user ID. // UserID returns user ID.
func (store *Store) UserID() string { func (store *Store) UserID() string {
return store.user.ID() return store.user.ID()
@ -24,7 +26,7 @@ func (store *Store) UserID() string {
// GetSpace returns used and total space in bytes. // GetSpace returns used and total space in bytes.
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) { func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
apiUser, err := store.client().CurrentUser() apiUser, err := store.client().CurrentUser(context.TODO())
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
@ -33,7 +35,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
// GetMaxUpload returns max size of message + all attachments in bytes. // GetMaxUpload returns max size of message + all attachments in bytes.
func (store *Store) GetMaxUpload() (int64, error) { func (store *Store) GetMaxUpload() (int64, error) {
apiUser, err := store.client().CurrentUser() apiUser, err := store.client().CurrentUser(context.TODO())
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -18,6 +18,7 @@
package store package store
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
@ -55,7 +56,7 @@ func (store *Store) createMailbox(name string) error {
return nil return nil
} }
_, err := store.client().CreateLabel(&pmapi.Label{ _, err := store.client().CreateLabel(context.TODO(), &pmapi.Label{
Name: name, Name: name,
Color: color, Color: color,
Exclusive: exclusive, Exclusive: exclusive,
@ -125,7 +126,7 @@ func (store *Store) leastUsedColor() string {
func (store *Store) updateMailbox(labelID, newName, color string) error { func (store *Store) updateMailbox(labelID, newName, color string) error {
defer store.eventLoop.pollNow() defer store.eventLoop.pollNow()
_, err := store.client().UpdateLabel(&pmapi.Label{ _, err := store.client().UpdateLabel(context.TODO(), &pmapi.Label{
ID: labelID, ID: labelID,
Name: newName, Name: newName,
Color: color, Color: color,
@ -142,15 +143,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
var err error var err error
switch labelID { switch labelID {
case pmapi.SpamLabel: case pmapi.SpamLabel:
err = store.client().EmptyFolder(pmapi.SpamLabel, addressID) err = store.client().EmptyFolder(context.TODO(), pmapi.SpamLabel, addressID)
case pmapi.TrashLabel: case pmapi.TrashLabel:
err = store.client().EmptyFolder(pmapi.TrashLabel, addressID) err = store.client().EmptyFolder(context.TODO(), pmapi.TrashLabel, addressID)
default: default:
err = fmt.Errorf("cannot empty mailbox %v", labelID) err = fmt.Errorf("cannot empty mailbox %v", labelID)
} }
return err return err
} }
return store.client().DeleteLabel(labelID) return store.client().DeleteLabel(context.TODO(), labelID)
} }
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error { func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
@ -165,7 +166,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
return nil return nil
} }
labels, err := store.client().ListLabels() labels, err := store.client().ListLabels(context.TODO())
if err != nil { if err != nil {
return err return err
} }

View File

@ -19,6 +19,7 @@ package store
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil" "io/ioutil"
@ -57,7 +58,7 @@ func (store *Store) CreateDraft(
} }
draftAction := store.getDraftAction(message) draftAction := store.getDraftAction(message)
draft, err := store.client().CreateDraft(message, parentID, draftAction) draft, err := store.client().CreateDraft(context.TODO(), message, parentID, draftAction)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "failed to create draft") return nil, nil, errors.Wrap(err, "failed to create draft")
} }
@ -69,7 +70,7 @@ func (store *Store) CreateDraft(
for _, att := range attachments { for _, att := range attachments {
att.attachment.MessageID = draft.ID att.attachment.MessageID = draft.ID
createdAttachment, err := store.client().CreateAttachment(att.attachment, att.encReader, att.sigReader) createdAttachment, err := store.client().CreateAttachment(context.TODO(), att.attachment, att.encReader, att.sigReader)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "failed to create attachment") return nil, nil, errors.Wrap(err, "failed to create attachment")
} }
@ -183,7 +184,7 @@ func (store *Store) getDraftAction(message *pmapi.Message) int {
// SendMessage sends the message. // SendMessage sends the message.
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error { func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
defer store.eventLoop.pollNow() defer store.eventLoop.pollNow()
_, _, err := store.client().SendMessage(messageID, req) _, _, err := store.client().SendMessage(context.TODO(), messageID, req)
return err return err
} }

View File

@ -127,12 +127,12 @@ func TestDeleteMessage(t *testing.T) {
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}}) checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
} }
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread int, labelIDs []string) { //nolint[unparam] func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread pmapi.Boolean, labelIDs []string) { //nolint[unparam]
msg := getTestMessage(id, subject, sender, unread, labelIDs) msg := getTestMessage(id, subject, sender, unread, labelIDs)
require.Nil(t, m.store.createOrUpdateMessageEvent(msg)) require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
} }
func getTestMessage(id, subject, sender string, unread int, labelIDs []string) *pmapi.Message { func getTestMessage(id, subject, sender string, unread pmapi.Boolean, labelIDs []string) *pmapi.Message {
address := &mail.Address{Address: sender} address := &mail.Address{Address: sender}
return &pmapi.Message{ return &pmapi.Message{
ID: id, ID: id,

View File

@ -18,6 +18,7 @@
package store package store
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
@ -34,7 +35,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
// updateCountsFromServer will download and set the counts. // updateCountsFromServer will download and set the counts.
func (store *Store) updateCountsFromServer() error { func (store *Store) updateCountsFromServer() error {
counts, err := store.client().CountMessages("") counts, err := store.client().CountMessages(context.TODO(), "")
if err != nil { if err != nil {
return errors.Wrap(err, "cannot update counts from server") return errors.Wrap(err, "cannot update counts from server")
} }
@ -152,7 +153,7 @@ func (store *Store) triggerSync() {
store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started") store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started")
err := syncAllMail(store.panicHandler, store, func() messageLister { return store.client() }, syncState) err := syncAllMail(store.panicHandler, store, store.client(), syncState)
if err != nil { if err != nil {
log.WithError(err).Error("Store sync failed") log.WithError(err).Error("Store sync failed")
store.syncCooldown.increaseWaitTime() store.syncCooldown.increaseWaitTime()

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/transfer (interfaces: PanicHandler,ClientManager,IMAPClientProvider) // Source: github.com/ProtonMail/proton-bridge/internal/transfer (interfaces: PanicHandler,IMAPClientProvider)
// Package mocks is a generated GoMock package. // Package mocks is a generated GoMock package.
package mocks package mocks
@ -7,7 +7,6 @@ package mocks
import ( import (
reflect "reflect" reflect "reflect"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
imap "github.com/emersion/go-imap" imap "github.com/emersion/go-imap"
sasl "github.com/emersion/go-sasl" sasl "github.com/emersion/go-sasl"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@ -48,57 +47,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
} }
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// CheckConnection mocks base method
func (m *MockClientManager) CheckConnection() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CheckConnection")
ret0, _ := ret[0].(error)
return ret0
}
// CheckConnection indicates an expected call of CheckConnection
func (mr *MockClientManagerMockRecorder) CheckConnection() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConnection", reflect.TypeOf((*MockClientManager)(nil).CheckConnection))
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// MockIMAPClientProvider is a mock of IMAPClientProvider interface // MockIMAPClientProvider is a mock of IMAPClientProvider interface
type MockIMAPClientProvider struct { type MockIMAPClientProvider struct {
ctrl *gomock.Controller ctrl *gomock.Controller

View File

@ -25,7 +25,6 @@ import (
imapID "github.com/ProtonMail/go-imap-id" imapID "github.com/ProtonMail/go-imap-id"
"github.com/ProtonMail/proton-bridge/internal/constants" "github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/emersion/go-imap" "github.com/emersion/go-imap"
imapClient "github.com/emersion/go-imap/client" imapClient "github.com/emersion/go-imap/client"
"github.com/emersion/go-sasl" "github.com/emersion/go-sasl"
@ -118,6 +117,9 @@ func (p *IMAPProvider) tryReconnect(ensureSelectedIn string) error {
return previousErr return previousErr
} }
// FIXME(conman): This should register as connection observer.
/*
err := pmapi.CheckConnection() err := pmapi.CheckConnection()
log.WithError(err).Debug("Connection check") log.WithError(err).Debug("Connection check")
if err != nil { if err != nil {
@ -125,8 +127,9 @@ func (p *IMAPProvider) tryReconnect(ensureSelectedIn string) error {
previousErr = err previousErr = err
continue continue
} }
*/
err = p.reauth() err := p.reauth()
log.WithError(err).Debug("Reauth") log.WithError(err).Debug("Reauth")
if err != nil { if err != nil {
time.Sleep(imapReconnectSleep) time.Sleep(imapReconnectSleep)

View File

@ -18,6 +18,7 @@
package transfer package transfer
import ( import (
"context"
"sort" "sort"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
@ -34,7 +35,7 @@ const (
// PMAPIProvider implements import and export to/from ProtonMail server. // PMAPIProvider implements import and export to/from ProtonMail server.
type PMAPIProvider struct { type PMAPIProvider struct {
clientManager ClientManager client pmapi.Client
userID string userID string
addressID string addressID string
keyRing *crypto.KeyRing keyRing *crypto.KeyRing
@ -44,12 +45,14 @@ type PMAPIProvider struct {
nextImportRequestsSize int nextImportRequestsSize int
timeIt *timeIt timeIt *timeIt
connection bool
} }
// NewPMAPIProvider returns new PMAPIProvider. // NewPMAPIProvider returns new PMAPIProvider.
func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*PMAPIProvider, error) { func NewPMAPIProvider(client pmapi.Client, userID, addressID string) (*PMAPIProvider, error) {
provider := &PMAPIProvider{ provider := &PMAPIProvider{
clientManager: clientManager, client: client,
userID: userID, userID: userID,
addressID: addressID, addressID: addressID,
builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers), builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
@ -61,7 +64,7 @@ func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*P
} }
if addressID != "" { if addressID != "" {
keyRing, err := clientManager.GetClient(userID).KeyRingForAddressID(addressID) keyRing, err := client.KeyRingForAddressID(addressID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to get key ring") return nil, errors.Wrap(err, "failed to get key ring")
} }
@ -71,10 +74,6 @@ func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*P
return provider, nil return provider, nil
} }
func (p *PMAPIProvider) client() pmapi.Client {
return p.clientManager.GetClient(p.userID)
}
// ID returns identifier of current setup of PMAPI provider. // ID returns identifier of current setup of PMAPI provider.
// Identification is unique per user. // Identification is unique per user.
func (p *PMAPIProvider) ID() string { func (p *PMAPIProvider) ID() string {
@ -83,7 +82,7 @@ func (p *PMAPIProvider) ID() string {
// Mailboxes returns all available labels in ProtonMail account. // Mailboxes returns all available labels in ProtonMail account.
func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox, error) { func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox, error) {
labels, err := p.client().ListLabels() labels, err := p.client.ListLabels(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -92,7 +91,7 @@ func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox,
emptyLabelsMap := map[string]bool{} emptyLabelsMap := map[string]bool{}
if !includeEmpty { if !includeEmpty {
messagesCounts, err := p.client().CountMessages(p.addressID) messagesCounts, err := p.client.CountMessages(context.Background(), p.addressID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -120,7 +119,7 @@ func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox,
ID: label.ID, ID: label.ID,
Name: label.Name, Name: label.Name,
Color: label.Color, Color: label.Color,
IsExclusive: label.Exclusive == 1, IsExclusive: bool(label.Exclusive),
}) })
} }
return mailboxes, nil return mailboxes, nil
@ -160,10 +159,10 @@ func (l byFoldersLabels) Swap(i, j int) {
// Less sorts first folders, then labels, by user order. // Less sorts first folders, then labels, by user order.
func (l byFoldersLabels) Less(i, j int) bool { func (l byFoldersLabels) Less(i, j int) bool {
if l[i].Exclusive == 1 && l[j].Exclusive == 0 { if l[i].Exclusive && !l[j].Exclusive {
return true return true
} }
if l[i].Exclusive == 0 && l[j].Exclusive == 1 { if !l[i].Exclusive && l[j].Exclusive {
return false return false
} }
return l[i].Order < l[j].Order return l[i].Order < l[j].Order

View File

@ -157,7 +157,7 @@ func (p *PMAPIProvider) exportMessage(rule *Rule, progress *Progress, pmapiMsgID
body, err := p.builder.NewJobWithOptions( body, err := p.builder.NewJobWithOptions(
context.Background(), context.Background(),
p.client(), p.client,
msg.ID, msg.ID,
message.JobOptions{IgnoreDecryptionErrors: !skipEncryptedMessages}, message.JobOptions{IgnoreDecryptionErrors: !skipEncryptedMessages},
).GetResult() ).GetResult()
@ -169,14 +169,9 @@ func (p *PMAPIProvider) exportMessage(rule *Rule, progress *Progress, pmapiMsgID
return Message{Body: []byte(msg.Body)}, err return Message{Body: []byte(msg.Body)}, err
} }
unread := false
if msg.Unread == 1 {
unread = true
}
return Message{ return Message{
ID: msgID, ID: msgID,
Unread: unread, Unread: bool(msg.Unread),
Body: body, Body: body,
Sources: []Mailbox{rule.SourceMailbox}, Sources: []Mailbox{rule.SourceMailbox},
Targets: rule.TargetMailboxes, Targets: rule.TargetMailboxes,

View File

@ -19,6 +19,7 @@ package transfer
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -56,7 +57,7 @@ func (p *PMAPIProvider) CreateMailbox(mailbox Mailbox) (Mailbox, error) {
exclusive = 1 exclusive = 1
} }
label, err := p.client().CreateLabel(&pmapi.Label{ label, err := p.client.CreateLabel(context.TODO(), &pmapi.Label{
Name: mailbox.Name, Name: mailbox.Name,
Color: mailbox.Color, Color: mailbox.Color,
Exclusive: exclusive, Exclusive: exclusive,
@ -194,7 +195,7 @@ func (p *PMAPIProvider) transferMessage(rules transferRules, progress *Progress,
return return
} }
importMsgReqSize := len(importMsgReq.Body) importMsgReqSize := len(importMsgReq.Message)
if p.nextImportRequestsSize+importMsgReqSize > pmapiImportBatchMaxSize || len(p.nextImportRequests) == pmapiImportBatchMaxItems { if p.nextImportRequestsSize+importMsgReqSize > pmapiImportBatchMaxSize || len(p.nextImportRequests) == pmapiImportBatchMaxItems {
preparedImportRequestsCh <- p.nextImportRequests preparedImportRequestsCh <- p.nextImportRequests
p.nextImportRequests = map[string]*pmapi.ImportMsgReq{} p.nextImportRequests = map[string]*pmapi.ImportMsgReq{}
@ -226,9 +227,12 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
} }
} }
unread := 0 var unread pmapi.Boolean
if msg.Unread { if msg.Unread {
unread = 1 unread = pmapi.True
} else {
unread = pmapi.False
} }
labelIDs := []string{} labelIDs := []string{}
@ -243,12 +247,14 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
} }
return &pmapi.ImportMsgReq{ return &pmapi.ImportMsgReq{
Metadata: &pmapi.ImportMetadata{
AddressID: p.addressID, AddressID: p.addressID,
Body: body,
Unread: unread, Unread: unread,
Time: message.Time, Time: message.Time,
Flags: computeMessageFlags(message.Header), Flags: computeMessageFlags(message.Header),
LabelIDs: labelIDs, LabelIDs: labelIDs,
},
Message: body,
}, nil }, nil
} }
@ -293,7 +299,7 @@ func (p *PMAPIProvider) importMessages(progress *Progress, importRequests map[st
} }
importMsgIDs := []string{} importMsgIDs := []string{}
importMsgRequests := []*pmapi.ImportMsgReq{} importMsgRequests := pmapi.ImportMsgReqs{}
for msgID, req := range importRequests { for msgID, req := range importRequests {
importMsgIDs = append(importMsgIDs, msgID) importMsgIDs = append(importMsgIDs, msgID)
importMsgRequests = append(importMsgRequests, req) importMsgRequests = append(importMsgRequests, req)
@ -327,7 +333,7 @@ func (p *PMAPIProvider) importMessages(progress *Progress, importRequests map[st
func (p *PMAPIProvider) importMessage(msgSourceID string, progress *Progress, req *pmapi.ImportMsgReq) (importedID string, importedErr error) { func (p *PMAPIProvider) importMessage(msgSourceID string, progress *Progress, req *pmapi.ImportMsgReq) (importedID string, importedErr error) {
progress.callWrap(func() error { progress.callWrap(func() error {
results, err := p.importRequest(msgSourceID, []*pmapi.ImportMsgReq{req}) results, err := p.importRequest(msgSourceID, pmapi.ImportMsgReqs{req})
if err != nil { if err != nil {
return errors.Wrap(err, "failed to import messages") return errors.Wrap(err, "failed to import messages")
} }

View File

@ -19,6 +19,7 @@ package transfer
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@ -33,7 +34,7 @@ func TestPMAPIProviderMailboxes(t *testing.T) {
defer m.ctrl.Finish() defer m.ctrl.Finish()
setupPMAPIClientExpectationForExport(&m) setupPMAPIClientExpectationForExport(&m)
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID") provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err) r.NoError(t, err)
tests := []struct { tests := []struct {
@ -78,7 +79,7 @@ func TestPMAPIProviderTransferTo(t *testing.T) {
defer m.ctrl.Finish() defer m.ctrl.Finish()
setupPMAPIClientExpectationForExport(&m) setupPMAPIClientExpectationForExport(&m)
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID") provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err) r.NoError(t, err)
rules, rulesClose := newTestRules(t) rules, rulesClose := newTestRules(t)
@ -96,7 +97,7 @@ func TestPMAPIProviderTransferFrom(t *testing.T) {
defer m.ctrl.Finish() defer m.ctrl.Finish()
setupPMAPIClientExpectationForImport(&m) setupPMAPIClientExpectationForImport(&m)
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID") provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err) r.NoError(t, err)
rules, rulesClose := newTestRules(t) rules, rulesClose := newTestRules(t)
@ -114,7 +115,7 @@ func TestPMAPIProviderTransferFromDraft(t *testing.T) {
defer m.ctrl.Finish() defer m.ctrl.Finish()
setupPMAPIClientExpectationForImportDraft(&m) setupPMAPIClientExpectationForImportDraft(&m)
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID") provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err) r.NoError(t, err)
rules, rulesClose := newTestRules(t) rules, rulesClose := newTestRules(t)
@ -133,9 +134,9 @@ func TestPMAPIProviderTransferFromTo(t *testing.T) {
setupPMAPIClientExpectationForExport(&m) setupPMAPIClientExpectationForExport(&m)
setupPMAPIClientExpectationForImport(&m) setupPMAPIClientExpectationForImport(&m)
source, err := NewPMAPIProvider(m.clientManager, "user", "addressID") source, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err) r.NoError(t, err)
target, err := NewPMAPIProvider(m.clientManager, "user", "addressID") target, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err) r.NoError(t, err)
rules, rulesClose := newTestRules(t) rules, rulesClose := newTestRules(t)
@ -151,22 +152,22 @@ func setupPMAPIRules(rules transferRules) {
func setupPMAPIClientExpectationForExport(m *mocks) { func setupPMAPIClientExpectationForExport(m *mocks) {
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes() m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{ m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: 0, Order: 2}, {ID: "label1", Name: "Foo", Color: "blue", Exclusive: 0, Order: 2},
{ID: "label2", Name: "Bar", Color: "green", Exclusive: 0, Order: 1}, {ID: "label2", Name: "Bar", Color: "green", Exclusive: 0, Order: 1},
{ID: "folder1", Name: "One", Color: "red", Exclusive: 1, Order: 1}, {ID: "folder1", Name: "One", Color: "red", Exclusive: 1, Order: 1},
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: 1, Order: 2}, {ID: "folder2", Name: "Two", Color: "orange", Exclusive: 1, Order: 2},
}, nil).AnyTimes() }, nil).AnyTimes()
m.pmapiClient.EXPECT().CountMessages(gomock.Any()).Return([]*pmapi.MessagesCount{ m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{
{LabelID: "label1", Total: 10}, {LabelID: "label1", Total: 10},
{LabelID: "label2", Total: 0}, {LabelID: "label2", Total: 0},
{LabelID: "folder1", Total: 20}, {LabelID: "folder1", Total: 20},
}, nil).AnyTimes() }, nil).AnyTimes()
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{ m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{
{ID: "msg1"}, {ID: "msg1"},
{ID: "msg2"}, {ID: "msg2"},
}, 2, nil).AnyTimes() }, 2, nil).AnyTimes()
m.pmapiClient.EXPECT().GetMessage(gomock.Any()).DoAndReturn(func(msgID string) (*pmapi.Message, error) { m.pmapiClient.EXPECT().GetMessage(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, msgID string) (*pmapi.Message, error) {
return &pmapi.Message{ return &pmapi.Message{
ID: msgID, ID: msgID,
Body: string(getTestMsgBody(msgID)), Body: string(getTestMsgBody(msgID)),
@ -177,11 +178,11 @@ func setupPMAPIClientExpectationForExport(m *mocks) {
func setupPMAPIClientExpectationForImport(m *mocks) { func setupPMAPIClientExpectationForImport(m *mocks) {
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes() m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
m.pmapiClient.EXPECT().Import(gomock.Any()).DoAndReturn(func(requests []*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error) { m.pmapiClient.EXPECT().Import(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, requests pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) {
results := []*pmapi.ImportMsgRes{} results := []*pmapi.ImportMsgRes{}
for _, request := range requests { for _, request := range requests {
for _, msgID := range []string{"msg1", "msg2"} { for _, msgID := range []string{"msg1", "msg2"} {
if bytes.Contains(request.Body, []byte(msgID)) { if bytes.Contains(request.Message, []byte(msgID)) {
results = append(results, &pmapi.ImportMsgRes{MessageID: msgID, Error: nil}) results = append(results, &pmapi.ImportMsgRes{MessageID: msgID, Error: nil})
} }
} }
@ -192,7 +193,7 @@ func setupPMAPIClientExpectationForImport(m *mocks) {
func setupPMAPIClientExpectationForImportDraft(m *mocks) { func setupPMAPIClientExpectationForImportDraft(m *mocks) {
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes() m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
m.pmapiClient.EXPECT().CreateDraft(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(msg *pmapi.Message, parentID string, action int) (*pmapi.Message, error) { m.pmapiClient.EXPECT().CreateDraft(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, msg *pmapi.Message, parentID string, action int) (*pmapi.Message, error) {
r.Equal(m.t, msg.Subject, "draft1") r.Equal(m.t, msg.Subject, "draft1")
msg.ID = "draft1" msg.ID = "draft1"
return msg, nil return msg, nil

View File

@ -18,6 +18,7 @@
package transfer package transfer
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"time" "time"
@ -57,6 +58,10 @@ func (p *PMAPIProvider) tryReconnect() error {
return previousErr return previousErr
} }
// FIXME(conman): This should register as a connection observer somehow...
// Maybe the entire "provider" could register as an observer and pause if it is notified of dropped connection?
/*
err := p.clientManager.CheckConnection() err := p.clientManager.CheckConnection()
log.WithError(err).Debug("Connection check") log.WithError(err).Debug("Connection check")
if err != nil { if err != nil {
@ -64,6 +69,7 @@ func (p *PMAPIProvider) tryReconnect() error {
previousErr = err previousErr = err
continue continue
} }
*/
break break
} }
@ -77,7 +83,7 @@ func (p *PMAPIProvider) listMessages(filter *pmapi.MessagesFilter) (messages []*
p.timeIt.start("listing", key) p.timeIt.start("listing", key)
defer p.timeIt.stop("listing", key) defer p.timeIt.stop("listing", key)
messages, count, err = p.client().ListMessages(filter) messages, count, err = p.client.ListMessages(context.TODO(), filter)
return err return err
}) })
return return
@ -88,18 +94,18 @@ func (p *PMAPIProvider) getMessage(msgID string) (message *pmapi.Message, err er
p.timeIt.start("download", msgID) p.timeIt.start("download", msgID)
defer p.timeIt.stop("download", msgID) defer p.timeIt.stop("download", msgID)
message, err = p.client().GetMessage(msgID) message, err = p.client.GetMessage(context.TODO(), msgID)
return err return err
}) })
return return
} }
func (p *PMAPIProvider) importRequest(msgSourceID string, req []*pmapi.ImportMsgReq) (res []*pmapi.ImportMsgRes, err error) { func (p *PMAPIProvider) importRequest(msgSourceID string, req pmapi.ImportMsgReqs) (res []*pmapi.ImportMsgRes, err error) {
err = p.ensureConnection(func() error { err = p.ensureConnection(func() error {
p.timeIt.start("upload", msgSourceID) p.timeIt.start("upload", msgSourceID)
defer p.timeIt.stop("upload", msgSourceID) defer p.timeIt.stop("upload", msgSourceID)
res, err = p.client().Import(req) res, err = p.client.Import(context.TODO(), req)
return err return err
}) })
return return
@ -110,7 +116,7 @@ func (p *PMAPIProvider) createDraft(msgSourceID string, message *pmapi.Message,
p.timeIt.start("upload", msgSourceID) p.timeIt.start("upload", msgSourceID)
defer p.timeIt.stop("upload", msgSourceID) defer p.timeIt.stop("upload", msgSourceID)
draft, err = p.client().CreateDraft(message, parent, action) draft, err = p.client.CreateDraft(context.TODO(), message, parent, action)
return err return err
}) })
return return
@ -123,7 +129,7 @@ func (p *PMAPIProvider) createAttachment(msgSourceID string, att *pmapi.Attachme
p.timeIt.start("upload", key) p.timeIt.start("upload", key)
defer p.timeIt.stop("upload", key) defer p.timeIt.stop("upload", key)
created, err = p.client().CreateAttachment(att, r, sig) created, err = p.client.CreateAttachment(context.TODO(), att, r, sig)
return err return err
}) })
return return

View File

@ -23,7 +23,6 @@ import (
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
transfermocks "github.com/ProtonMail/proton-bridge/internal/transfer/mocks" transfermocks "github.com/ProtonMail/proton-bridge/internal/transfer/mocks"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
@ -33,10 +32,8 @@ type mocks struct {
ctrl *gomock.Controller ctrl *gomock.Controller
panicHandler *transfermocks.MockPanicHandler panicHandler *transfermocks.MockPanicHandler
clientManager *transfermocks.MockClientManager
imapClientProvider *transfermocks.MockIMAPClientProvider imapClientProvider *transfermocks.MockIMAPClientProvider
pmapiClient *pmapimocks.MockClient pmapiClient *pmapimocks.MockClient
pmapiConfig *pmapi.ClientConfig
keyring *crypto.KeyRing keyring *crypto.KeyRing
} }
@ -49,15 +46,11 @@ func initMocks(t *testing.T) mocks {
ctrl: mockCtrl, ctrl: mockCtrl,
panicHandler: transfermocks.NewMockPanicHandler(mockCtrl), panicHandler: transfermocks.NewMockPanicHandler(mockCtrl),
clientManager: transfermocks.NewMockClientManager(mockCtrl),
imapClientProvider: transfermocks.NewMockIMAPClientProvider(mockCtrl), imapClientProvider: transfermocks.NewMockIMAPClientProvider(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl), pmapiClient: pmapimocks.NewMockClient(mockCtrl),
pmapiConfig: &pmapi.ClientConfig{},
keyring: newTestKeyring(), keyring: newTestKeyring(),
} }
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).AnyTimes()
return m return m
} }

View File

@ -17,10 +17,6 @@
package transfer package transfer
import (
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type PanicHandler interface { type PanicHandler interface {
HandlePanic() HandlePanic()
} }
@ -32,8 +28,3 @@ type MetricsManager interface {
Cancel() Cancel()
Fail() Fail()
} }
type ClientManager interface {
GetClient(userID string) pmapi.Client
CheckConnection() error
}

View File

@ -18,6 +18,7 @@
package updater package updater
import ( import (
"bytes"
"encoding/json" "encoding/json"
"io" "io"
@ -31,10 +32,6 @@ import (
var ErrManualUpdateRequired = errors.New("manual update is required") var ErrManualUpdateRequired = errors.New("manual update is required")
type ClientProvider interface {
GetAnonymousClient() pmapi.Client
}
type Installer interface { type Installer interface {
InstallUpdate(*semver.Version, io.Reader) error InstallUpdate(*semver.Version, io.Reader) error
} }
@ -46,7 +43,7 @@ type Settings interface {
} }
type Updater struct { type Updater struct {
cm ClientProvider cm pmapi.Manager
installer Installer installer Installer
settings Settings settings Settings
kr *crypto.KeyRing kr *crypto.KeyRing
@ -59,7 +56,7 @@ type Updater struct {
} }
func New( func New(
cm ClientProvider, cm pmapi.Manager,
installer Installer, installer Installer,
s Settings, s Settings,
kr *crypto.KeyRing, kr *crypto.KeyRing,
@ -87,13 +84,10 @@ func New(
func (u *Updater) Check() (VersionInfo, error) { func (u *Updater) Check() (VersionInfo, error) {
logrus.Info("Checking for updates") logrus.Info("Checking for updates")
client := u.cm.GetAnonymousClient() b, err := u.cm.DownloadAndVerify(
defer client.Logout() u.kr,
r, err := client.DownloadAndVerify(
u.getVersionFileURL(), u.getVersionFileURL(),
u.getVersionFileURL()+".sig", u.getVersionFileURL()+".sig",
u.kr,
) )
if err != nil { if err != nil {
return VersionInfo{}, err return VersionInfo{}, err
@ -101,7 +95,7 @@ func (u *Updater) Check() (VersionInfo, error) {
var versionMap VersionMap var versionMap VersionMap
if err := json.NewDecoder(r).Decode(&versionMap); err != nil { if err := json.Unmarshal(b, &versionMap); err != nil {
return VersionInfo{}, err return VersionInfo{}, err
} }
@ -141,15 +135,12 @@ func (u *Updater) InstallUpdate(update VersionInfo) error {
return u.locker.doOnce(func() error { return u.locker.doOnce(func() error {
logrus.WithField("package", update.Package).Info("Installing update package") logrus.WithField("package", update.Package).Info("Installing update package")
client := u.cm.GetAnonymousClient() b, err := u.cm.DownloadAndVerify(u.kr, update.Package, update.Package+".sig")
defer client.Logout()
r, err := client.DownloadAndVerify(update.Package, update.Package+".sig", u.kr)
if err != nil { if err != nil {
return errors.Wrap(ErrDownloadVerify, err.Error()) return errors.Wrap(ErrDownloadVerify, err.Error())
} }
if err := u.installer.InstallUpdate(update.Version, r); err != nil { if err := u.installer.InstallUpdate(update.Version, bytes.NewReader(b)); err != nil {
return errors.Wrap(ErrInstall, err.Error()) return errors.Wrap(ErrInstall, err.Error())
} }

View File

@ -18,7 +18,6 @@
package updater package updater
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
@ -29,7 +28,6 @@ import (
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/internal/config/settings" "github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -40,9 +38,9 @@ func TestCheck(t *testing.T) {
c := gomock.NewController(t) c := gomock.NewController(t)
defer c.Finish() defer c.Finish()
client := mocks.NewMockClient(c) cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.1.0", false) updater := newTestUpdater(cm, "1.1.0", false)
versionMap := VersionMap{ versionMap := VersionMap{
"stable": VersionInfo{ "stable": VersionInfo{
@ -53,13 +51,11 @@ func TestCheck(t *testing.T) {
}, },
} }
client.EXPECT().DownloadAndVerify( cm.EXPECT().DownloadAndVerify(
gomock.Any(),
updater.getVersionFileURL(), updater.getVersionFileURL(),
updater.getVersionFileURL()+".sig", updater.getVersionFileURL()+".sig",
gomock.Any(), ).Return(mustMarshal(t, versionMap), nil)
).Return(bytes.NewReader(mustMarshal(t, versionMap)), nil)
client.EXPECT().Logout()
version, err := updater.Check() version, err := updater.Check()
@ -71,9 +67,9 @@ func TestCheckEarlyAccess(t *testing.T) {
c := gomock.NewController(t) c := gomock.NewController(t)
defer c.Finish() defer c.Finish()
client := mocks.NewMockClient(c) cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.1.0", true) updater := newTestUpdater(cm, "1.1.0", true)
versionMap := VersionMap{ versionMap := VersionMap{
"stable": VersionInfo{ "stable": VersionInfo{
@ -90,13 +86,11 @@ func TestCheckEarlyAccess(t *testing.T) {
}, },
} }
client.EXPECT().DownloadAndVerify( cm.EXPECT().DownloadAndVerify(
gomock.Any(),
updater.getVersionFileURL(), updater.getVersionFileURL(),
updater.getVersionFileURL()+".sig", updater.getVersionFileURL()+".sig",
gomock.Any(), ).Return(mustMarshal(t, versionMap), nil)
).Return(bytes.NewReader(mustMarshal(t, versionMap)), nil)
client.EXPECT().Logout()
version, err := updater.Check() version, err := updater.Check()
@ -108,18 +102,16 @@ func TestCheckBadSignature(t *testing.T) {
c := gomock.NewController(t) c := gomock.NewController(t)
defer c.Finish() defer c.Finish()
client := mocks.NewMockClient(c) cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.2.0", false) updater := newTestUpdater(cm, "1.2.0", false)
client.EXPECT().DownloadAndVerify( cm.EXPECT().DownloadAndVerify(
gomock.Any(),
updater.getVersionFileURL(), updater.getVersionFileURL(),
updater.getVersionFileURL()+".sig", updater.getVersionFileURL()+".sig",
gomock.Any(),
).Return(nil, errors.New("bad signature")) ).Return(nil, errors.New("bad signature"))
client.EXPECT().Logout()
_, err := updater.Check() _, err := updater.Check()
assert.Error(t, err) assert.Error(t, err)
@ -129,9 +121,9 @@ func TestIsUpdateApplicable(t *testing.T) {
c := gomock.NewController(t) c := gomock.NewController(t)
defer c.Finish() defer c.Finish()
client := mocks.NewMockClient(c) cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.4.0", false) updater := newTestUpdater(cm, "1.4.0", false)
versionOld := VersionInfo{ versionOld := VersionInfo{
Version: semver.MustParse("1.3.0"), Version: semver.MustParse("1.3.0"),
@ -165,9 +157,9 @@ func TestCanInstall(t *testing.T) {
c := gomock.NewController(t) c := gomock.NewController(t)
defer c.Finish() defer c.Finish()
client := mocks.NewMockClient(c) cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.4.0", false) updater := newTestUpdater(cm, "1.4.0", false)
versionManual := VersionInfo{ versionManual := VersionInfo{
Version: semver.MustParse("1.5.0"), Version: semver.MustParse("1.5.0"),
@ -192,9 +184,9 @@ func TestInstallUpdate(t *testing.T) {
c := gomock.NewController(t) c := gomock.NewController(t)
defer c.Finish() defer c.Finish()
client := mocks.NewMockClient(c) cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.4.0", false) updater := newTestUpdater(cm, "1.4.0", false)
latestVersion := VersionInfo{ latestVersion := VersionInfo{
Version: semver.MustParse("1.5.0"), Version: semver.MustParse("1.5.0"),
@ -203,13 +195,11 @@ func TestInstallUpdate(t *testing.T) {
RolloutProportion: 1.0, RolloutProportion: 1.0,
} }
client.EXPECT().DownloadAndVerify( cm.EXPECT().DownloadAndVerify(
gomock.Any(),
latestVersion.Package, latestVersion.Package,
latestVersion.Package+".sig", latestVersion.Package+".sig",
gomock.Any(), ).Return([]byte("tgz_data_here"), nil)
).Return(bytes.NewReader([]byte("tgz_data_here")), nil)
client.EXPECT().Logout()
err := updater.InstallUpdate(latestVersion) err := updater.InstallUpdate(latestVersion)
@ -220,9 +210,9 @@ func TestInstallUpdateBadSignature(t *testing.T) {
c := gomock.NewController(t) c := gomock.NewController(t)
defer c.Finish() defer c.Finish()
client := mocks.NewMockClient(c) cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.4.0", false) updater := newTestUpdater(cm, "1.4.0", false)
latestVersion := VersionInfo{ latestVersion := VersionInfo{
Version: semver.MustParse("1.5.0"), Version: semver.MustParse("1.5.0"),
@ -231,14 +221,12 @@ func TestInstallUpdateBadSignature(t *testing.T) {
RolloutProportion: 1.0, RolloutProportion: 1.0,
} }
client.EXPECT().DownloadAndVerify( cm.EXPECT().DownloadAndVerify(
gomock.Any(),
latestVersion.Package, latestVersion.Package,
latestVersion.Package+".sig", latestVersion.Package+".sig",
gomock.Any(),
).Return(nil, errors.New("bad signature")) ).Return(nil, errors.New("bad signature"))
client.EXPECT().Logout()
err := updater.InstallUpdate(latestVersion) err := updater.InstallUpdate(latestVersion)
assert.Error(t, err) assert.Error(t, err)
@ -248,9 +236,9 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
c := gomock.NewController(t) c := gomock.NewController(t)
defer c.Finish() defer c.Finish()
client := mocks.NewMockClient(c) cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.4.0", false) updater := newTestUpdater(cm, "1.4.0", false)
updater.installer = &fakeInstaller{delay: 2 * time.Second} updater.installer = &fakeInstaller{delay: 2 * time.Second}
@ -261,13 +249,11 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
RolloutProportion: 1.0, RolloutProportion: 1.0,
} }
client.EXPECT().DownloadAndVerify( cm.EXPECT().DownloadAndVerify(
gomock.Any(),
latestVersion.Package, latestVersion.Package,
latestVersion.Package+".sig", latestVersion.Package+".sig",
gomock.Any(), ).Return([]byte("tgz_data_here"), nil)
).Return(bytes.NewReader([]byte("tgz_data_here")), nil)
client.EXPECT().Logout()
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
@ -288,9 +274,9 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
wg.Wait() wg.Wait()
} }
func newTestUpdater(client *mocks.MockClient, curVer string, earlyAccess bool) *Updater { func newTestUpdater(manager *mocks.MockManager, curVer string, earlyAccess bool) *Updater {
return New( return New(
&fakeClientProvider{client: client}, manager,
&fakeInstaller{}, &fakeInstaller{},
newFakeSettings(0.5, earlyAccess), newFakeSettings(0.5, earlyAccess),
nil, nil,
@ -299,14 +285,6 @@ func newTestUpdater(client *mocks.MockClient, curVer string, earlyAccess bool) *
) )
} }
type fakeClientProvider struct {
client *mocks.MockClient
}
func (p *fakeClientProvider) GetAnonymousClient() pmapi.Client {
return p.client
}
type fakeInstaller struct { type fakeInstaller struct {
bad bool bad bool
delay time.Duration delay time.Duration

View File

@ -133,7 +133,7 @@ func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil), m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
// store.New() in user.init // store.New() in user.init
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrInvalidToken), m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUnauthorized),
m.pmapiClient.EXPECT().Addresses().Return(nil), m.pmapiClient.EXPECT().Addresses().Return(nil),
// getAPIUser() loads user info from API (e.g. userID). // getAPIUser() loads user info from API (e.g. userID).

View File

@ -149,3 +149,13 @@ func (s *Credentials) Logout() {
func (s *Credentials) IsConnected() bool { func (s *Credentials) IsConnected() bool {
return s.APIToken != "" && s.MailboxPassword != "" return s.APIToken != "" && s.MailboxPassword != ""
} }
func (s *Credentials) SplitAPIToken() (string, string, error) {
split := strings.Split(s.APIToken, ":")
if len(split) != 2 {
return "", "", errors.New("malformed API token")
}
return split[0], split[1], nil
}

View File

@ -39,7 +39,7 @@ func NewStore(keychain *keychain.Keychain) *Store {
return &Store{secrets: keychain} return &Store{secrets: keychain}
} }
func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails []string) (creds *Credentials, err error) { func (s *Store) Add(userID, userName, uid, ref, mailboxPassword string, emails []string) (*Credentials, error) {
storeLocker.Lock() storeLocker.Lock()
defer storeLocker.Unlock() defer storeLocker.Unlock()
@ -49,10 +49,10 @@ func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails [
"emails": emails, "emails": emails,
}).Trace("Adding new credentials") }).Trace("Adding new credentials")
creds = &Credentials{ creds := &Credentials{
UserID: userID, UserID: userID,
Name: userName, Name: userName,
APIToken: apiToken, APIToken: uid + ":" + ref,
MailboxPassword: mailboxPassword, MailboxPassword: mailboxPassword,
IsHidden: false, IsHidden: false,
} }
@ -72,82 +72,82 @@ func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails [
creds.Timestamp = time.Now().Unix() creds.Timestamp = time.Now().Unix()
} }
if err = s.saveCredentials(creds); err != nil { if err := s.saveCredentials(creds); err != nil {
return return nil, err
} }
return creds, err return creds, nil
} }
func (s *Store) SwitchAddressMode(userID string) error { func (s *Store) SwitchAddressMode(userID string) (*Credentials, error) {
storeLocker.Lock() storeLocker.Lock()
defer storeLocker.Unlock() defer storeLocker.Unlock()
credentials, err := s.get(userID) credentials, err := s.get(userID)
if err != nil { if err != nil {
return err return nil, err
} }
credentials.IsCombinedAddressMode = !credentials.IsCombinedAddressMode credentials.IsCombinedAddressMode = !credentials.IsCombinedAddressMode
credentials.BridgePassword = generatePassword() credentials.BridgePassword = generatePassword()
return s.saveCredentials(credentials) return credentials, s.saveCredentials(credentials)
} }
func (s *Store) UpdateEmails(userID string, emails []string) error { func (s *Store) UpdateEmails(userID string, emails []string) (*Credentials, error) {
storeLocker.Lock() storeLocker.Lock()
defer storeLocker.Unlock() defer storeLocker.Unlock()
credentials, err := s.get(userID) credentials, err := s.get(userID)
if err != nil { if err != nil {
return err return nil, err
} }
credentials.SetEmailList(emails) credentials.SetEmailList(emails)
return s.saveCredentials(credentials) return credentials, s.saveCredentials(credentials)
} }
func (s *Store) UpdatePassword(userID, password string) error { func (s *Store) UpdatePassword(userID, password string) (*Credentials, error) {
storeLocker.Lock() storeLocker.Lock()
defer storeLocker.Unlock() defer storeLocker.Unlock()
credentials, err := s.get(userID) credentials, err := s.get(userID)
if err != nil { if err != nil {
return err return nil, err
} }
credentials.MailboxPassword = password credentials.MailboxPassword = password
return s.saveCredentials(credentials) return credentials, s.saveCredentials(credentials)
} }
func (s *Store) UpdateToken(userID, apiToken string) error { func (s *Store) UpdateToken(userID, uid, ref string) (*Credentials, error) {
storeLocker.Lock() storeLocker.Lock()
defer storeLocker.Unlock() defer storeLocker.Unlock()
credentials, err := s.get(userID) credentials, err := s.get(userID)
if err != nil { if err != nil {
return err return nil, err
} }
credentials.APIToken = apiToken credentials.APIToken = uid + ":" + ref
return s.saveCredentials(credentials) return credentials, s.saveCredentials(credentials)
} }
func (s *Store) Logout(userID string) error { func (s *Store) Logout(userID string) (*Credentials, error) {
storeLocker.Lock() storeLocker.Lock()
defer storeLocker.Unlock() defer storeLocker.Unlock()
credentials, err := s.get(userID) credentials, err := s.get(userID)
if err != nil { if err != nil {
return err return nil, err
} }
credentials.Logout() credentials.Logout()
return s.saveCredentials(credentials) return credentials, s.saveCredentials(credentials)
} }
// List returns a list of usernames that have credentials stored. // List returns a list of usernames that have credentials stored.
@ -249,7 +249,7 @@ func (s *Store) get(userID string) (creds *Credentials, err error) {
} }
// saveCredentials encrypts and saves password to the keychain store. // saveCredentials encrypts and saves password to the keychain store.
func (s *Store) saveCredentials(credentials *Credentials) (err error) { func (s *Store) saveCredentials(credentials *Credentials) error {
credentials.Version = keychain.Version credentials.Version = keychain.Version
return s.secrets.Put(credentials.UserID, credentials.Marshal()) return s.secrets.Put(credentials.UserID, credentials.Marshal())

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/users (interfaces: Locator,PanicHandler,ClientManager,CredentialsStorer,StoreMaker) // Source: github.com/ProtonMail/proton-bridge/internal/users (interfaces: Locator,PanicHandler,CredentialsStorer,StoreMaker)
// Package mocks is a generated GoMock package. // Package mocks is a generated GoMock package.
package mocks package mocks
@ -9,7 +9,6 @@ import (
store "github.com/ProtonMail/proton-bridge/internal/store" store "github.com/ProtonMail/proton-bridge/internal/store"
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials" credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
@ -85,109 +84,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
} }
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// AllowProxy mocks base method
func (m *MockClientManager) AllowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AllowProxy")
}
// AllowProxy indicates an expected call of AllowProxy
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
}
// CheckConnection mocks base method
func (m *MockClientManager) CheckConnection() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CheckConnection")
ret0, _ := ret[0].(error)
return ret0
}
// CheckConnection indicates an expected call of CheckConnection
func (mr *MockClientManagerMockRecorder) CheckConnection() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConnection", reflect.TypeOf((*MockClientManager)(nil).CheckConnection))
}
// DisallowProxy mocks base method
func (m *MockClientManager) DisallowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisallowProxy")
}
// DisallowProxy indicates an expected call of DisallowProxy
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
}
// GetAnonymousClient mocks base method
func (m *MockClientManager) GetAnonymousClient() pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAnonymousClient")
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetAnonymousClient indicates an expected call of GetAnonymousClient
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
}
// GetAuthUpdateChannel mocks base method
func (m *MockClientManager) GetAuthUpdateChannel() chan pmapi.ClientAuth {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthUpdateChannel")
ret0, _ := ret[0].(chan pmapi.ClientAuth)
return ret0
}
// GetAuthUpdateChannel indicates an expected call of GetAuthUpdateChannel
func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetAuthUpdateChannel))
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// MockCredentialsStorer is a mock of CredentialsStorer interface // MockCredentialsStorer is a mock of CredentialsStorer interface
type MockCredentialsStorer struct { type MockCredentialsStorer struct {
ctrl *gomock.Controller ctrl *gomock.Controller
@ -212,18 +108,18 @@ func (m *MockCredentialsStorer) EXPECT() *MockCredentialsStorerMockRecorder {
} }
// Add mocks base method // Add mocks base method
func (m *MockCredentialsStorer) Add(arg0, arg1, arg2, arg3 string, arg4 []string) (*credentials.Credentials, error) { func (m *MockCredentialsStorer) Add(arg0, arg1, arg2, arg3, arg4 string, arg5 []string) (*credentials.Credentials, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3, arg4) ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(*credentials.Credentials) ret0, _ := ret[0].(*credentials.Credentials)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// Add indicates an expected call of Add // Add indicates an expected call of Add
func (mr *MockCredentialsStorerMockRecorder) Add(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { func (mr *MockCredentialsStorerMockRecorder) Add(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockCredentialsStorer)(nil).Add), arg0, arg1, arg2, arg3, arg4) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockCredentialsStorer)(nil).Add), arg0, arg1, arg2, arg3, arg4, arg5)
} }
// Delete mocks base method // Delete mocks base method
@ -271,11 +167,12 @@ func (mr *MockCredentialsStorerMockRecorder) List() *gomock.Call {
} }
// Logout mocks base method // Logout mocks base method
func (m *MockCredentialsStorer) Logout(arg0 string) error { func (m *MockCredentialsStorer) Logout(arg0 string) (*credentials.Credentials, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Logout", arg0) ret := m.ctrl.Call(m, "Logout", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(*credentials.Credentials)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// Logout indicates an expected call of Logout // Logout indicates an expected call of Logout
@ -285,11 +182,12 @@ func (mr *MockCredentialsStorerMockRecorder) Logout(arg0 interface{}) *gomock.Ca
} }
// SwitchAddressMode mocks base method // SwitchAddressMode mocks base method
func (m *MockCredentialsStorer) SwitchAddressMode(arg0 string) error { func (m *MockCredentialsStorer) SwitchAddressMode(arg0 string) (*credentials.Credentials, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SwitchAddressMode", arg0) ret := m.ctrl.Call(m, "SwitchAddressMode", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(*credentials.Credentials)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// SwitchAddressMode indicates an expected call of SwitchAddressMode // SwitchAddressMode indicates an expected call of SwitchAddressMode
@ -299,11 +197,12 @@ func (mr *MockCredentialsStorerMockRecorder) SwitchAddressMode(arg0 interface{})
} }
// UpdateEmails mocks base method // UpdateEmails mocks base method
func (m *MockCredentialsStorer) UpdateEmails(arg0 string, arg1 []string) error { func (m *MockCredentialsStorer) UpdateEmails(arg0 string, arg1 []string) (*credentials.Credentials, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateEmails", arg0, arg1) ret := m.ctrl.Call(m, "UpdateEmails", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(*credentials.Credentials)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// UpdateEmails indicates an expected call of UpdateEmails // UpdateEmails indicates an expected call of UpdateEmails
@ -313,11 +212,12 @@ func (mr *MockCredentialsStorerMockRecorder) UpdateEmails(arg0, arg1 interface{}
} }
// UpdatePassword mocks base method // UpdatePassword mocks base method
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) error { func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) (*credentials.Credentials, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1) ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(*credentials.Credentials)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// UpdatePassword indicates an expected call of UpdatePassword // UpdatePassword indicates an expected call of UpdatePassword
@ -327,17 +227,18 @@ func (mr *MockCredentialsStorerMockRecorder) UpdatePassword(arg0, arg1 interface
} }
// UpdateToken mocks base method // UpdateToken mocks base method
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1 string) error { func (m *MockCredentialsStorer) UpdateToken(arg0, arg1, arg2 string) (*credentials.Credentials, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateToken", arg0, arg1) ret := m.ctrl.Call(m, "UpdateToken", arg0, arg1, arg2)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(*credentials.Credentials)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// UpdateToken indicates an expected call of UpdateToken // UpdateToken indicates an expected call of UpdateToken
func (mr *MockCredentialsStorerMockRecorder) UpdateToken(arg0, arg1 interface{}) *gomock.Call { func (mr *MockCredentialsStorerMockRecorder) UpdateToken(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateToken), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateToken), arg0, arg1, arg2)
} }
// MockStoreMaker is a mock of StoreMaker interface // MockStoreMaker is a mock of StoreMaker interface

View File

@ -20,14 +20,8 @@ package users
import ( import (
"github.com/ProtonMail/proton-bridge/internal/store" "github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/users/credentials" "github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
) )
type Configer interface {
GetAppVersion() string
GetAPIConfig() *pmapi.ClientConfig
}
type Locator interface { type Locator interface {
Clear() error Clear() error
} }
@ -38,25 +32,16 @@ type PanicHandler interface {
type CredentialsStorer interface { type CredentialsStorer interface {
List() (userIDs []string, err error) List() (userIDs []string, err error)
Add(userID, userName, apiToken, mailboxPassword string, emails []string) (*credentials.Credentials, error) Add(userID, userName, uid, ref, mailboxPassword string, emails []string) (*credentials.Credentials, error)
Get(userID string) (*credentials.Credentials, error) Get(userID string) (*credentials.Credentials, error)
SwitchAddressMode(userID string) error SwitchAddressMode(userID string) (*credentials.Credentials, error)
UpdateEmails(userID string, emails []string) error UpdateEmails(userID string, emails []string) (*credentials.Credentials, error)
UpdatePassword(userID, password string) error UpdatePassword(userID, password string) (*credentials.Credentials, error)
UpdateToken(userID, apiToken string) error UpdateToken(userID, uid, ref string) (*credentials.Credentials, error)
Logout(userID string) error Logout(userID string) (*credentials.Credentials, error)
Delete(userID string) error Delete(userID string) error
} }
type ClientManager interface {
GetClient(userID string) pmapi.Client
GetAnonymousClient() pmapi.Client
AllowProxy()
DisallowProxy()
GetAuthUpdateChannel() chan pmapi.ClientAuth
CheckConnection() error
}
type StoreMaker interface { type StoreMaker interface {
New(user store.BridgeUser) (*store.Store, error) New(user store.BridgeUser) (*store.Store, error)
Remove(userID string) error Remove(userID string) error

View File

@ -18,6 +18,7 @@
package users package users
import ( import (
"context"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@ -39,7 +40,7 @@ type User struct {
log *logrus.Entry log *logrus.Entry
panicHandler PanicHandler panicHandler PanicHandler
listener listener.Listener listener listener.Listener
clientManager ClientManager client pmapi.Client
credStorer CredentialsStorer credStorer CredentialsStorer
storeFactory StoreMaker storeFactory StoreMaker
@ -49,74 +50,75 @@ type User struct {
creds *credentials.Credentials creds *credentials.Credentials
lock sync.RWMutex lock sync.RWMutex
isAuthorized bool
useOnlyActiveAddresses bool
} }
// newUser creates a new user. // newUser creates a new user.
// The user is initially disconnected and must be connected by calling connect().
func newUser( func newUser(
panicHandler PanicHandler, panicHandler PanicHandler,
userID string, userID string,
eventListener listener.Listener, eventListener listener.Listener,
credStorer CredentialsStorer, credStorer CredentialsStorer,
clientManager ClientManager,
storeFactory StoreMaker, storeFactory StoreMaker,
) (u *User, err error) { useOnlyActiveAddresses bool,
) (*User, *credentials.Credentials, error) {
log := log.WithField("user", userID) log := log.WithField("user", userID)
log.Debug("Creating or loading user") log.Debug("Creating or loading user")
creds, err := credStorer.Get(userID) creds, err := credStorer.Get(userID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to load user credentials") return nil, nil, errors.Wrap(err, "failed to load user credentials")
} }
u = &User{ return &User{
log: log, log: log,
panicHandler: panicHandler, panicHandler: panicHandler,
listener: eventListener, listener: eventListener,
credStorer: credStorer, credStorer: credStorer,
clientManager: clientManager,
storeFactory: storeFactory, storeFactory: storeFactory,
userID: userID, userID: userID,
creds: creds, creds: creds,
useOnlyActiveAddresses: useOnlyActiveAddresses,
}, creds, nil
} }
return // connect connects a user. This includes
} // - providing it with an authorised API client
// - loading its credentials from the credentials store
// - loading and unlocking its PGP keys
// - loading its store
func (u *User) connect(ctx context.Context, client pmapi.Client, creds *credentials.Credentials) error {
u.log.Info("Connecting user")
func (u *User) client() pmapi.Client { // Connected users have an API client.
return u.clientManager.GetClient(u.userID) u.client = client
}
// init initialises a user. This includes reloading its credentials from the credentials store // FIXME(conman): How to remove this auth handler when user is disconnected?
// (such as when logging out and back in, you need to reload the credentials because the new credentials will u.client.AddAuthHandler(u.handleAuth)
// have the apitoken and password), authorising the user against the api, loading the user store (creating a new one
// if necessary), and setting the imap idle updates channel (used to send imap idle updates to the imap backend if
// something in the store changed).
func (u *User) init() (err error) {
u.log.Info("Initialising user")
// Reload the user's credentials (if they log out and back in we need the new // Save the latest credentials for the user.
// version with the apitoken and mailbox password).
creds, err := u.credStorer.Get(u.userID)
if err != nil {
return errors.Wrap(err, "failed to load user credentials")
}
u.creds = creds u.creds = creds
// Try to authorise the user if they aren't already authorised. // Connected users have unlocked keys.
// Note: we still allow users to set up accounts if the internet is off. // FIXME(conman): clients should always be authorized! This is a workaround to avoid a major refactor :(
if authErr := u.authorizeIfNecessary(false); authErr != nil { if u.creds.IsConnected() {
switch errors.Cause(authErr) { if err := u.client.Unlock(ctx, []byte(u.creds.MailboxPassword)); err != nil {
case pmapi.ErrAPINotReachable, pmapi.ErrUpgradeApplication, ErrLoggedOutUser: return err
u.log.WithError(authErr).Warn("Could not authorize user")
default:
if logoutErr := u.logout(); logoutErr != nil {
u.log.WithError(logoutErr).Warn("Could not logout user")
}
return errors.Wrap(authErr, "failed to authorize user")
} }
} }
// Connected users have a store.
if err := u.loadStore(); err != nil {
return err
}
return nil
}
func (u *User) loadStore() error {
// Logged-out user keeps store running to access offline data. // Logged-out user keeps store running to access offline data.
// Therefore it is necessary to close it before re-init. // Therefore it is necessary to close it before re-init.
if u.store != nil { if u.store != nil {
@ -125,93 +127,28 @@ func (u *User) init() (err error) {
} }
u.store = nil u.store = nil
} }
store, err := u.storeFactory.New(u) store, err := u.storeFactory.New(u)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to create store") return errors.Wrap(err, "failed to create store")
} }
u.store = store u.store = store
return err
}
// authorizeIfNecessary checks whether user is logged in and is connected to api auth channel.
// If user is not already connected to the api auth channel (for example there was no internet during start),
// it tries to connect it.
func (u *User) authorizeIfNecessary(emitEvent bool) (err error) {
// If user is connected and has an auth channel, then perfect, nothing to do here.
if u.creds.IsConnected() && u.isAuthorized {
// The keyring unlock is triggered here to resolve state where apiClient
// is authenticated (we have auth token) but it was not possible to download
// and unlock the keys (internet not reachable).
return u.unlockIfNecessary()
}
if !u.creds.IsConnected() {
err = ErrLoggedOutUser
} else if err = u.authorizeAndUnlock(); err != nil {
u.log.WithError(err).Error("Could not authorize and unlock user")
switch errors.Cause(err) {
case pmapi.ErrUpgradeApplication, pmapi.ErrAPINotReachable: // Ignore these errors.
default:
if errLogout := u.credStorer.Logout(u.userID); errLogout != nil {
u.log.WithField("err", errLogout).Error("Could not log user out from credentials store")
}
}
}
if emitEvent && err != nil &&
errors.Cause(err) != pmapi.ErrUpgradeApplication &&
errors.Cause(err) != pmapi.ErrAPINotReachable {
u.listener.Emit(events.LogoutEvent, u.userID)
}
return err
}
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
func (u *User) unlockIfNecessary() error {
if u.client().IsUnlocked() {
return nil return nil
} }
if err := u.client().Unlock([]byte(u.creds.MailboxPassword)); err != nil { func (u *User) handleAuth(auth *pmapi.Auth) error {
return errors.Wrap(err, "failed to unlock user")
}
return nil
}
// authorizeAndUnlock tries to authorize the user with the API using the the user's APIToken.
// If that succeeds, it tries to unlock the user's keys and addresses.
func (u *User) authorizeAndUnlock() (err error) {
if u.creds.APIToken == "" {
u.log.Warn("Could not connect to API auth channel, have no API token")
return nil
}
if _, err := u.client().AuthRefresh(u.creds.APIToken); err != nil {
return errors.Wrap(err, "failed to refresh API auth")
}
if err := u.client().Unlock([]byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to unlock user")
}
return nil
}
func (u *User) updateAuthToken(auth *pmapi.Auth) {
u.log.Debug("User received auth") u.log.Debug("User received auth")
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil { creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken)
u.log.WithError(err).Error("Failed to update refresh token in credentials store") if err != nil {
return return errors.Wrap(err, "failed to update refresh token in credentials store")
} }
u.refreshFromCredentials() u.creds = creds
u.isAuthorized = true return nil
} }
// clearStore removes the database. // clearStore removes the database.
@ -248,7 +185,7 @@ func (u *User) closeStore() error {
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations. // Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
// After proper refactor of SMTP and IMAP remove this method. // After proper refactor of SMTP and IMAP remove this method.
func (u *User) GetTemporaryPMAPIClient() pmapi.Client { func (u *User) GetTemporaryPMAPIClient() pmapi.Client {
return u.client() return u.client
} }
// ID returns the user's userID. // ID returns the user's userID.
@ -272,6 +209,10 @@ func (u *User) IsConnected() bool {
return u.creds.IsConnected() return u.creds.IsConnected()
} }
func (u *User) GetClient() pmapi.Client {
return u.client
}
// IsCombinedAddressMode returns whether user is set in combined or split mode. // IsCombinedAddressMode returns whether user is set in combined or split mode.
// Combined mode is the default mode and is what users typically need. // Combined mode is the default mode and is what users typically need.
// Split mode is mostly for outlook as it cannot handle sending e-mails from an // Split mode is mostly for outlook as it cannot handle sending e-mails from an
@ -345,7 +286,7 @@ func (u *User) GetAddressID(address string) (id string, err error) {
return u.store.GetAddressID(address) return u.store.GetAddressID(address)
} }
addresses := u.client().Addresses() addresses := u.client.Addresses()
pmapiAddress := addresses.ByEmail(address) pmapiAddress := addresses.ByEmail(address)
if pmapiAddress != nil { if pmapiAddress != nil {
return pmapiAddress.ID, nil return pmapiAddress.ID, nil
@ -366,18 +307,21 @@ func (u *User) GetBridgePassword() string {
// CheckBridgeLogin checks whether the user is logged in and the bridge // CheckBridgeLogin checks whether the user is logged in and the bridge
// IMAP/SMTP password is correct. // IMAP/SMTP password is correct.
func (u *User) CheckBridgeLogin(password string) error { func (u *User) CheckBridgeLogin(password string) error {
// FIXME(conman): Handle force upgrade?
/*
if isApplicationOutdated { if isApplicationOutdated {
u.listener.Emit(events.UpgradeApplicationEvent, "") u.listener.Emit(events.UpgradeApplicationEvent, "")
return pmapi.ErrUpgradeApplication return pmapi.ErrUpgradeApplication
} }
*/
u.lock.RLock() u.lock.RLock()
defer u.lock.RUnlock() defer u.lock.RUnlock()
// True here because users should be notified by popup of auth failure. if !u.creds.IsConnected() {
if err := u.authorizeIfNecessary(true); err != nil { u.listener.Emit(events.LogoutEvent, u.userID)
u.log.WithError(err).Error("Failed to authorize user") return ErrLoggedOutUser
return err
} }
return u.creds.CheckPassword(password) return u.creds.CheckPassword(password)
@ -388,60 +332,57 @@ func (u *User) UpdateUser() error {
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
if err := u.authorizeIfNecessary(true); err != nil { _, err := u.client.UpdateUser(context.TODO())
return errors.Wrap(err, "cannot update user")
}
_, err := u.client().UpdateUser()
if err != nil { if err != nil {
return err return err
} }
if err = u.client().ReloadKeys([]byte(u.creds.MailboxPassword)); err != nil { if err := u.client.ReloadKeys(context.TODO(), []byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to reload keys") return errors.Wrap(err, "failed to reload keys")
} }
emails := u.client().Addresses().ActiveEmails() creds, err := u.credStorer.UpdateEmails(u.userID, u.client.Addresses().ActiveEmails())
if err := u.credStorer.UpdateEmails(u.userID, emails); err != nil { if err != nil {
return err return err
} }
u.refreshFromCredentials() u.creds = creds
return nil return nil
} }
// SwitchAddressMode changes mode from combined to split and vice versa. The mode to switch to is determined by the // SwitchAddressMode changes mode from combined to split and vice versa. The mode to switch to is determined by the
// state of the user's credentials in the credentials store. See `IsCombinedAddressMode` for more details. // state of the user's credentials in the credentials store. See `IsCombinedAddressMode` for more details.
func (u *User) SwitchAddressMode() (err error) { func (u *User) SwitchAddressMode() error {
u.log.Trace("Switching user address mode") u.log.Trace("Switching user address mode")
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
u.CloseAllConnections() u.CloseAllConnections()
if u.store == nil { if u.store == nil {
err = errors.New("store is not initialised") return errors.New("store is not initialised")
return
} }
newAddressModeState := !u.IsCombinedAddressMode() newAddressModeState := !u.IsCombinedAddressMode()
if err = u.store.UseCombinedMode(newAddressModeState); err != nil { if err := u.store.UseCombinedMode(newAddressModeState); err != nil {
u.log.WithError(err).Error("Could not switch store address mode") return errors.Wrap(err, "could not switch store address mode")
return
} }
if u.creds.IsCombinedAddressMode != newAddressModeState { if u.creds.IsCombinedAddressMode == newAddressModeState {
if err = u.credStorer.SwitchAddressMode(u.userID); err != nil { return nil
u.log.WithError(err).Error("Could not switch credentials store address mode")
return
}
} }
u.refreshFromCredentials() creds, err := u.credStorer.SwitchAddressMode(u.userID)
if err != nil {
return errors.Wrap(err, "could not switch credentials store address mode")
}
return err u.creds = creds
return nil
} }
// logout is the same as Logout, but for internal purposes (logged out from // logout is the same as Logout, but for internal purposes (logged out from
@ -458,35 +399,37 @@ func (u *User) logout() error {
u.listener.Emit(events.UserRefreshEvent, u.userID) u.listener.Emit(events.UserRefreshEvent, u.userID)
} }
u.isAuthorized = false
return err return err
} }
// Logout logs out the user from pmapi, the credentials store, the mail store, and tries to remove as much // Logout logs out the user from pmapi, the credentials store, the mail store, and tries to remove as much
// sensitive data as possible. // sensitive data as possible.
func (u *User) Logout() (err error) { func (u *User) Logout() error {
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
u.log.Debug("Logging out user") u.log.Debug("Logging out user")
if !u.creds.IsConnected() { if !u.creds.IsConnected() {
return return nil
} }
u.client().Logout() // FIXME(conman): Do we delete API client now? Who cleans up? What about registered handlers?
if err := u.client.AuthDelete(context.TODO()); err != nil {
u.log.WithError(err).Warn("Failed to delete auth")
}
if err = u.credStorer.Logout(u.userID); err != nil { creds, err := u.credStorer.Logout(u.userID)
if err != nil {
u.log.WithError(err).Warn("Could not log user out from credentials store") u.log.WithError(err).Warn("Could not log user out from credentials store")
if err = u.credStorer.Delete(u.userID); err != nil { if err := u.credStorer.Delete(u.userID); err != nil {
u.log.WithError(err).Error("Could not delete user from credentials store") u.log.WithError(err).Error("Could not delete user from credentials store")
} }
} else {
u.creds = creds
} }
u.refreshFromCredentials()
// Do not close whole store, just event loop. Some information might be needed offline (e.g. addressID) // Do not close whole store, just event loop. Some information might be needed offline (e.g. addressID)
u.closeEventLoop() u.closeEventLoop()
@ -494,15 +437,7 @@ func (u *User) Logout() (err error) {
runtime.GC() runtime.GC()
return err return nil
}
func (u *User) refreshFromCredentials() {
if credentials, err := u.credStorer.Get(u.userID); err != nil {
log.WithError(err).Error("Cannot refresh user credentials")
} else {
u.creds = credentials
}
} }
func (u *User) closeEventLoop() { func (u *User) closeEventLoop() {

View File

@ -19,12 +19,15 @@
package users package users
import ( import (
"context"
"strings" "strings"
"sync" "sync"
"time"
"github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/events"
imapcache "github.com/ProtonMail/proton-bridge/internal/imap/cache" imapcache "github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/internal/metrics" "github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@ -45,7 +48,7 @@ type Users struct {
locations Locator locations Locator
panicHandler PanicHandler panicHandler PanicHandler
events listener.Listener events listener.Listener
clientManager ClientManager clientManager pmapi.Manager
credStorer CredentialsStorer credStorer CredentialsStorer
storeFactory StoreMaker storeFactory StoreMaker
@ -62,16 +65,13 @@ type Users struct {
useOnlyActiveAddresses bool useOnlyActiveAddresses bool
lock sync.RWMutex lock sync.RWMutex
// stopAll can be closed to stop all goroutines from looping (watchAppOutdated, watchAPIAuths, heartbeat etc).
stopAll chan struct{}
} }
func New( func New(
locations Locator, locations Locator,
panicHandler PanicHandler, panicHandler PanicHandler,
eventListener listener.Listener, eventListener listener.Listener,
clientManager ClientManager, clientManager pmapi.Manager,
credStorer CredentialsStorer, credStorer CredentialsStorer,
storeFactory StoreMaker, storeFactory StoreMaker,
useOnlyActiveAddresses bool, useOnlyActiveAddresses bool,
@ -87,57 +87,89 @@ func New(
storeFactory: storeFactory, storeFactory: storeFactory,
useOnlyActiveAddresses: useOnlyActiveAddresses, useOnlyActiveAddresses: useOnlyActiveAddresses,
lock: sync.RWMutex{}, lock: sync.RWMutex{},
stopAll: make(chan struct{}),
} }
// FIXME(conman): Handle force upgrade events.
/*
go func() { go func() {
defer panicHandler.HandlePanic() defer panicHandler.HandlePanic()
u.watchAppOutdated() u.watchAppOutdated()
}() }()
*/
go func() {
defer panicHandler.HandlePanic()
u.watchAPIAuths()
}()
if u.credStorer == nil { if u.credStorer == nil {
log.Error("No credentials store is available") log.Error("No credentials store is available")
} else if err := u.loadUsersFromCredentialsStore(); err != nil { } else if err := u.loadUsersFromCredentialsStore(context.TODO()); err != nil {
log.WithError(err).Error("Could not load all users from credentials store") log.WithError(err).Error("Could not load all users from credentials store")
} }
return u return u
} }
func (u *Users) loadUsersFromCredentialsStore() (err error) { func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
userIDs, err := u.credStorer.List() userIDs, err := u.credStorer.List()
if err != nil { if err != nil {
return return err
} }
for _, userID := range userIDs { for _, userID := range userIDs {
l := log.WithField("user", userID) user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
if err != nil {
user, newUserErr := newUser(u.panicHandler, userID, u.events, u.credStorer, u.clientManager, u.storeFactory) logrus.WithError(err).Warn("Could not create user, skipping")
if newUserErr != nil {
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
continue continue
} }
u.users = append(u.users, user) u.users = append(u.users, user)
if initUserErr := user.init(); initUserErr != nil { if creds.IsConnected() {
l.WithField("user", userID).WithError(initUserErr).Warn("Could not initialise user") if err := u.loadConnectedUser(ctx, user, creds); err != nil {
logrus.WithError(err).Warn("Could not load connected user")
}
} else {
logrus.Warn("User is disconnected and must be connected manually")
if err := u.loadDisconnectedUser(ctx, user, creds); err != nil {
logrus.WithError(err).Warn("Could not load disconnected user")
}
} }
} }
return err return err
} }
func (u *Users) loadDisconnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
// FIXME(conman): We shouldn't be creating unauthorized clients... this is hacky, just to avoid huge refactor!
return user.connect(ctx, u.clientManager.NewClient("", "", "", time.Time{}), creds)
}
func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
uid, ref, err := creds.SplitAPIToken()
if err != nil {
return errors.Wrap(err, "could not get user's refresh token")
}
client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref)
if err != nil {
// FIXME(conman): This is a problem... if we weren't able to create a new client due to internet,
// we need to be able to retry later, but I deleted all the hacky "retry auth if necessary" stuff...
return user.connect(ctx, u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
}
// Update the user's credentials with the latest auth used to connect this user.
if creds, err = u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil {
return errors.Wrap(err, "could not create get user's refresh token")
}
return user.connect(ctx, client, creds)
}
func (u *Users) watchAppOutdated() { func (u *Users) watchAppOutdated() {
// FIXME(conman): handle force upgrade events.
/*
ch := make(chan string) ch := make(chan string)
u.events.Add(events.UpgradeApplicationEvent, ch) u.events.Add(events.UpgradeApplicationEvent, ch)
@ -152,33 +184,7 @@ func (u *Users) watchAppOutdated() {
return return
} }
} }
} */
// watchAPIAuths receives auths from the client manager and sends them to the appropriate user.
func (u *Users) watchAPIAuths() {
for {
select {
case auth := <-u.clientManager.GetAuthUpdateChannel():
log.Debug("Users received auth from ClientManager")
user, ok := u.hasUser(auth.UserID)
if !ok {
log.WithField("userID", auth.UserID).Info("User not available for auth update")
continue
}
if auth.Auth != nil {
user.updateAuthToken(auth.Auth)
} else if err := user.logout(); err != nil {
log.WithError(err).
WithField("userID", auth.UserID).
Error("User logout failed while watching API auths")
}
case <-u.stopAll:
return
}
}
} }
func (u *Users) closeAllConnections() { func (u *Users) closeAllConnections() {
@ -192,63 +198,45 @@ func (u *Users) closeAllConnections() {
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) { func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
u.crashBandicoot(username) u.crashBandicoot(username)
// We need to use anonymous client because we don't yet have userID and so can't save auth tokens yet. return u.clientManager.NewClientWithLogin(context.TODO(), username, password)
authClient = u.clientManager.GetAnonymousClient()
authInfo, err := authClient.AuthInfo(username)
if err != nil {
log.WithField("username", username).WithError(err).Error("Could not get auth info for user")
return
}
if auth, err = authClient.Auth(username, password, authInfo); err != nil {
log.WithField("username", username).WithError(err).Error("Could not get auth for user")
return
}
return
} }
// FinishLogin finishes the login procedure and adds the user into the credentials store. // FinishLogin finishes the login procedure and adds the user into the credentials store.
func (u *Users) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPassphrase string) (user *User, err error) { //nolint[funlen] func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen]
defer func() { apiUser, passphrase, err := getAPIUser(context.TODO(), client, password)
if err != nil { if err != nil {
log.WithError(err).Debug("Login not finished; removing auth session") return nil, errors.Wrap(err, "failed to get API user")
if delAuthErr := authClient.DeleteAuth(); delAuthErr != nil {
log.WithError(delAuthErr).Error("Failed to clear login session after unlock")
} }
}
// The anonymous client will be removed from list and authentication will not be deleted.
authClient.Logout()
}()
apiUser, hashedPassphrase, err := getAPIUser(authClient, mbPassphrase) if user, ok := u.hasUser(apiUser.ID); ok {
if user.IsConnected() {
if err := client.AuthDelete(context.TODO()); err != nil {
logrus.WithError(err).Warn("Failed to delete new auth session")
}
return nil, errors.New("user is already connected")
}
// Update the user's credentials with the latest auth used to connect this user.
if _, err := u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil {
return nil, errors.Wrap(err, "failed to load user credentials")
}
// Update the password in case the user changed it.
creds, err := u.credStorer.UpdatePassword(apiUser.ID, string(passphrase))
if err != nil { if err != nil {
log.WithError(err).Error("Failed to get API user") return nil, errors.Wrap(err, "failed to update password of user in credentials store")
return
} }
log.Info("Got API user") if err := user.connect(context.TODO(), client, creds); err != nil {
return nil, errors.Wrap(err, "failed to reconnect existing user")
var ok bool
if user, ok = u.hasUser(apiUser.ID); ok {
if err = u.connectExistingUser(user, auth, hashedPassphrase); err != nil {
log.WithError(err).Error("Failed to connect existing user")
return
}
} else {
if err = u.addNewUser(apiUser, auth, hashedPassphrase); err != nil {
log.WithError(err).Error("Failed to add new user")
return
}
} }
// Old credentials use username as key (user ID) which needs to be removed return user, nil
// once user logs in again with proper ID fetched from API.
if _, ok := u.hasUser(apiUser.Name); ok {
if err := u.DeleteUser(apiUser.Name, true); err != nil {
log.WithError(err).Error("Failed to delete old user")
} }
if err := u.addNewUser(context.TODO(), client, apiUser, auth, passphrase); err != nil {
return nil, errors.Wrap(err, "failed to add new user")
} }
u.events.Emit(events.UserRefreshEvent, apiUser.ID) u.events.Emit(events.UserRefreshEvent, apiUser.ID)
@ -256,107 +244,63 @@ func (u *Users) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPassphr
return u.GetUser(apiUser.ID) return u.GetUser(apiUser.ID)
} }
// connectExistingUser connects an existing user.
func (u *Users) connectExistingUser(user *User, auth *pmapi.Auth, hashedPassphrase string) (err error) {
if user.IsConnected() {
return errors.New("user is already connected")
}
log.Info("Connecting existing user")
// Update the user's password in the cred store in case they changed it.
if err = u.credStorer.UpdatePassword(user.ID(), hashedPassphrase); err != nil {
return errors.Wrap(err, "failed to update password of user in credentials store")
}
client := u.clientManager.GetClient(user.ID())
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
return errors.Wrap(err, "failed to refresh auth token of new client")
}
if err = u.credStorer.UpdateToken(user.ID(), auth.GenToken()); err != nil {
return errors.Wrap(err, "failed to update token of user in credentials store")
}
if err = user.init(); err != nil {
return errors.Wrap(err, "failed to initialise user")
}
return
}
// addNewUser adds a new user. // addNewUser adds a new user.
func (u *Users) addNewUser(apiUser *pmapi.User, auth *pmapi.Auth, hashedPassphrase string) (err error) { func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
client := u.clientManager.GetClient(apiUser.ID) var emails []string
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
return errors.Wrap(err, "failed to refresh token in new client")
}
if apiUser, err = client.CurrentUser(); err != nil {
return errors.Wrap(err, "failed to update API user")
}
var emails []string //nolint[prealloc]
if u.useOnlyActiveAddresses { if u.useOnlyActiveAddresses {
emails = client.Addresses().ActiveEmails() emails = client.Addresses().ActiveEmails()
} else { } else {
emails = client.Addresses().AllEmails() emails = client.Addresses().AllEmails()
} }
if _, err = u.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), hashedPassphrase, emails); err != nil { if _, err := u.credStorer.Add(apiUser.ID, apiUser.Name, auth.UID, auth.RefreshToken, string(passphrase), emails); err != nil {
return errors.Wrap(err, "failed to add user to credentials store") return errors.Wrap(err, "failed to add user credentials to credentials store")
} }
user, err := newUser(u.panicHandler, apiUser.ID, u.events, u.credStorer, u.clientManager, u.storeFactory) user, creds, err := newUser(u.panicHandler, apiUser.ID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to create user") return errors.Wrap(err, "failed to create new user")
} }
// The user needs to be part of the users list in order for it to receive an auth during initialisation. if err := user.connect(ctx, client, creds); err != nil {
u.users = append(u.users, user) return errors.Wrap(err, "failed to connect new user")
if err = user.init(); err != nil {
u.users = u.users[:len(u.users)-1]
return errors.Wrap(err, "failed to initialise user")
} }
if err := u.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel)); err != nil { if err := u.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel)); err != nil {
log.WithError(err).Error("Failed to send metric") log.WithError(err).Error("Failed to send metric")
} }
return err u.users = append(u.users, user)
return nil
} }
func getAPIUser(client pmapi.Client, mbPassphrase string) (user *pmapi.User, hashedPassphrase string, err error) { func getAPIUser(ctx context.Context, client pmapi.Client, password string) (*pmapi.User, []byte, error) {
salt, err := client.AuthSalt() salt, err := client.AuthSalt(ctx)
if err != nil { if err != nil {
log.WithError(err).Error("Could not get salt") return nil, nil, errors.Wrap(err, "failed to get salt")
return nil, "", err
} }
hashedPassphrase, err = pmapi.HashMailboxPassword(mbPassphrase, salt) passphrase, err := pmapi.HashMailboxPassword(password, salt)
if err != nil { if err != nil {
log.WithError(err).Error("Could not hash mailbox password") return nil, nil, errors.Wrap(err, "failed to hash password")
return nil, "", err
} }
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong. // We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
if err = client.Unlock([]byte(hashedPassphrase)); err != nil { if err := client.Unlock(ctx, passphrase); err != nil {
log.WithError(err).Error("Wrong mailbox password") return nil, nil, errors.Wrap(err, "failed to unlock client")
return nil, "", ErrWrongMailboxPassword
} }
if user, err = client.CurrentUser(); err != nil { user, err := client.CurrentUser(ctx)
log.WithError(err).Error("Could not load user data") if err != nil {
return nil, "", err return nil, nil, errors.Wrap(err, "failed to load user data")
} }
return user, hashedPassphrase, nil return user, passphrase, nil
} }
// GetUsers returns all added users into keychain (even logged out users). // GetUsers returns all added users into keychain (even logged out users).
@ -452,11 +396,9 @@ func (u *Users) DeleteUser(userID string, clearStore bool) error {
// SendMetric sends a metric. We don't want to return any errors, only log them. // SendMetric sends a metric. We don't want to return any errors, only log them.
func (u *Users) SendMetric(m metrics.Metric) error { func (u *Users) SendMetric(m metrics.Metric) error {
c := u.clientManager.GetAnonymousClient()
defer c.Logout()
cat, act, lab := m.Get() cat, act, lab := m.Get()
if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil {
if err := u.clientManager.SendSimpleMetric(context.Background(), string(cat), string(act), string(lab)); err != nil {
return err return err
} }
@ -472,24 +414,22 @@ func (u *Users) SendMetric(m metrics.Metric) error {
// AllowProxy instructs the app to use DoH to access an API proxy if necessary. // AllowProxy instructs the app to use DoH to access an API proxy if necessary.
// It also needs to work before the app is initialised (because we may need to use the proxy at startup). // It also needs to work before the app is initialised (because we may need to use the proxy at startup).
func (u *Users) AllowProxy() { func (u *Users) AllowProxy() {
u.clientManager.AllowProxy() // FIXME(conman): Support DoH.
// u.apiManager.AllowProxy()
} }
// DisallowProxy instructs the app to not use DoH to access an API proxy if necessary. // DisallowProxy instructs the app to not use DoH to access an API proxy if necessary.
// It also needs to work before the app is initialised (because we may need to use the proxy at startup). // It also needs to work before the app is initialised (because we may need to use the proxy at startup).
func (u *Users) DisallowProxy() { func (u *Users) DisallowProxy() {
u.clientManager.DisallowProxy() // FIXME(conman): Support DoH.
// u.apiManager.DisallowProxy()
} }
// CheckConnection returns whether there is an internet connection. // CheckConnection returns whether there is an internet connection.
// This should use the connection manager when it is eventually implemented. // This should use the connection manager when it is eventually implemented.
func (u *Users) CheckConnection() error { func (u *Users) CheckConnection() error {
return u.clientManager.CheckConnection() // FIXME(conman): Other parts of bridge that rely on this method should register as a connection observer.
} panic("TODO: register as a connection observer to get this information")
// StopWatchers stops all goroutines.
func (u *Users) StopWatchers() {
close(u.stopAll)
} }
// hasUser returns whether the struct currently has a user with ID `id`. // hasUser returns whether the struct currently has a user with ID `id`.

View File

@ -20,8 +20,8 @@ package users
import ( import (
"errors" "errors"
"testing" "testing"
time "time"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials" "github.com/ProtonMail/proton-bridge/internal/users/credentials"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -49,20 +49,19 @@ func TestNewUsersWithDisconnectedUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
// Basically every call client has get client manager.
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder( gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil), m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil), m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")), m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil), m.pmapiClient.EXPECT().Addresses().Return(nil),
) )
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected}) checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
} }
/*
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) { func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
@ -132,6 +131,7 @@ func TestNewUsersFirstStart(t *testing.T) {
testNewUsers(t, m) testNewUsers(t, m)
} }
*/
func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) { func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) {
users := testNewUsers(t, m) users := testNewUsers(t, m)

View File

@ -48,18 +48,17 @@ func TestMain(m *testing.M) {
} }
var ( var (
testAuth = &pmapi.Auth{ //nolint[gochecknoglobals]
RefreshToken: "tok",
}
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals] testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
RefreshToken: "reftok", UID: "uid",
AccessToken: "acc",
RefreshToken: "ref",
} }
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals] testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "user", UserID: "user",
Name: "username", Name: "username",
Emails: "user@pm.me", Emails: "user@pm.me",
APIToken: "token", APIToken: "uid:acc",
MailboxPassword: "pass", MailboxPassword: "pass",
BridgePassword: "0123456789abcdef", BridgePassword: "0123456789abcdef",
Version: "v1", Version: "v1",
@ -67,11 +66,12 @@ var (
IsHidden: false, IsHidden: false,
IsCombinedAddressMode: true, IsCombinedAddressMode: true,
} }
testCredentialsSplit = &credentials.Credentials{ //nolint[gochecknoglobals] testCredentialsSplit = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "users", UserID: "users",
Name: "usersname", Name: "usersname",
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me", Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
APIToken: "token", APIToken: "uid:acc",
MailboxPassword: "pass", MailboxPassword: "pass",
BridgePassword: "0123456789abcdef", BridgePassword: "0123456789abcdef",
Version: "v1", Version: "v1",
@ -79,6 +79,7 @@ var (
IsHidden: false, IsHidden: false,
IsCombinedAddressMode: false, IsCombinedAddressMode: false,
} }
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals] testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "user", UserID: "user",
Name: "username", Name: "username",
@ -92,6 +93,19 @@ var (
IsCombinedAddressMode: true, IsCombinedAddressMode: true,
} }
testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "users",
Name: "usersname",
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
APIToken: "",
MailboxPassword: "",
BridgePassword: "0123456789abcdef",
Version: "v1",
Timestamp: 123456789,
IsHidden: false,
IsCombinedAddressMode: false,
}
testPMAPIUser = &pmapi.User{ //nolint[gochecknoglobals] testPMAPIUser = &pmapi.User{ //nolint[gochecknoglobals]
ID: "user", ID: "user",
Name: "username", Name: "username",
@ -130,11 +144,11 @@ type mocks struct {
ctrl *gomock.Controller ctrl *gomock.Controller
locator *usersmocks.MockLocator locator *usersmocks.MockLocator
PanicHandler *usersmocks.MockPanicHandler PanicHandler *usersmocks.MockPanicHandler
clientManager *usersmocks.MockClientManager
credentialsStore *usersmocks.MockCredentialsStorer credentialsStore *usersmocks.MockCredentialsStorer
storeMaker *usersmocks.MockStoreMaker storeMaker *usersmocks.MockStoreMaker
eventListener *MockListener eventListener *MockListener
clientManager *pmapimocks.MockManager
pmapiClient *pmapimocks.MockClient pmapiClient *pmapimocks.MockClient
storeCache *store.Cache storeCache *store.Cache
@ -171,11 +185,11 @@ func initMocks(t *testing.T) mocks {
ctrl: mockCtrl, ctrl: mockCtrl,
locator: usersmocks.NewMockLocator(mockCtrl), locator: usersmocks.NewMockLocator(mockCtrl),
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl), PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
clientManager: usersmocks.NewMockClientManager(mockCtrl),
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl), credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl), storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
eventListener: NewMockListener(mockCtrl), eventListener: NewMockListener(mockCtrl),
clientManager: pmapimocks.NewMockManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl), pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()), storeCache: store.NewCache(cacheFile.Name()),
@ -189,7 +203,7 @@ func initMocks(t *testing.T) mocks {
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests. var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db") dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
require.NoError(t, err, "could not get temporary file for store db") require.NoError(t, err, "could not get temporary file for store db")
return store.New(sentryReporter, m.PanicHandler, user, m.clientManager, m.eventListener, dbFile.Name(), m.storeCache) return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
}).AnyTimes() }).AnyTimes()
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes() m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
@ -198,46 +212,42 @@ func initMocks(t *testing.T) mocks {
func testNewUsersWithUsers(t *testing.T, m mocks) *Users { func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
// Events are asynchronous // Events are asynchronous
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).Times(2) m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).Times(2)
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2) m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2) m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
gomock.InOrder( gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil), m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil),
// Init for user. // Init for user.
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.credentialsStore.EXPECT().Get(testCredentials.UserID).Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil), m.credentialsStore.EXPECT().UpdateToken(testCredentials.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), m.credentialsStore.EXPECT().UpdatePassword(testCredentials.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Init for users. // Init for users.
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil), m.credentialsStore.EXPECT().Get(testCredentialsSplit.UserID).Return(testCredentialsSplit, nil),
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil), m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil), m.credentialsStore.EXPECT().UpdateToken(testCredentialsSplit.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentialsSplit, nil),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), m.credentialsStore.EXPECT().UpdatePassword(testCredentialsSplit.UserID, testCredentialsSplit.MailboxPassword).Return(testCredentialsSplit, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses), m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses),
) )
users := testNewUsers(t, m) return testNewUsers(t, m)
user, _ := users.GetUser("user")
mockAuthUpdate(user, "reftok", m)
user, _ = users.GetUser("user")
mockAuthUpdate(user, "reftok", m)
return users
} }
func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam] func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any()) // FIXME(conman): How to handle force upgrade?
m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan pmapi.ClientAuth)) // m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true) users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true)
@ -256,8 +266,8 @@ func TestClearData(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) // m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) // m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
users := testNewUsersWithUsers(t, m) users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users) defer cleanUpUsersData(users)
@ -267,13 +277,11 @@ func TestClearData(t *testing.T) {
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me") m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
m.pmapiClient.EXPECT().Logout() m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
m.credentialsStore.EXPECT().Logout("user").Return(nil) m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().Logout() m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
m.credentialsStore.EXPECT().Logout("users").Return(nil) m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil)
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil)
m.locator.EXPECT().Clear() m.locator.EXPECT().Clear()
@ -285,9 +293,9 @@ func TestClearData(t *testing.T) {
func mockEventLoopNoAction(m mocks) { func mockEventLoopNoAction(m mocks) {
// Set up mocks for starting the store's event loop (in store.New). // Set up mocks for starting the store's event loop (in store.New).
// The event loop runs in another goroutine so this might happen at any time. // The event loop runs in another goroutine so this might happen at any time.
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).AnyTimes() m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).AnyTimes()
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes() m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
} }
func mockConnectedUser(m mocks) { func mockConnectedUser(m mocks) {
@ -295,27 +303,13 @@ func mockConnectedUser(m mocks) {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil), // m.pmapiClient.EXPECT().AuthRefresh("uid:acc").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil), m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
// Set up mocks for store initialisation for the authorized user. // Set up mocks for store initialisation for the authorized user.
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil), m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil), m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
) )
} }
// mockAuthUpdate simulates users calling UpdateAuthToken on the given user.
// This would normally be done by users when it receives an auth from the ClientManager,
// but as we don't have a full users instance here, we do this manually.
func mockAuthUpdate(user *User, token string, m mocks) {
gomock.InOrder(
m.credentialsStore.EXPECT().UpdateToken("user", ":"+token).Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(token), nil),
)
user.updateAuthToken(refreshWithToken(token))
waitForEvents()
}

View File

@ -1,23 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
// IsAuthorized returns whether the user has received an Auth from the API yet.
func (u *User) IsAuthorized() bool {
return u.isAuthorized
}

View File

@ -40,8 +40,8 @@ type Builder struct {
} }
type Fetcher interface { type Fetcher interface {
GetMessage(string) (*pmapi.Message, error) GetMessage(context.Context, string) (*pmapi.Message, error)
GetAttachment(string) (io.ReadCloser, error) GetAttachment(context.Context, string) (io.ReadCloser, error)
KeyRingForAddressID(string) (*crypto.KeyRing, error) KeyRingForAddressID(string) (*crypto.KeyRing, error)
} }

View File

@ -1,309 +0,0 @@
# Do not modify this file!
It is here for historical reasons only. All changes should be documented in the
Changelog at the root of this repository.
# Changelog for API
> NOTE we are using versioning for go-pmapi in format `major.minor.bugfix`
> * major stays at version 1 for the forseeable future
> * minor is increased when a force upgrade happens or in case of major breaking changes
> * patch is increased when new features are added
## v1.0.16
### Fixed
* Potential crash when reporting cert pin failure
## v1.0.15
### Changed
* Merge only 50 events into one
* Response header timeout increased from 10s to 30s
### Fixed
* Make keyring unlocking threadsafe
## v1.0.14
### Added
* Config for disabling TLS cert fingerprint checking
### Fixed
* Ensure sensitive stuff is cleared on client logout even if requests fail
## v1.0.13
### Fixed
* Correctly set Transport in http client
## v1.0.12
### Changed
* Only `http.RoundTripper` interface is needed instead of full `http.Transport` struct
### Added
* GODT-61 (and related): Use DoH to find and switch to a proxy server if the API becomes unreachable
* GODT-67 added random wait to not cause spikes on server after StatusTooManyRequests
### Fixed
* FirstReadTimeout was wrongly timeout of the whole request including repeating ones, now it's really only timeout for the first read
## v1.0.11
### Added
* GODT-53 `Message.Type` added with constants `MessageType*`
## v1.0.10
### Added
* GODT-55 exporting DANGEROUSLYSetUID
### Changed
* The full communication between clien and API is logged if logrus level is trace
## v1.0.9
### Fixed
* Use correct address type value (because API starts counting from 1 but we were counting from 0)
## v1.0.8
### Added
* Introdcution of connection manager
### Fixed
* Deadlock during the auth-refresh
* Fixed an issue where some events were being discarded when merging
## v1.0.7
### Changed
* The given access token is saved during auth refresh if none was available yet
## v1.0.6
### Added
* `ClientConfig.Timeout` to be able to configure the whole timeout of request
* `ClientConfig.FirstReadTimeout` to be able to configure the timeout of request to the first byte
* `ClientConfig.MinSpeed` to be able to configure the timeout when the connection is too slow (limitation in minimum bytes per second)
* Set default timeouts for http.Transport with certificate pinning
### Changed
* http.Client by default uses ProxyFromEnvironment to support HTTP_PROXY and HTTPS_PROXY environment variables
## v1.0.5
### Added
* `ContentTypeMultipartEncrypted` MIME content type for encrypted email
* `MessageCounts` in event struct
## v1.0.4
### Added
* `PMKeys` for parsing and reading KeyRing
* `clearableKey` to rewrite memory
* Proton/backend-communication#25 Unlock with tokens (OneKey2RuleThemAll Phase I)
### Changed
* Update of gopenpgp: convert JSON to KeyRing in PMAPI
* `user.KeyRing` -> `user.KeyRing()`
* typo `client.GetAddresses()`
### Removed
* `address.KeyRing`
## v1.0.2 v1.0.3
### Changed
* Fixed capitalisation in a few places
* Added /metrics API route
* Changed function names to be compliant with go linter
* Encrypt with primary key only
* Fix `client.doBuffered` - closing body before handling unauthorized request
* go-pm-crypto -> GopenPGP
* redefine old functions in `keyring.go`
* `attachment.Decrypt` drops returning signature (does signature check by default)
* `attachment.Encrypt` is using readers instead of writers
* `attachment.DetachedSign` drops writer param and returns signature as a reader
* `message.Decrypt` drops returning signature (does signature check by default)
* Changed TLS report URL to https://reports.protonmail.ch/reports/tls
* Moved from current to soon TLS pin
## v1.0.1
### Removed
* `ClientID` from all auth routes
* `ErrorDescription` from error
## v1.0.0
### Changed
* `client.AuthInfo` does return 2FA information only when authenticated, for the first login information available in `Auth.HasTwoFactor`
* `client.Auth` does not accept 2FA code in favor of `client.Auth2FA`
* `client.Unlock` supports only new way of unlock with directly available access token
### Added
* `Res.StatusCode` to pass HTTP status code to responses
* `Auth.HasTwoFactor` method to determine whether account has enabled 2FA (same as `AuthInfo.HasTwoFactor`)
* `Auth2FA*` structs for 2FA endpoint
* `client.Auth2FA` method to fully unlock session with 2FA code
* `ErrUnauthorized` when request cannot be authorized
* `ErrBad2FACode` when bad 2FA and user cannot try again
* `ErrBad2FACodeTryAgain` when bad 2FA but user can try again
## 2019-08-06
### Added
* Send TLS issue report to API
* Cert fingerpring with `TLSPinning` struct
* Check API certificate fingerprint and verify hostname
### Changed
* Using `AddressID` for `/messge/count` and `/conversations/count`
* Less of copying of responses from the server in the memory
## 2019-08-01
* low case for `sirupsen`
* using go modules
## 2019-07-15
### Changed
* `client.Auths` field is removed in favor of function `client.SetAuths` which opens possibility to use interface
## 2019-05-18
### Changed
* proton/backend-communication#11 x-pm-uid sent always for `/auth/refresh`
* proton/backend-communication#11 UID never changes
## 2019-05-28
### Added
* New test server patern using callbacks
* Responses are read from json files
### Changed
* `auth_tests.go` to new callback server pattern
* Linter fixes for tests
### Removed
* `TestClient_Do_expired` due to no effect, use `DoUnauthorized` instead
## 2019-05-24
* Help functions for test
* CI with Lint
## 2019-05-23
* Log userID
## 2019-05-21
* Fix unlocking user keys
## 2019-04-25
### Changed
* rename `Uid` -> `UID` proton/backend-communication#11
## 2019-04-09
### Added
* sending attachments as zip `application/octet-stream`
* function `ReportReq.AddAttachment()`
* data memeber `ReportReq.Attachments`
* general function to report bug `client.Report(req ReportReq)` with object as parameter
### Changed
* `client.ReportBug` and `client.ReportBugWithClient` functions are obsolete and they uses `client.Report(req ReportReq)`
* `client.ReportCrash` is obsolete. Use sentry instead
* `Api`->`API`, `Uid`->`UID`
## 2019-03-13
* user id in raven
* add file position of panic sender
## 2019-03-06
* #30 update `pm-crypto` to store `KeyRing.FirstKeyID`
* #30 Add key salt to `Auth` object from `GetKeySalts` request
* #30 Add route `GET /keys/salt`
* removed unused `PmCrypto`
## 2019-02-20
* removed unused `decryptAccessToken`
## 2019-01-21
* #29 Parsing all goroutines from pprof
* #29 Sentry `Threads` implementation
* #29 using sentry for crashes
## 2019-01-07
* refactor `pmapi.DecryptString` -> `pmcrypto.KeyRing.DecryptString`
* fixed tests
* `crypto` -> `pmcrypto`
* refactoring code using repos `go-pm-crypto`, `go-pm-mime` and `go-srp`
## 2018-12-10
* #26 adding `Flags` field to message
* #26 removing fields deprecated by `Flags`: `IsEncrypted`, `Type`, `IsReplied`, `IsRepliedAll`, `IsForwarded`
* #26 removing deprecated consts (see #26 for replacement)
* #26 fixing tests (compiling not working)
## 2018-11-19
### Added
* Wait and retry from `DoJson` if banned from api
### Changed
* `ErrNoInternet` -> `ErrAPINotReachable`
* Adding codes for force upgrade: 5004 and 5005
* Adding codes for API offline: 7001
* Adding codes for BansRequests: 85131
## 2018-09-18
### Added
* `client.decryptAccessToken` if privateKey is received (tested with local api) #23
### Changed
* added fields to User
* local config TLS skip verify
## 2018-09-06
### Changed
* decrypt token only if needed
### Broken
* Tests are not working
## APIv3 UPDATE (2018-08-01)
* issue Desktop-Bridge#561
### Added
* Key flag consts
* `EventAddress`
* `MailSettings` object and route call
* `Client.KeyRingForAddressID`
* `AuthInfo.HasTwoFactor()`
* `Auth.HasMailboxPassword()`
### Changed
* Addresses are part of client
* Update user updates also addresses
* `BodyKey` and `AttachmentKey` contains `Key` and `Algorithm`
* `keyPair` (not use Pubkey) -> `pmKeyObject`
* lots of indent
* bugs route
* two factor (ready to U2F)
* Reorder some to match order in doc (easier to )
* omit address Order when empty
* update user and addresses in `CurrentUser()`
* `User.Unlock()` -> `Client.UnlockAddresses()`
* `AuthInfo.Uid` -> `AuthInfo.Uid()`
* `User.Addresses` -> `Client.Addresses()`
### Removed
* User v3 removed plenty (now in settings)
* Message v3 removed plenty (Starred is label)

View File

@ -1,19 +0,0 @@
export GO111MODULE=on
LINTVER="v1.21.0"
LINTSRC="https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh"
check-has-go:
@which go || (echo "Install Go-lang!" && exit 1)
install-dev-dependencies: install-linter
install-linter: check-has-go
curl -sfL $(LINTSRC) | sh -s -- -b $(shell go env GOPATH)/bin $(LINTVER)
lint:
which golangci-lint || $(MAKE) install-linter
golangci-lint run ./... \
test:
go test -run=${TESTRUN} ./...

View File

@ -18,10 +18,12 @@
package pmapi package pmapi
import ( import (
"context"
"errors" "errors"
"strings" "strings"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
) )
// Address statuses. // Address statuses.
@ -80,11 +82,6 @@ type Address struct {
// AddressList is a list of addresses. // AddressList is a list of addresses.
type AddressList []*Address type AddressList []*Address
type AddressesRes struct {
Res
Addresses AddressList
}
// ByID returns an address by id. Returns nil if no address is found. // ByID returns an address by id. Returns nil if no address is found.
func (l AddressList) ByID(id string) *Address { func (l AddressList) ByID(id string) *Address {
for _, addr := range l { for _, addr := range l {
@ -164,40 +161,22 @@ func ConstructAddress(headerEmail string, addressEmail string) string {
} }
// GetAddresses requests all of current user addresses (without pagination). // GetAddresses requests all of current user addresses (without pagination).
func (c *client) GetAddresses() (addresses AddressList, err error) { func (c *client) GetAddresses(ctx context.Context) (addresses AddressList, err error) {
req, err := c.NewRequest("GET", "/addresses", nil) var res struct {
if err != nil { Addresses []*Address
return
} }
var res AddressesRes if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
if err = c.DoJSON(req, &res); err != nil { return r.SetResult(&res).Get("/addresses")
return }); err != nil {
return nil, err
} }
return res.Addresses, res.Err() return res.Addresses, nil
} }
func (c *client) ReorderAddresses(addressIDs []string) (err error) { func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) (err error) {
var reqBody struct { panic("TODO")
AddressIDs []string
}
reqBody.AddressIDs = addressIDs
req, err := c.NewJSONRequest("PUT", "/addresses/order", reqBody)
if err != nil {
return
}
var addContactsRes AddContactsResponse
if err = c.DoJSON(req, &addContactsRes); err != nil {
return
}
_, err = c.UpdateUser()
return
} }
// Addresses returns the addresses stored in the client object itself rather than fetching from the API. // Addresses returns the addresses stored in the client object itself rather than fetching from the API.

View File

@ -21,13 +21,13 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
"net/textproto" "net/textproto"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
) )
type header textproto.MIMEHeader type header textproto.MIMEHeader
@ -138,12 +138,6 @@ func (a *Attachment) DetachedSign(kr *crypto.KeyRing, att io.Reader) (signed io.
return signAttachment(kr, att) return signAttachment(kr, att)
} }
type CreateAttachmentRes struct {
Res
Attachment *Attachment
}
func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.Reader) (err error) { func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.Reader) (err error) {
// Create metadata fields. // Create metadata fields.
if err = w.WriteField("Filename", att.Name); err != nil { if err = w.WriteField("Filename", att.Name); err != nil {
@ -185,91 +179,37 @@ func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.R
// CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID. // CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID.
// //
// The returned created attachment contains the new attachment ID and its size. // The returned created attachment contains the new attachment ID and its size.
func (c *client) CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (*Attachment, error) { func (c *client) CreateAttachment(ctx context.Context, att *Attachment, attData io.Reader, sigData io.Reader) (*Attachment, error) {
req, w, err := c.NewMultipartRequest("POST", "/mail/v4/attachments") var res struct {
if err != nil { Attachment *Attachment
return nil, err
} }
cx, cancel := context.WithCancel(req.Context()) if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
defer cancel() return r.SetResult(&res).
req = req.WithContext(cx) SetMultipartFormData(map[string]string{
"Filename": att.Name,
// We will write the request as long as it is sent to the API. "MessageID": att.MessageID,
var res CreateAttachmentRes "MIMEType": att.MIMEType,
done := make(chan error, 1) "ContentID": att.ContentID,
go (func() { }).
done <- c.DoJSON(req, &res) SetMultipartField("DataPacket", "DataPacket.pgp", "application/octet-stream", attData).
})() SetMultipartField("Signature", "Signature.pgp", "application/octet-stream", sigData).
Post("/mail/v4/attachments")
if err := writeAttachment(w.Writer, att, r, sig); err != nil { }); err != nil {
_ = w.Close()
return nil, err
}
_ = w.Close()
if err := <-done; err != nil {
return nil, err
}
if err := res.Err(); err != nil {
return nil, err return nil, err
} }
return res.Attachment, nil return res.Attachment, nil
} }
type UpdateAttachmentSignatureReq struct {
Signature string
}
func (c *client) UpdateAttachmentSignature(attachmentID, signature string) (err error) {
updateReq := &UpdateAttachmentSignatureReq{signature}
req, err := c.NewJSONRequest("PUT", "/mail/v4/attachments/"+attachmentID+"/signature", updateReq)
if err != nil {
return
}
var res Res
if err = c.DoJSON(req, &res); err != nil {
return
}
return
}
// DeleteAttachment removes an attachment. message is the message ID, att is the attachment ID.
func (c *client) DeleteAttachment(attID string) (err error) {
req, err := c.NewRequest("DELETE", "/mail/v4/attachments/"+attID, nil)
if err != nil {
return
}
var res Res
if err = c.DoJSON(req, &res); err != nil {
return
}
err = res.Err()
return
}
// GetAttachment gets an attachment's content. The returned data is encrypted. // GetAttachment gets an attachment's content. The returned data is encrypted.
func (c *client) GetAttachment(id string) (att io.ReadCloser, err error) { func (c *client) GetAttachment(ctx context.Context, attachmentID string) (att io.ReadCloser, err error) {
if id == "" { res, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
err = errors.New("pmapi: cannot get an attachment with an empty id") return r.SetDoNotParseResponse(true).Get("/mail/v4/attachments/" + attachmentID)
return })
}
req, err := c.NewRequest("GET", "/mail/v4/attachments/"+id, nil)
if err != nil { if err != nil {
return return nil, err
} }
res, err := c.Do(req, true) return res.RawBody(), nil
if err != nil {
return
}
att = res.Body
return
} }

View File

@ -19,6 +19,7 @@ package pmapi
import ( import (
"bytes" "bytes"
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -94,7 +95,7 @@ func TestAttachment_UnmarshalJSON(t *testing.T) {
} }
func TestClient_CreateAttachment(t *testing.T) { func TestClient_CreateAttachment(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/attachments")) Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/attachments"))
contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type")) contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type"))
@ -136,12 +137,14 @@ func TestClient_CreateAttachment(t *testing.T) {
t.Errorf("Invalid attachment packets: expected %v but got %v", testAttachment.KeyPackets, string(b)) t.Errorf("Invalid attachment packets: expected %v but got %v", testAttachment.KeyPackets, string(b))
} }
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testCreateAttachmentBody) fmt.Fprint(w, testCreateAttachmentBody)
})) }))
defer s.Close() defer s.Close()
r := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted r := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted
created, err := c.CreateAttachment(testAttachment, r, strings.NewReader("")) created, err := c.CreateAttachment(context.TODO(), testAttachment, r, strings.NewReader(""))
if err != nil { if err != nil {
t.Fatal("Expected no error while creating attachment, got:", err) t.Fatal("Expected no error while creating attachment, got:", err)
} }
@ -151,34 +154,17 @@ func TestClient_CreateAttachment(t *testing.T) {
} }
} }
func TestClient_DeleteAttachment(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "DELETE", "/mail/v4/attachments/"+testAttachment.ID))
b := &bytes.Buffer{}
if n, _ := b.ReadFrom(r.Body); n != 0 {
t.Fatal("expected no body but have: ", b.String())
}
fmt.Fprint(w, testDeleteAttachmentBody)
}))
defer s.Close()
err := c.DeleteAttachment(testAttachment.ID)
if err != nil {
t.Fatal("Expected no error while deleting attachment, got:", err)
}
}
func TestClient_GetAttachment(t *testing.T) { func TestClient_GetAttachment(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "GET", "/mail/v4/attachments/"+testAttachment.ID)) Ok(t, checkMethodAndPath(r, "GET", "/mail/v4/attachments/"+testAttachment.ID))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testAttachmentCleartext) fmt.Fprint(w, testAttachmentCleartext)
})) }))
defer s.Close() defer s.Close()
r, err := c.GetAttachment(testAttachment.ID) r, err := c.GetAttachment(context.TODO(), testAttachment.ID)
if err != nil { if err != nil {
t.Fatal("Expected no error while getting attachment, got:", err) t.Fatal("Expected no error while getting attachment, got:", err)
} }

View File

@ -1,407 +1,47 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi package pmapi
import ( import (
"crypto/subtle" "context"
"crypto/rand"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "io"
"net/http" "time"
"strings"
"github.com/ProtonMail/proton-bridge/pkg/srp" "github.com/go-resty/resty/v2"
) )
var ErrBad2FACode = errors.New("incorrect 2FA code") func (c *client) Auth2FA(ctx context.Context, req Auth2FAReq) error {
var ErrBad2FACodeTryAgain = errors.New("incorrect 2FA code: please try again") if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Post("/auth/2fa")
type AuthInfoReq struct { }); err != nil {
Username string
}
type U2FInfo struct {
Challenge string
RegisteredKeys []struct {
Version string
KeyHandle string
}
}
type TwoFactorInfo struct {
Enabled int // 0 for disabled, 1 for OTP, 2 for U2F, 3 for both.
TOTP int
U2F U2FInfo
}
func (twoFactor *TwoFactorInfo) hasTwoFactor() bool {
return twoFactor.Enabled > 0
}
// AuthInfo contains data used when authenticating a user. It should be
// provided to Client.Auth(). Each AuthInfo can be used for only one login attempt.
type AuthInfo struct {
TwoFA *TwoFactorInfo `json:"2FA,omitempty"`
version int
salt string
modulus string
srpSession string
serverEphemeral string
}
func (a *AuthInfo) HasTwoFactor() bool {
if a.TwoFA == nil {
return false
}
return a.TwoFA.hasTwoFactor()
}
type AuthInfoRes struct {
Res
AuthInfo
Modulus string
ServerEphemeral string
Version int
Salt string
SRPSession string
}
func (res *AuthInfoRes) getAuthInfo() *AuthInfo {
info := &res.AuthInfo
// Some fields in AuthInfo are private, so we need to copy them from AuthRes
// (private fields cannot be populated by json).
info.version = res.Version
info.salt = res.Salt
info.modulus = res.Modulus
info.srpSession = res.SRPSession
info.serverEphemeral = res.ServerEphemeral
return info
}
type AuthReq struct {
Username string
ClientProof string
ClientEphemeral string
SRPSession string
}
// Auth contains data after a successful authentication. It should be provided to Client.Unlock().
type Auth struct {
accessToken string // Read from AuthRes.
ExpiresIn int
uid string // Read from AuthRes.
RefreshToken string
EventID string
PasswordMode int
TwoFA *TwoFactorInfo `json:"2FA,omitempty"`
}
// UID returns the session UID from the Auth.
// Only Auths generated from the /auth route will have the UID.
// Auths generated from /auth/refresh are not required to.
func (s *Auth) UID() string {
return s.uid
}
// GenToken generates a string token containing the session UID and refresh token.
func (s *Auth) GenToken() string {
if s == nil {
return ""
}
return fmt.Sprintf("%v:%v", s.UID(), s.RefreshToken)
}
func (s *Auth) HasTwoFactor() bool {
if s.TwoFA == nil {
return false
}
return s.TwoFA.hasTwoFactor()
}
func (s *Auth) HasMailboxPassword() bool {
return s.PasswordMode == 2
}
type AuthRes struct {
Res
Auth
AccessToken string
TokenType string
// UID is the session UID. This is only present in an initial Auth (/auth), not in a refreshed Auth (/auth/refresh).
UID string
ServerProof string
}
func (res *AuthRes) getAuth() *Auth {
auth := &res.Auth
// Some fields in Auth are private, so we need to copy them from AuthRes
// (private fields cannot be populated by json).
auth.accessToken = res.AccessToken
auth.uid = res.UID
return auth
}
type Auth2FAReq struct {
TwoFactorCode string
// Prepared for U2F:
// U2F U2FRequest
}
type Auth2FARes struct {
Res
}
type AuthRefreshReq struct {
ResponseType string
GrantType string
RefreshToken string
UID string
RedirectURI string
State string
}
func (c *client) sendAuth(auth *Auth) {
if auth != nil {
c.log.WithField("auth", *auth).Debug("Client is sending auth to ClientManager")
} else {
c.log.Debug("Client is sending nil auth to ClientManager")
}
if auth != nil {
c.uid = auth.UID()
c.accessToken = auth.accessToken
}
c.cm.HandleAuth(ClientAuth{UserID: c.userID, Auth: auth})
}
// AuthInfo gets authentication info for a user.
func (c *client) AuthInfo(username string) (info *AuthInfo, err error) {
infoReq := &AuthInfoReq{
Username: username,
}
req, err := c.NewJSONRequest("POST", "/auth/info", infoReq)
if err != nil {
return
}
var infoRes AuthInfoRes
if err = c.DoJSON(req, &infoRes); err != nil {
return
}
info, err = infoRes.getAuthInfo(), infoRes.Err()
return
}
func srpProofsFromInfo(info *AuthInfo, username, password string, fallbackVersion int) (proofs *srp.SrpProofs, err error) {
version := info.version
if version == 0 {
version = fallbackVersion
}
srpAuth, err := srp.NewSrpAuth(version, username, password, info.salt, info.modulus, info.serverEphemeral)
if err != nil {
return
}
proofs, err = srpAuth.GenerateSrpProofs(2048)
return
}
func (c *client) tryAuth(username, password string, info *AuthInfo, fallbackVersion int) (res *AuthRes, err error) {
proofs, err := srpProofsFromInfo(info, username, password, fallbackVersion)
if err != nil {
return
}
authReq := &AuthReq{
Username: username,
ClientEphemeral: base64.StdEncoding.EncodeToString(proofs.ClientEphemeral),
ClientProof: base64.StdEncoding.EncodeToString(proofs.ClientProof),
SRPSession: info.srpSession,
}
req, err := c.NewJSONRequest("POST", "/auth", authReq)
if err != nil {
return
}
var authRes AuthRes
if err = c.DoJSON(req, &authRes); err != nil {
return
}
if err = authRes.Err(); err != nil {
return
}
serverProof, err := base64.StdEncoding.DecodeString(authRes.ServerProof)
if err != nil {
return
}
if subtle.ConstantTimeCompare(proofs.ExpectedServerProof, serverProof) != 1 {
return nil, errors.New("pmapi: bad server proof")
}
res, err = &authRes, authRes.Err()
return res, err
}
func (c *client) tryFullAuth(username, password string, fallbackVersion int) (info *AuthInfo, authRes *AuthRes, err error) {
info, err = c.AuthInfo(username)
if err != nil {
return
}
authRes, err = c.tryAuth(username, password, info, fallbackVersion)
return
}
// Auth will authenticate a user.
func (c *client) Auth(username, password string, info *AuthInfo) (auth *Auth, err error) {
if info == nil {
if info, err = c.AuthInfo(username); err != nil {
return
}
}
authRes, err := c.tryAuth(username, password, info, 2)
if err != nil && info.version == 0 && srp.CleanUserName(username) != strings.ToLower(username) {
info, authRes, err = c.tryFullAuth(username, password, 1)
}
if err != nil && info.version == 0 {
_, authRes, err = c.tryFullAuth(username, password, 0)
}
if err != nil {
return
}
auth = authRes.getAuth()
c.sendAuth(auth)
return auth, err
}
// Auth2FA will authenticate a user into full scope.
// `Auth` struct contains method `HasTwoFactor` deciding whether this has to be done.
func (c *client) Auth2FA(twoFactorCode string, auth *Auth) error {
auth2FAReq := &Auth2FAReq{
TwoFactorCode: twoFactorCode,
}
req, err := c.NewJSONRequest("POST", "/auth/2fa", auth2FAReq)
if err != nil {
return err return err
} }
var auth2FARes Auth2FARes
if err := c.DoJSON(req, &auth2FARes); err != nil {
return err
}
if err := auth2FARes.Err(); err != nil {
switch auth2FARes.StatusCode {
case http.StatusUnauthorized:
return ErrBad2FACode
case http.StatusUnprocessableEntity:
return ErrBad2FACodeTryAgain
default:
return err
}
}
return nil return nil
} }
// AuthRefresh will refresh an expired access token. func (c *client) AuthDelete(ctx context.Context) error {
func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) { if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
c.refreshLocker.Lock() return r.Delete("/auth")
defer c.refreshLocker.Unlock() }); err != nil {
return err
// If we don't yet have a saved access token, save this one in case the refresh fails!
// That way we can try again later (see handleUnauthorizedStatus).
c.cm.setTokenIfUnset(c.userID, uidAndRefreshToken)
split := strings.Split(uidAndRefreshToken, ":")
if len(split) != 2 {
err = ErrInvalidToken
return
} }
refreshReq := &AuthRefreshReq{ c.uid, c.acc, c.ref, c.exp = "", "", "", time.Time{}
ResponseType: "token",
GrantType: "refresh_token", // FIXME(conman): should we perhaps signal via AuthHandler that the auth was deleted?
RefreshToken: split[1],
UID: split[0], return nil
RedirectURI: "https://protonmail.ch",
State: "random_string",
} }
// UID must be set for `x-pm-uid` header field, see backend-communication#11 func (c *client) AuthSalt(ctx context.Context) (string, error) {
c.uid = split[0] salts, err := c.GetKeySalts(ctx)
req, err := c.NewJSONRequest("POST", "/auth/refresh", refreshReq)
if err != nil {
return
}
var res AuthRes
if err = c.DoJSON(req, &res); err != nil {
return
}
if err = res.Err(); err != nil {
return
}
auth = res.getAuth()
// Responses from /auth/refresh are not guaranteed to return the UID if it has not changed.
// But we want to always return it.
if auth.uid == "" {
auth.uid = c.uid
}
c.sendAuth(auth)
return auth, err
}
func (c *client) AuthSalt() (string, error) {
salts, err := c.GetKeySalts()
if err != nil { if err != nil {
return "", err return "", err
} }
if _, err := c.CurrentUser(); err != nil { if _, err := c.CurrentUser(ctx); err != nil {
return "", err return "", err
} }
@ -414,40 +54,37 @@ func (c *client) AuthSalt() (string, error) {
return "", errors.New("no matching salt found") return "", errors.New("no matching salt found")
} }
// Logout instructs the client manager to log this client out. func (c *client) AddAuthHandler(handler AuthHandler) {
func (c *client) Logout() { c.authHandlers = append(c.authHandlers, handler)
c.cm.LogoutClient(c.userID)
} }
// DeleteAuth deletes the API session. func (c *client) authRefresh(ctx context.Context) error {
func (c *client) DeleteAuth() (err error) { c.authLocker.Lock()
req, err := c.NewRequest("DELETE", "/auth", nil) defer c.authLocker.Unlock()
auth, err := c.req.authRefresh(ctx, c.uid, c.ref)
if err != nil { if err != nil {
return return err
} }
var res Res c.acc = auth.AccessToken
if err = c.DoJSON(req, &res); err != nil { c.ref = auth.RefreshToken
return
for _, handler := range c.authHandlers {
if err := handler(auth); err != nil {
return err
}
} }
if err = res.Err(); err != nil { return nil
return
} }
return func randomString(length int) string {
noise := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, noise); err != nil {
panic(err)
} }
// IsConnected returns whether the client is authorized to access the API. return base64.StdEncoding.EncodeToString(noise)[:length]
func (c *client) IsConnected() bool {
return c.uid != "" && c.accessToken != ""
}
// ClearData clears sensitive data from the client.
func (c *client) ClearData() {
c.uid = ""
c.accessToken = ""
c.addresses = nil
c.user = nil
c.clearKeys()
} }

View File

@ -1,351 +1,135 @@
// Copyright (c) 2021 Proton Technologies AG package pmapi_test
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import ( import (
"context"
"encoding/json" "encoding/json"
"math/rand" "errors"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"time" "time"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/srp"
"github.com/sirupsen/logrus"
a "github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
) )
var testIdentity = &crypto.Identity{ func TestAutomaticAuthRefresh(t *testing.T) {
Name: "UserID", var wantAuth = &pmapi.Auth{
Email: "", UID: "testUID",
AccessToken: "testAcc",
RefreshToken: "testRef",
} }
const ( mux := http.NewServeMux()
testUsername = "jason"
testAPIPassword = "apple"
testUID = "729ad6012421d67ad26950dc898bebe3a6e3caa2" //nolint[gosec] mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
testAccessToken = "de0423049b44243afeec7d9c1d99be7b46da1e8a" //nolint[gosec] w.Header().Set("Content-Type", "application/json")
testAccessTokenOld = "feb3159ac63fb05119bcf4480d939278aa746926" //nolint[gosec]
testRefreshToken = "a49b98256745bb497bec20e9b55f5de16f01fb52" //nolint[gosec]
testRefreshTokenNew = "b894b4c4f20003f12d486900d8b88c7d68e67235" //nolint[gosec]
)
var testAuthInfo = &AuthInfo{ if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
TwoFA: &TwoFactorInfo{TOTP: 1}, panic(err)
}
})
version: 4, mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
salt: "yKlc5/CvObfoiw==", w.WriteHeader(http.StatusOK)
modulus: "-----BEGIN PGP SIGNED MESSAGE-----\nHash: SHA256\n\nW2z5HBi8RvsfYzZTS7qBaUxxPhsfHJFZpu3Kd6s1JafNrCCH9rfvPLrfuqocxWPgWDH2R8neK7PkNvjxto9TStuY5z7jAzWRvFWN9cQhAKkdWgy0JY6ywVn22+HFpF4cYesHrqFIKUPDMSSIlWjBVmEJZ/MusD44ZT29xcPrOqeZvwtCffKtGAIjLYPZIEbZKnDM1Dm3q2K/xS5h+xdhjnndhsrkwm9U9oyA2wxzSXFL+pdfj2fOdRwuR5nW0J2NFrq3kJjkRmpO/Genq1UW+TEknIWAb6VzJJJA244K/H8cnSx2+nSNZO3bbo6Ys228ruV9A8m6DhxmS+bihN3ttQ==\n-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwl4EARYIABAFAlwB1j0JEDUFhcTpUY8mAAD8CgEAnsFnF4cF0uSHKkXa1GIa\nGO86yMV4zDZEZcDSJo0fgr8A/AlupGN9EdHlsrZLmTA1vhIx+rOgxdEff28N\nkvNM7qIK\n=q6vu\n-----END PGP SIGNATURE-----\n", })
srpSession: "9b2946bbd9055f17c34940abdce0c3d3",
serverEphemeral: "5tfigcLKoM0DPWYB+EqYE7QlqsiT63iOVlO5ZX0lTMEILSsrRdVCYrN8L3zkinsAjUZ/cx5wIS7N05k66uZb+ZE3lFOJS2s1BkqLvCrGxYL0e3n5YAnzHYlvCCJKXw/sK57ntfF1OOoblBXX6dw5LjeeDglEep2/DaE0TjD8WUpq4Ls2HlQGn9wrC7dFO2lJXsMhRffxKghiOsdvCLXDmwXginzn/LFezA8KrDsWOBSEGntwpg3s1xFj5h8BqtRHvC0igmoscqgw+3GCMTJ0NZAQ/L+5aJ/0ccL0WBK208ltCNl+/X6Sz0kpyvOP4RqFJhC1auVDJ9AjZQYSYZ1NEQ==", ts := httptest.NewServer(mux)
var gotAuth *pmapi.Auth
// Create a new client.
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
// Register an auth handler.
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
// Make a request with an access token that already expired one second ago.
if _, err := c.GetAddresses(context.Background()); err != nil {
t.Fatal("got unexpected error", err)
} }
// testAuth has default values which are adjusted in each test. // The auth callback should have been called.
var testAuth = &Auth{ if *gotAuth != *wantAuth {
EventID: "NcKPtU5eMNPMrDkIMbEJrgMtC9yQ7Xc5ZBT-tB3UtV1rZ324RWfCIdBI758q0UnsfywS8CkNenIQlWLIX_dUng==", t.Fatal("got unexpected auth", gotAuth)
ExpiresIn: 86400, }
RefreshToken: "feb3159ac63fb05119bcf4480d939278aa746926",
accessToken: testAccessToken,
uid: testUID,
} }
var testAuthRefreshReq = AuthRefreshReq{ func Test401AuthRefresh(t *testing.T) {
ResponseType: "token", var wantAuth = &pmapi.Auth{
GrantType: "refresh_token", UID: "testUID",
RefreshToken: testRefreshToken, AccessToken: "testAcc",
UID: testUID, RefreshToken: "testRef",
RedirectURI: "https://protonmail.ch",
State: "random_string",
} }
var testAuthReq = AuthReq{ mux := http.NewServeMux()
Username: testUsername,
ClientProof: "axfvYdl9iXZjY6zQ+hBYmY7X3TDc/9JtSvrmyZXhDxjxkXB3Hro27t1KItmFIJloItY5sLZDs0eEEZJI34oFZD4ViSG0kfB7ZXcCZ9Jse+U5OFu4vdnPTGolnSofRMEs1NR6ePXzH7mQ10qoq43ity3ve2vmhQNuJNlHAPynKf2WqKOgxq7mmkBzEpXES4mIhwwgVbOygKcUSvguz5E5g13ATF0ZX2d9SJWAbZ262Tks+h99Cdk/dOfgLQhr0nO/r0cpwP84W2RWU2Q34LNkKuuQHkjmxelgBleGq54tCbhoCAYPP6vapgrQjNoVAC/dkjIIAoNL9bJSIynFM5znAA==", mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
ClientEphemeral: "mK+eSMosfZO/Cs5s+vcbjpsN7F8UAObwlKKnCy/z9FpoMRM2PfTe5ywLBgffmLYaapPq7XOxaqaj08kcZLHcM1fIA2JQZZTKPnESN1qAQztJ3/YHMI0op6yBgzx9803OjIznjCD2B3XBSMOHIG4oG0UwocsIX32hiMnYlMMkt8NGrityPlnmEbxpRna3fu9LEZ+v0uo6PjKCrO7+9E3uaMi64HadXBfyx2raBFFwA+yh7FvE7U+hl3AJclEre4d8pmfhMdxXze1soJI8fMuqaa07rY0r0rF5mLLTuqTIGRFkU1qG9loq9+IMsSwgkt1P3ghW63JK7Y6LWdDy0d6cAg==", w.Header().Set("Content-Type", "application/json")
SRPSession: "9b2946bbd9055f17c34940abdce0c3d3",
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
panic(err)
}
})
var call int
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
call++
if call == 1 {
w.WriteHeader(http.StatusUnauthorized)
} else {
w.WriteHeader(http.StatusOK)
}
})
ts := httptest.NewServer(mux)
var gotAuth *pmapi.Auth
// Create a new client.
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
// Register an auth handler.
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
// The first request will fail with 401, triggering a refresh and retry.
if _, err := c.GetAddresses(context.Background()); err != nil {
t.Fatal("got unexpected error", err)
} }
var testAuth2FAReq = Auth2FAReq{ // The auth callback should have been called.
TwoFactorCode: "424242", if *gotAuth != *wantAuth {
t.Fatal("got unexpected auth", gotAuth)
}
} }
func init() { func Test401RevokedAuth(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel) mux := http.NewServeMux()
srp.RandReader = rand.New(rand.NewSource(42))
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
})
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
})
ts := httptest.NewServer(mux)
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
// The request will fail with 401, triggering a refresh.
// The retry will also fail with 401, returning an error.
_, err := c.GetAddresses(context.Background())
if err == nil {
t.Fatal("expected error, instead got", err)
} }
func TestClient_AuthInfo(t *testing.T) { if !errors.Is(err, pmapi.ErrUnauthorized) {
finish, c := newTestServerCallbacks(t, t.Fatal("expected error to be ErrUnauthorized, instead got", err)
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
Ok(t, checkMethodAndPath(r, "POST", "/auth/info"))
var infoReq AuthInfoReq
Ok(t, json.NewDecoder(r.Body).Decode(&infoReq))
Equals(t, infoReq.Username, testUsername)
return "/auth/info/post_response.json"
},
)
defer finish()
info, err := c.AuthInfo(testCurrentUser.Name)
Ok(t, err)
Equals(t, testAuthInfo, info)
} }
// TestClient_Auth reflects changes from proton/backend-communcation#3.
func TestClient_Auth(t *testing.T) {
srp.RandReader = rand.New(rand.NewSource(42))
finish, c := newTestServerCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
a.Nil(t, checkMethodAndPath(req, "POST", "/auth"))
var authReq AuthReq
r.Nil(t, json.NewDecoder(req.Body).Decode(&authReq))
r.Equal(t, testAuthReq, authReq)
return "/auth/post_response.json"
},
)
defer finish()
auth, err := c.Auth(testUsername, testAPIPassword, testAuthInfo)
r.Nil(t, err)
exp := &Auth{}
*exp = *testAuth
exp.accessToken = testAccessToken
exp.RefreshToken = testRefreshToken
a.Equal(t, exp, auth)
}
func TestClient_Auth2FA(t *testing.T) {
finish, c := newTestServerCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
Ok(t, checkMethodAndPath(r, "POST", "/auth/2fa"))
var info2FAReq Auth2FAReq
Ok(t, json.NewDecoder(r.Body).Decode(&info2FAReq))
Equals(t, info2FAReq.TwoFactorCode, testAuth2FAReq.TwoFactorCode)
return "/auth/2fa/post_response.json"
},
)
defer finish()
c.uid = testUID
c.accessToken = testAccessToken
err := c.Auth2FA(testAuth2FAReq.TwoFactorCode, testAuth)
Ok(t, err)
}
func TestClient_Auth2FA_Fail(t *testing.T) {
finish, c := newTestServerCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
Ok(t, checkMethodAndPath(r, "POST", "/auth/2fa"))
var info2FAReq Auth2FAReq
Ok(t, json.NewDecoder(r.Body).Decode(&info2FAReq))
Equals(t, info2FAReq.TwoFactorCode, testAuth2FAReq.TwoFactorCode)
return "/auth/2fa/post_401_bad_password.json"
},
)
defer finish()
c.uid = testUID
c.accessToken = testAccessToken
err := c.Auth2FA(testAuth2FAReq.TwoFactorCode, testAuth)
Equals(t, ErrBad2FACode, err)
}
func TestClient_Auth2FA_Retry(t *testing.T) {
finish, c := newTestServerCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
Ok(t, checkMethodAndPath(r, "POST", "/auth/2fa"))
var info2FAReq Auth2FAReq
Ok(t, json.NewDecoder(r.Body).Decode(&info2FAReq))
Equals(t, info2FAReq.TwoFactorCode, testAuth2FAReq.TwoFactorCode)
return "/auth/2fa/post_422_bad_password.json"
},
)
defer finish()
c.uid = testUID
c.accessToken = testAccessToken
err := c.Auth2FA(testAuth2FAReq.TwoFactorCode, testAuth)
Equals(t, ErrBad2FACodeTryAgain, err)
}
func TestClient_Unlock(t *testing.T) {
finish, c := newTestServerCallbacks(t,
routeGetUsers,
routeGetAddresses,
)
defer finish()
c.uid = testUID
c.accessToken = testAccessToken
err := c.Unlock([]byte("wrong"))
a.Error(t, err, "expected error, pasword is wrong")
err = c.Unlock([]byte(testMailboxPassword))
a.Nil(t, err)
a.Equal(t, testUID, c.uid)
a.Equal(t, testAccessToken, c.accessToken)
// second try should not fail because there is an unlocked key already
err = c.Unlock([]byte("wrong"))
a.Nil(t, err)
}
func TestClient_Unlock_EncPrivKey(t *testing.T) {
finish, c := newTestServerCallbacks(t,
routeGetUsers,
routeGetAddresses,
)
defer finish()
c.uid = testUID
c.accessToken = testAccessToken
err := c.Unlock([]byte(testMailboxPassword))
Ok(t, err)
Equals(t, testUID, c.uid)
Equals(t, testAccessToken, c.accessToken)
}
func routeAuthRefresh(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
Ok(tb, checkMethodAndPath(r, "POST", "/auth/refresh"))
Ok(tb, checkHeader(r.Header, "x-pm-uid", testUID))
var refreshReq AuthRefreshReq
Ok(tb, json.NewDecoder(r.Body).Decode(&refreshReq))
Equals(tb, testAuthRefreshReq, refreshReq)
return "/auth/refresh/post_response.json"
}
// TestClient_AuthRefresh reflects changes from proton/backend-communcation#11.
func TestClient_AuthRefresh(t *testing.T) {
finish, c := newTestServerCallbacks(t,
routeAuthRefresh,
)
defer finish()
c.uid = "" // Testing that we always send correct `x-pm-uid`.
c.accessToken = "oldToken"
auth, err := c.AuthRefresh(testUID + ":" + testRefreshToken)
Ok(t, err)
Equals(t, testUID, c.uid)
exp := &Auth{}
*exp = *testAuth
exp.uid = testUID // AuthRefresh will not return UID (only Auth returns the UID) we should set testUID to be able to generate token, see `GetToken`
exp.accessToken = testAccessToken
exp.EventID = ""
exp.ExpiresIn = 360000
exp.RefreshToken = testRefreshTokenNew
Equals(t, exp, auth)
}
func routeAuthRefreshHasUID(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
Ok(tb, checkMethodAndPath(r, "POST", "/auth/refresh"))
Ok(tb, checkHeader(r.Header, "x-pm-uid", testUID))
var refreshReq AuthRefreshReq
Ok(tb, json.NewDecoder(r.Body).Decode(&refreshReq))
Equals(tb, testAuthRefreshReq, refreshReq)
return "/auth/refresh/post_resp_has_uid.json"
}
// TestClient_AuthRefresh reflects changes from proton/backend-communcation#3.
func TestClient_AuthRefresh_HasUID(t *testing.T) {
finish, c := newTestServerCallbacks(t,
routeAuthRefreshHasUID,
)
defer finish()
c.uid = testUID
c.accessToken = "oldToken"
auth, err := c.AuthRefresh(testUID + ":" + testRefreshToken)
Ok(t, err)
exp := &Auth{}
*exp = *testAuth
exp.accessToken = testAccessToken
exp.EventID = ""
exp.ExpiresIn = 360000
exp.RefreshToken = testRefreshTokenNew
Equals(t, exp, auth)
}
func TestClient_Logout(t *testing.T) {
finish, c := newTestServerCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
Ok(t, checkMethodAndPath(r, "DELETE", "/auth"))
Ok(t, isAuthReq(r, testUID, testAccessToken))
return "auth/delete_response.json"
},
)
defer finish()
c.uid = testUID
c.accessToken = testAccessToken
c.Logout()
r.Eventually(t, func() bool {
return c.IsConnected() == false && c.userKeyRing == nil && c.addresses == nil && c.user == nil
}, 10*time.Second, 10*time.Millisecond)
}
func TestClient_DoUnauthorized(t *testing.T) {
finish, c := newTestServerCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
Ok(t, checkMethodAndPath(r, "GET", "/"))
return httpResponse(http.StatusUnauthorized)
},
routeAuthRefresh,
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
Ok(t, checkMethodAndPath(r, "GET", "/"))
Ok(t, isAuthReq(r, testUID, testAccessToken))
return httpResponse(http.StatusOK)
},
)
defer finish()
c.uid = testUID
c.accessToken = testAccessTokenOld
c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken
req, err := c.NewRequest("GET", "/", nil)
Ok(t, err)
res, err := c.Do(req, true)
Ok(t, err)
defer Ok(t, res.Body.Close())
} }

72
pkg/pmapi/auth_types.go Normal file
View File

@ -0,0 +1,72 @@
package pmapi
type AuthModulus struct {
Modulus string
ModulusID string
}
type GetAuthInfoReq struct {
Username string
}
type AuthInfo struct {
Version int
Modulus string
ServerEphemeral string
Salt string
SRPSession string
}
type TwoFAInfo struct {
Enabled TwoFAStatus
}
type TwoFAStatus int
const (
TwoFADisabled TwoFAStatus = iota
TOTPEnabled
// TODO: Support UTF
)
type PasswordMode int
const (
OnePasswordMode PasswordMode = iota + 1
TwoPasswordMode
)
type AuthReq struct {
Username string
ClientProof string
ClientEphemeral string
SRPSession string
}
type Auth struct {
UserID string
UID string
AccessToken string
RefreshToken string
ExpiresIn int64
Scope string
ServerProof string
TwoFA TwoFAInfo `json:"2FA"`
PasswordMode PasswordMode
}
type Auth2FAReq struct {
TwoFactorCode string
}
type AuthRefreshReq struct {
UID string
RefreshToken string
ResponseType string
GrantType string
RedirectURI string
State string
}

View File

@ -1,94 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"errors"
"fmt"
"net/http"
"time"
)
const protonStatusURL = "http://protonstatus.com/vpn_status"
// ErrNoInternetConnection indicates that both protonstatus and the API are unreachable.
var ErrNoInternetConnection = errors.New("no internet connection")
// CheckConnection returns an error if there is no internet connection.
// This should be moved to the ConnectionManager when it is implemented.
func (cm *ClientManager) CheckConnection() error {
// We use a normal dialer here which doesn't check tls fingerprints.
client := &http.Client{Timeout: time.Second * 10}
// Do not cumulate timeouts, use goroutines.
retStatus := make(chan error)
retAPI := make(chan error)
// vpn_status endpoint is fast and returns only OK. We check the connection only.
go checkConnection(client, protonStatusURL, retStatus)
// Check of API reachability also uses a fast endpoint.
go checkConnection(client, cm.GetRootURL()+"/tests/ping", retAPI)
errStatus := <-retStatus
errAPI := <-retAPI
switch {
case errStatus == nil && errAPI == nil:
return nil
case errStatus == nil && errAPI != nil:
cm.log.Error("ProtonStatus is reachable but API is not")
return ErrAPINotReachable
case errStatus != nil && errAPI == nil:
cm.log.Warn("API is reachable but protonstatus is not")
return nil
case errStatus != nil && errAPI != nil:
cm.log.Error("Both ProtonStatus and API are unreachable")
return ErrNoInternetConnection
}
return nil
}
// CheckConnection returns an error if there is no internet connection.
func CheckConnection() error {
client := &http.Client{Timeout: time.Second * 10}
retStatus := make(chan error)
go checkConnection(client, protonStatusURL, retStatus)
return <-retStatus
}
func checkConnection(client *http.Client, url string, errorChannel chan error) {
resp, err := client.Get(url)
if err != nil {
errorChannel <- err
return
}
_ = resp.Body.Close()
if resp.StatusCode != 200 {
errorChannel <- fmt.Errorf("HTTP status code %d", resp.StatusCode)
return
}
errorChannel <- nil
}

View File

@ -1,91 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"net/http"
"os"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/pkg/dialer"
"github.com/stretchr/testify/require"
)
const testServerPort = "18000"
const testRequestTimeout = 10 * time.Second
func TestMain(m *testing.M) {
go startServer()
time.Sleep(100 * time.Millisecond) // We need to wait till server is fully running.
code := m.Run()
os.Exit(code)
}
func startServer() {
http.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
})
http.HandleFunc("/timeout", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(testRequestTimeout + time.Second) // Add extra second to be sure it will timeout.
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
})
http.HandleFunc("/serverError", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", http.StatusInternalServerError)
})
panic(http.ListenAndServe(":"+testServerPort, nil))
}
func TestCheckConnection(t *testing.T) {
checkCheckConnection(t, "ok", "")
}
func TestCheckConnectionTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
checkCheckConnection(t, "timeout", "Client.Timeout exceeded while awaiting headers")
}
func TestCheckConnectionServerError(t *testing.T) {
checkCheckConnection(t, "serverError", "HTTP status code 500")
}
func checkCheckConnection(t *testing.T, path string, expectedErrMessage string) {
client := dialer.DialTimeoutClient()
client.Timeout = testRequestTimeout
ch := make(chan error)
go checkConnection(client, "http://localhost:"+testServerPort+"/"+path, ch)
timeout := time.After(testRequestTimeout + time.Second)
select {
case err := <-ch:
if expectedErrMessage == "" {
require.NoError(t, err)
} else {
require.Error(t, err, expectedErrMessage)
}
case <-timeout:
t.Error("checkConnection timeout failed")
}
}

View File

@ -18,97 +18,23 @@
package pmapi package pmapi
import ( import (
"bytes"
"context" "context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http" "net/http"
"reflect"
"strconv"
"strings"
"sync" "sync"
"time" "time"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/jaytaylor/html2text" "github.com/go-resty/resty/v2"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus"
) )
// Version of the API.
const Version = 3
// API return codes.
const (
ForceUpgradeBadAppVersion = 5003
APIOffline = 7001
ImportMessageTooLong = 36022
BansRequests = 85131
)
// The output errors.
var (
ErrInvalidToken = errors.New("refresh token invalid")
ErrAPINotReachable = errors.New("cannot reach the server")
ErrUpgradeApplication = errors.New("application upgrade required")
ErrConnectionSlow = errors.New("request canceled because connection speed was too slow")
)
type ErrUnprocessableEntity struct {
error
}
func (err *ErrUnprocessableEntity) Error() string {
return err.error.Error()
}
type ErrUnauthorized struct {
error
}
func (err *ErrUnauthorized) Error() string {
return fmt.Sprintf("unauthorized access: %+v", err.error.Error())
}
// ClientConfig contains Client configuration.
type ClientConfig struct {
// The client application name and version.
AppVersion string
// The client ID.
ClientID string
// Timeout is the timeout of the full request. It is passed to http.Client.
// If it is left unset, it means no timeout is applied.
Timeout time.Duration
// FirstReadTimeout specifies the timeout from getting response to the first read of body response.
// This timeout is applied only when MinBytesPerSecond is used.
// Default is 5 minutes.
FirstReadTimeout time.Duration
// MinBytesPerSecond specifies minimum Bytes per second or the request will be canceled.
// Zero means no limitation.
MinBytesPerSecond int64
ConnectionOnHandler func()
ConnectionOffHandler func()
UpgradeApplicationHandler func()
}
// client is a client of the protonmail API. It implements the Client interface. // client is a client of the protonmail API. It implements the Client interface.
type client struct { type client struct {
cm *ClientManager req requester
hc *http.Client
uid string uid, acc, ref string
accessToken string authHandlers []AuthHandler
userID string authLocker sync.RWMutex
requestLocker sync.Locker
refreshLocker sync.Locker
user *User user *User
addresses AddressList addresses AddressList
@ -116,404 +42,79 @@ type client struct {
addrKeyRing map[string]*crypto.KeyRing addrKeyRing map[string]*crypto.KeyRing
keyRingLock sync.Locker keyRingLock sync.Locker
log *logrus.Entry exp time.Time
} }
// newClient creates a new API client. func newClient(req requester, uid string) *client {
func newClient(cm *ClientManager, userID string) *client {
return &client{ return &client{
cm: cm, req: req,
hc: getHTTPClient(cm.config, cm.roundTripper, cm.cookieJar), uid: uid,
userID: userID,
requestLocker: &sync.Mutex{},
refreshLocker: &sync.Mutex{},
keyRingLock: &sync.Mutex{},
addrKeyRing: make(map[string]*crypto.KeyRing), addrKeyRing: make(map[string]*crypto.KeyRing),
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID), keyRingLock: &sync.RWMutex{},
} }
} }
// getHTTPClient returns a http client configured by the given client config and using the given transport. func (c *client) withAuth(acc, ref string, exp time.Time) *client {
func getHTTPClient(cfg *ClientConfig, rt http.RoundTripper, jar http.CookieJar) (hc *http.Client) { c.acc = acc
return &http.Client{ c.ref = ref
Transport: rt, c.exp = exp
Jar: jar,
Timeout: cfg.Timeout, return c
}
func (c *client) r(ctx context.Context) (*resty.Request, error) {
r := c.req.r(ctx)
if c.uid != "" {
r.SetHeader("x-pm-uid", c.uid)
}
if time.Now().After(c.exp) {
if err := c.authRefresh(ctx); err != nil {
return nil, err
} }
} }
func (c *client) IsUnlocked() bool { c.authLocker.RLock()
return c.userKeyRing != nil defer c.authLocker.RUnlock()
if c.acc != "" {
r.SetAuthToken(c.acc)
} }
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings. return r, nil
// If the keyrings are already present, they are not recreated.
func (c *client) Unlock(passphrase []byte) (err error) {
c.keyRingLock.Lock()
defer c.keyRingLock.Unlock()
return c.unlock(passphrase)
} }
// unlock unlocks the user's keys but without locking the keyring lock first. func (c *client) do(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) (*resty.Response, error) {
// Should only be used internally by methods that first lock the lock. r, err := c.r(ctx)
func (c *client) unlock(passphrase []byte) (err error) {
if _, err = c.CurrentUser(); err != nil {
return
}
if c.userKeyRing == nil {
if err = c.unlockUser(passphrase); err != nil {
return errors.Wrap(err, "failed to unlock user")
}
}
for _, address := range c.addresses {
if c.addrKeyRing[address.ID] == nil {
if err = c.unlockAddress(passphrase, address); err != nil {
return errors.Wrap(err, "failed to unlock address")
}
}
}
return
}
func (c *client) ReloadKeys(passphrase []byte) (err error) {
c.keyRingLock.Lock()
defer c.keyRingLock.Unlock()
c.clearKeys()
return c.unlock(passphrase)
}
func (c *client) clearKeys() {
if c.userKeyRing != nil {
c.userKeyRing.ClearPrivateParams()
c.userKeyRing = nil
}
for id, kr := range c.addrKeyRing {
if kr != nil {
kr.ClearPrivateParams()
}
delete(c.addrKeyRing, id)
}
}
func (c *client) CloseConnections() {
c.hc.CloseIdleConnections()
}
// Do makes an API request. It does not check for HTTP status code errors.
func (c *client) Do(req *http.Request, retryUnauthorized bool) (res *http.Response, err error) {
// Copy the request body in case we need to retry it.
var bodyBuffer []byte
if req.Body != nil {
defer req.Body.Close() //nolint[errcheck]
bodyBuffer, err = ioutil.ReadAll(req.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r := bytes.NewReader(bodyBuffer) res, err := wrapRestyError(fn(r))
req.Body = ioutil.NopCloser(r) if err != nil {
if res.StatusCode() != http.StatusUnauthorized {
return nil, err
} }
return c.doBuffered(req, bodyBuffer, retryUnauthorized) if err := c.authRefresh(ctx); err != nil {
return nil, err
} }
// If needed it retries using req and buffered body. return wrapRestyError(fn(r))
func (c *client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen]
isAuthReq := strings.Contains(req.URL.Path, "/auth")
req.Header.Set("User-Agent", c.cm.userAgent.String())
req.Header.Set("x-pm-appversion", c.cm.config.AppVersion)
if c.uid != "" {
req.Header.Set("x-pm-uid", c.uid)
} }
if c.accessToken != "" { return res, nil
req.Header.Set("Authorization", "Bearer "+c.accessToken)
}
c.log.Debugln("Requesting ", req.Method, req.URL.RequestURI())
if logrus.GetLevel() == logrus.TraceLevel {
head := ""
for i, v := range req.Header {
head += i + ": "
head += strings.Join(v, "")
head += "\n"
}
c.log.Tracef("REQHEAD \n%s", head)
c.log.Tracef("REQBODY '%s'", printBytes(bodyBuffer))
}
hasBody := len(bodyBuffer) > 0
if res, err = c.hc.Do(req); err != nil {
if res == nil {
c.log.WithError(err).Error("Cannot get response")
err = ErrAPINotReachable
c.cm.noConnection()
}
return
}
// Cookies are returned only after request was sent.
c.log.Tracef("REQCOOKIES '%v'", req.Cookies())
resDate := res.Header.Get("Date")
if resDate != "" {
if serverTime, err := http.ParseTime(resDate); err == nil {
crypto.UpdateTime(serverTime.Unix())
}
}
if res.StatusCode == http.StatusUnauthorized {
if hasBody {
r := bytes.NewReader(bodyBuffer)
req.Body = ioutil.NopCloser(r)
}
if !isAuthReq {
_, _ = io.Copy(ioutil.Discard, res.Body)
_ = res.Body.Close()
return c.handleStatusUnauthorized(req, bodyBuffer, res, retryUnauthorized)
}
}
// Retry induced by HTTP status code>
retryAfter := 10
doRetry := res.StatusCode == http.StatusTooManyRequests
if doRetry {
if headerAfter, err := strconv.Atoi(res.Header.Get("Retry-After")); err == nil && headerAfter > 0 {
retryAfter = headerAfter
}
// To avoid spikes when all clients retry at the same time, we add some random wait.
retryAfter += rand.Intn(10) //nolint[gosec] It is OK to use weak random number generator here
if hasBody {
r := bytes.NewReader(bodyBuffer)
req.Body = ioutil.NopCloser(r)
}
c.log.Warningf("Retrying %s after %ds induced by http code %d", req.URL.Path, retryAfter, res.StatusCode)
time.Sleep(time.Duration(retryAfter) * time.Second)
_, _ = io.Copy(ioutil.Discard, res.Body)
_ = res.Body.Close()
return c.doBuffered(req, bodyBuffer, false)
} }
func wrapRestyError(res *resty.Response, err error) (*resty.Response, error) {
if err, ok := err.(*resty.ResponseError); ok {
return res, err return res, err
} }
// DoJSON performs the request and unmarshals the response as JSON into data. if res.RawResponse != nil {
// If the API returns a non-2xx HTTP status code, the error returned will contain status return res, err
// and response as plaintext. API errors must be checked by the caller.
// It is performed buffered, in case we need to retry.
func (c *client) DoJSON(req *http.Request, data interface{}) error {
// Copy the request body in case we need to retry it
var reqBodyBuffer []byte
if req.Body != nil {
defer req.Body.Close() //nolint[errcheck]
var err error
if reqBodyBuffer, err = ioutil.ReadAll(req.Body); err != nil {
return err
} }
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer)) return res, errors.Wrap(ErrNoConnection, err.Error())
}
return c.doJSONBuffered(req, reqBodyBuffer, data)
}
// doJSONBuffered performs a buffered json request (see DoJSON for more information).
func (c *client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data interface{}) error { // nolint[funlen]
req.Header.Set("Accept", "application/vnd.protonmail.v1+json")
var cancelRequest context.CancelFunc
if c.cm.config.MinBytesPerSecond > 0 {
var ctx context.Context
ctx, cancelRequest = context.WithCancel(req.Context())
defer func() {
cancelRequest()
}()
req = req.WithContext(ctx)
}
res, err := c.doBuffered(req, reqBodyBuffer, false)
if err != nil {
return err
}
defer res.Body.Close() //nolint[errcheck]
var resBody []byte
if c.cm.config.MinBytesPerSecond == 0 {
resBody, err = ioutil.ReadAll(res.Body)
} else {
resBody, err = c.readAllMinSpeed(res.Body, cancelRequest)
if err == context.Canceled {
err = ErrConnectionSlow
}
}
// The server response may contain data which we want to have in memory
// for as little time as possible (such as keys). Go is garbage collected,
// so we are not in charge of when the memory will actually be cleared.
// We can at least try to rewrite the original data to mitigate this problem.
defer func() {
for i := 0; i < len(resBody); i++ {
resBody[i] = byte(65)
}
}()
if logrus.GetLevel() == logrus.TraceLevel {
head := ""
for i, v := range res.Header {
head += i + ": "
head += strings.Join(v, "")
head += "\n"
}
c.log.Tracef("RESHEAD \n%s", head)
c.log.Tracef("RESBODY '%s'", resBody)
}
if err != nil {
return err
}
// Retry induced by API code.
errCode := &Res{}
if err := json.Unmarshal(resBody, errCode); err == nil {
if errCode.Code == BansRequests {
retryAfter := 3
c.log.Warningf("Retrying %s after %ds induced by API code %d", req.URL.Path, retryAfter, errCode.Code)
time.Sleep(time.Duration(retryAfter) * time.Second)
if len(reqBodyBuffer) > 0 {
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer))
}
return c.doJSONBuffered(req, reqBodyBuffer, data)
}
if errCode.Err() == ErrAPINotReachable {
c.cm.noConnection()
}
if errCode.Err() == ErrUpgradeApplication {
c.cm.config.UpgradeApplicationHandler()
}
}
if err := json.Unmarshal(resBody, data); err != nil {
// Check to see if this is due to a non 2xx HTTP status code.
if res.StatusCode != http.StatusOK {
r := bytes.NewReader(bytes.ReplaceAll(resBody, []byte("\n"), []byte("\\n")))
plaintext, err := html2text.FromReader(r)
if err == nil {
return fmt.Errorf("Error: \n\n" + res.Status + "\n\n" + plaintext)
}
}
if errJS, ok := err.(*json.SyntaxError); ok {
return fmt.Errorf("invalid json %v (offset:%d) ", errJS.Error(), errJS.Offset)
}
return fmt.Errorf("unmarshal fail: %v ", err)
}
// Set StatusCode in case data struct supports that field.
// It's safe to set StatusCode, server returns Code. StatusCode should be preferred over Code.
dataValue := reflect.ValueOf(data).Elem()
statusCodeField := dataValue.FieldByName("StatusCode")
if statusCodeField.IsValid() && statusCodeField.CanSet() && statusCodeField.Kind() == reflect.Int {
statusCodeField.SetInt(int64(res.StatusCode))
}
if res.StatusCode != http.StatusOK {
c.log.Warnf("request %s %s NOT OK: %s", req.Method, req.URL.Path, res.Status)
}
return nil
}
func (c *client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) {
firstReadTimeout := c.cm.config.FirstReadTimeout
if firstReadTimeout == 0 {
firstReadTimeout = 5 * time.Minute
}
timer := time.AfterFunc(firstReadTimeout, func() {
cancelRequest()
})
// speedCheckSeconds controls how often we check the transfer speed.
// Note that connection can be unstable, on average very fast, but can be
// idle for few seconds; or that API can take its time before sending
// another data, e.g., API can send some data and take some time before
// processing and sending the rest of the response.
const speedCheckSeconds = 30
var buffer bytes.Buffer
for {
_, err := io.CopyN(&buffer, data, c.cm.config.MinBytesPerSecond*speedCheckSeconds)
timer.Stop()
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
timer.Reset(speedCheckSeconds * time.Second)
}
return ioutil.ReadAll(&buffer)
}
func (c *client) refreshAccessToken() (err error) {
c.log.Debug("Refreshing token")
refreshToken := c.cm.GetToken(c.userID)
if refreshToken == "" {
c.sendAuth(nil)
return ErrInvalidToken
}
if _, err := c.AuthRefresh(refreshToken); err != nil {
if err != ErrAPINotReachable {
c.sendAuth(nil)
}
return errors.Wrap(err, "failed to refresh auth")
}
return
}
func (c *client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) {
c.log.Info("Handling unauthorized status")
// If this is not a retry, then it is the first time handling status unauthorized,
// so try again without refreshing the access token.
if !retry {
c.log.Debug("Handling unauthorized status by retrying")
c.requestLocker.Lock()
defer c.requestLocker.Unlock()
_, _ = io.Copy(ioutil.Discard, res.Body)
_ = res.Body.Close()
return c.doBuffered(req, reqBodyBuffer, true)
}
// This is already a retry, so we will try to refresh the access token before trying again.
if err = c.refreshAccessToken(); err != nil {
c.log.WithError(err).Warn("Cannot refresh token")
err = &ErrUnauthorized{err}
return
}
_, err = io.Copy(ioutil.Discard, res.Body)
if err != nil {
c.log.WithError(err).Warn("Failed to read out response body")
}
_ = res.Body.Close()
return c.doBuffered(req, reqBodyBuffer, true)
} }

70
pkg/pmapi/client_keys.go Normal file
View File

@ -0,0 +1,70 @@
package pmapi
import (
"context"
"github.com/pkg/errors"
)
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
// If the keyrings are already present, they are not recreated.
func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) {
c.keyRingLock.Lock()
defer c.keyRingLock.Unlock()
// FIXME(conman): Should this be done as part of NewClient somehow?
return c.unlock(ctx, passphrase)
}
// unlock unlocks the user's keys but without locking the keyring lock first.
// Should only be used internally by methods that first lock the lock.
func (c *client) unlock(ctx context.Context, passphrase []byte) (err error) {
if _, err = c.CurrentUser(ctx); err != nil {
return
}
if c.userKeyRing == nil {
if err = c.unlockUser(passphrase); err != nil {
return errors.Wrap(err, "failed to unlock user")
}
}
for _, address := range c.addresses {
if c.addrKeyRing[address.ID] == nil {
if err = c.unlockAddress(passphrase, address); err != nil {
return errors.Wrap(err, "failed to unlock address")
}
}
}
return
}
func (c *client) ReloadKeys(ctx context.Context, passphrase []byte) (err error) {
c.keyRingLock.Lock()
defer c.keyRingLock.Unlock()
c.clearKeys()
return c.unlock(ctx, passphrase)
}
func (c *client) clearKeys() {
if c.userKeyRing != nil {
c.userKeyRing.ClearPrivateParams()
c.userKeyRing = nil
}
for id, kr := range c.addrKeyRing {
if kr != nil {
kr.ClearPrivateParams()
}
delete(c.addrKeyRing, id)
}
}
func (c *client) IsUnlocked() bool {
// FIXME(conman): Better way to check? we don't currently check address keys.
return c.userKeyRing != nil
}

View File

@ -1,210 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
var testClientConfig = &ClientConfig{
AppVersion: "GoPMAPI_1.0.14",
ClientID: "demoapp",
FirstReadTimeout: 500 * time.Millisecond,
MinBytesPerSecond: 256,
}
func newTestClient(cm *ClientManager) *client {
return cm.GetClient("tester").(*client)
}
func TestClient_Do(t *testing.T) {
const testResBody = "Hello World!"
var receivedReq *http.Request
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedReq = r
fmt.Fprint(w, testResBody)
}))
defer s.Close()
req, err := c.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal("Expected no error while creating request, got:", err)
}
res, err := c.Do(req, true)
if err != nil {
t.Fatal("Expected no error while executing request, got:", err)
}
b, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal("Expected no error while reading response, got:", err)
}
require.Nil(t, res.Body.Close())
if string(b) != testResBody {
t.Fatalf("Invalid response body: expected %v, got %v", testResBody, string(b))
}
h := receivedReq.Header
if h.Get("x-pm-appversion") != testClientConfig.AppVersion {
t.Fatalf("Invalid app version header: expected %v, got %v", testClientConfig.AppVersion, h.Get("x-pm-appversion"))
}
if h.Get("x-pm-uid") != "" {
t.Fatalf("Expected no uid header when not authenticated, got %v", h.Get("x-pm-uid"))
}
if h.Get("Authorization") != "" {
t.Fatalf("Expected no authentication header when not authenticated, got %v", h.Get("Authorization"))
}
}
func TestClient_DoRetryAfter(t *testing.T) {
testStart := time.Now()
secondAttemptTime := time.Now()
finish, c := newTestServerCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
w.Header().Set("content-type", "application/json;charset=utf-8")
w.Header().Set("Retry-After", "1")
w.WriteHeader(http.StatusTooManyRequests)
return ""
},
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
w.Header().Set("content-type", "application/json;charset=utf-8")
w.WriteHeader(http.StatusOK)
secondAttemptTime = time.Now()
return "/HTTP_200.json"
},
)
defer finish()
require.Nil(t, c.SendSimpleMetric("some_category", "some_action", "some_label"))
waitedTime := secondAttemptTime.Sub(testStart)
isInRange := 1*time.Second < waitedTime && waitedTime <= 11*time.Second
require.True(t, isInRange, "Waited time: %v", waitedTime)
}
type slowTransport struct {
transport http.RoundTripper
firstBodySleep time.Duration
}
func (t *slowTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := t.transport.RoundTrip(req)
if err == nil {
resp.Body = &slowReadCloser{
req: req,
readCloser: resp.Body,
firstBodySleep: t.firstBodySleep,
}
}
return resp, err
}
type slowReadCloser struct {
req *http.Request
readCloser io.ReadCloser
firstBodySleep time.Duration
}
func (r *slowReadCloser) Read(p []byte) (n int, err error) {
// Normally timeout is processed by Read function.
// It's hard to test slow connection; we need to manually
// check when context is Done, because otherwise timeout
// happens only during failed Read which will not happen
// in this artificial environment.
select {
case <-r.req.Context().Done():
return 0, context.Canceled
case <-time.After(r.firstBodySleep):
}
return r.readCloser.Read(p)
}
func (r *slowReadCloser) Close() error {
return r.readCloser.Close()
}
func TestClient_FirstReadTimeout(t *testing.T) {
requestTimeout := testClientConfig.FirstReadTimeout + 1*time.Second
finish, c := newTestServerCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
return "/HTTP_200.json"
},
)
defer finish()
c.hc.Transport = &slowTransport{
transport: c.hc.Transport,
firstBodySleep: requestTimeout,
}
started := time.Now()
err := c.SendSimpleMetric("some_category", "some_action", "some_label")
require.Error(t, err, "cannot reach the server")
require.True(t, time.Since(started) < requestTimeout, "Actual waited time: %v", time.Since(started))
}
func TestClient_MinSpeedTimeout(t *testing.T) {
finish, c := newTestServerCallbacks(t,
routeSlow(31*time.Second), // 1 second longer than the minimum transfer speed poll time.
)
defer finish()
err := c.SendSimpleMetric("some_category", "some_action", "some_label")
require.Error(t, err, "cannot reach the server")
}
func TestClient_MinSpeedNoTimeout(t *testing.T) {
finish, c := newTestServerCallbacks(t,
routeSlow(500*time.Millisecond),
)
defer finish()
err := c.SendSimpleMetric("some_category", "some_action", "some_label")
require.Nil(t, err)
}
func routeSlow(delay time.Duration) func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
return func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
w.Header().Set("content-type", "application/json;charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{\"code\":1000,\"key\":\""))
for chunk := 1; chunk <= 10; chunk++ {
// We need to write enough bytes which enforce flushing data
// because writer used by httptest does not implement Flusher.
for i := 1; i <= 10000; i++ {
_, _ = w.Write([]byte("a"))
}
time.Sleep(delay)
}
_, _ = w.Write([]byte("\"}"))
return ""
}
}

View File

@ -18,69 +18,66 @@
package pmapi package pmapi
import ( import (
"context"
"io" "io"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
) )
// Client defines the interface of a PMAPI client. // Client defines the interface of a PMAPI client.
type Client interface { type Client interface {
Auth(username, password string, info *AuthInfo) (*Auth, error) Auth2FA(context.Context, Auth2FAReq) error
AuthInfo(username string) (*AuthInfo, error) AuthSalt(ctx context.Context) (string, error)
AuthRefresh(token string) (*Auth, error) AuthDelete(context.Context) error
Auth2FA(twoFactorCode string, auth *Auth) error AddAuthHandler(AuthHandler)
AuthSalt() (salt string, err error)
Logout()
DeleteAuth() error
IsConnected() bool
CloseConnections()
ClearData()
CurrentUser() (*User, error) CurrentUser(ctx context.Context) (*User, error)
UpdateUser() (*User, error) UpdateUser(ctx context.Context) (*User, error)
Unlock(passphrase []byte) (err error) Unlock(ctx context.Context, passphrase []byte) (err error)
ReloadKeys(passphrase []byte) (err error) ReloadKeys(ctx context.Context, passphrase []byte) (err error)
IsUnlocked() bool IsUnlocked() bool
GetAddresses() (addresses AddressList, err error)
Addresses() AddressList Addresses() AddressList
ReorderAddresses(addressIDs []string) error GetAddresses(context.Context) (addresses AddressList, err error)
ReorderAddresses(ctx context.Context, addressIDs []string) error
GetEvent(eventID string) (*Event, error) GetEvent(ctx context.Context, eventID string) (*Event, error)
SendMessage(string, *SendMessageReq) (sent, parent *Message, err error) SendMessage(context.Context, string, *SendMessageReq) (sent, parent *Message, err error)
CreateDraft(m *Message, parent string, action int) (created *Message, err error) CreateDraft(ctx context.Context, m *Message, parent string, action int) (created *Message, err error)
Import([]*ImportMsgReq) ([]*ImportMsgRes, error) Import(context.Context, ImportMsgReqs) ([]*ImportMsgRes, error)
CountMessages(addressID string) ([]*MessagesCount, error) CountMessages(ctx context.Context, addressID string) ([]*MessagesCount, error)
ListMessages(filter *MessagesFilter) ([]*Message, int, error) ListMessages(ctx context.Context, filter *MessagesFilter) ([]*Message, int, error)
GetMessage(apiID string) (*Message, error) GetMessage(ctx context.Context, apiID string) (*Message, error)
DeleteMessages(apiIDs []string) error DeleteMessages(ctx context.Context, apiIDs []string) error
LabelMessages(apiIDs []string, labelID string) error LabelMessages(ctx context.Context, apiIDs []string, labelID string) error
UnlabelMessages(apiIDs []string, labelID string) error UnlabelMessages(ctx context.Context, apiIDs []string, labelID string) error
MarkMessagesRead(apiIDs []string) error MarkMessagesRead(ctx context.Context, apiIDs []string) error
MarkMessagesUnread(apiIDs []string) error MarkMessagesUnread(ctx context.Context, apiIDs []string) error
ListLabels() ([]*Label, error) ListLabels(ctx context.Context) ([]*Label, error)
CreateLabel(label *Label) (*Label, error) CreateLabel(ctx context.Context, label *Label) (*Label, error)
UpdateLabel(label *Label) (*Label, error) UpdateLabel(ctx context.Context, label *Label) (*Label, error)
DeleteLabel(labelID string) error DeleteLabel(ctx context.Context, labelID string) error
EmptyFolder(labelID string, addressID string) error EmptyFolder(ctx context.Context, labelID string, addressID string) error
Report(report ReportReq) error GetMailSettings(ctx context.Context) (MailSettings, error)
SendSimpleMetric(category, action, label string) error GetContactEmailByEmail(context.Context, string, int, int) ([]ContactEmail, error)
GetContactByID(context.Context, string) (Contact, error)
GetMailSettings() (MailSettings, error)
GetContactEmailByEmail(string, int, int) ([]ContactEmail, error)
GetContactByID(string) (Contact, error)
DecryptAndVerifyCards([]Card) ([]Card, error) DecryptAndVerifyCards([]Card) ([]Card, error)
GetAttachment(id string) (att io.ReadCloser, err error) GetAttachment(ctx context.Context, id string) (att io.ReadCloser, err error)
CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) CreateAttachment(ctx context.Context, att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
DeleteAttachment(attID string) (err error)
KeyRingForAddressID(string) (kr *crypto.KeyRing, err error) KeyRingForAddressID(string) (kr *crypto.KeyRing, err error)
GetPublicKeysForEmail(string) ([]PublicKey, bool, error) GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
}
DownloadAndVerify(string, string, *crypto.KeyRing) (io.Reader, error)
type AuthHandler func(*Auth) error
type requester interface {
r(context.Context) *resty.Request
authRefresh(context.Context, string, string) (*Auth, error)
} }

View File

@ -1,459 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/ProtonMail/proton-bridge/internal/config/useragent"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const maxLogoutRetries = 5
// ClientManager is a manager of clients.
type ClientManager struct { //nolint[maligned]
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
// create other types of clients (e.g. for integration tests).
newClient func(userID string) Client
config *ClientConfig
userAgent *useragent.UserAgent
roundTripper http.RoundTripper
clients map[string]Client
clientsLocker sync.Locker
cookieJar http.CookieJar
tokens map[string]string
tokensLocker sync.Locker
expirations map[string]*tokenExpiration
expiredTokens chan string
expirationsLocker sync.Locker
authUpdates chan ClientAuth
host, scheme string
hostLocker sync.RWMutex
allowProxy bool
proxyProvider *proxyProvider
proxyUseDuration time.Duration
idGen idGen
connectionOff bool
log *logrus.Entry
}
type idGen int
func (i *idGen) next() int {
(*i)++
return int(*i)
}
// ClientAuth holds an API auth produced by a Client for a specific user.
type ClientAuth struct {
UserID string
Auth *Auth
}
// tokenExpiration manages the expiration of an access token.
type tokenExpiration struct {
timer *time.Timer
cancel chan (struct{})
}
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
func NewClientManager(config *ClientConfig, userAgent *useragent.UserAgent) (cm *ClientManager) {
cm = &ClientManager{
config: config,
userAgent: userAgent,
roundTripper: http.DefaultTransport,
clients: make(map[string]Client),
clientsLocker: &sync.Mutex{},
tokens: make(map[string]string),
tokensLocker: &sync.Mutex{},
expirations: make(map[string]*tokenExpiration),
expiredTokens: make(chan string),
expirationsLocker: &sync.Mutex{},
host: rootURL,
scheme: rootScheme,
hostLocker: sync.RWMutex{},
authUpdates: make(chan ClientAuth),
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
proxyUseDuration: proxyUseDuration,
connectionOff: false,
log: logrus.WithField("pkg", "pmapi-manager"),
}
cm.newClient = func(userID string) Client {
return newClient(cm, userID)
}
go cm.watchTokenExpirations()
return cm
}
func (cm *ClientManager) noConnection() {
cm.log.Trace("No connection available")
if cm.connectionOff {
return
}
cm.log.Warn("Connection lost")
if cm.config.ConnectionOffHandler != nil {
cm.config.ConnectionOffHandler()
}
cm.connectionOff = true
go func() {
for {
time.Sleep(30 * time.Second)
if err := cm.CheckConnection(); err == nil {
cm.log.Info("Connection re-established")
if cm.config.ConnectionOnHandler != nil {
cm.config.ConnectionOnHandler()
}
cm.connectionOff = false
return
}
}
}()
}
// SetClientConstructor sets the method used to construct clients.
// By default this is `pmapi.newClient` but can be overridden with this method.
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
cm.newClient = f
}
// SetCookieJar sets the cookie jar given to clients.
func (cm *ClientManager) SetCookieJar(jar http.CookieJar) {
cm.cookieJar = jar
}
// SetRoundTripper sets the roundtripper used by clients created by this client manager.
func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
cm.roundTripper = rt
}
func (cm *ClientManager) GetAppVersion() string {
return cm.config.AppVersion
}
func (cm *ClientManager) GetUserAgent() string {
return cm.userAgent.String()
}
// GetClient returns a client for the given userID.
// If the client does not exist already, it is created.
func (cm *ClientManager) GetClient(userID string) Client {
cm.clientsLocker.Lock()
defer cm.clientsLocker.Unlock()
if client, ok := cm.clients[userID]; ok {
return client
}
client := cm.newClient(userID)
cm.clients[userID] = client
return client
}
// GetAnonymousClient returns an anonymous client.
func (cm *ClientManager) GetAnonymousClient() Client {
return cm.GetClient(fmt.Sprintf("anonymous-%v", cm.idGen.next()))
}
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
func (cm *ClientManager) LogoutClient(userID string) {
cm.clientsLocker.Lock()
defer cm.clientsLocker.Unlock()
client, ok := cm.clients[userID]
if !ok {
return
}
delete(cm.clients, userID)
go func() {
defer client.ClearData()
defer cm.clearToken(userID)
if strings.HasPrefix(userID, "anonymous-") {
return
}
var retries int
for client.DeleteAuth() == ErrAPINotReachable {
retries++
if retries > maxLogoutRetries {
cm.log.Error("Failed to delete client auth (retried too many times)")
break
}
cm.log.Warn("Failed to delete client auth because API was not reachable, retrying...")
}
}()
}
// GetRootURL returns the full root URL (scheme+host).
func (cm *ClientManager) GetRootURL() string {
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return fmt.Sprintf("%v://%v", cm.scheme, cm.host)
}
// getHost returns the host to make requests to.
// It does not include the protocol i.e. no "https://" (use getScheme for that).
func (cm *ClientManager) getHost() string {
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return cm.host
}
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
func (cm *ClientManager) IsProxyAllowed() bool {
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return cm.allowProxy
}
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
func (cm *ClientManager) AllowProxy() {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
cm.allowProxy = true
}
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
func (cm *ClientManager) DisallowProxy() {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
cm.allowProxy = false
cm.host = rootURL
for _, client := range cm.clients {
client.CloseConnections()
}
}
// IsProxyEnabled returns whether we are currently proxying requests.
func (cm *ClientManager) IsProxyEnabled() bool {
cm.hostLocker.RLock()
defer cm.hostLocker.RUnlock()
return cm.host != rootURL
}
// switchToReachableServer switches to using a reachable server (either proxy or standard API).
func (cm *ClientManager) switchToReachableServer() (proxy string, err error) {
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
logrus.Info("Attempting to switch to a proxy")
if proxy, err = cm.proxyProvider.findReachableServer(); err != nil {
err = errors.Wrap(err, "failed to find a usable proxy")
return
}
// If the chosen proxy is the standard API, we want to use it but still show the troubleshooting screen.
if proxy == rootURL {
logrus.Info("The standard API is reachable again; connection drop was only intermittent")
err = ErrAPINotReachable
cm.host = proxy
return
}
logrus.WithField("proxy", proxy).Info("Switching to a proxy")
// If the host is currently the rootURL, it's the first time we are enabling a proxy.
// This means we want to disable it again in 24 hours.
if cm.host == rootURL {
go func() {
<-time.After(cm.proxyUseDuration)
cm.hostLocker.Lock()
defer cm.hostLocker.Unlock()
cm.host = rootURL
}()
}
cm.host = proxy
return proxy, err
}
// GetToken returns the token for the given userID.
func (cm *ClientManager) GetToken(userID string) string {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
return cm.tokens[userID]
}
// GetAuthUpdateChannel returns a channel on which client auths can be received.
func (cm *ClientManager) GetAuthUpdateChannel() chan ClientAuth {
return cm.authUpdates
}
// setTokenIfUnset sets the token for the given userID if it wasn't already set.
// The set token does not expire.
func (cm *ClientManager) setTokenIfUnset(userID, token string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
if _, ok := cm.tokens[userID]; ok {
return
}
logrus.WithField("userID", userID).Info("Setting token because it is currently unset")
cm.tokens[userID] = token
}
// setToken sets the token for the given userID with the given expiration time.
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).Info("Updating token")
cm.tokens[userID] = token
cm.setTokenExpiration(userID, expiration)
}
// setTokenExpiration will ensure the token is refreshed if it expires.
// If the token already has an expiration time set, it is replaced.
func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) {
cm.expirationsLocker.Lock()
defer cm.expirationsLocker.Unlock()
// Reduce the expiration by one minute so we can do the refresh with enough time to spare.
expiration -= time.Minute
if exp, ok := cm.expirations[userID]; ok {
exp.timer.Stop()
close(exp.cancel)
}
cm.expirations[userID] = &tokenExpiration{
timer: time.NewTimer(expiration),
cancel: make(chan struct{}),
}
go func(expiration *tokenExpiration) {
select {
case <-expiration.timer.C:
cm.expiredTokens <- userID
case <-expiration.cancel:
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
}
}(cm.expirations[userID])
}
func (cm *ClientManager) clearToken(userID string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).Debug("Clearing token")
delete(cm.tokens, userID)
}
// HandleAuth updates or clears client authorisation based on auths received and then forwards the auth onwards.
func (cm *ClientManager) HandleAuth(ca ClientAuth) {
cm.clientsLocker.Lock()
defer cm.clientsLocker.Unlock()
// If we aren't managing this client, there's nothing to do.
if _, ok := cm.clients[ca.UserID]; !ok {
logrus.WithField("userID", ca.UserID).Info("Not handling auth for unmanaged client")
return
}
// If the auth is nil, we should clear the token.
if ca.Auth == nil {
cm.clearToken(ca.UserID)
go cm.LogoutClient(ca.UserID)
} else {
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
}
logrus.Debug("ClientManager is forwarding auth update...")
cm.authUpdates <- ca
logrus.Debug("Auth update was forwarded")
}
// watchTokenExpirations refreshes any tokens which are about to expire.
func (cm *ClientManager) watchTokenExpirations() {
for userID := range cm.expiredTokens {
log := cm.log.WithField("userID", userID)
log.Info("Auth token expired! Refreshing")
client, ok := cm.clients[userID]
if !ok {
log.Warn("Can't refresh expired token because there is no such client")
continue
}
token, ok := cm.tokens[userID]
if !ok {
log.Warn("Can't refresh expired token because there is no such token")
continue
}
if _, err := client.AuthRefresh(token); err != nil {
log.WithError(err).Error("Failed to refresh expired token")
}
}
}

View File

@ -1,31 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import "github.com/ProtonMail/proton-bridge/internal/config/useragent"
func newTestClientManager(cfg *ClientConfig) *ClientManager {
cm := NewClientManager(cfg, useragent.New())
go func() {
for range cm.authUpdates {
}
}()
return cm
}

View File

@ -1,55 +1,11 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi package pmapi
import ( type Config struct {
"runtime" HostURL string
"strings" AppVersion string
"time"
)
// rootURL is the API root URL.
// It must not contain the protocol! The protocol should be in rootScheme.
var rootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
// rootScheme is the scheme to use for connections to the root URL.
var rootScheme = "https" //nolint[gochecknoglobals]
func GetAPIConfig(configName, appVersion string) *ClientConfig {
return &ClientConfig{
AppVersion: getAPIOS() + strings.Title(configName) + "_" + appVersion,
ClientID: configName,
Timeout: 25 * time.Minute, // Overall request timeout (~25MB / 25 mins => ~16kB/s, should be reasonable).
FirstReadTimeout: 30 * time.Second, // 30s to match 30s response header timeout.
MinBytesPerSecond: 1 << 10, // Enforce minimum download speed of 1kB/s.
}
} }
// getAPIOS returns actual operating system. var DefaultConfig = Config{
func getAPIOS() string { HostURL: "https://api.protonmail.ch",
switch os := runtime.GOOS; os { AppVersion: "Other",
case "darwin": // nolint: goconst
return "macOS"
case "linux":
return "Linux"
case "windows":
return "Windows"
}
return "Linux"
} }

View File

@ -1,44 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// +build !build_qa
package pmapi
import (
"net/http"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
)
func GetRoundTripper(cm *ClientManager, listener listener.Listener) http.RoundTripper {
// We use a TLS dialer.
basicDialer := NewBasicTLSDialer()
// We wrap the TLS dialer in a layer which enforces connections to trusted servers.
pinningDialer := NewPinningTLSDialer(basicDialer)
// We want any pin mismatches to be communicated back to bridge GUI and reported.
pinningDialer.SetTLSIssueNotifier(func() { listener.Emit(events.TLSCertIssue, "") })
pinningDialer.EnableRemoteTLSIssueReporting(cm)
// We wrap the pinning dialer in a layer which adds "alternative routing" feature.
proxyDialer := NewProxyTLSDialer(pinningDialer, cm)
return CreateTransportWithDialer(proxyDialer)
}

View File

@ -1,51 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// +build build_qa
package pmapi
import (
"crypto/tls"
"net/http"
"os"
"strings"
"github.com/ProtonMail/proton-bridge/pkg/listener"
)
func init() {
// This config allows to dynamically change ROOT URL.
fullRootURL := os.Getenv("PMAPI_ROOT_URL")
if strings.HasPrefix(fullRootURL, "http") {
rootURLparts := strings.SplitN(fullRootURL, "://", 2)
rootScheme = rootURLparts[0]
rootURL = rootURLparts[1]
} else if fullRootURL != "" {
rootURL = fullRootURL
rootScheme = "https"
}
}
func GetRoundTripper(_ *ClientManager, _ listener.Listener) http.RoundTripper {
transport := CreateTransportWithDialer(NewBasicTLSDialer())
// TLS certificate of testing environment might be self-signed.
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return transport
}

View File

@ -1,23 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
// ConnectionReporter provides a way to report when internet connection is lost.
type ConnectionReporter interface {
NotifyConnectionLost() error
}

View File

@ -18,9 +18,11 @@
package pmapi package pmapi
import ( import (
"context"
"errors" "errors"
"net/url"
"strconv" "strconv"
"github.com/go-resty/resty/v2"
) )
type Card struct { type Card struct {
@ -105,322 +107,40 @@ func (c *client) DecryptAndVerifyCards(cards []Card) ([]Card, error) {
return cards, nil return cards, nil
} }
// ====================== READ ===========================
type ContactsListRes struct {
Res
Contacts []*Contact
}
// GetContacts gets all contacts.
func (c *client) GetContacts(page int, pageSize int) (contacts []*Contact, err error) {
v := url.Values{}
v.Set("Page", strconv.Itoa(page))
if pageSize > 0 {
v.Set("PageSize", strconv.Itoa(pageSize))
}
req, err := c.NewRequest("GET", "/contacts?"+v.Encode(), nil)
if err != nil {
return
}
var res ContactsListRes
if err = c.DoJSON(req, &res); err != nil {
return
}
contacts, err = res.Contacts, res.Err()
return
}
// GetContactByID gets contact details specified by contact ID. // GetContactByID gets contact details specified by contact ID.
func (c *client) GetContactByID(id string) (contactDetail Contact, err error) { func (c *client) GetContactByID(ctx context.Context, contactID string) (contactDetail Contact, err error) {
req, err := c.NewRequest("GET", "/contacts/"+id, nil) var res struct {
if err != nil {
return
}
type ContactRes struct {
Res
Contact Contact Contact Contact
} }
var res ContactRes
if err = c.DoJSON(req, &res); err != nil { if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return return r.SetResult(&res).Get("/contacts/v4/" + contactID)
}); err != nil {
return Contact{}, err
} }
contactDetail, err = res.Contact, res.Err() return res.Contact, nil
return
}
// GetContactsForExport gets contacts in vCard format, signed and encrypted.
func (c *client) GetContactsForExport(page int, pageSize int) (contacts []Contact, err error) {
v := url.Values{}
v.Set("Page", strconv.Itoa(page))
if pageSize > 0 {
v.Set("PageSize", strconv.Itoa(pageSize))
}
req, err := c.NewRequest("GET", "/contacts/export?"+v.Encode(), nil)
if err != nil {
return
}
type ContactsDetailsRes struct {
Res
Contacts []Contact
}
var res ContactsDetailsRes
if err = c.DoJSON(req, &res); err != nil {
return
}
contacts, err = res.Contacts, res.Err()
return
}
type ContactsEmailsRes struct {
Res
ContactEmails []ContactEmail
Total int
}
// GetAllContactsEmails gets all emails from all contacts.
func (c *client) GetAllContactsEmails(page int, pageSize int) (contactsEmails []ContactEmail, err error) {
v := url.Values{}
v.Set("Page", strconv.Itoa(page))
if pageSize > 0 {
v.Set("PageSize", strconv.Itoa(pageSize))
}
req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
if err != nil {
return
}
var res ContactsEmailsRes
if err = c.DoJSON(req, &res); err != nil {
return
}
contactsEmails, err = res.ContactEmails, res.Err()
return
} }
// GetContactEmailByEmail gets all emails from all contacts matching a specified email string. // GetContactEmailByEmail gets all emails from all contacts matching a specified email string.
func (c *client) GetContactEmailByEmail(email string, page int, pageSize int) (contactEmails []ContactEmail, err error) { func (c *client) GetContactEmailByEmail(ctx context.Context, email string, page int, pageSize int) (contactEmails []ContactEmail, err error) {
v := url.Values{} var res struct {
v.Set("Page", strconv.Itoa(page)) ContactEmails []ContactEmail
if pageSize > 0 {
v.Set("PageSize", strconv.Itoa(pageSize))
}
v.Set("Email", email)
req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
if err != nil {
return
} }
var res ContactsEmailsRes if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
if err = c.DoJSON(req, &res); err != nil { return r.SetQueryParams(map[string]string{
return "Email": email,
"Page": strconv.Itoa(page),
"PageSize": strconv.Itoa(pageSize),
}).SetResult(&res).Get("/contacts/v4")
}); err != nil {
return nil, err
} }
contactEmails, err = res.ContactEmails, res.Err() return res.ContactEmails, nil
return
} }
// ============================ CREATE ====================================
type CardsList struct {
Cards []Card
}
type ContactsCards struct {
Contacts []CardsList
}
type SingleContactResponse struct {
Res
Contact Contact
}
type IndexedContactResponse struct {
Index int
Response SingleContactResponse
}
type AddContactsResponse struct {
Res
Responses []IndexedContactResponse
}
type AddContactsReq struct {
ContactsCards
Overwrite int
Groups int
Labels int
}
// AddContacts adds contacts specified by cards. Performs signing and encrypting based on card type.
func (c *client) AddContacts(cards ContactsCards, overwrite int, groups int, labels int) (res *AddContactsResponse, err error) {
reqBody := AddContactsReq{
ContactsCards: cards,
Overwrite: overwrite,
Groups: groups,
Labels: labels,
}
req, err := c.NewJSONRequest("POST", "/contacts", reqBody)
if err != nil {
return
}
var addContactsRes AddContactsResponse
if err = c.DoJSON(req, &addContactsRes); err != nil {
return
}
res, err = &addContactsRes, addContactsRes.Err()
return
}
// ================================= UPDATE =======================================
type UpdateContactResponse struct {
Res
Contact Contact
}
type UpdateContactReq struct {
Cards []Card
}
// UpdateContact updates contact identified by contact ID. Modified contact is specified by cards.
func (c *client) UpdateContact(id string, cards []Card) (res *UpdateContactResponse, err error) {
reqBody := UpdateContactReq{
Cards: cards,
}
req, err := c.NewJSONRequest("PUT", "/contacts/"+id, reqBody)
if err != nil {
return
}
var updateContactRes UpdateContactResponse
if err = c.DoJSON(req, &updateContactRes); err != nil {
return
}
res, err = &updateContactRes, updateContactRes.Err()
return
}
type SingleIDResponse struct {
Res
ID string
}
type UpdateContactGroupsResponse struct {
Res
Response SingleIDResponse
}
func (c *client) AddContactGroups(groupID string, contactEmailIDs []string) (res *UpdateContactGroupsResponse, err error) {
return c.modifyContactGroups(groupID, addContactGroupsAction, contactEmailIDs)
}
func (c *client) RemoveContactGroups(groupID string, contactEmailIDs []string) (res *UpdateContactGroupsResponse, err error) {
return c.modifyContactGroups(groupID, removeContactGroupsAction, contactEmailIDs)
}
const (
removeContactGroupsAction = 0
addContactGroupsAction = 1
)
type ModifyContactGroupsReq struct {
LabelID string
Action int
ContactEmailIDs []string
}
func (c *client) modifyContactGroups(groupID string, modifyContactGroupsAction int, contactEmailIDs []string) (res *UpdateContactGroupsResponse, err error) {
reqBody := ModifyContactGroupsReq{
LabelID: groupID,
Action: modifyContactGroupsAction,
ContactEmailIDs: contactEmailIDs,
}
req, err := c.NewJSONRequest("PUT", "/contacts/group", reqBody)
if err != nil {
return
}
if err = c.DoJSON(req, &res); err != nil {
return
}
err = res.Err()
return
}
// ================================= DELETE =======================================
type DeleteReq struct {
IDs []string
}
// DeleteContacts deletes contacts specified by an array of contact IDs.
func (c *client) DeleteContacts(ids []string) (err error) {
deleteReq := DeleteReq{
IDs: ids,
}
req, err := c.NewJSONRequest("PUT", "/contacts/delete", deleteReq)
if err != nil {
return
}
type DeleteContactsRes struct {
Res
Responses []struct {
ID string
Response Res
}
}
var res DeleteContactsRes
if err = c.DoJSON(req, &res); err != nil {
return
}
if err = res.Err(); err != nil {
return
}
return
}
// DeleteAllContacts deletes all contacts.
func (c *client) DeleteAllContacts() (err error) {
req, err := c.NewRequest("DELETE", "/contacts", nil)
if err != nil {
return
}
var res Res
if err = c.DoJSON(req, &res); err != nil {
return
}
if err = res.Err(); err != nil {
return
}
return
}
// ===================== Private utility methods =======================
func isSignedCardType(cardType int) bool { func isSignedCardType(cardType int) bool {
return (cardType & CardSigned) == CardSigned return (cardType & CardSigned) == CardSigned
} }

View File

@ -18,7 +18,7 @@
package pmapi package pmapi
import ( import (
"encoding/json" "context"
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
@ -34,221 +34,6 @@ var (
EncryptedSignedCard = 3 EncryptedSignedCard = 3
) )
var testAddContactsReq = AddContactsReq{
ContactsCards: ContactsCards{
Contacts: []CardsList{
{
Cards: []Card{
{
Type: 2,
Data: `BEGIN:VCARD
VERSION:4.0
FN;TYPE=fn:Bob
item1.EMAIL:bob.tester@protonmail.com
UID:proton-web-cd974706-5cde-0e53-e131-c49c88a92ece
END:VCARD
`,
Signature: ``,
},
},
},
},
},
Overwrite: 0,
Groups: 0,
Labels: 0,
}
var testAddContactsResponseBody = `{
"Code": 1001,
"Responses": [
{
"Index": 0,
"Response": {
"Code": 1000,
"Contact": {
"ID": "EU7qYvPAdgJ-zl53hw_btO1WG8TN2FYh2cTIFq1_T6KqulwgxF8CzPjVk_RBUdEejtLvfynlelVNoZwMK_9X2g==",
"Name": "Bob",
"UID": "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
"Size": 139,
"CreateTime": 1517319495,
"ModifyTime": 1517319495,
"ContactEmails": [
{
"ID": "VT4NoPeQPk48_vg0CVmk63n5mB6CZn9q-P_DYODhOUemhuzUkgBFGF1MktVArjX5zsVdfVlEBFObvt0_K5NwPg==",
"Name": "Bob",
"Email": "bob.tester@protonmail.com",
"Type": [],
"Defaults": 1,
"Order": 1,
"ContactID": "EU7qYvPAdgJ-zl53hw_btO1WG8TN2FYh2cTIFq1_T6KqulwgxF8CzPjVk_RBUdEejtLvfynlelVNoZwMK_9X2g==",
"LabelIDs": []
}
],
"LabelIDs": []
}
}
}
]
}`
var testContactCreated = &AddContactsResponse{
Res: Res{
Code: 1001,
StatusCode: 200,
},
Responses: []IndexedContactResponse{
{
Index: 0,
Response: SingleContactResponse{
Res: Res{
Code: 1000,
},
Contact: Contact{
ID: "EU7qYvPAdgJ-zl53hw_btO1WG8TN2FYh2cTIFq1_T6KqulwgxF8CzPjVk_RBUdEejtLvfynlelVNoZwMK_9X2g==",
Name: "Bob",
UID: "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
Size: 139,
CreateTime: 1517319495,
ModifyTime: 1517319495,
ContactEmails: []ContactEmail{
{
ID: "VT4NoPeQPk48_vg0CVmk63n5mB6CZn9q-P_DYODhOUemhuzUkgBFGF1MktVArjX5zsVdfVlEBFObvt0_K5NwPg==",
Name: "Bob",
Email: "bob.tester@protonmail.com",
Type: []string{},
Defaults: 1,
Order: 1,
ContactID: "EU7qYvPAdgJ-zl53hw_btO1WG8TN2FYh2cTIFq1_T6KqulwgxF8CzPjVk_RBUdEejtLvfynlelVNoZwMK_9X2g==",
LabelIDs: []string{},
},
},
LabelIDs: []string{},
},
},
},
},
}
var testContactUpdated = &UpdateContactResponse{
Res: Res{
Code: 1000,
StatusCode: 200,
},
Contact: Contact{
ID: "l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ==",
Name: "Bob",
UID: "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
Size: 303,
CreateTime: 1517416603,
ModifyTime: 1517416656,
ContactEmails: []ContactEmail{
{
ID: "14n6vuf1zbeo3zsYzgV471S6xJ9gzl7-VZ8tcOTQq6ifBlNEre0SUdUM7sXh6e2Q_4NhJZaU9c7jLdB1HCV6dA==",
Name: "Bob",
Email: "bob.changed.tester@protonmail.com",
Type: []string{},
Defaults: 1,
Order: 1,
ContactID: "l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ==",
LabelIDs: []string{},
},
},
LabelIDs: []string{},
},
}
func TestContact_AddContact(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "POST", "/contacts"))
var addContactsReq AddContactsReq
if err := json.NewDecoder(r.Body).Decode(&addContactsReq); err != nil {
t.Error("Expecting no error while reading request body, got:", err)
}
if !reflect.DeepEqual(testAddContactsReq.ContactsCards, addContactsReq.ContactsCards) {
t.Errorf("Invalid contacts request: expected %+v but got %+v", testAddContactsReq.ContactsCards, addContactsReq.ContactsCards)
}
fmt.Fprint(w, testAddContactsResponseBody)
}))
defer s.Close()
created, err := c.AddContacts(testAddContactsReq.ContactsCards, 0, 0, 0)
if err != nil {
t.Fatal("Expected no error while adding contact, got:", err)
}
if !reflect.DeepEqual(created, testContactCreated) {
t.Fatalf("Invalid created contact: expected %+v, got %+v", testContactCreated, created)
}
}
var testGetContactsResponseBody = `{
"Code": 1000,
"Contacts": [
{
"ID": "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
"Name": "Alice",
"UID": "proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b",
"Size": 243,
"CreateTime": 1517395498,
"ModifyTime": 1517395498,
"LabelIDs": []
},
{
"ID": "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
"Name": "Bob",
"UID": "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
"Size": 303,
"CreateTime": 1517394677,
"ModifyTime": 1517394678,
"LabelIDs": []
}
],
"Total": 2
}`
var testGetContacts = []*Contact{
{
ID: "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
Name: "Alice",
UID: "proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b",
Size: 243,
CreateTime: 1517395498,
ModifyTime: 1517395498,
LabelIDs: []string{},
},
{
ID: "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
Name: "Bob",
UID: "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
Size: 303,
CreateTime: 1517394677,
ModifyTime: 1517394678,
LabelIDs: []string{},
},
}
func TestContact_GetContacts(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "GET", "/contacts?Page=0&PageSize=1000"))
fmt.Fprint(w, testGetContactsResponseBody)
}))
defer s.Close()
contacts, err := c.GetContacts(0, 1000)
if err != nil {
t.Fatal("Expected no error while getting contacts, got:", err)
}
if !reflect.DeepEqual(contacts, testGetContacts) {
t.Fatalf("Invalid created contact: expected %+v, got %+v", testGetContacts, contacts)
}
}
var testGetContactByIDResponseBody = `{ var testGetContactByIDResponseBody = `{
"Code": 1000, "Code": 1000,
"Contact": { "Contact": {
@ -321,14 +106,16 @@ var testGetContactByID = Contact{
} }
func TestContact_GetContactById(t *testing.T) { func TestContact_GetContactById(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "GET", "/contacts/s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==")) Ok(t, checkMethodAndPath(r, "GET", "/contacts/v4/s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg=="))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testGetContactByIDResponseBody) fmt.Fprint(w, testGetContactByIDResponseBody)
})) }))
defer s.Close() defer s.Close()
contact, err := c.GetContactByID("s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==") contact, err := c.GetContactByID(context.TODO(), "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==")
if err != nil { if err != nil {
t.Fatal("Expected no error while getting contacts, got:", err) t.Fatal("Expected no error while getting contacts, got:", err)
} }
@ -338,287 +125,6 @@ func TestContact_GetContactById(t *testing.T) {
} }
} }
var testGetContactsForExportResponseBody = `{
"Code": 1000,
"Contacts": [
{
"ID": "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
"Cards": [
{
"Type": 2,
"Data": "BEGIN:VCARD\nVERSION:4.0\nFN;TYPE=fn:Bob\nitem1.EMAIL:bob.changed.tester@protonmail.com\nUID:proton-web-cd974706-5cde-0e53-e131-c49c88a92ece\nEND:VCARD\n",
"Signature": "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZr2CRDMO9BwcW4mpAAAtxwIAFGgPO+xH4PHppffQC1R\nxCp/Bjzaq5rDUE3ZMKVJ1sFqGVlq2bP5CIN4w2XCe/MuZ+z2o87fSEtt2n7i\n0/8Ah35u4czn7t8FZoW8u9WwHPURa8gUbP3fYpVASBY1Bt2fUxJrSUYn5KQp\njJM/DgF99bhIjOTuhx9IN7DFKG647Arq+GJh9M6RJNxkb3CBfcCVUXoIwMB7\nnM/fA1r+mcl8dQam0WKVJgy9aO2XUUR62w1SpqJlXY3z8hKvXjjskzU3DQk5\net07RLVQvhy2nCZePsM+TJzL8OBbTa1aF/p1xPe+HND7t3ZCm9tQOY+UhK5H\nbhPbQY48KGdci1dTcm2HbsQ=\n=iOnV\n-----END PGP SIGNATURE-----\n"
},
{
"Type": 3,
"Data": "-----BEGIN PGP MESSAGE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwcBMA1vYAFKnBP8gAQgABsQWnqZadqrHDN43McGhEYfJjOB66R5HhkQAUavP\nHaAHpJciGxfz6tbztQu4C6kdMA80ElbD8c+bJqalw6ZbT4seoP4TTQLykD1n\n0LuNBlaW4x8kfd8rZzFdckk/dY2PruX6byAjSZslnZlZSwp99AJJbvJtfXRR\nzunKMbDieRkaApGZYT25wT5mz1embpXFesvO4nDkOEQCa0uyti3mNSLhYlf/\ntbaOS3WM9VYM9eB9YRZGzJNxMtTxOsd45tBlGCHnCzWEUnJdqZuYzH2QOky7\nMckXhk6YwyemYi/q7OOgSYEg/0lCs2EK3b//14yPDx8Bj5G7rZrnDgsP+BHj\nu9KaAZb2pSBPQoJ2DY3Y4A2Sg8GjaX5CMO9D6GKJkZSYkXddQgcmw7sVPUS+\n+5JaPXlfxoJOOn9kj9A6LDC6eMhYaLujG1BKcZ16DB0jqfwMnPLJ+bYEdatr\nKMvd9rbdsDwQ/tfk11VvHpiEBCNZjxM2+bdBLl9q2EXaLXi+dz/rJg5C0A9u\nNS2CzCUvg6+jNUzHo/RBfRXvlNV8tw==\n=mE2b\n-----END PGP MESSAGE-----\n",
"Signature": "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZr2CRDMO9BwcW4mpAAApucIAD/uwWuV6DOg127XIPG6\n/jluL8jmwyCJX9noL6S8ZVMOymziKSh4/P1QyMPC5SL4lMPEiuaEdyetfBkU\n+5hW3tcZ+ptxmDi59SVYqmXTVewPgeB7t8c5nbzCuVuzA7ZAo8HAXHzFVQDS\nj9fKVGjZzQkmlwdcfnkXHAF0Ejilv9wxOOYgqVDuzm7JXVF3Um7nAgGKTJE5\n5CNnrEjmJGapj96mQFwXzET/kAhNIBw9tL5FAkDlKImdw8C0w9sXdvDu3yVM\ntvUZ5o2rR6ft0SC1byFso49vgJ/syeK6P2pPzltZJbsp4MvmlPUB0/G1XRU+\nI7q4IOWCvs8RD88ADmOty2o=\n=hyZE\n-----END PGP SIGNATURE-----\n"
}
]
},
{
"ID": "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
"Cards": [
{
"Type": 3,
"Data": "-----BEGIN PGP MESSAGE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwcBMA1vYAFKnBP8gAQf/RnOQRpo8DVJHQSJRgckEaUQvdMcADiM4L23diyiS\nQfclby/Ve2WInmvZc2RJ3rWENfeqyDZE6krQT642pKiW09GOIyVIjl+hje9y\nE4HBX0AIAWv7QhhKX6UZcM5dYSFbV3j3QxQB8A4Thng2G6ltotMTlbtcHbhu\n96Lt6ngA1tngXLSF5seyflnoiSQ5gLi2qVzrd95dIP6D4Ottcp929/4hDGmq\nPyxw9dColx6gVd1bmIDSI6ewkET4Grmo6QYqjSvjqLOf0PqHKzqypSFLkI5l\nmmnWKYTQCgl9wX+hq6Qz5E+m/BtbkdeX0YxYUss2e+oSAzJmnfdETErG9U5z\n3NJqAc3sgdwDzfWHBzogAxAbDHiqrF6zMlR5SFvZ6nRU7M2DTOE5dJhf+zOp\n1WSKn5LR46LGyt0m5wJPDjaGyQdPffAO4EULvwhGENe10UxRjY1qcUmjYOtS\nunl/vh3afI9PC1jj+HHJD2VgCA==\n=UpcY\n-----END PGP MESSAGE-----\n",
"Signature": "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4pCRDMO9BwcW4mpAAA6h0H/2+97koXzly5pu9hpbaW\n75d1Q976RjMr5DjAx6tKFtSzznel8YfWgvA6OQmMGdPY8ae7/+3mwCJZYWy/\nXVvUfCSflmYpSIKGfP+Vm1XezWY1W84DGhiFj5n8sdaWisv3bpFwFf1YR3Ae\noBoZ4ufNzaQALRqGPMgXETtXZCtzuL/+0vGSKj5SLECiRcSE4jCPEVRy2bcl\nWJyB9r4VmcjF042OMHxphXoYmTEWvgigyaQFHNORu5cK9EHfHpCG6IcjGbdx\n+9Px5YnDY1ix+YpBKePGSTlLE0u6ow0VTUrdvNjl7IUBaRcfJcIIdgCBOTMw\n1uQ/yeyP46V5AFXFnIKeZeQ=\n=FlOf\n-----END PGP SIGNATURE-----\n"
},
{
"Type": 2,
"Data": "BEGIN:VCARD\nVERSION:4.0\nFN;TYPE=fn:Alice\nitem1.EMAIL:alice@protonmail.com\nUID:proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b\nEND:VCARD",
"Signature": "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4qCRDMO9BwcW4mpAAA3jUIAJ88mIyO8Yj0+evSFXnK\nNxNdjNgn7t1leY0BWlh1nkK76XrZEPipdw2QU8cOcZzn1Wby2SGfZVkwoPc4\nzAhPT4WKbkFVqXhDry5399kLwGYJCxdEcw/oPyYj+YgpQKMxhTrQq21tbEwr\n7JDRBXgi3Cckh/XsteFHOIiAVnM7BV6zFudipnYxa4uNF0Bf4VbUZx1Mm0Wb\nMJaGsO5reqQUQzDPO5TdSAZ8qGSdjVv7RESgUu5DckcDSsnB987Zbh9uFc22\nfPYmb6zA0cEZh3dAjpDPT7cg8hlvfYBb+kP3sLFyLiIkdEG8Pcagjf0k+l76\nr1IsPlYBx2LJmsJf+WDNlj8=\n=Xn+3\n-----END PGP SIGNATURE-----\n"
}
]
}
],
"Total": 2
}`
var testGetContactsForExport = []Contact{
{
ID: "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
Cards: []Card{
{
Type: 2,
Data: "BEGIN:VCARD\nVERSION:4.0\nFN;TYPE=fn:Bob\nitem1.EMAIL:bob.changed.tester@protonmail.com\nUID:proton-web-cd974706-5cde-0e53-e131-c49c88a92ece\nEND:VCARD\n",
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZr2CRDMO9BwcW4mpAAAtxwIAFGgPO+xH4PHppffQC1R\nxCp/Bjzaq5rDUE3ZMKVJ1sFqGVlq2bP5CIN4w2XCe/MuZ+z2o87fSEtt2n7i\n0/8Ah35u4czn7t8FZoW8u9WwHPURa8gUbP3fYpVASBY1Bt2fUxJrSUYn5KQp\njJM/DgF99bhIjOTuhx9IN7DFKG647Arq+GJh9M6RJNxkb3CBfcCVUXoIwMB7\nnM/fA1r+mcl8dQam0WKVJgy9aO2XUUR62w1SpqJlXY3z8hKvXjjskzU3DQk5\net07RLVQvhy2nCZePsM+TJzL8OBbTa1aF/p1xPe+HND7t3ZCm9tQOY+UhK5H\nbhPbQY48KGdci1dTcm2HbsQ=\n=iOnV\n-----END PGP SIGNATURE-----\n",
},
{
Type: 3,
Data: "-----BEGIN PGP MESSAGE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwcBMA1vYAFKnBP8gAQgABsQWnqZadqrHDN43McGhEYfJjOB66R5HhkQAUavP\nHaAHpJciGxfz6tbztQu4C6kdMA80ElbD8c+bJqalw6ZbT4seoP4TTQLykD1n\n0LuNBlaW4x8kfd8rZzFdckk/dY2PruX6byAjSZslnZlZSwp99AJJbvJtfXRR\nzunKMbDieRkaApGZYT25wT5mz1embpXFesvO4nDkOEQCa0uyti3mNSLhYlf/\ntbaOS3WM9VYM9eB9YRZGzJNxMtTxOsd45tBlGCHnCzWEUnJdqZuYzH2QOky7\nMckXhk6YwyemYi/q7OOgSYEg/0lCs2EK3b//14yPDx8Bj5G7rZrnDgsP+BHj\nu9KaAZb2pSBPQoJ2DY3Y4A2Sg8GjaX5CMO9D6GKJkZSYkXddQgcmw7sVPUS+\n+5JaPXlfxoJOOn9kj9A6LDC6eMhYaLujG1BKcZ16DB0jqfwMnPLJ+bYEdatr\nKMvd9rbdsDwQ/tfk11VvHpiEBCNZjxM2+bdBLl9q2EXaLXi+dz/rJg5C0A9u\nNS2CzCUvg6+jNUzHo/RBfRXvlNV8tw==\n=mE2b\n-----END PGP MESSAGE-----\n",
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZr2CRDMO9BwcW4mpAAApucIAD/uwWuV6DOg127XIPG6\n/jluL8jmwyCJX9noL6S8ZVMOymziKSh4/P1QyMPC5SL4lMPEiuaEdyetfBkU\n+5hW3tcZ+ptxmDi59SVYqmXTVewPgeB7t8c5nbzCuVuzA7ZAo8HAXHzFVQDS\nj9fKVGjZzQkmlwdcfnkXHAF0Ejilv9wxOOYgqVDuzm7JXVF3Um7nAgGKTJE5\n5CNnrEjmJGapj96mQFwXzET/kAhNIBw9tL5FAkDlKImdw8C0w9sXdvDu3yVM\ntvUZ5o2rR6ft0SC1byFso49vgJ/syeK6P2pPzltZJbsp4MvmlPUB0/G1XRU+\nI7q4IOWCvs8RD88ADmOty2o=\n=hyZE\n-----END PGP SIGNATURE-----\n",
},
},
},
{
ID: "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
Cards: []Card{
{
Type: 3,
Data: "-----BEGIN PGP MESSAGE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwcBMA1vYAFKnBP8gAQf/RnOQRpo8DVJHQSJRgckEaUQvdMcADiM4L23diyiS\nQfclby/Ve2WInmvZc2RJ3rWENfeqyDZE6krQT642pKiW09GOIyVIjl+hje9y\nE4HBX0AIAWv7QhhKX6UZcM5dYSFbV3j3QxQB8A4Thng2G6ltotMTlbtcHbhu\n96Lt6ngA1tngXLSF5seyflnoiSQ5gLi2qVzrd95dIP6D4Ottcp929/4hDGmq\nPyxw9dColx6gVd1bmIDSI6ewkET4Grmo6QYqjSvjqLOf0PqHKzqypSFLkI5l\nmmnWKYTQCgl9wX+hq6Qz5E+m/BtbkdeX0YxYUss2e+oSAzJmnfdETErG9U5z\n3NJqAc3sgdwDzfWHBzogAxAbDHiqrF6zMlR5SFvZ6nRU7M2DTOE5dJhf+zOp\n1WSKn5LR46LGyt0m5wJPDjaGyQdPffAO4EULvwhGENe10UxRjY1qcUmjYOtS\nunl/vh3afI9PC1jj+HHJD2VgCA==\n=UpcY\n-----END PGP MESSAGE-----\n",
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4pCRDMO9BwcW4mpAAA6h0H/2+97koXzly5pu9hpbaW\n75d1Q976RjMr5DjAx6tKFtSzznel8YfWgvA6OQmMGdPY8ae7/+3mwCJZYWy/\nXVvUfCSflmYpSIKGfP+Vm1XezWY1W84DGhiFj5n8sdaWisv3bpFwFf1YR3Ae\noBoZ4ufNzaQALRqGPMgXETtXZCtzuL/+0vGSKj5SLECiRcSE4jCPEVRy2bcl\nWJyB9r4VmcjF042OMHxphXoYmTEWvgigyaQFHNORu5cK9EHfHpCG6IcjGbdx\n+9Px5YnDY1ix+YpBKePGSTlLE0u6ow0VTUrdvNjl7IUBaRcfJcIIdgCBOTMw\n1uQ/yeyP46V5AFXFnIKeZeQ=\n=FlOf\n-----END PGP SIGNATURE-----\n",
},
{
Type: 2,
Data: "BEGIN:VCARD\nVERSION:4.0\nFN;TYPE=fn:Alice\nitem1.EMAIL:alice@protonmail.com\nUID:proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b\nEND:VCARD",
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4qCRDMO9BwcW4mpAAA3jUIAJ88mIyO8Yj0+evSFXnK\nNxNdjNgn7t1leY0BWlh1nkK76XrZEPipdw2QU8cOcZzn1Wby2SGfZVkwoPc4\nzAhPT4WKbkFVqXhDry5399kLwGYJCxdEcw/oPyYj+YgpQKMxhTrQq21tbEwr\n7JDRBXgi3Cckh/XsteFHOIiAVnM7BV6zFudipnYxa4uNF0Bf4VbUZx1Mm0Wb\nMJaGsO5reqQUQzDPO5TdSAZ8qGSdjVv7RESgUu5DckcDSsnB987Zbh9uFc22\nfPYmb6zA0cEZh3dAjpDPT7cg8hlvfYBb+kP3sLFyLiIkdEG8Pcagjf0k+l76\nr1IsPlYBx2LJmsJf+WDNlj8=\n=Xn+3\n-----END PGP SIGNATURE-----\n",
},
},
},
}
func TestContact_GetContactsForExport(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "GET", "/contacts/export?Page=0&PageSize=1000"))
fmt.Fprint(w, testGetContactsForExportResponseBody)
}))
defer s.Close()
contacts, err := c.GetContactsForExport(0, 1000)
if err != nil {
t.Fatal("Expected no error while getting contacts for export, got:", err)
}
if !reflect.DeepEqual(contacts, testGetContactsForExport) {
t.Fatalf("Invalid contact for export: expected %+v, got %+v", testGetContactsForExport, contacts)
}
}
var testGetContactsEmailsResponseBody = `{
"Code": 1000,
"ContactEmails": [
{
"ID": "Hgyz1tG0OiC2v_hMIVOa6juMOAp_recWNzWII7a79Tfwdx08Jy3FJY0_Y_UtFYwbi6mN-Xx1sOI9_GmUGAcwWg==",
"Name": "Bob",
"Email": "bob.changed.tester@protonmail.com",
"Type": [],
"Defaults": 1,
"Order": 1,
"ContactID": "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
"LabelIDs": []
},
{
"ID": "4m2sBxLq4McqD0D330Kuy5xG-yyDNXyLEjG5_RYcjy9X-3qHGNP07DNOWLY40TYtUAQr4fAVp8zOcZ_z2o6H-A==",
"Name": "Alice",
"Email": "alice@protonmail.com",
"Type": [],
"Defaults": 1,
"Order": 1,
"ContactID": "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
"LabelIDs": []
}
],
"Total": 2
}`
var testGetContactsEmails = []ContactEmail{
{
ID: "Hgyz1tG0OiC2v_hMIVOa6juMOAp_recWNzWII7a79Tfwdx08Jy3FJY0_Y_UtFYwbi6mN-Xx1sOI9_GmUGAcwWg==",
Name: "Bob",
Email: "bob.changed.tester@protonmail.com",
Type: []string{},
Defaults: 1,
Order: 1,
ContactID: "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
LabelIDs: []string{},
},
{
ID: "4m2sBxLq4McqD0D330Kuy5xG-yyDNXyLEjG5_RYcjy9X-3qHGNP07DNOWLY40TYtUAQr4fAVp8zOcZ_z2o6H-A==",
Name: "Alice",
Email: "alice@protonmail.com",
Type: []string{},
Defaults: 1,
Order: 1,
ContactID: "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
LabelIDs: []string{},
},
}
func TestContact_GetAllContactsEmails(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "GET", "/contacts/emails?Page=0&PageSize=1000"))
fmt.Fprint(w, testGetContactsEmailsResponseBody)
}))
defer s.Close()
contactsEmails, err := c.GetAllContactsEmails(0, 1000)
if err != nil {
t.Fatal("Expected no error while getting contacts for export, got:", err)
}
if !reflect.DeepEqual(contactsEmails, testGetContactsEmails) {
t.Fatalf("Invalid contact for export: expected %+v, got %+v", testGetContactsEmails, contactsEmails)
}
}
var testUpdateContactReq = UpdateContactReq{
Cards: []Card{
{
Type: 2,
Data: `BEGIN:VCARD
VERSION:4.0
FN;TYPE=fn:Bob
item1.EMAIL:bob.changed.tester@protonmail.com
UID:proton-web-cd974706-5cde-0e53-e131-c49c88a92ece
END:VCARD
`,
Signature: ``,
},
},
}
var testUpdateContactResponseBody = `{
"Code": 1000,
"Contact": {
"ID": "l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ==",
"Name": "Bob",
"UID": "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
"Size": 303,
"CreateTime": 1517416603,
"ModifyTime": 1517416656,
"ContactEmails": [
{
"ID": "14n6vuf1zbeo3zsYzgV471S6xJ9gzl7-VZ8tcOTQq6ifBlNEre0SUdUM7sXh6e2Q_4NhJZaU9c7jLdB1HCV6dA==",
"Name": "Bob",
"Email": "bob.changed.tester@protonmail.com",
"Type": [],
"Defaults": 1,
"Order": 1,
"ContactID": "l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ==",
"LabelIDs": []
}
],
"LabelIDs": []
}
}`
func TestContact_UpdateContact(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "PUT", "/contacts/l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ=="))
var updateContactReq UpdateContactReq
if err := json.NewDecoder(r.Body).Decode(&updateContactReq); err != nil {
t.Error("Expecting no error while reading request body, got:", err)
}
if !reflect.DeepEqual(testUpdateContactReq.Cards, updateContactReq.Cards) {
t.Errorf("Invalid contacts request: expected %+v but got %+v", testUpdateContactReq.Cards, updateContactReq.Cards)
}
fmt.Fprint(w, testUpdateContactResponseBody)
}))
defer s.Close()
created, err := c.UpdateContact("l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ==", testUpdateContactReq.Cards)
if err != nil {
t.Fatal("Expected no error while updating contact, got:", err)
}
if !reflect.DeepEqual(created, testContactUpdated) {
t.Fatalf("Invalid updated contact: expected\n%+v\ngot\n%+v\n", testContactUpdated, created)
}
}
var testDeleteContactsReq = DeleteReq{
IDs: []string{
"s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
},
}
var testDeleteContactsResponseBody = `{
"Code": 1001,
"Responses": [
{
"ID": "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
"Response": {
"Code": 1000
}
}
]
}`
func TestContact_DeleteContacts(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "PUT", "/contacts/delete"))
var deleteContactsReq DeleteReq
if err := json.NewDecoder(r.Body).Decode(&deleteContactsReq); err != nil {
t.Error("Expecting no error while reading request body, got:", err)
}
if !reflect.DeepEqual(testDeleteContactsReq.IDs, deleteContactsReq.IDs) {
t.Errorf("Invalid delete contacts request: expected %+v but got %+v", deleteContactsReq.IDs, testDeleteContactsReq.IDs)
}
fmt.Fprint(w, testDeleteContactsResponseBody)
}))
defer s.Close()
err := c.DeleteContacts([]string{"s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg=="})
if err != nil {
t.Fatal("Expected no error while getting contacts for export, got:", err)
}
}
var testDeleteAllResponseBody = `{
"Code": 1000
}`
func TestContact_DeleteAllContacts(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "DELETE", "/contacts"))
fmt.Fprint(w, testDeleteAllResponseBody)
}))
defer s.Close()
err := c.DeleteAllContacts()
if err != nil {
t.Fatal("Expected no error while getting contacts for export, got:", err)
}
}
func TestContact_isSignedCardType(t *testing.T) { func TestContact_isSignedCardType(t *testing.T) {
if !isSignedCardType(SignedCard) || !isSignedCardType(EncryptedSignedCard) { if !isSignedCardType(SignedCard) || !isSignedCardType(EncryptedSignedCard) {
t.Fatal("isSignedCardType shouldn't return false for signed card types") t.Fatal("isSignedCardType shouldn't return false for signed card types")
@ -654,7 +160,7 @@ var testCardsCleartext = []Card{
} }
func TestClient_Encrypt(t *testing.T) { func TestClient_Encrypt(t *testing.T) {
c := newTestClient(newTestClientManager(testClientConfig)) c := newClient(newManager(DefaultConfig), "")
c.userKeyRing = testPrivateKeyRing c.userKeyRing = testPrivateKeyRing
cardEncrypted, err := c.EncryptAndSignCards(testCardsCleartext) cardEncrypted, err := c.EncryptAndSignCards(testCardsCleartext)
@ -668,7 +174,7 @@ func TestClient_Encrypt(t *testing.T) {
} }
func TestClient_Decrypt(t *testing.T) { func TestClient_Decrypt(t *testing.T) {
c := newTestClient(newTestClientManager(testClientConfig)) c := newClient(newManager(DefaultConfig), "")
c.userKeyRing = testPrivateKeyRing c.userKeyRing = testPrivateKeyRing
cardCleartext, err := c.DecryptAndVerifyCards(testCardsEncrypted) cardCleartext, err := c.DecryptAndVerifyCards(testCardsEncrypted)

View File

@ -1,51 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
// ConversationsCount have same structure as MessagesCount.
type ConversationsCount MessagesCount
// ConversationsCountsRes holds response from server.
type ConversationsCountsRes struct {
Res
Counts []*ConversationsCount
}
// Conversation contains one body and multiple metadata.
type Conversation struct{}
// CountConversations counts conversations by label.
func (c *client) CountConversations(addressID string) (counts []*ConversationsCount, err error) {
reqURL := "/mail/v4/conversations/count"
if addressID != "" {
reqURL += ("?AddressID=" + addressID)
}
req, err := c.NewRequest("GET", reqURL, nil)
if err != nil {
return
}
var res ConversationsCountsRes
if err = c.DoJSON(req, &res); err != nil {
return
}
counts, err = res.Counts, res.Err()
return
}

19
pkg/pmapi/data_test.go Normal file
View File

@ -0,0 +1,19 @@
package pmapi
import "github.com/ProtonMail/gopenpgp/v2/crypto"
var testIdentity = &crypto.Identity{
Name: "UserID",
Email: "",
}
const (
testUsername = "jason"
testAPIPassword = "apple"
testUID = "729ad6012421d67ad26950dc898bebe3a6e3caa2" //nolint[gosec]
testAccessToken = "de0423049b44243afeec7d9c1d99be7b46da1e8a" //nolint[gosec]
testAccessTokenOld = "feb3159ac63fb05119bcf4480d939278aa746926" //nolint[gosec]
testRefreshToken = "a49b98256745bb497bec20e9b55f5de16f01fb52" //nolint[gosec]
testRefreshTokenNew = "b894b4c4f20003f12d486900d8b88c7d68e67235" //nolint[gosec]
)

View File

@ -1,43 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import "unicode/utf8"
func printBytes(body []byte) string {
if utf8.Valid(body) {
return string(body)
}
enc := []rune{}
for _, b := range body {
switch {
case b == 9:
enc = append(enc, rune('⟼'))
case b == 13:
enc = append(enc, rune('↵'))
case b < 32, b == 127:
enc = append(enc, '◡')
case b > 31 && b < 127, b == 10:
enc = append(enc, rune(b))
default:
enc = append(enc, 9728+rune(b))
}
}
return string(enc)
}

View File

@ -1,72 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"crypto/tls"
"net"
"net/http"
"time"
)
type TLSDialer interface {
DialTLS(network, address string) (conn net.Conn, err error)
}
// CreateTransportWithDialer creates an http.Transport that uses the given dialer to make TLS connections.
func CreateTransportWithDialer(dialer TLSDialer) *http.Transport {
return &http.Transport{
DialTLS: dialer.DialTLS,
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: 100,
IdleConnTimeout: 5 * time.Minute,
ExpectContinueTimeout: 500 * time.Millisecond,
// GODT-126: this was initially 10s but logs from users showed a significant number
// were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect.
// Bumping to 30s for now to avoid this problem.
ResponseHeaderTimeout: 30 * time.Second,
// If we allow up to 30 seconds for response headers, it is reasonable to allow up
// to 30 seconds for the TLS handshake to take place.
TLSHandshakeTimeout: 30 * time.Second,
}
}
// BasicTLSDialer implements TLSDialer.
type BasicTLSDialer struct{}
// NewBasicTLSDialer returns a new BasicTLSDialer.
func NewBasicTLSDialer() *BasicTLSDialer {
return &BasicTLSDialer{}
}
// DialTLS returns a connection to the given address using the given network.
func (b *BasicTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
dialer := &net.Dialer{Timeout: 30 * time.Second} // Alternative Routes spec says this should be a 30s timeout.
var tlsConfig *tls.Config
// If we are not dialing the standard API then we should skip cert verification checks.
if address != rootURL {
tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec]
}
return tls.DialWithDialer(dialer, network, address, tlsConfig)
}

View File

@ -1,93 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"crypto/tls"
"net"
"github.com/sirupsen/logrus"
)
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
// to report errors if the fingerprint check fails.
type PinningTLSDialer struct {
dialer TLSDialer
// pinChecker is used to check TLS keys of connections.
pinChecker *pinChecker
// tlsIssueNotifier is used to notify something when there is a TLS issue.
tlsIssueNotifier func()
reporter *tlsReporter
// A logger for logging messages.
log logrus.FieldLogger
}
// NewPinningTLSDialer constructs a new dialer which only returns tcp connections to servers
// which present known certificates.
// If enabled, it reports any invalid certificates it finds.
func NewPinningTLSDialer(dialer TLSDialer) *PinningTLSDialer {
return &PinningTLSDialer{
dialer: dialer,
pinChecker: newPinChecker(TrustedAPIPins),
log: logrus.WithField("pkg", "pmapi/tls-pinning"),
}
}
func (p *PinningTLSDialer) SetTLSIssueNotifier(notifier func()) {
p.tlsIssueNotifier = notifier
}
func (p *PinningTLSDialer) EnableRemoteTLSIssueReporting(cm *ClientManager) {
p.reporter = newTLSReporter(p.pinChecker, cm)
}
// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
func (p *PinningTLSDialer) DialTLS(network, address string) (net.Conn, error) {
conn, err := p.dialer.DialTLS(network, address)
if err != nil {
return nil, err
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
if err := p.pinChecker.checkCertificate(conn); err != nil {
if p.tlsIssueNotifier != nil {
go p.tlsIssueNotifier()
}
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
p.reporter.reportCertIssue(
TLSReportURI,
host,
port,
tlsConn.ConnectionState(),
)
}
return nil, err
}
return conn, nil
}

View File

@ -1,141 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
const liveAPI = "api.protonmail.ch"
var testLiveConfig = &ClientConfig{
AppVersion: "Bridge_1.2.4-test",
ClientID: "Bridge",
}
func createAndSetPinningDialer(cm *ClientManager) (*int, *PinningTLSDialer) {
called := 0
dialer := NewPinningTLSDialer(NewBasicTLSDialer())
dialer.SetTLSIssueNotifier(func() { called++ })
cm.SetRoundTripper(CreateTransportWithDialer(dialer))
return &called, dialer
}
func TestTLSPinValid(t *testing.T) {
cm := newTestClientManager(testLiveConfig)
cm.host = liveAPI
rootScheme = "https"
called, _ := createAndSetPinningDialer(cm)
client := cm.GetClient("pmapi" + t.Name())
_, err := client.AuthInfo("this.address.is.disabled")
Ok(t, err)
Equals(t, 0, *called)
}
func TestTLSPinBackup(t *testing.T) {
cm := newTestClientManager(testLiveConfig)
cm.host = liveAPI
called, p := createAndSetPinningDialer(cm)
p.pinChecker.trustedPins[1] = p.pinChecker.trustedPins[0]
p.pinChecker.trustedPins[0] = ""
client := cm.GetClient("pmapi" + t.Name())
_, err := client.AuthInfo("this.address.is.disabled")
Ok(t, err)
Equals(t, 0, *called)
}
func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused]
cm := newTestClientManager(testLiveConfig)
cm.host = liveAPI
called, p := createAndSetPinningDialer(cm)
for i := 0; i < len(p.pinChecker.trustedPins); i++ {
p.pinChecker.trustedPins[i] = "testing"
}
client := cm.GetClient("pmapi" + t.Name())
_, err := client.AuthInfo("this.address.is.disabled")
Ok(t, err)
// check that it will be called only once per session
client = cm.GetClient("pmapi" + t.Name())
_, err = client.AuthInfo("this.address.is.disabled")
Ok(t, err)
Equals(t, 1, *called)
}
func _TestTLSPinInvalid(t *testing.T) { // nolint[unused]
cm := newTestClientManager(testLiveConfig)
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeJSONResponsefromFile(t, w, "/auth/info/post_response.json", 0)
}))
defer ts.Close()
called, _ := createAndSetPinningDialer(cm)
client := cm.GetClient("pmapi" + t.Name())
cm.host = liveAPI
_, err := client.AuthInfo("this.address.is.disabled")
Ok(t, err)
cm.host = ts.URL
_, err = client.AuthInfo("this.address.is.disabled")
Assert(t, err != nil, "error is expected but have %v", err)
Equals(t, 1, *called)
}
// The tests below should pass but cannot run in CI due to proxy issues.
func _TestTLSSignedCertWrongPublicKey(t *testing.T) { // nolint[unused]
cm := newTestClientManager(testLiveConfig)
_, dialer := createAndSetPinningDialer(cm)
_, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
assert.Error(t, err, "expected dial to fail because of wrong public key")
}
func _TestTLSSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused]
cm := newTestClientManager(testLiveConfig)
_, dialer := createAndSetPinningDialer(cm)
dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`)
_, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
assert.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA")
}
func _TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused]
cm := newTestClientManager(testLiveConfig)
_, dialer := createAndSetPinningDialer(cm)
dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
_, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443")
assert.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
}

View File

@ -1,61 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"net"
)
// ProxyTLSDialer wraps a TLSDialer to switch to a proxy if the initial dial fails.
type ProxyTLSDialer struct {
dialer TLSDialer
cm *ClientManager
}
// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer.
func NewProxyTLSDialer(dialer TLSDialer, cm *ClientManager) *ProxyTLSDialer {
return &ProxyTLSDialer{
dialer: dialer,
cm: cm,
}
}
// DialTLS dials the given network/address. If it fails, it retries using a proxy.
func (d *ProxyTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
if conn, err = d.dialer.DialTLS(network, address); err == nil {
return conn, nil
}
if !d.cm.allowProxy {
return
}
var proxy string
if proxy, err = d.cm.switchToReachableServer(); err != nil {
return
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return
}
return d.dialer.DialTLS(network, net.JoinHostPort(proxy, port))
}

View File

@ -1,74 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
// DownloadAndVerify downloads a file and its signature from the given locations `file` and `sig`.
// The file and its signature are verified using the given keyring `kr`.
// If the file is verified successfully, it can be read from the returned reader.
// TLS fingerprinting is used to verify that connections are only made to known servers.
func (c *client) DownloadAndVerify(file, sig string, kr *crypto.KeyRing) (io.Reader, error) {
var fb, sb []byte
if err := c.fetchFile(file, func(r io.Reader) (err error) {
fb, err = ioutil.ReadAll(r)
return err
}); err != nil {
return nil, err
}
if err := c.fetchFile(sig, func(r io.Reader) (err error) {
sb, err = ioutil.ReadAll(r)
return err
}); err != nil {
return nil, err
}
if err := kr.VerifyDetached(
crypto.NewPlainMessage(fb),
crypto.NewPGPSignature(sb),
crypto.GetUnixTime(),
); err != nil {
return nil, err
}
return bytes.NewReader(fb), nil
}
func (c *client) fetchFile(file string, fn func(io.Reader) error) error {
res, err := c.hc.Get(file)
if err != nil {
return err
}
defer func() { _ = res.Body.Close() }()
if res.StatusCode != http.StatusOK {
return fmt.Errorf("failed to get file: http error %v", res.StatusCode)
}
return fn(res.Body)
}

9
pkg/pmapi/errors.go Normal file
View File

@ -0,0 +1,9 @@
package pmapi
import "errors"
var (
ErrNoConnection = errors.New("no internet connection")
ErrAPIFailure = errors.New("API returned an error")
ErrUnauthorized = errors.New("API client is unauthorized")
)

View File

@ -18,9 +18,11 @@
package pmapi package pmapi
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http"
"net/mail" "net/mail"
"github.com/go-resty/resty/v2"
) )
// Event represents changes since the last check. // Event represents changes since the last check.
@ -137,7 +139,7 @@ type EventMessageUpdated struct {
ID string ID string
Subject *string Subject *string
Unread *int Unread *Boolean
Flags *int64 Flags *int64
Sender *mail.Address Sender *mail.Address
ToList *[]*mail.Address ToList *[]*mail.Address
@ -163,62 +165,28 @@ type EventAddress struct {
Address *Address Address *Address
} }
type EventRes struct {
Res
*Event
}
type LatestEventRes struct {
Res
*Event
}
// GetEvent returns a summary of events that occurred since last. To get the latest event, // GetEvent returns a summary of events that occurred since last. To get the latest event,
// provide an empty last value. The latest event is always empty. // provide an empty last value. The latest event is always empty.
func (c *client) GetEvent(last string) (event *Event, err error) { func (c *client) GetEvent(ctx context.Context, eventID string) (event *Event, err error) {
return c.getEvent(last, 1) if eventID == "" {
eventID = "latest"
} }
func (c *client) getEvent(last string, numberOfMergedEvents int) (event *Event, err error) { var res struct {
var req *http.Request *Event
if last == "" {
req, err = c.NewRequest("GET", "/events/latest", nil) More int
if err != nil {
return
} }
var res LatestEventRes if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
if err = c.DoJSON(req, &res); err != nil { return r.SetResult(&res).Get("/events/" + eventID)
return }); err != nil {
return nil, err
} }
event, err = res.Event, res.Err() // FIXME(conman): use mergeEvents() function.
} else {
req, err = c.NewRequest("GET", "/events/"+last, nil)
if err != nil {
return
}
var res EventRes return res.Event, nil
if err = c.DoJSON(req, &res); err != nil {
return
}
event, err = res.Event, res.Err()
if err != nil {
return
}
if event.More == 1 && numberOfMergedEvents < maxNumberOfMergedEvents {
var moreEvents *Event
if moreEvents, err = c.getEvent(event.EventID, numberOfMergedEvents+1); err != nil {
return
}
event = mergeEvents(event, moreEvents)
}
}
return event, err
} }
// mergeEvents combines an old events and a new events object. // mergeEvents combines an old events and a new events object.

View File

@ -18,6 +18,7 @@
package pmapi package pmapi
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/mail" "net/mail"
@ -31,32 +32,41 @@ import (
) )
func TestClient_GetEvent(t *testing.T) { func TestClient_GetEvent(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NoError(t, checkMethodAndPath(r, "GET", "/events/latest")) assert.NoError(t, checkMethodAndPath(r, "GET", "/events/latest"))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testEventBody) fmt.Fprint(w, testEventBody)
})) }))
defer s.Close() defer s.Close()
event, err := c.GetEvent("") event, err := c.GetEvent(context.TODO(), "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, testEvent, event) require.Equal(t, testEvent, event)
} }
func TestClient_GetEvent_withID(t *testing.T) { func TestClient_GetEvent_withID(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NoError(t, checkMethodAndPath(r, "GET", "/events/"+testEvent.EventID)) assert.NoError(t, checkMethodAndPath(r, "GET", "/events/"+testEvent.EventID))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testEventBody) fmt.Fprint(w, testEventBody)
})) }))
defer s.Close() defer s.Close()
event, err := c.GetEvent(testEvent.EventID) event, err := c.GetEvent(context.TODO(), testEvent.EventID)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, testEvent, event) require.Equal(t, testEvent, event)
} }
// We first call GetEvent with id of eventID1, which returns More=1 so we fetch with id eventID2. // We first call GetEvent with id of eventID1, which returns More=1 so we fetch with id eventID2.
func TestClient_GetEvent_mergeEvents(t *testing.T) { // FIXME(conman): Merging is currently not supported. Implement it and then enable this test again!
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { func _TestClient_GetEvent_mergeEvents(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.RequestURI() { switch r.URL.RequestURI() {
case "/events/eventID1": case "/events/eventID1":
assert.NoError(t, checkMethodAndPath(r, "GET", "/events/eventID1")) assert.NoError(t, checkMethodAndPath(r, "GET", "/events/eventID1"))
@ -70,15 +80,16 @@ func TestClient_GetEvent_mergeEvents(t *testing.T) {
})) }))
defer s.Close() defer s.Close()
event, err := c.GetEvent("eventID1") event, err := c.GetEvent(context.TODO(), "eventID1")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, testEventMerged, event) require.Equal(t, testEventMerged, event)
} }
func TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) { // FIXME(conman): Merging is currently not supported. Implement it and then enable this test again!
func _TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) {
numberOfCalls := 0 numberOfCalls := 0
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numberOfCalls++ numberOfCalls++
re := regexp.MustCompile(`/eventID([0-9]+)`) re := regexp.MustCompile(`/eventID([0-9]+)`)
@ -93,18 +104,20 @@ func TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) {
fmt.Println("") fmt.Println("")
body := strings.ReplaceAll(testEventBodyMore1, "eventID2", "eventID"+strconv.Itoa(eventID+1)) body := strings.ReplaceAll(testEventBodyMore1, "eventID2", "eventID"+strconv.Itoa(eventID+1))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, body) fmt.Fprint(w, body)
})) }))
defer s.Close() defer s.Close()
event, err := c.GetEvent("eventID1") event, err := c.GetEvent(context.TODO(), "eventID1")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, maxNumberOfMergedEvents, numberOfCalls) require.Equal(t, maxNumberOfMergedEvents, numberOfCalls)
require.Equal(t, 1, event.More) require.Equal(t, 1, event.More)
} }
var ( var (
testEventMessageUpdateUnread = 0 testEventMessageUpdateUnread = False
testEvent = &Event{ testEvent = &Event{
EventID: "eventID1", EventID: "eventID1",

View File

@ -18,110 +18,68 @@
package pmapi package pmapi
import ( import (
"bytes"
"context"
"encoding/json" "encoding/json"
"io" "errors"
"mime/multipart"
"strconv" "strconv"
"github.com/go-resty/resty/v2"
) )
// Import errors. const MaxImportMessageRequestLength = 10
const (
ImportMessageTooLarge = 36022
)
// ImportReq is an import request.
type ImportReq struct {
// A list of messages that will be imported.
Messages []*ImportMsgReq
}
// WriteTo writes the import request to a multipart writer.
func (req *ImportReq) WriteTo(w *multipart.Writer) (err error) {
// Create Metadata field.
mw, err := w.CreateFormField("Metadata")
if err != nil {
return
}
// Build metadata.
metadata := map[string]*ImportMsgReq{}
for i, msg := range req.Messages {
name := strconv.Itoa(i)
metadata[name] = msg
}
// Write metadata.
if err = json.NewEncoder(mw).Encode(metadata); err != nil {
return
}
// Write messages.
for i, msg := range req.Messages {
name := strconv.Itoa(i)
var fw io.Writer
if fw, err = w.CreateFormFile(name, name+".eml"); err != nil {
return err
}
if _, err = fw.Write(msg.Body); err != nil {
return
}
// Adding new line to properly fetch the whole body on the API side.
// The reason is the bug in PHP: https://bugs.php.net/bug.php?id=75923
// Messages generated by PM already have it but importing already
// encrypted messages might not have it.
if _, err = fw.Write([]byte("\r\n")); err != nil {
return
}
}
return err
}
// ImportMsgReq is a request to import a message. All fields are optional except AddressID and Body.
type ImportMsgReq struct { type ImportMsgReq struct {
// The address where the message will be imported. Metadata *ImportMetadata // Metadata about the message to import.
Message []byte // The raw RFC822 message.
}
type ImportMsgReqs []*ImportMsgReq
func (reqs ImportMsgReqs) buildMultipartFormData() ([]*resty.MultipartField, error) {
var fields []*resty.MultipartField
metadata := make(map[string]*ImportMetadata)
for i, req := range reqs {
name := strconv.Itoa(i)
metadata[name] = req.Metadata
fields = append(fields, &resty.MultipartField{
Param: name,
FileName: name + ".eml",
ContentType: "message/rfc822",
Reader: bytes.NewReader(req.Message),
})
}
b, err := json.Marshal(metadata)
if err != nil {
return nil, err
}
fields = append(fields, &resty.MultipartField{
Param: "Metadata",
ContentType: "application/json",
Reader: bytes.NewReader(b),
})
return fields, nil
}
// TODO: Add other metadata.
type ImportMetadata struct {
AddressID string AddressID string
// The full MIME message. Unread Boolean // 0: read, 1: unread.
Body []byte `json:"-"` IsReplied Boolean // 1 if the message has been replied.
IsRepliedAll Boolean // 1 if the message has been replied to all.
// 0: read, 1: unread. IsForwarded Boolean // 1 if the message has been forwarded.
Unread int Time int64 // The time when the message was received as a Unix time.
// 1 if the message has been replied. Flags int64 // The type of the imported message.
IsReplied int LabelIDs []string // The labels to apply to the imported message. Must contain at least one system label.
// 1 if the message has been replied to all.
IsRepliedAll int
// 1 if the message has been forwarded.
IsForwarded int
// The time when the message was received as a Unix time.
Time int64
// The type of the imported message.
Flags int64
// The labels to apply to the imported message. Must contain at least one system label.
LabelIDs []string
} }
func (req ImportMsgReq) String() string {
data, _ := json.Marshal(req)
return string(data)
}
// ImportRes is a response to an import request.
type ImportRes struct {
Res
Responses []struct {
Name string
Response struct {
Res
MessageID string
}
}
}
// ImportMsgRes is a response to a single message import request.
type ImportMsgRes struct { type ImportMsgRes struct {
// The error encountered while importing the message, if any. // The error encountered while importing the message, if any.
Error error Error error
@ -130,41 +88,46 @@ type ImportMsgRes struct {
} }
// Import imports messages to the user's account. // Import imports messages to the user's account.
func (c *client) Import(reqs []*ImportMsgReq) (resps []*ImportMsgRes, err error) { func (c *client) Import(ctx context.Context, reqs ImportMsgReqs) ([]*ImportMsgRes, error) {
importReq := &ImportReq{Messages: reqs} if len(reqs) > MaxImportMessageRequestLength {
return nil, errors.New("request is too long")
}
req, w, err := c.NewMultipartRequest("POST", "/mail/v4/messages/import") fields, err := reqs.buildMultipartFormData()
if err != nil { if err != nil {
return return nil, err
} }
// We will write the request as long as it is sent to the API. var res struct {
var importRes ImportRes Responses []struct {
done := make(chan error, 1) Name string
go (func() { Response struct {
done <- c.DoJSON(req, &importRes) Error
})() MessageID string
// Write the request.
if err = importReq.WriteTo(w.Writer); err != nil {
return
} }
_ = w.Close()
if err = <-done; err != nil {
return
}
if err = importRes.Err(); err != nil {
return
}
resps = make([]*ImportMsgRes, len(importRes.Responses))
for i, r := range importRes.Responses {
resps[i] = &ImportMsgRes{
Error: r.Response.Err(),
MessageID: r.Response.MessageID,
} }
} }
return resps, err if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetMultipartFields(fields...).SetResult(&res).Post("/mail/v4/messages/import")
}); err != nil {
return nil, err
}
var resps []*ImportMsgRes
for _, resp := range res.Responses {
var err error
if resp.Response.Code != 1000 {
err = resp.Response.Error
}
resps = append(resps, &ImportMsgRes{
Error: err,
MessageID: resp.Response.MessageID,
})
}
return resps, nil
} }

View File

@ -18,6 +18,7 @@
package pmapi package pmapi
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -32,12 +33,14 @@ import (
var testImportReqs = []*ImportMsgReq{ var testImportReqs = []*ImportMsgReq{
{ {
Metadata: &ImportMetadata{
AddressID: "QMJs2dzTx7uqpH5PNgIzjULywU4gO9uMBhEMVFOAVJOoUml54gC0CCHtW9qYwzH-zYbZwMv3MFYncPjW1Usq7Q==", AddressID: "QMJs2dzTx7uqpH5PNgIzjULywU4gO9uMBhEMVFOAVJOoUml54gC0CCHtW9qYwzH-zYbZwMv3MFYncPjW1Usq7Q==",
Body: []byte("Hello World!"),
Unread: 0, Unread: 0,
Flags: FlagReceived | FlagImported, Flags: FlagReceived | FlagImported,
LabelIDs: []string{ArchiveLabel}, LabelIDs: []string{ArchiveLabel},
}, },
Message: []byte("Hello World!"),
},
} }
const testImportBody = `{ const testImportBody = `{
@ -54,7 +57,7 @@ var testImportRes = &ImportMsgRes{
} }
func TestClient_Import(t *testing.T) { // nolint[funlen] func TestClient_Import(t *testing.T) { // nolint[funlen]
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/messages/import")) Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/messages/import"))
contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type")) contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type"))
@ -67,51 +70,13 @@ func TestClient_Import(t *testing.T) { // nolint[funlen]
mr := multipart.NewReader(r.Body, params["boundary"]) mr := multipart.NewReader(r.Body, params["boundary"])
// First part is metadata. // First part is message body.
p, err := mr.NextPart() p, err := mr.NextPart()
if err != nil {
t.Error("Expected no error while reading first part of request body, got:", err)
}
contentDisp, params, err := pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
if err != nil {
t.Error("Expected no error while parsing part content disposition, got:", err)
}
if contentDisp != "form-data" {
t.Errorf("Invalid part content disposition: expected %v but got %v", "form-data", contentType)
}
if params["name"] != "Metadata" {
t.Errorf("Invalid part name: expected %v but got %v", "Metadata", params["name"])
}
metadata := map[string]*ImportMsgReq{}
if err := json.NewDecoder(p).Decode(&metadata); err != nil {
t.Error("Expected no error while parsing metadata json, got:", err)
}
if len(metadata) != 1 {
t.Errorf("Expected metadata to contain exactly one item, got %v", metadata)
}
req := metadata["0"]
if metadata["0"] == nil {
t.Errorf("Expected metadata to contain one item indexed by 0, got %v", metadata)
}
// No Body in metadata.
expected := *testImportReqs[0]
expected.Body = nil
if !reflect.DeepEqual(&expected, req) {
t.Errorf("Invalid message metadata: expected %v, got %v", &expected, req)
}
// Second part is message body.
p, err = mr.NextPart()
if err != nil { if err != nil {
t.Error("Expected no error while reading second part of request body, got:", err) t.Error("Expected no error while reading second part of request body, got:", err)
} }
contentDisp, params, err = pmmime.ParseMediaType(p.Header.Get("Content-Disposition")) contentDisp, params, err := pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
if err != nil { if err != nil {
t.Error("Expected no error while parsing part content disposition, got:", err) t.Error("Expected no error while parsing part content disposition, got:", err)
} }
@ -127,8 +92,44 @@ func TestClient_Import(t *testing.T) { // nolint[funlen]
t.Error("Expected no error while reading second part body, got:", err) t.Error("Expected no error while reading second part body, got:", err)
} }
if string(b) != string(testImportReqs[0].Body)+"\r\n" { if string(b) != string(testImportReqs[0].Message) {
t.Errorf("Invalid message body: expected %v but got %v", string(testImportReqs[0].Body), string(b)) t.Errorf("Invalid message body: expected %v but got %v", string(testImportReqs[0].Message), string(b))
}
// Second part is metadata.
p, err = mr.NextPart()
if err != nil {
t.Error("Expected no error while reading first part of request body, got:", err)
}
contentDisp, params, err = pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
if err != nil {
t.Error("Expected no error while parsing part content disposition, got:", err)
}
if contentDisp != "form-data" {
t.Errorf("Invalid part content disposition: expected %v but got %v", "form-data", contentType)
}
if params["name"] != "Metadata" {
t.Errorf("Invalid part name: expected %v but got %v", "Metadata", params["name"])
}
metadata := map[string]*ImportMetadata{}
if err := json.NewDecoder(p).Decode(&metadata); err != nil {
t.Error("Expected no error while parsing metadata json, got:", err)
}
if len(metadata) != 1 {
t.Errorf("Expected metadata to contain exactly one item, got %v", metadata)
}
req := metadata["0"]
if metadata["0"] == nil {
t.Errorf("Expected metadata to contain one item indexed by 0, got %v", metadata)
}
expected := *testImportReqs[0].Metadata
if !reflect.DeepEqual(&expected, req) {
t.Errorf("Invalid message metadata: expected %v, got %v", &expected, req)
} }
// No more parts. // No more parts.
@ -137,11 +138,13 @@ func TestClient_Import(t *testing.T) { // nolint[funlen]
t.Error("Expected no more parts but error was not EOF, got:", err) t.Error("Expected no more parts but error was not EOF, got:", err)
} }
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testImportBody) fmt.Fprint(w, testImportBody)
})) }))
defer s.Close() defer s.Close()
imported, err := c.Import(testImportReqs) imported, err := c.Import(context.TODO(), testImportReqs)
if err != nil { if err != nil {
t.Fatal("Expected no error while importing, got:", err) t.Fatal("Expected no error while importing, got:", err)
} }

View File

@ -18,11 +18,10 @@
package pmapi package pmapi
import ( import (
"fmt" "context"
"net/http"
"net/url" "net/url"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/go-resty/resty/v2"
) )
// Key flags. // Key flags.
@ -31,84 +30,34 @@ const (
UseToEncryptFlag UseToEncryptFlag
) )
type PublicKeyRes struct {
Res
RecipientType int
MIMEType string
Keys []PublicKey
}
type PublicKey struct { type PublicKey struct {
Flags int Flags int
PublicKey string PublicKey string
} }
// PublicKeys returns the public keys of the given email addresses. type RecipientType int
func (c *client) PublicKeys(emails []string) (keys map[string]*crypto.Key, err error) {
if len(emails) == 0 {
err = fmt.Errorf("pmapi: cannot get public keys: no email address provided")
return
}
keys = make(map[string]*crypto.Key)
for _, email := range emails {
email = url.QueryEscape(email)
var req *http.Request
if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil {
return
}
var res PublicKeyRes
if err = c.DoJSON(req, &res); err != nil {
return
}
for _, rawKey := range res.Keys {
if rawKey.Flags&UseToEncryptFlag == UseToEncryptFlag {
var key *crypto.Key
if key, err = crypto.NewKeyFromArmored(rawKey.PublicKey); err != nil {
return
}
keys[email] = key
}
}
}
return keys, err
}
const ( const (
RecipientInternal = 1 RecipientTypeInternal RecipientType = iota + 1
RecipientExternal = 2 RecipientTypeExternal
) )
// GetPublicKeysForEmail returns all sending public keys for the given email address. // GetPublicKeysForEmail returns all sending public keys for the given email address.
func (c *client) GetPublicKeysForEmail(email string) (keys []PublicKey, internal bool, err error) { func (c *client) GetPublicKeysForEmail(ctx context.Context, email string) (keys []PublicKey, internal bool, err error) {
email = url.QueryEscape(email) email = url.QueryEscape(email)
var req *http.Request var res struct {
if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil { Keys []PublicKey
return RecipientType RecipientType
} }
var res PublicKeyRes if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
if err = c.DoJSON(req, &res); err != nil { return r.SetResult(&res).SetQueryParam("Email", email).Get("/keys")
return }); err != nil {
return nil, false, err
} }
internal = res.RecipientType == RecipientInternal return res.Keys, res.RecipientType == RecipientTypeInternal, nil
for _, key := range res.Keys {
if key.Flags&UseToEncryptFlag == UseToEncryptFlag {
keys = append(keys, key)
}
}
return
} }
// KeySalt contains id and salt for key. // KeySalt contains id and salt for key.
@ -116,25 +65,17 @@ type KeySalt struct {
ID, KeySalt string ID, KeySalt string
} }
// KeySaltRes is used to unmarshal API response. // GetKeySalts sends request to get list of key salts (n.b. locked route).
type KeySaltRes struct { func (c *client) GetKeySalts(ctx context.Context) (keySalts []KeySalt, err error) {
Res var res struct {
KeySalts []KeySalt KeySalts []KeySalt
} }
// GetKeySalts sends request to get list of key salts (n.b. locked route). if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
func (c *client) GetKeySalts() (keySalts []KeySalt, err error) { return r.SetResult(&res).Get("/keys/salts")
var req *http.Request }); err != nil {
if req, err = c.NewRequest("GET", "/keys/salts", nil); err != nil { return nil, err
return
} }
var res KeySaltRes return res.KeySalts, nil
if err = c.DoJSON(req, &res); err != nil {
return
}
keySalts = res.KeySalts
return
} }

View File

@ -18,8 +18,11 @@
package pmapi package pmapi
import ( import (
"context"
"errors" "errors"
"fmt" "strconv"
"github.com/go-resty/resty/v2"
) )
// System labels. // System labels.
@ -92,100 +95,84 @@ type Label struct {
Notify int Notify int
} }
type LabelListRes struct { func (c *client) ListLabels(ctx context.Context) (labels []*Label, err error) {
Res return c.ListLabelType(ctx, LabelTypeMailbox)
Labels []*Label
} }
func (c *client) ListLabels() (labels []*Label, err error) { func (c *client) ListContactGroups(ctx context.Context) (labels []*Label, err error) {
return c.ListLabelType(LabelTypeMailbox) return c.ListLabelType(ctx, LabelTypeContactGroup)
}
func (c *client) ListContactGroups() (labels []*Label, err error) {
return c.ListLabelType(LabelTypeContactGroup)
} }
// ListLabelType lists all labels created by the user. // ListLabelType lists all labels created by the user.
func (c *client) ListLabelType(labelType int) (labels []*Label, err error) { func (c *client) ListLabelType(ctx context.Context, labelType int) (labels []*Label, err error) {
req, err := c.NewRequest("GET", fmt.Sprintf("/labels?%d", labelType), nil) var res struct {
if err != nil { Labels []*Label
return
} }
var res LabelListRes if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
if err = c.DoJSON(req, &res); err != nil { return r.SetQueryParam("Type", strconv.Itoa(labelType)).SetResult(&res).Get("/v4/labels")
return }); err != nil {
return nil, err
} }
labels, err = res.Labels, res.Err() return res.Labels, nil
return
} }
type LabelReq struct { type LabelReq struct {
*Label *Label
} }
type LabelRes struct { // CreateLabel creates a new label.
Res func (c *client) CreateLabel(ctx context.Context, label *Label) (created *Label, err error) {
if label.Name == "" {
return nil, errors.New("name is required")
}
var res struct {
Label *Label Label *Label
} }
// CreateLabel creates a new label. if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
func (c *client) CreateLabel(label *Label) (created *Label, err error) { return r.SetBody(&LabelReq{
if label.Name == "" { Label: label,
return nil, errors.New("name is required") }).SetResult(&res).Post("/v4/labels")
}); err != nil {
return nil, err
} }
labelReq := &LabelReq{label} return res.Label, nil
req, err := c.NewJSONRequest("POST", "/labels", labelReq)
if err != nil {
return
}
var res LabelRes
if err = c.DoJSON(req, &res); err != nil {
return
}
created, err = res.Label, res.Err()
return
} }
// UpdateLabel updates a label. // UpdateLabel updates a label.
func (c *client) UpdateLabel(label *Label) (updated *Label, err error) { func (c *client) UpdateLabel(ctx context.Context, label *Label) (updated *Label, err error) {
if label.Name == "" { if label.Name == "" {
return nil, errors.New("name is required") return nil, errors.New("name is required")
} }
labelReq := &LabelReq{label} var res struct {
req, err := c.NewJSONRequest("PUT", "/labels/"+label.ID, labelReq) Label *Label
if err != nil {
return
} }
var res LabelRes if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
if err = c.DoJSON(req, &res); err != nil { return r.SetBody(&LabelReq{
return Label: label,
}).SetResult(&res).Put("/v4/labels/" + label.ID)
}); err != nil {
return nil, err
} }
updated, err = res.Label, res.Err() return res.Label, nil
return
} }
// DeleteLabel deletes a label. // DeleteLabel deletes a label.
func (c *client) DeleteLabel(id string) (err error) { func (c *client) DeleteLabel(ctx context.Context, labelID string) error {
req, err := c.NewRequest("DELETE", "/labels/"+id, nil) if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
if err != nil { return r.Delete("/v4/labels/" + labelID)
return }); err != nil {
return err
} }
var res Res return nil
if err = c.DoJSON(req, &res); err != nil {
return
}
err = res.Err()
return
} }
// LeastUsedColor is intended to return color for creating a new inbox or label. // LeastUsedColor is intended to return color for creating a new inbox or label.

View File

@ -19,6 +19,7 @@ package pmapi
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -90,14 +91,16 @@ const testDeleteLabelBody = `{
` `
func TestClient_ListLabels(t *testing.T) { func TestClient_ListLabels(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "GET", "/labels?1")) Ok(t, checkMethodAndPath(r, "GET", "/v4/labels?Type=1"))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testLabelsBody) fmt.Fprint(w, testLabelsBody)
})) }))
defer s.Close() defer s.Close()
labels, err := c.ListLabels() labels, err := c.ListLabels(context.TODO())
if err != nil { if err != nil {
t.Fatal("Expected no error while listing labels, got:", err) t.Fatal("Expected no error while listing labels, got:", err)
} }
@ -114,8 +117,8 @@ func TestClient_ListLabels(t *testing.T) {
} }
func TestClient_CreateLabel(t *testing.T) { func TestClient_CreateLabel(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "POST", "/labels")) Ok(t, checkMethodAndPath(r, "POST", "/v4/labels"))
body := &bytes.Buffer{} body := &bytes.Buffer{}
_, err := body.ReadFrom(r.Body) _, err := body.ReadFrom(r.Body)
@ -133,11 +136,13 @@ func TestClient_CreateLabel(t *testing.T) {
t.Errorf("Invalid label request: expected %+v but got %+v", testLabelReq.Label, labelReq.Label) t.Errorf("Invalid label request: expected %+v but got %+v", testLabelReq.Label, labelReq.Label)
} }
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testCreateLabelBody) fmt.Fprint(w, testCreateLabelBody)
})) }))
defer s.Close() defer s.Close()
created, err := c.CreateLabel(testLabelReq.Label) created, err := c.CreateLabel(context.TODO(), testLabelReq.Label)
if err != nil { if err != nil {
t.Fatal("Expected no error while creating label, got:", err) t.Fatal("Expected no error while creating label, got:", err)
} }
@ -148,18 +153,18 @@ func TestClient_CreateLabel(t *testing.T) {
} }
func TestClient_CreateEmptyLabel(t *testing.T) { func TestClient_CreateEmptyLabel(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { s, c := newTestClient(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
r.Fail(t, "API should not be called") r.Fail(t, "API should not be called")
})) }))
defer s.Close() defer s.Close()
_, err := c.CreateLabel(&Label{}) _, err := c.CreateLabel(context.TODO(), &Label{})
r.EqualError(t, err, "name is required") r.EqualError(t, err, "name is required")
} }
func TestClient_UpdateLabel(t *testing.T) { func TestClient_UpdateLabel(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "PUT", "/labels/"+testLabelCreated.ID)) Ok(t, checkMethodAndPath(r, "PUT", "/v4/labels/"+testLabelCreated.ID))
var labelReq LabelReq var labelReq LabelReq
if err := json.NewDecoder(r.Body).Decode(&labelReq); err != nil { if err := json.NewDecoder(r.Body).Decode(&labelReq); err != nil {
@ -169,11 +174,13 @@ func TestClient_UpdateLabel(t *testing.T) {
t.Errorf("Invalid label request: expected %+v but got %+v", testLabelCreated, labelReq.Label) t.Errorf("Invalid label request: expected %+v but got %+v", testLabelCreated, labelReq.Label)
} }
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testCreateLabelBody) fmt.Fprint(w, testCreateLabelBody)
})) }))
defer s.Close() defer s.Close()
updated, err := c.UpdateLabel(testLabelCreated) updated, err := c.UpdateLabel(context.TODO(), testLabelCreated)
if err != nil { if err != nil {
t.Fatal("Expected no error while updating label, got:", err) t.Fatal("Expected no error while updating label, got:", err)
} }
@ -184,24 +191,26 @@ func TestClient_UpdateLabel(t *testing.T) {
} }
func TestClient_UpdateLabelToEmptyName(t *testing.T) { func TestClient_UpdateLabelToEmptyName(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { s, c := newTestClient(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
r.Fail(t, "API should not be called") r.Fail(t, "API should not be called")
})) }))
defer s.Close() defer s.Close()
_, err := c.UpdateLabel(&Label{ID: "label"}) _, err := c.UpdateLabel(context.TODO(), &Label{ID: "label"})
r.EqualError(t, err, "name is required") r.EqualError(t, err, "name is required")
} }
func TestClient_DeleteLabel(t *testing.T) { func TestClient_DeleteLabel(t *testing.T) {
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Ok(t, checkMethodAndPath(r, "DELETE", "/labels/"+testLabelCreated.ID)) Ok(t, checkMethodAndPath(r, "DELETE", "/v4/labels/"+testLabelCreated.ID))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testDeleteLabelBody) fmt.Fprint(w, testDeleteLabelBody)
})) }))
defer s.Close() defer s.Close()
err := c.DeleteLabel(testLabelCreated.ID) err := c.DeleteLabel(context.TODO(), testLabelCreated.ID)
if err != nil { if err != nil {
t.Fatal("Expected no error while deleting label, got:", err) t.Fatal("Expected no error while deleting label, got:", err)
} }

127
pkg/pmapi/manager.go Normal file
View File

@ -0,0 +1,127 @@
package pmapi
import (
"context"
"net/http"
"sync"
"time"
"github.com/go-resty/resty/v2"
)
type manager struct {
rc *resty.Client
isDown bool
locker sync.Locker
observers []ConnectionObserver
}
func newManager(cfg Config) *manager {
m := &manager{
rc: resty.New(),
locker: &sync.Mutex{},
}
// Set the API host.
m.rc.SetHostURL(cfg.HostURL)
// Set static header values.
m.rc.SetHeader("x-pm-appversion", cfg.AppVersion)
// Set middleware.
m.rc.OnAfterResponse(catchAPIError)
// Configure retry mechanism.
m.rc.SetRetryMaxWaitTime(time.Minute)
m.rc.SetRetryAfter(catchRetryAfter)
m.rc.AddRetryCondition(catchTooManyRequests)
m.rc.AddRetryCondition(catchNoResponse)
m.rc.AddRetryCondition(catchProxyAvailable)
// Determine what happens when requests succeed/fail.
m.rc.OnAfterResponse(m.handleRequestSuccess)
m.rc.OnError(m.handleRequestFailure)
// Set the data type of API errors.
m.rc.SetError(&Error{})
return m
}
func New(cfg Config) Manager {
return newManager(cfg)
}
func (m *manager) SetLogger(logger resty.Logger) {
m.rc.SetLogger(logger)
m.rc.SetDebug(true)
}
func (m *manager) SetTransport(transport http.RoundTripper) {
m.rc.SetTransport(transport)
}
func (m *manager) SetCookieJar(jar http.CookieJar) {
m.rc.SetCookieJar(jar)
}
func (m *manager) SetRetryCount(count int) {
m.rc.SetRetryCount(count)
}
func (m *manager) AddConnectionObserver(observer ConnectionObserver) {
m.observers = append(m.observers, observer)
}
func (m *manager) r(ctx context.Context) *resty.Request {
return m.rc.R().SetContext(ctx)
}
func (m *manager) handleRequestSuccess(_ *resty.Client, res *resty.Response) error {
m.locker.Lock()
defer m.locker.Unlock()
if !m.isDown {
return nil
}
// We successfully got a response; connection must be up.
m.isDown = false
for _, observer := range m.observers {
observer.OnUp()
}
return nil
}
func (m *manager) handleRequestFailure(req *resty.Request, err error) {
m.locker.Lock()
defer m.locker.Unlock()
if m.isDown {
return
}
if res, ok := err.(*resty.ResponseError); ok && res.Response.RawResponse != nil {
return
}
// We didn't get any response; connection must be down.
m.isDown = true
for _, observer := range m.observers {
observer.OnDown()
}
go m.pingUntilSuccess()
}
func (m *manager) pingUntilSuccess() {
for m.testPing(context.Background()) != nil {
time.Sleep(time.Second) // TODO: How long to sleep here?
}
}

114
pkg/pmapi/manager_auth.go Normal file
View File

@ -0,0 +1,114 @@
package pmapi
import (
"context"
"encoding/base64"
"time"
"github.com/ProtonMail/proton-bridge/pkg/srp"
)
func (m *manager) NewClient(uid, acc, ref string, exp time.Time) Client {
return newClient(m, uid).withAuth(acc, ref, exp)
}
func (m *manager) NewClientWithRefresh(ctx context.Context, uid, ref string) (Client, *Auth, error) {
c := newClient(m, uid)
auth, err := m.authRefresh(ctx, uid, ref)
if err != nil {
return nil, nil, err
}
return c.withAuth(auth.AccessToken, auth.RefreshToken, expiresIn(auth.ExpiresIn)), auth, nil
}
func (m *manager) NewClientWithLogin(ctx context.Context, username, password string) (Client, *Auth, error) {
info, err := m.getAuthInfo(ctx, GetAuthInfoReq{Username: username})
if err != nil {
return nil, nil, err
}
srpAuth, err := srp.NewSrpAuth(info.Version, username, password, info.Salt, info.Modulus, info.ServerEphemeral)
if err != nil {
return nil, nil, err
}
proofs, err := srpAuth.GenerateSrpProofs(2048)
if err != nil {
return nil, nil, err
}
auth, err := m.auth(ctx, AuthReq{
Username: username,
ClientProof: base64.StdEncoding.EncodeToString(proofs.ClientProof),
ClientEphemeral: base64.StdEncoding.EncodeToString(proofs.ClientEphemeral),
SRPSession: info.SRPSession,
})
if err != nil {
return nil, nil, err
}
return newClient(m, auth.UID).withAuth(auth.AccessToken, auth.RefreshToken, expiresIn(auth.ExpiresIn)), auth, nil
}
func (m *manager) getAuthModulus(ctx context.Context) (AuthModulus, error) {
var res struct {
AuthModulus
}
if _, err := m.r(ctx).SetResult(&res).Get("/auth/modulus"); err != nil {
return AuthModulus{}, err
}
return res.AuthModulus, nil
}
func (m *manager) getAuthInfo(ctx context.Context, req GetAuthInfoReq) (*AuthInfo, error) {
var res struct {
*AuthInfo
}
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/info"); err != nil {
return nil, err
}
return res.AuthInfo, nil
}
func (m *manager) auth(ctx context.Context, req AuthReq) (*Auth, error) {
var res struct {
*Auth
}
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth"); err != nil {
return nil, err
}
return res.Auth, nil
}
func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*Auth, error) {
var req = AuthRefreshReq{
UID: uid,
RefreshToken: ref,
ResponseType: "token",
GrantType: "refresh_token",
RedirectURI: "https://protonmail.ch",
State: randomString(32),
}
var res struct {
*Auth
}
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh"); err != nil {
return nil, err
}
return res.Auth, nil
}
func expiresIn(seconds int64) time.Time {
return time.Now().Add(time.Duration(seconds) * time.Second)
}

View File

@ -0,0 +1,51 @@
package pmapi
import (
"io/ioutil"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
// DownloadAndVerify downloads a file and its signature from the given locations `file` and `sig`.
// The file and its signature are verified using the given keyring `kr`.
// If the file is verified successfully, it can be read from the returned reader.
// TLS fingerprinting is used to verify that connections are only made to known servers.
func (m *manager) DownloadAndVerify(kr *crypto.KeyRing, url, sig string) ([]byte, error) {
fb, err := m.fetchFile(url)
if err != nil {
return nil, err
}
sb, err := m.fetchFile(sig)
if err != nil {
return nil, err
}
if err := kr.VerifyDetached(
crypto.NewPlainMessage(fb),
crypto.NewPGPSignature(sb),
crypto.GetUnixTime(),
); err != nil {
return nil, err
}
return fb, nil
}
func (m *manager) fetchFile(url string) ([]byte, error) {
res, err := m.rc.R().SetDoNotParseResponse(true).Get(url)
if err != nil {
return nil, err
}
b, err := ioutil.ReadAll(res.RawBody())
if err != nil {
return nil, err
}
if err := res.RawBody().Close(); err != nil {
return nil, err
}
return b, nil
}

View File

@ -0,0 +1,11 @@
package pmapi
import (
"context"
"errors"
)
func (m *manager) SendSimpleMetric(context.Context, string, string, string) error {
// FIXME(conman): Implement.
return errors.New("not implemented")
}

11
pkg/pmapi/manager_ping.go Normal file
View File

@ -0,0 +1,11 @@
package pmapi
import "context"
func (m *manager) testPing(ctx context.Context) error {
if _, err := m.r(ctx).Get("/tests/ping"); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,12 @@
package pmapi
import (
"context"
"errors"
)
// Report sends request as json or multipart (if has attachment).
func (m *manager) ReportBug(context.Context, ReportBugReq) error {
// FIXME(conman): Implement.
return errors.New("not implemented")
}

Some files were not shown because too many files have changed in this diff Show More