mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-15 14:56:42 +00:00
GODT-35: New pmapi client and manager using resty
This commit is contained in:
@ -23,7 +23,7 @@
|
||||
// - persistent settings
|
||||
// - event listener
|
||||
// - credentials store
|
||||
// - pmapi ClientManager
|
||||
// - pmapi Manager
|
||||
// In addition, the base initialises logging and reacts to command line arguments
|
||||
// which control the log verbosity and enable cpu/memory profiling.
|
||||
package base
|
||||
@ -85,7 +85,7 @@ type Base struct {
|
||||
Cache *cache.Cache
|
||||
Listener listener.Listener
|
||||
Creds *credentials.Store
|
||||
CM *pmapi.ClientManager
|
||||
CM pmapi.Manager
|
||||
CookieJar *cookies.Jar
|
||||
UserAgent *useragent.UserAgent
|
||||
Updater *updater.Updater
|
||||
@ -181,13 +181,26 @@ func New( // nolint[funlen]
|
||||
kc = keychain.NewMissingKeychain()
|
||||
}
|
||||
|
||||
// FIXME(conman): Customize config depending on build type (app version, host URL).
|
||||
cm := pmapi.New(pmapi.DefaultConfig)
|
||||
|
||||
// FIXME(conman): Should this be a real object, not just created via callbacks?
|
||||
cm.AddConnectionObserver(pmapi.NewConnectionObserver(
|
||||
func() { listener.Emit(events.InternetOffEvent, "") },
|
||||
func() { listener.Emit(events.InternetOnEvent, "") },
|
||||
))
|
||||
|
||||
// FIXME(conman): Implement force upgrade observer.
|
||||
// apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
|
||||
|
||||
// FIXME(conman): Set up fancy round tripper with DoH/TLS checks etc.
|
||||
// cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
|
||||
|
||||
jar, err := cookies.NewCookieJar(settingsObj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cm := pmapi.NewClientManager(getAPIConfig(configName, listener), userAgent)
|
||||
cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
|
||||
cm.SetCookieJar(jar)
|
||||
|
||||
key, err := crypto.NewKeyFromArmored(updater.DefaultPublicKey)
|
||||
@ -375,13 +388,3 @@ func (b *Base) doTeardown() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAPIConfig(configName string, listener listener.Listener) *pmapi.ClientConfig {
|
||||
apiConfig := pmapi.GetAPIConfig(configName, constants.Version)
|
||||
|
||||
apiConfig.ConnectionOffHandler = func() { listener.Emit(events.InternetOffEvent, "") }
|
||||
apiConfig.ConnectionOnHandler = func() { listener.Emit(events.InternetOnEvent, "") }
|
||||
apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
|
||||
|
||||
return apiConfig
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
@ -44,7 +45,7 @@ type Bridge struct {
|
||||
|
||||
locations Locator
|
||||
settings SettingsProvider
|
||||
clientManager users.ClientManager
|
||||
clientManager pmapi.Manager
|
||||
updater Updater
|
||||
versioner Versioner
|
||||
}
|
||||
@ -56,7 +57,7 @@ func New(
|
||||
sentryReporter *sentry.Reporter,
|
||||
panicHandler users.PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
clientManager users.ClientManager,
|
||||
clientManager pmapi.Manager,
|
||||
credStorer users.CredentialsStorer,
|
||||
updater Updater,
|
||||
versioner Versioner,
|
||||
@ -64,10 +65,11 @@ func New(
|
||||
// Allow DoH before starting the app if the user has previously set this setting.
|
||||
// This allows us to start even if protonmail is blocked.
|
||||
if s.GetBool(settings.AllowProxyKey) {
|
||||
clientManager.AllowProxy()
|
||||
// FIXME(conman): Support enable/disable of DoH.
|
||||
// clientManager.AllowProxy()
|
||||
}
|
||||
|
||||
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, clientManager, eventListener)
|
||||
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener)
|
||||
u := users.New(locations, panicHandler, eventListener, clientManager, credStorer, storeFactory, true)
|
||||
b := &Bridge{
|
||||
Users: u,
|
||||
@ -118,28 +120,15 @@ func (b *Bridge) heartbeat() {
|
||||
|
||||
// ReportBug reports a new bug from the user.
|
||||
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
||||
c := b.clientManager.GetAnonymousClient()
|
||||
defer c.Logout()
|
||||
|
||||
title := "[Bridge] Bug"
|
||||
report := pmapi.ReportReq{
|
||||
return b.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
|
||||
OS: osType,
|
||||
OSVersion: osVersion,
|
||||
Browser: emailClient,
|
||||
Title: title,
|
||||
Title: "[Bridge] Bug",
|
||||
Description: description,
|
||||
Username: accountName,
|
||||
Email: address,
|
||||
}
|
||||
|
||||
if err := c.Report(report); err != nil {
|
||||
log.Error("Reporting bug failed: ", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Bug successfully reported")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetUpdateChannel returns currently set update channel.
|
||||
|
||||
@ -31,7 +31,6 @@ type storeFactory struct {
|
||||
cache Cacher
|
||||
sentryReporter *sentry.Reporter
|
||||
panicHandler users.PanicHandler
|
||||
clientManager users.ClientManager
|
||||
eventListener listener.Listener
|
||||
storeCache *store.Cache
|
||||
}
|
||||
@ -40,14 +39,12 @@ func newStoreFactory(
|
||||
cache Cacher,
|
||||
sentryReporter *sentry.Reporter,
|
||||
panicHandler users.PanicHandler,
|
||||
clientManager users.ClientManager,
|
||||
eventListener listener.Listener,
|
||||
) *storeFactory {
|
||||
return &storeFactory{
|
||||
cache: cache,
|
||||
sentryReporter: sentryReporter,
|
||||
panicHandler: panicHandler,
|
||||
clientManager: clientManager,
|
||||
eventListener: eventListener,
|
||||
storeCache: store.NewCache(cache.GetIMAPCachePath()),
|
||||
}
|
||||
@ -56,7 +53,7 @@ func newStoreFactory(
|
||||
// New creates new store for given user.
|
||||
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
|
||||
storePath := getUserStorePath(f.cache.GetDBDir(), user.ID())
|
||||
return store.New(f.sentryReporter, f.panicHandler, user, f.clientManager, f.eventListener, storePath, f.storeCache)
|
||||
return store.New(f.sentryReporter, f.panicHandler, user, f.eventListener, storePath, f.storeCache)
|
||||
}
|
||||
|
||||
// Remove removes all store files for given user.
|
||||
|
||||
@ -18,8 +18,10 @@
|
||||
package cliie
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
@ -73,13 +75,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
return
|
||||
}
|
||||
|
||||
if auth.HasTwoFactor() {
|
||||
if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
|
||||
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
||||
if twoFactor == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err = client.Auth2FA(twoFactor, auth)
|
||||
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
|
||||
if err != nil {
|
||||
f.processAPIError(err)
|
||||
return
|
||||
@ -87,7 +89,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
}
|
||||
|
||||
mailboxPassword := password
|
||||
if auth.HasMailboxPassword() {
|
||||
if auth.PasswordMode == pmapi.TwoPasswordMode {
|
||||
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
||||
}
|
||||
if mailboxPassword == "" {
|
||||
|
||||
@ -20,7 +20,6 @@ package cliie
|
||||
import (
|
||||
"strings"
|
||||
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
@ -71,10 +70,13 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
|
||||
func (f *frontendCLI) processAPIError(err error) {
|
||||
log.Warn("API error: ", err)
|
||||
switch err {
|
||||
case pmapi.ErrAPINotReachable:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
// FIXME(conman): How to handle various API errors?
|
||||
/*
|
||||
case pmapi.ErrNoConnection:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
*/
|
||||
default:
|
||||
f.Println("Server error:", err.Error())
|
||||
}
|
||||
|
||||
@ -18,11 +18,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
@ -120,13 +122,13 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
return
|
||||
}
|
||||
|
||||
if auth.HasTwoFactor() {
|
||||
if auth.TwoFA.Enabled == pmapi.TOTPEnabled {
|
||||
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
||||
if twoFactor == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err = client.Auth2FA(twoFactor, auth)
|
||||
err = client.Auth2FA(context.TODO(), pmapi.Auth2FAReq{TwoFactorCode: twoFactor})
|
||||
if err != nil {
|
||||
f.processAPIError(err)
|
||||
return
|
||||
@ -134,7 +136,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
|
||||
}
|
||||
|
||||
mailboxPassword := password
|
||||
if auth.HasMailboxPassword() {
|
||||
if auth.PasswordMode == pmapi.TwoPasswordMode {
|
||||
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
||||
}
|
||||
if mailboxPassword == "" {
|
||||
|
||||
@ -20,7 +20,6 @@ package cli
|
||||
import (
|
||||
"strings"
|
||||
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
@ -71,10 +70,13 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
|
||||
func (f *frontendCLI) processAPIError(err error) {
|
||||
log.Warn("API error: ", err)
|
||||
switch err {
|
||||
case pmapi.ErrAPINotReachable:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
// FIXME(conman): How to handle various API errors?
|
||||
/*
|
||||
case pmapi.ErrNoConnection:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
*/
|
||||
default:
|
||||
f.Println("Server error:", err.Error())
|
||||
}
|
||||
|
||||
@ -164,7 +164,7 @@ func (a *Accounts) showLoginError(err error, scope string) bool {
|
||||
return false
|
||||
}
|
||||
log.Warnf("%s: %v", scope, err)
|
||||
if err == pmapi.ErrAPINotReachable {
|
||||
if err == pmapi.ErrNoConnection {
|
||||
a.qml.SetConnectionStatus(false)
|
||||
SendNotification(a.qml, TabAccount, a.qml.CanNotReachAPI())
|
||||
a.qml.ProcessFinished()
|
||||
|
||||
@ -130,7 +130,7 @@ func (s *FrontendQt) showLoginError(err error, scope string) bool {
|
||||
return false
|
||||
}
|
||||
log.Warnf("%s: %v", scope, err)
|
||||
if err == pmapi.ErrAPINotReachable {
|
||||
if err == pmapi.ErrNoConnection {
|
||||
s.Qml.SetConnectionStatus(false)
|
||||
s.SendNotification(TabAccount, s.Qml.CanNotReachAPI())
|
||||
s.Qml.ProcessFinished()
|
||||
|
||||
@ -20,6 +20,7 @@ package importexport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/transfer"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users"
|
||||
@ -39,7 +40,7 @@ type ImportExport struct {
|
||||
locations Locator
|
||||
cache Cacher
|
||||
panicHandler users.PanicHandler
|
||||
clientManager users.ClientManager
|
||||
clientManager pmapi.Manager
|
||||
}
|
||||
|
||||
func New(
|
||||
@ -47,7 +48,7 @@ func New(
|
||||
cache Cacher,
|
||||
panicHandler users.PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
clientManager users.ClientManager,
|
||||
clientManager pmapi.Manager,
|
||||
credStorer users.CredentialsStorer,
|
||||
) *ImportExport {
|
||||
u := users.New(locations, panicHandler, eventListener, clientManager, credStorer, &storeFactory{}, false)
|
||||
@ -64,57 +65,31 @@ func New(
|
||||
|
||||
// ReportBug reports a new bug from the user.
|
||||
func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
||||
c := ie.clientManager.GetAnonymousClient()
|
||||
defer c.Logout()
|
||||
|
||||
title := "[Import-Export] Bug"
|
||||
report := pmapi.ReportReq{
|
||||
return ie.clientManager.ReportBug(context.TODO(), pmapi.ReportBugReq{
|
||||
OS: osType,
|
||||
OSVersion: osVersion,
|
||||
Browser: emailClient,
|
||||
Title: title,
|
||||
Title: "[Import-Export] Bug",
|
||||
Description: description,
|
||||
Username: accountName,
|
||||
Email: address,
|
||||
}
|
||||
|
||||
if err := c.Report(report); err != nil {
|
||||
log.Error("Reporting bug failed: ", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Bug successfully reported")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ReportFile submits import report file.
|
||||
func (ie *ImportExport) ReportFile(osType, osVersion, accountName, address string, logdata []byte) error {
|
||||
c := ie.clientManager.GetAnonymousClient()
|
||||
defer c.Logout()
|
||||
|
||||
title := "[Import-Export] report file"
|
||||
description := "An Import-Export report from the user swam down the river."
|
||||
|
||||
report := pmapi.ReportReq{
|
||||
report := pmapi.ReportBugReq{
|
||||
OS: osType,
|
||||
OSVersion: osVersion,
|
||||
Description: description,
|
||||
Title: title,
|
||||
Description: "An Import-Export report from the user swam down the river.",
|
||||
Title: "[Import-Export] report file",
|
||||
Username: accountName,
|
||||
Email: address,
|
||||
}
|
||||
|
||||
report.AddAttachment("log", "report.log", bytes.NewReader(logdata))
|
||||
|
||||
if err := c.Report(report); err != nil {
|
||||
log.Error("Sending report failed: ", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("Report successfully sent")
|
||||
|
||||
return nil
|
||||
return ie.clientManager.ReportBug(context.TODO(), report)
|
||||
}
|
||||
|
||||
// GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account.
|
||||
@ -187,5 +162,5 @@ func (ie *ImportExport) getPMAPIProvider(username, address string) (*transfer.PM
|
||||
log.WithError(err).Info("Address does not exist, using all addresses")
|
||||
}
|
||||
|
||||
return transfer.NewPMAPIProvider(ie.clientManager, user.ID(), addressID)
|
||||
return transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"strings"
|
||||
@ -28,7 +29,7 @@ import (
|
||||
)
|
||||
|
||||
type messageGetter interface {
|
||||
GetMessage(string) (*pmapi.Message, error)
|
||||
GetMessage(context.Context, string) (*pmapi.Message, error)
|
||||
}
|
||||
|
||||
type sendRecorderValue struct {
|
||||
@ -126,7 +127,7 @@ func (q *sendRecorder) isSendingOrSent(client messageGetter, hash string) (isSen
|
||||
return true, false
|
||||
}
|
||||
|
||||
message, err := client.GetMessage(value.messageID)
|
||||
message, err := client.GetMessage(context.TODO(), value.messageID)
|
||||
// Message could be deleted or there could be an internet issue or whatever,
|
||||
// so let's assume the message was not sent.
|
||||
if err != nil {
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
@ -33,7 +34,7 @@ type testSendRecorderGetMessageMock struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *testSendRecorderGetMessageMock) GetMessage(messageID string) (*pmapi.Message, error) {
|
||||
func (m *testSendRecorderGetMessageMock) GetMessage(_ context.Context, messageID string) (*pmapi.Message, error) {
|
||||
return m.message, m.err
|
||||
}
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ package smtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -122,7 +123,7 @@ func (su *smtpUser) getSendPreferences(
|
||||
}
|
||||
|
||||
func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata, err error) {
|
||||
emails, err := su.client().GetContactEmailByEmail(recipient, 0, 1000)
|
||||
emails, err := su.client().GetContactEmailByEmail(context.TODO(), recipient, 0, 1000)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -134,7 +135,7 @@ func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata
|
||||
}
|
||||
|
||||
var contact pmapi.Contact
|
||||
if contact, err = su.client().GetContactByID(email.ContactID); err != nil {
|
||||
if contact, err = su.client().GetContactByID(context.TODO(), email.ContactID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -150,7 +151,7 @@ func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata
|
||||
}
|
||||
|
||||
func (su *smtpUser) getAPIKeyData(recipient string) (apiKeys []pmapi.PublicKey, isInternal bool, err error) {
|
||||
return su.client().GetPublicKeysForEmail(recipient)
|
||||
return su.client().GetPublicKeysForEmail(context.TODO(), recipient)
|
||||
}
|
||||
|
||||
// Discard currently processed message.
|
||||
@ -218,7 +219,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
|
||||
|
||||
messageReader = io.TeeReader(messageReader, b)
|
||||
|
||||
mailSettings, err := su.client().GetMailSettings()
|
||||
mailSettings, err := su.client().GetMailSettings(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -339,7 +340,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
|
||||
// can lead to sending the wrong message. Also clients do not necessarily
|
||||
// delete the old draft.
|
||||
if draftID != "" {
|
||||
if err := su.client().DeleteMessages([]string{draftID}); err != nil {
|
||||
if err := su.client().DeleteMessages(context.TODO(), []string{draftID}); err != nil {
|
||||
log.WithError(err).WithField("draftID", draftID).Warn("Original draft cannot be deleted")
|
||||
}
|
||||
}
|
||||
@ -393,7 +394,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
|
||||
return errors.New("error decoding subject message " + message.Header.Get("Subject"))
|
||||
}
|
||||
if !su.continueSendingUnencryptedMail(subject) {
|
||||
if err := su.client().DeleteMessages([]string{message.ID}); err != nil {
|
||||
if err := su.client().DeleteMessages(context.TODO(), []string{message.ID}); err != nil {
|
||||
log.WithError(err).Warn("Failed to delete canceled messages")
|
||||
}
|
||||
return errors.New("sending was canceled by user")
|
||||
@ -422,7 +423,7 @@ func (su *smtpUser) handleReferencesHeader(m *pmapi.Message) (draftID, parentID
|
||||
if su.addressID != "" {
|
||||
filter.AddressID = su.addressID
|
||||
}
|
||||
metadata, _, _ := su.client().ListMessages(filter)
|
||||
metadata, _, _ := su.client().ListMessages(context.TODO(), filter)
|
||||
for _, m := range metadata {
|
||||
if m.IsDraft() {
|
||||
draftID = m.ID
|
||||
@ -442,7 +443,7 @@ func (su *smtpUser) handleReferencesHeader(m *pmapi.Message) (draftID, parentID
|
||||
if su.addressID != "" {
|
||||
filter.AddressID = su.addressID
|
||||
}
|
||||
metadata, _, _ := su.client().ListMessages(filter)
|
||||
metadata, _, _ := su.client().ListMessages(context.TODO(), filter)
|
||||
// There can be two or messages with the same external ID and then we cannot
|
||||
// be sure which message should be parent. Better to not choose any.
|
||||
if len(metadata) == 1 {
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
@ -80,7 +81,7 @@ func (loop *eventLoop) client() pmapi.Client {
|
||||
func (loop *eventLoop) setFirstEventID() (err error) {
|
||||
loop.log.Info("Setting first event ID")
|
||||
|
||||
event, err := loop.client().GetEvent("")
|
||||
event, err := loop.client().GetEvent(context.TODO(), "")
|
||||
if err != nil {
|
||||
loop.log.WithError(err).Error("Could not get latest event ID")
|
||||
return
|
||||
@ -221,7 +222,8 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
|
||||
// (e.g. no internet, ulimit reached etc.)
|
||||
defer func() {
|
||||
if errors.Cause(err) == pmapi.ErrAPINotReachable {
|
||||
// FIXME(conman): How to handle errors of different types?
|
||||
if errors.Is(err, pmapi.ErrNoConnection) {
|
||||
l.Warn("Internet unavailable")
|
||||
err = nil
|
||||
}
|
||||
@ -232,18 +234,20 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
err = nil
|
||||
}
|
||||
|
||||
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
|
||||
l.Warn("Need to upgrade application")
|
||||
err = nil
|
||||
}
|
||||
|
||||
_, errUnauthorized := errors.Cause(err).(*pmapi.ErrUnauthorized)
|
||||
// FIXME(conman): Handle force upgrade.
|
||||
/*
|
||||
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
|
||||
l.Warn("Need to upgrade application")
|
||||
err = nil
|
||||
}
|
||||
*/
|
||||
|
||||
if err == nil {
|
||||
loop.errCounter = 0
|
||||
}
|
||||
// All errors except Invalid Token (which is not possible to recover from) are ignored.
|
||||
if err != nil && !errUnauthorized && errors.Cause(err) != pmapi.ErrInvalidToken {
|
||||
|
||||
// All errors except ErrUnauthorized (which is not possible to recover from) are ignored.
|
||||
if !errors.Is(err, pmapi.ErrUnauthorized) {
|
||||
l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped")
|
||||
loop.errCounter++
|
||||
if loop.errCounter == errMaxSentry {
|
||||
@ -264,7 +268,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
|
||||
loop.pollCounter++
|
||||
|
||||
var event *pmapi.Event
|
||||
if event, err = loop.client().GetEvent(loop.currentEventID); err != nil {
|
||||
if event, err = loop.client().GetEvent(context.TODO(), loop.currentEventID); err != nil {
|
||||
return false, errors.Wrap(err, "failed to get event")
|
||||
}
|
||||
|
||||
@ -461,12 +465,16 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
|
||||
|
||||
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
|
||||
|
||||
if msg, err = loop.client().GetMessage(message.ID); err != nil {
|
||||
if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok {
|
||||
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
|
||||
err = nil
|
||||
continue
|
||||
}
|
||||
if msg, err = loop.client().GetMessage(context.TODO(), message.ID); err != nil {
|
||||
// FIXME(conman): How to handle error of this particular type?
|
||||
|
||||
/*
|
||||
if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok {
|
||||
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
|
||||
err = nil
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
return errors.Wrap(err, "failed to get message from API for updating")
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/mail"
|
||||
"testing"
|
||||
"time"
|
||||
@ -39,15 +40,15 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
|
||||
// Doesn't matter which IDs are used.
|
||||
// This test is trying to see whether event loop will immediately process
|
||||
// next event if there is `More` of them.
|
||||
m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{
|
||||
EventID: "event50",
|
||||
More: 1,
|
||||
}, nil),
|
||||
m.client.EXPECT().GetEvent("event50").Return(&pmapi.Event{
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{
|
||||
EventID: "event70",
|
||||
More: 0,
|
||||
}, nil),
|
||||
m.client.EXPECT().GetEvent("event70").Return(&pmapi.Event{
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{
|
||||
EventID: "event71",
|
||||
More: 0,
|
||||
}, nil),
|
||||
@ -165,7 +166,7 @@ func TestEventLoopDeletionPaused(t *testing.T) {
|
||||
|
||||
func testEvent(t *testing.T, m *mocksForStore, event *pmapi.Event) {
|
||||
eventReceived := make(chan struct{})
|
||||
m.client.EXPECT().GetEvent("latestEventID").DoAndReturn(func(eventID string) (*pmapi.Event, error) {
|
||||
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").DoAndReturn(func(_ context.Context, eventID string) (*pmapi.Event, error) {
|
||||
defer close(eventReceived)
|
||||
return event, nil
|
||||
})
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -41,7 +43,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
|
||||
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
|
||||
// wrapping it.
|
||||
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
|
||||
msg, err := storeMailbox.client().GetMessage(apiID)
|
||||
msg, err := storeMailbox.client().GetMessage(context.TODO(), apiID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -58,15 +60,17 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
|
||||
}
|
||||
|
||||
importReqs := &pmapi.ImportMsgReq{
|
||||
AddressID: msg.AddressID,
|
||||
Body: body,
|
||||
Unread: msg.Unread,
|
||||
Flags: msg.Flags,
|
||||
Time: msg.Time,
|
||||
LabelIDs: labelIDs,
|
||||
Metadata: &pmapi.ImportMetadata{
|
||||
AddressID: msg.AddressID,
|
||||
Unread: msg.Unread,
|
||||
Flags: msg.Flags,
|
||||
Time: msg.Time,
|
||||
LabelIDs: labelIDs,
|
||||
},
|
||||
Message: body,
|
||||
}
|
||||
|
||||
res, err := storeMailbox.client().Import([]*pmapi.ImportMsgReq{importReqs})
|
||||
res, err := storeMailbox.client().Import(context.TODO(), pmapi.ImportMsgReqs{importReqs})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -95,7 +99,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
|
||||
return ErrAllMailOpNotAllowed
|
||||
}
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().LabelMessages(apiIDs, storeMailbox.labelID)
|
||||
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// UnlabelMessages removes the label by calling an API.
|
||||
@ -108,7 +112,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
|
||||
return ErrAllMailOpNotAllowed
|
||||
}
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID)
|
||||
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// MarkMessagesRead marks the message read by calling an API.
|
||||
@ -135,7 +139,7 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
return storeMailbox.client().MarkMessagesRead(ids)
|
||||
return storeMailbox.client().MarkMessagesRead(context.TODO(), ids)
|
||||
}
|
||||
|
||||
// MarkMessagesUnread marks the message unread by calling an API.
|
||||
@ -147,7 +151,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as unread")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().MarkMessagesUnread(apiIDs)
|
||||
return storeMailbox.client().MarkMessagesUnread(context.TODO(), apiIDs)
|
||||
}
|
||||
|
||||
// MarkMessagesStarred adds the Starred label by calling an API.
|
||||
@ -160,7 +164,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as starred")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().LabelMessages(apiIDs, pmapi.StarredLabel)
|
||||
return storeMailbox.client().LabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
|
||||
}
|
||||
|
||||
// MarkMessagesUnstarred removes the Starred label by calling an API.
|
||||
@ -173,7 +177,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as unstarred")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.client().UnlabelMessages(apiIDs, pmapi.StarredLabel)
|
||||
return storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, pmapi.StarredLabel)
|
||||
}
|
||||
|
||||
// MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API
|
||||
@ -257,11 +261,11 @@ func (storeMailbox *Mailbox) RemoveDeleted(apiIDs []string) error {
|
||||
}
|
||||
case pmapi.DraftLabel:
|
||||
storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts")
|
||||
if err := storeMailbox.client().DeleteMessages(apiIDs); err != nil {
|
||||
if err := storeMailbox.client().DeleteMessages(context.TODO(), apiIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if err := storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil {
|
||||
if err := storeMailbox.client().UnlabelMessages(context.TODO(), apiIDs, storeMailbox.labelID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -299,13 +303,13 @@ func (storeMailbox *Mailbox) deleteFromTrashOrSpam(apiIDs []string) error {
|
||||
}
|
||||
}
|
||||
if len(messageIDsToUnlabel) > 0 {
|
||||
if err := storeMailbox.client().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
||||
if err := storeMailbox.client().UnlabelMessages(context.TODO(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
||||
l.WithError(err).Warning("Cannot unlabel before deleting")
|
||||
}
|
||||
}
|
||||
if len(messageIDsToDelete) > 0 {
|
||||
storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages")
|
||||
if err := storeMailbox.client().DeleteMessages(messageIDsToDelete); err != nil {
|
||||
if err := storeMailbox.client().DeleteMessages(context.TODO(), messageIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,ClientManager,BridgeUser,ChangeNotifier)
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
@ -46,43 +46,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockClientManager is a mock of ClientManager interface
|
||||
type MockClientManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||
type MockClientManagerMockRecorder struct {
|
||||
mock *MockClientManager
|
||||
}
|
||||
|
||||
// NewMockClientManager creates a new mock instance
|
||||
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||
mock := &MockClientManager{ctrl: ctrl}
|
||||
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||
}
|
||||
|
||||
// MockBridgeUser is a mock of BridgeUser interface
|
||||
type MockBridgeUser struct {
|
||||
ctrl *gomock.Controller
|
||||
@ -145,6 +108,20 @@ func (mr *MockBridgeUserMockRecorder) GetAddressID(arg0 interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddressID", reflect.TypeOf((*MockBridgeUser)(nil).GetAddressID), arg0)
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockBridgeUser) GetClient() pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient")
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockBridgeUserMockRecorder) GetClient() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockBridgeUser)(nil).GetClient))
|
||||
}
|
||||
|
||||
// GetPrimaryAddress mocks base method
|
||||
func (m *MockBridgeUser) GetPrimaryAddress() string {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
@ -106,7 +107,6 @@ type Store struct {
|
||||
panicHandler PanicHandler
|
||||
eventLoop *eventLoop
|
||||
user BridgeUser
|
||||
clientManager ClientManager
|
||||
|
||||
log *logrus.Entry
|
||||
|
||||
@ -127,13 +127,12 @@ func New( // nolint[funlen]
|
||||
sentryReporter *sentry.Reporter,
|
||||
panicHandler PanicHandler,
|
||||
user BridgeUser,
|
||||
clientManager ClientManager,
|
||||
events listener.Listener,
|
||||
path string,
|
||||
cache *Cache,
|
||||
) (store *Store, err error) {
|
||||
if user == nil || clientManager == nil || events == nil || cache == nil {
|
||||
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, clientManager, events, cache)
|
||||
if user == nil || events == nil || cache == nil {
|
||||
return nil, fmt.Errorf("missing parameters - user: %v, events: %v, cache: %v", user, events, cache)
|
||||
}
|
||||
|
||||
l := log.WithField("user", user.ID())
|
||||
@ -156,7 +155,6 @@ func New( // nolint[funlen]
|
||||
store = &Store{
|
||||
sentryReporter: sentryReporter,
|
||||
panicHandler: panicHandler,
|
||||
clientManager: clientManager,
|
||||
user: user,
|
||||
cache: cache,
|
||||
filePath: path,
|
||||
@ -274,13 +272,13 @@ func (store *Store) init(firstInit bool) (err error) {
|
||||
}
|
||||
|
||||
func (store *Store) client() pmapi.Client {
|
||||
return store.clientManager.GetClient(store.UserID())
|
||||
return store.user.GetClient()
|
||||
}
|
||||
|
||||
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
|
||||
// the API is unavailable for whatever reason it tries to fetch the labels locally.
|
||||
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
|
||||
if labels, err = store.client().ListLabels(); err != nil {
|
||||
if labels, err = store.client().ListLabels(context.TODO()); err != nil {
|
||||
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
|
||||
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
|
||||
store.log.WithError(err).Error("Cannot list local labels")
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@ -133,7 +134,6 @@ type mocksForStore struct {
|
||||
events *storemocks.MockListener
|
||||
user *storemocks.MockBridgeUser
|
||||
client *pmapimocks.MockClient
|
||||
clientManager *storemocks.MockClientManager
|
||||
panicHandler *storemocks.MockPanicHandler
|
||||
changeNotifier *storemocks.MockChangeNotifier
|
||||
store *Store
|
||||
@ -150,7 +150,6 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
|
||||
events: storemocks.NewMockListener(ctrl),
|
||||
user: storemocks.NewMockBridgeUser(ctrl),
|
||||
client: pmapimocks.NewMockClient(ctrl),
|
||||
clientManager: storemocks.NewMockClientManager(ctrl),
|
||||
panicHandler: storemocks.NewMockPanicHandler(ctrl),
|
||||
changeNotifier: storemocks.NewMockChangeNotifier(ctrl),
|
||||
}
|
||||
@ -182,30 +181,30 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
|
||||
mocks.user.EXPECT().IsConnected().Return(true)
|
||||
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
|
||||
|
||||
mocks.clientManager.EXPECT().GetClient("userID").AnyTimes().Return(mocks.client)
|
||||
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
|
||||
|
||||
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
|
||||
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
|
||||
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
|
||||
})
|
||||
mocks.client.EXPECT().ListLabels().AnyTimes()
|
||||
mocks.client.EXPECT().CountMessages("")
|
||||
mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes()
|
||||
mocks.client.EXPECT().CountMessages(gomock.Any(), "")
|
||||
|
||||
// Call to get latest event ID and then to process first event.
|
||||
eventAfterSyncRequested := make(chan struct{})
|
||||
mocks.client.EXPECT().GetEvent("").Return(&pmapi.Event{
|
||||
mocks.client.EXPECT().GetEvent(gomock.Any(), "").Return(&pmapi.Event{
|
||||
EventID: "firstEventID",
|
||||
}, nil)
|
||||
mocks.client.EXPECT().GetEvent("firstEventID").DoAndReturn(func(_ string) (*pmapi.Event, error) {
|
||||
mocks.client.EXPECT().GetEvent(gomock.Any(), "firstEventID").DoAndReturn(func(_ context.Context, _ string) (*pmapi.Event, error) {
|
||||
close(eventAfterSyncRequested)
|
||||
return &pmapi.Event{
|
||||
EventID: "latestEventID",
|
||||
}, nil
|
||||
})
|
||||
|
||||
mocks.client.EXPECT().ListMessages(gomock.Any()).Return(msgs, len(msgs), nil).AnyTimes()
|
||||
mocks.client.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return(msgs, len(msgs), nil).AnyTimes()
|
||||
for _, msg := range msgs {
|
||||
mocks.client.EXPECT().GetMessage(msg.ID).Return(msg, nil).AnyTimes()
|
||||
mocks.client.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
|
||||
}
|
||||
|
||||
var err error
|
||||
@ -213,7 +212,6 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
|
||||
nil, // Sentry reporter is not used under unit tests.
|
||||
mocks.panicHandler,
|
||||
mocks.user,
|
||||
mocks.clientManager,
|
||||
mocks.events,
|
||||
filepath.Join(mocks.tmpDir, "mailbox-test.db"),
|
||||
mocks.cache,
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
@ -39,10 +40,10 @@ type storeSynchronizer interface {
|
||||
}
|
||||
|
||||
type messageLister interface {
|
||||
ListMessages(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
|
||||
ListMessages(context.Context, *pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
|
||||
}
|
||||
|
||||
func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func() messageLister, syncState *syncState) error {
|
||||
func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api messageLister, syncState *syncState) error {
|
||||
labelID := pmapi.AllMailLabel
|
||||
|
||||
// When the full sync starts (i.e. is not already in progress), we need to load
|
||||
@ -53,7 +54,7 @@ func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func()
|
||||
return errors.Wrap(err, "failed to load message IDs")
|
||||
}
|
||||
|
||||
if err := findIDRanges(labelID, api(), syncState); err != nil {
|
||||
if err := findIDRanges(labelID, api, syncState); err != nil {
|
||||
return errors.Wrap(err, "failed to load IDs ranges")
|
||||
}
|
||||
syncState.save()
|
||||
@ -71,7 +72,7 @@ func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func()
|
||||
defer panicHandler.HandlePanic()
|
||||
defer wg.Done()
|
||||
|
||||
err := syncBatch(labelID, store, api(), syncState, idRange, &shouldStop)
|
||||
err := syncBatch(labelID, store, api, syncState, idRange, &shouldStop)
|
||||
if err != nil {
|
||||
shouldStop = 1
|
||||
resultError = errors.Wrap(err, "failed to sync group")
|
||||
@ -147,7 +148,7 @@ func getSplitIDAndCount(labelID string, api messageLister, page int) (string, in
|
||||
Limit: 1,
|
||||
}
|
||||
// If the page does not exist, an empty page instead of an error is returned.
|
||||
messages, total, err := api.ListMessages(filter)
|
||||
messages, total, err := api.ListMessages(context.TODO(), filter)
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrap(err, "failed to list messages")
|
||||
}
|
||||
@ -189,7 +190,7 @@ func syncBatch( //nolint[funlen]
|
||||
|
||||
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
|
||||
|
||||
messages, _, err := api.ListMessages(filter)
|
||||
messages, _, err := api.ListMessages(context.TODO(), filter)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to list messages")
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
@ -34,7 +35,7 @@ type mockLister struct {
|
||||
messageIDs []string
|
||||
}
|
||||
|
||||
func (m *mockLister) ListMessages(filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) {
|
||||
func (m *mockLister) ListMessages(_ context.Context, filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) {
|
||||
if m.err != nil {
|
||||
return nil, 0, m.err
|
||||
}
|
||||
@ -197,7 +198,7 @@ func TestSyncAllMail(t *testing.T) { //nolint[funlen]
|
||||
|
||||
syncState := newSyncState(store, 0, tc.idRanges, tc.idsToBeDeleted)
|
||||
|
||||
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState)
|
||||
err := syncAllMail(m.panicHandler, store, api, syncState)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Check all messages were created or updated.
|
||||
@ -245,7 +246,7 @@ func TestSyncAllMail_FailedListing(t *testing.T) {
|
||||
}
|
||||
syncState := newTestSyncState(store)
|
||||
|
||||
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState)
|
||||
err := syncAllMail(m.panicHandler, store, api, syncState)
|
||||
require.EqualError(t, err, "failed to sync group: failed to list messages: error")
|
||||
}
|
||||
|
||||
@ -264,7 +265,7 @@ func TestSyncAllMail_FailedCreateOrUpdateMessage(t *testing.T) {
|
||||
}
|
||||
syncState := newTestSyncState(store)
|
||||
|
||||
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState)
|
||||
err := syncAllMail(m.panicHandler, store, api, syncState)
|
||||
require.EqualError(t, err, "failed to sync group: failed to create or update messages: error")
|
||||
}
|
||||
|
||||
|
||||
@ -23,10 +23,6 @@ type PanicHandler interface {
|
||||
HandlePanic()
|
||||
}
|
||||
|
||||
type ClientManager interface {
|
||||
GetClient(userID string) pmapi.Client
|
||||
}
|
||||
|
||||
// BridgeUser is subset of bridge.User for use by the Store.
|
||||
type BridgeUser interface {
|
||||
ID() string
|
||||
@ -35,6 +31,7 @@ type BridgeUser interface {
|
||||
IsCombinedAddressMode() bool
|
||||
GetPrimaryAddress() string
|
||||
GetStoreAddresses() []string
|
||||
GetClient() pmapi.Client
|
||||
UpdateUser() error
|
||||
CloseAllConnections()
|
||||
CloseConnection(string)
|
||||
|
||||
@ -17,6 +17,8 @@
|
||||
|
||||
package store
|
||||
|
||||
import "context"
|
||||
|
||||
// UserID returns user ID.
|
||||
func (store *Store) UserID() string {
|
||||
return store.user.ID()
|
||||
@ -24,7 +26,7 @@ func (store *Store) UserID() string {
|
||||
|
||||
// GetSpace returns used and total space in bytes.
|
||||
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
||||
apiUser, err := store.client().CurrentUser()
|
||||
apiUser, err := store.client().CurrentUser(context.TODO())
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
@ -33,7 +35,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
||||
|
||||
// GetMaxUpload returns max size of message + all attachments in bytes.
|
||||
func (store *Store) GetMaxUpload() (int64, error) {
|
||||
apiUser, err := store.client().CurrentUser()
|
||||
apiUser, err := store.client().CurrentUser(context.TODO())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@ -55,7 +56,7 @@ func (store *Store) createMailbox(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := store.client().CreateLabel(&pmapi.Label{
|
||||
_, err := store.client().CreateLabel(context.TODO(), &pmapi.Label{
|
||||
Name: name,
|
||||
Color: color,
|
||||
Exclusive: exclusive,
|
||||
@ -125,7 +126,7 @@ func (store *Store) leastUsedColor() string {
|
||||
func (store *Store) updateMailbox(labelID, newName, color string) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
|
||||
_, err := store.client().UpdateLabel(&pmapi.Label{
|
||||
_, err := store.client().UpdateLabel(context.TODO(), &pmapi.Label{
|
||||
ID: labelID,
|
||||
Name: newName,
|
||||
Color: color,
|
||||
@ -142,15 +143,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
|
||||
var err error
|
||||
switch labelID {
|
||||
case pmapi.SpamLabel:
|
||||
err = store.client().EmptyFolder(pmapi.SpamLabel, addressID)
|
||||
err = store.client().EmptyFolder(context.TODO(), pmapi.SpamLabel, addressID)
|
||||
case pmapi.TrashLabel:
|
||||
err = store.client().EmptyFolder(pmapi.TrashLabel, addressID)
|
||||
err = store.client().EmptyFolder(context.TODO(), pmapi.TrashLabel, addressID)
|
||||
default:
|
||||
err = fmt.Errorf("cannot empty mailbox %v", labelID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return store.client().DeleteLabel(labelID)
|
||||
return store.client().DeleteLabel(context.TODO(), labelID)
|
||||
}
|
||||
|
||||
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
|
||||
@ -165,7 +166,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
labels, err := store.client().ListLabels()
|
||||
labels, err := store.client().ListLabels(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -57,7 +58,7 @@ func (store *Store) CreateDraft(
|
||||
}
|
||||
|
||||
draftAction := store.getDraftAction(message)
|
||||
draft, err := store.client().CreateDraft(message, parentID, draftAction)
|
||||
draft, err := store.client().CreateDraft(context.TODO(), message, parentID, draftAction)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create draft")
|
||||
}
|
||||
@ -69,7 +70,7 @@ func (store *Store) CreateDraft(
|
||||
for _, att := range attachments {
|
||||
att.attachment.MessageID = draft.ID
|
||||
|
||||
createdAttachment, err := store.client().CreateAttachment(att.attachment, att.encReader, att.sigReader)
|
||||
createdAttachment, err := store.client().CreateAttachment(context.TODO(), att.attachment, att.encReader, att.sigReader)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create attachment")
|
||||
}
|
||||
@ -183,7 +184,7 @@ func (store *Store) getDraftAction(message *pmapi.Message) int {
|
||||
// SendMessage sends the message.
|
||||
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
_, _, err := store.client().SendMessage(messageID, req)
|
||||
_, _, err := store.client().SendMessage(context.TODO(), messageID, req)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@ -127,12 +127,12 @@ func TestDeleteMessage(t *testing.T) {
|
||||
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
|
||||
}
|
||||
|
||||
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread int, labelIDs []string) { //nolint[unparam]
|
||||
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread pmapi.Boolean, labelIDs []string) { //nolint[unparam]
|
||||
msg := getTestMessage(id, subject, sender, unread, labelIDs)
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
|
||||
}
|
||||
|
||||
func getTestMessage(id, subject, sender string, unread int, labelIDs []string) *pmapi.Message {
|
||||
func getTestMessage(id, subject, sender string, unread pmapi.Boolean, labelIDs []string) *pmapi.Message {
|
||||
address := &mail.Address{Address: sender}
|
||||
return &pmapi.Message{
|
||||
ID: id,
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
@ -34,7 +35,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
|
||||
|
||||
// updateCountsFromServer will download and set the counts.
|
||||
func (store *Store) updateCountsFromServer() error {
|
||||
counts, err := store.client().CountMessages("")
|
||||
counts, err := store.client().CountMessages(context.TODO(), "")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot update counts from server")
|
||||
}
|
||||
@ -152,7 +153,7 @@ func (store *Store) triggerSync() {
|
||||
|
||||
store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started")
|
||||
|
||||
err := syncAllMail(store.panicHandler, store, func() messageLister { return store.client() }, syncState)
|
||||
err := syncAllMail(store.panicHandler, store, store.client(), syncState)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Store sync failed")
|
||||
store.syncCooldown.increaseWaitTime()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/transfer (interfaces: PanicHandler,ClientManager,IMAPClientProvider)
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/transfer (interfaces: PanicHandler,IMAPClientProvider)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
@ -7,7 +7,6 @@ package mocks
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
imap "github.com/emersion/go-imap"
|
||||
sasl "github.com/emersion/go-sasl"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@ -48,57 +47,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockClientManager is a mock of ClientManager interface
|
||||
type MockClientManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||
type MockClientManagerMockRecorder struct {
|
||||
mock *MockClientManager
|
||||
}
|
||||
|
||||
// NewMockClientManager creates a new mock instance
|
||||
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||
mock := &MockClientManager{ctrl: ctrl}
|
||||
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CheckConnection mocks base method
|
||||
func (m *MockClientManager) CheckConnection() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CheckConnection")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CheckConnection indicates an expected call of CheckConnection
|
||||
func (mr *MockClientManagerMockRecorder) CheckConnection() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConnection", reflect.TypeOf((*MockClientManager)(nil).CheckConnection))
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||
}
|
||||
|
||||
// MockIMAPClientProvider is a mock of IMAPClientProvider interface
|
||||
type MockIMAPClientProvider struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@ -25,7 +25,6 @@ import (
|
||||
|
||||
imapID "github.com/ProtonMail/go-imap-id"
|
||||
"github.com/ProtonMail/proton-bridge/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/emersion/go-imap"
|
||||
imapClient "github.com/emersion/go-imap/client"
|
||||
"github.com/emersion/go-sasl"
|
||||
@ -118,15 +117,19 @@ func (p *IMAPProvider) tryReconnect(ensureSelectedIn string) error {
|
||||
return previousErr
|
||||
}
|
||||
|
||||
err := pmapi.CheckConnection()
|
||||
log.WithError(err).Debug("Connection check")
|
||||
if err != nil {
|
||||
time.Sleep(imapReconnectSleep)
|
||||
previousErr = err
|
||||
continue
|
||||
}
|
||||
// FIXME(conman): This should register as connection observer.
|
||||
|
||||
err = p.reauth()
|
||||
/*
|
||||
err := pmapi.CheckConnection()
|
||||
log.WithError(err).Debug("Connection check")
|
||||
if err != nil {
|
||||
time.Sleep(imapReconnectSleep)
|
||||
previousErr = err
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
err := p.reauth()
|
||||
log.WithError(err).Debug("Reauth")
|
||||
if err != nil {
|
||||
time.Sleep(imapReconnectSleep)
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
@ -34,25 +35,27 @@ const (
|
||||
|
||||
// PMAPIProvider implements import and export to/from ProtonMail server.
|
||||
type PMAPIProvider struct {
|
||||
clientManager ClientManager
|
||||
userID string
|
||||
addressID string
|
||||
keyRing *crypto.KeyRing
|
||||
builder *message.Builder
|
||||
client pmapi.Client
|
||||
userID string
|
||||
addressID string
|
||||
keyRing *crypto.KeyRing
|
||||
builder *message.Builder
|
||||
|
||||
nextImportRequests map[string]*pmapi.ImportMsgReq // Key is msg transfer ID.
|
||||
nextImportRequestsSize int
|
||||
|
||||
timeIt *timeIt
|
||||
|
||||
connection bool
|
||||
}
|
||||
|
||||
// NewPMAPIProvider returns new PMAPIProvider.
|
||||
func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*PMAPIProvider, error) {
|
||||
func NewPMAPIProvider(client pmapi.Client, userID, addressID string) (*PMAPIProvider, error) {
|
||||
provider := &PMAPIProvider{
|
||||
clientManager: clientManager,
|
||||
userID: userID,
|
||||
addressID: addressID,
|
||||
builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
|
||||
client: client,
|
||||
userID: userID,
|
||||
addressID: addressID,
|
||||
builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
|
||||
|
||||
nextImportRequests: map[string]*pmapi.ImportMsgReq{},
|
||||
nextImportRequestsSize: 0,
|
||||
@ -61,7 +64,7 @@ func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*P
|
||||
}
|
||||
|
||||
if addressID != "" {
|
||||
keyRing, err := clientManager.GetClient(userID).KeyRingForAddressID(addressID)
|
||||
keyRing, err := client.KeyRingForAddressID(addressID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get key ring")
|
||||
}
|
||||
@ -71,10 +74,6 @@ func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*P
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (p *PMAPIProvider) client() pmapi.Client {
|
||||
return p.clientManager.GetClient(p.userID)
|
||||
}
|
||||
|
||||
// ID returns identifier of current setup of PMAPI provider.
|
||||
// Identification is unique per user.
|
||||
func (p *PMAPIProvider) ID() string {
|
||||
@ -83,7 +82,7 @@ func (p *PMAPIProvider) ID() string {
|
||||
|
||||
// Mailboxes returns all available labels in ProtonMail account.
|
||||
func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox, error) {
|
||||
labels, err := p.client().ListLabels()
|
||||
labels, err := p.client.ListLabels(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -92,7 +91,7 @@ func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox,
|
||||
|
||||
emptyLabelsMap := map[string]bool{}
|
||||
if !includeEmpty {
|
||||
messagesCounts, err := p.client().CountMessages(p.addressID)
|
||||
messagesCounts, err := p.client.CountMessages(context.Background(), p.addressID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -120,7 +119,7 @@ func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox,
|
||||
ID: label.ID,
|
||||
Name: label.Name,
|
||||
Color: label.Color,
|
||||
IsExclusive: label.Exclusive == 1,
|
||||
IsExclusive: bool(label.Exclusive),
|
||||
})
|
||||
}
|
||||
return mailboxes, nil
|
||||
@ -160,10 +159,10 @@ func (l byFoldersLabels) Swap(i, j int) {
|
||||
|
||||
// Less sorts first folders, then labels, by user order.
|
||||
func (l byFoldersLabels) Less(i, j int) bool {
|
||||
if l[i].Exclusive == 1 && l[j].Exclusive == 0 {
|
||||
if l[i].Exclusive && !l[j].Exclusive {
|
||||
return true
|
||||
}
|
||||
if l[i].Exclusive == 0 && l[j].Exclusive == 1 {
|
||||
if !l[i].Exclusive && l[j].Exclusive {
|
||||
return false
|
||||
}
|
||||
return l[i].Order < l[j].Order
|
||||
|
||||
@ -157,7 +157,7 @@ func (p *PMAPIProvider) exportMessage(rule *Rule, progress *Progress, pmapiMsgID
|
||||
|
||||
body, err := p.builder.NewJobWithOptions(
|
||||
context.Background(),
|
||||
p.client(),
|
||||
p.client,
|
||||
msg.ID,
|
||||
message.JobOptions{IgnoreDecryptionErrors: !skipEncryptedMessages},
|
||||
).GetResult()
|
||||
@ -169,14 +169,9 @@ func (p *PMAPIProvider) exportMessage(rule *Rule, progress *Progress, pmapiMsgID
|
||||
return Message{Body: []byte(msg.Body)}, err
|
||||
}
|
||||
|
||||
unread := false
|
||||
if msg.Unread == 1 {
|
||||
unread = true
|
||||
}
|
||||
|
||||
return Message{
|
||||
ID: msgID,
|
||||
Unread: unread,
|
||||
Unread: bool(msg.Unread),
|
||||
Body: body,
|
||||
Sources: []Mailbox{rule.SourceMailbox},
|
||||
Targets: rule.TargetMailboxes,
|
||||
|
||||
@ -19,6 +19,7 @@ package transfer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -56,7 +57,7 @@ func (p *PMAPIProvider) CreateMailbox(mailbox Mailbox) (Mailbox, error) {
|
||||
exclusive = 1
|
||||
}
|
||||
|
||||
label, err := p.client().CreateLabel(&pmapi.Label{
|
||||
label, err := p.client.CreateLabel(context.TODO(), &pmapi.Label{
|
||||
Name: mailbox.Name,
|
||||
Color: mailbox.Color,
|
||||
Exclusive: exclusive,
|
||||
@ -194,7 +195,7 @@ func (p *PMAPIProvider) transferMessage(rules transferRules, progress *Progress,
|
||||
return
|
||||
}
|
||||
|
||||
importMsgReqSize := len(importMsgReq.Body)
|
||||
importMsgReqSize := len(importMsgReq.Message)
|
||||
if p.nextImportRequestsSize+importMsgReqSize > pmapiImportBatchMaxSize || len(p.nextImportRequests) == pmapiImportBatchMaxItems {
|
||||
preparedImportRequestsCh <- p.nextImportRequests
|
||||
p.nextImportRequests = map[string]*pmapi.ImportMsgReq{}
|
||||
@ -226,9 +227,12 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
|
||||
}
|
||||
}
|
||||
|
||||
unread := 0
|
||||
var unread pmapi.Boolean
|
||||
|
||||
if msg.Unread {
|
||||
unread = 1
|
||||
unread = pmapi.True
|
||||
} else {
|
||||
unread = pmapi.False
|
||||
}
|
||||
|
||||
labelIDs := []string{}
|
||||
@ -243,12 +247,14 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
|
||||
}
|
||||
|
||||
return &pmapi.ImportMsgReq{
|
||||
AddressID: p.addressID,
|
||||
Body: body,
|
||||
Unread: unread,
|
||||
Time: message.Time,
|
||||
Flags: computeMessageFlags(message.Header),
|
||||
LabelIDs: labelIDs,
|
||||
Metadata: &pmapi.ImportMetadata{
|
||||
AddressID: p.addressID,
|
||||
Unread: unread,
|
||||
Time: message.Time,
|
||||
Flags: computeMessageFlags(message.Header),
|
||||
LabelIDs: labelIDs,
|
||||
},
|
||||
Message: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -293,7 +299,7 @@ func (p *PMAPIProvider) importMessages(progress *Progress, importRequests map[st
|
||||
}
|
||||
|
||||
importMsgIDs := []string{}
|
||||
importMsgRequests := []*pmapi.ImportMsgReq{}
|
||||
importMsgRequests := pmapi.ImportMsgReqs{}
|
||||
for msgID, req := range importRequests {
|
||||
importMsgIDs = append(importMsgIDs, msgID)
|
||||
importMsgRequests = append(importMsgRequests, req)
|
||||
@ -327,7 +333,7 @@ func (p *PMAPIProvider) importMessages(progress *Progress, importRequests map[st
|
||||
|
||||
func (p *PMAPIProvider) importMessage(msgSourceID string, progress *Progress, req *pmapi.ImportMsgReq) (importedID string, importedErr error) {
|
||||
progress.callWrap(func() error {
|
||||
results, err := p.importRequest(msgSourceID, []*pmapi.ImportMsgReq{req})
|
||||
results, err := p.importRequest(msgSourceID, pmapi.ImportMsgReqs{req})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to import messages")
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ package transfer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@ -33,7 +34,7 @@ func TestPMAPIProviderMailboxes(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
setupPMAPIClientExpectationForExport(&m)
|
||||
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
@ -78,7 +79,7 @@ func TestPMAPIProviderTransferTo(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
setupPMAPIClientExpectationForExport(&m)
|
||||
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
|
||||
rules, rulesClose := newTestRules(t)
|
||||
@ -96,7 +97,7 @@ func TestPMAPIProviderTransferFrom(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
setupPMAPIClientExpectationForImport(&m)
|
||||
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
|
||||
rules, rulesClose := newTestRules(t)
|
||||
@ -114,7 +115,7 @@ func TestPMAPIProviderTransferFromDraft(t *testing.T) {
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
setupPMAPIClientExpectationForImportDraft(&m)
|
||||
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
|
||||
rules, rulesClose := newTestRules(t)
|
||||
@ -133,9 +134,9 @@ func TestPMAPIProviderTransferFromTo(t *testing.T) {
|
||||
setupPMAPIClientExpectationForExport(&m)
|
||||
setupPMAPIClientExpectationForImport(&m)
|
||||
|
||||
source, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
source, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
target, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
|
||||
target, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
|
||||
r.NoError(t, err)
|
||||
|
||||
rules, rulesClose := newTestRules(t)
|
||||
@ -151,22 +152,22 @@ func setupPMAPIRules(rules transferRules) {
|
||||
|
||||
func setupPMAPIClientExpectationForExport(m *mocks) {
|
||||
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{
|
||||
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: 0, Order: 2},
|
||||
{ID: "label2", Name: "Bar", Color: "green", Exclusive: 0, Order: 1},
|
||||
{ID: "folder1", Name: "One", Color: "red", Exclusive: 1, Order: 1},
|
||||
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: 1, Order: 2},
|
||||
}, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any()).Return([]*pmapi.MessagesCount{
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{
|
||||
{LabelID: "label1", Total: 10},
|
||||
{LabelID: "label2", Total: 0},
|
||||
{LabelID: "folder1", Total: 20},
|
||||
}, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{
|
||||
{ID: "msg1"},
|
||||
{ID: "msg2"},
|
||||
}, 2, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().GetMessage(gomock.Any()).DoAndReturn(func(msgID string) (*pmapi.Message, error) {
|
||||
m.pmapiClient.EXPECT().GetMessage(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, msgID string) (*pmapi.Message, error) {
|
||||
return &pmapi.Message{
|
||||
ID: msgID,
|
||||
Body: string(getTestMsgBody(msgID)),
|
||||
@ -177,11 +178,11 @@ func setupPMAPIClientExpectationForExport(m *mocks) {
|
||||
|
||||
func setupPMAPIClientExpectationForImport(m *mocks) {
|
||||
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().Import(gomock.Any()).DoAndReturn(func(requests []*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error) {
|
||||
m.pmapiClient.EXPECT().Import(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, requests pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) {
|
||||
results := []*pmapi.ImportMsgRes{}
|
||||
for _, request := range requests {
|
||||
for _, msgID := range []string{"msg1", "msg2"} {
|
||||
if bytes.Contains(request.Body, []byte(msgID)) {
|
||||
if bytes.Contains(request.Message, []byte(msgID)) {
|
||||
results = append(results, &pmapi.ImportMsgRes{MessageID: msgID, Error: nil})
|
||||
}
|
||||
}
|
||||
@ -192,7 +193,7 @@ func setupPMAPIClientExpectationForImport(m *mocks) {
|
||||
|
||||
func setupPMAPIClientExpectationForImportDraft(m *mocks) {
|
||||
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().CreateDraft(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(msg *pmapi.Message, parentID string, action int) (*pmapi.Message, error) {
|
||||
m.pmapiClient.EXPECT().CreateDraft(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, msg *pmapi.Message, parentID string, action int) (*pmapi.Message, error) {
|
||||
r.Equal(m.t, msg.Subject, "draft1")
|
||||
msg.ID = "draft1"
|
||||
return msg, nil
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
@ -57,13 +58,18 @@ func (p *PMAPIProvider) tryReconnect() error {
|
||||
return previousErr
|
||||
}
|
||||
|
||||
err := p.clientManager.CheckConnection()
|
||||
log.WithError(err).Debug("Connection check")
|
||||
if err != nil {
|
||||
time.Sleep(pmapiReconnectSleep)
|
||||
previousErr = err
|
||||
continue
|
||||
}
|
||||
// FIXME(conman): This should register as a connection observer somehow...
|
||||
// Maybe the entire "provider" could register as an observer and pause if it is notified of dropped connection?
|
||||
|
||||
/*
|
||||
err := p.clientManager.CheckConnection()
|
||||
log.WithError(err).Debug("Connection check")
|
||||
if err != nil {
|
||||
time.Sleep(pmapiReconnectSleep)
|
||||
previousErr = err
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
break
|
||||
}
|
||||
@ -77,7 +83,7 @@ func (p *PMAPIProvider) listMessages(filter *pmapi.MessagesFilter) (messages []*
|
||||
p.timeIt.start("listing", key)
|
||||
defer p.timeIt.stop("listing", key)
|
||||
|
||||
messages, count, err = p.client().ListMessages(filter)
|
||||
messages, count, err = p.client.ListMessages(context.TODO(), filter)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -88,18 +94,18 @@ func (p *PMAPIProvider) getMessage(msgID string) (message *pmapi.Message, err er
|
||||
p.timeIt.start("download", msgID)
|
||||
defer p.timeIt.stop("download", msgID)
|
||||
|
||||
message, err = p.client().GetMessage(msgID)
|
||||
message, err = p.client.GetMessage(context.TODO(), msgID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (p *PMAPIProvider) importRequest(msgSourceID string, req []*pmapi.ImportMsgReq) (res []*pmapi.ImportMsgRes, err error) {
|
||||
func (p *PMAPIProvider) importRequest(msgSourceID string, req pmapi.ImportMsgReqs) (res []*pmapi.ImportMsgRes, err error) {
|
||||
err = p.ensureConnection(func() error {
|
||||
p.timeIt.start("upload", msgSourceID)
|
||||
defer p.timeIt.stop("upload", msgSourceID)
|
||||
|
||||
res, err = p.client().Import(req)
|
||||
res, err = p.client.Import(context.TODO(), req)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -110,7 +116,7 @@ func (p *PMAPIProvider) createDraft(msgSourceID string, message *pmapi.Message,
|
||||
p.timeIt.start("upload", msgSourceID)
|
||||
defer p.timeIt.stop("upload", msgSourceID)
|
||||
|
||||
draft, err = p.client().CreateDraft(message, parent, action)
|
||||
draft, err = p.client.CreateDraft(context.TODO(), message, parent, action)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -123,7 +129,7 @@ func (p *PMAPIProvider) createAttachment(msgSourceID string, att *pmapi.Attachme
|
||||
p.timeIt.start("upload", key)
|
||||
defer p.timeIt.stop("upload", key)
|
||||
|
||||
created, err = p.client().CreateAttachment(att, r, sig)
|
||||
created, err = p.client.CreateAttachment(context.TODO(), att, r, sig)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
||||
@ -23,7 +23,6 @@ import (
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
transfermocks "github.com/ProtonMail/proton-bridge/internal/transfer/mocks"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
@ -33,10 +32,8 @@ type mocks struct {
|
||||
|
||||
ctrl *gomock.Controller
|
||||
panicHandler *transfermocks.MockPanicHandler
|
||||
clientManager *transfermocks.MockClientManager
|
||||
imapClientProvider *transfermocks.MockIMAPClientProvider
|
||||
pmapiClient *pmapimocks.MockClient
|
||||
pmapiConfig *pmapi.ClientConfig
|
||||
|
||||
keyring *crypto.KeyRing
|
||||
}
|
||||
@ -49,15 +46,11 @@ func initMocks(t *testing.T) mocks {
|
||||
|
||||
ctrl: mockCtrl,
|
||||
panicHandler: transfermocks.NewMockPanicHandler(mockCtrl),
|
||||
clientManager: transfermocks.NewMockClientManager(mockCtrl),
|
||||
imapClientProvider: transfermocks.NewMockIMAPClientProvider(mockCtrl),
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
pmapiConfig: &pmapi.ClientConfig{},
|
||||
keyring: newTestKeyring(),
|
||||
}
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).AnyTimes()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
|
||||
@ -17,10 +17,6 @@
|
||||
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
type PanicHandler interface {
|
||||
HandlePanic()
|
||||
}
|
||||
@ -32,8 +28,3 @@ type MetricsManager interface {
|
||||
Cancel()
|
||||
Fail()
|
||||
}
|
||||
|
||||
type ClientManager interface {
|
||||
GetClient(userID string) pmapi.Client
|
||||
CheckConnection() error
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
@ -31,10 +32,6 @@ import (
|
||||
|
||||
var ErrManualUpdateRequired = errors.New("manual update is required")
|
||||
|
||||
type ClientProvider interface {
|
||||
GetAnonymousClient() pmapi.Client
|
||||
}
|
||||
|
||||
type Installer interface {
|
||||
InstallUpdate(*semver.Version, io.Reader) error
|
||||
}
|
||||
@ -46,7 +43,7 @@ type Settings interface {
|
||||
}
|
||||
|
||||
type Updater struct {
|
||||
cm ClientProvider
|
||||
cm pmapi.Manager
|
||||
installer Installer
|
||||
settings Settings
|
||||
kr *crypto.KeyRing
|
||||
@ -59,7 +56,7 @@ type Updater struct {
|
||||
}
|
||||
|
||||
func New(
|
||||
cm ClientProvider,
|
||||
cm pmapi.Manager,
|
||||
installer Installer,
|
||||
s Settings,
|
||||
kr *crypto.KeyRing,
|
||||
@ -87,13 +84,10 @@ func New(
|
||||
func (u *Updater) Check() (VersionInfo, error) {
|
||||
logrus.Info("Checking for updates")
|
||||
|
||||
client := u.cm.GetAnonymousClient()
|
||||
defer client.Logout()
|
||||
|
||||
r, err := client.DownloadAndVerify(
|
||||
b, err := u.cm.DownloadAndVerify(
|
||||
u.kr,
|
||||
u.getVersionFileURL(),
|
||||
u.getVersionFileURL()+".sig",
|
||||
u.kr,
|
||||
)
|
||||
if err != nil {
|
||||
return VersionInfo{}, err
|
||||
@ -101,7 +95,7 @@ func (u *Updater) Check() (VersionInfo, error) {
|
||||
|
||||
var versionMap VersionMap
|
||||
|
||||
if err := json.NewDecoder(r).Decode(&versionMap); err != nil {
|
||||
if err := json.Unmarshal(b, &versionMap); err != nil {
|
||||
return VersionInfo{}, err
|
||||
}
|
||||
|
||||
@ -141,15 +135,12 @@ func (u *Updater) InstallUpdate(update VersionInfo) error {
|
||||
return u.locker.doOnce(func() error {
|
||||
logrus.WithField("package", update.Package).Info("Installing update package")
|
||||
|
||||
client := u.cm.GetAnonymousClient()
|
||||
defer client.Logout()
|
||||
|
||||
r, err := client.DownloadAndVerify(update.Package, update.Package+".sig", u.kr)
|
||||
b, err := u.cm.DownloadAndVerify(u.kr, update.Package, update.Package+".sig")
|
||||
if err != nil {
|
||||
return errors.Wrap(ErrDownloadVerify, err.Error())
|
||||
}
|
||||
|
||||
if err := u.installer.InstallUpdate(update.Version, r); err != nil {
|
||||
if err := u.installer.InstallUpdate(update.Version, bytes.NewReader(b)); err != nil {
|
||||
return errors.Wrap(ErrInstall, err.Error())
|
||||
}
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
@ -29,7 +28,6 @@ import (
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -40,9 +38,9 @@ func TestCheck(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.1.0", false)
|
||||
updater := newTestUpdater(cm, "1.1.0", false)
|
||||
|
||||
versionMap := VersionMap{
|
||||
"stable": VersionInfo{
|
||||
@ -53,13 +51,11 @@ func TestCheck(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
updater.getVersionFileURL(),
|
||||
updater.getVersionFileURL()+".sig",
|
||||
gomock.Any(),
|
||||
).Return(bytes.NewReader(mustMarshal(t, versionMap)), nil)
|
||||
|
||||
client.EXPECT().Logout()
|
||||
).Return(mustMarshal(t, versionMap), nil)
|
||||
|
||||
version, err := updater.Check()
|
||||
|
||||
@ -71,9 +67,9 @@ func TestCheckEarlyAccess(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.1.0", true)
|
||||
updater := newTestUpdater(cm, "1.1.0", true)
|
||||
|
||||
versionMap := VersionMap{
|
||||
"stable": VersionInfo{
|
||||
@ -90,13 +86,11 @@ func TestCheckEarlyAccess(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
updater.getVersionFileURL(),
|
||||
updater.getVersionFileURL()+".sig",
|
||||
gomock.Any(),
|
||||
).Return(bytes.NewReader(mustMarshal(t, versionMap)), nil)
|
||||
|
||||
client.EXPECT().Logout()
|
||||
).Return(mustMarshal(t, versionMap), nil)
|
||||
|
||||
version, err := updater.Check()
|
||||
|
||||
@ -108,18 +102,16 @@ func TestCheckBadSignature(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.2.0", false)
|
||||
updater := newTestUpdater(cm, "1.2.0", false)
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
updater.getVersionFileURL(),
|
||||
updater.getVersionFileURL()+".sig",
|
||||
gomock.Any(),
|
||||
).Return(nil, errors.New("bad signature"))
|
||||
|
||||
client.EXPECT().Logout()
|
||||
|
||||
_, err := updater.Check()
|
||||
|
||||
assert.Error(t, err)
|
||||
@ -129,9 +121,9 @@ func TestIsUpdateApplicable(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.4.0", false)
|
||||
updater := newTestUpdater(cm, "1.4.0", false)
|
||||
|
||||
versionOld := VersionInfo{
|
||||
Version: semver.MustParse("1.3.0"),
|
||||
@ -165,9 +157,9 @@ func TestCanInstall(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.4.0", false)
|
||||
updater := newTestUpdater(cm, "1.4.0", false)
|
||||
|
||||
versionManual := VersionInfo{
|
||||
Version: semver.MustParse("1.5.0"),
|
||||
@ -192,9 +184,9 @@ func TestInstallUpdate(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.4.0", false)
|
||||
updater := newTestUpdater(cm, "1.4.0", false)
|
||||
|
||||
latestVersion := VersionInfo{
|
||||
Version: semver.MustParse("1.5.0"),
|
||||
@ -203,13 +195,11 @@ func TestInstallUpdate(t *testing.T) {
|
||||
RolloutProportion: 1.0,
|
||||
}
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
latestVersion.Package,
|
||||
latestVersion.Package+".sig",
|
||||
gomock.Any(),
|
||||
).Return(bytes.NewReader([]byte("tgz_data_here")), nil)
|
||||
|
||||
client.EXPECT().Logout()
|
||||
).Return([]byte("tgz_data_here"), nil)
|
||||
|
||||
err := updater.InstallUpdate(latestVersion)
|
||||
|
||||
@ -220,9 +210,9 @@ func TestInstallUpdateBadSignature(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.4.0", false)
|
||||
updater := newTestUpdater(cm, "1.4.0", false)
|
||||
|
||||
latestVersion := VersionInfo{
|
||||
Version: semver.MustParse("1.5.0"),
|
||||
@ -231,14 +221,12 @@ func TestInstallUpdateBadSignature(t *testing.T) {
|
||||
RolloutProportion: 1.0,
|
||||
}
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
latestVersion.Package,
|
||||
latestVersion.Package+".sig",
|
||||
gomock.Any(),
|
||||
).Return(nil, errors.New("bad signature"))
|
||||
|
||||
client.EXPECT().Logout()
|
||||
|
||||
err := updater.InstallUpdate(latestVersion)
|
||||
|
||||
assert.Error(t, err)
|
||||
@ -248,9 +236,9 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
client := mocks.NewMockClient(c)
|
||||
cm := mocks.NewMockManager(c)
|
||||
|
||||
updater := newTestUpdater(client, "1.4.0", false)
|
||||
updater := newTestUpdater(cm, "1.4.0", false)
|
||||
|
||||
updater.installer = &fakeInstaller{delay: 2 * time.Second}
|
||||
|
||||
@ -261,13 +249,11 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
|
||||
RolloutProportion: 1.0,
|
||||
}
|
||||
|
||||
client.EXPECT().DownloadAndVerify(
|
||||
cm.EXPECT().DownloadAndVerify(
|
||||
gomock.Any(),
|
||||
latestVersion.Package,
|
||||
latestVersion.Package+".sig",
|
||||
gomock.Any(),
|
||||
).Return(bytes.NewReader([]byte("tgz_data_here")), nil)
|
||||
|
||||
client.EXPECT().Logout()
|
||||
).Return([]byte("tgz_data_here"), nil)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
@ -288,9 +274,9 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func newTestUpdater(client *mocks.MockClient, curVer string, earlyAccess bool) *Updater {
|
||||
func newTestUpdater(manager *mocks.MockManager, curVer string, earlyAccess bool) *Updater {
|
||||
return New(
|
||||
&fakeClientProvider{client: client},
|
||||
manager,
|
||||
&fakeInstaller{},
|
||||
newFakeSettings(0.5, earlyAccess),
|
||||
nil,
|
||||
@ -299,14 +285,6 @@ func newTestUpdater(client *mocks.MockClient, curVer string, earlyAccess bool) *
|
||||
)
|
||||
}
|
||||
|
||||
type fakeClientProvider struct {
|
||||
client *mocks.MockClient
|
||||
}
|
||||
|
||||
func (p *fakeClientProvider) GetAnonymousClient() pmapi.Client {
|
||||
return p.client
|
||||
}
|
||||
|
||||
type fakeInstaller struct {
|
||||
bad bool
|
||||
delay time.Duration
|
||||
|
||||
@ -133,7 +133,7 @@ func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
|
||||
|
||||
// store.New() in user.init
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrInvalidToken),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUnauthorized),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
|
||||
// getAPIUser() loads user info from API (e.g. userID).
|
||||
@ -149,3 +149,13 @@ func (s *Credentials) Logout() {
|
||||
func (s *Credentials) IsConnected() bool {
|
||||
return s.APIToken != "" && s.MailboxPassword != ""
|
||||
}
|
||||
|
||||
func (s *Credentials) SplitAPIToken() (string, string, error) {
|
||||
split := strings.Split(s.APIToken, ":")
|
||||
|
||||
if len(split) != 2 {
|
||||
return "", "", errors.New("malformed API token")
|
||||
}
|
||||
|
||||
return split[0], split[1], nil
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ func NewStore(keychain *keychain.Keychain) *Store {
|
||||
return &Store{secrets: keychain}
|
||||
}
|
||||
|
||||
func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails []string) (creds *Credentials, err error) {
|
||||
func (s *Store) Add(userID, userName, uid, ref, mailboxPassword string, emails []string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
@ -49,10 +49,10 @@ func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails [
|
||||
"emails": emails,
|
||||
}).Trace("Adding new credentials")
|
||||
|
||||
creds = &Credentials{
|
||||
creds := &Credentials{
|
||||
UserID: userID,
|
||||
Name: userName,
|
||||
APIToken: apiToken,
|
||||
APIToken: uid + ":" + ref,
|
||||
MailboxPassword: mailboxPassword,
|
||||
IsHidden: false,
|
||||
}
|
||||
@ -72,82 +72,82 @@ func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails [
|
||||
creds.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
if err = s.saveCredentials(creds); err != nil {
|
||||
return
|
||||
if err := s.saveCredentials(creds); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return creds, err
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
func (s *Store) SwitchAddressMode(userID string) error {
|
||||
func (s *Store) SwitchAddressMode(userID string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials.IsCombinedAddressMode = !credentials.IsCombinedAddressMode
|
||||
credentials.BridgePassword = generatePassword()
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
return credentials, s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateEmails(userID string, emails []string) error {
|
||||
func (s *Store) UpdateEmails(userID string, emails []string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials.SetEmailList(emails)
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
return credentials, s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdatePassword(userID, password string) error {
|
||||
func (s *Store) UpdatePassword(userID, password string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials.MailboxPassword = password
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
return credentials, s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateToken(userID, apiToken string) error {
|
||||
func (s *Store) UpdateToken(userID, uid, ref string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials.APIToken = apiToken
|
||||
credentials.APIToken = uid + ":" + ref
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
return credentials, s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) Logout(userID string) error {
|
||||
func (s *Store) Logout(userID string) (*Credentials, error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials.Logout()
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
return credentials, s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
// List returns a list of usernames that have credentials stored.
|
||||
@ -249,7 +249,7 @@ func (s *Store) get(userID string) (creds *Credentials, err error) {
|
||||
}
|
||||
|
||||
// saveCredentials encrypts and saves password to the keychain store.
|
||||
func (s *Store) saveCredentials(credentials *Credentials) (err error) {
|
||||
func (s *Store) saveCredentials(credentials *Credentials) error {
|
||||
credentials.Version = keychain.Version
|
||||
|
||||
return s.secrets.Put(credentials.UserID, credentials.Marshal())
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/users (interfaces: Locator,PanicHandler,ClientManager,CredentialsStorer,StoreMaker)
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/users (interfaces: Locator,PanicHandler,CredentialsStorer,StoreMaker)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
@ -9,7 +9,6 @@ import (
|
||||
|
||||
store "github.com/ProtonMail/proton-bridge/internal/store"
|
||||
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
@ -85,109 +84,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockClientManager is a mock of ClientManager interface
|
||||
type MockClientManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||
type MockClientManagerMockRecorder struct {
|
||||
mock *MockClientManager
|
||||
}
|
||||
|
||||
// NewMockClientManager creates a new mock instance
|
||||
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||
mock := &MockClientManager{ctrl: ctrl}
|
||||
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AllowProxy mocks base method
|
||||
func (m *MockClientManager) AllowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AllowProxy")
|
||||
}
|
||||
|
||||
// AllowProxy indicates an expected call of AllowProxy
|
||||
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// CheckConnection mocks base method
|
||||
func (m *MockClientManager) CheckConnection() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CheckConnection")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CheckConnection indicates an expected call of CheckConnection
|
||||
func (mr *MockClientManagerMockRecorder) CheckConnection() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConnection", reflect.TypeOf((*MockClientManager)(nil).CheckConnection))
|
||||
}
|
||||
|
||||
// DisallowProxy mocks base method
|
||||
func (m *MockClientManager) DisallowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisallowProxy")
|
||||
}
|
||||
|
||||
// DisallowProxy indicates an expected call of DisallowProxy
|
||||
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
|
||||
}
|
||||
|
||||
// GetAnonymousClient mocks base method
|
||||
func (m *MockClientManager) GetAnonymousClient() pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAnonymousClient")
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAnonymousClient indicates an expected call of GetAnonymousClient
|
||||
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
|
||||
}
|
||||
|
||||
// GetAuthUpdateChannel mocks base method
|
||||
func (m *MockClientManager) GetAuthUpdateChannel() chan pmapi.ClientAuth {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthUpdateChannel")
|
||||
ret0, _ := ret[0].(chan pmapi.ClientAuth)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAuthUpdateChannel indicates an expected call of GetAuthUpdateChannel
|
||||
func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetAuthUpdateChannel))
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||
}
|
||||
|
||||
// MockCredentialsStorer is a mock of CredentialsStorer interface
|
||||
type MockCredentialsStorer struct {
|
||||
ctrl *gomock.Controller
|
||||
@ -212,18 +108,18 @@ func (m *MockCredentialsStorer) EXPECT() *MockCredentialsStorerMockRecorder {
|
||||
}
|
||||
|
||||
// Add mocks base method
|
||||
func (m *MockCredentialsStorer) Add(arg0, arg1, arg2, arg3 string, arg4 []string) (*credentials.Credentials, error) {
|
||||
func (m *MockCredentialsStorer) Add(arg0, arg1, arg2, arg3, arg4 string, arg5 []string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3, arg4)
|
||||
ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add
|
||||
func (mr *MockCredentialsStorerMockRecorder) Add(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
|
||||
func (mr *MockCredentialsStorerMockRecorder) Add(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockCredentialsStorer)(nil).Add), arg0, arg1, arg2, arg3, arg4)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockCredentialsStorer)(nil).Add), arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
}
|
||||
|
||||
// Delete mocks base method
|
||||
@ -271,11 +167,12 @@ func (mr *MockCredentialsStorerMockRecorder) List() *gomock.Call {
|
||||
}
|
||||
|
||||
// Logout mocks base method
|
||||
func (m *MockCredentialsStorer) Logout(arg0 string) error {
|
||||
func (m *MockCredentialsStorer) Logout(arg0 string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Logout", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Logout indicates an expected call of Logout
|
||||
@ -285,11 +182,12 @@ func (mr *MockCredentialsStorerMockRecorder) Logout(arg0 interface{}) *gomock.Ca
|
||||
}
|
||||
|
||||
// SwitchAddressMode mocks base method
|
||||
func (m *MockCredentialsStorer) SwitchAddressMode(arg0 string) error {
|
||||
func (m *MockCredentialsStorer) SwitchAddressMode(arg0 string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SwitchAddressMode", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SwitchAddressMode indicates an expected call of SwitchAddressMode
|
||||
@ -299,11 +197,12 @@ func (mr *MockCredentialsStorerMockRecorder) SwitchAddressMode(arg0 interface{})
|
||||
}
|
||||
|
||||
// UpdateEmails mocks base method
|
||||
func (m *MockCredentialsStorer) UpdateEmails(arg0 string, arg1 []string) error {
|
||||
func (m *MockCredentialsStorer) UpdateEmails(arg0 string, arg1 []string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateEmails", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateEmails indicates an expected call of UpdateEmails
|
||||
@ -313,11 +212,12 @@ func (mr *MockCredentialsStorerMockRecorder) UpdateEmails(arg0, arg1 interface{}
|
||||
}
|
||||
|
||||
// UpdatePassword mocks base method
|
||||
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) error {
|
||||
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdatePassword indicates an expected call of UpdatePassword
|
||||
@ -327,17 +227,18 @@ func (mr *MockCredentialsStorerMockRecorder) UpdatePassword(arg0, arg1 interface
|
||||
}
|
||||
|
||||
// UpdateToken mocks base method
|
||||
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1 string) error {
|
||||
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1, arg2 string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateToken", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret := m.ctrl.Call(m, "UpdateToken", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateToken indicates an expected call of UpdateToken
|
||||
func (mr *MockCredentialsStorerMockRecorder) UpdateToken(arg0, arg1 interface{}) *gomock.Call {
|
||||
func (mr *MockCredentialsStorerMockRecorder) UpdateToken(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateToken), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateToken), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// MockStoreMaker is a mock of StoreMaker interface
|
||||
|
||||
@ -20,14 +20,8 @@ package users
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/internal/store"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
type Configer interface {
|
||||
GetAppVersion() string
|
||||
GetAPIConfig() *pmapi.ClientConfig
|
||||
}
|
||||
|
||||
type Locator interface {
|
||||
Clear() error
|
||||
}
|
||||
@ -38,25 +32,16 @@ type PanicHandler interface {
|
||||
|
||||
type CredentialsStorer interface {
|
||||
List() (userIDs []string, err error)
|
||||
Add(userID, userName, apiToken, mailboxPassword string, emails []string) (*credentials.Credentials, error)
|
||||
Add(userID, userName, uid, ref, mailboxPassword string, emails []string) (*credentials.Credentials, error)
|
||||
Get(userID string) (*credentials.Credentials, error)
|
||||
SwitchAddressMode(userID string) error
|
||||
UpdateEmails(userID string, emails []string) error
|
||||
UpdatePassword(userID, password string) error
|
||||
UpdateToken(userID, apiToken string) error
|
||||
Logout(userID string) error
|
||||
SwitchAddressMode(userID string) (*credentials.Credentials, error)
|
||||
UpdateEmails(userID string, emails []string) (*credentials.Credentials, error)
|
||||
UpdatePassword(userID, password string) (*credentials.Credentials, error)
|
||||
UpdateToken(userID, uid, ref string) (*credentials.Credentials, error)
|
||||
Logout(userID string) (*credentials.Credentials, error)
|
||||
Delete(userID string) error
|
||||
}
|
||||
|
||||
type ClientManager interface {
|
||||
GetClient(userID string) pmapi.Client
|
||||
GetAnonymousClient() pmapi.Client
|
||||
AllowProxy()
|
||||
DisallowProxy()
|
||||
GetAuthUpdateChannel() chan pmapi.ClientAuth
|
||||
CheckConnection() error
|
||||
}
|
||||
|
||||
type StoreMaker interface {
|
||||
New(user store.BridgeUser) (*store.Store, error)
|
||||
Remove(userID string) error
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -36,11 +37,11 @@ var ErrLoggedOutUser = errors.New("account is logged out, use the app to login a
|
||||
|
||||
// User is a struct on top of API client and credentials store.
|
||||
type User struct {
|
||||
log *logrus.Entry
|
||||
panicHandler PanicHandler
|
||||
listener listener.Listener
|
||||
clientManager ClientManager
|
||||
credStorer CredentialsStorer
|
||||
log *logrus.Entry
|
||||
panicHandler PanicHandler
|
||||
listener listener.Listener
|
||||
client pmapi.Client
|
||||
credStorer CredentialsStorer
|
||||
|
||||
storeFactory StoreMaker
|
||||
store *store.Store
|
||||
@ -48,75 +49,76 @@ type User struct {
|
||||
userID string
|
||||
creds *credentials.Credentials
|
||||
|
||||
lock sync.RWMutex
|
||||
isAuthorized bool
|
||||
lock sync.RWMutex
|
||||
|
||||
useOnlyActiveAddresses bool
|
||||
}
|
||||
|
||||
// newUser creates a new user.
|
||||
// The user is initially disconnected and must be connected by calling connect().
|
||||
func newUser(
|
||||
panicHandler PanicHandler,
|
||||
userID string,
|
||||
eventListener listener.Listener,
|
||||
credStorer CredentialsStorer,
|
||||
clientManager ClientManager,
|
||||
storeFactory StoreMaker,
|
||||
) (u *User, err error) {
|
||||
useOnlyActiveAddresses bool,
|
||||
) (*User, *credentials.Credentials, error) {
|
||||
log := log.WithField("user", userID)
|
||||
|
||||
log.Debug("Creating or loading user")
|
||||
|
||||
creds, err := credStorer.Get(userID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load user credentials")
|
||||
return nil, nil, errors.Wrap(err, "failed to load user credentials")
|
||||
}
|
||||
|
||||
u = &User{
|
||||
log: log,
|
||||
panicHandler: panicHandler,
|
||||
listener: eventListener,
|
||||
credStorer: credStorer,
|
||||
clientManager: clientManager,
|
||||
storeFactory: storeFactory,
|
||||
userID: userID,
|
||||
creds: creds,
|
||||
}
|
||||
|
||||
return
|
||||
return &User{
|
||||
log: log,
|
||||
panicHandler: panicHandler,
|
||||
listener: eventListener,
|
||||
credStorer: credStorer,
|
||||
storeFactory: storeFactory,
|
||||
userID: userID,
|
||||
creds: creds,
|
||||
useOnlyActiveAddresses: useOnlyActiveAddresses,
|
||||
}, creds, nil
|
||||
}
|
||||
|
||||
func (u *User) client() pmapi.Client {
|
||||
return u.clientManager.GetClient(u.userID)
|
||||
}
|
||||
// connect connects a user. This includes
|
||||
// - providing it with an authorised API client
|
||||
// - loading its credentials from the credentials store
|
||||
// - loading and unlocking its PGP keys
|
||||
// - loading its store
|
||||
func (u *User) connect(ctx context.Context, client pmapi.Client, creds *credentials.Credentials) error {
|
||||
u.log.Info("Connecting user")
|
||||
|
||||
// init initialises a user. This includes reloading its credentials from the credentials store
|
||||
// (such as when logging out and back in, you need to reload the credentials because the new credentials will
|
||||
// have the apitoken and password), authorising the user against the api, loading the user store (creating a new one
|
||||
// if necessary), and setting the imap idle updates channel (used to send imap idle updates to the imap backend if
|
||||
// something in the store changed).
|
||||
func (u *User) init() (err error) {
|
||||
u.log.Info("Initialising user")
|
||||
// Connected users have an API client.
|
||||
u.client = client
|
||||
|
||||
// Reload the user's credentials (if they log out and back in we need the new
|
||||
// version with the apitoken and mailbox password).
|
||||
creds, err := u.credStorer.Get(u.userID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to load user credentials")
|
||||
}
|
||||
// FIXME(conman): How to remove this auth handler when user is disconnected?
|
||||
u.client.AddAuthHandler(u.handleAuth)
|
||||
|
||||
// Save the latest credentials for the user.
|
||||
u.creds = creds
|
||||
|
||||
// Try to authorise the user if they aren't already authorised.
|
||||
// Note: we still allow users to set up accounts if the internet is off.
|
||||
if authErr := u.authorizeIfNecessary(false); authErr != nil {
|
||||
switch errors.Cause(authErr) {
|
||||
case pmapi.ErrAPINotReachable, pmapi.ErrUpgradeApplication, ErrLoggedOutUser:
|
||||
u.log.WithError(authErr).Warn("Could not authorize user")
|
||||
default:
|
||||
if logoutErr := u.logout(); logoutErr != nil {
|
||||
u.log.WithError(logoutErr).Warn("Could not logout user")
|
||||
}
|
||||
return errors.Wrap(authErr, "failed to authorize user")
|
||||
// Connected users have unlocked keys.
|
||||
// FIXME(conman): clients should always be authorized! This is a workaround to avoid a major refactor :(
|
||||
if u.creds.IsConnected() {
|
||||
if err := u.client.Unlock(ctx, []byte(u.creds.MailboxPassword)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Connected users have a store.
|
||||
if err := u.loadStore(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) loadStore() error {
|
||||
// Logged-out user keeps store running to access offline data.
|
||||
// Therefore it is necessary to close it before re-init.
|
||||
if u.store != nil {
|
||||
@ -125,93 +127,28 @@ func (u *User) init() (err error) {
|
||||
}
|
||||
u.store = nil
|
||||
}
|
||||
|
||||
store, err := u.storeFactory.New(u)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create store")
|
||||
}
|
||||
|
||||
u.store = store
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// authorizeIfNecessary checks whether user is logged in and is connected to api auth channel.
|
||||
// If user is not already connected to the api auth channel (for example there was no internet during start),
|
||||
// it tries to connect it.
|
||||
func (u *User) authorizeIfNecessary(emitEvent bool) (err error) {
|
||||
// If user is connected and has an auth channel, then perfect, nothing to do here.
|
||||
if u.creds.IsConnected() && u.isAuthorized {
|
||||
// The keyring unlock is triggered here to resolve state where apiClient
|
||||
// is authenticated (we have auth token) but it was not possible to download
|
||||
// and unlock the keys (internet not reachable).
|
||||
return u.unlockIfNecessary()
|
||||
}
|
||||
|
||||
if !u.creds.IsConnected() {
|
||||
err = ErrLoggedOutUser
|
||||
} else if err = u.authorizeAndUnlock(); err != nil {
|
||||
u.log.WithError(err).Error("Could not authorize and unlock user")
|
||||
|
||||
switch errors.Cause(err) {
|
||||
case pmapi.ErrUpgradeApplication, pmapi.ErrAPINotReachable: // Ignore these errors.
|
||||
default:
|
||||
if errLogout := u.credStorer.Logout(u.userID); errLogout != nil {
|
||||
u.log.WithField("err", errLogout).Error("Could not log user out from credentials store")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if emitEvent && err != nil &&
|
||||
errors.Cause(err) != pmapi.ErrUpgradeApplication &&
|
||||
errors.Cause(err) != pmapi.ErrAPINotReachable {
|
||||
u.listener.Emit(events.LogoutEvent, u.userID)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
|
||||
func (u *User) unlockIfNecessary() error {
|
||||
if u.client().IsUnlocked() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := u.client().Unlock([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorizeAndUnlock tries to authorize the user with the API using the the user's APIToken.
|
||||
// If that succeeds, it tries to unlock the user's keys and addresses.
|
||||
func (u *User) authorizeAndUnlock() (err error) {
|
||||
if u.creds.APIToken == "" {
|
||||
u.log.Warn("Could not connect to API auth channel, have no API token")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := u.client().AuthRefresh(u.creds.APIToken); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh API auth")
|
||||
}
|
||||
|
||||
if err := u.client().Unlock([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) updateAuthToken(auth *pmapi.Auth) {
|
||||
func (u *User) handleAuth(auth *pmapi.Auth) error {
|
||||
u.log.Debug("User received auth")
|
||||
|
||||
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
|
||||
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
|
||||
return
|
||||
creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to update refresh token in credentials store")
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
u.creds = creds
|
||||
|
||||
u.isAuthorized = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearStore removes the database.
|
||||
@ -248,7 +185,7 @@ func (u *User) closeStore() error {
|
||||
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
|
||||
// After proper refactor of SMTP and IMAP remove this method.
|
||||
func (u *User) GetTemporaryPMAPIClient() pmapi.Client {
|
||||
return u.client()
|
||||
return u.client
|
||||
}
|
||||
|
||||
// ID returns the user's userID.
|
||||
@ -272,6 +209,10 @@ func (u *User) IsConnected() bool {
|
||||
return u.creds.IsConnected()
|
||||
}
|
||||
|
||||
func (u *User) GetClient() pmapi.Client {
|
||||
return u.client
|
||||
}
|
||||
|
||||
// IsCombinedAddressMode returns whether user is set in combined or split mode.
|
||||
// Combined mode is the default mode and is what users typically need.
|
||||
// Split mode is mostly for outlook as it cannot handle sending e-mails from an
|
||||
@ -345,7 +286,7 @@ func (u *User) GetAddressID(address string) (id string, err error) {
|
||||
return u.store.GetAddressID(address)
|
||||
}
|
||||
|
||||
addresses := u.client().Addresses()
|
||||
addresses := u.client.Addresses()
|
||||
pmapiAddress := addresses.ByEmail(address)
|
||||
if pmapiAddress != nil {
|
||||
return pmapiAddress.ID, nil
|
||||
@ -366,18 +307,21 @@ func (u *User) GetBridgePassword() string {
|
||||
// CheckBridgeLogin checks whether the user is logged in and the bridge
|
||||
// IMAP/SMTP password is correct.
|
||||
func (u *User) CheckBridgeLogin(password string) error {
|
||||
if isApplicationOutdated {
|
||||
u.listener.Emit(events.UpgradeApplicationEvent, "")
|
||||
return pmapi.ErrUpgradeApplication
|
||||
}
|
||||
// FIXME(conman): Handle force upgrade?
|
||||
|
||||
/*
|
||||
if isApplicationOutdated {
|
||||
u.listener.Emit(events.UpgradeApplicationEvent, "")
|
||||
return pmapi.ErrUpgradeApplication
|
||||
}
|
||||
*/
|
||||
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
// True here because users should be notified by popup of auth failure.
|
||||
if err := u.authorizeIfNecessary(true); err != nil {
|
||||
u.log.WithError(err).Error("Failed to authorize user")
|
||||
return err
|
||||
if !u.creds.IsConnected() {
|
||||
u.listener.Emit(events.LogoutEvent, u.userID)
|
||||
return ErrLoggedOutUser
|
||||
}
|
||||
|
||||
return u.creds.CheckPassword(password)
|
||||
@ -388,60 +332,57 @@ func (u *User) UpdateUser() error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
if err := u.authorizeIfNecessary(true); err != nil {
|
||||
return errors.Wrap(err, "cannot update user")
|
||||
}
|
||||
|
||||
_, err := u.client().UpdateUser()
|
||||
_, err := u.client.UpdateUser(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = u.client().ReloadKeys([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
if err := u.client.ReloadKeys(context.TODO(), []byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to reload keys")
|
||||
}
|
||||
|
||||
emails := u.client().Addresses().ActiveEmails()
|
||||
if err := u.credStorer.UpdateEmails(u.userID, emails); err != nil {
|
||||
creds, err := u.credStorer.UpdateEmails(u.userID, u.client.Addresses().ActiveEmails())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
u.creds = creds
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SwitchAddressMode changes mode from combined to split and vice versa. The mode to switch to is determined by the
|
||||
// state of the user's credentials in the credentials store. See `IsCombinedAddressMode` for more details.
|
||||
func (u *User) SwitchAddressMode() (err error) {
|
||||
func (u *User) SwitchAddressMode() error {
|
||||
u.log.Trace("Switching user address mode")
|
||||
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
u.CloseAllConnections()
|
||||
|
||||
if u.store == nil {
|
||||
err = errors.New("store is not initialised")
|
||||
return
|
||||
return errors.New("store is not initialised")
|
||||
}
|
||||
|
||||
newAddressModeState := !u.IsCombinedAddressMode()
|
||||
|
||||
if err = u.store.UseCombinedMode(newAddressModeState); err != nil {
|
||||
u.log.WithError(err).Error("Could not switch store address mode")
|
||||
return
|
||||
if err := u.store.UseCombinedMode(newAddressModeState); err != nil {
|
||||
return errors.Wrap(err, "could not switch store address mode")
|
||||
}
|
||||
|
||||
if u.creds.IsCombinedAddressMode != newAddressModeState {
|
||||
if err = u.credStorer.SwitchAddressMode(u.userID); err != nil {
|
||||
u.log.WithError(err).Error("Could not switch credentials store address mode")
|
||||
return
|
||||
}
|
||||
if u.creds.IsCombinedAddressMode == newAddressModeState {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
creds, err := u.credStorer.SwitchAddressMode(u.userID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not switch credentials store address mode")
|
||||
}
|
||||
|
||||
return err
|
||||
u.creds = creds
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logout is the same as Logout, but for internal purposes (logged out from
|
||||
@ -458,35 +399,37 @@ func (u *User) logout() error {
|
||||
u.listener.Emit(events.UserRefreshEvent, u.userID)
|
||||
}
|
||||
|
||||
u.isAuthorized = false
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Logout logs out the user from pmapi, the credentials store, the mail store, and tries to remove as much
|
||||
// sensitive data as possible.
|
||||
func (u *User) Logout() (err error) {
|
||||
func (u *User) Logout() error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
u.log.Debug("Logging out user")
|
||||
|
||||
if !u.creds.IsConnected() {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
u.client().Logout()
|
||||
// FIXME(conman): Do we delete API client now? Who cleans up? What about registered handlers?
|
||||
if err := u.client.AuthDelete(context.TODO()); err != nil {
|
||||
u.log.WithError(err).Warn("Failed to delete auth")
|
||||
}
|
||||
|
||||
if err = u.credStorer.Logout(u.userID); err != nil {
|
||||
creds, err := u.credStorer.Logout(u.userID)
|
||||
if err != nil {
|
||||
u.log.WithError(err).Warn("Could not log user out from credentials store")
|
||||
|
||||
if err = u.credStorer.Delete(u.userID); err != nil {
|
||||
if err := u.credStorer.Delete(u.userID); err != nil {
|
||||
u.log.WithError(err).Error("Could not delete user from credentials store")
|
||||
}
|
||||
} else {
|
||||
u.creds = creds
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
|
||||
// Do not close whole store, just event loop. Some information might be needed offline (e.g. addressID)
|
||||
u.closeEventLoop()
|
||||
|
||||
@ -494,15 +437,7 @@ func (u *User) Logout() (err error) {
|
||||
|
||||
runtime.GC()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *User) refreshFromCredentials() {
|
||||
if credentials, err := u.credStorer.Get(u.userID); err != nil {
|
||||
log.WithError(err).Error("Cannot refresh user credentials")
|
||||
} else {
|
||||
u.creds = credentials
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) closeEventLoop() {
|
||||
|
||||
@ -19,12 +19,15 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
imapcache "github.com/ProtonMail/proton-bridge/internal/imap/cache"
|
||||
"github.com/ProtonMail/proton-bridge/internal/metrics"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@ -45,7 +48,7 @@ type Users struct {
|
||||
locations Locator
|
||||
panicHandler PanicHandler
|
||||
events listener.Listener
|
||||
clientManager ClientManager
|
||||
clientManager pmapi.Manager
|
||||
credStorer CredentialsStorer
|
||||
storeFactory StoreMaker
|
||||
|
||||
@ -62,16 +65,13 @@ type Users struct {
|
||||
useOnlyActiveAddresses bool
|
||||
|
||||
lock sync.RWMutex
|
||||
|
||||
// stopAll can be closed to stop all goroutines from looping (watchAppOutdated, watchAPIAuths, heartbeat etc).
|
||||
stopAll chan struct{}
|
||||
}
|
||||
|
||||
func New(
|
||||
locations Locator,
|
||||
panicHandler PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
clientManager ClientManager,
|
||||
clientManager pmapi.Manager,
|
||||
credStorer CredentialsStorer,
|
||||
storeFactory StoreMaker,
|
||||
useOnlyActiveAddresses bool,
|
||||
@ -87,98 +87,104 @@ func New(
|
||||
storeFactory: storeFactory,
|
||||
useOnlyActiveAddresses: useOnlyActiveAddresses,
|
||||
lock: sync.RWMutex{},
|
||||
stopAll: make(chan struct{}),
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
u.watchAppOutdated()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
u.watchAPIAuths()
|
||||
}()
|
||||
// FIXME(conman): Handle force upgrade events.
|
||||
/*
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
u.watchAppOutdated()
|
||||
}()
|
||||
*/
|
||||
|
||||
if u.credStorer == nil {
|
||||
log.Error("No credentials store is available")
|
||||
} else if err := u.loadUsersFromCredentialsStore(); err != nil {
|
||||
} else if err := u.loadUsersFromCredentialsStore(context.TODO()); err != nil {
|
||||
log.WithError(err).Error("Could not load all users from credentials store")
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *Users) loadUsersFromCredentialsStore() (err error) {
|
||||
func (u *Users) loadUsersFromCredentialsStore(ctx context.Context) error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
userIDs, err := u.credStorer.List()
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
for _, userID := range userIDs {
|
||||
l := log.WithField("user", userID)
|
||||
|
||||
user, newUserErr := newUser(u.panicHandler, userID, u.events, u.credStorer, u.clientManager, u.storeFactory)
|
||||
if newUserErr != nil {
|
||||
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
|
||||
user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Warn("Could not create user, skipping")
|
||||
continue
|
||||
}
|
||||
|
||||
u.users = append(u.users, user)
|
||||
|
||||
if initUserErr := user.init(); initUserErr != nil {
|
||||
l.WithField("user", userID).WithError(initUserErr).Warn("Could not initialise user")
|
||||
if creds.IsConnected() {
|
||||
if err := u.loadConnectedUser(ctx, user, creds); err != nil {
|
||||
logrus.WithError(err).Warn("Could not load connected user")
|
||||
}
|
||||
} else {
|
||||
logrus.Warn("User is disconnected and must be connected manually")
|
||||
|
||||
if err := u.loadDisconnectedUser(ctx, user, creds); err != nil {
|
||||
logrus.WithError(err).Warn("Could not load disconnected user")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *Users) watchAppOutdated() {
|
||||
ch := make(chan string)
|
||||
|
||||
u.events.Add(events.UpgradeApplicationEvent, ch)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
isApplicationOutdated = true
|
||||
u.closeAllConnections()
|
||||
|
||||
case <-u.stopAll:
|
||||
return
|
||||
}
|
||||
}
|
||||
func (u *Users) loadDisconnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
|
||||
// FIXME(conman): We shouldn't be creating unauthorized clients... this is hacky, just to avoid huge refactor!
|
||||
return user.connect(ctx, u.clientManager.NewClient("", "", "", time.Time{}), creds)
|
||||
}
|
||||
|
||||
// watchAPIAuths receives auths from the client manager and sends them to the appropriate user.
|
||||
func (u *Users) watchAPIAuths() {
|
||||
for {
|
||||
select {
|
||||
case auth := <-u.clientManager.GetAuthUpdateChannel():
|
||||
log.Debug("Users received auth from ClientManager")
|
||||
|
||||
user, ok := u.hasUser(auth.UserID)
|
||||
if !ok {
|
||||
log.WithField("userID", auth.UserID).Info("User not available for auth update")
|
||||
continue
|
||||
}
|
||||
|
||||
if auth.Auth != nil {
|
||||
user.updateAuthToken(auth.Auth)
|
||||
} else if err := user.logout(); err != nil {
|
||||
log.WithError(err).
|
||||
WithField("userID", auth.UserID).
|
||||
Error("User logout failed while watching API auths")
|
||||
}
|
||||
|
||||
case <-u.stopAll:
|
||||
return
|
||||
}
|
||||
func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
|
||||
uid, ref, err := creds.SplitAPIToken()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not get user's refresh token")
|
||||
}
|
||||
|
||||
client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref)
|
||||
if err != nil {
|
||||
// FIXME(conman): This is a problem... if we weren't able to create a new client due to internet,
|
||||
// we need to be able to retry later, but I deleted all the hacky "retry auth if necessary" stuff...
|
||||
return user.connect(ctx, u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
|
||||
}
|
||||
|
||||
// Update the user's credentials with the latest auth used to connect this user.
|
||||
if creds, err = u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil {
|
||||
return errors.Wrap(err, "could not create get user's refresh token")
|
||||
}
|
||||
|
||||
return user.connect(ctx, client, creds)
|
||||
}
|
||||
|
||||
func (u *Users) watchAppOutdated() {
|
||||
// FIXME(conman): handle force upgrade events.
|
||||
|
||||
/*
|
||||
ch := make(chan string)
|
||||
|
||||
u.events.Add(events.UpgradeApplicationEvent, ch)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
isApplicationOutdated = true
|
||||
u.closeAllConnections()
|
||||
|
||||
case <-u.stopAll:
|
||||
return
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
func (u *Users) closeAllConnections() {
|
||||
@ -192,63 +198,45 @@ func (u *Users) closeAllConnections() {
|
||||
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
|
||||
u.crashBandicoot(username)
|
||||
|
||||
// We need to use anonymous client because we don't yet have userID and so can't save auth tokens yet.
|
||||
authClient = u.clientManager.GetAnonymousClient()
|
||||
|
||||
authInfo, err := authClient.AuthInfo(username)
|
||||
if err != nil {
|
||||
log.WithField("username", username).WithError(err).Error("Could not get auth info for user")
|
||||
return
|
||||
}
|
||||
|
||||
if auth, err = authClient.Auth(username, password, authInfo); err != nil {
|
||||
log.WithField("username", username).WithError(err).Error("Could not get auth for user")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
return u.clientManager.NewClientWithLogin(context.TODO(), username, password)
|
||||
}
|
||||
|
||||
// FinishLogin finishes the login procedure and adds the user into the credentials store.
|
||||
func (u *Users) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPassphrase string) (user *User, err error) { //nolint[funlen]
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("Login not finished; removing auth session")
|
||||
if delAuthErr := authClient.DeleteAuth(); delAuthErr != nil {
|
||||
log.WithError(delAuthErr).Error("Failed to clear login session after unlock")
|
||||
}
|
||||
}
|
||||
// The anonymous client will be removed from list and authentication will not be deleted.
|
||||
authClient.Logout()
|
||||
}()
|
||||
|
||||
apiUser, hashedPassphrase, err := getAPIUser(authClient, mbPassphrase)
|
||||
func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen]
|
||||
apiUser, passphrase, err := getAPIUser(context.TODO(), client, password)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to get API user")
|
||||
return
|
||||
return nil, errors.Wrap(err, "failed to get API user")
|
||||
}
|
||||
|
||||
log.Info("Got API user")
|
||||
if user, ok := u.hasUser(apiUser.ID); ok {
|
||||
if user.IsConnected() {
|
||||
if err := client.AuthDelete(context.TODO()); err != nil {
|
||||
logrus.WithError(err).Warn("Failed to delete new auth session")
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if user, ok = u.hasUser(apiUser.ID); ok {
|
||||
if err = u.connectExistingUser(user, auth, hashedPassphrase); err != nil {
|
||||
log.WithError(err).Error("Failed to connect existing user")
|
||||
return
|
||||
return nil, errors.New("user is already connected")
|
||||
}
|
||||
} else {
|
||||
if err = u.addNewUser(apiUser, auth, hashedPassphrase); err != nil {
|
||||
log.WithError(err).Error("Failed to add new user")
|
||||
return
|
||||
|
||||
// Update the user's credentials with the latest auth used to connect this user.
|
||||
if _, err := u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load user credentials")
|
||||
}
|
||||
|
||||
// Update the password in case the user changed it.
|
||||
creds, err := u.credStorer.UpdatePassword(apiUser.ID, string(passphrase))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to update password of user in credentials store")
|
||||
}
|
||||
|
||||
if err := user.connect(context.TODO(), client, creds); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to reconnect existing user")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Old credentials use username as key (user ID) which needs to be removed
|
||||
// once user logs in again with proper ID fetched from API.
|
||||
if _, ok := u.hasUser(apiUser.Name); ok {
|
||||
if err := u.DeleteUser(apiUser.Name, true); err != nil {
|
||||
log.WithError(err).Error("Failed to delete old user")
|
||||
}
|
||||
if err := u.addNewUser(context.TODO(), client, apiUser, auth, passphrase); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to add new user")
|
||||
}
|
||||
|
||||
u.events.Emit(events.UserRefreshEvent, apiUser.ID)
|
||||
@ -256,107 +244,63 @@ func (u *Users) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPassphr
|
||||
return u.GetUser(apiUser.ID)
|
||||
}
|
||||
|
||||
// connectExistingUser connects an existing user.
|
||||
func (u *Users) connectExistingUser(user *User, auth *pmapi.Auth, hashedPassphrase string) (err error) {
|
||||
if user.IsConnected() {
|
||||
return errors.New("user is already connected")
|
||||
}
|
||||
|
||||
log.Info("Connecting existing user")
|
||||
|
||||
// Update the user's password in the cred store in case they changed it.
|
||||
if err = u.credStorer.UpdatePassword(user.ID(), hashedPassphrase); err != nil {
|
||||
return errors.Wrap(err, "failed to update password of user in credentials store")
|
||||
}
|
||||
|
||||
client := u.clientManager.GetClient(user.ID())
|
||||
|
||||
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh auth token of new client")
|
||||
}
|
||||
|
||||
if err = u.credStorer.UpdateToken(user.ID(), auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to update token of user in credentials store")
|
||||
}
|
||||
|
||||
if err = user.init(); err != nil {
|
||||
return errors.Wrap(err, "failed to initialise user")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// addNewUser adds a new user.
|
||||
func (u *Users) addNewUser(apiUser *pmapi.User, auth *pmapi.Auth, hashedPassphrase string) (err error) {
|
||||
func (u *Users) addNewUser(ctx context.Context, client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
client := u.clientManager.GetClient(apiUser.ID)
|
||||
var emails []string
|
||||
|
||||
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh token in new client")
|
||||
}
|
||||
|
||||
if apiUser, err = client.CurrentUser(); err != nil {
|
||||
return errors.Wrap(err, "failed to update API user")
|
||||
}
|
||||
|
||||
var emails []string //nolint[prealloc]
|
||||
if u.useOnlyActiveAddresses {
|
||||
emails = client.Addresses().ActiveEmails()
|
||||
} else {
|
||||
emails = client.Addresses().AllEmails()
|
||||
}
|
||||
|
||||
if _, err = u.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), hashedPassphrase, emails); err != nil {
|
||||
return errors.Wrap(err, "failed to add user to credentials store")
|
||||
if _, err := u.credStorer.Add(apiUser.ID, apiUser.Name, auth.UID, auth.RefreshToken, string(passphrase), emails); err != nil {
|
||||
return errors.Wrap(err, "failed to add user credentials to credentials store")
|
||||
}
|
||||
|
||||
user, err := newUser(u.panicHandler, apiUser.ID, u.events, u.credStorer, u.clientManager, u.storeFactory)
|
||||
user, creds, err := newUser(u.panicHandler, apiUser.ID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create user")
|
||||
return errors.Wrap(err, "failed to create new user")
|
||||
}
|
||||
|
||||
// The user needs to be part of the users list in order for it to receive an auth during initialisation.
|
||||
u.users = append(u.users, user)
|
||||
|
||||
if err = user.init(); err != nil {
|
||||
u.users = u.users[:len(u.users)-1]
|
||||
return errors.Wrap(err, "failed to initialise user")
|
||||
if err := user.connect(ctx, client, creds); err != nil {
|
||||
return errors.Wrap(err, "failed to connect new user")
|
||||
}
|
||||
|
||||
if err := u.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel)); err != nil {
|
||||
log.WithError(err).Error("Failed to send metric")
|
||||
}
|
||||
|
||||
return err
|
||||
u.users = append(u.users, user)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAPIUser(client pmapi.Client, mbPassphrase string) (user *pmapi.User, hashedPassphrase string, err error) {
|
||||
salt, err := client.AuthSalt()
|
||||
func getAPIUser(ctx context.Context, client pmapi.Client, password string) (*pmapi.User, []byte, error) {
|
||||
salt, err := client.AuthSalt(ctx)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not get salt")
|
||||
return nil, "", err
|
||||
return nil, nil, errors.Wrap(err, "failed to get salt")
|
||||
}
|
||||
|
||||
hashedPassphrase, err = pmapi.HashMailboxPassword(mbPassphrase, salt)
|
||||
passphrase, err := pmapi.HashMailboxPassword(password, salt)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not hash mailbox password")
|
||||
return nil, "", err
|
||||
return nil, nil, errors.Wrap(err, "failed to hash password")
|
||||
}
|
||||
|
||||
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
|
||||
if err = client.Unlock([]byte(hashedPassphrase)); err != nil {
|
||||
log.WithError(err).Error("Wrong mailbox password")
|
||||
return nil, "", ErrWrongMailboxPassword
|
||||
if err := client.Unlock(ctx, passphrase); err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to unlock client")
|
||||
}
|
||||
|
||||
if user, err = client.CurrentUser(); err != nil {
|
||||
log.WithError(err).Error("Could not load user data")
|
||||
return nil, "", err
|
||||
user, err := client.CurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to load user data")
|
||||
}
|
||||
|
||||
return user, hashedPassphrase, nil
|
||||
return user, passphrase, nil
|
||||
}
|
||||
|
||||
// GetUsers returns all added users into keychain (even logged out users).
|
||||
@ -452,11 +396,9 @@ func (u *Users) DeleteUser(userID string, clearStore bool) error {
|
||||
|
||||
// SendMetric sends a metric. We don't want to return any errors, only log them.
|
||||
func (u *Users) SendMetric(m metrics.Metric) error {
|
||||
c := u.clientManager.GetAnonymousClient()
|
||||
defer c.Logout()
|
||||
|
||||
cat, act, lab := m.Get()
|
||||
if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil {
|
||||
|
||||
if err := u.clientManager.SendSimpleMetric(context.Background(), string(cat), string(act), string(lab)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -472,24 +414,22 @@ func (u *Users) SendMetric(m metrics.Metric) error {
|
||||
// AllowProxy instructs the app to use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||
func (u *Users) AllowProxy() {
|
||||
u.clientManager.AllowProxy()
|
||||
// FIXME(conman): Support DoH.
|
||||
// u.apiManager.AllowProxy()
|
||||
}
|
||||
|
||||
// DisallowProxy instructs the app to not use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||
func (u *Users) DisallowProxy() {
|
||||
u.clientManager.DisallowProxy()
|
||||
// FIXME(conman): Support DoH.
|
||||
// u.apiManager.DisallowProxy()
|
||||
}
|
||||
|
||||
// CheckConnection returns whether there is an internet connection.
|
||||
// This should use the connection manager when it is eventually implemented.
|
||||
func (u *Users) CheckConnection() error {
|
||||
return u.clientManager.CheckConnection()
|
||||
}
|
||||
|
||||
// StopWatchers stops all goroutines.
|
||||
func (u *Users) StopWatchers() {
|
||||
close(u.stopAll)
|
||||
// FIXME(conman): Other parts of bridge that rely on this method should register as a connection observer.
|
||||
panic("TODO: register as a connection observer to get this information")
|
||||
}
|
||||
|
||||
// hasUser returns whether the struct currently has a user with ID `id`.
|
||||
|
||||
@ -20,8 +20,8 @@ package users
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
time "time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -49,20 +49,19 @@ func TestNewUsersWithDisconnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
// Basically every call client has get client manager.
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
|
||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
)
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||
}
|
||||
|
||||
/*
|
||||
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
@ -132,6 +131,7 @@ func TestNewUsersFirstStart(t *testing.T) {
|
||||
|
||||
testNewUsers(t, m)
|
||||
}
|
||||
*/
|
||||
|
||||
func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) {
|
||||
users := testNewUsers(t, m)
|
||||
|
||||
@ -48,18 +48,17 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
var (
|
||||
testAuth = &pmapi.Auth{ //nolint[gochecknoglobals]
|
||||
RefreshToken: "tok",
|
||||
}
|
||||
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
|
||||
RefreshToken: "reftok",
|
||||
UID: "uid",
|
||||
AccessToken: "acc",
|
||||
RefreshToken: "ref",
|
||||
}
|
||||
|
||||
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "user",
|
||||
Name: "username",
|
||||
Emails: "user@pm.me",
|
||||
APIToken: "token",
|
||||
APIToken: "uid:acc",
|
||||
MailboxPassword: "pass",
|
||||
BridgePassword: "0123456789abcdef",
|
||||
Version: "v1",
|
||||
@ -67,11 +66,12 @@ var (
|
||||
IsHidden: false,
|
||||
IsCombinedAddressMode: true,
|
||||
}
|
||||
|
||||
testCredentialsSplit = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "users",
|
||||
Name: "usersname",
|
||||
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
|
||||
APIToken: "token",
|
||||
APIToken: "uid:acc",
|
||||
MailboxPassword: "pass",
|
||||
BridgePassword: "0123456789abcdef",
|
||||
Version: "v1",
|
||||
@ -79,6 +79,7 @@ var (
|
||||
IsHidden: false,
|
||||
IsCombinedAddressMode: false,
|
||||
}
|
||||
|
||||
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "user",
|
||||
Name: "username",
|
||||
@ -92,6 +93,19 @@ var (
|
||||
IsCombinedAddressMode: true,
|
||||
}
|
||||
|
||||
testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "users",
|
||||
Name: "usersname",
|
||||
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
|
||||
APIToken: "",
|
||||
MailboxPassword: "",
|
||||
BridgePassword: "0123456789abcdef",
|
||||
Version: "v1",
|
||||
Timestamp: 123456789,
|
||||
IsHidden: false,
|
||||
IsCombinedAddressMode: false,
|
||||
}
|
||||
|
||||
testPMAPIUser = &pmapi.User{ //nolint[gochecknoglobals]
|
||||
ID: "user",
|
||||
Name: "username",
|
||||
@ -130,12 +144,12 @@ type mocks struct {
|
||||
ctrl *gomock.Controller
|
||||
locator *usersmocks.MockLocator
|
||||
PanicHandler *usersmocks.MockPanicHandler
|
||||
clientManager *usersmocks.MockClientManager
|
||||
credentialsStore *usersmocks.MockCredentialsStorer
|
||||
storeMaker *usersmocks.MockStoreMaker
|
||||
eventListener *MockListener
|
||||
|
||||
pmapiClient *pmapimocks.MockClient
|
||||
clientManager *pmapimocks.MockManager
|
||||
pmapiClient *pmapimocks.MockClient
|
||||
|
||||
storeCache *store.Cache
|
||||
}
|
||||
@ -171,12 +185,12 @@ func initMocks(t *testing.T) mocks {
|
||||
ctrl: mockCtrl,
|
||||
locator: usersmocks.NewMockLocator(mockCtrl),
|
||||
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
|
||||
clientManager: usersmocks.NewMockClientManager(mockCtrl),
|
||||
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
|
||||
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
|
||||
eventListener: NewMockListener(mockCtrl),
|
||||
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
clientManager: pmapimocks.NewMockManager(mockCtrl),
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
|
||||
storeCache: store.NewCache(cacheFile.Name()),
|
||||
}
|
||||
@ -189,7 +203,7 @@ func initMocks(t *testing.T) mocks {
|
||||
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
|
||||
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
|
||||
require.NoError(t, err, "could not get temporary file for store db")
|
||||
return store.New(sentryReporter, m.PanicHandler, user, m.clientManager, m.eventListener, dbFile.Name(), m.storeCache)
|
||||
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
|
||||
}).AnyTimes()
|
||||
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
|
||||
|
||||
@ -198,46 +212,42 @@ func initMocks(t *testing.T) mocks {
|
||||
|
||||
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
|
||||
// Events are asynchronous
|
||||
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil),
|
||||
|
||||
// Init for user.
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.credentialsStore.EXPECT().Get(testCredentials.UserID).Return(testCredentials, nil),
|
||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
||||
m.credentialsStore.EXPECT().UpdateToken(testCredentials.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().UpdatePassword(testCredentials.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
// Init for users.
|
||||
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
|
||||
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.credentialsStore.EXPECT().Get(testCredentialsSplit.UserID).Return(testCredentialsSplit, nil),
|
||||
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().AddAuthHandler(gomock.Any()),
|
||||
m.credentialsStore.EXPECT().UpdateToken(testCredentialsSplit.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentialsSplit, nil),
|
||||
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsSplit.UserID, testCredentialsSplit.MailboxPassword).Return(testCredentialsSplit, nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses),
|
||||
)
|
||||
|
||||
users := testNewUsers(t, m)
|
||||
|
||||
user, _ := users.GetUser("user")
|
||||
mockAuthUpdate(user, "reftok", m)
|
||||
|
||||
user, _ = users.GetUser("user")
|
||||
mockAuthUpdate(user, "reftok", m)
|
||||
|
||||
return users
|
||||
return testNewUsers(t, m)
|
||||
}
|
||||
|
||||
func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
|
||||
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
||||
m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan pmapi.ClientAuth))
|
||||
// FIXME(conman): How to handle force upgrade?
|
||||
// m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
||||
|
||||
users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker, true)
|
||||
|
||||
@ -256,8 +266,8 @@ func TestClearData(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
// m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
// m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
@ -267,13 +277,11 @@ func TestClearData(t *testing.T) {
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
|
||||
|
||||
m.pmapiClient.EXPECT().Logout()
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil)
|
||||
|
||||
m.pmapiClient.EXPECT().Logout()
|
||||
m.credentialsStore.EXPECT().Logout("users").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil)
|
||||
m.pmapiClient.EXPECT().AuthDelete(gomock.Any())
|
||||
m.credentialsStore.EXPECT().Logout("users").Return(testCredentialsSplitDisconnected, nil)
|
||||
|
||||
m.locator.EXPECT().Clear()
|
||||
|
||||
@ -285,9 +293,9 @@ func TestClearData(t *testing.T) {
|
||||
func mockEventLoopNoAction(m mocks) {
|
||||
// Set up mocks for starting the store's event loop (in store.New).
|
||||
// The event loop runs in another goroutine so this might happen at any time.
|
||||
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
}
|
||||
|
||||
func mockConnectedUser(m mocks) {
|
||||
@ -295,27 +303,13 @@ func mockConnectedUser(m mocks) {
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
// m.pmapiClient.EXPECT().AuthRefresh("uid:acc").Return(testAuthRefresh, nil),
|
||||
|
||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
|
||||
// Set up mocks for store initialisation for the authorized user.
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
)
|
||||
}
|
||||
|
||||
// mockAuthUpdate simulates users calling UpdateAuthToken on the given user.
|
||||
// This would normally be done by users when it receives an auth from the ClientManager,
|
||||
// but as we don't have a full users instance here, we do this manually.
|
||||
func mockAuthUpdate(user *User, token string, m mocks) {
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":"+token).Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(token), nil),
|
||||
)
|
||||
|
||||
user.updateAuthToken(refreshWithToken(token))
|
||||
|
||||
waitForEvents()
|
||||
}
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
// IsAuthorized returns whether the user has received an Auth from the API yet.
|
||||
func (u *User) IsAuthorized() bool {
|
||||
return u.isAuthorized
|
||||
}
|
||||
Reference in New Issue
Block a user