GODT-35: Finish all details and make tests pass

This commit is contained in:
Michal Horejsek
2021-03-11 14:37:15 +01:00
committed by Jakub
parent 2284e9ede1
commit 8109831c07
173 changed files with 4697 additions and 2897 deletions

View File

@ -181,21 +181,18 @@ func New( // nolint[funlen]
kc = keychain.NewMissingKeychain()
}
// FIXME(conman): Customize config depending on build type (app version, host URL).
cm := pmapi.New(pmapi.DefaultConfig)
cfg := pmapi.NewConfig(configName, constants.Version)
cfg.GetUserAgent = userAgent.String
cfg.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
cfg.TLSIssueHandler = func() { listener.Emit(events.TLSCertIssue, "") }
cm := pmapi.New(cfg)
// FIXME(conman): Should this be a real object, not just created via callbacks?
cm.AddConnectionObserver(pmapi.NewConnectionObserver(
func() { listener.Emit(events.InternetOffEvent, "") },
func() { listener.Emit(events.InternetOnEvent, "") },
))
// FIXME(conman): Implement force upgrade observer.
// apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
// FIXME(conman): Set up fancy round tripper with DoH/TLS checks etc.
// cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
jar, err := cookies.NewCookieJar(settingsObj)
if err != nil {
return nil, err
@ -341,6 +338,7 @@ func (b *Base) run(appMainLoop func(*Base, *cli.Context) error) cli.ActionFunc {
}
logging.SetLevel(c.String(flagLogLevel))
b.CM.SetLogging(logrus.WithField("pkg", "pmapi"), logrus.GetLevel() == logrus.TraceLevel)
logrus.
WithField("appName", b.Name).

View File

@ -65,8 +65,7 @@ func New(
// Allow DoH before starting the app if the user has previously set this setting.
// This allows us to start even if protonmail is blocked.
if s.GetBool(settings.AllowProxyKey) {
// FIXME(conman): Support enable/disable of DoH.
// clientManager.AllowProxy()
clientManager.AllowProxy()
}
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener)
@ -120,7 +119,7 @@ func (b *Bridge) heartbeat() {
// ReportBug reports a new bug from the user.
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
return b.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
return b.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{
OS: osType,
OSVersion: osVersion,
Browser: emailClient,

View File

@ -21,7 +21,6 @@ import (
"context"
"strings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/abiosoft/ishell"
)
@ -75,13 +74,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
return
}
if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
if auth.HasTwoFactor() {
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
if twoFactor == "" {
return
}
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
err = client.Auth2FA(context.Background(), twoFactor)
if err != nil {
f.processAPIError(err)
return
@ -89,7 +88,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
}
mailboxPassword := password
if auth.PasswordMode == pmapi.TwoPasswordMode {
if auth.HasMailboxPassword() {
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
}
if mailboxPassword == "" {

View File

@ -84,11 +84,6 @@ func New( //nolint[funlen]
Aliases: []string{"u", "version", "v"},
Func: fe.checkUpdates,
})
checkCmd.AddCmd(&ishell.Cmd{Name: "internet",
Help: "check internet connection. (aliases: i, conn, connection)",
Aliases: []string{"i", "con", "connection"},
Func: fe.checkInternetConnection,
})
fe.AddCmd(checkCmd)
// Print info commands.
@ -177,13 +172,13 @@ func New( //nolint[funlen]
}
func (f *frontendCLI) watchEvents() {
errorCh := f.getEventChannel(events.ErrorEvent)
credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent)
internetOffCh := f.getEventChannel(events.InternetOffEvent)
internetOnCh := f.getEventChannel(events.InternetOnEvent)
addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent)
logoutCh := f.getEventChannel(events.LogoutEvent)
certIssue := f.getEventChannel(events.TLSCertIssue)
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
for {
select {
case errorDetails := <-errorCh:
@ -208,13 +203,6 @@ func (f *frontendCLI) watchEvents() {
}
}
func (f *frontendCLI) getEventChannel(event string) <-chan string {
ch := make(chan string)
f.eventListener.Add(event, ch)
f.eventListener.RetryEmit(event)
return ch
}
// Loop starts the frontend loop with an interactive shell.
func (f *frontendCLI) Loop() error {
f.Print(`

View File

@ -29,14 +29,6 @@ func (f *frontendCLI) restart(c *ishell.Context) {
}
}
func (f *frontendCLI) checkInternetConnection(c *ishell.Context) {
if f.ie.CheckConnection() == nil {
f.Println("Internet connection is available.")
} else {
f.Println("Can not contact the server, please check your internet connection.")
}
}
func (f *frontendCLI) printLogDir(c *ishell.Context) {
if path, err := f.locations.ProvideLogsPath(); err != nil {
f.Println("Failed to determine location of log files")

View File

@ -20,6 +20,7 @@ package cliie
import (
"strings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/fatih/color"
)
@ -70,13 +71,10 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
func (f *frontendCLI) processAPIError(err error) {
log.Warn("API error: ", err)
switch err {
// FIXME(conman): How to handle various API errors?
/*
case pmapi.ErrNoConnection:
f.notifyInternetOff()
case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade()
*/
case pmapi.ErrNoConnection:
f.notifyInternetOff()
case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade()
default:
f.Println("Server error:", err.Error())
}

View File

@ -24,7 +24,6 @@ import (
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/abiosoft/ishell"
)
@ -122,13 +121,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
return
}
if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
if auth.HasTwoFactor() {
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
if twoFactor == "" {
return
}
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
err = client.Auth2FA(context.Background(), twoFactor)
if err != nil {
f.processAPIError(err)
return
@ -136,7 +135,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
}
mailboxPassword := password
if auth.PasswordMode == pmapi.TwoPasswordMode {
if auth.HasMailboxPassword() {
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
}
if mailboxPassword == "" {

View File

@ -157,15 +157,6 @@ func New( //nolint[funlen]
})
fe.AddCmd(updatesCmd)
// Check commands.
checkCmd := &ishell.Cmd{Name: "check", Help: "check internet connection or new version."}
checkCmd.AddCmd(&ishell.Cmd{Name: "internet",
Help: "check internet connection. (aliases: i, conn, connection)",
Aliases: []string{"i", "con", "connection"},
Func: fe.checkInternetConnection,
})
fe.AddCmd(checkCmd)
// Print info commands.
fe.AddCmd(&ishell.Cmd{Name: "log-dir",
Help: "print path to directory with logs. (aliases: log, logs)",
@ -228,14 +219,14 @@ func New( //nolint[funlen]
}
func (f *frontendCLI) watchEvents() {
errorCh := f.getEventChannel(events.ErrorEvent)
credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent)
internetOffCh := f.getEventChannel(events.InternetOffEvent)
internetOnCh := f.getEventChannel(events.InternetOnEvent)
addressChangedCh := f.getEventChannel(events.AddressChangedEvent)
addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent)
logoutCh := f.getEventChannel(events.LogoutEvent)
certIssue := f.getEventChannel(events.TLSCertIssue)
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
for {
select {
case errorDetails := <-errorCh:
@ -262,13 +253,6 @@ func (f *frontendCLI) watchEvents() {
}
}
func (f *frontendCLI) getEventChannel(event string) <-chan string {
ch := make(chan string)
f.eventListener.Add(event, ch)
f.eventListener.RetryEmit(event)
return ch
}
// Loop starts the frontend loop with an interactive shell.
func (f *frontendCLI) Loop() error {
f.Print(`

View File

@ -39,14 +39,6 @@ func (f *frontendCLI) restart(c *ishell.Context) {
}
}
func (f *frontendCLI) checkInternetConnection(c *ishell.Context) {
if f.bridge.CheckConnection() == nil {
f.Println("Internet connection is available.")
} else {
f.Println("Can not contact the server, please check your internet connection.")
}
}
func (f *frontendCLI) printLogDir(c *ishell.Context) {
if path, err := f.locations.ProvideLogsPath(); err != nil {
f.Println("Failed to determine location of log files")

View File

@ -20,6 +20,7 @@ package cli
import (
"strings"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/fatih/color"
)
@ -70,13 +71,10 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
func (f *frontendCLI) processAPIError(err error) {
log.Warn("API error: ", err)
switch err {
// FIXME(conman): How to handle various API errors?
/*
case pmapi.ErrNoConnection:
f.notifyInternetOff()
case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade()
*/
case pmapi.ErrNoConnection:
f.notifyInternetOff()
case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade()
default:
f.Println("Server error:", err.Error())
}

View File

@ -409,7 +409,6 @@ Dialog {
onShow: {
if (winMain.updateState==gui.enums.statusNoInternet) {
go.checkInternet()
if (winMain.updateState==gui.enums.statusNoInternet) {
go.notifyError(gui.enums.errNoInternet)
root.hide()

View File

@ -857,14 +857,12 @@ Dialog {
inputPort . checkIsANumber()
//emailProvider . currentIndex!=0
)) isOK = false
go.checkInternet()
if (winMain.updateState == gui.enums.statusNoInternet) { // todo: use main error dialog for this
errorPopup.show(qsTr("Please check your internet connection."))
return false
}
break
case 2: // loading structure
go.checkInternet()
if (winMain.updateState == gui.enums.statusNoInternet) {
errorPopup.show(qsTr("Please check your internet connection."))
return false
@ -949,7 +947,6 @@ Dialog {
onShow : {
root.clear()
if (winMain.updateState==gui.enums.statusNoInternet) {
go.checkInternet()
if (winMain.updateState==gui.enums.statusNoInternet) {
winMain.popupMessage.show(go.canNotReachAPI)
root.hide()

View File

@ -25,33 +25,12 @@ import ProtonUI 1.0
Rectangle {
id: root
property var iTry: 0
property var secLeft: 0
property var second: 1000 // convert millisecond to second
property var checkInterval: [ 5, 10, 30, 60, 120, 300, 600 ] // seconds
property bool isVisible: true
property var fontSize : 1.2 * Style.main.fontSize
color : "black"
state: "upToDate"
Timer {
id: retryInternet
interval: second
triggeredOnStart: false
repeat: true
onTriggered : {
secLeft--
if (secLeft <= 0) {
retryInternet.stop()
go.checkInternet()
if (iTry < checkInterval.length-1) {
iTry++
}
secLeft=checkInterval[iTry]
retryInternet.start()
}
}
}
Row {
id: messageRow
anchors.centerIn: root
@ -110,16 +89,12 @@ Rectangle {
case "internetCheck":
break;
case "noInternet" :
retryInternet.start()
secLeft=checkInterval[iTry]
break;
case "oldVersion":
break;
case "forceUpdate":
break;
case "upToDate":
iTry = 0
secLeft=checkInterval[iTry]
break;
case "updateRestart":
break;
@ -128,24 +103,6 @@ Rectangle {
default :
break;
}
if (root.state!="noInternet") {
retryInternet.stop()
}
}
function timeToRetry() {
if (secLeft==1){
return qsTr("a second", "time to wait till internet connection is retried")
} else if (secLeft<60){
return secLeft + " " + qsTr("seconds", "time to wait till internet connection is retried")
} else {
var leading = ""+secLeft%60
if (leading.length < 2) {
leading = "0" + leading
}
return Math.floor(secLeft/60) + ":" + leading
}
}
states: [
@ -194,23 +151,15 @@ Rectangle {
PropertyChanges {
target: message
color: Style.main.line
text: qsTr("Cannot contact server. Retrying in ", "displayed when the app is disconnected from the internet or server has problems")+timeToRetry()+"."
text: qsTr("Cannot contact server. Please wait...", "displayed when the app is disconnected from the internet or server has problems")
}
PropertyChanges {
target: linkText
visible: false
}
PropertyChanges {
target: actionText
visible: true
text: qsTr("Retry now", "click to try to connect to the internet when the app is disconnected from the internet")
onClicked: {
go.checkInternet()
}
}
PropertyChanges {
target: separatorText
visible: true
visible: false
text: "|"
}
PropertyChanges {

View File

@ -1331,10 +1331,6 @@ Window {
return (fname!="fail")
}
function checkInternet() {
// nothing to do
}
function loadImportReports(fname) {
console.log("load import reports for ", fname)
}

View File

@ -20,6 +20,7 @@
package qtcommon
import (
"context"
"fmt"
"strings"
"sync"
@ -207,7 +208,7 @@ func (a *Accounts) Auth2FA(twoFacAuth string) int {
if a.auth == nil || a.authClient == nil {
err = fmt.Errorf("missing authentication in auth2FA %p %p", a.auth, a.authClient)
} else {
err = a.authClient.Auth2FA(twoFacAuth, a.auth)
err = a.authClient.Auth2FA(context.Background(), twoFacAuth)
}
if a.showLoginError(err, "auth2FA") {

View File

@ -113,10 +113,3 @@ type Listener interface {
Add(string, chan<- string)
RetryEmit(string)
}
func MakeAndRegisterEvent(eventListener Listener, event string) <-chan string {
ch := make(chan string)
eventListener.Add(event, ch)
eventListener.RetryEmit(event)
return ch
}

View File

@ -143,16 +143,16 @@ func (f *FrontendQt) NotifySilentUpdateError(err error) {
}
func (f *FrontendQt) watchEvents() {
credentialsErrorCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.CredentialsErrorEvent)
internetOffCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOffEvent)
internetOnCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOnEvent)
secondInstanceCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.SecondInstanceEvent)
restartBridgeCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.RestartBridgeEvent)
addressChangedCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedEvent)
addressChangedLogoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedLogoutEvent)
logoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.LogoutEvent)
updateApplicationCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UpgradeApplicationEvent)
newUserCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UserRefreshEvent)
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
secondInstanceCh := f.eventListener.ProvideChannel(events.SecondInstanceEvent)
restartBridgeCh := f.eventListener.ProvideChannel(events.RestartBridgeEvent)
addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
updateApplicationCh := f.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
newUserCh := f.eventListener.ProvideChannel(events.UserRefreshEvent)
for {
select {
case <-credentialsErrorCh:
@ -351,11 +351,6 @@ func (f *FrontendQt) sendBug(description, emailClient, address string) bool {
// }
//}
// checkInternet is almost idetical to bridge
func (f *FrontendQt) checkInternet() {
f.Qml.SetConnectionStatus(f.ie.CheckConnection() == nil)
}
func (f *FrontendQt) showError(code int, err error) {
f.Qml.SetErrorDescription(err.Error())
log.WithField("code", code).Errorln(err.Error())

View File

@ -77,8 +77,7 @@ type GoQMLInterface struct {
_ string `property:"credentialsNotRemoved"`
_ string `property:"versionCheckFailed"`
//
_ func(isAvailable bool) `signal:"setConnectionStatus"`
_ func() `slot:"checkInternet"`
_ func(isAvailable bool) `signal:"setConnectionStatus"`
_ func() `slot:"setToRestart"`
@ -189,8 +188,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) {
return f.programVersion
})
s.ConnectCheckInternet(f.checkInternet)
s.ConnectSetToRestart(f.restarter.SetToRestart)
s.ConnectLoadStructureForExport(f.LoadStructureForExport)

View File

@ -20,6 +20,7 @@
package qt
import (
"context"
"fmt"
"strings"
@ -173,7 +174,7 @@ func (s *FrontendQt) auth2FA(twoFacAuth string) int {
if s.auth == nil || s.authClient == nil {
err = fmt.Errorf("missing authentication in auth2FA %p %p", s.auth, s.authClient)
} else {
err = s.authClient.Auth2FA(twoFacAuth, s.auth)
err = s.authClient.Auth2FA(context.Background(), twoFacAuth)
}
if s.showLoginError(err, "auth2FA") {

View File

@ -191,20 +191,20 @@ func (s *FrontendQt) NotifySilentUpdateError(err error) {
func (s *FrontendQt) watchEvents() {
s.WaitUntilFrontendIsReady()
errorCh := s.getEventChannel(events.ErrorEvent)
credentialsErrorCh := s.getEventChannel(events.CredentialsErrorEvent)
outgoingNoEncCh := s.getEventChannel(events.OutgoingNoEncEvent)
noActiveKeyForRecipientCh := s.getEventChannel(events.NoActiveKeyForRecipientEvent)
internetOffCh := s.getEventChannel(events.InternetOffEvent)
internetOnCh := s.getEventChannel(events.InternetOnEvent)
secondInstanceCh := s.getEventChannel(events.SecondInstanceEvent)
restartBridgeCh := s.getEventChannel(events.RestartBridgeEvent)
addressChangedCh := s.getEventChannel(events.AddressChangedEvent)
addressChangedLogoutCh := s.getEventChannel(events.AddressChangedLogoutEvent)
logoutCh := s.getEventChannel(events.LogoutEvent)
updateApplicationCh := s.getEventChannel(events.UpgradeApplicationEvent)
newUserCh := s.getEventChannel(events.UserRefreshEvent)
certIssue := s.getEventChannel(events.TLSCertIssue)
errorCh := s.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent)
outgoingNoEncCh := s.eventListener.ProvideChannel(events.OutgoingNoEncEvent)
noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent)
internetOffCh := s.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := s.eventListener.ProvideChannel(events.InternetOnEvent)
secondInstanceCh := s.eventListener.ProvideChannel(events.SecondInstanceEvent)
restartBridgeCh := s.eventListener.ProvideChannel(events.RestartBridgeEvent)
addressChangedCh := s.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := s.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := s.eventListener.ProvideChannel(events.LogoutEvent)
updateApplicationCh := s.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
newUserCh := s.eventListener.ProvideChannel(events.UserRefreshEvent)
certIssue := s.eventListener.ProvideChannel(events.TLSCertIssue)
for {
select {
case errorDetails := <-errorCh:
@ -254,13 +254,6 @@ func (s *FrontendQt) watchEvents() {
}
}
func (s *FrontendQt) getEventChannel(event string) <-chan string {
ch := make(chan string)
s.eventListener.Add(event, ch)
s.eventListener.RetryEmit(event)
return ch
}
// Loop function for tests.
//
// It runs QtExecute in new thread with function returning itself after setup.
@ -653,10 +646,6 @@ func (s *FrontendQt) isSMTPSTARTTLS() bool {
return !s.settings.GetBool(settings.SMTPSSLKey)
}
func (s *FrontendQt) checkInternet() {
s.Qml.SetConnectionStatus(s.bridge.CheckConnection() == nil)
}
func (s *FrontendQt) switchAddressModeUser(iAccount int) {
defer s.Qml.ProcessFinished()
userID := s.Accounts.get(iAccount).UserID()

View File

@ -83,8 +83,7 @@ type GoQMLInterface struct {
_ float32 `property:"progress"`
_ string `property:"progressDescription"`
_ func(isAvailable bool) `signal:"setConnectionStatus"`
_ func() `slot:"checkInternet"`
_ func(isAvailable bool) `signal:"setConnectionStatus"`
_ func() `slot:"setToRestart"`
@ -205,8 +204,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) {
return f.programVer
})
s.ConnectCheckInternet(f.checkInternet)
s.ConnectSetToRestart(f.restarter.SetToRestart)
s.ConnectToggleIsReportingOutgoingNoEnc(f.toggleIsReportingOutgoingNoEnc)

View File

@ -55,7 +55,6 @@ type UserManager interface {
GetUser(query string) (User, error)
DeleteUser(userID string, clearCache bool) error
ClearData() error
CheckConnection() error
}
// User is an interface of user needed by frontend.

View File

@ -38,11 +38,10 @@ type bridgeUser interface {
IsCombinedAddressMode() bool
GetAddressID(address string) (string, error)
GetPrimaryAddress() string
UpdateUser() error
Logout() error
CloseConnection(address string)
GetStore() storeUserProvider
GetTemporaryPMAPIClient() pmapi.Client
GetClient() pmapi.Client
}
type bridgeWrap struct {

View File

@ -422,7 +422,7 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
if isStringInList(m.LabelIDs, pmapi.StarredLabel) {
messageFlagsMap[imap.FlaggedFlag] = true
}
if m.Unread == 0 {
if !m.Unread {
messageFlagsMap[imap.SeenFlag] = true
}
if m.Has(pmapi.FlagReplied) || m.Has(pmapi.FlagRepliedAll) {
@ -560,7 +560,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil, err
}
if storeMessage.Message().Unread == 1 {
if storeMessage.Message().Unread {
for section := range msg.Body {
// Peek means get messages without marking them as read.
// If client does not only ask for peek, we have to mark them as read.

View File

@ -93,7 +93,7 @@ func newIMAPUser(
// This method should eventually no longer be necessary. Everything should go via store.
func (iu *imapUser) client() pmapi.Client {
return iu.user.GetTemporaryPMAPIClient()
return iu.user.GetClient()
}
func (iu *imapUser) isSubscribed(labelID string) bool {

View File

@ -22,6 +22,7 @@ import (
"bytes"
"context"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/transfer"
"github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
@ -40,6 +41,7 @@ type ImportExport struct {
locations Locator
cache Cacher
panicHandler users.PanicHandler
eventListener listener.Listener
clientManager pmapi.Manager
}
@ -59,13 +61,14 @@ func New(
locations: locations,
cache: cache,
panicHandler: panicHandler,
eventListener: eventListener,
clientManager: clientManager,
}
}
// ReportBug reports a new bug from the user.
func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
return ie.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
return ie.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{
OS: osType,
OSVersion: osVersion,
Browser: emailClient,
@ -89,7 +92,7 @@ func (ie *ImportExport) ReportFile(osType, osVersion, accountName, address strin
report.AddAttachment("log", "report.log", bytes.NewReader(logdata))
return ie.clientManager.ReportBug(context.TODO(), report)
return ie.clientManager.ReportBug(context.Background(), report)
}
// GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account.
@ -162,5 +165,23 @@ func (ie *ImportExport) getPMAPIProvider(username, address string) (*transfer.PM
log.WithError(err).Info("Address does not exist, using all addresses")
}
return transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
provider, err := transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
if err != nil {
return nil, err
}
go func() {
internetOffCh := ie.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := ie.eventListener.ProvideChannel(events.InternetOnEvent)
for {
select {
case <-internetOffCh:
provider.SetConnectionDown()
case <-internetOnCh:
provider.SetConnectionUp()
}
}
}()
return provider, nil
}

View File

@ -31,7 +31,7 @@ type bridgeUser interface {
CheckBridgeLogin(password string) error
IsCombinedAddressMode() bool
GetAddressID(address string) (string, error)
GetTemporaryPMAPIClient() pmapi.Client
GetClient() pmapi.Client
GetStore() storeUserProvider
}

View File

@ -81,7 +81,7 @@ func newSMTPUser(
// This method should eventually no longer be necessary. Everything should go via store.
func (su *smtpUser) client() pmapi.Client {
return su.user.GetTemporaryPMAPIClient()
return su.user.GetClient()
}
// Send sends an email from the given address to the given addresses with the given body.

View File

@ -90,7 +90,7 @@ func getLabelPrefix(l *pmapi.Label) string {
switch {
case pmapi.IsSystemLabel(l.ID):
return ""
case l.Exclusive == 1:
case bool(l.Exclusive):
return UserFoldersPrefix
default:
return UserLabelsPrefix

View File

@ -37,8 +37,8 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) {
m.newStoreNoEvents(true)
m.store.SetChangeNotifier(m.changeNotifier)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
}
func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
@ -52,8 +52,8 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
m.newStoreNoEvents(true)
m.store.SetChangeNotifier(m.changeNotifier)
msg1 := getTestMessage("msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
msg2 := getTestMessage("msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg2 := getTestMessage("msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
require.Nil(t, m.store.createOrUpdateMessagesEvent([]*pmapi.Message{msg1, msg2}))
}
@ -63,8 +63,8 @@ func TestNotifyChangeDeleteMessage(t *testing.T) {
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(2))
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(1))

View File

@ -81,7 +81,7 @@ func (loop *eventLoop) client() pmapi.Client {
func (loop *eventLoop) setFirstEventID() (err error) {
loop.log.Info("Setting first event ID")
event, err := loop.client().GetEvent(context.TODO(), "")
event, err := loop.client().GetEvent(context.Background(), "")
if err != nil {
loop.log.WithError(err).Error("Could not get latest event ID")
return
@ -222,8 +222,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
// (e.g. no internet, ulimit reached etc.)
defer func() {
// FIXME(conman): How to handle errors of different types?
if errors.Is(err, pmapi.ErrNoConnection) {
if errors.Cause(err) == pmapi.ErrNoConnection {
l.Warn("Internet unavailable")
err = nil
}
@ -234,20 +233,17 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
err = nil
}
// FIXME(conman): Handle force upgrade.
/*
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
l.Warn("Need to upgrade application")
err = nil
}
*/
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
l.Warn("Need to upgrade application")
err = nil
}
if err == nil {
loop.errCounter = 0
}
// All errors except ErrUnauthorized (which is not possible to recover from) are ignored.
if !errors.Is(err, pmapi.ErrUnauthorized) {
if err != nil && errors.Cause(err) != pmapi.ErrUnauthorized {
l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped")
loop.errCounter++
if loop.errCounter == errMaxSentry {
@ -268,7 +264,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
loop.pollCounter++
var event *pmapi.Event
if event, err = loop.client().GetEvent(context.TODO(), loop.currentEventID); err != nil {
if event, err = loop.client().GetEvent(context.Background(), loop.currentEventID); err != nil {
return false, errors.Wrap(err, "failed to get event")
}
@ -295,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
}
}
return event.More == 1, err
return bool(event.More), err
}
func (loop *eventLoop) processEvent(event *pmapi.Event) (err error) {
@ -354,7 +350,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
// Get old addresses for comparisons before updating user.
oldList := loop.client().Addresses()
if err = loop.user.UpdateUser(); err != nil {
if err = loop.user.UpdateUser(context.Background()); err != nil {
if logoutErr := loop.user.Logout(); logoutErr != nil {
log.WithError(logoutErr).Error("Failed to logout user after failed update")
}
@ -465,16 +461,12 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
if msg, err = loop.client().GetMessage(context.TODO(), message.ID); err != nil {
// FIXME(conman): How to handle error of this particular type?
/*
if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok {
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
err = nil
continue
}
*/
if msg, err = loop.client().GetMessage(context.Background(), message.ID); err != nil {
if _, ok := err.(pmapi.ErrUnprocessableEntity); ok {
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
err = nil
continue
}
return errors.Wrap(err, "failed to get message from API for updating")
}

View File

@ -42,15 +42,15 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
// next event if there is `More` of them.
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{
EventID: "event50",
More: 1,
More: true,
}, nil),
m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{
EventID: "event70",
More: 0,
More: false,
}, nil),
m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{
EventID: "event71",
More: 0,
More: false,
}, nil),
)
m.newStoreNoEvents(true)
@ -188,7 +188,7 @@ func TestEventLoopUpdateMessage(t *testing.T) {
msg := &pmapi.Message{
ID: "msg1",
Subject: "old",
Unread: 0,
Unread: false,
Flags: 10,
Sender: address1,
ToList: []*mail.Address{address2},
@ -200,7 +200,7 @@ func TestEventLoopUpdateMessage(t *testing.T) {
newMsg := &pmapi.Message{
ID: "msg1",
Subject: "new",
Unread: 1,
Unread: true,
Flags: 11,
Sender: address2,
ToList: []*mail.Address{address1},

View File

@ -129,17 +129,10 @@ func (mc *mailboxCounts) getPMLabel() *pmapi.Label {
Color: mc.Color,
Order: mc.Order,
Type: pmapi.LabelTypeMailbox,
Exclusive: mc.isExclusive(),
Exclusive: pmapi.Boolean(mc.IsFolder),
}
}
func (mc *mailboxCounts) isExclusive() int {
if mc.IsFolder {
return 1
}
return 0
}
// createOrUpdateMailboxCountsBuckets will not change the on-API-counts.
func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) error {
// Don't forget about system folders.
@ -162,7 +155,7 @@ func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) er
mailbox.LabelName = label.Path
mailbox.Color = label.Color
mailbox.Order = label.Order
mailbox.IsFolder = label.Exclusive == 1
mailbox.IsFolder = bool(label.Exclusive)
// Write.
if err = mailbox.txWriteToBucket(countsBkt); err != nil {

View File

@ -75,7 +75,7 @@ func TestMailboxNames(t *testing.T) {
newLabel(100, "labelID1", "Label1"),
newLabel(1000, "folderID1", "Folder1"),
}
foldersAndLabels[1].Exclusive = 1
foldersAndLabels[1].Exclusive = true
for _, counts := range getSystemFolders() {
foldersAndLabels = append(foldersAndLabels, counts.getPMLabel())

View File

@ -37,10 +37,10 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) {
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{pmapi.AllMailLabel})
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
@ -82,20 +82,20 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
m.newStoreNoEvents(true)
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = " externalID-non-pm-com "
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = "<externalID@pm.me>"
tstMsg.Header = mail.Header{"References": []string{"wrongID", "externalID-non-pm-com", "msg2"}}
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
// Not sure if this is a real-world scenario but we should be able to address this properly.
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = " < external.()+*[]ID@another.pm.me > "
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))

View File

@ -18,8 +18,6 @@
package store
import (
"context"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@ -43,7 +41,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
// wrapping it.
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
msg, err := storeMailbox.client().GetMessage(context.TODO(), apiID)
msg, err := storeMailbox.client().GetMessage(exposeContextForIMAP(), apiID)
if err != nil {
return nil, err
}
@ -70,7 +68,7 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
Message: body,
}
res, err := storeMailbox.client().Import(context.TODO(), pmapi.ImportMsgReqs{importReqs})
res, err := storeMailbox.client().Import(exposeContextForIMAP(), pmapi.ImportMsgReqs{importReqs})
if err != nil {
return err
}
@ -99,7 +97,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
return ErrAllMailOpNotAllowed
}
defer storeMailbox.pollNow()
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID)
}
// UnlabelMessages removes the label by calling an API.
@ -112,7 +110,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
return ErrAllMailOpNotAllowed
}
defer storeMailbox.pollNow()
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID)
}
// MarkMessagesRead marks the message read by calling an API.
@ -132,14 +130,14 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
// Therefore we do not issue API update if the message is already read.
ids := []string{}
for _, apiID := range apiIDs {
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread == 1 {
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread {
ids = append(ids, apiID)
}
}
if len(ids) == 0 {
return nil
}
return storeMailbox.client().MarkMessagesRead(context.TODO(), ids)
return storeMailbox.client().MarkMessagesRead(exposeContextForIMAP(), ids)
}
// MarkMessagesUnread marks the message unread by calling an API.
@ -151,7 +149,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as unread")
defer storeMailbox.pollNow()
return storeMailbox.client().MarkMessagesUnread(context.TODO(), apiIDs)
return storeMailbox.client().MarkMessagesUnread(exposeContextForIMAP(), apiIDs)
}
// MarkMessagesStarred adds the Starred label by calling an API.
@ -164,7 +162,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as starred")
defer storeMailbox.pollNow()
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel)
}
// MarkMessagesUnstarred removes the Starred label by calling an API.
@ -177,7 +175,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as unstarred")
defer storeMailbox.pollNow()
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel)
}
// MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API
@ -261,11 +259,11 @@ func (storeMailbox *Mailbox) RemoveDeleted(apiIDs []string) error {
}
case pmapi.DraftLabel:
storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts")
if err := storeMailbox.client().DeleteMessages(context.TODO(), apiIDs); err != nil {
if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), apiIDs); err != nil {
return err
}
default:
if err := storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID); err != nil {
if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID); err != nil {
return err
}
}
@ -303,13 +301,13 @@ func (storeMailbox *Mailbox) deleteFromTrashOrSpam(apiIDs []string) error {
}
}
if len(messageIDsToUnlabel) > 0 {
if err := storeMailbox.client().UnlabelMessages(context.TODO(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
l.WithError(err).Warning("Cannot unlabel before deleting")
}
}
if len(messageIDsToDelete) > 0 {
storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages")
if err := storeMailbox.client().DeleteMessages(context.TODO(), messageIDsToDelete); err != nil {
if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), messageIDsToDelete); err != nil {
return err
}
}

View File

@ -5,10 +5,10 @@
package mocks
import (
reflect "reflect"
context "context"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockPanicHandler is a mock of PanicHandler interface
@ -207,17 +207,17 @@ func (mr *MockBridgeUserMockRecorder) Logout() *gomock.Call {
}
// UpdateUser mocks base method
func (m *MockBridgeUser) UpdateUser() error {
func (m *MockBridgeUser) UpdateUser(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUser")
ret := m.ctrl.Call(m, "UpdateUser", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateUser indicates an expected call of UpdateUser
func (mr *MockBridgeUserMockRecorder) UpdateUser() *gomock.Call {
func (mr *MockBridgeUserMockRecorder) UpdateUser(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser), arg0)
}
// MockChangeNotifier is a mock of ChangeNotifier interface

View File

@ -5,10 +5,9 @@
package mocks
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockListener is a mock of Listener interface
@ -58,6 +57,20 @@ func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
}
// ProvideChannel mocks base method
func (m *MockListener) ProvideChannel(arg0 string) <-chan string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ProvideChannel", arg0)
ret0, _ := ret[0].(<-chan string)
return ret0
}
// ProvideChannel indicates an expected call of ProvideChannel
func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0)
}
// Remove mocks base method
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper()

View File

@ -101,6 +101,18 @@ var (
ErrNoSuchSeqNum = errors.New("no such sequence number") //nolint[gochecknoglobals]
)
// exposeContextForIMAP should be replaced once with context passed
// as an argument from IMAP package and IMAP library should cancel
// context when IMAP client cancels the request.
func exposeContextForIMAP() context.Context {
return context.TODO()
}
// exposeContextForSMTP is the same as above but for SMTP.
func exposeContextForSMTP() context.Context {
return context.TODO()
}
// Store is local user storage, which handles the synchronization between IMAP and PM API.
type Store struct {
sentryReporter *sentry.Reporter
@ -278,7 +290,7 @@ func (store *Store) client() pmapi.Client {
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
// the API is unavailable for whatever reason it tries to fetch the labels locally.
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
if labels, err = store.client().ListLabels(context.TODO()); err != nil {
if labels, err = store.client().ListLabels(context.Background()); err != nil {
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
store.log.WithError(err).Error("Cannot list local labels")

View File

@ -184,8 +184,8 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true},
})
mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes()
mocks.client.EXPECT().CountMessages(gomock.Any(), "")

View File

@ -148,7 +148,7 @@ func getSplitIDAndCount(labelID string, api messageLister, page int) (string, in
Limit: 1,
}
// If the page does not exist, an empty page instead of an error is returned.
messages, total, err := api.ListMessages(context.TODO(), filter)
messages, total, err := api.ListMessages(context.Background(), filter)
if err != nil {
return "", 0, errors.Wrap(err, "failed to list messages")
}
@ -190,7 +190,7 @@ func syncBatch( //nolint[funlen]
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
messages, _, err := api.ListMessages(context.TODO(), filter)
messages, _, err := api.ListMessages(context.Background(), filter)
if err != nil {
return errors.Wrap(err, "failed to list messages")
}

View File

@ -17,7 +17,11 @@
package store
import "github.com/ProtonMail/proton-bridge/pkg/pmapi"
import (
"context"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type PanicHandler interface {
HandlePanic()
@ -32,7 +36,7 @@ type BridgeUser interface {
GetPrimaryAddress() string
GetStoreAddresses() []string
GetClient() pmapi.Client
UpdateUser() error
UpdateUser(context.Context) error
CloseAllConnections()
CloseConnection(string)
Logout() error

View File

@ -17,8 +17,6 @@
package store
import "context"
// UserID returns user ID.
func (store *Store) UserID() string {
return store.user.ID()
@ -26,7 +24,7 @@ func (store *Store) UserID() string {
// GetSpace returns used and total space in bytes.
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
apiUser, err := store.client().CurrentUser(context.TODO())
apiUser, err := store.client().CurrentUser(exposeContextForIMAP())
if err != nil {
return 0, 0, err
}
@ -35,7 +33,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
// GetMaxUpload returns max size of message + all attachments in bytes.
func (store *Store) GetMaxUpload() (int64, error) {
apiUser, err := store.client().CurrentUser(context.TODO())
apiUser, err := store.client().CurrentUser(exposeContextForIMAP())
if err != nil {
return 0, err
}

View File

@ -147,7 +147,7 @@ func (store *Store) createOrUpdateAddressInfo(addressList pmapi.AddressList) (er
// filterAddresses filters out inactive addresses and ensures the original address is listed first.
func filterAddresses(addressList pmapi.AddressList) (filteredList pmapi.AddressList) {
for _, address := range addressList {
if address.Receive != pmapi.CanReceive {
if !address.Receive {
continue
}

View File

@ -18,7 +18,6 @@
package store
import (
"context"
"fmt"
"strings"
@ -39,14 +38,14 @@ func (store *Store) createMailbox(name string) error {
color := store.leastUsedColor()
var exclusive int
var exclusive bool
switch {
case strings.HasPrefix(name, UserLabelsPrefix):
name = strings.TrimPrefix(name, UserLabelsPrefix)
exclusive = 0
exclusive = false
case strings.HasPrefix(name, UserFoldersPrefix):
name = strings.TrimPrefix(name, UserFoldersPrefix)
exclusive = 1
exclusive = true
default:
// Ideally we would throw an error here, but then Outlook for
// macOS keeps trying to make an IMAP Drafts folder and popping
@ -56,10 +55,10 @@ func (store *Store) createMailbox(name string) error {
return nil
}
_, err := store.client().CreateLabel(context.TODO(), &pmapi.Label{
_, err := store.client().CreateLabel(exposeContextForIMAP(), &pmapi.Label{
Name: name,
Color: color,
Exclusive: exclusive,
Exclusive: pmapi.Boolean(exclusive),
Type: pmapi.LabelTypeMailbox,
})
return err
@ -126,7 +125,7 @@ func (store *Store) leastUsedColor() string {
func (store *Store) updateMailbox(labelID, newName, color string) error {
defer store.eventLoop.pollNow()
_, err := store.client().UpdateLabel(context.TODO(), &pmapi.Label{
_, err := store.client().UpdateLabel(exposeContextForIMAP(), &pmapi.Label{
ID: labelID,
Name: newName,
Color: color,
@ -143,15 +142,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
var err error
switch labelID {
case pmapi.SpamLabel:
err = store.client().EmptyFolder(context.TODO(), pmapi.SpamLabel, addressID)
err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.SpamLabel, addressID)
case pmapi.TrashLabel:
err = store.client().EmptyFolder(context.TODO(), pmapi.TrashLabel, addressID)
err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.TrashLabel, addressID)
default:
err = fmt.Errorf("cannot empty mailbox %v", labelID)
}
return err
}
return store.client().DeleteLabel(context.TODO(), labelID)
return store.client().DeleteLabel(exposeContextForIMAP(), labelID)
}
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
@ -166,7 +165,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
return nil
}
labels, err := store.client().ListLabels(context.TODO())
labels, err := store.client().ListLabels(exposeContextForIMAP())
if err != nil {
return err
}

View File

@ -19,7 +19,6 @@ package store
import (
"bytes"
"context"
"encoding/json"
"io"
"io/ioutil"
@ -58,7 +57,7 @@ func (store *Store) CreateDraft(
}
draftAction := store.getDraftAction(message)
draft, err := store.client().CreateDraft(context.TODO(), message, parentID, draftAction)
draft, err := store.client().CreateDraft(exposeContextForSMTP(), message, parentID, draftAction)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create draft")
}
@ -70,7 +69,7 @@ func (store *Store) CreateDraft(
for _, att := range attachments {
att.attachment.MessageID = draft.ID
createdAttachment, err := store.client().CreateAttachment(context.TODO(), att.attachment, att.encReader, att.sigReader)
createdAttachment, err := store.client().CreateAttachment(exposeContextForSMTP(), att.attachment, att.encReader, att.sigReader)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create attachment")
}
@ -184,7 +183,7 @@ func (store *Store) getDraftAction(message *pmapi.Message) int {
// SendMessage sends the message.
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
defer store.eventLoop.pollNow()
_, _, err := store.client().SendMessage(context.TODO(), messageID, req)
_, _, err := store.client().SendMessage(exposeContextForSMTP(), messageID, req)
return err
}

View File

@ -24,6 +24,7 @@ import (
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/golang/mock/gomock"
a "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -34,10 +35,10 @@ func TestGetAllMessageIDs(t *testing.T) {
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{})
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
}
@ -47,7 +48,7 @@ func TestGetMessageFromDB(t *testing.T) {
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
tests := []struct{ msgID, wantErr string }{
{"msg1", ""},
@ -72,7 +73,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg, err := m.store.getMessageFromDB("msg1")
require.Nil(t, err)
@ -104,7 +105,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
a.Equal(t, wantHeader, msg.Header)
// Check calculated data are not overridden by reinsert.
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err)
@ -118,8 +119,8 @@ func TestDeleteMessage(t *testing.T) {
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
require.Nil(t, m.store.deleteMessageEvent("msg1"))
@ -127,17 +128,17 @@ func TestDeleteMessage(t *testing.T) {
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
}
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread pmapi.Boolean, labelIDs []string) { //nolint[unparam]
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam]
msg := getTestMessage(id, subject, sender, unread, labelIDs)
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
}
func getTestMessage(id, subject, sender string, unread pmapi.Boolean, labelIDs []string) *pmapi.Message {
func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message {
address := &mail.Address{Address: sender}
return &pmapi.Message{
ID: id,
Subject: subject,
Unread: unread,
Unread: pmapi.Boolean(unread),
Sender: address,
ToList: []*mail.Address{address},
LabelIDs: labelIDs,
@ -162,7 +163,7 @@ func TestCreateDraftCheckMessageSize(t *testing.T) {
defer clear()
m.newStoreNoEvents(false)
m.client.EXPECT().CurrentUser().Return(&pmapi.User{
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+.
}, nil)
@ -181,7 +182,7 @@ func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) {
defer clear()
m.newStoreNoEvents(false)
m.client.EXPECT().CurrentUser().Return(&pmapi.User{
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+.
}, nil)

View File

@ -35,7 +35,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
// updateCountsFromServer will download and set the counts.
func (store *Store) updateCountsFromServer() error {
counts, err := store.client().CountMessages(context.TODO(), "")
counts, err := store.client().CountMessages(context.Background(), "")
if err != nil {
return errors.Wrap(err, "cannot update counts from server")
}

View File

@ -31,8 +31,8 @@ func TestLoadSaveSyncState(t *testing.T) {
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
// Clear everything.

View File

@ -5,11 +5,10 @@
package mocks
import (
reflect "reflect"
imap "github.com/emersion/go-imap"
sasl "github.com/emersion/go-sasl"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockPanicHandler is a mock of PanicHandler interface

View File

@ -19,7 +19,9 @@ package transfer
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"time"
@ -37,6 +39,8 @@ const (
imapRetries = 10
imapReconnectTimeout = 30 * time.Minute
imapReconnectSleep = time.Minute
protonStatusURL = "http://protonstatus.com/vpn_status"
)
type imapErrorLogger struct {
@ -117,19 +121,15 @@ func (p *IMAPProvider) tryReconnect(ensureSelectedIn string) error {
return previousErr
}
// FIXME(conman): This should register as connection observer.
err := checkConnection()
log.WithError(err).Debug("Connection check")
if err != nil {
time.Sleep(imapReconnectSleep)
previousErr = err
continue
}
/*
err := pmapi.CheckConnection()
log.WithError(err).Debug("Connection check")
if err != nil {
time.Sleep(imapReconnectSleep)
previousErr = err
continue
}
*/
err := p.reauth()
err = p.reauth()
log.WithError(err).Debug("Reauth")
if err != nil {
time.Sleep(imapReconnectSleep)
@ -289,3 +289,23 @@ func (p *IMAPProvider) fetchHelper(uid bool, ensureSelectedIn string, seqSet *im
return err
}, ensureSelectedIn)
}
// checkConnection returns an error if there is no internet connection.
// Note we don't want to use client manager because it only reports connection
// issues with API; we are only interested here whether we can reach
// third-party IMAP servers.
func checkConnection() error {
client := &http.Client{Timeout: time.Second * 10}
resp, err := client.Get(protonStatusURL)
if err != nil {
return err
}
_ = resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("HTTP status code %d", resp.StatusCode)
}
return nil
}

View File

@ -52,15 +52,10 @@ func (p *PMAPIProvider) CreateMailbox(mailbox Mailbox) (Mailbox, error) {
return Mailbox{}, errors.New("mailbox is already created")
}
exclusive := 0
if mailbox.IsExclusive {
exclusive = 1
}
label, err := p.client.CreateLabel(context.TODO(), &pmapi.Label{
label, err := p.client.CreateLabel(context.Background(), &pmapi.Label{
Name: mailbox.Name,
Color: mailbox.Color,
Exclusive: exclusive,
Exclusive: pmapi.Boolean(mailbox.IsExclusive),
Type: pmapi.LabelTypeMailbox,
})
if err != nil {
@ -126,7 +121,7 @@ func (p *PMAPIProvider) importDraft(msg Message, globalMailbox *Mailbox) (string
}
if message.Sender == nil {
mainAddress := p.client().Addresses().Main()
mainAddress := p.client.Addresses().Main()
message.Sender = &mail.Address{
Name: mainAddress.DisplayName,
Address: mainAddress.Email,
@ -227,14 +222,6 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
}
}
var unread pmapi.Boolean
if msg.Unread {
unread = pmapi.True
} else {
unread = pmapi.False
}
labelIDs := []string{}
for _, target := range msg.Targets {
// Frontend should not set All Mail to Rules, but to be sure...
@ -249,7 +236,7 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
return &pmapi.ImportMsgReq{
Metadata: &pmapi.ImportMetadata{
AddressID: p.addressID,
Unread: unread,
Unread: pmapi.Boolean(msg.Unread),
Time: message.Time,
Flags: computeMessageFlags(message.Header),
LabelIDs: labelIDs,

View File

@ -153,10 +153,10 @@ func setupPMAPIRules(rules transferRules) {
func setupPMAPIClientExpectationForExport(m *mocks) {
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: 0, Order: 2},
{ID: "label2", Name: "Bar", Color: "green", Exclusive: 0, Order: 1},
{ID: "folder1", Name: "One", Color: "red", Exclusive: 1, Order: 1},
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: 1, Order: 2},
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: false, Order: 2},
{ID: "label2", Name: "Bar", Color: "green", Exclusive: false, Order: 1},
{ID: "folder1", Name: "One", Color: "red", Exclusive: true, Order: 1},
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: true, Order: 2},
}, nil).AnyTimes()
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{
{LabelID: "label1", Total: 10},

View File

@ -30,9 +30,17 @@ import (
const (
pmapiRetries = 10
pmapiReconnectTimeout = 30 * time.Minute
pmapiReconnectSleep = time.Minute
pmapiReconnectSleep = 10 * time.Second
)
func (p *PMAPIProvider) SetConnectionUp() {
p.connection = true
}
func (p *PMAPIProvider) SetConnectionDown() {
p.connection = false
}
func (p *PMAPIProvider) ensureConnection(callback func() error) error {
var callErr error
for i := 1; i <= pmapiRetries; i++ {
@ -58,18 +66,10 @@ func (p *PMAPIProvider) tryReconnect() error {
return previousErr
}
// FIXME(conman): This should register as a connection observer somehow...
// Maybe the entire "provider" could register as an observer and pause if it is notified of dropped connection?
/*
err := p.clientManager.CheckConnection()
log.WithError(err).Debug("Connection check")
if err != nil {
time.Sleep(pmapiReconnectSleep)
previousErr = err
continue
}
*/
if !p.connection {
time.Sleep(pmapiReconnectSleep)
continue
}
break
}
@ -83,7 +83,7 @@ func (p *PMAPIProvider) listMessages(filter *pmapi.MessagesFilter) (messages []*
p.timeIt.start("listing", key)
defer p.timeIt.stop("listing", key)
messages, count, err = p.client.ListMessages(context.TODO(), filter)
messages, count, err = p.client.ListMessages(context.Background(), filter)
return err
})
return
@ -94,7 +94,7 @@ func (p *PMAPIProvider) getMessage(msgID string) (message *pmapi.Message, err er
p.timeIt.start("download", msgID)
defer p.timeIt.stop("download", msgID)
message, err = p.client.GetMessage(context.TODO(), msgID)
message, err = p.client.GetMessage(context.Background(), msgID)
return err
})
return
@ -105,7 +105,7 @@ func (p *PMAPIProvider) importRequest(msgSourceID string, req pmapi.ImportMsgReq
p.timeIt.start("upload", msgSourceID)
defer p.timeIt.stop("upload", msgSourceID)
res, err = p.client.Import(context.TODO(), req)
res, err = p.client.Import(context.Background(), req)
return err
})
return
@ -116,7 +116,7 @@ func (p *PMAPIProvider) createDraft(msgSourceID string, message *pmapi.Message,
p.timeIt.start("upload", msgSourceID)
defer p.timeIt.stop("upload", msgSourceID)
draft, err = p.client.CreateDraft(context.TODO(), message, parent, action)
draft, err = p.client.CreateDraft(context.Background(), message, parent, action)
return err
})
return
@ -129,7 +129,7 @@ func (p *PMAPIProvider) createAttachment(msgSourceID string, att *pmapi.Attachme
p.timeIt.start("upload", key)
defer p.timeIt.stop("upload", key)
created, err = p.client.CreateAttachment(context.TODO(), att, r, sig)
created, err = p.client.CreateAttachment(context.Background(), att, r, sig)
return err
})
return

View File

@ -28,6 +28,7 @@ import (
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
@ -274,7 +275,7 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
wg.Wait()
}
func newTestUpdater(manager *mocks.MockManager, curVer string, earlyAccess bool) *Updater {
func newTestUpdater(manager pmapi.Manager, curVer string, earlyAccess bool) *Updater {
return New(
manager,
&fakeInstaller{},

View File

@ -1,251 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestUpdateUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil),
m.pmapiClient.EXPECT().ReloadKeys([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
gomock.InOrder(
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
)
assert.NoError(t, user.UpdateUser())
waitForEvents()
}
func TestUserSwitchAddressMode(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
assert.True(t, user.store.IsCombinedMode())
assert.True(t, user.creds.IsCombinedAddressMode)
assert.True(t, user.IsCombinedAddressMode())
waitForEvents()
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsSplit, nil),
)
assert.NoError(t, user.SwitchAddressMode())
assert.False(t, user.store.IsCombinedMode())
assert.False(t, user.creds.IsCombinedAddressMode)
assert.False(t, user.IsCombinedAddressMode())
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
assert.NoError(t, user.SwitchAddressMode())
assert.True(t, user.store.IsCombinedMode())
assert.True(t, user.creds.IsCombinedAddressMode)
assert.True(t, user.IsCombinedAddressMode())
waitForEvents()
}
func TestLogoutUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := user.Logout()
waitForEvents()
assert.NoError(t, err)
}
func TestLogoutUserFailsLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := user.Logout()
waitForEvents()
assert.NoError(t, err)
}
func TestCheckBridgeLoginOK(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
waitForEvents()
assert.NoError(t, err)
}
func TestCheckBridgeLoginTwiceOK(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().IsUnlocked().Return(true),
)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
waitForEvents()
assert.NoError(t, err)
err = user.CheckBridgeLogin(testCredentials.BridgePassword)
waitForEvents()
assert.NoError(t, err)
}
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
isApplicationOutdated = true
err := user.CheckBridgeLogin("any-pass")
waitForEvents()
assert.Equal(t, pmapi.ErrUpgradeApplication, err)
isApplicationOutdated = false
}
func TestCheckBridgeLoginLoggedOut(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
assert.NoError(t, err)
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
err = user.init()
assert.Error(t, err)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
waitForEvents()
assert.Equal(t, ErrLoggedOutUser, err)
}
func TestCheckBridgeLoginBadPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
)
err := user.CheckBridgeLogin("wrong!")
waitForEvents()
assert.Equal(t, "backend/credentials: incorrect password", err.Error())
}

View File

@ -1,112 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
gomock "github.com/golang/mock/gomock"
a "github.com/stretchr/testify/assert"
)
func TestNewUserNoCredentialsStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
a.Error(t, err)
}
func TestNewUserAuthRefreshFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserHasCredentials(testCredentialsDisconnected, m)
}
func TestNewUserUnlockFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(errors.New("bad password")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserHasCredentials(testCredentialsDisconnected, m)
}
func TestNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
mockEventLoopNoAction(m)
checkNewUserHasCredentials(testCredentials, m)
}
func checkNewUserHasCredentials(creds *credentials.Credentials, m mocks) {
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
defer cleanUpUserData(user)
_ = user.init()
waitForEvents()
a.Equal(m.t, creds, user.creds)
}
func _TestUserEventRefreshUpdatesAddresses(t *testing.T) { // nolint[funlen]
a.Fail(t, "not implemented")
}

View File

@ -1,89 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testNewUser sets up a new, authorised user.
func testNewUser(m mocks) *User {
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
mockEventLoopNoAction(m)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
assert.NoError(m.t, err)
err = user.init()
assert.NoError(m.t, err)
mockAuthUpdate(user, "reftok", m)
return user
}
func testNewUserForLogout(m mocks) *User {
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
mockEventLoopNoAction(m)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
assert.NoError(m.t, err)
err = user.init()
assert.NoError(m.t, err)
return user
}
func cleanUpUserData(u *User) {
_ = u.clearStore()
}
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
assert.Fail(t, "not implemented")
}
func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
require.Nil(t, user.store.Close())
user.store = nil
assert.Nil(t, user.clearStore())
}
func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
assert.NotNil(t, user.store)
assert.Nil(t, user.clearStore())
}

View File

@ -1,143 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func TestGetNoUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "nouser", -1, "user nouser not found")
}
func TestGetUserByID(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "user", 0, "")
checkUsersGetUser(t, m, "users", 1, "")
}
func TestGetUserByName(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "username", 0, "")
checkUsersGetUser(t, m, "usersname", 1, "")
}
func TestGetUserByEmail(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "user@pm.me", 0, "")
checkUsersGetUser(t, m, "users@pm.me", 1, "")
checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "")
checkUsersGetUser(t, m, "alsouser@pm.me", 1, "")
}
func TestDeleteUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := users.DeleteUser("user", true)
assert.NoError(t, err)
assert.Equal(t, 1, len(users.users))
}
// Even when logout fails, delete is done.
func TestDeleteUserWithFailingLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("no such user")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := users.DeleteUser("user", true)
assert.NoError(t, err)
assert.Equal(t, 1, len(users.users))
}
func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) {
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
user, err := users.GetUser(query)
waitForEvents()
if expectedError != "" {
assert.Equal(m.t, expectedError, err.Error())
} else {
assert.NoError(m.t, err)
}
var expectedUser *User
if index >= 0 {
expectedUser = users.users[index]
}
assert.Equal(m.t, expectedUser, user)
}

View File

@ -1,219 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestUsersFinishLoginBadMailboxPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
gomock.InOrder(
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
// Set up mocks for FinishLogin.
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(errors.New("no keys could be unlocked")),
m.pmapiClient.EXPECT().DeleteAuth(),
m.pmapiClient.EXPECT().Logout(),
)
checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", ErrWrongMailboxPassword)
}
func refreshWithToken(token string) *pmapi.Auth {
return &pmapi.Auth{
RefreshToken: token,
}
}
func credentialsWithToken(token string) *credentials.Credentials {
tmp := &credentials.Credentials{}
*tmp = *testCredentials
tmp.APIToken = token
return tmp
}
func TestUsersFinishLoginNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Basically every call client has get client manager
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
// users.New() finds no users in keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
// getAPIUser() loads user info from API (e.g. userID).
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
// addNewUser()
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}),
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
// user.init() in addNewUser
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
// store.New() in user.init
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Emit event for new user and send metrics.
m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient),
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)),
m.pmapiClient.EXPECT().Logout(),
// Reload account list in GUI.
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
// defer logout anonymous
m.pmapiClient.EXPECT().Logout(),
)
mockEventLoopNoAction(m)
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
mockAuthUpdate(user, "afterCredentials", m)
}
func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
loggedOutCreds := *testCredentials
loggedOutCreds.APIToken = ""
loggedOutCreds.MailboxPassword = ""
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
// users.New() finds one existing user in keychain.
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
// newUser()
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
// user.init()
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
// store.New() in user.init
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUnauthorized),
m.pmapiClient.EXPECT().Addresses().Return(nil),
// getAPIUser() loads user info from API (e.g. userID).
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
// connectExistingUser()
m.credentialsStore.EXPECT().UpdatePassword("user", testCredentials.MailboxPassword).Return(nil),
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
m.credentialsStore.EXPECT().UpdateToken("user", ":afterLogin").Return(nil),
// user.init() in connectExistingUser
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
// store.New() in user.init
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Reload account list in GUI.
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
// defer logout anonymous
m.pmapiClient.EXPECT().Logout(),
)
mockEventLoopNoAction(m)
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
mockAuthUpdate(user, "afterCredentials", m)
}
func TestUsersFinishLoginConnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
mockConnectedUser(m)
mockEventLoopNoAction(m)
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
// Then, try to log in again...
gomock.InOrder(
m.pmapiClient.EXPECT().AuthSalt().Return("", nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().DeleteAuth(),
m.pmapiClient.EXPECT().Logout(),
)
_, err := users.FinishLogin(m.pmapiClient, testAuth, testCredentials.MailboxPassword)
assert.Equal(t, "user is already connected", err.Error())
}
func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) *User {
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword)
waitForEvents()
assert.Equal(t, expectedErr, err)
if expectedUserID != "" {
assert.Equal(t, expectedUserID, user.ID())
assert.Equal(t, 1, len(users.users))
assert.Equal(t, expectedUserID, users.users[0].ID())
} else {
assert.Equal(t, (*User)(nil), user)
assert.Equal(t, 0, len(users.users))
}
return user
}

View File

@ -233,7 +233,7 @@ func (s *Store) get(userID string) (creds *Credentials, err error) {
_, secret, err := s.secrets.Get(userID)
if err != nil {
log.WithError(err).Error("Could not get credentials from native keychain")
log.WithError(err).Warn("Could not get credentials from native keychain")
return
}

View File

@ -26,8 +26,7 @@ import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
r "github.com/stretchr/testify/require"
)
const testSep = "\n"
@ -249,26 +248,26 @@ func TestMarshalFormats(t *testing.T) {
log.Infof("secretFmt %#v %d\n", secretFmt, len(secretFmt))
output := testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalStrings(secretStrings))
r.NoError(t, output.UnmarshalStrings(secretStrings))
log.Infof("strings out %#v \n", output)
require.True(t, input.IsSame(&output), "strings out not same")
r.True(t, input.IsSame(&output), "strings out not same")
output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalGob(secretGob))
r.NoError(t, output.UnmarshalGob(secretGob))
log.Infof("gob out %#v\n \n", output)
assert.Equal(t, input, output)
r.Equal(t, input, output)
output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.FromJSON(secretJSON))
r.NoError(t, output.FromJSON(secretJSON))
log.Infof("json out %#v \n", output)
require.True(t, input.IsSame(&output), "json out not same")
r.True(t, input.IsSame(&output), "json out not same")
/*
// Simple Fscanf not working!
output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalFmt(secretFmt))
r.NoError(t, output.UnmarshalFmt(secretFmt))
log.Infof("fmt out %#v \n", output)
require.True(t, input.IsSame(&output), "fmt out not same")
r.True(t, input.IsSame(&output), "fmt out not same")
*/
}
@ -291,7 +290,7 @@ func TestMarshal(t *testing.T) {
log.Infof("secret %#v %d\n", secret, len(secret))
output := Credentials{APIToken: "refresh"}
require.NoError(t, output.Unmarshal(secret))
r.NoError(t, output.Unmarshal(secret))
log.Infof("output %#v\n", output)
assert.Equal(t, input, output)
r.Equal(t, input, output)
}

View File

@ -1,107 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./listener/listener.go
// Package users is a generated GoMock package.
package users
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockListener is a mock of Listener interface
type MockListener struct {
ctrl *gomock.Controller
recorder *MockListenerMockRecorder
}
// MockListenerMockRecorder is the mock recorder for MockListener
type MockListenerMockRecorder struct {
mock *MockListener
}
// NewMockListener creates a new mock instance
func NewMockListener(ctrl *gomock.Controller) *MockListener {
mock := &MockListener{ctrl: ctrl}
mock.recorder = &MockListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
return m.recorder
}
// SetLimit mocks base method
func (m *MockListener) SetLimit(eventName string, limit time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLimit", eventName, limit)
}
// SetLimit indicates an expected call of SetLimit
func (mr *MockListenerMockRecorder) SetLimit(eventName, limit interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), eventName, limit)
}
// Add mocks base method
func (m *MockListener) Add(eventName string, channel chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", eventName, channel)
}
// Add indicates an expected call of Add
func (mr *MockListenerMockRecorder) Add(eventName, channel interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), eventName, channel)
}
// Remove mocks base method
func (m *MockListener) Remove(eventName string, channel chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", eventName, channel)
}
// Remove indicates an expected call of Remove
func (mr *MockListenerMockRecorder) Remove(eventName, channel interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), eventName, channel)
}
// Emit mocks base method
func (m *MockListener) Emit(eventName, data string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Emit", eventName, data)
}
// Emit indicates an expected call of Emit
func (mr *MockListenerMockRecorder) Emit(eventName, data interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), eventName, data)
}
// SetBuffer mocks base method
func (m *MockListener) SetBuffer(eventName string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetBuffer", eventName)
}
// SetBuffer indicates an expected call of SetBuffer
func (mr *MockListenerMockRecorder) SetBuffer(eventName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), eventName)
}
// RetryEmit mocks base method
func (m *MockListener) RetryEmit(eventName string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RetryEmit", eventName)
}
// RetryEmit indicates an expected call of RetryEmit
func (mr *MockListenerMockRecorder) RetryEmit(eventName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), eventName)
}

