GODT-35: Finish all details and make tests pass

This commit is contained in:
Michal Horejsek
2021-03-11 14:37:15 +01:00
committed by Jakub
parent 2284e9ede1
commit 8109831c07
173 changed files with 4697 additions and 2897 deletions

5
go.mod
View File

@ -40,7 +40,7 @@ require (
github.com/fatih/color v1.9.0 github.com/fatih/color v1.9.0
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/getsentry/sentry-go v0.8.0 github.com/getsentry/sentry-go v0.8.0
github.com/go-resty/resty/v2 v2.4.0 github.com/go-resty/resty/v2 v2.6.0
github.com/golang/mock v1.4.4 github.com/golang/mock v1.4.4
github.com/google/go-cmp v0.5.1 github.com/google/go-cmp v0.5.1
github.com/google/uuid v1.1.1 github.com/google/uuid v1.1.1
@ -50,6 +50,7 @@ require (
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
github.com/logrusorgru/aurora v2.0.3+incompatible github.com/logrusorgru/aurora v2.0.3+incompatible
github.com/mattn/go-runewidth v0.0.9 // indirect github.com/mattn/go-runewidth v0.0.9 // indirect
github.com/miekg/dns v1.1.41
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
github.com/olekukonko/tablewriter v0.0.4 // indirect github.com/olekukonko/tablewriter v0.0.4 // indirect
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
@ -63,7 +64,7 @@ require (
github.com/urfave/cli/v2 v2.2.0 github.com/urfave/cli/v2 v2.2.0
github.com/vmihailenco/msgpack/v5 v5.1.3 github.com/vmihailenco/msgpack/v5 v5.1.3
go.etcd.io/bbolt v1.3.5 go.etcd.io/bbolt v1.3.5
golang.org/x/net v0.0.0-20201224014010-6772e930b67b golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec
) )

18
go.sum
View File

@ -113,8 +113,8 @@ github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclK
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w= github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8= github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8=
github.com/go-resty/resty/v2 v2.4.0 h1:s6TItTLejEI+2mn98oijC5w/Rk2YU+OA6x0mnZN6r6k= github.com/go-resty/resty/v2 v2.6.0 h1:joIR5PNLM2EFqqESUjCMGXrWmXNHEU9CEiK813oKYS4=
github.com/go-resty/resty/v2 v2.4.0/go.mod h1:B88+xCTEwvfD94NOuE6GS1wMlnoKNY8eEiNizfNwOwA= github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w/BIH7cC3Q=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
@ -195,6 +195,8 @@ github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY=
github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -310,12 +312,16 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -330,6 +336,10 @@ golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04 h1:cEhElsAv9LUt9ZUUocxzWe05oFLVd+AA2nstydTeI8g=
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 h1:Bli41pIlzTzf3KEY06n+xnzK/BESIg2ze4Pgfh/aI8c=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -181,21 +181,18 @@ func New( // nolint[funlen]
kc = keychain.NewMissingKeychain() kc = keychain.NewMissingKeychain()
} }
// FIXME(conman): Customize config depending on build type (app version, host URL). cfg := pmapi.NewConfig(configName, constants.Version)
cm := pmapi.New(pmapi.DefaultConfig) cfg.GetUserAgent = userAgent.String
cfg.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
cfg.TLSIssueHandler = func() { listener.Emit(events.TLSCertIssue, "") }
cm := pmapi.New(cfg)
// FIXME(conman): Should this be a real object, not just created via callbacks?
cm.AddConnectionObserver(pmapi.NewConnectionObserver( cm.AddConnectionObserver(pmapi.NewConnectionObserver(
func() { listener.Emit(events.InternetOffEvent, "") }, func() { listener.Emit(events.InternetOffEvent, "") },
func() { listener.Emit(events.InternetOnEvent, "") }, func() { listener.Emit(events.InternetOnEvent, "") },
)) ))
// FIXME(conman): Implement force upgrade observer.
// apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
// FIXME(conman): Set up fancy round tripper with DoH/TLS checks etc.
// cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
jar, err := cookies.NewCookieJar(settingsObj) jar, err := cookies.NewCookieJar(settingsObj)
if err != nil { if err != nil {
return nil, err return nil, err
@ -341,6 +338,7 @@ func (b *Base) run(appMainLoop func(*Base, *cli.Context) error) cli.ActionFunc {
} }
logging.SetLevel(c.String(flagLogLevel)) logging.SetLevel(c.String(flagLogLevel))
b.CM.SetLogging(logrus.WithField("pkg", "pmapi"), logrus.GetLevel() == logrus.TraceLevel)
logrus. logrus.
WithField("appName", b.Name). WithField("appName", b.Name).

View File

@ -65,8 +65,7 @@ func New(
// Allow DoH before starting the app if the user has previously set this setting. // Allow DoH before starting the app if the user has previously set this setting.
// This allows us to start even if protonmail is blocked. // This allows us to start even if protonmail is blocked.
if s.GetBool(settings.AllowProxyKey) { if s.GetBool(settings.AllowProxyKey) {
// FIXME(conman): Support enable/disable of DoH. clientManager.AllowProxy()
// clientManager.AllowProxy()
} }
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener) storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener)
@ -120,7 +119,7 @@ func (b *Bridge) heartbeat() {
// ReportBug reports a new bug from the user. // ReportBug reports a new bug from the user.
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error { func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
return b.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{ return b.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{
OS: osType, OS: osType,
OSVersion: osVersion, OSVersion: osVersion,
Browser: emailClient, Browser: emailClient,

View File

@ -21,7 +21,6 @@ import (
"context" "context"
"strings" "strings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/abiosoft/ishell" "github.com/abiosoft/ishell"
) )
@ -75,13 +74,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
return return
} }
if auth.TwoFA.Enabled == pmapi.TOTPEnabled { if auth.HasTwoFactor() {
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty) twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
if twoFactor == "" { if twoFactor == "" {
return return
} }
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor}) err = client.Auth2FA(context.Background(), twoFactor)
if err != nil { if err != nil {
f.processAPIError(err) f.processAPIError(err)
return return
@ -89,7 +88,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
} }
mailboxPassword := password mailboxPassword := password
if auth.PasswordMode == pmapi.TwoPasswordMode { if auth.HasMailboxPassword() {
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty) mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
} }
if mailboxPassword == "" { if mailboxPassword == "" {

View File

@ -84,11 +84,6 @@ func New( //nolint[funlen]
Aliases: []string{"u", "version", "v"}, Aliases: []string{"u", "version", "v"},
Func: fe.checkUpdates, Func: fe.checkUpdates,
}) })
checkCmd.AddCmd(&ishell.Cmd{Name: "internet",
Help: "check internet connection. (aliases: i, conn, connection)",
Aliases: []string{"i", "con", "connection"},
Func: fe.checkInternetConnection,
})
fe.AddCmd(checkCmd) fe.AddCmd(checkCmd)
// Print info commands. // Print info commands.
@ -177,13 +172,13 @@ func New( //nolint[funlen]
} }
func (f *frontendCLI) watchEvents() { func (f *frontendCLI) watchEvents() {
errorCh := f.getEventChannel(events.ErrorEvent) errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent) credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
internetOffCh := f.getEventChannel(events.InternetOffEvent) internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := f.getEventChannel(events.InternetOnEvent) internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent) addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := f.getEventChannel(events.LogoutEvent) logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
certIssue := f.getEventChannel(events.TLSCertIssue) certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
for { for {
select { select {
case errorDetails := <-errorCh: case errorDetails := <-errorCh:
@ -208,13 +203,6 @@ func (f *frontendCLI) watchEvents() {
} }
} }
func (f *frontendCLI) getEventChannel(event string) <-chan string {
ch := make(chan string)
f.eventListener.Add(event, ch)
f.eventListener.RetryEmit(event)
return ch
}
// Loop starts the frontend loop with an interactive shell. // Loop starts the frontend loop with an interactive shell.
func (f *frontendCLI) Loop() error { func (f *frontendCLI) Loop() error {
f.Print(` f.Print(`

View File

@ -29,14 +29,6 @@ func (f *frontendCLI) restart(c *ishell.Context) {
} }
} }
func (f *frontendCLI) checkInternetConnection(c *ishell.Context) {
if f.ie.CheckConnection() == nil {
f.Println("Internet connection is available.")
} else {
f.Println("Can not contact the server, please check your internet connection.")
}
}
func (f *frontendCLI) printLogDir(c *ishell.Context) { func (f *frontendCLI) printLogDir(c *ishell.Context) {
if path, err := f.locations.ProvideLogsPath(); err != nil { if path, err := f.locations.ProvideLogsPath(); err != nil {
f.Println("Failed to determine location of log files") f.Println("Failed to determine location of log files")

View File

@ -20,6 +20,7 @@ package cliie
import ( import (
"strings" "strings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/fatih/color" "github.com/fatih/color"
) )
@ -70,13 +71,10 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
func (f *frontendCLI) processAPIError(err error) { func (f *frontendCLI) processAPIError(err error) {
log.Warn("API error: ", err) log.Warn("API error: ", err)
switch err { switch err {
// FIXME(conman): How to handle various API errors?
/*
case pmapi.ErrNoConnection: case pmapi.ErrNoConnection:
f.notifyInternetOff() f.notifyInternetOff()
case pmapi.ErrUpgradeApplication: case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade() f.notifyNeedUpgrade()
*/
default: default:
f.Println("Server error:", err.Error()) f.Println("Server error:", err.Error())
} }

View File

@ -24,7 +24,6 @@ import (
"github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/settings" "github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/internal/frontend/types" "github.com/ProtonMail/proton-bridge/internal/frontend/types"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/abiosoft/ishell" "github.com/abiosoft/ishell"
) )
@ -122,13 +121,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
return return
} }
if auth.TwoFA.Enabled == pmapi.TOTPEnabled { if auth.HasTwoFactor() {
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty) twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
if twoFactor == "" { if twoFactor == "" {
return return
} }
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor}) err = client.Auth2FA(context.Background(), twoFactor)
if err != nil { if err != nil {
f.processAPIError(err) f.processAPIError(err)
return return
@ -136,7 +135,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
} }
mailboxPassword := password mailboxPassword := password
if auth.PasswordMode == pmapi.TwoPasswordMode { if auth.HasMailboxPassword() {
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty) mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
} }
if mailboxPassword == "" { if mailboxPassword == "" {

View File

@ -157,15 +157,6 @@ func New( //nolint[funlen]
}) })
fe.AddCmd(updatesCmd) fe.AddCmd(updatesCmd)
// Check commands.
checkCmd := &ishell.Cmd{Name: "check", Help: "check internet connection or new version."}
checkCmd.AddCmd(&ishell.Cmd{Name: "internet",
Help: "check internet connection. (aliases: i, conn, connection)",
Aliases: []string{"i", "con", "connection"},
Func: fe.checkInternetConnection,
})
fe.AddCmd(checkCmd)
// Print info commands. // Print info commands.
fe.AddCmd(&ishell.Cmd{Name: "log-dir", fe.AddCmd(&ishell.Cmd{Name: "log-dir",
Help: "print path to directory with logs. (aliases: log, logs)", Help: "print path to directory with logs. (aliases: log, logs)",
@ -228,14 +219,14 @@ func New( //nolint[funlen]
} }
func (f *frontendCLI) watchEvents() { func (f *frontendCLI) watchEvents() {
errorCh := f.getEventChannel(events.ErrorEvent) errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent) credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
internetOffCh := f.getEventChannel(events.InternetOffEvent) internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := f.getEventChannel(events.InternetOnEvent) internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
addressChangedCh := f.getEventChannel(events.AddressChangedEvent) addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent) addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := f.getEventChannel(events.LogoutEvent) logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
certIssue := f.getEventChannel(events.TLSCertIssue) certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
for { for {
select { select {
case errorDetails := <-errorCh: case errorDetails := <-errorCh:
@ -262,13 +253,6 @@ func (f *frontendCLI) watchEvents() {
} }
} }
func (f *frontendCLI) getEventChannel(event string) <-chan string {
ch := make(chan string)
f.eventListener.Add(event, ch)
f.eventListener.RetryEmit(event)
return ch
}
// Loop starts the frontend loop with an interactive shell. // Loop starts the frontend loop with an interactive shell.
func (f *frontendCLI) Loop() error { func (f *frontendCLI) Loop() error {
f.Print(` f.Print(`

View File

@ -39,14 +39,6 @@ func (f *frontendCLI) restart(c *ishell.Context) {
} }
} }
func (f *frontendCLI) checkInternetConnection(c *ishell.Context) {
if f.bridge.CheckConnection() == nil {
f.Println("Internet connection is available.")
} else {
f.Println("Can not contact the server, please check your internet connection.")
}
}
func (f *frontendCLI) printLogDir(c *ishell.Context) { func (f *frontendCLI) printLogDir(c *ishell.Context) {
if path, err := f.locations.ProvideLogsPath(); err != nil { if path, err := f.locations.ProvideLogsPath(); err != nil {
f.Println("Failed to determine location of log files") f.Println("Failed to determine location of log files")

View File

@ -20,6 +20,7 @@ package cli
import ( import (
"strings" "strings"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/fatih/color" "github.com/fatih/color"
) )
@ -70,13 +71,10 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
func (f *frontendCLI) processAPIError(err error) { func (f *frontendCLI) processAPIError(err error) {
log.Warn("API error: ", err) log.Warn("API error: ", err)
switch err { switch err {
// FIXME(conman): How to handle various API errors?
/*
case pmapi.ErrNoConnection: case pmapi.ErrNoConnection:
f.notifyInternetOff() f.notifyInternetOff()
case pmapi.ErrUpgradeApplication: case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade() f.notifyNeedUpgrade()
*/
default: default:
f.Println("Server error:", err.Error()) f.Println("Server error:", err.Error())
} }

View File

@ -409,7 +409,6 @@ Dialog {
onShow: { onShow: {
if (winMain.updateState==gui.enums.statusNoInternet) { if (winMain.updateState==gui.enums.statusNoInternet) {
go.checkInternet()
if (winMain.updateState==gui.enums.statusNoInternet) { if (winMain.updateState==gui.enums.statusNoInternet) {
go.notifyError(gui.enums.errNoInternet) go.notifyError(gui.enums.errNoInternet)
root.hide() root.hide()

View File

@ -857,14 +857,12 @@ Dialog {
inputPort . checkIsANumber() inputPort . checkIsANumber()
//emailProvider . currentIndex!=0 //emailProvider . currentIndex!=0
)) isOK = false )) isOK = false
go.checkInternet()
if (winMain.updateState == gui.enums.statusNoInternet) { // todo: use main error dialog for this if (winMain.updateState == gui.enums.statusNoInternet) { // todo: use main error dialog for this
errorPopup.show(qsTr("Please check your internet connection.")) errorPopup.show(qsTr("Please check your internet connection."))
return false return false
} }
break break
case 2: // loading structure case 2: // loading structure
go.checkInternet()
if (winMain.updateState == gui.enums.statusNoInternet) { if (winMain.updateState == gui.enums.statusNoInternet) {
errorPopup.show(qsTr("Please check your internet connection.")) errorPopup.show(qsTr("Please check your internet connection."))
return false return false
@ -949,7 +947,6 @@ Dialog {
onShow : { onShow : {
root.clear() root.clear()
if (winMain.updateState==gui.enums.statusNoInternet) { if (winMain.updateState==gui.enums.statusNoInternet) {
go.checkInternet()
if (winMain.updateState==gui.enums.statusNoInternet) { if (winMain.updateState==gui.enums.statusNoInternet) {
winMain.popupMessage.show(go.canNotReachAPI) winMain.popupMessage.show(go.canNotReachAPI)
root.hide() root.hide()

View File

@ -25,33 +25,12 @@ import ProtonUI 1.0
Rectangle { Rectangle {
id: root id: root
property var iTry: 0 property var iTry: 0
property var secLeft: 0
property var second: 1000 // convert millisecond to second property var second: 1000 // convert millisecond to second
property var checkInterval: [ 5, 10, 30, 60, 120, 300, 600 ] // seconds
property bool isVisible: true property bool isVisible: true
property var fontSize : 1.2 * Style.main.fontSize property var fontSize : 1.2 * Style.main.fontSize
color : "black" color : "black"
state: "upToDate" state: "upToDate"
Timer {
id: retryInternet
interval: second
triggeredOnStart: false
repeat: true
onTriggered : {
secLeft--
if (secLeft <= 0) {
retryInternet.stop()
go.checkInternet()
if (iTry < checkInterval.length-1) {
iTry++
}
secLeft=checkInterval[iTry]
retryInternet.start()
}
}
}
Row { Row {
id: messageRow id: messageRow
anchors.centerIn: root anchors.centerIn: root
@ -110,16 +89,12 @@ Rectangle {
case "internetCheck": case "internetCheck":
break; break;
case "noInternet" : case "noInternet" :
retryInternet.start()
secLeft=checkInterval[iTry]
break; break;
case "oldVersion": case "oldVersion":
break; break;
case "forceUpdate": case "forceUpdate":
break; break;
case "upToDate": case "upToDate":
iTry = 0
secLeft=checkInterval[iTry]
break; break;
case "updateRestart": case "updateRestart":
break; break;
@ -128,24 +103,6 @@ Rectangle {
default : default :
break; break;
} }
if (root.state!="noInternet") {
retryInternet.stop()
}
}
function timeToRetry() {
if (secLeft==1){
return qsTr("a second", "time to wait till internet connection is retried")
} else if (secLeft<60){
return secLeft + " " + qsTr("seconds", "time to wait till internet connection is retried")
} else {
var leading = ""+secLeft%60
if (leading.length < 2) {
leading = "0" + leading
}
return Math.floor(secLeft/60) + ":" + leading
}
} }
states: [ states: [
@ -194,23 +151,15 @@ Rectangle {
PropertyChanges { PropertyChanges {
target: message target: message
color: Style.main.line color: Style.main.line
text: qsTr("Cannot contact server. Retrying in ", "displayed when the app is disconnected from the internet or server has problems")+timeToRetry()+"." text: qsTr("Cannot contact server. Please wait...", "displayed when the app is disconnected from the internet or server has problems")
} }
PropertyChanges { PropertyChanges {
target: linkText target: linkText
visible: false visible: false
} }
PropertyChanges {
target: actionText
visible: true
text: qsTr("Retry now", "click to try to connect to the internet when the app is disconnected from the internet")
onClicked: {
go.checkInternet()
}
}
PropertyChanges { PropertyChanges {
target: separatorText target: separatorText
visible: true visible: false
text: "|" text: "|"
} }
PropertyChanges { PropertyChanges {

View File

@ -1331,10 +1331,6 @@ Window {
return (fname!="fail") return (fname!="fail")
} }
function checkInternet() {
// nothing to do
}
function loadImportReports(fname) { function loadImportReports(fname) {
console.log("load import reports for ", fname) console.log("load import reports for ", fname)
} }

View File

@ -20,6 +20,7 @@
package qtcommon package qtcommon
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
@ -207,7 +208,7 @@ func (a *Accounts) Auth2FA(twoFacAuth string) int {
if a.auth == nil || a.authClient == nil { if a.auth == nil || a.authClient == nil {
err = fmt.Errorf("missing authentication in auth2FA %p %p", a.auth, a.authClient) err = fmt.Errorf("missing authentication in auth2FA %p %p", a.auth, a.authClient)
} else { } else {
err = a.authClient.Auth2FA(twoFacAuth, a.auth) err = a.authClient.Auth2FA(context.Background(), twoFacAuth)
} }
if a.showLoginError(err, "auth2FA") { if a.showLoginError(err, "auth2FA") {

View File

@ -113,10 +113,3 @@ type Listener interface {
Add(string, chan<- string) Add(string, chan<- string)
RetryEmit(string) RetryEmit(string)
} }
func MakeAndRegisterEvent(eventListener Listener, event string) <-chan string {
ch := make(chan string)
eventListener.Add(event, ch)
eventListener.RetryEmit(event)
return ch
}

View File

@ -143,16 +143,16 @@ func (f *FrontendQt) NotifySilentUpdateError(err error) {
} }
func (f *FrontendQt) watchEvents() { func (f *FrontendQt) watchEvents() {
credentialsErrorCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.CredentialsErrorEvent) credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
internetOffCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOffEvent) internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOnEvent) internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
secondInstanceCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.SecondInstanceEvent) secondInstanceCh := f.eventListener.ProvideChannel(events.SecondInstanceEvent)
restartBridgeCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.RestartBridgeEvent) restartBridgeCh := f.eventListener.ProvideChannel(events.RestartBridgeEvent)
addressChangedCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedEvent) addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedLogoutEvent) addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.LogoutEvent) logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
updateApplicationCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UpgradeApplicationEvent) updateApplicationCh := f.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
newUserCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UserRefreshEvent) newUserCh := f.eventListener.ProvideChannel(events.UserRefreshEvent)
for { for {
select { select {
case <-credentialsErrorCh: case <-credentialsErrorCh:
@ -351,11 +351,6 @@ func (f *FrontendQt) sendBug(description, emailClient, address string) bool {
// } // }
//} //}
// checkInternet is almost idetical to bridge
func (f *FrontendQt) checkInternet() {
f.Qml.SetConnectionStatus(f.ie.CheckConnection() == nil)
}
func (f *FrontendQt) showError(code int, err error) { func (f *FrontendQt) showError(code int, err error) {
f.Qml.SetErrorDescription(err.Error()) f.Qml.SetErrorDescription(err.Error())
log.WithField("code", code).Errorln(err.Error()) log.WithField("code", code).Errorln(err.Error())

View File

@ -78,7 +78,6 @@ type GoQMLInterface struct {
_ string `property:"versionCheckFailed"` _ string `property:"versionCheckFailed"`
// //
_ func(isAvailable bool) `signal:"setConnectionStatus"` _ func(isAvailable bool) `signal:"setConnectionStatus"`
_ func() `slot:"checkInternet"`
_ func() `slot:"setToRestart"` _ func() `slot:"setToRestart"`
@ -189,8 +188,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) {
return f.programVersion return f.programVersion
}) })
s.ConnectCheckInternet(f.checkInternet)
s.ConnectSetToRestart(f.restarter.SetToRestart) s.ConnectSetToRestart(f.restarter.SetToRestart)
s.ConnectLoadStructureForExport(f.LoadStructureForExport) s.ConnectLoadStructureForExport(f.LoadStructureForExport)

View File

@ -20,6 +20,7 @@
package qt package qt
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
@ -173,7 +174,7 @@ func (s *FrontendQt) auth2FA(twoFacAuth string) int {
if s.auth == nil || s.authClient == nil { if s.auth == nil || s.authClient == nil {
err = fmt.Errorf("missing authentication in auth2FA %p %p", s.auth, s.authClient) err = fmt.Errorf("missing authentication in auth2FA %p %p", s.auth, s.authClient)
} else { } else {
err = s.authClient.Auth2FA(twoFacAuth, s.auth) err = s.authClient.Auth2FA(context.Background(), twoFacAuth)
} }
if s.showLoginError(err, "auth2FA") { if s.showLoginError(err, "auth2FA") {

View File

@ -191,20 +191,20 @@ func (s *FrontendQt) NotifySilentUpdateError(err error) {
func (s *FrontendQt) watchEvents() { func (s *FrontendQt) watchEvents() {
s.WaitUntilFrontendIsReady() s.WaitUntilFrontendIsReady()
errorCh := s.getEventChannel(events.ErrorEvent) errorCh := s.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := s.getEventChannel(events.CredentialsErrorEvent) credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent)
outgoingNoEncCh := s.getEventChannel(events.OutgoingNoEncEvent) outgoingNoEncCh := s.eventListener.ProvideChannel(events.OutgoingNoEncEvent)
noActiveKeyForRecipientCh := s.getEventChannel(events.NoActiveKeyForRecipientEvent) noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent)
internetOffCh := s.getEventChannel(events.InternetOffEvent) internetOffCh := s.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := s.getEventChannel(events.InternetOnEvent) internetOnCh := s.eventListener.ProvideChannel(events.InternetOnEvent)
secondInstanceCh := s.getEventChannel(events.SecondInstanceEvent) secondInstanceCh := s.eventListener.ProvideChannel(events.SecondInstanceEvent)
restartBridgeCh := s.getEventChannel(events.RestartBridgeEvent) restartBridgeCh := s.eventListener.ProvideChannel(events.RestartBridgeEvent)
addressChangedCh := s.getEventChannel(events.AddressChangedEvent) addressChangedCh := s.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := s.getEventChannel(events.AddressChangedLogoutEvent) addressChangedLogoutCh := s.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := s.getEventChannel(events.LogoutEvent) logoutCh := s.eventListener.ProvideChannel(events.LogoutEvent)
updateApplicationCh := s.getEventChannel(events.UpgradeApplicationEvent) updateApplicationCh := s.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
newUserCh := s.getEventChannel(events.UserRefreshEvent) newUserCh := s.eventListener.ProvideChannel(events.UserRefreshEvent)
certIssue := s.getEventChannel(events.TLSCertIssue) certIssue := s.eventListener.ProvideChannel(events.TLSCertIssue)
for { for {
select { select {
case errorDetails := <-errorCh: case errorDetails := <-errorCh:
@ -254,13 +254,6 @@ func (s *FrontendQt) watchEvents() {
} }
} }
func (s *FrontendQt) getEventChannel(event string) <-chan string {
ch := make(chan string)
s.eventListener.Add(event, ch)
s.eventListener.RetryEmit(event)
return ch
}
// Loop function for tests. // Loop function for tests.
// //
// It runs QtExecute in new thread with function returning itself after setup. // It runs QtExecute in new thread with function returning itself after setup.
@ -653,10 +646,6 @@ func (s *FrontendQt) isSMTPSTARTTLS() bool {
return !s.settings.GetBool(settings.SMTPSSLKey) return !s.settings.GetBool(settings.SMTPSSLKey)
} }
func (s *FrontendQt) checkInternet() {
s.Qml.SetConnectionStatus(s.bridge.CheckConnection() == nil)
}
func (s *FrontendQt) switchAddressModeUser(iAccount int) { func (s *FrontendQt) switchAddressModeUser(iAccount int) {
defer s.Qml.ProcessFinished() defer s.Qml.ProcessFinished()
userID := s.Accounts.get(iAccount).UserID() userID := s.Accounts.get(iAccount).UserID()

View File

@ -84,7 +84,6 @@ type GoQMLInterface struct {
_ string `property:"progressDescription"` _ string `property:"progressDescription"`
_ func(isAvailable bool) `signal:"setConnectionStatus"` _ func(isAvailable bool) `signal:"setConnectionStatus"`
_ func() `slot:"checkInternet"`
_ func() `slot:"setToRestart"` _ func() `slot:"setToRestart"`
@ -205,8 +204,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) {
return f.programVer return f.programVer
}) })
s.ConnectCheckInternet(f.checkInternet)
s.ConnectSetToRestart(f.restarter.SetToRestart) s.ConnectSetToRestart(f.restarter.SetToRestart)
s.ConnectToggleIsReportingOutgoingNoEnc(f.toggleIsReportingOutgoingNoEnc) s.ConnectToggleIsReportingOutgoingNoEnc(f.toggleIsReportingOutgoingNoEnc)

View File

@ -55,7 +55,6 @@ type UserManager interface {
GetUser(query string) (User, error) GetUser(query string) (User, error)
DeleteUser(userID string, clearCache bool) error DeleteUser(userID string, clearCache bool) error
ClearData() error ClearData() error
CheckConnection() error
} }
// User is an interface of user needed by frontend. // User is an interface of user needed by frontend.

View File

@ -38,11 +38,10 @@ type bridgeUser interface {
IsCombinedAddressMode() bool IsCombinedAddressMode() bool
GetAddressID(address string) (string, error) GetAddressID(address string) (string, error)
GetPrimaryAddress() string GetPrimaryAddress() string
UpdateUser() error
Logout() error Logout() error
CloseConnection(address string) CloseConnection(address string)
GetStore() storeUserProvider GetStore() storeUserProvider
GetTemporaryPMAPIClient() pmapi.Client GetClient() pmapi.Client
} }
type bridgeWrap struct { type bridgeWrap struct {

View File

@ -422,7 +422,7 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
if isStringInList(m.LabelIDs, pmapi.StarredLabel) { if isStringInList(m.LabelIDs, pmapi.StarredLabel) {
messageFlagsMap[imap.FlaggedFlag] = true messageFlagsMap[imap.FlaggedFlag] = true
} }
if m.Unread == 0 { if !m.Unread {
messageFlagsMap[imap.SeenFlag] = true messageFlagsMap[imap.SeenFlag] = true
} }
if m.Has(pmapi.FlagReplied) || m.Has(pmapi.FlagRepliedAll) { if m.Has(pmapi.FlagReplied) || m.Has(pmapi.FlagRepliedAll) {
@ -560,7 +560,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil, err return nil, err
} }
if storeMessage.Message().Unread == 1 { if storeMessage.Message().Unread {
for section := range msg.Body { for section := range msg.Body {
// Peek means get messages without marking them as read. // Peek means get messages without marking them as read.
// If client does not only ask for peek, we have to mark them as read. // If client does not only ask for peek, we have to mark them as read.

View File

@ -93,7 +93,7 @@ func newIMAPUser(
// This method should eventually no longer be necessary. Everything should go via store. // This method should eventually no longer be necessary. Everything should go via store.
func (iu *imapUser) client() pmapi.Client { func (iu *imapUser) client() pmapi.Client {
return iu.user.GetTemporaryPMAPIClient() return iu.user.GetClient()
} }
func (iu *imapUser) isSubscribed(labelID string) bool { func (iu *imapUser) isSubscribed(labelID string) bool {

View File

@ -22,6 +22,7 @@ import (
"bytes" "bytes"
"context" "context"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/transfer" "github.com/ProtonMail/proton-bridge/internal/transfer"
"github.com/ProtonMail/proton-bridge/internal/users" "github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
@ -40,6 +41,7 @@ type ImportExport struct {
locations Locator locations Locator
cache Cacher cache Cacher
panicHandler users.PanicHandler panicHandler users.PanicHandler
eventListener listener.Listener
clientManager pmapi.Manager clientManager pmapi.Manager
} }
@ -59,13 +61,14 @@ func New(
locations: locations, locations: locations,
cache: cache, cache: cache,
panicHandler: panicHandler, panicHandler: panicHandler,
eventListener: eventListener,
clientManager: clientManager, clientManager: clientManager,
} }
} }
// ReportBug reports a new bug from the user. // ReportBug reports a new bug from the user.
func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error { func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
return ie.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{ return ie.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{
OS: osType, OS: osType,
OSVersion: osVersion, OSVersion: osVersion,
Browser: emailClient, Browser: emailClient,
@ -89,7 +92,7 @@ func (ie *ImportExport) ReportFile(osType, osVersion, accountName, address strin
report.AddAttachment("log", "report.log", bytes.NewReader(logdata)) report.AddAttachment("log", "report.log", bytes.NewReader(logdata))
return ie.clientManager.ReportBug(context.TODO(), report) return ie.clientManager.ReportBug(context.Background(), report)
} }
// GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account. // GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account.
@ -162,5 +165,23 @@ func (ie *ImportExport) getPMAPIProvider(username, address string) (*transfer.PM
log.WithError(err).Info("Address does not exist, using all addresses") log.WithError(err).Info("Address does not exist, using all addresses")
} }
return transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID) provider, err := transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
if err != nil {
return nil, err
}
go func() {
internetOffCh := ie.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := ie.eventListener.ProvideChannel(events.InternetOnEvent)
for {
select {
case <-internetOffCh:
provider.SetConnectionDown()
case <-internetOnCh:
provider.SetConnectionUp()
}
}
}()
return provider, nil
} }

View File

@ -31,7 +31,7 @@ type bridgeUser interface {
CheckBridgeLogin(password string) error CheckBridgeLogin(password string) error
IsCombinedAddressMode() bool IsCombinedAddressMode() bool
GetAddressID(address string) (string, error) GetAddressID(address string) (string, error)
GetTemporaryPMAPIClient() pmapi.Client GetClient() pmapi.Client
GetStore() storeUserProvider GetStore() storeUserProvider
} }

View File

@ -81,7 +81,7 @@ func newSMTPUser(
// This method should eventually no longer be necessary. Everything should go via store. // This method should eventually no longer be necessary. Everything should go via store.
func (su *smtpUser) client() pmapi.Client { func (su *smtpUser) client() pmapi.Client {
return su.user.GetTemporaryPMAPIClient() return su.user.GetClient()
} }
// Send sends an email from the given address to the given addresses with the given body. // Send sends an email from the given address to the given addresses with the given body.

View File

@ -90,7 +90,7 @@ func getLabelPrefix(l *pmapi.Label) string {
switch { switch {
case pmapi.IsSystemLabel(l.ID): case pmapi.IsSystemLabel(l.ID):
return "" return ""
case l.Exclusive == 1: case bool(l.Exclusive):
return UserFoldersPrefix return UserFoldersPrefix
default: default:
return UserLabelsPrefix return UserLabelsPrefix

View File

@ -37,8 +37,8 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) {
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
m.store.SetChangeNotifier(m.changeNotifier) m.store.SetChangeNotifier(m.changeNotifier)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
} }
func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) { func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
@ -52,8 +52,8 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
m.store.SetChangeNotifier(m.changeNotifier) m.store.SetChangeNotifier(m.changeNotifier)
msg1 := getTestMessage("msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg2 := getTestMessage("msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel}) msg2 := getTestMessage("msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
require.Nil(t, m.store.createOrUpdateMessagesEvent([]*pmapi.Message{msg1, msg2})) require.Nil(t, m.store.createOrUpdateMessagesEvent([]*pmapi.Message{msg1, msg2}))
} }
@ -63,8 +63,8 @@ func TestNotifyChangeDeleteMessage(t *testing.T) {
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(2)) m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(2))
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(1)) m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(1))

View File

@ -81,7 +81,7 @@ func (loop *eventLoop) client() pmapi.Client {
func (loop *eventLoop) setFirstEventID() (err error) { func (loop *eventLoop) setFirstEventID() (err error) {
loop.log.Info("Setting first event ID") loop.log.Info("Setting first event ID")
event, err := loop.client().GetEvent(context.TODO(), "") event, err := loop.client().GetEvent(context.Background(), "")
if err != nil { if err != nil {
loop.log.WithError(err).Error("Could not get latest event ID") loop.log.WithError(err).Error("Could not get latest event ID")
return return
@ -222,8 +222,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually // We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
// (e.g. no internet, ulimit reached etc.) // (e.g. no internet, ulimit reached etc.)
defer func() { defer func() {
// FIXME(conman): How to handle errors of different types? if errors.Cause(err) == pmapi.ErrNoConnection {
if errors.Is(err, pmapi.ErrNoConnection) {
l.Warn("Internet unavailable") l.Warn("Internet unavailable")
err = nil err = nil
} }
@ -234,20 +233,17 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
err = nil err = nil
} }
// FIXME(conman): Handle force upgrade.
/*
if errors.Cause(err) == pmapi.ErrUpgradeApplication { if errors.Cause(err) == pmapi.ErrUpgradeApplication {
l.Warn("Need to upgrade application") l.Warn("Need to upgrade application")
err = nil err = nil
} }
*/
if err == nil { if err == nil {
loop.errCounter = 0 loop.errCounter = 0
} }
// All errors except ErrUnauthorized (which is not possible to recover from) are ignored. // All errors except ErrUnauthorized (which is not possible to recover from) are ignored.
if !errors.Is(err, pmapi.ErrUnauthorized) { if err != nil && errors.Cause(err) != pmapi.ErrUnauthorized {
l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped") l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped")
loop.errCounter++ loop.errCounter++
if loop.errCounter == errMaxSentry { if loop.errCounter == errMaxSentry {
@ -268,7 +264,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
loop.pollCounter++ loop.pollCounter++
var event *pmapi.Event var event *pmapi.Event
if event, err = loop.client().GetEvent(context.TODO(), loop.currentEventID); err != nil { if event, err = loop.client().GetEvent(context.Background(), loop.currentEventID); err != nil {
return false, errors.Wrap(err, "failed to get event") return false, errors.Wrap(err, "failed to get event")
} }
@ -295,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
} }
} }
return event.More == 1, err return bool(event.More), err
} }
func (loop *eventLoop) processEvent(event *pmapi.Event) (err error) { func (loop *eventLoop) processEvent(event *pmapi.Event) (err error) {
@ -354,7 +350,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
// Get old addresses for comparisons before updating user. // Get old addresses for comparisons before updating user.
oldList := loop.client().Addresses() oldList := loop.client().Addresses()
if err = loop.user.UpdateUser(); err != nil { if err = loop.user.UpdateUser(context.Background()); err != nil {
if logoutErr := loop.user.Logout(); logoutErr != nil { if logoutErr := loop.user.Logout(); logoutErr != nil {
log.WithError(logoutErr).Error("Failed to logout user after failed update") log.WithError(logoutErr).Error("Failed to logout user after failed update")
} }
@ -465,16 +461,12 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...") msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
if msg, err = loop.client().GetMessage(context.TODO(), message.ID); err != nil { if msg, err = loop.client().GetMessage(context.Background(), message.ID); err != nil {
// FIXME(conman): How to handle error of this particular type? if _, ok := err.(pmapi.ErrUnprocessableEntity); ok {
/*
if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok {
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API") msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
err = nil err = nil
continue continue
} }
*/
return errors.Wrap(err, "failed to get message from API for updating") return errors.Wrap(err, "failed to get message from API for updating")
} }

View File

@ -42,15 +42,15 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
// next event if there is `More` of them. // next event if there is `More` of them.
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{ m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{
EventID: "event50", EventID: "event50",
More: 1, More: true,
}, nil), }, nil),
m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{ m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{
EventID: "event70", EventID: "event70",
More: 0, More: false,
}, nil), }, nil),
m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{ m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{
EventID: "event71", EventID: "event71",
More: 0, More: false,
}, nil), }, nil),
) )
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
@ -188,7 +188,7 @@ func TestEventLoopUpdateMessage(t *testing.T) {
msg := &pmapi.Message{ msg := &pmapi.Message{
ID: "msg1", ID: "msg1",
Subject: "old", Subject: "old",
Unread: 0, Unread: false,
Flags: 10, Flags: 10,
Sender: address1, Sender: address1,
ToList: []*mail.Address{address2}, ToList: []*mail.Address{address2},
@ -200,7 +200,7 @@ func TestEventLoopUpdateMessage(t *testing.T) {
newMsg := &pmapi.Message{ newMsg := &pmapi.Message{
ID: "msg1", ID: "msg1",
Subject: "new", Subject: "new",
Unread: 1, Unread: true,
Flags: 11, Flags: 11,
Sender: address2, Sender: address2,
ToList: []*mail.Address{address1}, ToList: []*mail.Address{address1},

View File

@ -129,17 +129,10 @@ func (mc *mailboxCounts) getPMLabel() *pmapi.Label {
Color: mc.Color, Color: mc.Color,
Order: mc.Order, Order: mc.Order,
Type: pmapi.LabelTypeMailbox, Type: pmapi.LabelTypeMailbox,
Exclusive: mc.isExclusive(), Exclusive: pmapi.Boolean(mc.IsFolder),
} }
} }
func (mc *mailboxCounts) isExclusive() int {
if mc.IsFolder {
return 1
}
return 0
}
// createOrUpdateMailboxCountsBuckets will not change the on-API-counts. // createOrUpdateMailboxCountsBuckets will not change the on-API-counts.
func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) error { func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) error {
// Don't forget about system folders. // Don't forget about system folders.
@ -162,7 +155,7 @@ func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) er
mailbox.LabelName = label.Path mailbox.LabelName = label.Path
mailbox.Color = label.Color mailbox.Color = label.Color
mailbox.Order = label.Order mailbox.Order = label.Order
mailbox.IsFolder = label.Exclusive == 1 mailbox.IsFolder = bool(label.Exclusive)
// Write. // Write.
if err = mailbox.txWriteToBucket(countsBkt); err != nil { if err = mailbox.txWriteToBucket(countsBkt); err != nil {

View File

@ -75,7 +75,7 @@ func TestMailboxNames(t *testing.T) {
newLabel(100, "labelID1", "Label1"), newLabel(100, "labelID1", "Label1"),
newLabel(1000, "folderID1", "Folder1"), newLabel(1000, "folderID1", "Folder1"),
} }
foldersAndLabels[1].Exclusive = 1 foldersAndLabels[1].Exclusive = true
for _, counts := range getSystemFolders() { for _, counts := range getSystemFolders() {
foldersAndLabels = append(foldersAndLabels, counts.getPMLabel()) foldersAndLabels = append(foldersAndLabels, counts.getPMLabel())

View File

@ -37,10 +37,10 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) {
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{pmapi.AllMailLabel})
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"}) checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
@ -82,20 +82,20 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel}) tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg)) require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel}) tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = " externalID-non-pm-com " tstMsg.ExternalID = " externalID-non-pm-com "
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg)) require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel}) tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = "<externalID@pm.me>" tstMsg.ExternalID = "<externalID@pm.me>"
tstMsg.Header = mail.Header{"References": []string{"wrongID", "externalID-non-pm-com", "msg2"}} tstMsg.Header = mail.Header{"References": []string{"wrongID", "externalID-non-pm-com", "msg2"}}
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg)) require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
// Not sure if this is a real-world scenario but we should be able to address this properly. // Not sure if this is a real-world scenario but we should be able to address this properly.
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel}) tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = " < external.()+*[]ID@another.pm.me > " tstMsg.ExternalID = " < external.()+*[]ID@another.pm.me > "
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg)) require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))

View File

@ -18,8 +18,6 @@
package store package store
import ( import (
"context"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -43,7 +41,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message // FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
// wrapping it. // wrapping it.
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) { func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
msg, err := storeMailbox.client().GetMessage(context.TODO(), apiID) msg, err := storeMailbox.client().GetMessage(exposeContextForIMAP(), apiID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,7 +68,7 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
Message: body, Message: body,
} }
res, err := storeMailbox.client().Import(context.TODO(), pmapi.ImportMsgReqs{importReqs}) res, err := storeMailbox.client().Import(exposeContextForIMAP(), pmapi.ImportMsgReqs{importReqs})
if err != nil { if err != nil {
return err return err
} }
@ -99,7 +97,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
return ErrAllMailOpNotAllowed return ErrAllMailOpNotAllowed
} }
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, storeMailbox.labelID) return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID)
} }
// UnlabelMessages removes the label by calling an API. // UnlabelMessages removes the label by calling an API.
@ -112,7 +110,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
return ErrAllMailOpNotAllowed return ErrAllMailOpNotAllowed
} }
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID) return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID)
} }
// MarkMessagesRead marks the message read by calling an API. // MarkMessagesRead marks the message read by calling an API.
@ -132,14 +130,14 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
// Therefore we do not issue API update if the message is already read. // Therefore we do not issue API update if the message is already read.
ids := []string{} ids := []string{}
for _, apiID := range apiIDs { for _, apiID := range apiIDs {
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread == 1 { if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread {
ids = append(ids, apiID) ids = append(ids, apiID)
} }
} }
if len(ids) == 0 { if len(ids) == 0 {
return nil return nil
} }
return storeMailbox.client().MarkMessagesRead(context.TODO(), ids) return storeMailbox.client().MarkMessagesRead(exposeContextForIMAP(), ids)
} }
// MarkMessagesUnread marks the message unread by calling an API. // MarkMessagesUnread marks the message unread by calling an API.
@ -151,7 +149,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Marking messages as unread") }).Trace("Marking messages as unread")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.client().MarkMessagesUnread(context.TODO(), apiIDs) return storeMailbox.client().MarkMessagesUnread(exposeContextForIMAP(), apiIDs)
} }
// MarkMessagesStarred adds the Starred label by calling an API. // MarkMessagesStarred adds the Starred label by calling an API.
@ -164,7 +162,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Marking messages as starred") }).Trace("Marking messages as starred")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel) return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel)
} }
// MarkMessagesUnstarred removes the Starred label by calling an API. // MarkMessagesUnstarred removes the Starred label by calling an API.
@ -177,7 +175,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
"mailbox": storeMailbox.Name, "mailbox": storeMailbox.Name,
}).Trace("Marking messages as unstarred") }).Trace("Marking messages as unstarred")
defer storeMailbox.pollNow() defer storeMailbox.pollNow()
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel) return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel)
} }
// MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API // MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API
@ -261,11 +259,11 @@ func (storeMailbox *Mailbox) RemoveDeleted(apiIDs []string) error {
} }
case pmapi.DraftLabel: case pmapi.DraftLabel:
storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts") storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts")
if err := storeMailbox.client().DeleteMessages(context.TODO(), apiIDs); err != nil { if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), apiIDs); err != nil {
return err return err
} }
default: default:
if err := storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID); err != nil { if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID); err != nil {
return err return err
} }
} }
@ -303,13 +301,13 @@ func (storeMailbox *Mailbox) deleteFromTrashOrSpam(apiIDs []string) error {
} }
} }
if len(messageIDsToUnlabel) > 0 { if len(messageIDsToUnlabel) > 0 {
if err := storeMailbox.client().UnlabelMessages(context.TODO(), messageIDsToUnlabel, storeMailbox.labelID); err != nil { if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
l.WithError(err).Warning("Cannot unlabel before deleting") l.WithError(err).Warning("Cannot unlabel before deleting")
} }
} }
if len(messageIDsToDelete) > 0 { if len(messageIDsToDelete) > 0 {
storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages") storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages")
if err := storeMailbox.client().DeleteMessages(context.TODO(), messageIDsToDelete); err != nil { if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), messageIDsToDelete); err != nil {
return err return err
} }
} }

