Renamed bridge to general users and keep bridge only for bridge stuff

This commit is contained in:
Michal Horejsek
2020-05-21 14:37:15 +02:00
parent 4e2ab9b389
commit 4d2baa6b85
30 changed files with 720 additions and 661 deletions

View File

@ -0,0 +1,136 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package credentials implements our struct stored in keychain.
// Store struct is kind of like a database client.
// Credentials struct is kind of like one record from the database.
package credentials
import (
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"strings"
"github.com/sirupsen/logrus"
)
const sep = "\x00"
var (
log = logrus.WithField("pkg", "credentials") //nolint[gochecknoglobals]
ErrWrongFormat = errors.New("backend/creds: malformed password")
)
type Credentials struct {
UserID, // Do not marshal; used as a key.
Name,
Emails,
APIToken,
MailboxPassword,
BridgePassword,
Version string
Timestamp int64
IsHidden, // Deprecated.
IsCombinedAddressMode bool
}
func (s *Credentials) Marshal() string {
items := []string{
s.Name, // 0
s.Emails, // 1
s.APIToken, // 2
s.MailboxPassword, // 3
s.BridgePassword, // 4
s.Version, // 5
"", // 6
"", // 7
"", // 8
}
items[6] = fmt.Sprint(s.Timestamp)
if s.IsHidden {
items[7] = "1"
}
if s.IsCombinedAddressMode {
items[8] = "1"
}
str := strings.Join(items, sep)
return base64.StdEncoding.EncodeToString([]byte(str))
}
func (s *Credentials) Unmarshal(secret string) error {
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
return err
}
items := strings.Split(string(b), sep)
if len(items) != 9 {
return ErrWrongFormat
}
s.Name = items[0]
s.Emails = items[1]
s.APIToken = items[2]
s.MailboxPassword = items[3]
s.BridgePassword = items[4]
s.Version = items[5]
if _, err = fmt.Sscan(items[6], &s.Timestamp); err != nil {
s.Timestamp = 0
}
if s.IsHidden = false; items[7] == "1" {
s.IsHidden = true
}
if s.IsCombinedAddressMode = false; items[8] == "1" {
s.IsCombinedAddressMode = true
}
return nil
}
func (s *Credentials) SetEmailList(list []string) {
s.Emails = strings.Join(list, ";")
}
func (s *Credentials) EmailList() []string {
return strings.Split(s.Emails, ";")
}
func (s *Credentials) CheckPassword(password string) error {
if subtle.ConstantTimeCompare([]byte(s.BridgePassword), []byte(password)) != 1 {
log.WithFields(logrus.Fields{
"userID": s.UserID,
}).Debug("Incorrect bridge password")
return fmt.Errorf("backend/credentials: incorrect password")
}
return nil
}
func (s *Credentials) Logout() {
s.APIToken = ""
s.MailboxPassword = ""
}
func (s *Credentials) IsConnected() bool {
return s.APIToken != "" && s.MailboxPassword != ""
}

View File

@ -0,0 +1,39 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package credentials
import (
"crypto/rand"
"encoding/base64"
"io"
)
const keySize = 16
// generateKey generates a new random key.
func generateKey() []byte {
key := make([]byte, keySize)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
panic(err)
}
return key
}
func generatePassword() string {
return base64.RawURLEncoding.EncodeToString(generateKey())
}

View File

@ -0,0 +1,330 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package credentials
import (
"errors"
"fmt"
"sort"
"sync"
"time"
"github.com/ProtonMail/proton-bridge/pkg/keychain"
"github.com/sirupsen/logrus"
)
var storeLocker = sync.RWMutex{} //nolint[gochecknoglobals]
// Store is an encrypted credentials store.
type Store struct {
secrets *keychain.Access
}
// NewStore creates a new encrypted credentials store.
func NewStore(appName string) (*Store, error) {
secrets, err := keychain.NewAccess(appName)
return &Store{
secrets: secrets,
}, err
}
func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails []string) (creds *Credentials, err error) {
storeLocker.Lock()
defer storeLocker.Unlock()
log.WithFields(logrus.Fields{
"user": userID,
"username": userName,
"emails": emails,
}).Trace("Adding new credentials")
if err = s.checkKeychain(); err != nil {
return
}
creds = &Credentials{
UserID: userID,
Name: userName,
APIToken: apiToken,
MailboxPassword: mailboxPassword,
IsHidden: false,
}
creds.SetEmailList(emails)
var has bool
if has, err = s.has(userID); err != nil {
log.WithField("userID", userID).WithError(err).Error("Could not check if user credentials already exist")
return
}
if has {
log.Info("Updating credentials of existing user")
currentCredentials, err := s.get(userID)
if err != nil {
return nil, err
}
creds.BridgePassword = currentCredentials.BridgePassword
creds.IsCombinedAddressMode = currentCredentials.IsCombinedAddressMode
creds.Timestamp = currentCredentials.Timestamp
} else {
log.Info("Generating credentials for new user")
creds.BridgePassword = generatePassword()
creds.IsCombinedAddressMode = true
creds.Timestamp = time.Now().Unix()
}
if err = s.saveCredentials(creds); err != nil {
return
}
return creds, err
}
func (s *Store) SwitchAddressMode(userID string) error {
storeLocker.Lock()
defer storeLocker.Unlock()
credentials, err := s.get(userID)
if err != nil {
return err
}
credentials.IsCombinedAddressMode = !credentials.IsCombinedAddressMode
credentials.BridgePassword = generatePassword()
return s.saveCredentials(credentials)
}
func (s *Store) UpdateEmails(userID string, emails []string) error {
storeLocker.Lock()
defer storeLocker.Unlock()
credentials, err := s.get(userID)
if err != nil {
return err
}
credentials.SetEmailList(emails)
return s.saveCredentials(credentials)
}
func (s *Store) UpdatePassword(userID, password string) error {
storeLocker.Lock()
defer storeLocker.Unlock()
credentials, err := s.get(userID)
if err != nil {
return err
}
credentials.MailboxPassword = password
return s.saveCredentials(credentials)
}
func (s *Store) UpdateToken(userID, apiToken string) error {
storeLocker.Lock()
defer storeLocker.Unlock()
credentials, err := s.get(userID)
if err != nil {
return err
}
credentials.APIToken = apiToken
return s.saveCredentials(credentials)
}
func (s *Store) Logout(userID string) error {
storeLocker.Lock()
defer storeLocker.Unlock()
credentials, err := s.get(userID)
if err != nil {
return err
}
credentials.Logout()
return s.saveCredentials(credentials)
}
// List returns a list of usernames that have credentials stored.
func (s *Store) List() (userIDs []string, err error) {
storeLocker.RLock()
defer storeLocker.RUnlock()
log.Trace("Listing credentials in credentials store")
if err = s.checkKeychain(); err != nil {
return
}
var allUserIDs []string
if allUserIDs, err = s.secrets.List(); err != nil {
log.WithError(err).Error("Could not list credentials")
return
}
credentialList := []*Credentials{}
for _, userID := range allUserIDs {
creds, getErr := s.get(userID)
if getErr != nil {
log.WithField("userID", userID).WithError(getErr).Warn("Failed to get credentials")
continue
}
if creds.Timestamp == 0 {
continue
}
credentialList = append(credentialList, creds)
}
sort.Slice(credentialList, func(i, j int) bool {
return credentialList[i].Timestamp < credentialList[j].Timestamp
})
for _, credentials := range credentialList {
userIDs = append(userIDs, credentials.UserID)
}
return userIDs, err
}
func (s *Store) GetAndCheckPassword(userID, password string) (creds *Credentials, err error) {
storeLocker.RLock()
defer storeLocker.RUnlock()
log.WithFields(logrus.Fields{
"userID": userID,
}).Debug("Checking bridge password")
credentials, err := s.Get(userID)
if err != nil {
return nil, err
}
if err := credentials.CheckPassword(password); err != nil {
log.WithFields(logrus.Fields{
"userID": userID,
"err": err,
}).Debug("Incorrect bridge password")
return nil, err
}
return credentials, nil
}
func (s *Store) Get(userID string) (creds *Credentials, err error) {
storeLocker.RLock()
defer storeLocker.RUnlock()
var has bool
if has, err = s.has(userID); err != nil {
log.WithError(err).Error("Could not check for credentials")
return
}
if !has {
err = errors.New("no credentials found for given userID")
return
}
return s.get(userID)
}
func (s *Store) has(userID string) (has bool, err error) {
if err = s.checkKeychain(); err != nil {
return
}
var ids []string
if ids, err = s.secrets.List(); err != nil {
log.WithError(err).Error("Could not list credentials")
return
}
for _, id := range ids {
if id == userID {
has = true
}
}
return
}
func (s *Store) get(userID string) (creds *Credentials, err error) {
log := log.WithField("user", userID)
if err = s.checkKeychain(); err != nil {
return
}
secret, err := s.secrets.Get(userID)
if err != nil {
log.WithError(err).Error("Could not get credentials from native keychain")
return
}
credentials := &Credentials{UserID: userID}
if err = credentials.Unmarshal(secret); err != nil {
err = fmt.Errorf("backend/credentials: malformed secret: %v", err)
_ = s.secrets.Delete(userID)
log.WithError(err).Error("Could not unmarshal secret")
return
}
return credentials, nil
}
// saveCredentials encrypts and saves password to the keychain store.
func (s *Store) saveCredentials(credentials *Credentials) (err error) {
if err = s.checkKeychain(); err != nil {
return
}
credentials.Version = keychain.KeychainVersion
return s.secrets.Put(credentials.UserID, credentials.Marshal())
}
func (s *Store) checkKeychain() (err error) {
if s.secrets == nil {
err = keychain.ErrNoKeychainInstalled
log.WithError(err).Error("Store is unusable")
}
return
}
// Delete removes credentials from the store.
func (s *Store) Delete(userID string) (err error) {
storeLocker.Lock()
defer storeLocker.Unlock()
if err = s.checkKeychain(); err != nil {
return
}
return s.secrets.Delete(userID)
}