View File

@ -0,0 +1,120 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/pkg/listener (interfaces: Listener)
// Package mocks is a generated GoMock package.
package mocks
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
time "time"
)
// MockListener is a mock of Listener interface
type MockListener struct {
ctrl *gomock.Controller
recorder *MockListenerMockRecorder
}
// MockListenerMockRecorder is the mock recorder for MockListener
type MockListenerMockRecorder struct {
mock *MockListener
}
// NewMockListener creates a new mock instance
func NewMockListener(ctrl *gomock.Controller) *MockListener {
mock := &MockListener{ctrl: ctrl}
mock.recorder = &MockListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
return m.recorder
}
// Add mocks base method
func (m *MockListener) Add(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0, arg1)
}
// Add indicates an expected call of Add
func (mr *MockListenerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), arg0, arg1)
}
// Emit mocks base method
func (m *MockListener) Emit(arg0, arg1 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Emit", arg0, arg1)
}
// Emit indicates an expected call of Emit
func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
}
// ProvideChannel mocks base method
func (m *MockListener) ProvideChannel(arg0 string) <-chan string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ProvideChannel", arg0)
ret0, _ := ret[0].(<-chan string)
return ret0
}
// ProvideChannel indicates an expected call of ProvideChannel
func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0)
}
// Remove mocks base method
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", arg0, arg1)
}
// Remove indicates an expected call of Remove
func (mr *MockListenerMockRecorder) Remove(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), arg0, arg1)
}
// RetryEmit mocks base method
func (m *MockListener) RetryEmit(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RetryEmit", arg0)
}
// RetryEmit indicates an expected call of RetryEmit
func (mr *MockListenerMockRecorder) RetryEmit(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), arg0)
}
// SetBuffer mocks base method
func (m *MockListener) SetBuffer(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetBuffer", arg0)
}
// SetBuffer indicates an expected call of SetBuffer
func (mr *MockListenerMockRecorder) SetBuffer(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), arg0)
}
// SetLimit mocks base method
func (m *MockListener) SetLimit(arg0 string, arg1 time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLimit", arg0, arg1)
}
// SetLimit indicates an expected call of SetLimit
func (mr *MockListenerMockRecorder) SetLimit(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), arg0, arg1)
}