View File

@ -5,10 +5,10 @@
package mocks package mocks
import ( import (
reflect "reflect" context "context"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi" pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
reflect "reflect"
) )
// MockPanicHandler is a mock of PanicHandler interface // MockPanicHandler is a mock of PanicHandler interface
@ -207,17 +207,17 @@ func (mr *MockBridgeUserMockRecorder) Logout() *gomock.Call {
} }
// UpdateUser mocks base method // UpdateUser mocks base method
func (m *MockBridgeUser) UpdateUser() error { func (m *MockBridgeUser) UpdateUser(arg0 context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUser") ret := m.ctrl.Call(m, "UpdateUser", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// UpdateUser indicates an expected call of UpdateUser // UpdateUser indicates an expected call of UpdateUser
func (mr *MockBridgeUserMockRecorder) UpdateUser() *gomock.Call { func (mr *MockBridgeUserMockRecorder) UpdateUser(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser), arg0)
} }
// MockChangeNotifier is a mock of ChangeNotifier interface // MockChangeNotifier is a mock of ChangeNotifier interface

View File

@ -5,10 +5,9 @@
package mocks package mocks
import ( import (
gomock "github.com/golang/mock/gomock"
reflect "reflect" reflect "reflect"
time "time" time "time"
gomock "github.com/golang/mock/gomock"
) )
// MockListener is a mock of Listener interface // MockListener is a mock of Listener interface
@ -58,6 +57,20 @@ func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
} }
// ProvideChannel mocks base method
func (m *MockListener) ProvideChannel(arg0 string) <-chan string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ProvideChannel", arg0)
ret0, _ := ret[0].(<-chan string)
return ret0
}
// ProvideChannel indicates an expected call of ProvideChannel
func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0)
}
// Remove mocks base method // Remove mocks base method
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) { func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -101,6 +101,18 @@ var (
ErrNoSuchSeqNum = errors.New("no such sequence number") //nolint[gochecknoglobals] ErrNoSuchSeqNum = errors.New("no such sequence number") //nolint[gochecknoglobals]
) )
// exposeContextForIMAP should be replaced once with context passed
// as an argument from IMAP package and IMAP library should cancel
// context when IMAP client cancels the request.
func exposeContextForIMAP() context.Context {
return context.TODO()
}
// exposeContextForSMTP is the same as above but for SMTP.
func exposeContextForSMTP() context.Context {
return context.TODO()
}
// Store is local user storage, which handles the synchronization between IMAP and PM API. // Store is local user storage, which handles the synchronization between IMAP and PM API.
type Store struct { type Store struct {
sentryReporter *sentry.Reporter sentryReporter *sentry.Reporter
@ -278,7 +290,7 @@ func (store *Store) client() pmapi.Client {
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if // initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
// the API is unavailable for whatever reason it tries to fetch the labels locally. // the API is unavailable for whatever reason it tries to fetch the labels locally.
func (store *Store) initCounts() (labels []*pmapi.Label, err error) { func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
if labels, err = store.client().ListLabels(context.TODO()); err != nil { if labels, err = store.client().ListLabels(context.Background()); err != nil {
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.") store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
if labels, err = store.getLabelsFromLocalStorage(); err != nil { if labels, err = store.getLabelsFromLocalStorage(); err != nil {
store.log.WithError(err).Error("Cannot list local labels") store.log.WithError(err).Error("Cannot list local labels")

View File

@ -184,8 +184,8 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client) mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{ mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive}, {ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive}, {ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true},
}) })
mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes() mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes()
mocks.client.EXPECT().CountMessages(gomock.Any(), "") mocks.client.EXPECT().CountMessages(gomock.Any(), "")

View File

@ -148,7 +148,7 @@ func getSplitIDAndCount(labelID string, api messageLister, page int) (string, in
Limit: 1, Limit: 1,
} }
// If the page does not exist, an empty page instead of an error is returned. // If the page does not exist, an empty page instead of an error is returned.
messages, total, err := api.ListMessages(context.TODO(), filter) messages, total, err := api.ListMessages(context.Background(), filter)
if err != nil { if err != nil {
return "", 0, errors.Wrap(err, "failed to list messages") return "", 0, errors.Wrap(err, "failed to list messages")
} }
@ -190,7 +190,7 @@ func syncBatch( //nolint[funlen]
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page") log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
messages, _, err := api.ListMessages(context.TODO(), filter) messages, _, err := api.ListMessages(context.Background(), filter)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to list messages") return errors.Wrap(err, "failed to list messages")
} }

View File

@ -17,7 +17,11 @@
package store package store
import "github.com/ProtonMail/proton-bridge/pkg/pmapi" import (
"context"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type PanicHandler interface { type PanicHandler interface {
HandlePanic() HandlePanic()
@ -32,7 +36,7 @@ type BridgeUser interface {
GetPrimaryAddress() string GetPrimaryAddress() string
GetStoreAddresses() []string GetStoreAddresses() []string
GetClient() pmapi.Client GetClient() pmapi.Client
UpdateUser() error UpdateUser(context.Context) error
CloseAllConnections() CloseAllConnections()
CloseConnection(string) CloseConnection(string)
Logout() error Logout() error

View File

@ -17,8 +17,6 @@
package store package store
import "context"
// UserID returns user ID. // UserID returns user ID.
func (store *Store) UserID() string { func (store *Store) UserID() string {
return store.user.ID() return store.user.ID()
@ -26,7 +24,7 @@ func (store *Store) UserID() string {
// GetSpace returns used and total space in bytes. // GetSpace returns used and total space in bytes.
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) { func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
apiUser, err := store.client().CurrentUser(context.TODO()) apiUser, err := store.client().CurrentUser(exposeContextForIMAP())
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
@ -35,7 +33,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
// GetMaxUpload returns max size of message + all attachments in bytes. // GetMaxUpload returns max size of message + all attachments in bytes.
func (store *Store) GetMaxUpload() (int64, error) { func (store *Store) GetMaxUpload() (int64, error) {
apiUser, err := store.client().CurrentUser(context.TODO()) apiUser, err := store.client().CurrentUser(exposeContextForIMAP())
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -147,7 +147,7 @@ func (store *Store) createOrUpdateAddressInfo(addressList pmapi.AddressList) (er
// filterAddresses filters out inactive addresses and ensures the original address is listed first. // filterAddresses filters out inactive addresses and ensures the original address is listed first.
func filterAddresses(addressList pmapi.AddressList) (filteredList pmapi.AddressList) { func filterAddresses(addressList pmapi.AddressList) (filteredList pmapi.AddressList) {
for _, address := range addressList { for _, address := range addressList {
if address.Receive != pmapi.CanReceive { if !address.Receive {
continue continue
} }

View File

@ -18,7 +18,6 @@
package store package store
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
@ -39,14 +38,14 @@ func (store *Store) createMailbox(name string) error {
color := store.leastUsedColor() color := store.leastUsedColor()
var exclusive int var exclusive bool
switch { switch {
case strings.HasPrefix(name, UserLabelsPrefix): case strings.HasPrefix(name, UserLabelsPrefix):
name = strings.TrimPrefix(name, UserLabelsPrefix) name = strings.TrimPrefix(name, UserLabelsPrefix)
exclusive = 0 exclusive = false
case strings.HasPrefix(name, UserFoldersPrefix): case strings.HasPrefix(name, UserFoldersPrefix):
name = strings.TrimPrefix(name, UserFoldersPrefix) name = strings.TrimPrefix(name, UserFoldersPrefix)
exclusive = 1 exclusive = true
default: default:
// Ideally we would throw an error here, but then Outlook for // Ideally we would throw an error here, but then Outlook for
// macOS keeps trying to make an IMAP Drafts folder and popping // macOS keeps trying to make an IMAP Drafts folder and popping
@ -56,10 +55,10 @@ func (store *Store) createMailbox(name string) error {
return nil return nil
} }
_, err := store.client().CreateLabel(context.TODO(), &pmapi.Label{ _, err := store.client().CreateLabel(exposeContextForIMAP(), &pmapi.Label{
Name: name, Name: name,
Color: color, Color: color,
Exclusive: exclusive, Exclusive: pmapi.Boolean(exclusive),
Type: pmapi.LabelTypeMailbox, Type: pmapi.LabelTypeMailbox,
}) })
return err return err
@ -126,7 +125,7 @@ func (store *Store) leastUsedColor() string {
func (store *Store) updateMailbox(labelID, newName, color string) error { func (store *Store) updateMailbox(labelID, newName, color string) error {
defer store.eventLoop.pollNow() defer store.eventLoop.pollNow()
_, err := store.client().UpdateLabel(context.TODO(), &pmapi.Label{ _, err := store.client().UpdateLabel(exposeContextForIMAP(), &pmapi.Label{
ID: labelID, ID: labelID,
Name: newName, Name: newName,
Color: color, Color: color,
@ -143,15 +142,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
var err error var err error
switch labelID { switch labelID {
case pmapi.SpamLabel: case pmapi.SpamLabel:
err = store.client().EmptyFolder(context.TODO(), pmapi.SpamLabel, addressID) err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.SpamLabel, addressID)
case pmapi.TrashLabel: case pmapi.TrashLabel:
err = store.client().EmptyFolder(context.TODO(), pmapi.TrashLabel, addressID) err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.TrashLabel, addressID)
default: default:
err = fmt.Errorf("cannot empty mailbox %v", labelID) err = fmt.Errorf("cannot empty mailbox %v", labelID)
} }
return err return err
} }
return store.client().DeleteLabel(context.TODO(), labelID) return store.client().DeleteLabel(exposeContextForIMAP(), labelID)
} }
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error { func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
@ -166,7 +165,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
return nil return nil
} }
labels, err := store.client().ListLabels(context.TODO()) labels, err := store.client().ListLabels(exposeContextForIMAP())
if err != nil { if err != nil {
return err return err
} }

View File

@ -19,7 +19,6 @@ package store
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil" "io/ioutil"
@ -58,7 +57,7 @@ func (store *Store) CreateDraft(
} }
draftAction := store.getDraftAction(message) draftAction := store.getDraftAction(message)
draft, err := store.client().CreateDraft(context.TODO(), message, parentID, draftAction) draft, err := store.client().CreateDraft(exposeContextForSMTP(), message, parentID, draftAction)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "failed to create draft") return nil, nil, errors.Wrap(err, "failed to create draft")
} }
@ -70,7 +69,7 @@ func (store *Store) CreateDraft(
for _, att := range attachments { for _, att := range attachments {
att.attachment.MessageID = draft.ID att.attachment.MessageID = draft.ID
createdAttachment, err := store.client().CreateAttachment(context.TODO(), att.attachment, att.encReader, att.sigReader) createdAttachment, err := store.client().CreateAttachment(exposeContextForSMTP(), att.attachment, att.encReader, att.sigReader)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "failed to create attachment") return nil, nil, errors.Wrap(err, "failed to create attachment")
} }
@ -184,7 +183,7 @@ func (store *Store) getDraftAction(message *pmapi.Message) int {
// SendMessage sends the message. // SendMessage sends the message.
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error { func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
defer store.eventLoop.pollNow() defer store.eventLoop.pollNow()
_, _, err := store.client().SendMessage(context.TODO(), messageID, req) _, _, err := store.client().SendMessage(exposeContextForSMTP(), messageID, req)
return err return err
} }

View File

@ -24,6 +24,7 @@ import (
"testing" "testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/golang/mock/gomock"
a "github.com/stretchr/testify/assert" a "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -34,10 +35,10 @@ func TestGetAllMessageIDs(t *testing.T) {
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{}) insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{})
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"}) checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
} }
@ -47,7 +48,7 @@ func TestGetMessageFromDB(t *testing.T) {
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
tests := []struct{ msgID, wantErr string }{ tests := []struct{ msgID, wantErr string }{
{"msg1", ""}, {"msg1", ""},
@ -72,7 +73,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg, err := m.store.getMessageFromDB("msg1") msg, err := m.store.getMessageFromDB("msg1")
require.Nil(t, err) require.Nil(t, err)
@ -104,7 +105,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
a.Equal(t, wantHeader, msg.Header) a.Equal(t, wantHeader, msg.Header)
// Check calculated data are not overridden by reinsert. // Check calculated data are not overridden by reinsert.
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg, err = m.store.getMessageFromDB("msg1") msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err) require.Nil(t, err)
@ -118,8 +119,8 @@ func TestDeleteMessage(t *testing.T) {
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
require.Nil(t, m.store.deleteMessageEvent("msg1")) require.Nil(t, m.store.deleteMessageEvent("msg1"))
@ -127,17 +128,17 @@ func TestDeleteMessage(t *testing.T) {
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}}) checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
} }
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread pmapi.Boolean, labelIDs []string) { //nolint[unparam] func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam]
msg := getTestMessage(id, subject, sender, unread, labelIDs) msg := getTestMessage(id, subject, sender, unread, labelIDs)
require.Nil(t, m.store.createOrUpdateMessageEvent(msg)) require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
} }
func getTestMessage(id, subject, sender string, unread pmapi.Boolean, labelIDs []string) *pmapi.Message { func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message {
address := &mail.Address{Address: sender} address := &mail.Address{Address: sender}
return &pmapi.Message{ return &pmapi.Message{
ID: id, ID: id,
Subject: subject, Subject: subject,
Unread: unread, Unread: pmapi.Boolean(unread),
Sender: address, Sender: address,
ToList: []*mail.Address{address}, ToList: []*mail.Address{address},
LabelIDs: labelIDs, LabelIDs: labelIDs,
@ -162,7 +163,7 @@ func TestCreateDraftCheckMessageSize(t *testing.T) {
defer clear() defer clear()
m.newStoreNoEvents(false) m.newStoreNoEvents(false)
m.client.EXPECT().CurrentUser().Return(&pmapi.User{ m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+. MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+.
}, nil) }, nil)
@ -181,7 +182,7 @@ func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) {
defer clear() defer clear()
m.newStoreNoEvents(false) m.newStoreNoEvents(false)
m.client.EXPECT().CurrentUser().Return(&pmapi.User{ m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+. MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+.
}, nil) }, nil)

View File

@ -35,7 +35,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
// updateCountsFromServer will download and set the counts. // updateCountsFromServer will download and set the counts.
func (store *Store) updateCountsFromServer() error { func (store *Store) updateCountsFromServer() error {
counts, err := store.client().CountMessages(context.TODO(), "") counts, err := store.client().CountMessages(context.Background(), "")
if err != nil { if err != nil {
return errors.Wrap(err, "cannot update counts from server") return errors.Wrap(err, "cannot update counts from server")
} }

View File

@ -31,8 +31,8 @@ func TestLoadSaveSyncState(t *testing.T) {
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
// Clear everything. // Clear everything.

View File

@ -5,11 +5,10 @@
package mocks package mocks
import ( import (
reflect "reflect"
imap "github.com/emersion/go-imap" imap "github.com/emersion/go-imap"
sasl "github.com/emersion/go-sasl" sasl "github.com/emersion/go-sasl"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
reflect "reflect"
) )
// MockPanicHandler is a mock of PanicHandler interface // MockPanicHandler is a mock of PanicHandler interface

View File

@ -19,7 +19,9 @@ package transfer
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"net/http"
"strings" "strings"
"time" "time"
@ -37,6 +39,8 @@ const (
imapRetries = 10 imapRetries = 10
imapReconnectTimeout = 30 * time.Minute imapReconnectTimeout = 30 * time.Minute
imapReconnectSleep = time.Minute imapReconnectSleep = time.Minute
protonStatusURL = "http://protonstatus.com/vpn_status"
) )
type imapErrorLogger struct { type imapErrorLogger struct {
@ -117,19 +121,15 @@ func (p *IMAPProvider) tryReconnect(ensureSelectedIn string) error {
return previousErr return previousErr
} }
// FIXME(conman): This should register as connection observer. err := checkConnection()
/*
err := pmapi.CheckConnection()
log.WithError(err).Debug("Connection check") log.WithError(err).Debug("Connection check")
if err != nil { if err != nil {
time.Sleep(imapReconnectSleep) time.Sleep(imapReconnectSleep)
previousErr = err previousErr = err
continue continue
} }
*/
err := p.reauth() err = p.reauth()
log.WithError(err).Debug("Reauth") log.WithError(err).Debug("Reauth")
if err != nil { if err != nil {
time.Sleep(imapReconnectSleep) time.Sleep(imapReconnectSleep)
@ -289,3 +289,23 @@ func (p *IMAPProvider) fetchHelper(uid bool, ensureSelectedIn string, seqSet *im
return err return err
}, ensureSelectedIn) }, ensureSelectedIn)
} }
// checkConnection returns an error if there is no internet connection.
// Note we don't want to use client manager because it only reports connection
// issues with API; we are only interested here whether we can reach
// third-party IMAP servers.
func checkConnection() error {
client := &http.Client{Timeout: time.Second * 10}
resp, err := client.Get(protonStatusURL)
if err != nil {
return err
}
_ = resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("HTTP status code %d", resp.StatusCode)
}
return nil
}

View File

@ -52,15 +52,10 @@ func (p *PMAPIProvider) CreateMailbox(mailbox Mailbox) (Mailbox, error) {
return Mailbox{}, errors.New("mailbox is already created") return Mailbox{}, errors.New("mailbox is already created")
} }
exclusive := 0 label, err := p.client.CreateLabel(context.Background(), &pmapi.Label{
if mailbox.IsExclusive {
exclusive = 1
}
label, err := p.client.CreateLabel(context.TODO(), &pmapi.Label{
Name: mailbox.Name, Name: mailbox.Name,
Color: mailbox.Color, Color: mailbox.Color,
Exclusive: exclusive, Exclusive: pmapi.Boolean(mailbox.IsExclusive),
Type: pmapi.LabelTypeMailbox, Type: pmapi.LabelTypeMailbox,
}) })
if err != nil { if err != nil {
@ -126,7 +121,7 @@ func (p *PMAPIProvider) importDraft(msg Message, globalMailbox *Mailbox) (string
} }
if message.Sender == nil { if message.Sender == nil {
mainAddress := p.client().Addresses().Main() mainAddress := p.client.Addresses().Main()
message.Sender = &mail.Address{ message.Sender = &mail.Address{
Name: mainAddress.DisplayName, Name: mainAddress.DisplayName,
Address: mainAddress.Email, Address: mainAddress.Email,
@ -227,14 +222,6 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
} }
} }
var unread pmapi.Boolean
if msg.Unread {
unread = pmapi.True
} else {
unread = pmapi.False
}
labelIDs := []string{} labelIDs := []string{}
for _, target := range msg.Targets { for _, target := range msg.Targets {
// Frontend should not set All Mail to Rules, but to be sure... // Frontend should not set All Mail to Rules, but to be sure...
@ -249,7 +236,7 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
return &pmapi.ImportMsgReq{ return &pmapi.ImportMsgReq{
Metadata: &pmapi.ImportMetadata{ Metadata: &pmapi.ImportMetadata{
AddressID: p.addressID, AddressID: p.addressID,
Unread: unread, Unread: pmapi.Boolean(msg.Unread),
Time: message.Time, Time: message.Time,
Flags: computeMessageFlags(message.Header), Flags: computeMessageFlags(message.Header),
LabelIDs: labelIDs, LabelIDs: labelIDs,

View File

@ -153,10 +153,10 @@ func setupPMAPIRules(rules transferRules) {
func setupPMAPIClientExpectationForExport(m *mocks) { func setupPMAPIClientExpectationForExport(m *mocks) {
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes() m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{ m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: 0, Order: 2}, {ID: "label1", Name: "Foo", Color: "blue", Exclusive: false, Order: 2},
{ID: "label2", Name: "Bar", Color: "green", Exclusive: 0, Order: 1}, {ID: "label2", Name: "Bar", Color: "green", Exclusive: false, Order: 1},
{ID: "folder1", Name: "One", Color: "red", Exclusive: 1, Order: 1}, {ID: "folder1", Name: "One", Color: "red", Exclusive: true, Order: 1},
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: 1, Order: 2}, {ID: "folder2", Name: "Two", Color: "orange", Exclusive: true, Order: 2},
}, nil).AnyTimes() }, nil).AnyTimes()
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{ m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{
{LabelID: "label1", Total: 10}, {LabelID: "label1", Total: 10},

View File

@ -30,9 +30,17 @@ import (
const ( const (
pmapiRetries = 10 pmapiRetries = 10
pmapiReconnectTimeout = 30 * time.Minute pmapiReconnectTimeout = 30 * time.Minute
pmapiReconnectSleep = time.Minute pmapiReconnectSleep = 10 * time.Second
) )
func (p *PMAPIProvider) SetConnectionUp() {
p.connection = true
}
func (p *PMAPIProvider) SetConnectionDown() {
p.connection = false
}
func (p *PMAPIProvider) ensureConnection(callback func() error) error { func (p *PMAPIProvider) ensureConnection(callback func() error) error {
var callErr error var callErr error
for i := 1; i <= pmapiRetries; i++ { for i := 1; i <= pmapiRetries; i++ {
@ -58,18 +66,10 @@ func (p *PMAPIProvider) tryReconnect() error {
return previousErr return previousErr
} }
// FIXME(conman): This should register as a connection observer somehow... if !p.connection {
// Maybe the entire "provider" could register as an observer and pause if it is notified of dropped connection?
/*
err := p.clientManager.CheckConnection()
log.WithError(err).Debug("Connection check")
if err != nil {
time.Sleep(pmapiReconnectSleep) time.Sleep(pmapiReconnectSleep)
previousErr = err
continue continue
} }
*/
break break
} }
@ -83,7 +83,7 @@ func (p *PMAPIProvider) listMessages(filter *pmapi.MessagesFilter) (messages []*
p.timeIt.start("listing", key) p.timeIt.start("listing", key)
defer p.timeIt.stop("listing", key) defer p.timeIt.stop("listing", key)
messages, count, err = p.client.ListMessages(context.TODO(), filter) messages, count, err = p.client.ListMessages(context.Background(), filter)
return err return err
}) })
return return
@ -94,7 +94,7 @@ func (p *PMAPIProvider) getMessage(msgID string) (message *pmapi.Message, err er
p.timeIt.start("download", msgID) p.timeIt.start("download", msgID)
defer p.timeIt.stop("download", msgID) defer p.timeIt.stop("download", msgID)
message, err = p.client.GetMessage(context.TODO(), msgID) message, err = p.client.GetMessage(context.Background(), msgID)
return err return err
}) })
return return
@ -105,7 +105,7 @@ func (p *PMAPIProvider) importRequest(msgSourceID string, req pmapi.ImportMsgReq
p.timeIt.start("upload", msgSourceID) p.timeIt.start("upload", msgSourceID)
defer p.timeIt.stop("upload", msgSourceID) defer p.timeIt.stop("upload", msgSourceID)
res, err = p.client.Import(context.TODO(), req) res, err = p.client.Import(context.Background(), req)
return err return err
}) })
return return
@ -116,7 +116,7 @@ func (p *PMAPIProvider) createDraft(msgSourceID string, message *pmapi.Message,
p.timeIt.start("upload", msgSourceID) p.timeIt.start("upload", msgSourceID)
defer p.timeIt.stop("upload", msgSourceID) defer p.timeIt.stop("upload", msgSourceID)
draft, err = p.client.CreateDraft(context.TODO(), message, parent, action) draft, err = p.client.CreateDraft(context.Background(), message, parent, action)
return err return err
}) })
return return
@ -129,7 +129,7 @@ func (p *PMAPIProvider) createAttachment(msgSourceID string, att *pmapi.Attachme
p.timeIt.start("upload", key) p.timeIt.start("upload", key)
defer p.timeIt.stop("upload", key) defer p.timeIt.stop("upload", key)
created, err = p.client.CreateAttachment(context.TODO(), att, r, sig) created, err = p.client.CreateAttachment(context.Background(), att, r, sig)
return err return err
}) })
return return