View File

@ -0,0 +1,297 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package credentials
import (
"bytes"
"encoding/base64"
"encoding/gob"
"encoding/json"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const testSep = "\n"
const secretFormat = "%v" + testSep + // UserID,
"%v" + testSep + // Name,
"%v" + testSep + // Emails,
"%v" + testSep + // APIToken,
"%v" + testSep + // Mailbox,
"%v" + testSep + // BridgePassword,
"%v" + testSep + // Version string
"%v" + testSep + // Timestamp,
"%v" + testSep + // IsHidden,
"%v" // IsCombinedAddressMode
// the best would be to run this test on mac, win, and linux separately
type testCredentials struct {
UserID,
Name,
Emails,
APIToken,
Mailbox,
BridgePassword,
Version string
Timestamp int64
IsHidden,
IsCombinedAddressMode bool
}
func init() { //nolint[gochecknoinits]
gob.Register(testCredentials{})
}
func (s *testCredentials) MarshalGob() string {
buf := bytes.Buffer{}
enc := gob.NewEncoder(&buf)
if err := enc.Encode(s); err != nil {
return ""
}
fmt.Printf("MarshalGob: %#v\n", buf.String())
return base64.StdEncoding.EncodeToString(buf.Bytes())
}
func (s *testCredentials) Clear() {
s.UserID = ""
s.Name = ""
s.Emails = ""
s.APIToken = ""
s.Mailbox = ""
s.BridgePassword = ""
s.Version = ""
s.Timestamp = 0
s.IsHidden = false
s.IsCombinedAddressMode = false
}
func (s *testCredentials) UnmarshalGob(secret string) error {
s.Clear()
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
fmt.Println("decode base64", b)
return err
}
buf := bytes.NewBuffer(b)
dec := gob.NewDecoder(buf)
if err = dec.Decode(s); err != nil {
fmt.Println("decode gob", b, buf.Bytes())
return err
}
return nil
}
func (s *testCredentials) ToJSON() string {
if b, err := json.Marshal(s); err == nil {
fmt.Printf("MarshalJSON: %#v\n", string(b))
return base64.StdEncoding.EncodeToString(b)
}
return ""
}
func (s *testCredentials) FromJSON(secret string) error {
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
return err
}
if err = json.Unmarshal(b, s); err == nil {
return nil
}
return err
}
func (s *testCredentials) MarshalFmt() string {
buf := bytes.Buffer{}
fmt.Fprintf(
&buf, secretFormat,
s.UserID,
s.Name,
s.Emails,
s.APIToken,
s.Mailbox,
s.BridgePassword,
s.Version,
s.Timestamp,
s.IsHidden,
s.IsCombinedAddressMode,
)
fmt.Printf("MarshalFmt: %#v\n", buf.String())
return base64.StdEncoding.EncodeToString(buf.Bytes())
}
func (s *testCredentials) UnmarshalFmt(secret string) error {
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
return err
}
buf := bytes.NewBuffer(b)
fmt.Println("decode fmt", b, buf.Bytes())
_, err = fmt.Fscanf(
buf, secretFormat,
&s.UserID,
&s.Name,
&s.Emails,
&s.APIToken,
&s.Mailbox,
&s.BridgePassword,
&s.Version,
&s.Timestamp,
&s.IsHidden,
&s.IsCombinedAddressMode,
)
if err != nil {
return err
}
return nil
}
func (s *testCredentials) MarshalStrings() string { // this is the most space efficient
items := []string{
s.UserID, // 0
s.Name, // 1
s.Emails, // 2
s.APIToken, // 3
s.Mailbox, // 4
s.BridgePassword, // 5
s.Version, // 6
}
items = append(items, fmt.Sprint(s.Timestamp)) // 7
if s.IsHidden { // 8
items = append(items, "1")
} else {
items = append(items, "")
}
if s.IsCombinedAddressMode { // 9
items = append(items, "1")
} else {
items = append(items, "")
}
str := strings.Join(items, sep)
fmt.Printf("MarshalJoin: %#v\n", str)
return base64.StdEncoding.EncodeToString([]byte(str))
}
func (s *testCredentials) UnmarshalStrings(secret string) error {
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
return err
}
items := strings.Split(string(b), sep)
if len(items) != 10 {
return ErrWrongFormat
}
s.UserID = items[0]
s.Name = items[1]
s.Emails = items[2]
s.APIToken = items[3]
s.Mailbox = items[4]
s.BridgePassword = items[5]
s.Version = items[6]
if _, err = fmt.Sscanf(items[7], "%d", &s.Timestamp); err != nil {
s.Timestamp = 0
}
if s.IsHidden = false; items[8] == "1" {
s.IsHidden = true
}
if s.IsCombinedAddressMode = false; items[9] == "1" {
s.IsCombinedAddressMode = true
}
return nil
}
func (s *testCredentials) IsSame(rhs *testCredentials) bool {
return s.Name == rhs.Name &&
s.Emails == rhs.Emails &&
s.APIToken == rhs.APIToken &&
s.Mailbox == rhs.Mailbox &&
s.BridgePassword == rhs.BridgePassword &&
s.Version == rhs.Version &&
s.Timestamp == rhs.Timestamp &&
s.IsHidden == rhs.IsHidden &&
s.IsCombinedAddressMode == rhs.IsCombinedAddressMode
}
func TestMarshalFormats(t *testing.T) {
input := testCredentials{UserID: "007", Emails: "ja@pm.me;jakub@cu.th", Timestamp: 152469263742, IsHidden: true}
fmt.Printf("input %#v\n", input)
secretStrings := input.MarshalStrings()
fmt.Printf("secretStrings %#v %d\n", secretStrings, len(secretStrings))
secretGob := input.MarshalGob()
fmt.Printf("secretGob %#v %d\n", secretGob, len(secretGob))
secretJSON := input.ToJSON()
fmt.Printf("secretJSON %#v %d\n", secretJSON, len(secretJSON))
secretFmt := input.MarshalFmt()
fmt.Printf("secretFmt %#v %d\n", secretFmt, len(secretFmt))
output := testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalStrings(secretStrings))
fmt.Printf("strings out %#v \n", output)
require.True(t, input.IsSame(&output), "strings out not same")
output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalGob(secretGob))
fmt.Printf("gob out %#v\n \n", output)
assert.Equal(t, input, output)
output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.FromJSON(secretJSON))
fmt.Printf("json out %#v \n", output)
require.True(t, input.IsSame(&output), "json out not same")
/*
// Simple Fscanf not working!
output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalFmt(secretFmt))
fmt.Printf("fmt out %#v \n", output)
require.True(t, input.IsSame(&output), "fmt out not same")
*/
}
func TestMarshal(t *testing.T) {
input := Credentials{
UserID: "",
Name: "007",
Emails: "ja@pm.me;aj@cus.tom",
APIToken: "sdfdsfsdfsdfsdf",
MailboxPassword: "cdcdcdcd",
BridgePassword: "wew123",
Version: "k11",
Timestamp: 152469263742,
IsHidden: true,
IsCombinedAddressMode: false,
}
fmt.Printf("input %#v\n", input)
secret := input.Marshal()
fmt.Printf("secret %#v %d\n", secret, len(secret))
output := Credentials{APIToken: "refresh"}
require.NoError(t, output.Unmarshal(secret))
fmt.Printf("output %#v\n", output)
assert.Equal(t, input, output)
}

View File

@ -0,0 +1,107 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./listener/listener.go
// Package users is a generated GoMock package.
package users
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockListener is a mock of Listener interface
type MockListener struct {
ctrl *gomock.Controller
recorder *MockListenerMockRecorder
}
// MockListenerMockRecorder is the mock recorder for MockListener
type MockListenerMockRecorder struct {
mock *MockListener
}
// NewMockListener creates a new mock instance
func NewMockListener(ctrl *gomock.Controller) *MockListener {
mock := &MockListener{ctrl: ctrl}
mock.recorder = &MockListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
return m.recorder
}
// SetLimit mocks base method
func (m *MockListener) SetLimit(eventName string, limit time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLimit", eventName, limit)
}
// SetLimit indicates an expected call of SetLimit
func (mr *MockListenerMockRecorder) SetLimit(eventName, limit interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), eventName, limit)
}
// Add mocks base method
func (m *MockListener) Add(eventName string, channel chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", eventName, channel)
}
// Add indicates an expected call of Add
func (mr *MockListenerMockRecorder) Add(eventName, channel interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), eventName, channel)
}
// Remove mocks base method
func (m *MockListener) Remove(eventName string, channel chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", eventName, channel)
}
// Remove indicates an expected call of Remove
func (mr *MockListenerMockRecorder) Remove(eventName, channel interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), eventName, channel)
}
// Emit mocks base method
func (m *MockListener) Emit(eventName, data string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Emit", eventName, data)
}
// Emit indicates an expected call of Emit
func (mr *MockListenerMockRecorder) Emit(eventName, data interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), eventName, data)
}
// SetBuffer mocks base method
func (m *MockListener) SetBuffer(eventName string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetBuffer", eventName)
}
// SetBuffer indicates an expected call of SetBuffer
func (mr *MockListenerMockRecorder) SetBuffer(eventName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), eventName)
}
// RetryEmit mocks base method
func (m *MockListener) RetryEmit(eventName string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RetryEmit", eventName)
}
// RetryEmit indicates an expected call of RetryEmit
func (mr *MockListenerMockRecorder) RetryEmit(eventName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), eventName)
}

View File