View File

@ -5,11 +5,10 @@
package mocks
import (
reflect "reflect"
store "github.com/ProtonMail/proton-bridge/internal/store"
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockLocator is a mock of Locator interface

View File

@ -89,29 +89,25 @@ func newUser(
// - providing it with an authorised API client
// - loading its credentials from the credentials store
// - loading and unlocking its PGP keys
// - loading its store
func (u *User) connect(ctx context.Context, client pmapi.Client, creds *credentials.Credentials) error {
// - loading its store.
func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) error {
u.log.Info("Connecting user")
// Connected users have an API client.
u.client = client
// FIXME(conman): How to remove this auth handler when user is disconnected?
u.client.AddAuthHandler(u.handleAuth)
u.client.AddAuthRefreshHandler(u.handleAuthRefresh)
// Save the latest credentials for the user.
u.creds = creds
// Connected users have unlocked keys.
// FIXME(conman): clients should always be authorized! This is a workaround to avoid a major refactor :(
if u.creds.IsConnected() {
if err := u.client.Unlock(ctx, []byte(u.creds.MailboxPassword)); err != nil {
return err
}
if err := u.unlockIfNecessary(); err != nil {
return err
}
// Connected users have a store.
if err := u.loadStore(); err != nil {
if err := u.loadStore(); err != nil { //nolint[revive] easier to read
return err
}
@ -138,17 +134,25 @@ func (u *User) loadStore() error {
return nil
}
func (u *User) handleAuth(auth *pmapi.Auth) error {
u.log.Debug("User received auth")
func (u *User) handleAuthRefresh(auth *pmapi.AuthRefresh) {
u.log.Debug("User received auth refresh update")
if auth == nil {
if err := u.logout(); err != nil {
log.WithError(err).
WithField("userID", u.userID).
Error("User logout failed while watching API auths")
}
return
}
creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken)
if err != nil {
return errors.Wrap(err, "failed to update refresh token in credentials store")
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
return
}
u.creds = creds
return nil
}
// clearStore removes the database.
@ -181,13 +185,6 @@ func (u *User) closeStore() error {
return nil
}
// GetTemporaryPMAPIClient returns an authorised PMAPI client.
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
// After proper refactor of SMTP and IMAP remove this method.
func (u *User) GetTemporaryPMAPIClient() pmapi.Client {
return u.client
}
// ID returns the user's userID.
func (u *User) ID() string {
return u.userID
@ -210,9 +207,43 @@ func (u *User) IsConnected() bool {
}
func (u *User) GetClient() pmapi.Client {
if err := u.unlockIfNecessary(); err != nil {
u.log.WithError(err).Error("Failed to unlock user")
}
return u.client
}
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
func (u *User) unlockIfNecessary() error {
if !u.creds.IsConnected() {
return nil
}
if u.client.IsUnlocked() {
return nil
}
// unlockIfNecessary is called with every access to underlying pmapi
// client. Unlock should only finish unlocking when connection is back up.
// That means it should try it fast enough and not retry if connection
// is still down.
err := u.client.Unlock(pmapi.ContextWithoutRetry(context.Background()), []byte(u.creds.MailboxPassword))
if err == nil {
return nil
}
switch errors.Cause(err) {
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
u.log.WithError(err).Warn("Could not unlock user")
return nil
}
if logoutErr := u.logout(); logoutErr != nil {
u.log.WithError(logoutErr).Warn("Could not logout user")
}
return errors.Wrap(err, "failed to unlock user")
}
// IsCombinedAddressMode returns whether user is set in combined or split mode.
// Combined mode is the default mode and is what users typically need.
// Split mode is mostly for outlook as it cannot handle sending e-mails from an
@ -307,14 +338,10 @@ func (u *User) GetBridgePassword() string {
// CheckBridgeLogin checks whether the user is logged in and the bridge
// IMAP/SMTP password is correct.
func (u *User) CheckBridgeLogin(password string) error {
// FIXME(conman): Handle force upgrade?
/*
if isApplicationOutdated {
u.listener.Emit(events.UpgradeApplicationEvent, "")
return pmapi.ErrUpgradeApplication
}
*/
if isApplicationOutdated {
u.listener.Emit(events.UpgradeApplicationEvent, "")
return pmapi.ErrUpgradeApplication
}
u.lock.RLock()
defer u.lock.RUnlock()
@ -328,16 +355,16 @@ func (u *User) CheckBridgeLogin(password string) error {
}
// UpdateUser updates user details from API and saves to the credentials.
func (u *User) UpdateUser() error {
func (u *User) UpdateUser(ctx context.Context) error {
u.lock.Lock()
defer u.lock.Unlock()
_, err := u.client.UpdateUser(context.TODO())
_, err := u.client.UpdateUser(ctx)
if err != nil {
return err
}
if err := u.client.ReloadKeys(context.TODO(), []byte(u.creds.MailboxPassword)); err != nil {
if err := u.client.ReloadKeys(ctx, []byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to reload keys")
}
@ -414,8 +441,7 @@ func (u *User) Logout() error {
return nil
}
// FIXME(conman): Do we delete API client now? Who cleans up? What about registered handlers?
if err := u.client.AuthDelete(context.TODO()); err != nil {
if err := u.client.AuthDelete(context.Background()); err != nil {
u.log.WithError(err).Warn("Failed to delete auth")
}

View File

@ -0,0 +1,195 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"context"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
r "github.com/stretchr/testify/require"
)
func TestUpdateUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().UpdateUser(gomock.Any()).Return(nil, nil),
m.pmapiClient.EXPECT().ReloadKeys(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}).Return(testCredentials, nil),
)
r.NoError(t, user.UpdateUser(context.Background()))
}
func TestUserSwitchAddressMode(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
// Ignore any sync on background.
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
// Check initial state.
r.True(t, user.store.IsCombinedMode())
r.True(t, user.creds.IsCombinedAddressMode)
r.True(t, user.IsCombinedAddressMode())
// Mock change to split mode.
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentialsSplit, nil),
)
// Check switch to split mode.
r.NoError(t, user.SwitchAddressMode())
r.False(t, user.store.IsCombinedMode())
r.False(t, user.creds.IsCombinedAddressMode)
r.False(t, user.IsCombinedAddressMode())
// MOck change to combined mode.
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentials, nil),
)
// Check switch to combined mode.
r.NoError(t, user.SwitchAddressMode())
r.True(t, user.store.IsCombinedMode())
r.True(t, user.creds.IsCombinedAddressMode)
r.True(t, user.IsCombinedAddressMode())
}
func TestLogoutUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
)
err := user.Logout()
r.NoError(t, err)
}
func TestLogoutUserFailsLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
)
err := user.Logout()
r.NoError(t, err)
}
func TestCheckBridgeLogin(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
r.NoError(t, err)
}
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
isApplicationOutdated = true
err := user.CheckBridgeLogin("any-pass")
r.Equal(t, pmapi.ErrUpgradeApplication, err)
isApplicationOutdated = false
}
func TestCheckBridgeLoginLoggedOut(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
gomock.InOrder(
// Mock init of user.
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
// Mock CheckBridgeLogin.
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"),
)
user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.NoError(t, err)
err = user.connect(m.pmapiClient, testCredentialsDisconnected)
r.Error(t, err)
defer cleanUpUserData(user)
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
r.Equal(t, ErrLoggedOutUser, err)
}
func TestCheckBridgeLoginBadPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
err := user.CheckBridgeLogin("wrong!")
r.EqualError(t, err, "backend/credentials: incorrect password")
}

