mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-15 14:56:42 +00:00
Renamed bridge to general users and keep bridge only for bridge stuff
This commit is contained in:
136
internal/users/credentials/credentials.go
Normal file
136
internal/users/credentials/credentials.go
Normal file
@ -0,0 +1,136 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package credentials implements our struct stored in keychain.
|
||||
// Store struct is kind of like a database client.
|
||||
// Credentials struct is kind of like one record from the database.
|
||||
package credentials
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const sep = "\x00"
|
||||
|
||||
var (
|
||||
log = logrus.WithField("pkg", "credentials") //nolint[gochecknoglobals]
|
||||
|
||||
ErrWrongFormat = errors.New("backend/creds: malformed password")
|
||||
)
|
||||
|
||||
type Credentials struct {
|
||||
UserID, // Do not marshal; used as a key.
|
||||
Name,
|
||||
Emails,
|
||||
APIToken,
|
||||
MailboxPassword,
|
||||
BridgePassword,
|
||||
Version string
|
||||
Timestamp int64
|
||||
IsHidden, // Deprecated.
|
||||
IsCombinedAddressMode bool
|
||||
}
|
||||
|
||||
func (s *Credentials) Marshal() string {
|
||||
items := []string{
|
||||
s.Name, // 0
|
||||
s.Emails, // 1
|
||||
s.APIToken, // 2
|
||||
s.MailboxPassword, // 3
|
||||
s.BridgePassword, // 4
|
||||
s.Version, // 5
|
||||
"", // 6
|
||||
"", // 7
|
||||
"", // 8
|
||||
}
|
||||
|
||||
items[6] = fmt.Sprint(s.Timestamp)
|
||||
|
||||
if s.IsHidden {
|
||||
items[7] = "1"
|
||||
}
|
||||
|
||||
if s.IsCombinedAddressMode {
|
||||
items[8] = "1"
|
||||
}
|
||||
|
||||
str := strings.Join(items, sep)
|
||||
return base64.StdEncoding.EncodeToString([]byte(str))
|
||||
}
|
||||
|
||||
func (s *Credentials) Unmarshal(secret string) error {
|
||||
b, err := base64.StdEncoding.DecodeString(secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items := strings.Split(string(b), sep)
|
||||
|
||||
if len(items) != 9 {
|
||||
return ErrWrongFormat
|
||||
}
|
||||
|
||||
s.Name = items[0]
|
||||
s.Emails = items[1]
|
||||
s.APIToken = items[2]
|
||||
s.MailboxPassword = items[3]
|
||||
s.BridgePassword = items[4]
|
||||
s.Version = items[5]
|
||||
if _, err = fmt.Sscan(items[6], &s.Timestamp); err != nil {
|
||||
s.Timestamp = 0
|
||||
}
|
||||
if s.IsHidden = false; items[7] == "1" {
|
||||
s.IsHidden = true
|
||||
}
|
||||
if s.IsCombinedAddressMode = false; items[8] == "1" {
|
||||
s.IsCombinedAddressMode = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Credentials) SetEmailList(list []string) {
|
||||
s.Emails = strings.Join(list, ";")
|
||||
}
|
||||
|
||||
func (s *Credentials) EmailList() []string {
|
||||
return strings.Split(s.Emails, ";")
|
||||
}
|
||||
|
||||
func (s *Credentials) CheckPassword(password string) error {
|
||||
if subtle.ConstantTimeCompare([]byte(s.BridgePassword), []byte(password)) != 1 {
|
||||
log.WithFields(logrus.Fields{
|
||||
"userID": s.UserID,
|
||||
}).Debug("Incorrect bridge password")
|
||||
|
||||
return fmt.Errorf("backend/credentials: incorrect password")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Credentials) Logout() {
|
||||
s.APIToken = ""
|
||||
s.MailboxPassword = ""
|
||||
}
|
||||
|
||||
func (s *Credentials) IsConnected() bool {
|
||||
return s.APIToken != "" && s.MailboxPassword != ""
|
||||
}
|
||||
39
internal/users/credentials/crypto.go
Normal file
39
internal/users/credentials/crypto.go
Normal file
@ -0,0 +1,39 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package credentials
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
)
|
||||
|
||||
const keySize = 16
|
||||
|
||||
// generateKey generates a new random key.
|
||||
func generateKey() []byte {
|
||||
key := make([]byte, keySize)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func generatePassword() string {
|
||||
return base64.RawURLEncoding.EncodeToString(generateKey())
|
||||
}
|
||||
330
internal/users/credentials/store.go
Normal file
330
internal/users/credentials/store.go
Normal file
@ -0,0 +1,330 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package credentials
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/keychain"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var storeLocker = sync.RWMutex{} //nolint[gochecknoglobals]
|
||||
|
||||
// Store is an encrypted credentials store.
|
||||
type Store struct {
|
||||
secrets *keychain.Access
|
||||
}
|
||||
|
||||
// NewStore creates a new encrypted credentials store.
|
||||
func NewStore(appName string) (*Store, error) {
|
||||
secrets, err := keychain.NewAccess(appName)
|
||||
return &Store{
|
||||
secrets: secrets,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails []string) (creds *Credentials, err error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
log.WithFields(logrus.Fields{
|
||||
"user": userID,
|
||||
"username": userName,
|
||||
"emails": emails,
|
||||
}).Trace("Adding new credentials")
|
||||
|
||||
if err = s.checkKeychain(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
creds = &Credentials{
|
||||
UserID: userID,
|
||||
Name: userName,
|
||||
APIToken: apiToken,
|
||||
MailboxPassword: mailboxPassword,
|
||||
IsHidden: false,
|
||||
}
|
||||
|
||||
creds.SetEmailList(emails)
|
||||
|
||||
var has bool
|
||||
if has, err = s.has(userID); err != nil {
|
||||
log.WithField("userID", userID).WithError(err).Error("Could not check if user credentials already exist")
|
||||
return
|
||||
}
|
||||
|
||||
if has {
|
||||
log.Info("Updating credentials of existing user")
|
||||
currentCredentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
creds.BridgePassword = currentCredentials.BridgePassword
|
||||
creds.IsCombinedAddressMode = currentCredentials.IsCombinedAddressMode
|
||||
creds.Timestamp = currentCredentials.Timestamp
|
||||
} else {
|
||||
log.Info("Generating credentials for new user")
|
||||
creds.BridgePassword = generatePassword()
|
||||
creds.IsCombinedAddressMode = true
|
||||
creds.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
if err = s.saveCredentials(creds); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return creds, err
|
||||
}
|
||||
|
||||
func (s *Store) SwitchAddressMode(userID string) error {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
credentials.IsCombinedAddressMode = !credentials.IsCombinedAddressMode
|
||||
credentials.BridgePassword = generatePassword()
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateEmails(userID string, emails []string) error {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
credentials.SetEmailList(emails)
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdatePassword(userID, password string) error {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
credentials.MailboxPassword = password
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateToken(userID, apiToken string) error {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
credentials.APIToken = apiToken
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
func (s *Store) Logout(userID string) error {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
credentials, err := s.get(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
credentials.Logout()
|
||||
|
||||
return s.saveCredentials(credentials)
|
||||
}
|
||||
|
||||
// List returns a list of usernames that have credentials stored.
|
||||
func (s *Store) List() (userIDs []string, err error) {
|
||||
storeLocker.RLock()
|
||||
defer storeLocker.RUnlock()
|
||||
|
||||
log.Trace("Listing credentials in credentials store")
|
||||
|
||||
if err = s.checkKeychain(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var allUserIDs []string
|
||||
if allUserIDs, err = s.secrets.List(); err != nil {
|
||||
log.WithError(err).Error("Could not list credentials")
|
||||
return
|
||||
}
|
||||
|
||||
credentialList := []*Credentials{}
|
||||
for _, userID := range allUserIDs {
|
||||
creds, getErr := s.get(userID)
|
||||
if getErr != nil {
|
||||
log.WithField("userID", userID).WithError(getErr).Warn("Failed to get credentials")
|
||||
continue
|
||||
}
|
||||
|
||||
if creds.Timestamp == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
credentialList = append(credentialList, creds)
|
||||
}
|
||||
|
||||
sort.Slice(credentialList, func(i, j int) bool {
|
||||
return credentialList[i].Timestamp < credentialList[j].Timestamp
|
||||
})
|
||||
|
||||
for _, credentials := range credentialList {
|
||||
userIDs = append(userIDs, credentials.UserID)
|
||||
}
|
||||
|
||||
return userIDs, err
|
||||
}
|
||||
|
||||
func (s *Store) GetAndCheckPassword(userID, password string) (creds *Credentials, err error) {
|
||||
storeLocker.RLock()
|
||||
defer storeLocker.RUnlock()
|
||||
|
||||
log.WithFields(logrus.Fields{
|
||||
"userID": userID,
|
||||
}).Debug("Checking bridge password")
|
||||
|
||||
credentials, err := s.Get(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := credentials.CheckPassword(password); err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"userID": userID,
|
||||
"err": err,
|
||||
}).Debug("Incorrect bridge password")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func (s *Store) Get(userID string) (creds *Credentials, err error) {
|
||||
storeLocker.RLock()
|
||||
defer storeLocker.RUnlock()
|
||||
|
||||
var has bool
|
||||
if has, err = s.has(userID); err != nil {
|
||||
log.WithError(err).Error("Could not check for credentials")
|
||||
return
|
||||
}
|
||||
|
||||
if !has {
|
||||
err = errors.New("no credentials found for given userID")
|
||||
return
|
||||
}
|
||||
|
||||
return s.get(userID)
|
||||
}
|
||||
|
||||
func (s *Store) has(userID string) (has bool, err error) {
|
||||
if err = s.checkKeychain(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var ids []string
|
||||
if ids, err = s.secrets.List(); err != nil {
|
||||
log.WithError(err).Error("Could not list credentials")
|
||||
return
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
if id == userID {
|
||||
has = true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Store) get(userID string) (creds *Credentials, err error) {
|
||||
log := log.WithField("user", userID)
|
||||
|
||||
if err = s.checkKeychain(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
secret, err := s.secrets.Get(userID)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not get credentials from native keychain")
|
||||
return
|
||||
}
|
||||
|
||||
credentials := &Credentials{UserID: userID}
|
||||
if err = credentials.Unmarshal(secret); err != nil {
|
||||
err = fmt.Errorf("backend/credentials: malformed secret: %v", err)
|
||||
_ = s.secrets.Delete(userID)
|
||||
log.WithError(err).Error("Could not unmarshal secret")
|
||||
return
|
||||
}
|
||||
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
// saveCredentials encrypts and saves password to the keychain store.
|
||||
func (s *Store) saveCredentials(credentials *Credentials) (err error) {
|
||||
if err = s.checkKeychain(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
credentials.Version = keychain.KeychainVersion
|
||||
|
||||
return s.secrets.Put(credentials.UserID, credentials.Marshal())
|
||||
}
|
||||
|
||||
func (s *Store) checkKeychain() (err error) {
|
||||
if s.secrets == nil {
|
||||
err = keychain.ErrNoKeychainInstalled
|
||||
log.WithError(err).Error("Store is unusable")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Delete removes credentials from the store.
|
||||
func (s *Store) Delete(userID string) (err error) {
|
||||
storeLocker.Lock()
|
||||
defer storeLocker.Unlock()
|
||||
|
||||
if err = s.checkKeychain(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return s.secrets.Delete(userID)
|
||||
}
|
||||
297
internal/users/credentials/store_test.go
Normal file
297
internal/users/credentials/store_test.go
Normal file
@ -0,0 +1,297 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package credentials
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testSep = "\n"
|
||||
const secretFormat = "%v" + testSep + // UserID,
|
||||
"%v" + testSep + // Name,
|
||||
"%v" + testSep + // Emails,
|
||||
"%v" + testSep + // APIToken,
|
||||
"%v" + testSep + // Mailbox,
|
||||
"%v" + testSep + // BridgePassword,
|
||||
"%v" + testSep + // Version string
|
||||
"%v" + testSep + // Timestamp,
|
||||
"%v" + testSep + // IsHidden,
|
||||
"%v" // IsCombinedAddressMode
|
||||
|
||||
// the best would be to run this test on mac, win, and linux separately
|
||||
|
||||
type testCredentials struct {
|
||||
UserID,
|
||||
Name,
|
||||
Emails,
|
||||
APIToken,
|
||||
Mailbox,
|
||||
BridgePassword,
|
||||
Version string
|
||||
Timestamp int64
|
||||
IsHidden,
|
||||
IsCombinedAddressMode bool
|
||||
}
|
||||
|
||||
func init() { //nolint[gochecknoinits]
|
||||
gob.Register(testCredentials{})
|
||||
}
|
||||
|
||||
func (s *testCredentials) MarshalGob() string {
|
||||
buf := bytes.Buffer{}
|
||||
enc := gob.NewEncoder(&buf)
|
||||
if err := enc.Encode(s); err != nil {
|
||||
return ""
|
||||
}
|
||||
fmt.Printf("MarshalGob: %#v\n", buf.String())
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *testCredentials) Clear() {
|
||||
s.UserID = ""
|
||||
s.Name = ""
|
||||
s.Emails = ""
|
||||
s.APIToken = ""
|
||||
s.Mailbox = ""
|
||||
s.BridgePassword = ""
|
||||
s.Version = ""
|
||||
s.Timestamp = 0
|
||||
s.IsHidden = false
|
||||
s.IsCombinedAddressMode = false
|
||||
}
|
||||
|
||||
func (s *testCredentials) UnmarshalGob(secret string) error {
|
||||
s.Clear()
|
||||
b, err := base64.StdEncoding.DecodeString(secret)
|
||||
if err != nil {
|
||||
fmt.Println("decode base64", b)
|
||||
return err
|
||||
}
|
||||
buf := bytes.NewBuffer(b)
|
||||
dec := gob.NewDecoder(buf)
|
||||
if err = dec.Decode(s); err != nil {
|
||||
fmt.Println("decode gob", b, buf.Bytes())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testCredentials) ToJSON() string {
|
||||
if b, err := json.Marshal(s); err == nil {
|
||||
fmt.Printf("MarshalJSON: %#v\n", string(b))
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *testCredentials) FromJSON(secret string) error {
|
||||
b, err := base64.StdEncoding.DecodeString(secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = json.Unmarshal(b, s); err == nil {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *testCredentials) MarshalFmt() string {
|
||||
buf := bytes.Buffer{}
|
||||
fmt.Fprintf(
|
||||
&buf, secretFormat,
|
||||
s.UserID,
|
||||
s.Name,
|
||||
s.Emails,
|
||||
s.APIToken,
|
||||
s.Mailbox,
|
||||
s.BridgePassword,
|
||||
s.Version,
|
||||
s.Timestamp,
|
||||
s.IsHidden,
|
||||
s.IsCombinedAddressMode,
|
||||
)
|
||||
fmt.Printf("MarshalFmt: %#v\n", buf.String())
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *testCredentials) UnmarshalFmt(secret string) error {
|
||||
b, err := base64.StdEncoding.DecodeString(secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf := bytes.NewBuffer(b)
|
||||
fmt.Println("decode fmt", b, buf.Bytes())
|
||||
_, err = fmt.Fscanf(
|
||||
buf, secretFormat,
|
||||
&s.UserID,
|
||||
&s.Name,
|
||||
&s.Emails,
|
||||
&s.APIToken,
|
||||
&s.Mailbox,
|
||||
&s.BridgePassword,
|
||||
&s.Version,
|
||||
&s.Timestamp,
|
||||
&s.IsHidden,
|
||||
&s.IsCombinedAddressMode,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testCredentials) MarshalStrings() string { // this is the most space efficient
|
||||
items := []string{
|
||||
s.UserID, // 0
|
||||
s.Name, // 1
|
||||
s.Emails, // 2
|
||||
s.APIToken, // 3
|
||||
s.Mailbox, // 4
|
||||
s.BridgePassword, // 5
|
||||
s.Version, // 6
|
||||
}
|
||||
items = append(items, fmt.Sprint(s.Timestamp)) // 7
|
||||
|
||||
if s.IsHidden { // 8
|
||||
items = append(items, "1")
|
||||
} else {
|
||||
items = append(items, "")
|
||||
}
|
||||
|
||||
if s.IsCombinedAddressMode { // 9
|
||||
items = append(items, "1")
|
||||
} else {
|
||||
items = append(items, "")
|
||||
}
|
||||
|
||||
str := strings.Join(items, sep)
|
||||
|
||||
fmt.Printf("MarshalJoin: %#v\n", str)
|
||||
return base64.StdEncoding.EncodeToString([]byte(str))
|
||||
}
|
||||
|
||||
func (s *testCredentials) UnmarshalStrings(secret string) error {
|
||||
b, err := base64.StdEncoding.DecodeString(secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items := strings.Split(string(b), sep)
|
||||
if len(items) != 10 {
|
||||
return ErrWrongFormat
|
||||
}
|
||||
|
||||
s.UserID = items[0]
|
||||
s.Name = items[1]
|
||||
s.Emails = items[2]
|
||||
s.APIToken = items[3]
|
||||
s.Mailbox = items[4]
|
||||
s.BridgePassword = items[5]
|
||||
s.Version = items[6]
|
||||
if _, err = fmt.Sscanf(items[7], "%d", &s.Timestamp); err != nil {
|
||||
s.Timestamp = 0
|
||||
}
|
||||
if s.IsHidden = false; items[8] == "1" {
|
||||
s.IsHidden = true
|
||||
}
|
||||
if s.IsCombinedAddressMode = false; items[9] == "1" {
|
||||
s.IsCombinedAddressMode = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testCredentials) IsSame(rhs *testCredentials) bool {
|
||||
return s.Name == rhs.Name &&
|
||||
s.Emails == rhs.Emails &&
|
||||
s.APIToken == rhs.APIToken &&
|
||||
s.Mailbox == rhs.Mailbox &&
|
||||
s.BridgePassword == rhs.BridgePassword &&
|
||||
s.Version == rhs.Version &&
|
||||
s.Timestamp == rhs.Timestamp &&
|
||||
s.IsHidden == rhs.IsHidden &&
|
||||
s.IsCombinedAddressMode == rhs.IsCombinedAddressMode
|
||||
}
|
||||
|
||||
func TestMarshalFormats(t *testing.T) {
|
||||
input := testCredentials{UserID: "007", Emails: "ja@pm.me;jakub@cu.th", Timestamp: 152469263742, IsHidden: true}
|
||||
fmt.Printf("input %#v\n", input)
|
||||
|
||||
secretStrings := input.MarshalStrings()
|
||||
fmt.Printf("secretStrings %#v %d\n", secretStrings, len(secretStrings))
|
||||
secretGob := input.MarshalGob()
|
||||
fmt.Printf("secretGob %#v %d\n", secretGob, len(secretGob))
|
||||
secretJSON := input.ToJSON()
|
||||
fmt.Printf("secretJSON %#v %d\n", secretJSON, len(secretJSON))
|
||||
secretFmt := input.MarshalFmt()
|
||||
fmt.Printf("secretFmt %#v %d\n", secretFmt, len(secretFmt))
|
||||
|
||||
output := testCredentials{APIToken: "refresh"}
|
||||
require.NoError(t, output.UnmarshalStrings(secretStrings))
|
||||
fmt.Printf("strings out %#v \n", output)
|
||||
require.True(t, input.IsSame(&output), "strings out not same")
|
||||
|
||||
output = testCredentials{APIToken: "refresh"}
|
||||
require.NoError(t, output.UnmarshalGob(secretGob))
|
||||
fmt.Printf("gob out %#v\n \n", output)
|
||||
assert.Equal(t, input, output)
|
||||
|
||||
output = testCredentials{APIToken: "refresh"}
|
||||
require.NoError(t, output.FromJSON(secretJSON))
|
||||
fmt.Printf("json out %#v \n", output)
|
||||
require.True(t, input.IsSame(&output), "json out not same")
|
||||
|
||||
/*
|
||||
// Simple Fscanf not working!
|
||||
output = testCredentials{APIToken: "refresh"}
|
||||
require.NoError(t, output.UnmarshalFmt(secretFmt))
|
||||
fmt.Printf("fmt out %#v \n", output)
|
||||
require.True(t, input.IsSame(&output), "fmt out not same")
|
||||
*/
|
||||
}
|
||||
|
||||
func TestMarshal(t *testing.T) {
|
||||
input := Credentials{
|
||||
UserID: "",
|
||||
Name: "007",
|
||||
Emails: "ja@pm.me;aj@cus.tom",
|
||||
APIToken: "sdfdsfsdfsdfsdf",
|
||||
MailboxPassword: "cdcdcdcd",
|
||||
BridgePassword: "wew123",
|
||||
Version: "k11",
|
||||
Timestamp: 152469263742,
|
||||
IsHidden: true,
|
||||
IsCombinedAddressMode: false,
|
||||
}
|
||||
fmt.Printf("input %#v\n", input)
|
||||
|
||||
secret := input.Marshal()
|
||||
fmt.Printf("secret %#v %d\n", secret, len(secret))
|
||||
|
||||
output := Credentials{APIToken: "refresh"}
|
||||
require.NoError(t, output.Unmarshal(secret))
|
||||
fmt.Printf("output %#v\n", output)
|
||||
assert.Equal(t, input, output)
|
||||
}
|
||||
107
internal/users/mock_listener.go
Normal file
107
internal/users/mock_listener.go
Normal file
@ -0,0 +1,107 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./listener/listener.go
|
||||
|
||||
// Package users is a generated GoMock package.
|
||||
package users
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockListener is a mock of Listener interface
|
||||
type MockListener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockListenerMockRecorder
|
||||
}
|
||||
|
||||
// MockListenerMockRecorder is the mock recorder for MockListener
|
||||
type MockListenerMockRecorder struct {
|
||||
mock *MockListener
|
||||
}
|
||||
|
||||
// NewMockListener creates a new mock instance
|
||||
func NewMockListener(ctrl *gomock.Controller) *MockListener {
|
||||
mock := &MockListener{ctrl: ctrl}
|
||||
mock.recorder = &MockListenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// SetLimit mocks base method
|
||||
func (m *MockListener) SetLimit(eventName string, limit time.Duration) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetLimit", eventName, limit)
|
||||
}
|
||||
|
||||
// SetLimit indicates an expected call of SetLimit
|
||||
func (mr *MockListenerMockRecorder) SetLimit(eventName, limit interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), eventName, limit)
|
||||
}
|
||||
|
||||
// Add mocks base method
|
||||
func (m *MockListener) Add(eventName string, channel chan<- string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Add", eventName, channel)
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add
|
||||
func (mr *MockListenerMockRecorder) Add(eventName, channel interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), eventName, channel)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockListener) Remove(eventName string, channel chan<- string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Remove", eventName, channel)
|
||||
}
|
||||
|
||||
// Remove indicates an expected call of Remove
|
||||
func (mr *MockListenerMockRecorder) Remove(eventName, channel interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), eventName, channel)
|
||||
}
|
||||
|
||||
// Emit mocks base method
|
||||
func (m *MockListener) Emit(eventName, data string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Emit", eventName, data)
|
||||
}
|
||||
|
||||
// Emit indicates an expected call of Emit
|
||||
func (mr *MockListenerMockRecorder) Emit(eventName, data interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), eventName, data)
|
||||
}
|
||||
|
||||
// SetBuffer mocks base method
|
||||
func (m *MockListener) SetBuffer(eventName string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetBuffer", eventName)
|
||||
}
|
||||
|
||||
// SetBuffer indicates an expected call of SetBuffer
|
||||
func (mr *MockListenerMockRecorder) SetBuffer(eventName interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), eventName)
|
||||
}
|
||||
|
||||
// RetryEmit mocks base method
|
||||
func (m *MockListener) RetryEmit(eventName string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "RetryEmit", eventName)
|
||||
}
|
||||
|
||||
// RetryEmit indicates an expected call of RetryEmit
|
||||
func (mr *MockListenerMockRecorder) RetryEmit(eventName interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), eventName)
|
||||
}
|
||||
485
internal/users/mocks/mocks.go
Normal file
485
internal/users/mocks/mocks.go
Normal file
@ -0,0 +1,485 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/users (interfaces: Configer,PreferenceProvider,PanicHandler,ClientManager,CredentialsStorer)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockConfiger is a mock of Configer interface
|
||||
type MockConfiger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockConfigerMockRecorder
|
||||
}
|
||||
|
||||
// MockConfigerMockRecorder is the mock recorder for MockConfiger
|
||||
type MockConfigerMockRecorder struct {
|
||||
mock *MockConfiger
|
||||
}
|
||||
|
||||
// NewMockConfiger creates a new mock instance
|
||||
func NewMockConfiger(ctrl *gomock.Controller) *MockConfiger {
|
||||
mock := &MockConfiger{ctrl: ctrl}
|
||||
mock.recorder = &MockConfigerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockConfiger) EXPECT() *MockConfigerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ClearData mocks base method
|
||||
func (m *MockConfiger) ClearData() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClearData")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClearData indicates an expected call of ClearData
|
||||
func (mr *MockConfigerMockRecorder) ClearData() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearData", reflect.TypeOf((*MockConfiger)(nil).ClearData))
|
||||
}
|
||||
|
||||
// GetAPIConfig mocks base method
|
||||
func (m *MockConfiger) GetAPIConfig() *pmapi.ClientConfig {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAPIConfig")
|
||||
ret0, _ := ret[0].(*pmapi.ClientConfig)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAPIConfig indicates an expected call of GetAPIConfig
|
||||
func (mr *MockConfigerMockRecorder) GetAPIConfig() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIConfig", reflect.TypeOf((*MockConfiger)(nil).GetAPIConfig))
|
||||
}
|
||||
|
||||
// GetDBDir mocks base method
|
||||
func (m *MockConfiger) GetDBDir() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDBDir")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetDBDir indicates an expected call of GetDBDir
|
||||
func (mr *MockConfigerMockRecorder) GetDBDir() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDBDir", reflect.TypeOf((*MockConfiger)(nil).GetDBDir))
|
||||
}
|
||||
|
||||
// GetIMAPCachePath mocks base method
|
||||
func (m *MockConfiger) GetIMAPCachePath() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetIMAPCachePath")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetIMAPCachePath indicates an expected call of GetIMAPCachePath
|
||||
func (mr *MockConfigerMockRecorder) GetIMAPCachePath() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIMAPCachePath", reflect.TypeOf((*MockConfiger)(nil).GetIMAPCachePath))
|
||||
}
|
||||
|
||||
// GetVersion mocks base method
|
||||
func (m *MockConfiger) GetVersion() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetVersion")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetVersion indicates an expected call of GetVersion
|
||||
func (mr *MockConfigerMockRecorder) GetVersion() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockConfiger)(nil).GetVersion))
|
||||
}
|
||||
|
||||
// MockPreferenceProvider is a mock of PreferenceProvider interface
|
||||
type MockPreferenceProvider struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPreferenceProviderMockRecorder
|
||||
}
|
||||
|
||||
// MockPreferenceProviderMockRecorder is the mock recorder for MockPreferenceProvider
|
||||
type MockPreferenceProviderMockRecorder struct {
|
||||
mock *MockPreferenceProvider
|
||||
}
|
||||
|
||||
// NewMockPreferenceProvider creates a new mock instance
|
||||
func NewMockPreferenceProvider(ctrl *gomock.Controller) *MockPreferenceProvider {
|
||||
mock := &MockPreferenceProvider{ctrl: ctrl}
|
||||
mock.recorder = &MockPreferenceProviderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockPreferenceProvider) EXPECT() *MockPreferenceProviderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Get mocks base method
|
||||
func (m *MockPreferenceProvider) Get(arg0 string) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get
|
||||
func (mr *MockPreferenceProviderMockRecorder) Get(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPreferenceProvider)(nil).Get), arg0)
|
||||
}
|
||||
|
||||
// GetBool mocks base method
|
||||
func (m *MockPreferenceProvider) GetBool(arg0 string) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetBool", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetBool indicates an expected call of GetBool
|
||||
func (mr *MockPreferenceProviderMockRecorder) GetBool(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBool", reflect.TypeOf((*MockPreferenceProvider)(nil).GetBool), arg0)
|
||||
}
|
||||
|
||||
// GetInt mocks base method
|
||||
func (m *MockPreferenceProvider) GetInt(arg0 string) int {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetInt", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetInt indicates an expected call of GetInt
|
||||
func (mr *MockPreferenceProviderMockRecorder) GetInt(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInt", reflect.TypeOf((*MockPreferenceProvider)(nil).GetInt), arg0)
|
||||
}
|
||||
|
||||
// Set mocks base method
|
||||
func (m *MockPreferenceProvider) Set(arg0, arg1 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Set", arg0, arg1)
|
||||
}
|
||||
|
||||
// Set indicates an expected call of Set
|
||||
func (mr *MockPreferenceProviderMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockPreferenceProvider)(nil).Set), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockPanicHandler is a mock of PanicHandler interface
|
||||
type MockPanicHandler struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPanicHandlerMockRecorder
|
||||
}
|
||||
|
||||
// MockPanicHandlerMockRecorder is the mock recorder for MockPanicHandler
|
||||
type MockPanicHandlerMockRecorder struct {
|
||||
mock *MockPanicHandler
|
||||
}
|
||||
|
||||
// NewMockPanicHandler creates a new mock instance
|
||||
func NewMockPanicHandler(ctrl *gomock.Controller) *MockPanicHandler {
|
||||
mock := &MockPanicHandler{ctrl: ctrl}
|
||||
mock.recorder = &MockPanicHandlerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockPanicHandler) EXPECT() *MockPanicHandlerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// HandlePanic mocks base method
|
||||
func (m *MockPanicHandler) HandlePanic() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "HandlePanic")
|
||||
}
|
||||
|
||||
// HandlePanic indicates an expected call of HandlePanic
|
||||
func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockClientManager is a mock of ClientManager interface
|
||||
type MockClientManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
|
||||
type MockClientManagerMockRecorder struct {
|
||||
mock *MockClientManager
|
||||
}
|
||||
|
||||
// NewMockClientManager creates a new mock instance
|
||||
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
|
||||
mock := &MockClientManager{ctrl: ctrl}
|
||||
mock.recorder = &MockClientManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AllowProxy mocks base method
|
||||
func (m *MockClientManager) AllowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AllowProxy")
|
||||
}
|
||||
|
||||
// AllowProxy indicates an expected call of AllowProxy
|
||||
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// CheckConnection mocks base method
|
||||
func (m *MockClientManager) CheckConnection() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CheckConnection")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CheckConnection indicates an expected call of CheckConnection
|
||||
func (mr *MockClientManagerMockRecorder) CheckConnection() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConnection", reflect.TypeOf((*MockClientManager)(nil).CheckConnection))
|
||||
}
|
||||
|
||||
// DisallowProxy mocks base method
|
||||
func (m *MockClientManager) DisallowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisallowProxy")
|
||||
}
|
||||
|
||||
// DisallowProxy indicates an expected call of DisallowProxy
|
||||
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
|
||||
}
|
||||
|
||||
// GetAnonymousClient mocks base method
|
||||
func (m *MockClientManager) GetAnonymousClient() pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAnonymousClient")
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAnonymousClient indicates an expected call of GetAnonymousClient
|
||||
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
|
||||
}
|
||||
|
||||
// GetAuthUpdateChannel mocks base method
|
||||
func (m *MockClientManager) GetAuthUpdateChannel() chan pmapi.ClientAuth {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthUpdateChannel")
|
||||
ret0, _ := ret[0].(chan pmapi.ClientAuth)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAuthUpdateChannel indicates an expected call of GetAuthUpdateChannel
|
||||
func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetAuthUpdateChannel))
|
||||
}
|
||||
|
||||
// GetClient mocks base method
|
||||
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClient", arg0)
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClient indicates an expected call of GetClient
|
||||
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
|
||||
}
|
||||
|
||||
// SetUserAgent mocks base method
|
||||
func (m *MockClientManager) SetUserAgent(arg0, arg1, arg2 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetUserAgent", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetUserAgent indicates an expected call of SetUserAgent
|
||||
func (mr *MockClientManagerMockRecorder) SetUserAgent(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserAgent", reflect.TypeOf((*MockClientManager)(nil).SetUserAgent), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// MockCredentialsStorer is a mock of CredentialsStorer interface
|
||||
type MockCredentialsStorer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCredentialsStorerMockRecorder
|
||||
}
|
||||
|
||||
// MockCredentialsStorerMockRecorder is the mock recorder for MockCredentialsStorer
|
||||
type MockCredentialsStorerMockRecorder struct {
|
||||
mock *MockCredentialsStorer
|
||||
}
|
||||
|
||||
// NewMockCredentialsStorer creates a new mock instance
|
||||
func NewMockCredentialsStorer(ctrl *gomock.Controller) *MockCredentialsStorer {
|
||||
mock := &MockCredentialsStorer{ctrl: ctrl}
|
||||
mock.recorder = &MockCredentialsStorerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockCredentialsStorer) EXPECT() *MockCredentialsStorerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Add mocks base method
|
||||
func (m *MockCredentialsStorer) Add(arg0, arg1, arg2, arg3 string, arg4 []string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add
|
||||
func (mr *MockCredentialsStorerMockRecorder) Add(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockCredentialsStorer)(nil).Add), arg0, arg1, arg2, arg3, arg4)
|
||||
}
|
||||
|
||||
// Delete mocks base method
|
||||
func (m *MockCredentialsStorer) Delete(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete
|
||||
func (mr *MockCredentialsStorerMockRecorder) Delete(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockCredentialsStorer)(nil).Delete), arg0)
|
||||
}
|
||||
|
||||
// Get mocks base method
|
||||
func (m *MockCredentialsStorer) Get(arg0 string) (*credentials.Credentials, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", arg0)
|
||||
ret0, _ := ret[0].(*credentials.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get
|
||||
func (mr *MockCredentialsStorerMockRecorder) Get(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCredentialsStorer)(nil).Get), arg0)
|
||||
}
|
||||
|
||||
// List mocks base method
|
||||
func (m *MockCredentialsStorer) List() ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "List")
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// List indicates an expected call of List
|
||||
func (mr *MockCredentialsStorerMockRecorder) List() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockCredentialsStorer)(nil).List))
|
||||
}
|
||||
|
||||
// Logout mocks base method
|
||||
func (m *MockCredentialsStorer) Logout(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Logout", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Logout indicates an expected call of Logout
|
||||
func (mr *MockCredentialsStorerMockRecorder) Logout(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logout", reflect.TypeOf((*MockCredentialsStorer)(nil).Logout), arg0)
|
||||
}
|
||||
|
||||
// SwitchAddressMode mocks base method
|
||||
func (m *MockCredentialsStorer) SwitchAddressMode(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SwitchAddressMode", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SwitchAddressMode indicates an expected call of SwitchAddressMode
|
||||
func (mr *MockCredentialsStorerMockRecorder) SwitchAddressMode(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SwitchAddressMode", reflect.TypeOf((*MockCredentialsStorer)(nil).SwitchAddressMode), arg0)
|
||||
}
|
||||
|
||||
// UpdateEmails mocks base method
|
||||
func (m *MockCredentialsStorer) UpdateEmails(arg0 string, arg1 []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateEmails", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateEmails indicates an expected call of UpdateEmails
|
||||
func (mr *MockCredentialsStorerMockRecorder) UpdateEmails(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEmails", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateEmails), arg0, arg1)
|
||||
}
|
||||
|
||||
// UpdatePassword mocks base method
|
||||
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdatePassword indicates an expected call of UpdatePassword
|
||||
func (mr *MockCredentialsStorerMockRecorder) UpdatePassword(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdatePassword), arg0, arg1)
|
||||
}
|
||||
|
||||
// UpdateToken mocks base method
|
||||
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateToken", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateToken indicates an expected call of UpdateToken
|
||||
func (mr *MockCredentialsStorerMockRecorder) UpdateToken(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateToken), arg0, arg1)
|
||||
}
|
||||
64
internal/users/types.go
Normal file
64
internal/users/types.go
Normal file
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
type Configer interface {
|
||||
ClearData() error
|
||||
GetDBDir() string
|
||||
GetVersion() string
|
||||
GetIMAPCachePath() string
|
||||
GetAPIConfig() *pmapi.ClientConfig
|
||||
}
|
||||
|
||||
type PreferenceProvider interface {
|
||||
Get(key string) string
|
||||
GetBool(key string) bool
|
||||
GetInt(key string) int
|
||||
Set(key string, value string)
|
||||
}
|
||||
|
||||
type PanicHandler interface {
|
||||
HandlePanic()
|
||||
}
|
||||
|
||||
type CredentialsStorer interface {
|
||||
List() (userIDs []string, err error)
|
||||
Add(userID, userName, apiToken, mailboxPassword string, emails []string) (*credentials.Credentials, error)
|
||||
Get(userID string) (*credentials.Credentials, error)
|
||||
SwitchAddressMode(userID string) error
|
||||
UpdateEmails(userID string, emails []string) error
|
||||
UpdatePassword(userID, password string) error
|
||||
UpdateToken(userID, apiToken string) error
|
||||
Logout(userID string) error
|
||||
Delete(userID string) error
|
||||
}
|
||||
|
||||
type ClientManager interface {
|
||||
GetClient(userID string) pmapi.Client
|
||||
GetAnonymousClient() pmapi.Client
|
||||
AllowProxy()
|
||||
DisallowProxy()
|
||||
GetAuthUpdateChannel() chan pmapi.ClientAuth
|
||||
CheckConnection() error
|
||||
SetUserAgent(clientName, clientVersion, os string)
|
||||
}
|
||||
588
internal/users/user.go
Normal file
588
internal/users/user.go
Normal file
@ -0,0 +1,588 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/store"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
imapBackend "github.com/emersion/go-imap/backend"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ErrLoggedOutUser is sent to IMAP and SMTP if user exists, password is OK but user is logged out from the app.
|
||||
var ErrLoggedOutUser = errors.New("account is logged out, use the app to login again")
|
||||
|
||||
// User is a struct on top of API client and credentials store.
|
||||
type User struct {
|
||||
log *logrus.Entry
|
||||
panicHandler PanicHandler
|
||||
listener listener.Listener
|
||||
clientManager ClientManager
|
||||
credStorer CredentialsStorer
|
||||
|
||||
imapUpdatesChannel chan imapBackend.Update
|
||||
|
||||
store *store.Store
|
||||
storeCache *store.Cache
|
||||
storePath string
|
||||
|
||||
userID string
|
||||
creds *credentials.Credentials
|
||||
|
||||
lock sync.RWMutex
|
||||
isAuthorized bool
|
||||
|
||||
unlockingKeyringLock sync.Mutex
|
||||
wasKeyringUnlocked bool
|
||||
}
|
||||
|
||||
// newUser creates a new user.
|
||||
func newUser(
|
||||
panicHandler PanicHandler,
|
||||
userID string,
|
||||
eventListener listener.Listener,
|
||||
credStorer CredentialsStorer,
|
||||
clientManager ClientManager,
|
||||
storeCache *store.Cache,
|
||||
storeDir string,
|
||||
) (u *User, err error) {
|
||||
log := log.WithField("user", userID)
|
||||
log.Debug("Creating or loading user")
|
||||
|
||||
creds, err := credStorer.Get(userID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load user credentials")
|
||||
}
|
||||
|
||||
u = &User{
|
||||
log: log,
|
||||
panicHandler: panicHandler,
|
||||
listener: eventListener,
|
||||
credStorer: credStorer,
|
||||
clientManager: clientManager,
|
||||
storeCache: storeCache,
|
||||
storePath: getUserStorePath(storeDir, userID),
|
||||
userID: userID,
|
||||
creds: creds,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (u *User) client() pmapi.Client {
|
||||
return u.clientManager.GetClient(u.userID)
|
||||
}
|
||||
|
||||
// init initialises a user. This includes reloading its credentials from the credentials store
|
||||
// (such as when logging out and back in, you need to reload the credentials because the new credentials will
|
||||
// have the apitoken and password), authorising the user against the api, loading the user store (creating a new one
|
||||
// if necessary), and setting the imap idle updates channel (used to send imap idle updates to the imap backend if
|
||||
// something in the store changed).
|
||||
func (u *User) init(idleUpdates chan imapBackend.Update) (err error) {
|
||||
u.unlockingKeyringLock.Lock()
|
||||
u.wasKeyringUnlocked = false
|
||||
u.unlockingKeyringLock.Unlock()
|
||||
|
||||
u.log.Info("Initialising user")
|
||||
|
||||
// Reload the user's credentials (if they log out and back in we need the new
|
||||
// version with the apitoken and mailbox password).
|
||||
creds, err := u.credStorer.Get(u.userID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to load user credentials")
|
||||
}
|
||||
u.creds = creds
|
||||
|
||||
// Try to authorise the user if they aren't already authorised.
|
||||
// Note: we still allow users to set up accounts if the internet is off.
|
||||
if authErr := u.authorizeIfNecessary(false); authErr != nil {
|
||||
switch errors.Cause(authErr) {
|
||||
case pmapi.ErrAPINotReachable, pmapi.ErrUpgradeApplication, ErrLoggedOutUser:
|
||||
u.log.WithError(authErr).Warn("Could not authorize user")
|
||||
default:
|
||||
if logoutErr := u.logout(); logoutErr != nil {
|
||||
u.log.WithError(logoutErr).Warn("Could not logout user")
|
||||
}
|
||||
return errors.Wrap(authErr, "failed to authorize user")
|
||||
}
|
||||
}
|
||||
|
||||
// Logged-out user keeps store running to access offline data.
|
||||
// Therefore it is necessary to close it before re-init.
|
||||
if u.store != nil {
|
||||
if err := u.store.Close(); err != nil {
|
||||
log.WithError(err).Error("Not able to close store")
|
||||
}
|
||||
u.store = nil
|
||||
}
|
||||
store, err := store.New(u.panicHandler, u, u.clientManager, u.listener, u.storePath, u.storeCache)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create store")
|
||||
}
|
||||
u.store = store
|
||||
|
||||
// Save the imap updates channel here so it can be set later when imap connects.
|
||||
u.imapUpdatesChannel = idleUpdates
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *User) SetIMAPIdleUpdateChannel() {
|
||||
if u.store == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.store.SetIMAPUpdateChannel(u.imapUpdatesChannel)
|
||||
}
|
||||
|
||||
// authorizeIfNecessary checks whether user is logged in and is connected to api auth channel.
|
||||
// If user is not already connected to the api auth channel (for example there was no internet during start),
|
||||
// it tries to connect it.
|
||||
func (u *User) authorizeIfNecessary(emitEvent bool) (err error) {
|
||||
// If user is connected and has an auth channel, then perfect, nothing to do here.
|
||||
if u.creds.IsConnected() && u.isAuthorized {
|
||||
// The keyring unlock is triggered here to resolve state where apiClient
|
||||
// is authenticated (we have auth token) but it was not possible to download
|
||||
// and unlock the keys (internet not reachable).
|
||||
return u.unlockIfNecessary()
|
||||
}
|
||||
|
||||
if !u.creds.IsConnected() {
|
||||
err = ErrLoggedOutUser
|
||||
} else if err = u.authorizeAndUnlock(); err != nil {
|
||||
u.log.WithError(err).Error("Could not authorize and unlock user")
|
||||
|
||||
switch errors.Cause(err) {
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
u.listener.Emit(events.UpgradeApplicationEvent, "")
|
||||
|
||||
case pmapi.ErrAPINotReachable:
|
||||
u.listener.Emit(events.InternetOffEvent, "")
|
||||
|
||||
default:
|
||||
if errLogout := u.credStorer.Logout(u.userID); errLogout != nil {
|
||||
u.log.WithField("err", errLogout).Error("Could not log user out from credentials store")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if emitEvent && err != nil &&
|
||||
errors.Cause(err) != pmapi.ErrUpgradeApplication &&
|
||||
errors.Cause(err) != pmapi.ErrAPINotReachable {
|
||||
u.listener.Emit(events.LogoutEvent, u.userID)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
|
||||
func (u *User) unlockIfNecessary() error {
|
||||
u.unlockingKeyringLock.Lock()
|
||||
defer u.unlockingKeyringLock.Unlock()
|
||||
|
||||
if u.wasKeyringUnlocked {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := u.client().Unlock(u.creds.MailboxPassword); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
|
||||
if err := u.client().UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user addresses")
|
||||
}
|
||||
|
||||
u.wasKeyringUnlocked = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorizeAndUnlock tries to authorize the user with the API using the the user's APIToken.
|
||||
// If that succeeds, it tries to unlock the user's keys and addresses.
|
||||
func (u *User) authorizeAndUnlock() (err error) {
|
||||
if u.creds.APIToken == "" {
|
||||
u.log.Warn("Could not connect to API auth channel, have no API token")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := u.client().AuthRefresh(u.creds.APIToken); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh API auth")
|
||||
}
|
||||
|
||||
if _, err = u.client().Unlock(u.creds.MailboxPassword); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
|
||||
if err = u.client().UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user addresses")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) updateAuthToken(auth *pmapi.Auth) {
|
||||
u.log.Debug("User received auth")
|
||||
|
||||
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
|
||||
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
|
||||
return
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
|
||||
u.isAuthorized = true
|
||||
}
|
||||
|
||||
// clearStore removes the database.
|
||||
func (u *User) clearStore() error {
|
||||
u.log.Trace("Clearing user store")
|
||||
|
||||
if u.store != nil {
|
||||
if err := u.store.Remove(); err != nil {
|
||||
return errors.Wrap(err, "failed to remove store")
|
||||
}
|
||||
} else {
|
||||
u.log.Warn("Store is not initialized: cleaning up store files manually")
|
||||
if err := store.RemoveStore(u.storeCache, u.storePath, u.userID); err != nil {
|
||||
return errors.Wrap(err, "failed to remove store manually")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeStore just closes the store without deleting it.
|
||||
func (u *User) closeStore() error {
|
||||
u.log.Trace("Closing user store")
|
||||
|
||||
if u.store != nil {
|
||||
if err := u.store.Close(); err != nil {
|
||||
return errors.Wrap(err, "failed to close store")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getUserStorePath returns the file path of the store database for the given userID.
|
||||
func getUserStorePath(storeDir string, userID string) (path string) {
|
||||
fileName := fmt.Sprintf("mailbox-%v.db", userID)
|
||||
return filepath.Join(storeDir, fileName)
|
||||
}
|
||||
|
||||
// GetTemporaryPMAPIClient returns an authorised PMAPI client.
|
||||
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
|
||||
// After proper refactor of SMTP and IMAP remove this method.
|
||||
func (u *User) GetTemporaryPMAPIClient() pmapi.Client {
|
||||
return u.client()
|
||||
}
|
||||
|
||||
// ID returns the user's userID.
|
||||
func (u *User) ID() string {
|
||||
return u.userID
|
||||
}
|
||||
|
||||
// Username returns the user's username as found in the user's credentials.
|
||||
func (u *User) Username() string {
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
return u.creds.Name
|
||||
}
|
||||
|
||||
// IsConnected returns whether user is logged in.
|
||||
func (u *User) IsConnected() bool {
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
return u.creds.IsConnected()
|
||||
}
|
||||
|
||||
// IsCombinedAddressMode returns whether user is set in combined or split mode.
|
||||
// Combined mode is the default mode and is what users typically need.
|
||||
// Split mode is mostly for outlook as it cannot handle sending e-mails from an
|
||||
// address other than the primary one.
|
||||
func (u *User) IsCombinedAddressMode() bool {
|
||||
if u.store != nil {
|
||||
return u.store.IsCombinedMode()
|
||||
}
|
||||
|
||||
return u.creds.IsCombinedAddressMode
|
||||
}
|
||||
|
||||
// GetPrimaryAddress returns the user's original address (which is
|
||||
// not necessarily the same as the primary address, because a primary address
|
||||
// might be an alias and be in position one).
|
||||
func (u *User) GetPrimaryAddress() string {
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
return u.creds.EmailList()[0]
|
||||
}
|
||||
|
||||
// GetStoreAddresses returns all addresses used by the store (so in combined mode,
|
||||
// that's just the original address, but in split mode, that's all active addresses).
|
||||
func (u *User) GetStoreAddresses() []string {
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
if u.IsCombinedAddressMode() {
|
||||
return u.creds.EmailList()[:1]
|
||||
}
|
||||
|
||||
return u.creds.EmailList()
|
||||
}
|
||||
|
||||
// getStoreAddresses returns a user's used addresses (with the original address in first place).
|
||||
func (u *User) getStoreAddresses() []string { // nolint[unused]
|
||||
addrInfo, err := u.store.GetAddressInfo()
|
||||
if err != nil {
|
||||
u.log.WithError(err).Error("Failed getting address info from store")
|
||||
return nil
|
||||
}
|
||||
|
||||
addresses := []string{}
|
||||
for _, addr := range addrInfo {
|
||||
addresses = append(addresses, addr.Address)
|
||||
}
|
||||
|
||||
if u.IsCombinedAddressMode() {
|
||||
return addresses[:1]
|
||||
}
|
||||
|
||||
return addresses
|
||||
}
|
||||
|
||||
// GetAddresses returns list of all addresses.
|
||||
func (u *User) GetAddresses() []string {
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
return u.creds.EmailList()
|
||||
}
|
||||
|
||||
// GetAddressID returns the API ID of the given address.
|
||||
func (u *User) GetAddressID(address string) (id string, err error) {
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
address = strings.ToLower(address)
|
||||
|
||||
if u.store == nil {
|
||||
err = errors.New("store is not initialised")
|
||||
return
|
||||
}
|
||||
|
||||
return u.store.GetAddressID(address)
|
||||
}
|
||||
|
||||
// GetBridgePassword returns bridge password. This is not a password of the PM
|
||||
// account, but generated password for local purposes to not use a PM account
|
||||
// in the clients (such as Thunderbird).
|
||||
func (u *User) GetBridgePassword() string {
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
return u.creds.BridgePassword
|
||||
}
|
||||
|
||||
// CheckBridgeLogin checks whether the user is logged in and the bridge
|
||||
// IMAP/SMTP password is correct.
|
||||
func (u *User) CheckBridgeLogin(password string) error {
|
||||
if isApplicationOutdated {
|
||||
u.listener.Emit(events.UpgradeApplicationEvent, "")
|
||||
return pmapi.ErrUpgradeApplication
|
||||
}
|
||||
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
// True here because users should be notified by popup of auth failure.
|
||||
if err := u.authorizeIfNecessary(true); err != nil {
|
||||
u.log.WithError(err).Error("Failed to authorize user")
|
||||
return err
|
||||
}
|
||||
|
||||
return u.creds.CheckPassword(password)
|
||||
}
|
||||
|
||||
// UpdateUser updates user details from API and saves to the credentials.
|
||||
func (u *User) UpdateUser() error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
if err := u.authorizeIfNecessary(true); err != nil {
|
||||
return errors.Wrap(err, "cannot update user")
|
||||
}
|
||||
|
||||
_, err := u.client().UpdateUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = u.client().Unlock(u.creds.MailboxPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := u.client().UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
emails := u.client().Addresses().ActiveEmails()
|
||||
if err := u.credStorer.UpdateEmails(u.userID, emails); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SwitchAddressMode changes mode from combined to split and vice versa. The mode to switch to is determined by the
|
||||
// state of the user's credentials in the credentials store. See `IsCombinedAddressMode` for more details.
|
||||
func (u *User) SwitchAddressMode() (err error) {
|
||||
u.log.Trace("Switching user address mode")
|
||||
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
u.closeAllConnections()
|
||||
|
||||
if u.store == nil {
|
||||
err = errors.New("store is not initialised")
|
||||
return
|
||||
}
|
||||
|
||||
newAddressModeState := !u.IsCombinedAddressMode()
|
||||
|
||||
if err = u.store.UseCombinedMode(newAddressModeState); err != nil {
|
||||
u.log.WithError(err).Error("Could not switch store address mode")
|
||||
return
|
||||
}
|
||||
|
||||
if u.creds.IsCombinedAddressMode != newAddressModeState {
|
||||
if err = u.credStorer.SwitchAddressMode(u.userID); err != nil {
|
||||
u.log.WithError(err).Error("Could not switch credentials store address mode")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// logout is the same as Logout, but for internal purposes (logged out from
|
||||
// the server) which emits LogoutEvent to notify other parts of the app.
|
||||
func (u *User) logout() error {
|
||||
u.lock.Lock()
|
||||
wasConnected := u.creds.IsConnected()
|
||||
u.lock.Unlock()
|
||||
|
||||
err := u.Logout()
|
||||
|
||||
if wasConnected {
|
||||
u.listener.Emit(events.LogoutEvent, u.userID)
|
||||
u.listener.Emit(events.UserRefreshEvent, u.userID)
|
||||
}
|
||||
|
||||
u.isAuthorized = false
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Logout logs out the user from pmapi, the credentials store, the mail store, and tries to remove as much
|
||||
// sensitive data as possible.
|
||||
func (u *User) Logout() (err error) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
u.log.Debug("Logging out user")
|
||||
|
||||
if !u.creds.IsConnected() {
|
||||
return
|
||||
}
|
||||
|
||||
u.unlockingKeyringLock.Lock()
|
||||
u.wasKeyringUnlocked = false
|
||||
u.unlockingKeyringLock.Unlock()
|
||||
|
||||
u.client().Logout()
|
||||
|
||||
if err = u.credStorer.Logout(u.userID); err != nil {
|
||||
u.log.WithError(err).Warn("Could not log user out from credentials store")
|
||||
|
||||
if err = u.credStorer.Delete(u.userID); err != nil {
|
||||
u.log.WithError(err).Error("Could not delete user from credentials store")
|
||||
}
|
||||
}
|
||||
|
||||
u.refreshFromCredentials()
|
||||
|
||||
// Do not close whole store, just event loop. Some information might be needed offline (e.g. addressID)
|
||||
u.closeEventLoop()
|
||||
|
||||
u.closeAllConnections()
|
||||
|
||||
runtime.GC()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *User) refreshFromCredentials() {
|
||||
if credentials, err := u.credStorer.Get(u.userID); err != nil {
|
||||
log.WithError(err).Error("Cannot refresh user credentials")
|
||||
} else {
|
||||
u.creds = credentials
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) closeEventLoop() {
|
||||
if u.store == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.store.CloseEventLoop()
|
||||
}
|
||||
|
||||
// closeAllConnections calls CloseConnection for all users addresses.
|
||||
func (u *User) closeAllConnections() {
|
||||
for _, address := range u.creds.EmailList() {
|
||||
u.CloseConnection(address)
|
||||
}
|
||||
|
||||
if u.store != nil {
|
||||
u.store.SetIMAPUpdateChannel(nil)
|
||||
}
|
||||
}
|
||||
|
||||
// CloseConnection emits closeConnection event on `address` which should close all active connection.
|
||||
func (u *User) CloseConnection(address string) {
|
||||
u.listener.Emit(events.CloseConnectionEvent, address)
|
||||
}
|
||||
|
||||
func (u *User) GetStore() *store.Store {
|
||||
return u.store
|
||||
}
|
||||
234
internal/users/user_credentials_test.go
Normal file
234
internal/users/user_credentials_test.go
Normal file
@ -0,0 +1,234 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
|
||||
|
||||
m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
|
||||
)
|
||||
|
||||
assert.NoError(t, user.UpdateUser())
|
||||
|
||||
waitForEvents()
|
||||
}
|
||||
|
||||
func TestUserSwitchAddressMode(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
assert.True(t, user.store.IsCombinedMode())
|
||||
assert.True(t, user.creds.IsCombinedAddressMode)
|
||||
assert.True(t, user.IsCombinedAddressMode())
|
||||
waitForEvents()
|
||||
|
||||
gomock.InOrder(
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsSplit, nil),
|
||||
)
|
||||
|
||||
assert.NoError(t, user.SwitchAddressMode())
|
||||
assert.False(t, user.store.IsCombinedMode())
|
||||
assert.False(t, user.creds.IsCombinedAddressMode)
|
||||
assert.False(t, user.IsCombinedAddressMode())
|
||||
|
||||
gomock.InOrder(
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
)
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
|
||||
assert.NoError(t, user.SwitchAddressMode())
|
||||
assert.True(t, user.store.IsCombinedMode())
|
||||
assert.True(t, user.creds.IsCombinedAddressMode)
|
||||
assert.True(t, user.IsCombinedAddressMode())
|
||||
|
||||
waitForEvents()
|
||||
}
|
||||
|
||||
func TestLogoutUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUserForLogout(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Logout().Return(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
err := user.Logout()
|
||||
|
||||
waitForEvents()
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLogoutUserFailsLogout(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUserForLogout(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Logout().Return(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
)
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
err := user.Logout()
|
||||
waitForEvents()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginOK(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
|
||||
)
|
||||
|
||||
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
|
||||
|
||||
waitForEvents()
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
|
||||
|
||||
isApplicationOutdated = true
|
||||
|
||||
err := user.CheckBridgeLogin("any-pass")
|
||||
waitForEvents()
|
||||
assert.Equal(t, pmapi.ErrUpgradeApplication, err)
|
||||
|
||||
isApplicationOutdated = false
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginLoggedOut(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
|
||||
user, err := newUser(
|
||||
m.PanicHandler, "user",
|
||||
m.eventListener, m.credentialsStore,
|
||||
m.clientManager, m.storeCache, "/tmp",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient).MinTimes(1)
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
)
|
||||
|
||||
err = user.init(nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
|
||||
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
|
||||
waitForEvents()
|
||||
assert.Equal(t, ErrLoggedOutUser, err)
|
||||
}
|
||||
|
||||
func TestCheckBridgeLoginBadPassword(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUser(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
|
||||
)
|
||||
|
||||
err := user.CheckBridgeLogin("wrong!")
|
||||
waitForEvents()
|
||||
assert.Equal(t, "backend/credentials: incorrect password", err.Error())
|
||||
}
|
||||
177
internal/users/user_new_test.go
Normal file
177
internal/users/user_new_test.go
Normal file
@ -0,0 +1,177 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewUserNoCredentialsStore(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
|
||||
|
||||
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||
a.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewUserAppOutdated(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrUpgradeApplication),
|
||||
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, ""),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUpgradeApplication),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
)
|
||||
|
||||
checkNewUserHasCredentials(testCredentials, m)
|
||||
}
|
||||
|
||||
func TestNewUserNoInternetConnection(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrAPINotReachable),
|
||||
m.eventListener.EXPECT().Emit(events.InternetOffEvent, ""),
|
||||
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrAPINotReachable),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
m.pmapiClient.EXPECT().GetEvent("").Return(nil, pmapi.ErrAPINotReachable).AnyTimes(),
|
||||
)
|
||||
|
||||
checkNewUserHasCredentials(testCredentials, m)
|
||||
}
|
||||
|
||||
func TestNewUserAuthRefreshFails(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
)
|
||||
|
||||
checkNewUserHasCredentials(testCredentialsDisconnected, m)
|
||||
}
|
||||
|
||||
func TestNewUserUnlockFails(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password")),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
)
|
||||
|
||||
checkNewUserHasCredentials(testCredentialsDisconnected, m)
|
||||
}
|
||||
|
||||
func TestNewUserUnlockAddressesFails(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(errors.New("bad password")),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
)
|
||||
|
||||
checkNewUserHasCredentials(testCredentialsDisconnected, m)
|
||||
}
|
||||
|
||||
func TestNewUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
mockConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
checkNewUserHasCredentials(testCredentials, m)
|
||||
}
|
||||
|
||||
func checkNewUserHasCredentials(creds *credentials.Credentials, m mocks) {
|
||||
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
_ = user.init(nil)
|
||||
|
||||
waitForEvents()
|
||||
|
||||
a.Equal(m.t, creds, user.creds)
|
||||
}
|
||||
|
||||
func _TestUserEventRefreshUpdatesAddresses(t *testing.T) { // nolint[funlen]
|
||||
a.Fail(t, "not implemented")
|
||||
}
|
||||
101
internal/users/user_test.go
Normal file
101
internal/users/user_test.go
Normal file
@ -0,0 +1,101 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testNewUser sets up a new, authorised user.
|
||||
func testNewUser(m mocks) *User {
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
mockConnectedUser(m)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).MaxTimes(1),
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
|
||||
)
|
||||
|
||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
err = user.init(nil)
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
mockAuthUpdate(user, "reftok", m)
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func testNewUserForLogout(m mocks) *User {
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
mockConnectedUser(m)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).MaxTimes(1),
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
|
||||
)
|
||||
|
||||
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
err = user.init(nil)
|
||||
assert.NoError(m.t, err)
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func cleanUpUserData(u *User) {
|
||||
_ = u.clearStore()
|
||||
}
|
||||
|
||||
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
|
||||
assert.Fail(t, "not implemented")
|
||||
}
|
||||
|
||||
func TestClearStoreWithStore(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUserForLogout(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
require.Nil(t, user.store.Close())
|
||||
user.store = nil
|
||||
assert.Nil(t, user.clearStore())
|
||||
}
|
||||
|
||||
func TestClearStoreWithoutStore(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
user := testNewUserForLogout(m)
|
||||
defer cleanUpUserData(user)
|
||||
|
||||
assert.NotNil(t, user.store)
|
||||
assert.Nil(t, user.clearStore())
|
||||
}
|
||||
537
internal/users/users.go
Normal file
537
internal/users/users.go
Normal file
@ -0,0 +1,537 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package users provides core business logic providing API over credentials store and PM API.
|
||||
package users
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/metrics"
|
||||
"github.com/ProtonMail/proton-bridge/internal/preferences"
|
||||
"github.com/ProtonMail/proton-bridge/internal/store"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
imapBackend "github.com/emersion/go-imap/backend"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pkg/errors"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
log = logrus.WithField("pkg", "users") //nolint[gochecknoglobals]
|
||||
isApplicationOutdated = false //nolint[gochecknoglobals]
|
||||
)
|
||||
|
||||
// Users is a struct handling users.
|
||||
type Users struct {
|
||||
config Configer
|
||||
pref PreferenceProvider
|
||||
panicHandler PanicHandler
|
||||
events listener.Listener
|
||||
clientManager ClientManager
|
||||
credStorer CredentialsStorer
|
||||
storeCache *store.Cache
|
||||
|
||||
// users is a list of accounts that have been added to the app.
|
||||
// They are stored sorted in the credentials store in the order
|
||||
// that they were added to the app chronologically.
|
||||
// People are used to that and so we preserve that ordering here.
|
||||
users []*User
|
||||
|
||||
// idleUpdates is a channel which the imap backend listens to and which it uses
|
||||
// to send idle updates to the mail client (eg thunderbird).
|
||||
// The user stores should send idle updates on this channel.
|
||||
idleUpdates chan imapBackend.Update
|
||||
|
||||
lock sync.RWMutex
|
||||
|
||||
// stopAll can be closed to stop all goroutines from looping (watchAppOutdated, watchAPIAuths, heartbeat etc).
|
||||
stopAll chan struct{}
|
||||
}
|
||||
|
||||
func New(
|
||||
config Configer,
|
||||
pref PreferenceProvider,
|
||||
panicHandler PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
clientManager ClientManager,
|
||||
credStorer CredentialsStorer,
|
||||
) *Users {
|
||||
log.Trace("Creating new users")
|
||||
|
||||
u := &Users{
|
||||
config: config,
|
||||
pref: pref,
|
||||
panicHandler: panicHandler,
|
||||
events: eventListener,
|
||||
clientManager: clientManager,
|
||||
credStorer: credStorer,
|
||||
storeCache: store.NewCache(config.GetIMAPCachePath()),
|
||||
idleUpdates: make(chan imapBackend.Update),
|
||||
lock: sync.RWMutex{},
|
||||
stopAll: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Allow DoH before starting the app if the user has previously set this setting.
|
||||
// This allows us to start even if protonmail is blocked.
|
||||
if pref.GetBool(preferences.AllowProxyKey) {
|
||||
u.AllowProxy()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
u.watchAppOutdated()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
u.watchAPIAuths()
|
||||
}()
|
||||
|
||||
go u.heartbeat()
|
||||
|
||||
if u.credStorer == nil {
|
||||
log.Error("No credentials store is available")
|
||||
} else if err := u.loadUsersFromCredentialsStore(); err != nil {
|
||||
log.WithError(err).Error("Could not load all users from credentials store")
|
||||
}
|
||||
|
||||
if pref.GetBool(preferences.FirstStartKey) {
|
||||
u.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(config.GetVersion())))
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// heartbeat sends a heartbeat signal once a day.
|
||||
func (u *Users) heartbeat() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
next, err := strconv.ParseInt(u.pref.Get(preferences.NextHeartbeatKey), 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
nextTime := time.Unix(next, 0)
|
||||
if time.Now().After(nextTime) {
|
||||
u.SendMetric(metrics.New(metrics.Heartbeat, metrics.Daily, metrics.NoLabel))
|
||||
nextTime = nextTime.Add(24 * time.Hour)
|
||||
u.pref.Set(preferences.NextHeartbeatKey, strconv.FormatInt(nextTime.Unix(), 10))
|
||||
}
|
||||
|
||||
case <-u.stopAll:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Users) loadUsersFromCredentialsStore() (err error) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
userIDs, err := u.credStorer.List()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, userID := range userIDs {
|
||||
l := log.WithField("user", userID)
|
||||
|
||||
user, newUserErr := newUser(u.panicHandler, userID, u.events, u.credStorer, u.clientManager, u.storeCache, u.config.GetDBDir())
|
||||
if newUserErr != nil {
|
||||
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
|
||||
continue
|
||||
}
|
||||
|
||||
u.users = append(u.users, user)
|
||||
|
||||
if initUserErr := user.init(u.idleUpdates); initUserErr != nil {
|
||||
l.WithField("user", userID).WithError(initUserErr).Warn("Could not initialise user")
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *Users) watchAppOutdated() {
|
||||
ch := make(chan string)
|
||||
|
||||
u.events.Add(events.UpgradeApplicationEvent, ch)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
isApplicationOutdated = true
|
||||
u.closeAllConnections()
|
||||
|
||||
case <-u.stopAll:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// watchAPIAuths receives auths from the client manager and sends them to the appropriate user.
|
||||
func (u *Users) watchAPIAuths() {
|
||||
for {
|
||||
select {
|
||||
case auth := <-u.clientManager.GetAuthUpdateChannel():
|
||||
log.Debug("Users received auth from ClientManager")
|
||||
|
||||
user, ok := u.hasUser(auth.UserID)
|
||||
if !ok {
|
||||
log.WithField("userID", auth.UserID).Info("User not available for auth update")
|
||||
continue
|
||||
}
|
||||
|
||||
if auth.Auth != nil {
|
||||
user.updateAuthToken(auth.Auth)
|
||||
} else if err := user.logout(); err != nil {
|
||||
log.WithError(err).
|
||||
WithField("userID", auth.UserID).
|
||||
Error("User logout failed while watching API auths")
|
||||
}
|
||||
|
||||
case <-u.stopAll:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Users) closeAllConnections() {
|
||||
for _, user := range u.users {
|
||||
user.closeAllConnections()
|
||||
}
|
||||
}
|
||||
|
||||
// Login authenticates a user.
|
||||
// The login flow:
|
||||
// * Authenticate user:
|
||||
// client, auth, err := users.Authenticate(username, password)
|
||||
//
|
||||
// * In case user `auth.HasTwoFactor()`, ask for it and fully authenticate the user.
|
||||
// auth2FA, err := client.Auth2FA(twoFactorCode)
|
||||
//
|
||||
// * In case user `auth.HasMailboxPassword()`, ask for it, otherwise use `password`
|
||||
// and then finish the login procedure.
|
||||
// user, err := users.FinishLogin(client, auth, mailboxPassword)
|
||||
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
|
||||
u.crashBandicoot(username)
|
||||
|
||||
// We need to use anonymous client because we don't yet have userID and so can't save auth tokens yet.
|
||||
authClient = u.clientManager.GetAnonymousClient()
|
||||
|
||||
authInfo, err := authClient.AuthInfo(username)
|
||||
if err != nil {
|
||||
log.WithField("username", username).WithError(err).Error("Could not get auth info for user")
|
||||
return
|
||||
}
|
||||
|
||||
if auth, err = authClient.Auth(username, password, authInfo); err != nil {
|
||||
log.WithField("username", username).WithError(err).Error("Could not get auth for user")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// FinishLogin finishes the login procedure and adds the user into the credentials store.
|
||||
// See `Login` for more details of the login flow.
|
||||
func (u *Users) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPassword string) (user *User, err error) { //nolint[funlen]
|
||||
defer func() {
|
||||
if err == pmapi.ErrUpgradeApplication {
|
||||
u.events.Emit(events.UpgradeApplicationEvent, "")
|
||||
}
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("Login not finished; removing auth session")
|
||||
if delAuthErr := authClient.DeleteAuth(); delAuthErr != nil {
|
||||
log.WithError(delAuthErr).Error("Failed to clear login session after unlock")
|
||||
}
|
||||
}
|
||||
// The anonymous client will be removed from list and authentication will not be deleted.
|
||||
authClient.Logout()
|
||||
}()
|
||||
|
||||
apiUser, hashedPassword, err := getAPIUser(authClient, auth, mbPassword)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to get API user")
|
||||
return
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if user, ok = u.hasUser(apiUser.ID); ok {
|
||||
if err = u.connectExistingUser(user, auth, hashedPassword); err != nil {
|
||||
log.WithError(err).Error("Failed to connect existing user")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err = u.addNewUser(apiUser, auth, hashedPassword); err != nil {
|
||||
log.WithError(err).Error("Failed to add new user")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
u.events.Emit(events.UserRefreshEvent, apiUser.ID)
|
||||
|
||||
return u.GetUser(apiUser.ID)
|
||||
}
|
||||
|
||||
// connectExistingUser connects an existing user.
|
||||
func (u *Users) connectExistingUser(user *User, auth *pmapi.Auth, hashedPassword string) (err error) {
|
||||
if user.IsConnected() {
|
||||
return errors.New("user is already connected")
|
||||
}
|
||||
|
||||
// Update the user's password in the cred store in case they changed it.
|
||||
if err = u.credStorer.UpdatePassword(user.ID(), hashedPassword); err != nil {
|
||||
return errors.Wrap(err, "failed to update password of user in credentials store")
|
||||
}
|
||||
|
||||
client := u.clientManager.GetClient(user.ID())
|
||||
|
||||
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh auth token of new client")
|
||||
}
|
||||
|
||||
if err = u.credStorer.UpdateToken(user.ID(), auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to update token of user in credentials store")
|
||||
}
|
||||
|
||||
if err = user.init(u.idleUpdates); err != nil {
|
||||
return errors.Wrap(err, "failed to initialise user")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// addNewUser adds a new user.
|
||||
func (u *Users) addNewUser(apiUser *pmapi.User, auth *pmapi.Auth, hashedPassword string) (err error) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
client := u.clientManager.GetClient(apiUser.ID)
|
||||
|
||||
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
|
||||
return errors.Wrap(err, "failed to refresh token in new client")
|
||||
}
|
||||
|
||||
if apiUser, err = client.CurrentUser(); err != nil {
|
||||
return errors.Wrap(err, "failed to update API user")
|
||||
}
|
||||
|
||||
activeEmails := client.Addresses().ActiveEmails()
|
||||
|
||||
if _, err = u.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), hashedPassword, activeEmails); err != nil {
|
||||
return errors.Wrap(err, "failed to add user to credentials store")
|
||||
}
|
||||
|
||||
user, err := newUser(u.panicHandler, apiUser.ID, u.events, u.credStorer, u.clientManager, u.storeCache, u.config.GetDBDir())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create user")
|
||||
}
|
||||
|
||||
// The user needs to be part of the users list in order for it to receive an auth during initialisation.
|
||||
u.users = append(u.users, user)
|
||||
|
||||
if err = user.init(u.idleUpdates); err != nil {
|
||||
u.users = u.users[:len(u.users)-1]
|
||||
return errors.Wrap(err, "failed to initialise user")
|
||||
}
|
||||
|
||||
u.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func getAPIUser(client pmapi.Client, auth *pmapi.Auth, mbPassword string) (user *pmapi.User, hashedPassword string, err error) {
|
||||
hashedPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not hash mailbox password")
|
||||
return
|
||||
}
|
||||
|
||||
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
|
||||
if _, err = client.Unlock(hashedPassword); err != nil {
|
||||
log.WithError(err).Error("Wrong mailbox password")
|
||||
return
|
||||
}
|
||||
|
||||
if user, err = client.CurrentUser(); err != nil {
|
||||
log.WithError(err).Error("Could not load API user")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetUsers returns all added users into keychain (even logged out users).
|
||||
func (u *Users) GetUsers() []*User {
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
return u.users
|
||||
}
|
||||
|
||||
// GetUser returns a user by `query` which is compared to users' ID, username or any attached e-mail address.
|
||||
func (u *Users) GetUser(query string) (*User, error) {
|
||||
u.crashBandicoot(query)
|
||||
|
||||
u.lock.RLock()
|
||||
defer u.lock.RUnlock()
|
||||
|
||||
for _, user := range u.users {
|
||||
if strings.EqualFold(user.ID(), query) || strings.EqualFold(user.Username(), query) {
|
||||
return user, nil
|
||||
}
|
||||
for _, address := range user.GetAddresses() {
|
||||
if strings.EqualFold(address, query) {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("user " + query + " not found")
|
||||
}
|
||||
|
||||
// ClearData closes all connections (to release db files and so on) and clears all data.
|
||||
func (u *Users) ClearData() error {
|
||||
var result *multierror.Error
|
||||
for _, user := range u.users {
|
||||
if err := user.Logout(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := user.closeStore(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if err := u.config.ClearData(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// DeleteUser deletes user completely; it logs user out from the API, stops any
|
||||
// active connection, deletes from credentials store and removes from the Bridge struct.
|
||||
func (u *Users) DeleteUser(userID string, clearStore bool) error {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
log := log.WithField("user", userID)
|
||||
|
||||
for idx, user := range u.users {
|
||||
if user.ID() == userID {
|
||||
if err := user.Logout(); err != nil {
|
||||
log.WithError(err).Error("Cannot logout user")
|
||||
// We can try to continue to remove the user.
|
||||
// Token will still be valid, but will expire eventually.
|
||||
}
|
||||
|
||||
if err := user.closeStore(); err != nil {
|
||||
log.WithError(err).Error("Failed to close user store")
|
||||
}
|
||||
if clearStore {
|
||||
// Clear cache after closing connections (done in logout).
|
||||
if err := user.clearStore(); err != nil {
|
||||
log.WithError(err).Error("Failed to clear user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := u.credStorer.Delete(userID); err != nil {
|
||||
log.WithError(err).Error("Cannot remove user")
|
||||
return err
|
||||
}
|
||||
u.users = append(u.users[:idx], u.users[idx+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("user " + userID + " not found")
|
||||
}
|
||||
|
||||
// SendMetric sends a metric. We don't want to return any errors, only log them.
|
||||
func (u *Users) SendMetric(m metrics.Metric) {
|
||||
c := u.clientManager.GetAnonymousClient()
|
||||
defer c.Logout()
|
||||
|
||||
cat, act, lab := m.Get()
|
||||
if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil {
|
||||
log.Error("Sending metric failed: ", err)
|
||||
}
|
||||
|
||||
log.WithFields(logrus.Fields{
|
||||
"cat": cat,
|
||||
"act": act,
|
||||
"lab": lab,
|
||||
}).Debug("Metric successfully sent")
|
||||
}
|
||||
|
||||
// GetIMAPUpdatesChannel sets the channel on which idle events should be sent.
|
||||
func (u *Users) GetIMAPUpdatesChannel() chan imapBackend.Update {
|
||||
if u.idleUpdates == nil {
|
||||
log.Warn("IMAP updates channel is nil")
|
||||
}
|
||||
|
||||
return u.idleUpdates
|
||||
}
|
||||
|
||||
// AllowProxy instructs the app to use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||
func (u *Users) AllowProxy() {
|
||||
u.clientManager.AllowProxy()
|
||||
}
|
||||
|
||||
// DisallowProxy instructs the app to not use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||
func (u *Users) DisallowProxy() {
|
||||
u.clientManager.DisallowProxy()
|
||||
}
|
||||
|
||||
// CheckConnection returns whether there is an internet connection.
|
||||
// This should use the connection manager when it is eventually implemented.
|
||||
func (u *Users) CheckConnection() error {
|
||||
return u.clientManager.CheckConnection()
|
||||
}
|
||||
|
||||
// StopWatchers stops all goroutines.
|
||||
func (u *Users) StopWatchers() {
|
||||
close(u.stopAll)
|
||||
}
|
||||
|
||||
// hasUser returns whether the struct currently has a user with ID `id`.
|
||||
func (u *Users) hasUser(id string) (user *User, ok bool) {
|
||||
for _, u := range u.users {
|
||||
if u.ID() == id {
|
||||
user, ok = u, true
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// "Easter egg" for testing purposes.
|
||||
func (u *Users) crashBandicoot(username string) {
|
||||
if username == "crash@bandicoot" {
|
||||
panic("Your wish is my command… I crash!")
|
||||
}
|
||||
}
|
||||
143
internal/users/users_actions_test.go
Normal file
143
internal/users/users_actions_test.go
Normal file
@ -0,0 +1,143 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetNoUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
checkUsersGetUser(t, m, "nouser", -1, "user nouser not found")
|
||||
}
|
||||
|
||||
func TestGetUserByID(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
checkUsersGetUser(t, m, "user", 0, "")
|
||||
checkUsersGetUser(t, m, "users", 1, "")
|
||||
}
|
||||
|
||||
func TestGetUserByName(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
checkUsersGetUser(t, m, "username", 0, "")
|
||||
checkUsersGetUser(t, m, "usersname", 1, "")
|
||||
}
|
||||
|
||||
func TestGetUserByEmail(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
checkUsersGetUser(t, m, "user@pm.me", 0, "")
|
||||
checkUsersGetUser(t, m, "users@pm.me", 1, "")
|
||||
checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "")
|
||||
checkUsersGetUser(t, m, "alsouser@pm.me", 1, "")
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Logout().Return(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
err := users.DeleteUser("user", true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(users.users))
|
||||
}
|
||||
|
||||
// Even when logout fails, delete is done.
|
||||
func TestDeleteUserWithFailingLogout(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Logout().Return(),
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("no such user")),
|
||||
m.credentialsStore.EXPECT().Delete("user").Return(nil),
|
||||
)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
err := users.DeleteUser("user", true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(users.users))
|
||||
}
|
||||
|
||||
func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) {
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
user, err := users.GetUser(query)
|
||||
waitForEvents()
|
||||
|
||||
if expectedError != "" {
|
||||
assert.Equal(m.t, expectedError, err.Error())
|
||||
} else {
|
||||
assert.NoError(m.t, err)
|
||||
}
|
||||
|
||||
var expectedUser *User
|
||||
if index >= 0 {
|
||||
expectedUser = users.users[index]
|
||||
}
|
||||
|
||||
assert.Equal(m.t, expectedUser, user)
|
||||
}
|
||||
239
internal/users/users_login_test.go
Normal file
239
internal/users/users_login_test.go
Normal file
@ -0,0 +1,239 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/metrics"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUsersFinishLoginBadMailboxPassword(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
err := errors.New("bad password")
|
||||
gomock.InOrder(
|
||||
// Init users with no user from keychain.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
|
||||
|
||||
// Set up mocks for FinishLogin.
|
||||
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err),
|
||||
m.pmapiClient.EXPECT().DeleteAuth(),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
)
|
||||
|
||||
checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", err)
|
||||
}
|
||||
|
||||
func TestUsersFinishLoginUpgradeApplication(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
err := errors.New("Cannot logout when upgrade needed")
|
||||
gomock.InOrder(
|
||||
// Init users with no user from keychain.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
|
||||
|
||||
// Set up mocks for FinishLogin.
|
||||
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, pmapi.ErrUpgradeApplication),
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, ""),
|
||||
m.pmapiClient.EXPECT().DeleteAuth().Return(err),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
)
|
||||
|
||||
checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", pmapi.ErrUpgradeApplication)
|
||||
}
|
||||
|
||||
func refreshWithToken(token string) *pmapi.Auth {
|
||||
return &pmapi.Auth{
|
||||
RefreshToken: token,
|
||||
KeySalt: "", // No salting in tests.
|
||||
}
|
||||
}
|
||||
|
||||
func credentialsWithToken(token string) *credentials.Credentials {
|
||||
tmp := &credentials.Credentials{}
|
||||
*tmp = *testCredentials
|
||||
tmp.APIToken = token
|
||||
return tmp
|
||||
}
|
||||
|
||||
func TestUsersFinishLoginNewUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
// Basically every call client has get client manager
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
gomock.InOrder(
|
||||
// users.New() finds no users in keychain.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
|
||||
|
||||
// getAPIUser() loads user info from API (e.g. userID).
|
||||
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
||||
|
||||
// addNewUser()
|
||||
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
|
||||
|
||||
// user.init() in addNewUser
|
||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
|
||||
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
|
||||
// store.New() in user.init
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
// Emit event for new user and send metrics.
|
||||
m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient),
|
||||
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
|
||||
// Reload account list in GUI.
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
|
||||
|
||||
// defer logout anonymous
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
)
|
||||
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
|
||||
|
||||
mockAuthUpdate(user, "afterCredentials", m)
|
||||
}
|
||||
|
||||
func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
loggedOutCreds := *testCredentials
|
||||
loggedOutCreds.APIToken = ""
|
||||
loggedOutCreds.MailboxPassword = ""
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
gomock.InOrder(
|
||||
// users.New() finds one existing user in keychain.
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
|
||||
|
||||
// newUser()
|
||||
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
|
||||
|
||||
// user.init()
|
||||
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
|
||||
|
||||
// store.New() in user.init
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrInvalidToken),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
|
||||
// getAPIUser() loads user info from API (e.g. userID).
|
||||
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
||||
|
||||
// connectExistingUser()
|
||||
m.credentialsStore.EXPECT().UpdatePassword("user", testCredentials.MailboxPassword).Return(nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":afterLogin").Return(nil),
|
||||
|
||||
// user.init() in connectExistingUser
|
||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
|
||||
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
|
||||
// store.New() in user.init
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
// Reload account list in GUI.
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
|
||||
|
||||
// defer logout anonymous
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
)
|
||||
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
|
||||
|
||||
mockAuthUpdate(user, "afterCredentials", m)
|
||||
}
|
||||
|
||||
func TestUsersFinishLoginConnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
|
||||
mockConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
users := testNewUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
// Then, try to log in again...
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
|
||||
m.pmapiClient.EXPECT().DeleteAuth(),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
)
|
||||
|
||||
_, err := users.FinishLogin(m.pmapiClient, testAuth, testCredentials.MailboxPassword)
|
||||
assert.Equal(t, "user is already connected", err.Error())
|
||||
}
|
||||
|
||||
func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) *User {
|
||||
users := testNewUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword)
|
||||
|
||||
waitForEvents()
|
||||
|
||||
assert.Equal(t, expectedErr, err)
|
||||
|
||||
if expectedUserID != "" {
|
||||
assert.Equal(t, expectedUserID, user.ID())
|
||||
assert.Equal(t, 1, len(users.users))
|
||||
assert.Equal(t, expectedUserID, users.users[0].ID())
|
||||
} else {
|
||||
assert.Equal(t, (*User)(nil), user)
|
||||
assert.Equal(t, 0, len(users.users))
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
188
internal/users/users_new_test.go
Normal file
188
internal/users/users_new_test.go
Normal file
@ -0,0 +1,188 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/metrics"
|
||||
"github.com/ProtonMail/proton-bridge/internal/preferences"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewUsersNoKeychain(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, errors.New("no keychain"))
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{})
|
||||
}
|
||||
|
||||
func TestNewUsersWithoutUsersInCredentialsStore(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{})
|
||||
}
|
||||
|
||||
func TestNewUsersWithDisconnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
// Basically every call client has get client manager.
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
)
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||
}
|
||||
|
||||
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
|
||||
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token"))
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
|
||||
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
|
||||
m.pmapiClient.EXPECT().Logout()
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
|
||||
}
|
||||
|
||||
func mockConnectedUser(m mocks) {
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
|
||||
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
|
||||
// Set up mocks for store initialisation for the authorized user.
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
)
|
||||
}
|
||||
|
||||
// mockAuthUpdate simulates users calling UpdateAuthToken on the given user.
|
||||
// This would normally be done by users when it receives an auth from the ClientManager,
|
||||
// but as we don't have a full users instance here, we do this manually.
|
||||
func mockAuthUpdate(user *User, token string, m mocks) {
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().UpdateToken("user", ":"+token).Return(nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(token), nil),
|
||||
)
|
||||
|
||||
user.updateAuthToken(refreshWithToken(token))
|
||||
|
||||
waitForEvents()
|
||||
}
|
||||
|
||||
func TestNewUsersWithConnectedUser(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
|
||||
|
||||
mockConnectedUser(m)
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
|
||||
}
|
||||
|
||||
// Tests two users with different states and checks also the order from
|
||||
// credentials store is kept also in array of users.
|
||||
func TestNewUsersWithUsers(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"userDisconnected", "user"}, nil)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
|
||||
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
|
||||
// Set up mocks for store initialisation for the unauth user.
|
||||
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
|
||||
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(nil),
|
||||
)
|
||||
|
||||
mockConnectedUser(m)
|
||||
|
||||
mockEventLoopNoAction(m)
|
||||
|
||||
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
|
||||
}
|
||||
|
||||
func TestNewUsersFirstStart(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
|
||||
m.prefProvider.EXPECT().GetBool(preferences.FirstStartKey).Return(true),
|
||||
m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient),
|
||||
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.FirstStart), gomock.Any()),
|
||||
m.pmapiClient.EXPECT().Logout(),
|
||||
)
|
||||
|
||||
testNewUsers(t, m)
|
||||
}
|
||||
|
||||
func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) {
|
||||
users := testNewUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
assert.Equal(m.t, len(expectedCredentials), len(users.GetUsers()))
|
||||
|
||||
credentials := []*credentials.Credentials{}
|
||||
for _, user := range users.users {
|
||||
credentials = append(credentials, user.creds)
|
||||
}
|
||||
|
||||
assert.Equal(m.t, expectedCredentials, credentials)
|
||||
}
|
||||
303
internal/users/users_test.go
Normal file
303
internal/users/users_test.go
Normal file
@ -0,0 +1,303 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/internal/metrics"
|
||||
"github.com/ProtonMail/proton-bridge/internal/preferences"
|
||||
"github.com/ProtonMail/proton-bridge/internal/store"
|
||||
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
|
||||
usersmocks "github.com/ProtonMail/proton-bridge/internal/users/mocks"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if os.Getenv("VERBOSITY") == "fatal" {
|
||||
logrus.SetLevel(logrus.FatalLevel)
|
||||
}
|
||||
if os.Getenv("VERBOSITY") == "trace" {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
var (
|
||||
testAuth = &pmapi.Auth{ //nolint[gochecknoglobals]
|
||||
RefreshToken: "tok",
|
||||
KeySalt: "", // No salting in tests.
|
||||
}
|
||||
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
|
||||
RefreshToken: "reftok",
|
||||
KeySalt: "", // No salting in tests.
|
||||
}
|
||||
|
||||
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "user",
|
||||
Name: "username",
|
||||
Emails: "user@pm.me",
|
||||
APIToken: "token",
|
||||
MailboxPassword: "pass",
|
||||
BridgePassword: "0123456789abcdef",
|
||||
Version: "v1",
|
||||
Timestamp: 123456789,
|
||||
IsHidden: false,
|
||||
IsCombinedAddressMode: true,
|
||||
}
|
||||
testCredentialsSplit = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "users",
|
||||
Name: "usersname",
|
||||
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
|
||||
APIToken: "token",
|
||||
MailboxPassword: "pass",
|
||||
BridgePassword: "0123456789abcdef",
|
||||
Version: "v1",
|
||||
Timestamp: 123456789,
|
||||
IsHidden: false,
|
||||
IsCombinedAddressMode: false,
|
||||
}
|
||||
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
|
||||
UserID: "user",
|
||||
Name: "username",
|
||||
Emails: "user@pm.me",
|
||||
APIToken: "",
|
||||
MailboxPassword: "",
|
||||
BridgePassword: "0123456789abcdef",
|
||||
Version: "v1",
|
||||
Timestamp: 123456789,
|
||||
IsHidden: false,
|
||||
IsCombinedAddressMode: true,
|
||||
}
|
||||
|
||||
testPMAPIUser = &pmapi.User{ //nolint[gochecknoglobals]
|
||||
ID: "user",
|
||||
Name: "username",
|
||||
}
|
||||
|
||||
testPMAPIAddress = &pmapi.Address{ //nolint[gochecknoglobals]
|
||||
ID: "testAddressID",
|
||||
Type: pmapi.OriginalAddress,
|
||||
Email: "user@pm.me",
|
||||
Receive: pmapi.CanReceive,
|
||||
}
|
||||
|
||||
testPMAPIAddresses = []*pmapi.Address{ //nolint[gochecknoglobals]
|
||||
{ID: "usersAddress1ID", Email: "users@pm.me", Receive: pmapi.CanReceive, Type: pmapi.OriginalAddress},
|
||||
{ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress},
|
||||
{ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress},
|
||||
}
|
||||
|
||||
testPMAPIEvent = &pmapi.Event{ // nolint[gochecknoglobals]
|
||||
EventID: "ACXDmTaBub14w==",
|
||||
}
|
||||
)
|
||||
|
||||
func waitForEvents() {
|
||||
// Wait for goroutine to add listener.
|
||||
// E.g. calling login to invoke firstsync event. Functions can end sooner than
|
||||
// goroutines call the listener mock. We need to wait a little bit before the end of
|
||||
// the test to capture all event calls. This allows us to detect whether there were
|
||||
// missing calls, or perhaps whether something was called too many times.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
type mocks struct {
|
||||
t *testing.T
|
||||
|
||||
ctrl *gomock.Controller
|
||||
config *usersmocks.MockConfiger
|
||||
PanicHandler *usersmocks.MockPanicHandler
|
||||
prefProvider *usersmocks.MockPreferenceProvider
|
||||
clientManager *usersmocks.MockClientManager
|
||||
credentialsStore *usersmocks.MockCredentialsStorer
|
||||
eventListener *MockListener
|
||||
|
||||
pmapiClient *pmapimocks.MockClient
|
||||
|
||||
storeCache *store.Cache
|
||||
}
|
||||
|
||||
type fullStackReporter struct {
|
||||
T testing.TB
|
||||
}
|
||||
|
||||
func (fr *fullStackReporter) Errorf(format string, args ...interface{}) {
|
||||
fmt.Printf("err: "+format+"\n", args...)
|
||||
fr.T.Fail()
|
||||
}
|
||||
func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
|
||||
debug.PrintStack()
|
||||
fmt.Printf("fail: "+format+"\n", args...)
|
||||
fr.T.FailNow()
|
||||
}
|
||||
|
||||
func initMocks(t *testing.T) mocks {
|
||||
var mockCtrl *gomock.Controller
|
||||
if os.Getenv("VERBOSITY") == "trace" {
|
||||
mockCtrl = gomock.NewController(&fullStackReporter{t})
|
||||
} else {
|
||||
mockCtrl = gomock.NewController(t)
|
||||
}
|
||||
|
||||
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
|
||||
require.NoError(t, err, "could not get temporary file for store cache")
|
||||
|
||||
m := mocks{
|
||||
t: t,
|
||||
|
||||
ctrl: mockCtrl,
|
||||
config: usersmocks.NewMockConfiger(mockCtrl),
|
||||
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
|
||||
prefProvider: usersmocks.NewMockPreferenceProvider(mockCtrl),
|
||||
clientManager: usersmocks.NewMockClientManager(mockCtrl),
|
||||
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
|
||||
eventListener: NewMockListener(mockCtrl),
|
||||
|
||||
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
|
||||
|
||||
storeCache: store.NewCache(cacheFile.Name()),
|
||||
}
|
||||
|
||||
// Ignore heartbeat calls because they always happen.
|
||||
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Heartbeat), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
m.prefProvider.EXPECT().Get(preferences.NextHeartbeatKey).AnyTimes()
|
||||
m.prefProvider.EXPECT().Set(preferences.NextHeartbeatKey, gomock.Any()).AnyTimes()
|
||||
|
||||
// Called during clean-up.
|
||||
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
|
||||
// Events are asynchronous
|
||||
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
|
||||
|
||||
gomock.InOrder(
|
||||
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil),
|
||||
|
||||
// Init for user.
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
// Init for users.
|
||||
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
|
||||
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
|
||||
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
|
||||
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
|
||||
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
|
||||
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses),
|
||||
)
|
||||
|
||||
users := testNewUsers(t, m)
|
||||
|
||||
user, _ := users.GetUser("user")
|
||||
mockAuthUpdate(user, "reftok", m)
|
||||
|
||||
user, _ = users.GetUser("user")
|
||||
mockAuthUpdate(user, "reftok", m)
|
||||
|
||||
return users
|
||||
}
|
||||
|
||||
func testNewUsers(t *testing.T, m mocks) *Users {
|
||||
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
|
||||
require.NoError(t, err, "could not get temporary file for store cache")
|
||||
|
||||
m.prefProvider.EXPECT().GetBool(preferences.FirstStartKey).Return(false).AnyTimes()
|
||||
m.prefProvider.EXPECT().GetBool(preferences.AllowProxyKey).Return(false).AnyTimes()
|
||||
m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes()
|
||||
m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes()
|
||||
m.config.EXPECT().GetVersion().Return("ver").AnyTimes()
|
||||
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
|
||||
m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan pmapi.ClientAuth))
|
||||
|
||||
users := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore)
|
||||
|
||||
waitForEvents()
|
||||
|
||||
return users
|
||||
}
|
||||
|
||||
func cleanUpUsersData(b *Users) {
|
||||
for _, user := range b.users {
|
||||
_ = user.clearStore()
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearData(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
|
||||
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
|
||||
|
||||
users := testNewUsersWithUsers(t, m)
|
||||
defer cleanUpUsersData(users)
|
||||
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
|
||||
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
|
||||
|
||||
m.pmapiClient.EXPECT().Logout()
|
||||
m.credentialsStore.EXPECT().Logout("user").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
|
||||
|
||||
m.pmapiClient.EXPECT().Logout()
|
||||
m.credentialsStore.EXPECT().Logout("users").Return(nil)
|
||||
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil)
|
||||
|
||||
m.config.EXPECT().ClearData().Return(nil)
|
||||
|
||||
require.NoError(t, users.ClearData())
|
||||
|
||||
waitForEvents()
|
||||
}
|
||||
|
||||
func mockEventLoopNoAction(m mocks) {
|
||||
// Set up mocks for starting the store's event loop (in store.New).
|
||||
// The event loop runs in another goroutine so this might happen at any time.
|
||||
gomock.InOrder(
|
||||
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil),
|
||||
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil),
|
||||
// Set up mocks for performing the initial store sync.
|
||||
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil),
|
||||
)
|
||||
}
|
||||
23
internal/users/users_test_exports.go
Normal file
23
internal/users/users_test_exports.go
Normal file
@ -0,0 +1,23 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package users
|
||||
|
||||
// IsAuthorized returns whether the user has received an Auth from the API yet.
|
||||
func (u *User) IsAuthorized() bool {
|
||||
return u.isAuthorized
|
||||
}
|
||||
Reference in New Issue
Block a user