mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
GODT-35: Finish all details and make tests pass
This commit is contained in:
5
go.mod
5
go.mod
@ -40,7 +40,7 @@ require (
|
||||
github.com/fatih/color v1.9.0
|
||||
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
|
||||
github.com/getsentry/sentry-go v0.8.0
|
||||
github.com/go-resty/resty/v2 v2.4.0
|
||||
github.com/go-resty/resty/v2 v2.6.0
|
||||
github.com/golang/mock v1.4.4
|
||||
github.com/google/go-cmp v0.5.1
|
||||
github.com/google/uuid v1.1.1
|
||||
@ -50,6 +50,7 @@ require (
|
||||
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
|
||||
github.com/logrusorgru/aurora v2.0.3+incompatible
|
||||
github.com/mattn/go-runewidth v0.0.9 // indirect
|
||||
github.com/miekg/dns v1.1.41
|
||||
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
|
||||
github.com/olekukonko/tablewriter v0.0.4 // indirect
|
||||
github.com/pkg/errors v0.9.1
|
||||
@ -63,7 +64,7 @@ require (
|
||||
github.com/urfave/cli/v2 v2.2.0
|
||||
github.com/vmihailenco/msgpack/v5 v5.1.3
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
|
||||
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec
|
||||
)
|
||||
|
||||
|
||||
18
go.sum
18
go.sum
@ -113,8 +113,8 @@ github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclK
|
||||
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
|
||||
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
|
||||
github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8=
|
||||
github.com/go-resty/resty/v2 v2.4.0 h1:s6TItTLejEI+2mn98oijC5w/Rk2YU+OA6x0mnZN6r6k=
|
||||
github.com/go-resty/resty/v2 v2.4.0/go.mod h1:B88+xCTEwvfD94NOuE6GS1wMlnoKNY8eEiNizfNwOwA=
|
||||
github.com/go-resty/resty/v2 v2.6.0 h1:joIR5PNLM2EFqqESUjCMGXrWmXNHEU9CEiK813oKYS4=
|
||||
github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w/BIH7cC3Q=
|
||||
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||
@ -195,6 +195,8 @@ github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m
|
||||
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
|
||||
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
|
||||
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
|
||||
github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY=
|
||||
github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@ -310,12 +312,16 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw=
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@ -330,6 +336,10 @@ golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04 h1:cEhElsAv9LUt9ZUUocxzWe05oFLVd+AA2nstydTeI8g=
|
||||
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 h1:Bli41pIlzTzf3KEY06n+xnzK/BESIg2ze4Pgfh/aI8c=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
||||
@ -181,21 +181,18 @@ func New( // nolint[funlen]
|
||||
kc = keychain.NewMissingKeychain()
|
||||
}
|
||||
|
||||
// FIXME(conman): Customize config depending on build type (app version, host URL).
|
||||
cm := pmapi.New(pmapi.DefaultConfig)
|
||||
cfg := pmapi.NewConfig(configName, constants.Version)
|
||||
cfg.GetUserAgent = userAgent.String
|
||||
cfg.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
|
||||
cfg.TLSIssueHandler = func() { listener.Emit(events.TLSCertIssue, "") }
|
||||
|
||||
cm := pmapi.New(cfg)
|
||||
|
||||
// FIXME(conman): Should this be a real object, not just created via callbacks?
|
||||
cm.AddConnectionObserver(pmapi.NewConnectionObserver(
|
||||
func() { listener.Emit(events.InternetOffEvent, "") },
|
||||
func() { listener.Emit(events.InternetOnEvent, "") },
|
||||
))
|
||||
|
||||
// FIXME(conman): Implement force upgrade observer.
|
||||
// apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
|
||||
|
||||
// FIXME(conman): Set up fancy round tripper with DoH/TLS checks etc.
|
||||
// cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
|
||||
|
||||
jar, err := cookies.NewCookieJar(settingsObj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -341,6 +338,7 @@ func (b *Base) run(appMainLoop func(*Base, *cli.Context) error) cli.ActionFunc {
|
||||
}
|
||||
|
||||
logging.SetLevel(c.String(flagLogLevel))
|
||||
b.CM.SetLogging(logrus.WithField("pkg", "pmapi"), logrus.GetLevel() == logrus.TraceLevel)
|
||||
|
||||
logrus.
|
||||
WithField("appName", b.Name).
|
||||
|
||||
@ -65,8 +65,7 @@ func New(
|
||||
// Allow DoH before starting the app if the user has previously set this setting.
|
||||
// This allows us to start even if protonmail is blocked.
|
||||
if s.GetBool(settings.AllowProxyKey) {
|
||||
// FIXME(conman): Support enable/disable of DoH.
|
||||
// clientManager.AllowProxy()
|
||||
clientManager.AllowProxy()
|
||||
}
|
||||
|
||||
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener)
|
||||
@ -120,7 +119,7 @@ func (b *Bridge) heartbeat() {
|
||||
|
||||
// ReportBug reports a new bug from the user.
|
||||
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
||||
return b.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
|
||||
return b.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{
|
||||
OS: osType,
|
||||
OSVersion: osVersion,
|
||||
Browser: emailClient,
|
||||
|
||||
@ -21,7 +21,6 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
@ -75,13 +74,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
return
|
||||
}
|
||||
|
||||
if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
|
||||
if auth.HasTwoFactor() {
|
||||
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
||||
if twoFactor == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
|
||||
err = client.Auth2FA(context.Background(), twoFactor)
|
||||
if err != nil {
|
||||
f.processAPIError(err)
|
||||
return
|
||||
@ -89,7 +88,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
}
|
||||
|
||||
mailboxPassword := password
|
||||
if auth.PasswordMode == pmapi.TwoPasswordMode {
|
||||
if auth.HasMailboxPassword() {
|
||||
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
||||
}
|
||||
if mailboxPassword == "" {
|
||||
|
||||
@ -84,11 +84,6 @@ func New( //nolint[funlen]
|
||||
Aliases: []string{"u", "version", "v"},
|
||||
Func: fe.checkUpdates,
|
||||
})
|
||||
checkCmd.AddCmd(&ishell.Cmd{Name: "internet",
|
||||
Help: "check internet connection. (aliases: i, conn, connection)",
|
||||
Aliases: []string{"i", "con", "connection"},
|
||||
Func: fe.checkInternetConnection,
|
||||
})
|
||||
fe.AddCmd(checkCmd)
|
||||
|
||||
// Print info commands.
|
||||
@ -177,13 +172,13 @@ func New( //nolint[funlen]
|
||||
}
|
||||
|
||||
func (f *frontendCLI) watchEvents() {
|
||||
errorCh := f.getEventChannel(events.ErrorEvent)
|
||||
credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent)
|
||||
internetOffCh := f.getEventChannel(events.InternetOffEvent)
|
||||
internetOnCh := f.getEventChannel(events.InternetOnEvent)
|
||||
addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent)
|
||||
logoutCh := f.getEventChannel(events.LogoutEvent)
|
||||
certIssue := f.getEventChannel(events.TLSCertIssue)
|
||||
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
|
||||
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
|
||||
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
|
||||
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
|
||||
certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
|
||||
for {
|
||||
select {
|
||||
case errorDetails := <-errorCh:
|
||||
@ -208,13 +203,6 @@ func (f *frontendCLI) watchEvents() {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) getEventChannel(event string) <-chan string {
|
||||
ch := make(chan string)
|
||||
f.eventListener.Add(event, ch)
|
||||
f.eventListener.RetryEmit(event)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Loop starts the frontend loop with an interactive shell.
|
||||
func (f *frontendCLI) Loop() error {
|
||||
f.Print(`
|
||||
|
||||
@ -29,14 +29,6 @@ func (f *frontendCLI) restart(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) checkInternetConnection(c *ishell.Context) {
|
||||
if f.ie.CheckConnection() == nil {
|
||||
f.Println("Internet connection is available.")
|
||||
} else {
|
||||
f.Println("Can not contact the server, please check your internet connection.")
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) printLogDir(c *ishell.Context) {
|
||||
if path, err := f.locations.ProvideLogsPath(); err != nil {
|
||||
f.Println("Failed to determine location of log files")
|
||||
|
||||
@ -20,6 +20,7 @@ package cliie
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
@ -70,13 +71,10 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
|
||||
func (f *frontendCLI) processAPIError(err error) {
|
||||
log.Warn("API error: ", err)
|
||||
switch err {
|
||||
// FIXME(conman): How to handle various API errors?
|
||||
/*
|
||||
case pmapi.ErrNoConnection:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
*/
|
||||
default:
|
||||
f.Println("Server error:", err.Error())
|
||||
}
|
||||
|
||||
@ -24,7 +24,6 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
@ -122,13 +121,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
return
|
||||
}
|
||||
|
||||
if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
|
||||
if auth.HasTwoFactor() {
|
||||
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
||||
if twoFactor == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
|
||||
err = client.Auth2FA(context.Background(), twoFactor)
|
||||
if err != nil {
|
||||
f.processAPIError(err)
|
||||
return
|
||||
@ -136,7 +135,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
}
|
||||
|
||||
mailboxPassword := password
|
||||
if auth.PasswordMode == pmapi.TwoPasswordMode {
|
||||
if auth.HasMailboxPassword() {
|
||||
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
||||
}
|
||||
if mailboxPassword == "" {
|
||||
|
||||
@ -157,15 +157,6 @@ func New( //nolint[funlen]
|
||||
})
|
||||
fe.AddCmd(updatesCmd)
|
||||
|
||||
// Check commands.
|
||||
checkCmd := &ishell.Cmd{Name: "check", Help: "check internet connection or new version."}
|
||||
checkCmd.AddCmd(&ishell.Cmd{Name: "internet",
|
||||
Help: "check internet connection. (aliases: i, conn, connection)",
|
||||
Aliases: []string{"i", "con", "connection"},
|
||||
Func: fe.checkInternetConnection,
|
||||
})
|
||||
fe.AddCmd(checkCmd)
|
||||
|
||||
// Print info commands.
|
||||
fe.AddCmd(&ishell.Cmd{Name: "log-dir",
|
||||
Help: "print path to directory with logs. (aliases: log, logs)",
|
||||
@ -228,14 +219,14 @@ func New( //nolint[funlen]
|
||||
}
|
||||
|
||||
func (f *frontendCLI) watchEvents() {
|
||||
errorCh := f.getEventChannel(events.ErrorEvent)
|
||||
credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent)
|
||||
internetOffCh := f.getEventChannel(events.InternetOffEvent)
|
||||
internetOnCh := f.getEventChannel(events.InternetOnEvent)
|
||||
addressChangedCh := f.getEventChannel(events.AddressChangedEvent)
|
||||
addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent)
|
||||
logoutCh := f.getEventChannel(events.LogoutEvent)
|
||||
certIssue := f.getEventChannel(events.TLSCertIssue)
|
||||
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
|
||||
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
|
||||
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
|
||||
addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
|
||||
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
|
||||
certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
|
||||
for {
|
||||
select {
|
||||
case errorDetails := <-errorCh:
|
||||
@ -262,13 +253,6 @@ func (f *frontendCLI) watchEvents() {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) getEventChannel(event string) <-chan string {
|
||||
ch := make(chan string)
|
||||
f.eventListener.Add(event, ch)
|
||||
f.eventListener.RetryEmit(event)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Loop starts the frontend loop with an interactive shell.
|
||||
func (f *frontendCLI) Loop() error {
|
||||
f.Print(`
|
||||
|
||||
@ -39,14 +39,6 @@ func (f *frontendCLI) restart(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) checkInternetConnection(c *ishell.Context) {
|
||||
if f.bridge.CheckConnection() == nil {
|
||||
f.Println("Internet connection is available.")
|
||||
} else {
|
||||
f.Println("Can not contact the server, please check your internet connection.")
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) printLogDir(c *ishell.Context) {
|
||||
if path, err := f.locations.ProvideLogsPath(); err != nil {
|
||||
f.Println("Failed to determine location of log files")
|
||||
|
||||
@ -20,6 +20,7 @@ package cli
|
||||
import (
|
||||
"strings"
|
||||
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
@ -70,13 +71,10 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
|
||||
func (f *frontendCLI) processAPIError(err error) {
|
||||
log.Warn("API error: ", err)
|
||||
switch err {
|
||||
// FIXME(conman): How to handle various API errors?
|
||||
/*
|
||||
case pmapi.ErrNoConnection:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
*/
|
||||
default:
|
||||
f.Println("Server error:", err.Error())
|
||||
}
|
||||
|
||||
@ -409,7 +409,6 @@ Dialog {
|
||||
|
||||
onShow: {
|
||||
if (winMain.updateState==gui.enums.statusNoInternet) {
|
||||
go.checkInternet()
|
||||
if (winMain.updateState==gui.enums.statusNoInternet) {
|
||||
go.notifyError(gui.enums.errNoInternet)
|
||||
root.hide()
|
||||
|
||||
@ -857,14 +857,12 @@ Dialog {
|
||||
inputPort . checkIsANumber()
|
||||
//emailProvider . currentIndex!=0
|
||||
)) isOK = false
|
||||
go.checkInternet()
|
||||
if (winMain.updateState == gui.enums.statusNoInternet) { // todo: use main error dialog for this
|
||||
errorPopup.show(qsTr("Please check your internet connection."))
|
||||
return false
|
||||
}
|
||||
break
|
||||
case 2: // loading structure
|
||||
go.checkInternet()
|
||||
if (winMain.updateState == gui.enums.statusNoInternet) {
|
||||
errorPopup.show(qsTr("Please check your internet connection."))
|
||||
return false
|
||||
@ -949,7 +947,6 @@ Dialog {
|
||||
onShow : {
|
||||
root.clear()
|
||||
if (winMain.updateState==gui.enums.statusNoInternet) {
|
||||
go.checkInternet()
|
||||
if (winMain.updateState==gui.enums.statusNoInternet) {
|
||||
winMain.popupMessage.show(go.canNotReachAPI)
|
||||
root.hide()
|
||||
|
||||
@ -25,33 +25,12 @@ import ProtonUI 1.0
|
||||
Rectangle {
|
||||
id: root
|
||||
property var iTry: 0
|
||||
property var secLeft: 0
|
||||
property var second: 1000 // convert millisecond to second
|
||||
property var checkInterval: [ 5, 10, 30, 60, 120, 300, 600 ] // seconds
|
||||
property bool isVisible: true
|
||||
property var fontSize : 1.2 * Style.main.fontSize
|
||||
color : "black"
|
||||
state: "upToDate"
|
||||
|
||||
Timer {
|
||||
id: retryInternet
|
||||
interval: second
|
||||
triggeredOnStart: false
|
||||
repeat: true
|
||||
onTriggered : {
|
||||
secLeft--
|
||||
if (secLeft <= 0) {
|
||||
retryInternet.stop()
|
||||
go.checkInternet()
|
||||
if (iTry < checkInterval.length-1) {
|
||||
iTry++
|
||||
}
|
||||
secLeft=checkInterval[iTry]
|
||||
retryInternet.start()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Row {
|
||||
id: messageRow
|
||||
anchors.centerIn: root
|
||||
@ -110,16 +89,12 @@ Rectangle {
|
||||
case "internetCheck":
|
||||
break;
|
||||
case "noInternet" :
|
||||
retryInternet.start()
|
||||
secLeft=checkInterval[iTry]
|
||||
break;
|
||||
case "oldVersion":
|
||||
break;
|
||||
case "forceUpdate":
|
||||
break;
|
||||
case "upToDate":
|
||||
iTry = 0
|
||||
secLeft=checkInterval[iTry]
|
||||
break;
|
||||
case "updateRestart":
|
||||
break;
|
||||
@ -128,24 +103,6 @@ Rectangle {
|
||||
default :
|
||||
break;
|
||||
}
|
||||
|
||||
if (root.state!="noInternet") {
|
||||
retryInternet.stop()
|
||||
}
|
||||
}
|
||||
|
||||
function timeToRetry() {
|
||||
if (secLeft==1){
|
||||
return qsTr("a second", "time to wait till internet connection is retried")
|
||||
} else if (secLeft<60){
|
||||
return secLeft + " " + qsTr("seconds", "time to wait till internet connection is retried")
|
||||
} else {
|
||||
var leading = ""+secLeft%60
|
||||
if (leading.length < 2) {
|
||||
leading = "0" + leading
|
||||
}
|
||||
return Math.floor(secLeft/60) + ":" + leading
|
||||
}
|
||||
}
|
||||
|
||||
states: [
|
||||
@ -194,23 +151,15 @@ Rectangle {
|
||||
PropertyChanges {
|
||||
target: message
|
||||
color: Style.main.line
|
||||
text: qsTr("Cannot contact server. Retrying in ", "displayed when the app is disconnected from the internet or server has problems")+timeToRetry()+"."
|
||||
text: qsTr("Cannot contact server. Please wait...", "displayed when the app is disconnected from the internet or server has problems")
|
||||
}
|
||||
PropertyChanges {
|
||||
target: linkText
|
||||
visible: false
|
||||
}
|
||||
PropertyChanges {
|
||||
target: actionText
|
||||
visible: true
|
||||
text: qsTr("Retry now", "click to try to connect to the internet when the app is disconnected from the internet")
|
||||
onClicked: {
|
||||
go.checkInternet()
|
||||
}
|
||||
}
|
||||
PropertyChanges {
|
||||
target: separatorText
|
||||
visible: true
|
||||
visible: false
|
||||
text: "|"
|
||||
}
|
||||
PropertyChanges {
|
||||
|
||||
@ -1331,10 +1331,6 @@ Window {
|
||||
return (fname!="fail")
|
||||
}
|
||||
|
||||
function checkInternet() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
function loadImportReports(fname) {
|
||||
console.log("load import reports for ", fname)
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
package qtcommon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -207,7 +208,7 @@ func (a *Accounts) Auth2FA(twoFacAuth string) int {
|
||||
if a.auth == nil || a.authClient == nil {
|
||||
err = fmt.Errorf("missing authentication in auth2FA %p %p", a.auth, a.authClient)
|
||||
} else {
|
||||
err = a.authClient.Auth2FA(twoFacAuth, a.auth)
|
||||
err = a.authClient.Auth2FA(context.Background(), twoFacAuth)
|
||||
}
|
||||
|
||||
if a.showLoginError(err, "auth2FA") {
|
||||
|
||||
@ -113,10 +113,3 @@ type Listener interface {
|
||||
Add(string, chan<- string)
|
||||
RetryEmit(string)
|
||||
}
|
||||
|
||||
func MakeAndRegisterEvent(eventListener Listener, event string) <-chan string {
|
||||
ch := make(chan string)
|
||||
eventListener.Add(event, ch)
|
||||
eventListener.RetryEmit(event)
|
||||
return ch
|
||||
}
|
||||
|
||||
@ -143,16 +143,16 @@ func (f *FrontendQt) NotifySilentUpdateError(err error) {
|
||||
}
|
||||
|
||||
func (f *FrontendQt) watchEvents() {
|
||||
credentialsErrorCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.CredentialsErrorEvent)
|
||||
internetOffCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOffEvent)
|
||||
internetOnCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOnEvent)
|
||||
secondInstanceCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.SecondInstanceEvent)
|
||||
restartBridgeCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.RestartBridgeEvent)
|
||||
addressChangedCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedEvent)
|
||||
addressChangedLogoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedLogoutEvent)
|
||||
logoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.LogoutEvent)
|
||||
updateApplicationCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UpgradeApplicationEvent)
|
||||
newUserCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UserRefreshEvent)
|
||||
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
|
||||
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
|
||||
secondInstanceCh := f.eventListener.ProvideChannel(events.SecondInstanceEvent)
|
||||
restartBridgeCh := f.eventListener.ProvideChannel(events.RestartBridgeEvent)
|
||||
addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
|
||||
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
|
||||
updateApplicationCh := f.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
|
||||
newUserCh := f.eventListener.ProvideChannel(events.UserRefreshEvent)
|
||||
for {
|
||||
select {
|
||||
case <-credentialsErrorCh:
|
||||
@ -351,11 +351,6 @@ func (f *FrontendQt) sendBug(description, emailClient, address string) bool {
|
||||
// }
|
||||
//}
|
||||
|
||||
// checkInternet is almost idetical to bridge
|
||||
func (f *FrontendQt) checkInternet() {
|
||||
f.Qml.SetConnectionStatus(f.ie.CheckConnection() == nil)
|
||||
}
|
||||
|
||||
func (f *FrontendQt) showError(code int, err error) {
|
||||
f.Qml.SetErrorDescription(err.Error())
|
||||
log.WithField("code", code).Errorln(err.Error())
|
||||
|
||||
@ -78,7 +78,6 @@ type GoQMLInterface struct {
|
||||
_ string `property:"versionCheckFailed"`
|
||||
//
|
||||
_ func(isAvailable bool) `signal:"setConnectionStatus"`
|
||||
_ func() `slot:"checkInternet"`
|
||||
|
||||
_ func() `slot:"setToRestart"`
|
||||
|
||||
@ -189,8 +188,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) {
|
||||
return f.programVersion
|
||||
})
|
||||
|
||||
s.ConnectCheckInternet(f.checkInternet)
|
||||
|
||||
s.ConnectSetToRestart(f.restarter.SetToRestart)
|
||||
|
||||
s.ConnectLoadStructureForExport(f.LoadStructureForExport)
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
package qt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@ -173,7 +174,7 @@ func (s *FrontendQt) auth2FA(twoFacAuth string) int {
|
||||
if s.auth == nil || s.authClient == nil {
|
||||
err = fmt.Errorf("missing authentication in auth2FA %p %p", s.auth, s.authClient)
|
||||
} else {
|
||||
err = s.authClient.Auth2FA(twoFacAuth, s.auth)
|
||||
err = s.authClient.Auth2FA(context.Background(), twoFacAuth)
|
||||
}
|
||||
|
||||
if s.showLoginError(err, "auth2FA") {
|
||||
|
||||
@ -191,20 +191,20 @@ func (s *FrontendQt) NotifySilentUpdateError(err error) {
|
||||
func (s *FrontendQt) watchEvents() {
|
||||
s.WaitUntilFrontendIsReady()
|
||||
|
||||
errorCh := s.getEventChannel(events.ErrorEvent)
|
||||
credentialsErrorCh := s.getEventChannel(events.CredentialsErrorEvent)
|
||||
outgoingNoEncCh := s.getEventChannel(events.OutgoingNoEncEvent)
|
||||
noActiveKeyForRecipientCh := s.getEventChannel(events.NoActiveKeyForRecipientEvent)
|
||||
internetOffCh := s.getEventChannel(events.InternetOffEvent)
|
||||
internetOnCh := s.getEventChannel(events.InternetOnEvent)
|
||||
secondInstanceCh := s.getEventChannel(events.SecondInstanceEvent)
|
||||
restartBridgeCh := s.getEventChannel(events.RestartBridgeEvent)
|
||||
addressChangedCh := s.getEventChannel(events.AddressChangedEvent)
|
||||
addressChangedLogoutCh := s.getEventChannel(events.AddressChangedLogoutEvent)
|
||||
logoutCh := s.getEventChannel(events.LogoutEvent)
|
||||
updateApplicationCh := s.getEventChannel(events.UpgradeApplicationEvent)
|
||||
newUserCh := s.getEventChannel(events.UserRefreshEvent)
|
||||
certIssue := s.getEventChannel(events.TLSCertIssue)
|
||||
errorCh := s.eventListener.ProvideChannel(events.ErrorEvent)
|
||||
credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||
outgoingNoEncCh := s.eventListener.ProvideChannel(events.OutgoingNoEncEvent)
|
||||
noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent)
|
||||
internetOffCh := s.eventListener.ProvideChannel(events.InternetOffEvent)
|
||||
internetOnCh := s.eventListener.ProvideChannel(events.InternetOnEvent)
|
||||
secondInstanceCh := s.eventListener.ProvideChannel(events.SecondInstanceEvent)
|
||||
restartBridgeCh := s.eventListener.ProvideChannel(events.RestartBridgeEvent)
|
||||
addressChangedCh := s.eventListener.ProvideChannel(events.AddressChangedEvent)
|
||||
addressChangedLogoutCh := s.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||
logoutCh := s.eventListener.ProvideChannel(events.LogoutEvent)
|
||||
updateApplicationCh := s.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
|
||||
newUserCh := s.eventListener.ProvideChannel(events.UserRefreshEvent)
|
||||
certIssue := s.eventListener.ProvideChannel(events.TLSCertIssue)
|
||||
for {
|
||||
select {
|
||||
case errorDetails := <-errorCh:
|
||||
@ -254,13 +254,6 @@ func (s *FrontendQt) watchEvents() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FrontendQt) getEventChannel(event string) <-chan string {
|
||||
ch := make(chan string)
|
||||
s.eventListener.Add(event, ch)
|
||||
s.eventListener.RetryEmit(event)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Loop function for tests.
|
||||
//
|
||||
// It runs QtExecute in new thread with function returning itself after setup.
|
||||
@ -653,10 +646,6 @@ func (s *FrontendQt) isSMTPSTARTTLS() bool {
|
||||
return !s.settings.GetBool(settings.SMTPSSLKey)
|
||||
}
|
||||
|
||||
func (s *FrontendQt) checkInternet() {
|
||||
s.Qml.SetConnectionStatus(s.bridge.CheckConnection() == nil)
|
||||
}
|
||||
|
||||
func (s *FrontendQt) switchAddressModeUser(iAccount int) {
|
||||
defer s.Qml.ProcessFinished()
|
||||
userID := s.Accounts.get(iAccount).UserID()
|
||||
|
||||
@ -84,7 +84,6 @@ type GoQMLInterface struct {
|
||||
_ string `property:"progressDescription"`
|
||||
|
||||
_ func(isAvailable bool) `signal:"setConnectionStatus"`
|
||||
_ func() `slot:"checkInternet"`
|
||||
|
||||
_ func() `slot:"setToRestart"`
|
||||
|
||||
@ -205,8 +204,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) {
|
||||
return f.programVer
|
||||
})
|
||||
|
||||
s.ConnectCheckInternet(f.checkInternet)
|
||||
|
||||
s.ConnectSetToRestart(f.restarter.SetToRestart)
|
||||
|
||||
s.ConnectToggleIsReportingOutgoingNoEnc(f.toggleIsReportingOutgoingNoEnc)
|
||||
|
||||
@ -55,7 +55,6 @@ type UserManager interface {
|
||||
GetUser(query string) (User, error)
|
||||
DeleteUser(userID string, clearCache bool) error
|
||||
ClearData() error
|
||||
CheckConnection() error
|
||||
}
|
||||
|
||||
// User is an interface of user needed by frontend.
|
||||
|
||||
@ -38,11 +38,10 @@ type bridgeUser interface {
|
||||
IsCombinedAddressMode() bool
|
||||
GetAddressID(address string) (string, error)
|
||||
GetPrimaryAddress() string
|
||||
UpdateUser() error
|
||||
Logout() error
|
||||
CloseConnection(address string)
|
||||
GetStore() storeUserProvider
|
||||
GetTemporaryPMAPIClient() pmapi.Client
|
||||
GetClient() pmapi.Client
|
||||
}
|
||||
|
||||
type bridgeWrap struct {
|
||||
|
||||
@ -422,7 +422,7 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
|
||||
if isStringInList(m.LabelIDs, pmapi.StarredLabel) {
|
||||
messageFlagsMap[imap.FlaggedFlag] = true
|
||||
}
|
||||
if m.Unread == 0 {
|
||||
if !m.Unread {
|
||||
messageFlagsMap[imap.SeenFlag] = true
|
||||
}
|
||||
if m.Has(pmapi.FlagReplied) || m.Has(pmapi.FlagRepliedAll) {
|
||||
@ -560,7 +560,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if storeMessage.Message().Unread == 1 {
|
||||
if storeMessage.Message().Unread {
|
||||
for section := range msg.Body {
|
||||
// Peek means get messages without marking them as read.
|
||||
// If client does not only ask for peek, we have to mark them as read.
|
||||
|
||||
@ -93,7 +93,7 @@ func newIMAPUser(
|
||||
|
||||
// This method should eventually no longer be necessary. Everything should go via store.
|
||||
func (iu *imapUser) client() pmapi.Client {
|
||||
return iu.user.GetTemporaryPMAPIClient()
|
||||
return iu.user.GetClient()
|
||||
}
|
||||
|
||||
func (iu *imapUser) isSubscribed(labelID string) bool {
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/transfer"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
@ -40,6 +41,7 @@ type ImportExport struct {
|
||||
locations Locator
|
||||
cache Cacher
|
||||
panicHandler users.PanicHandler
|
||||
eventListener listener.Listener
|
||||
clientManager pmapi.Manager
|
||||
}
|
||||
|
||||
@ -59,13 +61,14 @@ func New(
|
||||
locations: locations,
|
||||
cache: cache,
|
||||
panicHandler: panicHandler,
|
||||
eventListener: eventListener,
|
||||
clientManager: clientManager,
|
||||
}
|
||||
}
|
||||
|
||||
// ReportBug reports a new bug from the user.
|
||||
func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
||||
return ie.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
|
||||
return ie.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{
|
||||
OS: osType,
|
||||
OSVersion: osVersion,
|
||||
Browser: emailClient,
|
||||
@ -89,7 +92,7 @@ func (ie *ImportExport) ReportFile(osType, osVersion, accountName, address strin
|
||||
|
||||
report.AddAttachment("log", "report.log", bytes.NewReader(logdata))
|
||||
|
||||
return ie.clientManager.ReportBug(context.TODO(), report)
|
||||
return ie.clientManager.ReportBug(context.Background(), report)
|
||||
}
|
||||
|
||||
// GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account.
|
||||
@ -162,5 +165,23 @@ func (ie *ImportExport) getPMAPIProvider(username, address string) (*transfer.PM
|
||||
log.WithError(err).Info("Address does not exist, using all addresses")
|
||||
}
|
||||
|
||||
return transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
|
||||
provider, err := transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
internetOffCh := ie.eventListener.ProvideChannel(events.InternetOffEvent)
|
||||
internetOnCh := ie.eventListener.ProvideChannel(events.InternetOnEvent)
|
||||
for {
|
||||
select {
|
||||
case <-internetOffCh:
|
||||
provider.SetConnectionDown()
|
||||
case <-internetOnCh:
|
||||
provider.SetConnectionUp()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
@ -31,7 +31,7 @@ type bridgeUser interface {
|
||||
CheckBridgeLogin(password string) error
|
||||
IsCombinedAddressMode() bool
|
||||
GetAddressID(address string) (string, error)
|
||||
GetTemporaryPMAPIClient() pmapi.Client
|
||||
GetClient() pmapi.Client
|
||||
GetStore() storeUserProvider
|
||||
}
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ func newSMTPUser(
|
||||
|
||||
// This method should eventually no longer be necessary. Everything should go via store.
|
||||
func (su *smtpUser) client() pmapi.Client {
|
||||
return su.user.GetTemporaryPMAPIClient()
|
||||
return su.user.GetClient()
|
||||
}
|
||||
|
||||
// Send sends an email from the given address to the given addresses with the given body.
|
||||
|
||||
@ -90,7 +90,7 @@ func getLabelPrefix(l *pmapi.Label) string {
|
||||
switch {
|
||||
case pmapi.IsSystemLabel(l.ID):
|
||||
return ""
|
||||
case l.Exclusive == 1:
|
||||
case bool(l.Exclusive):
|
||||
return UserFoldersPrefix
|
||||
default:
|
||||
return UserLabelsPrefix
|
||||
|
||||
@ -37,8 +37,8 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) {
|
||||
m.newStoreNoEvents(true)
|
||||
m.store.SetChangeNotifier(m.changeNotifier)
|
||||
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
}
|
||||
|
||||
func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
|
||||
@ -52,8 +52,8 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
|
||||
m.newStoreNoEvents(true)
|
||||
m.store.SetChangeNotifier(m.changeNotifier)
|
||||
|
||||
msg1 := getTestMessage("msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
msg2 := getTestMessage("msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
msg2 := getTestMessage("msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
require.Nil(t, m.store.createOrUpdateMessagesEvent([]*pmapi.Message{msg1, msg2}))
|
||||
}
|
||||
|
||||
@ -63,8 +63,8 @@ func TestNotifyChangeDeleteMessage(t *testing.T) {
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
|
||||
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(2))
|
||||
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(1))
|
||||
|
||||
@ -81,7 +81,7 @@ func (loop *eventLoop) client() pmapi.Client {
|
||||
func (loop *eventLoop) setFirstEventID() (err error) {
|
||||
loop.log.Info("Setting first event ID")
|
||||
|
||||
event, err := loop.client().GetEvent(context.TODO(), "")
|
||||
event, err := loop.client().GetEvent(context.Background(), "")
|
||||
if err != nil {
|
||||
loop.log.WithError(err).Error("Could not get latest event ID")
|
||||
return
|
||||
@ -222,8 +222,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
|
||||
// (e.g. no internet, ulimit reached etc.)
|
||||
defer func() {
|
||||
// FIXME(conman): How to handle errors of different types?
|
||||
if errors.Is(err, pmapi.ErrNoConnection) {
|
||||
if errors.Cause(err) == pmapi.ErrNoConnection {
|
||||
l.Warn("Internet unavailable")
|
||||
err = nil
|
||||
}
|
||||
@ -234,20 +233,17 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
err = nil
|
||||
}
|
||||
|
||||
// FIXME(conman): Handle force upgrade.
|
||||
/*
|
||||
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
|
||||
l.Warn("Need to upgrade application")
|
||||
err = nil
|
||||
}
|
||||
*/
|
||||
|
||||
if err == nil {
|
||||
loop.errCounter = 0
|
||||
}
|
||||
|
||||
// All errors except ErrUnauthorized (which is not possible to recover from) are ignored.
|
||||
if !errors.Is(err, pmapi.ErrUnauthorized) {
|
||||
if err != nil && errors.Cause(err) != pmapi.ErrUnauthorized {
|
||||
l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped")
|
||||
loop.errCounter++
|
||||
if loop.errCounter == errMaxSentry {
|
||||
@ -268,7 +264,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
loop.pollCounter++
|
||||
|
||||
var event *pmapi.Event
|
||||
if event, err = loop.client().GetEvent(context.TODO(), loop.currentEventID); err != nil {
|
||||
if event, err = loop.client().GetEvent(context.Background(), loop.currentEventID); err != nil {
|
||||
return false, errors.Wrap(err, "failed to get event")
|
||||
}
|
||||
|
||||
@ -295,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
}
|
||||
}
|
||||
|
||||
return event.More == 1, err
|
||||
return bool(event.More), err
|
||||
}
|
||||
|
||||
func (loop *eventLoop) processEvent(event *pmapi.Event) (err error) {
|
||||
@ -354,7 +350,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
|
||||
// Get old addresses for comparisons before updating user.
|
||||
oldList := loop.client().Addresses()
|
||||
|
||||
if err = loop.user.UpdateUser(); err != nil {
|
||||
if err = loop.user.UpdateUser(context.Background()); err != nil {
|
||||
if logoutErr := loop.user.Logout(); logoutErr != nil {
|
||||
log.WithError(logoutErr).Error("Failed to logout user after failed update")
|
||||
}
|
||||
@ -465,16 +461,12 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
|
||||
|
||||
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
|
||||
|
||||
if msg, err = loop.client().GetMessage(context.TODO(), message.ID); err != nil {
|
||||
// FIXME(conman): How to handle error of this particular type?
|
||||
|
||||
/*
|
||||
if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok {
|
||||
if msg, err = loop.client().GetMessage(context.Background(), message.ID); err != nil {
|
||||
if _, ok := err.(pmapi.ErrUnprocessableEntity); ok {
|
||||
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
|
||||
err = nil
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
return errors.Wrap(err, "failed to get message from API for updating")
|
||||
}
|
||||
|
||||
@ -42,15 +42,15 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
|
||||
// next event if there is `More` of them.
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{
|
||||
EventID: "event50",
|
||||
More: 1,
|
||||
More: true,
|
||||
}, nil),
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{
|
||||
EventID: "event70",
|
||||
More: 0,
|
||||
More: false,
|
||||
}, nil),
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{
|
||||
EventID: "event71",
|
||||
More: 0,
|
||||
More: false,
|
||||
}, nil),
|
||||
)
|
||||
m.newStoreNoEvents(true)
|
||||
@ -188,7 +188,7 @@ func TestEventLoopUpdateMessage(t *testing.T) {
|
||||
msg := &pmapi.Message{
|
||||
ID: "msg1",
|
||||
Subject: "old",
|
||||
Unread: 0,
|
||||
Unread: false,
|
||||
Flags: 10,
|
||||
Sender: address1,
|
||||
ToList: []*mail.Address{address2},
|
||||
@ -200,7 +200,7 @@ func TestEventLoopUpdateMessage(t *testing.T) {
|
||||
newMsg := &pmapi.Message{
|
||||
ID: "msg1",
|
||||
Subject: "new",
|
||||
Unread: 1,
|
||||
Unread: true,
|
||||
Flags: 11,
|
||||
Sender: address2,
|
||||
ToList: []*mail.Address{address1},
|
||||
|
||||
@ -129,17 +129,10 @@ func (mc *mailboxCounts) getPMLabel() *pmapi.Label {
|
||||
Color: mc.Color,
|
||||
Order: mc.Order,
|
||||
Type: pmapi.LabelTypeMailbox,
|
||||
Exclusive: mc.isExclusive(),
|
||||
Exclusive: pmapi.Boolean(mc.IsFolder),
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *mailboxCounts) isExclusive() int {
|
||||
if mc.IsFolder {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// createOrUpdateMailboxCountsBuckets will not change the on-API-counts.
|
||||
func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) error {
|
||||
// Don't forget about system folders.
|
||||
@ -162,7 +155,7 @@ func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) er
|
||||
mailbox.LabelName = label.Path
|
||||
mailbox.Color = label.Color
|
||||
mailbox.Order = label.Order
|
||||
mailbox.IsFolder = label.Exclusive == 1
|
||||
mailbox.IsFolder = bool(label.Exclusive)
|
||||
|
||||
// Write.
|
||||
if err = mailbox.txWriteToBucket(countsBkt); err != nil {
|
||||
|
||||
@ -75,7 +75,7 @@ func TestMailboxNames(t *testing.T) {
|
||||
newLabel(100, "labelID1", "Label1"),
|
||||
newLabel(1000, "folderID1", "Folder1"),
|
||||
}
|
||||
foldersAndLabels[1].Exclusive = 1
|
||||
foldersAndLabels[1].Exclusive = true
|
||||
|
||||
for _, counts := range getSystemFolders() {
|
||||
foldersAndLabels = append(foldersAndLabels, counts.getPMLabel())
|
||||
|
||||
@ -37,10 +37,10 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) {
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
|
||||
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
|
||||
insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
|
||||
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
|
||||
|
||||
@ -82,20 +82,20 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
|
||||
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||
|
||||
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
tstMsg.ExternalID = " externalID-non-pm-com "
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||
|
||||
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
tstMsg.ExternalID = "<externalID@pm.me>"
|
||||
tstMsg.Header = mail.Header{"References": []string{"wrongID", "externalID-non-pm-com", "msg2"}}
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||
|
||||
// Not sure if this is a real-world scenario but we should be able to address this properly.
|
||||
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
tstMsg.ExternalID = " < external.()+*[]ID@another.pm.me > "
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||
|
||||
|
||||
@ -18,8 +18,6 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -43,7 +41,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
|
||||
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
|
||||
// wrapping it.
|
||||
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
|
||||
msg, err := storeMailbox.client().GetMessage(context.TODO(), apiID)
|
||||
msg, err := storeMailbox.client().GetMessage(exposeContextForIMAP(), apiID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -70,7 +68,7 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
|
||||
Message: body,
|
||||
}
|
||||
|
||||
res, err := storeMailbox.client().Import(context.TODO(), pmapi.ImportMsgReqs{importReqs})
|
||||
res, err := storeMailbox.client().Import(exposeContextForIMAP(), pmapi.ImportMsgReqs{importReqs})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -99,7 +97,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
|
||||
return ErrAllMailOpNotAllowed
|
||||
}
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
|
||||
return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// UnlabelMessages removes the label by calling an API.
|
||||
@ -112,7 +110,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
|
||||
return ErrAllMailOpNotAllowed
|
||||
}
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
|
||||
return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// MarkMessagesRead marks the message read by calling an API.
|
||||
@ -132,14 +130,14 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
|
||||
// Therefore we do not issue API update if the message is already read.
|
||||
ids := []string{}
|
||||
for _, apiID := range apiIDs {
|
||||
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread == 1 {
|
||||
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread {
|
||||
ids = append(ids, apiID)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
return storeMailbox.client().MarkMessagesRead(context.TODO(), ids)
|
||||
return storeMailbox.client().MarkMessagesRead(exposeContextForIMAP(), ids)
|
||||
}
|
||||
|
||||
// MarkMessagesUnread marks the message unread by calling an API.
|
||||
@ -151,7 +149,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as unread")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().MarkMessagesUnread(context.TODO(), apiIDs)
|
||||
return storeMailbox.client().MarkMessagesUnread(exposeContextForIMAP(), apiIDs)
|
||||
}
|
||||
|
||||
// MarkMessagesStarred adds the Starred label by calling an API.
|
||||
@ -164,7 +162,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as starred")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
|
||||
return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel)
|
||||
}
|
||||
|
||||
// MarkMessagesUnstarred removes the Starred label by calling an API.
|
||||
@ -177,7 +175,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as unstarred")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
|
||||
return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel)
|
||||
}
|
||||
|
||||
// MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API
|
||||
@ -261,11 +259,11 @@ func (storeMailbox *Mailbox) RemoveDeleted(apiIDs []string) error {
|
||||
}
|
||||
case pmapi.DraftLabel:
|
||||
storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts")
|
||||
if err := storeMailbox.client().DeleteMessages(context.TODO(), apiIDs); err != nil {
|
||||
if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), apiIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if err := storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID); err != nil {
|
||||
if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -303,13 +301,13 @@ func (storeMailbox *Mailbox) deleteFromTrashOrSpam(apiIDs []string) error {
|
||||
}
|
||||
}
|
||||
if len(messageIDsToUnlabel) > 0 {
|
||||
if err := storeMailbox.client().UnlabelMessages(context.TODO(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
||||
if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
||||
l.WithError(err).Warning("Cannot unlabel before deleting")
|
||||
}
|
||||
}
|
||||
if len(messageIDsToDelete) > 0 {
|
||||
storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages")
|
||||
if err := storeMailbox.client().DeleteMessages(context.TODO(), messageIDsToDelete); err != nil {
|
||||
if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), messageIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,10 +5,10 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
context "context"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockPanicHandler is a mock of PanicHandler interface
|
||||
@ -207,17 +207,17 @@ func (mr *MockBridgeUserMockRecorder) Logout() *gomock.Call {
|
||||
}
|
||||
|
||||
// UpdateUser mocks base method
|
||||
func (m *MockBridgeUser) UpdateUser() error {
|
||||
func (m *MockBridgeUser) UpdateUser(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUser")
|
||||
ret := m.ctrl.Call(m, "UpdateUser", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateUser indicates an expected call of UpdateUser
|
||||
func (mr *MockBridgeUserMockRecorder) UpdateUser() *gomock.Call {
|
||||
func (mr *MockBridgeUserMockRecorder) UpdateUser(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser), arg0)
|
||||
}
|
||||
|
||||
// MockChangeNotifier is a mock of ChangeNotifier interface
|
||||
|
||||
@ -5,10 +5,9 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockListener is a mock of Listener interface
|
||||
@ -58,6 +57,20 @@ func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
|
||||
}
|
||||
|
||||
// ProvideChannel mocks base method
|
||||
func (m *MockListener) ProvideChannel(arg0 string) <-chan string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ProvideChannel", arg0)
|
||||
ret0, _ := ret[0].(<-chan string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ProvideChannel indicates an expected call of ProvideChannel
|
||||
func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -101,6 +101,18 @@ var (
|
||||
ErrNoSuchSeqNum = errors.New("no such sequence number") //nolint[gochecknoglobals]
|
||||
)
|
||||
|
||||
// exposeContextForIMAP should be replaced once with context passed
|
||||
// as an argument from IMAP package and IMAP library should cancel
|
||||
// context when IMAP client cancels the request.
|
||||
func exposeContextForIMAP() context.Context {
|
||||
return context.TODO()
|
||||
}
|
||||
|
||||
// exposeContextForSMTP is the same as above but for SMTP.
|
||||
func exposeContextForSMTP() context.Context {
|
||||
return context.TODO()
|
||||
}
|
||||
|
||||
// Store is local user storage, which handles the synchronization between IMAP and PM API.
|
||||
type Store struct {
|
||||
sentryReporter *sentry.Reporter
|
||||
@ -278,7 +290,7 @@ func (store *Store) client() pmapi.Client {
|
||||
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
|
||||
// the API is unavailable for whatever reason it tries to fetch the labels locally.
|
||||
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
|
||||
if labels, err = store.client().ListLabels(context.TODO()); err != nil {
|
||||
if labels, err = store.client().ListLabels(context.Background()); err != nil {
|
||||
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
|
||||
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
|
||||
store.log.WithError(err).Error("Cannot list local labels")
|
||||
|
||||
@ -184,8 +184,8 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
|
||||
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
|
||||
|
||||
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
|
||||
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
|
||||
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
|
||||
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true},
|
||||
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true},
|
||||
})
|
||||
mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes()
|
||||
mocks.client.EXPECT().CountMessages(gomock.Any(), "")
|
||||
|
||||
@ -148,7 +148,7 @@ func getSplitIDAndCount(labelID string, api messageLister, page int) (string, in
|
||||
Limit: 1,
|
||||
}
|
||||
// If the page does not exist, an empty page instead of an error is returned.
|
||||
messages, total, err := api.ListMessages(context.TODO(), filter)
|
||||
messages, total, err := api.ListMessages(context.Background(), filter)
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrap(err, "failed to list messages")
|
||||
}
|
||||
@ -190,7 +190,7 @@ func syncBatch( //nolint[funlen]
|
||||
|
||||
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
|
||||
|
||||
messages, _, err := api.ListMessages(context.TODO(), filter)
|
||||
messages, _, err := api.ListMessages(context.Background(), filter)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to list messages")
|
||||
}
|
||||
|
||||
@ -17,7 +17,11 @@
|
||||
|
||||
package store
|
||||
|
||||
import "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
type PanicHandler interface {
|
||||
HandlePanic()
|
||||
@ -32,7 +36,7 @@ type BridgeUser interface {
|
||||
GetPrimaryAddress() string
|
||||
GetStoreAddresses() []string
|
||||
GetClient() pmapi.Client
|
||||
UpdateUser() error
|
||||
UpdateUser(context.Context) error
|
||||
CloseAllConnections()
|
||||
CloseConnection(string)
|
||||
Logout() error
|
||||
|
||||
@ -17,8 +17,6 @@
|
||||
|
||||
package store
|
||||
|
||||
import "context"
|
||||
|
||||
// UserID returns user ID.
|
||||
func (store *Store) UserID() string {
|
||||
return store.user.ID()
|
||||
@ -26,7 +24,7 @@ func (store *Store) UserID() string {
|
||||
|
||||
// GetSpace returns used and total space in bytes.
|
||||
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
||||
apiUser, err := store.client().CurrentUser(context.TODO())
|
||||
apiUser, err := store.client().CurrentUser(exposeContextForIMAP())
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
@ -35,7 +33,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
||||
|
||||
// GetMaxUpload returns max size of message + all attachments in bytes.
|
||||
func (store *Store) GetMaxUpload() (int64, error) {
|
||||
apiUser, err := store.client().CurrentUser(context.TODO())
|
||||
apiUser, err := store.client().CurrentUser(exposeContextForIMAP())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@ -147,7 +147,7 @@ func (store *Store) createOrUpdateAddressInfo(addressList pmapi.AddressList) (er
|
||||
// filterAddresses filters out inactive addresses and ensures the original address is listed first.
|
||||
func filterAddresses(addressList pmapi.AddressList) (filteredList pmapi.AddressList) {
|
||||
for _, address := range addressList {
|
||||
if address.Receive != pmapi.CanReceive {
|
||||
if !address.Receive {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@ -39,14 +38,14 @@ func (store *Store) createMailbox(name string) error {
|
||||
|
||||
color := store.leastUsedColor()
|
||||
|
||||
var exclusive int
|
||||
var exclusive bool
|
||||
switch {
|
||||
case strings.HasPrefix(name, UserLabelsPrefix):
|
||||
name = strings.TrimPrefix(name, UserLabelsPrefix)
|
||||
exclusive = 0
|
||||
exclusive = false
|
||||
case strings.HasPrefix(name, UserFoldersPrefix):
|
||||
name = strings.TrimPrefix(name, UserFoldersPrefix)
|
||||
exclusive = 1
|
||||
exclusive = true
|
||||
default:
|
||||
// Ideally we would throw an error here, but then Outlook for
|
||||
// macOS keeps trying to make an IMAP Drafts folder and popping
|
||||
@ -56,10 +55,10 @@ func (store *Store) createMailbox(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := store.client().CreateLabel(context.TODO(), &pmapi.Label{
|
||||
_, err := store.client().CreateLabel(exposeContextForIMAP(), &pmapi.Label{
|
||||
Name: name,
|
||||
Color: color,
|
||||
Exclusive: exclusive,
|
||||
Exclusive: pmapi.Boolean(exclusive),
|
||||
Type: pmapi.LabelTypeMailbox,
|
||||
})
|
||||
return err
|
||||
@ -126,7 +125,7 @@ func (store *Store) leastUsedColor() string {
|
||||
func (store *Store) updateMailbox(labelID, newName, color string) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
|
||||
_, err := store.client().UpdateLabel(context.TODO(), &pmapi.Label{
|
||||
_, err := store.client().UpdateLabel(exposeContextForIMAP(), &pmapi.Label{
|
||||
ID: labelID,
|
||||
Name: newName,
|
||||
Color: color,
|
||||
@ -143,15 +142,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
|
||||
var err error
|
||||
switch labelID {
|
||||
case pmapi.SpamLabel:
|
||||
err = store.client().EmptyFolder(context.TODO(), pmapi.SpamLabel, addressID)
|
||||
err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.SpamLabel, addressID)
|
||||
case pmapi.TrashLabel:
|
||||
err = store.client().EmptyFolder(context.TODO(), pmapi.TrashLabel, addressID)
|
||||
err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.TrashLabel, addressID)
|
||||
default:
|
||||
err = fmt.Errorf("cannot empty mailbox %v", labelID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return store.client().DeleteLabel(context.TODO(), labelID)
|
||||
return store.client().DeleteLabel(exposeContextForIMAP(), labelID)
|
||||
}
|
||||
|
||||
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
|
||||
@ -166,7 +165,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
labels, err := store.client().ListLabels(context.TODO())
|
||||
labels, err := store.client().ListLabels(exposeContextForIMAP())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -19,7 +19,6 @@ package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -58,7 +57,7 @@ func (store *Store) CreateDraft(
|
||||
}
|
||||
|
||||
draftAction := store.getDraftAction(message)
|
||||
draft, err := store.client().CreateDraft(context.TODO(), message, parentID, draftAction)
|
||||
draft, err := store.client().CreateDraft(exposeContextForSMTP(), message, parentID, draftAction)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create draft")
|
||||
}
|
||||
@ -70,7 +69,7 @@ func (store *Store) CreateDraft(
|
||||
for _, att := range attachments {
|
||||
att.attachment.MessageID = draft.ID
|
||||
|
||||
createdAttachment, err := store.client().CreateAttachment(context.TODO(), att.attachment, att.encReader, att.sigReader)
|
||||
createdAttachment, err := store.client().CreateAttachment(exposeContextForSMTP(), att.attachment, att.encReader, att.sigReader)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create attachment")
|
||||
}
|
||||
@ -184,7 +183,7 @@ func (store *Store) getDraftAction(message *pmapi.Message) int {
|
||||
// SendMessage sends the message.
|
||||
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
_, _, err := store.client().SendMessage(context.TODO(), messageID, req)
|
||||
_, _, err := store.client().SendMessage(exposeContextForSMTP(), messageID, req)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/golang/mock/gomock"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -34,10 +35,10 @@ func TestGetAllMessageIDs(t *testing.T) {
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
|
||||
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{})
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
|
||||
insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{})
|
||||
|
||||
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
|
||||
}
|
||||
@ -47,7 +48,7 @@ func TestGetMessageFromDB(t *testing.T) {
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
|
||||
tests := []struct{ msgID, wantErr string }{
|
||||
{"msg1", ""},
|
||||
@ -72,7 +73,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
|
||||
msg, err := m.store.getMessageFromDB("msg1")
|
||||
require.Nil(t, err)
|
||||
@ -104,7 +105,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
|
||||
a.Equal(t, wantHeader, msg.Header)
|
||||
|
||||
// Check calculated data are not overridden by reinsert.
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
|
||||
msg, err = m.store.getMessageFromDB("msg1")
|
||||
require.Nil(t, err)
|
||||
@ -118,8 +119,8 @@ func TestDeleteMessage(t *testing.T) {
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
|
||||
|
||||
require.Nil(t, m.store.deleteMessageEvent("msg1"))
|
||||
|
||||
@ -127,17 +128,17 @@ func TestDeleteMessage(t *testing.T) {
|
||||
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
|
||||
}
|
||||
|
||||
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread pmapi.Boolean, labelIDs []string) { //nolint[unparam]
|
||||
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam]
|
||||
msg := getTestMessage(id, subject, sender, unread, labelIDs)
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
|
||||
}
|
||||
|
||||
func getTestMessage(id, subject, sender string, unread pmapi.Boolean, labelIDs []string) *pmapi.Message {
|
||||
func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message {
|
||||
address := &mail.Address{Address: sender}
|
||||
return &pmapi.Message{
|
||||
ID: id,
|
||||
Subject: subject,
|
||||
Unread: unread,
|
||||
Unread: pmapi.Boolean(unread),
|
||||
Sender: address,
|
||||
ToList: []*mail.Address{address},
|
||||
LabelIDs: labelIDs,
|
||||
@ -162,7 +163,7 @@ func TestCreateDraftCheckMessageSize(t *testing.T) {
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(false)
|
||||
m.client.EXPECT().CurrentUser().Return(&pmapi.User{
|
||||
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
|
||||
MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+.
|
||||
}, nil)
|
||||
|
||||
@ -181,7 +182,7 @@ func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) {
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(false)
|
||||
m.client.EXPECT().CurrentUser().Return(&pmapi.User{
|
||||
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
|
||||
MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+.
|
||||
}, nil)
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
|
||||
|
||||
// updateCountsFromServer will download and set the counts.
|
||||
func (store *Store) updateCountsFromServer() error {
|
||||
counts, err := store.client().CountMessages(context.TODO(), "")
|
||||
counts, err := store.client().CountMessages(context.Background(), "")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot update counts from server")
|
||||
}
|
||||
|
||||
@ -31,8 +31,8 @@ func TestLoadSaveSyncState(t *testing.T) {
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
|
||||
// Clear everything.
|
||||
|
||||
|
||||
@ -5,11 +5,10 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
imap "github.com/emersion/go-imap"
|
||||
sasl "github.com/emersion/go-sasl"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockPanicHandler is a mock of PanicHandler interface
|
||||
|
||||
@ -19,7 +19,9 @@ package transfer
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -37,6 +39,8 @@ const (
|
||||
imapRetries = 10
|
||||
imapReconnectTimeout = 30 * time.Minute
|
||||
imapReconnectSleep = time.Minute
|
||||
|
||||
protonStatusURL = "http://protonstatus.com/vpn_status"
|
||||
)
|
||||
|
||||
type imapErrorLogger struct {
|
||||
@ -117,19 +121,15 @@ func (p *IMAPProvider) tryReconnect(ensureSelectedIn string) error {
|
||||
return previousErr
|
||||
}
|
||||
|
||||
// FIXME(conman): This should register as connection observer.
|
||||
|
||||
/*
|
||||
err := pmapi.CheckConnection()
|
||||
err := checkConnection()
|
||||
log.WithError(err).Debug("Connection check")
|
||||
if err != nil {
|
||||
time.Sleep(imapReconnectSleep)
|
||||
previousErr = err
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
err := p.reauth()
|
||||
err = p.reauth()
|
||||
log.WithError(err).Debug("Reauth")
|
||||
if err != nil {
|
||||
time.Sleep(imapReconnectSleep)
|
||||
@ -289,3 +289,23 @@ func (p *IMAPProvider) fetchHelper(uid bool, ensureSelectedIn string, seqSet *im
|
||||
return err
|
||||
}, ensureSelectedIn)
|
||||
}
|
||||
|
||||
// checkConnection returns an error if there is no internet connection.
|
||||
// Note we don't want to use client manager because it only reports connection
|
||||
// issues with API; we are only interested here whether we can reach
|
||||
// third-party IMAP servers.
|
||||
func checkConnection() error {
|
||||
client := &http.Client{Timeout: time.Second * 10}
|
||||
|
||||
resp, err := client.Get(protonStatusURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("HTTP status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -52,15 +52,10 @@ func (p *PMAPIProvider) CreateMailbox(mailbox Mailbox) (Mailbox, error) {
|
||||
return Mailbox{}, errors.New("mailbox is already created")
|
||||
}
|
||||
|
||||
exclusive := 0
|
||||
if mailbox.IsExclusive {
|
||||
exclusive = 1
|
||||
}
|
||||
|
||||
label, err := p.client.CreateLabel(context.TODO(), &pmapi.Label{
|
||||
label, err := p.client.CreateLabel(context.Background(), &pmapi.Label{
|
||||
Name: mailbox.Name,
|
||||
Color: mailbox.Color,
|
||||
Exclusive: exclusive,
|
||||
Exclusive: pmapi.Boolean(mailbox.IsExclusive),
|
||||
Type: pmapi.LabelTypeMailbox,
|
||||
})
|
||||
if err != nil {
|
||||
@ -126,7 +121,7 @@ func (p *PMAPIProvider) importDraft(msg Message, globalMailbox *Mailbox) (string
|
||||
}
|
||||
|
||||
if message.Sender == nil {
|
||||
mainAddress := p.client().Addresses().Main()
|
||||
mainAddress := p.client.Addresses().Main()
|
||||
message.Sender = &mail.Address{
|
||||
Name: mainAddress.DisplayName,
|
||||
Address: mainAddress.Email,
|
||||
@ -227,14 +222,6 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
|
||||
}
|
||||
}
|
||||
|
||||
var unread pmapi.Boolean
|
||||
|
||||
if msg.Unread {
|
||||
unread = pmapi.True
|
||||
} else {
|
||||
unread = pmapi.False
|
||||
}
|
||||
|
||||
labelIDs := []string{}
|
||||
for _, target := range msg.Targets {
|
||||
// Frontend should not set All Mail to Rules, but to be sure...
|
||||
@ -249,7 +236,7 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
|
||||
return &pmapi.ImportMsgReq{
|
||||
Metadata: &pmapi.ImportMetadata{
|
||||
AddressID: p.addressID,
|
||||
Unread: unread,
|
||||
Unread: pmapi.Boolean(msg.Unread),
|
||||
Time: message.Time,
|
||||
Flags: computeMessageFlags(message.Header),
|
||||
LabelIDs: labelIDs,
|
||||
|
||||
@ -153,10 +153,10 @@ func setupPMAPIRules(rules transferRules) {
|
||||
func setupPMAPIClientExpectationForExport(m *mocks) {
|
||||
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{
|
||||
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: 0, Order: 2},
|
||||
{ID: "label2", Name: "Bar", Color: "green", Exclusive: 0, Order: 1},
|
||||
{ID: "folder1", Name: "One", Color: "red", Exclusive: 1, Order: 1},
|
||||
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: 1, Order: 2},
|
||||
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: false, Order: 2},
|
||||
{ID: "label2", Name: "Bar", Color: "green", Exclusive: false, Order: 1},
|
||||
{ID: "folder1", Name: "One", Color: "red", Exclusive: true, Order: 1},
|
||||
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: true, Order: 2},
|
||||
}, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{
|
||||
{LabelID: "label1", Total: 10},
|
||||
|
||||
@ -30,9 +30,17 @@ import (
|
||||
const (
|
||||
pmapiRetries = 10
|
||||
pmapiReconnectTimeout = 30 * time.Minute
|
||||
pmapiReconnectSleep = time.Minute
|
||||
pmapiReconnectSleep = 10 * time.Second
|
||||
)
|
||||
|
||||
func (p *PMAPIProvider) SetConnectionUp() {
|
||||
p.connection = true
|
||||
}
|
||||
|
||||
func (p *PMAPIProvider) SetConnectionDown() {
|
||||
p.connection = false
|
||||
}
|
||||
|
||||
func (p *PMAPIProvider) ensureConnection(callback func() error) error {
|
||||
var callErr error
|
||||
for i := 1; i <= pmapiRetries; i++ {
|
||||
@ -58,18 +66,10 @@ func (p *PMAPIProvider) tryReconnect() error {
|
||||
return previousErr
|
||||
}
|
||||
|
||||
// FIXME(conman): This should register as a connection observer somehow...
|
||||
// Maybe the entire "provider" could register as an observer and pause if it is notified of dropped connection?
|
||||
|
||||
/*
|
||||
err := p.clientManager.CheckConnection()
|
||||
log.WithError(err).Debug("Connection check")
|
||||
if err != nil {
|
||||
if !p.connection {
|
||||
time.Sleep(pmapiReconnectSleep)
|
||||
previousErr = err
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
break
|
||||
}
|
||||
@ -83,7 +83,7 @@ func (p *PMAPIProvider) listMessages(filter *pmapi.MessagesFilter) (messages []*
|
||||
p.timeIt.start("listing", key)
|
||||
defer p.timeIt.stop("listing", key)
|
||||
|
||||
messages, count, err = p.client.ListMessages(context.TODO(), filter)
|
||||
messages, count, err = p.client.ListMessages(context.Background(), filter)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -94,7 +94,7 @@ func (p *PMAPIProvider) getMessage(msgID string) (message *pmapi.Message, err er
|
||||
p.timeIt.start("download", msgID)
|
||||
defer p.timeIt.stop("download", msgID)
|
||||
|
||||
message, err = p.client.GetMessage(context.TODO(), msgID)
|
||||
message, err = p.client.GetMessage(context.Background(), msgID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -105,7 +105,7 @@ func (p *PMAPIProvider) importRequest(msgSourceID string, req pmapi.ImportMsgReq
|
||||
p.timeIt.start("upload", msgSourceID)
|
||||
defer p.timeIt.stop("upload", msgSourceID)
|
||||
|
||||
res, err = p.client.Import(context.TODO(), req)
|
||||
res, err = p.client.Import(context.Background(), req)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -116,7 +116,7 @@ func (p *PMAPIProvider) createDraft(msgSourceID string, message *pmapi.Message,
|
||||
p.timeIt.start("upload", msgSourceID)
|
||||
defer p.timeIt.stop("upload", msgSourceID)
|
||||
|
||||
draft, err = p.client.CreateDraft(context.TODO(), message, parent, action)
|
||||
draft, err = p.client.CreateDraft(context.Background(), message, parent, action)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -129,7 +129,7 @@ func (p *PMAPIProvider) createAttachment(msgSourceID string, att *pmapi.Attachme
|
||||
p.timeIt.start("upload", key)
|
||||
defer p.timeIt.stop("upload", key)
|
||||
|
||||
created, err = p.client.CreateAttachment(context.TODO(), att, r, sig)
|
||||
created, err = p.client.CreateAttachment(context.Background(), att, r, sig)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
||||
@ -28,6 +28,7 @@ import (
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -274,7 +275,7 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func newTestUpdater(manager *mocks.MockManager, curVer string, earlyAccess bool) *Updater {
|
||||
func newTestUpdater(manager pmapi.Manager, curVer string, earlyAccess bool) *Updater {
|
||||
return New(
|
||||
manager,
|
||||
&fakeInstaller{},
|
||||
|
||||
@ -1,251 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
||||
|
||||
m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().ReloadKeys([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
|
||||
)
|
||||
|
||||
assert.NoError(t, user.UpdateUser())
|
||||
|
||||
waitForEvents()
|
||||
}
|
||||
|
||||
func TestUserSwitchAddressMode(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
assert.True(t, user.store.IsCombinedMode())
|
||||
assert.True(t, user.creds.IsCombinedAddressMode)
|
||||
assert.True(t, user.IsCombinedAddressMode())
|
||||
waitForEvents()
|
||||
|
||||
gomock.InOrder(
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsSplit, nil),
|
||||
)
|
||||
|
||||
assert.NoError(t, user.SwitchAddressMode())
|
||||
assert.False(t, user.store.IsCombinedMode())
|
||||
assert.False(t, user.creds.IsCombinedAddressMode)
|
||||
assert.False(t, user.IsCombinedAddressMode())
|
||||
|
||||
gomock.InOrder(
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
)
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
|
||||
assert.NoError(t, user.SwitchAddressMode())
|
||||
assert.True(t, user.store.IsCombinedMode())
|
||||
assert.True(t, user.creds.IsCombinedAddressMode)
|
||||
assert.True(t, user.IsCombinedAddressMode())
|
||||
|
||||
waitForEvents()
|
||||
}
|
||||
|
||||
func TestLogoutUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUserForLogout(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Logout().Return(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
err := user.Logout()
|
||||
|
||||
waitForEvents()
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLogoutUserFailsLogout(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUserForLogout(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Logout().Return(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
)
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
err := user.Logout()
|
||||
waitForEvents()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginOK(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
||||
)
|
||||
|
||||
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
|
||||
|
||||
waitForEvents()
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginTwiceOK(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().IsUnlocked().Return(true),
|
||||
)
|
||||
|
||||
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
|
||||
waitForEvents()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = user.CheckBridgeLogin(testCredentials.BridgePassword)
|
||||
waitForEvents()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
|
||||
|
||||
isApplicationOutdated = true
|
||||
|
||||
err := user.CheckBridgeLogin("any-pass")
|
||||
waitForEvents()
|
||||
assert.Equal(t, pmapi.ErrUpgradeApplication, err)
|
||||
|
||||
isApplicationOutdated = false
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginLoggedOut(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
|
||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
|
||||
assert.NoError(t, err)
|
||||
|
||||
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient).MinTimes(1)
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
)
|
||||
|
||||
err = user.init()
|
||||
assert.Error(t, err)
|
||||
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
|
||||
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
|
||||
waitForEvents()
|
||||
assert.Equal(t, ErrLoggedOutUser, err)
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginBadPassword(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
||||
)
|
||||
|
||||
err := user.CheckBridgeLogin("wrong!")
|
||||
waitForEvents()
|
||||
assert.Equal(t, "backend/credentials: incorrect password", err.Error())
|
||||
}
|
||||
@ -1,112 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewUserNoCredentialsStore(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
|
||||
|
||||
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
|
||||
a.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewUserAuthRefreshFails(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
)
|
||||
|
||||
checkNewUserHasCredentials(testCredentialsDisconnected, m)
|
||||
}
|
||||
|
||||
func TestNewUserUnlockFails(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
|
||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(errors.New("bad password")),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
)
|
||||
|
||||
checkNewUserHasCredentials(testCredentialsDisconnected, m)
|
||||
}
|
||||
|
||||
func TestNewUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
mockConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
checkNewUserHasCredentials(testCredentials, m)
|
||||
}
|
||||
|
||||
func checkNewUserHasCredentials(creds *credentials.Credentials, m mocks) {
|
||||
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
_ = user.init()
|
||||
|
||||
waitForEvents()
|
||||
|
||||
a.Equal(m.t, creds, user.creds)
|
||||
}
|
||||
|
||||
func _TestUserEventRefreshUpdatesAddresses(t *testing.T) { // nolint[funlen]
|
||||
a.Fail(t, "not implemented")
|
||||
}
|
||||
@ -1,89 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testNewUser sets up a new, authorised user.
|
||||
func testNewUser(m mocks) *User {
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
mockConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
err = user.init()
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
mockAuthUpdate(user, "reftok", m)
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func testNewUserForLogout(m mocks) *User {
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
mockConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
err = user.init()
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func cleanUpUserData(u *User) {
|
||||
_ = u.clearStore()
|
||||
}
|
||||
|
||||
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
|
||||
assert.Fail(t, "not implemented")
|
||||
}
|
||||
|
||||
func TestClearStoreWithStore(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUserForLogout(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
require.Nil(t, user.store.Close())
|
||||
user.store = nil
|
||||
assert.Nil(t, user.clearStore())
|
||||
}
|
||||
|
||||
func TestClearStoreWithoutStore(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUserForLogout(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
assert.NotNil(t, user.store)
|
||||
assert.Nil(t, user.clearStore())
|
||||
}
|
||||
@ -1,143 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetNoUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
checkUsersGetUser(t, m, "nouser", -1, "user nouser not found")
|
||||
}
|
||||
|
||||
func TestGetUserByID(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
checkUsersGetUser(t, m, "user", 0, "")
|
||||
checkUsersGetUser(t, m, "users", 1, "")
|
||||
}
|
||||
|
||||
func TestGetUserByName(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
checkUsersGetUser(t, m, "username", 0, "")
|
||||
checkUsersGetUser(t, m, "usersname", 1, "")
|
||||
}
|
||||
|
||||
func TestGetUserByEmail(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
checkUsersGetUser(t, m, "user@pm.me", 0, "")
|
||||
checkUsersGetUser(t, m, "users@pm.me", 1, "")
|
||||
checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "")
|
||||
checkUsersGetUser(t, m, "alsouser@pm.me", 1, "")
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Logout().Return(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
err := users.DeleteUser("user", true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(users.users))
|
||||
}
|
||||
|
||||
// Even when logout fails, delete is done.
|
||||
func TestDeleteUserWithFailingLogout(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Logout().Return(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("no such user")),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
err := users.DeleteUser("user", true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(users.users))
|
||||
}
|
||||
|
||||
func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) {
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
user, err := users.GetUser(query)
|
||||
waitForEvents()
|
||||
|
||||
if expectedError != "" {
|
||||
assert.Equal(m.t, expectedError, err.Error())
|
||||
} else {
|
||||
assert.NoError(m.t, err)
|
||||
}
|
||||
|
||||
var expectedUser *User
|
||||
if index >= 0 {
|
||||
expectedUser = users.users[index]
|
||||
}
|
||||
|
||||
assert.Equal(m.t, expectedUser, user)
|
||||
}
|
||||
@ -1,219 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/metrics"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUsersFinishLoginBadMailboxPassword(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
gomock.InOrder(
|
||||
// Init users with no user from keychain.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
|
||||
|
||||
// Set up mocks for FinishLogin.
|
||||
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(errors.New("no keys could be unlocked")),
|
||||
m.pmapiClient.EXPECT().DeleteAuth(),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
)
|
||||
|
||||
checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", ErrWrongMailboxPassword)
|
||||
}
|
||||
|
||||
func refreshWithToken(token string) *pmapi.Auth {
|
||||
return &pmapi.Auth{
|
||||
RefreshToken: token,
|
||||
}
|
||||
}
|
||||
|
||||
func credentialsWithToken(token string) *credentials.Credentials {
|
||||
tmp := &credentials.Credentials{}
|
||||
*tmp = *testCredentials
|
||||
tmp.APIToken = token
|
||||
return tmp
|
||||
}
|
||||
|
||||
func TestUsersFinishLoginNewUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
// Basically every call client has get client manager
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
gomock.InOrder(
|
||||
// users.New() finds no users in keychain.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
|
||||
|
||||
// getAPIUser() loads user info from API (e.g. userID).
|
||||
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
||||
|
||||
// addNewUser()
|
||||
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
|
||||
|
||||
// user.init() in addNewUser
|
||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
|
||||
// store.New() in user.init
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
// Emit event for new user and send metrics.
|
||||
m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient),
|
||||
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
|
||||
// Reload account list in GUI.
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
|
||||
|
||||
// defer logout anonymous
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
)
|
||||
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
|
||||
|
||||
mockAuthUpdate(user, "afterCredentials", m)
|
||||
}
|
||||
|
||||
func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
loggedOutCreds := *testCredentials
|
||||
loggedOutCreds.APIToken = ""
|
||||
loggedOutCreds.MailboxPassword = ""
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
gomock.InOrder(
|
||||
// users.New() finds one existing user in keychain.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
|
||||
|
||||
// newUser()
|
||||
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
|
||||
|
||||
// user.init()
|
||||
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
|
||||
|
||||
// store.New() in user.init
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUnauthorized),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
|
||||
// getAPIUser() loads user info from API (e.g. userID).
|
||||
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
||||
|
||||
// connectExistingUser()
|
||||
m.credentialsStore.EXPECT().UpdatePassword("user", testCredentials.MailboxPassword).Return(nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":afterLogin").Return(nil),
|
||||
|
||||
// user.init() in connectExistingUser
|
||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
|
||||
// store.New() in user.init
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
// Reload account list in GUI.
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
|
||||
|
||||
// defer logout anonymous
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
)
|
||||
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
|
||||
|
||||
mockAuthUpdate(user, "afterCredentials", m)
|
||||
}
|
||||
|
||||
func TestUsersFinishLoginConnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
|
||||
mockConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
users := testNewUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
// Then, try to log in again...
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
||||
m.pmapiClient.EXPECT().DeleteAuth(),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
)
|
||||
|
||||
_, err := users.FinishLogin(m.pmapiClient, testAuth, testCredentials.MailboxPassword)
|
||||
assert.Equal(t, "user is already connected", err.Error())
|
||||
}
|
||||
|
||||
func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) *User {
|
||||
users := testNewUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword)
|
||||
|
||||
waitForEvents()
|
||||
|
||||
assert.Equal(t, expectedErr, err)
|
||||
|
||||
if expectedUserID != "" {
|
||||
assert.Equal(t, expectedUserID, user.ID())
|
||||
assert.Equal(t, 1, len(users.users))
|
||||
assert.Equal(t, expectedUserID, users.users[0].ID())
|
||||
} else {
|
||||
assert.Equal(t, (*User)(nil), user)
|
||||
assert.Equal(t, 0, len(users.users))
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
@ -233,7 +233,7 @@ func (s *Store) get(userID string) (creds *Credentials, err error) {
|
||||
|
||||
_, secret, err := s.secrets.Get(userID)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not get credentials from native keychain")
|
||||
log.WithError(err).Warn("Could not get credentials from native keychain")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -26,8 +26,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testSep = "\n"
|
||||
@ -249,26 +248,26 @@ func TestMarshalFormats(t *testing.T) {
|
||||
log.Infof("secretFmt %#v %d\n", secretFmt, len(secretFmt))
|
||||
|
||||
output := testCredentials{APIToken: "refresh"}
|
||||
require.NoError(t, output.UnmarshalStrings(secretStrings))
|
||||
r.NoError(t, output.UnmarshalStrings(secretStrings))
|
||||
log.Infof("strings out %#v \n", output)
|
||||
require.True(t, input.IsSame(&output), "strings out not same")
|
||||
r.True(t, input.IsSame(&output), "strings out not same")
|
||||
|
||||
output = testCredentials{APIToken: "refresh"}
|
||||
require.NoError(t, output.UnmarshalGob(secretGob))
|
||||
r.NoError(t, output.UnmarshalGob(secretGob))
|
||||
log.Infof("gob out %#v\n \n", output)
|
||||
assert.Equal(t, input, output)
|
||||
r.Equal(t, input, output)
|
||||
|
||||
output = testCredentials{APIToken: "refresh"}
|
||||
require.NoError(t, output.FromJSON(secretJSON))
|
||||
r.NoError(t, output.FromJSON(secretJSON))
|
||||
log.Infof("json out %#v \n", output)
|
||||
require.True(t, input.IsSame(&output), "json out not same")
|
||||
r.True(t, input.IsSame(&output), "json out not same")
|
||||
|
||||
/*
|
||||
// Simple Fscanf not working!
|
||||
output = testCredentials{APIToken: "refresh"}
|
||||
require.NoError(t, output.UnmarshalFmt(secretFmt))
|
||||
r.NoError(t, output.UnmarshalFmt(secretFmt))
|
||||
log.Infof("fmt out %#v \n", output)
|
||||
require.True(t, input.IsSame(&output), "fmt out not same")
|
||||
r.True(t, input.IsSame(&output), "fmt out not same")
|
||||
*/
|
||||
}
|
||||
|
||||
@ -291,7 +290,7 @@ func TestMarshal(t *testing.T) {
|
||||
log.Infof("secret %#v %d\n", secret, len(secret))
|
||||
|
||||
output := Credentials{APIToken: "refresh"}
|
||||
require.NoError(t, output.Unmarshal(secret))
|
||||
r.NoError(t, output.Unmarshal(secret))
|
||||
log.Infof("output %#v\n", output)
|
||||
assert.Equal(t, input, output)
|
||||
r.Equal(t, input, output)
|
||||
}
|
||||
|
||||
@ -1,107 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./listener/listener.go
|
||||
|
||||
// Package users is a generated GoMock package.
|
||||
package users
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockListener is a mock of Listener interface
|
||||
type MockListener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockListenerMockRecorder
|
||||
}
|
||||
|
||||
// MockListenerMockRecorder is the mock recorder for MockListener
|
||||
type MockListenerMockRecorder struct {
|
||||
mock *MockListener
|
||||
}
|
||||
|
||||
// NewMockListener creates a new mock instance
|
||||
func NewMockListener(ctrl *gomock.Controller) *MockListener {
|
||||
mock := &MockListener{ctrl: ctrl}
|
||||
mock.recorder = &MockListenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// SetLimit mocks base method
|
||||
func (m *MockListener) SetLimit(eventName string, limit time.Duration) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetLimit", eventName, limit)
|
||||
}
|
||||
|
||||
// SetLimit indicates an expected call of SetLimit
|
||||
func (mr *MockListenerMockRecorder) SetLimit(eventName, limit interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), eventName, limit)
|
||||
}
|
||||
|
||||
// Add mocks base method
|
||||
func (m *MockListener) Add(eventName string, channel chan<- string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Add", eventName, channel)
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add
|
||||
func (mr *MockListenerMockRecorder) Add(eventName, channel interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), eventName, channel)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockListener) Remove(eventName string, channel chan<- string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Remove", eventName, channel)
|
||||
}
|
||||
|
||||
// Remove indicates an expected call of Remove
|
||||
func (mr *MockListenerMockRecorder) Remove(eventName, channel interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), eventName, channel)
|
||||
}
|
||||
|
||||
// Emit mocks base method
|
||||
func (m *MockListener) Emit(eventName, data string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Emit", eventName, data)
|
||||
}
|
||||
|
||||
// Emit indicates an expected call of Emit
|
||||
func (mr *MockListenerMockRecorder) Emit(eventName, data interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), eventName, data)
|
||||
}
|
||||
|
||||
// SetBuffer mocks base method
|
||||
func (m *MockListener) SetBuffer(eventName string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetBuffer", eventName)
|
||||
}
|
||||
|
||||
// SetBuffer indicates an expected call of SetBuffer
|
||||
func (mr *MockListenerMockRecorder) SetBuffer(eventName interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), eventName)
|
||||
}
|
||||
|
||||
// RetryEmit mocks base method
|
||||
func (m *MockListener) RetryEmit(eventName string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "RetryEmit", eventName)
|
||||
}
|
||||
|
||||
// RetryEmit indicates an expected call of RetryEmit
|
||||
func (mr *MockListenerMockRecorder) RetryEmit(eventName interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), eventName)
|
||||
}
|
||||
120
internal/users/mocks/listener_mocks.go
Normal file
120
internal/users/mocks/listener_mocks.go
Normal file
@ -0,0 +1,120 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/pkg/listener (interfaces: Listener)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
)
|
||||
|
||||
// MockListener is a mock of Listener interface
|
||||
type MockListener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockListenerMockRecorder
|
||||
}
|
||||
|
||||
// MockListenerMockRecorder is the mock recorder for MockListener
|
||||
type MockListenerMockRecorder struct {
|
||||
mock *MockListener
|
||||
}
|
||||
|
||||
// NewMockListener creates a new mock instance
|
||||
func NewMockListener(ctrl *gomock.Controller) *MockListener {
|
||||
mock := &MockListener{ctrl: ctrl}
|
||||
mock.recorder = &MockListenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Add mocks base method
|
||||
func (m *MockListener) Add(arg0 string, arg1 chan<- string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Add", arg0, arg1)
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add
|
||||
func (mr *MockListenerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), arg0, arg1)
|
||||
}
|
||||
|
||||
// Emit mocks base method
|
||||
func (m *MockListener) Emit(arg0, arg1 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Emit", arg0, arg1)
|
||||
}
|
||||
|
||||
// Emit indicates an expected call of Emit
|
||||
func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
|
||||
}
|
||||
|
||||
// ProvideChannel mocks base method
|
||||
func (m *MockListener) ProvideChannel(arg0 string) <-chan string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ProvideChannel", arg0)
|
||||
ret0, _ := ret[0].(<-chan string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ProvideChannel indicates an expected call of ProvideChannel
|
||||
func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Remove", arg0, arg1)
|
||||
}
|
||||
|
||||
// Remove indicates an expected call of Remove
|
||||
func (mr *MockListenerMockRecorder) Remove(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), arg0, arg1)
|
||||
}
|
||||
|
||||
// RetryEmit mocks base method
|
||||
func (m *MockListener) RetryEmit(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "RetryEmit", arg0)
|
||||
}
|
||||
|
||||
// RetryEmit indicates an expected call of RetryEmit
|
||||
func (mr *MockListenerMockRecorder) RetryEmit(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), arg0)
|
||||
}
|
||||
|
||||
// SetBuffer mocks base method
|
||||
func (m *MockListener) SetBuffer(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetBuffer", arg0)
|
||||
}
|
||||
|
||||
// SetBuffer indicates an expected call of SetBuffer
|
||||
func (mr *MockListenerMockRecorder) SetBuffer(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), arg0)
|
||||
}
|
||||
|
||||
// SetLimit mocks base method
|
||||
func (m *MockListener) SetLimit(arg0 string, arg1 time.Duration) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetLimit", arg0, arg1)
|
||||
}
|
||||
|
||||
// SetLimit indicates an expected call of SetLimit
|
||||
func (mr *MockListenerMockRecorder) SetLimit(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), arg0, arg1)
|
||||
}
|
||||
@ -5,11 +5,10 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
store "github.com/ProtonMail/proton-bridge/internal/store"
|
||||
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockLocator is a mock of Locator interface
|
||||
|
||||
@ -89,29 +89,25 @@ func newUser(
|
||||
// - providing it with an authorised API client
|
||||
// - loading its credentials from the credentials store
|
||||
// - loading and unlocking its PGP keys
|
||||
// - loading its store
|
||||
func (u *User) connect(ctx context.Context, client pmapi.Client, creds *credentials.Credentials) error {
|
||||
// - loading its store.
|
||||
func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) error {
|
||||
u.log.Info("Connecting user")
|
||||
|
||||
// Connected users have an API client.
|
||||
u.client = client
|
||||
|
||||
// FIXME(conman): How to remove this auth handler when user is disconnected?
|
||||
u.client.AddAuthHandler(u.handleAuth)
|
||||
u.client.AddAuthRefreshHandler(u.handleAuthRefresh)
|
||||
|
||||
// Save the latest credentials for the user.
|
||||
u.creds = creds
|
||||
|
||||
// Connected users have unlocked keys.
|
||||
// FIXME(conman): clients should always be authorized! This is a workaround to avoid a major refactor :(
|
||||
if u.creds.IsConnected() {
|
||||
if err := u.client.Unlock(ctx, []byte(u.creds.MailboxPassword)); err != nil {
|
||||
if err := u.unlockIfNecessary(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Connected users have a store.
|
||||
if err := u.loadStore(); err != nil {
|
||||
if err := u.loadStore(); err != nil { //nolint[revive] easier to read
|
||||
return err
|
||||
}
|
||||
|
||||
@ -138,17 +134,25 @@ func (u *User) loadStore() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) handleAuth(auth *pmapi.Auth) error {
|
||||
u.log.Debug("User received auth")
|
||||
func (u *User) handleAuthRefresh(auth *pmapi.AuthRefresh) {
|
||||
u.log.Debug("User received auth refresh update")
|
||||
|
||||
if auth == nil {
|
||||
if err := u.logout(); err != nil {
|
||||
log.WithError(err).
|
||||
WithField("userID", u.userID).
|
||||
Error("User logout failed while watching API auths")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to update refresh token in credentials store")
|
||||
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
|
||||
return
|
||||
}
|
||||
|
||||
u.creds = creds
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearStore removes the database.
|
||||
@ -181,13 +185,6 @@ func (u *User) closeStore() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTemporaryPMAPIClient returns an authorised PMAPI client.
|
||||
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
|
||||
// After proper refactor of SMTP and IMAP remove this method.
|
||||
func (u *User) GetTemporaryPMAPIClient() pmapi.Client {
|
||||
return u.client
|
||||
}
|
||||
|
||||
// ID returns the user's userID.
|
||||
func (u *User) ID() string {
|
||||
return u.userID
|
||||
@ -210,9 +207,43 @@ func (u *User) IsConnected() bool {
|
||||
}
|
||||
|
||||
func (u *User) GetClient() pmapi.Client {
|
||||
if err := u.unlockIfNecessary(); err != nil {
|
||||
u.log.WithError(err).Error("Failed to unlock user")
|
||||
}
|
||||
return u.client
|
||||
}
|
||||
|
||||
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
|
||||
func (u *User) unlockIfNecessary() error {
|
||||
if !u.creds.IsConnected() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.client.IsUnlocked() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// unlockIfNecessary is called with every access to underlying pmapi
|
||||
// client. Unlock should only finish unlocking when connection is back up.
|
||||
// That means it should try it fast enough and not retry if connection
|
||||
// is still down.
|
||||
err := u.client.Unlock(pmapi.ContextWithoutRetry(context.Background()), []byte(u.creds.MailboxPassword))
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch errors.Cause(err) {
|
||||
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
|
||||
u.log.WithError(err).Warn("Could not unlock user")
|
||||
return nil
|
||||
}
|
||||
|
||||
if logoutErr := u.logout(); logoutErr != nil {
|
||||
u.log.WithError(logoutErr).Warn("Could not logout user")
|
||||
}
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
|
||||
// IsCombinedAddressMode returns whether user is set in combined or split mode.
|
||||
// Combined mode is the default mode and is what users typically need.
|
||||
// Split mode is mostly for outlook as it cannot handle sending e-mails from an
|
||||
@ -307,14 +338,10 @@ func (u *User) GetBridgePassword() string {
|
||||
// CheckBridgeLogin checks whether the user is logged in and the bridge
|
||||
// IMAP/SMTP password is correct.
|
||||
func (u *User) CheckBridgeLogin(password string) error {
|
||||
// FIXME(conman): Handle force upgrade?
|
||||
|
||||
/*
|
||||
if isApplicationOutdated {
|
||||
u.listener.Emit(events.UpgradeApplicationEvent, "")
|
||||
return pmapi.ErrUpgradeApplication
|
||||
}
|
||||
*/
|
||||
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
@ -328,16 +355,16 @@ func (u *User) CheckBridgeLogin(password string) error {
|
||||
}
|
||||
|
||||
// UpdateUser updates user details from API and saves to the credentials.
|
||||
func (u *User) UpdateUser() error {
|
||||
func (u *User) UpdateUser(ctx context.Context) error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
_, err := u.client.UpdateUser(context.TODO())
|
||||
_, err := u.client.UpdateUser(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := u.client.ReloadKeys(context.TODO(), []byte(u.creds.MailboxPassword)); err != nil {
|
||||
if err := u.client.ReloadKeys(ctx, []byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to reload keys")
|
||||
}
|
||||
|
||||
@ -414,8 +441,7 @@ func (u *User) Logout() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FIXME(conman): Do we delete API client now? Who cleans up? What about registered handlers?
|
||||
if err := u.client.AuthDelete(context.TODO()); err != nil {
|
||||
if err := u.client.AuthDelete(context.Background()); err != nil {
|
||||
u.log.WithError(err).Warn("Failed to delete auth")
|
||||
}
|
||||
|
||||
|
||||
195
internal/users/user_credentials_test.go
Normal file
195
internal/users/user_credentials_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/pkg/errors"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().UpdateUser(gomock.Any()).Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().ReloadKeys(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}).Return(testCredentials, nil),
|
||||
)
|
||||
|
||||
r.NoError(t, user.UpdateUser(context.Background()))
|
||||
}
|
||||
|
||||
func TestUserSwitchAddressMode(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
// Ignore any sync on background.
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
|
||||
// Check initial state.
|
||||
r.True(t, user.store.IsCombinedMode())
|
||||
r.True(t, user.creds.IsCombinedAddressMode)
|
||||
r.True(t, user.IsCombinedAddressMode())
|
||||
|
||||
// Mock change to split mode.
|
||||
gomock.InOrder(
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentialsSplit, nil),
|
||||
)
|
||||
|
||||
// Check switch to split mode.
|
||||
r.NoError(t, user.SwitchAddressMode())
|
||||
r.False(t, user.store.IsCombinedMode())
|
||||
r.False(t, user.creds.IsCombinedAddressMode)
|
||||
r.False(t, user.IsCombinedAddressMode())
|
||||
|
||||
// MOck change to combined mode.
|
||||
gomock.InOrder(
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentials, nil),
|
||||
)
|
||||
|
||||
// Check switch to combined mode.
|
||||
r.NoError(t, user.SwitchAddressMode())
|
||||
r.True(t, user.store.IsCombinedMode())
|
||||
r.True(t, user.creds.IsCombinedAddressMode)
|
||||
r.True(t, user.IsCombinedAddressMode())
|
||||
}
|
||||
|
||||
func TestLogoutUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
||||
)
|
||||
|
||||
err := user.Logout()
|
||||
r.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLogoutUserFailsLogout(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
||||
)
|
||||
|
||||
err := user.Logout()
|
||||
r.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckBridgeLogin(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
|
||||
r.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
|
||||
|
||||
isApplicationOutdated = true
|
||||
|
||||
err := user.CheckBridgeLogin("any-pass")
|
||||
r.Equal(t, pmapi.ErrUpgradeApplication, err)
|
||||
|
||||
isApplicationOutdated = false
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginLoggedOut(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
gomock.InOrder(
|
||||
// Mock init of user.
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
|
||||
// Mock CheckBridgeLogin.
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"),
|
||||
)
|
||||
|
||||
user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
|
||||
r.NoError(t, err)
|
||||
|
||||
err = user.connect(m.pmapiClient, testCredentialsDisconnected)
|
||||
r.Error(t, err)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
|
||||
r.Equal(t, ErrLoggedOutUser, err)
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginBadPassword(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
err := user.CheckBridgeLogin("wrong!")
|
||||
r.EqualError(t, err, "backend/credentials: incorrect password")
|
||||
}
|
||||
88
internal/users/user_new_test.go
Normal file
88
internal/users/user_new_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewUserNoCredentialsStore(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
|
||||
|
||||
_, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
|
||||
r.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewUserUnlockFails(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
gomock.InOrder(
|
||||
// Init of user.
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
|
||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("bad password")),
|
||||
|
||||
// Handle of unlock error.
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"),
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
|
||||
)
|
||||
|
||||
checkNewUserHasCredentials(m, "failed to unlock user: bad password", testCredentialsDisconnected)
|
||||
}
|
||||
|
||||
func TestNewUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
mockInitConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
checkNewUserHasCredentials(m, "", testCredentials)
|
||||
}
|
||||
|
||||
func checkNewUserHasCredentials(m mocks, wantErr string, wantCreds *credentials.Credentials) {
|
||||
user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
|
||||
r.NoError(m.t, err)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
err = user.connect(m.pmapiClient, testCredentials)
|
||||
if wantErr == "" {
|
||||
r.NoError(m.t, err)
|
||||
} else {
|
||||
r.EqualError(m.t, err, wantErr)
|
||||
}
|
||||
|
||||
r.Equal(m.t, wantCreds, user.creds)
|
||||
|
||||
waitForEvents()
|
||||
}
|
||||
51
internal/users/user_store_test.go
Normal file
51
internal/users/user_store_test.go
Normal file
@ -0,0 +1,51 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
|
||||
r.Fail(t, "not implemented")
|
||||
}
|
||||
|
||||
func TestClearStoreWithStore(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
r.Nil(t, user.store.Close())
|
||||
user.store = nil
|
||||
r.Nil(t, user.clearStore())
|
||||
}
|
||||
|
||||
func TestClearStoreWithoutStore(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
r.NotNil(t, user.store)
|
||||
r.Nil(t, user.clearStore())
|
||||
}
|
||||
41
internal/users/user_test.go
Normal file
41
internal/users/user_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testNewUser sets up a new, authorised user.
|
||||
func testNewUser(m mocks) *User {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
mockInitConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
|
||||
r.NoError(m.t, err)
|
||||
|
||||
err = user.connect(m.pmapiClient, creds)
|
||||
r.NoError(m.t, err)
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func cleanUpUserData(u *User) {
|
||||
_ = u.clearStore()
|
||||
}
|
||||
@ -89,24 +89,42 @@ func New(
|
||||
lock: sync.RWMutex{},
|
||||
}
|
||||
|
||||
// FIXME(conman): Handle force upgrade events.
|
||||
/*
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
u.watchAppOutdated()
|
||||
u.watchEvents()
|
||||
}()
|
||||
*/
|
||||
|
||||
if u.credStorer == nil {
|
||||
log.Error("No credentials store is available")
|
||||
} else if err := u.loadUsersFromCredentialsStore(context.TODO()); err != nil {
|
||||
} else if err := u.loadUsersFromCredentialsStore(); err != nil {
|
||||
log.WithError(err).Error("Could not load all users from credentials store")
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
|
||||
func (u *Users) watchEvents() {
|
||||
upgradeCh := u.events.ProvideChannel(events.UpgradeApplicationEvent)
|
||||
internetOnCh := u.events.ProvideChannel(events.InternetOnEvent)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-upgradeCh:
|
||||
isApplicationOutdated = true
|
||||
u.closeAllConnections()
|
||||
case <-internetOnCh:
|
||||
for _, user := range u.users {
|
||||
if user.store == nil {
|
||||
if err := user.loadStore(); err != nil {
|
||||
log.WithError(err).Error("Failed to load store after reconnecting")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Users) loadUsersFromCredentialsStore() error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
@ -116,23 +134,26 @@ func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
|
||||
}
|
||||
|
||||
for _, userID := range userIDs {
|
||||
l := log.WithField("user", userID)
|
||||
user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Warn("Could not create user, skipping")
|
||||
l.WithError(err).Warn("Could not create user, skipping")
|
||||
continue
|
||||
}
|
||||
|
||||
u.users = append(u.users, user)
|
||||
|
||||
if creds.IsConnected() {
|
||||
if err := u.loadConnectedUser(ctx, user, creds); err != nil {
|
||||
logrus.WithError(err).Warn("Could not load connected user")
|
||||
// If there is no connection, we don't want to retry. Load should
|
||||
// happen fast enough to not block GUI. When connection is back up,
|
||||
// watchEvents and unlockIfNecessary will finish user init later.
|
||||
if err := u.loadConnectedUser(pmapi.ContextWithoutRetry(context.Background()), user, creds); err != nil {
|
||||
l.WithError(err).Warn("Could not load connected user")
|
||||
}
|
||||
} else {
|
||||
logrus.Warn("User is disconnected and must be connected manually")
|
||||
|
||||
if err := u.loadDisconnectedUser(ctx, user, creds); err != nil {
|
||||
logrus.WithError(err).Warn("Could not load disconnected user")
|
||||
l.Warn("User is disconnected and must be connected manually")
|
||||
if err := user.connect(u.clientManager.NewClient("", "", "", time.Time{}), creds); err != nil {
|
||||
l.WithError(err).Warn("Could not load disconnected user")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -140,11 +161,6 @@ func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *Users) loadDisconnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
|
||||
// FIXME(conman): We shouldn't be creating unauthorized clients... this is hacky, just to avoid huge refactor!
|
||||
return user.connect(ctx, u.clientManager.NewClient("", "", "", time.Time{}), creds)
|
||||
}
|
||||
|
||||
func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
|
||||
uid, ref, err := creds.SplitAPIToken()
|
||||
if err != nil {
|
||||
@ -153,38 +169,27 @@ func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *creden
|
||||
|
||||
client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref)
|
||||
if err != nil {
|
||||
// FIXME(conman): This is a problem... if we weren't able to create a new client due to internet,
|
||||
// we need to be able to retry later, but I deleted all the hacky "retry auth if necessary" stuff...
|
||||
return user.connect(ctx, u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
|
||||
// When client cannot be refreshed right away due to no connection,
|
||||
// we create client which will refresh automatically when possible.
|
||||
connectErr := user.connect(u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
|
||||
|
||||
switch errors.Cause(err) {
|
||||
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
|
||||
return connectErr
|
||||
}
|
||||
|
||||
if logoutErr := user.logout(); logoutErr != nil {
|
||||
logrus.WithError(logoutErr).Warn("Could not logout user")
|
||||
}
|
||||
return errors.Wrap(err, "could not refresh token")
|
||||
}
|
||||
|
||||
// Update the user's credentials with the latest auth used to connect this user.
|
||||
if creds, err = u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil {
|
||||
if creds, err = u.credStorer.UpdateToken(creds.UserID, auth.UID, auth.RefreshToken); err != nil {
|
||||
return errors.Wrap(err, "could not create get user's refresh token")
|
||||
}
|
||||
|
||||
return user.connect(ctx, client, creds)
|
||||
}
|
||||
|
||||
func (u *Users) watchAppOutdated() {
|
||||
// FIXME(conman): handle force upgrade events.
|
||||
|
||||
/*
|
||||
ch := make(chan string)
|
||||
|
||||
u.events.Add(events.UpgradeApplicationEvent, ch)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
isApplicationOutdated = true
|
||||
u.closeAllConnections()
|
||||
|
||||
case <-u.stopAll:
|
||||
return
|
||||
}
|
||||
}
|
||||
*/
|
||||
return user.connect(client, creds)
|
||||
}
|
||||
|
||||
func (u *Users) closeAllConnections() {
|
||||
@ -198,19 +203,19 @@ func (u *Users) closeAllConnections() {
|
||||
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
|
||||
u.crashBandicoot(username)
|
||||
|
||||
return u.clientManager.NewClientWithLogin(context.TODO(), username, password)
|
||||
return u.clientManager.NewClientWithLogin(context.Background(), username, password)
|
||||
}
|
||||
|
||||
// FinishLogin finishes the login procedure and adds the user into the credentials store.
|
||||
func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen]
|
||||
apiUser, passphrase, err := getAPIUser(context.TODO(), client, password)
|
||||
apiUser, passphrase, err := getAPIUser(context.Background(), client, password)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get API user")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user, ok := u.hasUser(apiUser.ID); ok {
|
||||
if user.IsConnected() {
|
||||
if err := client.AuthDelete(context.TODO()); err != nil {
|
||||
if err := client.AuthDelete(context.Background()); err != nil {
|
||||
logrus.WithError(err).Warn("Failed to delete new auth session")
|
||||
}
|
||||
|
||||
@ -228,14 +233,16 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password stri
|
||||
return nil, errors.Wrap(err, "failed to update password of user in credentials store")
|
||||
}
|
||||
|
||||
if err := user.connect(context.TODO(), client, creds); err != nil {
|
||||
if err := user.connect(client, creds); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to reconnect existing user")
|
||||
}
|
||||
|
||||
u.events.Emit(events.UserRefreshEvent, apiUser.ID)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
if err := u.addNewUser(context.TODO(), client, apiUser, auth, passphrase); err != nil {
|
||||
if err := u.addNewUser(client, apiUser, auth, passphrase); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to add new user")
|
||||
}
|
||||
|
||||
@ -245,7 +252,7 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password stri
|
||||
}
|
||||
|
||||
// addNewUser adds a new user.
|
||||
func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
|
||||
func (u *Users) addNewUser(client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
@ -266,7 +273,7 @@ func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pm
|
||||
return errors.Wrap(err, "failed to create new user")
|
||||
}
|
||||
|
||||
if err := user.connect(ctx, client, creds); err != nil {
|
||||
if err := user.connect(client, creds); err != nil {
|
||||
return errors.Wrap(err, "failed to connect new user")
|
||||
}
|
||||
|
||||
@ -292,7 +299,7 @@ func getAPIUser(ctx context.Context, client pmapi.Client, password string) (*pma
|
||||
|
||||
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
|
||||
if err := client.Unlock(ctx, passphrase); err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to unlock client")
|
||||
return nil, nil, ErrWrongMailboxPassword
|
||||
}
|
||||
|
||||
user, err := client.CurrentUser(ctx)
|
||||
@ -414,22 +421,13 @@ func (u *Users) SendMetric(m metrics.Metric) error {
|
||||
// AllowProxy instructs the app to use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||
func (u *Users) AllowProxy() {
|
||||
// FIXME(conman): Support DoH.
|
||||
// u.apiManager.AllowProxy()
|
||||
u.clientManager.AllowProxy()
|
||||
}
|
||||
|
||||
// DisallowProxy instructs the app to not use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||
func (u *Users) DisallowProxy() {
|
||||
// FIXME(conman): Support DoH.
|
||||
// u.apiManager.DisallowProxy()
|
||||
}
|
||||
|
||||
// CheckConnection returns whether there is an internet connection.
|
||||
// This should use the connection manager when it is eventually implemented.
|
||||
func (u *Users) CheckConnection() error {
|
||||
// FIXME(conman): Other parts of bridge that rely on this method should register as a connection observer.
|
||||
panic("TODO: register as a connection observer to get this information")
|
||||
u.clientManager.DisallowProxy()
|
||||
}
|
||||
|
||||
// hasUser returns whether the struct currently has a user with ID `id`.
|
||||
|
||||
49
internal/users/users_clear_test.go
Normal file
49
internal/users/users_clear_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClearData(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
|
||||
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
|
||||
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||
m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil)
|
||||
|
||||
m.locator.EXPECT().Clear()
|
||||
|
||||
r.NoError(t, users.ClearData())
|
||||
}
|
||||
69
internal/users/users_delete_test.go
Normal file
69
internal/users/users_delete_test.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
)
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
err := users.DeleteUser("user", true)
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, 1, len(users.users))
|
||||
}
|
||||
|
||||
// Even when logout fails, delete is done.
|
||||
func TestDeleteUserWithFailingLogout(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")),
|
||||
// Once called from user.Logout after failed creds.Logout as fallback, and once at the end of users.Logout.
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
err := users.DeleteUser("user", true)
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, 1, len(users.users))
|
||||
}
|
||||
76
internal/users/users_get_test.go
Normal file
76
internal/users/users_get_test.go
Normal file
@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetNoUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
checkUsersGetUser(t, m, "nouser", -1, "user nouser not found")
|
||||
}
|
||||
|
||||
func TestGetUserByID(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
checkUsersGetUser(t, m, "user", 0, "")
|
||||
checkUsersGetUser(t, m, "users", 1, "")
|
||||
}
|
||||
|
||||
func TestGetUserByName(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
checkUsersGetUser(t, m, "username", 0, "")
|
||||
checkUsersGetUser(t, m, "usersname", 1, "")
|
||||
}
|
||||
|
||||
func TestGetUserByEmail(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
checkUsersGetUser(t, m, "user@pm.me", 0, "")
|
||||
checkUsersGetUser(t, m, "users@pm.me", 1, "")
|
||||
checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "")
|
||||
checkUsersGetUser(t, m, "alsouser@pm.me", 1, "")
|
||||
}
|
||||
|
||||
func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) {
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
user, err := users.GetUser(query)
|
||||
|
||||
if expectedError != "" {
|
||||
r.EqualError(m.t, err, expectedError)
|
||||
} else {
|
||||
r.NoError(m.t, err)
|
||||
}
|
||||
|
||||
var expectedUser *User
|
||||
if index >= 0 {
|
||||
expectedUser = users.users[index]
|
||||
}
|
||||
r.Equal(m.t, expectedUser, user)
|
||||
}
|
||||
132
internal/users/users_login_test.go
Normal file
132
internal/users/users_login_test.go
Normal file
@ -0,0 +1,132 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/metrics"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/pkg/errors"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsersFinishLoginBadMailboxPassword(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
// Init users with no user from keychain.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
|
||||
|
||||
// Set up mocks for FinishLogin.
|
||||
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil)
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("no keys could be unlocked"))
|
||||
|
||||
checkUsersFinishLogin(t, m, testAuthRefresh, testCredentials.MailboxPassword, "", ErrWrongMailboxPassword)
|
||||
}
|
||||
|
||||
func TestUsersFinishLoginNewUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
// Init users with no user from keychain.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
|
||||
|
||||
mockAddingConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel))
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentials.UserID)
|
||||
|
||||
checkUsersFinishLogin(t, m, testAuthRefresh, testCredentials.MailboxPassword, testCredentials.UserID, nil)
|
||||
}
|
||||
|
||||
func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
// Mock loading disconnected user.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID}, nil)
|
||||
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
|
||||
|
||||
// Mock process of FinishLogin of already added user.
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUserDisconnected, nil),
|
||||
m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
|
||||
)
|
||||
mockInitConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID)
|
||||
|
||||
authRefresh := &pmapi.Auth{
|
||||
UserID: testCredentialsDisconnected.UserID,
|
||||
AuthRefresh: pmapi.AuthRefresh{
|
||||
UID: "uid",
|
||||
AccessToken: "acc",
|
||||
RefreshToken: "ref",
|
||||
},
|
||||
}
|
||||
checkUsersFinishLogin(t, m, authRefresh, testCredentials.MailboxPassword, testCredentialsDisconnected.UserID, nil)
|
||||
}
|
||||
|
||||
func TestUsersFinishLoginConnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
// Mock loading connected user.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
|
||||
mockLoadingConnectedUser(m, testCredentials)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
// Mock process of FinishLogin of already connected user.
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil),
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
|
||||
)
|
||||
|
||||
users := testNewUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
_, err := users.FinishLogin(m.pmapiClient, testAuthRefresh, testCredentials.MailboxPassword)
|
||||
r.EqualError(t, err, "user is already connected")
|
||||
}
|
||||
|
||||
func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) {
|
||||
users := testNewUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword)
|
||||
|
||||
r.Equal(t, expectedErr, err)
|
||||
|
||||
if expectedUserID != "" {
|
||||
r.Equal(t, expectedUserID, user.ID())
|
||||
r.Equal(t, 1, len(users.users))
|
||||
r.Equal(t, expectedUserID, users.users[0].ID())
|
||||
} else {
|
||||
r.Equal(t, (*User)(nil), user)
|
||||
r.Equal(t, 0, len(users.users))
|
||||
}
|
||||
}
|
||||
@ -22,9 +22,10 @@ import (
|
||||
"testing"
|
||||
time "time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewUsersNoKeychain(t *testing.T) {
|
||||
@ -32,7 +33,6 @@ func TestNewUsersNoKeychain(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, errors.New("no keychain"))
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{})
|
||||
}
|
||||
|
||||
@ -41,108 +41,73 @@ func TestNewUsersWithoutUsersInCredentialsStore(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{})
|
||||
}
|
||||
|
||||
func TestNewUsersWithDisconnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
|
||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
)
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||
}
|
||||
|
||||
/*
|
||||
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token"))
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.pmapiClient.EXPECT().Logout()
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||
}
|
||||
|
||||
func TestNewUsersWithConnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
|
||||
mockConnectedUser(m)
|
||||
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
|
||||
mockLoadingConnectedUser(m, testCredentials)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
|
||||
}
|
||||
|
||||
func TestNewUsersWithDisconnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID}, nil)
|
||||
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||
}
|
||||
|
||||
// Tests two users with different states and checks also the order from
|
||||
// credentials store is kept also in array of users.
|
||||
func TestNewUsersWithUsers(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"userDisconnected", "user"}, nil)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
|
||||
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
|
||||
// Set up mocks for store initialisation for the unauth user.
|
||||
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
)
|
||||
|
||||
mockConnectedUser(m)
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil)
|
||||
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
|
||||
mockLoadingConnectedUser(m, testCredentials)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
|
||||
}
|
||||
|
||||
func TestNewUsersFirstStart(t *testing.T) {
|
||||
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
|
||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, errors.New("bad token"))
|
||||
m.clientManager.EXPECT().NewClient("uid", "", "acc", time.Time{}).Return(m.pmapiClient)
|
||||
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
|
||||
m.pmapiClient.EXPECT().IsUnlocked().Return(false)
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("not authorized"))
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||
|
||||
testNewUsers(t, m)
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||
}
|
||||
*/
|
||||
|
||||
func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) {
|
||||
users := testNewUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
assert.Equal(m.t, len(expectedCredentials), len(users.GetUsers()))
|
||||
r.Equal(m.t, len(expectedCredentials), len(users.GetUsers()))
|
||||
|
||||
credentials := []*credentials.Credentials{}
|
||||
for _, user := range users.users {
|
||||
credentials = append(credentials, user.creds)
|
||||
}
|
||||
|
||||
assert.Equal(m.t, expectedCredentials, credentials)
|
||||
r.Equal(m.t, expectedCredentials, credentials)
|
||||
}
|
||||
|
||||
@ -33,8 +33,9 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@ -49,9 +50,12 @@ func TestMain(m *testing.M) {
|
||||
|
||||
var (
|
||||
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
|
||||
UserID: "user",
|
||||
AuthRefresh: pmapi.AuthRefresh{
|
||||
UID: "uid",
|
||||
AccessToken: "acc",
|
||||
RefreshToken: "ref",
|
||||
},
|
||||
}
|
||||
|
||||
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
@ -81,7 +85,7 @@ var (
|
||||
}
|
||||
|
||||
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "user",
|
||||
UserID: "userDisconnected",
|
||||
Name: "username",
|
||||
Emails: "user@pm.me",
|
||||
APIToken: "",
|
||||
@ -94,7 +98,7 @@ var (
|
||||
}
|
||||
|
||||
testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "users",
|
||||
UserID: "usersDisconnected",
|
||||
Name: "usersname",
|
||||
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
|
||||
APIToken: "",
|
||||
@ -111,17 +115,22 @@ var (
|
||||
Name: "username",
|
||||
}
|
||||
|
||||
testPMAPIUserDisconnected = &pmapi.User{ //nolint[gochecknoglobals]
|
||||
ID: "userDisconnected",
|
||||
Name: "username",
|
||||
}
|
||||
|
||||
testPMAPIAddress = &pmapi.Address{ //nolint[gochecknoglobals]
|
||||
ID: "testAddressID",
|
||||
Type: pmapi.OriginalAddress,
|
||||
Email: "user@pm.me",
|
||||
Receive: pmapi.CanReceive,
|
||||
Receive: true,
|
||||
}
|
||||
|
||||
testPMAPIAddresses = []*pmapi.Address{ //nolint[gochecknoglobals]
|
||||
{ID: "usersAddress1ID", Email: "users@pm.me", Receive: pmapi.CanReceive, Type: pmapi.OriginalAddress},
|
||||
{ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress},
|
||||
{ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress},
|
||||
{ID: "usersAddress1ID", Email: "users@pm.me", Receive: true, Type: pmapi.OriginalAddress},
|
||||
{ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: true, Type: pmapi.AliasAddress},
|
||||
{ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: true, Type: pmapi.AliasAddress},
|
||||
}
|
||||
|
||||
testPMAPIEvent = &pmapi.Event{ // nolint[gochecknoglobals]
|
||||
@ -129,15 +138,6 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func waitForEvents() {
|
||||
// Wait for goroutine to add listener.
|
||||
// E.g. calling login to invoke firstsync event. Functions can end sooner than
|
||||
// goroutines call the listener mock. We need to wait a little bit before the end of
|
||||
// the test to capture all event calls. This allows us to detect whether there were
|
||||
// missing calls, or perhaps whether something was called too many times.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
type mocks struct {
|
||||
t *testing.T
|
||||
|
||||
@ -146,7 +146,7 @@ type mocks struct {
|
||||
PanicHandler *usersmocks.MockPanicHandler
|
||||
credentialsStore *usersmocks.MockCredentialsStorer
|
||||
storeMaker *usersmocks.MockStoreMaker
|
||||
eventListener *MockListener
|
||||
eventListener *usersmocks.MockListener
|
||||
|
||||
clientManager *pmapimocks.MockManager
|
||||
pmapiClient *pmapimocks.MockClient
|
||||
@ -154,6 +154,48 @@ type mocks struct {
|
||||
storeCache *store.Cache
|
||||
}
|
||||
|
||||
func initMocks(t *testing.T) mocks {
|
||||
var mockCtrl *gomock.Controller
|
||||
if os.Getenv("VERBOSITY") == "trace" {
|
||||
mockCtrl = gomock.NewController(&fullStackReporter{t})
|
||||
} else {
|
||||
mockCtrl = gomock.NewController(t)
|
||||
}
|
||||
|
||||
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
|
||||
r.NoError(t, err, "could not get temporary file for store cache")
|
||||
|
||||
m := mocks{
|
||||
t: t,
|
||||
|
||||
ctrl: mockCtrl,
|
||||
locator: usersmocks.NewMockLocator(mockCtrl),
|
||||
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
|
||||
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
|
||||
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
|
||||
eventListener: usersmocks.NewMockListener(mockCtrl),
|
||||
|
||||
clientManager: pmapimocks.NewMockManager(mockCtrl),
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
|
||||
storeCache: store.NewCache(cacheFile.Name()),
|
||||
}
|
||||
|
||||
// Called during clean-up.
|
||||
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
|
||||
|
||||
// Set up store factory.
|
||||
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
|
||||
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
|
||||
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
|
||||
r.NoError(t, err, "could not get temporary file for store db")
|
||||
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
|
||||
}).AnyTimes()
|
||||
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
type fullStackReporter struct {
|
||||
T testing.TB
|
||||
}
|
||||
@ -168,86 +210,18 @@ func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
|
||||
fr.T.FailNow()
|
||||
}
|
||||
|
||||
func initMocks(t *testing.T) mocks {
|
||||
var mockCtrl *gomock.Controller
|
||||
if os.Getenv("VERBOSITY") == "trace" {
|
||||
mockCtrl = gomock.NewController(&fullStackReporter{t})
|
||||
} else {
|
||||
mockCtrl = gomock.NewController(t)
|
||||
}
|
||||
|
||||
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
|
||||
require.NoError(t, err, "could not get temporary file for store cache")
|
||||
|
||||
m := mocks{
|
||||
t: t,
|
||||
|
||||
ctrl: mockCtrl,
|
||||
locator: usersmocks.NewMockLocator(mockCtrl),
|
||||
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
|
||||
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
|
||||
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
|
||||
eventListener: NewMockListener(mockCtrl),
|
||||
|
||||
clientManager: pmapimocks.NewMockManager(mockCtrl),
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
|
||||
storeCache: store.NewCache(cacheFile.Name()),
|
||||
}
|
||||
|
||||
// Called during clean-up.
|
||||
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
|
||||
|
||||
// Set up store factory.
|
||||
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
|
||||
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
|
||||
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
|
||||
require.NoError(t, err, "could not get temporary file for store db")
|
||||
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
|
||||
}).AnyTimes()
|
||||
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
|
||||
// Events are asynchronous
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil),
|
||||
|
||||
// Init for user.
|
||||
m.credentialsStore.EXPECT().Get(testCredentials.UserID).Return(testCredentials, nil),
|
||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
||||
m.credentialsStore.EXPECT().UpdateToken(testCredentials.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().UpdatePassword(testCredentials.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
// Init for users.
|
||||
m.credentialsStore.EXPECT().Get(testCredentialsSplit.UserID).Return(testCredentialsSplit, nil),
|
||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
||||
m.credentialsStore.EXPECT().UpdateToken(testCredentialsSplit.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentialsSplit, nil),
|
||||
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsSplit.UserID, testCredentialsSplit.MailboxPassword).Return(testCredentialsSplit, nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses),
|
||||
)
|
||||
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil)
|
||||
mockLoadingConnectedUser(m, testCredentials)
|
||||
mockLoadingConnectedUser(m, testCredentialsSplit)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
return testNewUsers(t, m)
|
||||
}
|
||||
|
||||
func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
|
||||
// FIXME(conman): How to handle force upgrade?
|
||||
// m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
||||
m.eventListener.EXPECT().ProvideChannel(events.UpgradeApplicationEvent)
|
||||
m.eventListener.EXPECT().ProvideChannel(events.InternetOnEvent)
|
||||
|
||||
users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true)
|
||||
|
||||
@ -256,38 +230,84 @@ func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
|
||||
return users
|
||||
}
|
||||
|
||||
func waitForEvents() {
|
||||
// Wait for goroutine to add listener.
|
||||
// E.g. calling login to invoke firstsync event. Functions can end sooner than
|
||||
// goroutines call the listener mock. We need to wait a little bit before the end of
|
||||
// the test to capture all event calls. This allows us to detect whether there were
|
||||
// missing calls, or perhaps whether something was called too many times.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func cleanUpUsersData(b *Users) {
|
||||
for _, user := range b.users {
|
||||
_ = user.clearStore()
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearData(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
func mockAddingConnectedUser(m mocks) {
|
||||
gomock.InOrder(
|
||||
// Mock of users.FinishLogin.
|
||||
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
m.credentialsStore.EXPECT().Add("user", "username", testAuthRefresh.UID, testAuthRefresh.RefreshToken, testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}).Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
)
|
||||
|
||||
// m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
// m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
mockInitConnectedUser(m)
|
||||
}
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
|
||||
authRefresh := &pmapi.AuthRefresh{
|
||||
UID: "uid",
|
||||
AccessToken: "acc",
|
||||
RefreshToken: "ref",
|
||||
}
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
|
||||
gomock.InOrder(
|
||||
// Mock of users.loadUsersFromCredentialsStore.
|
||||
m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil),
|
||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, authRefresh, nil),
|
||||
m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil),
|
||||
)
|
||||
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
|
||||
mockInitConnectedUser(m)
|
||||
}
|
||||
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||
m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil)
|
||||
func mockInitConnectedUser(m mocks) {
|
||||
// Mock of user initialisation.
|
||||
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
|
||||
m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes()
|
||||
|
||||
m.locator.EXPECT().Clear()
|
||||
// Mock of store initialisation.
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
)
|
||||
}
|
||||
|
||||
require.NoError(t, users.ClearData())
|
||||
func mockLoadingDisconnectedUser(m mocks, creds *credentials.Credentials) {
|
||||
gomock.InOrder(
|
||||
// Mock of users.loadUsersFromCredentialsStore.
|
||||
m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil),
|
||||
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
|
||||
)
|
||||
|
||||
waitForEvents()
|
||||
mockInitDisconnectedUser(m)
|
||||
}
|
||||
|
||||
func mockInitDisconnectedUser(m mocks) {
|
||||
gomock.InOrder(
|
||||
// Mock of user initialisation.
|
||||
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
|
||||
|
||||
// Mock of store initialisation for the unauthorized user.
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
)
|
||||
}
|
||||
|
||||
func mockEventLoopNoAction(m mocks) {
|
||||
@ -297,19 +317,3 @@ func mockEventLoopNoAction(m mocks) {
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
}
|
||||
|
||||
func mockConnectedUser(m mocks) {
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
// m.pmapiClient.EXPECT().AuthRefresh("uid:acc").Return(testAuthRefresh, nil),
|
||||
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
|
||||
// Set up mocks for store initialisation for the authorized user.
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
)
|
||||
}
|
||||
|
||||
@ -105,6 +105,10 @@ func (h *macOSHelper) Get(secretURL string) (string, string, error) {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return "", "", errors.New("no result")
|
||||
}
|
||||
|
||||
if len(results) != 1 {
|
||||
return "", "", errors.New("ambiguous results")
|
||||
}
|
||||
|
||||
@ -29,6 +29,7 @@ var log = logrus.WithField("pkg", "bridgeUtils/listener") //nolint[gochecknoglob
|
||||
// Listener has a list of channels watching for updates.
|
||||
type Listener interface {
|
||||
SetLimit(eventName string, limit time.Duration)
|
||||
ProvideChannel(eventName string) <-chan string
|
||||
Add(eventName string, channel chan<- string)
|
||||
Remove(eventName string, channel chan<- string)
|
||||
Emit(eventName string, data string)
|
||||
@ -69,6 +70,15 @@ func (l *listener) SetLimit(eventName string, limit time.Duration) {
|
||||
l.limits[eventName] = limit
|
||||
}
|
||||
|
||||
// ProvideChannel creates new channel, adds it to listener and sends to it
|
||||
// bufferent events.
|
||||
func (l *listener) ProvideChannel(eventName string) <-chan string {
|
||||
ch := make(chan string)
|
||||
l.Add(eventName, ch)
|
||||
l.RetryEmit(eventName)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Add adds an event listener.
|
||||
func (l *listener) Add(eventName string, channel chan<- string) {
|
||||
l.lock.Lock()
|
||||
|
||||
@ -97,7 +97,7 @@ func fetchWorker(fetchReqCh <-chan fetchReq, fetchResCh chan<- fetchRes, attachW
|
||||
}
|
||||
|
||||
func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, error) {
|
||||
msg, err := req.api.GetMessage(req.messageID)
|
||||
msg, err := req.api.GetMessage(req.ctx, req.messageID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -109,7 +109,7 @@ func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, er
|
||||
}
|
||||
|
||||
process := func(value interface{}) (interface{}, error) {
|
||||
rc, err := req.api.GetAttachment(value.(string))
|
||||
rc, err := req.api.GetAttachment(req.ctx, value.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -43,10 +43,10 @@ func newTestFetcher(
|
||||
) Fetcher {
|
||||
f := mocks.NewMockFetcher(m)
|
||||
|
||||
f.EXPECT().GetMessage(msg.ID).Return(msg, nil)
|
||||
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
|
||||
|
||||
for i, att := range msg.Attachments {
|
||||
f.EXPECT().GetAttachment(att.ID).Return(newTestReadCloser(attData[i]), nil)
|
||||
f.EXPECT().GetAttachment(gomock.Any(), att.ID).Return(newTestReadCloser(attData[i]), nil)
|
||||
}
|
||||
|
||||
f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(kr, nil)
|
||||
|
||||
@ -1230,7 +1230,7 @@ func TestBuildFetchMessageFail(t *testing.T) {
|
||||
|
||||
// Pretend the message cannot be fetched.
|
||||
f := mocks.NewMockFetcher(m)
|
||||
f.EXPECT().GetMessage(msg.ID).Return(nil, errors.New("oops"))
|
||||
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(nil, errors.New("oops"))
|
||||
|
||||
// The job should fail, returning an error and a nil result.
|
||||
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
|
||||
@ -1251,8 +1251,8 @@ func TestBuildFetchAttachmentFail(t *testing.T) {
|
||||
|
||||
// Pretend the attachment cannot be fetched.
|
||||
f := mocks.NewMockFetcher(m)
|
||||
f.EXPECT().GetMessage(msg.ID).Return(msg, nil)
|
||||
f.EXPECT().GetAttachment(msg.Attachments[0].ID).Return(nil, errors.New("oops"))
|
||||
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
|
||||
f.EXPECT().GetAttachment(gomock.Any(), msg.Attachments[0].ID).Return(nil, errors.New("oops"))
|
||||
|
||||
// The job should fail, returning an error and a nil result.
|
||||
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
|
||||
@ -1272,7 +1272,7 @@ func TestBuildNoSuchKeyRing(t *testing.T) {
|
||||
|
||||
// Pretend there is no available keyring.
|
||||
f := mocks.NewMockFetcher(m)
|
||||
f.EXPECT().GetMessage(msg.ID).Return(msg, nil)
|
||||
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
|
||||
f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(nil, errors.New("oops"))
|
||||
|
||||
res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
|
||||
|
||||
@ -31,7 +31,7 @@ const (
|
||||
|
||||
// GetFlags returns imap flags from pmapi message attributes.
|
||||
func GetFlags(m *pmapi.Message) (flags []string) {
|
||||
if m.Unread == 0 {
|
||||
if !m.Unread {
|
||||
flags = append(flags, imap.SeenFlag)
|
||||
}
|
||||
if !m.Has(pmapi.FlagSent) && !m.Has(pmapi.FlagReceived) {
|
||||
@ -68,11 +68,11 @@ func ParseFlags(m *pmapi.Message, flags []string) {
|
||||
m.Flags = pmapi.FlagReceived
|
||||
}
|
||||
|
||||
m.Unread = 1
|
||||
m.Unread = true
|
||||
for _, f := range flags {
|
||||
switch f {
|
||||
case imap.SeenFlag:
|
||||
m.Unread = 0
|
||||
m.Unread = false
|
||||
case imap.DraftFlag:
|
||||
m.Flags &= ^pmapi.FlagSent
|
||||
m.Flags &= ^pmapi.FlagReceived
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
io "io"
|
||||
reflect "reflect"
|
||||
|
||||
@ -37,33 +38,33 @@ func (m *MockFetcher) EXPECT() *MockFetcherMockRecorder {
|
||||
}
|
||||
|
||||
// GetAttachment mocks base method
|
||||
func (m *MockFetcher) GetAttachment(arg0 string) (io.ReadCloser, error) {
|
||||
func (m *MockFetcher) GetAttachment(arg0 context.Context, arg1 string) (io.ReadCloser, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAttachment", arg0)
|
||||
ret := m.ctrl.Call(m, "GetAttachment", arg0, arg1)
|
||||
ret0, _ := ret[0].(io.ReadCloser)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAttachment indicates an expected call of GetAttachment
|
||||
func (mr *MockFetcherMockRecorder) GetAttachment(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockFetcherMockRecorder) GetAttachment(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockFetcher)(nil).GetAttachment), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockFetcher)(nil).GetAttachment), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetMessage mocks base method
|
||||
func (m *MockFetcher) GetMessage(arg0 string) (*pmapi.Message, error) {
|
||||
func (m *MockFetcher) GetMessage(arg0 context.Context, arg1 string) (*pmapi.Message, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMessage", arg0)
|
||||
ret := m.ctrl.Call(m, "GetMessage", arg0, arg1)
|
||||
ret0, _ := ret[0].(*pmapi.Message)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMessage indicates an expected call of GetMessage
|
||||
func (mr *MockFetcherMockRecorder) GetMessage(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockFetcherMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockFetcher)(nil).GetMessage), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockFetcher)(nil).GetMessage), arg0, arg1)
|
||||
}
|
||||
|
||||
// KeyRingForAddressID mocks base method
|
||||
|
||||
@ -191,7 +191,7 @@ func DecodeHeader(raw string) (decoded string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeHeader using quoted printable and utf8
|
||||
// EncodeHeader using quoted printable and utf8.
|
||||
func EncodeHeader(s string) string {
|
||||
return mime.QEncoding.Encode("utf-8", s)
|
||||
}
|
||||
|
||||
@ -19,7 +19,6 @@ package pmmime
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
//"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
||||
@ -32,12 +32,6 @@ const (
|
||||
EnabledAddress
|
||||
)
|
||||
|
||||
// Address receive values.
|
||||
const (
|
||||
CannotReceive = iota
|
||||
CanReceive
|
||||
)
|
||||
|
||||
// Address HasKeys values.
|
||||
const (
|
||||
MissingKeys = iota
|
||||
@ -66,7 +60,7 @@ type Address struct {
|
||||
DomainID string
|
||||
Email string
|
||||
Send int
|
||||
Receive int
|
||||
Receive Boolean
|
||||
Status int
|
||||
Order int `json:",omitempty"`
|
||||
Type int
|
||||
@ -103,7 +97,7 @@ func (l AddressList) AllEmails() (addresses []string) {
|
||||
// ActiveEmails returns only active emails.
|
||||
func (l AddressList) ActiveEmails() (addresses []string) {
|
||||
for _, a := range l {
|
||||
if a.Receive == CanReceive {
|
||||
if a.Receive {
|
||||
addresses = append(addresses, a.Email)
|
||||
}
|
||||
}
|
||||
@ -175,8 +169,19 @@ func (c *client) GetAddresses(ctx context.Context) (addresses AddressList, err e
|
||||
return res.Addresses, nil
|
||||
}
|
||||
|
||||
func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) (err error) {
|
||||
panic("TODO")
|
||||
func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) error {
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(&struct {
|
||||
AddressIDs []string
|
||||
}{
|
||||
AddressIDs: addressIDs,
|
||||
}).Put("/addresses/order")
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := c.UpdateUser(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Addresses returns the addresses stored in the client object itself rather than fetching from the API.
|
||||
@ -185,24 +190,22 @@ func (c *client) Addresses() AddressList {
|
||||
}
|
||||
|
||||
// unlockAddresses unlocks all keys for all addresses of current user.
|
||||
func (c *client) unlockAddress(passphrase []byte, address *Address) (err error) {
|
||||
func (c *client) unlockAddress(passphrase []byte, address *Address) error {
|
||||
if address == nil {
|
||||
return errors.New("address data is missing")
|
||||
}
|
||||
|
||||
if address.HasKeys == MissingKeys {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
var kr *crypto.KeyRing
|
||||
|
||||
if kr, err = address.Keys.UnlockAll(passphrase, c.userKeyRing); err != nil {
|
||||
return
|
||||
kr, err := address.Keys.UnlockAll(passphrase, c.userKeyRing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.addrKeyRing[address.ID] = kr
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) {
|
||||
|
||||
@ -20,6 +20,8 @@ package pmapi
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testAddressList = AddressList{
|
||||
@ -46,39 +48,29 @@ var testAddressList = AddressList{
|
||||
},
|
||||
}
|
||||
|
||||
func routeGetAddresses(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(tb, checkMethodAndPath(r, "GET", "/addresses"))
|
||||
Ok(tb, isAuthReq(r, testUID, testAccessToken))
|
||||
func routeGetAddresses(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(tb, checkMethodAndPath(req, "GET", "/addresses"))
|
||||
r.NoError(tb, isAuthReq(req, testUID, testAccessToken))
|
||||
return "addresses/get_response.json"
|
||||
}
|
||||
|
||||
func TestAddressList(t *testing.T) {
|
||||
input := "1"
|
||||
addr := testAddressList.ByID(input)
|
||||
if addr != testAddressList[0] {
|
||||
t.Errorf("ById(%s) expected:\n%v\n but have:\n%v\n", input, testAddressList[0], addr)
|
||||
}
|
||||
r.Equal(t, testAddressList[0], addr)
|
||||
|
||||
input = "42"
|
||||
addr = testAddressList.ByID(input)
|
||||
if addr != nil {
|
||||
t.Errorf("ById expected nil for %s but have : %v\n", input, addr)
|
||||
}
|
||||
r.Nil(t, addr)
|
||||
|
||||
input = "root@protonmail.com"
|
||||
addr = testAddressList.ByEmail(input)
|
||||
if addr != testAddressList[2] {
|
||||
t.Errorf("ByEmail(%s) expected:\n%v\n but have:\n%v\n", input, testAddressList[2], addr)
|
||||
}
|
||||
r.Equal(t, testAddressList[2], addr)
|
||||
|
||||
input = "idontexist@protonmail.com"
|
||||
addr = testAddressList.ByEmail(input)
|
||||
if addr != nil {
|
||||
t.Errorf("ByEmail expected nil for %s but have : %v\n", input, addr)
|
||||
}
|
||||
r.Nil(t, addr)
|
||||
|
||||
addr = testAddressList.Main()
|
||||
if addr != testAddressList[1] {
|
||||
t.Errorf("Main() expected:\n%v\n but have:\n%v\n", testAddressList[1], addr)
|
||||
}
|
||||
r.Equal(t, testAddressList[1], addr)
|
||||
}
|
||||
|
||||
@ -23,7 +23,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
@ -138,44 +137,6 @@ func (a *Attachment) DetachedSign(kr *crypto.KeyRing, att io.Reader) (signed io.
|
||||
return signAttachment(kr, att)
|
||||
}
|
||||
|
||||
func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.Reader) (err error) {
|
||||
// Create metadata fields.
|
||||
if err = w.WriteField("Filename", att.Name); err != nil {
|
||||
return
|
||||
}
|
||||
if err = w.WriteField("MessageID", att.MessageID); err != nil {
|
||||
return
|
||||
}
|
||||
if err = w.WriteField("MIMEType", att.MIMEType); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = w.WriteField("ContentID", att.ContentID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// And send attachment data.
|
||||
ff, err := w.CreateFormFile("DataPacket", "DataPacket.pgp")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = io.Copy(ff, r); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// And send attachment data.
|
||||
sigff, err := w.CreateFormFile("Signature", "Signature.pgp")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = io.Copy(sigff, sig); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID.
|
||||
//
|
||||
// The returned created attachment contains the new attachment ID and its size.
|
||||
|
||||
@ -28,13 +28,13 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pmmime "github.com/ProtonMail/proton-bridge/pkg/mime"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testAttachment = &Attachment{
|
||||
@ -77,65 +77,40 @@ const testCreateAttachmentBody = `{
|
||||
"Attachment": {"ID": "y6uKIlc2HdoHPAwPSrvf7dXoZNMYvBgxshYUN67cY5DJjL2O8NYewuvGHcYvCfd8LpEoAI_GdymO0Jr0mHlsEw=="}
|
||||
}`
|
||||
|
||||
const testDeleteAttachmentBody = `{
|
||||
"Code": 1000
|
||||
}`
|
||||
|
||||
func TestAttachment_UnmarshalJSON(t *testing.T) {
|
||||
att := new(Attachment)
|
||||
if err := json.Unmarshal([]byte(testAttachmentJSON), att); err != nil {
|
||||
t.Fatal("Expected no error while unmarshaling JSON, got:", err)
|
||||
}
|
||||
err := json.Unmarshal([]byte(testAttachmentJSON), att)
|
||||
r.NoError(t, err)
|
||||
|
||||
att.MessageID = testAttachment.MessageID // This isn't in the JSON object
|
||||
|
||||
if !reflect.DeepEqual(testAttachment, att) {
|
||||
t.Errorf("Invalid attachment: expected %+v but got %+v", testAttachment, att)
|
||||
}
|
||||
r.Equal(t, testAttachment, att)
|
||||
}
|
||||
|
||||
func TestClient_CreateAttachment(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/attachments"))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/mail/v4/attachments"))
|
||||
|
||||
contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
t.Error("Expected no error while parsing request content type, got:", err)
|
||||
}
|
||||
if contentType != "multipart/form-data" {
|
||||
t.Errorf("Invalid request content type: expected %v but got %v", "multipart/form-data", contentType)
|
||||
}
|
||||
contentType, params, err := pmmime.ParseMediaType(req.Header.Get("Content-Type"))
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, "multipart/form-data", contentType)
|
||||
|
||||
mr := multipart.NewReader(r.Body, params["boundary"])
|
||||
mr := multipart.NewReader(req.Body, params["boundary"])
|
||||
form, err := mr.ReadForm(10 * 1024)
|
||||
if err != nil {
|
||||
t.Error("Expected no error while parsing request form, got:", err)
|
||||
}
|
||||
defer Ok(t, form.RemoveAll())
|
||||
r.NoError(t, err)
|
||||
defer r.NoError(t, form.RemoveAll())
|
||||
|
||||
if form.Value["Filename"][0] != testAttachment.Name {
|
||||
t.Errorf("Invalid attachment filename: expected %v but got %v", testAttachment.Name, form.Value["Filename"][0])
|
||||
}
|
||||
if form.Value["MessageID"][0] != testAttachment.MessageID {
|
||||
t.Errorf("Invalid attachment message id: expected %v but got %v", testAttachment.MessageID, form.Value["MessageID"][0])
|
||||
}
|
||||
if form.Value["MIMEType"][0] != testAttachment.MIMEType {
|
||||
t.Errorf("Invalid attachment message id: expected %v but got %v", testAttachment.MIMEType, form.Value["MIMEType"][0])
|
||||
}
|
||||
r.Equal(t, testAttachment.Name, form.Value["Filename"][0])
|
||||
r.Equal(t, testAttachment.MessageID, form.Value["MessageID"][0])
|
||||
r.Equal(t, testAttachment.MIMEType, form.Value["MIMEType"][0])
|
||||
|
||||
dataFile, err := form.File["DataPacket"][0].Open()
|
||||
if err != nil {
|
||||
t.Error("Expected no error while opening packets file, got:", err)
|
||||
}
|
||||
defer Ok(t, dataFile.Close())
|
||||
r.NoError(t, err)
|
||||
defer r.NoError(t, dataFile.Close())
|
||||
|
||||
b, err := ioutil.ReadAll(dataFile)
|
||||
if err != nil {
|
||||
t.Error("Expected no error while reading packets file, got:", err)
|
||||
}
|
||||
if string(b) != testAttachmentCleartext {
|
||||
t.Errorf("Invalid attachment packets: expected %v but got %v", testAttachment.KeyPackets, string(b))
|
||||
}
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, testAttachmentCleartext, string(b))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@ -143,50 +118,39 @@ func TestClient_CreateAttachment(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
r := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted
|
||||
created, err := c.CreateAttachment(context.TODO(), testAttachment, r, strings.NewReader(""))
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while creating attachment, got:", err)
|
||||
}
|
||||
reader := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted
|
||||
created, err := c.CreateAttachment(context.Background(), testAttachment, reader, strings.NewReader(""))
|
||||
r.NoError(t, err)
|
||||
|
||||
if created.ID != testAttachment.ID {
|
||||
t.Errorf("Invalid attachment id: expected %v but got %v", testAttachment.ID, created.ID)
|
||||
}
|
||||
r.Equal(t, testAttachment.ID, created.ID)
|
||||
}
|
||||
|
||||
func TestClient_GetAttachment(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/mail/v4/attachments/"+testAttachment.ID))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "GET", "/mail/v4/attachments/"+testAttachment.ID))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testAttachmentCleartext)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
r, err := c.GetAttachment(context.TODO(), testAttachment.ID)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while getting attachment, got:", err)
|
||||
}
|
||||
defer r.Close() //nolint[errcheck]
|
||||
att, err := c.GetAttachment(context.Background(), testAttachment.ID)
|
||||
r.NoError(t, err)
|
||||
defer att.Close() //nolint[errcheck]
|
||||
|
||||
// In reality, r contains encrypted data
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while reading attachment, got:", err)
|
||||
}
|
||||
b, err := ioutil.ReadAll(att)
|
||||
r.NoError(t, err)
|
||||
|
||||
if string(b) != testAttachmentCleartext {
|
||||
t.Errorf("Invalid attachment data: expected %q but got %q", testAttachmentCleartext, string(b))
|
||||
}
|
||||
r.Equal(t, testAttachmentCleartext, string(b))
|
||||
}
|
||||
|
||||
func TestAttachment_Encrypt(t *testing.T) {
|
||||
data := bytes.NewBufferString(testAttachmentCleartext)
|
||||
r, err := testAttachment.Encrypt(testPublicKeyRing, data)
|
||||
assert.Nil(t, err)
|
||||
a.Nil(t, err)
|
||||
b, err := ioutil.ReadAll(r)
|
||||
assert.Nil(t, err)
|
||||
a.Nil(t, err)
|
||||
|
||||
// Result is always different, so the best way is to test it by decrypting again.
|
||||
// Another test for decrypting will help us to be sure it's working.
|
||||
@ -202,8 +166,8 @@ func TestAttachment_Decrypt(t *testing.T) {
|
||||
|
||||
func decryptAndCheck(t *testing.T, data io.Reader) {
|
||||
r, err := testAttachment.Decrypt(data, testPrivateKeyRing)
|
||||
assert.Nil(t, err)
|
||||
a.Nil(t, err)
|
||||
b, err := ioutil.ReadAll(r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testAttachmentCleartext, string(b))
|
||||
a.Nil(t, err)
|
||||
a.Equal(t, testAttachmentCleartext, string(b))
|
||||
}
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
@ -6,15 +23,117 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
func (c *client) Auth2FA(ctx context.Context, req Auth2FAReq) error {
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(req).Post("/auth/2fa")
|
||||
type AuthModulus struct {
|
||||
Modulus string
|
||||
ModulusID string
|
||||
}
|
||||
|
||||
type GetAuthInfoReq struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
type AuthInfo struct {
|
||||
Version int
|
||||
Modulus string
|
||||
ServerEphemeral string
|
||||
Salt string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
type TwoFAInfo struct {
|
||||
Enabled TwoFAStatus
|
||||
}
|
||||
|
||||
func (twoFAInfo TwoFAInfo) hasTwoFactor() bool {
|
||||
return twoFAInfo.Enabled > 0
|
||||
}
|
||||
|
||||
type TwoFAStatus int
|
||||
|
||||
const (
|
||||
TwoFADisabled TwoFAStatus = iota
|
||||
TOTPEnabled
|
||||
U2FEnabled
|
||||
TOTPAndU2FEnabled
|
||||
)
|
||||
|
||||
type PasswordMode int
|
||||
|
||||
const (
|
||||
OnePasswordMode PasswordMode = iota + 1
|
||||
TwoPasswordMode
|
||||
)
|
||||
|
||||
type AuthReq struct {
|
||||
Username string
|
||||
ClientProof string
|
||||
ClientEphemeral string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
type AuthRefresh struct {
|
||||
UID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresIn int64
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
AuthRefresh
|
||||
|
||||
UserID string
|
||||
ServerProof string
|
||||
PasswordMode PasswordMode
|
||||
TwoFA *TwoFAInfo `json:"2FA,omitempty"`
|
||||
}
|
||||
|
||||
func (a Auth) HasTwoFactor() bool {
|
||||
if a.TwoFA == nil {
|
||||
return false
|
||||
}
|
||||
return a.TwoFA.hasTwoFactor()
|
||||
}
|
||||
|
||||
func (a Auth) HasMailboxPassword() bool {
|
||||
return a.PasswordMode == TwoPasswordMode
|
||||
}
|
||||
|
||||
type auth2FAReq struct {
|
||||
TwoFactorCode string
|
||||
}
|
||||
|
||||
type authRefreshReq struct {
|
||||
UID string
|
||||
RefreshToken string
|
||||
ResponseType string
|
||||
GrantType string
|
||||
RedirectURI string
|
||||
State string
|
||||
}
|
||||
|
||||
func (c *client) Auth2FA(ctx context.Context, twoFactorCode string) error {
|
||||
// 2FA is called during login procedure during which refresh token should
|
||||
// be valid, therefore, no refresh is needed if there is an error.
|
||||
ctx = ContextWithoutAuthRefresh(ctx)
|
||||
|
||||
if res, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(auth2FAReq{TwoFactorCode: twoFactorCode}).Post("/auth/2fa")
|
||||
}); err != nil {
|
||||
if res != nil {
|
||||
switch res.StatusCode() {
|
||||
case http.StatusUnauthorized:
|
||||
return ErrBad2FACode
|
||||
case http.StatusUnprocessableEntity:
|
||||
return ErrBad2FACodeTryAgain
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@ -29,9 +148,7 @@ func (c *client) AuthDelete(ctx context.Context) error {
|
||||
}
|
||||
|
||||
c.uid, c.acc, c.ref, c.exp = "", "", "", time.Time{}
|
||||
|
||||
// FIXME(conman): should we perhaps signal via AuthHandler that the auth was deleted?
|
||||
|
||||
c.sendAuthRefresh(nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -54,7 +171,7 @@ func (c *client) AuthSalt(ctx context.Context) (string, error) {
|
||||
return "", errors.New("no matching salt found")
|
||||
}
|
||||
|
||||
func (c *client) AddAuthHandler(handler AuthHandler) {
|
||||
func (c *client) AddAuthRefreshHandler(handler AuthRefreshHandler) {
|
||||
c.authHandlers = append(c.authHandlers, handler)
|
||||
}
|
||||
|
||||
@ -62,23 +179,35 @@ func (c *client) authRefresh(ctx context.Context) error {
|
||||
c.authLocker.Lock()
|
||||
defer c.authLocker.Unlock()
|
||||
|
||||
auth, err := c.req.authRefresh(ctx, c.uid, c.ref)
|
||||
if c.ref == "" {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
|
||||
auth, err := c.manager.authRefresh(ctx, c.uid, c.ref)
|
||||
if err != nil {
|
||||
if err != ErrNoConnection {
|
||||
c.sendAuthRefresh(nil)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c.acc = auth.AccessToken
|
||||
c.ref = auth.RefreshToken
|
||||
c.exp = expiresIn(auth.ExpiresIn)
|
||||
|
||||
for _, handler := range c.authHandlers {
|
||||
if err := handler(auth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.sendAuthRefresh(auth)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) sendAuthRefresh(auth *AuthRefresh) {
|
||||
for _, handler := range c.authHandlers {
|
||||
go handler(auth)
|
||||
}
|
||||
if auth == nil {
|
||||
c.authHandlers = []AuthRefreshHandler{}
|
||||
}
|
||||
}
|
||||
|
||||
func randomString(length int) string {
|
||||
noise := make([]byte, length)
|
||||
|
||||
|
||||
@ -1,22 +1,40 @@
|
||||
package pmapi_test
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
var wantAuth = &pmapi.Auth{
|
||||
var wantAuthRefresh = &AuthRefresh{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
ExpiresIn: 100,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
@ -24,7 +42,7 @@ func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
@ -35,28 +53,28 @@ func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuth *pmapi.Auth
|
||||
var gotAuthRefresh *AuthRefresh
|
||||
|
||||
// Create a new client.
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
|
||||
|
||||
// Register an auth handler.
|
||||
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
|
||||
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||
|
||||
// Make a request with an access token that already expired one second ago.
|
||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
|
||||
// The auth callback should have been called.
|
||||
if *gotAuth != *wantAuth {
|
||||
t.Fatal("got unexpected auth", gotAuth)
|
||||
}
|
||||
a.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||
|
||||
cl := c.(*client) //nolint[forcetypeassert] we want to panic here
|
||||
a.Equal(t, wantAuthRefresh.AccessToken, cl.acc)
|
||||
a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref)
|
||||
a.WithinDuration(t, expiresIn(100), cl.exp, time.Second)
|
||||
}
|
||||
|
||||
func Test401AuthRefresh(t *testing.T) {
|
||||
var wantAuth = &pmapi.Auth{
|
||||
var wantAuthRefresh = &AuthRefresh{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
@ -67,7 +85,7 @@ func Test401AuthRefresh(t *testing.T) {
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
@ -86,24 +104,21 @@ func Test401AuthRefresh(t *testing.T) {
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuth *pmapi.Auth
|
||||
var gotAuthRefresh *AuthRefresh
|
||||
|
||||
// Create a new client.
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
|
||||
// Register an auth handler.
|
||||
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
|
||||
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||
|
||||
// The first request will fail with 401, triggering a refresh and retry.
|
||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
|
||||
// The auth callback should have been called.
|
||||
if *gotAuth != *wantAuth {
|
||||
t.Fatal("got unexpected auth", gotAuth)
|
||||
}
|
||||
r.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||
}
|
||||
|
||||
func Test401RevokedAuth(t *testing.T) {
|
||||
@ -119,17 +134,57 @@ func Test401RevokedAuth(t *testing.T) {
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
|
||||
// The request will fail with 401, triggering a refresh.
|
||||
// The retry will also fail with 401, returning an error.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
r.EqualError(t, err, ErrUnauthorized.Error())
|
||||
}
|
||||
|
||||
if !errors.Is(err, pmapi.ErrUnauthorized) {
|
||||
t.Fatal("expected error to be ErrUnauthorized, instead got", err)
|
||||
func TestAuth2FA(t *testing.T) {
|
||||
twoFACode := "code"
|
||||
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
|
||||
var twoFAreq auth2FAReq
|
||||
r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq))
|
||||
r.Equal(t, twoFAreq.TwoFactorCode, twoFACode)
|
||||
|
||||
return "/auth/2fa/post_response.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), twoFACode)
|
||||
r.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuth2FA_Fail(t *testing.T) {
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
return "/auth/2fa/post_401_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), "code")
|
||||
r.Equal(t, ErrBad2FACode, err)
|
||||
}
|
||||
|
||||
func TestAuth2FA_Retry(t *testing.T) {
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
return "/auth/2fa/post_422_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), "code")
|
||||
r.Equal(t, ErrBad2FACodeTryAgain, err)
|
||||
}
|
||||
|
||||
@ -1,72 +0,0 @@
|
||||
package pmapi
|
||||
|
||||
type AuthModulus struct {
|
||||
Modulus string
|
||||
ModulusID string
|
||||
}
|
||||
|
||||
type GetAuthInfoReq struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
type AuthInfo struct {
|
||||
Version int
|
||||
Modulus string
|
||||
ServerEphemeral string
|
||||
Salt string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
type TwoFAInfo struct {
|
||||
Enabled TwoFAStatus
|
||||
}
|
||||
|
||||
type TwoFAStatus int
|
||||
|
||||
const (
|
||||
TwoFADisabled TwoFAStatus = iota
|
||||
TOTPEnabled
|
||||
// TODO: Support UTF
|
||||
)
|
||||
|
||||
type PasswordMode int
|
||||
|
||||
const (
|
||||
OnePasswordMode PasswordMode = iota + 1
|
||||
TwoPasswordMode
|
||||
)
|
||||
|
||||
type AuthReq struct {
|
||||
Username string
|
||||
ClientProof string
|
||||
ClientEphemeral string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
UserID string
|
||||
|
||||
UID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresIn int64
|
||||
|
||||
Scope string
|
||||
ServerProof string
|
||||
|
||||
TwoFA TwoFAInfo `json:"2FA"`
|
||||
PasswordMode PasswordMode
|
||||
}
|
||||
|
||||
type Auth2FAReq struct {
|
||||
TwoFactorCode string
|
||||
}
|
||||
|
||||
type AuthRefreshReq struct {
|
||||
UID string
|
||||
RefreshToken string
|
||||
ResponseType string
|
||||
GrantType string
|
||||
RedirectURI string
|
||||
State string
|
||||
}
|
||||
41
pkg/pmapi/boolean.go
Normal file
41
pkg/pmapi/boolean.go
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Boolean bool
|
||||
|
||||
func (boolean *Boolean) UnmarshalJSON(b []byte) error {
|
||||
var value int
|
||||
err := json.Unmarshal(b, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*boolean = Boolean(value == 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (boolean Boolean) MarshalJSON() ([]byte, error) {
|
||||
var value int
|
||||
if boolean {
|
||||
value = 1
|
||||
}
|
||||
return json.Marshal(value)
|
||||
}
|
||||
@ -25,15 +25,14 @@ import (
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// client is a client of the protonmail API. It implements the Client interface.
|
||||
type client struct {
|
||||
req requester
|
||||
manager clientManager
|
||||
|
||||
uid, acc, ref string
|
||||
authHandlers []AuthHandler
|
||||
authHandlers []AuthRefreshHandler
|
||||
authLocker sync.RWMutex
|
||||
|
||||
user *User
|
||||
@ -45,9 +44,9 @@ type client struct {
|
||||
exp time.Time
|
||||
}
|
||||
|
||||
func newClient(req requester, uid string) *client {
|
||||
func newClient(manager clientManager, uid string) *client {
|
||||
return &client{
|
||||
req: req,
|
||||
manager: manager,
|
||||
uid: uid,
|
||||
addrKeyRing: make(map[string]*crypto.KeyRing),
|
||||
keyRingLock: &sync.RWMutex{},
|
||||
@ -63,7 +62,7 @@ func (c *client) withAuth(acc, ref string, exp time.Time) *client {
|
||||
}
|
||||
|
||||
func (c *client) r(ctx context.Context) (*resty.Request, error) {
|
||||
r := c.req.r(ctx)
|
||||
r := c.manager.r(ctx)
|
||||
|
||||
if c.uid != "" {
|
||||
r.SetHeader("x-pm-uid", c.uid)
|
||||
@ -91,30 +90,23 @@ func (c *client) do(ctx context.Context, fn func(*resty.Request) (*resty.Respons
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := wrapRestyError(fn(r))
|
||||
res, err := wrapNoConnection(fn(r))
|
||||
if err != nil {
|
||||
if res.StatusCode() != http.StatusUnauthorized {
|
||||
return nil, err
|
||||
// Return also response so caller has more options to decide what to do.
|
||||
return res, err
|
||||
}
|
||||
|
||||
if !isAuthRefreshDisabled(ctx) {
|
||||
if err := c.authRefresh(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapRestyError(fn(r))
|
||||
return wrapNoConnection(fn(r))
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func wrapRestyError(res *resty.Response, err error) (*resty.Response, error) {
|
||||
if err, ok := err.(*resty.ResponseError); ok {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if res.RawResponse != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return res, errors.Wrap(ErrNoConnection, err.Error())
|
||||
}
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
@ -12,8 +29,6 @@ func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) {
|
||||
c.keyRingLock.Lock()
|
||||
defer c.keyRingLock.Unlock()
|
||||
|
||||
// FIXME(conman): Should this be done as part of NewClient somehow?
|
||||
|
||||
return c.unlock(ctx, passphrase)
|
||||
}
|
||||
|
||||
@ -65,6 +80,15 @@ func (c *client) clearKeys() {
|
||||
}
|
||||
|
||||
func (c *client) IsUnlocked() bool {
|
||||
// FIXME(conman): Better way to check? we don't currently check address keys.
|
||||
return c.userKeyRing != nil
|
||||
if c.userKeyRing == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, address := range c.addresses {
|
||||
if address.HasKeys != MissingKeys && c.addrKeyRing[address.ID] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@ -27,10 +27,10 @@ import (
|
||||
|
||||
// Client defines the interface of a PMAPI client.
|
||||
type Client interface {
|
||||
Auth2FA(context.Context, Auth2FAReq) error
|
||||
Auth2FA(context.Context, string) error
|
||||
AuthSalt(ctx context.Context) (string, error)
|
||||
AuthDelete(context.Context) error
|
||||
AddAuthHandler(AuthHandler)
|
||||
AddAuthRefreshHandler(AuthRefreshHandler)
|
||||
|
||||
CurrentUser(ctx context.Context) (*User, error)
|
||||
UpdateUser(ctx context.Context) (*User, error)
|
||||
@ -75,9 +75,9 @@ type Client interface {
|
||||
GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
|
||||
}
|
||||
|
||||
type AuthHandler func(*Auth) error
|
||||
type AuthRefreshHandler func(*AuthRefresh)
|
||||
|
||||
type requester interface {
|
||||
type clientManager interface {
|
||||
r(context.Context) *resty.Request
|
||||
authRefresh(context.Context, string, string) (*Auth, error)
|
||||
authRefresh(context.Context, string, string) (*AuthRefresh, error)
|
||||
}
|
||||
|
||||
@ -1,11 +1,72 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// HostURL is the base URL of API.
|
||||
HostURL string
|
||||
|
||||
// AppVersion sets version to headers of each request.
|
||||
AppVersion string
|
||||
|
||||
// UserAgent sets user agent to headers of each request.
|
||||
// Used only if GetUserAgent is not set.
|
||||
UserAgent string
|
||||
|
||||
// GetUserAgent is dynamic version of UserAgent.
|
||||
// Overrides UserAgent.
|
||||
GetUserAgent func() string
|
||||
|
||||
// UpgradeApplicationHandler is used to notify when there is a force upgrade.
|
||||
UpgradeApplicationHandler func()
|
||||
|
||||
// TLSIssueHandler is used to notify when there is a TLS issue.
|
||||
TLSIssueHandler func()
|
||||
}
|
||||
|
||||
var DefaultConfig = Config{
|
||||
HostURL: "https://api.protonmail.ch",
|
||||
AppVersion: "Other",
|
||||
func NewConfig(appVersionName, appVersion string) Config {
|
||||
return Config{
|
||||
HostURL: getRootURL(),
|
||||
AppVersion: getAPIOS() + strings.Title(appVersionName) + "_" + appVersion,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) getUserAgent() string {
|
||||
if c.GetUserAgent == nil {
|
||||
return c.UserAgent
|
||||
}
|
||||
return c.GetUserAgent()
|
||||
}
|
||||
|
||||
// getAPIOS returns actual operating system.
|
||||
func getAPIOS() string {
|
||||
switch os := runtime.GOOS; os {
|
||||
case "darwin": // nolint: goconst
|
||||
return "macOS"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
}
|
||||
return "Linux"
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user