View File

@ -0,0 +1,88 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
gomock "github.com/golang/mock/gomock"
r "github.com/stretchr/testify/require"
)
func TestNewUserNoCredentialsStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
_, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.Error(t, err)
}
func TestNewUserUnlockFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
gomock.InOrder(
// Init of user.
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("bad password")),
// Handle of unlock error.
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"),
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
)
checkNewUserHasCredentials(m, "failed to unlock user: bad password", testCredentialsDisconnected)
}
func TestNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m)
mockEventLoopNoAction(m)
checkNewUserHasCredentials(m, "", testCredentials)
}
func checkNewUserHasCredentials(m mocks, wantErr string, wantCreds *credentials.Credentials) {
user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.NoError(m.t, err)
defer cleanUpUserData(user)
err = user.connect(m.pmapiClient, testCredentials)
if wantErr == "" {
r.NoError(m.t, err)
} else {
r.EqualError(m.t, err, wantErr)
}
r.Equal(m.t, wantCreds, user.creds)
waitForEvents()
}

View File

@ -0,0 +1,51 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
r "github.com/stretchr/testify/require"
)
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
r.Fail(t, "not implemented")
}
func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
r.Nil(t, user.store.Close())
user.store = nil
r.Nil(t, user.clearStore())
}
func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
r.NotNil(t, user.store)
r.Nil(t, user.clearStore())
}