View File

@ -28,6 +28,7 @@ import (
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/internal/config/settings" "github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -274,7 +275,7 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
wg.Wait() wg.Wait()
} }
func newTestUpdater(manager *mocks.MockManager, curVer string, earlyAccess bool) *Updater { func newTestUpdater(manager pmapi.Manager, curVer string, earlyAccess bool) *Updater {
return New( return New(
manager, manager,
&fakeInstaller{}, &fakeInstaller{},

View File

@ -1,251 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestUpdateUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil),
m.pmapiClient.EXPECT().ReloadKeys([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
gomock.InOrder(
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
)
assert.NoError(t, user.UpdateUser())
waitForEvents()
}
func TestUserSwitchAddressMode(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
assert.True(t, user.store.IsCombinedMode())
assert.True(t, user.creds.IsCombinedAddressMode)
assert.True(t, user.IsCombinedAddressMode())
waitForEvents()
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsSplit, nil),
)
assert.NoError(t, user.SwitchAddressMode())
assert.False(t, user.store.IsCombinedMode())
assert.False(t, user.creds.IsCombinedAddressMode)
assert.False(t, user.IsCombinedAddressMode())
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
assert.NoError(t, user.SwitchAddressMode())
assert.True(t, user.store.IsCombinedMode())
assert.True(t, user.creds.IsCombinedAddressMode)
assert.True(t, user.IsCombinedAddressMode())
waitForEvents()
}
func TestLogoutUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := user.Logout()
waitForEvents()
assert.NoError(t, err)
}
func TestLogoutUserFailsLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := user.Logout()
waitForEvents()
assert.NoError(t, err)
}
func TestCheckBridgeLoginOK(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
waitForEvents()
assert.NoError(t, err)
}
func TestCheckBridgeLoginTwiceOK(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().IsUnlocked().Return(true),
)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
waitForEvents()
assert.NoError(t, err)
err = user.CheckBridgeLogin(testCredentials.BridgePassword)
waitForEvents()
assert.NoError(t, err)
}
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
isApplicationOutdated = true
err := user.CheckBridgeLogin("any-pass")
waitForEvents()
assert.Equal(t, pmapi.ErrUpgradeApplication, err)
isApplicationOutdated = false
}
func TestCheckBridgeLoginLoggedOut(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
assert.NoError(t, err)
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
err = user.init()
assert.Error(t, err)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
waitForEvents()
assert.Equal(t, ErrLoggedOutUser, err)
}
func TestCheckBridgeLoginBadPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
)
err := user.CheckBridgeLogin("wrong!")
waitForEvents()
assert.Equal(t, "backend/credentials: incorrect password", err.Error())
}