@ -0,0 +1,485 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/users (interfaces: Configer,PreferenceProvider,PanicHandler,ClientManager,CredentialsStorer)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
)
// MockConfiger is a mock of Configer interface
type MockConfiger struct {
ctrl *gomock.Controller
recorder *MockConfigerMockRecorder
}
// MockConfigerMockRecorder is the mock recorder for MockConfiger
type MockConfigerMockRecorder struct {
mock *MockConfiger
}
// NewMockConfiger creates a new mock instance
func NewMockConfiger(ctrl *gomock.Controller) *MockConfiger {
mock := &MockConfiger{ctrl: ctrl}
mock.recorder = &MockConfigerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockConfiger) EXPECT() *MockConfigerMockRecorder {
return m.recorder
}
// ClearData mocks base method
func (m *MockConfiger) ClearData() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClearData")
ret0, _ := ret[0].(error)
return ret0
}
// ClearData indicates an expected call of ClearData
func (mr *MockConfigerMockRecorder) ClearData() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearData", reflect.TypeOf((*MockConfiger)(nil).ClearData))
}
// GetAPIConfig mocks base method
func (m *MockConfiger) GetAPIConfig() *pmapi.ClientConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAPIConfig")
ret0, _ := ret[0].(*pmapi.ClientConfig)
return ret0
}
// GetAPIConfig indicates an expected call of GetAPIConfig
func (mr *MockConfigerMockRecorder) GetAPIConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIConfig", reflect.TypeOf((*MockConfiger)(nil).GetAPIConfig))
}
// GetDBDir mocks base method
func (m *MockConfiger) GetDBDir() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDBDir")
ret0, _ := ret[0].(string)
return ret0
}
// GetDBDir indicates an expected call of GetDBDir
func (mr *MockConfigerMockRecorder) GetDBDir() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDBDir", reflect.TypeOf((*MockConfiger)(nil).GetDBDir))
}
// GetIMAPCachePath mocks base method
func (m *MockConfiger) GetIMAPCachePath() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetIMAPCachePath")
ret0, _ := ret[0].(string)
return ret0
}
// GetIMAPCachePath indicates an expected call of GetIMAPCachePath
func (mr *MockConfigerMockRecorder) GetIMAPCachePath() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIMAPCachePath", reflect.TypeOf((*MockConfiger)(nil).GetIMAPCachePath))
}
// GetVersion mocks base method
func (m *MockConfiger) GetVersion() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetVersion")
ret0, _ := ret[0].(string)
return ret0
}
// GetVersion indicates an expected call of GetVersion
func (mr *MockConfigerMockRecorder) GetVersion() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockConfiger)(nil).GetVersion))
}
// MockPreferenceProvider is a mock of PreferenceProvider interface
type MockPreferenceProvider struct {
ctrl *gomock.Controller
recorder *MockPreferenceProviderMockRecorder
}
// MockPreferenceProviderMockRecorder is the mock recorder for MockPreferenceProvider
type MockPreferenceProviderMockRecorder struct {
mock *MockPreferenceProvider
}
// NewMockPreferenceProvider creates a new mock instance
func NewMockPreferenceProvider(ctrl *gomock.Controller) *MockPreferenceProvider {
mock := &MockPreferenceProvider{ctrl: ctrl}
mock.recorder = &MockPreferenceProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockPreferenceProvider) EXPECT() *MockPreferenceProviderMockRecorder {
return m.recorder
}
// Get mocks base method
func (m *MockPreferenceProvider) Get(arg0 string) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0)
ret0, _ := ret[0].(string)
return ret0
}
// Get indicates an expected call of Get
func (mr *MockPreferenceProviderMockRecorder) Get(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPreferenceProvider)(nil).Get), arg0)
}
// GetBool mocks base method
func (m *MockPreferenceProvider) GetBool(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBool", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// GetBool indicates an expected call of GetBool
func (mr *MockPreferenceProviderMockRecorder) GetBool(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBool", reflect.TypeOf((*MockPreferenceProvider)(nil).GetBool), arg0)
}
// GetInt mocks base method
func (m *MockPreferenceProvider) GetInt(arg0 string) int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetInt", arg0)
ret0, _ := ret[0].(int)
return ret0
}
// GetInt indicates an expected call of GetInt
func (mr *MockPreferenceProviderMockRecorder) GetInt(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInt", reflect.TypeOf((*MockPreferenceProvider)(nil).GetInt), arg0)
}
// Set mocks base method
func (m *MockPreferenceProvider) Set(arg0, arg1 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Set", arg0, arg1)
}
// Set indicates an expected call of Set
func (mr *MockPreferenceProviderMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockPreferenceProvider)(nil).Set), arg0, arg1)
}
// MockPanicHandler is a mock of PanicHandler interface
type MockPanicHandler struct {
ctrl *gomock.Controller
recorder *MockPanicHandlerMockRecorder
}
// MockPanicHandlerMockRecorder is the mock recorder for MockPanicHandler
type MockPanicHandlerMockRecorder struct {
mock *MockPanicHandler
}
// NewMockPanicHandler creates a new mock instance
func NewMockPanicHandler(ctrl *gomock.Controller) *MockPanicHandler {
mock := &MockPanicHandler{ctrl: ctrl}
mock.recorder = &MockPanicHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockPanicHandler) EXPECT() *MockPanicHandlerMockRecorder {
return m.recorder
}
// HandlePanic mocks base method
func (m *MockPanicHandler) HandlePanic() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "HandlePanic")
}
// HandlePanic indicates an expected call of HandlePanic
func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
}
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// AllowProxy mocks base method
func (m *MockClientManager) AllowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AllowProxy")
}
// AllowProxy indicates an expected call of AllowProxy
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
}
// CheckConnection mocks base method
func (m *MockClientManager) CheckConnection() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CheckConnection")
ret0, _ := ret[0].(error)
return ret0
}
// CheckConnection indicates an expected call of CheckConnection
func (mr *MockClientManagerMockRecorder) CheckConnection() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConnection", reflect.TypeOf((*MockClientManager)(nil).CheckConnection))
}
// DisallowProxy mocks base method
func (m *MockClientManager) DisallowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisallowProxy")
}
// DisallowProxy indicates an expected call of DisallowProxy
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
}
// GetAnonymousClient mocks base method
func (m *MockClientManager) GetAnonymousClient() pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAnonymousClient")
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetAnonymousClient indicates an expected call of GetAnonymousClient
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
}
// GetAuthUpdateChannel mocks base method
func (m *MockClientManager) GetAuthUpdateChannel() chan pmapi.ClientAuth {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthUpdateChannel")
ret0, _ := ret[0].(chan pmapi.ClientAuth)
return ret0
}
// GetAuthUpdateChannel indicates an expected call of GetAuthUpdateChannel
func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetAuthUpdateChannel))
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// SetUserAgent mocks base method
func (m *MockClientManager) SetUserAgent(arg0, arg1, arg2 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetUserAgent", arg0, arg1, arg2)
}
// SetUserAgent indicates an expected call of SetUserAgent
func (mr *MockClientManagerMockRecorder) SetUserAgent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserAgent", reflect.TypeOf((*MockClientManager)(nil).SetUserAgent), arg0, arg1, arg2)
}
// MockCredentialsStorer is a mock of CredentialsStorer interface
type MockCredentialsStorer struct {
ctrl *gomock.Controller
recorder *MockCredentialsStorerMockRecorder
}
// MockCredentialsStorerMockRecorder is the mock recorder for MockCredentialsStorer
type MockCredentialsStorerMockRecorder struct {
mock *MockCredentialsStorer
}
// NewMockCredentialsStorer creates a new mock instance
func NewMockCredentialsStorer(ctrl *gomock.Controller) *MockCredentialsStorer {
mock := &MockCredentialsStorer{ctrl: ctrl}
mock.recorder = &MockCredentialsStorerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockCredentialsStorer) EXPECT() *MockCredentialsStorerMockRecorder {
return m.recorder
}
// Add mocks base method
func (m *MockCredentialsStorer) Add(arg0, arg1, arg2, arg3 string, arg4 []string) (*credentials.Credentials, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*credentials.Credentials)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Add indicates an expected call of Add
func (mr *MockCredentialsStorerMockRecorder) Add(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockCredentialsStorer)(nil).Add), arg0, arg1, arg2, arg3, arg4)
}
// Delete mocks base method
func (m *MockCredentialsStorer) Delete(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete
func (mr *MockCredentialsStorerMockRecorder) Delete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockCredentialsStorer)(nil).Delete), arg0)
}
// Get mocks base method
func (m *MockCredentialsStorer) Get(arg0 string) (*credentials.Credentials, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0)
ret0, _ := ret[0].(*credentials.Credentials)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get
func (mr *MockCredentialsStorerMockRecorder) Get(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCredentialsStorer)(nil).Get), arg0)
}
// List mocks base method
func (m *MockCredentialsStorer) List() ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List")
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List
func (mr *MockCredentialsStorerMockRecorder) List() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockCredentialsStorer)(nil).List))
}
// Logout mocks base method
func (m *MockCredentialsStorer) Logout(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Logout", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Logout indicates an expected call of Logout
func (mr *MockCredentialsStorerMockRecorder) Logout(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logout", reflect.TypeOf((*MockCredentialsStorer)(nil).Logout), arg0)
}
// SwitchAddressMode mocks base method
func (m *MockCredentialsStorer) SwitchAddressMode(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SwitchAddressMode", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SwitchAddressMode indicates an expected call of SwitchAddressMode
func (mr *MockCredentialsStorerMockRecorder) SwitchAddressMode(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SwitchAddressMode", reflect.TypeOf((*MockCredentialsStorer)(nil).SwitchAddressMode), arg0)
}
// UpdateEmails mocks base method
func (m *MockCredentialsStorer) UpdateEmails(arg0 string, arg1 []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateEmails", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateEmails indicates an expected call of UpdateEmails
func (mr *MockCredentialsStorerMockRecorder) UpdateEmails(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEmails", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateEmails), arg0, arg1)
}
// UpdatePassword mocks base method
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpdatePassword indicates an expected call of UpdatePassword
func (mr *MockCredentialsStorerMockRecorder) UpdatePassword(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdatePassword), arg0, arg1)
}
// UpdateToken mocks base method
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateToken", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateToken indicates an expected call of UpdateToken
func (mr *MockCredentialsStorerMockRecorder) UpdateToken(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateToken), arg0, arg1)
}