View File

@ -0,0 +1,41 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
r "github.com/stretchr/testify/require"
)
// testNewUser sets up a new, authorised user.
func testNewUser(m mocks) *User {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m)
mockEventLoopNoAction(m)
user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.NoError(m.t, err)
err = user.connect(m.pmapiClient, creds)
r.NoError(m.t, err)
return user
}
func cleanUpUserData(u *User) {
_ = u.clearStore()
}

View File

@ -89,24 +89,42 @@ func New(
lock: sync.RWMutex{},
}
// FIXME(conman): Handle force upgrade events.
/*
go func() {
defer panicHandler.HandlePanic()
u.watchAppOutdated()
}()
*/
go func() {
defer panicHandler.HandlePanic()
u.watchEvents()
}()
if u.credStorer == nil {
log.Error("No credentials store is available")
} else if err := u.loadUsersFromCredentialsStore(context.TODO()); err != nil {
} else if err := u.loadUsersFromCredentialsStore(); err != nil {
log.WithError(err).Error("Could not load all users from credentials store")
}
return u
}
func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
func (u *Users) watchEvents() {
upgradeCh := u.events.ProvideChannel(events.UpgradeApplicationEvent)
internetOnCh := u.events.ProvideChannel(events.InternetOnEvent)
for {
select {
case <-upgradeCh:
isApplicationOutdated = true
u.closeAllConnections()
case <-internetOnCh:
for _, user := range u.users {
if user.store == nil {
if err := user.loadStore(); err != nil {
log.WithError(err).Error("Failed to load store after reconnecting")
}
}
}
}
}
}
func (u *Users) loadUsersFromCredentialsStore() error {
u.lock.Lock()
defer u.lock.Unlock()
@ -116,23 +134,26 @@ func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
}
for _, userID := range userIDs {
l := log.WithField("user", userID)
user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
if err != nil {
logrus.WithError(err).Warn("Could not create user, skipping")
l.WithError(err).Warn("Could not create user, skipping")
continue
}
u.users = append(u.users, user)
if creds.IsConnected() {
if err := u.loadConnectedUser(ctx, user, creds); err != nil {
logrus.WithError(err).Warn("Could not load connected user")
// If there is no connection, we don't want to retry. Load should
// happen fast enough to not block GUI. When connection is back up,
// watchEvents and unlockIfNecessary will finish user init later.
if err := u.loadConnectedUser(pmapi.ContextWithoutRetry(context.Background()), user, creds); err != nil {
l.WithError(err).Warn("Could not load connected user")
}
} else {
logrus.Warn("User is disconnected and must be connected manually")
if err := u.loadDisconnectedUser(ctx, user, creds); err != nil {
logrus.WithError(err).Warn("Could not load disconnected user")
l.Warn("User is disconnected and must be connected manually")
if err := user.connect(u.clientManager.NewClient("", "", "", time.Time{}), creds); err != nil {
l.WithError(err).Warn("Could not load disconnected user")
}
}
}
@ -140,11 +161,6 @@ func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
return err
}
func (u *Users) loadDisconnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
// FIXME(conman): We shouldn't be creating unauthorized clients... this is hacky, just to avoid huge refactor!
return user.connect(ctx, u.clientManager.NewClient("", "", "", time.Time{}), creds)
}
func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
uid, ref, err := creds.SplitAPIToken()
if err != nil {
@ -153,38 +169,27 @@ func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *creden
client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref)
if err != nil {
// FIXME(conman): This is a problem... if we weren't able to create a new client due to internet,
// we need to be able to retry later, but I deleted all the hacky "retry auth if necessary" stuff...
return user.connect(ctx, u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
// When client cannot be refreshed right away due to no connection,
// we create client which will refresh automatically when possible.
connectErr := user.connect(u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
switch errors.Cause(err) {
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
return connectErr
}
if logoutErr := user.logout(); logoutErr != nil {
logrus.WithError(logoutErr).Warn("Could not logout user")
}
return errors.Wrap(err, "could not refresh token")
}
// Update the user's credentials with the latest auth used to connect this user.
if creds, err = u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil {
if creds, err = u.credStorer.UpdateToken(creds.UserID, auth.UID, auth.RefreshToken); err != nil {
return errors.Wrap(err, "could not create get user's refresh token")
}
return user.connect(ctx, client, creds)
}
func (u *Users) watchAppOutdated() {
// FIXME(conman): handle force upgrade events.
/*
ch := make(chan string)
u.events.Add(events.UpgradeApplicationEvent, ch)
for {
select {
case <-ch:
isApplicationOutdated = true
u.closeAllConnections()
case <-u.stopAll:
return
}
}
*/
return user.connect(client, creds)
}
func (u *Users) closeAllConnections() {
@ -198,19 +203,19 @@ func (u *Users) closeAllConnections() {
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
u.crashBandicoot(username)
return u.clientManager.NewClientWithLogin(context.TODO(), username, password)
return u.clientManager.NewClientWithLogin(context.Background(), username, password)
}
// FinishLogin finishes the login procedure and adds the user into the credentials store.
func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen]
apiUser, passphrase, err := getAPIUser(context.TODO(), client, password)
apiUser, passphrase, err := getAPIUser(context.Background(), client, password)
if err != nil {
return nil, errors.Wrap(err, "failed to get API user")
return nil, err
}
if user, ok := u.hasUser(apiUser.ID); ok {
if user.IsConnected() {
if err := client.AuthDelete(context.TODO()); err != nil {
if err := client.AuthDelete(context.Background()); err != nil {
logrus.WithError(err).Warn("Failed to delete new auth session")
}
@ -228,14 +233,16 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password stri
return nil, errors.Wrap(err, "failed to update password of user in credentials store")
}
if err := user.connect(context.TODO(), client, creds); err != nil {
if err := user.connect(client, creds); err != nil {
return nil, errors.Wrap(err, "failed to reconnect existing user")
}
u.events.Emit(events.UserRefreshEvent, apiUser.ID)
return user, nil
}
if err := u.addNewUser(context.TODO(), client, apiUser, auth, passphrase); err != nil {
if err := u.addNewUser(client, apiUser, auth, passphrase); err != nil {
return nil, errors.Wrap(err, "failed to add new user")
}
@ -245,7 +252,7 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password stri
}
// addNewUser adds a new user.
func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
func (u *Users) addNewUser(client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
u.lock.Lock()
defer u.lock.Unlock()
@ -266,7 +273,7 @@ func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pm
return errors.Wrap(err, "failed to create new user")
}
if err := user.connect(ctx, client, creds); err != nil {
if err := user.connect(client, creds); err != nil {
return errors.Wrap(err, "failed to connect new user")
}
@ -292,7 +299,7 @@ func getAPIUser(ctx context.Context, client pmapi.Client, password string) (*pma
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
if err := client.Unlock(ctx, passphrase); err != nil {
return nil, nil, errors.Wrap(err, "failed to unlock client")
return nil, nil, ErrWrongMailboxPassword
}
user, err := client.CurrentUser(ctx)
@ -414,22 +421,13 @@ func (u *Users) SendMetric(m metrics.Metric) error {
// AllowProxy instructs the app to use DoH to access an API proxy if necessary.
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
func (u *Users) AllowProxy() {
// FIXME(conman): Support DoH.
// u.apiManager.AllowProxy()
u.clientManager.AllowProxy()
}
// DisallowProxy instructs the app to not use DoH to access an API proxy if necessary.
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
func (u *Users) DisallowProxy() {
// FIXME(conman): Support DoH.
// u.apiManager.DisallowProxy()
}
// CheckConnection returns whether there is an internet connection.
// This should use the connection manager when it is eventually implemented.
func (u *Users) CheckConnection() error {
// FIXME(conman): Other parts of bridge that rely on this method should register as a connection observer.
panic("TODO: register as a connection observer to get this information")
u.clientManager.DisallowProxy()
}
// hasUser returns whether the struct currently has a user with ID `id`.

View File

@ -0,0 +1,49 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
gomock "github.com/golang/mock/gomock"
r "github.com/stretchr/testify/require"
)
func TestClearData(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil)
m.locator.EXPECT().Clear()
r.NoError(t, users.ClearData())
}

View File

@ -0,0 +1,69 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
gomock "github.com/golang/mock/gomock"
r "github.com/stretchr/testify/require"
)
func TestDeleteUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
gomock.InOrder(
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := users.DeleteUser("user", true)
r.NoError(t, err)
r.Equal(t, 1, len(users.users))
}
// Even when logout fails, delete is done.
func TestDeleteUserWithFailingLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
gomock.InOrder(
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")),
// Once called from user.Logout after failed creds.Logout as fallback, and once at the end of users.Logout.
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := users.DeleteUser("user", true)
r.NoError(t, err)
r.Equal(t, 1, len(users.users))
}

View File

@ -0,0 +1,76 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
r "github.com/stretchr/testify/require"
)
func TestGetNoUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
checkUsersGetUser(t, m, "nouser", -1, "user nouser not found")
}
func TestGetUserByID(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
checkUsersGetUser(t, m, "user", 0, "")
checkUsersGetUser(t, m, "users", 1, "")
}
func TestGetUserByName(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
checkUsersGetUser(t, m, "username", 0, "")
checkUsersGetUser(t, m, "usersname", 1, "")
}
func TestGetUserByEmail(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
checkUsersGetUser(t, m, "user@pm.me", 0, "")
checkUsersGetUser(t, m, "users@pm.me", 1, "")
checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "")
checkUsersGetUser(t, m, "alsouser@pm.me", 1, "")
}
func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) {
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
user, err := users.GetUser(query)
if expectedError != "" {
r.EqualError(m.t, err, expectedError)
} else {
r.NoError(m.t, err)
}
var expectedUser *User
if index >= 0 {
expectedUser = users.users[index]
}
r.Equal(m.t, expectedUser, user)
}