View File

@ -1,112 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
gomock "github.com/golang/mock/gomock"
a "github.com/stretchr/testify/assert"
)
func TestNewUserNoCredentialsStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
a.Error(t, err)
}
func TestNewUserAuthRefreshFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserHasCredentials(testCredentialsDisconnected, m)
}
func TestNewUserUnlockFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(errors.New("bad password")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserHasCredentials(testCredentialsDisconnected, m)
}
func TestNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
mockEventLoopNoAction(m)
checkNewUserHasCredentials(testCredentials, m)
}
func checkNewUserHasCredentials(creds *credentials.Credentials, m mocks) {
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
defer cleanUpUserData(user)
_ = user.init()
waitForEvents()
a.Equal(m.t, creds, user.creds)
}
func _TestUserEventRefreshUpdatesAddresses(t *testing.T) { // nolint[funlen]
a.Fail(t, "not implemented")
}

View File

@ -1,89 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testNewUser sets up a new, authorised user.
func testNewUser(m mocks) *User {
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
mockEventLoopNoAction(m)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
assert.NoError(m.t, err)
err = user.init()
assert.NoError(m.t, err)
mockAuthUpdate(user, "reftok", m)
return user
}
func testNewUserForLogout(m mocks) *User {
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
mockEventLoopNoAction(m)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
assert.NoError(m.t, err)
err = user.init()
assert.NoError(m.t, err)
return user
}
func cleanUpUserData(u *User) {
_ = u.clearStore()
}
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
assert.Fail(t, "not implemented")
}
func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
require.Nil(t, user.store.Close())
user.store = nil
assert.Nil(t, user.clearStore())
}
func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
assert.NotNil(t, user.store)
assert.Nil(t, user.clearStore())
}

View File

@ -1,143 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func TestGetNoUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "nouser", -1, "user nouser not found")
}
func TestGetUserByID(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "user", 0, "")
checkUsersGetUser(t, m, "users", 1, "")
}
func TestGetUserByName(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "username", 0, "")
checkUsersGetUser(t, m, "usersname", 1, "")
}
func TestGetUserByEmail(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "user@pm.me", 0, "")
checkUsersGetUser(t, m, "users@pm.me", 1, "")
checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "")
checkUsersGetUser(t, m, "alsouser@pm.me", 1, "")
}
func TestDeleteUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := users.DeleteUser("user", true)
assert.NoError(t, err)
assert.Equal(t, 1, len(users.users))
}
// Even when logout fails, delete is done.
func TestDeleteUserWithFailingLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("no such user")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := users.DeleteUser("user", true)
assert.NoError(t, err)
assert.Equal(t, 1, len(users.users))
}
func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) {
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
user, err := users.GetUser(query)
waitForEvents()
if expectedError != "" {
assert.Equal(m.t, expectedError, err.Error())
} else {
assert.NoError(m.t, err)
}
var expectedUser *User
if index >= 0 {
expectedUser = users.users[index]
}
assert.Equal(m.t, expectedUser, user)
}

View File

@ -1,219 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestUsersFinishLoginBadMailboxPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
gomock.InOrder(
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
// Set up mocks for FinishLogin.
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(errors.New("no keys could be unlocked")),
m.pmapiClient.EXPECT().DeleteAuth(),
m.pmapiClient.EXPECT().Logout(),
)
checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", ErrWrongMailboxPassword)
}
func refreshWithToken(token string) *pmapi.Auth {
return &pmapi.Auth{
RefreshToken: token,
}
}
func credentialsWithToken(token string) *credentials.Credentials {
tmp := &credentials.Credentials{}
*tmp = *testCredentials
tmp.APIToken = token
return tmp
}
func TestUsersFinishLoginNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Basically every call client has get client manager
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
// users.New() finds no users in keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
// getAPIUser() loads user info from API (e.g. userID).
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
// addNewUser()
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}),
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
// user.init() in addNewUser
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
// store.New() in user.init
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Emit event for new user and send metrics.
m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient),
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)),
m.pmapiClient.EXPECT().Logout(),
// Reload account list in GUI.
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
// defer logout anonymous
m.pmapiClient.EXPECT().Logout(),
)
mockEventLoopNoAction(m)
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
mockAuthUpdate(user, "afterCredentials", m)
}
func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
loggedOutCreds := *testCredentials
loggedOutCreds.APIToken = ""
loggedOutCreds.MailboxPassword = ""
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
// users.New() finds one existing user in keychain.
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
// newUser()
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
// user.init()
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
// store.New() in user.init
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUnauthorized),
m.pmapiClient.EXPECT().Addresses().Return(nil),
// getAPIUser() loads user info from API (e.g. userID).
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
// connectExistingUser()
m.credentialsStore.EXPECT().UpdatePassword("user", testCredentials.MailboxPassword).Return(nil),
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
m.credentialsStore.EXPECT().UpdateToken("user", ":afterLogin").Return(nil),
// user.init() in connectExistingUser
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
// store.New() in user.init
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Reload account list in GUI.
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
// defer logout anonymous
m.pmapiClient.EXPECT().Logout(),
)
mockEventLoopNoAction(m)
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
mockAuthUpdate(user, "afterCredentials", m)
}
func TestUsersFinishLoginConnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
mockConnectedUser(m)
mockEventLoopNoAction(m)
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
// Then, try to log in again...
gomock.InOrder(
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().DeleteAuth(),
m.pmapiClient.EXPECT().Logout(),
)
_, err := users.FinishLogin(m.pmapiClient, testAuth, testCredentials.MailboxPassword)
assert.Equal(t, "user is already connected", err.Error())
}
func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) *User {
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword)
waitForEvents()
assert.Equal(t, expectedErr, err)
if expectedUserID != "" {
assert.Equal(t, expectedUserID, user.ID())
assert.Equal(t, 1, len(users.users))
assert.Equal(t, expectedUserID, users.users[0].ID())
} else {
assert.Equal(t, (*User)(nil), user)
assert.Equal(t, 0, len(users.users))
}
return user
}

View File

