mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
GODT-35: New pmapi client and manager using resty
This commit is contained in:
9
Makefile
9
Makefile
@ -254,12 +254,13 @@ coverage: test
|
||||
go tool cover -html=/tmp/coverage.out -o=coverage.html
|
||||
|
||||
mocks:
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,ClientManager,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/transfer PanicHandler,ClientManager,IMAPClientProvider > internal/transfer/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,ClientManager,BridgeUser,ChangeNotifier > internal/store/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/users/mocks/listener_mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/transfer PanicHandler,IMAPClientProvider > internal/transfer/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier > internal/store/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client,Manager > pkg/pmapi/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/message Fetcher > pkg/message/mocks/mocks.go
|
||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client > pkg/pmapi/mocks/mocks.go
|
||||
|
||||
lint: gofiles lint-golang lint-license lint-changelog
|
||||
|
||||
|
||||
5
go.mod
5
go.mod
@ -40,7 +40,7 @@ require (
|
||||
github.com/fatih/color v1.9.0
|
||||
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
|
||||
github.com/getsentry/sentry-go v0.8.0
|
||||
github.com/go-resty/resty/v2 v2.3.0
|
||||
github.com/go-resty/resty/v2 v2.4.0
|
||||
github.com/golang/mock v1.4.4
|
||||
github.com/google/go-cmp v0.5.1
|
||||
github.com/google/uuid v1.1.1
|
||||
@ -50,7 +50,6 @@ require (
|
||||
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
|
||||
github.com/logrusorgru/aurora v2.0.3+incompatible
|
||||
github.com/mattn/go-runewidth v0.0.9 // indirect
|
||||
github.com/miekg/dns v1.1.30
|
||||
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
|
||||
github.com/olekukonko/tablewriter v0.0.4 // indirect
|
||||
github.com/pkg/errors v0.9.1
|
||||
@ -64,7 +63,7 @@ require (
|
||||
github.com/urfave/cli/v2 v2.2.0
|
||||
github.com/vmihailenco/msgpack/v5 v5.1.3
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
golang.org/x/net v0.0.0-20200707034311-ab3426394381
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b
|
||||
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec
|
||||
)
|
||||
|
||||
|
||||
20
go.sum
20
go.sum
@ -113,8 +113,8 @@ github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclK
|
||||
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
|
||||
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
|
||||
github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8=
|
||||
github.com/go-resty/resty/v2 v2.3.0 h1:JOOeAvjSlapTT92p8xiS19Zxev1neGikoHsXJeOq8So=
|
||||
github.com/go-resty/resty/v2 v2.3.0/go.mod h1:UpN9CgLZNsv4e9XG50UU8xdI0F43UQ4HmxLBDwaroHU=
|
||||
github.com/go-resty/resty/v2 v2.4.0 h1:s6TItTLejEI+2mn98oijC5w/Rk2YU+OA6x0mnZN6r6k=
|
||||
github.com/go-resty/resty/v2 v2.4.0/go.mod h1:B88+xCTEwvfD94NOuE6GS1wMlnoKNY8eEiNizfNwOwA=
|
||||
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||
@ -195,8 +195,6 @@ github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m
|
||||
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
|
||||
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
|
||||
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
|
||||
github.com/miekg/dns v1.1.30 h1:Qww6FseFn8PRfw07jueqIXqodm0JKiiKuK0DeXSqfyo=
|
||||
github.com/miekg/dns v1.1.30/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@ -310,12 +308,10 @@ golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn
|
||||
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU=
|
||||
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw=
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
|
||||
@ -330,14 +326,15 @@ golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec h1:A1qYjneJuzBZZ2gIB8rd6zrfq6l7SoEMJ8EsSilNK/U=
|
||||
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@ -348,7 +345,6 @@ golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3
|
||||
golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69 h1:yBHHx+XZqXJBm6Exke3N7V9gnlsyXxoCPEb1yVenjfk=
|
||||
golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@ -23,7 +23,7 @@
|
||||
// - persistent settings
|
||||
// - event listener
|
||||
// - credentials store
|
||||
// - pmapi ClientManager
|
||||
// - pmapi Manager
|
||||
// In addition, the base initialises logging and reacts to command line arguments
|
||||
// which control the log verbosity and enable cpu/memory profiling.
|
||||
package base
|
||||
@ -85,7 +85,7 @@ type Base struct {
|
||||
Cache *cache.Cache
|
||||
Listener listener.Listener
|
||||
Creds *credentials.Store
|
||||
CM *pmapi.ClientManager
|
||||
CM pmapi.Manager
|
||||
CookieJar *cookies.Jar
|
||||
UserAgent *useragent.UserAgent
|
||||
Updater *updater.Updater
|
||||
@ -181,13 +181,26 @@ func New( // nolint[funlen]
|
||||
kc = keychain.NewMissingKeychain()
|
||||
}
|
||||
|
||||
// FIXME(conman): Customize config depending on build type (app version, host URL).
|
||||
cm := pmapi.New(pmapi.DefaultConfig)
|
||||
|
||||
// FIXME(conman): Should this be a real object, not just created via callbacks?
|
||||
cm.AddConnectionObserver(pmapi.NewConnectionObserver(
|
||||
func() { listener.Emit(events.InternetOffEvent, "") },
|
||||
func() { listener.Emit(events.InternetOnEvent, "") },
|
||||
))
|
||||
|
||||
// FIXME(conman): Implement force upgrade observer.
|
||||
// apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
|
||||
|
||||
// FIXME(conman): Set up fancy round tripper with DoH/TLS checks etc.
|
||||
// cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
|
||||
|
||||
jar, err := cookies.NewCookieJar(settingsObj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cm := pmapi.NewClientManager(getAPIConfig(configName, listener), userAgent)
|
||||
cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
|
||||
cm.SetCookieJar(jar)
|
||||
|
||||
key, err := crypto.NewKeyFromArmored(updater.DefaultPublicKey)
|
||||
@ -375,13 +388,3 @@ func (b *Base) doTeardown() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAPIConfig(configName string, listener listener.Listener) *pmapi.ClientConfig {
|
||||
apiConfig := pmapi.GetAPIConfig(configName, constants.Version)
|
||||
|
||||
apiConfig.ConnectionOffHandler = func() { listener.Emit(events.InternetOffEvent, "") }
|
||||
apiConfig.ConnectionOnHandler = func() { listener.Emit(events.InternetOnEvent, "") }
|
||||
apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
|
||||
|
||||
return apiConfig
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
@ -44,7 +45,7 @@ type Bridge struct {
|
||||
|
||||
locations Locator
|
||||
settings SettingsProvider
|
||||
clientManager users.ClientManager
|
||||
clientManager pmapi.Manager
|
||||
updater Updater
|
||||
versioner Versioner
|
||||
}
|
||||
@ -56,7 +57,7 @@ func New(
|
||||
sentryReporter *sentry.Reporter,
|
||||
panicHandler users.PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
clientManager users.ClientManager,
|
||||
clientManager pmapi.Manager,
|
||||
credStorer users.CredentialsStorer,
|
||||
updater Updater,
|
||||
versioner Versioner,
|
||||
@ -64,10 +65,11 @@ func New(
|
||||
// Allow DoH before starting the app if the user has previously set this setting.
|
||||
// This allows us to start even if protonmail is blocked.
|
||||
if s.GetBool(settings.AllowProxyKey) {
|
||||
clientManager.AllowProxy()
|
||||
// FIXME(conman): Support enable/disable of DoH.
|
||||
// clientManager.AllowProxy()
|
||||
}
|
||||
|
||||
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, clientManager, eventListener)
|
||||
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener)
|
||||
u := users.New(locations, panicHandler, eventListener, clientManager, credStorer, storeFactory, true)
|
||||
b := &Bridge{
|
||||
Users: u,
|
||||
@ -118,28 +120,15 @@ func (b *Bridge) heartbeat() {
|
||||
|
||||
// ReportBug reports a new bug from the user.
|
||||
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
||||
c := b.clientManager.GetAnonymousClient()
|
||||
defer c.Logout()
|
||||
|
||||
title := "[Bridge] Bug"
|
||||
report := pmapi.ReportReq{
|
||||
return b.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
|
||||
OS: osType,
|
||||
OSVersion: osVersion,
|
||||
Browser: emailClient,
|
||||
Title: title,
|
||||
Title: "[Bridge] Bug",
|
||||
Description: description,
|
||||
Username: accountName,
|
||||
Email: address,
|
||||
}
|
||||
|
||||
if err := c.Report(report); err != nil {
|
||||
log.Error("Reporting bug failed: ", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Bug successfully reported")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetUpdateChannel returns currently set update channel.
|
||||
|
||||
@ -31,7 +31,6 @@ type storeFactory struct {
|
||||
cache Cacher
|
||||
sentryReporter *sentry.Reporter
|
||||
panicHandler users.PanicHandler
|
||||
clientManager users.ClientManager
|
||||
eventListener listener.Listener
|
||||
storeCache *store.Cache
|
||||
}
|
||||
@ -40,14 +39,12 @@ func newStoreFactory(
|
||||
cache Cacher,
|
||||
sentryReporter *sentry.Reporter,
|
||||
panicHandler users.PanicHandler,
|
||||
clientManager users.ClientManager,
|
||||
eventListener listener.Listener,
|
||||
) *storeFactory {
|
||||
return &storeFactory{
|
||||
cache: cache,
|
||||
sentryReporter: sentryReporter,
|
||||
panicHandler: panicHandler,
|
||||
clientManager: clientManager,
|
||||
eventListener: eventListener,
|
||||
storeCache: store.NewCache(cache.GetIMAPCachePath()),
|
||||
}
|
||||
@ -56,7 +53,7 @@ func newStoreFactory(
|
||||
// New creates new store for given user.
|
||||
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
|
||||
storePath := getUserStorePath(f.cache.GetDBDir(), user.ID())
|
||||
return store.New(f.sentryReporter, f.panicHandler, user, f.clientManager, f.eventListener, storePath, f.storeCache)
|
||||
return store.New(f.sentryReporter, f.panicHandler, user, f.eventListener, storePath, f.storeCache)
|
||||
}
|
||||
|
||||
// Remove removes all store files for given user.
|
||||
|
||||
@ -18,8 +18,10 @@
|
||||
package cliie
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
@ -73,13 +75,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
return
|
||||
}
|
||||
|
||||
if auth.HasTwoFactor() {
|
||||
if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
|
||||
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
||||
if twoFactor == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err = client.Auth2FA(twoFactor, auth)
|
||||
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
|
||||
if err != nil {
|
||||
f.processAPIError(err)
|
||||
return
|
||||
@ -87,7 +89,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
}
|
||||
|
||||
mailboxPassword := password
|
||||
if auth.HasMailboxPassword() {
|
||||
if auth.PasswordMode == pmapi.TwoPasswordMode {
|
||||
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
||||
}
|
||||
if mailboxPassword == "" {
|
||||
|
||||
@ -20,7 +20,6 @@ package cliie
|
||||
import (
|
||||
"strings"
|
||||
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
@ -71,10 +70,13 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
|
||||
func (f *frontendCLI) processAPIError(err error) {
|
||||
log.Warn("API error: ", err)
|
||||
switch err {
|
||||
case pmapi.ErrAPINotReachable:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
// FIXME(conman): How to handle various API errors?
|
||||
/*
|
||||
case pmapi.ErrNoConnection:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
*/
|
||||
default:
|
||||
f.Println("Server error:", err.Error())
|
||||
}
|
||||
|
||||
@ -18,11 +18,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
@ -120,13 +122,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
return
|
||||
}
|
||||
|
||||
if auth.HasTwoFactor() {
|
||||
if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
|
||||
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
||||
if twoFactor == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err = client.Auth2FA(twoFactor, auth)
|
||||
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
|
||||
if err != nil {
|
||||
f.processAPIError(err)
|
||||
return
|
||||
@ -134,7 +136,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
}
|
||||
|
||||
mailboxPassword := password
|
||||
if auth.HasMailboxPassword() {
|
||||
if auth.PasswordMode == pmapi.TwoPasswordMode {
|
||||
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
||||
}
|
||||
if mailboxPassword == "" {
|
||||
|
||||
@ -20,7 +20,6 @@ package cli
|
||||
import (
|
||||
"strings"
|
||||
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
@ -71,10 +70,13 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
|
||||
func (f *frontendCLI) processAPIError(err error) {
|
||||
log.Warn("API error: ", err)
|
||||
switch err {
|
||||
case pmapi.ErrAPINotReachable:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
// FIXME(conman): How to handle various API errors?
|
||||
/*
|
||||
case pmapi.ErrNoConnection:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
*/
|
||||
default:
|
||||
f.Println("Server error:", err.Error())
|
||||
}
|
||||
|
||||
@ -164,7 +164,7 @@ func (a *Accounts) showLoginError(err error, scope string) bool {
|
||||
return false
|
||||
}
|
||||
log.Warnf("%s: %v", scope, err)
|
||||
if err == pmapi.ErrAPINotReachable {
|
||||
if err == pmapi.ErrNoConnection {
|
||||
a.qml.SetConnectionStatus(false)
|
||||
SendNotification(a.qml, TabAccount, a.qml.CanNotReachAPI())
|
||||
a.qml.ProcessFinished()
|
||||
|
||||
@ -130,7 +130,7 @@ func (s *FrontendQt) showLoginError(err error, scope string) bool {
|
||||
return false
|
||||
}
|
||||
log.Warnf("%s: %v", scope, err)
|
||||
if err == pmapi.ErrAPINotReachable {
|
||||
if err == pmapi.ErrNoConnection {
|
||||
s.Qml.SetConnectionStatus(false)
|
||||
s.SendNotification(TabAccount, s.Qml.CanNotReachAPI())
|
||||
s.Qml.ProcessFinished()
|
||||
|
||||
@ -20,6 +20,7 @@ package importexport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/transfer"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users"
|
||||
@ -39,7 +40,7 @@ type ImportExport struct {
|
||||
locations Locator
|
||||
cache Cacher
|
||||
panicHandler users.PanicHandler
|
||||
clientManager users.ClientManager
|
||||
clientManager pmapi.Manager
|
||||
}
|
||||
|
||||
func New(
|
||||
@ -47,7 +48,7 @@ func New(
|
||||
cache Cacher,
|
||||
panicHandler users.PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
clientManager users.ClientManager,
|
||||
clientManager pmapi.Manager,
|
||||
credStorer users.CredentialsStorer,
|
||||
) *ImportExport {
|
||||
u := users.New(locations, panicHandler, eventListener, clientManager, credStorer, &storeFactory{}, false)
|
||||
@ -64,57 +65,31 @@ func New(
|
||||
|
||||
// ReportBug reports a new bug from the user.
|
||||
func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
||||
c := ie.clientManager.GetAnonymousClient()
|
||||
defer c.Logout()
|
||||
|
||||
title := "[Import-Export] Bug"
|
||||
report := pmapi.ReportReq{
|
||||
return ie.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
|
||||
OS: osType,
|
||||
OSVersion: osVersion,
|
||||
Browser: emailClient,
|
||||
Title: title,
|
||||
Title: "[Import-Export] Bug",
|
||||
Description: description,
|
||||
Username: accountName,
|
||||
Email: address,
|
||||
}
|
||||
|
||||
if err := c.Report(report); err != nil {
|
||||
log.Error("Reporting bug failed: ", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Bug successfully reported")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ReportFile submits import report file.
|
||||
func (ie *ImportExport) ReportFile(osType, osVersion, accountName, address string, logdata []byte) error {
|
||||
c := ie.clientManager.GetAnonymousClient()
|
||||
defer c.Logout()
|
||||
|
||||
title := "[Import-Export] report file"
|
||||
description := "An Import-Export report from the user swam down the river."
|
||||
|
||||
report := pmapi.ReportReq{
|
||||
report := pmapi.ReportBugReq{
|
||||
OS: osType,
|
||||
OSVersion: osVersion,
|
||||
Description: description,
|
||||
Title: title,
|
||||
Description: "An Import-Export report from the user swam down the river.",
|
||||
Title: "[Import-Export] report file",
|
||||
Username: accountName,
|
||||
Email: address,
|
||||
}
|
||||
|
||||
report.AddAttachment("log", "report.log", bytes.NewReader(logdata))
|
||||
|
||||
if err := c.Report(report); err != nil {
|
||||
log.Error("Sending report failed: ", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Report successfully sent")
|
||||
|
||||
return nil
|
||||
return ie.clientManager.ReportBug(context.TODO(), report)
|
||||
}
|
||||
|
||||
// GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account.
|
||||
@ -187,5 +162,5 @@ func (ie *ImportExport) getPMAPIProvider(username, address string) (*transfer.PM
|
||||
log.WithError(err).Info("Address does not exist, using all addresses")
|
||||
}
|
||||
|
||||
return transfer.NewPMAPIProvider(ie.clientManager, user.ID(), addressID)
|
||||
return transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"strings"
|
||||
@ -28,7 +29,7 @@ import (
|
||||
)
|
||||
|
||||
type messageGetter interface {
|
||||
GetMessage(string) (*pmapi.Message, error)
|
||||
GetMessage(context.Context, string) (*pmapi.Message, error)
|
||||
}
|
||||
|
||||
type sendRecorderValue struct {
|
||||
@ -126,7 +127,7 @@ func (q *sendRecorder) isSendingOrSent(client messageGetter, hash string) (isSen
|
||||
return true, false
|
||||
}
|
||||
|
||||
message, err := client.GetMessage(value.messageID)
|
||||
message, err := client.GetMessage(context.TODO(), value.messageID)
|
||||
// Message could be deleted or there could be an internet issue or whatever,
|
||||
// so let's assume the message was not sent.
|
||||
if err != nil {
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
@ -33,7 +34,7 @@ type testSendRecorderGetMessageMock struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *testSendRecorderGetMessageMock) GetMessage(messageID string) (*pmapi.Message, error) {
|
||||
func (m *testSendRecorderGetMessageMock) GetMessage(_ context.Context, messageID string) (*pmapi.Message, error) {
|
||||
return m.message, m.err
|
||||
}
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ package smtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -122,7 +123,7 @@ func (su *smtpUser) getSendPreferences(
|
||||
}
|
||||
|
||||
func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata, err error) {
|
||||
emails, err := su.client().GetContactEmailByEmail(recipient, 0, 1000)
|
||||
emails, err := su.client().GetContactEmailByEmail(context.TODO(), recipient, 0, 1000)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -134,7 +135,7 @@ func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata
|
||||
}
|
||||
|
||||
var contact pmapi.Contact
|
||||
if contact, err = su.client().GetContactByID(email.ContactID); err != nil {
|
||||
if contact, err = su.client().GetContactByID(context.TODO(), email.ContactID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -150,7 +151,7 @@ func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata
|
||||
}
|
||||
|
||||
func (su *smtpUser) getAPIKeyData(recipient string) (apiKeys []pmapi.PublicKey, isInternal bool, err error) {
|
||||
return su.client().GetPublicKeysForEmail(recipient)
|
||||
return su.client().GetPublicKeysForEmail(context.TODO(), recipient)
|
||||
}
|
||||
|
||||
// Discard currently processed message.
|
||||
@ -218,7 +219,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
|
||||
|
||||
messageReader = io.TeeReader(messageReader, b)
|
||||
|
||||
mailSettings, err := su.client().GetMailSettings()
|
||||
mailSettings, err := su.client().GetMailSettings(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -339,7 +340,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
|
||||
// can lead to sending the wrong message. Also clients do not necessarily
|
||||
// delete the old draft.
|
||||
if draftID != "" {
|
||||
if err := su.client().DeleteMessages([]string{draftID}); err != nil {
|
||||
if err := su.client().DeleteMessages(context.TODO(), []string{draftID}); err != nil {
|
||||
log.WithError(err).WithField("draftID", draftID).Warn("Original draft cannot be deleted")
|
||||
}
|
||||
}
|
||||
@ -393,7 +394,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
|
||||
return errors.New("error decoding subject message " + message.Header.Get("Subject"))
|
||||
}
|
||||
if !su.continueSendingUnencryptedMail(subject) {
|
||||
if err := su.client().DeleteMessages([]string{message.ID}); err != nil {
|
||||
if err := su.client().DeleteMessages(context.TODO(), []string{message.ID}); err != nil {
|
||||
log.WithError(err).Warn("Failed to delete canceled messages")
|
||||
}
|
||||
return errors.New("sending was canceled by user")
|
||||
@ -422,7 +423,7 @@ func (su *smtpUser) handleReferencesHeader(m *pmapi.Message) (draftID, parentID
|
||||
if su.addressID != "" {
|
||||
filter.AddressID = su.addressID
|
||||
}
|
||||
metadata, _, _ := su.client().ListMessages(filter)
|
||||
metadata, _, _ := su.client().ListMessages(context.TODO(), filter)
|
||||
for _, m := range metadata {
|
||||
if m.IsDraft() {
|
||||
draftID = m.ID
|
||||
@ -442,7 +443,7 @@ func (su *smtpUser) handleReferencesHeader(m *pmapi.Message) (draftID, parentID
|
||||
if su.addressID != "" {
|
||||
filter.AddressID = su.addressID
|
||||
}
|
||||
metadata, _, _ := su.client().ListMessages(filter)
|
||||
metadata, _, _ := su.client().ListMessages(context.TODO(), filter)
|
||||
// There can be two or messages with the same external ID and then we cannot
|
||||
// be sure which message should be parent. Better to not choose any.
|
||||
if len(metadata) == 1 {
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
@ -80,7 +81,7 @@ func (loop *eventLoop) client() pmapi.Client {
|
||||
func (loop *eventLoop) setFirstEventID() (err error) {
|
||||
loop.log.Info("Setting first event ID")
|
||||
|
||||
event, err := loop.client().GetEvent("")
|
||||
event, err := loop.client().GetEvent(context.TODO(), "")
|
||||
if err != nil {
|
||||
loop.log.WithError(err).Error("Could not get latest event ID")
|
||||
return
|
||||
@ -221,7 +222,8 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
|
||||
// (e.g. no internet, ulimit reached etc.)
|
||||
defer func() {
|
||||
if errors.Cause(err) == pmapi.ErrAPINotReachable {
|
||||
// FIXME(conman): How to handle errors of different types?
|
||||
if errors.Is(err, pmapi.ErrNoConnection) {
|
||||
l.Warn("Internet unavailable")
|
||||
err = nil
|
||||
}
|
||||
@ -232,18 +234,20 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
err = nil
|
||||
}
|
||||
|
||||
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
|
||||
l.Warn("Need to upgrade application")
|
||||
err = nil
|
||||
}
|
||||
|
||||
_, errUnauthorized := errors.Cause(err).(*pmapi.ErrUnauthorized)
|
||||
// FIXME(conman): Handle force upgrade.
|
||||
/*
|
||||
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
|
||||
l.Warn("Need to upgrade application")
|
||||
err = nil
|
||||
}
|
||||
*/
|
||||
|
||||
if err == nil {
|
||||
loop.errCounter = 0
|
||||
}
|
||||
// All errors except Invalid Token (which is not possible to recover from) are ignored.
|
||||
if err != nil && !errUnauthorized && errors.Cause(err) != pmapi.ErrInvalidToken {
|
||||
|
||||
// All errors except ErrUnauthorized (which is not possible to recover from) are ignored.
|
||||
if !errors.Is(err, pmapi.ErrUnauthorized) {
|
||||
l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped")
|
||||
loop.errCounter++
|
||||
if loop.errCounter == errMaxSentry {
|
||||
@ -264,7 +268,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
loop.pollCounter++
|
||||
|
||||
var event *pmapi.Event
|
||||
if event, err = loop.client().GetEvent(loop.currentEventID); err != nil {
|
||||
if event, err = loop.client().GetEvent(context.TODO(), loop.currentEventID); err != nil {
|
||||
return false, errors.Wrap(err, "failed to get event")
|
||||
}
|
||||
|
||||
@ -461,12 +465,16 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
|
||||
|
||||
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
|
||||
|
||||
if msg, err = loop.client().GetMessage(message.ID); err != nil {
|
||||
if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok {
|
||||
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
|
||||
err = nil
|
||||
continue
|
||||
}
|
||||
if msg, err = loop.client().GetMessage(context.TODO(), message.ID); err != nil {
|
||||
// FIXME(conman): How to handle error of this particular type?
|
||||
|
||||
/*
|
||||
if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok {
|
||||
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
|
||||
err = nil
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
return errors.Wrap(err, "failed to get message from API for updating")
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/mail"
|
||||
"testing"
|
||||
"time"
|
||||
@ -39,15 +40,15 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
|
||||
// Doesn't matter which IDs are used.
|
||||
// This test is trying to see whether event loop will immediately process
|
||||
// next event if there is `More` of them.
|
||||
m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{
|
||||
EventID: "event50",
|
||||
More: 1,
|
||||
}, nil),
|
||||
m.client.EXPECT().GetEvent("event50").Return(&pmapi.Event{
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{
|
||||
EventID: "event70",
|
||||
More: 0,
|
||||
}, nil),
|
||||
m.client.EXPECT().GetEvent("event70").Return(&pmapi.Event{
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{
|
||||
EventID: "event71",
|
||||
More: 0,
|
||||
}, nil),
|
||||
@ -165,7 +166,7 @@ func TestEventLoopDeletionPaused(t *testing.T) {
|
||||
|
||||
func testEvent(t *testing.T, m *mocksForStore, event *pmapi.Event) {
|
||||
eventReceived := make(chan struct{})
|
||||
m.client.EXPECT().GetEvent("latestEventID").DoAndReturn(func(eventID string) (*pmapi.Event, error) {
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").DoAndReturn(func(_ context.Context, eventID string) (*pmapi.Event, error) {
|
||||
defer close(eventReceived)
|
||||
return event, nil
|
||||
})
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -41,7 +43,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
|
||||
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
|
||||
// wrapping it.
|
||||
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
|
||||
msg, err := storeMailbox.client().GetMessage(apiID)
|
||||
msg, err := storeMailbox.client().GetMessage(context.TODO(), apiID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -58,15 +60,17 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
|
||||
}
|
||||
|
||||
importReqs := &pmapi.ImportMsgReq{
|
||||
AddressID: msg.AddressID,
|
||||
Body: body,
|
||||
Unread: msg.Unread,
|
||||
Flags: msg.Flags,
|
||||
Time: msg.Time,
|
||||
LabelIDs: labelIDs,
|
||||
Metadata: &pmapi.ImportMetadata{
|
||||
AddressID: msg.AddressID,
|
||||
Unread: msg.Unread,
|
||||
Flags: msg.Flags,
|
||||
Time: msg.Time,
|
||||
LabelIDs: labelIDs,
|
||||
},
|
||||
Message: body,
|
||||
}
|
||||
|
||||
res, err := storeMailbox.client().Import([]*pmapi.ImportMsgReq{importReqs})
|
||||
res, err := storeMailbox.client().Import(context.TODO(), pmapi.ImportMsgReqs{importReqs})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -95,7 +99,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
|
||||
return ErrAllMailOpNotAllowed
|
||||
}
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().LabelMessages(apiIDs, storeMailbox.labelID)
|
||||
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// UnlabelMessages removes the label by calling an API.
|
||||
@ -108,7 +112,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
|
||||
return ErrAllMailOpNotAllowed
|
||||
}
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID)
|
||||
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// MarkMessagesRead marks the message read by calling an API.
|
||||
@ -135,7 +139,7 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
return storeMailbox.client().MarkMessagesRead(ids)
|
||||
return storeMailbox.client().MarkMessagesRead(context.TODO(), ids)
|
||||
}
|
||||
|
||||
// MarkMessagesUnread marks the message unread by calling an API.
|
||||
@ -147,7 +151,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as unread")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().MarkMessagesUnread(apiIDs)
|
||||
return storeMailbox.client().MarkMessagesUnread(context.TODO(), apiIDs)
|
||||
}
|
||||
|
||||
// MarkMessagesStarred adds the Starred label by calling an API.
|
||||
@ -160,7 +164,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as starred")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().LabelMessages(apiIDs, pmapi.StarredLabel)
|
||||
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
|
||||
}
|
||||
|
||||
// MarkMessagesUnstarred removes the Starred label by calling an API.
|
||||
@ -173,7 +177,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as unstarred")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().UnlabelMessages(apiIDs, pmapi.StarredLabel)
|
||||
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
|
||||
}
|
||||
|
||||
// MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API
|
||||
@ -257,11 +261,11 @@ func (storeMailbox *Mailbox) RemoveDeleted(apiIDs []string) error {
|
||||
}
|
||||
case pmapi.DraftLabel:
|
||||
storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts")
|
||||
if err := storeMailbox.client().DeleteMessages(apiIDs); err != nil {
|
||||
if err := storeMailbox.client().DeleteMessages(context.TODO(), apiIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if err := storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil {
|
||||
if err := storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -299,13 +303,13 @@ func (storeMailbox *Mailbox) deleteFromTrashOrSpam(apiIDs []string) error {
|
||||
}
|
||||
}
|
||||
if len(messageIDsToUnlabel) > 0 {
|
||||
if err := storeMailbox.client().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
||||
if err := storeMailbox.client().UnlabelMessages(context.TODO(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
||||
l.WithError(err).Warning("Cannot unlabel before deleting")
|
||||
}
|
||||
}
|
||||
if len(messageIDsToDelete) > 0 {
|
||||
storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages")
|
||||
if err := storeMailbox.client().DeleteMessages(messageIDsToDelete); err != nil {
|
||||
if err := storeMailbox.client().DeleteMessages(context.TODO(), messageIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,ClientManager,BridgeUser,ChangeNotifier)
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
@ -46,43 +46,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockClientManager is a mock of ClientManager interface
|
||||
type MockClientManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||
type MockClientManagerMockRecorder struct {
|
||||
mock *MockClientManager
|
||||
}
|
||||
|
||||
// NewMockClientManager creates a new mock instance
|
||||
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||
mock := &MockClientManager{ctrl: ctrl}
|
||||
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||
}
|
||||
|
||||
// MockBridgeUser is a mock of BridgeUser interface
|
||||
type MockBridgeUser struct {
|
||||
ctrl *gomock.Controller
|
||||
@ -145,6 +108,20 @@ func (mr *MockBridgeUserMockRecorder) GetAddressID(arg0 interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddressID", reflect.TypeOf((*MockBridgeUser)(nil).GetAddressID), arg0)
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockBridgeUser) GetClient() pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient")
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockBridgeUserMockRecorder) GetClient() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockBridgeUser)(nil).GetClient))
|
||||
}
|
||||
|
||||
// GetPrimaryAddress mocks base method
|
||||
func (m *MockBridgeUser) GetPrimaryAddress() string {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
@ -106,7 +107,6 @@ type Store struct {
|
||||
panicHandler PanicHandler
|
||||
eventLoop *eventLoop
|
||||
user BridgeUser
|
||||
clientManager ClientManager
|
||||
|
||||
log *logrus.Entry
|
||||
|
||||
@ -127,13 +127,12 @@ func New( // nolint[funlen]
|
||||
sentryReporter *sentry.Reporter,
|
||||
panicHandler PanicHandler,
|
||||
user BridgeUser,
|
||||
clientManager ClientManager,
|
||||
events listener.Listener,
|
||||
path string,
|
||||
cache *Cache,
|
||||
) (store *Store, err error) {
|
||||
if user == nil || clientManager == nil || events == nil || cache == nil {
|
||||
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, clientManager, events, cache)
|
||||
if user == nil || events == nil || cache == nil {
|
||||
return nil, fmt.Errorf("missing parameters - user: %v, events: %v, cache: %v", user, events, cache)
|
||||
}
|
||||
|
||||
l := log.WithField("user", user.ID())
|
||||
@ -156,7 +155,6 @@ func New( // nolint[funlen]
|
||||
store = &Store{
|
||||
sentryReporter: sentryReporter,
|
||||
panicHandler: panicHandler,
|
||||
clientManager: clientManager,
|
||||
user: user,
|
||||
cache: cache,
|
||||
filePath: path,
|
||||
@ -274,13 +272,13 @@ func (store *Store) init(firstInit bool) (err error) {
|
||||
}
|
||||
|
||||
func (store *Store) client() pmapi.Client {
|
||||
return store.clientManager.GetClient(store.UserID())
|
||||
return store.user.GetClient()
|
||||
}
|
||||
|
||||
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
|
||||
// the API is unavailable for whatever reason it tries to fetch the labels locally.
|
||||
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
|
||||
if labels, err = store.client().ListLabels(); err != nil {
|
||||
if labels, err = store.client().ListLabels(context.TODO()); err != nil {
|
||||
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
|
||||
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
|
||||
store.log.WithError(err).Error("Cannot list local labels")
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@ -133,7 +134,6 @@ type mocksForStore struct {
|
||||
events *storemocks.MockListener
|
||||
user *storemocks.MockBridgeUser
|
||||
client *pmapimocks.MockClient
|
||||
clientManager *storemocks.MockClientManager
|
||||
panicHandler *storemocks.MockPanicHandler
|
||||
changeNotifier *storemocks.MockChangeNotifier
|
||||
store *Store
|
||||
@ -150,7 +150,6 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
|
||||
events: storemocks.NewMockListener(ctrl),
|
||||
user: storemocks.NewMockBridgeUser(ctrl),
|
||||
client: pmapimocks.NewMockClient(ctrl),
|
||||
clientManager: storemocks.NewMockClientManager(ctrl),
|
||||
panicHandler: storemocks.NewMockPanicHandler(ctrl),
|
||||
changeNotifier: storemocks.NewMockChangeNotifier(ctrl),
|
||||
}
|
||||
@ -182,30 +181,30 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
|
||||
mocks.user.EXPECT().IsConnected().Return(true)
|
||||
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
|
||||
|
||||
mocks.clientManager.EXPECT().GetClient("userID").AnyTimes().Return(mocks.client)
|
||||
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
|
||||
|
||||
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
|
||||
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
|
||||
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
|
||||
})
|
||||
mocks.client.EXPECT().ListLabels().AnyTimes()
|
||||
mocks.client.EXPECT().CountMessages("")
|
||||
mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes()
|
||||
mocks.client.EXPECT().CountMessages(gomock.Any(), "")
|
||||
|
||||
// Call to get latest event ID and then to process first event.
|
||||
eventAfterSyncRequested := make(chan struct{})
|
||||
mocks.client.EXPECT().GetEvent("").Return(&pmapi.Event{
|
||||
mocks.client.EXPECT().GetEvent(gomock.Any(), "").Return(&pmapi.Event{
|
||||
EventID: "firstEventID",
|
||||
}, nil)
|
||||
mocks.client.EXPECT().GetEvent("firstEventID").DoAndReturn(func(_ string) (*pmapi.Event, error) {
|
||||
mocks.client.EXPECT().GetEvent(gomock.Any(), "firstEventID").DoAndReturn(func(_ context.Context, _ string) (*pmapi.Event, error) {
|
||||
close(eventAfterSyncRequested)
|
||||
return &pmapi.Event{
|
||||
EventID: "latestEventID",
|
||||
}, nil
|
||||
})
|
||||
|
||||
mocks.client.EXPECT().ListMessages(gomock.Any()).Return(msgs, len(msgs), nil).AnyTimes()
|
||||
mocks.client.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return(msgs, len(msgs), nil).AnyTimes()
|
||||
for _, msg := range msgs {
|
||||
mocks.client.EXPECT().GetMessage(msg.ID).Return(msg, nil).AnyTimes()
|
||||
mocks.client.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
|
||||
}
|
||||
|
||||
var err error
|
||||
@ -213,7 +212,6 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
|
||||
nil, // Sentry reporter is not used under unit tests.
|
||||
mocks.panicHandler,
|
||||
mocks.user,
|
||||
mocks.clientManager,
|
||||
mocks.events,
|
||||
filepath.Join(mocks.tmpDir, "mailbox-test.db"),
|
||||
mocks.cache,
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
@ -39,10 +40,10 @@ type storeSynchronizer interface {
|
||||
}
|
||||
|
||||
type messageLister interface {
|
||||
ListMessages(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
|
||||
ListMessages(context.Context, *pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
|
||||
}
|
||||
|
||||
func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func() messageLister, syncState *syncState) error {
|
||||
func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api messageLister, syncState *syncState) error {
|
||||
labelID := pmapi.AllMailLabel
|
||||
|
||||
// When the full sync starts (i.e. is not already in progress), we need to load
|
||||
@ -53,7 +54,7 @@ func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func()
|
||||
return errors.Wrap(err, "failed to load message IDs")
|
||||
}
|
||||
|
||||
if err := findIDRanges(labelID, api(), syncState); err != nil {
|
||||
if err := findIDRanges(labelID, api, syncState); err != nil {
|
||||
return errors.Wrap(err, "failed to load IDs ranges")
|
||||
}
|
||||
syncState.save()
|
||||
@ -71,7 +72,7 @@ func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func()
|
||||
defer panicHandler.HandlePanic()
|
||||
defer wg.Done()
|
||||
|
||||
err := syncBatch(labelID, store, api(), syncState, idRange, &shouldStop)
|
||||
err := syncBatch(labelID, store, api, syncState, idRange, &shouldStop)
|
||||
if err != nil {
|
||||
shouldStop = 1
|
||||
resultError = errors.Wrap(err, "failed to sync group")
|
||||
@ -147,7 +148,7 @@ func getSplitIDAndCount(labelID string, api messageLister, page int) (string, in
|
||||
Limit: 1,
|
||||
}
|
||||
// If the page does not exist, an empty page instead of an error is returned.
|
||||
messages, total, err := api.ListMessages(filter)
|
||||
messages, total, err := api.ListMessages(context.TODO(), filter)
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrap(err, "failed to list messages")
|
||||
}
|
||||
@ -189,7 +190,7 @@ func syncBatch( //nolint[funlen]
|
||||
|
||||
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
|
||||
|
||||
messages, _, err := api.ListMessages(filter)
|
||||
messages, _, err := api.ListMessages(context.TODO(), filter)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to list messages")
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
@ -34,7 +35,7 @@ type mockLister struct {
|
||||
messageIDs []string
|
||||
}
|
||||
|
||||
func (m *mockLister) ListMessages(filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) {
|
||||
func (m *mockLister) ListMessages(_ context.Context, filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) {
|
||||
if m.err != nil {
|
||||
return nil, 0, m.err
|
||||
}
|
||||
@ -197,7 +198,7 @@ func TestSyncAllMail(t *testing.T) { //nolint[funlen]
|
||||
|
||||
syncState := newSyncState(store, 0, tc.idRanges, tc.idsToBeDeleted)
|
||||
|
||||
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState)
|
||||
err := syncAllMail(m.panicHandler, store, api, syncState)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Check all messages were created or updated.
|
||||
@ -245,7 +246,7 @@ func TestSyncAllMail_FailedListing(t *testing.T) {
|
||||
}
|
||||
syncState := newTestSyncState(store)
|
||||
|
||||
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState)
|
||||
err := syncAllMail(m.panicHandler, store, api, syncState)
|
||||
require.EqualError(t, err, "failed to sync group: failed to list messages: error")
|
||||
}
|
||||
|
||||
@ -264,7 +265,7 @@ func TestSyncAllMail_FailedCreateOrUpdateMessage(t *testing.T) {
|
||||
}
|
||||
syncState := newTestSyncState(store)
|
||||
|
||||
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState)
|
||||
err := syncAllMail(m.panicHandler, store, api, syncState)
|
||||
require.EqualError(t, err, "failed to sync group: failed to create or update messages: error")
|
||||
}
|
||||
|
||||
|
||||
@ -23,10 +23,6 @@ type PanicHandler interface {
|
||||
HandlePanic()
|
||||
}
|
||||
|
||||
type ClientManager interface {
|
||||
GetClient(userID string) pmapi.Client
|
||||
}
|
||||
|
||||
// BridgeUser is subset of bridge.User for use by the Store.
|
||||
type BridgeUser interface {
|
||||
ID() string
|
||||
@ -35,6 +31,7 @@ type BridgeUser interface {
|
||||
IsCombinedAddressMode() bool
|
||||
GetPrimaryAddress() string
|
||||
GetStoreAddresses() []string
|
||||
GetClient() pmapi.Client
|
||||
UpdateUser() error
|
||||
CloseAllConnections()
|
||||
CloseConnection(string)
|
||||
|
||||
@ -17,6 +17,8 @@
|
||||
|
||||
package store
|
||||
|
||||
import "context"
|
||||
|
||||
// UserID returns user ID.
|
||||
func (store *Store) UserID() string {
|
||||
return store.user.ID()
|
||||
@ -24,7 +26,7 @@ func (store *Store) UserID() string {
|
||||
|
||||
// GetSpace returns used and total space in bytes.
|
||||
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
||||
apiUser, err := store.client().CurrentUser()
|
||||
apiUser, err := store.client().CurrentUser(context.TODO())
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
@ -33,7 +35,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
||||
|
||||
// GetMaxUpload returns max size of message + all attachments in bytes.
|
||||
func (store *Store) GetMaxUpload() (int64, error) {
|
||||
apiUser, err := store.client().CurrentUser()
|
||||
apiUser, err := store.client().CurrentUser(context.TODO())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@ -55,7 +56,7 @@ func (store *Store) createMailbox(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := store.client().CreateLabel(&pmapi.Label{
|
||||
_, err := store.client().CreateLabel(context.TODO(), &pmapi.Label{
|
||||
Name: name,
|
||||
Color: color,
|
||||
Exclusive: exclusive,
|
||||
@ -125,7 +126,7 @@ func (store *Store) leastUsedColor() string {
|
||||
func (store *Store) updateMailbox(labelID, newName, color string) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
|
||||
_, err := store.client().UpdateLabel(&pmapi.Label{
|
||||
_, err := store.client().UpdateLabel(context.TODO(), &pmapi.Label{
|
||||
ID: labelID,
|
||||
Name: newName,
|
||||
Color: color,
|
||||
@ -142,15 +143,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
|
||||
var err error
|
||||
switch labelID {
|
||||
case pmapi.SpamLabel:
|
||||
err = store.client().EmptyFolder(pmapi.SpamLabel, addressID)
|
||||
err = store.client().EmptyFolder(context.TODO(), pmapi.SpamLabel, addressID)
|
||||
case pmapi.TrashLabel:
|
||||
err = store.client().EmptyFolder(pmapi.TrashLabel, addressID)
|
||||
err = store.client().EmptyFolder(context.TODO(), pmapi.TrashLabel, addressID)
|
||||
default:
|
||||
err = fmt.Errorf("cannot empty mailbox %v", labelID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return store.client().DeleteLabel(labelID)
|
||||
return store.client().DeleteLabel(context.TODO(), labelID)
|
||||
}
|
||||
|
||||
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
|
||||
@ -165,7 +166,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
labels, err := store.client().ListLabels()
|
||||
labels, err := store.client().ListLabels(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -57,7 +58,7 @@ func (store *Store) CreateDraft(
|
||||
}
|
||||
|
||||
draftAction := store.getDraftAction(message)
|
||||
draft, err := store.client().CreateDraft(message, parentID, draftAction)
|
||||
draft, err := store.client().CreateDraft(context.TODO(), message, parentID, draftAction)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create draft")
|
||||
}
|
||||
@ -69,7 +70,7 @@ func (store *Store) CreateDraft(
|
||||
for _, att := range attachments {
|
||||
att.attachment.MessageID = draft.ID
|
||||
|
||||
createdAttachment, err := store.client().CreateAttachment(att.attachment, att.encReader, att.sigReader)
|
||||
createdAttachment, err := store.client().CreateAttachment(context.TODO(), att.attachment, att.encReader, att.sigReader)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create attachment")
|
||||
}
|
||||
@ -183,7 +184,7 @@ func (store *Store) getDraftAction(message *pmapi.Message) int {
|
||||
// SendMessage sends the message.
|
||||
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
_, _, err := store.client().SendMessage(messageID, req)
|
||||
_, _, err := store.client().SendMessage(context.TODO(), messageID, req)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@ -127,12 +127,12 @@ func TestDeleteMessage(t *testing.T) {
|
||||
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
|
||||
}
|
||||
|
||||
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread int, labelIDs []string) { //nolint[unparam]
|
||||
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread pmapi.Boolean, labelIDs []string) { //nolint[unparam]
|
||||
msg := getTestMessage(id, subject, sender, unread, labelIDs)
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
|
||||
}
|
||||
|
||||
func getTestMessage(id, subject, sender string, unread int, labelIDs []string) *pmapi.Message {
|
||||
func getTestMessage(id, subject, sender string, unread pmapi.Boolean, labelIDs []string) *pmapi.Message {
|
||||
address := &mail.Address{Address: sender}
|
||||
return &pmapi.Message{
|
||||
ID: id,
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
@ -34,7 +35,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
|
||||
|
||||
// updateCountsFromServer will download and set the counts.
|
||||
func (store *Store) updateCountsFromServer() error {
|
||||
counts, err := store.client().CountMessages("")
|
||||
counts, err := store.client().CountMessages(context.TODO(), "")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot update counts from server")
|
||||
}
|
||||
@ -152,7 +153,7 @@ func (store *Store) triggerSync() {
|
||||
|
||||
store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started")
|
||||
|
||||
err := syncAllMail(store.panicHandler, store, func() messageLister { return store.client() }, syncState)
|
||||
err := syncAllMail(store.panicHandler, store, store.client(), syncState)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Store sync failed")
|
||||
store.syncCooldown.increaseWaitTime()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/transfer (interfaces: PanicHandler,ClientManager,IMAPClientProvider)
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/transfer (interfaces: PanicHandler,IMAPClientProvider)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
@ -7,7 +7,6 @@ package mocks
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
imap "github.com/emersion/go-imap"
|
||||
sasl "github.com/emersion/go-sasl"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@ -48,57 +47,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockClientManager is a mock of ClientManager interface
|
||||
type MockClientManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||
type MockClientManagerMockRecorder struct {
|
||||
mock *MockClientManager
|
||||
}
|
||||
|
||||
// NewMockClientManager creates a new mock instance
|
||||
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||
mock := &MockClientManager{ctrl: ctrl}
|
||||
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CheckConnection mocks base method
|
||||
func (m *MockClientManager) CheckConnection() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CheckConnection")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CheckConnection indicates an expected call of CheckConnection
|
||||
func (mr *MockClientManagerMockRecorder) CheckConnection() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConnection", reflect.TypeOf((*MockClientManager)(nil).CheckConnection))
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||
}
|
||||
|
||||
// MockIMAPClientProvider is a mock of IMAPClientProvider interface
|
||||
type MockIMAPClientProvider struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@ -25,7 +25,6 @@ import (
|
||||
|
||||
imapID "github.com/ProtonMail/go-imap-id"
|
||||
"github.com/ProtonMail/proton-bridge/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/emersion/go-imap"
|
||||
imapClient "github.com/emersion/go-imap/client"
|
||||
"github.com/emersion/go-sasl"
|
||||
@ -118,15 +117,19 @@ func (p *IMAPProvider) tryReconnect(ensureSelectedIn string) error {
|
||||
return previousErr
|
||||
}
|
||||
|
||||
err := pmapi.CheckConnection()
|
||||
log.WithError(err).Debug("Connection check")
|
||||
if err != nil {
|
||||
time.Sleep(imapReconnectSleep)
|
||||
previousErr = err
|
||||
continue
|
||||
}
|
||||
// FIXME(conman): This should register as connection observer.
|
||||
|
||||
err = p.reauth()
|
||||
/*
|
||||
err := pmapi.CheckConnection()
|
||||
log.WithError(err).Debug("Connection check")
|
||||
if err != nil {
|
||||
time.Sleep(imapReconnectSleep)
|
||||
previousErr = err
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
err := p.reauth()
|
||||
log.WithError(err).Debug("Reauth")
|
||||
if err != nil {
|
||||
time.Sleep(imapReconnectSleep)
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
@ -34,25 +35,27 @@ const (
|
||||
|
||||
// PMAPIProvider implements import and export to/from ProtonMail server.
|
||||
type PMAPIProvider struct {
|
||||
clientManager ClientManager
|
||||
userID string
|
||||
addressID string
|
||||
keyRing *crypto.KeyRing
|
||||
builder *message.Builder
|
||||
client pmapi.Client
|
||||
userID string
|
||||
addressID string
|
||||
keyRing *crypto.KeyRing
|
||||
builder *message.Builder
|
||||
|
||||
nextImportRequests map[string]*pmapi.ImportMsgReq // Key is msg transfer ID.
|
||||
nextImportRequestsSize int
|
||||
|
||||
timeIt *timeIt
|
||||
|
||||
connection bool
|
||||
}
|
||||
|
||||
// NewPMAPIProvider returns new PMAPIProvider.
|
||||
func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*PMAPIProvider, error) {
|
||||
func NewPMAPIProvider(client pmapi.Client, userID, addressID string) (*PMAPIProvider, error) {
|
||||
provider := &PMAPIProvider{
|
||||
clientManager: clientManager,
|
||||
userID: userID,
|
||||
addressID: addressID,
|
||||
builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
|
||||
client: client,
|
||||
userID: userID,
|
||||
addressID: addressID,
|
||||
builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
|
||||
|
||||
nextImportRequests: map[string]*pmapi.ImportMsgReq{},
|
||||
nextImportRequestsSize: 0,
|
||||
@ -61,7 +64,7 @@ func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*P
|
||||
}
|
||||
|
||||
if addressID != "" {
|
||||
keyRing, err := clientManager.GetClient(userID).KeyRingForAddressID(addressID)
|
||||
keyRing, err := client.KeyRingForAddressID(addressID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get key ring")
|
||||
}
|
||||
@ -71,10 +74,6 @@ func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*P
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (p *PMAPIProvider) client() pmapi.Client {
|
||||
return p.clientManager.GetClient(p.userID)
|
||||
}
|
||||
|
||||
// ID returns identifier of current setup of PMAPI provider.
|
||||
// Identification is unique per user.
|
||||
func (p *PMAPIProvider) ID() string {
|
||||
@ -83,7 +82,7 @@ func (p *PMAPIProvider) ID() string {
|
||||
|
||||
// Mailboxes returns all available labels in ProtonMail account.
|
||||
func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox, error) {
|
||||
labels, err := p.client().ListLabels()
|
||||
labels, err := p.client.ListLabels(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -92,7 +91,7 @@ func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox,
|
||||
|
||||
emptyLabelsMap := map[string]bool{}
|
||||
if !includeEmpty {
|
||||
messagesCounts, err := p.client().CountMessages(p.addressID)
|
||||
messagesCounts, err := p.client.CountMessages(context.Background(), p.addressID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -120,7 +119,7 @@ func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox,
|
||||
ID: label.ID,
|
||||
Name: label.Name,
|
||||
Color: label.Color,
|
||||
IsExclusive: label.Exclusive == 1,
|
||||
IsExclusive: bool(label.Exclusive),
|
||||
})
|
||||
}
|
||||
return mailboxes, nil
|
||||
@ -160,10 +159,10 @@ func (l byFoldersLabels) Swap(i, j int) {
|
||||
|
||||
// Less sorts first folders, then labels, by user order.
|
||||
func (l byFoldersLabels) Less(i, j int) bool {
|
||||
if l[i].Exclusive == 1 && l[j].Exclusive == 0 {
|
||||
if l[i].Exclusive && !l[j].Exclusive {
|
||||
return true
|
||||
}
|
||||
if l[i].Exclusive == 0 && l[j].Exclusive == 1 {
|
||||
if !l[i].Exclusive && l[j].Exclusive {
|
||||
return false
|
||||
}
|
||||
return l[i].Order < l[j].Order
|
||||
|
||||
@ -157,7 +157,7 @@ func (p *PMAPIProvider) exportMessage(rule *Rule, progress *Progress, pmapiMsgID
|
||||
|
||||
body, err := p.builder.NewJobWithOptions(
|
||||
context.Background(),
|
||||
p.client(),
|
||||
p.client,
|
||||
msg.ID,
|
||||
message.JobOptions{IgnoreDecryptionErrors: !skipEncryptedMessages},
|
||||
).GetResult()
|
||||
@ -169,14 +169,9 @@ func (p *PMAPIProvider) exportMessage(rule *Rule, progress *Progress, pmapiMsgID
|
||||
return Message{Body: []byte(msg.Body)}, err
|
||||
}
|
||||
|
||||
unread := false
|
||||
if msg.Unread == 1 {
|
||||
unread = true
|
||||
}
|
||||
|
||||
return Message{
|
||||
ID: msgID,
|
||||
Unread: unread,
|
||||
Unread: bool(msg.Unread),
|
||||
Body: body,
|
||||
Sources: []Mailbox{rule.SourceMailbox},
|
||||
Targets: rule.TargetMailboxes,
|
||||
|
||||
@ -19,6 +19,7 @@ package transfer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -56,7 +57,7 @@ func (p *PMAPIProvider) CreateMailbox(mailbox Mailbox) (Mailbox, error) {
|
||||
exclusive = 1
|
||||
}
|
||||
|
||||
label, err := p.client().CreateLabel(&pmapi.Label{
|
||||
label, err := p.client.CreateLabel(context.TODO(), &pmapi.Label{
|
||||
Name: mailbox.Name,
|
||||
Color: mailbox.Color,
|
||||
Exclusive: exclusive,
|
||||
@ -194,7 +195,7 @@ func (p *PMAPIProvider) transferMessage(rules transferRules, progress *Progress,
|
||||
return
|
||||
}
|
||||
|
||||
importMsgReqSize := len(importMsgReq.Body)
|
||||
importMsgReqSize := len(importMsgReq.Message)
|
||||
if p.nextImportRequestsSize+importMsgReqSize > pmapiImportBatchMaxSize || len(p.nextImportRequests) == pmapiImportBatchMaxItems {
|
||||
preparedImportRequestsCh <- p.nextImportRequests
|
||||
p.nextImportRequests = map[string]*pmapi.ImportMsgReq{}
|
||||
@ -226,9 +227,12 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
|
||||
}
|
||||
}
|
||||
|
||||
unread := 0
|
||||
var unread pmapi.Boolean
|
||||
|
||||
if msg.Unread {
|
||||
unread = 1
|
||||
unread = pmapi.True
|
||||
} else {
|
||||
unread = pmapi.False
|
||||
}
|
||||
|
||||
labelIDs := []string{}
|
||||
@ -243,12 +247,14 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
|
||||
}
|
||||
|
||||
return &pmapi.ImportMsgReq{
|
||||
AddressID: p.addressID,
|
||||
Body: body,
|
||||
Unread: unread,
|
||||
Time: message.Time,
|
||||
Flags: computeMessageFlags(message.Header),
|
||||
LabelIDs: labelIDs,
|
||||
Metadata: &pmapi.ImportMetadata{
|
||||
AddressID: p.addressID,
|
||||
Unread: unread,
|
||||
Time: message.Time,
|
||||
Flags: computeMessageFlags(message.Header),
|
||||
LabelIDs: labelIDs,
|
||||
},
|
||||
Message: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -293,7 +299,7 @@ func (p *PMAPIProvider) importMessages(progress *Progress, importRequests map[st
|
||||
}
|
||||
|
||||
importMsgIDs := []string{}
|
||||
importMsgRequests := []*pmapi.ImportMsgReq{}
|
||||
importMsgRequests := pmapi.ImportMsgReqs{}
|
||||
for msgID, req := range importRequests {
|
||||
importMsgIDs = append(importMsgIDs, msgID)
|
||||
importMsgRequests = append(importMsgRequests, req)
|
||||
@ -327,7 +333,7 @@ func (p *PMAPIProvider) importMessages(progress *Progress, importRequests map[st
|
||||
|
||||
func (p *PMAPIProvider) importMessage(msgSourceID string, progress *Progress, req *pmapi.ImportMsgReq) (importedID string, importedErr error) {
|
||||
progress.callWrap(func() error {
|
||||
results, err := p.importRequest(msgSourceID, []*pmapi.ImportMsgReq{req})
|
||||
results, err := p.importRequest(msgSourceID, pmapi.ImportMsgReqs{req})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to import messages")
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ package transfer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@ -33,7 +34,7 @@ func TestPMAPIProviderMailboxes(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
setupPMAPIClientExpectationForExport(&m)
|
||||
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
@ -78,7 +79,7 @@ func TestPMAPIProviderTransferTo(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
setupPMAPIClientExpectationForExport(&m)
|
||||
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
|
||||
rules, rulesClose := newTestRules(t)
|
||||
@ -96,7 +97,7 @@ func TestPMAPIProviderTransferFrom(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
setupPMAPIClientExpectationForImport(&m)
|
||||
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
|
||||
rules, rulesClose := newTestRules(t)
|
||||
@ -114,7 +115,7 @@ func TestPMAPIProviderTransferFromDraft(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
setupPMAPIClientExpectationForImportDraft(&m)
|
||||
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
|
||||
rules, rulesClose := newTestRules(t)
|
||||
@ -133,9 +134,9 @@ func TestPMAPIProviderTransferFromTo(t *testing.T) {
|
||||
setupPMAPIClientExpectationForExport(&m)
|
||||
setupPMAPIClientExpectationForImport(&m)
|
||||
|
||||
source, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
source, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
target, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
target, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
|
||||
rules, rulesClose := newTestRules(t)
|
||||
@ -151,22 +152,22 @@ func setupPMAPIRules(rules transferRules) {
|
||||
|
||||
func setupPMAPIClientExpectationForExport(m *mocks) {
|
||||
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{
|
||||
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: 0, Order: 2},
|
||||
{ID: "label2", Name: "Bar", Color: "green", Exclusive: 0, Order: 1},
|
||||
{ID: "folder1", Name: "One", Color: "red", Exclusive: 1, Order: 1},
|
||||
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: 1, Order: 2},
|
||||
}, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any()).Return([]*pmapi.MessagesCount{
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{
|
||||
{LabelID: "label1", Total: 10},
|
||||
{LabelID: "label2", Total: 0},
|
||||
{LabelID: "folder1", Total: 20},
|
||||
}, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{
|
||||
{ID: "msg1"},
|
||||
{ID: "msg2"},
|
||||
}, 2, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().GetMessage(gomock.Any()).DoAndReturn(func(msgID string) (*pmapi.Message, error) {
|
||||
m.pmapiClient.EXPECT().GetMessage(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, msgID string) (*pmapi.Message, error) {
|
||||
return &pmapi.Message{
|
||||
ID: msgID,
|
||||
Body: string(getTestMsgBody(msgID)),
|
||||
@ -177,11 +178,11 @@ func setupPMAPIClientExpectationForExport(m *mocks) {
|
||||
|
||||
func setupPMAPIClientExpectationForImport(m *mocks) {
|
||||
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().Import(gomock.Any()).DoAndReturn(func(requests []*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error) {
|
||||
m.pmapiClient.EXPECT().Import(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, requests pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) {
|
||||
results := []*pmapi.ImportMsgRes{}
|
||||
for _, request := range requests {
|
||||
for _, msgID := range []string{"msg1", "msg2"} {
|
||||
if bytes.Contains(request.Body, []byte(msgID)) {
|
||||
if bytes.Contains(request.Message, []byte(msgID)) {
|
||||
results = append(results, &pmapi.ImportMsgRes{MessageID: msgID, Error: nil})
|
||||
}
|
||||
}
|
||||
@ -192,7 +193,7 @@ func setupPMAPIClientExpectationForImport(m *mocks) {
|
||||
|
||||
func setupPMAPIClientExpectationForImportDraft(m *mocks) {
|
||||
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().CreateDraft(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(msg *pmapi.Message, parentID string, action int) (*pmapi.Message, error) {
|
||||
m.pmapiClient.EXPECT().CreateDraft(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, msg *pmapi.Message, parentID string, action int) (*pmapi.Message, error) {
|
||||
r.Equal(m.t, msg.Subject, "draft1")
|
||||
msg.ID = "draft1"
|
||||
return msg, nil
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
@ -57,13 +58,18 @@ func (p *PMAPIProvider) tryReconnect() error {
|
||||
return previousErr
|
||||
}
|
||||
|
||||
err := p.clientManager.CheckConnection()
|
||||
log.WithError(err).Debug("Connection check")
|
||||
if err != nil {
|
||||
time.Sleep(pmapiReconnectSleep)
|
||||
previousErr = err
|
||||
continue
|
||||
}
|
||||
// FIXME(conman): This should register as a connection observer somehow...
|
||||
// Maybe the entire "provider" could register as an observer and pause if it is notified of dropped connection?
|
||||
|
||||
/*
|
||||
err := p.clientManager.CheckConnection()
|
||||
log.WithError(err).Debug("Connection check")
|
||||
if err != nil {
|
||||
time.Sleep(pmapiReconnectSleep)
|
||||
previousErr = err
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
break
|
||||
}
|
||||
@ -77,7 +83,7 @@ func (p *PMAPIProvider) listMessages(filter *pmapi.MessagesFilter) (messages []*
|
||||
p.timeIt.start("listing", key)
|
||||
defer p.timeIt.stop("listing", key)
|
||||
|
||||
messages, count, err = p.client().ListMessages(filter)
|
||||
messages, count, err = p.client.ListMessages(context.TODO(), filter)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -88,18 +94,18 @@ func (p *PMAPIProvider) getMessage(msgID string) (message *pmapi.Message, err er
|
||||
p.timeIt.start("download", msgID)
|
||||
defer p.timeIt.stop("download", msgID)
|
||||
|
||||
message, err = p.client().GetMessage(msgID)
|
||||
message, err = p.client.GetMessage(context.TODO(), msgID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (p *PMAPIProvider) importRequest(msgSourceID string, req []*pmapi.ImportMsgReq) (res []*pmapi.ImportMsgRes, err error) {
|
||||
func (p *PMAPIProvider) importRequest(msgSourceID string, req pmapi.ImportMsgReqs) (res []*pmapi.ImportMsgRes, err error) {
|
||||
err = p.ensureConnection(func() error {
|
||||
p.timeIt.start("upload", msgSourceID)
|
||||
defer p.timeIt.stop("upload", msgSourceID)
|
||||
|
||||
res, err = p.client().Import(req)
|
||||
res, err = p.client.Import(context.TODO(), req)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -110,7 +116,7 @@ func (p *PMAPIProvider) createDraft(msgSourceID string, message *pmapi.Message,
|
||||
p.timeIt.start("upload", msgSourceID)
|
||||
defer p.timeIt.stop("upload", msgSourceID)
|
||||
|
||||
draft, err = p.client().CreateDraft(message, parent, action)
|
||||
draft, err = p.client.CreateDraft(context.TODO(), message, parent, action)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -123,7 +129,7 @@ func (p *PMAPIProvider) createAttachment(msgSourceID string, att *pmapi.Attachme
|
||||
p.timeIt.start("upload", key)
|
||||
defer p.timeIt.stop("upload", key)
|
||||
|
||||
created, err = p.client().CreateAttachment(att, r, sig)
|
||||
created, err = p.client.CreateAttachment(context.TODO(), att, r, sig)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
||||
@ -23,7 +23,6 @@ import (
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
transfermocks "github.com/ProtonMail/proton-bridge/internal/transfer/mocks"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
@ -33,10 +32,8 @@ type mocks struct {
|
||||
|
||||
ctrl *gomock.Controller
|
||||
panicHandler *transfermocks.MockPanicHandler
|
||||
clientManager *transfermocks.MockClientManager
|
||||
imapClientProvider *transfermocks.MockIMAPClientProvider
|
||||
pmapiClient *pmapimocks.MockClient
|
||||
pmapiConfig *pmapi.ClientConfig
|
||||
|
||||
keyring *crypto.KeyRing
|
||||
}
|
||||
@ -49,15 +46,11 @@ func initMocks(t *testing.T) mocks {
|
||||
|
||||
ctrl: mockCtrl,
|
||||
panicHandler: transfermocks.NewMockPanicHandler(mockCtrl),
|
||||
clientManager: transfermocks.NewMockClientManager(mockCtrl),
|
||||
imapClientProvider: transfermocks.NewMockIMAPClientProvider(mockCtrl),
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
pmapiConfig: &pmapi.ClientConfig{},
|
||||
keyring: newTestKeyring(),
|
||||
}
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).AnyTimes()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
|
||||
@ -17,10 +17,6 @@
|
||||
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
type PanicHandler interface {
|
||||
HandlePanic()
|
||||
}
|
||||
@ -32,8 +28,3 @@ type MetricsManager interface {
|
||||
Cancel()
|
||||
Fail()
|
||||
}
|
||||
|
||||
type ClientManager interface {
|
||||
GetClient(userID string) pmapi.Client
|
||||
CheckConnection() error
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
@ -31,10 +32,6 @@ import (
|
||||
|
||||
var ErrManualUpdateRequired = errors.New("manual update is required")
|
||||
|
||||
type ClientProvider interface {
|
||||
GetAnonymousClient() pmapi.Client
|
||||
}
|
||||
|
||||
type Installer interface {
|
||||
InstallUpdate(*semver.Version, io.Reader) error
|
||||
}
|
||||
@ -46,7 +43,7 @@ type Settings interface {
|
||||
}
|
||||
|
||||
type Updater struct {
|
||||
cm ClientProvider
|
||||
cm pmapi.Manager
|
||||
installer Installer
|
||||
settings Settings
|
||||
kr *crypto.KeyRing
|
||||
@ -59,7 +56,7 @@ type Updater struct {
|
||||
}
|
||||
|
||||
func New(
|
||||
cm ClientProvider,
|
||||
cm pmapi.Manager,
|
||||
installer Installer,
|
||||
s Settings,
|
||||
kr *crypto.KeyRing,
|
||||
@ -87,13 +84,10 @@ func New(
|
||||
func (u *Updater) Check() (VersionInfo, error) {
|
||||
logrus.Info("Checking for updates")
|
||||
|
||||
client := u.cm.GetAnonymousClient()
|
||||
defer client.Logout()
|
||||
|
||||
r, err := client.DownloadAndVerify(
|
||||
b, err := u.cm.DownloadAndVerify(
|
||||
u.kr,
|
||||
u.getVersionFileURL(),
|
||||
u.getVersionFileURL()+".sig",
|
||||
u.kr,
|
||||
)
|
||||
if err != nil {
|
||||
return VersionInfo{}, err
|
||||
@ -101,7 +95,7 @@ func (u *Updater) Check() (VersionInfo, error) {
|
||||
|
||||
var versionMap VersionMap
|
||||
|
||||
if err := json.NewDecoder(r).Decode(&versionMap); err != nil {
|
||||
if err := json.Unmarshal(b, &versionMap); err != nil {
|
||||
return VersionInfo{}, err
|
||||
}
|
||||
|
||||
@ -141,15 +135,12 @@ func (u *Updater) InstallUpdate(update VersionInfo) error {
|
||||
return u.locker.doOnce(func() error {
|
||||
logrus.WithField("package", update.Package).Info("Installing update package")
|
||||
|
||||
client := u.cm.GetAnonymousClient()
|
||||
defer client.Logout()
|
||||
|
||||
r, err := client.DownloadAndVerify(update.Package, update.Package+".sig", u.kr)
|
||||
b, err := u.cm.DownloadAndVerify(u.kr, update.Package, update.Package+".sig")
|
||||
if err != nil {
|
||||
return errors.Wrap(ErrDownloadVerify, err.Error())
|
||||
}
|
||||
|
||||
if err := u.installer.InstallUpdate(update.Version, r); err != nil {
|
||||
if err := u.installer.InstallUpdate(update.Version, bytes.NewReader(b)); err != nil {
|
||||
return errors.Wrap(ErrInstall, err.Error())
|
||||
}
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
@ -29,7 +28,6 @@ import (
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -40,9 +38,9 @@ func TestCheck(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.1.0", false)
|
||||
updater := newTestUpdater(cm, "1.1.0", false)
|
||||
|
||||
versionMap := VersionMap{
|
||||
"stable": VersionInfo{
|
||||
@ -53,13 +51,11 @@ func TestCheck(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
updater.getVersionFileURL(),
|
||||
updater.getVersionFileURL()+".sig",
|
||||
gomock.Any(),
|
||||
).Return(bytes.NewReader(mustMarshal(t, versionMap)), nil)
|
||||
|
||||
client.EXPECT().Logout()
|
||||
).Return(mustMarshal(t, versionMap), nil)
|
||||
|
||||
version, err := updater.Check()
|
||||
|
||||
@ -71,9 +67,9 @@ func TestCheckEarlyAccess(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.1.0", true)
|
||||
updater := newTestUpdater(cm, "1.1.0", true)
|
||||
|
||||
versionMap := VersionMap{
|
||||
"stable": VersionInfo{
|
||||
@ -90,13 +86,11 @@ func TestCheckEarlyAccess(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
updater.getVersionFileURL(),
|
||||
updater.getVersionFileURL()+".sig",
|
||||
gomock.Any(),
|
||||
).Return(bytes.NewReader(mustMarshal(t, versionMap)), nil)
|
||||
|
||||
client.EXPECT().Logout()
|
||||
).Return(mustMarshal(t, versionMap), nil)
|
||||
|
||||
version, err := updater.Check()
|
||||
|
||||
@ -108,18 +102,16 @@ func TestCheckBadSignature(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.2.0", false)
|
||||
updater := newTestUpdater(cm, "1.2.0", false)
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
updater.getVersionFileURL(),
|
||||
updater.getVersionFileURL()+".sig",
|
||||
gomock.Any(),
|
||||
).Return(nil, errors.New("bad signature"))
|
||||
|
||||
client.EXPECT().Logout()
|
||||
|
||||
_, err := updater.Check()
|
||||
|
||||
assert.Error(t, err)
|
||||
@ -129,9 +121,9 @@ func TestIsUpdateApplicable(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.4.0", false)
|
||||
updater := newTestUpdater(cm, "1.4.0", false)
|
||||
|
||||
versionOld := VersionInfo{
|
||||
Version: semver.MustParse("1.3.0"),
|
||||
@ -165,9 +157,9 @@ func TestCanInstall(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.4.0", false)
|
||||
updater := newTestUpdater(cm, "1.4.0", false)
|
||||
|
||||
versionManual := VersionInfo{
|
||||
Version: semver.MustParse("1.5.0"),
|
||||
@ -192,9 +184,9 @@ func TestInstallUpdate(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.4.0", false)
|
||||
updater := newTestUpdater(cm, "1.4.0", false)
|
||||
|
||||
latestVersion := VersionInfo{
|
||||
Version: semver.MustParse("1.5.0"),
|
||||
@ -203,13 +195,11 @@ func TestInstallUpdate(t *testing.T) {
|
||||
RolloutProportion: 1.0,
|
||||
}
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
latestVersion.Package,
|
||||
latestVersion.Package+".sig",
|
||||
gomock.Any(),
|
||||
).Return(bytes.NewReader([]byte("tgz_data_here")), nil)
|
||||
|
||||
client.EXPECT().Logout()
|
||||
).Return([]byte("tgz_data_here"), nil)
|
||||
|
||||
err := updater.InstallUpdate(latestVersion)
|
||||
|
||||
@ -220,9 +210,9 @@ func TestInstallUpdateBadSignature(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.4.0", false)
|
||||
updater := newTestUpdater(cm, "1.4.0", false)
|
||||
|
||||
latestVersion := VersionInfo{
|
||||
Version: semver.MustParse("1.5.0"),
|
||||
@ -231,14 +221,12 @@ func TestInstallUpdateBadSignature(t *testing.T) {
|
||||
RolloutProportion: 1.0,
|
||||
}
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
latestVersion.Package,
|
||||
latestVersion.Package+".sig",
|
||||
gomock.Any(),
|
||||
).Return(nil, errors.New("bad signature"))
|
||||
|
||||
client.EXPECT().Logout()
|
||||
|
||||
err := updater.InstallUpdate(latestVersion)
|
||||
|
||||
assert.Error(t, err)
|
||||
@ -248,9 +236,9 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.4.0", false)
|
||||
updater := newTestUpdater(cm, "1.4.0", false)
|
||||
|
||||
updater.installer = &fakeInstaller{delay: 2 * time.Second}
|
||||
|
||||
@ -261,13 +249,11 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
|
||||
RolloutProportion: 1.0,
|
||||
}
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
latestVersion.Package,
|
||||
latestVersion.Package+".sig",
|
||||
gomock.Any(),
|
||||
).Return(bytes.NewReader([]byte("tgz_data_here")), nil)
|
||||
|
||||
client.EXPECT().Logout()
|
||||
).Return([]byte("tgz_data_here"), nil)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
@ -288,9 +274,9 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func newTestUpdater(client *mocks.MockClient, curVer string, earlyAccess bool) *Updater {
|
||||
func newTestUpdater(manager *mocks.MockManager, curVer string, earlyAccess bool) *Updater {
|
||||
return New(
|
||||
&fakeClientProvider{client: client},
|
||||
manager,
|
||||
&fakeInstaller{},
|
||||
newFakeSettings(0.5, earlyAccess),
|
||||
nil,
|
||||
@ -299,14 +285,6 @@ func newTestUpdater(client *mocks.MockClient, curVer string, earlyAccess bool) *
|
||||
)
|
||||
}
|
||||
|
||||
type fakeClientProvider struct {
|
||||
client *mocks.MockClient
|
||||
}
|
||||
|
||||
func (p *fakeClientProvider) GetAnonymousClient() pmapi.Client {
|
||||
return p.client
|
||||
}
|
||||
|
||||
type fakeInstaller struct {
|
||||
bad bool
|
||||
delay time.Duration
|
||||
|
||||
@ -133,7 +133,7 @@ func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
|
||||
|
||||
// store.New() in user.init
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrInvalidToken),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUnauthorized),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
|
||||
// getAPIUser() loads user info from API (e.g. userID).
|
||||
@ -149,3 +149,13 @@ func (s *Credentials) Logout() {
|
||||
func (s *Credentials) IsConnected() bool {
|
||||
return s.APIToken != "" && s.MailboxPassword != ""
|
||||
}
|
||||
|
||||
func (s *Credentials) SplitAPIToken() (string, string, error) {
|
||||
split := strings.Split(s.APIToken, ":")
|
||||
|
||||
if len(split) != 2 {
|
||||
return "", "", errors.New("malformed API token")
|
||||
}
|
||||
|
||||
return split[0], split[1], nil
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ func NewStore(keychain *keychain.Keychain) *Store {
|
||||
return &Store{secrets: keychain}
|
||||
}
|
||||
|
||||
func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails []string) (creds *Credentials, err error) {
|
||||
func (s *Store) Add(userID, userName, uid, ref, mailboxPassword string, emails []string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
@ -49,10 +49,10 @@ func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails [
|
||||
"emails": emails,
|
||||
}).Trace("Adding new credentials")
|
||||
|
||||
creds = &Credentials{
|
||||
creds := &Credentials{
|
||||
UserID: userID,
|
||||
Name: userName,
|
||||
APIToken: apiToken,
|
||||
APIToken: uid + ":" + ref,
|
||||
MailboxPassword: mailboxPassword,
|
||||
IsHidden: false,
|
||||
}
|
||||
@ -72,82 +72,82 @@ func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails [
|
||||
creds.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
if err = s.saveCredentials(creds); err != nil {
|
||||
return
|
||||
if err := s.saveCredentials(creds); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return creds, err
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
func (s *Store) SwitchAddressMode(userID string) error {
|
||||
func (s *Store) SwitchAddressMode(userID string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials.IsCombinedAddressMode = !credentials.IsCombinedAddressMode
|
||||
credentials.BridgePassword = generatePassword()
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
return credentials, s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateEmails(userID string, emails []string) error {
|
||||
func (s *Store) UpdateEmails(userID string, emails []string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials.SetEmailList(emails)
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
return credentials, s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdatePassword(userID, password string) error {
|
||||
func (s *Store) UpdatePassword(userID, password string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials.MailboxPassword = password
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
return credentials, s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateToken(userID, apiToken string) error {
|
||||
func (s *Store) UpdateToken(userID, uid, ref string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials.APIToken = apiToken
|
||||
credentials.APIToken = uid + ":" + ref
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
return credentials, s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) Logout(userID string) error {
|
||||
func (s *Store) Logout(userID string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials.Logout()
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
return credentials, s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
// List returns a list of usernames that have credentials stored.
|
||||
@ -249,7 +249,7 @@ func (s *Store) get(userID string) (creds *Credentials, err error) {
|
||||
}
|
||||
|
||||
// saveCredentials encrypts and saves password to the keychain store.
|
||||
func (s *Store) saveCredentials(credentials *Credentials) (err error) {
|
||||
func (s *Store) saveCredentials(credentials *Credentials) error {
|
||||
credentials.Version = keychain.Version
|
||||
|
||||
return s.secrets.Put(credentials.UserID, credentials.Marshal())
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/users (interfaces: Locator,PanicHandler,ClientManager,CredentialsStorer,StoreMaker)
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/users (interfaces: Locator,PanicHandler,CredentialsStorer,StoreMaker)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
@ -9,7 +9,6 @@ import (
|
||||
|
||||
store "github.com/ProtonMail/proton-bridge/internal/store"
|
||||
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
@ -85,109 +84,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockClientManager is a mock of ClientManager interface
|
||||
type MockClientManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||
type MockClientManagerMockRecorder struct {
|
||||
mock *MockClientManager
|
||||
}
|
||||
|
||||
// NewMockClientManager creates a new mock instance
|
||||
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||
mock := &MockClientManager{ctrl: ctrl}
|
||||
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AllowProxy mocks base method
|
||||
func (m *MockClientManager) AllowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AllowProxy")
|
||||
}
|
||||
|
||||
// AllowProxy indicates an expected call of AllowProxy
|
||||
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// CheckConnection mocks base method
|
||||
func (m *MockClientManager) CheckConnection() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CheckConnection")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CheckConnection indicates an expected call of CheckConnection
|
||||
func (mr *MockClientManagerMockRecorder) CheckConnection() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConnection", reflect.TypeOf((*MockClientManager)(nil).CheckConnection))
|
||||
}
|
||||
|
||||
// DisallowProxy mocks base method
|
||||
func (m *MockClientManager) DisallowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisallowProxy")
|
||||
}
|
||||
|
||||
// DisallowProxy indicates an expected call of DisallowProxy
|
||||
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
|
||||
}
|
||||
|
||||
// GetAnonymousClient mocks base method
|
||||
func (m *MockClientManager) GetAnonymousClient() pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAnonymousClient")
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAnonymousClient indicates an expected call of GetAnonymousClient
|
||||
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
|
||||
}
|
||||
|
||||
// GetAuthUpdateChannel mocks base method
|
||||
func (m *MockClientManager) GetAuthUpdateChannel() chan pmapi.ClientAuth {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthUpdateChannel")
|
||||
ret0, _ := ret[0].(chan pmapi.ClientAuth)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAuthUpdateChannel indicates an expected call of GetAuthUpdateChannel
|
||||
func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetAuthUpdateChannel))
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||
}
|
||||
|
||||
// MockCredentialsStorer is a mock of CredentialsStorer interface
|
||||
type MockCredentialsStorer struct {
|
||||
ctrl *gomock.Controller
|
||||
@ -212,18 +108,18 @@ func (m *MockCredentialsStorer) EXPECT() *MockCredentialsStorerMockRecorder {
|
||||
}
|
||||
|
||||
// Add mocks base method
|
||||
func (m *MockCredentialsStorer) Add(arg0, arg1, arg2, arg3 string, arg4 []string) (*credentials.Credentials, error) {
|
||||
func (m *MockCredentialsStorer) Add(arg0, arg1, arg2, arg3, arg4 string, arg5 []string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3, arg4)
|
||||
ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add
|
||||
func (mr *MockCredentialsStorerMockRecorder) Add(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
|
||||
func (mr *MockCredentialsStorerMockRecorder) Add(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockCredentialsStorer)(nil).Add), arg0, arg1, arg2, arg3, arg4)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockCredentialsStorer)(nil).Add), arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
}
|
||||
|
||||
// Delete mocks base method
|
||||
@ -271,11 +167,12 @@ func (mr *MockCredentialsStorerMockRecorder) List() *gomock.Call {
|
||||
}
|
||||
|
||||
// Logout mocks base method
|
||||
func (m *MockCredentialsStorer) Logout(arg0 string) error {
|
||||
func (m *MockCredentialsStorer) Logout(arg0 string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Logout", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Logout indicates an expected call of Logout
|
||||
@ -285,11 +182,12 @@ func (mr *MockCredentialsStorerMockRecorder) Logout(arg0 interface{}) *gomock.Ca
|
||||
}
|
||||
|
||||
// SwitchAddressMode mocks base method
|
||||
func (m *MockCredentialsStorer) SwitchAddressMode(arg0 string) error {
|
||||
func (m *MockCredentialsStorer) SwitchAddressMode(arg0 string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SwitchAddressMode", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SwitchAddressMode indicates an expected call of SwitchAddressMode
|
||||
@ -299,11 +197,12 @@ func (mr *MockCredentialsStorerMockRecorder) SwitchAddressMode(arg0 interface{})
|
||||
}
|
||||
|
||||
// UpdateEmails mocks base method
|
||||
func (m *MockCredentialsStorer) UpdateEmails(arg0 string, arg1 []string) error {
|
||||
func (m *MockCredentialsStorer) UpdateEmails(arg0 string, arg1 []string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateEmails", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateEmails indicates an expected call of UpdateEmails
|
||||
@ -313,11 +212,12 @@ func (mr *MockCredentialsStorerMockRecorder) UpdateEmails(arg0, arg1 interface{}
|
||||
}
|
||||
|
||||
// UpdatePassword mocks base method
|
||||
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) error {
|
||||
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdatePassword indicates an expected call of UpdatePassword
|
||||
@ -327,17 +227,18 @@ func (mr *MockCredentialsStorerMockRecorder) UpdatePassword(arg0, arg1 interface
|
||||
}
|
||||
|
||||
// UpdateToken mocks base method
|
||||
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1 string) error {
|
||||
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1, arg2 string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateToken", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret := m.ctrl.Call(m, "UpdateToken", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateToken indicates an expected call of UpdateToken
|
||||
func (mr *MockCredentialsStorerMockRecorder) UpdateToken(arg0, arg1 interface{}) *gomock.Call {
|
||||
func (mr *MockCredentialsStorerMockRecorder) UpdateToken(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateToken), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateToken), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// MockStoreMaker is a mock of StoreMaker interface
|
||||
|
||||
@ -20,14 +20,8 @@ package users
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/internal/store"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
type Configer interface {
|
||||
GetAppVersion() string
|
||||
GetAPIConfig() *pmapi.ClientConfig
|
||||
}
|
||||
|
||||
type Locator interface {
|
||||
Clear() error
|
||||
}
|
||||
@ -38,25 +32,16 @@ type PanicHandler interface {
|
||||
|
||||
type CredentialsStorer interface {
|
||||
List() (userIDs []string, err error)
|
||||
Add(userID, userName, apiToken, mailboxPassword string, emails []string) (*credentials.Credentials, error)
|
||||
Add(userID, userName, uid, ref, mailboxPassword string, emails []string) (*credentials.Credentials, error)
|
||||
Get(userID string) (*credentials.Credentials, error)
|
||||
SwitchAddressMode(userID string) error
|
||||
UpdateEmails(userID string, emails []string) error
|
||||
UpdatePassword(userID, password string) error
|
||||
UpdateToken(userID, apiToken string) error
|
||||
Logout(userID string) error
|
||||
SwitchAddressMode(userID string) (*credentials.Credentials, error)
|
||||
UpdateEmails(userID string, emails []string) (*credentials.Credentials, error)
|
||||
UpdatePassword(userID, password string) (*credentials.Credentials, error)
|
||||
UpdateToken(userID, uid, ref string) (*credentials.Credentials, error)
|
||||
Logout(userID string) (*credentials.Credentials, error)
|
||||
Delete(userID string) error
|
||||
}
|
||||
|
||||
type ClientManager interface {
|
||||
GetClient(userID string) pmapi.Client
|
||||
GetAnonymousClient() pmapi.Client
|
||||
AllowProxy()
|
||||
DisallowProxy()
|
||||
GetAuthUpdateChannel() chan pmapi.ClientAuth
|
||||
CheckConnection() error
|
||||
}
|
||||
|
||||
type StoreMaker interface {
|
||||
New(user store.BridgeUser) (*store.Store, error)
|
||||
Remove(userID string) error
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -36,11 +37,11 @@ var ErrLoggedOutUser = errors.New("account is logged out, use the app to login a
|
||||
|
||||
// User is a struct on top of API client and credentials store.
|
||||
type User struct {
|
||||
log *logrus.Entry
|
||||
panicHandler PanicHandler
|
||||
listener listener.Listener
|
||||
clientManager ClientManager
|
||||
credStorer CredentialsStorer
|
||||
log *logrus.Entry
|
||||
panicHandler PanicHandler
|
||||
listener listener.Listener
|
||||
client pmapi.Client
|
||||
credStorer CredentialsStorer
|
||||
|
||||
storeFactory StoreMaker
|
||||
store *store.Store
|
||||
@ -48,75 +49,76 @@ type User struct {
|
||||
userID string
|
||||
creds *credentials.Credentials
|
||||
|
||||
lock sync.RWMutex
|
||||
isAuthorized bool
|
||||
lock sync.RWMutex
|
||||
|
||||
useOnlyActiveAddresses bool
|
||||
}
|
||||
|
||||
// newUser creates a new user.
|
||||
// The user is initially disconnected and must be connected by calling connect().
|
||||
func newUser(
|
||||
panicHandler PanicHandler,
|
||||
userID string,
|
||||
eventListener listener.Listener,
|
||||
credStorer CredentialsStorer,
|
||||
clientManager ClientManager,
|
||||
storeFactory StoreMaker,
|
||||
) (u *User, err error) {
|
||||
useOnlyActiveAddresses bool,
|
||||
) (*User, *credentials.Credentials, error) {
|
||||
log := log.WithField("user", userID)
|
||||
|
||||
log.Debug("Creating or loading user")
|
||||
|
||||
creds, err := credStorer.Get(userID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load user credentials")
|
||||
return nil, nil, errors.Wrap(err, "failed to load user credentials")
|
||||
}
|
||||
|
||||
u = &User{
|
||||
log: log,
|
||||
panicHandler: panicHandler,
|
||||
listener: eventListener,
|
||||
credStorer: credStorer,
|
||||
clientManager: clientManager,
|
||||
storeFactory: storeFactory,
|
||||
userID: userID,
|
||||
creds: creds,
|
||||
}
|
||||
|
||||
return
|
||||
return &User{
|
||||
log: log,
|
||||
panicHandler: panicHandler,
|
||||
listener: eventListener,
|
||||
credStorer: credStorer,
|
||||
storeFactory: storeFactory,
|
||||
userID: userID,
|
||||
creds: creds,
|
||||
useOnlyActiveAddresses: useOnlyActiveAddresses,
|
||||
}, creds, nil
|
||||
}
|
||||
|
||||
func (u *User) client() pmapi.Client {
|
||||
return u.clientManager.GetClient(u.userID)
|
||||
}
|
||||
// connect connects a user. This includes
|
||||
// - providing it with an authorised API client
|
||||
// - loading its credentials from the credentials store
|
||||
// - loading and unlocking its PGP keys
|
||||
// - loading its store
|
||||
func (u *User) connect(ctx context.Context, client pmapi.Client, creds *credentials.Credentials) error {
|
||||
u.log.Info("Connecting user")
|
||||
|
||||
// init initialises a user. This includes reloading its credentials from the credentials store
|
||||
// (such as when logging out and back in, you need to reload the credentials because the new credentials will
|
||||
// have the apitoken and password), authorising the user against the api, loading the user store (creating a new one
|
||||
// if necessary), and setting the imap idle updates channel (used to send imap idle updates to the imap backend if
|
||||
// something in the store changed).
|
||||
func (u *User) init() (err error) {
|
||||
u.log.Info("Initialising user")
|
||||
// Connected users have an API client.
|
||||
u.client = client
|
||||
|
||||
// Reload the user's credentials (if they log out and back in we need the new
|
||||
// version with the apitoken and mailbox password).
|
||||
creds, err := u.credStorer.Get(u.userID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to load user credentials")
|
||||
}
|
||||
// FIXME(conman): How to remove this auth handler when user is disconnected?
|
||||
u.client.AddAuthHandler(u.handleAuth)
|
||||
|
||||
// Save the latest credentials for the user.
|
||||
u.creds = creds
|
||||
|
||||
// Try to authorise the user if they aren't already authorised.
|
||||
// Note: we still allow users to set up accounts if the internet is off.
|
||||
if authErr := u.authorizeIfNecessary(false); authErr != nil {
|
||||
switch errors.Cause(authErr) {
|
||||
case pmapi.ErrAPINotReachable, pmapi.ErrUpgradeApplication, ErrLoggedOutUser:
|
||||
u.log.WithError(authErr).Warn("Could not authorize user")
|
||||
default:
|
||||
if logoutErr := u.logout(); logoutErr != nil {
|
||||
u.log.WithError(logoutErr).Warn("Could not logout user")
|
||||
}
|
||||
return errors.Wrap(authErr, "failed to authorize user")
|
||||
// Connected users have unlocked keys.
|
||||
// FIXME(conman): clients should always be authorized! This is a workaround to avoid a major refactor :(
|
||||
if u.creds.IsConnected() {
|
||||
if err := u.client.Unlock(ctx, []byte(u.creds.MailboxPassword)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Connected users have a store.
|
||||
if err := u.loadStore(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) loadStore() error {
|
||||
// Logged-out user keeps store running to access offline data.
|
||||
// Therefore it is necessary to close it before re-init.
|
||||
if u.store != nil {
|
||||
@ -125,93 +127,28 @@ func (u *User) init() (err error) {
|
||||
}
|
||||
u.store = nil
|
||||
}
|
||||
|
||||
store, err := u.storeFactory.New(u)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create store")
|
||||
}
|
||||
|
||||
u.store = store
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// authorizeIfNecessary checks whether user is logged in and is connected to api auth channel.
|
||||
// If user is not already connected to the api auth channel (for example there was no internet during start),
|
||||
// it tries to connect it.
|
||||
func (u *User) authorizeIfNecessary(emitEvent bool) (err error) {
|
||||
// If user is connected and has an auth channel, then perfect, nothing to do here.
|
||||
if u.creds.IsConnected() && u.isAuthorized {
|
||||
// The keyring unlock is triggered here to resolve state where apiClient
|
||||
// is authenticated (we have auth token) but it was not possible to download
|
||||
// and unlock the keys (internet not reachable).
|
||||
return u.unlockIfNecessary()
|
||||
}
|
||||
|
||||
if !u.creds.IsConnected() {
|
||||
err = ErrLoggedOutUser
|
||||
} else if err = u.authorizeAndUnlock(); err != nil {
|
||||
u.log.WithError(err).Error("Could not authorize and unlock user")
|
||||
|
||||
switch errors.Cause(err) {
|
||||
case pmapi.ErrUpgradeApplication, pmapi.ErrAPINotReachable: // Ignore these errors.
|
||||
default:
|
||||
if errLogout := u.credStorer.Logout(u.userID); errLogout != nil {
|
||||
u.log.WithField("err", errLogout).Error("Could not log user out from credentials store")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if emitEvent && err != nil &&
|
||||
errors.Cause(err) != pmapi.ErrUpgradeApplication &&
|
||||
errors.Cause(err) != pmapi.ErrAPINotReachable {
|
||||
u.listener.Emit(events.LogoutEvent, u.userID)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
|
||||
func (u *User) unlockIfNecessary() error {
|
||||
if u.client().IsUnlocked() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := u.client().Unlock([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorizeAndUnlock tries to authorize the user with the API using the the user's APIToken.
|
||||
// If that succeeds, it tries to unlock the user's keys and addresses.
|
||||
func (u *User) authorizeAndUnlock() (err error) {
|
||||
if u.creds.APIToken == "" {
|
||||
u.log.Warn("Could not connect to API auth channel, have no API token")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := u.client().AuthRefresh(u.creds.APIToken); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh API auth")
|
||||
}
|
||||
|
||||
if err := u.client().Unlock([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) updateAuthToken(auth *pmapi.Auth) {
|
||||
func (u *User) handleAuth(auth *pmapi.Auth) error {
|
||||
u.log.Debug("User received auth")
|
||||
|
||||
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
|
||||
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
|
||||
return
|
||||
creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to update refresh token in credentials store")
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
u.creds = creds
|
||||
|
||||
u.isAuthorized = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearStore removes the database.
|
||||
@ -248,7 +185,7 @@ func (u *User) closeStore() error {
|
||||
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
|
||||
// After proper refactor of SMTP and IMAP remove this method.
|
||||
func (u *User) GetTemporaryPMAPIClient() pmapi.Client {
|
||||
return u.client()
|
||||
return u.client
|
||||
}
|
||||
|
||||
// ID returns the user's userID.
|
||||
@ -272,6 +209,10 @@ func (u *User) IsConnected() bool {
|
||||
return u.creds.IsConnected()
|
||||
}
|
||||
|
||||
func (u *User) GetClient() pmapi.Client {
|
||||
return u.client
|
||||
}
|
||||
|
||||
// IsCombinedAddressMode returns whether user is set in combined or split mode.
|
||||
// Combined mode is the default mode and is what users typically need.
|
||||
// Split mode is mostly for outlook as it cannot handle sending e-mails from an
|
||||
@ -345,7 +286,7 @@ func (u *User) GetAddressID(address string) (id string, err error) {
|
||||
return u.store.GetAddressID(address)
|
||||
}
|
||||
|
||||
addresses := u.client().Addresses()
|
||||
addresses := u.client.Addresses()
|
||||
pmapiAddress := addresses.ByEmail(address)
|
||||
if pmapiAddress != nil {
|
||||
return pmapiAddress.ID, nil
|
||||
@ -366,18 +307,21 @@ func (u *User) GetBridgePassword() string {
|
||||
// CheckBridgeLogin checks whether the user is logged in and the bridge
|
||||
// IMAP/SMTP password is correct.
|
||||
func (u *User) CheckBridgeLogin(password string) error {
|
||||
if isApplicationOutdated {
|
||||
u.listener.Emit(events.UpgradeApplicationEvent, "")
|
||||
return pmapi.ErrUpgradeApplication
|
||||
}
|
||||
// FIXME(conman): Handle force upgrade?
|
||||
|
||||
/*
|
||||
if isApplicationOutdated {
|
||||
u.listener.Emit(events.UpgradeApplicationEvent, "")
|
||||
return pmapi.ErrUpgradeApplication
|
||||
}
|
||||
*/
|
||||
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
// True here because users should be notified by popup of auth failure.
|
||||
if err := u.authorizeIfNecessary(true); err != nil {
|
||||
u.log.WithError(err).Error("Failed to authorize user")
|
||||
return err
|
||||
if !u.creds.IsConnected() {
|
||||
u.listener.Emit(events.LogoutEvent, u.userID)
|
||||
return ErrLoggedOutUser
|
||||
}
|
||||
|
||||
return u.creds.CheckPassword(password)
|
||||
@ -388,60 +332,57 @@ func (u *User) UpdateUser() error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
if err := u.authorizeIfNecessary(true); err != nil {
|
||||
return errors.Wrap(err, "cannot update user")
|
||||
}
|
||||
|
||||
_, err := u.client().UpdateUser()
|
||||
_, err := u.client.UpdateUser(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = u.client().ReloadKeys([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
if err := u.client.ReloadKeys(context.TODO(), []byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to reload keys")
|
||||
}
|
||||
|
||||
emails := u.client().Addresses().ActiveEmails()
|
||||
if err := u.credStorer.UpdateEmails(u.userID, emails); err != nil {
|
||||
creds, err := u.credStorer.UpdateEmails(u.userID, u.client.Addresses().ActiveEmails())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
u.creds = creds
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SwitchAddressMode changes mode from combined to split and vice versa. The mode to switch to is determined by the
|
||||
// state of the user's credentials in the credentials store. See `IsCombinedAddressMode` for more details.
|
||||
func (u *User) SwitchAddressMode() (err error) {
|
||||
func (u *User) SwitchAddressMode() error {
|
||||
u.log.Trace("Switching user address mode")
|
||||
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
u.CloseAllConnections()
|
||||
|
||||
if u.store == nil {
|
||||
err = errors.New("store is not initialised")
|
||||
return
|
||||
return errors.New("store is not initialised")
|
||||
}
|
||||
|
||||
newAddressModeState := !u.IsCombinedAddressMode()
|
||||
|
||||
if err = u.store.UseCombinedMode(newAddressModeState); err != nil {
|
||||
u.log.WithError(err).Error("Could not switch store address mode")
|
||||
return
|
||||
if err := u.store.UseCombinedMode(newAddressModeState); err != nil {
|
||||
return errors.Wrap(err, "could not switch store address mode")
|
||||
}
|
||||
|
||||
if u.creds.IsCombinedAddressMode != newAddressModeState {
|
||||
if err = u.credStorer.SwitchAddressMode(u.userID); err != nil {
|
||||
u.log.WithError(err).Error("Could not switch credentials store address mode")
|
||||
return
|
||||
}
|
||||
if u.creds.IsCombinedAddressMode == newAddressModeState {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
creds, err := u.credStorer.SwitchAddressMode(u.userID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not switch credentials store address mode")
|
||||
}
|
||||
|
||||
return err
|
||||
u.creds = creds
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logout is the same as Logout, but for internal purposes (logged out from
|
||||
@ -458,35 +399,37 @@ func (u *User) logout() error {
|
||||
u.listener.Emit(events.UserRefreshEvent, u.userID)
|
||||
}
|
||||
|
||||
u.isAuthorized = false
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Logout logs out the user from pmapi, the credentials store, the mail store, and tries to remove as much
|
||||
// sensitive data as possible.
|
||||
func (u *User) Logout() (err error) {
|
||||
func (u *User) Logout() error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
u.log.Debug("Logging out user")
|
||||
|
||||
if !u.creds.IsConnected() {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
u.client().Logout()
|
||||
// FIXME(conman): Do we delete API client now? Who cleans up? What about registered handlers?
|
||||
if err := u.client.AuthDelete(context.TODO()); err != nil {
|
||||
u.log.WithError(err).Warn("Failed to delete auth")
|
||||
}
|
||||
|
||||
if err = u.credStorer.Logout(u.userID); err != nil {
|
||||
creds, err := u.credStorer.Logout(u.userID)
|
||||
if err != nil {
|
||||
u.log.WithError(err).Warn("Could not log user out from credentials store")
|
||||
|
||||
if err = u.credStorer.Delete(u.userID); err != nil {
|
||||
if err := u.credStorer.Delete(u.userID); err != nil {
|
||||
u.log.WithError(err).Error("Could not delete user from credentials store")
|
||||
}
|
||||
} else {
|
||||
u.creds = creds
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
|
||||
// Do not close whole store, just event loop. Some information might be needed offline (e.g. addressID)
|
||||
u.closeEventLoop()
|
||||
|
||||
@ -494,15 +437,7 @@ func (u *User) Logout() (err error) {
|
||||
|
||||
runtime.GC()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *User) refreshFromCredentials() {
|
||||
if credentials, err := u.credStorer.Get(u.userID); err != nil {
|
||||
log.WithError(err).Error("Cannot refresh user credentials")
|
||||
} else {
|
||||
u.creds = credentials
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) closeEventLoop() {
|
||||
|
||||
@ -19,12 +19,15 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
imapcache "github.com/ProtonMail/proton-bridge/internal/imap/cache"
|
||||
"github.com/ProtonMail/proton-bridge/internal/metrics"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@ -45,7 +48,7 @@ type Users struct {
|
||||
locations Locator
|
||||
panicHandler PanicHandler
|
||||
events listener.Listener
|
||||
clientManager ClientManager
|
||||
clientManager pmapi.Manager
|
||||
credStorer CredentialsStorer
|
||||
storeFactory StoreMaker
|
||||
|
||||
@ -62,16 +65,13 @@ type Users struct {
|
||||
useOnlyActiveAddresses bool
|
||||
|
||||
lock sync.RWMutex
|
||||
|
||||
// stopAll can be closed to stop all goroutines from looping (watchAppOutdated, watchAPIAuths, heartbeat etc).
|
||||
stopAll chan struct{}
|
||||
}
|
||||
|
||||
func New(
|
||||
locations Locator,
|
||||
panicHandler PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
clientManager ClientManager,
|
||||
clientManager pmapi.Manager,
|
||||
credStorer CredentialsStorer,
|
||||
storeFactory StoreMaker,
|
||||
useOnlyActiveAddresses bool,
|
||||
@ -87,98 +87,104 @@ func New(
|
||||
storeFactory: storeFactory,
|
||||
useOnlyActiveAddresses: useOnlyActiveAddresses,
|
||||
lock: sync.RWMutex{},
|
||||
stopAll: make(chan struct{}),
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
u.watchAppOutdated()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
u.watchAPIAuths()
|
||||
}()
|
||||
// FIXME(conman): Handle force upgrade events.
|
||||
/*
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
u.watchAppOutdated()
|
||||
}()
|
||||
*/
|
||||
|
||||
if u.credStorer == nil {
|
||||
log.Error("No credentials store is available")
|
||||
} else if err := u.loadUsersFromCredentialsStore(); err != nil {
|
||||
} else if err := u.loadUsersFromCredentialsStore(context.TODO()); err != nil {
|
||||
log.WithError(err).Error("Could not load all users from credentials store")
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *Users) loadUsersFromCredentialsStore() (err error) {
|
||||
func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
userIDs, err := u.credStorer.List()
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
for _, userID := range userIDs {
|
||||
l := log.WithField("user", userID)
|
||||
|
||||
user, newUserErr := newUser(u.panicHandler, userID, u.events, u.credStorer, u.clientManager, u.storeFactory)
|
||||
if newUserErr != nil {
|
||||
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
|
||||
user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Warn("Could not create user, skipping")
|
||||
continue
|
||||
}
|
||||
|
||||
u.users = append(u.users, user)
|
||||
|
||||
if initUserErr := user.init(); initUserErr != nil {
|
||||
l.WithField("user", userID).WithError(initUserErr).Warn("Could not initialise user")
|
||||
if creds.IsConnected() {
|
||||
if err := u.loadConnectedUser(ctx, user, creds); err != nil {
|
||||
logrus.WithError(err).Warn("Could not load connected user")
|
||||
}
|
||||
} else {
|
||||
logrus.Warn("User is disconnected and must be connected manually")
|
||||
|
||||
if err := u.loadDisconnectedUser(ctx, user, creds); err != nil {
|
||||
logrus.WithError(err).Warn("Could not load disconnected user")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *Users) watchAppOutdated() {
|
||||
ch := make(chan string)
|
||||
|
||||
u.events.Add(events.UpgradeApplicationEvent, ch)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
isApplicationOutdated = true
|
||||
u.closeAllConnections()
|
||||
|
||||
case <-u.stopAll:
|
||||
return
|
||||
}
|
||||
}
|
||||
func (u *Users) loadDisconnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
|
||||
// FIXME(conman): We shouldn't be creating unauthorized clients... this is hacky, just to avoid huge refactor!
|
||||
return user.connect(ctx, u.clientManager.NewClient("", "", "", time.Time{}), creds)
|
||||
}
|
||||
|
||||
// watchAPIAuths receives auths from the client manager and sends them to the appropriate user.
|
||||
func (u *Users) watchAPIAuths() {
|
||||
for {
|
||||
select {
|
||||
case auth := <-u.clientManager.GetAuthUpdateChannel():
|
||||
log.Debug("Users received auth from ClientManager")
|
||||
|
||||
user, ok := u.hasUser(auth.UserID)
|
||||
if !ok {
|
||||
log.WithField("userID", auth.UserID).Info("User not available for auth update")
|
||||
continue
|
||||
}
|
||||
|
||||
if auth.Auth != nil {
|
||||
user.updateAuthToken(auth.Auth)
|
||||
} else if err := user.logout(); err != nil {
|
||||
log.WithError(err).
|
||||
WithField("userID", auth.UserID).
|
||||
Error("User logout failed while watching API auths")
|
||||
}
|
||||
|
||||
case <-u.stopAll:
|
||||
return
|
||||
}
|
||||
func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
|
||||
uid, ref, err := creds.SplitAPIToken()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not get user's refresh token")
|
||||
}
|
||||
|
||||
client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref)
|
||||
if err != nil {
|
||||
// FIXME(conman): This is a problem... if we weren't able to create a new client due to internet,
|
||||
// we need to be able to retry later, but I deleted all the hacky "retry auth if necessary" stuff...
|
||||
return user.connect(ctx, u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
|
||||
}
|
||||
|
||||
// Update the user's credentials with the latest auth used to connect this user.
|
||||
if creds, err = u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil {
|
||||
return errors.Wrap(err, "could not create get user's refresh token")
|
||||
}
|
||||
|
||||
return user.connect(ctx, client, creds)
|
||||
}
|
||||
|
||||
func (u *Users) watchAppOutdated() {
|
||||
// FIXME(conman): handle force upgrade events.
|
||||
|
||||
/*
|
||||
ch := make(chan string)
|
||||
|
||||
u.events.Add(events.UpgradeApplicationEvent, ch)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
isApplicationOutdated = true
|
||||
u.closeAllConnections()
|
||||
|
||||
case <-u.stopAll:
|
||||
return
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
func (u *Users) closeAllConnections() {
|
||||
@ -192,63 +198,45 @@ func (u *Users) closeAllConnections() {
|
||||
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
|
||||
u.crashBandicoot(username)
|
||||
|
||||
// We need to use anonymous client because we don't yet have userID and so can't save auth tokens yet.
|
||||
authClient = u.clientManager.GetAnonymousClient()
|
||||
|
||||
authInfo, err := authClient.AuthInfo(username)
|
||||
if err != nil {
|
||||
log.WithField("username", username).WithError(err).Error("Could not get auth info for user")
|
||||
return
|
||||
}
|
||||
|
||||
if auth, err = authClient.Auth(username, password, authInfo); err != nil {
|
||||
log.WithField("username", username).WithError(err).Error("Could not get auth for user")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
return u.clientManager.NewClientWithLogin(context.TODO(), username, password)
|
||||
}
|
||||
|
||||
// FinishLogin finishes the login procedure and adds the user into the credentials store.
|
||||
func (u *Users) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPassphrase string) (user *User, err error) { //nolint[funlen]
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("Login not finished; removing auth session")
|
||||
if delAuthErr := authClient.DeleteAuth(); delAuthErr != nil {
|
||||
log.WithError(delAuthErr).Error("Failed to clear login session after unlock")
|
||||
}
|
||||
}
|
||||
// The anonymous client will be removed from list and authentication will not be deleted.
|
||||
authClient.Logout()
|
||||
}()
|
||||
|
||||
apiUser, hashedPassphrase, err := getAPIUser(authClient, mbPassphrase)
|
||||
func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen]
|
||||
apiUser, passphrase, err := getAPIUser(context.TODO(), client, password)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to get API user")
|
||||
return
|
||||
return nil, errors.Wrap(err, "failed to get API user")
|
||||
}
|
||||
|
||||
log.Info("Got API user")
|
||||
if user, ok := u.hasUser(apiUser.ID); ok {
|
||||
if user.IsConnected() {
|
||||
if err := client.AuthDelete(context.TODO()); err != nil {
|
||||
logrus.WithError(err).Warn("Failed to delete new auth session")
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if user, ok = u.hasUser(apiUser.ID); ok {
|
||||
if err = u.connectExistingUser(user, auth, hashedPassphrase); err != nil {
|
||||
log.WithError(err).Error("Failed to connect existing user")
|
||||
return
|
||||
return nil, errors.New("user is already connected")
|
||||
}
|
||||
} else {
|
||||
if err = u.addNewUser(apiUser, auth, hashedPassphrase); err != nil {
|
||||
log.WithError(err).Error("Failed to add new user")
|
||||
return
|
||||
|
||||
// Update the user's credentials with the latest auth used to connect this user.
|
||||
if _, err := u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load user credentials")
|
||||
}
|
||||
|
||||
// Update the password in case the user changed it.
|
||||
creds, err := u.credStorer.UpdatePassword(apiUser.ID, string(passphrase))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to update password of user in credentials store")
|
||||
}
|
||||
|
||||
if err := user.connect(context.TODO(), client, creds); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to reconnect existing user")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Old credentials use username as key (user ID) which needs to be removed
|
||||
// once user logs in again with proper ID fetched from API.
|
||||
if _, ok := u.hasUser(apiUser.Name); ok {
|
||||
if err := u.DeleteUser(apiUser.Name, true); err != nil {
|
||||
log.WithError(err).Error("Failed to delete old user")
|
||||
}
|
||||
if err := u.addNewUser(context.TODO(), client, apiUser, auth, passphrase); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to add new user")
|
||||
}
|
||||
|
||||
u.events.Emit(events.UserRefreshEvent, apiUser.ID)
|
||||
@ -256,107 +244,63 @@ func (u *Users) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPassphr
|
||||
return u.GetUser(apiUser.ID)
|
||||
}
|
||||
|
||||
// connectExistingUser connects an existing user.
|
||||
func (u *Users) connectExistingUser(user *User, auth *pmapi.Auth, hashedPassphrase string) (err error) {
|
||||
if user.IsConnected() {
|
||||
return errors.New("user is already connected")
|
||||
}
|
||||
|
||||
log.Info("Connecting existing user")
|
||||
|
||||
// Update the user's password in the cred store in case they changed it.
|
||||
if err = u.credStorer.UpdatePassword(user.ID(), hashedPassphrase); err != nil {
|
||||
return errors.Wrap(err, "failed to update password of user in credentials store")
|
||||
}
|
||||
|
||||
client := u.clientManager.GetClient(user.ID())
|
||||
|
||||
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh auth token of new client")
|
||||
}
|
||||
|
||||
if err = u.credStorer.UpdateToken(user.ID(), auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to update token of user in credentials store")
|
||||
}
|
||||
|
||||
if err = user.init(); err != nil {
|
||||
return errors.Wrap(err, "failed to initialise user")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// addNewUser adds a new user.
|
||||
func (u *Users) addNewUser(apiUser *pmapi.User, auth *pmapi.Auth, hashedPassphrase string) (err error) {
|
||||
func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
client := u.clientManager.GetClient(apiUser.ID)
|
||||
var emails []string
|
||||
|
||||
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh token in new client")
|
||||
}
|
||||
|
||||
if apiUser, err = client.CurrentUser(); err != nil {
|
||||
return errors.Wrap(err, "failed to update API user")
|
||||
}
|
||||
|
||||
var emails []string //nolint[prealloc]
|
||||
if u.useOnlyActiveAddresses {
|
||||
emails = client.Addresses().ActiveEmails()
|
||||
} else {
|
||||
emails = client.Addresses().AllEmails()
|
||||
}
|
||||
|
||||
if _, err = u.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), hashedPassphrase, emails); err != nil {
|
||||
return errors.Wrap(err, "failed to add user to credentials store")
|
||||
if _, err := u.credStorer.Add(apiUser.ID, apiUser.Name, auth.UID, auth.RefreshToken, string(passphrase), emails); err != nil {
|
||||
return errors.Wrap(err, "failed to add user credentials to credentials store")
|
||||
}
|
||||
|
||||
user, err := newUser(u.panicHandler, apiUser.ID, u.events, u.credStorer, u.clientManager, u.storeFactory)
|
||||
user, creds, err := newUser(u.panicHandler, apiUser.ID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create user")
|
||||
return errors.Wrap(err, "failed to create new user")
|
||||
}
|
||||
|
||||
// The user needs to be part of the users list in order for it to receive an auth during initialisation.
|
||||
u.users = append(u.users, user)
|
||||
|
||||
if err = user.init(); err != nil {
|
||||
u.users = u.users[:len(u.users)-1]
|
||||
return errors.Wrap(err, "failed to initialise user")
|
||||
if err := user.connect(ctx, client, creds); err != nil {
|
||||
return errors.Wrap(err, "failed to connect new user")
|
||||
}
|
||||
|
||||
if err := u.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel)); err != nil {
|
||||
log.WithError(err).Error("Failed to send metric")
|
||||
}
|
||||
|
||||
return err
|
||||
u.users = append(u.users, user)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAPIUser(client pmapi.Client, mbPassphrase string) (user *pmapi.User, hashedPassphrase string, err error) {
|
||||
salt, err := client.AuthSalt()
|
||||
func getAPIUser(ctx context.Context, client pmapi.Client, password string) (*pmapi.User, []byte, error) {
|
||||
salt, err := client.AuthSalt(ctx)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not get salt")
|
||||
return nil, "", err
|
||||
return nil, nil, errors.Wrap(err, "failed to get salt")
|
||||
}
|
||||
|
||||
hashedPassphrase, err = pmapi.HashMailboxPassword(mbPassphrase, salt)
|
||||
passphrase, err := pmapi.HashMailboxPassword(password, salt)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not hash mailbox password")
|
||||
return nil, "", err
|
||||
return nil, nil, errors.Wrap(err, "failed to hash password")
|
||||
}
|
||||
|
||||
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
|
||||
if err = client.Unlock([]byte(hashedPassphrase)); err != nil {
|
||||
log.WithError(err).Error("Wrong mailbox password")
|
||||
return nil, "", ErrWrongMailboxPassword
|
||||
if err := client.Unlock(ctx, passphrase); err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to unlock client")
|
||||
}
|
||||
|
||||
if user, err = client.CurrentUser(); err != nil {
|
||||
log.WithError(err).Error("Could not load user data")
|
||||
return nil, "", err
|
||||
user, err := client.CurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to load user data")
|
||||
}
|
||||
|
||||
return user, hashedPassphrase, nil
|
||||
return user, passphrase, nil
|
||||
}
|
||||
|
||||
// GetUsers returns all added users into keychain (even logged out users).
|
||||
@ -452,11 +396,9 @@ func (u *Users) DeleteUser(userID string, clearStore bool) error {
|
||||
|
||||
// SendMetric sends a metric. We don't want to return any errors, only log them.
|
||||
func (u *Users) SendMetric(m metrics.Metric) error {
|
||||
c := u.clientManager.GetAnonymousClient()
|
||||
defer c.Logout()
|
||||
|
||||
cat, act, lab := m.Get()
|
||||
if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil {
|
||||
|
||||
if err := u.clientManager.SendSimpleMetric(context.Background(), string(cat), string(act), string(lab)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -472,24 +414,22 @@ func (u *Users) SendMetric(m metrics.Metric) error {
|
||||
// AllowProxy instructs the app to use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||
func (u *Users) AllowProxy() {
|
||||
u.clientManager.AllowProxy()
|
||||
// FIXME(conman): Support DoH.
|
||||
// u.apiManager.AllowProxy()
|
||||
}
|
||||
|
||||
// DisallowProxy instructs the app to not use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||
func (u *Users) DisallowProxy() {
|
||||
u.clientManager.DisallowProxy()
|
||||
// FIXME(conman): Support DoH.
|
||||
// u.apiManager.DisallowProxy()
|
||||
}
|
||||
|
||||
// CheckConnection returns whether there is an internet connection.
|
||||
// This should use the connection manager when it is eventually implemented.
|
||||
func (u *Users) CheckConnection() error {
|
||||
return u.clientManager.CheckConnection()
|
||||
}
|
||||
|
||||
// StopWatchers stops all goroutines.
|
||||
func (u *Users) StopWatchers() {
|
||||
close(u.stopAll)
|
||||
// FIXME(conman): Other parts of bridge that rely on this method should register as a connection observer.
|
||||
panic("TODO: register as a connection observer to get this information")
|
||||
}
|
||||
|
||||
// hasUser returns whether the struct currently has a user with ID `id`.
|
||||
|
||||
@ -20,8 +20,8 @@ package users
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
time "time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -49,20 +49,19 @@ func TestNewUsersWithDisconnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
// Basically every call client has get client manager.
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
|
||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
)
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||
}
|
||||
|
||||
/*
|
||||
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
@ -132,6 +131,7 @@ func TestNewUsersFirstStart(t *testing.T) {
|
||||
|
||||
testNewUsers(t, m)
|
||||
}
|
||||
*/
|
||||
|
||||
func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) {
|
||||
users := testNewUsers(t, m)
|
||||
|
||||
@ -48,18 +48,17 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
var (
|
||||
testAuth = &pmapi.Auth{ //nolint[gochecknoglobals]
|
||||
RefreshToken: "tok",
|
||||
}
|
||||
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
|
||||
RefreshToken: "reftok",
|
||||
UID: "uid",
|
||||
AccessToken: "acc",
|
||||
RefreshToken: "ref",
|
||||
}
|
||||
|
||||
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "user",
|
||||
Name: "username",
|
||||
Emails: "user@pm.me",
|
||||
APIToken: "token",
|
||||
APIToken: "uid:acc",
|
||||
MailboxPassword: "pass",
|
||||
BridgePassword: "0123456789abcdef",
|
||||
Version: "v1",
|
||||
@ -67,11 +66,12 @@ var (
|
||||
IsHidden: false,
|
||||
IsCombinedAddressMode: true,
|
||||
}
|
||||
|
||||
testCredentialsSplit = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "users",
|
||||
Name: "usersname",
|
||||
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
|
||||
APIToken: "token",
|
||||
APIToken: "uid:acc",
|
||||
MailboxPassword: "pass",
|
||||
BridgePassword: "0123456789abcdef",
|
||||
Version: "v1",
|
||||
@ -79,6 +79,7 @@ var (
|
||||
IsHidden: false,
|
||||
IsCombinedAddressMode: false,
|
||||
}
|
||||
|
||||
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "user",
|
||||
Name: "username",
|
||||
@ -92,6 +93,19 @@ var (
|
||||
IsCombinedAddressMode: true,
|
||||
}
|
||||
|
||||
testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "users",
|
||||
Name: "usersname",
|
||||
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
|
||||
APIToken: "",
|
||||
MailboxPassword: "",
|
||||
BridgePassword: "0123456789abcdef",
|
||||
Version: "v1",
|
||||
Timestamp: 123456789,
|
||||
IsHidden: false,
|
||||
IsCombinedAddressMode: false,
|
||||
}
|
||||
|
||||
testPMAPIUser = &pmapi.User{ //nolint[gochecknoglobals]
|
||||
ID: "user",
|
||||
Name: "username",
|
||||
@ -130,12 +144,12 @@ type mocks struct {
|
||||
ctrl *gomock.Controller
|
||||
locator *usersmocks.MockLocator
|
||||
PanicHandler *usersmocks.MockPanicHandler
|
||||
clientManager *usersmocks.MockClientManager
|
||||
credentialsStore *usersmocks.MockCredentialsStorer
|
||||
storeMaker *usersmocks.MockStoreMaker
|
||||
eventListener *MockListener
|
||||
|
||||
pmapiClient *pmapimocks.MockClient
|
||||
clientManager *pmapimocks.MockManager
|
||||
pmapiClient *pmapimocks.MockClient
|
||||
|
||||
storeCache *store.Cache
|
||||
}
|
||||
@ -171,12 +185,12 @@ func initMocks(t *testing.T) mocks {
|
||||
ctrl: mockCtrl,
|
||||
locator: usersmocks.NewMockLocator(mockCtrl),
|
||||
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
|
||||
clientManager: usersmocks.NewMockClientManager(mockCtrl),
|
||||
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
|
||||
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
|
||||
eventListener: NewMockListener(mockCtrl),
|
||||
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
clientManager: pmapimocks.NewMockManager(mockCtrl),
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
|
||||
storeCache: store.NewCache(cacheFile.Name()),
|
||||
}
|
||||
@ -189,7 +203,7 @@ func initMocks(t *testing.T) mocks {
|
||||
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
|
||||
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
|
||||
require.NoError(t, err, "could not get temporary file for store db")
|
||||
return store.New(sentryReporter, m.PanicHandler, user, m.clientManager, m.eventListener, dbFile.Name(), m.storeCache)
|
||||
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
|
||||
}).AnyTimes()
|
||||
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
|
||||
|
||||
@ -198,46 +212,42 @@ func initMocks(t *testing.T) mocks {
|
||||
|
||||
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
|
||||
// Events are asynchronous
|
||||
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil),
|
||||
|
||||
// Init for user.
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.credentialsStore.EXPECT().Get(testCredentials.UserID).Return(testCredentials, nil),
|
||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
||||
m.credentialsStore.EXPECT().UpdateToken(testCredentials.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().UpdatePassword(testCredentials.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
// Init for users.
|
||||
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
|
||||
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.credentialsStore.EXPECT().Get(testCredentialsSplit.UserID).Return(testCredentialsSplit, nil),
|
||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
||||
m.credentialsStore.EXPECT().UpdateToken(testCredentialsSplit.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentialsSplit, nil),
|
||||
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsSplit.UserID, testCredentialsSplit.MailboxPassword).Return(testCredentialsSplit, nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses),
|
||||
)
|
||||
|
||||
users := testNewUsers(t, m)
|
||||
|
||||
user, _ := users.GetUser("user")
|
||||
mockAuthUpdate(user, "reftok", m)
|
||||
|
||||
user, _ = users.GetUser("user")
|
||||
mockAuthUpdate(user, "reftok", m)
|
||||
|
||||
return users
|
||||
return testNewUsers(t, m)
|
||||
}
|
||||
|
||||
func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
|
||||
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
||||
m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan pmapi.ClientAuth))
|
||||
// FIXME(conman): How to handle force upgrade?
|
||||
// m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
||||
|
||||
users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true)
|
||||
|
||||
@ -256,8 +266,8 @@ func TestClearData(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
// m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
// m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
@ -267,13 +277,11 @@ func TestClearData(t *testing.T) {
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
|
||||
|
||||
m.pmapiClient.EXPECT().Logout()
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
|
||||
|
||||
m.pmapiClient.EXPECT().Logout()
|
||||
m.credentialsStore.EXPECT().Logout("users").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil)
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||
m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil)
|
||||
|
||||
m.locator.EXPECT().Clear()
|
||||
|
||||
@ -285,9 +293,9 @@ func TestClearData(t *testing.T) {
|
||||
func mockEventLoopNoAction(m mocks) {
|
||||
// Set up mocks for starting the store's event loop (in store.New).
|
||||
// The event loop runs in another goroutine so this might happen at any time.
|
||||
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
}
|
||||
|
||||
func mockConnectedUser(m mocks) {
|
||||
@ -295,27 +303,13 @@ func mockConnectedUser(m mocks) {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
// m.pmapiClient.EXPECT().AuthRefresh("uid:acc").Return(testAuthRefresh, nil),
|
||||
|
||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
|
||||
// Set up mocks for store initialisation for the authorized user.
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
)
|
||||
}
|
||||
|
||||
// mockAuthUpdate simulates users calling UpdateAuthToken on the given user.
|
||||
// This would normally be done by users when it receives an auth from the ClientManager,
|
||||
// but as we don't have a full users instance here, we do this manually.
|
||||
func mockAuthUpdate(user *User, token string, m mocks) {
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":"+token).Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(token), nil),
|
||||
)
|
||||
|
||||
user.updateAuthToken(refreshWithToken(token))
|
||||
|
||||
waitForEvents()
|
||||
}
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
// IsAuthorized returns whether the user has received an Auth from the API yet.
|
||||
func (u *User) IsAuthorized() bool {
|
||||
return u.isAuthorized
|
||||
}
|
||||
@ -40,8 +40,8 @@ type Builder struct {
|
||||
}
|
||||
|
||||
type Fetcher interface {
|
||||
GetMessage(string) (*pmapi.Message, error)
|
||||
GetAttachment(string) (io.ReadCloser, error)
|
||||
GetMessage(context.Context, string) (*pmapi.Message, error)
|
||||
GetAttachment(context.Context, string) (io.ReadCloser, error)
|
||||
KeyRingForAddressID(string) (*crypto.KeyRing, error)
|
||||
}
|
||||
|
||||
|
||||
@ -1,309 +0,0 @@
|
||||
# Do not modify this file!
|
||||
It is here for historical reasons only. All changes should be documented in the
|
||||
Changelog at the root of this repository.
|
||||
|
||||
|
||||
# Changelog for API
|
||||
> NOTE we are using versioning for go-pmapi in format `major.minor.bugfix`
|
||||
> * major stays at version 1 for the forseeable future
|
||||
> * minor is increased when a force upgrade happens or in case of major breaking changes
|
||||
> * patch is increased when new features are added
|
||||
|
||||
## v1.0.16
|
||||
|
||||
### Fixed
|
||||
* Potential crash when reporting cert pin failure
|
||||
|
||||
## v1.0.15
|
||||
|
||||
### Changed
|
||||
* Merge only 50 events into one
|
||||
* Response header timeout increased from 10s to 30s
|
||||
|
||||
### Fixed
|
||||
* Make keyring unlocking threadsafe
|
||||
|
||||
## v1.0.14
|
||||
|
||||
### Added
|
||||
* Config for disabling TLS cert fingerprint checking
|
||||
|
||||
### Fixed
|
||||
* Ensure sensitive stuff is cleared on client logout even if requests fail
|
||||
|
||||
## v1.0.13
|
||||
|
||||
### Fixed
|
||||
* Correctly set Transport in http client
|
||||
|
||||
## v1.0.12
|
||||
|
||||
### Changed
|
||||
* Only `http.RoundTripper` interface is needed instead of full `http.Transport` struct
|
||||
|
||||
### Added
|
||||
* GODT-61 (and related): Use DoH to find and switch to a proxy server if the API becomes unreachable
|
||||
* GODT-67 added random wait to not cause spikes on server after StatusTooManyRequests
|
||||
|
||||
### Fixed
|
||||
* FirstReadTimeout was wrongly timeout of the whole request including repeating ones, now it's really only timeout for the first read
|
||||
|
||||
## v1.0.11
|
||||
|
||||
### Added
|
||||
* GODT-53 `Message.Type` added with constants `MessageType*`
|
||||
|
||||
## v1.0.10
|
||||
|
||||
### Added
|
||||
* GODT-55 exporting DANGEROUSLYSetUID
|
||||
|
||||
### Changed
|
||||
* The full communication between clien and API is logged if logrus level is trace
|
||||
|
||||
## v1.0.9
|
||||
|
||||
### Fixed
|
||||
* Use correct address type value (because API starts counting from 1 but we were counting from 0)
|
||||
|
||||
## v1.0.8
|
||||
|
||||
### Added
|
||||
* Introdcution of connection manager
|
||||
|
||||
### Fixed
|
||||
* Deadlock during the auth-refresh
|
||||
* Fixed an issue where some events were being discarded when merging
|
||||
|
||||
## v1.0.7
|
||||
|
||||
### Changed
|
||||
* The given access token is saved during auth refresh if none was available yet
|
||||
|
||||
|
||||
## v1.0.6
|
||||
|
||||
### Added
|
||||
* `ClientConfig.Timeout` to be able to configure the whole timeout of request
|
||||
* `ClientConfig.FirstReadTimeout` to be able to configure the timeout of request to the first byte
|
||||
* `ClientConfig.MinSpeed` to be able to configure the timeout when the connection is too slow (limitation in minimum bytes per second)
|
||||
* Set default timeouts for http.Transport with certificate pinning
|
||||
|
||||
### Changed
|
||||
* http.Client by default uses ProxyFromEnvironment to support HTTP_PROXY and HTTPS_PROXY environment variables
|
||||
|
||||
## v1.0.5
|
||||
|
||||
### Added
|
||||
* `ContentTypeMultipartEncrypted` MIME content type for encrypted email
|
||||
* `MessageCounts` in event struct
|
||||
|
||||
## v1.0.4
|
||||
|
||||
### Added
|
||||
* `PMKeys` for parsing and reading KeyRing
|
||||
* `clearableKey` to rewrite memory
|
||||
* Proton/backend-communication#25 Unlock with tokens (OneKey2RuleThemAll Phase I)
|
||||
|
||||
### Changed
|
||||
* Update of gopenpgp: convert JSON to KeyRing in PMAPI
|
||||
* `user.KeyRing` -> `user.KeyRing()`
|
||||
* typo `client.GetAddresses()`
|
||||
|
||||
### Removed
|
||||
* `address.KeyRing`
|
||||
|
||||
## v1.0.2 v1.0.3
|
||||
|
||||
### Changed
|
||||
* Fixed capitalisation in a few places
|
||||
* Added /metrics API route
|
||||
* Changed function names to be compliant with go linter
|
||||
* Encrypt with primary key only
|
||||
* Fix `client.doBuffered` - closing body before handling unauthorized request
|
||||
* go-pm-crypto -> GopenPGP
|
||||
* redefine old functions in `keyring.go`
|
||||
* `attachment.Decrypt` drops returning signature (does signature check by default)
|
||||
* `attachment.Encrypt` is using readers instead of writers
|
||||
* `attachment.DetachedSign` drops writer param and returns signature as a reader
|
||||
* `message.Decrypt` drops returning signature (does signature check by default)
|
||||
* Changed TLS report URL to https://reports.protonmail.ch/reports/tls
|
||||
* Moved from current to soon TLS pin
|
||||
|
||||
## v1.0.1
|
||||
|
||||
### Removed
|
||||
* `ClientID` from all auth routes
|
||||
* `ErrorDescription` from error
|
||||
|
||||
## v1.0.0
|
||||
|
||||
### Changed
|
||||
* `client.AuthInfo` does return 2FA information only when authenticated, for the first login information available in `Auth.HasTwoFactor`
|
||||
* `client.Auth` does not accept 2FA code in favor of `client.Auth2FA`
|
||||
* `client.Unlock` supports only new way of unlock with directly available access token
|
||||
|
||||
### Added
|
||||
* `Res.StatusCode` to pass HTTP status code to responses
|
||||
* `Auth.HasTwoFactor` method to determine whether account has enabled 2FA (same as `AuthInfo.HasTwoFactor`)
|
||||
* `Auth2FA*` structs for 2FA endpoint
|
||||
* `client.Auth2FA` method to fully unlock session with 2FA code
|
||||
* `ErrUnauthorized` when request cannot be authorized
|
||||
* `ErrBad2FACode` when bad 2FA and user cannot try again
|
||||
* `ErrBad2FACodeTryAgain` when bad 2FA but user can try again
|
||||
|
||||
## 2019-08-06
|
||||
|
||||
### Added
|
||||
* Send TLS issue report to API
|
||||
* Cert fingerpring with `TLSPinning` struct
|
||||
* Check API certificate fingerprint and verify hostname
|
||||
|
||||
### Changed
|
||||
* Using `AddressID` for `/messge/count` and `/conversations/count`
|
||||
* Less of copying of responses from the server in the memory
|
||||
|
||||
## 2019-08-01
|
||||
* low case for `sirupsen`
|
||||
* using go modules
|
||||
|
||||
## 2019-07-15
|
||||
|
||||
### Changed
|
||||
* `client.Auths` field is removed in favor of function `client.SetAuths` which opens possibility to use interface
|
||||
|
||||
## 2019-05-18
|
||||
|
||||
### Changed
|
||||
* proton/backend-communication#11 x-pm-uid sent always for `/auth/refresh`
|
||||
* proton/backend-communication#11 UID never changes
|
||||
|
||||
## 2019-05-28
|
||||
|
||||
### Added
|
||||
* New test server patern using callbacks
|
||||
* Responses are read from json files
|
||||
|
||||
### Changed
|
||||
* `auth_tests.go` to new callback server pattern
|
||||
* Linter fixes for tests
|
||||
|
||||
### Removed
|
||||
* `TestClient_Do_expired` due to no effect, use `DoUnauthorized` instead
|
||||
|
||||
## 2019-05-24
|
||||
* Help functions for test
|
||||
* CI with Lint
|
||||
|
||||
## 2019-05-23
|
||||
* Log userID
|
||||
|
||||
## 2019-05-21
|
||||
* Fix unlocking user keys
|
||||
|
||||
## 2019-04-25
|
||||
|
||||
### Changed
|
||||
* rename `Uid` -> `UID` proton/backend-communication#11
|
||||
|
||||
## 2019-04-09
|
||||
|
||||
### Added
|
||||
* sending attachments as zip `application/octet-stream`
|
||||
* function `ReportReq.AddAttachment()`
|
||||
* data memeber `ReportReq.Attachments`
|
||||
* general function to report bug `client.Report(req ReportReq)` with object as parameter
|
||||
|
||||
### Changed
|
||||
* `client.ReportBug` and `client.ReportBugWithClient` functions are obsolete and they uses `client.Report(req ReportReq)`
|
||||
* `client.ReportCrash` is obsolete. Use sentry instead
|
||||
* `Api`->`API`, `Uid`->`UID`
|
||||
|
||||
## 2019-03-13
|
||||
* user id in raven
|
||||
* add file position of panic sender
|
||||
|
||||
## 2019-03-06
|
||||
* #30 update `pm-crypto` to store `KeyRing.FirstKeyID`
|
||||
* #30 Add key salt to `Auth` object from `GetKeySalts` request
|
||||
* #30 Add route `GET /keys/salt`
|
||||
* removed unused `PmCrypto`
|
||||
|
||||
## 2019-02-20
|
||||
* removed unused `decryptAccessToken`
|
||||
|
||||
## 2019-01-21
|
||||
* #29 Parsing all goroutines from pprof
|
||||
* #29 Sentry `Threads` implementation
|
||||
* #29 using sentry for crashes
|
||||
|
||||
## 2019-01-07
|
||||
* refactor `pmapi.DecryptString` -> `pmcrypto.KeyRing.DecryptString`
|
||||
* fixed tests
|
||||
* `crypto` -> `pmcrypto`
|
||||
* refactoring code using repos `go-pm-crypto`, `go-pm-mime` and `go-srp`
|
||||
|
||||
|
||||
## 2018-12-10
|
||||
* #26 adding `Flags` field to message
|
||||
* #26 removing fields deprecated by `Flags`: `IsEncrypted`, `Type`, `IsReplied`, `IsRepliedAll`, `IsForwarded`
|
||||
* #26 removing deprecated consts (see #26 for replacement)
|
||||
* #26 fixing tests (compiling not working)
|
||||
|
||||
## 2018-11-19
|
||||
|
||||
### Added
|
||||
* Wait and retry from `DoJson` if banned from api
|
||||
|
||||
### Changed
|
||||
* `ErrNoInternet` -> `ErrAPINotReachable`
|
||||
* Adding codes for force upgrade: 5004 and 5005
|
||||
* Adding codes for API offline: 7001
|
||||
* Adding codes for BansRequests: 85131
|
||||
|
||||
## 2018-09-18
|
||||
|
||||
### Added
|
||||
* `client.decryptAccessToken` if privateKey is received (tested with local api) #23
|
||||
|
||||
### Changed
|
||||
* added fields to User
|
||||
* local config TLS skip verify
|
||||
|
||||
## 2018-09-06
|
||||
|
||||
### Changed
|
||||
* decrypt token only if needed
|
||||
|
||||
### Broken
|
||||
* Tests are not working
|
||||
|
||||
## APIv3 UPDATE (2018-08-01)
|
||||
* issue Desktop-Bridge#561
|
||||
|
||||
### Added
|
||||
* Key flag consts
|
||||
* `EventAddress`
|
||||
* `MailSettings` object and route call
|
||||
* `Client.KeyRingForAddressID`
|
||||
* `AuthInfo.HasTwoFactor()`
|
||||
* `Auth.HasMailboxPassword()`
|
||||
|
||||
### Changed
|
||||
* Addresses are part of client
|
||||
* Update user updates also addresses
|
||||
* `BodyKey` and `AttachmentKey` contains `Key` and `Algorithm`
|
||||
* `keyPair` (not use Pubkey) -> `pmKeyObject`
|
||||
* lots of indent
|
||||
* bugs route
|
||||
* two factor (ready to U2F)
|
||||
* Reorder some to match order in doc (easier to )
|
||||
* omit address Order when empty
|
||||
* update user and addresses in `CurrentUser()`
|
||||
* `User.Unlock()` -> `Client.UnlockAddresses()`
|
||||
* `AuthInfo.Uid` -> `AuthInfo.Uid()`
|
||||
* `User.Addresses` -> `Client.Addresses()`
|
||||
|
||||
### Removed
|
||||
* User v3 removed plenty (now in settings)
|
||||
* Message v3 removed plenty (Starred is label)
|
||||
@ -1,19 +0,0 @@
|
||||
export GO111MODULE=on
|
||||
|
||||
LINTVER="v1.21.0"
|
||||
LINTSRC="https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh"
|
||||
|
||||
check-has-go:
|
||||
@which go || (echo "Install Go-lang!" && exit 1)
|
||||
|
||||
install-dev-dependencies: install-linter
|
||||
|
||||
install-linter: check-has-go
|
||||
curl -sfL $(LINTSRC) | sh -s -- -b $(shell go env GOPATH)/bin $(LINTVER)
|
||||
|
||||
lint:
|
||||
which golangci-lint || $(MAKE) install-linter
|
||||
golangci-lint run ./... \
|
||||
|
||||
test:
|
||||
go test -run=${TESTRUN} ./...
|
||||
@ -18,10 +18,12 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// Address statuses.
|
||||
@ -80,11 +82,6 @@ type Address struct {
|
||||
// AddressList is a list of addresses.
|
||||
type AddressList []*Address
|
||||
|
||||
type AddressesRes struct {
|
||||
Res
|
||||
Addresses AddressList
|
||||
}
|
||||
|
||||
// ByID returns an address by id. Returns nil if no address is found.
|
||||
func (l AddressList) ByID(id string) *Address {
|
||||
for _, addr := range l {
|
||||
@ -164,40 +161,22 @@ func ConstructAddress(headerEmail string, addressEmail string) string {
|
||||
}
|
||||
|
||||
// GetAddresses requests all of current user addresses (without pagination).
|
||||
func (c *client) GetAddresses() (addresses AddressList, err error) {
|
||||
req, err := c.NewRequest("GET", "/addresses", nil)
|
||||
if err != nil {
|
||||
return
|
||||
func (c *client) GetAddresses(ctx context.Context) (addresses AddressList, err error) {
|
||||
var res struct {
|
||||
Addresses []*Address
|
||||
}
|
||||
|
||||
var res AddressesRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetResult(&res).Get("/addresses")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Addresses, res.Err()
|
||||
return res.Addresses, nil
|
||||
}
|
||||
|
||||
func (c *client) ReorderAddresses(addressIDs []string) (err error) {
|
||||
var reqBody struct {
|
||||
AddressIDs []string
|
||||
}
|
||||
|
||||
reqBody.AddressIDs = addressIDs
|
||||
|
||||
req, err := c.NewJSONRequest("PUT", "/addresses/order", reqBody)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var addContactsRes AddContactsResponse
|
||||
if err = c.DoJSON(req, &addContactsRes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = c.UpdateUser()
|
||||
|
||||
return
|
||||
func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) (err error) {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// Addresses returns the addresses stored in the client object itself rather than fetching from the API.
|
||||
|
||||
@ -21,13 +21,13 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type header textproto.MIMEHeader
|
||||
@ -138,12 +138,6 @@ func (a *Attachment) DetachedSign(kr *crypto.KeyRing, att io.Reader) (signed io.
|
||||
return signAttachment(kr, att)
|
||||
}
|
||||
|
||||
type CreateAttachmentRes struct {
|
||||
Res
|
||||
|
||||
Attachment *Attachment
|
||||
}
|
||||
|
||||
func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.Reader) (err error) {
|
||||
// Create metadata fields.
|
||||
if err = w.WriteField("Filename", att.Name); err != nil {
|
||||
@ -185,91 +179,37 @@ func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.R
|
||||
// CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID.
|
||||
//
|
||||
// The returned created attachment contains the new attachment ID and its size.
|
||||
func (c *client) CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (*Attachment, error) {
|
||||
req, w, err := c.NewMultipartRequest("POST", "/mail/v4/attachments")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (c *client) CreateAttachment(ctx context.Context, att *Attachment, attData io.Reader, sigData io.Reader) (*Attachment, error) {
|
||||
var res struct {
|
||||
Attachment *Attachment
|
||||
}
|
||||
|
||||
cx, cancel := context.WithCancel(req.Context())
|
||||
defer cancel()
|
||||
req = req.WithContext(cx)
|
||||
|
||||
// We will write the request as long as it is sent to the API.
|
||||
var res CreateAttachmentRes
|
||||
done := make(chan error, 1)
|
||||
go (func() {
|
||||
done <- c.DoJSON(req, &res)
|
||||
})()
|
||||
|
||||
if err := writeAttachment(w.Writer, att, r, sig); err != nil {
|
||||
_ = w.Close()
|
||||
return nil, err
|
||||
}
|
||||
_ = w.Close()
|
||||
|
||||
if err := <-done; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := res.Err(); err != nil {
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetResult(&res).
|
||||
SetMultipartFormData(map[string]string{
|
||||
"Filename": att.Name,
|
||||
"MessageID": att.MessageID,
|
||||
"MIMEType": att.MIMEType,
|
||||
"ContentID": att.ContentID,
|
||||
}).
|
||||
SetMultipartField("DataPacket", "DataPacket.pgp", "application/octet-stream", attData).
|
||||
SetMultipartField("Signature", "Signature.pgp", "application/octet-stream", sigData).
|
||||
Post("/mail/v4/attachments")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Attachment, nil
|
||||
}
|
||||
|
||||
type UpdateAttachmentSignatureReq struct {
|
||||
Signature string
|
||||
}
|
||||
|
||||
func (c *client) UpdateAttachmentSignature(attachmentID, signature string) (err error) {
|
||||
updateReq := &UpdateAttachmentSignatureReq{signature}
|
||||
req, err := c.NewJSONRequest("PUT", "/mail/v4/attachments/"+attachmentID+"/signature", updateReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var res Res
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// DeleteAttachment removes an attachment. message is the message ID, att is the attachment ID.
|
||||
func (c *client) DeleteAttachment(attID string) (err error) {
|
||||
req, err := c.NewRequest("DELETE", "/mail/v4/attachments/"+attID, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var res Res
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = res.Err()
|
||||
return
|
||||
}
|
||||
|
||||
// GetAttachment gets an attachment's content. The returned data is encrypted.
|
||||
func (c *client) GetAttachment(id string) (att io.ReadCloser, err error) {
|
||||
if id == "" {
|
||||
err = errors.New("pmapi: cannot get an attachment with an empty id")
|
||||
return
|
||||
}
|
||||
|
||||
req, err := c.NewRequest("GET", "/mail/v4/attachments/"+id, nil)
|
||||
func (c *client) GetAttachment(ctx context.Context, attachmentID string) (att io.ReadCloser, err error) {
|
||||
res, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetDoNotParseResponse(true).Get("/mail/v4/attachments/" + attachmentID)
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := c.Do(req, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
att = res.Body
|
||||
return
|
||||
return res.RawBody(), nil
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ package pmapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -94,7 +95,7 @@ func TestAttachment_UnmarshalJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_CreateAttachment(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/attachments"))
|
||||
|
||||
contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
@ -136,12 +137,14 @@ func TestClient_CreateAttachment(t *testing.T) {
|
||||
t.Errorf("Invalid attachment packets: expected %v but got %v", testAttachment.KeyPackets, string(b))
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testCreateAttachmentBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
r := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted
|
||||
created, err := c.CreateAttachment(testAttachment, r, strings.NewReader(""))
|
||||
created, err := c.CreateAttachment(context.TODO(), testAttachment, r, strings.NewReader(""))
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while creating attachment, got:", err)
|
||||
}
|
||||
@ -151,34 +154,17 @@ func TestClient_CreateAttachment(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_DeleteAttachment(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "DELETE", "/mail/v4/attachments/"+testAttachment.ID))
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
if n, _ := b.ReadFrom(r.Body); n != 0 {
|
||||
t.Fatal("expected no body but have: ", b.String())
|
||||
}
|
||||
|
||||
fmt.Fprint(w, testDeleteAttachmentBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
err := c.DeleteAttachment(testAttachment.ID)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while deleting attachment, got:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_GetAttachment(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/mail/v4/attachments/"+testAttachment.ID))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testAttachmentCleartext)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
r, err := c.GetAttachment(testAttachment.ID)
|
||||
r, err := c.GetAttachment(context.TODO(), testAttachment.ID)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while getting attachment, got:", err)
|
||||
}
|
||||
|
||||
@ -1,407 +1,47 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/srp"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
var ErrBad2FACode = errors.New("incorrect 2FA code")
|
||||
var ErrBad2FACodeTryAgain = errors.New("incorrect 2FA code: please try again")
|
||||
|
||||
type AuthInfoReq struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
type U2FInfo struct {
|
||||
Challenge string
|
||||
RegisteredKeys []struct {
|
||||
Version string
|
||||
KeyHandle string
|
||||
}
|
||||
}
|
||||
|
||||
type TwoFactorInfo struct {
|
||||
Enabled int // 0 for disabled, 1 for OTP, 2 for U2F, 3 for both.
|
||||
TOTP int
|
||||
U2F U2FInfo
|
||||
}
|
||||
|
||||
func (twoFactor *TwoFactorInfo) hasTwoFactor() bool {
|
||||
return twoFactor.Enabled > 0
|
||||
}
|
||||
|
||||
// AuthInfo contains data used when authenticating a user. It should be
|
||||
// provided to Client.Auth(). Each AuthInfo can be used for only one login attempt.
|
||||
type AuthInfo struct {
|
||||
TwoFA *TwoFactorInfo `json:"2FA,omitempty"`
|
||||
|
||||
version int
|
||||
salt string
|
||||
modulus string
|
||||
srpSession string
|
||||
serverEphemeral string
|
||||
}
|
||||
|
||||
func (a *AuthInfo) HasTwoFactor() bool {
|
||||
if a.TwoFA == nil {
|
||||
return false
|
||||
}
|
||||
return a.TwoFA.hasTwoFactor()
|
||||
}
|
||||
|
||||
type AuthInfoRes struct {
|
||||
Res
|
||||
AuthInfo
|
||||
|
||||
Modulus string
|
||||
ServerEphemeral string
|
||||
Version int
|
||||
Salt string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
func (res *AuthInfoRes) getAuthInfo() *AuthInfo {
|
||||
info := &res.AuthInfo
|
||||
|
||||
// Some fields in AuthInfo are private, so we need to copy them from AuthRes
|
||||
// (private fields cannot be populated by json).
|
||||
info.version = res.Version
|
||||
info.salt = res.Salt
|
||||
info.modulus = res.Modulus
|
||||
info.srpSession = res.SRPSession
|
||||
info.serverEphemeral = res.ServerEphemeral
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
type AuthReq struct {
|
||||
Username string
|
||||
ClientProof string
|
||||
ClientEphemeral string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
// Auth contains data after a successful authentication. It should be provided to Client.Unlock().
|
||||
type Auth struct {
|
||||
accessToken string // Read from AuthRes.
|
||||
ExpiresIn int
|
||||
uid string // Read from AuthRes.
|
||||
RefreshToken string
|
||||
EventID string
|
||||
PasswordMode int
|
||||
TwoFA *TwoFactorInfo `json:"2FA,omitempty"`
|
||||
}
|
||||
|
||||
// UID returns the session UID from the Auth.
|
||||
// Only Auths generated from the /auth route will have the UID.
|
||||
// Auths generated from /auth/refresh are not required to.
|
||||
func (s *Auth) UID() string {
|
||||
return s.uid
|
||||
}
|
||||
|
||||
// GenToken generates a string token containing the session UID and refresh token.
|
||||
func (s *Auth) GenToken() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v:%v", s.UID(), s.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *Auth) HasTwoFactor() bool {
|
||||
if s.TwoFA == nil {
|
||||
return false
|
||||
}
|
||||
return s.TwoFA.hasTwoFactor()
|
||||
}
|
||||
|
||||
func (s *Auth) HasMailboxPassword() bool {
|
||||
return s.PasswordMode == 2
|
||||
}
|
||||
|
||||
type AuthRes struct {
|
||||
Res
|
||||
Auth
|
||||
|
||||
AccessToken string
|
||||
TokenType string
|
||||
|
||||
// UID is the session UID. This is only present in an initial Auth (/auth), not in a refreshed Auth (/auth/refresh).
|
||||
UID string
|
||||
|
||||
ServerProof string
|
||||
}
|
||||
|
||||
func (res *AuthRes) getAuth() *Auth {
|
||||
auth := &res.Auth
|
||||
|
||||
// Some fields in Auth are private, so we need to copy them from AuthRes
|
||||
// (private fields cannot be populated by json).
|
||||
auth.accessToken = res.AccessToken
|
||||
auth.uid = res.UID
|
||||
|
||||
return auth
|
||||
}
|
||||
|
||||
type Auth2FAReq struct {
|
||||
TwoFactorCode string
|
||||
|
||||
// Prepared for U2F:
|
||||
// U2F U2FRequest
|
||||
}
|
||||
|
||||
type Auth2FARes struct {
|
||||
Res
|
||||
}
|
||||
|
||||
type AuthRefreshReq struct {
|
||||
ResponseType string
|
||||
GrantType string
|
||||
RefreshToken string
|
||||
UID string
|
||||
RedirectURI string
|
||||
State string
|
||||
}
|
||||
|
||||
func (c *client) sendAuth(auth *Auth) {
|
||||
if auth != nil {
|
||||
c.log.WithField("auth", *auth).Debug("Client is sending auth to ClientManager")
|
||||
} else {
|
||||
c.log.Debug("Client is sending nil auth to ClientManager")
|
||||
}
|
||||
|
||||
if auth != nil {
|
||||
c.uid = auth.UID()
|
||||
c.accessToken = auth.accessToken
|
||||
}
|
||||
|
||||
c.cm.HandleAuth(ClientAuth{UserID: c.userID, Auth: auth})
|
||||
}
|
||||
|
||||
// AuthInfo gets authentication info for a user.
|
||||
func (c *client) AuthInfo(username string) (info *AuthInfo, err error) {
|
||||
infoReq := &AuthInfoReq{
|
||||
Username: username,
|
||||
}
|
||||
|
||||
req, err := c.NewJSONRequest("POST", "/auth/info", infoReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var infoRes AuthInfoRes
|
||||
if err = c.DoJSON(req, &infoRes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
info, err = infoRes.getAuthInfo(), infoRes.Err()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func srpProofsFromInfo(info *AuthInfo, username, password string, fallbackVersion int) (proofs *srp.SrpProofs, err error) {
|
||||
version := info.version
|
||||
if version == 0 {
|
||||
version = fallbackVersion
|
||||
}
|
||||
|
||||
srpAuth, err := srp.NewSrpAuth(version, username, password, info.salt, info.modulus, info.serverEphemeral)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
proofs, err = srpAuth.GenerateSrpProofs(2048)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *client) tryAuth(username, password string, info *AuthInfo, fallbackVersion int) (res *AuthRes, err error) {
|
||||
proofs, err := srpProofsFromInfo(info, username, password, fallbackVersion)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
authReq := &AuthReq{
|
||||
Username: username,
|
||||
ClientEphemeral: base64.StdEncoding.EncodeToString(proofs.ClientEphemeral),
|
||||
ClientProof: base64.StdEncoding.EncodeToString(proofs.ClientProof),
|
||||
SRPSession: info.srpSession,
|
||||
}
|
||||
|
||||
req, err := c.NewJSONRequest("POST", "/auth", authReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var authRes AuthRes
|
||||
if err = c.DoJSON(req, &authRes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = authRes.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
serverProof, err := base64.StdEncoding.DecodeString(authRes.ServerProof)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare(proofs.ExpectedServerProof, serverProof) != 1 {
|
||||
return nil, errors.New("pmapi: bad server proof")
|
||||
}
|
||||
|
||||
res, err = &authRes, authRes.Err()
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (c *client) tryFullAuth(username, password string, fallbackVersion int) (info *AuthInfo, authRes *AuthRes, err error) {
|
||||
info, err = c.AuthInfo(username)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
authRes, err = c.tryAuth(username, password, info, fallbackVersion)
|
||||
return
|
||||
}
|
||||
|
||||
// Auth will authenticate a user.
|
||||
func (c *client) Auth(username, password string, info *AuthInfo) (auth *Auth, err error) {
|
||||
if info == nil {
|
||||
if info, err = c.AuthInfo(username); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
authRes, err := c.tryAuth(username, password, info, 2)
|
||||
if err != nil && info.version == 0 && srp.CleanUserName(username) != strings.ToLower(username) {
|
||||
info, authRes, err = c.tryFullAuth(username, password, 1)
|
||||
}
|
||||
if err != nil && info.version == 0 {
|
||||
_, authRes, err = c.tryFullAuth(username, password, 0)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
auth = authRes.getAuth()
|
||||
c.sendAuth(auth)
|
||||
|
||||
return auth, err
|
||||
}
|
||||
|
||||
// Auth2FA will authenticate a user into full scope.
|
||||
// `Auth` struct contains method `HasTwoFactor` deciding whether this has to be done.
|
||||
func (c *client) Auth2FA(twoFactorCode string, auth *Auth) error {
|
||||
auth2FAReq := &Auth2FAReq{
|
||||
TwoFactorCode: twoFactorCode,
|
||||
}
|
||||
|
||||
req, err := c.NewJSONRequest("POST", "/auth/2fa", auth2FAReq)
|
||||
if err != nil {
|
||||
func (c *client) Auth2FA(ctx context.Context, req Auth2FAReq) error {
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(req).Post("/auth/2fa")
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var auth2FARes Auth2FARes
|
||||
if err := c.DoJSON(req, &auth2FARes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := auth2FARes.Err(); err != nil {
|
||||
switch auth2FARes.StatusCode {
|
||||
case http.StatusUnauthorized:
|
||||
return ErrBad2FACode
|
||||
case http.StatusUnprocessableEntity:
|
||||
return ErrBad2FACodeTryAgain
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthRefresh will refresh an expired access token.
|
||||
func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
|
||||
c.refreshLocker.Lock()
|
||||
defer c.refreshLocker.Unlock()
|
||||
|
||||
// If we don't yet have a saved access token, save this one in case the refresh fails!
|
||||
// That way we can try again later (see handleUnauthorizedStatus).
|
||||
c.cm.setTokenIfUnset(c.userID, uidAndRefreshToken)
|
||||
|
||||
split := strings.Split(uidAndRefreshToken, ":")
|
||||
if len(split) != 2 {
|
||||
err = ErrInvalidToken
|
||||
return
|
||||
func (c *client) AuthDelete(ctx context.Context) error {
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.Delete("/auth")
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
refreshReq := &AuthRefreshReq{
|
||||
ResponseType: "token",
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: split[1],
|
||||
UID: split[0],
|
||||
RedirectURI: "https://protonmail.ch",
|
||||
State: "random_string",
|
||||
}
|
||||
c.uid, c.acc, c.ref, c.exp = "", "", "", time.Time{}
|
||||
|
||||
// UID must be set for `x-pm-uid` header field, see backend-communication#11
|
||||
c.uid = split[0]
|
||||
// FIXME(conman): should we perhaps signal via AuthHandler that the auth was deleted?
|
||||
|
||||
req, err := c.NewJSONRequest("POST", "/auth/refresh", refreshReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var res AuthRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
if err = res.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
auth = res.getAuth()
|
||||
|
||||
// Responses from /auth/refresh are not guaranteed to return the UID if it has not changed.
|
||||
// But we want to always return it.
|
||||
if auth.uid == "" {
|
||||
auth.uid = c.uid
|
||||
}
|
||||
|
||||
c.sendAuth(auth)
|
||||
|
||||
return auth, err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) AuthSalt() (string, error) {
|
||||
salts, err := c.GetKeySalts()
|
||||
func (c *client) AuthSalt(ctx context.Context) (string, error) {
|
||||
salts, err := c.GetKeySalts(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if _, err := c.CurrentUser(); err != nil {
|
||||
if _, err := c.CurrentUser(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@ -414,40 +54,37 @@ func (c *client) AuthSalt() (string, error) {
|
||||
return "", errors.New("no matching salt found")
|
||||
}
|
||||
|
||||
// Logout instructs the client manager to log this client out.
|
||||
func (c *client) Logout() {
|
||||
c.cm.LogoutClient(c.userID)
|
||||
func (c *client) AddAuthHandler(handler AuthHandler) {
|
||||
c.authHandlers = append(c.authHandlers, handler)
|
||||
}
|
||||
|
||||
// DeleteAuth deletes the API session.
|
||||
func (c *client) DeleteAuth() (err error) {
|
||||
req, err := c.NewRequest("DELETE", "/auth", nil)
|
||||
func (c *client) authRefresh(ctx context.Context) error {
|
||||
c.authLocker.Lock()
|
||||
defer c.authLocker.Unlock()
|
||||
|
||||
auth, err := c.req.authRefresh(ctx, c.uid, c.ref)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
var res Res
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
c.acc = auth.AccessToken
|
||||
c.ref = auth.RefreshToken
|
||||
|
||||
for _, handler := range c.authHandlers {
|
||||
if err := handler(auth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = res.Err(); err != nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func randomString(length int) string {
|
||||
noise := make([]byte, length)
|
||||
|
||||
if _, err := io.ReadFull(rand.Reader, noise); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsConnected returns whether the client is authorized to access the API.
|
||||
func (c *client) IsConnected() bool {
|
||||
return c.uid != "" && c.accessToken != ""
|
||||
}
|
||||
|
||||
// ClearData clears sensitive data from the client.
|
||||
func (c *client) ClearData() {
|
||||
c.uid = ""
|
||||
c.accessToken = ""
|
||||
c.addresses = nil
|
||||
c.user = nil
|
||||
c.clearKeys()
|
||||
return base64.StdEncoding.EncodeToString(noise)[:length]
|
||||
}
|
||||
|
||||
@ -1,351 +1,135 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
package pmapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/srp"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
var testIdentity = &crypto.Identity{
|
||||
Name: "UserID",
|
||||
Email: "",
|
||||
func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
var wantAuth = &pmapi.Auth{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuth *pmapi.Auth
|
||||
|
||||
// Create a new client.
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
|
||||
|
||||
// Register an auth handler.
|
||||
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
|
||||
|
||||
// Make a request with an access token that already expired one second ago.
|
||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
|
||||
// The auth callback should have been called.
|
||||
if *gotAuth != *wantAuth {
|
||||
t.Fatal("got unexpected auth", gotAuth)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
testUsername = "jason"
|
||||
testAPIPassword = "apple"
|
||||
func Test401AuthRefresh(t *testing.T) {
|
||||
var wantAuth = &pmapi.Auth{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
}
|
||||
|
||||
testUID = "729ad6012421d67ad26950dc898bebe3a6e3caa2" //nolint[gosec]
|
||||
testAccessToken = "de0423049b44243afeec7d9c1d99be7b46da1e8a" //nolint[gosec]
|
||||
testAccessTokenOld = "feb3159ac63fb05119bcf4480d939278aa746926" //nolint[gosec]
|
||||
testRefreshToken = "a49b98256745bb497bec20e9b55f5de16f01fb52" //nolint[gosec]
|
||||
testRefreshTokenNew = "b894b4c4f20003f12d486900d8b88c7d68e67235" //nolint[gosec]
|
||||
)
|
||||
mux := http.NewServeMux()
|
||||
|
||||
var testAuthInfo = &AuthInfo{
|
||||
TwoFA: &TwoFactorInfo{TOTP: 1},
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
version: 4,
|
||||
salt: "yKlc5/CvObfoiw==",
|
||||
modulus: "-----BEGIN PGP SIGNED MESSAGE-----\nHash: SHA256\n\nW2z5HBi8RvsfYzZTS7qBaUxxPhsfHJFZpu3Kd6s1JafNrCCH9rfvPLrfuqocxWPgWDH2R8neK7PkNvjxto9TStuY5z7jAzWRvFWN9cQhAKkdWgy0JY6ywVn22+HFpF4cYesHrqFIKUPDMSSIlWjBVmEJZ/MusD44ZT29xcPrOqeZvwtCffKtGAIjLYPZIEbZKnDM1Dm3q2K/xS5h+xdhjnndhsrkwm9U9oyA2wxzSXFL+pdfj2fOdRwuR5nW0J2NFrq3kJjkRmpO/Genq1UW+TEknIWAb6VzJJJA244K/H8cnSx2+nSNZO3bbo6Ys228ruV9A8m6DhxmS+bihN3ttQ==\n-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwl4EARYIABAFAlwB1j0JEDUFhcTpUY8mAAD8CgEAnsFnF4cF0uSHKkXa1GIa\nGO86yMV4zDZEZcDSJo0fgr8A/AlupGN9EdHlsrZLmTA1vhIx+rOgxdEff28N\nkvNM7qIK\n=q6vu\n-----END PGP SIGNATURE-----\n",
|
||||
srpSession: "9b2946bbd9055f17c34940abdce0c3d3",
|
||||
serverEphemeral: "5tfigcLKoM0DPWYB+EqYE7QlqsiT63iOVlO5ZX0lTMEILSsrRdVCYrN8L3zkinsAjUZ/cx5wIS7N05k66uZb+ZE3lFOJS2s1BkqLvCrGxYL0e3n5YAnzHYlvCCJKXw/sK57ntfF1OOoblBXX6dw5LjeeDglEep2/DaE0TjD8WUpq4Ls2HlQGn9wrC7dFO2lJXsMhRffxKghiOsdvCLXDmwXginzn/LFezA8KrDsWOBSEGntwpg3s1xFj5h8BqtRHvC0igmoscqgw+3GCMTJ0NZAQ/L+5aJ/0ccL0WBK208ltCNl+/X6Sz0kpyvOP4RqFJhC1auVDJ9AjZQYSYZ1NEQ==",
|
||||
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
var call int
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
call++
|
||||
|
||||
if call == 1 {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuth *pmapi.Auth
|
||||
|
||||
// Create a new client.
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
|
||||
// Register an auth handler.
|
||||
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
|
||||
|
||||
// The first request will fail with 401, triggering a refresh and retry.
|
||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
|
||||
// The auth callback should have been called.
|
||||
if *gotAuth != *wantAuth {
|
||||
t.Fatal("got unexpected auth", gotAuth)
|
||||
}
|
||||
}
|
||||
|
||||
// testAuth has default values which are adjusted in each test.
|
||||
var testAuth = &Auth{
|
||||
EventID: "NcKPtU5eMNPMrDkIMbEJrgMtC9yQ7Xc5ZBT-tB3UtV1rZ324RWfCIdBI758q0UnsfywS8CkNenIQlWLIX_dUng==",
|
||||
ExpiresIn: 86400,
|
||||
RefreshToken: "feb3159ac63fb05119bcf4480d939278aa746926",
|
||||
func Test401RevokedAuth(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
accessToken: testAccessToken,
|
||||
uid: testUID,
|
||||
}
|
||||
|
||||
var testAuthRefreshReq = AuthRefreshReq{
|
||||
ResponseType: "token",
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: testRefreshToken,
|
||||
UID: testUID,
|
||||
RedirectURI: "https://protonmail.ch",
|
||||
State: "random_string",
|
||||
}
|
||||
|
||||
var testAuthReq = AuthReq{
|
||||
Username: testUsername,
|
||||
ClientProof: "axfvYdl9iXZjY6zQ+hBYmY7X3TDc/9JtSvrmyZXhDxjxkXB3Hro27t1KItmFIJloItY5sLZDs0eEEZJI34oFZD4ViSG0kfB7ZXcCZ9Jse+U5OFu4vdnPTGolnSofRMEs1NR6ePXzH7mQ10qoq43ity3ve2vmhQNuJNlHAPynKf2WqKOgxq7mmkBzEpXES4mIhwwgVbOygKcUSvguz5E5g13ATF0ZX2d9SJWAbZ262Tks+h99Cdk/dOfgLQhr0nO/r0cpwP84W2RWU2Q34LNkKuuQHkjmxelgBleGq54tCbhoCAYPP6vapgrQjNoVAC/dkjIIAoNL9bJSIynFM5znAA==",
|
||||
ClientEphemeral: "mK+eSMosfZO/Cs5s+vcbjpsN7F8UAObwlKKnCy/z9FpoMRM2PfTe5ywLBgffmLYaapPq7XOxaqaj08kcZLHcM1fIA2JQZZTKPnESN1qAQztJ3/YHMI0op6yBgzx9803OjIznjCD2B3XBSMOHIG4oG0UwocsIX32hiMnYlMMkt8NGrityPlnmEbxpRna3fu9LEZ+v0uo6PjKCrO7+9E3uaMi64HadXBfyx2raBFFwA+yh7FvE7U+hl3AJclEre4d8pmfhMdxXze1soJI8fMuqaa07rY0r0rF5mLLTuqTIGRFkU1qG9loq9+IMsSwgkt1P3ghW63JK7Y6LWdDy0d6cAg==",
|
||||
SRPSession: "9b2946bbd9055f17c34940abdce0c3d3",
|
||||
}
|
||||
|
||||
var testAuth2FAReq = Auth2FAReq{
|
||||
TwoFactorCode: "424242",
|
||||
}
|
||||
|
||||
func init() {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
srp.RandReader = rand.New(rand.NewSource(42))
|
||||
}
|
||||
|
||||
func TestClient_AuthInfo(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/auth/info"))
|
||||
|
||||
var infoReq AuthInfoReq
|
||||
Ok(t, json.NewDecoder(r.Body).Decode(&infoReq))
|
||||
Equals(t, infoReq.Username, testUsername)
|
||||
|
||||
return "/auth/info/post_response.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
info, err := c.AuthInfo(testCurrentUser.Name)
|
||||
Ok(t, err)
|
||||
Equals(t, testAuthInfo, info)
|
||||
}
|
||||
|
||||
// TestClient_Auth reflects changes from proton/backend-communcation#3.
|
||||
func TestClient_Auth(t *testing.T) {
|
||||
srp.RandReader = rand.New(rand.NewSource(42))
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
a.Nil(t, checkMethodAndPath(req, "POST", "/auth"))
|
||||
|
||||
var authReq AuthReq
|
||||
r.Nil(t, json.NewDecoder(req.Body).Decode(&authReq))
|
||||
r.Equal(t, testAuthReq, authReq)
|
||||
|
||||
return "/auth/post_response.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
auth, err := c.Auth(testUsername, testAPIPassword, testAuthInfo)
|
||||
r.Nil(t, err)
|
||||
|
||||
exp := &Auth{}
|
||||
*exp = *testAuth
|
||||
exp.accessToken = testAccessToken
|
||||
exp.RefreshToken = testRefreshToken
|
||||
a.Equal(t, exp, auth)
|
||||
}
|
||||
|
||||
func TestClient_Auth2FA(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/auth/2fa"))
|
||||
|
||||
var info2FAReq Auth2FAReq
|
||||
Ok(t, json.NewDecoder(r.Body).Decode(&info2FAReq))
|
||||
Equals(t, info2FAReq.TwoFactorCode, testAuth2FAReq.TwoFactorCode)
|
||||
|
||||
return "/auth/2fa/post_response.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
c.uid = testUID
|
||||
c.accessToken = testAccessToken
|
||||
err := c.Auth2FA(testAuth2FAReq.TwoFactorCode, testAuth)
|
||||
Ok(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Auth2FA_Fail(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/auth/2fa"))
|
||||
|
||||
var info2FAReq Auth2FAReq
|
||||
Ok(t, json.NewDecoder(r.Body).Decode(&info2FAReq))
|
||||
Equals(t, info2FAReq.TwoFactorCode, testAuth2FAReq.TwoFactorCode)
|
||||
|
||||
return "/auth/2fa/post_401_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
c.uid = testUID
|
||||
c.accessToken = testAccessToken
|
||||
err := c.Auth2FA(testAuth2FAReq.TwoFactorCode, testAuth)
|
||||
Equals(t, ErrBad2FACode, err)
|
||||
}
|
||||
|
||||
func TestClient_Auth2FA_Retry(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/auth/2fa"))
|
||||
|
||||
var info2FAReq Auth2FAReq
|
||||
Ok(t, json.NewDecoder(r.Body).Decode(&info2FAReq))
|
||||
Equals(t, info2FAReq.TwoFactorCode, testAuth2FAReq.TwoFactorCode)
|
||||
|
||||
return "/auth/2fa/post_422_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
c.uid = testUID
|
||||
c.accessToken = testAccessToken
|
||||
err := c.Auth2FA(testAuth2FAReq.TwoFactorCode, testAuth)
|
||||
Equals(t, ErrBad2FACodeTryAgain, err)
|
||||
}
|
||||
|
||||
func TestClient_Unlock(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
routeGetUsers,
|
||||
routeGetAddresses,
|
||||
)
|
||||
defer finish()
|
||||
c.uid = testUID
|
||||
c.accessToken = testAccessToken
|
||||
|
||||
err := c.Unlock([]byte("wrong"))
|
||||
a.Error(t, err, "expected error, pasword is wrong")
|
||||
|
||||
err = c.Unlock([]byte(testMailboxPassword))
|
||||
a.Nil(t, err)
|
||||
a.Equal(t, testUID, c.uid)
|
||||
a.Equal(t, testAccessToken, c.accessToken)
|
||||
|
||||
// second try should not fail because there is an unlocked key already
|
||||
err = c.Unlock([]byte("wrong"))
|
||||
a.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Unlock_EncPrivKey(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
routeGetUsers,
|
||||
routeGetAddresses,
|
||||
)
|
||||
defer finish()
|
||||
c.uid = testUID
|
||||
c.accessToken = testAccessToken
|
||||
|
||||
err := c.Unlock([]byte(testMailboxPassword))
|
||||
Ok(t, err)
|
||||
Equals(t, testUID, c.uid)
|
||||
Equals(t, testAccessToken, c.accessToken)
|
||||
}
|
||||
|
||||
func routeAuthRefresh(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(tb, checkMethodAndPath(r, "POST", "/auth/refresh"))
|
||||
Ok(tb, checkHeader(r.Header, "x-pm-uid", testUID))
|
||||
|
||||
var refreshReq AuthRefreshReq
|
||||
Ok(tb, json.NewDecoder(r.Body).Decode(&refreshReq))
|
||||
Equals(tb, testAuthRefreshReq, refreshReq)
|
||||
|
||||
return "/auth/refresh/post_response.json"
|
||||
}
|
||||
|
||||
// TestClient_AuthRefresh reflects changes from proton/backend-communcation#11.
|
||||
func TestClient_AuthRefresh(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
routeAuthRefresh,
|
||||
)
|
||||
defer finish()
|
||||
c.uid = "" // Testing that we always send correct `x-pm-uid`.
|
||||
c.accessToken = "oldToken"
|
||||
|
||||
auth, err := c.AuthRefresh(testUID + ":" + testRefreshToken)
|
||||
Ok(t, err)
|
||||
Equals(t, testUID, c.uid)
|
||||
|
||||
exp := &Auth{}
|
||||
*exp = *testAuth
|
||||
exp.uid = testUID // AuthRefresh will not return UID (only Auth returns the UID) we should set testUID to be able to generate token, see `GetToken`
|
||||
exp.accessToken = testAccessToken
|
||||
exp.EventID = ""
|
||||
exp.ExpiresIn = 360000
|
||||
exp.RefreshToken = testRefreshTokenNew
|
||||
Equals(t, exp, auth)
|
||||
}
|
||||
|
||||
func routeAuthRefreshHasUID(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(tb, checkMethodAndPath(r, "POST", "/auth/refresh"))
|
||||
Ok(tb, checkHeader(r.Header, "x-pm-uid", testUID))
|
||||
|
||||
var refreshReq AuthRefreshReq
|
||||
Ok(tb, json.NewDecoder(r.Body).Decode(&refreshReq))
|
||||
Equals(tb, testAuthRefreshReq, refreshReq)
|
||||
|
||||
return "/auth/refresh/post_resp_has_uid.json"
|
||||
}
|
||||
|
||||
// TestClient_AuthRefresh reflects changes from proton/backend-communcation#3.
|
||||
func TestClient_AuthRefresh_HasUID(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
routeAuthRefreshHasUID,
|
||||
)
|
||||
defer finish()
|
||||
c.uid = testUID
|
||||
c.accessToken = "oldToken"
|
||||
|
||||
auth, err := c.AuthRefresh(testUID + ":" + testRefreshToken)
|
||||
Ok(t, err)
|
||||
|
||||
exp := &Auth{}
|
||||
*exp = *testAuth
|
||||
exp.accessToken = testAccessToken
|
||||
exp.EventID = ""
|
||||
exp.ExpiresIn = 360000
|
||||
exp.RefreshToken = testRefreshTokenNew
|
||||
Equals(t, exp, auth)
|
||||
}
|
||||
|
||||
func TestClient_Logout(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(t, checkMethodAndPath(r, "DELETE", "/auth"))
|
||||
Ok(t, isAuthReq(r, testUID, testAccessToken))
|
||||
return "auth/delete_response.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
c.uid = testUID
|
||||
c.accessToken = testAccessToken
|
||||
|
||||
c.Logout()
|
||||
|
||||
r.Eventually(t, func() bool {
|
||||
return c.IsConnected() == false && c.userKeyRing == nil && c.addresses == nil && c.user == nil
|
||||
}, 10*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestClient_DoUnauthorized(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/"))
|
||||
return httpResponse(http.StatusUnauthorized)
|
||||
},
|
||||
routeAuthRefresh,
|
||||
func(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/"))
|
||||
Ok(t, isAuthReq(r, testUID, testAccessToken))
|
||||
return httpResponse(http.StatusOK)
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
c.uid = testUID
|
||||
c.accessToken = testAccessTokenOld
|
||||
c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken
|
||||
|
||||
req, err := c.NewRequest("GET", "/", nil)
|
||||
Ok(t, err)
|
||||
|
||||
res, err := c.Do(req, true)
|
||||
Ok(t, err)
|
||||
|
||||
defer Ok(t, res.Body.Close())
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
|
||||
// The request will fail with 401, triggering a refresh.
|
||||
// The retry will also fail with 401, returning an error.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
if !errors.Is(err, pmapi.ErrUnauthorized) {
|
||||
t.Fatal("expected error to be ErrUnauthorized, instead got", err)
|
||||
}
|
||||
}
|
||||
|
||||
72
pkg/pmapi/auth_types.go
Normal file
72
pkg/pmapi/auth_types.go
Normal file
@ -0,0 +1,72 @@
|
||||
package pmapi
|
||||
|
||||
type AuthModulus struct {
|
||||
Modulus string
|
||||
ModulusID string
|
||||
}
|
||||
|
||||
type GetAuthInfoReq struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
type AuthInfo struct {
|
||||
Version int
|
||||
Modulus string
|
||||
ServerEphemeral string
|
||||
Salt string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
type TwoFAInfo struct {
|
||||
Enabled TwoFAStatus
|
||||
}
|
||||
|
||||
type TwoFAStatus int
|
||||
|
||||
const (
|
||||
TwoFADisabled TwoFAStatus = iota
|
||||
TOTPEnabled
|
||||
// TODO: Support UTF
|
||||
)
|
||||
|
||||
type PasswordMode int
|
||||
|
||||
const (
|
||||
OnePasswordMode PasswordMode = iota + 1
|
||||
TwoPasswordMode
|
||||
)
|
||||
|
||||
type AuthReq struct {
|
||||
Username string
|
||||
ClientProof string
|
||||
ClientEphemeral string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
UserID string
|
||||
|
||||
UID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresIn int64
|
||||
|
||||
Scope string
|
||||
ServerProof string
|
||||
|
||||
TwoFA TwoFAInfo `json:"2FA"`
|
||||
PasswordMode PasswordMode
|
||||
}
|
||||
|
||||
type Auth2FAReq struct {
|
||||
TwoFactorCode string
|
||||
}
|
||||
|
||||
type AuthRefreshReq struct {
|
||||
UID string
|
||||
RefreshToken string
|
||||
ResponseType string
|
||||
GrantType string
|
||||
RedirectURI string
|
||||
State string
|
||||
}
|
||||
@ -1,94 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const protonStatusURL = "http://protonstatus.com/vpn_status"
|
||||
|
||||
// ErrNoInternetConnection indicates that both protonstatus and the API are unreachable.
|
||||
var ErrNoInternetConnection = errors.New("no internet connection")
|
||||
|
||||
// CheckConnection returns an error if there is no internet connection.
|
||||
// This should be moved to the ConnectionManager when it is implemented.
|
||||
func (cm *ClientManager) CheckConnection() error {
|
||||
// We use a normal dialer here which doesn't check tls fingerprints.
|
||||
client := &http.Client{Timeout: time.Second * 10}
|
||||
|
||||
// Do not cumulate timeouts, use goroutines.
|
||||
retStatus := make(chan error)
|
||||
retAPI := make(chan error)
|
||||
|
||||
// vpn_status endpoint is fast and returns only OK. We check the connection only.
|
||||
go checkConnection(client, protonStatusURL, retStatus)
|
||||
|
||||
// Check of API reachability also uses a fast endpoint.
|
||||
go checkConnection(client, cm.GetRootURL()+"/tests/ping", retAPI)
|
||||
|
||||
errStatus := <-retStatus
|
||||
errAPI := <-retAPI
|
||||
|
||||
switch {
|
||||
case errStatus == nil && errAPI == nil:
|
||||
return nil
|
||||
|
||||
case errStatus == nil && errAPI != nil:
|
||||
cm.log.Error("ProtonStatus is reachable but API is not")
|
||||
return ErrAPINotReachable
|
||||
|
||||
case errStatus != nil && errAPI == nil:
|
||||
cm.log.Warn("API is reachable but protonstatus is not")
|
||||
return nil
|
||||
|
||||
case errStatus != nil && errAPI != nil:
|
||||
cm.log.Error("Both ProtonStatus and API are unreachable")
|
||||
return ErrNoInternetConnection
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckConnection returns an error if there is no internet connection.
|
||||
func CheckConnection() error {
|
||||
client := &http.Client{Timeout: time.Second * 10}
|
||||
retStatus := make(chan error)
|
||||
go checkConnection(client, protonStatusURL, retStatus)
|
||||
return <-retStatus
|
||||
}
|
||||
|
||||
func checkConnection(client *http.Client, url string, errorChannel chan error) {
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
errorChannel <- err
|
||||
return
|
||||
}
|
||||
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
errorChannel <- fmt.Errorf("HTTP status code %d", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
errorChannel <- nil
|
||||
}
|
||||
@ -1,91 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/dialer"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testServerPort = "18000"
|
||||
const testRequestTimeout = 10 * time.Second
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
go startServer()
|
||||
time.Sleep(100 * time.Millisecond) // We need to wait till server is fully running.
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func startServer() {
|
||||
http.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("OK"))
|
||||
})
|
||||
http.HandleFunc("/timeout", func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(testRequestTimeout + time.Second) // Add extra second to be sure it will timeout.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("OK"))
|
||||
})
|
||||
http.HandleFunc("/serverError", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
})
|
||||
panic(http.ListenAndServe(":"+testServerPort, nil))
|
||||
}
|
||||
|
||||
func TestCheckConnection(t *testing.T) {
|
||||
checkCheckConnection(t, "ok", "")
|
||||
}
|
||||
|
||||
func TestCheckConnectionTimeout(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
checkCheckConnection(t, "timeout", "Client.Timeout exceeded while awaiting headers")
|
||||
}
|
||||
|
||||
func TestCheckConnectionServerError(t *testing.T) {
|
||||
checkCheckConnection(t, "serverError", "HTTP status code 500")
|
||||
}
|
||||
|
||||
func checkCheckConnection(t *testing.T, path string, expectedErrMessage string) {
|
||||
client := dialer.DialTimeoutClient()
|
||||
client.Timeout = testRequestTimeout
|
||||
|
||||
ch := make(chan error)
|
||||
|
||||
go checkConnection(client, "http://localhost:"+testServerPort+"/"+path, ch)
|
||||
|
||||
timeout := time.After(testRequestTimeout + time.Second)
|
||||
select {
|
||||
case err := <-ch:
|
||||
if expectedErrMessage == "" {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err, expectedErrMessage)
|
||||
}
|
||||
case <-timeout:
|
||||
t.Error("checkConnection timeout failed")
|
||||
}
|
||||
}
|
||||
@ -18,97 +18,23 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/jaytaylor/html2text"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Version of the API.
|
||||
const Version = 3
|
||||
|
||||
// API return codes.
|
||||
const (
|
||||
ForceUpgradeBadAppVersion = 5003
|
||||
APIOffline = 7001
|
||||
ImportMessageTooLong = 36022
|
||||
BansRequests = 85131
|
||||
)
|
||||
|
||||
// The output errors.
|
||||
var (
|
||||
ErrInvalidToken = errors.New("refresh token invalid")
|
||||
ErrAPINotReachable = errors.New("cannot reach the server")
|
||||
ErrUpgradeApplication = errors.New("application upgrade required")
|
||||
ErrConnectionSlow = errors.New("request canceled because connection speed was too slow")
|
||||
)
|
||||
|
||||
type ErrUnprocessableEntity struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (err *ErrUnprocessableEntity) Error() string {
|
||||
return err.error.Error()
|
||||
}
|
||||
|
||||
type ErrUnauthorized struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (err *ErrUnauthorized) Error() string {
|
||||
return fmt.Sprintf("unauthorized access: %+v", err.error.Error())
|
||||
}
|
||||
|
||||
// ClientConfig contains Client configuration.
|
||||
type ClientConfig struct {
|
||||
// The client application name and version.
|
||||
AppVersion string
|
||||
|
||||
// The client ID.
|
||||
ClientID string
|
||||
|
||||
// Timeout is the timeout of the full request. It is passed to http.Client.
|
||||
// If it is left unset, it means no timeout is applied.
|
||||
Timeout time.Duration
|
||||
|
||||
// FirstReadTimeout specifies the timeout from getting response to the first read of body response.
|
||||
// This timeout is applied only when MinBytesPerSecond is used.
|
||||
// Default is 5 minutes.
|
||||
FirstReadTimeout time.Duration
|
||||
|
||||
// MinBytesPerSecond specifies minimum Bytes per second or the request will be canceled.
|
||||
// Zero means no limitation.
|
||||
MinBytesPerSecond int64
|
||||
|
||||
ConnectionOnHandler func()
|
||||
ConnectionOffHandler func()
|
||||
UpgradeApplicationHandler func()
|
||||
}
|
||||
|
||||
// client is a client of the protonmail API. It implements the Client interface.
|
||||
type client struct {
|
||||
cm *ClientManager
|
||||
hc *http.Client
|
||||
req requester
|
||||
|
||||
uid string
|
||||
accessToken string
|
||||
userID string
|
||||
requestLocker sync.Locker
|
||||
refreshLocker sync.Locker
|
||||
uid, acc, ref string
|
||||
authHandlers []AuthHandler
|
||||
authLocker sync.RWMutex
|
||||
|
||||
user *User
|
||||
addresses AddressList
|
||||
@ -116,404 +42,79 @@ type client struct {
|
||||
addrKeyRing map[string]*crypto.KeyRing
|
||||
keyRingLock sync.Locker
|
||||
|
||||
log *logrus.Entry
|
||||
exp time.Time
|
||||
}
|
||||
|
||||
// newClient creates a new API client.
|
||||
func newClient(cm *ClientManager, userID string) *client {
|
||||
func newClient(req requester, uid string) *client {
|
||||
return &client{
|
||||
cm: cm,
|
||||
hc: getHTTPClient(cm.config, cm.roundTripper, cm.cookieJar),
|
||||
userID: userID,
|
||||
requestLocker: &sync.Mutex{},
|
||||
refreshLocker: &sync.Mutex{},
|
||||
keyRingLock: &sync.Mutex{},
|
||||
addrKeyRing: make(map[string]*crypto.KeyRing),
|
||||
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
|
||||
req: req,
|
||||
uid: uid,
|
||||
addrKeyRing: make(map[string]*crypto.KeyRing),
|
||||
keyRingLock: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// getHTTPClient returns a http client configured by the given client config and using the given transport.
|
||||
func getHTTPClient(cfg *ClientConfig, rt http.RoundTripper, jar http.CookieJar) (hc *http.Client) {
|
||||
return &http.Client{
|
||||
Transport: rt,
|
||||
Jar: jar,
|
||||
Timeout: cfg.Timeout,
|
||||
}
|
||||
func (c *client) withAuth(acc, ref string, exp time.Time) *client {
|
||||
c.acc = acc
|
||||
c.ref = ref
|
||||
c.exp = exp
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *client) IsUnlocked() bool {
|
||||
return c.userKeyRing != nil
|
||||
}
|
||||
|
||||
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
|
||||
// If the keyrings are already present, they are not recreated.
|
||||
func (c *client) Unlock(passphrase []byte) (err error) {
|
||||
c.keyRingLock.Lock()
|
||||
defer c.keyRingLock.Unlock()
|
||||
|
||||
return c.unlock(passphrase)
|
||||
}
|
||||
|
||||
// unlock unlocks the user's keys but without locking the keyring lock first.
|
||||
// Should only be used internally by methods that first lock the lock.
|
||||
func (c *client) unlock(passphrase []byte) (err error) {
|
||||
if _, err = c.CurrentUser(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.userKeyRing == nil {
|
||||
if err = c.unlockUser(passphrase); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
}
|
||||
|
||||
for _, address := range c.addresses {
|
||||
if c.addrKeyRing[address.ID] == nil {
|
||||
if err = c.unlockAddress(passphrase, address); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock address")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *client) ReloadKeys(passphrase []byte) (err error) {
|
||||
c.keyRingLock.Lock()
|
||||
defer c.keyRingLock.Unlock()
|
||||
|
||||
c.clearKeys()
|
||||
|
||||
return c.unlock(passphrase)
|
||||
}
|
||||
|
||||
func (c *client) clearKeys() {
|
||||
if c.userKeyRing != nil {
|
||||
c.userKeyRing.ClearPrivateParams()
|
||||
c.userKeyRing = nil
|
||||
}
|
||||
|
||||
for id, kr := range c.addrKeyRing {
|
||||
if kr != nil {
|
||||
kr.ClearPrivateParams()
|
||||
}
|
||||
delete(c.addrKeyRing, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) CloseConnections() {
|
||||
c.hc.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// Do makes an API request. It does not check for HTTP status code errors.
|
||||
func (c *client) Do(req *http.Request, retryUnauthorized bool) (res *http.Response, err error) {
|
||||
// Copy the request body in case we need to retry it.
|
||||
var bodyBuffer []byte
|
||||
if req.Body != nil {
|
||||
defer req.Body.Close() //nolint[errcheck]
|
||||
bodyBuffer, err = ioutil.ReadAll(req.Body)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := bytes.NewReader(bodyBuffer)
|
||||
req.Body = ioutil.NopCloser(r)
|
||||
}
|
||||
|
||||
return c.doBuffered(req, bodyBuffer, retryUnauthorized)
|
||||
}
|
||||
|
||||
// If needed it retries using req and buffered body.
|
||||
func (c *client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen]
|
||||
isAuthReq := strings.Contains(req.URL.Path, "/auth")
|
||||
|
||||
req.Header.Set("User-Agent", c.cm.userAgent.String())
|
||||
req.Header.Set("x-pm-appversion", c.cm.config.AppVersion)
|
||||
func (c *client) r(ctx context.Context) (*resty.Request, error) {
|
||||
r := c.req.r(ctx)
|
||||
|
||||
if c.uid != "" {
|
||||
req.Header.Set("x-pm-uid", c.uid)
|
||||
r.SetHeader("x-pm-uid", c.uid)
|
||||
}
|
||||
|
||||
if c.accessToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.accessToken)
|
||||
}
|
||||
|
||||
c.log.Debugln("Requesting ", req.Method, req.URL.RequestURI())
|
||||
if logrus.GetLevel() == logrus.TraceLevel {
|
||||
head := ""
|
||||
for i, v := range req.Header {
|
||||
head += i + ": "
|
||||
head += strings.Join(v, "")
|
||||
head += "\n"
|
||||
}
|
||||
c.log.Tracef("REQHEAD \n%s", head)
|
||||
c.log.Tracef("REQBODY '%s'", printBytes(bodyBuffer))
|
||||
}
|
||||
|
||||
hasBody := len(bodyBuffer) > 0
|
||||
if res, err = c.hc.Do(req); err != nil {
|
||||
if res == nil {
|
||||
c.log.WithError(err).Error("Cannot get response")
|
||||
err = ErrAPINotReachable
|
||||
c.cm.noConnection()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cookies are returned only after request was sent.
|
||||
c.log.Tracef("REQCOOKIES '%v'", req.Cookies())
|
||||
|
||||
resDate := res.Header.Get("Date")
|
||||
if resDate != "" {
|
||||
if serverTime, err := http.ParseTime(resDate); err == nil {
|
||||
crypto.UpdateTime(serverTime.Unix())
|
||||
}
|
||||
}
|
||||
|
||||
if res.StatusCode == http.StatusUnauthorized {
|
||||
if hasBody {
|
||||
r := bytes.NewReader(bodyBuffer)
|
||||
req.Body = ioutil.NopCloser(r)
|
||||
}
|
||||
|
||||
if !isAuthReq {
|
||||
_, _ = io.Copy(ioutil.Discard, res.Body)
|
||||
_ = res.Body.Close()
|
||||
return c.handleStatusUnauthorized(req, bodyBuffer, res, retryUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
// Retry induced by HTTP status code>
|
||||
retryAfter := 10
|
||||
doRetry := res.StatusCode == http.StatusTooManyRequests
|
||||
if doRetry {
|
||||
if headerAfter, err := strconv.Atoi(res.Header.Get("Retry-After")); err == nil && headerAfter > 0 {
|
||||
retryAfter = headerAfter
|
||||
}
|
||||
// To avoid spikes when all clients retry at the same time, we add some random wait.
|
||||
retryAfter += rand.Intn(10) //nolint[gosec] It is OK to use weak random number generator here
|
||||
|
||||
if hasBody {
|
||||
r := bytes.NewReader(bodyBuffer)
|
||||
req.Body = ioutil.NopCloser(r)
|
||||
}
|
||||
|
||||
c.log.Warningf("Retrying %s after %ds induced by http code %d", req.URL.Path, retryAfter, res.StatusCode)
|
||||
time.Sleep(time.Duration(retryAfter) * time.Second)
|
||||
_, _ = io.Copy(ioutil.Discard, res.Body)
|
||||
_ = res.Body.Close()
|
||||
return c.doBuffered(req, bodyBuffer, false)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// DoJSON performs the request and unmarshals the response as JSON into data.
|
||||
// If the API returns a non-2xx HTTP status code, the error returned will contain status
|
||||
// and response as plaintext. API errors must be checked by the caller.
|
||||
// It is performed buffered, in case we need to retry.
|
||||
func (c *client) DoJSON(req *http.Request, data interface{}) error {
|
||||
// Copy the request body in case we need to retry it
|
||||
var reqBodyBuffer []byte
|
||||
|
||||
if req.Body != nil {
|
||||
defer req.Body.Close() //nolint[errcheck]
|
||||
var err error
|
||||
if reqBodyBuffer, err = ioutil.ReadAll(req.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer))
|
||||
}
|
||||
|
||||
return c.doJSONBuffered(req, reqBodyBuffer, data)
|
||||
}
|
||||
|
||||
// doJSONBuffered performs a buffered json request (see DoJSON for more information).
|
||||
func (c *client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data interface{}) error { // nolint[funlen]
|
||||
req.Header.Set("Accept", "application/vnd.protonmail.v1+json")
|
||||
|
||||
var cancelRequest context.CancelFunc
|
||||
if c.cm.config.MinBytesPerSecond > 0 {
|
||||
var ctx context.Context
|
||||
ctx, cancelRequest = context.WithCancel(req.Context())
|
||||
defer func() {
|
||||
cancelRequest()
|
||||
}()
|
||||
req = req.WithContext(ctx)
|
||||
}
|
||||
|
||||
res, err := c.doBuffered(req, reqBodyBuffer, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close() //nolint[errcheck]
|
||||
|
||||
var resBody []byte
|
||||
if c.cm.config.MinBytesPerSecond == 0 {
|
||||
resBody, err = ioutil.ReadAll(res.Body)
|
||||
} else {
|
||||
resBody, err = c.readAllMinSpeed(res.Body, cancelRequest)
|
||||
if err == context.Canceled {
|
||||
err = ErrConnectionSlow
|
||||
}
|
||||
}
|
||||
|
||||
// The server response may contain data which we want to have in memory
|
||||
// for as little time as possible (such as keys). Go is garbage collected,
|
||||
// so we are not in charge of when the memory will actually be cleared.
|
||||
// We can at least try to rewrite the original data to mitigate this problem.
|
||||
defer func() {
|
||||
for i := 0; i < len(resBody); i++ {
|
||||
resBody[i] = byte(65)
|
||||
}
|
||||
}()
|
||||
|
||||
if logrus.GetLevel() == logrus.TraceLevel {
|
||||
head := ""
|
||||
for i, v := range res.Header {
|
||||
head += i + ": "
|
||||
head += strings.Join(v, "")
|
||||
head += "\n"
|
||||
}
|
||||
c.log.Tracef("RESHEAD \n%s", head)
|
||||
c.log.Tracef("RESBODY '%s'", resBody)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Retry induced by API code.
|
||||
errCode := &Res{}
|
||||
if err := json.Unmarshal(resBody, errCode); err == nil {
|
||||
if errCode.Code == BansRequests {
|
||||
retryAfter := 3
|
||||
c.log.Warningf("Retrying %s after %ds induced by API code %d", req.URL.Path, retryAfter, errCode.Code)
|
||||
time.Sleep(time.Duration(retryAfter) * time.Second)
|
||||
if len(reqBodyBuffer) > 0 {
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer))
|
||||
}
|
||||
return c.doJSONBuffered(req, reqBodyBuffer, data)
|
||||
}
|
||||
if errCode.Err() == ErrAPINotReachable {
|
||||
c.cm.noConnection()
|
||||
}
|
||||
if errCode.Err() == ErrUpgradeApplication {
|
||||
c.cm.config.UpgradeApplicationHandler()
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resBody, data); err != nil {
|
||||
// Check to see if this is due to a non 2xx HTTP status code.
|
||||
if res.StatusCode != http.StatusOK {
|
||||
r := bytes.NewReader(bytes.ReplaceAll(resBody, []byte("\n"), []byte("\\n")))
|
||||
plaintext, err := html2text.FromReader(r)
|
||||
if err == nil {
|
||||
return fmt.Errorf("Error: \n\n" + res.Status + "\n\n" + plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
if errJS, ok := err.(*json.SyntaxError); ok {
|
||||
return fmt.Errorf("invalid json %v (offset:%d) ", errJS.Error(), errJS.Offset)
|
||||
}
|
||||
|
||||
return fmt.Errorf("unmarshal fail: %v ", err)
|
||||
}
|
||||
|
||||
// Set StatusCode in case data struct supports that field.
|
||||
// It's safe to set StatusCode, server returns Code. StatusCode should be preferred over Code.
|
||||
dataValue := reflect.ValueOf(data).Elem()
|
||||
statusCodeField := dataValue.FieldByName("StatusCode")
|
||||
if statusCodeField.IsValid() && statusCodeField.CanSet() && statusCodeField.Kind() == reflect.Int {
|
||||
statusCodeField.SetInt(int64(res.StatusCode))
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
c.log.Warnf("request %s %s NOT OK: %s", req.Method, req.URL.Path, res.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) {
|
||||
firstReadTimeout := c.cm.config.FirstReadTimeout
|
||||
if firstReadTimeout == 0 {
|
||||
firstReadTimeout = 5 * time.Minute
|
||||
}
|
||||
timer := time.AfterFunc(firstReadTimeout, func() {
|
||||
cancelRequest()
|
||||
})
|
||||
|
||||
// speedCheckSeconds controls how often we check the transfer speed.
|
||||
// Note that connection can be unstable, on average very fast, but can be
|
||||
// idle for few seconds; or that API can take its time before sending
|
||||
// another data, e.g., API can send some data and take some time before
|
||||
// processing and sending the rest of the response.
|
||||
const speedCheckSeconds = 30
|
||||
|
||||
var buffer bytes.Buffer
|
||||
for {
|
||||
_, err := io.CopyN(&buffer, data, c.cm.config.MinBytesPerSecond*speedCheckSeconds)
|
||||
timer.Stop()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
if time.Now().After(c.exp) {
|
||||
if err := c.authRefresh(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
timer.Reset(speedCheckSeconds * time.Second)
|
||||
}
|
||||
|
||||
return ioutil.ReadAll(&buffer)
|
||||
c.authLocker.RLock()
|
||||
defer c.authLocker.RUnlock()
|
||||
|
||||
if c.acc != "" {
|
||||
r.SetAuthToken(c.acc)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *client) refreshAccessToken() (err error) {
|
||||
c.log.Debug("Refreshing token")
|
||||
|
||||
refreshToken := c.cm.GetToken(c.userID)
|
||||
|
||||
if refreshToken == "" {
|
||||
c.sendAuth(nil)
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
if _, err := c.AuthRefresh(refreshToken); err != nil {
|
||||
if err != ErrAPINotReachable {
|
||||
c.sendAuth(nil)
|
||||
}
|
||||
return errors.Wrap(err, "failed to refresh auth")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) {
|
||||
c.log.Info("Handling unauthorized status")
|
||||
|
||||
// If this is not a retry, then it is the first time handling status unauthorized,
|
||||
// so try again without refreshing the access token.
|
||||
if !retry {
|
||||
c.log.Debug("Handling unauthorized status by retrying")
|
||||
c.requestLocker.Lock()
|
||||
defer c.requestLocker.Unlock()
|
||||
|
||||
_, _ = io.Copy(ioutil.Discard, res.Body)
|
||||
_ = res.Body.Close()
|
||||
return c.doBuffered(req, reqBodyBuffer, true)
|
||||
}
|
||||
|
||||
// This is already a retry, so we will try to refresh the access token before trying again.
|
||||
if err = c.refreshAccessToken(); err != nil {
|
||||
c.log.WithError(err).Warn("Cannot refresh token")
|
||||
err = &ErrUnauthorized{err}
|
||||
return
|
||||
}
|
||||
_, err = io.Copy(ioutil.Discard, res.Body)
|
||||
func (c *client) do(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) (*resty.Response, error) {
|
||||
r, err := c.r(ctx)
|
||||
if err != nil {
|
||||
c.log.WithError(err).Warn("Failed to read out response body")
|
||||
return nil, err
|
||||
}
|
||||
_ = res.Body.Close()
|
||||
return c.doBuffered(req, reqBodyBuffer, true)
|
||||
|
||||
res, err := wrapRestyError(fn(r))
|
||||
if err != nil {
|
||||
if res.StatusCode() != http.StatusUnauthorized {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.authRefresh(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapRestyError(fn(r))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func wrapRestyError(res *resty.Response, err error) (*resty.Response, error) {
|
||||
if err, ok := err.(*resty.ResponseError); ok {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if res.RawResponse != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return res, errors.Wrap(ErrNoConnection, err.Error())
|
||||
}
|
||||
|
||||
70
pkg/pmapi/client_keys.go
Normal file
70
pkg/pmapi/client_keys.go
Normal file
@ -0,0 +1,70 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
|
||||
// If the keyrings are already present, they are not recreated.
|
||||
func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) {
|
||||
c.keyRingLock.Lock()
|
||||
defer c.keyRingLock.Unlock()
|
||||
|
||||
// FIXME(conman): Should this be done as part of NewClient somehow?
|
||||
|
||||
return c.unlock(ctx, passphrase)
|
||||
}
|
||||
|
||||
// unlock unlocks the user's keys but without locking the keyring lock first.
|
||||
// Should only be used internally by methods that first lock the lock.
|
||||
func (c *client) unlock(ctx context.Context, passphrase []byte) (err error) {
|
||||
if _, err = c.CurrentUser(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.userKeyRing == nil {
|
||||
if err = c.unlockUser(passphrase); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
}
|
||||
|
||||
for _, address := range c.addresses {
|
||||
if c.addrKeyRing[address.ID] == nil {
|
||||
if err = c.unlockAddress(passphrase, address); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock address")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *client) ReloadKeys(ctx context.Context, passphrase []byte) (err error) {
|
||||
c.keyRingLock.Lock()
|
||||
defer c.keyRingLock.Unlock()
|
||||
|
||||
c.clearKeys()
|
||||
|
||||
return c.unlock(ctx, passphrase)
|
||||
}
|
||||
|
||||
func (c *client) clearKeys() {
|
||||
if c.userKeyRing != nil {
|
||||
c.userKeyRing.ClearPrivateParams()
|
||||
c.userKeyRing = nil
|
||||
}
|
||||
|
||||
for id, kr := range c.addrKeyRing {
|
||||
if kr != nil {
|
||||
kr.ClearPrivateParams()
|
||||
}
|
||||
delete(c.addrKeyRing, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) IsUnlocked() bool {
|
||||
// FIXME(conman): Better way to check? we don't currently check address keys.
|
||||
return c.userKeyRing != nil
|
||||
}
|
||||
@ -1,210 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testClientConfig = &ClientConfig{
|
||||
AppVersion: "GoPMAPI_1.0.14",
|
||||
ClientID: "demoapp",
|
||||
FirstReadTimeout: 500 * time.Millisecond,
|
||||
MinBytesPerSecond: 256,
|
||||
}
|
||||
|
||||
func newTestClient(cm *ClientManager) *client {
|
||||
return cm.GetClient("tester").(*client)
|
||||
}
|
||||
|
||||
func TestClient_Do(t *testing.T) {
|
||||
const testResBody = "Hello World!"
|
||||
|
||||
var receivedReq *http.Request
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedReq = r
|
||||
fmt.Fprint(w, testResBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
req, err := c.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while creating request, got:", err)
|
||||
}
|
||||
|
||||
res, err := c.Do(req, true)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while executing request, got:", err)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while reading response, got:", err)
|
||||
}
|
||||
require.Nil(t, res.Body.Close())
|
||||
|
||||
if string(b) != testResBody {
|
||||
t.Fatalf("Invalid response body: expected %v, got %v", testResBody, string(b))
|
||||
}
|
||||
|
||||
h := receivedReq.Header
|
||||
if h.Get("x-pm-appversion") != testClientConfig.AppVersion {
|
||||
t.Fatalf("Invalid app version header: expected %v, got %v", testClientConfig.AppVersion, h.Get("x-pm-appversion"))
|
||||
}
|
||||
if h.Get("x-pm-uid") != "" {
|
||||
t.Fatalf("Expected no uid header when not authenticated, got %v", h.Get("x-pm-uid"))
|
||||
}
|
||||
if h.Get("Authorization") != "" {
|
||||
t.Fatalf("Expected no authentication header when not authenticated, got %v", h.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_DoRetryAfter(t *testing.T) {
|
||||
testStart := time.Now()
|
||||
secondAttemptTime := time.Now()
|
||||
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
w.Header().Set("content-type", "application/json;charset=utf-8")
|
||||
w.Header().Set("Retry-After", "1")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
return ""
|
||||
},
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
w.Header().Set("content-type", "application/json;charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
secondAttemptTime = time.Now()
|
||||
return "/HTTP_200.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
require.Nil(t, c.SendSimpleMetric("some_category", "some_action", "some_label"))
|
||||
waitedTime := secondAttemptTime.Sub(testStart)
|
||||
isInRange := 1*time.Second < waitedTime && waitedTime <= 11*time.Second
|
||||
require.True(t, isInRange, "Waited time: %v", waitedTime)
|
||||
}
|
||||
|
||||
type slowTransport struct {
|
||||
transport http.RoundTripper
|
||||
firstBodySleep time.Duration
|
||||
}
|
||||
|
||||
func (t *slowTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := t.transport.RoundTrip(req)
|
||||
if err == nil {
|
||||
resp.Body = &slowReadCloser{
|
||||
req: req,
|
||||
readCloser: resp.Body,
|
||||
firstBodySleep: t.firstBodySleep,
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
type slowReadCloser struct {
|
||||
req *http.Request
|
||||
readCloser io.ReadCloser
|
||||
firstBodySleep time.Duration
|
||||
}
|
||||
|
||||
func (r *slowReadCloser) Read(p []byte) (n int, err error) {
|
||||
// Normally timeout is processed by Read function.
|
||||
// It's hard to test slow connection; we need to manually
|
||||
// check when context is Done, because otherwise timeout
|
||||
// happens only during failed Read which will not happen
|
||||
// in this artificial environment.
|
||||
select {
|
||||
case <-r.req.Context().Done():
|
||||
return 0, context.Canceled
|
||||
case <-time.After(r.firstBodySleep):
|
||||
}
|
||||
return r.readCloser.Read(p)
|
||||
}
|
||||
|
||||
func (r *slowReadCloser) Close() error {
|
||||
return r.readCloser.Close()
|
||||
}
|
||||
|
||||
func TestClient_FirstReadTimeout(t *testing.T) {
|
||||
requestTimeout := testClientConfig.FirstReadTimeout + 1*time.Second
|
||||
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
return "/HTTP_200.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
c.hc.Transport = &slowTransport{
|
||||
transport: c.hc.Transport,
|
||||
firstBodySleep: requestTimeout,
|
||||
}
|
||||
|
||||
started := time.Now()
|
||||
err := c.SendSimpleMetric("some_category", "some_action", "some_label")
|
||||
require.Error(t, err, "cannot reach the server")
|
||||
require.True(t, time.Since(started) < requestTimeout, "Actual waited time: %v", time.Since(started))
|
||||
}
|
||||
|
||||
func TestClient_MinSpeedTimeout(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
routeSlow(31*time.Second), // 1 second longer than the minimum transfer speed poll time.
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.SendSimpleMetric("some_category", "some_action", "some_label")
|
||||
require.Error(t, err, "cannot reach the server")
|
||||
}
|
||||
|
||||
func TestClient_MinSpeedNoTimeout(t *testing.T) {
|
||||
finish, c := newTestServerCallbacks(t,
|
||||
routeSlow(500*time.Millisecond),
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.SendSimpleMetric("some_category", "some_action", "some_label")
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
func routeSlow(delay time.Duration) func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
return func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
w.Header().Set("content-type", "application/json;charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
_, _ = w.Write([]byte("{\"code\":1000,\"key\":\""))
|
||||
for chunk := 1; chunk <= 10; chunk++ {
|
||||
// We need to write enough bytes which enforce flushing data
|
||||
// because writer used by httptest does not implement Flusher.
|
||||
for i := 1; i <= 10000; i++ {
|
||||
_, _ = w.Write([]byte("a"))
|
||||
}
|
||||
time.Sleep(delay)
|
||||
}
|
||||
_, _ = w.Write([]byte("\"}"))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@ -18,69 +18,66 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// Client defines the interface of a PMAPI client.
|
||||
type Client interface {
|
||||
Auth(username, password string, info *AuthInfo) (*Auth, error)
|
||||
AuthInfo(username string) (*AuthInfo, error)
|
||||
AuthRefresh(token string) (*Auth, error)
|
||||
Auth2FA(twoFactorCode string, auth *Auth) error
|
||||
AuthSalt() (salt string, err error)
|
||||
Logout()
|
||||
DeleteAuth() error
|
||||
IsConnected() bool
|
||||
CloseConnections()
|
||||
ClearData()
|
||||
Auth2FA(context.Context, Auth2FAReq) error
|
||||
AuthSalt(ctx context.Context) (string, error)
|
||||
AuthDelete(context.Context) error
|
||||
AddAuthHandler(AuthHandler)
|
||||
|
||||
CurrentUser() (*User, error)
|
||||
UpdateUser() (*User, error)
|
||||
Unlock(passphrase []byte) (err error)
|
||||
ReloadKeys(passphrase []byte) (err error)
|
||||
CurrentUser(ctx context.Context) (*User, error)
|
||||
UpdateUser(ctx context.Context) (*User, error)
|
||||
Unlock(ctx context.Context, passphrase []byte) (err error)
|
||||
ReloadKeys(ctx context.Context, passphrase []byte) (err error)
|
||||
IsUnlocked() bool
|
||||
|
||||
GetAddresses() (addresses AddressList, err error)
|
||||
Addresses() AddressList
|
||||
ReorderAddresses(addressIDs []string) error
|
||||
GetAddresses(context.Context) (addresses AddressList, err error)
|
||||
ReorderAddresses(ctx context.Context, addressIDs []string) error
|
||||
|
||||
GetEvent(eventID string) (*Event, error)
|
||||
GetEvent(ctx context.Context, eventID string) (*Event, error)
|
||||
|
||||
SendMessage(string, *SendMessageReq) (sent, parent *Message, err error)
|
||||
CreateDraft(m *Message, parent string, action int) (created *Message, err error)
|
||||
Import([]*ImportMsgReq) ([]*ImportMsgRes, error)
|
||||
SendMessage(context.Context, string, *SendMessageReq) (sent, parent *Message, err error)
|
||||
CreateDraft(ctx context.Context, m *Message, parent string, action int) (created *Message, err error)
|
||||
Import(context.Context, ImportMsgReqs) ([]*ImportMsgRes, error)
|
||||
|
||||
CountMessages(addressID string) ([]*MessagesCount, error)
|
||||
ListMessages(filter *MessagesFilter) ([]*Message, int, error)
|
||||
GetMessage(apiID string) (*Message, error)
|
||||
DeleteMessages(apiIDs []string) error
|
||||
LabelMessages(apiIDs []string, labelID string) error
|
||||
UnlabelMessages(apiIDs []string, labelID string) error
|
||||
MarkMessagesRead(apiIDs []string) error
|
||||
MarkMessagesUnread(apiIDs []string) error
|
||||
CountMessages(ctx context.Context, addressID string) ([]*MessagesCount, error)
|
||||
ListMessages(ctx context.Context, filter *MessagesFilter) ([]*Message, int, error)
|
||||
GetMessage(ctx context.Context, apiID string) (*Message, error)
|
||||
DeleteMessages(ctx context.Context, apiIDs []string) error
|
||||
LabelMessages(ctx context.Context, apiIDs []string, labelID string) error
|
||||
UnlabelMessages(ctx context.Context, apiIDs []string, labelID string) error
|
||||
MarkMessagesRead(ctx context.Context, apiIDs []string) error
|
||||
MarkMessagesUnread(ctx context.Context, apiIDs []string) error
|
||||
|
||||
ListLabels() ([]*Label, error)
|
||||
CreateLabel(label *Label) (*Label, error)
|
||||
UpdateLabel(label *Label) (*Label, error)
|
||||
DeleteLabel(labelID string) error
|
||||
EmptyFolder(labelID string, addressID string) error
|
||||
ListLabels(ctx context.Context) ([]*Label, error)
|
||||
CreateLabel(ctx context.Context, label *Label) (*Label, error)
|
||||
UpdateLabel(ctx context.Context, label *Label) (*Label, error)
|
||||
DeleteLabel(ctx context.Context, labelID string) error
|
||||
EmptyFolder(ctx context.Context, labelID string, addressID string) error
|
||||
|
||||
Report(report ReportReq) error
|
||||
SendSimpleMetric(category, action, label string) error
|
||||
|
||||
GetMailSettings() (MailSettings, error)
|
||||
GetContactEmailByEmail(string, int, int) ([]ContactEmail, error)
|
||||
GetContactByID(string) (Contact, error)
|
||||
GetMailSettings(ctx context.Context) (MailSettings, error)
|
||||
GetContactEmailByEmail(context.Context, string, int, int) ([]ContactEmail, error)
|
||||
GetContactByID(context.Context, string) (Contact, error)
|
||||
DecryptAndVerifyCards([]Card) ([]Card, error)
|
||||
|
||||
GetAttachment(id string) (att io.ReadCloser, err error)
|
||||
CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
|
||||
DeleteAttachment(attID string) (err error)
|
||||
GetAttachment(ctx context.Context, id string) (att io.ReadCloser, err error)
|
||||
CreateAttachment(ctx context.Context, att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
|
||||
|
||||
KeyRingForAddressID(string) (kr *crypto.KeyRing, err error)
|
||||
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
|
||||
|
||||
DownloadAndVerify(string, string, *crypto.KeyRing) (io.Reader, error)
|
||||
GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
|
||||
}
|
||||
|
||||
type AuthHandler func(*Auth) error
|
||||
|
||||
type requester interface {
|
||||
r(context.Context) *resty.Request
|
||||
authRefresh(context.Context, string, string) (*Auth, error)
|
||||
}
|
||||
|
||||
@ -1,459 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/config/useragent"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const maxLogoutRetries = 5
|
||||
|
||||
// ClientManager is a manager of clients.
|
||||
type ClientManager struct { //nolint[maligned]
|
||||
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to
|
||||
// create other types of clients (e.g. for integration tests).
|
||||
newClient func(userID string) Client
|
||||
|
||||
config *ClientConfig
|
||||
userAgent *useragent.UserAgent
|
||||
roundTripper http.RoundTripper
|
||||
|
||||
clients map[string]Client
|
||||
clientsLocker sync.Locker
|
||||
cookieJar http.CookieJar
|
||||
|
||||
tokens map[string]string
|
||||
tokensLocker sync.Locker
|
||||
|
||||
expirations map[string]*tokenExpiration
|
||||
expiredTokens chan string
|
||||
expirationsLocker sync.Locker
|
||||
|
||||
authUpdates chan ClientAuth
|
||||
|
||||
host, scheme string
|
||||
hostLocker sync.RWMutex
|
||||
|
||||
allowProxy bool
|
||||
proxyProvider *proxyProvider
|
||||
proxyUseDuration time.Duration
|
||||
|
||||
idGen idGen
|
||||
|
||||
connectionOff bool
|
||||
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
type idGen int
|
||||
|
||||
func (i *idGen) next() int {
|
||||
(*i)++
|
||||
return int(*i)
|
||||
}
|
||||
|
||||
// ClientAuth holds an API auth produced by a Client for a specific user.
|
||||
type ClientAuth struct {
|
||||
UserID string
|
||||
Auth *Auth
|
||||
}
|
||||
|
||||
// tokenExpiration manages the expiration of an access token.
|
||||
type tokenExpiration struct {
|
||||
timer *time.Timer
|
||||
cancel chan (struct{})
|
||||
}
|
||||
|
||||
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
|
||||
func NewClientManager(config *ClientConfig, userAgent *useragent.UserAgent) (cm *ClientManager) {
|
||||
cm = &ClientManager{
|
||||
config: config,
|
||||
userAgent: userAgent,
|
||||
roundTripper: http.DefaultTransport,
|
||||
|
||||
clients: make(map[string]Client),
|
||||
clientsLocker: &sync.Mutex{},
|
||||
|
||||
tokens: make(map[string]string),
|
||||
tokensLocker: &sync.Mutex{},
|
||||
|
||||
expirations: make(map[string]*tokenExpiration),
|
||||
expiredTokens: make(chan string),
|
||||
expirationsLocker: &sync.Mutex{},
|
||||
|
||||
host: rootURL,
|
||||
scheme: rootScheme,
|
||||
hostLocker: sync.RWMutex{},
|
||||
|
||||
authUpdates: make(chan ClientAuth),
|
||||
|
||||
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
|
||||
proxyUseDuration: proxyUseDuration,
|
||||
|
||||
connectionOff: false,
|
||||
|
||||
log: logrus.WithField("pkg", "pmapi-manager"),
|
||||
}
|
||||
|
||||
cm.newClient = func(userID string) Client {
|
||||
return newClient(cm, userID)
|
||||
}
|
||||
|
||||
go cm.watchTokenExpirations()
|
||||
|
||||
return cm
|
||||
}
|
||||
|
||||
func (cm *ClientManager) noConnection() {
|
||||
cm.log.Trace("No connection available")
|
||||
if cm.connectionOff {
|
||||
return
|
||||
}
|
||||
|
||||
cm.log.Warn("Connection lost")
|
||||
if cm.config.ConnectionOffHandler != nil {
|
||||
cm.config.ConnectionOffHandler()
|
||||
}
|
||||
cm.connectionOff = true
|
||||
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(30 * time.Second)
|
||||
|
||||
if err := cm.CheckConnection(); err == nil {
|
||||
cm.log.Info("Connection re-established")
|
||||
if cm.config.ConnectionOnHandler != nil {
|
||||
cm.config.ConnectionOnHandler()
|
||||
}
|
||||
cm.connectionOff = false
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SetClientConstructor sets the method used to construct clients.
|
||||
// By default this is `pmapi.newClient` but can be overridden with this method.
|
||||
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) {
|
||||
cm.newClient = f
|
||||
}
|
||||
|
||||
// SetCookieJar sets the cookie jar given to clients.
|
||||
func (cm *ClientManager) SetCookieJar(jar http.CookieJar) {
|
||||
cm.cookieJar = jar
|
||||
}
|
||||
|
||||
// SetRoundTripper sets the roundtripper used by clients created by this client manager.
|
||||
func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
|
||||
cm.roundTripper = rt
|
||||
}
|
||||
|
||||
func (cm *ClientManager) GetAppVersion() string {
|
||||
return cm.config.AppVersion
|
||||
}
|
||||
|
||||
func (cm *ClientManager) GetUserAgent() string {
|
||||
return cm.userAgent.String()
|
||||
}
|
||||
|
||||
// GetClient returns a client for the given userID.
|
||||
// If the client does not exist already, it is created.
|
||||
func (cm *ClientManager) GetClient(userID string) Client {
|
||||
cm.clientsLocker.Lock()
|
||||
defer cm.clientsLocker.Unlock()
|
||||
|
||||
if client, ok := cm.clients[userID]; ok {
|
||||
return client
|
||||
}
|
||||
|
||||
client := cm.newClient(userID)
|
||||
|
||||
cm.clients[userID] = client
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// GetAnonymousClient returns an anonymous client.
|
||||
func (cm *ClientManager) GetAnonymousClient() Client {
|
||||
return cm.GetClient(fmt.Sprintf("anonymous-%v", cm.idGen.next()))
|
||||
}
|
||||
|
||||
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
||||
func (cm *ClientManager) LogoutClient(userID string) {
|
||||
cm.clientsLocker.Lock()
|
||||
defer cm.clientsLocker.Unlock()
|
||||
|
||||
client, ok := cm.clients[userID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(cm.clients, userID)
|
||||
|
||||
go func() {
|
||||
defer client.ClearData()
|
||||
defer cm.clearToken(userID)
|
||||
|
||||
if strings.HasPrefix(userID, "anonymous-") {
|
||||
return
|
||||
}
|
||||
|
||||
var retries int
|
||||
|
||||
for client.DeleteAuth() == ErrAPINotReachable {
|
||||
retries++
|
||||
|
||||
if retries > maxLogoutRetries {
|
||||
cm.log.Error("Failed to delete client auth (retried too many times)")
|
||||
break
|
||||
}
|
||||
|
||||
cm.log.Warn("Failed to delete client auth because API was not reachable, retrying...")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// GetRootURL returns the full root URL (scheme+host).
|
||||
func (cm *ClientManager) GetRootURL() string {
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return fmt.Sprintf("%v://%v", cm.scheme, cm.host)
|
||||
}
|
||||
|
||||
// getHost returns the host to make requests to.
|
||||
// It does not include the protocol i.e. no "https://" (use getScheme for that).
|
||||
func (cm *ClientManager) getHost() string {
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return cm.host
|
||||
}
|
||||
|
||||
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
||||
func (cm *ClientManager) IsProxyAllowed() bool {
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return cm.allowProxy
|
||||
}
|
||||
|
||||
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
|
||||
func (cm *ClientManager) AllowProxy() {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
cm.allowProxy = true
|
||||
}
|
||||
|
||||
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
|
||||
func (cm *ClientManager) DisallowProxy() {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
cm.allowProxy = false
|
||||
cm.host = rootURL
|
||||
|
||||
for _, client := range cm.clients {
|
||||
client.CloseConnections()
|
||||
}
|
||||
}
|
||||
|
||||
// IsProxyEnabled returns whether we are currently proxying requests.
|
||||
func (cm *ClientManager) IsProxyEnabled() bool {
|
||||
cm.hostLocker.RLock()
|
||||
defer cm.hostLocker.RUnlock()
|
||||
|
||||
return cm.host != rootURL
|
||||
}
|
||||
|
||||
// switchToReachableServer switches to using a reachable server (either proxy or standard API).
|
||||
func (cm *ClientManager) switchToReachableServer() (proxy string, err error) {
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
logrus.Info("Attempting to switch to a proxy")
|
||||
|
||||
if proxy, err = cm.proxyProvider.findReachableServer(); err != nil {
|
||||
err = errors.Wrap(err, "failed to find a usable proxy")
|
||||
return
|
||||
}
|
||||
|
||||
// If the chosen proxy is the standard API, we want to use it but still show the troubleshooting screen.
|
||||
if proxy == rootURL {
|
||||
logrus.Info("The standard API is reachable again; connection drop was only intermittent")
|
||||
err = ErrAPINotReachable
|
||||
cm.host = proxy
|
||||
return
|
||||
}
|
||||
|
||||
logrus.WithField("proxy", proxy).Info("Switching to a proxy")
|
||||
|
||||
// If the host is currently the rootURL, it's the first time we are enabling a proxy.
|
||||
// This means we want to disable it again in 24 hours.
|
||||
if cm.host == rootURL {
|
||||
go func() {
|
||||
<-time.After(cm.proxyUseDuration)
|
||||
|
||||
cm.hostLocker.Lock()
|
||||
defer cm.hostLocker.Unlock()
|
||||
|
||||
cm.host = rootURL
|
||||
}()
|
||||
}
|
||||
|
||||
cm.host = proxy
|
||||
|
||||
return proxy, err
|
||||
}
|
||||
|
||||
// GetToken returns the token for the given userID.
|
||||
func (cm *ClientManager) GetToken(userID string) string {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
return cm.tokens[userID]
|
||||
}
|
||||
|
||||
// GetAuthUpdateChannel returns a channel on which client auths can be received.
|
||||
func (cm *ClientManager) GetAuthUpdateChannel() chan ClientAuth {
|
||||
return cm.authUpdates
|
||||
}
|
||||
|
||||
// setTokenIfUnset sets the token for the given userID if it wasn't already set.
|
||||
// The set token does not expire.
|
||||
func (cm *ClientManager) setTokenIfUnset(userID, token string) {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
if _, ok := cm.tokens[userID]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
logrus.WithField("userID", userID).Info("Setting token because it is currently unset")
|
||||
|
||||
cm.tokens[userID] = token
|
||||
}
|
||||
|
||||
// setToken sets the token for the given userID with the given expiration time.
|
||||
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).Info("Updating token")
|
||||
|
||||
cm.tokens[userID] = token
|
||||
|
||||
cm.setTokenExpiration(userID, expiration)
|
||||
}
|
||||
|
||||
// setTokenExpiration will ensure the token is refreshed if it expires.
|
||||
// If the token already has an expiration time set, it is replaced.
|
||||
func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) {
|
||||
cm.expirationsLocker.Lock()
|
||||
defer cm.expirationsLocker.Unlock()
|
||||
|
||||
// Reduce the expiration by one minute so we can do the refresh with enough time to spare.
|
||||
expiration -= time.Minute
|
||||
|
||||
if exp, ok := cm.expirations[userID]; ok {
|
||||
exp.timer.Stop()
|
||||
close(exp.cancel)
|
||||
}
|
||||
|
||||
cm.expirations[userID] = &tokenExpiration{
|
||||
timer: time.NewTimer(expiration),
|
||||
cancel: make(chan struct{}),
|
||||
}
|
||||
|
||||
go func(expiration *tokenExpiration) {
|
||||
select {
|
||||
case <-expiration.timer.C:
|
||||
cm.expiredTokens <- userID
|
||||
|
||||
case <-expiration.cancel:
|
||||
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
|
||||
}
|
||||
}(cm.expirations[userID])
|
||||
}
|
||||
|
||||
func (cm *ClientManager) clearToken(userID string) {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).Debug("Clearing token")
|
||||
|
||||
delete(cm.tokens, userID)
|
||||
}
|
||||
|
||||
// HandleAuth updates or clears client authorisation based on auths received and then forwards the auth onwards.
|
||||
func (cm *ClientManager) HandleAuth(ca ClientAuth) {
|
||||
cm.clientsLocker.Lock()
|
||||
defer cm.clientsLocker.Unlock()
|
||||
|
||||
// If we aren't managing this client, there's nothing to do.
|
||||
if _, ok := cm.clients[ca.UserID]; !ok {
|
||||
logrus.WithField("userID", ca.UserID).Info("Not handling auth for unmanaged client")
|
||||
return
|
||||
}
|
||||
|
||||
// If the auth is nil, we should clear the token.
|
||||
if ca.Auth == nil {
|
||||
cm.clearToken(ca.UserID)
|
||||
go cm.LogoutClient(ca.UserID)
|
||||
} else {
|
||||
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
|
||||
}
|
||||
|
||||
logrus.Debug("ClientManager is forwarding auth update...")
|
||||
cm.authUpdates <- ca
|
||||
logrus.Debug("Auth update was forwarded")
|
||||
}
|
||||
|
||||
// watchTokenExpirations refreshes any tokens which are about to expire.
|
||||
func (cm *ClientManager) watchTokenExpirations() {
|
||||
for userID := range cm.expiredTokens {
|
||||
log := cm.log.WithField("userID", userID)
|
||||
|
||||
log.Info("Auth token expired! Refreshing")
|
||||
|
||||
client, ok := cm.clients[userID]
|
||||
if !ok {
|
||||
log.Warn("Can't refresh expired token because there is no such client")
|
||||
continue
|
||||
}
|
||||
|
||||
token, ok := cm.tokens[userID]
|
||||
if !ok {
|
||||
log.Warn("Can't refresh expired token because there is no such token")
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := client.AuthRefresh(token); err != nil {
|
||||
log.WithError(err).Error("Failed to refresh expired token")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,31 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import "github.com/ProtonMail/proton-bridge/internal/config/useragent"
|
||||
|
||||
func newTestClientManager(cfg *ClientConfig) *ClientManager {
|
||||
cm := NewClientManager(cfg, useragent.New())
|
||||
|
||||
go func() {
|
||||
for range cm.authUpdates {
|
||||
}
|
||||
}()
|
||||
|
||||
return cm
|
||||
}
|
||||
@ -1,55 +1,11 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// rootURL is the API root URL.
|
||||
// It must not contain the protocol! The protocol should be in rootScheme.
|
||||
var rootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
|
||||
|
||||
// rootScheme is the scheme to use for connections to the root URL.
|
||||
var rootScheme = "https" //nolint[gochecknoglobals]
|
||||
|
||||
func GetAPIConfig(configName, appVersion string) *ClientConfig {
|
||||
return &ClientConfig{
|
||||
AppVersion: getAPIOS() + strings.Title(configName) + "_" + appVersion,
|
||||
ClientID: configName,
|
||||
Timeout: 25 * time.Minute, // Overall request timeout (~25MB / 25 mins => ~16kB/s, should be reasonable).
|
||||
FirstReadTimeout: 30 * time.Second, // 30s to match 30s response header timeout.
|
||||
MinBytesPerSecond: 1 << 10, // Enforce minimum download speed of 1kB/s.
|
||||
}
|
||||
type Config struct {
|
||||
HostURL string
|
||||
AppVersion string
|
||||
}
|
||||
|
||||
// getAPIOS returns actual operating system.
|
||||
func getAPIOS() string {
|
||||
switch os := runtime.GOOS; os {
|
||||
case "darwin": // nolint: goconst
|
||||
return "macOS"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
}
|
||||
|
||||
return "Linux"
|
||||
var DefaultConfig = Config{
|
||||
HostURL: "https://api.protonmail.ch",
|
||||
AppVersion: "Other",
|
||||
}
|
||||
|
||||
@ -1,44 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// +build !build_qa
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
)
|
||||
|
||||
func GetRoundTripper(cm *ClientManager, listener listener.Listener) http.RoundTripper {
|
||||
// We use a TLS dialer.
|
||||
basicDialer := NewBasicTLSDialer()
|
||||
|
||||
// We wrap the TLS dialer in a layer which enforces connections to trusted servers.
|
||||
pinningDialer := NewPinningTLSDialer(basicDialer)
|
||||
|
||||
// We want any pin mismatches to be communicated back to bridge GUI and reported.
|
||||
pinningDialer.SetTLSIssueNotifier(func() { listener.Emit(events.TLSCertIssue, "") })
|
||||
pinningDialer.EnableRemoteTLSIssueReporting(cm)
|
||||
|
||||
// We wrap the pinning dialer in a layer which adds "alternative routing" feature.
|
||||
proxyDialer := NewProxyTLSDialer(pinningDialer, cm)
|
||||
|
||||
return CreateTransportWithDialer(proxyDialer)
|
||||
}
|
||||
@ -1,51 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// +build build_qa
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// This config allows to dynamically change ROOT URL.
|
||||
fullRootURL := os.Getenv("PMAPI_ROOT_URL")
|
||||
if strings.HasPrefix(fullRootURL, "http") {
|
||||
rootURLparts := strings.SplitN(fullRootURL, "://", 2)
|
||||
rootScheme = rootURLparts[0]
|
||||
rootURL = rootURLparts[1]
|
||||
} else if fullRootURL != "" {
|
||||
rootURL = fullRootURL
|
||||
rootScheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
func GetRoundTripper(_ *ClientManager, _ listener.Listener) http.RoundTripper {
|
||||
transport := CreateTransportWithDialer(NewBasicTLSDialer())
|
||||
|
||||
// TLS certificate of testing environment might be self-signed.
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
return transport
|
||||
}
|
||||
@ -1,23 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
// ConnectionReporter provides a way to report when internet connection is lost.
|
||||
type ConnectionReporter interface {
|
||||
NotifyConnectionLost() error
|
||||
}
|
||||
@ -18,9 +18,11 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type Card struct {
|
||||
@ -105,322 +107,40 @@ func (c *client) DecryptAndVerifyCards(cards []Card) ([]Card, error) {
|
||||
return cards, nil
|
||||
}
|
||||
|
||||
// ====================== READ ===========================
|
||||
|
||||
type ContactsListRes struct {
|
||||
Res
|
||||
Contacts []*Contact
|
||||
}
|
||||
|
||||
// GetContacts gets all contacts.
|
||||
func (c *client) GetContacts(page int, pageSize int) (contacts []*Contact, err error) {
|
||||
v := url.Values{}
|
||||
v.Set("Page", strconv.Itoa(page))
|
||||
if pageSize > 0 {
|
||||
v.Set("PageSize", strconv.Itoa(pageSize))
|
||||
}
|
||||
req, err := c.NewRequest("GET", "/contacts?"+v.Encode(), nil)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var res ContactsListRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
contacts, err = res.Contacts, res.Err()
|
||||
return
|
||||
}
|
||||
|
||||
// GetContactByID gets contact details specified by contact ID.
|
||||
func (c *client) GetContactByID(id string) (contactDetail Contact, err error) {
|
||||
req, err := c.NewRequest("GET", "/contacts/"+id, nil)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
type ContactRes struct {
|
||||
Res
|
||||
func (c *client) GetContactByID(ctx context.Context, contactID string) (contactDetail Contact, err error) {
|
||||
var res struct {
|
||||
Contact Contact
|
||||
}
|
||||
var res ContactRes
|
||||
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetResult(&res).Get("/contacts/v4/" + contactID)
|
||||
}); err != nil {
|
||||
return Contact{}, err
|
||||
}
|
||||
|
||||
contactDetail, err = res.Contact, res.Err()
|
||||
return
|
||||
}
|
||||
|
||||
// GetContactsForExport gets contacts in vCard format, signed and encrypted.
|
||||
func (c *client) GetContactsForExport(page int, pageSize int) (contacts []Contact, err error) {
|
||||
v := url.Values{}
|
||||
v.Set("Page", strconv.Itoa(page))
|
||||
if pageSize > 0 {
|
||||
v.Set("PageSize", strconv.Itoa(pageSize))
|
||||
}
|
||||
|
||||
req, err := c.NewRequest("GET", "/contacts/export?"+v.Encode(), nil)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
type ContactsDetailsRes struct {
|
||||
Res
|
||||
Contacts []Contact
|
||||
}
|
||||
var res ContactsDetailsRes
|
||||
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
contacts, err = res.Contacts, res.Err()
|
||||
return
|
||||
}
|
||||
|
||||
type ContactsEmailsRes struct {
|
||||
Res
|
||||
ContactEmails []ContactEmail
|
||||
Total int
|
||||
}
|
||||
|
||||
// GetAllContactsEmails gets all emails from all contacts.
|
||||
func (c *client) GetAllContactsEmails(page int, pageSize int) (contactsEmails []ContactEmail, err error) {
|
||||
v := url.Values{}
|
||||
v.Set("Page", strconv.Itoa(page))
|
||||
if pageSize > 0 {
|
||||
v.Set("PageSize", strconv.Itoa(pageSize))
|
||||
}
|
||||
|
||||
req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var res ContactsEmailsRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
contactsEmails, err = res.ContactEmails, res.Err()
|
||||
return
|
||||
return res.Contact, nil
|
||||
}
|
||||
|
||||
// GetContactEmailByEmail gets all emails from all contacts matching a specified email string.
|
||||
func (c *client) GetContactEmailByEmail(email string, page int, pageSize int) (contactEmails []ContactEmail, err error) {
|
||||
v := url.Values{}
|
||||
v.Set("Page", strconv.Itoa(page))
|
||||
if pageSize > 0 {
|
||||
v.Set("PageSize", strconv.Itoa(pageSize))
|
||||
}
|
||||
v.Set("Email", email)
|
||||
|
||||
req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
|
||||
if err != nil {
|
||||
return
|
||||
func (c *client) GetContactEmailByEmail(ctx context.Context, email string, page int, pageSize int) (contactEmails []ContactEmail, err error) {
|
||||
var res struct {
|
||||
ContactEmails []ContactEmail
|
||||
}
|
||||
|
||||
var res ContactsEmailsRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetQueryParams(map[string]string{
|
||||
"Email": email,
|
||||
"Page": strconv.Itoa(page),
|
||||
"PageSize": strconv.Itoa(pageSize),
|
||||
}).SetResult(&res).Get("/contacts/v4")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contactEmails, err = res.ContactEmails, res.Err()
|
||||
return
|
||||
return res.ContactEmails, nil
|
||||
}
|
||||
|
||||
// ============================ CREATE ====================================
|
||||
|
||||
type CardsList struct {
|
||||
Cards []Card
|
||||
}
|
||||
|
||||
type ContactsCards struct {
|
||||
Contacts []CardsList
|
||||
}
|
||||
|
||||
type SingleContactResponse struct {
|
||||
Res
|
||||
Contact Contact
|
||||
}
|
||||
|
||||
type IndexedContactResponse struct {
|
||||
Index int
|
||||
Response SingleContactResponse
|
||||
}
|
||||
|
||||
type AddContactsResponse struct {
|
||||
Res
|
||||
Responses []IndexedContactResponse
|
||||
}
|
||||
|
||||
type AddContactsReq struct {
|
||||
ContactsCards
|
||||
Overwrite int
|
||||
Groups int
|
||||
Labels int
|
||||
}
|
||||
|
||||
// AddContacts adds contacts specified by cards. Performs signing and encrypting based on card type.
|
||||
func (c *client) AddContacts(cards ContactsCards, overwrite int, groups int, labels int) (res *AddContactsResponse, err error) {
|
||||
reqBody := AddContactsReq{
|
||||
ContactsCards: cards,
|
||||
Overwrite: overwrite,
|
||||
Groups: groups,
|
||||
Labels: labels,
|
||||
}
|
||||
|
||||
req, err := c.NewJSONRequest("POST", "/contacts", reqBody)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var addContactsRes AddContactsResponse
|
||||
if err = c.DoJSON(req, &addContactsRes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
res, err = &addContactsRes, addContactsRes.Err()
|
||||
return
|
||||
}
|
||||
|
||||
// ================================= UPDATE =======================================
|
||||
|
||||
type UpdateContactResponse struct {
|
||||
Res
|
||||
Contact Contact
|
||||
}
|
||||
|
||||
type UpdateContactReq struct {
|
||||
Cards []Card
|
||||
}
|
||||
|
||||
// UpdateContact updates contact identified by contact ID. Modified contact is specified by cards.
|
||||
func (c *client) UpdateContact(id string, cards []Card) (res *UpdateContactResponse, err error) {
|
||||
reqBody := UpdateContactReq{
|
||||
Cards: cards,
|
||||
}
|
||||
req, err := c.NewJSONRequest("PUT", "/contacts/"+id, reqBody)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var updateContactRes UpdateContactResponse
|
||||
if err = c.DoJSON(req, &updateContactRes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
res, err = &updateContactRes, updateContactRes.Err()
|
||||
return
|
||||
}
|
||||
|
||||
type SingleIDResponse struct {
|
||||
Res
|
||||
ID string
|
||||
}
|
||||
|
||||
type UpdateContactGroupsResponse struct {
|
||||
Res
|
||||
Response SingleIDResponse
|
||||
}
|
||||
|
||||
func (c *client) AddContactGroups(groupID string, contactEmailIDs []string) (res *UpdateContactGroupsResponse, err error) {
|
||||
return c.modifyContactGroups(groupID, addContactGroupsAction, contactEmailIDs)
|
||||
}
|
||||
|
||||
func (c *client) RemoveContactGroups(groupID string, contactEmailIDs []string) (res *UpdateContactGroupsResponse, err error) {
|
||||
return c.modifyContactGroups(groupID, removeContactGroupsAction, contactEmailIDs)
|
||||
}
|
||||
|
||||
const (
|
||||
removeContactGroupsAction = 0
|
||||
addContactGroupsAction = 1
|
||||
)
|
||||
|
||||
type ModifyContactGroupsReq struct {
|
||||
LabelID string
|
||||
Action int
|
||||
ContactEmailIDs []string
|
||||
}
|
||||
|
||||
func (c *client) modifyContactGroups(groupID string, modifyContactGroupsAction int, contactEmailIDs []string) (res *UpdateContactGroupsResponse, err error) {
|
||||
reqBody := ModifyContactGroupsReq{
|
||||
LabelID: groupID,
|
||||
Action: modifyContactGroupsAction,
|
||||
ContactEmailIDs: contactEmailIDs,
|
||||
}
|
||||
req, err := c.NewJSONRequest("PUT", "/contacts/group", reqBody)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
err = res.Err()
|
||||
return
|
||||
}
|
||||
|
||||
// ================================= DELETE =======================================
|
||||
|
||||
type DeleteReq struct {
|
||||
IDs []string
|
||||
}
|
||||
|
||||
// DeleteContacts deletes contacts specified by an array of contact IDs.
|
||||
func (c *client) DeleteContacts(ids []string) (err error) {
|
||||
deleteReq := DeleteReq{
|
||||
IDs: ids,
|
||||
}
|
||||
|
||||
req, err := c.NewJSONRequest("PUT", "/contacts/delete", deleteReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
type DeleteContactsRes struct {
|
||||
Res
|
||||
Responses []struct {
|
||||
ID string
|
||||
Response Res
|
||||
}
|
||||
}
|
||||
var res DeleteContactsRes
|
||||
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
if err = res.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeleteAllContacts deletes all contacts.
|
||||
func (c *client) DeleteAllContacts() (err error) {
|
||||
req, err := c.NewRequest("DELETE", "/contacts", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var res Res
|
||||
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
if err = res.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ===================== Private utility methods =======================
|
||||
|
||||
func isSignedCardType(cardType int) bool {
|
||||
return (cardType & CardSigned) == CardSigned
|
||||
}
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
@ -34,221 +34,6 @@ var (
|
||||
EncryptedSignedCard = 3
|
||||
)
|
||||
|
||||
var testAddContactsReq = AddContactsReq{
|
||||
ContactsCards: ContactsCards{
|
||||
Contacts: []CardsList{
|
||||
{
|
||||
Cards: []Card{
|
||||
{
|
||||
Type: 2,
|
||||
Data: `BEGIN:VCARD
|
||||
VERSION:4.0
|
||||
FN;TYPE=fn:Bob
|
||||
item1.EMAIL:bob.tester@protonmail.com
|
||||
UID:proton-web-cd974706-5cde-0e53-e131-c49c88a92ece
|
||||
END:VCARD
|
||||
`,
|
||||
Signature: ``,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Overwrite: 0,
|
||||
Groups: 0,
|
||||
Labels: 0,
|
||||
}
|
||||
|
||||
var testAddContactsResponseBody = `{
|
||||
"Code": 1001,
|
||||
"Responses": [
|
||||
{
|
||||
"Index": 0,
|
||||
"Response": {
|
||||
"Code": 1000,
|
||||
"Contact": {
|
||||
"ID": "EU7qYvPAdgJ-zl53hw_btO1WG8TN2FYh2cTIFq1_T6KqulwgxF8CzPjVk_RBUdEejtLvfynlelVNoZwMK_9X2g==",
|
||||
"Name": "Bob",
|
||||
"UID": "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
|
||||
"Size": 139,
|
||||
"CreateTime": 1517319495,
|
||||
"ModifyTime": 1517319495,
|
||||
"ContactEmails": [
|
||||
{
|
||||
"ID": "VT4NoPeQPk48_vg0CVmk63n5mB6CZn9q-P_DYODhOUemhuzUkgBFGF1MktVArjX5zsVdfVlEBFObvt0_K5NwPg==",
|
||||
"Name": "Bob",
|
||||
"Email": "bob.tester@protonmail.com",
|
||||
"Type": [],
|
||||
"Defaults": 1,
|
||||
"Order": 1,
|
||||
"ContactID": "EU7qYvPAdgJ-zl53hw_btO1WG8TN2FYh2cTIFq1_T6KqulwgxF8CzPjVk_RBUdEejtLvfynlelVNoZwMK_9X2g==",
|
||||
"LabelIDs": []
|
||||
}
|
||||
],
|
||||
"LabelIDs": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
var testContactCreated = &AddContactsResponse{
|
||||
Res: Res{
|
||||
Code: 1001,
|
||||
StatusCode: 200,
|
||||
},
|
||||
Responses: []IndexedContactResponse{
|
||||
{
|
||||
Index: 0,
|
||||
Response: SingleContactResponse{
|
||||
Res: Res{
|
||||
Code: 1000,
|
||||
},
|
||||
Contact: Contact{
|
||||
ID: "EU7qYvPAdgJ-zl53hw_btO1WG8TN2FYh2cTIFq1_T6KqulwgxF8CzPjVk_RBUdEejtLvfynlelVNoZwMK_9X2g==",
|
||||
Name: "Bob",
|
||||
UID: "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
|
||||
Size: 139,
|
||||
CreateTime: 1517319495,
|
||||
ModifyTime: 1517319495,
|
||||
ContactEmails: []ContactEmail{
|
||||
{
|
||||
ID: "VT4NoPeQPk48_vg0CVmk63n5mB6CZn9q-P_DYODhOUemhuzUkgBFGF1MktVArjX5zsVdfVlEBFObvt0_K5NwPg==",
|
||||
Name: "Bob",
|
||||
Email: "bob.tester@protonmail.com",
|
||||
Type: []string{},
|
||||
Defaults: 1,
|
||||
Order: 1,
|
||||
ContactID: "EU7qYvPAdgJ-zl53hw_btO1WG8TN2FYh2cTIFq1_T6KqulwgxF8CzPjVk_RBUdEejtLvfynlelVNoZwMK_9X2g==",
|
||||
LabelIDs: []string{},
|
||||
},
|
||||
},
|
||||
LabelIDs: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var testContactUpdated = &UpdateContactResponse{
|
||||
Res: Res{
|
||||
Code: 1000,
|
||||
StatusCode: 200,
|
||||
},
|
||||
Contact: Contact{
|
||||
ID: "l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ==",
|
||||
Name: "Bob",
|
||||
UID: "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
|
||||
Size: 303,
|
||||
CreateTime: 1517416603,
|
||||
ModifyTime: 1517416656,
|
||||
ContactEmails: []ContactEmail{
|
||||
{
|
||||
ID: "14n6vuf1zbeo3zsYzgV471S6xJ9gzl7-VZ8tcOTQq6ifBlNEre0SUdUM7sXh6e2Q_4NhJZaU9c7jLdB1HCV6dA==",
|
||||
Name: "Bob",
|
||||
Email: "bob.changed.tester@protonmail.com",
|
||||
Type: []string{},
|
||||
Defaults: 1,
|
||||
Order: 1,
|
||||
ContactID: "l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ==",
|
||||
LabelIDs: []string{},
|
||||
},
|
||||
},
|
||||
LabelIDs: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestContact_AddContact(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/contacts"))
|
||||
|
||||
var addContactsReq AddContactsReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&addContactsReq); err != nil {
|
||||
t.Error("Expecting no error while reading request body, got:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(testAddContactsReq.ContactsCards, addContactsReq.ContactsCards) {
|
||||
t.Errorf("Invalid contacts request: expected %+v but got %+v", testAddContactsReq.ContactsCards, addContactsReq.ContactsCards)
|
||||
}
|
||||
|
||||
fmt.Fprint(w, testAddContactsResponseBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
created, err := c.AddContacts(testAddContactsReq.ContactsCards, 0, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while adding contact, got:", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(created, testContactCreated) {
|
||||
t.Fatalf("Invalid created contact: expected %+v, got %+v", testContactCreated, created)
|
||||
}
|
||||
}
|
||||
|
||||
var testGetContactsResponseBody = `{
|
||||
"Code": 1000,
|
||||
"Contacts": [
|
||||
{
|
||||
"ID": "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
|
||||
"Name": "Alice",
|
||||
"UID": "proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b",
|
||||
"Size": 243,
|
||||
"CreateTime": 1517395498,
|
||||
"ModifyTime": 1517395498,
|
||||
"LabelIDs": []
|
||||
},
|
||||
{
|
||||
"ID": "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
|
||||
"Name": "Bob",
|
||||
"UID": "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
|
||||
"Size": 303,
|
||||
"CreateTime": 1517394677,
|
||||
"ModifyTime": 1517394678,
|
||||
"LabelIDs": []
|
||||
}
|
||||
],
|
||||
"Total": 2
|
||||
}`
|
||||
|
||||
var testGetContacts = []*Contact{
|
||||
|
||||
{
|
||||
ID: "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
|
||||
Name: "Alice",
|
||||
UID: "proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b",
|
||||
Size: 243,
|
||||
CreateTime: 1517395498,
|
||||
ModifyTime: 1517395498,
|
||||
LabelIDs: []string{},
|
||||
},
|
||||
{
|
||||
ID: "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
|
||||
Name: "Bob",
|
||||
UID: "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
|
||||
Size: 303,
|
||||
CreateTime: 1517394677,
|
||||
ModifyTime: 1517394678,
|
||||
LabelIDs: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestContact_GetContacts(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/contacts?Page=0&PageSize=1000"))
|
||||
|
||||
fmt.Fprint(w, testGetContactsResponseBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
contacts, err := c.GetContacts(0, 1000)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while getting contacts, got:", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(contacts, testGetContacts) {
|
||||
t.Fatalf("Invalid created contact: expected %+v, got %+v", testGetContacts, contacts)
|
||||
}
|
||||
}
|
||||
|
||||
var testGetContactByIDResponseBody = `{
|
||||
"Code": 1000,
|
||||
"Contact": {
|
||||
@ -321,14 +106,16 @@ var testGetContactByID = Contact{
|
||||
}
|
||||
|
||||
func TestContact_GetContactById(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/contacts/s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg=="))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/contacts/v4/s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg=="))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testGetContactByIDResponseBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
contact, err := c.GetContactByID("s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==")
|
||||
contact, err := c.GetContactByID(context.TODO(), "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==")
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while getting contacts, got:", err)
|
||||
}
|
||||
@ -338,287 +125,6 @@ func TestContact_GetContactById(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var testGetContactsForExportResponseBody = `{
|
||||
"Code": 1000,
|
||||
"Contacts": [
|
||||
{
|
||||
"ID": "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
|
||||
"Cards": [
|
||||
{
|
||||
"Type": 2,
|
||||
"Data": "BEGIN:VCARD\nVERSION:4.0\nFN;TYPE=fn:Bob\nitem1.EMAIL:bob.changed.tester@protonmail.com\nUID:proton-web-cd974706-5cde-0e53-e131-c49c88a92ece\nEND:VCARD\n",
|
||||
"Signature": "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZr2CRDMO9BwcW4mpAAAtxwIAFGgPO+xH4PHppffQC1R\nxCp/Bjzaq5rDUE3ZMKVJ1sFqGVlq2bP5CIN4w2XCe/MuZ+z2o87fSEtt2n7i\n0/8Ah35u4czn7t8FZoW8u9WwHPURa8gUbP3fYpVASBY1Bt2fUxJrSUYn5KQp\njJM/DgF99bhIjOTuhx9IN7DFKG647Arq+GJh9M6RJNxkb3CBfcCVUXoIwMB7\nnM/fA1r+mcl8dQam0WKVJgy9aO2XUUR62w1SpqJlXY3z8hKvXjjskzU3DQk5\net07RLVQvhy2nCZePsM+TJzL8OBbTa1aF/p1xPe+HND7t3ZCm9tQOY+UhK5H\nbhPbQY48KGdci1dTcm2HbsQ=\n=iOnV\n-----END PGP SIGNATURE-----\n"
|
||||
},
|
||||
{
|
||||
"Type": 3,
|
||||
"Data": "-----BEGIN PGP MESSAGE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwcBMA1vYAFKnBP8gAQgABsQWnqZadqrHDN43McGhEYfJjOB66R5HhkQAUavP\nHaAHpJciGxfz6tbztQu4C6kdMA80ElbD8c+bJqalw6ZbT4seoP4TTQLykD1n\n0LuNBlaW4x8kfd8rZzFdckk/dY2PruX6byAjSZslnZlZSwp99AJJbvJtfXRR\nzunKMbDieRkaApGZYT25wT5mz1embpXFesvO4nDkOEQCa0uyti3mNSLhYlf/\ntbaOS3WM9VYM9eB9YRZGzJNxMtTxOsd45tBlGCHnCzWEUnJdqZuYzH2QOky7\nMckXhk6YwyemYi/q7OOgSYEg/0lCs2EK3b//14yPDx8Bj5G7rZrnDgsP+BHj\nu9KaAZb2pSBPQoJ2DY3Y4A2Sg8GjaX5CMO9D6GKJkZSYkXddQgcmw7sVPUS+\n+5JaPXlfxoJOOn9kj9A6LDC6eMhYaLujG1BKcZ16DB0jqfwMnPLJ+bYEdatr\nKMvd9rbdsDwQ/tfk11VvHpiEBCNZjxM2+bdBLl9q2EXaLXi+dz/rJg5C0A9u\nNS2CzCUvg6+jNUzHo/RBfRXvlNV8tw==\n=mE2b\n-----END PGP MESSAGE-----\n",
|
||||
"Signature": "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZr2CRDMO9BwcW4mpAAApucIAD/uwWuV6DOg127XIPG6\n/jluL8jmwyCJX9noL6S8ZVMOymziKSh4/P1QyMPC5SL4lMPEiuaEdyetfBkU\n+5hW3tcZ+ptxmDi59SVYqmXTVewPgeB7t8c5nbzCuVuzA7ZAo8HAXHzFVQDS\nj9fKVGjZzQkmlwdcfnkXHAF0Ejilv9wxOOYgqVDuzm7JXVF3Um7nAgGKTJE5\n5CNnrEjmJGapj96mQFwXzET/kAhNIBw9tL5FAkDlKImdw8C0w9sXdvDu3yVM\ntvUZ5o2rR6ft0SC1byFso49vgJ/syeK6P2pPzltZJbsp4MvmlPUB0/G1XRU+\nI7q4IOWCvs8RD88ADmOty2o=\n=hyZE\n-----END PGP SIGNATURE-----\n"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"ID": "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
|
||||
"Cards": [
|
||||
{
|
||||
"Type": 3,
|
||||
"Data": "-----BEGIN PGP MESSAGE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwcBMA1vYAFKnBP8gAQf/RnOQRpo8DVJHQSJRgckEaUQvdMcADiM4L23diyiS\nQfclby/Ve2WInmvZc2RJ3rWENfeqyDZE6krQT642pKiW09GOIyVIjl+hje9y\nE4HBX0AIAWv7QhhKX6UZcM5dYSFbV3j3QxQB8A4Thng2G6ltotMTlbtcHbhu\n96Lt6ngA1tngXLSF5seyflnoiSQ5gLi2qVzrd95dIP6D4Ottcp929/4hDGmq\nPyxw9dColx6gVd1bmIDSI6ewkET4Grmo6QYqjSvjqLOf0PqHKzqypSFLkI5l\nmmnWKYTQCgl9wX+hq6Qz5E+m/BtbkdeX0YxYUss2e+oSAzJmnfdETErG9U5z\n3NJqAc3sgdwDzfWHBzogAxAbDHiqrF6zMlR5SFvZ6nRU7M2DTOE5dJhf+zOp\n1WSKn5LR46LGyt0m5wJPDjaGyQdPffAO4EULvwhGENe10UxRjY1qcUmjYOtS\nunl/vh3afI9PC1jj+HHJD2VgCA==\n=UpcY\n-----END PGP MESSAGE-----\n",
|
||||
"Signature": "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4pCRDMO9BwcW4mpAAA6h0H/2+97koXzly5pu9hpbaW\n75d1Q976RjMr5DjAx6tKFtSzznel8YfWgvA6OQmMGdPY8ae7/+3mwCJZYWy/\nXVvUfCSflmYpSIKGfP+Vm1XezWY1W84DGhiFj5n8sdaWisv3bpFwFf1YR3Ae\noBoZ4ufNzaQALRqGPMgXETtXZCtzuL/+0vGSKj5SLECiRcSE4jCPEVRy2bcl\nWJyB9r4VmcjF042OMHxphXoYmTEWvgigyaQFHNORu5cK9EHfHpCG6IcjGbdx\n+9Px5YnDY1ix+YpBKePGSTlLE0u6ow0VTUrdvNjl7IUBaRcfJcIIdgCBOTMw\n1uQ/yeyP46V5AFXFnIKeZeQ=\n=FlOf\n-----END PGP SIGNATURE-----\n"
|
||||
},
|
||||
{
|
||||
"Type": 2,
|
||||
"Data": "BEGIN:VCARD\nVERSION:4.0\nFN;TYPE=fn:Alice\nitem1.EMAIL:alice@protonmail.com\nUID:proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b\nEND:VCARD",
|
||||
"Signature": "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4qCRDMO9BwcW4mpAAA3jUIAJ88mIyO8Yj0+evSFXnK\nNxNdjNgn7t1leY0BWlh1nkK76XrZEPipdw2QU8cOcZzn1Wby2SGfZVkwoPc4\nzAhPT4WKbkFVqXhDry5399kLwGYJCxdEcw/oPyYj+YgpQKMxhTrQq21tbEwr\n7JDRBXgi3Cckh/XsteFHOIiAVnM7BV6zFudipnYxa4uNF0Bf4VbUZx1Mm0Wb\nMJaGsO5reqQUQzDPO5TdSAZ8qGSdjVv7RESgUu5DckcDSsnB987Zbh9uFc22\nfPYmb6zA0cEZh3dAjpDPT7cg8hlvfYBb+kP3sLFyLiIkdEG8Pcagjf0k+l76\nr1IsPlYBx2LJmsJf+WDNlj8=\n=Xn+3\n-----END PGP SIGNATURE-----\n"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"Total": 2
|
||||
}`
|
||||
|
||||
var testGetContactsForExport = []Contact{
|
||||
{
|
||||
ID: "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
|
||||
Cards: []Card{
|
||||
{
|
||||
Type: 2,
|
||||
Data: "BEGIN:VCARD\nVERSION:4.0\nFN;TYPE=fn:Bob\nitem1.EMAIL:bob.changed.tester@protonmail.com\nUID:proton-web-cd974706-5cde-0e53-e131-c49c88a92ece\nEND:VCARD\n",
|
||||
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZr2CRDMO9BwcW4mpAAAtxwIAFGgPO+xH4PHppffQC1R\nxCp/Bjzaq5rDUE3ZMKVJ1sFqGVlq2bP5CIN4w2XCe/MuZ+z2o87fSEtt2n7i\n0/8Ah35u4czn7t8FZoW8u9WwHPURa8gUbP3fYpVASBY1Bt2fUxJrSUYn5KQp\njJM/DgF99bhIjOTuhx9IN7DFKG647Arq+GJh9M6RJNxkb3CBfcCVUXoIwMB7\nnM/fA1r+mcl8dQam0WKVJgy9aO2XUUR62w1SpqJlXY3z8hKvXjjskzU3DQk5\net07RLVQvhy2nCZePsM+TJzL8OBbTa1aF/p1xPe+HND7t3ZCm9tQOY+UhK5H\nbhPbQY48KGdci1dTcm2HbsQ=\n=iOnV\n-----END PGP SIGNATURE-----\n",
|
||||
},
|
||||
{
|
||||
Type: 3,
|
||||
Data: "-----BEGIN PGP MESSAGE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwcBMA1vYAFKnBP8gAQgABsQWnqZadqrHDN43McGhEYfJjOB66R5HhkQAUavP\nHaAHpJciGxfz6tbztQu4C6kdMA80ElbD8c+bJqalw6ZbT4seoP4TTQLykD1n\n0LuNBlaW4x8kfd8rZzFdckk/dY2PruX6byAjSZslnZlZSwp99AJJbvJtfXRR\nzunKMbDieRkaApGZYT25wT5mz1embpXFesvO4nDkOEQCa0uyti3mNSLhYlf/\ntbaOS3WM9VYM9eB9YRZGzJNxMtTxOsd45tBlGCHnCzWEUnJdqZuYzH2QOky7\nMckXhk6YwyemYi/q7OOgSYEg/0lCs2EK3b//14yPDx8Bj5G7rZrnDgsP+BHj\nu9KaAZb2pSBPQoJ2DY3Y4A2Sg8GjaX5CMO9D6GKJkZSYkXddQgcmw7sVPUS+\n+5JaPXlfxoJOOn9kj9A6LDC6eMhYaLujG1BKcZ16DB0jqfwMnPLJ+bYEdatr\nKMvd9rbdsDwQ/tfk11VvHpiEBCNZjxM2+bdBLl9q2EXaLXi+dz/rJg5C0A9u\nNS2CzCUvg6+jNUzHo/RBfRXvlNV8tw==\n=mE2b\n-----END PGP MESSAGE-----\n",
|
||||
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZr2CRDMO9BwcW4mpAAApucIAD/uwWuV6DOg127XIPG6\n/jluL8jmwyCJX9noL6S8ZVMOymziKSh4/P1QyMPC5SL4lMPEiuaEdyetfBkU\n+5hW3tcZ+ptxmDi59SVYqmXTVewPgeB7t8c5nbzCuVuzA7ZAo8HAXHzFVQDS\nj9fKVGjZzQkmlwdcfnkXHAF0Ejilv9wxOOYgqVDuzm7JXVF3Um7nAgGKTJE5\n5CNnrEjmJGapj96mQFwXzET/kAhNIBw9tL5FAkDlKImdw8C0w9sXdvDu3yVM\ntvUZ5o2rR6ft0SC1byFso49vgJ/syeK6P2pPzltZJbsp4MvmlPUB0/G1XRU+\nI7q4IOWCvs8RD88ADmOty2o=\n=hyZE\n-----END PGP SIGNATURE-----\n",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
|
||||
Cards: []Card{
|
||||
{
|
||||
Type: 3,
|
||||
Data: "-----BEGIN PGP MESSAGE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwcBMA1vYAFKnBP8gAQf/RnOQRpo8DVJHQSJRgckEaUQvdMcADiM4L23diyiS\nQfclby/Ve2WInmvZc2RJ3rWENfeqyDZE6krQT642pKiW09GOIyVIjl+hje9y\nE4HBX0AIAWv7QhhKX6UZcM5dYSFbV3j3QxQB8A4Thng2G6ltotMTlbtcHbhu\n96Lt6ngA1tngXLSF5seyflnoiSQ5gLi2qVzrd95dIP6D4Ottcp929/4hDGmq\nPyxw9dColx6gVd1bmIDSI6ewkET4Grmo6QYqjSvjqLOf0PqHKzqypSFLkI5l\nmmnWKYTQCgl9wX+hq6Qz5E+m/BtbkdeX0YxYUss2e+oSAzJmnfdETErG9U5z\n3NJqAc3sgdwDzfWHBzogAxAbDHiqrF6zMlR5SFvZ6nRU7M2DTOE5dJhf+zOp\n1WSKn5LR46LGyt0m5wJPDjaGyQdPffAO4EULvwhGENe10UxRjY1qcUmjYOtS\nunl/vh3afI9PC1jj+HHJD2VgCA==\n=UpcY\n-----END PGP MESSAGE-----\n",
|
||||
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4pCRDMO9BwcW4mpAAA6h0H/2+97koXzly5pu9hpbaW\n75d1Q976RjMr5DjAx6tKFtSzznel8YfWgvA6OQmMGdPY8ae7/+3mwCJZYWy/\nXVvUfCSflmYpSIKGfP+Vm1XezWY1W84DGhiFj5n8sdaWisv3bpFwFf1YR3Ae\noBoZ4ufNzaQALRqGPMgXETtXZCtzuL/+0vGSKj5SLECiRcSE4jCPEVRy2bcl\nWJyB9r4VmcjF042OMHxphXoYmTEWvgigyaQFHNORu5cK9EHfHpCG6IcjGbdx\n+9Px5YnDY1ix+YpBKePGSTlLE0u6ow0VTUrdvNjl7IUBaRcfJcIIdgCBOTMw\n1uQ/yeyP46V5AFXFnIKeZeQ=\n=FlOf\n-----END PGP SIGNATURE-----\n",
|
||||
},
|
||||
{
|
||||
Type: 2,
|
||||
Data: "BEGIN:VCARD\nVERSION:4.0\nFN;TYPE=fn:Alice\nitem1.EMAIL:alice@protonmail.com\nUID:proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b\nEND:VCARD",
|
||||
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4qCRDMO9BwcW4mpAAA3jUIAJ88mIyO8Yj0+evSFXnK\nNxNdjNgn7t1leY0BWlh1nkK76XrZEPipdw2QU8cOcZzn1Wby2SGfZVkwoPc4\nzAhPT4WKbkFVqXhDry5399kLwGYJCxdEcw/oPyYj+YgpQKMxhTrQq21tbEwr\n7JDRBXgi3Cckh/XsteFHOIiAVnM7BV6zFudipnYxa4uNF0Bf4VbUZx1Mm0Wb\nMJaGsO5reqQUQzDPO5TdSAZ8qGSdjVv7RESgUu5DckcDSsnB987Zbh9uFc22\nfPYmb6zA0cEZh3dAjpDPT7cg8hlvfYBb+kP3sLFyLiIkdEG8Pcagjf0k+l76\nr1IsPlYBx2LJmsJf+WDNlj8=\n=Xn+3\n-----END PGP SIGNATURE-----\n",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestContact_GetContactsForExport(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/contacts/export?Page=0&PageSize=1000"))
|
||||
|
||||
fmt.Fprint(w, testGetContactsForExportResponseBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
contacts, err := c.GetContactsForExport(0, 1000)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while getting contacts for export, got:", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(contacts, testGetContactsForExport) {
|
||||
t.Fatalf("Invalid contact for export: expected %+v, got %+v", testGetContactsForExport, contacts)
|
||||
}
|
||||
}
|
||||
|
||||
var testGetContactsEmailsResponseBody = `{
|
||||
"Code": 1000,
|
||||
"ContactEmails": [
|
||||
{
|
||||
"ID": "Hgyz1tG0OiC2v_hMIVOa6juMOAp_recWNzWII7a79Tfwdx08Jy3FJY0_Y_UtFYwbi6mN-Xx1sOI9_GmUGAcwWg==",
|
||||
"Name": "Bob",
|
||||
"Email": "bob.changed.tester@protonmail.com",
|
||||
"Type": [],
|
||||
"Defaults": 1,
|
||||
"Order": 1,
|
||||
"ContactID": "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
|
||||
"LabelIDs": []
|
||||
},
|
||||
{
|
||||
"ID": "4m2sBxLq4McqD0D330Kuy5xG-yyDNXyLEjG5_RYcjy9X-3qHGNP07DNOWLY40TYtUAQr4fAVp8zOcZ_z2o6H-A==",
|
||||
"Name": "Alice",
|
||||
"Email": "alice@protonmail.com",
|
||||
"Type": [],
|
||||
"Defaults": 1,
|
||||
"Order": 1,
|
||||
"ContactID": "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
|
||||
"LabelIDs": []
|
||||
}
|
||||
],
|
||||
"Total": 2
|
||||
}`
|
||||
|
||||
var testGetContactsEmails = []ContactEmail{
|
||||
{
|
||||
ID: "Hgyz1tG0OiC2v_hMIVOa6juMOAp_recWNzWII7a79Tfwdx08Jy3FJY0_Y_UtFYwbi6mN-Xx1sOI9_GmUGAcwWg==",
|
||||
Name: "Bob",
|
||||
Email: "bob.changed.tester@protonmail.com",
|
||||
Type: []string{},
|
||||
Defaults: 1,
|
||||
Order: 1,
|
||||
ContactID: "c6CWuyEE6mMRApAxvvCO9MQKydTU8Do1iikL__M5MoWWjDEebzChAUx-73qa1jTV54RzFO5p9pLBPsIIgCwpww==",
|
||||
LabelIDs: []string{},
|
||||
},
|
||||
{
|
||||
ID: "4m2sBxLq4McqD0D330Kuy5xG-yyDNXyLEjG5_RYcjy9X-3qHGNP07DNOWLY40TYtUAQr4fAVp8zOcZ_z2o6H-A==",
|
||||
Name: "Alice",
|
||||
Email: "alice@protonmail.com",
|
||||
Type: []string{},
|
||||
Defaults: 1,
|
||||
Order: 1,
|
||||
ContactID: "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
|
||||
LabelIDs: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestContact_GetAllContactsEmails(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/contacts/emails?Page=0&PageSize=1000"))
|
||||
|
||||
fmt.Fprint(w, testGetContactsEmailsResponseBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
contactsEmails, err := c.GetAllContactsEmails(0, 1000)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while getting contacts for export, got:", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(contactsEmails, testGetContactsEmails) {
|
||||
t.Fatalf("Invalid contact for export: expected %+v, got %+v", testGetContactsEmails, contactsEmails)
|
||||
}
|
||||
}
|
||||
|
||||
var testUpdateContactReq = UpdateContactReq{
|
||||
Cards: []Card{
|
||||
{
|
||||
Type: 2,
|
||||
Data: `BEGIN:VCARD
|
||||
VERSION:4.0
|
||||
FN;TYPE=fn:Bob
|
||||
item1.EMAIL:bob.changed.tester@protonmail.com
|
||||
UID:proton-web-cd974706-5cde-0e53-e131-c49c88a92ece
|
||||
END:VCARD
|
||||
`,
|
||||
Signature: ``,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var testUpdateContactResponseBody = `{
|
||||
"Code": 1000,
|
||||
"Contact": {
|
||||
"ID": "l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ==",
|
||||
"Name": "Bob",
|
||||
"UID": "proton-web-cd974706-5cde-0e53-e131-c49c88a92ece",
|
||||
"Size": 303,
|
||||
"CreateTime": 1517416603,
|
||||
"ModifyTime": 1517416656,
|
||||
"ContactEmails": [
|
||||
{
|
||||
"ID": "14n6vuf1zbeo3zsYzgV471S6xJ9gzl7-VZ8tcOTQq6ifBlNEre0SUdUM7sXh6e2Q_4NhJZaU9c7jLdB1HCV6dA==",
|
||||
"Name": "Bob",
|
||||
"Email": "bob.changed.tester@protonmail.com",
|
||||
"Type": [],
|
||||
"Defaults": 1,
|
||||
"Order": 1,
|
||||
"ContactID": "l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ==",
|
||||
"LabelIDs": []
|
||||
}
|
||||
],
|
||||
"LabelIDs": []
|
||||
}
|
||||
}`
|
||||
|
||||
func TestContact_UpdateContact(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "PUT", "/contacts/l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ=="))
|
||||
|
||||
var updateContactReq UpdateContactReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&updateContactReq); err != nil {
|
||||
t.Error("Expecting no error while reading request body, got:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(testUpdateContactReq.Cards, updateContactReq.Cards) {
|
||||
t.Errorf("Invalid contacts request: expected %+v but got %+v", testUpdateContactReq.Cards, updateContactReq.Cards)
|
||||
}
|
||||
|
||||
fmt.Fprint(w, testUpdateContactResponseBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
created, err := c.UpdateContact("l4PrVkmDsIIDba9aln829uwPK0nnyWZHnFtrsyb7CJsYgrD6JTVTuuoaVmaANfO2jIVxzZ2vtbt74rznGjjwFQ==", testUpdateContactReq.Cards)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while updating contact, got:", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(created, testContactUpdated) {
|
||||
t.Fatalf("Invalid updated contact: expected\n%+v\ngot\n%+v\n", testContactUpdated, created)
|
||||
}
|
||||
}
|
||||
|
||||
var testDeleteContactsReq = DeleteReq{
|
||||
IDs: []string{
|
||||
"s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
|
||||
},
|
||||
}
|
||||
|
||||
var testDeleteContactsResponseBody = `{
|
||||
"Code": 1001,
|
||||
"Responses": [
|
||||
{
|
||||
"ID": "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
|
||||
"Response": {
|
||||
"Code": 1000
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
func TestContact_DeleteContacts(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "PUT", "/contacts/delete"))
|
||||
|
||||
var deleteContactsReq DeleteReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&deleteContactsReq); err != nil {
|
||||
t.Error("Expecting no error while reading request body, got:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(testDeleteContactsReq.IDs, deleteContactsReq.IDs) {
|
||||
t.Errorf("Invalid delete contacts request: expected %+v but got %+v", deleteContactsReq.IDs, testDeleteContactsReq.IDs)
|
||||
}
|
||||
|
||||
fmt.Fprint(w, testDeleteContactsResponseBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
err := c.DeleteContacts([]string{"s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg=="})
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while getting contacts for export, got:", err)
|
||||
}
|
||||
}
|
||||
|
||||
var testDeleteAllResponseBody = `{
|
||||
"Code": 1000
|
||||
}`
|
||||
|
||||
func TestContact_DeleteAllContacts(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "DELETE", "/contacts"))
|
||||
|
||||
fmt.Fprint(w, testDeleteAllResponseBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
err := c.DeleteAllContacts()
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while getting contacts for export, got:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContact_isSignedCardType(t *testing.T) {
|
||||
if !isSignedCardType(SignedCard) || !isSignedCardType(EncryptedSignedCard) {
|
||||
t.Fatal("isSignedCardType shouldn't return false for signed card types")
|
||||
@ -654,7 +160,7 @@ var testCardsCleartext = []Card{
|
||||
}
|
||||
|
||||
func TestClient_Encrypt(t *testing.T) {
|
||||
c := newTestClient(newTestClientManager(testClientConfig))
|
||||
c := newClient(newManager(DefaultConfig), "")
|
||||
c.userKeyRing = testPrivateKeyRing
|
||||
|
||||
cardEncrypted, err := c.EncryptAndSignCards(testCardsCleartext)
|
||||
@ -668,7 +174,7 @@ func TestClient_Encrypt(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_Decrypt(t *testing.T) {
|
||||
c := newTestClient(newTestClientManager(testClientConfig))
|
||||
c := newClient(newManager(DefaultConfig), "")
|
||||
c.userKeyRing = testPrivateKeyRing
|
||||
|
||||
cardCleartext, err := c.DecryptAndVerifyCards(testCardsEncrypted)
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
// ConversationsCount have same structure as MessagesCount.
|
||||
type ConversationsCount MessagesCount
|
||||
|
||||
// ConversationsCountsRes holds response from server.
|
||||
type ConversationsCountsRes struct {
|
||||
Res
|
||||
|
||||
Counts []*ConversationsCount
|
||||
}
|
||||
|
||||
// Conversation contains one body and multiple metadata.
|
||||
type Conversation struct{}
|
||||
|
||||
// CountConversations counts conversations by label.
|
||||
func (c *client) CountConversations(addressID string) (counts []*ConversationsCount, err error) {
|
||||
reqURL := "/mail/v4/conversations/count"
|
||||
if addressID != "" {
|
||||
reqURL += ("?AddressID=" + addressID)
|
||||
}
|
||||
req, err := c.NewRequest("GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var res ConversationsCountsRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
counts, err = res.Counts, res.Err()
|
||||
return
|
||||
}
|
||||
19
pkg/pmapi/data_test.go
Normal file
19
pkg/pmapi/data_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package pmapi
|
||||
|
||||
import "github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
|
||||
var testIdentity = &crypto.Identity{
|
||||
Name: "UserID",
|
||||
Email: "",
|
||||
}
|
||||
|
||||
const (
|
||||
testUsername = "jason"
|
||||
testAPIPassword = "apple"
|
||||
|
||||
testUID = "729ad6012421d67ad26950dc898bebe3a6e3caa2" //nolint[gosec]
|
||||
testAccessToken = "de0423049b44243afeec7d9c1d99be7b46da1e8a" //nolint[gosec]
|
||||
testAccessTokenOld = "feb3159ac63fb05119bcf4480d939278aa746926" //nolint[gosec]
|
||||
testRefreshToken = "a49b98256745bb497bec20e9b55f5de16f01fb52" //nolint[gosec]
|
||||
testRefreshTokenNew = "b894b4c4f20003f12d486900d8b88c7d68e67235" //nolint[gosec]
|
||||
)
|
||||
@ -1,43 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
func printBytes(body []byte) string {
|
||||
if utf8.Valid(body) {
|
||||
return string(body)
|
||||
}
|
||||
enc := []rune{}
|
||||
for _, b := range body {
|
||||
switch {
|
||||
case b == 9:
|
||||
enc = append(enc, rune('⟼'))
|
||||
case b == 13:
|
||||
enc = append(enc, rune('↵'))
|
||||
case b < 32, b == 127:
|
||||
enc = append(enc, '◡')
|
||||
case b > 31 && b < 127, b == 10:
|
||||
enc = append(enc, rune(b))
|
||||
default:
|
||||
enc = append(enc, 9728+rune(b))
|
||||
}
|
||||
}
|
||||
|
||||
return string(enc)
|
||||
}
|
||||
@ -1,72 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TLSDialer interface {
|
||||
DialTLS(network, address string) (conn net.Conn, err error)
|
||||
}
|
||||
|
||||
// CreateTransportWithDialer creates an http.Transport that uses the given dialer to make TLS connections.
|
||||
func CreateTransportWithDialer(dialer TLSDialer) *http.Transport {
|
||||
return &http.Transport{
|
||||
DialTLS: dialer.DialTLS,
|
||||
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 5 * time.Minute,
|
||||
ExpectContinueTimeout: 500 * time.Millisecond,
|
||||
|
||||
// GODT-126: this was initially 10s but logs from users showed a significant number
|
||||
// were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect.
|
||||
// Bumping to 30s for now to avoid this problem.
|
||||
ResponseHeaderTimeout: 30 * time.Second,
|
||||
|
||||
// If we allow up to 30 seconds for response headers, it is reasonable to allow up
|
||||
// to 30 seconds for the TLS handshake to take place.
|
||||
TLSHandshakeTimeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// BasicTLSDialer implements TLSDialer.
|
||||
type BasicTLSDialer struct{}
|
||||
|
||||
// NewBasicTLSDialer returns a new BasicTLSDialer.
|
||||
func NewBasicTLSDialer() *BasicTLSDialer {
|
||||
return &BasicTLSDialer{}
|
||||
}
|
||||
|
||||
// DialTLS returns a connection to the given address using the given network.
|
||||
func (b *BasicTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
|
||||
dialer := &net.Dialer{Timeout: 30 * time.Second} // Alternative Routes spec says this should be a 30s timeout.
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
|
||||
// If we are not dialing the standard API then we should skip cert verification checks.
|
||||
if address != rootURL {
|
||||
tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec]
|
||||
}
|
||||
|
||||
return tls.DialWithDialer(dialer, network, address, tlsConfig)
|
||||
}
|
||||
@ -1,93 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
|
||||
// to report errors if the fingerprint check fails.
|
||||
type PinningTLSDialer struct {
|
||||
dialer TLSDialer
|
||||
|
||||
// pinChecker is used to check TLS keys of connections.
|
||||
pinChecker *pinChecker
|
||||
|
||||
// tlsIssueNotifier is used to notify something when there is a TLS issue.
|
||||
tlsIssueNotifier func()
|
||||
|
||||
reporter *tlsReporter
|
||||
|
||||
// A logger for logging messages.
|
||||
log logrus.FieldLogger
|
||||
}
|
||||
|
||||
// NewPinningTLSDialer constructs a new dialer which only returns tcp connections to servers
|
||||
// which present known certificates.
|
||||
// If enabled, it reports any invalid certificates it finds.
|
||||
func NewPinningTLSDialer(dialer TLSDialer) *PinningTLSDialer {
|
||||
return &PinningTLSDialer{
|
||||
dialer: dialer,
|
||||
pinChecker: newPinChecker(TrustedAPIPins),
|
||||
log: logrus.WithField("pkg", "pmapi/tls-pinning"),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PinningTLSDialer) SetTLSIssueNotifier(notifier func()) {
|
||||
p.tlsIssueNotifier = notifier
|
||||
}
|
||||
|
||||
func (p *PinningTLSDialer) EnableRemoteTLSIssueReporting(cm *ClientManager) {
|
||||
p.reporter = newTLSReporter(p.pinChecker, cm)
|
||||
}
|
||||
|
||||
// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
|
||||
func (p *PinningTLSDialer) DialTLS(network, address string) (net.Conn, error) {
|
||||
conn, err := p.dialer.DialTLS(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.pinChecker.checkCertificate(conn); err != nil {
|
||||
if p.tlsIssueNotifier != nil {
|
||||
go p.tlsIssueNotifier()
|
||||
}
|
||||
|
||||
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
|
||||
p.reporter.reportCertIssue(
|
||||
TLSReportURI,
|
||||
host,
|
||||
port,
|
||||
tlsConn.ConnectionState(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
@ -1,141 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const liveAPI = "api.protonmail.ch"
|
||||
|
||||
var testLiveConfig = &ClientConfig{
|
||||
AppVersion: "Bridge_1.2.4-test",
|
||||
ClientID: "Bridge",
|
||||
}
|
||||
|
||||
func createAndSetPinningDialer(cm *ClientManager) (*int, *PinningTLSDialer) {
|
||||
called := 0
|
||||
|
||||
dialer := NewPinningTLSDialer(NewBasicTLSDialer())
|
||||
dialer.SetTLSIssueNotifier(func() { called++ })
|
||||
cm.SetRoundTripper(CreateTransportWithDialer(dialer))
|
||||
|
||||
return &called, dialer
|
||||
}
|
||||
|
||||
func TestTLSPinValid(t *testing.T) {
|
||||
cm := newTestClientManager(testLiveConfig)
|
||||
cm.host = liveAPI
|
||||
rootScheme = "https"
|
||||
called, _ := createAndSetPinningDialer(cm)
|
||||
client := cm.GetClient("pmapi" + t.Name())
|
||||
|
||||
_, err := client.AuthInfo("this.address.is.disabled")
|
||||
Ok(t, err)
|
||||
|
||||
Equals(t, 0, *called)
|
||||
}
|
||||
|
||||
func TestTLSPinBackup(t *testing.T) {
|
||||
cm := newTestClientManager(testLiveConfig)
|
||||
cm.host = liveAPI
|
||||
called, p := createAndSetPinningDialer(cm)
|
||||
p.pinChecker.trustedPins[1] = p.pinChecker.trustedPins[0]
|
||||
p.pinChecker.trustedPins[0] = ""
|
||||
|
||||
client := cm.GetClient("pmapi" + t.Name())
|
||||
|
||||
_, err := client.AuthInfo("this.address.is.disabled")
|
||||
Ok(t, err)
|
||||
|
||||
Equals(t, 0, *called)
|
||||
}
|
||||
|
||||
func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused]
|
||||
cm := newTestClientManager(testLiveConfig)
|
||||
cm.host = liveAPI
|
||||
|
||||
called, p := createAndSetPinningDialer(cm)
|
||||
for i := 0; i < len(p.pinChecker.trustedPins); i++ {
|
||||
p.pinChecker.trustedPins[i] = "testing"
|
||||
}
|
||||
|
||||
client := cm.GetClient("pmapi" + t.Name())
|
||||
|
||||
_, err := client.AuthInfo("this.address.is.disabled")
|
||||
Ok(t, err)
|
||||
|
||||
// check that it will be called only once per session
|
||||
client = cm.GetClient("pmapi" + t.Name())
|
||||
_, err = client.AuthInfo("this.address.is.disabled")
|
||||
Ok(t, err)
|
||||
|
||||
Equals(t, 1, *called)
|
||||
}
|
||||
|
||||
func _TestTLSPinInvalid(t *testing.T) { // nolint[unused]
|
||||
cm := newTestClientManager(testLiveConfig)
|
||||
|
||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSONResponsefromFile(t, w, "/auth/info/post_response.json", 0)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
called, _ := createAndSetPinningDialer(cm)
|
||||
|
||||
client := cm.GetClient("pmapi" + t.Name())
|
||||
|
||||
cm.host = liveAPI
|
||||
_, err := client.AuthInfo("this.address.is.disabled")
|
||||
Ok(t, err)
|
||||
|
||||
cm.host = ts.URL
|
||||
_, err = client.AuthInfo("this.address.is.disabled")
|
||||
Assert(t, err != nil, "error is expected but have %v", err)
|
||||
|
||||
Equals(t, 1, *called)
|
||||
}
|
||||
|
||||
// The tests below should pass but cannot run in CI due to proxy issues.
|
||||
|
||||
func _TestTLSSignedCertWrongPublicKey(t *testing.T) { // nolint[unused]
|
||||
cm := newTestClientManager(testLiveConfig)
|
||||
_, dialer := createAndSetPinningDialer(cm)
|
||||
_, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
|
||||
assert.Error(t, err, "expected dial to fail because of wrong public key")
|
||||
}
|
||||
|
||||
func _TestTLSSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused]
|
||||
cm := newTestClientManager(testLiveConfig)
|
||||
_, dialer := createAndSetPinningDialer(cm)
|
||||
dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`)
|
||||
_, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
|
||||
assert.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA")
|
||||
}
|
||||
|
||||
func _TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) { // nolint[unused]
|
||||
cm := newTestClientManager(testLiveConfig)
|
||||
_, dialer := createAndSetPinningDialer(cm)
|
||||
dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
|
||||
_, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443")
|
||||
assert.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
|
||||
}
|
||||
@ -1,61 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// ProxyTLSDialer wraps a TLSDialer to switch to a proxy if the initial dial fails.
|
||||
type ProxyTLSDialer struct {
|
||||
dialer TLSDialer
|
||||
|
||||
cm *ClientManager
|
||||
}
|
||||
|
||||
// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer.
|
||||
func NewProxyTLSDialer(dialer TLSDialer, cm *ClientManager) *ProxyTLSDialer {
|
||||
return &ProxyTLSDialer{
|
||||
dialer: dialer,
|
||||
cm: cm,
|
||||
}
|
||||
}
|
||||
|
||||
// DialTLS dials the given network/address. If it fails, it retries using a proxy.
|
||||
func (d *ProxyTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
|
||||
if conn, err = d.dialer.DialTLS(network, address); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
if !d.cm.allowProxy {
|
||||
return
|
||||
}
|
||||
|
||||
var proxy string
|
||||
|
||||
if proxy, err = d.cm.switchToReachableServer(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return d.dialer.DialTLS(network, net.JoinHostPort(proxy, port))
|
||||
}
|
||||
@ -1,74 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
)
|
||||
|
||||
// DownloadAndVerify downloads a file and its signature from the given locations `file` and `sig`.
|
||||
// The file and its signature are verified using the given keyring `kr`.
|
||||
// If the file is verified successfully, it can be read from the returned reader.
|
||||
// TLS fingerprinting is used to verify that connections are only made to known servers.
|
||||
func (c *client) DownloadAndVerify(file, sig string, kr *crypto.KeyRing) (io.Reader, error) {
|
||||
var fb, sb []byte
|
||||
|
||||
if err := c.fetchFile(file, func(r io.Reader) (err error) {
|
||||
fb, err = ioutil.ReadAll(r)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.fetchFile(sig, func(r io.Reader) (err error) {
|
||||
sb, err = ioutil.ReadAll(r)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := kr.VerifyDetached(
|
||||
crypto.NewPlainMessage(fb),
|
||||
crypto.NewPGPSignature(sb),
|
||||
crypto.GetUnixTime(),
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bytes.NewReader(fb), nil
|
||||
}
|
||||
|
||||
func (c *client) fetchFile(file string, fn func(io.Reader) error) error {
|
||||
res, err := c.hc.Get(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = res.Body.Close() }()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to get file: http error %v", res.StatusCode)
|
||||
}
|
||||
|
||||
return fn(res.Body)
|
||||
}
|
||||
9
pkg/pmapi/errors.go
Normal file
9
pkg/pmapi/errors.go
Normal file
@ -0,0 +1,9 @@
|
||||
package pmapi
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNoConnection = errors.New("no internet connection")
|
||||
ErrAPIFailure = errors.New("API returned an error")
|
||||
ErrUnauthorized = errors.New("API client is unauthorized")
|
||||
)
|
||||
@ -18,9 +18,11 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// Event represents changes since the last check.
|
||||
@ -137,7 +139,7 @@ type EventMessageUpdated struct {
|
||||
ID string
|
||||
|
||||
Subject *string
|
||||
Unread *int
|
||||
Unread *Boolean
|
||||
Flags *int64
|
||||
Sender *mail.Address
|
||||
ToList *[]*mail.Address
|
||||
@ -163,62 +165,28 @@ type EventAddress struct {
|
||||
Address *Address
|
||||
}
|
||||
|
||||
type EventRes struct {
|
||||
Res
|
||||
*Event
|
||||
}
|
||||
|
||||
type LatestEventRes struct {
|
||||
Res
|
||||
*Event
|
||||
}
|
||||
|
||||
// GetEvent returns a summary of events that occurred since last. To get the latest event,
|
||||
// provide an empty last value. The latest event is always empty.
|
||||
func (c *client) GetEvent(last string) (event *Event, err error) {
|
||||
return c.getEvent(last, 1)
|
||||
}
|
||||
|
||||
func (c *client) getEvent(last string, numberOfMergedEvents int) (event *Event, err error) {
|
||||
var req *http.Request
|
||||
if last == "" {
|
||||
req, err = c.NewRequest("GET", "/events/latest", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var res LatestEventRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
event, err = res.Event, res.Err()
|
||||
} else {
|
||||
req, err = c.NewRequest("GET", "/events/"+last, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var res EventRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
event, err = res.Event, res.Err()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if event.More == 1 && numberOfMergedEvents < maxNumberOfMergedEvents {
|
||||
var moreEvents *Event
|
||||
if moreEvents, err = c.getEvent(event.EventID, numberOfMergedEvents+1); err != nil {
|
||||
return
|
||||
}
|
||||
event = mergeEvents(event, moreEvents)
|
||||
}
|
||||
func (c *client) GetEvent(ctx context.Context, eventID string) (event *Event, err error) {
|
||||
if eventID == "" {
|
||||
eventID = "latest"
|
||||
}
|
||||
|
||||
return event, err
|
||||
var res struct {
|
||||
*Event
|
||||
|
||||
More int
|
||||
}
|
||||
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetResult(&res).Get("/events/" + eventID)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// FIXME(conman): use mergeEvents() function.
|
||||
|
||||
return res.Event, nil
|
||||
}
|
||||
|
||||
// mergeEvents combines an old events and a new events object.
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
@ -31,32 +32,41 @@ import (
|
||||
)
|
||||
|
||||
func TestClient_GetEvent(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.NoError(t, checkMethodAndPath(r, "GET", "/events/latest"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testEventBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
event, err := c.GetEvent("")
|
||||
event, err := c.GetEvent(context.TODO(), "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testEvent, event)
|
||||
}
|
||||
|
||||
func TestClient_GetEvent_withID(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.NoError(t, checkMethodAndPath(r, "GET", "/events/"+testEvent.EventID))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testEventBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
event, err := c.GetEvent(testEvent.EventID)
|
||||
event, err := c.GetEvent(context.TODO(), testEvent.EventID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testEvent, event)
|
||||
}
|
||||
|
||||
// We first call GetEvent with id of eventID1, which returns More=1 so we fetch with id eventID2.
|
||||
func TestClient_GetEvent_mergeEvents(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// FIXME(conman): Merging is currently not supported. Implement it and then enable this test again!
|
||||
func _TestClient_GetEvent_mergeEvents(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
switch r.URL.RequestURI() {
|
||||
case "/events/eventID1":
|
||||
assert.NoError(t, checkMethodAndPath(r, "GET", "/events/eventID1"))
|
||||
@ -70,15 +80,16 @@ func TestClient_GetEvent_mergeEvents(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
event, err := c.GetEvent("eventID1")
|
||||
event, err := c.GetEvent(context.TODO(), "eventID1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testEventMerged, event)
|
||||
}
|
||||
|
||||
func TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) {
|
||||
// FIXME(conman): Merging is currently not supported. Implement it and then enable this test again!
|
||||
func _TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) {
|
||||
numberOfCalls := 0
|
||||
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
numberOfCalls++
|
||||
|
||||
re := regexp.MustCompile(`/eventID([0-9]+)`)
|
||||
@ -93,18 +104,20 @@ func TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) {
|
||||
fmt.Println("")
|
||||
|
||||
body := strings.ReplaceAll(testEventBodyMore1, "eventID2", "eventID"+strconv.Itoa(eventID+1))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, body)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
event, err := c.GetEvent("eventID1")
|
||||
event, err := c.GetEvent(context.TODO(), "eventID1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, maxNumberOfMergedEvents, numberOfCalls)
|
||||
require.Equal(t, 1, event.More)
|
||||
}
|
||||
|
||||
var (
|
||||
testEventMessageUpdateUnread = 0
|
||||
testEventMessageUpdateUnread = False
|
||||
|
||||
testEvent = &Event{
|
||||
EventID: "eventID1",
|
||||
|
||||
@ -18,110 +18,68 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// Import errors.
|
||||
const (
|
||||
ImportMessageTooLarge = 36022
|
||||
)
|
||||
const MaxImportMessageRequestLength = 10
|
||||
|
||||
// ImportReq is an import request.
|
||||
type ImportReq struct {
|
||||
// A list of messages that will be imported.
|
||||
Messages []*ImportMsgReq
|
||||
}
|
||||
|
||||
// WriteTo writes the import request to a multipart writer.
|
||||
func (req *ImportReq) WriteTo(w *multipart.Writer) (err error) {
|
||||
// Create Metadata field.
|
||||
mw, err := w.CreateFormField("Metadata")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Build metadata.
|
||||
metadata := map[string]*ImportMsgReq{}
|
||||
for i, msg := range req.Messages {
|
||||
name := strconv.Itoa(i)
|
||||
metadata[name] = msg
|
||||
}
|
||||
|
||||
// Write metadata.
|
||||
if err = json.NewEncoder(mw).Encode(metadata); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Write messages.
|
||||
for i, msg := range req.Messages {
|
||||
name := strconv.Itoa(i)
|
||||
|
||||
var fw io.Writer
|
||||
if fw, err = w.CreateFormFile(name, name+".eml"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = fw.Write(msg.Body); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Adding new line to properly fetch the whole body on the API side.
|
||||
// The reason is the bug in PHP: https://bugs.php.net/bug.php?id=75923
|
||||
// Messages generated by PM already have it but importing already
|
||||
// encrypted messages might not have it.
|
||||
if _, err = fw.Write([]byte("\r\n")); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ImportMsgReq is a request to import a message. All fields are optional except AddressID and Body.
|
||||
type ImportMsgReq struct {
|
||||
// The address where the message will be imported.
|
||||
AddressID string
|
||||
// The full MIME message.
|
||||
Body []byte `json:"-"`
|
||||
|
||||
// 0: read, 1: unread.
|
||||
Unread int
|
||||
// 1 if the message has been replied.
|
||||
IsReplied int
|
||||
// 1 if the message has been replied to all.
|
||||
IsRepliedAll int
|
||||
// 1 if the message has been forwarded.
|
||||
IsForwarded int
|
||||
// The time when the message was received as a Unix time.
|
||||
Time int64
|
||||
// The type of the imported message.
|
||||
Flags int64
|
||||
// The labels to apply to the imported message. Must contain at least one system label.
|
||||
LabelIDs []string
|
||||
Metadata *ImportMetadata // Metadata about the message to import.
|
||||
Message []byte // The raw RFC822 message.
|
||||
}
|
||||
|
||||
func (req ImportMsgReq) String() string {
|
||||
data, _ := json.Marshal(req)
|
||||
return string(data)
|
||||
}
|
||||
type ImportMsgReqs []*ImportMsgReq
|
||||
|
||||
// ImportRes is a response to an import request.
|
||||
type ImportRes struct {
|
||||
Res
|
||||
func (reqs ImportMsgReqs) buildMultipartFormData() ([]*resty.MultipartField, error) {
|
||||
var fields []*resty.MultipartField
|
||||
|
||||
Responses []struct {
|
||||
Name string
|
||||
Response struct {
|
||||
Res
|
||||
MessageID string
|
||||
}
|
||||
metadata := make(map[string]*ImportMetadata)
|
||||
|
||||
for i, req := range reqs {
|
||||
name := strconv.Itoa(i)
|
||||
|
||||
metadata[name] = req.Metadata
|
||||
|
||||
fields = append(fields, &resty.MultipartField{
|
||||
Param: name,
|
||||
FileName: name + ".eml",
|
||||
ContentType: "message/rfc822",
|
||||
Reader: bytes.NewReader(req.Message),
|
||||
})
|
||||
}
|
||||
|
||||
b, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fields = append(fields, &resty.MultipartField{
|
||||
Param: "Metadata",
|
||||
ContentType: "application/json",
|
||||
Reader: bytes.NewReader(b),
|
||||
})
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
// TODO: Add other metadata.
|
||||
type ImportMetadata struct {
|
||||
AddressID string
|
||||
Unread Boolean // 0: read, 1: unread.
|
||||
IsReplied Boolean // 1 if the message has been replied.
|
||||
IsRepliedAll Boolean // 1 if the message has been replied to all.
|
||||
IsForwarded Boolean // 1 if the message has been forwarded.
|
||||
Time int64 // The time when the message was received as a Unix time.
|
||||
Flags int64 // The type of the imported message.
|
||||
LabelIDs []string // The labels to apply to the imported message. Must contain at least one system label.
|
||||
}
|
||||
|
||||
// ImportMsgRes is a response to a single message import request.
|
||||
type ImportMsgRes struct {
|
||||
// The error encountered while importing the message, if any.
|
||||
Error error
|
||||
@ -130,41 +88,46 @@ type ImportMsgRes struct {
|
||||
}
|
||||
|
||||
// Import imports messages to the user's account.
|
||||
func (c *client) Import(reqs []*ImportMsgReq) (resps []*ImportMsgRes, err error) {
|
||||
importReq := &ImportReq{Messages: reqs}
|
||||
func (c *client) Import(ctx context.Context, reqs ImportMsgReqs) ([]*ImportMsgRes, error) {
|
||||
if len(reqs) > MaxImportMessageRequestLength {
|
||||
return nil, errors.New("request is too long")
|
||||
}
|
||||
|
||||
req, w, err := c.NewMultipartRequest("POST", "/mail/v4/messages/import")
|
||||
fields, err := reqs.buildMultipartFormData()
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We will write the request as long as it is sent to the API.
|
||||
var importRes ImportRes
|
||||
done := make(chan error, 1)
|
||||
go (func() {
|
||||
done <- c.DoJSON(req, &importRes)
|
||||
})()
|
||||
|
||||
// Write the request.
|
||||
if err = importReq.WriteTo(w.Writer); err != nil {
|
||||
return
|
||||
}
|
||||
_ = w.Close()
|
||||
|
||||
if err = <-done; err != nil {
|
||||
return
|
||||
}
|
||||
if err = importRes.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resps = make([]*ImportMsgRes, len(importRes.Responses))
|
||||
for i, r := range importRes.Responses {
|
||||
resps[i] = &ImportMsgRes{
|
||||
Error: r.Response.Err(),
|
||||
MessageID: r.Response.MessageID,
|
||||
var res struct {
|
||||
Responses []struct {
|
||||
Name string
|
||||
Response struct {
|
||||
Error
|
||||
MessageID string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resps, err
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetMultipartFields(fields...).SetResult(&res).Post("/mail/v4/messages/import")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resps []*ImportMsgRes
|
||||
|
||||
for _, resp := range res.Responses {
|
||||
var err error
|
||||
|
||||
if resp.Response.Code != 1000 {
|
||||
err = resp.Response.Error
|
||||
}
|
||||
|
||||
resps = append(resps, &ImportMsgRes{
|
||||
Error: err,
|
||||
MessageID: resp.Response.MessageID,
|
||||
})
|
||||
}
|
||||
|
||||
return resps, nil
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -32,11 +33,13 @@ import (
|
||||
|
||||
var testImportReqs = []*ImportMsgReq{
|
||||
{
|
||||
AddressID: "QMJs2dzTx7uqpH5PNgIzjULywU4gO9uMBhEMVFOAVJOoUml54gC0CCHtW9qYwzH-zYbZwMv3MFYncPjW1Usq7Q==",
|
||||
Body: []byte("Hello World!"),
|
||||
Unread: 0,
|
||||
Flags: FlagReceived | FlagImported,
|
||||
LabelIDs: []string{ArchiveLabel},
|
||||
Metadata: &ImportMetadata{
|
||||
AddressID: "QMJs2dzTx7uqpH5PNgIzjULywU4gO9uMBhEMVFOAVJOoUml54gC0CCHtW9qYwzH-zYbZwMv3MFYncPjW1Usq7Q==",
|
||||
Unread: 0,
|
||||
Flags: FlagReceived | FlagImported,
|
||||
LabelIDs: []string{ArchiveLabel},
|
||||
},
|
||||
Message: []byte("Hello World!"),
|
||||
},
|
||||
}
|
||||
|
||||
@ -54,7 +57,7 @@ var testImportRes = &ImportMsgRes{
|
||||
}
|
||||
|
||||
func TestClient_Import(t *testing.T) { // nolint[funlen]
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/messages/import"))
|
||||
|
||||
contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
@ -67,51 +70,13 @@ func TestClient_Import(t *testing.T) { // nolint[funlen]
|
||||
|
||||
mr := multipart.NewReader(r.Body, params["boundary"])
|
||||
|
||||
// First part is metadata.
|
||||
// First part is message body.
|
||||
p, err := mr.NextPart()
|
||||
if err != nil {
|
||||
t.Error("Expected no error while reading first part of request body, got:", err)
|
||||
}
|
||||
|
||||
contentDisp, params, err := pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
|
||||
if err != nil {
|
||||
t.Error("Expected no error while parsing part content disposition, got:", err)
|
||||
}
|
||||
if contentDisp != "form-data" {
|
||||
t.Errorf("Invalid part content disposition: expected %v but got %v", "form-data", contentType)
|
||||
}
|
||||
if params["name"] != "Metadata" {
|
||||
t.Errorf("Invalid part name: expected %v but got %v", "Metadata", params["name"])
|
||||
}
|
||||
|
||||
metadata := map[string]*ImportMsgReq{}
|
||||
if err := json.NewDecoder(p).Decode(&metadata); err != nil {
|
||||
t.Error("Expected no error while parsing metadata json, got:", err)
|
||||
}
|
||||
|
||||
if len(metadata) != 1 {
|
||||
t.Errorf("Expected metadata to contain exactly one item, got %v", metadata)
|
||||
}
|
||||
|
||||
req := metadata["0"]
|
||||
if metadata["0"] == nil {
|
||||
t.Errorf("Expected metadata to contain one item indexed by 0, got %v", metadata)
|
||||
}
|
||||
|
||||
// No Body in metadata.
|
||||
expected := *testImportReqs[0]
|
||||
expected.Body = nil
|
||||
if !reflect.DeepEqual(&expected, req) {
|
||||
t.Errorf("Invalid message metadata: expected %v, got %v", &expected, req)
|
||||
}
|
||||
|
||||
// Second part is message body.
|
||||
p, err = mr.NextPart()
|
||||
if err != nil {
|
||||
t.Error("Expected no error while reading second part of request body, got:", err)
|
||||
}
|
||||
|
||||
contentDisp, params, err = pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
|
||||
contentDisp, params, err := pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
|
||||
if err != nil {
|
||||
t.Error("Expected no error while parsing part content disposition, got:", err)
|
||||
}
|
||||
@ -127,8 +92,44 @@ func TestClient_Import(t *testing.T) { // nolint[funlen]
|
||||
t.Error("Expected no error while reading second part body, got:", err)
|
||||
}
|
||||
|
||||
if string(b) != string(testImportReqs[0].Body)+"\r\n" {
|
||||
t.Errorf("Invalid message body: expected %v but got %v", string(testImportReqs[0].Body), string(b))
|
||||
if string(b) != string(testImportReqs[0].Message) {
|
||||
t.Errorf("Invalid message body: expected %v but got %v", string(testImportReqs[0].Message), string(b))
|
||||
}
|
||||
|
||||
// Second part is metadata.
|
||||
p, err = mr.NextPart()
|
||||
if err != nil {
|
||||
t.Error("Expected no error while reading first part of request body, got:", err)
|
||||
}
|
||||
|
||||
contentDisp, params, err = pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
|
||||
if err != nil {
|
||||
t.Error("Expected no error while parsing part content disposition, got:", err)
|
||||
}
|
||||
if contentDisp != "form-data" {
|
||||
t.Errorf("Invalid part content disposition: expected %v but got %v", "form-data", contentType)
|
||||
}
|
||||
if params["name"] != "Metadata" {
|
||||
t.Errorf("Invalid part name: expected %v but got %v", "Metadata", params["name"])
|
||||
}
|
||||
|
||||
metadata := map[string]*ImportMetadata{}
|
||||
if err := json.NewDecoder(p).Decode(&metadata); err != nil {
|
||||
t.Error("Expected no error while parsing metadata json, got:", err)
|
||||
}
|
||||
|
||||
if len(metadata) != 1 {
|
||||
t.Errorf("Expected metadata to contain exactly one item, got %v", metadata)
|
||||
}
|
||||
|
||||
req := metadata["0"]
|
||||
if metadata["0"] == nil {
|
||||
t.Errorf("Expected metadata to contain one item indexed by 0, got %v", metadata)
|
||||
}
|
||||
|
||||
expected := *testImportReqs[0].Metadata
|
||||
if !reflect.DeepEqual(&expected, req) {
|
||||
t.Errorf("Invalid message metadata: expected %v, got %v", &expected, req)
|
||||
}
|
||||
|
||||
// No more parts.
|
||||
@ -137,11 +138,13 @@ func TestClient_Import(t *testing.T) { // nolint[funlen]
|
||||
t.Error("Expected no more parts but error was not EOF, got:", err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testImportBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
imported, err := c.Import(testImportReqs)
|
||||
imported, err := c.Import(context.TODO(), testImportReqs)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while importing, got:", err)
|
||||
}
|
||||
|
||||
103
pkg/pmapi/key.go
103
pkg/pmapi/key.go
@ -18,11 +18,10 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// Key flags.
|
||||
@ -31,84 +30,34 @@ const (
|
||||
UseToEncryptFlag
|
||||
)
|
||||
|
||||
type PublicKeyRes struct {
|
||||
Res
|
||||
|
||||
RecipientType int
|
||||
MIMEType string
|
||||
Keys []PublicKey
|
||||
}
|
||||
|
||||
type PublicKey struct {
|
||||
Flags int
|
||||
PublicKey string
|
||||
}
|
||||
|
||||
// PublicKeys returns the public keys of the given email addresses.
|
||||
func (c *client) PublicKeys(emails []string) (keys map[string]*crypto.Key, err error) {
|
||||
if len(emails) == 0 {
|
||||
err = fmt.Errorf("pmapi: cannot get public keys: no email address provided")
|
||||
return
|
||||
}
|
||||
|
||||
keys = make(map[string]*crypto.Key)
|
||||
|
||||
for _, email := range emails {
|
||||
email = url.QueryEscape(email)
|
||||
|
||||
var req *http.Request
|
||||
if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var res PublicKeyRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, rawKey := range res.Keys {
|
||||
if rawKey.Flags&UseToEncryptFlag == UseToEncryptFlag {
|
||||
var key *crypto.Key
|
||||
|
||||
if key, err = crypto.NewKeyFromArmored(rawKey.PublicKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
keys[email] = key
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return keys, err
|
||||
}
|
||||
type RecipientType int
|
||||
|
||||
const (
|
||||
RecipientInternal = 1
|
||||
RecipientExternal = 2
|
||||
RecipientTypeInternal RecipientType = iota + 1
|
||||
RecipientTypeExternal
|
||||
)
|
||||
|
||||
// GetPublicKeysForEmail returns all sending public keys for the given email address.
|
||||
func (c *client) GetPublicKeysForEmail(email string) (keys []PublicKey, internal bool, err error) {
|
||||
func (c *client) GetPublicKeysForEmail(ctx context.Context, email string) (keys []PublicKey, internal bool, err error) {
|
||||
email = url.QueryEscape(email)
|
||||
|
||||
var req *http.Request
|
||||
if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil {
|
||||
return
|
||||
var res struct {
|
||||
Keys []PublicKey
|
||||
RecipientType RecipientType
|
||||
}
|
||||
|
||||
var res PublicKeyRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetResult(&res).SetQueryParam("Email", email).Get("/keys")
|
||||
}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
internal = res.RecipientType == RecipientInternal
|
||||
|
||||
for _, key := range res.Keys {
|
||||
if key.Flags&UseToEncryptFlag == UseToEncryptFlag {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
return
|
||||
return res.Keys, res.RecipientType == RecipientTypeInternal, nil
|
||||
}
|
||||
|
||||
// KeySalt contains id and salt for key.
|
||||
@ -116,25 +65,17 @@ type KeySalt struct {
|
||||
ID, KeySalt string
|
||||
}
|
||||
|
||||
// KeySaltRes is used to unmarshal API response.
|
||||
type KeySaltRes struct {
|
||||
Res
|
||||
KeySalts []KeySalt
|
||||
}
|
||||
|
||||
// GetKeySalts sends request to get list of key salts (n.b. locked route).
|
||||
func (c *client) GetKeySalts() (keySalts []KeySalt, err error) {
|
||||
var req *http.Request
|
||||
if req, err = c.NewRequest("GET", "/keys/salts", nil); err != nil {
|
||||
return
|
||||
func (c *client) GetKeySalts(ctx context.Context) (keySalts []KeySalt, err error) {
|
||||
var res struct {
|
||||
KeySalts []KeySalt
|
||||
}
|
||||
|
||||
var res KeySaltRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetResult(&res).Get("/keys/salts")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keySalts = res.KeySalts
|
||||
|
||||
return
|
||||
return res.KeySalts, nil
|
||||
}
|
||||
|
||||
@ -18,8 +18,11 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// System labels.
|
||||
@ -92,100 +95,84 @@ type Label struct {
|
||||
Notify int
|
||||
}
|
||||
|
||||
type LabelListRes struct {
|
||||
Res
|
||||
Labels []*Label
|
||||
func (c *client) ListLabels(ctx context.Context) (labels []*Label, err error) {
|
||||
return c.ListLabelType(ctx, LabelTypeMailbox)
|
||||
}
|
||||
|
||||
func (c *client) ListLabels() (labels []*Label, err error) {
|
||||
return c.ListLabelType(LabelTypeMailbox)
|
||||
}
|
||||
|
||||
func (c *client) ListContactGroups() (labels []*Label, err error) {
|
||||
return c.ListLabelType(LabelTypeContactGroup)
|
||||
func (c *client) ListContactGroups(ctx context.Context) (labels []*Label, err error) {
|
||||
return c.ListLabelType(ctx, LabelTypeContactGroup)
|
||||
}
|
||||
|
||||
// ListLabelType lists all labels created by the user.
|
||||
func (c *client) ListLabelType(labelType int) (labels []*Label, err error) {
|
||||
req, err := c.NewRequest("GET", fmt.Sprintf("/labels?%d", labelType), nil)
|
||||
if err != nil {
|
||||
return
|
||||
func (c *client) ListLabelType(ctx context.Context, labelType int) (labels []*Label, err error) {
|
||||
var res struct {
|
||||
Labels []*Label
|
||||
}
|
||||
|
||||
var res LabelListRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetQueryParam("Type", strconv.Itoa(labelType)).SetResult(&res).Get("/v4/labels")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
labels, err = res.Labels, res.Err()
|
||||
return
|
||||
return res.Labels, nil
|
||||
}
|
||||
|
||||
type LabelReq struct {
|
||||
*Label
|
||||
}
|
||||
|
||||
type LabelRes struct {
|
||||
Res
|
||||
Label *Label
|
||||
}
|
||||
|
||||
// CreateLabel creates a new label.
|
||||
func (c *client) CreateLabel(label *Label) (created *Label, err error) {
|
||||
func (c *client) CreateLabel(ctx context.Context, label *Label) (created *Label, err error) {
|
||||
if label.Name == "" {
|
||||
return nil, errors.New("name is required")
|
||||
}
|
||||
|
||||
labelReq := &LabelReq{label}
|
||||
req, err := c.NewJSONRequest("POST", "/labels", labelReq)
|
||||
if err != nil {
|
||||
return
|
||||
var res struct {
|
||||
Label *Label
|
||||
}
|
||||
|
||||
var res LabelRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(&LabelReq{
|
||||
Label: label,
|
||||
}).SetResult(&res).Post("/v4/labels")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created, err = res.Label, res.Err()
|
||||
return
|
||||
return res.Label, nil
|
||||
}
|
||||
|
||||
// UpdateLabel updates a label.
|
||||
func (c *client) UpdateLabel(label *Label) (updated *Label, err error) {
|
||||
func (c *client) UpdateLabel(ctx context.Context, label *Label) (updated *Label, err error) {
|
||||
if label.Name == "" {
|
||||
return nil, errors.New("name is required")
|
||||
}
|
||||
|
||||
labelReq := &LabelReq{label}
|
||||
req, err := c.NewJSONRequest("PUT", "/labels/"+label.ID, labelReq)
|
||||
if err != nil {
|
||||
return
|
||||
var res struct {
|
||||
Label *Label
|
||||
}
|
||||
|
||||
var res LabelRes
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(&LabelReq{
|
||||
Label: label,
|
||||
}).SetResult(&res).Put("/v4/labels/" + label.ID)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err = res.Label, res.Err()
|
||||
return
|
||||
return res.Label, nil
|
||||
}
|
||||
|
||||
// DeleteLabel deletes a label.
|
||||
func (c *client) DeleteLabel(id string) (err error) {
|
||||
req, err := c.NewRequest("DELETE", "/labels/"+id, nil)
|
||||
if err != nil {
|
||||
return
|
||||
func (c *client) DeleteLabel(ctx context.Context, labelID string) error {
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.Delete("/v4/labels/" + labelID)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var res Res
|
||||
if err = c.DoJSON(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = res.Err()
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// LeastUsedColor is intended to return color for creating a new inbox or label.
|
||||
|
||||
@ -19,6 +19,7 @@ package pmapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -90,14 +91,16 @@ const testDeleteLabelBody = `{
|
||||
`
|
||||
|
||||
func TestClient_ListLabels(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/labels?1"))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/v4/labels?Type=1"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testLabelsBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
labels, err := c.ListLabels()
|
||||
labels, err := c.ListLabels(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while listing labels, got:", err)
|
||||
}
|
||||
@ -114,8 +117,8 @@ func TestClient_ListLabels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_CreateLabel(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/labels"))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/v4/labels"))
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
_, err := body.ReadFrom(r.Body)
|
||||
@ -133,11 +136,13 @@ func TestClient_CreateLabel(t *testing.T) {
|
||||
t.Errorf("Invalid label request: expected %+v but got %+v", testLabelReq.Label, labelReq.Label)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testCreateLabelBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
created, err := c.CreateLabel(testLabelReq.Label)
|
||||
created, err := c.CreateLabel(context.TODO(), testLabelReq.Label)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while creating label, got:", err)
|
||||
}
|
||||
@ -148,18 +153,18 @@ func TestClient_CreateLabel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_CreateEmptyLabel(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
r.Fail(t, "API should not be called")
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
_, err := c.CreateLabel(&Label{})
|
||||
_, err := c.CreateLabel(context.TODO(), &Label{})
|
||||
r.EqualError(t, err, "name is required")
|
||||
}
|
||||
|
||||
func TestClient_UpdateLabel(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "PUT", "/labels/"+testLabelCreated.ID))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "PUT", "/v4/labels/"+testLabelCreated.ID))
|
||||
|
||||
var labelReq LabelReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&labelReq); err != nil {
|
||||
@ -169,11 +174,13 @@ func TestClient_UpdateLabel(t *testing.T) {
|
||||
t.Errorf("Invalid label request: expected %+v but got %+v", testLabelCreated, labelReq.Label)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testCreateLabelBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
updated, err := c.UpdateLabel(testLabelCreated)
|
||||
updated, err := c.UpdateLabel(context.TODO(), testLabelCreated)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while updating label, got:", err)
|
||||
}
|
||||
@ -184,24 +191,26 @@ func TestClient_UpdateLabel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_UpdateLabelToEmptyName(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
r.Fail(t, "API should not be called")
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
_, err := c.UpdateLabel(&Label{ID: "label"})
|
||||
_, err := c.UpdateLabel(context.TODO(), &Label{ID: "label"})
|
||||
r.EqualError(t, err, "name is required")
|
||||
}
|
||||
|
||||
func TestClient_DeleteLabel(t *testing.T) {
|
||||
s, c := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "DELETE", "/labels/"+testLabelCreated.ID))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "DELETE", "/v4/labels/"+testLabelCreated.ID))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testDeleteLabelBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
err := c.DeleteLabel(testLabelCreated.ID)
|
||||
err := c.DeleteLabel(context.TODO(), testLabelCreated.ID)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while deleting label, got:", err)
|
||||
}
|
||||
|
||||
127
pkg/pmapi/manager.go
Normal file
127
pkg/pmapi/manager.go
Normal file
@ -0,0 +1,127 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type manager struct {
|
||||
rc *resty.Client
|
||||
|
||||
isDown bool
|
||||
locker sync.Locker
|
||||
observers []ConnectionObserver
|
||||
}
|
||||
|
||||
func newManager(cfg Config) *manager {
|
||||
m := &manager{
|
||||
rc: resty.New(),
|
||||
locker: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// Set the API host.
|
||||
m.rc.SetHostURL(cfg.HostURL)
|
||||
|
||||
// Set static header values.
|
||||
m.rc.SetHeader("x-pm-appversion", cfg.AppVersion)
|
||||
|
||||
// Set middleware.
|
||||
m.rc.OnAfterResponse(catchAPIError)
|
||||
|
||||
// Configure retry mechanism.
|
||||
m.rc.SetRetryMaxWaitTime(time.Minute)
|
||||
m.rc.SetRetryAfter(catchRetryAfter)
|
||||
m.rc.AddRetryCondition(catchTooManyRequests)
|
||||
m.rc.AddRetryCondition(catchNoResponse)
|
||||
m.rc.AddRetryCondition(catchProxyAvailable)
|
||||
|
||||
// Determine what happens when requests succeed/fail.
|
||||
m.rc.OnAfterResponse(m.handleRequestSuccess)
|
||||
m.rc.OnError(m.handleRequestFailure)
|
||||
|
||||
// Set the data type of API errors.
|
||||
m.rc.SetError(&Error{})
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func New(cfg Config) Manager {
|
||||
return newManager(cfg)
|
||||
}
|
||||
|
||||
func (m *manager) SetLogger(logger resty.Logger) {
|
||||
m.rc.SetLogger(logger)
|
||||
m.rc.SetDebug(true)
|
||||
}
|
||||
|
||||
func (m *manager) SetTransport(transport http.RoundTripper) {
|
||||
m.rc.SetTransport(transport)
|
||||
}
|
||||
|
||||
func (m *manager) SetCookieJar(jar http.CookieJar) {
|
||||
m.rc.SetCookieJar(jar)
|
||||
}
|
||||
|
||||
func (m *manager) SetRetryCount(count int) {
|
||||
m.rc.SetRetryCount(count)
|
||||
}
|
||||
|
||||
func (m *manager) AddConnectionObserver(observer ConnectionObserver) {
|
||||
m.observers = append(m.observers, observer)
|
||||
}
|
||||
|
||||
func (m *manager) r(ctx context.Context) *resty.Request {
|
||||
return m.rc.R().SetContext(ctx)
|
||||
}
|
||||
|
||||
func (m *manager) handleRequestSuccess(_ *resty.Client, res *resty.Response) error {
|
||||
m.locker.Lock()
|
||||
defer m.locker.Unlock()
|
||||
|
||||
if !m.isDown {
|
||||
return nil
|
||||
}
|
||||
|
||||
// We successfully got a response; connection must be up.
|
||||
|
||||
m.isDown = false
|
||||
|
||||
for _, observer := range m.observers {
|
||||
observer.OnUp()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) handleRequestFailure(req *resty.Request, err error) {
|
||||
m.locker.Lock()
|
||||
defer m.locker.Unlock()
|
||||
|
||||
if m.isDown {
|
||||
return
|
||||
}
|
||||
|
||||
if res, ok := err.(*resty.ResponseError); ok && res.Response.RawResponse != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// We didn't get any response; connection must be down.
|
||||
|
||||
m.isDown = true
|
||||
|
||||
for _, observer := range m.observers {
|
||||
observer.OnDown()
|
||||
}
|
||||
|
||||
go m.pingUntilSuccess()
|
||||
}
|
||||
|
||||
func (m *manager) pingUntilSuccess() {
|
||||
for m.testPing(context.Background()) != nil {
|
||||
time.Sleep(time.Second) // TODO: How long to sleep here?
|
||||
}
|
||||
}
|
||||
114
pkg/pmapi/manager_auth.go
Normal file
114
pkg/pmapi/manager_auth.go
Normal file
@ -0,0 +1,114 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/srp"
|
||||
)
|
||||
|
||||
func (m *manager) NewClient(uid, acc, ref string, exp time.Time) Client {
|
||||
return newClient(m, uid).withAuth(acc, ref, exp)
|
||||
}
|
||||
|
||||
func (m *manager) NewClientWithRefresh(ctx context.Context, uid, ref string) (Client, *Auth, error) {
|
||||
c := newClient(m, uid)
|
||||
|
||||
auth, err := m.authRefresh(ctx, uid, ref)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return c.withAuth(auth.AccessToken, auth.RefreshToken, expiresIn(auth.ExpiresIn)), auth, nil
|
||||
}
|
||||
|
||||
func (m *manager) NewClientWithLogin(ctx context.Context, username, password string) (Client, *Auth, error) {
|
||||
info, err := m.getAuthInfo(ctx, GetAuthInfoReq{Username: username})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
srpAuth, err := srp.NewSrpAuth(info.Version, username, password, info.Salt, info.Modulus, info.ServerEphemeral)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proofs, err := srpAuth.GenerateSrpProofs(2048)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
auth, err := m.auth(ctx, AuthReq{
|
||||
Username: username,
|
||||
ClientProof: base64.StdEncoding.EncodeToString(proofs.ClientProof),
|
||||
ClientEphemeral: base64.StdEncoding.EncodeToString(proofs.ClientEphemeral),
|
||||
SRPSession: info.SRPSession,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return newClient(m, auth.UID).withAuth(auth.AccessToken, auth.RefreshToken, expiresIn(auth.ExpiresIn)), auth, nil
|
||||
}
|
||||
|
||||
func (m *manager) getAuthModulus(ctx context.Context) (AuthModulus, error) {
|
||||
var res struct {
|
||||
AuthModulus
|
||||
}
|
||||
|
||||
if _, err := m.r(ctx).SetResult(&res).Get("/auth/modulus"); err != nil {
|
||||
return AuthModulus{}, err
|
||||
}
|
||||
|
||||
return res.AuthModulus, nil
|
||||
}
|
||||
|
||||
func (m *manager) getAuthInfo(ctx context.Context, req GetAuthInfoReq) (*AuthInfo, error) {
|
||||
var res struct {
|
||||
*AuthInfo
|
||||
}
|
||||
|
||||
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/info"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.AuthInfo, nil
|
||||
}
|
||||
|
||||
func (m *manager) auth(ctx context.Context, req AuthReq) (*Auth, error) {
|
||||
var res struct {
|
||||
*Auth
|
||||
}
|
||||
|
||||
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Auth, nil
|
||||
}
|
||||
|
||||
func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*Auth, error) {
|
||||
var req = AuthRefreshReq{
|
||||
UID: uid,
|
||||
RefreshToken: ref,
|
||||
ResponseType: "token",
|
||||
GrantType: "refresh_token",
|
||||
RedirectURI: "https://protonmail.ch",
|
||||
State: randomString(32),
|
||||
}
|
||||
|
||||
var res struct {
|
||||
*Auth
|
||||
}
|
||||
|
||||
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Auth, nil
|
||||
}
|
||||
|
||||
func expiresIn(seconds int64) time.Time {
|
||||
return time.Now().Add(time.Duration(seconds) * time.Second)
|
||||
}
|
||||
51
pkg/pmapi/manager_download.go
Normal file
51
pkg/pmapi/manager_download.go
Normal file
@ -0,0 +1,51 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
)
|
||||
|
||||
// DownloadAndVerify downloads a file and its signature from the given locations `file` and `sig`.
|
||||
// The file and its signature are verified using the given keyring `kr`.
|
||||
// If the file is verified successfully, it can be read from the returned reader.
|
||||
// TLS fingerprinting is used to verify that connections are only made to known servers.
|
||||
func (m *manager) DownloadAndVerify(kr *crypto.KeyRing, url, sig string) ([]byte, error) {
|
||||
fb, err := m.fetchFile(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sb, err := m.fetchFile(sig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := kr.VerifyDetached(
|
||||
crypto.NewPlainMessage(fb),
|
||||
crypto.NewPGPSignature(sb),
|
||||
crypto.GetUnixTime(),
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fb, nil
|
||||
}
|
||||
|
||||
func (m *manager) fetchFile(url string) ([]byte, error) {
|
||||
res, err := m.rc.R().SetDoNotParseResponse(true).Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(res.RawBody())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := res.RawBody().Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
11
pkg/pmapi/manager_metrics.go
Normal file
11
pkg/pmapi/manager_metrics.go
Normal file
@ -0,0 +1,11 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
func (m *manager) SendSimpleMetric(context.Context, string, string, string) error {
|
||||
// FIXME(conman): Implement.
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
11
pkg/pmapi/manager_ping.go
Normal file
11
pkg/pmapi/manager_ping.go
Normal file
@ -0,0 +1,11 @@
|
||||
package pmapi
|
||||
|
||||
import "context"
|
||||
|
||||
func (m *manager) testPing(ctx context.Context) error {
|
||||
if _, err := m.r(ctx).Get("/tests/ping"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
12
pkg/pmapi/manager_report.go
Normal file
12
pkg/pmapi/manager_report.go
Normal file
@ -0,0 +1,12 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Report sends request as json or multipart (if has attachment).
|
||||
func (m *manager) ReportBug(context.Context, ReportBugReq) error {
|
||||
// FIXME(conman): Implement.
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user