mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 12:46:46 +00:00
GODT-35: Finish all details and make tests pass
This commit is contained in:
5
go.mod
5
go.mod
@ -40,7 +40,7 @@ require (
|
|||||||
github.com/fatih/color v1.9.0
|
github.com/fatih/color v1.9.0
|
||||||
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
|
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
|
||||||
github.com/getsentry/sentry-go v0.8.0
|
github.com/getsentry/sentry-go v0.8.0
|
||||||
github.com/go-resty/resty/v2 v2.4.0
|
github.com/go-resty/resty/v2 v2.6.0
|
||||||
github.com/golang/mock v1.4.4
|
github.com/golang/mock v1.4.4
|
||||||
github.com/google/go-cmp v0.5.1
|
github.com/google/go-cmp v0.5.1
|
||||||
github.com/google/uuid v1.1.1
|
github.com/google/uuid v1.1.1
|
||||||
@ -50,6 +50,7 @@ require (
|
|||||||
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
|
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
|
||||||
github.com/logrusorgru/aurora v2.0.3+incompatible
|
github.com/logrusorgru/aurora v2.0.3+incompatible
|
||||||
github.com/mattn/go-runewidth v0.0.9 // indirect
|
github.com/mattn/go-runewidth v0.0.9 // indirect
|
||||||
|
github.com/miekg/dns v1.1.41
|
||||||
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
|
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
|
||||||
github.com/olekukonko/tablewriter v0.0.4 // indirect
|
github.com/olekukonko/tablewriter v0.0.4 // indirect
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
@ -63,7 +64,7 @@ require (
|
|||||||
github.com/urfave/cli/v2 v2.2.0
|
github.com/urfave/cli/v2 v2.2.0
|
||||||
github.com/vmihailenco/msgpack/v5 v5.1.3
|
github.com/vmihailenco/msgpack/v5 v5.1.3
|
||||||
go.etcd.io/bbolt v1.3.5
|
go.etcd.io/bbolt v1.3.5
|
||||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
|
||||||
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec
|
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
18
go.sum
18
go.sum
@ -113,8 +113,8 @@ github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclK
|
|||||||
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
|
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
|
||||||
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
|
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
|
||||||
github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8=
|
github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8=
|
||||||
github.com/go-resty/resty/v2 v2.4.0 h1:s6TItTLejEI+2mn98oijC5w/Rk2YU+OA6x0mnZN6r6k=
|
github.com/go-resty/resty/v2 v2.6.0 h1:joIR5PNLM2EFqqESUjCMGXrWmXNHEU9CEiK813oKYS4=
|
||||||
github.com/go-resty/resty/v2 v2.4.0/go.mod h1:B88+xCTEwvfD94NOuE6GS1wMlnoKNY8eEiNizfNwOwA=
|
github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w/BIH7cC3Q=
|
||||||
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||||
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||||
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||||
@ -195,6 +195,8 @@ github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m
|
|||||||
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
|
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
|
||||||
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
|
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
|
||||||
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
|
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
|
||||||
|
github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY=
|
||||||
|
github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI=
|
||||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@ -310,12 +312,16 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
|
||||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
|
||||||
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
|
||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
|
||||||
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@ -330,6 +336,10 @@ golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04 h1:cEhElsAv9LUt9ZUUocxzWe05oFLVd+AA2nstydTeI8g=
|
||||||
|
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 h1:Bli41pIlzTzf3KEY06n+xnzK/BESIg2ze4Pgfh/aI8c=
|
||||||
|
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
|||||||
@ -181,21 +181,18 @@ func New( // nolint[funlen]
|
|||||||
kc = keychain.NewMissingKeychain()
|
kc = keychain.NewMissingKeychain()
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME(conman): Customize config depending on build type (app version, host URL).
|
cfg := pmapi.NewConfig(configName, constants.Version)
|
||||||
cm := pmapi.New(pmapi.DefaultConfig)
|
cfg.GetUserAgent = userAgent.String
|
||||||
|
cfg.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
|
||||||
|
cfg.TLSIssueHandler = func() { listener.Emit(events.TLSCertIssue, "") }
|
||||||
|
|
||||||
|
cm := pmapi.New(cfg)
|
||||||
|
|
||||||
// FIXME(conman): Should this be a real object, not just created via callbacks?
|
|
||||||
cm.AddConnectionObserver(pmapi.NewConnectionObserver(
|
cm.AddConnectionObserver(pmapi.NewConnectionObserver(
|
||||||
func() { listener.Emit(events.InternetOffEvent, "") },
|
func() { listener.Emit(events.InternetOffEvent, "") },
|
||||||
func() { listener.Emit(events.InternetOnEvent, "") },
|
func() { listener.Emit(events.InternetOnEvent, "") },
|
||||||
))
|
))
|
||||||
|
|
||||||
// FIXME(conman): Implement force upgrade observer.
|
|
||||||
// apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
|
|
||||||
|
|
||||||
// FIXME(conman): Set up fancy round tripper with DoH/TLS checks etc.
|
|
||||||
// cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
|
|
||||||
|
|
||||||
jar, err := cookies.NewCookieJar(settingsObj)
|
jar, err := cookies.NewCookieJar(settingsObj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -341,6 +338,7 @@ func (b *Base) run(appMainLoop func(*Base, *cli.Context) error) cli.ActionFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logging.SetLevel(c.String(flagLogLevel))
|
logging.SetLevel(c.String(flagLogLevel))
|
||||||
|
b.CM.SetLogging(logrus.WithField("pkg", "pmapi"), logrus.GetLevel() == logrus.TraceLevel)
|
||||||
|
|
||||||
logrus.
|
logrus.
|
||||||
WithField("appName", b.Name).
|
WithField("appName", b.Name).
|
||||||
|
|||||||
@ -65,8 +65,7 @@ func New(
|
|||||||
// Allow DoH before starting the app if the user has previously set this setting.
|
// Allow DoH before starting the app if the user has previously set this setting.
|
||||||
// This allows us to start even if protonmail is blocked.
|
// This allows us to start even if protonmail is blocked.
|
||||||
if s.GetBool(settings.AllowProxyKey) {
|
if s.GetBool(settings.AllowProxyKey) {
|
||||||
// FIXME(conman): Support enable/disable of DoH.
|
clientManager.AllowProxy()
|
||||||
// clientManager.AllowProxy()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener)
|
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener)
|
||||||
@ -120,7 +119,7 @@ func (b *Bridge) heartbeat() {
|
|||||||
|
|
||||||
// ReportBug reports a new bug from the user.
|
// ReportBug reports a new bug from the user.
|
||||||
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
||||||
return b.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
|
return b.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{
|
||||||
OS: osType,
|
OS: osType,
|
||||||
OSVersion: osVersion,
|
OSVersion: osVersion,
|
||||||
Browser: emailClient,
|
Browser: emailClient,
|
||||||
|
|||||||
@ -21,7 +21,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
|
||||||
"github.com/abiosoft/ishell"
|
"github.com/abiosoft/ishell"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -75,13 +74,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
|
if auth.HasTwoFactor() {
|
||||||
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
||||||
if twoFactor == "" {
|
if twoFactor == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
|
err = client.Auth2FA(context.Background(), twoFactor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.processAPIError(err)
|
f.processAPIError(err)
|
||||||
return
|
return
|
||||||
@ -89,7 +88,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
|||||||
}
|
}
|
||||||
|
|
||||||
mailboxPassword := password
|
mailboxPassword := password
|
||||||
if auth.PasswordMode == pmapi.TwoPasswordMode {
|
if auth.HasMailboxPassword() {
|
||||||
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
||||||
}
|
}
|
||||||
if mailboxPassword == "" {
|
if mailboxPassword == "" {
|
||||||
|
|||||||
@ -84,11 +84,6 @@ func New( //nolint[funlen]
|
|||||||
Aliases: []string{"u", "version", "v"},
|
Aliases: []string{"u", "version", "v"},
|
||||||
Func: fe.checkUpdates,
|
Func: fe.checkUpdates,
|
||||||
})
|
})
|
||||||
checkCmd.AddCmd(&ishell.Cmd{Name: "internet",
|
|
||||||
Help: "check internet connection. (aliases: i, conn, connection)",
|
|
||||||
Aliases: []string{"i", "con", "connection"},
|
|
||||||
Func: fe.checkInternetConnection,
|
|
||||||
})
|
|
||||||
fe.AddCmd(checkCmd)
|
fe.AddCmd(checkCmd)
|
||||||
|
|
||||||
// Print info commands.
|
// Print info commands.
|
||||||
@ -177,13 +172,13 @@ func New( //nolint[funlen]
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *frontendCLI) watchEvents() {
|
func (f *frontendCLI) watchEvents() {
|
||||||
errorCh := f.getEventChannel(events.ErrorEvent)
|
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
|
||||||
credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent)
|
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||||
internetOffCh := f.getEventChannel(events.InternetOffEvent)
|
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
|
||||||
internetOnCh := f.getEventChannel(events.InternetOnEvent)
|
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
|
||||||
addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent)
|
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||||
logoutCh := f.getEventChannel(events.LogoutEvent)
|
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
|
||||||
certIssue := f.getEventChannel(events.TLSCertIssue)
|
certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case errorDetails := <-errorCh:
|
case errorDetails := <-errorCh:
|
||||||
@ -208,13 +203,6 @@ func (f *frontendCLI) watchEvents() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *frontendCLI) getEventChannel(event string) <-chan string {
|
|
||||||
ch := make(chan string)
|
|
||||||
f.eventListener.Add(event, ch)
|
|
||||||
f.eventListener.RetryEmit(event)
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop starts the frontend loop with an interactive shell.
|
// Loop starts the frontend loop with an interactive shell.
|
||||||
func (f *frontendCLI) Loop() error {
|
func (f *frontendCLI) Loop() error {
|
||||||
f.Print(`
|
f.Print(`
|
||||||
|
|||||||
@ -29,14 +29,6 @@ func (f *frontendCLI) restart(c *ishell.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *frontendCLI) checkInternetConnection(c *ishell.Context) {
|
|
||||||
if f.ie.CheckConnection() == nil {
|
|
||||||
f.Println("Internet connection is available.")
|
|
||||||
} else {
|
|
||||||
f.Println("Can not contact the server, please check your internet connection.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *frontendCLI) printLogDir(c *ishell.Context) {
|
func (f *frontendCLI) printLogDir(c *ishell.Context) {
|
||||||
if path, err := f.locations.ProvideLogsPath(); err != nil {
|
if path, err := f.locations.ProvideLogsPath(); err != nil {
|
||||||
f.Println("Failed to determine location of log files")
|
f.Println("Failed to determine location of log files")
|
||||||
|
|||||||
@ -20,6 +20,7 @@ package cliie
|
|||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -70,13 +71,10 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
|
|||||||
func (f *frontendCLI) processAPIError(err error) {
|
func (f *frontendCLI) processAPIError(err error) {
|
||||||
log.Warn("API error: ", err)
|
log.Warn("API error: ", err)
|
||||||
switch err {
|
switch err {
|
||||||
// FIXME(conman): How to handle various API errors?
|
|
||||||
/*
|
|
||||||
case pmapi.ErrNoConnection:
|
case pmapi.ErrNoConnection:
|
||||||
f.notifyInternetOff()
|
f.notifyInternetOff()
|
||||||
case pmapi.ErrUpgradeApplication:
|
case pmapi.ErrUpgradeApplication:
|
||||||
f.notifyNeedUpgrade()
|
f.notifyNeedUpgrade()
|
||||||
*/
|
|
||||||
default:
|
default:
|
||||||
f.Println("Server error:", err.Error())
|
f.Println("Server error:", err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,7 +24,6 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||||
"github.com/ProtonMail/proton-bridge/internal/config/settings"
|
"github.com/ProtonMail/proton-bridge/internal/config/settings"
|
||||||
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
|
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
|
||||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
|
||||||
"github.com/abiosoft/ishell"
|
"github.com/abiosoft/ishell"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -122,13 +121,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
|
if auth.HasTwoFactor() {
|
||||||
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
||||||
if twoFactor == "" {
|
if twoFactor == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
|
err = client.Auth2FA(context.Background(), twoFactor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.processAPIError(err)
|
f.processAPIError(err)
|
||||||
return
|
return
|
||||||
@ -136,7 +135,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
|||||||
}
|
}
|
||||||
|
|
||||||
mailboxPassword := password
|
mailboxPassword := password
|
||||||
if auth.PasswordMode == pmapi.TwoPasswordMode {
|
if auth.HasMailboxPassword() {
|
||||||
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
||||||
}
|
}
|
||||||
if mailboxPassword == "" {
|
if mailboxPassword == "" {
|
||||||
|
|||||||
@ -157,15 +157,6 @@ func New( //nolint[funlen]
|
|||||||
})
|
})
|
||||||
fe.AddCmd(updatesCmd)
|
fe.AddCmd(updatesCmd)
|
||||||
|
|
||||||
// Check commands.
|
|
||||||
checkCmd := &ishell.Cmd{Name: "check", Help: "check internet connection or new version."}
|
|
||||||
checkCmd.AddCmd(&ishell.Cmd{Name: "internet",
|
|
||||||
Help: "check internet connection. (aliases: i, conn, connection)",
|
|
||||||
Aliases: []string{"i", "con", "connection"},
|
|
||||||
Func: fe.checkInternetConnection,
|
|
||||||
})
|
|
||||||
fe.AddCmd(checkCmd)
|
|
||||||
|
|
||||||
// Print info commands.
|
// Print info commands.
|
||||||
fe.AddCmd(&ishell.Cmd{Name: "log-dir",
|
fe.AddCmd(&ishell.Cmd{Name: "log-dir",
|
||||||
Help: "print path to directory with logs. (aliases: log, logs)",
|
Help: "print path to directory with logs. (aliases: log, logs)",
|
||||||
@ -228,14 +219,14 @@ func New( //nolint[funlen]
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *frontendCLI) watchEvents() {
|
func (f *frontendCLI) watchEvents() {
|
||||||
errorCh := f.getEventChannel(events.ErrorEvent)
|
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
|
||||||
credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent)
|
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||||
internetOffCh := f.getEventChannel(events.InternetOffEvent)
|
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
|
||||||
internetOnCh := f.getEventChannel(events.InternetOnEvent)
|
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
|
||||||
addressChangedCh := f.getEventChannel(events.AddressChangedEvent)
|
addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
|
||||||
addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent)
|
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||||
logoutCh := f.getEventChannel(events.LogoutEvent)
|
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
|
||||||
certIssue := f.getEventChannel(events.TLSCertIssue)
|
certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case errorDetails := <-errorCh:
|
case errorDetails := <-errorCh:
|
||||||
@ -262,13 +253,6 @@ func (f *frontendCLI) watchEvents() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *frontendCLI) getEventChannel(event string) <-chan string {
|
|
||||||
ch := make(chan string)
|
|
||||||
f.eventListener.Add(event, ch)
|
|
||||||
f.eventListener.RetryEmit(event)
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop starts the frontend loop with an interactive shell.
|
// Loop starts the frontend loop with an interactive shell.
|
||||||
func (f *frontendCLI) Loop() error {
|
func (f *frontendCLI) Loop() error {
|
||||||
f.Print(`
|
f.Print(`
|
||||||
|
|||||||
@ -39,14 +39,6 @@ func (f *frontendCLI) restart(c *ishell.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *frontendCLI) checkInternetConnection(c *ishell.Context) {
|
|
||||||
if f.bridge.CheckConnection() == nil {
|
|
||||||
f.Println("Internet connection is available.")
|
|
||||||
} else {
|
|
||||||
f.Println("Can not contact the server, please check your internet connection.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *frontendCLI) printLogDir(c *ishell.Context) {
|
func (f *frontendCLI) printLogDir(c *ishell.Context) {
|
||||||
if path, err := f.locations.ProvideLogsPath(); err != nil {
|
if path, err := f.locations.ProvideLogsPath(); err != nil {
|
||||||
f.Println("Failed to determine location of log files")
|
f.Println("Failed to determine location of log files")
|
||||||
|
|||||||
@ -20,6 +20,7 @@ package cli
|
|||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -70,13 +71,10 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
|
|||||||
func (f *frontendCLI) processAPIError(err error) {
|
func (f *frontendCLI) processAPIError(err error) {
|
||||||
log.Warn("API error: ", err)
|
log.Warn("API error: ", err)
|
||||||
switch err {
|
switch err {
|
||||||
// FIXME(conman): How to handle various API errors?
|
|
||||||
/*
|
|
||||||
case pmapi.ErrNoConnection:
|
case pmapi.ErrNoConnection:
|
||||||
f.notifyInternetOff()
|
f.notifyInternetOff()
|
||||||
case pmapi.ErrUpgradeApplication:
|
case pmapi.ErrUpgradeApplication:
|
||||||
f.notifyNeedUpgrade()
|
f.notifyNeedUpgrade()
|
||||||
*/
|
|
||||||
default:
|
default:
|
||||||
f.Println("Server error:", err.Error())
|
f.Println("Server error:", err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -409,7 +409,6 @@ Dialog {
|
|||||||
|
|
||||||
onShow: {
|
onShow: {
|
||||||
if (winMain.updateState==gui.enums.statusNoInternet) {
|
if (winMain.updateState==gui.enums.statusNoInternet) {
|
||||||
go.checkInternet()
|
|
||||||
if (winMain.updateState==gui.enums.statusNoInternet) {
|
if (winMain.updateState==gui.enums.statusNoInternet) {
|
||||||
go.notifyError(gui.enums.errNoInternet)
|
go.notifyError(gui.enums.errNoInternet)
|
||||||
root.hide()
|
root.hide()
|
||||||
|
|||||||
@ -857,14 +857,12 @@ Dialog {
|
|||||||
inputPort . checkIsANumber()
|
inputPort . checkIsANumber()
|
||||||
//emailProvider . currentIndex!=0
|
//emailProvider . currentIndex!=0
|
||||||
)) isOK = false
|
)) isOK = false
|
||||||
go.checkInternet()
|
|
||||||
if (winMain.updateState == gui.enums.statusNoInternet) { // todo: use main error dialog for this
|
if (winMain.updateState == gui.enums.statusNoInternet) { // todo: use main error dialog for this
|
||||||
errorPopup.show(qsTr("Please check your internet connection."))
|
errorPopup.show(qsTr("Please check your internet connection."))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
case 2: // loading structure
|
case 2: // loading structure
|
||||||
go.checkInternet()
|
|
||||||
if (winMain.updateState == gui.enums.statusNoInternet) {
|
if (winMain.updateState == gui.enums.statusNoInternet) {
|
||||||
errorPopup.show(qsTr("Please check your internet connection."))
|
errorPopup.show(qsTr("Please check your internet connection."))
|
||||||
return false
|
return false
|
||||||
@ -949,7 +947,6 @@ Dialog {
|
|||||||
onShow : {
|
onShow : {
|
||||||
root.clear()
|
root.clear()
|
||||||
if (winMain.updateState==gui.enums.statusNoInternet) {
|
if (winMain.updateState==gui.enums.statusNoInternet) {
|
||||||
go.checkInternet()
|
|
||||||
if (winMain.updateState==gui.enums.statusNoInternet) {
|
if (winMain.updateState==gui.enums.statusNoInternet) {
|
||||||
winMain.popupMessage.show(go.canNotReachAPI)
|
winMain.popupMessage.show(go.canNotReachAPI)
|
||||||
root.hide()
|
root.hide()
|
||||||
|
|||||||
@ -25,33 +25,12 @@ import ProtonUI 1.0
|
|||||||
Rectangle {
|
Rectangle {
|
||||||
id: root
|
id: root
|
||||||
property var iTry: 0
|
property var iTry: 0
|
||||||
property var secLeft: 0
|
|
||||||
property var second: 1000 // convert millisecond to second
|
property var second: 1000 // convert millisecond to second
|
||||||
property var checkInterval: [ 5, 10, 30, 60, 120, 300, 600 ] // seconds
|
|
||||||
property bool isVisible: true
|
property bool isVisible: true
|
||||||
property var fontSize : 1.2 * Style.main.fontSize
|
property var fontSize : 1.2 * Style.main.fontSize
|
||||||
color : "black"
|
color : "black"
|
||||||
state: "upToDate"
|
state: "upToDate"
|
||||||
|
|
||||||
Timer {
|
|
||||||
id: retryInternet
|
|
||||||
interval: second
|
|
||||||
triggeredOnStart: false
|
|
||||||
repeat: true
|
|
||||||
onTriggered : {
|
|
||||||
secLeft--
|
|
||||||
if (secLeft <= 0) {
|
|
||||||
retryInternet.stop()
|
|
||||||
go.checkInternet()
|
|
||||||
if (iTry < checkInterval.length-1) {
|
|
||||||
iTry++
|
|
||||||
}
|
|
||||||
secLeft=checkInterval[iTry]
|
|
||||||
retryInternet.start()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Row {
|
Row {
|
||||||
id: messageRow
|
id: messageRow
|
||||||
anchors.centerIn: root
|
anchors.centerIn: root
|
||||||
@ -110,16 +89,12 @@ Rectangle {
|
|||||||
case "internetCheck":
|
case "internetCheck":
|
||||||
break;
|
break;
|
||||||
case "noInternet" :
|
case "noInternet" :
|
||||||
retryInternet.start()
|
|
||||||
secLeft=checkInterval[iTry]
|
|
||||||
break;
|
break;
|
||||||
case "oldVersion":
|
case "oldVersion":
|
||||||
break;
|
break;
|
||||||
case "forceUpdate":
|
case "forceUpdate":
|
||||||
break;
|
break;
|
||||||
case "upToDate":
|
case "upToDate":
|
||||||
iTry = 0
|
|
||||||
secLeft=checkInterval[iTry]
|
|
||||||
break;
|
break;
|
||||||
case "updateRestart":
|
case "updateRestart":
|
||||||
break;
|
break;
|
||||||
@ -128,24 +103,6 @@ Rectangle {
|
|||||||
default :
|
default :
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (root.state!="noInternet") {
|
|
||||||
retryInternet.stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function timeToRetry() {
|
|
||||||
if (secLeft==1){
|
|
||||||
return qsTr("a second", "time to wait till internet connection is retried")
|
|
||||||
} else if (secLeft<60){
|
|
||||||
return secLeft + " " + qsTr("seconds", "time to wait till internet connection is retried")
|
|
||||||
} else {
|
|
||||||
var leading = ""+secLeft%60
|
|
||||||
if (leading.length < 2) {
|
|
||||||
leading = "0" + leading
|
|
||||||
}
|
|
||||||
return Math.floor(secLeft/60) + ":" + leading
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
states: [
|
states: [
|
||||||
@ -194,23 +151,15 @@ Rectangle {
|
|||||||
PropertyChanges {
|
PropertyChanges {
|
||||||
target: message
|
target: message
|
||||||
color: Style.main.line
|
color: Style.main.line
|
||||||
text: qsTr("Cannot contact server. Retrying in ", "displayed when the app is disconnected from the internet or server has problems")+timeToRetry()+"."
|
text: qsTr("Cannot contact server. Please wait...", "displayed when the app is disconnected from the internet or server has problems")
|
||||||
}
|
}
|
||||||
PropertyChanges {
|
PropertyChanges {
|
||||||
target: linkText
|
target: linkText
|
||||||
visible: false
|
visible: false
|
||||||
}
|
}
|
||||||
PropertyChanges {
|
|
||||||
target: actionText
|
|
||||||
visible: true
|
|
||||||
text: qsTr("Retry now", "click to try to connect to the internet when the app is disconnected from the internet")
|
|
||||||
onClicked: {
|
|
||||||
go.checkInternet()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
PropertyChanges {
|
PropertyChanges {
|
||||||
target: separatorText
|
target: separatorText
|
||||||
visible: true
|
visible: false
|
||||||
text: "|"
|
text: "|"
|
||||||
}
|
}
|
||||||
PropertyChanges {
|
PropertyChanges {
|
||||||
|
|||||||
@ -1331,10 +1331,6 @@ Window {
|
|||||||
return (fname!="fail")
|
return (fname!="fail")
|
||||||
}
|
}
|
||||||
|
|
||||||
function checkInternet() {
|
|
||||||
// nothing to do
|
|
||||||
}
|
|
||||||
|
|
||||||
function loadImportReports(fname) {
|
function loadImportReports(fname) {
|
||||||
console.log("load import reports for ", fname)
|
console.log("load import reports for ", fname)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@
|
|||||||
package qtcommon
|
package qtcommon
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -207,7 +208,7 @@ func (a *Accounts) Auth2FA(twoFacAuth string) int {
|
|||||||
if a.auth == nil || a.authClient == nil {
|
if a.auth == nil || a.authClient == nil {
|
||||||
err = fmt.Errorf("missing authentication in auth2FA %p %p", a.auth, a.authClient)
|
err = fmt.Errorf("missing authentication in auth2FA %p %p", a.auth, a.authClient)
|
||||||
} else {
|
} else {
|
||||||
err = a.authClient.Auth2FA(twoFacAuth, a.auth)
|
err = a.authClient.Auth2FA(context.Background(), twoFacAuth)
|
||||||
}
|
}
|
||||||
|
|
||||||
if a.showLoginError(err, "auth2FA") {
|
if a.showLoginError(err, "auth2FA") {
|
||||||
|
|||||||
@ -113,10 +113,3 @@ type Listener interface {
|
|||||||
Add(string, chan<- string)
|
Add(string, chan<- string)
|
||||||
RetryEmit(string)
|
RetryEmit(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeAndRegisterEvent(eventListener Listener, event string) <-chan string {
|
|
||||||
ch := make(chan string)
|
|
||||||
eventListener.Add(event, ch)
|
|
||||||
eventListener.RetryEmit(event)
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|||||||
@ -143,16 +143,16 @@ func (f *FrontendQt) NotifySilentUpdateError(err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *FrontendQt) watchEvents() {
|
func (f *FrontendQt) watchEvents() {
|
||||||
credentialsErrorCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.CredentialsErrorEvent)
|
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||||
internetOffCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOffEvent)
|
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
|
||||||
internetOnCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOnEvent)
|
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
|
||||||
secondInstanceCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.SecondInstanceEvent)
|
secondInstanceCh := f.eventListener.ProvideChannel(events.SecondInstanceEvent)
|
||||||
restartBridgeCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.RestartBridgeEvent)
|
restartBridgeCh := f.eventListener.ProvideChannel(events.RestartBridgeEvent)
|
||||||
addressChangedCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedEvent)
|
addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
|
||||||
addressChangedLogoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedLogoutEvent)
|
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||||
logoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.LogoutEvent)
|
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
|
||||||
updateApplicationCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UpgradeApplicationEvent)
|
updateApplicationCh := f.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
|
||||||
newUserCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UserRefreshEvent)
|
newUserCh := f.eventListener.ProvideChannel(events.UserRefreshEvent)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-credentialsErrorCh:
|
case <-credentialsErrorCh:
|
||||||
@ -351,11 +351,6 @@ func (f *FrontendQt) sendBug(description, emailClient, address string) bool {
|
|||||||
// }
|
// }
|
||||||
//}
|
//}
|
||||||
|
|
||||||
// checkInternet is almost idetical to bridge
|
|
||||||
func (f *FrontendQt) checkInternet() {
|
|
||||||
f.Qml.SetConnectionStatus(f.ie.CheckConnection() == nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *FrontendQt) showError(code int, err error) {
|
func (f *FrontendQt) showError(code int, err error) {
|
||||||
f.Qml.SetErrorDescription(err.Error())
|
f.Qml.SetErrorDescription(err.Error())
|
||||||
log.WithField("code", code).Errorln(err.Error())
|
log.WithField("code", code).Errorln(err.Error())
|
||||||
|
|||||||
@ -78,7 +78,6 @@ type GoQMLInterface struct {
|
|||||||
_ string `property:"versionCheckFailed"`
|
_ string `property:"versionCheckFailed"`
|
||||||
//
|
//
|
||||||
_ func(isAvailable bool) `signal:"setConnectionStatus"`
|
_ func(isAvailable bool) `signal:"setConnectionStatus"`
|
||||||
_ func() `slot:"checkInternet"`
|
|
||||||
|
|
||||||
_ func() `slot:"setToRestart"`
|
_ func() `slot:"setToRestart"`
|
||||||
|
|
||||||
@ -189,8 +188,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) {
|
|||||||
return f.programVersion
|
return f.programVersion
|
||||||
})
|
})
|
||||||
|
|
||||||
s.ConnectCheckInternet(f.checkInternet)
|
|
||||||
|
|
||||||
s.ConnectSetToRestart(f.restarter.SetToRestart)
|
s.ConnectSetToRestart(f.restarter.SetToRestart)
|
||||||
|
|
||||||
s.ConnectLoadStructureForExport(f.LoadStructureForExport)
|
s.ConnectLoadStructureForExport(f.LoadStructureForExport)
|
||||||
|
|||||||
@ -20,6 +20,7 @@
|
|||||||
package qt
|
package qt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -173,7 +174,7 @@ func (s *FrontendQt) auth2FA(twoFacAuth string) int {
|
|||||||
if s.auth == nil || s.authClient == nil {
|
if s.auth == nil || s.authClient == nil {
|
||||||
err = fmt.Errorf("missing authentication in auth2FA %p %p", s.auth, s.authClient)
|
err = fmt.Errorf("missing authentication in auth2FA %p %p", s.auth, s.authClient)
|
||||||
} else {
|
} else {
|
||||||
err = s.authClient.Auth2FA(twoFacAuth, s.auth)
|
err = s.authClient.Auth2FA(context.Background(), twoFacAuth)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.showLoginError(err, "auth2FA") {
|
if s.showLoginError(err, "auth2FA") {
|
||||||
|
|||||||
@ -191,20 +191,20 @@ func (s *FrontendQt) NotifySilentUpdateError(err error) {
|
|||||||
func (s *FrontendQt) watchEvents() {
|
func (s *FrontendQt) watchEvents() {
|
||||||
s.WaitUntilFrontendIsReady()
|
s.WaitUntilFrontendIsReady()
|
||||||
|
|
||||||
errorCh := s.getEventChannel(events.ErrorEvent)
|
errorCh := s.eventListener.ProvideChannel(events.ErrorEvent)
|
||||||
credentialsErrorCh := s.getEventChannel(events.CredentialsErrorEvent)
|
credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||||
outgoingNoEncCh := s.getEventChannel(events.OutgoingNoEncEvent)
|
outgoingNoEncCh := s.eventListener.ProvideChannel(events.OutgoingNoEncEvent)
|
||||||
noActiveKeyForRecipientCh := s.getEventChannel(events.NoActiveKeyForRecipientEvent)
|
noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent)
|
||||||
internetOffCh := s.getEventChannel(events.InternetOffEvent)
|
internetOffCh := s.eventListener.ProvideChannel(events.InternetOffEvent)
|
||||||
internetOnCh := s.getEventChannel(events.InternetOnEvent)
|
internetOnCh := s.eventListener.ProvideChannel(events.InternetOnEvent)
|
||||||
secondInstanceCh := s.getEventChannel(events.SecondInstanceEvent)
|
secondInstanceCh := s.eventListener.ProvideChannel(events.SecondInstanceEvent)
|
||||||
restartBridgeCh := s.getEventChannel(events.RestartBridgeEvent)
|
restartBridgeCh := s.eventListener.ProvideChannel(events.RestartBridgeEvent)
|
||||||
addressChangedCh := s.getEventChannel(events.AddressChangedEvent)
|
addressChangedCh := s.eventListener.ProvideChannel(events.AddressChangedEvent)
|
||||||
addressChangedLogoutCh := s.getEventChannel(events.AddressChangedLogoutEvent)
|
addressChangedLogoutCh := s.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||||
logoutCh := s.getEventChannel(events.LogoutEvent)
|
logoutCh := s.eventListener.ProvideChannel(events.LogoutEvent)
|
||||||
updateApplicationCh := s.getEventChannel(events.UpgradeApplicationEvent)
|
updateApplicationCh := s.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
|
||||||
newUserCh := s.getEventChannel(events.UserRefreshEvent)
|
newUserCh := s.eventListener.ProvideChannel(events.UserRefreshEvent)
|
||||||
certIssue := s.getEventChannel(events.TLSCertIssue)
|
certIssue := s.eventListener.ProvideChannel(events.TLSCertIssue)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case errorDetails := <-errorCh:
|
case errorDetails := <-errorCh:
|
||||||
@ -254,13 +254,6 @@ func (s *FrontendQt) watchEvents() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FrontendQt) getEventChannel(event string) <-chan string {
|
|
||||||
ch := make(chan string)
|
|
||||||
s.eventListener.Add(event, ch)
|
|
||||||
s.eventListener.RetryEmit(event)
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop function for tests.
|
// Loop function for tests.
|
||||||
//
|
//
|
||||||
// It runs QtExecute in new thread with function returning itself after setup.
|
// It runs QtExecute in new thread with function returning itself after setup.
|
||||||
@ -653,10 +646,6 @@ func (s *FrontendQt) isSMTPSTARTTLS() bool {
|
|||||||
return !s.settings.GetBool(settings.SMTPSSLKey)
|
return !s.settings.GetBool(settings.SMTPSSLKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FrontendQt) checkInternet() {
|
|
||||||
s.Qml.SetConnectionStatus(s.bridge.CheckConnection() == nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FrontendQt) switchAddressModeUser(iAccount int) {
|
func (s *FrontendQt) switchAddressModeUser(iAccount int) {
|
||||||
defer s.Qml.ProcessFinished()
|
defer s.Qml.ProcessFinished()
|
||||||
userID := s.Accounts.get(iAccount).UserID()
|
userID := s.Accounts.get(iAccount).UserID()
|
||||||
|
|||||||
@ -84,7 +84,6 @@ type GoQMLInterface struct {
|
|||||||
_ string `property:"progressDescription"`
|
_ string `property:"progressDescription"`
|
||||||
|
|
||||||
_ func(isAvailable bool) `signal:"setConnectionStatus"`
|
_ func(isAvailable bool) `signal:"setConnectionStatus"`
|
||||||
_ func() `slot:"checkInternet"`
|
|
||||||
|
|
||||||
_ func() `slot:"setToRestart"`
|
_ func() `slot:"setToRestart"`
|
||||||
|
|
||||||
@ -205,8 +204,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) {
|
|||||||
return f.programVer
|
return f.programVer
|
||||||
})
|
})
|
||||||
|
|
||||||
s.ConnectCheckInternet(f.checkInternet)
|
|
||||||
|
|
||||||
s.ConnectSetToRestart(f.restarter.SetToRestart)
|
s.ConnectSetToRestart(f.restarter.SetToRestart)
|
||||||
|
|
||||||
s.ConnectToggleIsReportingOutgoingNoEnc(f.toggleIsReportingOutgoingNoEnc)
|
s.ConnectToggleIsReportingOutgoingNoEnc(f.toggleIsReportingOutgoingNoEnc)
|
||||||
|
|||||||
@ -55,7 +55,6 @@ type UserManager interface {
|
|||||||
GetUser(query string) (User, error)
|
GetUser(query string) (User, error)
|
||||||
DeleteUser(userID string, clearCache bool) error
|
DeleteUser(userID string, clearCache bool) error
|
||||||
ClearData() error
|
ClearData() error
|
||||||
CheckConnection() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// User is an interface of user needed by frontend.
|
// User is an interface of user needed by frontend.
|
||||||
|
|||||||
@ -38,11 +38,10 @@ type bridgeUser interface {
|
|||||||
IsCombinedAddressMode() bool
|
IsCombinedAddressMode() bool
|
||||||
GetAddressID(address string) (string, error)
|
GetAddressID(address string) (string, error)
|
||||||
GetPrimaryAddress() string
|
GetPrimaryAddress() string
|
||||||
UpdateUser() error
|
|
||||||
Logout() error
|
Logout() error
|
||||||
CloseConnection(address string)
|
CloseConnection(address string)
|
||||||
GetStore() storeUserProvider
|
GetStore() storeUserProvider
|
||||||
GetTemporaryPMAPIClient() pmapi.Client
|
GetClient() pmapi.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
type bridgeWrap struct {
|
type bridgeWrap struct {
|
||||||
|
|||||||
@ -422,7 +422,7 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
|
|||||||
if isStringInList(m.LabelIDs, pmapi.StarredLabel) {
|
if isStringInList(m.LabelIDs, pmapi.StarredLabel) {
|
||||||
messageFlagsMap[imap.FlaggedFlag] = true
|
messageFlagsMap[imap.FlaggedFlag] = true
|
||||||
}
|
}
|
||||||
if m.Unread == 0 {
|
if !m.Unread {
|
||||||
messageFlagsMap[imap.SeenFlag] = true
|
messageFlagsMap[imap.SeenFlag] = true
|
||||||
}
|
}
|
||||||
if m.Has(pmapi.FlagReplied) || m.Has(pmapi.FlagRepliedAll) {
|
if m.Has(pmapi.FlagReplied) || m.Has(pmapi.FlagRepliedAll) {
|
||||||
@ -560,7 +560,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if storeMessage.Message().Unread == 1 {
|
if storeMessage.Message().Unread {
|
||||||
for section := range msg.Body {
|
for section := range msg.Body {
|
||||||
// Peek means get messages without marking them as read.
|
// Peek means get messages without marking them as read.
|
||||||
// If client does not only ask for peek, we have to mark them as read.
|
// If client does not only ask for peek, we have to mark them as read.
|
||||||
|
|||||||
@ -93,7 +93,7 @@ func newIMAPUser(
|
|||||||
|
|
||||||
// This method should eventually no longer be necessary. Everything should go via store.
|
// This method should eventually no longer be necessary. Everything should go via store.
|
||||||
func (iu *imapUser) client() pmapi.Client {
|
func (iu *imapUser) client() pmapi.Client {
|
||||||
return iu.user.GetTemporaryPMAPIClient()
|
return iu.user.GetClient()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (iu *imapUser) isSubscribed(labelID string) bool {
|
func (iu *imapUser) isSubscribed(labelID string) bool {
|
||||||
|
|||||||
@ -22,6 +22,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/internal/transfer"
|
"github.com/ProtonMail/proton-bridge/internal/transfer"
|
||||||
"github.com/ProtonMail/proton-bridge/internal/users"
|
"github.com/ProtonMail/proton-bridge/internal/users"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
@ -40,6 +41,7 @@ type ImportExport struct {
|
|||||||
locations Locator
|
locations Locator
|
||||||
cache Cacher
|
cache Cacher
|
||||||
panicHandler users.PanicHandler
|
panicHandler users.PanicHandler
|
||||||
|
eventListener listener.Listener
|
||||||
clientManager pmapi.Manager
|
clientManager pmapi.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,13 +61,14 @@ func New(
|
|||||||
locations: locations,
|
locations: locations,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
panicHandler: panicHandler,
|
panicHandler: panicHandler,
|
||||||
|
eventListener: eventListener,
|
||||||
clientManager: clientManager,
|
clientManager: clientManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReportBug reports a new bug from the user.
|
// ReportBug reports a new bug from the user.
|
||||||
func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
||||||
return ie.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
|
return ie.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{
|
||||||
OS: osType,
|
OS: osType,
|
||||||
OSVersion: osVersion,
|
OSVersion: osVersion,
|
||||||
Browser: emailClient,
|
Browser: emailClient,
|
||||||
@ -89,7 +92,7 @@ func (ie *ImportExport) ReportFile(osType, osVersion, accountName, address strin
|
|||||||
|
|
||||||
report.AddAttachment("log", "report.log", bytes.NewReader(logdata))
|
report.AddAttachment("log", "report.log", bytes.NewReader(logdata))
|
||||||
|
|
||||||
return ie.clientManager.ReportBug(context.TODO(), report)
|
return ie.clientManager.ReportBug(context.Background(), report)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account.
|
// GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account.
|
||||||
@ -162,5 +165,23 @@ func (ie *ImportExport) getPMAPIProvider(username, address string) (*transfer.PM
|
|||||||
log.WithError(err).Info("Address does not exist, using all addresses")
|
log.WithError(err).Info("Address does not exist, using all addresses")
|
||||||
}
|
}
|
||||||
|
|
||||||
return transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
|
provider, err := transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
internetOffCh := ie.eventListener.ProvideChannel(events.InternetOffEvent)
|
||||||
|
internetOnCh := ie.eventListener.ProvideChannel(events.InternetOnEvent)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-internetOffCh:
|
||||||
|
provider.SetConnectionDown()
|
||||||
|
case <-internetOnCh:
|
||||||
|
provider.SetConnectionUp()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return provider, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,7 +31,7 @@ type bridgeUser interface {
|
|||||||
CheckBridgeLogin(password string) error
|
CheckBridgeLogin(password string) error
|
||||||
IsCombinedAddressMode() bool
|
IsCombinedAddressMode() bool
|
||||||
GetAddressID(address string) (string, error)
|
GetAddressID(address string) (string, error)
|
||||||
GetTemporaryPMAPIClient() pmapi.Client
|
GetClient() pmapi.Client
|
||||||
GetStore() storeUserProvider
|
GetStore() storeUserProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -81,7 +81,7 @@ func newSMTPUser(
|
|||||||
|
|
||||||
// This method should eventually no longer be necessary. Everything should go via store.
|
// This method should eventually no longer be necessary. Everything should go via store.
|
||||||
func (su *smtpUser) client() pmapi.Client {
|
func (su *smtpUser) client() pmapi.Client {
|
||||||
return su.user.GetTemporaryPMAPIClient()
|
return su.user.GetClient()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends an email from the given address to the given addresses with the given body.
|
// Send sends an email from the given address to the given addresses with the given body.
|
||||||
|
|||||||
@ -90,7 +90,7 @@ func getLabelPrefix(l *pmapi.Label) string {
|
|||||||
switch {
|
switch {
|
||||||
case pmapi.IsSystemLabel(l.ID):
|
case pmapi.IsSystemLabel(l.ID):
|
||||||
return ""
|
return ""
|
||||||
case l.Exclusive == 1:
|
case bool(l.Exclusive):
|
||||||
return UserFoldersPrefix
|
return UserFoldersPrefix
|
||||||
default:
|
default:
|
||||||
return UserLabelsPrefix
|
return UserLabelsPrefix
|
||||||
|
|||||||
@ -37,8 +37,8 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) {
|
|||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
m.store.SetChangeNotifier(m.changeNotifier)
|
m.store.SetChangeNotifier(m.changeNotifier)
|
||||||
|
|
||||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
|
func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
|
||||||
@ -52,8 +52,8 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
|
|||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
m.store.SetChangeNotifier(m.changeNotifier)
|
m.store.SetChangeNotifier(m.changeNotifier)
|
||||||
|
|
||||||
msg1 := getTestMessage("msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
msg2 := getTestMessage("msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
msg2 := getTestMessage("msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
require.Nil(t, m.store.createOrUpdateMessagesEvent([]*pmapi.Message{msg1, msg2}))
|
require.Nil(t, m.store.createOrUpdateMessagesEvent([]*pmapi.Message{msg1, msg2}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,8 +63,8 @@ func TestNotifyChangeDeleteMessage(t *testing.T) {
|
|||||||
|
|
||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
|
|
||||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
|
|
||||||
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(2))
|
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(2))
|
||||||
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(1))
|
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(1))
|
||||||
|
|||||||
@ -81,7 +81,7 @@ func (loop *eventLoop) client() pmapi.Client {
|
|||||||
func (loop *eventLoop) setFirstEventID() (err error) {
|
func (loop *eventLoop) setFirstEventID() (err error) {
|
||||||
loop.log.Info("Setting first event ID")
|
loop.log.Info("Setting first event ID")
|
||||||
|
|
||||||
event, err := loop.client().GetEvent(context.TODO(), "")
|
event, err := loop.client().GetEvent(context.Background(), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
loop.log.WithError(err).Error("Could not get latest event ID")
|
loop.log.WithError(err).Error("Could not get latest event ID")
|
||||||
return
|
return
|
||||||
@ -222,8 +222,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
|||||||
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
|
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
|
||||||
// (e.g. no internet, ulimit reached etc.)
|
// (e.g. no internet, ulimit reached etc.)
|
||||||
defer func() {
|
defer func() {
|
||||||
// FIXME(conman): How to handle errors of different types?
|
if errors.Cause(err) == pmapi.ErrNoConnection {
|
||||||
if errors.Is(err, pmapi.ErrNoConnection) {
|
|
||||||
l.Warn("Internet unavailable")
|
l.Warn("Internet unavailable")
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
@ -234,20 +233,17 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
|||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME(conman): Handle force upgrade.
|
|
||||||
/*
|
|
||||||
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
|
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
|
||||||
l.Warn("Need to upgrade application")
|
l.Warn("Need to upgrade application")
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
loop.errCounter = 0
|
loop.errCounter = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// All errors except ErrUnauthorized (which is not possible to recover from) are ignored.
|
// All errors except ErrUnauthorized (which is not possible to recover from) are ignored.
|
||||||
if !errors.Is(err, pmapi.ErrUnauthorized) {
|
if err != nil && errors.Cause(err) != pmapi.ErrUnauthorized {
|
||||||
l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped")
|
l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped")
|
||||||
loop.errCounter++
|
loop.errCounter++
|
||||||
if loop.errCounter == errMaxSentry {
|
if loop.errCounter == errMaxSentry {
|
||||||
@ -268,7 +264,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
|||||||
loop.pollCounter++
|
loop.pollCounter++
|
||||||
|
|
||||||
var event *pmapi.Event
|
var event *pmapi.Event
|
||||||
if event, err = loop.client().GetEvent(context.TODO(), loop.currentEventID); err != nil {
|
if event, err = loop.client().GetEvent(context.Background(), loop.currentEventID); err != nil {
|
||||||
return false, errors.Wrap(err, "failed to get event")
|
return false, errors.Wrap(err, "failed to get event")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return event.More == 1, err
|
return bool(event.More), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (loop *eventLoop) processEvent(event *pmapi.Event) (err error) {
|
func (loop *eventLoop) processEvent(event *pmapi.Event) (err error) {
|
||||||
@ -354,7 +350,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
|
|||||||
// Get old addresses for comparisons before updating user.
|
// Get old addresses for comparisons before updating user.
|
||||||
oldList := loop.client().Addresses()
|
oldList := loop.client().Addresses()
|
||||||
|
|
||||||
if err = loop.user.UpdateUser(); err != nil {
|
if err = loop.user.UpdateUser(context.Background()); err != nil {
|
||||||
if logoutErr := loop.user.Logout(); logoutErr != nil {
|
if logoutErr := loop.user.Logout(); logoutErr != nil {
|
||||||
log.WithError(logoutErr).Error("Failed to logout user after failed update")
|
log.WithError(logoutErr).Error("Failed to logout user after failed update")
|
||||||
}
|
}
|
||||||
@ -465,16 +461,12 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
|
|||||||
|
|
||||||
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
|
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
|
||||||
|
|
||||||
if msg, err = loop.client().GetMessage(context.TODO(), message.ID); err != nil {
|
if msg, err = loop.client().GetMessage(context.Background(), message.ID); err != nil {
|
||||||
// FIXME(conman): How to handle error of this particular type?
|
if _, ok := err.(pmapi.ErrUnprocessableEntity); ok {
|
||||||
|
|
||||||
/*
|
|
||||||
if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok {
|
|
||||||
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
|
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
|
||||||
err = nil
|
err = nil
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
return errors.Wrap(err, "failed to get message from API for updating")
|
return errors.Wrap(err, "failed to get message from API for updating")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -42,15 +42,15 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
|
|||||||
// next event if there is `More` of them.
|
// next event if there is `More` of them.
|
||||||
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{
|
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{
|
||||||
EventID: "event50",
|
EventID: "event50",
|
||||||
More: 1,
|
More: true,
|
||||||
}, nil),
|
}, nil),
|
||||||
m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{
|
m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{
|
||||||
EventID: "event70",
|
EventID: "event70",
|
||||||
More: 0,
|
More: false,
|
||||||
}, nil),
|
}, nil),
|
||||||
m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{
|
m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{
|
||||||
EventID: "event71",
|
EventID: "event71",
|
||||||
More: 0,
|
More: false,
|
||||||
}, nil),
|
}, nil),
|
||||||
)
|
)
|
||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
@ -188,7 +188,7 @@ func TestEventLoopUpdateMessage(t *testing.T) {
|
|||||||
msg := &pmapi.Message{
|
msg := &pmapi.Message{
|
||||||
ID: "msg1",
|
ID: "msg1",
|
||||||
Subject: "old",
|
Subject: "old",
|
||||||
Unread: 0,
|
Unread: false,
|
||||||
Flags: 10,
|
Flags: 10,
|
||||||
Sender: address1,
|
Sender: address1,
|
||||||
ToList: []*mail.Address{address2},
|
ToList: []*mail.Address{address2},
|
||||||
@ -200,7 +200,7 @@ func TestEventLoopUpdateMessage(t *testing.T) {
|
|||||||
newMsg := &pmapi.Message{
|
newMsg := &pmapi.Message{
|
||||||
ID: "msg1",
|
ID: "msg1",
|
||||||
Subject: "new",
|
Subject: "new",
|
||||||
Unread: 1,
|
Unread: true,
|
||||||
Flags: 11,
|
Flags: 11,
|
||||||
Sender: address2,
|
Sender: address2,
|
||||||
ToList: []*mail.Address{address1},
|
ToList: []*mail.Address{address1},
|
||||||
|
|||||||
@ -129,17 +129,10 @@ func (mc *mailboxCounts) getPMLabel() *pmapi.Label {
|
|||||||
Color: mc.Color,
|
Color: mc.Color,
|
||||||
Order: mc.Order,
|
Order: mc.Order,
|
||||||
Type: pmapi.LabelTypeMailbox,
|
Type: pmapi.LabelTypeMailbox,
|
||||||
Exclusive: mc.isExclusive(),
|
Exclusive: pmapi.Boolean(mc.IsFolder),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mc *mailboxCounts) isExclusive() int {
|
|
||||||
if mc.IsFolder {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// createOrUpdateMailboxCountsBuckets will not change the on-API-counts.
|
// createOrUpdateMailboxCountsBuckets will not change the on-API-counts.
|
||||||
func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) error {
|
func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) error {
|
||||||
// Don't forget about system folders.
|
// Don't forget about system folders.
|
||||||
@ -162,7 +155,7 @@ func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) er
|
|||||||
mailbox.LabelName = label.Path
|
mailbox.LabelName = label.Path
|
||||||
mailbox.Color = label.Color
|
mailbox.Color = label.Color
|
||||||
mailbox.Order = label.Order
|
mailbox.Order = label.Order
|
||||||
mailbox.IsFolder = label.Exclusive == 1
|
mailbox.IsFolder = bool(label.Exclusive)
|
||||||
|
|
||||||
// Write.
|
// Write.
|
||||||
if err = mailbox.txWriteToBucket(countsBkt); err != nil {
|
if err = mailbox.txWriteToBucket(countsBkt); err != nil {
|
||||||
|
|||||||
@ -75,7 +75,7 @@ func TestMailboxNames(t *testing.T) {
|
|||||||
newLabel(100, "labelID1", "Label1"),
|
newLabel(100, "labelID1", "Label1"),
|
||||||
newLabel(1000, "folderID1", "Folder1"),
|
newLabel(1000, "folderID1", "Folder1"),
|
||||||
}
|
}
|
||||||
foldersAndLabels[1].Exclusive = 1
|
foldersAndLabels[1].Exclusive = true
|
||||||
|
|
||||||
for _, counts := range getSystemFolders() {
|
for _, counts := range getSystemFolders() {
|
||||||
foldersAndLabels = append(foldersAndLabels, counts.getPMLabel())
|
foldersAndLabels = append(foldersAndLabels, counts.getPMLabel())
|
||||||
|
|||||||
@ -37,10 +37,10 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) {
|
|||||||
|
|
||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
|
|
||||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
|
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
|
||||||
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||||
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{pmapi.AllMailLabel})
|
insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
|
|
||||||
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
|
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
|
||||||
|
|
||||||
@ -82,20 +82,20 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
|
|||||||
|
|
||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
|
|
||||||
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||||
|
|
||||||
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||||
tstMsg.ExternalID = " externalID-non-pm-com "
|
tstMsg.ExternalID = " externalID-non-pm-com "
|
||||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||||
|
|
||||||
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||||
tstMsg.ExternalID = "<externalID@pm.me>"
|
tstMsg.ExternalID = "<externalID@pm.me>"
|
||||||
tstMsg.Header = mail.Header{"References": []string{"wrongID", "externalID-non-pm-com", "msg2"}}
|
tstMsg.Header = mail.Header{"References": []string{"wrongID", "externalID-non-pm-com", "msg2"}}
|
||||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||||
|
|
||||||
// Not sure if this is a real-world scenario but we should be able to address this properly.
|
// Not sure if this is a real-world scenario but we should be able to address this properly.
|
||||||
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||||
tstMsg.ExternalID = " < external.()+*[]ID@another.pm.me > "
|
tstMsg.ExternalID = " < external.()+*[]ID@another.pm.me > "
|
||||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||||
|
|
||||||
|
|||||||
@ -18,8 +18,6 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@ -43,7 +41,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
|
|||||||
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
|
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
|
||||||
// wrapping it.
|
// wrapping it.
|
||||||
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
|
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
|
||||||
msg, err := storeMailbox.client().GetMessage(context.TODO(), apiID)
|
msg, err := storeMailbox.client().GetMessage(exposeContextForIMAP(), apiID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -70,7 +68,7 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
|
|||||||
Message: body,
|
Message: body,
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := storeMailbox.client().Import(context.TODO(), pmapi.ImportMsgReqs{importReqs})
|
res, err := storeMailbox.client().Import(exposeContextForIMAP(), pmapi.ImportMsgReqs{importReqs})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -99,7 +97,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
|
|||||||
return ErrAllMailOpNotAllowed
|
return ErrAllMailOpNotAllowed
|
||||||
}
|
}
|
||||||
defer storeMailbox.pollNow()
|
defer storeMailbox.pollNow()
|
||||||
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
|
return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnlabelMessages removes the label by calling an API.
|
// UnlabelMessages removes the label by calling an API.
|
||||||
@ -112,7 +110,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
|
|||||||
return ErrAllMailOpNotAllowed
|
return ErrAllMailOpNotAllowed
|
||||||
}
|
}
|
||||||
defer storeMailbox.pollNow()
|
defer storeMailbox.pollNow()
|
||||||
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
|
return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkMessagesRead marks the message read by calling an API.
|
// MarkMessagesRead marks the message read by calling an API.
|
||||||
@ -132,14 +130,14 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
|
|||||||
// Therefore we do not issue API update if the message is already read.
|
// Therefore we do not issue API update if the message is already read.
|
||||||
ids := []string{}
|
ids := []string{}
|
||||||
for _, apiID := range apiIDs {
|
for _, apiID := range apiIDs {
|
||||||
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread == 1 {
|
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread {
|
||||||
ids = append(ids, apiID)
|
ids = append(ids, apiID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(ids) == 0 {
|
if len(ids) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return storeMailbox.client().MarkMessagesRead(context.TODO(), ids)
|
return storeMailbox.client().MarkMessagesRead(exposeContextForIMAP(), ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkMessagesUnread marks the message unread by calling an API.
|
// MarkMessagesUnread marks the message unread by calling an API.
|
||||||
@ -151,7 +149,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
|
|||||||
"mailbox": storeMailbox.Name,
|
"mailbox": storeMailbox.Name,
|
||||||
}).Trace("Marking messages as unread")
|
}).Trace("Marking messages as unread")
|
||||||
defer storeMailbox.pollNow()
|
defer storeMailbox.pollNow()
|
||||||
return storeMailbox.client().MarkMessagesUnread(context.TODO(), apiIDs)
|
return storeMailbox.client().MarkMessagesUnread(exposeContextForIMAP(), apiIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkMessagesStarred adds the Starred label by calling an API.
|
// MarkMessagesStarred adds the Starred label by calling an API.
|
||||||
@ -164,7 +162,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
|
|||||||
"mailbox": storeMailbox.Name,
|
"mailbox": storeMailbox.Name,
|
||||||
}).Trace("Marking messages as starred")
|
}).Trace("Marking messages as starred")
|
||||||
defer storeMailbox.pollNow()
|
defer storeMailbox.pollNow()
|
||||||
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
|
return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkMessagesUnstarred removes the Starred label by calling an API.
|
// MarkMessagesUnstarred removes the Starred label by calling an API.
|
||||||
@ -177,7 +175,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
|
|||||||
"mailbox": storeMailbox.Name,
|
"mailbox": storeMailbox.Name,
|
||||||
}).Trace("Marking messages as unstarred")
|
}).Trace("Marking messages as unstarred")
|
||||||
defer storeMailbox.pollNow()
|
defer storeMailbox.pollNow()
|
||||||
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
|
return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API
|
// MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API
|
||||||
@ -261,11 +259,11 @@ func (storeMailbox *Mailbox) RemoveDeleted(apiIDs []string) error {
|
|||||||
}
|
}
|
||||||
case pmapi.DraftLabel:
|
case pmapi.DraftLabel:
|
||||||
storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts")
|
storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts")
|
||||||
if err := storeMailbox.client().DeleteMessages(context.TODO(), apiIDs); err != nil {
|
if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), apiIDs); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if err := storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID); err != nil {
|
if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -303,13 +301,13 @@ func (storeMailbox *Mailbox) deleteFromTrashOrSpam(apiIDs []string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(messageIDsToUnlabel) > 0 {
|
if len(messageIDsToUnlabel) > 0 {
|
||||||
if err := storeMailbox.client().UnlabelMessages(context.TODO(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
||||||
l.WithError(err).Warning("Cannot unlabel before deleting")
|
l.WithError(err).Warning("Cannot unlabel before deleting")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(messageIDsToDelete) > 0 {
|
if len(messageIDsToDelete) > 0 {
|
||||||
storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages")
|
storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages")
|
||||||
if err := storeMailbox.client().DeleteMessages(context.TODO(), messageIDsToDelete); err != nil {
|
if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), messageIDsToDelete); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,10 +5,10 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
reflect "reflect"
|
context "context"
|
||||||
|
|
||||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
reflect "reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockPanicHandler is a mock of PanicHandler interface
|
// MockPanicHandler is a mock of PanicHandler interface
|
||||||
@ -207,17 +207,17 @@ func (mr *MockBridgeUserMockRecorder) Logout() *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateUser mocks base method
|
// UpdateUser mocks base method
|
||||||
func (m *MockBridgeUser) UpdateUser() error {
|
func (m *MockBridgeUser) UpdateUser(arg0 context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "UpdateUser")
|
ret := m.ctrl.Call(m, "UpdateUser", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateUser indicates an expected call of UpdateUser
|
// UpdateUser indicates an expected call of UpdateUser
|
||||||
func (mr *MockBridgeUserMockRecorder) UpdateUser() *gomock.Call {
|
func (mr *MockBridgeUserMockRecorder) UpdateUser(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockChangeNotifier is a mock of ChangeNotifier interface
|
// MockChangeNotifier is a mock of ChangeNotifier interface
|
||||||
|
|||||||
@ -5,10 +5,9 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
time "time"
|
time "time"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockListener is a mock of Listener interface
|
// MockListener is a mock of Listener interface
|
||||||
@ -58,6 +57,20 @@ func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideChannel mocks base method
|
||||||
|
func (m *MockListener) ProvideChannel(arg0 string) <-chan string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ProvideChannel", arg0)
|
||||||
|
ret0, _ := ret[0].(<-chan string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideChannel indicates an expected call of ProvideChannel
|
||||||
|
func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
// Remove mocks base method
|
// Remove mocks base method
|
||||||
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
|
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@ -101,6 +101,18 @@ var (
|
|||||||
ErrNoSuchSeqNum = errors.New("no such sequence number") //nolint[gochecknoglobals]
|
ErrNoSuchSeqNum = errors.New("no such sequence number") //nolint[gochecknoglobals]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// exposeContextForIMAP should be replaced once with context passed
|
||||||
|
// as an argument from IMAP package and IMAP library should cancel
|
||||||
|
// context when IMAP client cancels the request.
|
||||||
|
func exposeContextForIMAP() context.Context {
|
||||||
|
return context.TODO()
|
||||||
|
}
|
||||||
|
|
||||||
|
// exposeContextForSMTP is the same as above but for SMTP.
|
||||||
|
func exposeContextForSMTP() context.Context {
|
||||||
|
return context.TODO()
|
||||||
|
}
|
||||||
|
|
||||||
// Store is local user storage, which handles the synchronization between IMAP and PM API.
|
// Store is local user storage, which handles the synchronization between IMAP and PM API.
|
||||||
type Store struct {
|
type Store struct {
|
||||||
sentryReporter *sentry.Reporter
|
sentryReporter *sentry.Reporter
|
||||||
@ -278,7 +290,7 @@ func (store *Store) client() pmapi.Client {
|
|||||||
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
|
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
|
||||||
// the API is unavailable for whatever reason it tries to fetch the labels locally.
|
// the API is unavailable for whatever reason it tries to fetch the labels locally.
|
||||||
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
|
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
|
||||||
if labels, err = store.client().ListLabels(context.TODO()); err != nil {
|
if labels, err = store.client().ListLabels(context.Background()); err != nil {
|
||||||
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
|
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
|
||||||
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
|
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
|
||||||
store.log.WithError(err).Error("Cannot list local labels")
|
store.log.WithError(err).Error("Cannot list local labels")
|
||||||
|
|||||||
@ -184,8 +184,8 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
|
|||||||
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
|
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
|
||||||
|
|
||||||
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
|
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
|
||||||
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
|
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true},
|
||||||
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
|
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true},
|
||||||
})
|
})
|
||||||
mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes()
|
mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes()
|
||||||
mocks.client.EXPECT().CountMessages(gomock.Any(), "")
|
mocks.client.EXPECT().CountMessages(gomock.Any(), "")
|
||||||
|
|||||||
@ -148,7 +148,7 @@ func getSplitIDAndCount(labelID string, api messageLister, page int) (string, in
|
|||||||
Limit: 1,
|
Limit: 1,
|
||||||
}
|
}
|
||||||
// If the page does not exist, an empty page instead of an error is returned.
|
// If the page does not exist, an empty page instead of an error is returned.
|
||||||
messages, total, err := api.ListMessages(context.TODO(), filter)
|
messages, total, err := api.ListMessages(context.Background(), filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, errors.Wrap(err, "failed to list messages")
|
return "", 0, errors.Wrap(err, "failed to list messages")
|
||||||
}
|
}
|
||||||
@ -190,7 +190,7 @@ func syncBatch( //nolint[funlen]
|
|||||||
|
|
||||||
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
|
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
|
||||||
|
|
||||||
messages, _, err := api.ListMessages(context.TODO(), filter)
|
messages, _, err := api.ListMessages(context.Background(), filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to list messages")
|
return errors.Wrap(err, "failed to list messages")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,7 +17,11 @@
|
|||||||
|
|
||||||
package store
|
package store
|
||||||
|
|
||||||
import "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
|
)
|
||||||
|
|
||||||
type PanicHandler interface {
|
type PanicHandler interface {
|
||||||
HandlePanic()
|
HandlePanic()
|
||||||
@ -32,7 +36,7 @@ type BridgeUser interface {
|
|||||||
GetPrimaryAddress() string
|
GetPrimaryAddress() string
|
||||||
GetStoreAddresses() []string
|
GetStoreAddresses() []string
|
||||||
GetClient() pmapi.Client
|
GetClient() pmapi.Client
|
||||||
UpdateUser() error
|
UpdateUser(context.Context) error
|
||||||
CloseAllConnections()
|
CloseAllConnections()
|
||||||
CloseConnection(string)
|
CloseConnection(string)
|
||||||
Logout() error
|
Logout() error
|
||||||
|
|||||||
@ -17,8 +17,6 @@
|
|||||||
|
|
||||||
package store
|
package store
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
// UserID returns user ID.
|
// UserID returns user ID.
|
||||||
func (store *Store) UserID() string {
|
func (store *Store) UserID() string {
|
||||||
return store.user.ID()
|
return store.user.ID()
|
||||||
@ -26,7 +24,7 @@ func (store *Store) UserID() string {
|
|||||||
|
|
||||||
// GetSpace returns used and total space in bytes.
|
// GetSpace returns used and total space in bytes.
|
||||||
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
||||||
apiUser, err := store.client().CurrentUser(context.TODO())
|
apiUser, err := store.client().CurrentUser(exposeContextForIMAP())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, err
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
@ -35,7 +33,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
|||||||
|
|
||||||
// GetMaxUpload returns max size of message + all attachments in bytes.
|
// GetMaxUpload returns max size of message + all attachments in bytes.
|
||||||
func (store *Store) GetMaxUpload() (int64, error) {
|
func (store *Store) GetMaxUpload() (int64, error) {
|
||||||
apiUser, err := store.client().CurrentUser(context.TODO())
|
apiUser, err := store.client().CurrentUser(exposeContextForIMAP())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -147,7 +147,7 @@ func (store *Store) createOrUpdateAddressInfo(addressList pmapi.AddressList) (er
|
|||||||
// filterAddresses filters out inactive addresses and ensures the original address is listed first.
|
// filterAddresses filters out inactive addresses and ensures the original address is listed first.
|
||||||
func filterAddresses(addressList pmapi.AddressList) (filteredList pmapi.AddressList) {
|
func filterAddresses(addressList pmapi.AddressList) (filteredList pmapi.AddressList) {
|
||||||
for _, address := range addressList {
|
for _, address := range addressList {
|
||||||
if address.Receive != pmapi.CanReceive {
|
if !address.Receive {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,6 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -39,14 +38,14 @@ func (store *Store) createMailbox(name string) error {
|
|||||||
|
|
||||||
color := store.leastUsedColor()
|
color := store.leastUsedColor()
|
||||||
|
|
||||||
var exclusive int
|
var exclusive bool
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(name, UserLabelsPrefix):
|
case strings.HasPrefix(name, UserLabelsPrefix):
|
||||||
name = strings.TrimPrefix(name, UserLabelsPrefix)
|
name = strings.TrimPrefix(name, UserLabelsPrefix)
|
||||||
exclusive = 0
|
exclusive = false
|
||||||
case strings.HasPrefix(name, UserFoldersPrefix):
|
case strings.HasPrefix(name, UserFoldersPrefix):
|
||||||
name = strings.TrimPrefix(name, UserFoldersPrefix)
|
name = strings.TrimPrefix(name, UserFoldersPrefix)
|
||||||
exclusive = 1
|
exclusive = true
|
||||||
default:
|
default:
|
||||||
// Ideally we would throw an error here, but then Outlook for
|
// Ideally we would throw an error here, but then Outlook for
|
||||||
// macOS keeps trying to make an IMAP Drafts folder and popping
|
// macOS keeps trying to make an IMAP Drafts folder and popping
|
||||||
@ -56,10 +55,10 @@ func (store *Store) createMailbox(name string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := store.client().CreateLabel(context.TODO(), &pmapi.Label{
|
_, err := store.client().CreateLabel(exposeContextForIMAP(), &pmapi.Label{
|
||||||
Name: name,
|
Name: name,
|
||||||
Color: color,
|
Color: color,
|
||||||
Exclusive: exclusive,
|
Exclusive: pmapi.Boolean(exclusive),
|
||||||
Type: pmapi.LabelTypeMailbox,
|
Type: pmapi.LabelTypeMailbox,
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
@ -126,7 +125,7 @@ func (store *Store) leastUsedColor() string {
|
|||||||
func (store *Store) updateMailbox(labelID, newName, color string) error {
|
func (store *Store) updateMailbox(labelID, newName, color string) error {
|
||||||
defer store.eventLoop.pollNow()
|
defer store.eventLoop.pollNow()
|
||||||
|
|
||||||
_, err := store.client().UpdateLabel(context.TODO(), &pmapi.Label{
|
_, err := store.client().UpdateLabel(exposeContextForIMAP(), &pmapi.Label{
|
||||||
ID: labelID,
|
ID: labelID,
|
||||||
Name: newName,
|
Name: newName,
|
||||||
Color: color,
|
Color: color,
|
||||||
@ -143,15 +142,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
|
|||||||
var err error
|
var err error
|
||||||
switch labelID {
|
switch labelID {
|
||||||
case pmapi.SpamLabel:
|
case pmapi.SpamLabel:
|
||||||
err = store.client().EmptyFolder(context.TODO(), pmapi.SpamLabel, addressID)
|
err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.SpamLabel, addressID)
|
||||||
case pmapi.TrashLabel:
|
case pmapi.TrashLabel:
|
||||||
err = store.client().EmptyFolder(context.TODO(), pmapi.TrashLabel, addressID)
|
err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.TrashLabel, addressID)
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("cannot empty mailbox %v", labelID)
|
err = fmt.Errorf("cannot empty mailbox %v", labelID)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return store.client().DeleteLabel(context.TODO(), labelID)
|
return store.client().DeleteLabel(exposeContextForIMAP(), labelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
|
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
|
||||||
@ -166,7 +165,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
labels, err := store.client().ListLabels(context.TODO())
|
labels, err := store.client().ListLabels(exposeContextForIMAP())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,7 +19,6 @@ package store
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
@ -58,7 +57,7 @@ func (store *Store) CreateDraft(
|
|||||||
}
|
}
|
||||||
|
|
||||||
draftAction := store.getDraftAction(message)
|
draftAction := store.getDraftAction(message)
|
||||||
draft, err := store.client().CreateDraft(context.TODO(), message, parentID, draftAction)
|
draft, err := store.client().CreateDraft(exposeContextForSMTP(), message, parentID, draftAction)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errors.Wrap(err, "failed to create draft")
|
return nil, nil, errors.Wrap(err, "failed to create draft")
|
||||||
}
|
}
|
||||||
@ -70,7 +69,7 @@ func (store *Store) CreateDraft(
|
|||||||
for _, att := range attachments {
|
for _, att := range attachments {
|
||||||
att.attachment.MessageID = draft.ID
|
att.attachment.MessageID = draft.ID
|
||||||
|
|
||||||
createdAttachment, err := store.client().CreateAttachment(context.TODO(), att.attachment, att.encReader, att.sigReader)
|
createdAttachment, err := store.client().CreateAttachment(exposeContextForSMTP(), att.attachment, att.encReader, att.sigReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errors.Wrap(err, "failed to create attachment")
|
return nil, nil, errors.Wrap(err, "failed to create attachment")
|
||||||
}
|
}
|
||||||
@ -184,7 +183,7 @@ func (store *Store) getDraftAction(message *pmapi.Message) int {
|
|||||||
// SendMessage sends the message.
|
// SendMessage sends the message.
|
||||||
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
|
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
|
||||||
defer store.eventLoop.pollNow()
|
defer store.eventLoop.pollNow()
|
||||||
_, _, err := store.client().SendMessage(context.TODO(), messageID, req)
|
_, _, err := store.client().SendMessage(exposeContextForSMTP(), messageID, req)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
a "github.com/stretchr/testify/assert"
|
a "github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@ -34,10 +35,10 @@ func TestGetAllMessageIDs(t *testing.T) {
|
|||||||
|
|
||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
|
|
||||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
|
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
|
||||||
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||||
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{})
|
insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{})
|
||||||
|
|
||||||
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
|
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
|
||||||
}
|
}
|
||||||
@ -47,7 +48,7 @@ func TestGetMessageFromDB(t *testing.T) {
|
|||||||
defer clear()
|
defer clear()
|
||||||
|
|
||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
|
|
||||||
tests := []struct{ msgID, wantErr string }{
|
tests := []struct{ msgID, wantErr string }{
|
||||||
{"msg1", ""},
|
{"msg1", ""},
|
||||||
@ -72,7 +73,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
|
|||||||
defer clear()
|
defer clear()
|
||||||
|
|
||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
|
|
||||||
msg, err := m.store.getMessageFromDB("msg1")
|
msg, err := m.store.getMessageFromDB("msg1")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@ -104,7 +105,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
|
|||||||
a.Equal(t, wantHeader, msg.Header)
|
a.Equal(t, wantHeader, msg.Header)
|
||||||
|
|
||||||
// Check calculated data are not overridden by reinsert.
|
// Check calculated data are not overridden by reinsert.
|
||||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
|
|
||||||
msg, err = m.store.getMessageFromDB("msg1")
|
msg, err = m.store.getMessageFromDB("msg1")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@ -118,8 +119,8 @@ func TestDeleteMessage(t *testing.T) {
|
|||||||
defer clear()
|
defer clear()
|
||||||
|
|
||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
|
||||||
|
|
||||||
require.Nil(t, m.store.deleteMessageEvent("msg1"))
|
require.Nil(t, m.store.deleteMessageEvent("msg1"))
|
||||||
|
|
||||||
@ -127,17 +128,17 @@ func TestDeleteMessage(t *testing.T) {
|
|||||||
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
|
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
|
||||||
}
|
}
|
||||||
|
|
||||||
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread pmapi.Boolean, labelIDs []string) { //nolint[unparam]
|
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam]
|
||||||
msg := getTestMessage(id, subject, sender, unread, labelIDs)
|
msg := getTestMessage(id, subject, sender, unread, labelIDs)
|
||||||
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
|
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestMessage(id, subject, sender string, unread pmapi.Boolean, labelIDs []string) *pmapi.Message {
|
func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message {
|
||||||
address := &mail.Address{Address: sender}
|
address := &mail.Address{Address: sender}
|
||||||
return &pmapi.Message{
|
return &pmapi.Message{
|
||||||
ID: id,
|
ID: id,
|
||||||
Subject: subject,
|
Subject: subject,
|
||||||
Unread: unread,
|
Unread: pmapi.Boolean(unread),
|
||||||
Sender: address,
|
Sender: address,
|
||||||
ToList: []*mail.Address{address},
|
ToList: []*mail.Address{address},
|
||||||
LabelIDs: labelIDs,
|
LabelIDs: labelIDs,
|
||||||
@ -162,7 +163,7 @@ func TestCreateDraftCheckMessageSize(t *testing.T) {
|
|||||||
defer clear()
|
defer clear()
|
||||||
|
|
||||||
m.newStoreNoEvents(false)
|
m.newStoreNoEvents(false)
|
||||||
m.client.EXPECT().CurrentUser().Return(&pmapi.User{
|
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
|
||||||
MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+.
|
MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+.
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
@ -181,7 +182,7 @@ func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) {
|
|||||||
defer clear()
|
defer clear()
|
||||||
|
|
||||||
m.newStoreNoEvents(false)
|
m.newStoreNoEvents(false)
|
||||||
m.client.EXPECT().CurrentUser().Return(&pmapi.User{
|
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
|
||||||
MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+.
|
MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+.
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
|
|||||||
@ -35,7 +35,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
|
|||||||
|
|
||||||
// updateCountsFromServer will download and set the counts.
|
// updateCountsFromServer will download and set the counts.
|
||||||
func (store *Store) updateCountsFromServer() error {
|
func (store *Store) updateCountsFromServer() error {
|
||||||
counts, err := store.client().CountMessages(context.TODO(), "")
|
counts, err := store.client().CountMessages(context.Background(), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "cannot update counts from server")
|
return errors.Wrap(err, "cannot update counts from server")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,8 +31,8 @@ func TestLoadSaveSyncState(t *testing.T) {
|
|||||||
defer clear()
|
defer clear()
|
||||||
|
|
||||||
m.newStoreNoEvents(true)
|
m.newStoreNoEvents(true)
|
||||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||||
|
|
||||||
// Clear everything.
|
// Clear everything.
|
||||||
|
|
||||||
|
|||||||
@ -5,11 +5,10 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
imap "github.com/emersion/go-imap"
|
imap "github.com/emersion/go-imap"
|
||||||
sasl "github.com/emersion/go-sasl"
|
sasl "github.com/emersion/go-sasl"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
reflect "reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockPanicHandler is a mock of PanicHandler interface
|
// MockPanicHandler is a mock of PanicHandler interface
|
||||||
|
|||||||
@ -19,7 +19,9 @@ package transfer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -37,6 +39,8 @@ const (
|
|||||||
imapRetries = 10
|
imapRetries = 10
|
||||||
imapReconnectTimeout = 30 * time.Minute
|
imapReconnectTimeout = 30 * time.Minute
|
||||||
imapReconnectSleep = time.Minute
|
imapReconnectSleep = time.Minute
|
||||||
|
|
||||||
|
protonStatusURL = "http://protonstatus.com/vpn_status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type imapErrorLogger struct {
|
type imapErrorLogger struct {
|
||||||
@ -117,19 +121,15 @@ func (p *IMAPProvider) tryReconnect(ensureSelectedIn string) error {
|
|||||||
return previousErr
|
return previousErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME(conman): This should register as connection observer.
|
err := checkConnection()
|
||||||
|
|
||||||
/*
|
|
||||||
err := pmapi.CheckConnection()
|
|
||||||
log.WithError(err).Debug("Connection check")
|
log.WithError(err).Debug("Connection check")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
time.Sleep(imapReconnectSleep)
|
time.Sleep(imapReconnectSleep)
|
||||||
previousErr = err
|
previousErr = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
err := p.reauth()
|
err = p.reauth()
|
||||||
log.WithError(err).Debug("Reauth")
|
log.WithError(err).Debug("Reauth")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
time.Sleep(imapReconnectSleep)
|
time.Sleep(imapReconnectSleep)
|
||||||
@ -289,3 +289,23 @@ func (p *IMAPProvider) fetchHelper(uid bool, ensureSelectedIn string, seqSet *im
|
|||||||
return err
|
return err
|
||||||
}, ensureSelectedIn)
|
}, ensureSelectedIn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkConnection returns an error if there is no internet connection.
|
||||||
|
// Note we don't want to use client manager because it only reports connection
|
||||||
|
// issues with API; we are only interested here whether we can reach
|
||||||
|
// third-party IMAP servers.
|
||||||
|
func checkConnection() error {
|
||||||
|
client := &http.Client{Timeout: time.Second * 10}
|
||||||
|
|
||||||
|
resp, err := client.Get(protonStatusURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return fmt.Errorf("HTTP status code %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@ -52,15 +52,10 @@ func (p *PMAPIProvider) CreateMailbox(mailbox Mailbox) (Mailbox, error) {
|
|||||||
return Mailbox{}, errors.New("mailbox is already created")
|
return Mailbox{}, errors.New("mailbox is already created")
|
||||||
}
|
}
|
||||||
|
|
||||||
exclusive := 0
|
label, err := p.client.CreateLabel(context.Background(), &pmapi.Label{
|
||||||
if mailbox.IsExclusive {
|
|
||||||
exclusive = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
label, err := p.client.CreateLabel(context.TODO(), &pmapi.Label{
|
|
||||||
Name: mailbox.Name,
|
Name: mailbox.Name,
|
||||||
Color: mailbox.Color,
|
Color: mailbox.Color,
|
||||||
Exclusive: exclusive,
|
Exclusive: pmapi.Boolean(mailbox.IsExclusive),
|
||||||
Type: pmapi.LabelTypeMailbox,
|
Type: pmapi.LabelTypeMailbox,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -126,7 +121,7 @@ func (p *PMAPIProvider) importDraft(msg Message, globalMailbox *Mailbox) (string
|
|||||||
}
|
}
|
||||||
|
|
||||||
if message.Sender == nil {
|
if message.Sender == nil {
|
||||||
mainAddress := p.client().Addresses().Main()
|
mainAddress := p.client.Addresses().Main()
|
||||||
message.Sender = &mail.Address{
|
message.Sender = &mail.Address{
|
||||||
Name: mainAddress.DisplayName,
|
Name: mainAddress.DisplayName,
|
||||||
Address: mainAddress.Email,
|
Address: mainAddress.Email,
|
||||||
@ -227,14 +222,6 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var unread pmapi.Boolean
|
|
||||||
|
|
||||||
if msg.Unread {
|
|
||||||
unread = pmapi.True
|
|
||||||
} else {
|
|
||||||
unread = pmapi.False
|
|
||||||
}
|
|
||||||
|
|
||||||
labelIDs := []string{}
|
labelIDs := []string{}
|
||||||
for _, target := range msg.Targets {
|
for _, target := range msg.Targets {
|
||||||
// Frontend should not set All Mail to Rules, but to be sure...
|
// Frontend should not set All Mail to Rules, but to be sure...
|
||||||
@ -249,7 +236,7 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
|
|||||||
return &pmapi.ImportMsgReq{
|
return &pmapi.ImportMsgReq{
|
||||||
Metadata: &pmapi.ImportMetadata{
|
Metadata: &pmapi.ImportMetadata{
|
||||||
AddressID: p.addressID,
|
AddressID: p.addressID,
|
||||||
Unread: unread,
|
Unread: pmapi.Boolean(msg.Unread),
|
||||||
Time: message.Time,
|
Time: message.Time,
|
||||||
Flags: computeMessageFlags(message.Header),
|
Flags: computeMessageFlags(message.Header),
|
||||||
LabelIDs: labelIDs,
|
LabelIDs: labelIDs,
|
||||||
|
|||||||
@ -153,10 +153,10 @@ func setupPMAPIRules(rules transferRules) {
|
|||||||
func setupPMAPIClientExpectationForExport(m *mocks) {
|
func setupPMAPIClientExpectationForExport(m *mocks) {
|
||||||
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
|
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
|
||||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{
|
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{
|
||||||
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: 0, Order: 2},
|
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: false, Order: 2},
|
||||||
{ID: "label2", Name: "Bar", Color: "green", Exclusive: 0, Order: 1},
|
{ID: "label2", Name: "Bar", Color: "green", Exclusive: false, Order: 1},
|
||||||
{ID: "folder1", Name: "One", Color: "red", Exclusive: 1, Order: 1},
|
{ID: "folder1", Name: "One", Color: "red", Exclusive: true, Order: 1},
|
||||||
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: 1, Order: 2},
|
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: true, Order: 2},
|
||||||
}, nil).AnyTimes()
|
}, nil).AnyTimes()
|
||||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{
|
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{
|
||||||
{LabelID: "label1", Total: 10},
|
{LabelID: "label1", Total: 10},
|
||||||
|
|||||||
@ -30,9 +30,17 @@ import (
|
|||||||
const (
|
const (
|
||||||
pmapiRetries = 10
|
pmapiRetries = 10
|
||||||
pmapiReconnectTimeout = 30 * time.Minute
|
pmapiReconnectTimeout = 30 * time.Minute
|
||||||
pmapiReconnectSleep = time.Minute
|
pmapiReconnectSleep = 10 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (p *PMAPIProvider) SetConnectionUp() {
|
||||||
|
p.connection = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PMAPIProvider) SetConnectionDown() {
|
||||||
|
p.connection = false
|
||||||
|
}
|
||||||
|
|
||||||
func (p *PMAPIProvider) ensureConnection(callback func() error) error {
|
func (p *PMAPIProvider) ensureConnection(callback func() error) error {
|
||||||
var callErr error
|
var callErr error
|
||||||
for i := 1; i <= pmapiRetries; i++ {
|
for i := 1; i <= pmapiRetries; i++ {
|
||||||
@ -58,18 +66,10 @@ func (p *PMAPIProvider) tryReconnect() error {
|
|||||||
return previousErr
|
return previousErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME(conman): This should register as a connection observer somehow...
|
if !p.connection {
|
||||||
// Maybe the entire "provider" could register as an observer and pause if it is notified of dropped connection?
|
|
||||||
|
|
||||||
/*
|
|
||||||
err := p.clientManager.CheckConnection()
|
|
||||||
log.WithError(err).Debug("Connection check")
|
|
||||||
if err != nil {
|
|
||||||
time.Sleep(pmapiReconnectSleep)
|
time.Sleep(pmapiReconnectSleep)
|
||||||
previousErr = err
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -83,7 +83,7 @@ func (p *PMAPIProvider) listMessages(filter *pmapi.MessagesFilter) (messages []*
|
|||||||
p.timeIt.start("listing", key)
|
p.timeIt.start("listing", key)
|
||||||
defer p.timeIt.stop("listing", key)
|
defer p.timeIt.stop("listing", key)
|
||||||
|
|
||||||
messages, count, err = p.client.ListMessages(context.TODO(), filter)
|
messages, count, err = p.client.ListMessages(context.Background(), filter)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@ -94,7 +94,7 @@ func (p *PMAPIProvider) getMessage(msgID string) (message *pmapi.Message, err er
|
|||||||
p.timeIt.start("download", msgID)
|
p.timeIt.start("download", msgID)
|
||||||
defer p.timeIt.stop("download", msgID)
|
defer p.timeIt.stop("download", msgID)
|
||||||
|
|
||||||
message, err = p.client.GetMessage(context.TODO(), msgID)
|
message, err = p.client.GetMessage(context.Background(), msgID)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@ -105,7 +105,7 @@ func (p *PMAPIProvider) importRequest(msgSourceID string, req pmapi.ImportMsgReq
|
|||||||
p.timeIt.start("upload", msgSourceID)
|
p.timeIt.start("upload", msgSourceID)
|
||||||
defer p.timeIt.stop("upload", msgSourceID)
|
defer p.timeIt.stop("upload", msgSourceID)
|
||||||
|
|
||||||
res, err = p.client.Import(context.TODO(), req)
|
res, err = p.client.Import(context.Background(), req)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@ -116,7 +116,7 @@ func (p *PMAPIProvider) createDraft(msgSourceID string, message *pmapi.Message,
|
|||||||
p.timeIt.start("upload", msgSourceID)
|
p.timeIt.start("upload", msgSourceID)
|
||||||
defer p.timeIt.stop("upload", msgSourceID)
|
defer p.timeIt.stop("upload", msgSourceID)
|
||||||
|
|
||||||
draft, err = p.client.CreateDraft(context.TODO(), message, parent, action)
|
draft, err = p.client.CreateDraft(context.Background(), message, parent, action)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@ -129,7 +129,7 @@ func (p *PMAPIProvider) createAttachment(msgSourceID string, att *pmapi.Attachme
|
|||||||
p.timeIt.start("upload", key)
|
p.timeIt.start("upload", key)
|
||||||
defer p.timeIt.stop("upload", key)
|
defer p.timeIt.stop("upload", key)
|
||||||
|
|
||||||
created, err = p.client.CreateAttachment(context.TODO(), att, r, sig)
|
created, err = p.client.CreateAttachment(context.Background(), att, r, sig)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|||||||
@ -28,6 +28,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Masterminds/semver/v3"
|
"github.com/Masterminds/semver/v3"
|
||||||
"github.com/ProtonMail/proton-bridge/internal/config/settings"
|
"github.com/ProtonMail/proton-bridge/internal/config/settings"
|
||||||
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -274,7 +275,7 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestUpdater(manager *mocks.MockManager, curVer string, earlyAccess bool) *Updater {
|
func newTestUpdater(manager pmapi.Manager, curVer string, earlyAccess bool) *Updater {
|
||||||
return New(
|
return New(
|
||||||
manager,
|
manager,
|
||||||
&fakeInstaller{},
|
&fakeInstaller{},
|
||||||
|
|||||||
@ -1,251 +0,0 @@
|
|||||||
// Copyright (c) 2021 Proton Technologies AG
|
|
||||||
//
|
|
||||||
// This file is part of ProtonMail Bridge.
|
|
||||||
//
|
|
||||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
|
||||||
// it under the terms of the GNU General Public License as published by
|
|
||||||
// the Free Software Foundation, either version 3 of the License, or
|
|
||||||
// (at your option) any later version.
|
|
||||||
//
|
|
||||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
|
||||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
// GNU General Public License for more details.
|
|
||||||
//
|
|
||||||
// You should have received a copy of the GNU General Public License
|
|
||||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
package users
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUpdateUser(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
user := testNewUser(m)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil),
|
|
||||||
m.pmapiClient.EXPECT().ReloadKeys([]byte(testCredentials.MailboxPassword)).Return(nil),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
|
|
||||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert.NoError(t, user.UpdateUser())
|
|
||||||
|
|
||||||
waitForEvents()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserSwitchAddressMode(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
user := testNewUser(m)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
assert.True(t, user.store.IsCombinedMode())
|
|
||||||
assert.True(t, user.creds.IsCombinedAddressMode)
|
|
||||||
assert.True(t, user.IsCombinedAddressMode())
|
|
||||||
waitForEvents()
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
|
||||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsSplit, nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert.NoError(t, user.SwitchAddressMode())
|
|
||||||
assert.False(t, user.store.IsCombinedMode())
|
|
||||||
assert.False(t, user.creds.IsCombinedAddressMode)
|
|
||||||
assert.False(t, user.IsCombinedAddressMode())
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
|
|
||||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
|
||||||
)
|
|
||||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
|
||||||
|
|
||||||
assert.NoError(t, user.SwitchAddressMode())
|
|
||||||
assert.True(t, user.store.IsCombinedMode())
|
|
||||||
assert.True(t, user.creds.IsCombinedAddressMode)
|
|
||||||
assert.True(t, user.IsCombinedAddressMode())
|
|
||||||
|
|
||||||
waitForEvents()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogoutUser(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
user := testNewUserForLogout(m)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.pmapiClient.EXPECT().Logout().Return(),
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
|
||||||
|
|
||||||
err := user.Logout()
|
|
||||||
|
|
||||||
waitForEvents()
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogoutUserFailsLogout(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
user := testNewUserForLogout(m)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.pmapiClient.EXPECT().Logout().Return(),
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
|
|
||||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
|
||||||
)
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
|
||||||
|
|
||||||
err := user.Logout()
|
|
||||||
waitForEvents()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckBridgeLoginOK(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
user := testNewUser(m)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
|
|
||||||
|
|
||||||
waitForEvents()
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckBridgeLoginTwiceOK(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
user := testNewUser(m)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
|
||||||
m.pmapiClient.EXPECT().IsUnlocked().Return(true),
|
|
||||||
)
|
|
||||||
|
|
||||||
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
|
|
||||||
waitForEvents()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
err = user.CheckBridgeLogin(testCredentials.BridgePassword)
|
|
||||||
waitForEvents()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
user := testNewUser(m)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
|
|
||||||
|
|
||||||
isApplicationOutdated = true
|
|
||||||
|
|
||||||
err := user.CheckBridgeLogin("any-pass")
|
|
||||||
waitForEvents()
|
|
||||||
assert.Equal(t, pmapi.ErrUpgradeApplication, err)
|
|
||||||
|
|
||||||
isApplicationOutdated = false
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckBridgeLoginLoggedOut(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
|
||||||
|
|
||||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient).MinTimes(1)
|
|
||||||
gomock.InOrder(
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
|
||||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
err = user.init()
|
|
||||||
assert.Error(t, err)
|
|
||||||
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
|
||||||
|
|
||||||
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
|
|
||||||
waitForEvents()
|
|
||||||
assert.Equal(t, ErrLoggedOutUser, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckBridgeLoginBadPassword(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
user := testNewUser(m)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
err := user.CheckBridgeLogin("wrong!")
|
|
||||||
waitForEvents()
|
|
||||||
assert.Equal(t, "backend/credentials: incorrect password", err.Error())
|
|
||||||
}
|
|
||||||
@ -1,112 +0,0 @@
|
|||||||
// Copyright (c) 2021 Proton Technologies AG
|
|
||||||
//
|
|
||||||
// This file is part of ProtonMail Bridge.
|
|
||||||
//
|
|
||||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
|
||||||
// it under the terms of the GNU General Public License as published by
|
|
||||||
// the Free Software Foundation, either version 3 of the License, or
|
|
||||||
// (at your option) any later version.
|
|
||||||
//
|
|
||||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
|
||||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
// GNU General Public License for more details.
|
|
||||||
//
|
|
||||||
// You should have received a copy of the GNU General Public License
|
|
||||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
package users
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
|
||||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
a "github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewUserNoCredentialsStore(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
|
|
||||||
|
|
||||||
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
|
|
||||||
a.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewUserAuthRefreshFails(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
|
||||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")),
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().Logout(),
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
checkNewUserHasCredentials(testCredentialsDisconnected, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewUserUnlockFails(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
|
||||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(errors.New("bad password")),
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
|
||||||
m.pmapiClient.EXPECT().Logout(),
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
checkNewUserHasCredentials(testCredentialsDisconnected, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewUser(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
mockConnectedUser(m)
|
|
||||||
mockEventLoopNoAction(m)
|
|
||||||
|
|
||||||
checkNewUserHasCredentials(testCredentials, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkNewUserHasCredentials(creds *credentials.Credentials, m mocks) {
|
|
||||||
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
_ = user.init()
|
|
||||||
|
|
||||||
waitForEvents()
|
|
||||||
|
|
||||||
a.Equal(m.t, creds, user.creds)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _TestUserEventRefreshUpdatesAddresses(t *testing.T) { // nolint[funlen]
|
|
||||||
a.Fail(t, "not implemented")
|
|
||||||
}
|
|
||||||
@ -1,89 +0,0 @@
|
|||||||
// Copyright (c) 2021 Proton Technologies AG
|
|
||||||
//
|
|
||||||
// This file is part of ProtonMail Bridge.
|
|
||||||
//
|
|
||||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
|
||||||
// it under the terms of the GNU General Public License as published by
|
|
||||||
// the Free Software Foundation, either version 3 of the License, or
|
|
||||||
// (at your option) any later version.
|
|
||||||
//
|
|
||||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
|
||||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
// GNU General Public License for more details.
|
|
||||||
//
|
|
||||||
// You should have received a copy of the GNU General Public License
|
|
||||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
package users
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// testNewUser sets up a new, authorised user.
|
|
||||||
func testNewUser(m mocks) *User {
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
mockConnectedUser(m)
|
|
||||||
mockEventLoopNoAction(m)
|
|
||||||
|
|
||||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
|
|
||||||
assert.NoError(m.t, err)
|
|
||||||
|
|
||||||
err = user.init()
|
|
||||||
assert.NoError(m.t, err)
|
|
||||||
|
|
||||||
mockAuthUpdate(user, "reftok", m)
|
|
||||||
|
|
||||||
return user
|
|
||||||
}
|
|
||||||
|
|
||||||
func testNewUserForLogout(m mocks) *User {
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
mockConnectedUser(m)
|
|
||||||
mockEventLoopNoAction(m)
|
|
||||||
|
|
||||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
|
|
||||||
assert.NoError(m.t, err)
|
|
||||||
|
|
||||||
err = user.init()
|
|
||||||
assert.NoError(m.t, err)
|
|
||||||
|
|
||||||
return user
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanUpUserData(u *User) {
|
|
||||||
_ = u.clearStore()
|
|
||||||
}
|
|
||||||
|
|
||||||
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
|
|
||||||
assert.Fail(t, "not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClearStoreWithStore(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
user := testNewUserForLogout(m)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
require.Nil(t, user.store.Close())
|
|
||||||
user.store = nil
|
|
||||||
assert.Nil(t, user.clearStore())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClearStoreWithoutStore(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
user := testNewUserForLogout(m)
|
|
||||||
defer cleanUpUserData(user)
|
|
||||||
|
|
||||||
assert.NotNil(t, user.store)
|
|
||||||
assert.Nil(t, user.clearStore())
|
|
||||||
}
|
|
||||||
@ -1,143 +0,0 @@
|
|||||||
// Copyright (c) 2021 Proton Technologies AG
|
|
||||||
//
|
|
||||||
// This file is part of ProtonMail Bridge.
|
|
||||||
//
|
|
||||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
|
||||||
// it under the terms of the GNU General Public License as published by
|
|
||||||
// the Free Software Foundation, either version 3 of the License, or
|
|
||||||
// (at your option) any later version.
|
|
||||||
//
|
|
||||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
|
||||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
// GNU General Public License for more details.
|
|
||||||
//
|
|
||||||
// You should have received a copy of the GNU General Public License
|
|
||||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
package users
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetNoUser(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
checkUsersGetUser(t, m, "nouser", -1, "user nouser not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetUserByID(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
checkUsersGetUser(t, m, "user", 0, "")
|
|
||||||
checkUsersGetUser(t, m, "users", 1, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetUserByName(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
checkUsersGetUser(t, m, "username", 0, "")
|
|
||||||
checkUsersGetUser(t, m, "usersname", 1, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetUserByEmail(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
checkUsersGetUser(t, m, "user@pm.me", 0, "")
|
|
||||||
checkUsersGetUser(t, m, "users@pm.me", 1, "")
|
|
||||||
checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "")
|
|
||||||
checkUsersGetUser(t, m, "alsouser@pm.me", 1, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeleteUser(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
users := testNewUsersWithUsers(t, m)
|
|
||||||
defer cleanUpUsersData(users)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.pmapiClient.EXPECT().Logout().Return(),
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
|
||||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
|
||||||
|
|
||||||
err := users.DeleteUser("user", true)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, 1, len(users.users))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Even when logout fails, delete is done.
|
|
||||||
func TestDeleteUserWithFailingLogout(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
users := testNewUsersWithUsers(t, m)
|
|
||||||
defer cleanUpUsersData(users)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.pmapiClient.EXPECT().Logout().Return(),
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
|
|
||||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("no such user")),
|
|
||||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
|
||||||
|
|
||||||
err := users.DeleteUser("user", true)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, 1, len(users.users))
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) {
|
|
||||||
users := testNewUsersWithUsers(t, m)
|
|
||||||
defer cleanUpUsersData(users)
|
|
||||||
|
|
||||||
user, err := users.GetUser(query)
|
|
||||||
waitForEvents()
|
|
||||||
|
|
||||||
if expectedError != "" {
|
|
||||||
assert.Equal(m.t, expectedError, err.Error())
|
|
||||||
} else {
|
|
||||||
assert.NoError(m.t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var expectedUser *User
|
|
||||||
if index >= 0 {
|
|
||||||
expectedUser = users.users[index]
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(m.t, expectedUser, user)
|
|
||||||
}
|
|
||||||
@ -1,219 +0,0 @@
|
|||||||
// Copyright (c) 2021 Proton Technologies AG
|
|
||||||
//
|
|
||||||
// This file is part of ProtonMail Bridge.
|
|
||||||
//
|
|
||||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
|
||||||
// it under the terms of the GNU General Public License as published by
|
|
||||||
// the Free Software Foundation, either version 3 of the License, or
|
|
||||||
// (at your option) any later version.
|
|
||||||
//
|
|
||||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
|
||||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
// GNU General Public License for more details.
|
|
||||||
//
|
|
||||||
// You should have received a copy of the GNU General Public License
|
|
||||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
package users
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
|
||||||
"github.com/ProtonMail/proton-bridge/internal/metrics"
|
|
||||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUsersFinishLoginBadMailboxPassword(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
// Init users with no user from keychain.
|
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
|
|
||||||
|
|
||||||
// Set up mocks for FinishLogin.
|
|
||||||
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(errors.New("no keys could be unlocked")),
|
|
||||||
m.pmapiClient.EXPECT().DeleteAuth(),
|
|
||||||
m.pmapiClient.EXPECT().Logout(),
|
|
||||||
)
|
|
||||||
|
|
||||||
checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", ErrWrongMailboxPassword)
|
|
||||||
}
|
|
||||||
|
|
||||||
func refreshWithToken(token string) *pmapi.Auth {
|
|
||||||
return &pmapi.Auth{
|
|
||||||
RefreshToken: token,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func credentialsWithToken(token string) *credentials.Credentials {
|
|
||||||
tmp := &credentials.Credentials{}
|
|
||||||
*tmp = *testCredentials
|
|
||||||
tmp.APIToken = token
|
|
||||||
return tmp
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUsersFinishLoginNewUser(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
// Basically every call client has get client manager
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
// users.New() finds no users in keychain.
|
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
|
|
||||||
|
|
||||||
// getAPIUser() loads user info from API (e.g. userID).
|
|
||||||
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
|
||||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
|
||||||
|
|
||||||
// addNewUser()
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
|
|
||||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
|
||||||
m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
|
|
||||||
|
|
||||||
// user.init() in addNewUser
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
|
||||||
|
|
||||||
// store.New() in user.init
|
|
||||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
|
||||||
|
|
||||||
// Emit event for new user and send metrics.
|
|
||||||
m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient),
|
|
||||||
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)),
|
|
||||||
m.pmapiClient.EXPECT().Logout(),
|
|
||||||
|
|
||||||
// Reload account list in GUI.
|
|
||||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
|
|
||||||
|
|
||||||
// defer logout anonymous
|
|
||||||
m.pmapiClient.EXPECT().Logout(),
|
|
||||||
)
|
|
||||||
|
|
||||||
mockEventLoopNoAction(m)
|
|
||||||
|
|
||||||
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
|
|
||||||
|
|
||||||
mockAuthUpdate(user, "afterCredentials", m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
loggedOutCreds := *testCredentials
|
|
||||||
loggedOutCreds.APIToken = ""
|
|
||||||
loggedOutCreds.MailboxPassword = ""
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
// users.New() finds one existing user in keychain.
|
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
|
|
||||||
|
|
||||||
// newUser()
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
|
|
||||||
|
|
||||||
// user.init()
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
|
|
||||||
|
|
||||||
// store.New() in user.init
|
|
||||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUnauthorized),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
|
||||||
|
|
||||||
// getAPIUser() loads user info from API (e.g. userID).
|
|
||||||
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
|
||||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
|
||||||
|
|
||||||
// connectExistingUser()
|
|
||||||
m.credentialsStore.EXPECT().UpdatePassword("user", testCredentials.MailboxPassword).Return(nil),
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
|
|
||||||
m.credentialsStore.EXPECT().UpdateToken("user", ":afterLogin").Return(nil),
|
|
||||||
|
|
||||||
// user.init() in connectExistingUser
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
|
||||||
|
|
||||||
// store.New() in user.init
|
|
||||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
|
||||||
|
|
||||||
// Reload account list in GUI.
|
|
||||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
|
|
||||||
|
|
||||||
// defer logout anonymous
|
|
||||||
m.pmapiClient.EXPECT().Logout(),
|
|
||||||
)
|
|
||||||
|
|
||||||
mockEventLoopNoAction(m)
|
|
||||||
|
|
||||||
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
|
|
||||||
|
|
||||||
mockAuthUpdate(user, "afterCredentials", m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUsersFinishLoginConnectedUser(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
|
||||||
|
|
||||||
mockConnectedUser(m)
|
|
||||||
mockEventLoopNoAction(m)
|
|
||||||
|
|
||||||
users := testNewUsers(t, m)
|
|
||||||
defer cleanUpUsersData(users)
|
|
||||||
|
|
||||||
// Then, try to log in again...
|
|
||||||
gomock.InOrder(
|
|
||||||
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
|
|
||||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
|
||||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
|
||||||
m.pmapiClient.EXPECT().DeleteAuth(),
|
|
||||||
m.pmapiClient.EXPECT().Logout(),
|
|
||||||
)
|
|
||||||
|
|
||||||
_, err := users.FinishLogin(m.pmapiClient, testAuth, testCredentials.MailboxPassword)
|
|
||||||
assert.Equal(t, "user is already connected", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) *User {
|
|
||||||
users := testNewUsers(t, m)
|
|
||||||
defer cleanUpUsersData(users)
|
|
||||||
|
|
||||||
user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword)
|
|
||||||
|
|
||||||
waitForEvents()
|
|
||||||
|
|
||||||
assert.Equal(t, expectedErr, err)
|
|
||||||
|
|
||||||
if expectedUserID != "" {
|
|
||||||
assert.Equal(t, expectedUserID, user.ID())
|
|
||||||
assert.Equal(t, 1, len(users.users))
|
|
||||||
assert.Equal(t, expectedUserID, users.users[0].ID())
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, (*User)(nil), user)
|
|
||||||
assert.Equal(t, 0, len(users.users))
|
|
||||||
}
|
|
||||||
|
|
||||||
return user
|
|
||||||
}
|
|
||||||
@ -233,7 +233,7 @@ func (s *Store) get(userID string) (creds *Credentials, err error) {
|
|||||||
|
|
||||||
_, secret, err := s.secrets.Get(userID)
|
_, secret, err := s.secrets.Get(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Could not get credentials from native keychain")
|
log.WithError(err).Warn("Could not get credentials from native keychain")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -26,8 +26,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
r "github.com/stretchr/testify/require"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const testSep = "\n"
|
const testSep = "\n"
|
||||||
@ -249,26 +248,26 @@ func TestMarshalFormats(t *testing.T) {
|
|||||||
log.Infof("secretFmt %#v %d\n", secretFmt, len(secretFmt))
|
log.Infof("secretFmt %#v %d\n", secretFmt, len(secretFmt))
|
||||||
|
|
||||||
output := testCredentials{APIToken: "refresh"}
|
output := testCredentials{APIToken: "refresh"}
|
||||||
require.NoError(t, output.UnmarshalStrings(secretStrings))
|
r.NoError(t, output.UnmarshalStrings(secretStrings))
|
||||||
log.Infof("strings out %#v \n", output)
|
log.Infof("strings out %#v \n", output)
|
||||||
require.True(t, input.IsSame(&output), "strings out not same")
|
r.True(t, input.IsSame(&output), "strings out not same")
|
||||||
|
|
||||||
output = testCredentials{APIToken: "refresh"}
|
output = testCredentials{APIToken: "refresh"}
|
||||||
require.NoError(t, output.UnmarshalGob(secretGob))
|
r.NoError(t, output.UnmarshalGob(secretGob))
|
||||||
log.Infof("gob out %#v\n \n", output)
|
log.Infof("gob out %#v\n \n", output)
|
||||||
assert.Equal(t, input, output)
|
r.Equal(t, input, output)
|
||||||
|
|
||||||
output = testCredentials{APIToken: "refresh"}
|
output = testCredentials{APIToken: "refresh"}
|
||||||
require.NoError(t, output.FromJSON(secretJSON))
|
r.NoError(t, output.FromJSON(secretJSON))
|
||||||
log.Infof("json out %#v \n", output)
|
log.Infof("json out %#v \n", output)
|
||||||
require.True(t, input.IsSame(&output), "json out not same")
|
r.True(t, input.IsSame(&output), "json out not same")
|
||||||
|
|
||||||
/*
|
/*
|
||||||
// Simple Fscanf not working!
|
// Simple Fscanf not working!
|
||||||
output = testCredentials{APIToken: "refresh"}
|
output = testCredentials{APIToken: "refresh"}
|
||||||
require.NoError(t, output.UnmarshalFmt(secretFmt))
|
r.NoError(t, output.UnmarshalFmt(secretFmt))
|
||||||
log.Infof("fmt out %#v \n", output)
|
log.Infof("fmt out %#v \n", output)
|
||||||
require.True(t, input.IsSame(&output), "fmt out not same")
|
r.True(t, input.IsSame(&output), "fmt out not same")
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -291,7 +290,7 @@ func TestMarshal(t *testing.T) {
|
|||||||
log.Infof("secret %#v %d\n", secret, len(secret))
|
log.Infof("secret %#v %d\n", secret, len(secret))
|
||||||
|
|
||||||
output := Credentials{APIToken: "refresh"}
|
output := Credentials{APIToken: "refresh"}
|
||||||
require.NoError(t, output.Unmarshal(secret))
|
r.NoError(t, output.Unmarshal(secret))
|
||||||
log.Infof("output %#v\n", output)
|
log.Infof("output %#v\n", output)
|
||||||
assert.Equal(t, input, output)
|
r.Equal(t, input, output)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,107 +0,0 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: ./listener/listener.go
|
|
||||||
|
|
||||||
// Package users is a generated GoMock package.
|
|
||||||
package users
|
|
||||||
|
|
||||||
import (
|
|
||||||
reflect "reflect"
|
|
||||||
time "time"
|
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockListener is a mock of Listener interface
|
|
||||||
type MockListener struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockListenerMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockListenerMockRecorder is the mock recorder for MockListener
|
|
||||||
type MockListenerMockRecorder struct {
|
|
||||||
mock *MockListener
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockListener creates a new mock instance
|
|
||||||
func NewMockListener(ctrl *gomock.Controller) *MockListener {
|
|
||||||
mock := &MockListener{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockListenerMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use
|
|
||||||
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLimit mocks base method
|
|
||||||
func (m *MockListener) SetLimit(eventName string, limit time.Duration) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "SetLimit", eventName, limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLimit indicates an expected call of SetLimit
|
|
||||||
func (mr *MockListenerMockRecorder) SetLimit(eventName, limit interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), eventName, limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add mocks base method
|
|
||||||
func (m *MockListener) Add(eventName string, channel chan<- string) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "Add", eventName, channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add indicates an expected call of Add
|
|
||||||
func (mr *MockListenerMockRecorder) Add(eventName, channel interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), eventName, channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove mocks base method
|
|
||||||
func (m *MockListener) Remove(eventName string, channel chan<- string) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "Remove", eventName, channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove indicates an expected call of Remove
|
|
||||||
func (mr *MockListenerMockRecorder) Remove(eventName, channel interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), eventName, channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit mocks base method
|
|
||||||
func (m *MockListener) Emit(eventName, data string) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "Emit", eventName, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit indicates an expected call of Emit
|
|
||||||
func (mr *MockListenerMockRecorder) Emit(eventName, data interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), eventName, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBuffer mocks base method
|
|
||||||
func (m *MockListener) SetBuffer(eventName string) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "SetBuffer", eventName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBuffer indicates an expected call of SetBuffer
|
|
||||||
func (mr *MockListenerMockRecorder) SetBuffer(eventName interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), eventName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RetryEmit mocks base method
|
|
||||||
func (m *MockListener) RetryEmit(eventName string) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "RetryEmit", eventName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RetryEmit indicates an expected call of RetryEmit
|
|
||||||
func (mr *MockListenerMockRecorder) RetryEmit(eventName interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), eventName)
|
|
||||||
}
|
|
||||||
120
internal/users/mocks/listener_mocks.go
Normal file
120
internal/users/mocks/listener_mocks.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/ProtonMail/proton-bridge/pkg/listener (interfaces: Listener)
|
||||||
|
|
||||||
|
// Package mocks is a generated GoMock package.
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
reflect "reflect"
|
||||||
|
time "time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockListener is a mock of Listener interface
|
||||||
|
type MockListener struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockListenerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockListenerMockRecorder is the mock recorder for MockListener
|
||||||
|
type MockListenerMockRecorder struct {
|
||||||
|
mock *MockListener
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockListener creates a new mock instance
|
||||||
|
func NewMockListener(ctrl *gomock.Controller) *MockListener {
|
||||||
|
mock := &MockListener{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockListenerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add mocks base method
|
||||||
|
func (m *MockListener) Add(arg0 string, arg1 chan<- string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Add", arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add indicates an expected call of Add
|
||||||
|
func (mr *MockListenerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit mocks base method
|
||||||
|
func (m *MockListener) Emit(arg0, arg1 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Emit", arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit indicates an expected call of Emit
|
||||||
|
func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideChannel mocks base method
|
||||||
|
func (m *MockListener) ProvideChannel(arg0 string) <-chan string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ProvideChannel", arg0)
|
||||||
|
ret0, _ := ret[0].(<-chan string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideChannel indicates an expected call of ProvideChannel
|
||||||
|
func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove mocks base method
|
||||||
|
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Remove", arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove indicates an expected call of Remove
|
||||||
|
func (mr *MockListenerMockRecorder) Remove(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryEmit mocks base method
|
||||||
|
func (m *MockListener) RetryEmit(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "RetryEmit", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryEmit indicates an expected call of RetryEmit
|
||||||
|
func (mr *MockListenerMockRecorder) RetryEmit(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBuffer mocks base method
|
||||||
|
func (m *MockListener) SetBuffer(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetBuffer", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBuffer indicates an expected call of SetBuffer
|
||||||
|
func (mr *MockListenerMockRecorder) SetBuffer(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLimit mocks base method
|
||||||
|
func (m *MockListener) SetLimit(arg0 string, arg1 time.Duration) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetLimit", arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLimit indicates an expected call of SetLimit
|
||||||
|
func (mr *MockListenerMockRecorder) SetLimit(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), arg0, arg1)
|
||||||
|
}
|
||||||
@ -5,11 +5,10 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
store "github.com/ProtonMail/proton-bridge/internal/store"
|
store "github.com/ProtonMail/proton-bridge/internal/store"
|
||||||
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
reflect "reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockLocator is a mock of Locator interface
|
// MockLocator is a mock of Locator interface
|
||||||
|
|||||||
@ -89,29 +89,25 @@ func newUser(
|
|||||||
// - providing it with an authorised API client
|
// - providing it with an authorised API client
|
||||||
// - loading its credentials from the credentials store
|
// - loading its credentials from the credentials store
|
||||||
// - loading and unlocking its PGP keys
|
// - loading and unlocking its PGP keys
|
||||||
// - loading its store
|
// - loading its store.
|
||||||
func (u *User) connect(ctx context.Context, client pmapi.Client, creds *credentials.Credentials) error {
|
func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) error {
|
||||||
u.log.Info("Connecting user")
|
u.log.Info("Connecting user")
|
||||||
|
|
||||||
// Connected users have an API client.
|
// Connected users have an API client.
|
||||||
u.client = client
|
u.client = client
|
||||||
|
|
||||||
// FIXME(conman): How to remove this auth handler when user is disconnected?
|
u.client.AddAuthRefreshHandler(u.handleAuthRefresh)
|
||||||
u.client.AddAuthHandler(u.handleAuth)
|
|
||||||
|
|
||||||
// Save the latest credentials for the user.
|
// Save the latest credentials for the user.
|
||||||
u.creds = creds
|
u.creds = creds
|
||||||
|
|
||||||
// Connected users have unlocked keys.
|
// Connected users have unlocked keys.
|
||||||
// FIXME(conman): clients should always be authorized! This is a workaround to avoid a major refactor :(
|
if err := u.unlockIfNecessary(); err != nil {
|
||||||
if u.creds.IsConnected() {
|
|
||||||
if err := u.client.Unlock(ctx, []byte(u.creds.MailboxPassword)); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Connected users have a store.
|
// Connected users have a store.
|
||||||
if err := u.loadStore(); err != nil {
|
if err := u.loadStore(); err != nil { //nolint[revive] easier to read
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,17 +134,25 @@ func (u *User) loadStore() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) handleAuth(auth *pmapi.Auth) error {
|
func (u *User) handleAuthRefresh(auth *pmapi.AuthRefresh) {
|
||||||
u.log.Debug("User received auth")
|
u.log.Debug("User received auth refresh update")
|
||||||
|
|
||||||
|
if auth == nil {
|
||||||
|
if err := u.logout(); err != nil {
|
||||||
|
log.WithError(err).
|
||||||
|
WithField("userID", u.userID).
|
||||||
|
Error("User logout failed while watching API auths")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken)
|
creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to update refresh token in credentials store")
|
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
u.creds = creds
|
u.creds = creds
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// clearStore removes the database.
|
// clearStore removes the database.
|
||||||
@ -181,13 +185,6 @@ func (u *User) closeStore() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTemporaryPMAPIClient returns an authorised PMAPI client.
|
|
||||||
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
|
|
||||||
// After proper refactor of SMTP and IMAP remove this method.
|
|
||||||
func (u *User) GetTemporaryPMAPIClient() pmapi.Client {
|
|
||||||
return u.client
|
|
||||||
}
|
|
||||||
|
|
||||||
// ID returns the user's userID.
|
// ID returns the user's userID.
|
||||||
func (u *User) ID() string {
|
func (u *User) ID() string {
|
||||||
return u.userID
|
return u.userID
|
||||||
@ -210,9 +207,43 @@ func (u *User) IsConnected() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) GetClient() pmapi.Client {
|
func (u *User) GetClient() pmapi.Client {
|
||||||
|
if err := u.unlockIfNecessary(); err != nil {
|
||||||
|
u.log.WithError(err).Error("Failed to unlock user")
|
||||||
|
}
|
||||||
return u.client
|
return u.client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
|
||||||
|
func (u *User) unlockIfNecessary() error {
|
||||||
|
if !u.creds.IsConnected() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.client.IsUnlocked() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unlockIfNecessary is called with every access to underlying pmapi
|
||||||
|
// client. Unlock should only finish unlocking when connection is back up.
|
||||||
|
// That means it should try it fast enough and not retry if connection
|
||||||
|
// is still down.
|
||||||
|
err := u.client.Unlock(pmapi.ContextWithoutRetry(context.Background()), []byte(u.creds.MailboxPassword))
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch errors.Cause(err) {
|
||||||
|
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
|
||||||
|
u.log.WithError(err).Warn("Could not unlock user")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if logoutErr := u.logout(); logoutErr != nil {
|
||||||
|
u.log.WithError(logoutErr).Warn("Could not logout user")
|
||||||
|
}
|
||||||
|
return errors.Wrap(err, "failed to unlock user")
|
||||||
|
}
|
||||||
|
|
||||||
// IsCombinedAddressMode returns whether user is set in combined or split mode.
|
// IsCombinedAddressMode returns whether user is set in combined or split mode.
|
||||||
// Combined mode is the default mode and is what users typically need.
|
// Combined mode is the default mode and is what users typically need.
|
||||||
// Split mode is mostly for outlook as it cannot handle sending e-mails from an
|
// Split mode is mostly for outlook as it cannot handle sending e-mails from an
|
||||||
@ -307,14 +338,10 @@ func (u *User) GetBridgePassword() string {
|
|||||||
// CheckBridgeLogin checks whether the user is logged in and the bridge
|
// CheckBridgeLogin checks whether the user is logged in and the bridge
|
||||||
// IMAP/SMTP password is correct.
|
// IMAP/SMTP password is correct.
|
||||||
func (u *User) CheckBridgeLogin(password string) error {
|
func (u *User) CheckBridgeLogin(password string) error {
|
||||||
// FIXME(conman): Handle force upgrade?
|
|
||||||
|
|
||||||
/*
|
|
||||||
if isApplicationOutdated {
|
if isApplicationOutdated {
|
||||||
u.listener.Emit(events.UpgradeApplicationEvent, "")
|
u.listener.Emit(events.UpgradeApplicationEvent, "")
|
||||||
return pmapi.ErrUpgradeApplication
|
return pmapi.ErrUpgradeApplication
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
u.lock.RLock()
|
u.lock.RLock()
|
||||||
defer u.lock.RUnlock()
|
defer u.lock.RUnlock()
|
||||||
@ -328,16 +355,16 @@ func (u *User) CheckBridgeLogin(password string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateUser updates user details from API and saves to the credentials.
|
// UpdateUser updates user details from API and saves to the credentials.
|
||||||
func (u *User) UpdateUser() error {
|
func (u *User) UpdateUser(ctx context.Context) error {
|
||||||
u.lock.Lock()
|
u.lock.Lock()
|
||||||
defer u.lock.Unlock()
|
defer u.lock.Unlock()
|
||||||
|
|
||||||
_, err := u.client.UpdateUser(context.TODO())
|
_, err := u.client.UpdateUser(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := u.client.ReloadKeys(context.TODO(), []byte(u.creds.MailboxPassword)); err != nil {
|
if err := u.client.ReloadKeys(ctx, []byte(u.creds.MailboxPassword)); err != nil {
|
||||||
return errors.Wrap(err, "failed to reload keys")
|
return errors.Wrap(err, "failed to reload keys")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -414,8 +441,7 @@ func (u *User) Logout() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME(conman): Do we delete API client now? Who cleans up? What about registered handlers?
|
if err := u.client.AuthDelete(context.Background()); err != nil {
|
||||||
if err := u.client.AuthDelete(context.TODO()); err != nil {
|
|
||||||
u.log.WithError(err).Warn("Failed to delete auth")
|
u.log.WithError(err).Warn("Failed to delete auth")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
195
internal/users/user_credentials_test.go
Normal file
195
internal/users/user_credentials_test.go
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||||
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateUser(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
user := testNewUser(m)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
gomock.InOrder(
|
||||||
|
m.pmapiClient.EXPECT().UpdateUser(gomock.Any()).Return(nil, nil),
|
||||||
|
m.pmapiClient.EXPECT().ReloadKeys(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||||
|
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||||
|
|
||||||
|
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}).Return(testCredentials, nil),
|
||||||
|
)
|
||||||
|
|
||||||
|
r.NoError(t, user.UpdateUser(context.Background()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserSwitchAddressMode(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
user := testNewUser(m)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
// Ignore any sync on background.
|
||||||
|
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||||
|
|
||||||
|
// Check initial state.
|
||||||
|
r.True(t, user.store.IsCombinedMode())
|
||||||
|
r.True(t, user.creds.IsCombinedAddressMode)
|
||||||
|
r.True(t, user.IsCombinedAddressMode())
|
||||||
|
|
||||||
|
// Mock change to split mode.
|
||||||
|
gomock.InOrder(
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
||||||
|
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||||
|
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||||
|
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||||
|
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentialsSplit, nil),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check switch to split mode.
|
||||||
|
r.NoError(t, user.SwitchAddressMode())
|
||||||
|
r.False(t, user.store.IsCombinedMode())
|
||||||
|
r.False(t, user.creds.IsCombinedAddressMode)
|
||||||
|
r.False(t, user.IsCombinedAddressMode())
|
||||||
|
|
||||||
|
// MOck change to combined mode.
|
||||||
|
gomock.InOrder(
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
|
||||||
|
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||||
|
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||||
|
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||||
|
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentials, nil),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check switch to combined mode.
|
||||||
|
r.NoError(t, user.SwitchAddressMode())
|
||||||
|
r.True(t, user.store.IsCombinedMode())
|
||||||
|
r.True(t, user.creds.IsCombinedAddressMode)
|
||||||
|
r.True(t, user.IsCombinedAddressMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogoutUser(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
user := testNewUser(m)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
gomock.InOrder(
|
||||||
|
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||||
|
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
||||||
|
)
|
||||||
|
|
||||||
|
err := user.Logout()
|
||||||
|
r.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogoutUserFailsLogout(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
user := testNewUser(m)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
gomock.InOrder(
|
||||||
|
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||||
|
m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")),
|
||||||
|
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
||||||
|
)
|
||||||
|
|
||||||
|
err := user.Logout()
|
||||||
|
r.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckBridgeLogin(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
user := testNewUser(m)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
|
||||||
|
r.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
user := testNewUser(m)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
|
||||||
|
|
||||||
|
isApplicationOutdated = true
|
||||||
|
|
||||||
|
err := user.CheckBridgeLogin("any-pass")
|
||||||
|
r.Equal(t, pmapi.ErrUpgradeApplication, err)
|
||||||
|
|
||||||
|
isApplicationOutdated = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckBridgeLoginLoggedOut(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
gomock.InOrder(
|
||||||
|
// Mock init of user.
|
||||||
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||||
|
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
|
||||||
|
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
|
||||||
|
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||||
|
|
||||||
|
// Mock CheckBridgeLogin.
|
||||||
|
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"),
|
||||||
|
)
|
||||||
|
|
||||||
|
user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
|
||||||
|
r.NoError(t, err)
|
||||||
|
|
||||||
|
err = user.connect(m.pmapiClient, testCredentialsDisconnected)
|
||||||
|
r.Error(t, err)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
|
||||||
|
r.Equal(t, ErrLoggedOutUser, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckBridgeLoginBadPassword(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
user := testNewUser(m)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
err := user.CheckBridgeLogin("wrong!")
|
||||||
|
r.EqualError(t, err, "backend/credentials: incorrect password")
|
||||||
|
}
|
||||||
88
internal/users/user_new_test.go
Normal file
88
internal/users/user_new_test.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||||
|
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewUserNoCredentialsStore(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
|
||||||
|
|
||||||
|
_, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
|
||||||
|
r.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewUserUnlockFails(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
gomock.InOrder(
|
||||||
|
// Init of user.
|
||||||
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||||
|
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
|
||||||
|
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
||||||
|
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("bad password")),
|
||||||
|
|
||||||
|
// Handle of unlock error.
|
||||||
|
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||||
|
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
||||||
|
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"),
|
||||||
|
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
|
||||||
|
)
|
||||||
|
|
||||||
|
checkNewUserHasCredentials(m, "failed to unlock user: bad password", testCredentialsDisconnected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewUser(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||||
|
mockInitConnectedUser(m)
|
||||||
|
mockEventLoopNoAction(m)
|
||||||
|
|
||||||
|
checkNewUserHasCredentials(m, "", testCredentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkNewUserHasCredentials(m mocks, wantErr string, wantCreds *credentials.Credentials) {
|
||||||
|
user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
|
||||||
|
r.NoError(m.t, err)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
err = user.connect(m.pmapiClient, testCredentials)
|
||||||
|
if wantErr == "" {
|
||||||
|
r.NoError(m.t, err)
|
||||||
|
} else {
|
||||||
|
r.EqualError(m.t, err, wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Equal(m.t, wantCreds, user.creds)
|
||||||
|
|
||||||
|
waitForEvents()
|
||||||
|
}
|
||||||
51
internal/users/user_store_test.go
Normal file
51
internal/users/user_store_test.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
|
||||||
|
r.Fail(t, "not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearStoreWithStore(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
user := testNewUser(m)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
r.Nil(t, user.store.Close())
|
||||||
|
user.store = nil
|
||||||
|
r.Nil(t, user.clearStore())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearStoreWithoutStore(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
user := testNewUser(m)
|
||||||
|
defer cleanUpUserData(user)
|
||||||
|
|
||||||
|
r.NotNil(t, user.store)
|
||||||
|
r.Nil(t, user.clearStore())
|
||||||
|
}
|
||||||
41
internal/users/user_test.go
Normal file
41
internal/users/user_test.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testNewUser sets up a new, authorised user.
|
||||||
|
func testNewUser(m mocks) *User {
|
||||||
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||||
|
mockInitConnectedUser(m)
|
||||||
|
mockEventLoopNoAction(m)
|
||||||
|
|
||||||
|
user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
|
||||||
|
r.NoError(m.t, err)
|
||||||
|
|
||||||
|
err = user.connect(m.pmapiClient, creds)
|
||||||
|
r.NoError(m.t, err)
|
||||||
|
|
||||||
|
return user
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanUpUserData(u *User) {
|
||||||
|
_ = u.clearStore()
|
||||||
|
}
|
||||||
@ -89,24 +89,42 @@ func New(
|
|||||||
lock: sync.RWMutex{},
|
lock: sync.RWMutex{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME(conman): Handle force upgrade events.
|
|
||||||
/*
|
|
||||||
go func() {
|
go func() {
|
||||||
defer panicHandler.HandlePanic()
|
defer panicHandler.HandlePanic()
|
||||||
u.watchAppOutdated()
|
u.watchEvents()
|
||||||
}()
|
}()
|
||||||
*/
|
|
||||||
|
|
||||||
if u.credStorer == nil {
|
if u.credStorer == nil {
|
||||||
log.Error("No credentials store is available")
|
log.Error("No credentials store is available")
|
||||||
} else if err := u.loadUsersFromCredentialsStore(context.TODO()); err != nil {
|
} else if err := u.loadUsersFromCredentialsStore(); err != nil {
|
||||||
log.WithError(err).Error("Could not load all users from credentials store")
|
log.WithError(err).Error("Could not load all users from credentials store")
|
||||||
}
|
}
|
||||||
|
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
|
func (u *Users) watchEvents() {
|
||||||
|
upgradeCh := u.events.ProvideChannel(events.UpgradeApplicationEvent)
|
||||||
|
internetOnCh := u.events.ProvideChannel(events.InternetOnEvent)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-upgradeCh:
|
||||||
|
isApplicationOutdated = true
|
||||||
|
u.closeAllConnections()
|
||||||
|
case <-internetOnCh:
|
||||||
|
for _, user := range u.users {
|
||||||
|
if user.store == nil {
|
||||||
|
if err := user.loadStore(); err != nil {
|
||||||
|
log.WithError(err).Error("Failed to load store after reconnecting")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Users) loadUsersFromCredentialsStore() error {
|
||||||
u.lock.Lock()
|
u.lock.Lock()
|
||||||
defer u.lock.Unlock()
|
defer u.lock.Unlock()
|
||||||
|
|
||||||
@ -116,23 +134,26 @@ func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, userID := range userIDs {
|
for _, userID := range userIDs {
|
||||||
|
l := log.WithField("user", userID)
|
||||||
user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
|
user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Warn("Could not create user, skipping")
|
l.WithError(err).Warn("Could not create user, skipping")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
u.users = append(u.users, user)
|
u.users = append(u.users, user)
|
||||||
|
|
||||||
if creds.IsConnected() {
|
if creds.IsConnected() {
|
||||||
if err := u.loadConnectedUser(ctx, user, creds); err != nil {
|
// If there is no connection, we don't want to retry. Load should
|
||||||
logrus.WithError(err).Warn("Could not load connected user")
|
// happen fast enough to not block GUI. When connection is back up,
|
||||||
|
// watchEvents and unlockIfNecessary will finish user init later.
|
||||||
|
if err := u.loadConnectedUser(pmapi.ContextWithoutRetry(context.Background()), user, creds); err != nil {
|
||||||
|
l.WithError(err).Warn("Could not load connected user")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logrus.Warn("User is disconnected and must be connected manually")
|
l.Warn("User is disconnected and must be connected manually")
|
||||||
|
if err := user.connect(u.clientManager.NewClient("", "", "", time.Time{}), creds); err != nil {
|
||||||
if err := u.loadDisconnectedUser(ctx, user, creds); err != nil {
|
l.WithError(err).Warn("Could not load disconnected user")
|
||||||
logrus.WithError(err).Warn("Could not load disconnected user")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -140,11 +161,6 @@ func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Users) loadDisconnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
|
|
||||||
// FIXME(conman): We shouldn't be creating unauthorized clients... this is hacky, just to avoid huge refactor!
|
|
||||||
return user.connect(ctx, u.clientManager.NewClient("", "", "", time.Time{}), creds)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
|
func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
|
||||||
uid, ref, err := creds.SplitAPIToken()
|
uid, ref, err := creds.SplitAPIToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -153,38 +169,27 @@ func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *creden
|
|||||||
|
|
||||||
client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref)
|
client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// FIXME(conman): This is a problem... if we weren't able to create a new client due to internet,
|
// When client cannot be refreshed right away due to no connection,
|
||||||
// we need to be able to retry later, but I deleted all the hacky "retry auth if necessary" stuff...
|
// we create client which will refresh automatically when possible.
|
||||||
return user.connect(ctx, u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
|
connectErr := user.connect(u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
|
||||||
|
|
||||||
|
switch errors.Cause(err) {
|
||||||
|
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
|
||||||
|
return connectErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if logoutErr := user.logout(); logoutErr != nil {
|
||||||
|
logrus.WithError(logoutErr).Warn("Could not logout user")
|
||||||
|
}
|
||||||
|
return errors.Wrap(err, "could not refresh token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the user's credentials with the latest auth used to connect this user.
|
// Update the user's credentials with the latest auth used to connect this user.
|
||||||
if creds, err = u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil {
|
if creds, err = u.credStorer.UpdateToken(creds.UserID, auth.UID, auth.RefreshToken); err != nil {
|
||||||
return errors.Wrap(err, "could not create get user's refresh token")
|
return errors.Wrap(err, "could not create get user's refresh token")
|
||||||
}
|
}
|
||||||
|
|
||||||
return user.connect(ctx, client, creds)
|
return user.connect(client, creds)
|
||||||
}
|
|
||||||
|
|
||||||
func (u *Users) watchAppOutdated() {
|
|
||||||
// FIXME(conman): handle force upgrade events.
|
|
||||||
|
|
||||||
/*
|
|
||||||
ch := make(chan string)
|
|
||||||
|
|
||||||
u.events.Add(events.UpgradeApplicationEvent, ch)
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ch:
|
|
||||||
isApplicationOutdated = true
|
|
||||||
u.closeAllConnections()
|
|
||||||
|
|
||||||
case <-u.stopAll:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Users) closeAllConnections() {
|
func (u *Users) closeAllConnections() {
|
||||||
@ -198,19 +203,19 @@ func (u *Users) closeAllConnections() {
|
|||||||
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
|
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
|
||||||
u.crashBandicoot(username)
|
u.crashBandicoot(username)
|
||||||
|
|
||||||
return u.clientManager.NewClientWithLogin(context.TODO(), username, password)
|
return u.clientManager.NewClientWithLogin(context.Background(), username, password)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FinishLogin finishes the login procedure and adds the user into the credentials store.
|
// FinishLogin finishes the login procedure and adds the user into the credentials store.
|
||||||
func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen]
|
func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen]
|
||||||
apiUser, passphrase, err := getAPIUser(context.TODO(), client, password)
|
apiUser, passphrase, err := getAPIUser(context.Background(), client, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to get API user")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if user, ok := u.hasUser(apiUser.ID); ok {
|
if user, ok := u.hasUser(apiUser.ID); ok {
|
||||||
if user.IsConnected() {
|
if user.IsConnected() {
|
||||||
if err := client.AuthDelete(context.TODO()); err != nil {
|
if err := client.AuthDelete(context.Background()); err != nil {
|
||||||
logrus.WithError(err).Warn("Failed to delete new auth session")
|
logrus.WithError(err).Warn("Failed to delete new auth session")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,14 +233,16 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password stri
|
|||||||
return nil, errors.Wrap(err, "failed to update password of user in credentials store")
|
return nil, errors.Wrap(err, "failed to update password of user in credentials store")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.connect(context.TODO(), client, creds); err != nil {
|
if err := user.connect(client, creds); err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to reconnect existing user")
|
return nil, errors.Wrap(err, "failed to reconnect existing user")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
u.events.Emit(events.UserRefreshEvent, apiUser.ID)
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := u.addNewUser(context.TODO(), client, apiUser, auth, passphrase); err != nil {
|
if err := u.addNewUser(client, apiUser, auth, passphrase); err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to add new user")
|
return nil, errors.Wrap(err, "failed to add new user")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,7 +252,7 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addNewUser adds a new user.
|
// addNewUser adds a new user.
|
||||||
func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
|
func (u *Users) addNewUser(client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
|
||||||
u.lock.Lock()
|
u.lock.Lock()
|
||||||
defer u.lock.Unlock()
|
defer u.lock.Unlock()
|
||||||
|
|
||||||
@ -266,7 +273,7 @@ func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pm
|
|||||||
return errors.Wrap(err, "failed to create new user")
|
return errors.Wrap(err, "failed to create new user")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.connect(ctx, client, creds); err != nil {
|
if err := user.connect(client, creds); err != nil {
|
||||||
return errors.Wrap(err, "failed to connect new user")
|
return errors.Wrap(err, "failed to connect new user")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -292,7 +299,7 @@ func getAPIUser(ctx context.Context, client pmapi.Client, password string) (*pma
|
|||||||
|
|
||||||
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
|
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
|
||||||
if err := client.Unlock(ctx, passphrase); err != nil {
|
if err := client.Unlock(ctx, passphrase); err != nil {
|
||||||
return nil, nil, errors.Wrap(err, "failed to unlock client")
|
return nil, nil, ErrWrongMailboxPassword
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := client.CurrentUser(ctx)
|
user, err := client.CurrentUser(ctx)
|
||||||
@ -414,22 +421,13 @@ func (u *Users) SendMetric(m metrics.Metric) error {
|
|||||||
// AllowProxy instructs the app to use DoH to access an API proxy if necessary.
|
// AllowProxy instructs the app to use DoH to access an API proxy if necessary.
|
||||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||||
func (u *Users) AllowProxy() {
|
func (u *Users) AllowProxy() {
|
||||||
// FIXME(conman): Support DoH.
|
u.clientManager.AllowProxy()
|
||||||
// u.apiManager.AllowProxy()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisallowProxy instructs the app to not use DoH to access an API proxy if necessary.
|
// DisallowProxy instructs the app to not use DoH to access an API proxy if necessary.
|
||||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||||
func (u *Users) DisallowProxy() {
|
func (u *Users) DisallowProxy() {
|
||||||
// FIXME(conman): Support DoH.
|
u.clientManager.DisallowProxy()
|
||||||
// u.apiManager.DisallowProxy()
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckConnection returns whether there is an internet connection.
|
|
||||||
// This should use the connection manager when it is eventually implemented.
|
|
||||||
func (u *Users) CheckConnection() error {
|
|
||||||
// FIXME(conman): Other parts of bridge that rely on this method should register as a connection observer.
|
|
||||||
panic("TODO: register as a connection observer to get this information")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasUser returns whether the struct currently has a user with ID `id`.
|
// hasUser returns whether the struct currently has a user with ID `id`.
|
||||||
|
|||||||
49
internal/users/users_clear_test.go
Normal file
49
internal/users/users_clear_test.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClearData(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
users := testNewUsersWithUsers(t, m)
|
||||||
|
defer cleanUpUsersData(users)
|
||||||
|
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me")
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
|
||||||
|
|
||||||
|
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||||
|
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
|
||||||
|
|
||||||
|
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||||
|
m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil)
|
||||||
|
|
||||||
|
m.locator.EXPECT().Clear()
|
||||||
|
|
||||||
|
r.NoError(t, users.ClearData())
|
||||||
|
}
|
||||||
69
internal/users/users_delete_test.go
Normal file
69
internal/users/users_delete_test.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeleteUser(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
users := testNewUsersWithUsers(t, m)
|
||||||
|
defer cleanUpUsersData(users)
|
||||||
|
|
||||||
|
gomock.InOrder(
|
||||||
|
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||||
|
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
|
||||||
|
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||||
|
)
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||||
|
|
||||||
|
err := users.DeleteUser("user", true)
|
||||||
|
r.NoError(t, err)
|
||||||
|
r.Equal(t, 1, len(users.users))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Even when logout fails, delete is done.
|
||||||
|
func TestDeleteUserWithFailingLogout(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
users := testNewUsersWithUsers(t, m)
|
||||||
|
defer cleanUpUsersData(users)
|
||||||
|
|
||||||
|
gomock.InOrder(
|
||||||
|
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||||
|
m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")),
|
||||||
|
// Once called from user.Logout after failed creds.Logout as fallback, and once at the end of users.Logout.
|
||||||
|
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||||
|
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||||
|
)
|
||||||
|
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||||
|
|
||||||
|
err := users.DeleteUser("user", true)
|
||||||
|
r.NoError(t, err)
|
||||||
|
r.Equal(t, 1, len(users.users))
|
||||||
|
}
|
||||||
76
internal/users/users_get_test.go
Normal file
76
internal/users/users_get_test.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetNoUser(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
checkUsersGetUser(t, m, "nouser", -1, "user nouser not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserByID(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
checkUsersGetUser(t, m, "user", 0, "")
|
||||||
|
checkUsersGetUser(t, m, "users", 1, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserByName(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
checkUsersGetUser(t, m, "username", 0, "")
|
||||||
|
checkUsersGetUser(t, m, "usersname", 1, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserByEmail(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
checkUsersGetUser(t, m, "user@pm.me", 0, "")
|
||||||
|
checkUsersGetUser(t, m, "users@pm.me", 1, "")
|
||||||
|
checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "")
|
||||||
|
checkUsersGetUser(t, m, "alsouser@pm.me", 1, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) {
|
||||||
|
users := testNewUsersWithUsers(t, m)
|
||||||
|
defer cleanUpUsersData(users)
|
||||||
|
|
||||||
|
user, err := users.GetUser(query)
|
||||||
|
|
||||||
|
if expectedError != "" {
|
||||||
|
r.EqualError(m.t, err, expectedError)
|
||||||
|
} else {
|
||||||
|
r.NoError(m.t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var expectedUser *User
|
||||||
|
if index >= 0 {
|
||||||
|
expectedUser = users.users[index]
|
||||||
|
}
|
||||||
|
r.Equal(m.t, expectedUser, user)
|
||||||
|
}
|
||||||
132
internal/users/users_login_test.go
Normal file
132
internal/users/users_login_test.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||||
|
"github.com/ProtonMail/proton-bridge/internal/metrics"
|
||||||
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsersFinishLoginBadMailboxPassword(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
// Init users with no user from keychain.
|
||||||
|
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
|
||||||
|
|
||||||
|
// Set up mocks for FinishLogin.
|
||||||
|
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil)
|
||||||
|
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("no keys could be unlocked"))
|
||||||
|
|
||||||
|
checkUsersFinishLogin(t, m, testAuthRefresh, testCredentials.MailboxPassword, "", ErrWrongMailboxPassword)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsersFinishLoginNewUser(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
// Init users with no user from keychain.
|
||||||
|
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
|
||||||
|
|
||||||
|
mockAddingConnectedUser(m)
|
||||||
|
mockEventLoopNoAction(m)
|
||||||
|
|
||||||
|
m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel))
|
||||||
|
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentials.UserID)
|
||||||
|
|
||||||
|
checkUsersFinishLogin(t, m, testAuthRefresh, testCredentials.MailboxPassword, testCredentials.UserID, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
// Mock loading disconnected user.
|
||||||
|
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID}, nil)
|
||||||
|
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
|
||||||
|
|
||||||
|
// Mock process of FinishLogin of already added user.
|
||||||
|
gomock.InOrder(
|
||||||
|
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
|
||||||
|
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||||
|
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUserDisconnected, nil),
|
||||||
|
m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
|
||||||
|
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
|
||||||
|
)
|
||||||
|
mockInitConnectedUser(m)
|
||||||
|
mockEventLoopNoAction(m)
|
||||||
|
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID)
|
||||||
|
|
||||||
|
authRefresh := &pmapi.Auth{
|
||||||
|
UserID: testCredentialsDisconnected.UserID,
|
||||||
|
AuthRefresh: pmapi.AuthRefresh{
|
||||||
|
UID: "uid",
|
||||||
|
AccessToken: "acc",
|
||||||
|
RefreshToken: "ref",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
checkUsersFinishLogin(t, m, authRefresh, testCredentials.MailboxPassword, testCredentialsDisconnected.UserID, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsersFinishLoginConnectedUser(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
// Mock loading connected user.
|
||||||
|
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
|
||||||
|
mockLoadingConnectedUser(m, testCredentials)
|
||||||
|
mockEventLoopNoAction(m)
|
||||||
|
|
||||||
|
// Mock process of FinishLogin of already connected user.
|
||||||
|
gomock.InOrder(
|
||||||
|
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
|
||||||
|
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||||
|
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil),
|
||||||
|
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||||
|
)
|
||||||
|
|
||||||
|
users := testNewUsers(t, m)
|
||||||
|
defer cleanUpUsersData(users)
|
||||||
|
|
||||||
|
_, err := users.FinishLogin(m.pmapiClient, testAuthRefresh, testCredentials.MailboxPassword)
|
||||||
|
r.EqualError(t, err, "user is already connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) {
|
||||||
|
users := testNewUsers(t, m)
|
||||||
|
defer cleanUpUsersData(users)
|
||||||
|
|
||||||
|
user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword)
|
||||||
|
|
||||||
|
r.Equal(t, expectedErr, err)
|
||||||
|
|
||||||
|
if expectedUserID != "" {
|
||||||
|
r.Equal(t, expectedUserID, user.ID())
|
||||||
|
r.Equal(t, 1, len(users.users))
|
||||||
|
r.Equal(t, expectedUserID, users.users[0].ID())
|
||||||
|
} else {
|
||||||
|
r.Equal(t, (*User)(nil), user)
|
||||||
|
r.Equal(t, 0, len(users.users))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -22,9 +22,10 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
time "time"
|
time "time"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/assert"
|
r "github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewUsersNoKeychain(t *testing.T) {
|
func TestNewUsersNoKeychain(t *testing.T) {
|
||||||
@ -32,7 +33,6 @@ func TestNewUsersNoKeychain(t *testing.T) {
|
|||||||
defer m.ctrl.Finish()
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{}, errors.New("no keychain"))
|
m.credentialsStore.EXPECT().List().Return([]string{}, errors.New("no keychain"))
|
||||||
|
|
||||||
checkUsersNew(t, m, []*credentials.Credentials{})
|
checkUsersNew(t, m, []*credentials.Credentials{})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,108 +41,73 @@ func TestNewUsersWithoutUsersInCredentialsStore(t *testing.T) {
|
|||||||
defer m.ctrl.Finish()
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
|
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
|
||||||
|
|
||||||
checkUsersNew(t, m, []*credentials.Credentials{})
|
checkUsersNew(t, m, []*credentials.Credentials{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewUsersWithDisconnectedUser(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
|
||||||
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
|
|
||||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
|
||||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
|
|
||||||
m := initMocks(t)
|
|
||||||
defer m.ctrl.Finish()
|
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
|
||||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token"))
|
|
||||||
|
|
||||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
|
||||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
|
||||||
m.pmapiClient.EXPECT().Logout()
|
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
|
||||||
|
|
||||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewUsersWithConnectedUser(t *testing.T) {
|
func TestNewUsersWithConnectedUser(t *testing.T) {
|
||||||
m := initMocks(t)
|
m := initMocks(t)
|
||||||
defer m.ctrl.Finish()
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
mockLoadingConnectedUser(m, testCredentials)
|
||||||
|
|
||||||
mockConnectedUser(m)
|
|
||||||
mockEventLoopNoAction(m)
|
mockEventLoopNoAction(m)
|
||||||
|
|
||||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
|
checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewUsersWithDisconnectedUser(t *testing.T) {
|
||||||
|
m := initMocks(t)
|
||||||
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
|
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID}, nil)
|
||||||
|
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
|
||||||
|
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||||
|
}
|
||||||
|
|
||||||
// Tests two users with different states and checks also the order from
|
// Tests two users with different states and checks also the order from
|
||||||
// credentials store is kept also in array of users.
|
// credentials store is kept also in array of users.
|
||||||
func TestNewUsersWithUsers(t *testing.T) {
|
func TestNewUsersWithUsers(t *testing.T) {
|
||||||
m := initMocks(t)
|
m := initMocks(t)
|
||||||
defer m.ctrl.Finish()
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil)
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{"userDisconnected", "user"}, nil)
|
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
|
||||||
|
mockLoadingConnectedUser(m, testCredentials)
|
||||||
gomock.InOrder(
|
|
||||||
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
|
|
||||||
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
|
|
||||||
// Set up mocks for store initialisation for the unauth user.
|
|
||||||
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
|
|
||||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
|
|
||||||
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
|
||||||
)
|
|
||||||
|
|
||||||
mockConnectedUser(m)
|
|
||||||
|
|
||||||
mockEventLoopNoAction(m)
|
mockEventLoopNoAction(m)
|
||||||
|
|
||||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
|
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewUsersFirstStart(t *testing.T) {
|
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
|
||||||
m := initMocks(t)
|
m := initMocks(t)
|
||||||
defer m.ctrl.Finish()
|
defer m.ctrl.Finish()
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
|
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, errors.New("bad token"))
|
||||||
|
m.clientManager.EXPECT().NewClient("uid", "", "acc", time.Time{}).Return(m.pmapiClient)
|
||||||
|
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
|
||||||
|
m.pmapiClient.EXPECT().IsUnlocked().Return(false)
|
||||||
|
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("not authorized"))
|
||||||
|
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||||
|
|
||||||
testNewUsers(t, m)
|
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||||
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||||
|
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
|
||||||
|
|
||||||
|
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||||
|
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||||
|
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||||
|
|
||||||
|
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) {
|
func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) {
|
||||||
users := testNewUsers(t, m)
|
users := testNewUsers(t, m)
|
||||||
defer cleanUpUsersData(users)
|
defer cleanUpUsersData(users)
|
||||||
|
|
||||||
assert.Equal(m.t, len(expectedCredentials), len(users.GetUsers()))
|
r.Equal(m.t, len(expectedCredentials), len(users.GetUsers()))
|
||||||
|
|
||||||
credentials := []*credentials.Credentials{}
|
credentials := []*credentials.Credentials{}
|
||||||
for _, user := range users.users {
|
for _, user := range users.users {
|
||||||
credentials = append(credentials, user.creds)
|
credentials = append(credentials, user.creds)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(m.t, expectedCredentials, credentials)
|
r.Equal(m.t, expectedCredentials, credentials)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -33,8 +33,9 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/require"
|
r "github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
@ -49,9 +50,12 @@ func TestMain(m *testing.M) {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
|
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
|
||||||
|
UserID: "user",
|
||||||
|
AuthRefresh: pmapi.AuthRefresh{
|
||||||
UID: "uid",
|
UID: "uid",
|
||||||
AccessToken: "acc",
|
AccessToken: "acc",
|
||||||
RefreshToken: "ref",
|
RefreshToken: "ref",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
|
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||||
@ -81,7 +85,7 @@ var (
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||||
UserID: "user",
|
UserID: "userDisconnected",
|
||||||
Name: "username",
|
Name: "username",
|
||||||
Emails: "user@pm.me",
|
Emails: "user@pm.me",
|
||||||
APIToken: "",
|
APIToken: "",
|
||||||
@ -94,7 +98,7 @@ var (
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||||
UserID: "users",
|
UserID: "usersDisconnected",
|
||||||
Name: "usersname",
|
Name: "usersname",
|
||||||
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
|
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
|
||||||
APIToken: "",
|
APIToken: "",
|
||||||
@ -111,17 +115,22 @@ var (
|
|||||||
Name: "username",
|
Name: "username",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
testPMAPIUserDisconnected = &pmapi.User{ //nolint[gochecknoglobals]
|
||||||
|
ID: "userDisconnected",
|
||||||
|
Name: "username",
|
||||||
|
}
|
||||||
|
|
||||||
testPMAPIAddress = &pmapi.Address{ //nolint[gochecknoglobals]
|
testPMAPIAddress = &pmapi.Address{ //nolint[gochecknoglobals]
|
||||||
ID: "testAddressID",
|
ID: "testAddressID",
|
||||||
Type: pmapi.OriginalAddress,
|
Type: pmapi.OriginalAddress,
|
||||||
Email: "user@pm.me",
|
Email: "user@pm.me",
|
||||||
Receive: pmapi.CanReceive,
|
Receive: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
testPMAPIAddresses = []*pmapi.Address{ //nolint[gochecknoglobals]
|
testPMAPIAddresses = []*pmapi.Address{ //nolint[gochecknoglobals]
|
||||||
{ID: "usersAddress1ID", Email: "users@pm.me", Receive: pmapi.CanReceive, Type: pmapi.OriginalAddress},
|
{ID: "usersAddress1ID", Email: "users@pm.me", Receive: true, Type: pmapi.OriginalAddress},
|
||||||
{ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress},
|
{ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: true, Type: pmapi.AliasAddress},
|
||||||
{ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress},
|
{ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: true, Type: pmapi.AliasAddress},
|
||||||
}
|
}
|
||||||
|
|
||||||
testPMAPIEvent = &pmapi.Event{ // nolint[gochecknoglobals]
|
testPMAPIEvent = &pmapi.Event{ // nolint[gochecknoglobals]
|
||||||
@ -129,15 +138,6 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func waitForEvents() {
|
|
||||||
// Wait for goroutine to add listener.
|
|
||||||
// E.g. calling login to invoke firstsync event. Functions can end sooner than
|
|
||||||
// goroutines call the listener mock. We need to wait a little bit before the end of
|
|
||||||
// the test to capture all event calls. This allows us to detect whether there were
|
|
||||||
// missing calls, or perhaps whether something was called too many times.
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
type mocks struct {
|
type mocks struct {
|
||||||
t *testing.T
|
t *testing.T
|
||||||
|
|
||||||
@ -146,7 +146,7 @@ type mocks struct {
|
|||||||
PanicHandler *usersmocks.MockPanicHandler
|
PanicHandler *usersmocks.MockPanicHandler
|
||||||
credentialsStore *usersmocks.MockCredentialsStorer
|
credentialsStore *usersmocks.MockCredentialsStorer
|
||||||
storeMaker *usersmocks.MockStoreMaker
|
storeMaker *usersmocks.MockStoreMaker
|
||||||
eventListener *MockListener
|
eventListener *usersmocks.MockListener
|
||||||
|
|
||||||
clientManager *pmapimocks.MockManager
|
clientManager *pmapimocks.MockManager
|
||||||
pmapiClient *pmapimocks.MockClient
|
pmapiClient *pmapimocks.MockClient
|
||||||
@ -154,6 +154,48 @@ type mocks struct {
|
|||||||
storeCache *store.Cache
|
storeCache *store.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initMocks(t *testing.T) mocks {
|
||||||
|
var mockCtrl *gomock.Controller
|
||||||
|
if os.Getenv("VERBOSITY") == "trace" {
|
||||||
|
mockCtrl = gomock.NewController(&fullStackReporter{t})
|
||||||
|
} else {
|
||||||
|
mockCtrl = gomock.NewController(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
|
||||||
|
r.NoError(t, err, "could not get temporary file for store cache")
|
||||||
|
|
||||||
|
m := mocks{
|
||||||
|
t: t,
|
||||||
|
|
||||||
|
ctrl: mockCtrl,
|
||||||
|
locator: usersmocks.NewMockLocator(mockCtrl),
|
||||||
|
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
|
||||||
|
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
|
||||||
|
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
|
||||||
|
eventListener: usersmocks.NewMockListener(mockCtrl),
|
||||||
|
|
||||||
|
clientManager: pmapimocks.NewMockManager(mockCtrl),
|
||||||
|
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||||
|
|
||||||
|
storeCache: store.NewCache(cacheFile.Name()),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Called during clean-up.
|
||||||
|
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
|
||||||
|
|
||||||
|
// Set up store factory.
|
||||||
|
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
|
||||||
|
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
|
||||||
|
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
|
||||||
|
r.NoError(t, err, "could not get temporary file for store db")
|
||||||
|
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
|
||||||
|
}).AnyTimes()
|
||||||
|
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
type fullStackReporter struct {
|
type fullStackReporter struct {
|
||||||
T testing.TB
|
T testing.TB
|
||||||
}
|
}
|
||||||
@ -168,86 +210,18 @@ func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
|
|||||||
fr.T.FailNow()
|
fr.T.FailNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
func initMocks(t *testing.T) mocks {
|
|
||||||
var mockCtrl *gomock.Controller
|
|
||||||
if os.Getenv("VERBOSITY") == "trace" {
|
|
||||||
mockCtrl = gomock.NewController(&fullStackReporter{t})
|
|
||||||
} else {
|
|
||||||
mockCtrl = gomock.NewController(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
|
|
||||||
require.NoError(t, err, "could not get temporary file for store cache")
|
|
||||||
|
|
||||||
m := mocks{
|
|
||||||
t: t,
|
|
||||||
|
|
||||||
ctrl: mockCtrl,
|
|
||||||
locator: usersmocks.NewMockLocator(mockCtrl),
|
|
||||||
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
|
|
||||||
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
|
|
||||||
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
|
|
||||||
eventListener: NewMockListener(mockCtrl),
|
|
||||||
|
|
||||||
clientManager: pmapimocks.NewMockManager(mockCtrl),
|
|
||||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
|
||||||
|
|
||||||
storeCache: store.NewCache(cacheFile.Name()),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Called during clean-up.
|
|
||||||
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
|
|
||||||
|
|
||||||
// Set up store factory.
|
|
||||||
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
|
|
||||||
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
|
|
||||||
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
|
|
||||||
require.NoError(t, err, "could not get temporary file for store db")
|
|
||||||
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
|
|
||||||
}).AnyTimes()
|
|
||||||
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
|
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
|
||||||
// Events are asynchronous
|
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil)
|
||||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).Times(2)
|
mockLoadingConnectedUser(m, testCredentials)
|
||||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
|
mockLoadingConnectedUser(m, testCredentialsSplit)
|
||||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
|
mockEventLoopNoAction(m)
|
||||||
|
|
||||||
gomock.InOrder(
|
|
||||||
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil),
|
|
||||||
|
|
||||||
// Init for user.
|
|
||||||
m.credentialsStore.EXPECT().Get(testCredentials.UserID).Return(testCredentials, nil),
|
|
||||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
|
|
||||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
|
||||||
m.credentialsStore.EXPECT().UpdateToken(testCredentials.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
|
|
||||||
m.credentialsStore.EXPECT().UpdatePassword(testCredentials.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
|
|
||||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
|
|
||||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
|
||||||
|
|
||||||
// Init for users.
|
|
||||||
m.credentialsStore.EXPECT().Get(testCredentialsSplit.UserID).Return(testCredentialsSplit, nil),
|
|
||||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
|
|
||||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
|
||||||
m.credentialsStore.EXPECT().UpdateToken(testCredentialsSplit.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentialsSplit, nil),
|
|
||||||
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsSplit.UserID, testCredentialsSplit.MailboxPassword).Return(testCredentialsSplit, nil),
|
|
||||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
|
|
||||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses),
|
|
||||||
)
|
|
||||||
|
|
||||||
return testNewUsers(t, m)
|
return testNewUsers(t, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
|
func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
|
||||||
// FIXME(conman): How to handle force upgrade?
|
m.eventListener.EXPECT().ProvideChannel(events.UpgradeApplicationEvent)
|
||||||
// m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
m.eventListener.EXPECT().ProvideChannel(events.InternetOnEvent)
|
||||||
|
|
||||||
users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true)
|
users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true)
|
||||||
|
|
||||||
@ -256,38 +230,84 @@ func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
|
|||||||
return users
|
return users
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func waitForEvents() {
|
||||||
|
// Wait for goroutine to add listener.
|
||||||
|
// E.g. calling login to invoke firstsync event. Functions can end sooner than
|
||||||
|
// goroutines call the listener mock. We need to wait a little bit before the end of
|
||||||
|
// the test to capture all event calls. This allows us to detect whether there were
|
||||||
|
// missing calls, or perhaps whether something was called too many times.
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
func cleanUpUsersData(b *Users) {
|
func cleanUpUsersData(b *Users) {
|
||||||
for _, user := range b.users {
|
for _, user := range b.users {
|
||||||
_ = user.clearStore()
|
_ = user.clearStore()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClearData(t *testing.T) {
|
func mockAddingConnectedUser(m mocks) {
|
||||||
m := initMocks(t)
|
gomock.InOrder(
|
||||||
defer m.ctrl.Finish()
|
// Mock of users.FinishLogin.
|
||||||
|
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
|
||||||
|
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||||
|
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil),
|
||||||
|
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||||
|
m.credentialsStore.EXPECT().Add("user", "username", testAuthRefresh.UID, testAuthRefresh.RefreshToken, testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}).Return(testCredentials, nil),
|
||||||
|
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||||
|
)
|
||||||
|
|
||||||
// m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
mockInitConnectedUser(m)
|
||||||
// m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
}
|
||||||
|
|
||||||
users := testNewUsersWithUsers(t, m)
|
func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
|
||||||
defer cleanUpUsersData(users)
|
authRefresh := &pmapi.AuthRefresh{
|
||||||
|
UID: "uid",
|
||||||
|
AccessToken: "acc",
|
||||||
|
RefreshToken: "ref",
|
||||||
|
}
|
||||||
|
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
gomock.InOrder(
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me")
|
// Mock of users.loadUsersFromCredentialsStore.
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
|
m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil),
|
||||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
|
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, authRefresh, nil),
|
||||||
|
m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil),
|
||||||
|
)
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
mockInitConnectedUser(m)
|
||||||
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
|
}
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
func mockInitConnectedUser(m mocks) {
|
||||||
m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil)
|
// Mock of user initialisation.
|
||||||
|
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
|
||||||
|
m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes()
|
||||||
|
|
||||||
m.locator.EXPECT().Clear()
|
// Mock of store initialisation.
|
||||||
|
gomock.InOrder(
|
||||||
|
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||||
|
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||||
|
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
require.NoError(t, users.ClearData())
|
func mockLoadingDisconnectedUser(m mocks, creds *credentials.Credentials) {
|
||||||
|
gomock.InOrder(
|
||||||
|
// Mock of users.loadUsersFromCredentialsStore.
|
||||||
|
m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil),
|
||||||
|
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
|
||||||
|
)
|
||||||
|
|
||||||
waitForEvents()
|
mockInitDisconnectedUser(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockInitDisconnectedUser(m mocks) {
|
||||||
|
gomock.InOrder(
|
||||||
|
// Mock of user initialisation.
|
||||||
|
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
|
||||||
|
|
||||||
|
// Mock of store initialisation for the unauthorized user.
|
||||||
|
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
|
||||||
|
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockEventLoopNoAction(m mocks) {
|
func mockEventLoopNoAction(m mocks) {
|
||||||
@ -297,19 +317,3 @@ func mockEventLoopNoAction(m mocks) {
|
|||||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
|
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
|
||||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockConnectedUser(m mocks) {
|
|
||||||
gomock.InOrder(
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
|
||||||
|
|
||||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
|
||||||
// m.pmapiClient.EXPECT().AuthRefresh("uid:acc").Return(testAuthRefresh, nil),
|
|
||||||
|
|
||||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
|
||||||
|
|
||||||
// Set up mocks for store initialisation for the authorized user.
|
|
||||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
|
||||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|||||||
@ -105,6 +105,10 @@ func (h *macOSHelper) Get(secretURL string) (string, string, error) {
|
|||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(results) == 0 {
|
||||||
|
return "", "", errors.New("no result")
|
||||||
|
}
|
||||||
|
|
||||||
if len(results) != 1 {
|
if len(results) != 1 {
|
||||||
return "", "", errors.New("ambiguous results")
|
return "", "", errors.New("ambiguous results")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,6 +29,7 @@ var log = logrus.WithField("pkg", "bridgeUtils/listener") //nolint[gochecknoglob
|
|||||||
// Listener has a list of channels watching for updates.
|
// Listener has a list of channels watching for updates.
|
||||||
type Listener interface {
|
type Listener interface {
|
||||||
SetLimit(eventName string, limit time.Duration)
|
SetLimit(eventName string, limit time.Duration)
|
||||||
|
ProvideChannel(eventName string) <-chan string
|
||||||
Add(eventName string, channel chan<- string)
|
Add(eventName string, channel chan<- string)
|
||||||
Remove(eventName string, channel chan<- string)
|
Remove(eventName string, channel chan<- string)
|
||||||
Emit(eventName string, data string)
|
Emit(eventName string, data string)
|
||||||
@ -69,6 +70,15 @@ func (l *listener) SetLimit(eventName string, limit time.Duration) {
|
|||||||
l.limits[eventName] = limit
|
l.limits[eventName] = limit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideChannel creates new channel, adds it to listener and sends to it
|
||||||
|
// bufferent events.
|
||||||
|
func (l *listener) ProvideChannel(eventName string) <-chan string {
|
||||||
|
ch := make(chan string)
|
||||||
|
l.Add(eventName, ch)
|
||||||
|
l.RetryEmit(eventName)
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
// Add adds an event listener.
|
// Add adds an event listener.
|
||||||
func (l *listener) Add(eventName string, channel chan<- string) {
|
func (l *listener) Add(eventName string, channel chan<- string) {
|
||||||
l.lock.Lock()
|
l.lock.Lock()
|
||||||
|
|||||||
@ -97,7 +97,7 @@ func fetchWorker(fetchReqCh <-chan fetchReq, fetchResCh chan<- fetchRes, attachW
|
|||||||
}
|
}
|
||||||
|
|
||||||
func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, error) {
|
func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, error) {
|
||||||
msg, err := req.api.GetMessage(req.messageID)
|
msg, err := req.api.GetMessage(req.ctx, req.messageID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@ -109,7 +109,7 @@ func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
process := func(value interface{}) (interface{}, error) {
|
process := func(value interface{}) (interface{}, error) {
|
||||||
rc, err := req.api.GetAttachment(value.(string))
|
rc, err := req.api.GetAttachment(req.ctx, value.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,10 +43,10 @@ func newTestFetcher(
|
|||||||
) Fetcher {
|
) Fetcher {
|
||||||
f := mocks.NewMockFetcher(m)
|
f := mocks.NewMockFetcher(m)
|
||||||
|
|
||||||
f.EXPECT().GetMessage(msg.ID).Return(msg, nil)
|
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
|
||||||
|
|
||||||
for i, att := range msg.Attachments {
|
for i, att := range msg.Attachments {
|
||||||
f.EXPECT().GetAttachment(att.ID).Return(newTestReadCloser(attData[i]), nil)
|
f.EXPECT().GetAttachment(gomock.Any(), att.ID).Return(newTestReadCloser(attData[i]), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(kr, nil)
|
f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(kr, nil)
|
||||||
|
|||||||
@ -1230,7 +1230,7 @@ func TestBuildFetchMessageFail(t *testing.T) {
|
|||||||
|
|
||||||
// Pretend the message cannot be fetched.
|
// Pretend the message cannot be fetched.
|
||||||
f := mocks.NewMockFetcher(m)
|
f := mocks.NewMockFetcher(m)
|
||||||
f.EXPECT().GetMessage(msg.ID).Return(nil, errors.New("oops"))
|
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(nil, errors.New("oops"))
|
||||||
|
|
||||||
// The job should fail, returning an error and a nil result.
|
// The job should fail, returning an error and a nil result.
|
||||||
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
|
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
|
||||||
@ -1251,8 +1251,8 @@ func TestBuildFetchAttachmentFail(t *testing.T) {
|
|||||||
|
|
||||||
// Pretend the attachment cannot be fetched.
|
// Pretend the attachment cannot be fetched.
|
||||||
f := mocks.NewMockFetcher(m)
|
f := mocks.NewMockFetcher(m)
|
||||||
f.EXPECT().GetMessage(msg.ID).Return(msg, nil)
|
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
|
||||||
f.EXPECT().GetAttachment(msg.Attachments[0].ID).Return(nil, errors.New("oops"))
|
f.EXPECT().GetAttachment(gomock.Any(), msg.Attachments[0].ID).Return(nil, errors.New("oops"))
|
||||||
|
|
||||||
// The job should fail, returning an error and a nil result.
|
// The job should fail, returning an error and a nil result.
|
||||||
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
|
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
|
||||||
@ -1272,7 +1272,7 @@ func TestBuildNoSuchKeyRing(t *testing.T) {
|
|||||||
|
|
||||||
// Pretend there is no available keyring.
|
// Pretend there is no available keyring.
|
||||||
f := mocks.NewMockFetcher(m)
|
f := mocks.NewMockFetcher(m)
|
||||||
f.EXPECT().GetMessage(msg.ID).Return(msg, nil)
|
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
|
||||||
f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(nil, errors.New("oops"))
|
f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(nil, errors.New("oops"))
|
||||||
|
|
||||||
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
|
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
|
||||||
|
|||||||
@ -31,7 +31,7 @@ const (
|
|||||||
|
|
||||||
// GetFlags returns imap flags from pmapi message attributes.
|
// GetFlags returns imap flags from pmapi message attributes.
|
||||||
func GetFlags(m *pmapi.Message) (flags []string) {
|
func GetFlags(m *pmapi.Message) (flags []string) {
|
||||||
if m.Unread == 0 {
|
if !m.Unread {
|
||||||
flags = append(flags, imap.SeenFlag)
|
flags = append(flags, imap.SeenFlag)
|
||||||
}
|
}
|
||||||
if !m.Has(pmapi.FlagSent) && !m.Has(pmapi.FlagReceived) {
|
if !m.Has(pmapi.FlagSent) && !m.Has(pmapi.FlagReceived) {
|
||||||
@ -68,11 +68,11 @@ func ParseFlags(m *pmapi.Message, flags []string) {
|
|||||||
m.Flags = pmapi.FlagReceived
|
m.Flags = pmapi.FlagReceived
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Unread = 1
|
m.Unread = true
|
||||||
for _, f := range flags {
|
for _, f := range flags {
|
||||||
switch f {
|
switch f {
|
||||||
case imap.SeenFlag:
|
case imap.SeenFlag:
|
||||||
m.Unread = 0
|
m.Unread = false
|
||||||
case imap.DraftFlag:
|
case imap.DraftFlag:
|
||||||
m.Flags &= ^pmapi.FlagSent
|
m.Flags &= ^pmapi.FlagSent
|
||||||
m.Flags &= ^pmapi.FlagReceived
|
m.Flags &= ^pmapi.FlagReceived
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
context "context"
|
||||||
io "io"
|
io "io"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
@ -37,33 +38,33 @@ func (m *MockFetcher) EXPECT() *MockFetcherMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAttachment mocks base method
|
// GetAttachment mocks base method
|
||||||
func (m *MockFetcher) GetAttachment(arg0 string) (io.ReadCloser, error) {
|
func (m *MockFetcher) GetAttachment(arg0 context.Context, arg1 string) (io.ReadCloser, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetAttachment", arg0)
|
ret := m.ctrl.Call(m, "GetAttachment", arg0, arg1)
|
||||||
ret0, _ := ret[0].(io.ReadCloser)
|
ret0, _ := ret[0].(io.ReadCloser)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAttachment indicates an expected call of GetAttachment
|
// GetAttachment indicates an expected call of GetAttachment
|
||||||
func (mr *MockFetcherMockRecorder) GetAttachment(arg0 interface{}) *gomock.Call {
|
func (mr *MockFetcherMockRecorder) GetAttachment(arg0, arg1 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockFetcher)(nil).GetAttachment), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockFetcher)(nil).GetAttachment), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessage mocks base method
|
// GetMessage mocks base method
|
||||||
func (m *MockFetcher) GetMessage(arg0 string) (*pmapi.Message, error) {
|
func (m *MockFetcher) GetMessage(arg0 context.Context, arg1 string) (*pmapi.Message, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetMessage", arg0)
|
ret := m.ctrl.Call(m, "GetMessage", arg0, arg1)
|
||||||
ret0, _ := ret[0].(*pmapi.Message)
|
ret0, _ := ret[0].(*pmapi.Message)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessage indicates an expected call of GetMessage
|
// GetMessage indicates an expected call of GetMessage
|
||||||
func (mr *MockFetcherMockRecorder) GetMessage(arg0 interface{}) *gomock.Call {
|
func (mr *MockFetcherMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockFetcher)(nil).GetMessage), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockFetcher)(nil).GetMessage), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeyRingForAddressID mocks base method
|
// KeyRingForAddressID mocks base method
|
||||||
|
|||||||
@ -191,7 +191,7 @@ func DecodeHeader(raw string) (decoded string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncodeHeader using quoted printable and utf8
|
// EncodeHeader using quoted printable and utf8.
|
||||||
func EncodeHeader(s string) string {
|
func EncodeHeader(s string) string {
|
||||||
return mime.QEncoding.Encode("utf-8", s)
|
return mime.QEncoding.Encode("utf-8", s)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,7 +19,6 @@ package pmmime
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
//"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
|||||||
@ -32,12 +32,6 @@ const (
|
|||||||
EnabledAddress
|
EnabledAddress
|
||||||
)
|
)
|
||||||
|
|
||||||
// Address receive values.
|
|
||||||
const (
|
|
||||||
CannotReceive = iota
|
|
||||||
CanReceive
|
|
||||||
)
|
|
||||||
|
|
||||||
// Address HasKeys values.
|
// Address HasKeys values.
|
||||||
const (
|
const (
|
||||||
MissingKeys = iota
|
MissingKeys = iota
|
||||||
@ -66,7 +60,7 @@ type Address struct {
|
|||||||
DomainID string
|
DomainID string
|
||||||
Email string
|
Email string
|
||||||
Send int
|
Send int
|
||||||
Receive int
|
Receive Boolean
|
||||||
Status int
|
Status int
|
||||||
Order int `json:",omitempty"`
|
Order int `json:",omitempty"`
|
||||||
Type int
|
Type int
|
||||||
@ -103,7 +97,7 @@ func (l AddressList) AllEmails() (addresses []string) {
|
|||||||
// ActiveEmails returns only active emails.
|
// ActiveEmails returns only active emails.
|
||||||
func (l AddressList) ActiveEmails() (addresses []string) {
|
func (l AddressList) ActiveEmails() (addresses []string) {
|
||||||
for _, a := range l {
|
for _, a := range l {
|
||||||
if a.Receive == CanReceive {
|
if a.Receive {
|
||||||
addresses = append(addresses, a.Email)
|
addresses = append(addresses, a.Email)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -175,8 +169,19 @@ func (c *client) GetAddresses(ctx context.Context) (addresses AddressList, err e
|
|||||||
return res.Addresses, nil
|
return res.Addresses, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) (err error) {
|
func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) error {
|
||||||
panic("TODO")
|
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||||
|
return r.SetBody(&struct {
|
||||||
|
AddressIDs []string
|
||||||
|
}{
|
||||||
|
AddressIDs: addressIDs,
|
||||||
|
}).Put("/addresses/order")
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := c.UpdateUser(ctx)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Addresses returns the addresses stored in the client object itself rather than fetching from the API.
|
// Addresses returns the addresses stored in the client object itself rather than fetching from the API.
|
||||||
@ -185,24 +190,22 @@ func (c *client) Addresses() AddressList {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// unlockAddresses unlocks all keys for all addresses of current user.
|
// unlockAddresses unlocks all keys for all addresses of current user.
|
||||||
func (c *client) unlockAddress(passphrase []byte, address *Address) (err error) {
|
func (c *client) unlockAddress(passphrase []byte, address *Address) error {
|
||||||
if address == nil {
|
if address == nil {
|
||||||
return errors.New("address data is missing")
|
return errors.New("address data is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
if address.HasKeys == MissingKeys {
|
if address.HasKeys == MissingKeys {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var kr *crypto.KeyRing
|
kr, err := address.Keys.UnlockAll(passphrase, c.userKeyRing)
|
||||||
|
if err != nil {
|
||||||
if kr, err = address.Keys.UnlockAll(passphrase, c.userKeyRing); err != nil {
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.addrKeyRing[address.ID] = kr
|
c.addrKeyRing[address.ID] = kr
|
||||||
|
return nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) {
|
func (c *client) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) {
|
||||||
|
|||||||
@ -20,6 +20,8 @@ package pmapi
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testAddressList = AddressList{
|
var testAddressList = AddressList{
|
||||||
@ -46,39 +48,29 @@ var testAddressList = AddressList{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func routeGetAddresses(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
func routeGetAddresses(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||||
Ok(tb, checkMethodAndPath(r, "GET", "/addresses"))
|
r.NoError(tb, checkMethodAndPath(req, "GET", "/addresses"))
|
||||||
Ok(tb, isAuthReq(r, testUID, testAccessToken))
|
r.NoError(tb, isAuthReq(req, testUID, testAccessToken))
|
||||||
return "addresses/get_response.json"
|
return "addresses/get_response.json"
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddressList(t *testing.T) {
|
func TestAddressList(t *testing.T) {
|
||||||
input := "1"
|
input := "1"
|
||||||
addr := testAddressList.ByID(input)
|
addr := testAddressList.ByID(input)
|
||||||
if addr != testAddressList[0] {
|
r.Equal(t, testAddressList[0], addr)
|
||||||
t.Errorf("ById(%s) expected:\n%v\n but have:\n%v\n", input, testAddressList[0], addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
input = "42"
|
input = "42"
|
||||||
addr = testAddressList.ByID(input)
|
addr = testAddressList.ByID(input)
|
||||||
if addr != nil {
|
r.Nil(t, addr)
|
||||||
t.Errorf("ById expected nil for %s but have : %v\n", input, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
input = "root@protonmail.com"
|
input = "root@protonmail.com"
|
||||||
addr = testAddressList.ByEmail(input)
|
addr = testAddressList.ByEmail(input)
|
||||||
if addr != testAddressList[2] {
|
r.Equal(t, testAddressList[2], addr)
|
||||||
t.Errorf("ByEmail(%s) expected:\n%v\n but have:\n%v\n", input, testAddressList[2], addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
input = "idontexist@protonmail.com"
|
input = "idontexist@protonmail.com"
|
||||||
addr = testAddressList.ByEmail(input)
|
addr = testAddressList.ByEmail(input)
|
||||||
if addr != nil {
|
r.Nil(t, addr)
|
||||||
t.Errorf("ByEmail expected nil for %s but have : %v\n", input, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr = testAddressList.Main()
|
addr = testAddressList.Main()
|
||||||
if addr != testAddressList[1] {
|
r.Equal(t, testAddressList[1], addr)
|
||||||
t.Errorf("Main() expected:\n%v\n but have:\n%v\n", testAddressList[1], addr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,7 +23,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
|
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
@ -138,44 +137,6 @@ func (a *Attachment) DetachedSign(kr *crypto.KeyRing, att io.Reader) (signed io.
|
|||||||
return signAttachment(kr, att)
|
return signAttachment(kr, att)
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.Reader) (err error) {
|
|
||||||
// Create metadata fields.
|
|
||||||
if err = w.WriteField("Filename", att.Name); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = w.WriteField("MessageID", att.MessageID); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = w.WriteField("MIMEType", att.MIMEType); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = w.WriteField("ContentID", att.ContentID); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// And send attachment data.
|
|
||||||
ff, err := w.CreateFormFile("DataPacket", "DataPacket.pgp")
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err = io.Copy(ff, r); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// And send attachment data.
|
|
||||||
sigff, err := w.CreateFormFile("Signature", "Signature.pgp")
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = io.Copy(sigff, sig); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID.
|
// CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID.
|
||||||
//
|
//
|
||||||
// The returned created attachment contains the new attachment ID and its size.
|
// The returned created attachment contains the new attachment ID and its size.
|
||||||
|
|||||||
@ -28,13 +28,13 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
pmmime "github.com/ProtonMail/proton-bridge/pkg/mime"
|
pmmime "github.com/ProtonMail/proton-bridge/pkg/mime"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
a "github.com/stretchr/testify/assert"
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testAttachment = &Attachment{
|
var testAttachment = &Attachment{
|
||||||
@ -77,65 +77,40 @@ const testCreateAttachmentBody = `{
|
|||||||
"Attachment": {"ID": "y6uKIlc2HdoHPAwPSrvf7dXoZNMYvBgxshYUN67cY5DJjL2O8NYewuvGHcYvCfd8LpEoAI_GdymO0Jr0mHlsEw=="}
|
"Attachment": {"ID": "y6uKIlc2HdoHPAwPSrvf7dXoZNMYvBgxshYUN67cY5DJjL2O8NYewuvGHcYvCfd8LpEoAI_GdymO0Jr0mHlsEw=="}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
const testDeleteAttachmentBody = `{
|
|
||||||
"Code": 1000
|
|
||||||
}`
|
|
||||||
|
|
||||||
func TestAttachment_UnmarshalJSON(t *testing.T) {
|
func TestAttachment_UnmarshalJSON(t *testing.T) {
|
||||||
att := new(Attachment)
|
att := new(Attachment)
|
||||||
if err := json.Unmarshal([]byte(testAttachmentJSON), att); err != nil {
|
err := json.Unmarshal([]byte(testAttachmentJSON), att)
|
||||||
t.Fatal("Expected no error while unmarshaling JSON, got:", err)
|
r.NoError(t, err)
|
||||||
}
|
|
||||||
|
|
||||||
att.MessageID = testAttachment.MessageID // This isn't in the JSON object
|
att.MessageID = testAttachment.MessageID // This isn't in the JSON object
|
||||||
|
|
||||||
if !reflect.DeepEqual(testAttachment, att) {
|
r.Equal(t, testAttachment, att)
|
||||||
t.Errorf("Invalid attachment: expected %+v but got %+v", testAttachment, att)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_CreateAttachment(t *testing.T) {
|
func TestClient_CreateAttachment(t *testing.T) {
|
||||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/attachments"))
|
r.NoError(t, checkMethodAndPath(req, "POST", "/mail/v4/attachments"))
|
||||||
|
|
||||||
contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type"))
|
contentType, params, err := pmmime.ParseMediaType(req.Header.Get("Content-Type"))
|
||||||
if err != nil {
|
r.NoError(t, err)
|
||||||
t.Error("Expected no error while parsing request content type, got:", err)
|
r.Equal(t, "multipart/form-data", contentType)
|
||||||
}
|
|
||||||
if contentType != "multipart/form-data" {
|
|
||||||
t.Errorf("Invalid request content type: expected %v but got %v", "multipart/form-data", contentType)
|
|
||||||
}
|
|
||||||
|
|
||||||
mr := multipart.NewReader(r.Body, params["boundary"])
|
mr := multipart.NewReader(req.Body, params["boundary"])
|
||||||
form, err := mr.ReadForm(10 * 1024)
|
form, err := mr.ReadForm(10 * 1024)
|
||||||
if err != nil {
|
r.NoError(t, err)
|
||||||
t.Error("Expected no error while parsing request form, got:", err)
|
defer r.NoError(t, form.RemoveAll())
|
||||||
}
|
|
||||||
defer Ok(t, form.RemoveAll())
|
|
||||||
|
|
||||||
if form.Value["Filename"][0] != testAttachment.Name {
|
r.Equal(t, testAttachment.Name, form.Value["Filename"][0])
|
||||||
t.Errorf("Invalid attachment filename: expected %v but got %v", testAttachment.Name, form.Value["Filename"][0])
|
r.Equal(t, testAttachment.MessageID, form.Value["MessageID"][0])
|
||||||
}
|
r.Equal(t, testAttachment.MIMEType, form.Value["MIMEType"][0])
|
||||||
if form.Value["MessageID"][0] != testAttachment.MessageID {
|
|
||||||
t.Errorf("Invalid attachment message id: expected %v but got %v", testAttachment.MessageID, form.Value["MessageID"][0])
|
|
||||||
}
|
|
||||||
if form.Value["MIMEType"][0] != testAttachment.MIMEType {
|
|
||||||
t.Errorf("Invalid attachment message id: expected %v but got %v", testAttachment.MIMEType, form.Value["MIMEType"][0])
|
|
||||||
}
|
|
||||||
|
|
||||||
dataFile, err := form.File["DataPacket"][0].Open()
|
dataFile, err := form.File["DataPacket"][0].Open()
|
||||||
if err != nil {
|
r.NoError(t, err)
|
||||||
t.Error("Expected no error while opening packets file, got:", err)
|
defer r.NoError(t, dataFile.Close())
|
||||||
}
|
|
||||||
defer Ok(t, dataFile.Close())
|
|
||||||
|
|
||||||
b, err := ioutil.ReadAll(dataFile)
|
b, err := ioutil.ReadAll(dataFile)
|
||||||
if err != nil {
|
r.NoError(t, err)
|
||||||
t.Error("Expected no error while reading packets file, got:", err)
|
r.Equal(t, testAttachmentCleartext, string(b))
|
||||||
}
|
|
||||||
if string(b) != testAttachmentCleartext {
|
|
||||||
t.Errorf("Invalid attachment packets: expected %v but got %v", testAttachment.KeyPackets, string(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
@ -143,50 +118,39 @@ func TestClient_CreateAttachment(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
r := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted
|
reader := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted
|
||||||
created, err := c.CreateAttachment(context.TODO(), testAttachment, r, strings.NewReader(""))
|
created, err := c.CreateAttachment(context.Background(), testAttachment, reader, strings.NewReader(""))
|
||||||
if err != nil {
|
r.NoError(t, err)
|
||||||
t.Fatal("Expected no error while creating attachment, got:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if created.ID != testAttachment.ID {
|
r.Equal(t, testAttachment.ID, created.ID)
|
||||||
t.Errorf("Invalid attachment id: expected %v but got %v", testAttachment.ID, created.ID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_GetAttachment(t *testing.T) {
|
func TestClient_GetAttachment(t *testing.T) {
|
||||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
Ok(t, checkMethodAndPath(r, "GET", "/mail/v4/attachments/"+testAttachment.ID))
|
r.NoError(t, checkMethodAndPath(req, "GET", "/mail/v4/attachments/"+testAttachment.ID))
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
fmt.Fprint(w, testAttachmentCleartext)
|
fmt.Fprint(w, testAttachmentCleartext)
|
||||||
}))
|
}))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
r, err := c.GetAttachment(context.TODO(), testAttachment.ID)
|
att, err := c.GetAttachment(context.Background(), testAttachment.ID)
|
||||||
if err != nil {
|
r.NoError(t, err)
|
||||||
t.Fatal("Expected no error while getting attachment, got:", err)
|
defer att.Close() //nolint[errcheck]
|
||||||
}
|
|
||||||
defer r.Close() //nolint[errcheck]
|
|
||||||
|
|
||||||
// In reality, r contains encrypted data
|
// In reality, r contains encrypted data
|
||||||
b, err := ioutil.ReadAll(r)
|
b, err := ioutil.ReadAll(att)
|
||||||
if err != nil {
|
r.NoError(t, err)
|
||||||
t.Fatal("Expected no error while reading attachment, got:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(b) != testAttachmentCleartext {
|
r.Equal(t, testAttachmentCleartext, string(b))
|
||||||
t.Errorf("Invalid attachment data: expected %q but got %q", testAttachmentCleartext, string(b))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAttachment_Encrypt(t *testing.T) {
|
func TestAttachment_Encrypt(t *testing.T) {
|
||||||
data := bytes.NewBufferString(testAttachmentCleartext)
|
data := bytes.NewBufferString(testAttachmentCleartext)
|
||||||
r, err := testAttachment.Encrypt(testPublicKeyRing, data)
|
r, err := testAttachment.Encrypt(testPublicKeyRing, data)
|
||||||
assert.Nil(t, err)
|
a.Nil(t, err)
|
||||||
b, err := ioutil.ReadAll(r)
|
b, err := ioutil.ReadAll(r)
|
||||||
assert.Nil(t, err)
|
a.Nil(t, err)
|
||||||
|
|
||||||
// Result is always different, so the best way is to test it by decrypting again.
|
// Result is always different, so the best way is to test it by decrypting again.
|
||||||
// Another test for decrypting will help us to be sure it's working.
|
// Another test for decrypting will help us to be sure it's working.
|
||||||
@ -202,8 +166,8 @@ func TestAttachment_Decrypt(t *testing.T) {
|
|||||||
|
|
||||||
func decryptAndCheck(t *testing.T, data io.Reader) {
|
func decryptAndCheck(t *testing.T, data io.Reader) {
|
||||||
r, err := testAttachment.Decrypt(data, testPrivateKeyRing)
|
r, err := testAttachment.Decrypt(data, testPrivateKeyRing)
|
||||||
assert.Nil(t, err)
|
a.Nil(t, err)
|
||||||
b, err := ioutil.ReadAll(r)
|
b, err := ioutil.ReadAll(r)
|
||||||
assert.Nil(t, err)
|
a.Nil(t, err)
|
||||||
assert.Equal(t, testAttachmentCleartext, string(b))
|
a.Equal(t, testAttachmentCleartext, string(b))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,3 +1,20 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package pmapi
|
package pmapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -6,15 +23,117 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *client) Auth2FA(ctx context.Context, req Auth2FAReq) error {
|
type AuthModulus struct {
|
||||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
Modulus string
|
||||||
return r.SetBody(req).Post("/auth/2fa")
|
ModulusID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetAuthInfoReq struct {
|
||||||
|
Username string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthInfo struct {
|
||||||
|
Version int
|
||||||
|
Modulus string
|
||||||
|
ServerEphemeral string
|
||||||
|
Salt string
|
||||||
|
SRPSession string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TwoFAInfo struct {
|
||||||
|
Enabled TwoFAStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
func (twoFAInfo TwoFAInfo) hasTwoFactor() bool {
|
||||||
|
return twoFAInfo.Enabled > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
type TwoFAStatus int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TwoFADisabled TwoFAStatus = iota
|
||||||
|
TOTPEnabled
|
||||||
|
U2FEnabled
|
||||||
|
TOTPAndU2FEnabled
|
||||||
|
)
|
||||||
|
|
||||||
|
type PasswordMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
OnePasswordMode PasswordMode = iota + 1
|
||||||
|
TwoPasswordMode
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthReq struct {
|
||||||
|
Username string
|
||||||
|
ClientProof string
|
||||||
|
ClientEphemeral string
|
||||||
|
SRPSession string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthRefresh struct {
|
||||||
|
UID string
|
||||||
|
AccessToken string
|
||||||
|
RefreshToken string
|
||||||
|
ExpiresIn int64
|
||||||
|
Scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Auth struct {
|
||||||
|
AuthRefresh
|
||||||
|
|
||||||
|
UserID string
|
||||||
|
ServerProof string
|
||||||
|
PasswordMode PasswordMode
|
||||||
|
TwoFA *TwoFAInfo `json:"2FA,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Auth) HasTwoFactor() bool {
|
||||||
|
if a.TwoFA == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a.TwoFA.hasTwoFactor()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Auth) HasMailboxPassword() bool {
|
||||||
|
return a.PasswordMode == TwoPasswordMode
|
||||||
|
}
|
||||||
|
|
||||||
|
type auth2FAReq struct {
|
||||||
|
TwoFactorCode string
|
||||||
|
}
|
||||||
|
|
||||||
|
type authRefreshReq struct {
|
||||||
|
UID string
|
||||||
|
RefreshToken string
|
||||||
|
ResponseType string
|
||||||
|
GrantType string
|
||||||
|
RedirectURI string
|
||||||
|
State string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) Auth2FA(ctx context.Context, twoFactorCode string) error {
|
||||||
|
// 2FA is called during login procedure during which refresh token should
|
||||||
|
// be valid, therefore, no refresh is needed if there is an error.
|
||||||
|
ctx = ContextWithoutAuthRefresh(ctx)
|
||||||
|
|
||||||
|
if res, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||||
|
return r.SetBody(auth2FAReq{TwoFactorCode: twoFactorCode}).Post("/auth/2fa")
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
if res != nil {
|
||||||
|
switch res.StatusCode() {
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return ErrBad2FACode
|
||||||
|
case http.StatusUnprocessableEntity:
|
||||||
|
return ErrBad2FACodeTryAgain
|
||||||
|
}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,9 +148,7 @@ func (c *client) AuthDelete(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.uid, c.acc, c.ref, c.exp = "", "", "", time.Time{}
|
c.uid, c.acc, c.ref, c.exp = "", "", "", time.Time{}
|
||||||
|
c.sendAuthRefresh(nil)
|
||||||
// FIXME(conman): should we perhaps signal via AuthHandler that the auth was deleted?
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,7 +171,7 @@ func (c *client) AuthSalt(ctx context.Context) (string, error) {
|
|||||||
return "", errors.New("no matching salt found")
|
return "", errors.New("no matching salt found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) AddAuthHandler(handler AuthHandler) {
|
func (c *client) AddAuthRefreshHandler(handler AuthRefreshHandler) {
|
||||||
c.authHandlers = append(c.authHandlers, handler)
|
c.authHandlers = append(c.authHandlers, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,23 +179,35 @@ func (c *client) authRefresh(ctx context.Context) error {
|
|||||||
c.authLocker.Lock()
|
c.authLocker.Lock()
|
||||||
defer c.authLocker.Unlock()
|
defer c.authLocker.Unlock()
|
||||||
|
|
||||||
auth, err := c.req.authRefresh(ctx, c.uid, c.ref)
|
if c.ref == "" {
|
||||||
|
return ErrUnauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
auth, err := c.manager.authRefresh(ctx, c.uid, c.ref)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err != ErrNoConnection {
|
||||||
|
c.sendAuthRefresh(nil)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.acc = auth.AccessToken
|
c.acc = auth.AccessToken
|
||||||
c.ref = auth.RefreshToken
|
c.ref = auth.RefreshToken
|
||||||
|
c.exp = expiresIn(auth.ExpiresIn)
|
||||||
|
|
||||||
for _, handler := range c.authHandlers {
|
c.sendAuthRefresh(auth)
|
||||||
if err := handler(auth); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *client) sendAuthRefresh(auth *AuthRefresh) {
|
||||||
|
for _, handler := range c.authHandlers {
|
||||||
|
go handler(auth)
|
||||||
|
}
|
||||||
|
if auth == nil {
|
||||||
|
c.authHandlers = []AuthRefreshHandler{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func randomString(length int) string {
|
func randomString(length int) string {
|
||||||
noise := make([]byte, length)
|
noise := make([]byte, length)
|
||||||
|
|
||||||
|
|||||||
@ -1,22 +1,40 @@
|
|||||||
package pmapi_test
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package pmapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
a "github.com/stretchr/testify/assert"
|
||||||
|
r "github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAutomaticAuthRefresh(t *testing.T) {
|
func TestAutomaticAuthRefresh(t *testing.T) {
|
||||||
var wantAuth = &pmapi.Auth{
|
var wantAuthRefresh = &AuthRefresh{
|
||||||
UID: "testUID",
|
UID: "testUID",
|
||||||
AccessToken: "testAcc",
|
AccessToken: "testAcc",
|
||||||
RefreshToken: "testRef",
|
RefreshToken: "testRef",
|
||||||
|
ExpiresIn: 100,
|
||||||
}
|
}
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
@ -24,7 +42,7 @@ func TestAutomaticAuthRefresh(t *testing.T) {
|
|||||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
|
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -35,28 +53,28 @@ func TestAutomaticAuthRefresh(t *testing.T) {
|
|||||||
|
|
||||||
ts := httptest.NewServer(mux)
|
ts := httptest.NewServer(mux)
|
||||||
|
|
||||||
var gotAuth *pmapi.Auth
|
var gotAuthRefresh *AuthRefresh
|
||||||
|
|
||||||
// Create a new client.
|
c := New(Config{HostURL: ts.URL}).
|
||||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
|
||||||
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
|
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
|
||||||
|
|
||||||
// Register an auth handler.
|
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||||
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
|
|
||||||
|
|
||||||
// Make a request with an access token that already expired one second ago.
|
// Make a request with an access token that already expired one second ago.
|
||||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
_, err := c.GetAddresses(context.Background())
|
||||||
t.Fatal("got unexpected error", err)
|
r.NoError(t, err)
|
||||||
}
|
|
||||||
|
|
||||||
// The auth callback should have been called.
|
// The auth callback should have been called.
|
||||||
if *gotAuth != *wantAuth {
|
a.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||||
t.Fatal("got unexpected auth", gotAuth)
|
|
||||||
}
|
cl := c.(*client) //nolint[forcetypeassert] we want to panic here
|
||||||
|
a.Equal(t, wantAuthRefresh.AccessToken, cl.acc)
|
||||||
|
a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref)
|
||||||
|
a.WithinDuration(t, expiresIn(100), cl.exp, time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test401AuthRefresh(t *testing.T) {
|
func Test401AuthRefresh(t *testing.T) {
|
||||||
var wantAuth = &pmapi.Auth{
|
var wantAuthRefresh = &AuthRefresh{
|
||||||
UID: "testUID",
|
UID: "testUID",
|
||||||
AccessToken: "testAcc",
|
AccessToken: "testAcc",
|
||||||
RefreshToken: "testRef",
|
RefreshToken: "testRef",
|
||||||
@ -67,7 +85,7 @@ func Test401AuthRefresh(t *testing.T) {
|
|||||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
|
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -86,24 +104,21 @@ func Test401AuthRefresh(t *testing.T) {
|
|||||||
|
|
||||||
ts := httptest.NewServer(mux)
|
ts := httptest.NewServer(mux)
|
||||||
|
|
||||||
var gotAuth *pmapi.Auth
|
var gotAuthRefresh *AuthRefresh
|
||||||
|
|
||||||
// Create a new client.
|
// Create a new client.
|
||||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
c := New(Config{HostURL: ts.URL}).
|
||||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||||
|
|
||||||
// Register an auth handler.
|
// Register an auth handler.
|
||||||
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
|
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||||
|
|
||||||
// The first request will fail with 401, triggering a refresh and retry.
|
// The first request will fail with 401, triggering a refresh and retry.
|
||||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
_, err := c.GetAddresses(context.Background())
|
||||||
t.Fatal("got unexpected error", err)
|
r.NoError(t, err)
|
||||||
}
|
|
||||||
|
|
||||||
// The auth callback should have been called.
|
// The auth callback should have been called.
|
||||||
if *gotAuth != *wantAuth {
|
r.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||||
t.Fatal("got unexpected auth", gotAuth)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test401RevokedAuth(t *testing.T) {
|
func Test401RevokedAuth(t *testing.T) {
|
||||||
@ -119,17 +134,57 @@ func Test401RevokedAuth(t *testing.T) {
|
|||||||
|
|
||||||
ts := httptest.NewServer(mux)
|
ts := httptest.NewServer(mux)
|
||||||
|
|
||||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
c := New(Config{HostURL: ts.URL}).
|
||||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||||
|
|
||||||
// The request will fail with 401, triggering a refresh.
|
// The request will fail with 401, triggering a refresh.
|
||||||
// The retry will also fail with 401, returning an error.
|
// The retry will also fail with 401, returning an error.
|
||||||
_, err := c.GetAddresses(context.Background())
|
_, err := c.GetAddresses(context.Background())
|
||||||
if err == nil {
|
r.EqualError(t, err, ErrUnauthorized.Error())
|
||||||
t.Fatal("expected error, instead got", err)
|
}
|
||||||
}
|
|
||||||
|
func TestAuth2FA(t *testing.T) {
|
||||||
if !errors.Is(err, pmapi.ErrUnauthorized) {
|
twoFACode := "code"
|
||||||
t.Fatal("expected error to be ErrUnauthorized, instead got", err)
|
|
||||||
}
|
finish, c := newTestClientCallbacks(t,
|
||||||
|
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||||
|
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||||
|
|
||||||
|
var twoFAreq auth2FAReq
|
||||||
|
r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq))
|
||||||
|
r.Equal(t, twoFAreq.TwoFactorCode, twoFACode)
|
||||||
|
|
||||||
|
return "/auth/2fa/post_response.json"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
defer finish()
|
||||||
|
|
||||||
|
err := c.Auth2FA(context.Background(), twoFACode)
|
||||||
|
r.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuth2FA_Fail(t *testing.T) {
|
||||||
|
finish, c := newTestClientCallbacks(t,
|
||||||
|
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||||
|
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||||
|
return "/auth/2fa/post_401_bad_password.json"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
defer finish()
|
||||||
|
|
||||||
|
err := c.Auth2FA(context.Background(), "code")
|
||||||
|
r.Equal(t, ErrBad2FACode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuth2FA_Retry(t *testing.T) {
|
||||||
|
finish, c := newTestClientCallbacks(t,
|
||||||
|
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||||
|
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||||
|
return "/auth/2fa/post_422_bad_password.json"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
defer finish()
|
||||||
|
|
||||||
|
err := c.Auth2FA(context.Background(), "code")
|
||||||
|
r.Equal(t, ErrBad2FACodeTryAgain, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,72 +0,0 @@
|
|||||||
package pmapi
|
|
||||||
|
|
||||||
type AuthModulus struct {
|
|
||||||
Modulus string
|
|
||||||
ModulusID string
|
|
||||||
}
|
|
||||||
|
|
||||||
type GetAuthInfoReq struct {
|
|
||||||
Username string
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthInfo struct {
|
|
||||||
Version int
|
|
||||||
Modulus string
|
|
||||||
ServerEphemeral string
|
|
||||||
Salt string
|
|
||||||
SRPSession string
|
|
||||||
}
|
|
||||||
|
|
||||||
type TwoFAInfo struct {
|
|
||||||
Enabled TwoFAStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
type TwoFAStatus int
|
|
||||||
|
|
||||||
const (
|
|
||||||
TwoFADisabled TwoFAStatus = iota
|
|
||||||
TOTPEnabled
|
|
||||||
// TODO: Support UTF
|
|
||||||
)
|
|
||||||
|
|
||||||
type PasswordMode int
|
|
||||||
|
|
||||||
const (
|
|
||||||
OnePasswordMode PasswordMode = iota + 1
|
|
||||||
TwoPasswordMode
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthReq struct {
|
|
||||||
Username string
|
|
||||||
ClientProof string
|
|
||||||
ClientEphemeral string
|
|
||||||
SRPSession string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Auth struct {
|
|
||||||
UserID string
|
|
||||||
|
|
||||||
UID string
|
|
||||||
AccessToken string
|
|
||||||
RefreshToken string
|
|
||||||
ExpiresIn int64
|
|
||||||
|
|
||||||
Scope string
|
|
||||||
ServerProof string
|
|
||||||
|
|
||||||
TwoFA TwoFAInfo `json:"2FA"`
|
|
||||||
PasswordMode PasswordMode
|
|
||||||
}
|
|
||||||
|
|
||||||
type Auth2FAReq struct {
|
|
||||||
TwoFactorCode string
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthRefreshReq struct {
|
|
||||||
UID string
|
|
||||||
RefreshToken string
|
|
||||||
ResponseType string
|
|
||||||
GrantType string
|
|
||||||
RedirectURI string
|
|
||||||
State string
|
|
||||||
}
|
|
||||||
41
pkg/pmapi/boolean.go
Normal file
41
pkg/pmapi/boolean.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package pmapi
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
type Boolean bool
|
||||||
|
|
||||||
|
func (boolean *Boolean) UnmarshalJSON(b []byte) error {
|
||||||
|
var value int
|
||||||
|
err := json.Unmarshal(b, &value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*boolean = Boolean(value == 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (boolean Boolean) MarshalJSON() ([]byte, error) {
|
||||||
|
var value int
|
||||||
|
if boolean {
|
||||||
|
value = 1
|
||||||
|
}
|
||||||
|
return json.Marshal(value)
|
||||||
|
}
|
||||||
@ -25,15 +25,14 @@ import (
|
|||||||
|
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// client is a client of the protonmail API. It implements the Client interface.
|
// client is a client of the protonmail API. It implements the Client interface.
|
||||||
type client struct {
|
type client struct {
|
||||||
req requester
|
manager clientManager
|
||||||
|
|
||||||
uid, acc, ref string
|
uid, acc, ref string
|
||||||
authHandlers []AuthHandler
|
authHandlers []AuthRefreshHandler
|
||||||
authLocker sync.RWMutex
|
authLocker sync.RWMutex
|
||||||
|
|
||||||
user *User
|
user *User
|
||||||
@ -45,9 +44,9 @@ type client struct {
|
|||||||
exp time.Time
|
exp time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClient(req requester, uid string) *client {
|
func newClient(manager clientManager, uid string) *client {
|
||||||
return &client{
|
return &client{
|
||||||
req: req,
|
manager: manager,
|
||||||
uid: uid,
|
uid: uid,
|
||||||
addrKeyRing: make(map[string]*crypto.KeyRing),
|
addrKeyRing: make(map[string]*crypto.KeyRing),
|
||||||
keyRingLock: &sync.RWMutex{},
|
keyRingLock: &sync.RWMutex{},
|
||||||
@ -63,7 +62,7 @@ func (c *client) withAuth(acc, ref string, exp time.Time) *client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) r(ctx context.Context) (*resty.Request, error) {
|
func (c *client) r(ctx context.Context) (*resty.Request, error) {
|
||||||
r := c.req.r(ctx)
|
r := c.manager.r(ctx)
|
||||||
|
|
||||||
if c.uid != "" {
|
if c.uid != "" {
|
||||||
r.SetHeader("x-pm-uid", c.uid)
|
r.SetHeader("x-pm-uid", c.uid)
|
||||||
@ -91,30 +90,23 @@ func (c *client) do(ctx context.Context, fn func(*resty.Request) (*resty.Respons
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := wrapRestyError(fn(r))
|
res, err := wrapNoConnection(fn(r))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if res.StatusCode() != http.StatusUnauthorized {
|
if res.StatusCode() != http.StatusUnauthorized {
|
||||||
return nil, err
|
// Return also response so caller has more options to decide what to do.
|
||||||
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !isAuthRefreshDisabled(ctx) {
|
||||||
if err := c.authRefresh(ctx); err != nil {
|
if err := c.authRefresh(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return wrapRestyError(fn(r))
|
return wrapNoConnection(fn(r))
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func wrapRestyError(res *resty.Response, err error) (*resty.Response, error) {
|
|
||||||
if err, ok := err.(*resty.ResponseError); ok {
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.RawResponse != nil {
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, errors.Wrap(ErrNoConnection, err.Error())
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,3 +1,20 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package pmapi
|
package pmapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -12,8 +29,6 @@ func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) {
|
|||||||
c.keyRingLock.Lock()
|
c.keyRingLock.Lock()
|
||||||
defer c.keyRingLock.Unlock()
|
defer c.keyRingLock.Unlock()
|
||||||
|
|
||||||
// FIXME(conman): Should this be done as part of NewClient somehow?
|
|
||||||
|
|
||||||
return c.unlock(ctx, passphrase)
|
return c.unlock(ctx, passphrase)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,6 +80,15 @@ func (c *client) clearKeys() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) IsUnlocked() bool {
|
func (c *client) IsUnlocked() bool {
|
||||||
// FIXME(conman): Better way to check? we don't currently check address keys.
|
if c.userKeyRing == nil {
|
||||||
return c.userKeyRing != nil
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, address := range c.addresses {
|
||||||
|
if address.HasKeys != MissingKeys && c.addrKeyRing[address.ID] == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,10 +27,10 @@ import (
|
|||||||
|
|
||||||
// Client defines the interface of a PMAPI client.
|
// Client defines the interface of a PMAPI client.
|
||||||
type Client interface {
|
type Client interface {
|
||||||
Auth2FA(context.Context, Auth2FAReq) error
|
Auth2FA(context.Context, string) error
|
||||||
AuthSalt(ctx context.Context) (string, error)
|
AuthSalt(ctx context.Context) (string, error)
|
||||||
AuthDelete(context.Context) error
|
AuthDelete(context.Context) error
|
||||||
AddAuthHandler(AuthHandler)
|
AddAuthRefreshHandler(AuthRefreshHandler)
|
||||||
|
|
||||||
CurrentUser(ctx context.Context) (*User, error)
|
CurrentUser(ctx context.Context) (*User, error)
|
||||||
UpdateUser(ctx context.Context) (*User, error)
|
UpdateUser(ctx context.Context) (*User, error)
|
||||||
@ -75,9 +75,9 @@ type Client interface {
|
|||||||
GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
|
GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthHandler func(*Auth) error
|
type AuthRefreshHandler func(*AuthRefresh)
|
||||||
|
|
||||||
type requester interface {
|
type clientManager interface {
|
||||||
r(context.Context) *resty.Request
|
r(context.Context) *resty.Request
|
||||||
authRefresh(context.Context, string, string) (*Auth, error)
|
authRefresh(context.Context, string, string) (*AuthRefresh, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,72 @@
|
|||||||
|
// Copyright (c) 2021 Proton Technologies AG
|
||||||
|
//
|
||||||
|
// This file is part of ProtonMail Bridge.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
package pmapi
|
package pmapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
// HostURL is the base URL of API.
|
||||||
HostURL string
|
HostURL string
|
||||||
|
|
||||||
|
// AppVersion sets version to headers of each request.
|
||||||
AppVersion string
|
AppVersion string
|
||||||
|
|
||||||
|
// UserAgent sets user agent to headers of each request.
|
||||||
|
// Used only if GetUserAgent is not set.
|
||||||
|
UserAgent string
|
||||||
|
|
||||||
|
// GetUserAgent is dynamic version of UserAgent.
|
||||||
|
// Overrides UserAgent.
|
||||||
|
GetUserAgent func() string
|
||||||
|
|
||||||
|
// UpgradeApplicationHandler is used to notify when there is a force upgrade.
|
||||||
|
UpgradeApplicationHandler func()
|
||||||
|
|
||||||
|
// TLSIssueHandler is used to notify when there is a TLS issue.
|
||||||
|
TLSIssueHandler func()
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultConfig = Config{
|
func NewConfig(appVersionName, appVersion string) Config {
|
||||||
HostURL: "https://api.protonmail.ch",
|
return Config{
|
||||||
AppVersion: "Other",
|
HostURL: getRootURL(),
|
||||||
|
AppVersion: getAPIOS() + strings.Title(appVersionName) + "_" + appVersion,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) getUserAgent() string {
|
||||||
|
if c.GetUserAgent == nil {
|
||||||
|
return c.UserAgent
|
||||||
|
}
|
||||||
|
return c.GetUserAgent()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAPIOS returns actual operating system.
|
||||||
|
func getAPIOS() string {
|
||||||
|
switch os := runtime.GOOS; os {
|
||||||
|
case "darwin": // nolint: goconst
|
||||||
|
return "macOS"
|
||||||
|
case "linux":
|
||||||
|
return "Linux"
|
||||||
|
case "windows":
|
||||||
|
return "Windows"
|
||||||
|
}
|
||||||
|
return "Linux"
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user