@ -233,7 +233,7 @@ func (s *Store) get(userID string) (creds *Credentials, err error) {
_, secret, err := s.secrets.Get(userID) _, secret, err := s.secrets.Get(userID)
if err != nil { if err != nil {
log.WithError(err).Error("Could not get credentials from native keychain") log.WithError(err).Warn("Could not get credentials from native keychain")
return return
} }

View File

@ -26,8 +26,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert" r "github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
) )
const testSep = "\n" const testSep = "\n"
@ -249,26 +248,26 @@ func TestMarshalFormats(t *testing.T) {
log.Infof("secretFmt %#v %d\n", secretFmt, len(secretFmt)) log.Infof("secretFmt %#v %d\n", secretFmt, len(secretFmt))
output := testCredentials{APIToken: "refresh"} output := testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalStrings(secretStrings)) r.NoError(t, output.UnmarshalStrings(secretStrings))
log.Infof("strings out %#v \n", output) log.Infof("strings out %#v \n", output)
require.True(t, input.IsSame(&output), "strings out not same") r.True(t, input.IsSame(&output), "strings out not same")
output = testCredentials{APIToken: "refresh"} output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalGob(secretGob)) r.NoError(t, output.UnmarshalGob(secretGob))
log.Infof("gob out %#v\n \n", output) log.Infof("gob out %#v\n \n", output)
assert.Equal(t, input, output) r.Equal(t, input, output)
output = testCredentials{APIToken: "refresh"} output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.FromJSON(secretJSON)) r.NoError(t, output.FromJSON(secretJSON))
log.Infof("json out %#v \n", output) log.Infof("json out %#v \n", output)
require.True(t, input.IsSame(&output), "json out not same") r.True(t, input.IsSame(&output), "json out not same")
/* /*
// Simple Fscanf not working! // Simple Fscanf not working!
output = testCredentials{APIToken: "refresh"} output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalFmt(secretFmt)) r.NoError(t, output.UnmarshalFmt(secretFmt))
log.Infof("fmt out %#v \n", output) log.Infof("fmt out %#v \n", output)
require.True(t, input.IsSame(&output), "fmt out not same") r.True(t, input.IsSame(&output), "fmt out not same")
*/ */
} }
@ -291,7 +290,7 @@ func TestMarshal(t *testing.T) {
log.Infof("secret %#v %d\n", secret, len(secret)) log.Infof("secret %#v %d\n", secret, len(secret))
output := Credentials{APIToken: "refresh"} output := Credentials{APIToken: "refresh"}
require.NoError(t, output.Unmarshal(secret)) r.NoError(t, output.Unmarshal(secret))
log.Infof("output %#v\n", output) log.Infof("output %#v\n", output)
assert.Equal(t, input, output) r.Equal(t, input, output)
} }

View File

@ -1,107 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./listener/listener.go
// Package users is a generated GoMock package.
package users
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockListener is a mock of Listener interface
type MockListener struct {
ctrl *gomock.Controller
recorder *MockListenerMockRecorder
}
// MockListenerMockRecorder is the mock recorder for MockListener
type MockListenerMockRecorder struct {
mock *MockListener
}
// NewMockListener creates a new mock instance
func NewMockListener(ctrl *gomock.Controller) *MockListener {
mock := &MockListener{ctrl: ctrl}
mock.recorder = &MockListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
return m.recorder
}
// SetLimit mocks base method
func (m *MockListener) SetLimit(eventName string, limit time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLimit", eventName, limit)
}
// SetLimit indicates an expected call of SetLimit
func (mr *MockListenerMockRecorder) SetLimit(eventName, limit interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), eventName, limit)
}
// Add mocks base method
func (m *MockListener) Add(eventName string, channel chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", eventName, channel)
}
// Add indicates an expected call of Add
func (mr *MockListenerMockRecorder) Add(eventName, channel interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), eventName, channel)
}
// Remove mocks base method
func (m *MockListener) Remove(eventName string, channel chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", eventName, channel)
}
// Remove indicates an expected call of Remove
func (mr *MockListenerMockRecorder) Remove(eventName, channel interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), eventName, channel)
}
// Emit mocks base method
func (m *MockListener) Emit(eventName, data string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Emit", eventName, data)
}
// Emit indicates an expected call of Emit
func (mr *MockListenerMockRecorder) Emit(eventName, data interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), eventName, data)
}
// SetBuffer mocks base method
func (m *MockListener) SetBuffer(eventName string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetBuffer", eventName)
}
// SetBuffer indicates an expected call of SetBuffer
func (mr *MockListenerMockRecorder) SetBuffer(eventName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), eventName)
}
// RetryEmit mocks base method
func (m *MockListener) RetryEmit(eventName string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RetryEmit", eventName)
}
// RetryEmit indicates an expected call of RetryEmit
func (mr *MockListenerMockRecorder) RetryEmit(eventName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), eventName)
}

View File

@ -0,0 +1,120 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/pkg/listener (interfaces: Listener)
// Package mocks is a generated GoMock package.
package mocks
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
time "time"
)
// MockListener is a mock of Listener interface
type MockListener struct {
ctrl *gomock.Controller
recorder *MockListenerMockRecorder
}
// MockListenerMockRecorder is the mock recorder for MockListener
type MockListenerMockRecorder struct {
mock *MockListener
}
// NewMockListener creates a new mock instance
func NewMockListener(ctrl *gomock.Controller) *MockListener {
mock := &MockListener{ctrl: ctrl}
mock.recorder = &MockListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
return m.recorder
}
// Add mocks base method
func (m *MockListener) Add(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0, arg1)
}
// Add indicates an expected call of Add
func (mr *MockListenerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), arg0, arg1)
}
// Emit mocks base method
func (m *MockListener) Emit(arg0, arg1 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Emit", arg0, arg1)
}
// Emit indicates an expected call of Emit
func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
}
// ProvideChannel mocks base method
func (m *MockListener) ProvideChannel(arg0 string) <-chan string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ProvideChannel", arg0)
ret0, _ := ret[0].(<-chan string)
return ret0
}
// ProvideChannel indicates an expected call of ProvideChannel
func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0)
}
// Remove mocks base method
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", arg0, arg1)
}
// Remove indicates an expected call of Remove
func (mr *MockListenerMockRecorder) Remove(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), arg0, arg1)
}
// RetryEmit mocks base method
func (m *MockListener) RetryEmit(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RetryEmit", arg0)
}
// RetryEmit indicates an expected call of RetryEmit
func (mr *MockListenerMockRecorder) RetryEmit(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), arg0)
}
// SetBuffer mocks base method
func (m *MockListener) SetBuffer(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetBuffer", arg0)
}
// SetBuffer indicates an expected call of SetBuffer
func (mr *MockListenerMockRecorder) SetBuffer(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), arg0)
}
// SetLimit mocks base method
func (m *MockListener) SetLimit(arg0 string, arg1 time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLimit", arg0, arg1)
}
// SetLimit indicates an expected call of SetLimit
func (mr *MockListenerMockRecorder) SetLimit(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), arg0, arg1)
}

View File

@ -5,11 +5,10 @@
package mocks package mocks
import ( import (
reflect "reflect"
store "github.com/ProtonMail/proton-bridge/internal/store" store "github.com/ProtonMail/proton-bridge/internal/store"
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials" credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
reflect "reflect"
) )
// MockLocator is a mock of Locator interface // MockLocator is a mock of Locator interface

View File

@ -89,29 +89,25 @@ func newUser(
// - providing it with an authorised API client // - providing it with an authorised API client
// - loading its credentials from the credentials store // - loading its credentials from the credentials store
// - loading and unlocking its PGP keys // - loading and unlocking its PGP keys
// - loading its store // - loading its store.
func (u *User) connect(ctx context.Context, client pmapi.Client, creds *credentials.Credentials) error { func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) error {
u.log.Info("Connecting user") u.log.Info("Connecting user")
// Connected users have an API client. // Connected users have an API client.
u.client = client u.client = client
// FIXME(conman): How to remove this auth handler when user is disconnected? u.client.AddAuthRefreshHandler(u.handleAuthRefresh)
u.client.AddAuthHandler(u.handleAuth)
// Save the latest credentials for the user. // Save the latest credentials for the user.
u.creds = creds u.creds = creds
// Connected users have unlocked keys. // Connected users have unlocked keys.
// FIXME(conman): clients should always be authorized! This is a workaround to avoid a major refactor :( if err := u.unlockIfNecessary(); err != nil {
if u.creds.IsConnected() {
if err := u.client.Unlock(ctx, []byte(u.creds.MailboxPassword)); err != nil {
return err return err
} }
}
// Connected users have a store. // Connected users have a store.
if err := u.loadStore(); err != nil { if err := u.loadStore(); err != nil { //nolint[revive] easier to read
return err return err
} }
@ -138,17 +134,25 @@ func (u *User) loadStore() error {
return nil return nil
} }
func (u *User) handleAuth(auth *pmapi.Auth) error { func (u *User) handleAuthRefresh(auth *pmapi.AuthRefresh) {
u.log.Debug("User received auth") u.log.Debug("User received auth refresh update")
if auth == nil {
if err := u.logout(); err != nil {
log.WithError(err).
WithField("userID", u.userID).
Error("User logout failed while watching API auths")
}
return
}
creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken) creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to update refresh token in credentials store") u.log.WithError(err).Error("Failed to update refresh token in credentials store")
return
} }
u.creds = creds u.creds = creds
return nil
} }
// clearStore removes the database. // clearStore removes the database.
@ -181,13 +185,6 @@ func (u *User) closeStore() error {
return nil return nil
} }
// GetTemporaryPMAPIClient returns an authorised PMAPI client.
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
// After proper refactor of SMTP and IMAP remove this method.
func (u *User) GetTemporaryPMAPIClient() pmapi.Client {
return u.client
}
// ID returns the user's userID. // ID returns the user's userID.
func (u *User) ID() string { func (u *User) ID() string {
return u.userID return u.userID
@ -210,9 +207,43 @@ func (u *User) IsConnected() bool {
} }
func (u *User) GetClient() pmapi.Client { func (u *User) GetClient() pmapi.Client {
if err := u.unlockIfNecessary(); err != nil {
u.log.WithError(err).Error("Failed to unlock user")
}
return u.client return u.client
} }
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
func (u *User) unlockIfNecessary() error {
if !u.creds.IsConnected() {
return nil
}
if u.client.IsUnlocked() {
return nil
}
// unlockIfNecessary is called with every access to underlying pmapi
// client. Unlock should only finish unlocking when connection is back up.
// That means it should try it fast enough and not retry if connection
// is still down.
err := u.client.Unlock(pmapi.ContextWithoutRetry(context.Background()), []byte(u.creds.MailboxPassword))
if err == nil {
return nil
}
switch errors.Cause(err) {
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
u.log.WithError(err).Warn("Could not unlock user")
return nil
}
if logoutErr := u.logout(); logoutErr != nil {
u.log.WithError(logoutErr).Warn("Could not logout user")
}
return errors.Wrap(err, "failed to unlock user")
}
// IsCombinedAddressMode returns whether user is set in combined or split mode. // IsCombinedAddressMode returns whether user is set in combined or split mode.
// Combined mode is the default mode and is what users typically need. // Combined mode is the default mode and is what users typically need.
// Split mode is mostly for outlook as it cannot handle sending e-mails from an // Split mode is mostly for outlook as it cannot handle sending e-mails from an
@ -307,14 +338,10 @@ func (u *User) GetBridgePassword() string {
// CheckBridgeLogin checks whether the user is logged in and the bridge // CheckBridgeLogin checks whether the user is logged in and the bridge
// IMAP/SMTP password is correct. // IMAP/SMTP password is correct.
func (u *User) CheckBridgeLogin(password string) error { func (u *User) CheckBridgeLogin(password string) error {
// FIXME(conman): Handle force upgrade?
/*
if isApplicationOutdated { if isApplicationOutdated {
u.listener.Emit(events.UpgradeApplicationEvent, "") u.listener.Emit(events.UpgradeApplicationEvent, "")
return pmapi.ErrUpgradeApplication return pmapi.ErrUpgradeApplication
} }
*/
u.lock.RLock() u.lock.RLock()
defer u.lock.RUnlock() defer u.lock.RUnlock()
@ -328,16 +355,16 @@ func (u *User) CheckBridgeLogin(password string) error {
} }
// UpdateUser updates user details from API and saves to the credentials. // UpdateUser updates user details from API and saves to the credentials.
func (u *User) UpdateUser() error { func (u *User) UpdateUser(ctx context.Context) error {
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
_, err := u.client.UpdateUser(context.TODO()) _, err := u.client.UpdateUser(ctx)
if err != nil { if err != nil {
return err return err
} }
if err := u.client.ReloadKeys(context.TODO(), []byte(u.creds.MailboxPassword)); err != nil { if err := u.client.ReloadKeys(ctx, []byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to reload keys") return errors.Wrap(err, "failed to reload keys")
} }
@ -414,8 +441,7 @@ func (u *User) Logout() error {
return nil return nil
} }
// FIXME(conman): Do we delete API client now? Who cleans up? What about registered handlers? if err := u.client.AuthDelete(context.Background()); err != nil {
if err := u.client.AuthDelete(context.TODO()); err != nil {
u.log.WithError(err).Warn("Failed to delete auth") u.log.WithError(err).Warn("Failed to delete auth")
} }

View File

@ -0,0 +1,195 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"context"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
r "github.com/stretchr/testify/require"
)
func TestUpdateUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().UpdateUser(gomock.Any()).Return(nil, nil),
m.pmapiClient.EXPECT().ReloadKeys(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}).Return(testCredentials, nil),
)
r.NoError(t, user.UpdateUser(context.Background()))
}
func TestUserSwitchAddressMode(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
// Ignore any sync on background.
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
// Check initial state.
r.True(t, user.store.IsCombinedMode())
r.True(t, user.creds.IsCombinedAddressMode)
r.True(t, user.IsCombinedAddressMode())
// Mock change to split mode.
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentialsSplit, nil),
)
// Check switch to split mode.
r.NoError(t, user.SwitchAddressMode())
r.False(t, user.store.IsCombinedMode())
r.False(t, user.creds.IsCombinedAddressMode)
r.False(t, user.IsCombinedAddressMode())
// MOck change to combined mode.
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentials, nil),
)
// Check switch to combined mode.
r.NoError(t, user.SwitchAddressMode())
r.True(t, user.store.IsCombinedMode())
r.True(t, user.creds.IsCombinedAddressMode)
r.True(t, user.IsCombinedAddressMode())
}
func TestLogoutUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
)
err := user.Logout()
r.NoError(t, err)
}
func TestLogoutUserFailsLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
)
err := user.Logout()
r.NoError(t, err)
}
func TestCheckBridgeLogin(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
r.NoError(t, err)
}
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
isApplicationOutdated = true
err := user.CheckBridgeLogin("any-pass")
r.Equal(t, pmapi.ErrUpgradeApplication, err)
isApplicationOutdated = false
}
func TestCheckBridgeLoginLoggedOut(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
gomock.InOrder(
// Mock init of user.
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
// Mock CheckBridgeLogin.
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"),
)
user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.NoError(t, err)
err = user.connect(m.pmapiClient, testCredentialsDisconnected)
r.Error(t, err)
defer cleanUpUserData(user)
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
r.Equal(t, ErrLoggedOutUser, err)
}
func TestCheckBridgeLoginBadPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
err := user.CheckBridgeLogin("wrong!")
r.EqualError(t, err, "backend/credentials: incorrect password")
}

View File

@ -0,0 +1,88 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
gomock "github.com/golang/mock/gomock"
r "github.com/stretchr/testify/require"
)
func TestNewUserNoCredentialsStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
_, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.Error(t, err)
}
func TestNewUserUnlockFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
gomock.InOrder(
// Init of user.
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("bad password")),
// Handle of unlock error.
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"),
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
)
checkNewUserHasCredentials(m, "failed to unlock user: bad password", testCredentialsDisconnected)
}
func TestNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m)
mockEventLoopNoAction(m)
checkNewUserHasCredentials(m, "", testCredentials)
}
func checkNewUserHasCredentials(m mocks, wantErr string, wantCreds *credentials.Credentials) {
user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.NoError(m.t, err)
defer cleanUpUserData(user)
err = user.connect(m.pmapiClient, testCredentials)
if wantErr == "" {
r.NoError(m.t, err)
} else {
r.EqualError(m.t, err, wantErr)
}
r.Equal(m.t, wantCreds, user.creds)
waitForEvents()
}

View File

@ -0,0 +1,51 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
r "github.com/stretchr/testify/require"
)
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
r.Fail(t, "not implemented")
}
func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
r.Nil(t, user.store.Close())
user.store = nil
r.Nil(t, user.clearStore())
}
func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
r.NotNil(t, user.store)
r.Nil(t, user.clearStore())
}

View File

@ -0,0 +1,41 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
r "github.com/stretchr/testify/require"
)
// testNewUser sets up a new, authorised user.
func testNewUser(m mocks) *User {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m)
mockEventLoopNoAction(m)
user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.NoError(m.t, err)
err = user.connect(m.pmapiClient, creds)
r.NoError(m.t, err)
return user
}
func cleanUpUserData(u *User) {
_ = u.clearStore()
}

View File