View File

@ -0,0 +1,132 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
r "github.com/stretchr/testify/require"
)
func TestUsersFinishLoginBadMailboxPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
// Set up mocks for FinishLogin.
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil)
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("no keys could be unlocked"))
checkUsersFinishLogin(t, m, testAuthRefresh, testCredentials.MailboxPassword, "", ErrWrongMailboxPassword)
}
func TestUsersFinishLoginNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
mockAddingConnectedUser(m)
mockEventLoopNoAction(m)
m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel))
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentials.UserID)
checkUsersFinishLogin(t, m, testAuthRefresh, testCredentials.MailboxPassword, testCredentials.UserID, nil)
}
func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Mock loading disconnected user.
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID}, nil)
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
// Mock process of FinishLogin of already added user.
gomock.InOrder(
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUserDisconnected, nil),
m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
)
mockInitConnectedUser(m)
mockEventLoopNoAction(m)
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID)
authRefresh := &pmapi.Auth{
UserID: testCredentialsDisconnected.UserID,
AuthRefresh: pmapi.AuthRefresh{
UID: "uid",
AccessToken: "acc",
RefreshToken: "ref",
},
}
checkUsersFinishLogin(t, m, authRefresh, testCredentials.MailboxPassword, testCredentialsDisconnected.UserID, nil)
}
func TestUsersFinishLoginConnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Mock loading connected user.
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockEventLoopNoAction(m)
// Mock process of FinishLogin of already connected user.
gomock.InOrder(
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
)
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
_, err := users.FinishLogin(m.pmapiClient, testAuthRefresh, testCredentials.MailboxPassword)
r.EqualError(t, err, "user is already connected")
}
func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) {
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword)
r.Equal(t, expectedErr, err)
if expectedUserID != "" {
r.Equal(t, expectedUserID, user.ID())
r.Equal(t, 1, len(users.users))
r.Equal(t, expectedUserID, users.users[0].ID())
} else {
r.Equal(t, (*User)(nil), user)
r.Equal(t, 0, len(users.users))
}
}