64
internal/users/types.go Normal file
View File

@ -0,0 +1,64 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type Configer interface {
ClearData() error
GetDBDir() string
GetVersion() string
GetIMAPCachePath() string
GetAPIConfig() *pmapi.ClientConfig
}
type PreferenceProvider interface {
Get(key string) string
GetBool(key string) bool
GetInt(key string) int
Set(key string, value string)
}
type PanicHandler interface {
HandlePanic()
}
type CredentialsStorer interface {
List() (userIDs []string, err error)
Add(userID, userName, apiToken, mailboxPassword string, emails []string) (*credentials.Credentials, error)
Get(userID string) (*credentials.Credentials, error)
SwitchAddressMode(userID string) error
UpdateEmails(userID string, emails []string) error
UpdatePassword(userID, password string) error
UpdateToken(userID, apiToken string) error
Logout(userID string) error
Delete(userID string) error
}
type ClientManager interface {
GetClient(userID string) pmapi.Client
GetAnonymousClient() pmapi.Client
AllowProxy()
DisallowProxy()
GetAuthUpdateChannel() chan pmapi.ClientAuth
CheckConnection() error
SetUserAgent(clientName, clientVersion, os string)
}

588
internal/users/user.go Normal file
View File

@ -0,0 +1,588 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"fmt"
"path/filepath"
"runtime"
"strings"
"sync"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
imapBackend "github.com/emersion/go-imap/backend"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// ErrLoggedOutUser is sent to IMAP and SMTP if user exists, password is OK but user is logged out from the app.
var ErrLoggedOutUser = errors.New("account is logged out, use the app to login again")
// User is a struct on top of API client and credentials store.
type User struct {
log *logrus.Entry
panicHandler PanicHandler
listener listener.Listener
clientManager ClientManager
credStorer CredentialsStorer
imapUpdatesChannel chan imapBackend.Update
store *store.Store
storeCache *store.Cache
storePath string
userID string
creds *credentials.Credentials
lock sync.RWMutex
isAuthorized bool
unlockingKeyringLock sync.Mutex
wasKeyringUnlocked bool
}
// newUser creates a new user.
func newUser(
panicHandler PanicHandler,
userID string,
eventListener listener.Listener,
credStorer CredentialsStorer,
clientManager ClientManager,
storeCache *store.Cache,
storeDir string,
) (u *User, err error) {
log := log.WithField("user", userID)
log.Debug("Creating or loading user")
creds, err := credStorer.Get(userID)
if err != nil {
return nil, errors.Wrap(err, "failed to load user credentials")
}
u = &User{
log: log,
panicHandler: panicHandler,
listener: eventListener,
credStorer: credStorer,
clientManager: clientManager,
storeCache: storeCache,
storePath: getUserStorePath(storeDir, userID),
userID: userID,
creds: creds,
}
return
}
func (u *User) client() pmapi.Client {
return u.clientManager.GetClient(u.userID)
}
// init initialises a user. This includes reloading its credentials from the credentials store
// (such as when logging out and back in, you need to reload the credentials because the new credentials will
// have the apitoken and password), authorising the user against the api, loading the user store (creating a new one
// if necessary), and setting the imap idle updates channel (used to send imap idle updates to the imap backend if
// something in the store changed).
func (u *User) init(idleUpdates chan imapBackend.Update) (err error) {
u.unlockingKeyringLock.Lock()
u.wasKeyringUnlocked = false
u.unlockingKeyringLock.Unlock()
u.log.Info("Initialising user")
// Reload the user's credentials (if they log out and back in we need the new
// version with the apitoken and mailbox password).
creds, err := u.credStorer.Get(u.userID)
if err != nil {
return errors.Wrap(err, "failed to load user credentials")
}
u.creds = creds
// Try to authorise the user if they aren't already authorised.
// Note: we still allow users to set up accounts if the internet is off.
if authErr := u.authorizeIfNecessary(false); authErr != nil {
switch errors.Cause(authErr) {
case pmapi.ErrAPINotReachable, pmapi.ErrUpgradeApplication, ErrLoggedOutUser:
u.log.WithError(authErr).Warn("Could not authorize user")
default:
if logoutErr := u.logout(); logoutErr != nil {
u.log.WithError(logoutErr).Warn("Could not logout user")
}
return errors.Wrap(authErr, "failed to authorize user")
}
}
// Logged-out user keeps store running to access offline data.
// Therefore it is necessary to close it before re-init.
if u.store != nil {
if err := u.store.Close(); err != nil {
log.WithError(err).Error("Not able to close store")
}
u.store = nil
}
store, err := store.New(u.panicHandler, u, u.clientManager, u.listener, u.storePath, u.storeCache)
if err != nil {
return errors.Wrap(err, "failed to create store")
}
u.store = store
// Save the imap updates channel here so it can be set later when imap connects.
u.imapUpdatesChannel = idleUpdates
return err
}
func (u *User) SetIMAPIdleUpdateChannel() {
if u.store == nil {
return
}
u.store.SetIMAPUpdateChannel(u.imapUpdatesChannel)
}
// authorizeIfNecessary checks whether user is logged in and is connected to api auth channel.
// If user is not already connected to the api auth channel (for example there was no internet during start),
// it tries to connect it.
func (u *User) authorizeIfNecessary(emitEvent bool) (err error) {
// If user is connected and has an auth channel, then perfect, nothing to do here.
if u.creds.IsConnected() && u.isAuthorized {
// The keyring unlock is triggered here to resolve state where apiClient
// is authenticated (we have auth token) but it was not possible to download
// and unlock the keys (internet not reachable).
return u.unlockIfNecessary()
}
if !u.creds.IsConnected() {
err = ErrLoggedOutUser
} else if err = u.authorizeAndUnlock(); err != nil {
u.log.WithError(err).Error("Could not authorize and unlock user")
switch errors.Cause(err) {
case pmapi.ErrUpgradeApplication:
u.listener.Emit(events.UpgradeApplicationEvent, "")
case pmapi.ErrAPINotReachable:
u.listener.Emit(events.InternetOffEvent, "")
default:
if errLogout := u.credStorer.Logout(u.userID); errLogout != nil {
u.log.WithField("err", errLogout).Error("Could not log user out from credentials store")
}
}
}
if emitEvent && err != nil &&
errors.Cause(err) != pmapi.ErrUpgradeApplication &&
errors.Cause(err) != pmapi.ErrAPINotReachable {
u.listener.Emit(events.LogoutEvent, u.userID)
}
return err
}
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
func (u *User) unlockIfNecessary() error {
u.unlockingKeyringLock.Lock()
defer u.unlockingKeyringLock.Unlock()
if u.wasKeyringUnlocked {
return nil
}
if _, err := u.client().Unlock(u.creds.MailboxPassword); err != nil {
return errors.Wrap(err, "failed to unlock user")
}
if err := u.client().UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to unlock user addresses")
}
u.wasKeyringUnlocked = true
return nil
}
// authorizeAndUnlock tries to authorize the user with the API using the the user's APIToken.
// If that succeeds, it tries to unlock the user's keys and addresses.
func (u *User) authorizeAndUnlock() (err error) {
if u.creds.APIToken == "" {
u.log.Warn("Could not connect to API auth channel, have no API token")
return nil
}
if _, err := u.client().AuthRefresh(u.creds.APIToken); err != nil {
return errors.Wrap(err, "failed to refresh API auth")
}
if _, err = u.client().Unlock(u.creds.MailboxPassword); err != nil {
return errors.Wrap(err, "failed to unlock user")
}
if err = u.client().UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to unlock user addresses")
}
return nil
}
func (u *User) updateAuthToken(auth *pmapi.Auth) {
u.log.Debug("User received auth")
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
return
}
u.refreshFromCredentials()
u.isAuthorized = true
}
// clearStore removes the database.
func (u *User) clearStore() error {
u.log.Trace("Clearing user store")
if u.store != nil {
if err := u.store.Remove(); err != nil {
return errors.Wrap(err, "failed to remove store")
}
} else {
u.log.Warn("Store is not initialized: cleaning up store files manually")
if err := store.RemoveStore(u.storeCache, u.storePath, u.userID); err != nil {
return errors.Wrap(err, "failed to remove store manually")
}
}
return nil
}
// closeStore just closes the store without deleting it.
func (u *User) closeStore() error {
u.log.Trace("Closing user store")
if u.store != nil {
if err := u.store.Close(); err != nil {
return errors.Wrap(err, "failed to close store")
}
}
return nil
}
// getUserStorePath returns the file path of the store database for the given userID.
func getUserStorePath(storeDir string, userID string) (path string) {
fileName := fmt.Sprintf("mailbox-%v.db", userID)
return filepath.Join(storeDir, fileName)
}
// GetTemporaryPMAPIClient returns an authorised PMAPI client.
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
// After proper refactor of SMTP and IMAP remove this method.
func (u *User) GetTemporaryPMAPIClient() pmapi.Client {
return u.client()
}
// ID returns the user's userID.
func (u *User) ID() string {
return u.userID
}
// Username returns the user's username as found in the user's credentials.
func (u *User) Username() string {
u.lock.RLock()
defer u.lock.RUnlock()
return u.creds.Name
}
// IsConnected returns whether user is logged in.
func (u *User) IsConnected() bool {
u.lock.RLock()
defer u.lock.RUnlock()
return u.creds.IsConnected()
}
// IsCombinedAddressMode returns whether user is set in combined or split mode.
// Combined mode is the default mode and is what users typically need.
// Split mode is mostly for outlook as it cannot handle sending e-mails from an
// address other than the primary one.
func (u *User) IsCombinedAddressMode() bool {
if u.store != nil {
return u.store.IsCombinedMode()
}
return u.creds.IsCombinedAddressMode
}
// GetPrimaryAddress returns the user's original address (which is
// not necessarily the same as the primary address, because a primary address
// might be an alias and be in position one).
func (u *User) GetPrimaryAddress() string {
u.lock.RLock()
defer u.lock.RUnlock()
return u.creds.EmailList()[0]
}
// GetStoreAddresses returns all addresses used by the store (so in combined mode,
// that's just the original address, but in split mode, that's all active addresses).
func (u *User) GetStoreAddresses() []string {
u.lock.RLock()
defer u.lock.RUnlock()
if u.IsCombinedAddressMode() {
return u.creds.EmailList()[:1]
}
return u.creds.EmailList()
}
// getStoreAddresses returns a user's used addresses (with the original address in first place).
func (u *User) getStoreAddresses() []string { // nolint[unused]
addrInfo, err := u.store.GetAddressInfo()
if err != nil {
u.log.WithError(err).Error("Failed getting address info from store")
return nil
}
addresses := []string{}
for _, addr := range addrInfo {
addresses = append(addresses, addr.Address)
}
if u.IsCombinedAddressMode() {
return addresses[:1]
}
return addresses
}
// GetAddresses returns list of all addresses.
func (u *User) GetAddresses() []string {
u.lock.RLock()
defer u.lock.RUnlock()
return u.creds.EmailList()
}
// GetAddressID returns the API ID of the given address.
func (u *User) GetAddressID(address string) (id string, err error) {
u.lock.RLock()
defer u.lock.RUnlock()
address = strings.ToLower(address)
if u.store == nil {
err = errors.New("store is not initialised")
return
}
return u.store.GetAddressID(address)
}
// GetBridgePassword returns bridge password. This is not a password of the PM
// account, but generated password for local purposes to not use a PM account
// in the clients (such as Thunderbird).
func (u *User) GetBridgePassword() string {
u.lock.RLock()
defer u.lock.RUnlock()
return u.creds.BridgePassword
}
// CheckBridgeLogin checks whether the user is logged in and the bridge
// IMAP/SMTP password is correct.
func (u *User) CheckBridgeLogin(password string) error {
if isApplicationOutdated {
u.listener.Emit(events.UpgradeApplicationEvent, "")
return pmapi.ErrUpgradeApplication
}
u.lock.RLock()
defer u.lock.RUnlock()
// True here because users should be notified by popup of auth failure.
if err := u.authorizeIfNecessary(true); err != nil {
u.log.WithError(err).Error("Failed to authorize user")
return err
}
return u.creds.CheckPassword(password)
}
// UpdateUser updates user details from API and saves to the credentials.
func (u *User) UpdateUser() error {
u.lock.Lock()
defer u.lock.Unlock()
if err := u.authorizeIfNecessary(true); err != nil {
return errors.Wrap(err, "cannot update user")
}
_, err := u.client().UpdateUser()
if err != nil {
return err
}
if _, err = u.client().Unlock(u.creds.MailboxPassword); err != nil {
return err
}
if err := u.client().UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
return err
}
emails := u.client().Addresses().ActiveEmails()
if err := u.credStorer.UpdateEmails(u.userID, emails); err != nil {
return err
}
u.refreshFromCredentials()
return nil
}
// SwitchAddressMode changes mode from combined to split and vice versa. The mode to switch to is determined by the
// state of the user's credentials in the credentials store. See `IsCombinedAddressMode` for more details.
func (u *User) SwitchAddressMode() (err error) {
u.log.Trace("Switching user address mode")
u.lock.Lock()
defer u.lock.Unlock()
u.closeAllConnections()
if u.store == nil {
err = errors.New("store is not initialised")
return
}
newAddressModeState := !u.IsCombinedAddressMode()
if err = u.store.UseCombinedMode(newAddressModeState); err != nil {
u.log.WithError(err).Error("Could not switch store address mode")
return
}
if u.creds.IsCombinedAddressMode != newAddressModeState {
if err = u.credStorer.SwitchAddressMode(u.userID); err != nil {
u.log.WithError(err).Error("Could not switch credentials store address mode")
return
}
}
u.refreshFromCredentials()
return err
}
// logout is the same as Logout, but for internal purposes (logged out from
// the server) which emits LogoutEvent to notify other parts of the app.
func (u *User) logout() error {
u.lock.Lock()
wasConnected := u.creds.IsConnected()
u.lock.Unlock()
err := u.Logout()
if wasConnected {
u.listener.Emit(events.LogoutEvent, u.userID)
u.listener.Emit(events.UserRefreshEvent, u.userID)
}
u.isAuthorized = false
return err
}
// Logout logs out the user from pmapi, the credentials store, the mail store, and tries to remove as much
// sensitive data as possible.
func (u *User) Logout() (err error) {
u.lock.Lock()
defer u.lock.Unlock()
u.log.Debug("Logging out user")
if !u.creds.IsConnected() {
return
}
u.unlockingKeyringLock.Lock()
u.wasKeyringUnlocked = false
u.unlockingKeyringLock.Unlock()
u.client().Logout()
if err = u.credStorer.Logout(u.userID); err != nil {
u.log.WithError(err).Warn("Could not log user out from credentials store")
if err = u.credStorer.Delete(u.userID); err != nil {
u.log.WithError(err).Error("Could not delete user from credentials store")
}
}
u.refreshFromCredentials()
// Do not close whole store, just event loop. Some information might be needed offline (e.g. addressID)
u.closeEventLoop()
u.closeAllConnections()
runtime.GC()
return err
}
func (u *User) refreshFromCredentials() {
if credentials, err := u.credStorer.Get(u.userID); err != nil {
log.WithError(err).Error("Cannot refresh user credentials")
} else {
u.creds = credentials
}
}
func (u *User) closeEventLoop() {
if u.store == nil {
return
}
u.store.CloseEventLoop()
}
// closeAllConnections calls CloseConnection for all users addresses.
func (u *User) closeAllConnections() {
for _, address := range u.creds.EmailList() {
u.CloseConnection(address)
}
if u.store != nil {
u.store.SetIMAPUpdateChannel(nil)
}
}
// CloseConnection emits closeConnection event on `address` which should close all active connection.
func (u *User) CloseConnection(address string) {
u.listener.Emit(events.CloseConnectionEvent, address)
}
func (u *User) GetStore() *store.Store {
return u.store
}