@ -89,24 +89,42 @@ func New(
lock: sync.RWMutex{}, lock: sync.RWMutex{},
} }
// FIXME(conman): Handle force upgrade events.
/*
go func() { go func() {
defer panicHandler.HandlePanic() defer panicHandler.HandlePanic()
u.watchAppOutdated() u.watchEvents()
}() }()
*/
if u.credStorer == nil { if u.credStorer == nil {
log.Error("No credentials store is available") log.Error("No credentials store is available")
} else if err := u.loadUsersFromCredentialsStore(context.TODO()); err != nil { } else if err := u.loadUsersFromCredentialsStore(); err != nil {
log.WithError(err).Error("Could not load all users from credentials store") log.WithError(err).Error("Could not load all users from credentials store")
} }
return u return u
} }
func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error { func (u *Users) watchEvents() {
upgradeCh := u.events.ProvideChannel(events.UpgradeApplicationEvent)
internetOnCh := u.events.ProvideChannel(events.InternetOnEvent)
for {
select {
case <-upgradeCh:
isApplicationOutdated = true
u.closeAllConnections()
case <-internetOnCh:
for _, user := range u.users {
if user.store == nil {
if err := user.loadStore(); err != nil {
log.WithError(err).Error("Failed to load store after reconnecting")
}
}
}
}
}
}
func (u *Users) loadUsersFromCredentialsStore() error {
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
@ -116,23 +134,26 @@ func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
} }
for _, userID := range userIDs { for _, userID := range userIDs {
l := log.WithField("user", userID)
user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses) user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
if err != nil { if err != nil {
logrus.WithError(err).Warn("Could not create user, skipping") l.WithError(err).Warn("Could not create user, skipping")
continue continue
} }
u.users = append(u.users, user) u.users = append(u.users, user)
if creds.IsConnected() { if creds.IsConnected() {
if err := u.loadConnectedUser(ctx, user, creds); err != nil { // If there is no connection, we don't want to retry. Load should
logrus.WithError(err).Warn("Could not load connected user") // happen fast enough to not block GUI. When connection is back up,
// watchEvents and unlockIfNecessary will finish user init later.
if err := u.loadConnectedUser(pmapi.ContextWithoutRetry(context.Background()), user, creds); err != nil {
l.WithError(err).Warn("Could not load connected user")
} }
} else { } else {
logrus.Warn("User is disconnected and must be connected manually") l.Warn("User is disconnected and must be connected manually")
if err := user.connect(u.clientManager.NewClient("", "", "", time.Time{}), creds); err != nil {
if err := u.loadDisconnectedUser(ctx, user, creds); err != nil { l.WithError(err).Warn("Could not load disconnected user")
logrus.WithError(err).Warn("Could not load disconnected user")
} }
} }
} }
@ -140,11 +161,6 @@ func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
return err return err
} }
func (u *Users) loadDisconnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
// FIXME(conman): We shouldn't be creating unauthorized clients... this is hacky, just to avoid huge refactor!
return user.connect(ctx, u.clientManager.NewClient("", "", "", time.Time{}), creds)
}
func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error { func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
uid, ref, err := creds.SplitAPIToken() uid, ref, err := creds.SplitAPIToken()
if err != nil { if err != nil {
@ -153,38 +169,27 @@ func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *creden
client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref) client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref)
if err != nil { if err != nil {
// FIXME(conman): This is a problem... if we weren't able to create a new client due to internet, // When client cannot be refreshed right away due to no connection,
// we need to be able to retry later, but I deleted all the hacky "retry auth if necessary" stuff... // we create client which will refresh automatically when possible.
return user.connect(ctx, u.clientManager.NewClient(uid, "", ref, time.Time{}), creds) connectErr := user.connect(u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
switch errors.Cause(err) {
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
return connectErr
}
if logoutErr := user.logout(); logoutErr != nil {
logrus.WithError(logoutErr).Warn("Could not logout user")
}
return errors.Wrap(err, "could not refresh token")
} }
// Update the user's credentials with the latest auth used to connect this user. // Update the user's credentials with the latest auth used to connect this user.
if creds, err = u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil { if creds, err = u.credStorer.UpdateToken(creds.UserID, auth.UID, auth.RefreshToken); err != nil {
return errors.Wrap(err, "could not create get user's refresh token") return errors.Wrap(err, "could not create get user's refresh token")
} }
return user.connect(ctx, client, creds) return user.connect(client, creds)
}
func (u *Users) watchAppOutdated() {
// FIXME(conman): handle force upgrade events.
/*
ch := make(chan string)
u.events.Add(events.UpgradeApplicationEvent, ch)
for {
select {
case <-ch:
isApplicationOutdated = true
u.closeAllConnections()
case <-u.stopAll:
return
}
}
*/
} }
func (u *Users) closeAllConnections() { func (u *Users) closeAllConnections() {
@ -198,19 +203,19 @@ func (u *Users) closeAllConnections() {
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) { func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
u.crashBandicoot(username) u.crashBandicoot(username)
return u.clientManager.NewClientWithLogin(context.TODO(), username, password) return u.clientManager.NewClientWithLogin(context.Background(), username, password)
} }
// FinishLogin finishes the login procedure and adds the user into the credentials store. // FinishLogin finishes the login procedure and adds the user into the credentials store.
func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen] func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen]
apiUser, passphrase, err := getAPIUser(context.TODO(), client, password) apiUser, passphrase, err := getAPIUser(context.Background(), client, password)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to get API user") return nil, err
} }
if user, ok := u.hasUser(apiUser.ID); ok { if user, ok := u.hasUser(apiUser.ID); ok {
if user.IsConnected() { if user.IsConnected() {
if err := client.AuthDelete(context.TODO()); err != nil { if err := client.AuthDelete(context.Background()); err != nil {
logrus.WithError(err).Warn("Failed to delete new auth session") logrus.WithError(err).Warn("Failed to delete new auth session")
} }
@ -228,14 +233,16 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password stri
return nil, errors.Wrap(err, "failed to update password of user in credentials store") return nil, errors.Wrap(err, "failed to update password of user in credentials store")
} }
if err := user.connect(context.TODO(), client, creds); err != nil { if err := user.connect(client, creds); err != nil {
return nil, errors.Wrap(err, "failed to reconnect existing user") return nil, errors.Wrap(err, "failed to reconnect existing user")
} }
u.events.Emit(events.UserRefreshEvent, apiUser.ID)
return user, nil return user, nil
} }
if err := u.addNewUser(context.TODO(), client, apiUser, auth, passphrase); err != nil { if err := u.addNewUser(client, apiUser, auth, passphrase); err != nil {
return nil, errors.Wrap(err, "failed to add new user") return nil, errors.Wrap(err, "failed to add new user")
} }
@ -245,7 +252,7 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password stri
} }
// addNewUser adds a new user. // addNewUser adds a new user.
func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error { func (u *Users) addNewUser(client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
@ -266,7 +273,7 @@ func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pm
return errors.Wrap(err, "failed to create new user") return errors.Wrap(err, "failed to create new user")
} }
if err := user.connect(ctx, client, creds); err != nil { if err := user.connect(client, creds); err != nil {
return errors.Wrap(err, "failed to connect new user") return errors.Wrap(err, "failed to connect new user")
} }
@ -292,7 +299,7 @@ func getAPIUser(ctx context.Context, client pmapi.Client, password string) (*pma
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong. // We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
if err := client.Unlock(ctx, passphrase); err != nil { if err := client.Unlock(ctx, passphrase); err != nil {
return nil, nil, errors.Wrap(err, "failed to unlock client") return nil, nil, ErrWrongMailboxPassword
} }
user, err := client.CurrentUser(ctx) user, err := client.CurrentUser(ctx)
@ -414,22 +421,13 @@ func (u *Users) SendMetric(m metrics.Metric) error {
// AllowProxy instructs the app to use DoH to access an API proxy if necessary. // AllowProxy instructs the app to use DoH to access an API proxy if necessary.
// It also needs to work before the app is initialised (because we may need to use the proxy at startup). // It also needs to work before the app is initialised (because we may need to use the proxy at startup).
func (u *Users) AllowProxy() { func (u *Users) AllowProxy() {
// FIXME(conman): Support DoH. u.clientManager.AllowProxy()
// u.apiManager.AllowProxy()
} }
// DisallowProxy instructs the app to not use DoH to access an API proxy if necessary. // DisallowProxy instructs the app to not use DoH to access an API proxy if necessary.
// It also needs to work before the app is initialised (because we may need to use the proxy at startup). // It also needs to work before the app is initialised (because we may need to use the proxy at startup).
func (u *Users) DisallowProxy() { func (u *Users) DisallowProxy() {
// FIXME(conman): Support DoH. u.clientManager.DisallowProxy()
// u.apiManager.DisallowProxy()
}
// CheckConnection returns whether there is an internet connection.
// This should use the connection manager when it is eventually implemented.
func (u *Users) CheckConnection() error {
// FIXME(conman): Other parts of bridge that rely on this method should register as a connection observer.
panic("TODO: register as a connection observer to get this information")
} }
// hasUser returns whether the struct currently has a user with ID `id`. // hasUser returns whether the struct currently has a user with ID `id`.

View File

@ -0,0 +1,49 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
gomock "github.com/golang/mock/gomock"
r "github.com/stretchr/testify/require"
)
func TestClearData(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil)
m.locator.EXPECT().Clear()
r.NoError(t, users.ClearData())
}

View File

@ -0,0 +1,69 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
gomock "github.com/golang/mock/gomock"
r "github.com/stretchr/testify/require"
)
func TestDeleteUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
gomock.InOrder(
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := users.DeleteUser("user", true)
r.NoError(t, err)
r.Equal(t, 1, len(users.users))
}
// Even when logout fails, delete is done.
func TestDeleteUserWithFailingLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
gomock.InOrder(
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")),
// Once called from user.Logout after failed creds.Logout as fallback, and once at the end of users.Logout.
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := users.DeleteUser("user", true)
r.NoError(t, err)
r.Equal(t, 1, len(users.users))
}

View File

@ -0,0 +1,76 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
r "github.com/stretchr/testify/require"
)
func TestGetNoUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
checkUsersGetUser(t, m, "nouser", -1, "user nouser not found")
}
func TestGetUserByID(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
checkUsersGetUser(t, m, "user", 0, "")
checkUsersGetUser(t, m, "users", 1, "")
}
func TestGetUserByName(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
checkUsersGetUser(t, m, "username", 0, "")
checkUsersGetUser(t, m, "usersname", 1, "")
}
func TestGetUserByEmail(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
checkUsersGetUser(t, m, "user@pm.me", 0, "")
checkUsersGetUser(t, m, "users@pm.me", 1, "")
checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "")
checkUsersGetUser(t, m, "alsouser@pm.me", 1, "")
}
func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) {
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
user, err := users.GetUser(query)
if expectedError != "" {
r.EqualError(m.t, err, expectedError)
} else {
r.NoError(m.t, err)
}
var expectedUser *User
if index >= 0 {
expectedUser = users.users[index]
}
r.Equal(m.t, expectedUser, user)
}

View File

@ -0,0 +1,132 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
r "github.com/stretchr/testify/require"
)
func TestUsersFinishLoginBadMailboxPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
// Set up mocks for FinishLogin.
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil)
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("no keys could be unlocked"))
checkUsersFinishLogin(t, m, testAuthRefresh, testCredentials.MailboxPassword, "", ErrWrongMailboxPassword)
}
func TestUsersFinishLoginNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
mockAddingConnectedUser(m)
mockEventLoopNoAction(m)
m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel))
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentials.UserID)
checkUsersFinishLogin(t, m, testAuthRefresh, testCredentials.MailboxPassword, testCredentials.UserID, nil)
}
func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Mock loading disconnected user.
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID}, nil)
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
// Mock process of FinishLogin of already added user.
gomock.InOrder(
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUserDisconnected, nil),
m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
)
mockInitConnectedUser(m)
mockEventLoopNoAction(m)
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID)
authRefresh := &pmapi.Auth{
UserID: testCredentialsDisconnected.UserID,
AuthRefresh: pmapi.AuthRefresh{
UID: "uid",
AccessToken: "acc",
RefreshToken: "ref",
},
}
checkUsersFinishLogin(t, m, authRefresh, testCredentials.MailboxPassword, testCredentialsDisconnected.UserID, nil)
}
func TestUsersFinishLoginConnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Mock loading connected user.
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockEventLoopNoAction(m)
// Mock process of FinishLogin of already connected user.
gomock.InOrder(
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
)
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
_, err := users.FinishLogin(m.pmapiClient, testAuthRefresh, testCredentials.MailboxPassword)
r.EqualError(t, err, "user is already connected")
}
func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) {
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword)
r.Equal(t, expectedErr, err)
if expectedUserID != "" {
r.Equal(t, expectedUserID, user.ID())
r.Equal(t, 1, len(users.users))
r.Equal(t, expectedUserID, users.users[0].ID())
} else {
r.Equal(t, (*User)(nil), user)
r.Equal(t, 0, len(users.users))
}
}

View File

@ -22,9 +22,10 @@ import (
"testing" "testing"
time "time" time "time"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials" "github.com/ProtonMail/proton-bridge/internal/users/credentials"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" r "github.com/stretchr/testify/require"
) )
func TestNewUsersNoKeychain(t *testing.T) { func TestNewUsersNoKeychain(t *testing.T) {
@ -32,7 +33,6 @@ func TestNewUsersNoKeychain(t *testing.T) {
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{}, errors.New("no keychain")) m.credentialsStore.EXPECT().List().Return([]string{}, errors.New("no keychain"))
checkUsersNew(t, m, []*credentials.Credentials{}) checkUsersNew(t, m, []*credentials.Credentials{})
} }
@ -41,108 +41,73 @@ func TestNewUsersWithoutUsersInCredentialsStore(t *testing.T) {
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{}, nil) m.credentialsStore.EXPECT().List().Return([]string{}, nil)
checkUsersNew(t, m, []*credentials.Credentials{}) checkUsersNew(t, m, []*credentials.Credentials{})
} }
func TestNewUsersWithDisconnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
}
/*
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token"))
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.pmapiClient.EXPECT().Logout()
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
}
func TestNewUsersWithConnectedUser(t *testing.T) { func TestNewUsersWithConnectedUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil) mockLoadingConnectedUser(m, testCredentials)
mockConnectedUser(m)
mockEventLoopNoAction(m) mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentials}) checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
} }
func TestNewUsersWithDisconnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID}, nil)
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
}
// Tests two users with different states and checks also the order from // Tests two users with different states and checks also the order from
// credentials store is kept also in array of users. // credentials store is kept also in array of users.
func TestNewUsersWithUsers(t *testing.T) { func TestNewUsersWithUsers(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil)
m.credentialsStore.EXPECT().List().Return([]string{"userDisconnected", "user"}, nil) mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
mockLoadingConnectedUser(m, testCredentials)
gomock.InOrder(
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
// Set up mocks for store initialisation for the unauth user.
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
mockConnectedUser(m)
mockEventLoopNoAction(m) mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials}) checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
} }
func TestNewUsersFirstStart(t *testing.T) { func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{}, nil) m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, errors.New("bad token"))
m.clientManager.EXPECT().NewClient("uid", "", "acc", time.Time{}).Return(m.pmapiClient)
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
m.pmapiClient.EXPECT().IsUnlocked().Return(false)
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("not authorized"))
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
testNewUsers(t, m) m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
} }
*/
func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) { func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) {
users := testNewUsers(t, m) users := testNewUsers(t, m)
defer cleanUpUsersData(users) defer cleanUpUsersData(users)
assert.Equal(m.t, len(expectedCredentials), len(users.GetUsers())) r.Equal(m.t, len(expectedCredentials), len(users.GetUsers()))
credentials := []*credentials.Credentials{} credentials := []*credentials.Credentials{}
for _, user := range users.users { for _, user := range users.users {
credentials = append(credentials, user.creds) credentials = append(credentials, user.creds)
} }
assert.Equal(m.t, expectedCredentials, credentials) r.Equal(m.t, expectedCredentials, credentials)
} }

View File

@ -33,8 +33,9 @@ import (
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" r "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -49,9 +50,12 @@ func TestMain(m *testing.M) {
var ( var (
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals] testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
UserID: "user",
AuthRefresh: pmapi.AuthRefresh{
UID: "uid", UID: "uid",
AccessToken: "acc", AccessToken: "acc",
RefreshToken: "ref", RefreshToken: "ref",
},
} }
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals] testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
@ -81,7 +85,7 @@ var (
} }
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals] testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "user", UserID: "userDisconnected",
Name: "username", Name: "username",
Emails: "user@pm.me", Emails: "user@pm.me",
APIToken: "", APIToken: "",
@ -94,7 +98,7 @@ var (
} }
testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals] testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "users", UserID: "usersDisconnected",
Name: "usersname", Name: "usersname",
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me", Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
APIToken: "", APIToken: "",
@ -111,17 +115,22 @@ var (
Name: "username", Name: "username",
} }
testPMAPIUserDisconnected = &pmapi.User{ //nolint[gochecknoglobals]
ID: "userDisconnected",
Name: "username",
}
testPMAPIAddress = &pmapi.Address{ //nolint[gochecknoglobals] testPMAPIAddress = &pmapi.Address{ //nolint[gochecknoglobals]
ID: "testAddressID", ID: "testAddressID",
Type: pmapi.OriginalAddress, Type: pmapi.OriginalAddress,
Email: "user@pm.me", Email: "user@pm.me",
Receive: pmapi.CanReceive, Receive: true,
} }
testPMAPIAddresses = []*pmapi.Address{ //nolint[gochecknoglobals] testPMAPIAddresses = []*pmapi.Address{ //nolint[gochecknoglobals]
{ID: "usersAddress1ID", Email: "users@pm.me", Receive: pmapi.CanReceive, Type: pmapi.OriginalAddress}, {ID: "usersAddress1ID", Email: "users@pm.me", Receive: true, Type: pmapi.OriginalAddress},
{ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress}, {ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: true, Type: pmapi.AliasAddress},
{ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress}, {ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: true, Type: pmapi.AliasAddress},
} }
testPMAPIEvent = &pmapi.Event{ // nolint[gochecknoglobals] testPMAPIEvent = &pmapi.Event{ // nolint[gochecknoglobals]
@ -129,15 +138,6 @@ var (
} }
) )
func waitForEvents() {
// Wait for goroutine to add listener.
// E.g. calling login to invoke firstsync event. Functions can end sooner than
// goroutines call the listener mock. We need to wait a little bit before the end of
// the test to capture all event calls. This allows us to detect whether there were
// missing calls, or perhaps whether something was called too many times.
time.Sleep(100 * time.Millisecond)
}
type mocks struct { type mocks struct {
t *testing.T t *testing.T
@ -146,7 +146,7 @@ type mocks struct {
PanicHandler *usersmocks.MockPanicHandler PanicHandler *usersmocks.MockPanicHandler
credentialsStore *usersmocks.MockCredentialsStorer credentialsStore *usersmocks.MockCredentialsStorer
storeMaker *usersmocks.MockStoreMaker storeMaker *usersmocks.MockStoreMaker
eventListener *MockListener eventListener *usersmocks.MockListener
clientManager *pmapimocks.MockManager clientManager *pmapimocks.MockManager
pmapiClient *pmapimocks.MockClient pmapiClient *pmapimocks.MockClient
@ -154,6 +154,48 @@ type mocks struct {
storeCache *store.Cache storeCache *store.Cache
} }
func initMocks(t *testing.T) mocks {
var mockCtrl *gomock.Controller
if os.Getenv("VERBOSITY") == "trace" {
mockCtrl = gomock.NewController(&fullStackReporter{t})
} else {
mockCtrl = gomock.NewController(t)
}
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
r.NoError(t, err, "could not get temporary file for store cache")
m := mocks{
t: t,
ctrl: mockCtrl,
locator: usersmocks.NewMockLocator(mockCtrl),
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
eventListener: usersmocks.NewMockListener(mockCtrl),
clientManager: pmapimocks.NewMockManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()),
}
// Called during clean-up.
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
// Set up store factory.
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
r.NoError(t, err, "could not get temporary file for store db")
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
}).AnyTimes()
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
return m
}
type fullStackReporter struct { type fullStackReporter struct {
T testing.TB T testing.TB
} }
@ -168,86 +210,18 @@ func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
fr.T.FailNow() fr.T.FailNow()
} }
func initMocks(t *testing.T) mocks {
var mockCtrl *gomock.Controller
if os.Getenv("VERBOSITY") == "trace" {
mockCtrl = gomock.NewController(&fullStackReporter{t})
} else {
mockCtrl = gomock.NewController(t)
}
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
require.NoError(t, err, "could not get temporary file for store cache")
m := mocks{
t: t,
ctrl: mockCtrl,
locator: usersmocks.NewMockLocator(mockCtrl),
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
eventListener: NewMockListener(mockCtrl),
clientManager: pmapimocks.NewMockManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()),
}
// Called during clean-up.
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
// Set up store factory.
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
require.NoError(t, err, "could not get temporary file for store db")
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
}).AnyTimes()
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
return m
}
func testNewUsersWithUsers(t *testing.T, m mocks) *Users { func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
// Events are asynchronous m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil)
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).Times(2) mockLoadingConnectedUser(m, testCredentials)
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2) mockLoadingConnectedUser(m, testCredentialsSplit)
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2) mockEventLoopNoAction(m)
gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil),
// Init for user.
m.credentialsStore.EXPECT().Get(testCredentials.UserID).Return(testCredentials, nil),
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
m.credentialsStore.EXPECT().UpdateToken(testCredentials.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
m.credentialsStore.EXPECT().UpdatePassword(testCredentials.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Init for users.
m.credentialsStore.EXPECT().Get(testCredentialsSplit.UserID).Return(testCredentialsSplit, nil),
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
m.credentialsStore.EXPECT().UpdateToken(testCredentialsSplit.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentialsSplit, nil),
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsSplit.UserID, testCredentialsSplit.MailboxPassword).Return(testCredentialsSplit, nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses),
)
return testNewUsers(t, m) return testNewUsers(t, m)
} }
func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam] func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
// FIXME(conman): How to handle force upgrade? m.eventListener.EXPECT().ProvideChannel(events.UpgradeApplicationEvent)
// m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any()) m.eventListener.EXPECT().ProvideChannel(events.InternetOnEvent)
users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true) users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true)
@ -256,38 +230,84 @@ func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
return users return users
} }
func waitForEvents() {
// Wait for goroutine to add listener.
// E.g. calling login to invoke firstsync event. Functions can end sooner than
// goroutines call the listener mock. We need to wait a little bit before the end of
// the test to capture all event calls. This allows us to detect whether there were
// missing calls, or perhaps whether something was called too many times.
time.Sleep(100 * time.Millisecond)
}
func cleanUpUsersData(b *Users) { func cleanUpUsersData(b *Users) {
for _, user := range b.users { for _, user := range b.users {
_ = user.clearStore() _ = user.clearStore()
} }
} }
func TestClearData(t *testing.T) { func mockAddingConnectedUser(m mocks) {
m := initMocks(t) gomock.InOrder(
defer m.ctrl.Finish() // Mock of users.FinishLogin.
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().Add("user", "username", testAuthRefresh.UID, testAuthRefresh.RefreshToken, testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}).Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
// m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1) mockInitConnectedUser(m)
// m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1) }
users := testNewUsersWithUsers(t, m) func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
defer cleanUpUsersData(users) authRefresh := &pmapi.AuthRefresh{
UID: "uid",
AccessToken: "acc",
RefreshToken: "ref",
}
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me") gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me") // Mock of users.loadUsersFromCredentialsStore.
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me") m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me") m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, authRefresh, nil),
m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil),
)
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()) mockInitConnectedUser(m)
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil) }
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()) func mockInitConnectedUser(m mocks) {
m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil) // Mock of user initialisation.
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes()
m.locator.EXPECT().Clear() // Mock of store initialisation.
gomock.InOrder(
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
)
}
require.NoError(t, users.ClearData()) func mockLoadingDisconnectedUser(m mocks, creds *credentials.Credentials) {
gomock.InOrder(
// Mock of users.loadUsersFromCredentialsStore.
m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil),
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
)
waitForEvents() mockInitDisconnectedUser(m)
}
func mockInitDisconnectedUser(m mocks) {
gomock.InOrder(
// Mock of user initialisation.
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
// Mock of store initialisation for the unauthorized user.
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
} }
func mockEventLoopNoAction(m mocks) { func mockEventLoopNoAction(m mocks) {
@ -297,19 +317,3 @@ func mockEventLoopNoAction(m mocks) {
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes() m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes() m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
} }
func mockConnectedUser(m mocks) {
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
// m.pmapiClient.EXPECT().AuthRefresh("uid:acc").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
// Set up mocks for store initialisation for the authorized user.
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
)
}