View File

@ -22,9 +22,10 @@ import (
"testing"
time "time"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
)
func TestNewUsersNoKeychain(t *testing.T) {
@ -32,7 +33,6 @@ func TestNewUsersNoKeychain(t *testing.T) {
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{}, errors.New("no keychain"))
checkUsersNew(t, m, []*credentials.Credentials{})
}
@ -41,108 +41,73 @@ func TestNewUsersWithoutUsersInCredentialsStore(t *testing.T) {
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
checkUsersNew(t, m, []*credentials.Credentials{})
}
func TestNewUsersWithDisconnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
}
/*
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token"))
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.pmapiClient.EXPECT().Logout()
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
}
func TestNewUsersWithConnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
mockConnectedUser(m)
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
}
func TestNewUsersWithDisconnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID}, nil)
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
}
// Tests two users with different states and checks also the order from
// credentials store is kept also in array of users.
func TestNewUsersWithUsers(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"userDisconnected", "user"}, nil)
gomock.InOrder(
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
// Set up mocks for store initialisation for the unauth user.
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
mockConnectedUser(m)
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil)
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
mockLoadingConnectedUser(m, testCredentials)
mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
}
func TestNewUsersFirstStart(t *testing.T) {
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(nil, nil, errors.New("bad token"))
m.clientManager.EXPECT().NewClient("uid", "", "acc", time.Time{}).Return(m.pmapiClient)
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
m.pmapiClient.EXPECT().IsUnlocked().Return(false)
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("not authorized"))
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
testNewUsers(t, m)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
}
*/
func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) {
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
assert.Equal(m.t, len(expectedCredentials), len(users.GetUsers()))
r.Equal(m.t, len(expectedCredentials), len(users.GetUsers()))
credentials := []*credentials.Credentials{}
for _, user := range users.users {
credentials = append(credentials, user.creds)
}
assert.Equal(m.t, expectedCredentials, credentials)
r.Equal(m.t, expectedCredentials, credentials)
}