View File

@ -0,0 +1,234 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestUpdateUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
gomock.InOrder(
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
)
assert.NoError(t, user.UpdateUser())
waitForEvents()
}
func TestUserSwitchAddressMode(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
assert.True(t, user.store.IsCombinedMode())
assert.True(t, user.creds.IsCombinedAddressMode)
assert.True(t, user.IsCombinedAddressMode())
waitForEvents()
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsSplit, nil),
)
assert.NoError(t, user.SwitchAddressMode())
assert.False(t, user.store.IsCombinedMode())
assert.False(t, user.creds.IsCombinedAddressMode)
assert.False(t, user.IsCombinedAddressMode())
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
assert.NoError(t, user.SwitchAddressMode())
assert.True(t, user.store.IsCombinedMode())
assert.True(t, user.creds.IsCombinedAddressMode)
assert.True(t, user.IsCombinedAddressMode())
waitForEvents()
}
func TestLogoutUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := user.Logout()
waitForEvents()
assert.NoError(t, err)
}
func TestLogoutUserFailsLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := user.Logout()
waitForEvents()
assert.NoError(t, err)
}
func TestCheckBridgeLoginOK(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
waitForEvents()
assert.NoError(t, err)
}
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
isApplicationOutdated = true
err := user.CheckBridgeLogin("any-pass")
waitForEvents()
assert.Equal(t, pmapi.ErrUpgradeApplication, err)
isApplicationOutdated = false
}
func TestCheckBridgeLoginLoggedOut(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
user, err := newUser(
m.PanicHandler, "user",
m.eventListener, m.credentialsStore,
m.clientManager, m.storeCache, "/tmp",
)
assert.NoError(t, err)
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
err = user.init(nil)
assert.Error(t, err)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
waitForEvents()
assert.Equal(t, ErrLoggedOutUser, err)
}
func TestCheckBridgeLoginBadPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
)
err := user.CheckBridgeLogin("wrong!")
waitForEvents()
assert.Equal(t, "backend/credentials: incorrect password", err.Error())
}

View File

@ -0,0 +1,177 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
a "github.com/stretchr/testify/assert"
)
func TestNewUserNoCredentialsStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
a.Error(t, err)
}
func TestNewUserAppOutdated(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrUpgradeApplication),
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, ""),
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrUpgradeApplication),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
checkNewUserHasCredentials(testCredentials, m)
}
func TestNewUserNoInternetConnection(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, pmapi.ErrAPINotReachable),
m.eventListener.EXPECT().Emit(events.InternetOffEvent, ""),
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrAPINotReachable),
m.pmapiClient.EXPECT().Addresses().Return(nil),
m.pmapiClient.EXPECT().GetEvent("").Return(nil, pmapi.ErrAPINotReachable).AnyTimes(),
)
checkNewUserHasCredentials(testCredentials, m)
}
func TestNewUserAuthRefreshFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserHasCredentials(testCredentialsDisconnected, m)
}
func TestNewUserUnlockFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, errors.New("bad password")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserHasCredentials(testCredentialsDisconnected, m)
}
func TestNewUserUnlockAddressesFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(errors.New("bad password")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserHasCredentials(testCredentialsDisconnected, m)
}
func TestNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
mockEventLoopNoAction(m)
checkNewUserHasCredentials(testCredentials, m)
}
func checkNewUserHasCredentials(creds *credentials.Credentials, m mocks) {
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
defer cleanUpUserData(user)
_ = user.init(nil)
waitForEvents()
a.Equal(m.t, creds, user.creds)
}
func _TestUserEventRefreshUpdatesAddresses(t *testing.T) { // nolint[funlen]
a.Fail(t, "not implemented")
}

101
internal/users/user_test.go Normal file
View File