View File

@ -105,6 +105,10 @@ func (h *macOSHelper) Get(secretURL string) (string, string, error) {
return "", "", err return "", "", err
} }
if len(results) == 0 {
return "", "", errors.New("no result")
}
if len(results) != 1 { if len(results) != 1 {
return "", "", errors.New("ambiguous results") return "", "", errors.New("ambiguous results")
} }

View File

@ -29,6 +29,7 @@ var log = logrus.WithField("pkg", "bridgeUtils/listener") //nolint[gochecknoglob
// Listener has a list of channels watching for updates. // Listener has a list of channels watching for updates.
type Listener interface { type Listener interface {
SetLimit(eventName string, limit time.Duration) SetLimit(eventName string, limit time.Duration)
ProvideChannel(eventName string) <-chan string
Add(eventName string, channel chan<- string) Add(eventName string, channel chan<- string)
Remove(eventName string, channel chan<- string) Remove(eventName string, channel chan<- string)
Emit(eventName string, data string) Emit(eventName string, data string)
@ -69,6 +70,15 @@ func (l *listener) SetLimit(eventName string, limit time.Duration) {
l.limits[eventName] = limit l.limits[eventName] = limit
} }
// ProvideChannel creates new channel, adds it to listener and sends to it
// bufferent events.
func (l *listener) ProvideChannel(eventName string) <-chan string {
ch := make(chan string)
l.Add(eventName, ch)
l.RetryEmit(eventName)
return ch
}
// Add adds an event listener. // Add adds an event listener.
func (l *listener) Add(eventName string, channel chan<- string) { func (l *listener) Add(eventName string, channel chan<- string) {
l.lock.Lock() l.lock.Lock()

View File

@ -97,7 +97,7 @@ func fetchWorker(fetchReqCh <-chan fetchReq, fetchResCh chan<- fetchRes, attachW
} }
func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, error) { func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, error) {
msg, err := req.api.GetMessage(req.messageID) msg, err := req.api.GetMessage(req.ctx, req.messageID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -109,7 +109,7 @@ func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, er
} }
process := func(value interface{}) (interface{}, error) { process := func(value interface{}) (interface{}, error) {
rc, err := req.api.GetAttachment(value.(string)) rc, err := req.api.GetAttachment(req.ctx, value.(string))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -43,10 +43,10 @@ func newTestFetcher(
) Fetcher { ) Fetcher {
f := mocks.NewMockFetcher(m) f := mocks.NewMockFetcher(m)
f.EXPECT().GetMessage(msg.ID).Return(msg, nil) f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
for i, att := range msg.Attachments { for i, att := range msg.Attachments {
f.EXPECT().GetAttachment(att.ID).Return(newTestReadCloser(attData[i]), nil) f.EXPECT().GetAttachment(gomock.Any(), att.ID).Return(newTestReadCloser(attData[i]), nil)
} }
f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(kr, nil) f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(kr, nil)

View File

@ -1230,7 +1230,7 @@ func TestBuildFetchMessageFail(t *testing.T) {
// Pretend the message cannot be fetched. // Pretend the message cannot be fetched.
f := mocks.NewMockFetcher(m) f := mocks.NewMockFetcher(m)
f.EXPECT().GetMessage(msg.ID).Return(nil, errors.New("oops")) f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(nil, errors.New("oops"))
// The job should fail, returning an error and a nil result. // The job should fail, returning an error and a nil result.
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult() res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
@ -1251,8 +1251,8 @@ func TestBuildFetchAttachmentFail(t *testing.T) {
// Pretend the attachment cannot be fetched. // Pretend the attachment cannot be fetched.
f := mocks.NewMockFetcher(m) f := mocks.NewMockFetcher(m)
f.EXPECT().GetMessage(msg.ID).Return(msg, nil) f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
f.EXPECT().GetAttachment(msg.Attachments[0].ID).Return(nil, errors.New("oops")) f.EXPECT().GetAttachment(gomock.Any(), msg.Attachments[0].ID).Return(nil, errors.New("oops"))
// The job should fail, returning an error and a nil result. // The job should fail, returning an error and a nil result.
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult() res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
@ -1272,7 +1272,7 @@ func TestBuildNoSuchKeyRing(t *testing.T) {
// Pretend there is no available keyring. // Pretend there is no available keyring.
f := mocks.NewMockFetcher(m) f := mocks.NewMockFetcher(m)
f.EXPECT().GetMessage(msg.ID).Return(msg, nil) f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(nil, errors.New("oops")) f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(nil, errors.New("oops"))
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult() res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()

View File

@ -31,7 +31,7 @@ const (
// GetFlags returns imap flags from pmapi message attributes. // GetFlags returns imap flags from pmapi message attributes.
func GetFlags(m *pmapi.Message) (flags []string) { func GetFlags(m *pmapi.Message) (flags []string) {
if m.Unread == 0 { if !m.Unread {
flags = append(flags, imap.SeenFlag) flags = append(flags, imap.SeenFlag)
} }
if !m.Has(pmapi.FlagSent) && !m.Has(pmapi.FlagReceived) { if !m.Has(pmapi.FlagSent) && !m.Has(pmapi.FlagReceived) {
@ -68,11 +68,11 @@ func ParseFlags(m *pmapi.Message, flags []string) {
m.Flags = pmapi.FlagReceived m.Flags = pmapi.FlagReceived
} }
m.Unread = 1 m.Unread = true
for _, f := range flags { for _, f := range flags {
switch f { switch f {
case imap.SeenFlag: case imap.SeenFlag:
m.Unread = 0 m.Unread = false
case imap.DraftFlag: case imap.DraftFlag:
m.Flags &= ^pmapi.FlagSent m.Flags &= ^pmapi.FlagSent
m.Flags &= ^pmapi.FlagReceived m.Flags &= ^pmapi.FlagReceived

View File

@ -5,6 +5,7 @@
package mocks package mocks
import ( import (
context "context"
io "io" io "io"
reflect "reflect" reflect "reflect"
@ -37,33 +38,33 @@ func (m *MockFetcher) EXPECT() *MockFetcherMockRecorder {
} }
// GetAttachment mocks base method // GetAttachment mocks base method
func (m *MockFetcher) GetAttachment(arg0 string) (io.ReadCloser, error) { func (m *MockFetcher) GetAttachment(arg0 context.Context, arg1 string) (io.ReadCloser, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAttachment", arg0) ret := m.ctrl.Call(m, "GetAttachment", arg0, arg1)
ret0, _ := ret[0].(io.ReadCloser) ret0, _ := ret[0].(io.ReadCloser)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetAttachment indicates an expected call of GetAttachment // GetAttachment indicates an expected call of GetAttachment
func (mr *MockFetcherMockRecorder) GetAttachment(arg0 interface{}) *gomock.Call { func (mr *MockFetcherMockRecorder) GetAttachment(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockFetcher)(nil).GetAttachment), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockFetcher)(nil).GetAttachment), arg0, arg1)
} }
// GetMessage mocks base method // GetMessage mocks base method
func (m *MockFetcher) GetMessage(arg0 string) (*pmapi.Message, error) { func (m *MockFetcher) GetMessage(arg0 context.Context, arg1 string) (*pmapi.Message, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMessage", arg0) ret := m.ctrl.Call(m, "GetMessage", arg0, arg1)
ret0, _ := ret[0].(*pmapi.Message) ret0, _ := ret[0].(*pmapi.Message)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetMessage indicates an expected call of GetMessage // GetMessage indicates an expected call of GetMessage
func (mr *MockFetcherMockRecorder) GetMessage(arg0 interface{}) *gomock.Call { func (mr *MockFetcherMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockFetcher)(nil).GetMessage), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockFetcher)(nil).GetMessage), arg0, arg1)
} }
// KeyRingForAddressID mocks base method // KeyRingForAddressID mocks base method

View File

@ -191,7 +191,7 @@ func DecodeHeader(raw string) (decoded string, err error) {
return return
} }
// EncodeHeader using quoted printable and utf8 // EncodeHeader using quoted printable and utf8.
func EncodeHeader(s string) string { func EncodeHeader(s string) string {
return mime.QEncoding.Encode("utf-8", s) return mime.QEncoding.Encode("utf-8", s)
} }

View File

@ -19,7 +19,6 @@ package pmmime
import ( import (
"bytes" "bytes"
//"fmt"
"strings" "strings"
"testing" "testing"

View File

@ -32,12 +32,6 @@ const (
EnabledAddress EnabledAddress
) )
// Address receive values.
const (
CannotReceive = iota
CanReceive
)
// Address HasKeys values. // Address HasKeys values.
const ( const (
MissingKeys = iota MissingKeys = iota
@ -66,7 +60,7 @@ type Address struct {
DomainID string DomainID string
Email string Email string
Send int Send int
Receive int Receive Boolean
Status int Status int
Order int `json:",omitempty"` Order int `json:",omitempty"`
Type int Type int
@ -103,7 +97,7 @@ func (l AddressList) AllEmails() (addresses []string) {
// ActiveEmails returns only active emails. // ActiveEmails returns only active emails.
func (l AddressList) ActiveEmails() (addresses []string) { func (l AddressList) ActiveEmails() (addresses []string) {
for _, a := range l { for _, a := range l {
if a.Receive == CanReceive { if a.Receive {
addresses = append(addresses, a.Email) addresses = append(addresses, a.Email)
} }
} }
@ -175,8 +169,19 @@ func (c *client) GetAddresses(ctx context.Context) (addresses AddressList, err e
return res.Addresses, nil return res.Addresses, nil
} }
func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) (err error) { func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) error {
panic("TODO") if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(&struct {
AddressIDs []string
}{
AddressIDs: addressIDs,
}).Put("/addresses/order")
}); err != nil {
return err
}
_, err := c.UpdateUser(ctx)
return err
} }
// Addresses returns the addresses stored in the client object itself rather than fetching from the API. // Addresses returns the addresses stored in the client object itself rather than fetching from the API.
@ -185,24 +190,22 @@ func (c *client) Addresses() AddressList {
} }
// unlockAddresses unlocks all keys for all addresses of current user. // unlockAddresses unlocks all keys for all addresses of current user.
func (c *client) unlockAddress(passphrase []byte, address *Address) (err error) { func (c *client) unlockAddress(passphrase []byte, address *Address) error {
if address == nil { if address == nil {
return errors.New("address data is missing") return errors.New("address data is missing")
} }
if address.HasKeys == MissingKeys { if address.HasKeys == MissingKeys {
return return nil
} }
var kr *crypto.KeyRing kr, err := address.Keys.UnlockAll(passphrase, c.userKeyRing)
if err != nil {
if kr, err = address.Keys.UnlockAll(passphrase, c.userKeyRing); err != nil { return err
return
} }
c.addrKeyRing[address.ID] = kr c.addrKeyRing[address.ID] = kr
return nil
return
} }
func (c *client) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) { func (c *client) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) {

View File

@ -20,6 +20,8 @@ package pmapi
import ( import (
"net/http" "net/http"
"testing" "testing"
r "github.com/stretchr/testify/require"
) )
var testAddressList = AddressList{ var testAddressList = AddressList{
@ -46,39 +48,29 @@ var testAddressList = AddressList{
}, },
} }
func routeGetAddresses(tb testing.TB, w http.ResponseWriter, r *http.Request) string { func routeGetAddresses(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
Ok(tb, checkMethodAndPath(r, "GET", "/addresses")) r.NoError(tb, checkMethodAndPath(req, "GET", "/addresses"))
Ok(tb, isAuthReq(r, testUID, testAccessToken)) r.NoError(tb, isAuthReq(req, testUID, testAccessToken))
return "addresses/get_response.json" return "addresses/get_response.json"
} }
func TestAddressList(t *testing.T) { func TestAddressList(t *testing.T) {
input := "1" input := "1"
addr := testAddressList.ByID(input) addr := testAddressList.ByID(input)
if addr != testAddressList[0] { r.Equal(t, testAddressList[0], addr)
t.Errorf("ById(%s) expected:\n%v\n but have:\n%v\n", input, testAddressList[0], addr)
}
input = "42" input = "42"
addr = testAddressList.ByID(input) addr = testAddressList.ByID(input)
if addr != nil { r.Nil(t, addr)
t.Errorf("ById expected nil for %s but have : %v\n", input, addr)
}
input = "root@protonmail.com" input = "root@protonmail.com"
addr = testAddressList.ByEmail(input) addr = testAddressList.ByEmail(input)
if addr != testAddressList[2] { r.Equal(t, testAddressList[2], addr)
t.Errorf("ByEmail(%s) expected:\n%v\n but have:\n%v\n", input, testAddressList[2], addr)
}
input = "idontexist@protonmail.com" input = "idontexist@protonmail.com"
addr = testAddressList.ByEmail(input) addr = testAddressList.ByEmail(input)
if addr != nil { r.Nil(t, addr)
t.Errorf("ByEmail expected nil for %s but have : %v\n", input, addr)
}
addr = testAddressList.Main() addr = testAddressList.Main()
if addr != testAddressList[1] { r.Equal(t, testAddressList[1], addr)
t.Errorf("Main() expected:\n%v\n but have:\n%v\n", testAddressList[1], addr)
}
} }

View File

@ -23,7 +23,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"mime/multipart"
"net/textproto" "net/textproto"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
@ -138,44 +137,6 @@ func (a *Attachment) DetachedSign(kr *crypto.KeyRing, att io.Reader) (signed io.
return signAttachment(kr, att) return signAttachment(kr, att)
} }
func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.Reader) (err error) {
// Create metadata fields.
if err = w.WriteField("Filename", att.Name); err != nil {
return
}
if err = w.WriteField("MessageID", att.MessageID); err != nil {
return
}
if err = w.WriteField("MIMEType", att.MIMEType); err != nil {
return
}
if err = w.WriteField("ContentID", att.ContentID); err != nil {
return
}
// And send attachment data.
ff, err := w.CreateFormFile("DataPacket", "DataPacket.pgp")
if err != nil {
return
}
if _, err = io.Copy(ff, r); err != nil {
return
}
// And send attachment data.
sigff, err := w.CreateFormFile("Signature", "Signature.pgp")
if err != nil {
return
}
if _, err = io.Copy(sigff, sig); err != nil {
return
}
return err
}
// CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID. // CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID.
// //
// The returned created attachment contains the new attachment ID and its size. // The returned created attachment contains the new attachment ID and its size.

View File

@ -28,13 +28,13 @@ import (
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/textproto" "net/textproto"
"reflect"
"strings" "strings"
"testing" "testing"
pmmime "github.com/ProtonMail/proton-bridge/pkg/mime" pmmime "github.com/ProtonMail/proton-bridge/pkg/mime"
"github.com/stretchr/testify/assert" a "github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
) )
var testAttachment = &Attachment{ var testAttachment = &Attachment{
@ -77,65 +77,40 @@ const testCreateAttachmentBody = `{
"Attachment": {"ID": "y6uKIlc2HdoHPAwPSrvf7dXoZNMYvBgxshYUN67cY5DJjL2O8NYewuvGHcYvCfd8LpEoAI_GdymO0Jr0mHlsEw=="} "Attachment": {"ID": "y6uKIlc2HdoHPAwPSrvf7dXoZNMYvBgxshYUN67cY5DJjL2O8NYewuvGHcYvCfd8LpEoAI_GdymO0Jr0mHlsEw=="}
}` }`
const testDeleteAttachmentBody = `{
"Code": 1000
}`
func TestAttachment_UnmarshalJSON(t *testing.T) { func TestAttachment_UnmarshalJSON(t *testing.T) {
att := new(Attachment) att := new(Attachment)
if err := json.Unmarshal([]byte(testAttachmentJSON), att); err != nil { err := json.Unmarshal([]byte(testAttachmentJSON), att)
t.Fatal("Expected no error while unmarshaling JSON, got:", err) r.NoError(t, err)
}
att.MessageID = testAttachment.MessageID // This isn't in the JSON object att.MessageID = testAttachment.MessageID // This isn't in the JSON object
if !reflect.DeepEqual(testAttachment, att) { r.Equal(t, testAttachment, att)
t.Errorf("Invalid attachment: expected %+v but got %+v", testAttachment, att)
}
} }
func TestClient_CreateAttachment(t *testing.T) { func TestClient_CreateAttachment(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/attachments")) r.NoError(t, checkMethodAndPath(req, "POST", "/mail/v4/attachments"))
contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type")) contentType, params, err := pmmime.ParseMediaType(req.Header.Get("Content-Type"))
if err != nil { r.NoError(t, err)
t.Error("Expected no error while parsing request content type, got:", err) r.Equal(t, "multipart/form-data", contentType)
}
if contentType != "multipart/form-data" {
t.Errorf("Invalid request content type: expected %v but got %v", "multipart/form-data", contentType)
}
mr := multipart.NewReader(r.Body, params["boundary"]) mr := multipart.NewReader(req.Body, params["boundary"])
form, err := mr.ReadForm(10 * 1024) form, err := mr.ReadForm(10 * 1024)
if err != nil { r.NoError(t, err)
t.Error("Expected no error while parsing request form, got:", err) defer r.NoError(t, form.RemoveAll())
}
defer Ok(t, form.RemoveAll())
if form.Value["Filename"][0] != testAttachment.Name { r.Equal(t, testAttachment.Name, form.Value["Filename"][0])
t.Errorf("Invalid attachment filename: expected %v but got %v", testAttachment.Name, form.Value["Filename"][0]) r.Equal(t, testAttachment.MessageID, form.Value["MessageID"][0])
} r.Equal(t, testAttachment.MIMEType, form.Value["MIMEType"][0])
if form.Value["MessageID"][0] != testAttachment.MessageID {
t.Errorf("Invalid attachment message id: expected %v but got %v", testAttachment.MessageID, form.Value["MessageID"][0])
}
if form.Value["MIMEType"][0] != testAttachment.MIMEType {
t.Errorf("Invalid attachment message id: expected %v but got %v", testAttachment.MIMEType, form.Value["MIMEType"][0])
}
dataFile, err := form.File["DataPacket"][0].Open() dataFile, err := form.File["DataPacket"][0].Open()
if err != nil { r.NoError(t, err)
t.Error("Expected no error while opening packets file, got:", err) defer r.NoError(t, dataFile.Close())
}
defer Ok(t, dataFile.Close())
b, err := ioutil.ReadAll(dataFile) b, err := ioutil.ReadAll(dataFile)
if err != nil { r.NoError(t, err)
t.Error("Expected no error while reading packets file, got:", err) r.Equal(t, testAttachmentCleartext, string(b))
}
if string(b) != testAttachmentCleartext {
t.Errorf("Invalid attachment packets: expected %v but got %v", testAttachment.KeyPackets, string(b))
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -143,50 +118,39 @@ func TestClient_CreateAttachment(t *testing.T) {
})) }))
defer s.Close() defer s.Close()
r := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted reader := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted
created, err := c.CreateAttachment(context.TODO(), testAttachment, r, strings.NewReader("")) created, err := c.CreateAttachment(context.Background(), testAttachment, reader, strings.NewReader(""))
if err != nil { r.NoError(t, err)
t.Fatal("Expected no error while creating attachment, got:", err)
}
if created.ID != testAttachment.ID { r.Equal(t, testAttachment.ID, created.ID)
t.Errorf("Invalid attachment id: expected %v but got %v", testAttachment.ID, created.ID)
}
} }
func TestClient_GetAttachment(t *testing.T) { func TestClient_GetAttachment(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
Ok(t, checkMethodAndPath(r, "GET", "/mail/v4/attachments/"+testAttachment.ID)) r.NoError(t, checkMethodAndPath(req, "GET", "/mail/v4/attachments/"+testAttachment.ID))
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testAttachmentCleartext) fmt.Fprint(w, testAttachmentCleartext)
})) }))
defer s.Close() defer s.Close()
r, err := c.GetAttachment(context.TODO(), testAttachment.ID) att, err := c.GetAttachment(context.Background(), testAttachment.ID)
if err != nil { r.NoError(t, err)
t.Fatal("Expected no error while getting attachment, got:", err) defer att.Close() //nolint[errcheck]
}
defer r.Close() //nolint[errcheck]
// In reality, r contains encrypted data // In reality, r contains encrypted data
b, err := ioutil.ReadAll(r) b, err := ioutil.ReadAll(att)
if err != nil { r.NoError(t, err)
t.Fatal("Expected no error while reading attachment, got:", err)
}
if string(b) != testAttachmentCleartext { r.Equal(t, testAttachmentCleartext, string(b))
t.Errorf("Invalid attachment data: expected %q but got %q", testAttachmentCleartext, string(b))
}
} }
func TestAttachment_Encrypt(t *testing.T) { func TestAttachment_Encrypt(t *testing.T) {
data := bytes.NewBufferString(testAttachmentCleartext) data := bytes.NewBufferString(testAttachmentCleartext)
r, err := testAttachment.Encrypt(testPublicKeyRing, data) r, err := testAttachment.Encrypt(testPublicKeyRing, data)
assert.Nil(t, err) a.Nil(t, err)
b, err := ioutil.ReadAll(r) b, err := ioutil.ReadAll(r)
assert.Nil(t, err) a.Nil(t, err)
// Result is always different, so the best way is to test it by decrypting again. // Result is always different, so the best way is to test it by decrypting again.
// Another test for decrypting will help us to be sure it's working. // Another test for decrypting will help us to be sure it's working.
@ -202,8 +166,8 @@ func TestAttachment_Decrypt(t *testing.T) {
func decryptAndCheck(t *testing.T, data io.Reader) { func decryptAndCheck(t *testing.T, data io.Reader) {
r, err := testAttachment.Decrypt(data, testPrivateKeyRing) r, err := testAttachment.Decrypt(data, testPrivateKeyRing)
assert.Nil(t, err) a.Nil(t, err)
b, err := ioutil.ReadAll(r) b, err := ioutil.ReadAll(r)
assert.Nil(t, err) a.Nil(t, err)
assert.Equal(t, testAttachmentCleartext, string(b)) a.Equal(t, testAttachmentCleartext, string(b))
} }

View File

@ -1,3 +1,20 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi package pmapi
import ( import (
@ -6,15 +23,117 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"io" "io"
"net/http"
"time" "time"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
) )
func (c *client) Auth2FA(ctx context.Context, req Auth2FAReq) error { type AuthModulus struct {
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { Modulus string
return r.SetBody(req).Post("/auth/2fa") ModulusID string
}
type GetAuthInfoReq struct {
Username string
}
type AuthInfo struct {
Version int
Modulus string
ServerEphemeral string
Salt string
SRPSession string
}
type TwoFAInfo struct {
Enabled TwoFAStatus
}
func (twoFAInfo TwoFAInfo) hasTwoFactor() bool {
return twoFAInfo.Enabled > 0
}
type TwoFAStatus int
const (
TwoFADisabled TwoFAStatus = iota
TOTPEnabled
U2FEnabled
TOTPAndU2FEnabled
)
type PasswordMode int
const (
OnePasswordMode PasswordMode = iota + 1
TwoPasswordMode
)
type AuthReq struct {
Username string
ClientProof string
ClientEphemeral string
SRPSession string
}
type AuthRefresh struct {
UID string
AccessToken string
RefreshToken string
ExpiresIn int64
Scopes []string
}
type Auth struct {
AuthRefresh
UserID string
ServerProof string
PasswordMode PasswordMode
TwoFA *TwoFAInfo `json:"2FA,omitempty"`
}
func (a Auth) HasTwoFactor() bool {
if a.TwoFA == nil {
return false
}
return a.TwoFA.hasTwoFactor()
}
func (a Auth) HasMailboxPassword() bool {
return a.PasswordMode == TwoPasswordMode
}
type auth2FAReq struct {
TwoFactorCode string
}
type authRefreshReq struct {
UID string
RefreshToken string
ResponseType string
GrantType string
RedirectURI string
State string
}
func (c *client) Auth2FA(ctx context.Context, twoFactorCode string) error {
// 2FA is called during login procedure during which refresh token should
// be valid, therefore, no refresh is needed if there is an error.
ctx = ContextWithoutAuthRefresh(ctx)
if res, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(auth2FAReq{TwoFactorCode: twoFactorCode}).Post("/auth/2fa")
}); err != nil { }); err != nil {
if res != nil {
switch res.StatusCode() {
case http.StatusUnauthorized:
return ErrBad2FACode
case http.StatusUnprocessableEntity:
return ErrBad2FACodeTryAgain
}
}
return err return err
} }
@ -29,9 +148,7 @@ func (c *client) AuthDelete(ctx context.Context) error {
} }
c.uid, c.acc, c.ref, c.exp = "", "", "", time.Time{} c.uid, c.acc, c.ref, c.exp = "", "", "", time.Time{}
c.sendAuthRefresh(nil)
// FIXME(conman): should we perhaps signal via AuthHandler that the auth was deleted?
return nil return nil
} }
@ -54,7 +171,7 @@ func (c *client) AuthSalt(ctx context.Context) (string, error) {
return "", errors.New("no matching salt found") return "", errors.New("no matching salt found")
} }
func (c *client) AddAuthHandler(handler AuthHandler) { func (c *client) AddAuthRefreshHandler(handler AuthRefreshHandler) {
c.authHandlers = append(c.authHandlers, handler) c.authHandlers = append(c.authHandlers, handler)
} }
@ -62,23 +179,35 @@ func (c *client) authRefresh(ctx context.Context) error {
c.authLocker.Lock() c.authLocker.Lock()
defer c.authLocker.Unlock() defer c.authLocker.Unlock()
auth, err := c.req.authRefresh(ctx, c.uid, c.ref) if c.ref == "" {
return ErrUnauthorized
}
auth, err := c.manager.authRefresh(ctx, c.uid, c.ref)
if err != nil { if err != nil {
if err != ErrNoConnection {
c.sendAuthRefresh(nil)
}
return err return err
} }
c.acc = auth.AccessToken c.acc = auth.AccessToken
c.ref = auth.RefreshToken c.ref = auth.RefreshToken
c.exp = expiresIn(auth.ExpiresIn)
for _, handler := range c.authHandlers { c.sendAuthRefresh(auth)
if err := handler(auth); err != nil {
return err
}
}
return nil return nil
} }
func (c *client) sendAuthRefresh(auth *AuthRefresh) {
for _, handler := range c.authHandlers {
go handler(auth)
}
if auth == nil {
c.authHandlers = []AuthRefreshHandler{}
}
}
func randomString(length int) string { func randomString(length int) string {
noise := make([]byte, length) noise := make([]byte, length)

View File

@ -1,22 +1,40 @@
package pmapi_test // Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time" "time"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" a "github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
) )
func TestAutomaticAuthRefresh(t *testing.T) { func TestAutomaticAuthRefresh(t *testing.T) {
var wantAuth = &pmapi.Auth{ var wantAuthRefresh = &AuthRefresh{
UID: "testUID", UID: "testUID",
AccessToken: "testAcc", AccessToken: "testAcc",
RefreshToken: "testRef", RefreshToken: "testRef",
ExpiresIn: 100,
} }
mux := http.NewServeMux() mux := http.NewServeMux()
@ -24,7 +42,7 @@ func TestAutomaticAuthRefresh(t *testing.T) {
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(wantAuth); err != nil { if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
panic(err) panic(err)
} }
}) })
@ -35,28 +53,28 @@ func TestAutomaticAuthRefresh(t *testing.T) {
ts := httptest.NewServer(mux) ts := httptest.NewServer(mux)
var gotAuth *pmapi.Auth var gotAuthRefresh *AuthRefresh
// Create a new client. c := New(Config{HostURL: ts.URL}).
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second)) NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
// Register an auth handler. c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
// Make a request with an access token that already expired one second ago. // Make a request with an access token that already expired one second ago.
if _, err := c.GetAddresses(context.Background()); err != nil { _, err := c.GetAddresses(context.Background())
t.Fatal("got unexpected error", err) r.NoError(t, err)
}
// The auth callback should have been called. // The auth callback should have been called.
if *gotAuth != *wantAuth { a.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
t.Fatal("got unexpected auth", gotAuth)
} cl := c.(*client) //nolint[forcetypeassert] we want to panic here
a.Equal(t, wantAuthRefresh.AccessToken, cl.acc)
a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref)
a.WithinDuration(t, expiresIn(100), cl.exp, time.Second)
} }
func Test401AuthRefresh(t *testing.T) { func Test401AuthRefresh(t *testing.T) {
var wantAuth = &pmapi.Auth{ var wantAuthRefresh = &AuthRefresh{
UID: "testUID", UID: "testUID",
AccessToken: "testAcc", AccessToken: "testAcc",
RefreshToken: "testRef", RefreshToken: "testRef",
@ -67,7 +85,7 @@ func Test401AuthRefresh(t *testing.T) {
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(wantAuth); err != nil { if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
panic(err) panic(err)
} }
}) })
@ -86,24 +104,21 @@ func Test401AuthRefresh(t *testing.T) {
ts := httptest.NewServer(mux) ts := httptest.NewServer(mux)
var gotAuth *pmapi.Auth var gotAuthRefresh *AuthRefresh
// Create a new client. // Create a new client.
c := pmapi.New(pmapi.Config{HostURL: ts.URL}). c := New(Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour)) NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
// Register an auth handler. // Register an auth handler.
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil }) c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
// The first request will fail with 401, triggering a refresh and retry. // The first request will fail with 401, triggering a refresh and retry.
if _, err := c.GetAddresses(context.Background()); err != nil { _, err := c.GetAddresses(context.Background())
t.Fatal("got unexpected error", err) r.NoError(t, err)
}
// The auth callback should have been called. // The auth callback should have been called.
if *gotAuth != *wantAuth { r.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
t.Fatal("got unexpected auth", gotAuth)
}
} }
func Test401RevokedAuth(t *testing.T) { func Test401RevokedAuth(t *testing.T) {
@ -119,17 +134,57 @@ func Test401RevokedAuth(t *testing.T) {
ts := httptest.NewServer(mux) ts := httptest.NewServer(mux)
c := pmapi.New(pmapi.Config{HostURL: ts.URL}). c := New(Config{HostURL: ts.URL}).
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour)) NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
// The request will fail with 401, triggering a refresh. // The request will fail with 401, triggering a refresh.
// The retry will also fail with 401, returning an error. // The retry will also fail with 401, returning an error.
_, err := c.GetAddresses(context.Background()) _, err := c.GetAddresses(context.Background())
if err == nil { r.EqualError(t, err, ErrUnauthorized.Error())
t.Fatal("expected error, instead got", err) }
}
func TestAuth2FA(t *testing.T) {
if !errors.Is(err, pmapi.ErrUnauthorized) { twoFACode := "code"
t.Fatal("expected error to be ErrUnauthorized, instead got", err)
} finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
var twoFAreq auth2FAReq
r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq))
r.Equal(t, twoFAreq.TwoFactorCode, twoFACode)
return "/auth/2fa/post_response.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), twoFACode)
r.NoError(t, err)
}
func TestAuth2FA_Fail(t *testing.T) {
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
return "/auth/2fa/post_401_bad_password.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), "code")
r.Equal(t, ErrBad2FACode, err)
}
func TestAuth2FA_Retry(t *testing.T) {
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
return "/auth/2fa/post_422_bad_password.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), "code")
r.Equal(t, ErrBad2FACodeTryAgain, err)
} }

View File

@ -1,72 +0,0 @@
package pmapi
type AuthModulus struct {
Modulus string
ModulusID string
}
type GetAuthInfoReq struct {
Username string
}
type AuthInfo struct {
Version int
Modulus string
ServerEphemeral string
Salt string
SRPSession string
}
type TwoFAInfo struct {
Enabled TwoFAStatus
}
type TwoFAStatus int
const (
TwoFADisabled TwoFAStatus = iota
TOTPEnabled
// TODO: Support UTF
)
type PasswordMode int
const (
OnePasswordMode PasswordMode = iota + 1
TwoPasswordMode
)
type AuthReq struct {
Username string
ClientProof string
ClientEphemeral string
SRPSession string
}
type Auth struct {
UserID string
UID string
AccessToken string
RefreshToken string
ExpiresIn int64
Scope string
ServerProof string
TwoFA TwoFAInfo `json:"2FA"`
PasswordMode PasswordMode
}
type Auth2FAReq struct {
TwoFactorCode string
}
type AuthRefreshReq struct {
UID string
RefreshToken string
ResponseType string
GrantType string
RedirectURI string
State string
}

41
pkg/pmapi/boolean.go Normal file
View File

@ -0,0 +1,41 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import "encoding/json"
type Boolean bool
func (boolean *Boolean) UnmarshalJSON(b []byte) error {
var value int
err := json.Unmarshal(b, &value)
if err != nil {
return err
}
*boolean = Boolean(value == 1)
return nil
}
func (boolean Boolean) MarshalJSON() ([]byte, error) {
var value int
if boolean {
value = 1
}
return json.Marshal(value)
}

View File

@ -25,15 +25,14 @@ import (
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
"github.com/pkg/errors"
) )
// client is a client of the protonmail API. It implements the Client interface. // client is a client of the protonmail API. It implements the Client interface.
type client struct { type client struct {
req requester manager clientManager
uid, acc, ref string uid, acc, ref string
authHandlers []AuthHandler authHandlers []AuthRefreshHandler
authLocker sync.RWMutex authLocker sync.RWMutex
user *User user *User
@ -45,9 +44,9 @@ type client struct {
exp time.Time exp time.Time
} }
func newClient(req requester, uid string) *client { func newClient(manager clientManager, uid string) *client {
return &client{ return &client{
req: req, manager: manager,
uid: uid, uid: uid,
addrKeyRing: make(map[string]*crypto.KeyRing), addrKeyRing: make(map[string]*crypto.KeyRing),
keyRingLock: &sync.RWMutex{}, keyRingLock: &sync.RWMutex{},
@ -63,7 +62,7 @@ func (c *client) withAuth(acc, ref string, exp time.Time) *client {
} }
func (c *client) r(ctx context.Context) (*resty.Request, error) { func (c *client) r(ctx context.Context) (*resty.Request, error) {
r := c.req.r(ctx) r := c.manager.r(ctx)
if c.uid != "" { if c.uid != "" {
r.SetHeader("x-pm-uid", c.uid) r.SetHeader("x-pm-uid", c.uid)
@ -91,30 +90,23 @@ func (c *client) do(ctx context.Context, fn func(*resty.Request) (*resty.Respons
return nil, err return nil, err
} }
res, err := wrapRestyError(fn(r)) res, err := wrapNoConnection(fn(r))
if err != nil { if err != nil {
if res.StatusCode() != http.StatusUnauthorized { if res.StatusCode() != http.StatusUnauthorized {
return nil, err // Return also response so caller has more options to decide what to do.
return res, err
} }
if !isAuthRefreshDisabled(ctx) {
if err := c.authRefresh(ctx); err != nil { if err := c.authRefresh(ctx); err != nil {
return nil, err return nil, err
} }
return wrapRestyError(fn(r)) return wrapNoConnection(fn(r))
}
return res, err
} }
return res, nil return res, nil
} }
func wrapRestyError(res *resty.Response, err error) (*resty.Response, error) {
if err, ok := err.(*resty.ResponseError); ok {
return res, err
}
if res.RawResponse != nil {
return res, err
}
return res, errors.Wrap(ErrNoConnection, err.Error())
}

View File

@ -1,3 +1,20 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi package pmapi
import ( import (
@ -12,8 +29,6 @@ func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) {
c.keyRingLock.Lock() c.keyRingLock.Lock()
defer c.keyRingLock.Unlock() defer c.keyRingLock.Unlock()
// FIXME(conman): Should this be done as part of NewClient somehow?
return c.unlock(ctx, passphrase) return c.unlock(ctx, passphrase)
} }
@ -65,6 +80,15 @@ func (c *client) clearKeys() {
} }
func (c *client) IsUnlocked() bool { func (c *client) IsUnlocked() bool {
// FIXME(conman): Better way to check? we don't currently check address keys. if c.userKeyRing == nil {
return c.userKeyRing != nil return false
}
for _, address := range c.addresses {
if address.HasKeys != MissingKeys && c.addrKeyRing[address.ID] == nil {
return false
}
}
return true
} }

View File

@ -27,10 +27,10 @@ import (
// Client defines the interface of a PMAPI client. // Client defines the interface of a PMAPI client.
type Client interface { type Client interface {
Auth2FA(context.Context, Auth2FAReq) error Auth2FA(context.Context, string) error
AuthSalt(ctx context.Context) (string, error) AuthSalt(ctx context.Context) (string, error)
AuthDelete(context.Context) error AuthDelete(context.Context) error
AddAuthHandler(AuthHandler) AddAuthRefreshHandler(AuthRefreshHandler)
CurrentUser(ctx context.Context) (*User, error) CurrentUser(ctx context.Context) (*User, error)
UpdateUser(ctx context.Context) (*User, error) UpdateUser(ctx context.Context) (*User, error)
@ -75,9 +75,9 @@ type Client interface {
GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error) GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
} }
type AuthHandler func(*Auth) error type AuthRefreshHandler func(*AuthRefresh)
type requester interface { type clientManager interface {
r(context.Context) *resty.Request r(context.Context) *resty.Request
authRefresh(context.Context, string, string) (*Auth, error) authRefresh(context.Context, string, string) (*AuthRefresh, error)
} }

View File

@ -1,11 +1,72 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi package pmapi
import (
"runtime"
"strings"
)
type Config struct { type Config struct {
// HostURL is the base URL of API.
HostURL string HostURL string
// AppVersion sets version to headers of each request.
AppVersion string AppVersion string
// UserAgent sets user agent to headers of each request.
// Used only if GetUserAgent is not set.
UserAgent string
// GetUserAgent is dynamic version of UserAgent.
// Overrides UserAgent.
GetUserAgent func() string
// UpgradeApplicationHandler is used to notify when there is a force upgrade.
UpgradeApplicationHandler func()
// TLSIssueHandler is used to notify when there is a TLS issue.
TLSIssueHandler func()
} }
var DefaultConfig = Config{ func NewConfig(appVersionName, appVersion string) Config {
HostURL: "https://api.protonmail.ch", return Config{
AppVersion: "Other", HostURL: getRootURL(),
AppVersion: getAPIOS() + strings.Title(appVersionName) + "_" + appVersion,
}
}
func (c *Config) getUserAgent() string {
if c.GetUserAgent == nil {
return c.UserAgent
}
return c.GetUserAgent()
}
// getAPIOS returns actual operating system.
func getAPIOS() string {
switch os := runtime.GOOS; os {
case "darwin": // nolint: goconst
return "macOS"
case "linux":
return "Linux"
case "windows":
return "Windows"
}
return "Linux"
} }

Some files were not shown because too many files have changed in this diff Show More