View File

@ -33,8 +33,9 @@ import (
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
r "github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
@ -49,9 +50,12 @@ func TestMain(m *testing.M) {
var (
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
UID: "uid",
AccessToken: "acc",
RefreshToken: "ref",
UserID: "user",
AuthRefresh: pmapi.AuthRefresh{
UID: "uid",
AccessToken: "acc",
RefreshToken: "ref",
},
}
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
@ -81,7 +85,7 @@ var (
}
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "user",
UserID: "userDisconnected",
Name: "username",
Emails: "user@pm.me",
APIToken: "",
@ -94,7 +98,7 @@ var (
}
testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "users",
UserID: "usersDisconnected",
Name: "usersname",
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
APIToken: "",
@ -111,17 +115,22 @@ var (
Name: "username",
}
testPMAPIUserDisconnected = &pmapi.User{ //nolint[gochecknoglobals]
ID: "userDisconnected",
Name: "username",
}
testPMAPIAddress = &pmapi.Address{ //nolint[gochecknoglobals]
ID: "testAddressID",
Type: pmapi.OriginalAddress,
Email: "user@pm.me",
Receive: pmapi.CanReceive,
Receive: true,
}
testPMAPIAddresses = []*pmapi.Address{ //nolint[gochecknoglobals]
{ID: "usersAddress1ID", Email: "users@pm.me", Receive: pmapi.CanReceive, Type: pmapi.OriginalAddress},
{ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress},
{ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress},
{ID: "usersAddress1ID", Email: "users@pm.me", Receive: true, Type: pmapi.OriginalAddress},
{ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: true, Type: pmapi.AliasAddress},
{ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: true, Type: pmapi.AliasAddress},
}
testPMAPIEvent = &pmapi.Event{ // nolint[gochecknoglobals]
@ -129,15 +138,6 @@ var (
}
)
func waitForEvents() {
// Wait for goroutine to add listener.
// E.g. calling login to invoke firstsync event. Functions can end sooner than
// goroutines call the listener mock. We need to wait a little bit before the end of
// the test to capture all event calls. This allows us to detect whether there were
// missing calls, or perhaps whether something was called too many times.
time.Sleep(100 * time.Millisecond)
}
type mocks struct {
t *testing.T
@ -146,7 +146,7 @@ type mocks struct {
PanicHandler *usersmocks.MockPanicHandler
credentialsStore *usersmocks.MockCredentialsStorer
storeMaker *usersmocks.MockStoreMaker
eventListener *MockListener
eventListener *usersmocks.MockListener
clientManager *pmapimocks.MockManager
pmapiClient *pmapimocks.MockClient
@ -154,6 +154,48 @@ type mocks struct {
storeCache *store.Cache
}
func initMocks(t *testing.T) mocks {
var mockCtrl *gomock.Controller
if os.Getenv("VERBOSITY") == "trace" {
mockCtrl = gomock.NewController(&fullStackReporter{t})
} else {
mockCtrl = gomock.NewController(t)
}
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
r.NoError(t, err, "could not get temporary file for store cache")
m := mocks{
t: t,
ctrl: mockCtrl,
locator: usersmocks.NewMockLocator(mockCtrl),
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
eventListener: usersmocks.NewMockListener(mockCtrl),
clientManager: pmapimocks.NewMockManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()),
}
// Called during clean-up.
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
// Set up store factory.
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
r.NoError(t, err, "could not get temporary file for store db")
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
}).AnyTimes()
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
return m
}
type fullStackReporter struct {
T testing.TB
}
@ -168,86 +210,18 @@ func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
fr.T.FailNow()
}
func initMocks(t *testing.T) mocks {
var mockCtrl *gomock.Controller
if os.Getenv("VERBOSITY") == "trace" {
mockCtrl = gomock.NewController(&fullStackReporter{t})
} else {
mockCtrl = gomock.NewController(t)
}
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
require.NoError(t, err, "could not get temporary file for store cache")
m := mocks{
t: t,
ctrl: mockCtrl,
locator: usersmocks.NewMockLocator(mockCtrl),
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
eventListener: NewMockListener(mockCtrl),
clientManager: pmapimocks.NewMockManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()),
}
// Called during clean-up.
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
// Set up store factory.
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
require.NoError(t, err, "could not get temporary file for store db")
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
}).AnyTimes()
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
return m
}
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
// Events are asynchronous
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).Times(2)
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil),
// Init for user.
m.credentialsStore.EXPECT().Get(testCredentials.UserID).Return(testCredentials, nil),
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
m.credentialsStore.EXPECT().UpdateToken(testCredentials.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
m.credentialsStore.EXPECT().UpdatePassword(testCredentials.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Init for users.
m.credentialsStore.EXPECT().Get(testCredentialsSplit.UserID).Return(testCredentialsSplit, nil),
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
m.credentialsStore.EXPECT().UpdateToken(testCredentialsSplit.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentialsSplit, nil),
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsSplit.UserID, testCredentialsSplit.MailboxPassword).Return(testCredentialsSplit, nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses),
)
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockLoadingConnectedUser(m, testCredentialsSplit)
mockEventLoopNoAction(m)
return testNewUsers(t, m)
}
func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
// FIXME(conman): How to handle force upgrade?
// m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
m.eventListener.EXPECT().ProvideChannel(events.UpgradeApplicationEvent)
m.eventListener.EXPECT().ProvideChannel(events.InternetOnEvent)
users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true)
@ -256,38 +230,84 @@ func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
return users
}
func waitForEvents() {
// Wait for goroutine to add listener.
// E.g. calling login to invoke firstsync event. Functions can end sooner than
// goroutines call the listener mock. We need to wait a little bit before the end of
// the test to capture all event calls. This allows us to detect whether there were
// missing calls, or perhaps whether something was called too many times.
time.Sleep(100 * time.Millisecond)
}
func cleanUpUsersData(b *Users) {
for _, user := range b.users {
_ = user.clearStore()
}
}
func TestClearData(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
func mockAddingConnectedUser(m mocks) {
gomock.InOrder(
// Mock of users.FinishLogin.
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().Add("user", "username", testAuthRefresh.UID, testAuthRefresh.RefreshToken, testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}).Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
// m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
// m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
mockInitConnectedUser(m)
}
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
authRefresh := &pmapi.AuthRefresh{
UID: "uid",
AccessToken: "acc",
RefreshToken: "ref",
}
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
gomock.InOrder(
// Mock of users.loadUsersFromCredentialsStore.
m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil),
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, authRefresh, nil),
m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil),
)
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
mockInitConnectedUser(m)
}
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil)
func mockInitConnectedUser(m mocks) {
// Mock of user initialisation.
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes()
m.locator.EXPECT().Clear()
// Mock of store initialisation.
gomock.InOrder(
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
)
}
require.NoError(t, users.ClearData())
func mockLoadingDisconnectedUser(m mocks, creds *credentials.Credentials) {
gomock.InOrder(
// Mock of users.loadUsersFromCredentialsStore.
m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil),
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
)
waitForEvents()
mockInitDisconnectedUser(m)
}
func mockInitDisconnectedUser(m mocks) {
gomock.InOrder(
// Mock of user initialisation.
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
// Mock of store initialisation for the unauthorized user.
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
}
func mockEventLoopNoAction(m mocks) {
@ -297,19 +317,3 @@ func mockEventLoopNoAction(m mocks) {
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
}
func mockConnectedUser(m mocks) {
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
// m.pmapiClient.EXPECT().AuthRefresh("uid:acc").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
// Set up mocks for store initialisation for the authorized user.
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
)
}