@ -0,0 +1,101 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testNewUser sets up a new, authorised user.
func testNewUser(m mocks) *User {
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
gomock.InOrder(
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
assert.NoError(m.t, err)
err = user.init(nil)
assert.NoError(m.t, err)
mockAuthUpdate(user, "reftok", m)
return user
}
func testNewUserForLogout(m mocks) *User {
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
gomock.InOrder(
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeCache, "/tmp")
assert.NoError(m.t, err)
err = user.init(nil)
assert.NoError(m.t, err)
return user
}
func cleanUpUserData(u *User) {
_ = u.clearStore()
}
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
assert.Fail(t, "not implemented")
}
func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
require.Nil(t, user.store.Close())
user.store = nil
assert.Nil(t, user.clearStore())
}
func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
assert.NotNil(t, user.store)
assert.Nil(t, user.clearStore())
}

537
internal/users/users.go Normal file
View File

@ -0,0 +1,537 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package users provides core business logic providing API over credentials store and PM API.
package users
import (
"strconv"
"strings"
"sync"
"time"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/preferences"
"github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
imapBackend "github.com/emersion/go-imap/backend"
"github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
logrus "github.com/sirupsen/logrus"
)
var (
log = logrus.WithField("pkg", "users") //nolint[gochecknoglobals]
isApplicationOutdated = false //nolint[gochecknoglobals]
)
// Users is a struct handling users.
type Users struct {
config Configer
pref PreferenceProvider
panicHandler PanicHandler
events listener.Listener
clientManager ClientManager
credStorer CredentialsStorer
storeCache *store.Cache
// users is a list of accounts that have been added to the app.
// They are stored sorted in the credentials store in the order
// that they were added to the app chronologically.
// People are used to that and so we preserve that ordering here.
users []*User
// idleUpdates is a channel which the imap backend listens to and which it uses
// to send idle updates to the mail client (eg thunderbird).
// The user stores should send idle updates on this channel.
idleUpdates chan imapBackend.Update
lock sync.RWMutex
// stopAll can be closed to stop all goroutines from looping (watchAppOutdated, watchAPIAuths, heartbeat etc).
stopAll chan struct{}
}
func New(
config Configer,
pref PreferenceProvider,
panicHandler PanicHandler,
eventListener listener.Listener,
clientManager ClientManager,
credStorer CredentialsStorer,
) *Users {
log.Trace("Creating new users")
u := &Users{
config: config,
pref: pref,
panicHandler: panicHandler,
events: eventListener,
clientManager: clientManager,
credStorer: credStorer,
storeCache: store.NewCache(config.GetIMAPCachePath()),
idleUpdates: make(chan imapBackend.Update),
lock: sync.RWMutex{},
stopAll: make(chan struct{}),
}
// Allow DoH before starting the app if the user has previously set this setting.
// This allows us to start even if protonmail is blocked.
if pref.GetBool(preferences.AllowProxyKey) {
u.AllowProxy()
}
go func() {
defer panicHandler.HandlePanic()
u.watchAppOutdated()
}()
go func() {
defer panicHandler.HandlePanic()
u.watchAPIAuths()
}()
go u.heartbeat()
if u.credStorer == nil {
log.Error("No credentials store is available")
} else if err := u.loadUsersFromCredentialsStore(); err != nil {
log.WithError(err).Error("Could not load all users from credentials store")
}
if pref.GetBool(preferences.FirstStartKey) {
u.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(config.GetVersion())))
}
return u
}
// heartbeat sends a heartbeat signal once a day.
func (u *Users) heartbeat() {
ticker := time.NewTicker(1 * time.Minute)
for {
select {
case <-ticker.C:
next, err := strconv.ParseInt(u.pref.Get(preferences.NextHeartbeatKey), 10, 64)
if err != nil {
continue
}
nextTime := time.Unix(next, 0)
if time.Now().After(nextTime) {
u.SendMetric(metrics.New(metrics.Heartbeat, metrics.Daily, metrics.NoLabel))
nextTime = nextTime.Add(24 * time.Hour)
u.pref.Set(preferences.NextHeartbeatKey, strconv.FormatInt(nextTime.Unix(), 10))
}
case <-u.stopAll:
return
}
}
}
func (u *Users) loadUsersFromCredentialsStore() (err error) {
u.lock.Lock()
defer u.lock.Unlock()
userIDs, err := u.credStorer.List()
if err != nil {
return
}
for _, userID := range userIDs {
l := log.WithField("user", userID)
user, newUserErr := newUser(u.panicHandler, userID, u.events, u.credStorer, u.clientManager, u.storeCache, u.config.GetDBDir())
if newUserErr != nil {
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
continue
}
u.users = append(u.users, user)
if initUserErr := user.init(u.idleUpdates); initUserErr != nil {
l.WithField("user", userID).WithError(initUserErr).Warn("Could not initialise user")
}
}
return err
}
func (u *Users) watchAppOutdated() {
ch := make(chan string)
u.events.Add(events.UpgradeApplicationEvent, ch)
for {
select {
case <-ch:
isApplicationOutdated = true
u.closeAllConnections()
case <-u.stopAll:
return
}
}
}
// watchAPIAuths receives auths from the client manager and sends them to the appropriate user.
func (u *Users) watchAPIAuths() {
for {
select {
case auth := <-u.clientManager.GetAuthUpdateChannel():
log.Debug("Users received auth from ClientManager")
user, ok := u.hasUser(auth.UserID)
if !ok {
log.WithField("userID", auth.UserID).Info("User not available for auth update")
continue
}
if auth.Auth != nil {
user.updateAuthToken(auth.Auth)
} else if err := user.logout(); err != nil {
log.WithError(err).
WithField("userID", auth.UserID).
Error("User logout failed while watching API auths")
}
case <-u.stopAll:
return
}
}
}
func (u *Users) closeAllConnections() {
for _, user := range u.users {
user.closeAllConnections()
}
}
// Login authenticates a user.
// The login flow:
// * Authenticate user:
// client, auth, err := users.Authenticate(username, password)
//
// * In case user `auth.HasTwoFactor()`, ask for it and fully authenticate the user.
// auth2FA, err := client.Auth2FA(twoFactorCode)
//
// * In case user `auth.HasMailboxPassword()`, ask for it, otherwise use `password`
// and then finish the login procedure.
// user, err := users.FinishLogin(client, auth, mailboxPassword)
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
u.crashBandicoot(username)
// We need to use anonymous client because we don't yet have userID and so can't save auth tokens yet.
authClient = u.clientManager.GetAnonymousClient()
authInfo, err := authClient.AuthInfo(username)
if err != nil {
log.WithField("username", username).WithError(err).Error("Could not get auth info for user")
return
}
if auth, err = authClient.Auth(username, password, authInfo); err != nil {
log.WithField("username", username).WithError(err).Error("Could not get auth for user")
return
}
return
}
// FinishLogin finishes the login procedure and adds the user into the credentials store.
// See `Login` for more details of the login flow.
func (u *Users) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPassword string) (user *User, err error) { //nolint[funlen]
defer func() {
if err == pmapi.ErrUpgradeApplication {
u.events.Emit(events.UpgradeApplicationEvent, "")
}
if err != nil {
log.WithError(err).Debug("Login not finished; removing auth session")
if delAuthErr := authClient.DeleteAuth(); delAuthErr != nil {
log.WithError(delAuthErr).Error("Failed to clear login session after unlock")
}
}
// The anonymous client will be removed from list and authentication will not be deleted.
authClient.Logout()
}()
apiUser, hashedPassword, err := getAPIUser(authClient, auth, mbPassword)
if err != nil {
log.WithError(err).Error("Failed to get API user")
return
}
var ok bool
if user, ok = u.hasUser(apiUser.ID); ok {
if err = u.connectExistingUser(user, auth, hashedPassword); err != nil {
log.WithError(err).Error("Failed to connect existing user")
return
}
} else {
if err = u.addNewUser(apiUser, auth, hashedPassword); err != nil {
log.WithError(err).Error("Failed to add new user")
return
}
}
u.events.Emit(events.UserRefreshEvent, apiUser.ID)
return u.GetUser(apiUser.ID)
}
// connectExistingUser connects an existing user.
func (u *Users) connectExistingUser(user *User, auth *pmapi.Auth, hashedPassword string) (err error) {
if user.IsConnected() {
return errors.New("user is already connected")
}
// Update the user's password in the cred store in case they changed it.
if err = u.credStorer.UpdatePassword(user.ID(), hashedPassword); err != nil {
return errors.Wrap(err, "failed to update password of user in credentials store")
}
client := u.clientManager.GetClient(user.ID())
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
return errors.Wrap(err, "failed to refresh auth token of new client")
}
if err = u.credStorer.UpdateToken(user.ID(), auth.GenToken()); err != nil {
return errors.Wrap(err, "failed to update token of user in credentials store")
}
if err = user.init(u.idleUpdates); err != nil {
return errors.Wrap(err, "failed to initialise user")
}
return
}
// addNewUser adds a new user.
func (u *Users) addNewUser(apiUser *pmapi.User, auth *pmapi.Auth, hashedPassword string) (err error) {
u.lock.Lock()
defer u.lock.Unlock()
client := u.clientManager.GetClient(apiUser.ID)
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
return errors.Wrap(err, "failed to refresh token in new client")
}
if apiUser, err = client.CurrentUser(); err != nil {
return errors.Wrap(err, "failed to update API user")
}
activeEmails := client.Addresses().ActiveEmails()
if _, err = u.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), hashedPassword, activeEmails); err != nil {
return errors.Wrap(err, "failed to add user to credentials store")
}
user, err := newUser(u.panicHandler, apiUser.ID, u.events, u.credStorer, u.clientManager, u.storeCache, u.config.GetDBDir())
if err != nil {
return errors.Wrap(err, "failed to create user")
}
// The user needs to be part of the users list in order for it to receive an auth during initialisation.
u.users = append(u.users, user)
if err = user.init(u.idleUpdates); err != nil {
u.users = u.users[:len(u.users)-1]
return errors.Wrap(err, "failed to initialise user")
}
u.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel))
return err
}
func getAPIUser(client pmapi.Client, auth *pmapi.Auth, mbPassword string) (user *pmapi.User, hashedPassword string, err error) {
hashedPassword, err = pmapi.HashMailboxPassword(mbPassword, auth.KeySalt)
if err != nil {
log.WithError(err).Error("Could not hash mailbox password")
return
}
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
if _, err = client.Unlock(hashedPassword); err != nil {
log.WithError(err).Error("Wrong mailbox password")
return
}
if user, err = client.CurrentUser(); err != nil {
log.WithError(err).Error("Could not load API user")
return
}
return
}
// GetUsers returns all added users into keychain (even logged out users).
func (u *Users) GetUsers() []*User {
u.lock.RLock()
defer u.lock.RUnlock()
return u.users
}
// GetUser returns a user by `query` which is compared to users' ID, username or any attached e-mail address.
func (u *Users) GetUser(query string) (*User, error) {
u.crashBandicoot(query)
u.lock.RLock()
defer u.lock.RUnlock()
for _, user := range u.users {
if strings.EqualFold(user.ID(), query) || strings.EqualFold(user.Username(), query) {
return user, nil
}
for _, address := range user.GetAddresses() {
if strings.EqualFold(address, query) {
return user, nil
}
}
}
return nil, errors.New("user " + query + " not found")
}
// ClearData closes all connections (to release db files and so on) and clears all data.
func (u *Users) ClearData() error {
var result *multierror.Error
for _, user := range u.users {
if err := user.Logout(); err != nil {
result = multierror.Append(result, err)
}
if err := user.closeStore(); err != nil {
result = multierror.Append(result, err)
}
}
if err := u.config.ClearData(); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
}
// DeleteUser deletes user completely; it logs user out from the API, stops any
// active connection, deletes from credentials store and removes from the Bridge struct.
func (u *Users) DeleteUser(userID string, clearStore bool) error {
u.lock.Lock()
defer u.lock.Unlock()
log := log.WithField("user", userID)
for idx, user := range u.users {
if user.ID() == userID {
if err := user.Logout(); err != nil {
log.WithError(err).Error("Cannot logout user")
// We can try to continue to remove the user.
// Token will still be valid, but will expire eventually.
}
if err := user.closeStore(); err != nil {
log.WithError(err).Error("Failed to close user store")
}
if clearStore {
// Clear cache after closing connections (done in logout).
if err := user.clearStore(); err != nil {
log.WithError(err).Error("Failed to clear user")
}
}
if err := u.credStorer.Delete(userID); err != nil {
log.WithError(err).Error("Cannot remove user")
return err
}
u.users = append(u.users[:idx], u.users[idx+1:]...)
return nil
}
}
return errors.New("user " + userID + " not found")
}
// SendMetric sends a metric. We don't want to return any errors, only log them.
func (u *Users) SendMetric(m metrics.Metric) {
c := u.clientManager.GetAnonymousClient()
defer c.Logout()
cat, act, lab := m.Get()
if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil {
log.Error("Sending metric failed: ", err)
}
log.WithFields(logrus.Fields{
"cat": cat,
"act": act,
"lab": lab,
}).Debug("Metric successfully sent")
}
// GetIMAPUpdatesChannel sets the channel on which idle events should be sent.
func (u *Users) GetIMAPUpdatesChannel() chan imapBackend.Update {
if u.idleUpdates == nil {
log.Warn("IMAP updates channel is nil")
}
return u.idleUpdates
}
// AllowProxy instructs the app to use DoH to access an API proxy if necessary.
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
func (u *Users) AllowProxy() {
u.clientManager.AllowProxy()
}
// DisallowProxy instructs the app to not use DoH to access an API proxy if necessary.
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
func (u *Users) DisallowProxy() {
u.clientManager.DisallowProxy()
}
// CheckConnection returns whether there is an internet connection.
// This should use the connection manager when it is eventually implemented.
func (u *Users) CheckConnection() error {
return u.clientManager.CheckConnection()
}
// StopWatchers stops all goroutines.
func (u *Users) StopWatchers() {
close(u.stopAll)
}
// hasUser returns whether the struct currently has a user with ID `id`.
func (u *Users) hasUser(id string) (user *User, ok bool) {
for _, u := range u.users {
if u.ID() == id {
user, ok = u, true
return
}
}
return
}
// "Easter egg" for testing purposes.
func (u *Users) crashBandicoot(username string) {
if username == "crash@bandicoot" {
panic("Your wish is my command… I crash!")
}
}

View File

@ -0,0 +1,143 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func TestGetNoUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "nouser", -1, "user nouser not found")
}
func TestGetUserByID(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "user", 0, "")
checkUsersGetUser(t, m, "users", 1, "")
}
func TestGetUserByName(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "username", 0, "")
checkUsersGetUser(t, m, "usersname", 1, "")
}
func TestGetUserByEmail(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
checkUsersGetUser(t, m, "user@pm.me", 0, "")
checkUsersGetUser(t, m, "users@pm.me", 1, "")
checkUsersGetUser(t, m, "anotheruser@pm.me", 1, "")
checkUsersGetUser(t, m, "alsouser@pm.me", 1, "")
}
func TestDeleteUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := users.DeleteUser("user", true)
assert.NoError(t, err)
assert.Equal(t, 1, len(users.users))
}
// Even when logout fails, delete is done.
func TestDeleteUserWithFailingLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("no such user")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := users.DeleteUser("user", true)
assert.NoError(t, err)
assert.Equal(t, 1, len(users.users))
}
func checkUsersGetUser(t *testing.T, m mocks, query string, index int, expectedError string) {
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
user, err := users.GetUser(query)
waitForEvents()
if expectedError != "" {
assert.Equal(m.t, expectedError, err.Error())
} else {
assert.NoError(m.t, err)
}
var expectedUser *User
if index >= 0 {
expectedUser = users.users[index]
}
assert.Equal(m.t, expectedUser, user)
}

View File

@ -0,0 +1,239 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func TestUsersFinishLoginBadMailboxPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
err := errors.New("bad password")
gomock.InOrder(
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
// Set up mocks for FinishLogin.
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, err),
m.pmapiClient.EXPECT().DeleteAuth(),
m.pmapiClient.EXPECT().Logout(),
)
checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", err)
}
func TestUsersFinishLoginUpgradeApplication(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
err := errors.New("Cannot logout when upgrade needed")
gomock.InOrder(
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
// Set up mocks for FinishLogin.
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, pmapi.ErrUpgradeApplication),
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, ""),
m.pmapiClient.EXPECT().DeleteAuth().Return(err),
m.pmapiClient.EXPECT().Logout(),
)
checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "", pmapi.ErrUpgradeApplication)
}
func refreshWithToken(token string) *pmapi.Auth {
return &pmapi.Auth{
RefreshToken: token,
KeySalt: "", // No salting in tests.
}
}
func credentialsWithToken(token string) *credentials.Credentials {
tmp := &credentials.Credentials{}
*tmp = *testCredentials
tmp.APIToken = token
return tmp
}
func TestUsersFinishLoginNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Basically every call client has get client manager
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
// users.New() finds no users in keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
// getAPIUser() loads user info from API (e.g. userID).
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
// addNewUser()
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().Add("user", "username", ":afterLogin", testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}),
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
// user.init() in addNewUser
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
// store.New() in user.init
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Emit event for new user and send metrics.
m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient),
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)),
m.pmapiClient.EXPECT().Logout(),
// Reload account list in GUI.
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
// defer logout anonymous
m.pmapiClient.EXPECT().Logout(),
)
mockEventLoopNoAction(m)
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
mockAuthUpdate(user, "afterCredentials", m)
}
func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
loggedOutCreds := *testCredentials
loggedOutCreds.APIToken = ""
loggedOutCreds.MailboxPassword = ""
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
// users.New() finds one existing user in keychain.
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
// newUser()
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
// user.init()
m.credentialsStore.EXPECT().Get("user").Return(&loggedOutCreds, nil),
// store.New() in user.init
m.pmapiClient.EXPECT().ListLabels().Return(nil, pmapi.ErrInvalidToken),
m.pmapiClient.EXPECT().Addresses().Return(nil),
// getAPIUser() loads user info from API (e.g. userID).
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
// connectExistingUser()
m.credentialsStore.EXPECT().UpdatePassword("user", testCredentials.MailboxPassword).Return(nil),
m.pmapiClient.EXPECT().AuthRefresh(":tok").Return(refreshWithToken("afterLogin"), nil),
m.credentialsStore.EXPECT().UpdateToken("user", ":afterLogin").Return(nil),
// user.init() in connectExistingUser
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(":afterLogin"), nil),
m.pmapiClient.EXPECT().AuthRefresh(":afterLogin").Return(refreshWithToken("afterCredentials"), nil),
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
// store.New() in user.init
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Reload account list in GUI.
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
// defer logout anonymous
m.pmapiClient.EXPECT().Logout(),
)
mockEventLoopNoAction(m)
user := checkUsersFinishLogin(t, m, testAuth, testCredentials.MailboxPassword, "user", nil)
mockAuthUpdate(user, "afterCredentials", m)
}
func TestUsersFinishLoginConnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
mockConnectedUser(m)
mockEventLoopNoAction(m)
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
// Then, try to log in again...
gomock.InOrder(
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().CurrentUser().Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().DeleteAuth(),
m.pmapiClient.EXPECT().Logout(),
)
_, err := users.FinishLogin(m.pmapiClient, testAuth, testCredentials.MailboxPassword)
assert.Equal(t, "user is already connected", err.Error())
}
func checkUsersFinishLogin(t *testing.T, m mocks, auth *pmapi.Auth, mailboxPassword string, expectedUserID string, expectedErr error) *User {
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
user, err := users.FinishLogin(m.pmapiClient, auth, mailboxPassword)
waitForEvents()
assert.Equal(t, expectedErr, err)
if expectedUserID != "" {
assert.Equal(t, expectedUserID, user.ID())
assert.Equal(t, 1, len(users.users))
assert.Equal(t, expectedUserID, users.users[0].ID())
} else {
assert.Equal(t, (*User)(nil), user)
assert.Equal(t, 0, len(users.users))
}
return user
}

View File

@ -0,0 +1,188 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"errors"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/preferences"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func TestNewUsersNoKeychain(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{}, errors.New("no keychain"))
checkUsersNew(t, m, []*credentials.Credentials{})
}
func TestNewUsersWithoutUsersInCredentialsStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
checkUsersNew(t, m, []*credentials.Credentials{})
}
func TestNewUsersWithDisconnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
// Basically every call client has get client manager.
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
}
func TestNewUsersWithConnectedUserWithBadToken(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil).Times(2)
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token"))
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.pmapiClient.EXPECT().Logout()
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected})
}
func mockConnectedUser(m mocks) {
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock(testCredentials.MailboxPassword).Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte(testCredentials.MailboxPassword)).Return(nil),
// Set up mocks for store initialisation for the authorized user.
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
)
}
// mockAuthUpdate simulates users calling UpdateAuthToken on the given user.
// This would normally be done by users when it receives an auth from the ClientManager,
// but as we don't have a full users instance here, we do this manually.
func mockAuthUpdate(user *User, token string, m mocks) {
gomock.InOrder(
m.credentialsStore.EXPECT().UpdateToken("user", ":"+token).Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(credentialsWithToken(token), nil),
)
user.updateAuthToken(refreshWithToken(token))
waitForEvents()
}
func TestNewUsersWithConnectedUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"user"}, nil)
mockConnectedUser(m)
mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
}
// Tests two users with different states and checks also the order from
// credentials store is kept also in array of users.
func TestNewUsersWithUsers(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.credentialsStore.EXPECT().List().Return([]string{"userDisconnected", "user"}, nil)
gomock.InOrder(
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
m.credentialsStore.EXPECT().Get("userDisconnected").Return(testCredentialsDisconnected, nil),
// Set up mocks for store initialisation for the unauth user.
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.clientManager.EXPECT().GetClient("userDisconnected").Return(m.pmapiClient),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
mockConnectedUser(m)
mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
}
func TestNewUsersFirstStart(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{}, nil),
m.prefProvider.EXPECT().GetBool(preferences.FirstStartKey).Return(true),
m.clientManager.EXPECT().GetAnonymousClient().Return(m.pmapiClient),
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Setup), string(metrics.FirstStart), gomock.Any()),
m.pmapiClient.EXPECT().Logout(),
)
testNewUsers(t, m)
}
func checkUsersNew(t *testing.T, m mocks, expectedCredentials []*credentials.Credentials) {
users := testNewUsers(t, m)
defer cleanUpUsersData(users)
assert.Equal(m.t, len(expectedCredentials), len(users.GetUsers()))
credentials := []*credentials.Credentials{}
for _, user := range users.users {
credentials = append(credentials, user.creds)
}
assert.Equal(m.t, expectedCredentials, credentials)
}

View File

@ -0,0 +1,303 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"fmt"
"io/ioutil"
"os"
"runtime/debug"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/preferences"
"github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
usersmocks "github.com/ProtonMail/proton-bridge/internal/users/mocks"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
gomock "github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
if os.Getenv("VERBOSITY") == "fatal" {
logrus.SetLevel(logrus.FatalLevel)
}
if os.Getenv("VERBOSITY") == "trace" {
logrus.SetLevel(logrus.TraceLevel)
}
os.Exit(m.Run())
}
var (
testAuth = &pmapi.Auth{ //nolint[gochecknoglobals]
RefreshToken: "tok",
KeySalt: "", // No salting in tests.
}
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
RefreshToken: "reftok",
KeySalt: "", // No salting in tests.
}
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "user",
Name: "username",
Emails: "user@pm.me",
APIToken: "token",
MailboxPassword: "pass",
BridgePassword: "0123456789abcdef",
Version: "v1",
Timestamp: 123456789,
IsHidden: false,
IsCombinedAddressMode: true,
}
testCredentialsSplit = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "users",
Name: "usersname",
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
APIToken: "token",
MailboxPassword: "pass",
BridgePassword: "0123456789abcdef",
Version: "v1",
Timestamp: 123456789,
IsHidden: false,
IsCombinedAddressMode: false,
}
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "user",
Name: "username",
Emails: "user@pm.me",
APIToken: "",
MailboxPassword: "",
BridgePassword: "0123456789abcdef",
Version: "v1",
Timestamp: 123456789,
IsHidden: false,
IsCombinedAddressMode: true,
}
testPMAPIUser = &pmapi.User{ //nolint[gochecknoglobals]
ID: "user",
Name: "username",
}
testPMAPIAddress = &pmapi.Address{ //nolint[gochecknoglobals]
ID: "testAddressID",
Type: pmapi.OriginalAddress,
Email: "user@pm.me",
Receive: pmapi.CanReceive,
}
testPMAPIAddresses = []*pmapi.Address{ //nolint[gochecknoglobals]
{ID: "usersAddress1ID", Email: "users@pm.me", Receive: pmapi.CanReceive, Type: pmapi.OriginalAddress},
{ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress},
{ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: pmapi.CanReceive, Type: pmapi.AliasAddress},
}
testPMAPIEvent = &pmapi.Event{ // nolint[gochecknoglobals]
EventID: "ACXDmTaBub14w==",
}
)
func waitForEvents() {
// Wait for goroutine to add listener.
// E.g. calling login to invoke firstsync event. Functions can end sooner than
// goroutines call the listener mock. We need to wait a little bit before the end of
// the test to capture all event calls. This allows us to detect whether there were
// missing calls, or perhaps whether something was called too many times.
time.Sleep(100 * time.Millisecond)
}
type mocks struct {
t *testing.T
ctrl *gomock.Controller
config *usersmocks.MockConfiger
PanicHandler *usersmocks.MockPanicHandler
prefProvider *usersmocks.MockPreferenceProvider
clientManager *usersmocks.MockClientManager
credentialsStore *usersmocks.MockCredentialsStorer
eventListener *MockListener
pmapiClient *pmapimocks.MockClient
storeCache *store.Cache
}
type fullStackReporter struct {
T testing.TB
}
func (fr *fullStackReporter) Errorf(format string, args ...interface{}) {
fmt.Printf("err: "+format+"\n", args...)
fr.T.Fail()
}
func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
debug.PrintStack()
fmt.Printf("fail: "+format+"\n", args...)
fr.T.FailNow()
}
func initMocks(t *testing.T) mocks {
var mockCtrl *gomock.Controller
if os.Getenv("VERBOSITY") == "trace" {
mockCtrl = gomock.NewController(&fullStackReporter{t})
} else {
mockCtrl = gomock.NewController(t)
}
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
require.NoError(t, err, "could not get temporary file for store cache")
m := mocks{
t: t,
ctrl: mockCtrl,
config: usersmocks.NewMockConfiger(mockCtrl),
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
prefProvider: usersmocks.NewMockPreferenceProvider(mockCtrl),
clientManager: usersmocks.NewMockClientManager(mockCtrl),
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
eventListener: NewMockListener(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()),
}
// Ignore heartbeat calls because they always happen.
m.pmapiClient.EXPECT().SendSimpleMetric(string(metrics.Heartbeat), gomock.Any(), gomock.Any()).AnyTimes()
m.prefProvider.EXPECT().Get(preferences.NextHeartbeatKey).AnyTimes()
m.prefProvider.EXPECT().Set(preferences.NextHeartbeatKey, gomock.Any()).AnyTimes()
// Called during clean-up.
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
return m
}
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
// Events are asynchronous
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil).Times(2)
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).Times(2)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).Times(2)
gomock.InOrder(
m.credentialsStore.EXPECT().List().Return([]string{"user", "users"}, nil),
// Init for user.
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
// Init for users.
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().Unlock("pass").Return(nil, nil),
m.pmapiClient.EXPECT().UnlockAddresses([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return(testPMAPIAddresses),
)
users := testNewUsers(t, m)
user, _ := users.GetUser("user")
mockAuthUpdate(user, "reftok", m)
user, _ = users.GetUser("user")
mockAuthUpdate(user, "reftok", m)
return users
}
func testNewUsers(t *testing.T, m mocks) *Users {
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
require.NoError(t, err, "could not get temporary file for store cache")
m.prefProvider.EXPECT().GetBool(preferences.FirstStartKey).Return(false).AnyTimes()
m.prefProvider.EXPECT().GetBool(preferences.AllowProxyKey).Return(false).AnyTimes()
m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes()
m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes()
m.config.EXPECT().GetVersion().Return("ver").AnyTimes()
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan pmapi.ClientAuth))
users := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore)
waitForEvents()
return users
}
func cleanUpUsersData(b *Users) {
for _, user := range b.users {
_ = user.clearStore()
}
}
func TestClearData(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.clientManager.EXPECT().GetClient("users").Return(m.pmapiClient).MinTimes(1)
users := testNewUsersWithUsers(t, m)
defer cleanUpUsersData(users)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me")
m.pmapiClient.EXPECT().Logout()
m.credentialsStore.EXPECT().Logout("user").Return(nil)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
m.pmapiClient.EXPECT().Logout()
m.credentialsStore.EXPECT().Logout("users").Return(nil)
m.credentialsStore.EXPECT().Get("users").Return(testCredentialsSplit, nil)
m.config.EXPECT().ClearData().Return(nil)
require.NoError(t, users.ClearData())
waitForEvents()
}
func mockEventLoopNoAction(m mocks) {
// Set up mocks for starting the store's event loop (in store.New).
// The event loop runs in another goroutine so this might happen at any time.
gomock.InOrder(
m.pmapiClient.EXPECT().GetEvent("").Return(testPMAPIEvent, nil),
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil),
// Set up mocks for performing the initial store sync.
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil),
)
}

View File

@ -0,0 +1,23 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
// IsAuthorized returns whether the user has received an Auth from the API yet.
func (u *User) IsAuthorized() bool {
return u.isAuthorized
}