mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2026-02-04 00:08:33 +00:00
feat: simple client manager
This commit is contained in:
@ -57,6 +57,7 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/pkg/args"
|
"github.com/ProtonMail/proton-bridge/pkg/args"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/config"
|
"github.com/ProtonMail/proton-bridge/pkg/config"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||||
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/updates"
|
"github.com/ProtonMail/proton-bridge/pkg/updates"
|
||||||
"github.com/allan-simon/go-singleinstance"
|
"github.com/allan-simon/go-singleinstance"
|
||||||
"github.com/getsentry/raven-go"
|
"github.com/getsentry/raven-go"
|
||||||
@ -273,9 +274,8 @@ func run(context *cli.Context) (contextError error) { // nolint[funlen]
|
|||||||
log.Error("Could not get credentials store: ", credentialsError)
|
log.Error("Could not get credentials store: ", credentialsError)
|
||||||
}
|
}
|
||||||
|
|
||||||
pmapiClientFactory := pmapifactory.New(cfg, eventListener)
|
clientman := pmapi.NewClientManager(pmapifactory.GetClientConfig(cfg, eventListener))
|
||||||
|
bridgeInstance := bridge.New(cfg, pref, panicHandler, eventListener, Version, clientman, credentialsStore)
|
||||||
bridgeInstance := bridge.New(cfg, pref, panicHandler, eventListener, Version, pmapiClientFactory, credentialsStore)
|
|
||||||
imapBackend := imap.NewIMAPBackend(panicHandler, eventListener, cfg, bridgeInstance)
|
imapBackend := imap.NewIMAPBackend(panicHandler, eventListener, cfg, bridgeInstance)
|
||||||
smtpBackend := smtp.NewSMTPBackend(panicHandler, eventListener, pref, bridgeInstance)
|
smtpBackend := smtp.NewSMTPBackend(panicHandler, eventListener, pref, bridgeInstance)
|
||||||
|
|
||||||
|
|||||||
@ -43,14 +43,14 @@ var (
|
|||||||
|
|
||||||
// Bridge is a struct handling users.
|
// Bridge is a struct handling users.
|
||||||
type Bridge struct {
|
type Bridge struct {
|
||||||
config Configer
|
config Configer
|
||||||
pref PreferenceProvider
|
pref PreferenceProvider
|
||||||
panicHandler PanicHandler
|
panicHandler PanicHandler
|
||||||
events listener.Listener
|
events listener.Listener
|
||||||
version string
|
version string
|
||||||
pmapiClientFactory PMAPIProviderFactory
|
clientMan *pmapi.ClientManager
|
||||||
credStorer CredentialsStorer
|
credStorer CredentialsStorer
|
||||||
storeCache *store.Cache
|
storeCache *store.Cache
|
||||||
|
|
||||||
// users is a list of accounts that have been added to bridge.
|
// users is a list of accounts that have been added to bridge.
|
||||||
// They are stored sorted in the credentials store in the order
|
// They are stored sorted in the credentials store in the order
|
||||||
@ -76,22 +76,22 @@ func New(
|
|||||||
panicHandler PanicHandler,
|
panicHandler PanicHandler,
|
||||||
eventListener listener.Listener,
|
eventListener listener.Listener,
|
||||||
version string,
|
version string,
|
||||||
pmapiClientFactory PMAPIProviderFactory,
|
clientMan *pmapi.ClientManager,
|
||||||
credStorer CredentialsStorer,
|
credStorer CredentialsStorer,
|
||||||
) *Bridge {
|
) *Bridge {
|
||||||
log.Trace("Creating new bridge")
|
log.Trace("Creating new bridge")
|
||||||
|
|
||||||
b := &Bridge{
|
b := &Bridge{
|
||||||
config: config,
|
config: config,
|
||||||
pref: pref,
|
pref: pref,
|
||||||
panicHandler: panicHandler,
|
panicHandler: panicHandler,
|
||||||
events: eventListener,
|
events: eventListener,
|
||||||
version: version,
|
version: version,
|
||||||
pmapiClientFactory: pmapiClientFactory,
|
clientMan: clientMan,
|
||||||
credStorer: credStorer,
|
credStorer: credStorer,
|
||||||
storeCache: store.NewCache(config.GetIMAPCachePath()),
|
storeCache: store.NewCache(config.GetIMAPCachePath()),
|
||||||
idleUpdates: make(chan interface{}),
|
idleUpdates: make(chan interface{}),
|
||||||
lock: sync.RWMutex{},
|
lock: sync.RWMutex{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allow DoH before starting bridge if the user has previously set this setting.
|
// Allow DoH before starting bridge if the user has previously set this setting.
|
||||||
@ -148,9 +148,7 @@ func (b *Bridge) loadUsersFromCredentialsStore() (err error) {
|
|||||||
for _, userID := range userIDs {
|
for _, userID := range userIDs {
|
||||||
l := log.WithField("user", userID)
|
l := log.WithField("user", userID)
|
||||||
|
|
||||||
apiClient := b.pmapiClientFactory(userID)
|
user, newUserErr := newUser(b.panicHandler, userID, b.events, b.credStorer, b.clientMan, b.storeCache, b.config.GetDBDir())
|
||||||
|
|
||||||
user, newUserErr := newUser(b.panicHandler, userID, b.events, b.credStorer, apiClient, b.storeCache, b.config.GetDBDir())
|
|
||||||
if newUserErr != nil {
|
if newUserErr != nil {
|
||||||
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
|
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
|
||||||
continue
|
continue
|
||||||
@ -158,7 +156,7 @@ func (b *Bridge) loadUsersFromCredentialsStore() (err error) {
|
|||||||
|
|
||||||
b.users = append(b.users, user)
|
b.users = append(b.users, user)
|
||||||
|
|
||||||
if initUserErr := user.init(b.idleUpdates, apiClient); initUserErr != nil {
|
if initUserErr := user.init(b.idleUpdates); initUserErr != nil {
|
||||||
l.WithField("user", userID).WithError(initUserErr).Warn("Could not initialise user")
|
l.WithField("user", userID).WithError(initUserErr).Warn("Could not initialise user")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -199,7 +197,7 @@ func (b *Bridge) Login(username, password string) (loginClient PMAPIProvider, au
|
|||||||
|
|
||||||
// We need to use "login" client because we need userID to properly
|
// We need to use "login" client because we need userID to properly
|
||||||
// assign access tokens into token manager.
|
// assign access tokens into token manager.
|
||||||
loginClient = b.pmapiClientFactory("login")
|
loginClient = b.clientMan.GetClient("login")
|
||||||
|
|
||||||
authInfo, err := loginClient.AuthInfo(username)
|
authInfo, err := loginClient.AuthInfo(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -268,7 +266,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
|
|||||||
}
|
}
|
||||||
|
|
||||||
apiToken := auth.UID() + ":" + auth.RefreshToken
|
apiToken := auth.UID() + ":" + auth.RefreshToken
|
||||||
apiClient := b.pmapiClientFactory(apiUser.ID)
|
apiClient := b.clientMan.GetClient(apiUser.ID)
|
||||||
auth, err = apiClient.AuthRefresh(apiToken)
|
auth, err = apiClient.AuthRefresh(apiToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Could refresh token in new client")
|
log.WithError(err).Error("Could refresh token in new client")
|
||||||
@ -297,7 +295,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
|
|||||||
|
|
||||||
// If it's a new user, generate the user object.
|
// If it's a new user, generate the user object.
|
||||||
if !hasUser {
|
if !hasUser {
|
||||||
user, err = newUser(b.panicHandler, apiUser.ID, b.events, b.credStorer, apiClient, b.storeCache, b.config.GetDBDir())
|
user, err = newUser(b.panicHandler, apiUser.ID, b.events, b.credStorer, b.clientMan, b.storeCache, b.config.GetDBDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user")
|
log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user")
|
||||||
return
|
return
|
||||||
@ -305,7 +303,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set up the user auth and store (which we do for both new and existing users).
|
// Set up the user auth and store (which we do for both new and existing users).
|
||||||
if err = user.init(b.idleUpdates, apiClient); err != nil {
|
if err = user.init(b.idleUpdates); err != nil {
|
||||||
log.WithField("user", user.userID).WithError(err).Error("Could not initialise user")
|
log.WithField("user", user.userID).WithError(err).Error("Could not initialise user")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -407,9 +405,11 @@ func (b *Bridge) DeleteUser(userID string, clearStore bool) error {
|
|||||||
|
|
||||||
// ReportBug reports a new bug from the user.
|
// ReportBug reports a new bug from the user.
|
||||||
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
|
||||||
apiClient := b.pmapiClientFactory("bug_reporter")
|
c := b.clientMan.GetClient("bug_reporter")
|
||||||
|
defer func() { _ = c.Logout() }()
|
||||||
|
|
||||||
title := "[Bridge] Bug"
|
title := "[Bridge] Bug"
|
||||||
err := apiClient.ReportBugWithEmailClient(
|
if err := c.ReportBugWithEmailClient(
|
||||||
osType,
|
osType,
|
||||||
osVersion,
|
osVersion,
|
||||||
title,
|
title,
|
||||||
@ -417,23 +417,26 @@ func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address,
|
|||||||
accountName,
|
accountName,
|
||||||
address,
|
address,
|
||||||
emailClient,
|
emailClient,
|
||||||
)
|
); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Error("Reporting bug failed: ", err)
|
log.Error("Reporting bug failed: ", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("Bug successfully reported")
|
log.Info("Bug successfully reported")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMetric sends a metric. We don't want to return any errors, only log them.
|
// SendMetric sends a metric. We don't want to return any errors, only log them.
|
||||||
func (b *Bridge) SendMetric(m m.Metric) {
|
func (b *Bridge) SendMetric(m m.Metric) {
|
||||||
apiClient := b.pmapiClientFactory("metric_reporter")
|
c := b.clientMan.GetClient("metric_reporter")
|
||||||
|
defer func() { _ = c.Logout() }()
|
||||||
|
|
||||||
cat, act, lab := m.Get()
|
cat, act, lab := m.Get()
|
||||||
err := apiClient.SendSimpleMetric(string(cat), string(act), string(lab))
|
if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Error("Sending metric failed: ", err)
|
log.Error("Sending metric failed: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithFields(logrus.Fields{
|
log.WithFields(logrus.Fields{
|
||||||
"cat": cat,
|
"cat": cat,
|
||||||
"act": act,
|
"act": act,
|
||||||
|
|||||||
@ -43,7 +43,8 @@ type PanicHandler interface {
|
|||||||
HandlePanic()
|
HandlePanic()
|
||||||
}
|
}
|
||||||
|
|
||||||
type PMAPIProviderFactory func(string) PMAPIProvider
|
type Clientman interface {
|
||||||
|
}
|
||||||
|
|
||||||
type PMAPIProvider interface {
|
type PMAPIProvider interface {
|
||||||
SetAuths(auths chan<- *pmapi.Auth)
|
SetAuths(auths chan<- *pmapi.Auth)
|
||||||
|
|||||||
@ -41,7 +41,7 @@ type User struct {
|
|||||||
log *logrus.Entry
|
log *logrus.Entry
|
||||||
panicHandler PanicHandler
|
panicHandler PanicHandler
|
||||||
listener listener.Listener
|
listener listener.Listener
|
||||||
apiClient PMAPIProvider
|
clientMan *pmapi.ClientManager
|
||||||
credStorer CredentialsStorer
|
credStorer CredentialsStorer
|
||||||
|
|
||||||
imapUpdatesChannel chan interface{}
|
imapUpdatesChannel chan interface{}
|
||||||
@ -67,7 +67,7 @@ func newUser(
|
|||||||
userID string,
|
userID string,
|
||||||
eventListener listener.Listener,
|
eventListener listener.Listener,
|
||||||
credStorer CredentialsStorer,
|
credStorer CredentialsStorer,
|
||||||
apiClient PMAPIProvider,
|
clientMan *pmapi.ClientManager,
|
||||||
storeCache *store.Cache,
|
storeCache *store.Cache,
|
||||||
storeDir string,
|
storeDir string,
|
||||||
) (u *User, err error) {
|
) (u *User, err error) {
|
||||||
@ -84,7 +84,7 @@ func newUser(
|
|||||||
panicHandler: panicHandler,
|
panicHandler: panicHandler,
|
||||||
listener: eventListener,
|
listener: eventListener,
|
||||||
credStorer: credStorer,
|
credStorer: credStorer,
|
||||||
apiClient: apiClient,
|
clientMan: clientMan,
|
||||||
storeCache: storeCache,
|
storeCache: storeCache,
|
||||||
storePath: getUserStorePath(storeDir, userID),
|
storePath: getUserStorePath(storeDir, userID),
|
||||||
userID: userID,
|
userID: userID,
|
||||||
@ -94,16 +94,16 @@ func newUser(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *User) client() PMAPIProvider {
|
||||||
|
return u.clientMan.GetClient(u.userID)
|
||||||
|
}
|
||||||
|
|
||||||
// init initialises a bridge user. This includes reloading its credentials from the credentials store
|
// init initialises a bridge user. This includes reloading its credentials from the credentials store
|
||||||
// (such as when logging out and back in, you need to reload the credentials because the new credentials will
|
// (such as when logging out and back in, you need to reload the credentials because the new credentials will
|
||||||
// have the apitoken and password), authorising the user against the api, loading the user store (creating a new one
|
// have the apitoken and password), authorising the user against the api, loading the user store (creating a new one
|
||||||
// if necessary), and setting the imap idle updates channel (used to send imap idle updates to the imap backend if
|
// if necessary), and setting the imap idle updates channel (used to send imap idle updates to the imap backend if
|
||||||
// something in the store changed).
|
// something in the store changed).
|
||||||
func (u *User) init(idleUpdates chan interface{}, apiClient PMAPIProvider) (err error) {
|
func (u *User) init(idleUpdates chan interface{}) (err error) {
|
||||||
// If this is an existing user, we still need a new api client to get a new refresh token.
|
|
||||||
// If it's a new user, doesn't matter really; this is basically a noop in this case.
|
|
||||||
u.apiClient = apiClient
|
|
||||||
|
|
||||||
u.unlockingKeyringLock.Lock()
|
u.unlockingKeyringLock.Lock()
|
||||||
u.wasKeyringUnlocked = false
|
u.wasKeyringUnlocked = false
|
||||||
u.unlockingKeyringLock.Unlock()
|
u.unlockingKeyringLock.Unlock()
|
||||||
@ -118,7 +118,7 @@ func (u *User) init(idleUpdates chan interface{}, apiClient PMAPIProvider) (err
|
|||||||
|
|
||||||
// Set up the auth channel on which auths from the api client are sent.
|
// Set up the auth channel on which auths from the api client are sent.
|
||||||
u.authChannel = make(chan *pmapi.Auth)
|
u.authChannel = make(chan *pmapi.Auth)
|
||||||
u.apiClient.SetAuths(u.authChannel)
|
u.client().SetAuths(u.authChannel)
|
||||||
u.hasAPIAuth = false
|
u.hasAPIAuth = false
|
||||||
go func() {
|
go func() {
|
||||||
defer u.panicHandler.HandlePanic()
|
defer u.panicHandler.HandlePanic()
|
||||||
@ -147,7 +147,7 @@ func (u *User) init(idleUpdates chan interface{}, apiClient PMAPIProvider) (err
|
|||||||
}
|
}
|
||||||
u.store = nil
|
u.store = nil
|
||||||
}
|
}
|
||||||
store, err := store.New(u.panicHandler, u, u.apiClient, u.listener, u.storePath, u.storeCache)
|
store, err := store.New(u.panicHandler, u, u.client(), u.listener, u.storePath, u.storeCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to create store")
|
return errors.Wrap(err, "failed to create store")
|
||||||
}
|
}
|
||||||
@ -216,11 +216,11 @@ func (u *User) unlockIfNecessary() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := u.apiClient.Unlock(u.creds.MailboxPassword); err != nil {
|
if _, err := u.client().Unlock(u.creds.MailboxPassword); err != nil {
|
||||||
return errors.Wrap(err, "failed to unlock user")
|
return errors.Wrap(err, "failed to unlock user")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := u.apiClient.UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
|
if err := u.client().UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
|
||||||
return errors.Wrap(err, "failed to unlock user addresses")
|
return errors.Wrap(err, "failed to unlock user addresses")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,17 +236,17 @@ func (u *User) authorizeAndUnlock() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
auth, err := u.apiClient.AuthRefresh(u.creds.APIToken)
|
auth, err := u.client().AuthRefresh(u.creds.APIToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to refresh API auth")
|
return errors.Wrap(err, "failed to refresh API auth")
|
||||||
}
|
}
|
||||||
u.authChannel <- auth
|
u.authChannel <- auth
|
||||||
|
|
||||||
if _, err = u.apiClient.Unlock(u.creds.MailboxPassword); err != nil {
|
if _, err = u.client().Unlock(u.creds.MailboxPassword); err != nil {
|
||||||
return errors.Wrap(err, "failed to unlock user")
|
return errors.Wrap(err, "failed to unlock user")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = u.apiClient.UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
|
if err = u.client().UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
|
||||||
return errors.Wrap(err, "failed to unlock user addresses")
|
return errors.Wrap(err, "failed to unlock user addresses")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -321,7 +321,7 @@ func getUserStorePath(storeDir string, userID string) (path string) {
|
|||||||
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
|
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
|
||||||
// After proper refactor of SMTP and IMAP remove this method.
|
// After proper refactor of SMTP and IMAP remove this method.
|
||||||
func (u *User) GetTemporaryPMAPIClient() PMAPIProvider {
|
func (u *User) GetTemporaryPMAPIClient() PMAPIProvider {
|
||||||
return u.apiClient
|
return u.client()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the user's userID.
|
// ID returns the user's userID.
|
||||||
@ -462,20 +462,20 @@ func (u *User) UpdateUser() error {
|
|||||||
return errors.Wrap(err, "cannot update user")
|
return errors.Wrap(err, "cannot update user")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := u.apiClient.UpdateUser()
|
_, err := u.client().UpdateUser()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = u.apiClient.Unlock(u.creds.MailboxPassword); err != nil {
|
if _, err = u.client().Unlock(u.creds.MailboxPassword); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := u.apiClient.UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
|
if err := u.client().UnlockAddresses([]byte(u.creds.MailboxPassword)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
emails := u.apiClient.Addresses().ActiveEmails()
|
emails := u.client().Addresses().ActiveEmails()
|
||||||
if err := u.credStorer.UpdateEmails(u.userID, emails); err != nil {
|
if err := u.credStorer.UpdateEmails(u.userID, emails); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -548,10 +548,10 @@ func (u *User) Logout() (err error) {
|
|||||||
u.wasKeyringUnlocked = false
|
u.wasKeyringUnlocked = false
|
||||||
u.unlockingKeyringLock.Unlock()
|
u.unlockingKeyringLock.Unlock()
|
||||||
|
|
||||||
if err = u.apiClient.Logout(); err != nil {
|
if err = u.client().Logout(); err != nil {
|
||||||
u.log.WithError(err).Warn("Could not log user out from API client")
|
u.log.WithError(err).Warn("Could not log user out from API client")
|
||||||
}
|
}
|
||||||
u.apiClient.SetAuths(nil)
|
u.client().SetAuths(nil)
|
||||||
|
|
||||||
// Logout needs to stop auth channel so when user logs back in
|
// Logout needs to stop auth channel so when user logs back in
|
||||||
// it can register again with new client.
|
// it can register again with new client.
|
||||||
|
|||||||
@ -26,8 +26,6 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(cfg bridge.Configer, _ listener.Listener) bridge.PMAPIProviderFactory {
|
func GetClientConfig(config bridge.Configer, _ listener.Listener) *pmapi.ClientConfig {
|
||||||
return func(userID string) bridge.PMAPIProvider {
|
return config.GetAPIConfig()
|
||||||
return pmapi.NewClient(cfg.GetAPIConfig(), userID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,10 +29,10 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(config bridge.Configer, listener listener.Listener) bridge.PMAPIProviderFactory {
|
func GetClientConfig(config bridge.Configer, listener listener.Listener) *pmapi.ClientConfig {
|
||||||
cfg := config.GetAPIConfig()
|
clientConfig := config.GetAPIConfig()
|
||||||
|
|
||||||
pin := pmapi.NewPMAPIPinning(cfg.AppVersion)
|
pin := pmapi.NewPMAPIPinning(clientConfig.AppVersion)
|
||||||
pin.ReportCertIssueLocal = func() {
|
pin.ReportCertIssueLocal = func() {
|
||||||
listener.Emit(events.TLSCertIssue, "")
|
listener.Emit(events.TLSCertIssue, "")
|
||||||
}
|
}
|
||||||
@ -41,14 +41,12 @@ func New(config bridge.Configer, listener listener.Listener) bridge.PMAPIProvide
|
|||||||
// - IdleConnTimeout: 5 * time.Minute,
|
// - IdleConnTimeout: 5 * time.Minute,
|
||||||
// - ExpectContinueTimeout: 500 * time.Millisecond,
|
// - ExpectContinueTimeout: 500 * time.Millisecond,
|
||||||
// - ResponseHeaderTimeout: 30 * time.Second,
|
// - ResponseHeaderTimeout: 30 * time.Second,
|
||||||
cfg.Transport = pin.TransportWithPinning()
|
clientConfig.Transport = pin.TransportWithPinning()
|
||||||
|
|
||||||
// We set additional timeouts/thresholds for the request as a whole:
|
// We set additional timeouts/thresholds for the request as a whole:
|
||||||
cfg.Timeout = 10 * time.Minute // Overall request timeout (~25MB / 10 mins => ~40kB/s, should be reasonable).
|
clientConfig.Timeout = 10 * time.Minute // Overall request timeout (~25MB / 10 mins => ~40kB/s, should be reasonable).
|
||||||
cfg.FirstReadTimeout = 30 * time.Second // 30s to match 30s response header timeout.
|
clientConfig.FirstReadTimeout = 30 * time.Second // 30s to match 30s response header timeout.
|
||||||
cfg.MinSpeed = 1 << 13 // Enforce minimum download speed of 8kB/s.
|
clientConfig.MinSpeed = 1 << 13 // Enforce minimum download speed of 8kB/s.
|
||||||
|
|
||||||
return func(userID string) bridge.PMAPIProvider {
|
return clientConfig
|
||||||
return pmapi.NewClient(cfg, userID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -78,8 +78,6 @@ func newConfig(appName, version, revision, cacheVersion string, appDirs, appDirs
|
|||||||
TLSHandshakeTimeout: 10 * time.Second,
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
ResponseHeaderTimeout: 10 * time.Second,
|
ResponseHeaderTimeout: 10 * time.Second,
|
||||||
},
|
},
|
||||||
// TokenManager should not be required, but PMAPI still doesn't handle not-set cases everywhere.
|
|
||||||
TokenManager: pmapi.NewTokenManager(),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -64,9 +64,11 @@ func GetLogEntry(packageName string) *logrus.Entry {
|
|||||||
// HandlePanic reports the crash to sentry or local file when sentry fails.
|
// HandlePanic reports the crash to sentry or local file when sentry fails.
|
||||||
func HandlePanic(cfg *Config, output string) {
|
func HandlePanic(cfg *Config, output string) {
|
||||||
if !cfg.IsDevMode() {
|
if !cfg.IsDevMode() {
|
||||||
c := pmapi.NewClient(cfg.GetAPIConfig(), "no-user-id")
|
// TODO: Is it okay to just create a throwaway client like this?
|
||||||
err := c.ReportSentryCrash(fmt.Errorf(output))
|
c := pmapi.NewClientManager(cfg.GetAPIConfig()).GetClient("no-user-id")
|
||||||
if err != nil {
|
defer func() { _ = c.Logout() }()
|
||||||
|
|
||||||
|
if err := c.ReportSentryCrash(fmt.Errorf(output)); err != nil {
|
||||||
log.Error("Sentry crash report failed: ", err)
|
log.Error("Sentry crash report failed: ", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -307,11 +307,7 @@ func (c *Client) Auth(username, password string, info *AuthInfo) (auth *Auth, er
|
|||||||
if c.auths != nil {
|
if c.auths != nil {
|
||||||
c.auths <- auth
|
c.auths <- auth
|
||||||
}
|
}
|
||||||
|
c.cm.SetToken(c.userID, c.uid+":"+auth.RefreshToken)
|
||||||
if c.tokenManager != nil {
|
|
||||||
c.tokenManager.SetToken(c.userID, c.uid+":"+auth.RefreshToken)
|
|
||||||
c.log.Info("Set token from auth " + c.uid + ":" + auth.RefreshToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth has to be fully unlocked to get key salt. During `Auth` it can happen
|
// Auth has to be fully unlocked to get key salt. During `Auth` it can happen
|
||||||
// only to accounts without 2FA. For 2FA accounts, it's done in `Auth2FA`.
|
// only to accounts without 2FA. For 2FA accounts, it's done in `Auth2FA`.
|
||||||
@ -407,14 +403,7 @@ func (c *Client) Unlock(password string) (kr *pmcrypto.KeyRing, err error) {
|
|||||||
func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
|
func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
|
||||||
// If we don't yet have a saved access token, save this one in case the refresh fails!
|
// If we don't yet have a saved access token, save this one in case the refresh fails!
|
||||||
// That way we can try again later (see handleUnauthorizedStatus).
|
// That way we can try again later (see handleUnauthorizedStatus).
|
||||||
if c.tokenManager != nil {
|
c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
|
||||||
currentAccessToken := c.tokenManager.GetToken(c.userID)
|
|
||||||
if currentAccessToken == "" {
|
|
||||||
c.log.WithField("token", uidAndRefreshToken).
|
|
||||||
Info("Currently have no access token, setting given one")
|
|
||||||
c.tokenManager.SetToken(c.userID, uidAndRefreshToken)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
split := strings.Split(uidAndRefreshToken, ":")
|
split := strings.Split(uidAndRefreshToken, ":")
|
||||||
if len(split) != 2 {
|
if len(split) != 2 {
|
||||||
@ -456,11 +445,7 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
|||||||
|
|
||||||
c.uid = auth.UID()
|
c.uid = auth.UID()
|
||||||
c.accessToken = auth.accessToken
|
c.accessToken = auth.accessToken
|
||||||
|
c.cm.SetToken(c.userID, c.uid+":"+res.RefreshToken)
|
||||||
if c.tokenManager != nil {
|
|
||||||
c.tokenManager.SetToken(c.userID, c.uid+":"+res.RefreshToken)
|
|
||||||
c.log.Info("Set token from auth refresh " + c.uid + ":" + res.RefreshToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
|
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
|
||||||
return auth, err
|
return auth, err
|
||||||
@ -497,9 +482,7 @@ func (c *Client) Logout() (err error) {
|
|||||||
c.kr = nil
|
c.kr = nil
|
||||||
// c.addresses = nil
|
// c.addresses = nil
|
||||||
c.user = nil
|
c.user = nil
|
||||||
if c.tokenManager != nil {
|
c.cm.ClearToken(c.userID)
|
||||||
c.tokenManager.SetToken(c.userID, "")
|
|
||||||
}
|
|
||||||
// }()
|
// }()
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|||||||
@ -353,8 +353,7 @@ func TestClient_DoUnauthorized(t *testing.T) {
|
|||||||
c.uid = testUID
|
c.uid = testUID
|
||||||
c.accessToken = testAccessTokenOld
|
c.accessToken = testAccessTokenOld
|
||||||
c.expiresAt = aLongTimeAgo
|
c.expiresAt = aLongTimeAgo
|
||||||
c.tokenManager = NewTokenManager()
|
c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken
|
||||||
c.tokenManager.tokenMap[c.userID] = testUID + ":" + testRefreshToken
|
|
||||||
|
|
||||||
req, err := NewRequest("GET", "/", nil)
|
req, err := NewRequest("GET", "/", nil)
|
||||||
Ok(t, err)
|
Ok(t, err)
|
||||||
|
|||||||
@ -132,8 +132,8 @@ func writeMultipartReport(w *multipart.Writer, rep *ReportReq) error { // nolint
|
|||||||
|
|
||||||
// Report sends request as json or multipart (if has attachment).
|
// Report sends request as json or multipart (if has attachment).
|
||||||
func (c *Client) Report(rep ReportReq) (err error) {
|
func (c *Client) Report(rep ReportReq) (err error) {
|
||||||
rep.Client = c.config.ClientID
|
rep.Client = c.cm.GetConfig().ClientID
|
||||||
rep.ClientVersion = c.config.AppVersion
|
rep.ClientVersion = c.cm.GetConfig().AppVersion
|
||||||
rep.ClientType = EmailClientType
|
rep.ClientType = EmailClientType
|
||||||
|
|
||||||
var req *http.Request
|
var req *http.Request
|
||||||
@ -196,8 +196,8 @@ func (c *Client) ReportBugWithEmailClient(os, osVersion, title, description, use
|
|||||||
// ReportCrash is old. Use sentry instead.
|
// ReportCrash is old. Use sentry instead.
|
||||||
func (c *Client) ReportCrash(stacktrace string) (err error) {
|
func (c *Client) ReportCrash(stacktrace string) (err error) {
|
||||||
crashReq := ReportReq{
|
crashReq := ReportReq{
|
||||||
Client: c.config.ClientID,
|
Client: c.cm.GetConfig().ClientID,
|
||||||
ClientVersion: c.config.AppVersion,
|
ClientVersion: c.cm.GetConfig().AppVersion,
|
||||||
ClientType: EmailClientType,
|
ClientType: EmailClientType,
|
||||||
OS: runtime.GOOS,
|
OS: runtime.GOOS,
|
||||||
Debug: stacktrace,
|
Debug: stacktrace,
|
||||||
|
|||||||
@ -68,33 +68,6 @@ func (err *ErrUnauthorized) Error() string {
|
|||||||
return fmt.Sprintf("unauthorized access: %+v", err.error.Error())
|
return fmt.Sprintf("unauthorized access: %+v", err.error.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenManager struct {
|
|
||||||
tokensLocker sync.Locker
|
|
||||||
tokenMap map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTokenManager() *TokenManager {
|
|
||||||
tm := &TokenManager{
|
|
||||||
tokensLocker: &sync.Mutex{},
|
|
||||||
tokenMap: map[string]string{},
|
|
||||||
}
|
|
||||||
return tm
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *TokenManager) GetToken(userID string) string {
|
|
||||||
tm.tokensLocker.Lock()
|
|
||||||
defer tm.tokensLocker.Unlock()
|
|
||||||
|
|
||||||
return tm.tokenMap[userID]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *TokenManager) SetToken(userID, token string) {
|
|
||||||
tm.tokensLocker.Lock()
|
|
||||||
defer tm.tokensLocker.Unlock()
|
|
||||||
|
|
||||||
tm.tokenMap[userID] = token
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientConfig contains Client configuration.
|
// ClientConfig contains Client configuration.
|
||||||
type ClientConfig struct {
|
type ClientConfig struct {
|
||||||
// The client application name and version.
|
// The client application name and version.
|
||||||
@ -103,7 +76,8 @@ type ClientConfig struct {
|
|||||||
// The client ID.
|
// The client ID.
|
||||||
ClientID string
|
ClientID string
|
||||||
|
|
||||||
TokenManager *TokenManager
|
// The sentry DSN.
|
||||||
|
SentryDSN string
|
||||||
|
|
||||||
// Transport specifies the mechanism by which individual HTTP requests are made.
|
// Transport specifies the mechanism by which individual HTTP requests are made.
|
||||||
// If nil, http.DefaultTransport is used.
|
// If nil, http.DefaultTransport is used.
|
||||||
@ -127,10 +101,8 @@ type ClientConfig struct {
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
auths chan<- *Auth // Channel that sends Auth responses back to the bridge.
|
auths chan<- *Auth // Channel that sends Auth responses back to the bridge.
|
||||||
|
|
||||||
log *logrus.Entry
|
|
||||||
config *ClientConfig
|
|
||||||
client *http.Client
|
client *http.Client
|
||||||
conrep ConnectionReporter
|
cm *ClientManager
|
||||||
|
|
||||||
uid string
|
uid string
|
||||||
accessToken string
|
accessToken string
|
||||||
@ -138,18 +110,32 @@ type Client struct {
|
|||||||
requestLocker sync.Locker
|
requestLocker sync.Locker
|
||||||
keyLocker sync.Locker
|
keyLocker sync.Locker
|
||||||
|
|
||||||
tokenManager *TokenManager
|
expiresAt time.Time
|
||||||
expiresAt time.Time
|
user *User
|
||||||
user *User
|
addresses AddressList
|
||||||
addresses AddressList
|
kr *pmcrypto.KeyRing
|
||||||
kr *pmcrypto.KeyRing
|
|
||||||
|
log *logrus.Entry
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new API client.
|
// newClient creates a new API client.
|
||||||
func NewClient(cfg *ClientConfig, userID string) *Client {
|
func newClient(cm *ClientManager, userID string) *Client {
|
||||||
|
return &Client{
|
||||||
|
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
|
||||||
|
client: getHTTPClient(cm.GetConfig()),
|
||||||
|
cm: cm,
|
||||||
|
userID: userID,
|
||||||
|
requestLocker: &sync.Mutex{},
|
||||||
|
keyLocker: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHTTPClient returns a http client configured by the given client config.
|
||||||
|
func getHTTPClient(cfg *ClientConfig) *http.Client {
|
||||||
hc := &http.Client{
|
hc := &http.Client{
|
||||||
Timeout: cfg.Timeout,
|
Timeout: cfg.Timeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Transport != nil {
|
if cfg.Transport != nil {
|
||||||
cfgTransport, ok := cfg.Transport.(*http.Transport)
|
cfgTransport, ok := cfg.Transport.(*http.Transport)
|
||||||
if ok {
|
if ok {
|
||||||
@ -168,37 +154,7 @@ func NewClient(cfg *ClientConfig, userID string) *Client {
|
|||||||
hc.Transport = defaultTransport
|
hc.Transport = defaultTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
log := logrus.WithFields(logrus.Fields{
|
return hc
|
||||||
"pkg": "pmapi",
|
|
||||||
"userID": userID,
|
|
||||||
})
|
|
||||||
|
|
||||||
return &Client{
|
|
||||||
log: log,
|
|
||||||
config: cfg,
|
|
||||||
client: hc,
|
|
||||||
tokenManager: cfg.TokenManager,
|
|
||||||
userID: userID,
|
|
||||||
requestLocker: &sync.Mutex{},
|
|
||||||
keyLocker: &sync.Mutex{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetConnectionReporter sets the connection reporter used by the client to report when
|
|
||||||
// internet connection is lost.
|
|
||||||
func (c *Client) SetConnectionReporter(conrep ConnectionReporter) {
|
|
||||||
c.conrep = conrep
|
|
||||||
}
|
|
||||||
|
|
||||||
// reportLostConnection reports that the internet connection has been lost using the connection reporter.
|
|
||||||
// If the connection reporter has not been set, this does nothing.
|
|
||||||
func (c *Client) reportLostConnection() {
|
|
||||||
if c.conrep != nil {
|
|
||||||
err := c.conrep.NotifyConnectionLost()
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to notify of lost connection")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do makes an API request. It does not check for HTTP status code errors.
|
// Do makes an API request. It does not check for HTTP status code errors.
|
||||||
@ -224,7 +180,7 @@ func (c *Client) Do(req *http.Request, retryUnauthorized bool) (res *http.Respon
|
|||||||
func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen]
|
func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen]
|
||||||
isAuthReq := strings.Contains(req.URL.Path, "/auth")
|
isAuthReq := strings.Contains(req.URL.Path, "/auth")
|
||||||
|
|
||||||
req.Header.Set("x-pm-appversion", c.config.AppVersion)
|
req.Header.Set("x-pm-appversion", c.cm.GetConfig().AppVersion)
|
||||||
req.Header.Set("x-pm-apiversion", strconv.Itoa(Version))
|
req.Header.Set("x-pm-apiversion", strconv.Itoa(Version))
|
||||||
|
|
||||||
if c.uid != "" {
|
if c.uid != "" {
|
||||||
@ -252,7 +208,6 @@ func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthori
|
|||||||
if res == nil {
|
if res == nil {
|
||||||
c.log.WithError(err).Error("Cannot get response")
|
c.log.WithError(err).Error("Cannot get response")
|
||||||
err = ErrAPINotReachable
|
err = ErrAPINotReachable
|
||||||
c.reportLostConnection()
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -328,7 +283,7 @@ func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data in
|
|||||||
req.Header.Set("Accept", "application/vnd.protonmail.v1+json")
|
req.Header.Set("Accept", "application/vnd.protonmail.v1+json")
|
||||||
|
|
||||||
var cancelRequest context.CancelFunc
|
var cancelRequest context.CancelFunc
|
||||||
if c.config.MinSpeed > 0 {
|
if c.cm.GetConfig().MinSpeed > 0 {
|
||||||
var ctx context.Context
|
var ctx context.Context
|
||||||
ctx, cancelRequest = context.WithCancel(req.Context())
|
ctx, cancelRequest = context.WithCancel(req.Context())
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -344,7 +299,7 @@ func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data in
|
|||||||
defer res.Body.Close() //nolint[errcheck]
|
defer res.Body.Close() //nolint[errcheck]
|
||||||
|
|
||||||
var resBody []byte
|
var resBody []byte
|
||||||
if c.config.MinSpeed == 0 {
|
if c.cm.GetConfig().MinSpeed == 0 {
|
||||||
resBody, err = ioutil.ReadAll(res.Body)
|
resBody, err = ioutil.ReadAll(res.Body)
|
||||||
} else {
|
} else {
|
||||||
resBody, err = c.readAllMinSpeed(res.Body, cancelRequest)
|
resBody, err = c.readAllMinSpeed(res.Body, cancelRequest)
|
||||||
@ -422,7 +377,7 @@ func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data in
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) {
|
func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) {
|
||||||
firstReadTimeout := c.config.FirstReadTimeout
|
firstReadTimeout := c.cm.GetConfig().FirstReadTimeout
|
||||||
if firstReadTimeout == 0 {
|
if firstReadTimeout == 0 {
|
||||||
firstReadTimeout = 5 * time.Minute
|
firstReadTimeout = 5 * time.Minute
|
||||||
}
|
}
|
||||||
@ -431,7 +386,7 @@ func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFun
|
|||||||
})
|
})
|
||||||
var buffer bytes.Buffer
|
var buffer bytes.Buffer
|
||||||
for {
|
for {
|
||||||
_, err := io.CopyN(&buffer, data, c.config.MinSpeed)
|
_, err := io.CopyN(&buffer, data, c.cm.GetConfig().MinSpeed)
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
timer.Reset(1 * time.Second)
|
timer.Reset(1 * time.Second)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
@ -445,15 +400,13 @@ func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFun
|
|||||||
|
|
||||||
func (c *Client) refreshAccessToken() (err error) {
|
func (c *Client) refreshAccessToken() (err error) {
|
||||||
c.log.Debug("Refreshing token")
|
c.log.Debug("Refreshing token")
|
||||||
refreshToken := c.tokenManager.GetToken(c.userID)
|
refreshToken := c.cm.GetToken(c.userID)
|
||||||
c.log.WithField("token", refreshToken).Info("Current refresh token")
|
c.log.WithField("token", refreshToken).Info("Current refresh token")
|
||||||
if refreshToken == "" {
|
if refreshToken == "" {
|
||||||
if c.auths != nil {
|
if c.auths != nil {
|
||||||
c.auths <- nil
|
c.auths <- nil
|
||||||
}
|
}
|
||||||
if c.tokenManager != nil {
|
c.cm.ClearToken(c.userID)
|
||||||
c.tokenManager.SetToken(c.userID, "")
|
|
||||||
}
|
|
||||||
return ErrInvalidToken
|
return ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -465,9 +418,7 @@ func (c *Client) refreshAccessToken() (err error) {
|
|||||||
if c.auths != nil {
|
if c.auths != nil {
|
||||||
c.auths <- nil
|
c.auths <- nil
|
||||||
}
|
}
|
||||||
if c.tokenManager != nil {
|
c.cm.ClearToken(c.userID)
|
||||||
c.tokenManager.SetToken(c.userID, "")
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.uid = auth.UID()
|
c.uid = auth.UID()
|
||||||
|
|||||||
@ -37,8 +37,7 @@ var testClientConfig = &ClientConfig{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newTestClient() *Client {
|
func newTestClient() *Client {
|
||||||
c := NewClient(testClientConfig, "tester")
|
c := newClient(NewClientManager(testClientConfig), "tester")
|
||||||
c.tokenManager = NewTokenManager()
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
69
pkg/pmapi/clientmanager.go
Normal file
69
pkg/pmapi/clientmanager.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package pmapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/getsentry/raven-go"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientManager is a manager of clients.
|
||||||
|
type ClientManager struct {
|
||||||
|
// TODO: Lockers.
|
||||||
|
|
||||||
|
clients map[string]*Client
|
||||||
|
tokens map[string]string
|
||||||
|
config *ClientConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
|
||||||
|
func NewClientManager(config *ClientConfig) *ClientManager {
|
||||||
|
if err := raven.SetDSN(config.SentryDSN); err != nil {
|
||||||
|
logrus.WithError(err).Error("Could not set up sentry DSN")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ClientManager{
|
||||||
|
clients: make(map[string]*Client),
|
||||||
|
tokens: make(map[string]string),
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient returns a client for the given userID.
|
||||||
|
// If the client does not exist already, it is created.
|
||||||
|
func (cm *ClientManager) GetClient(userID string) *Client {
|
||||||
|
if client, ok := cm.clients[userID]; ok {
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.clients[userID] = newClient(cm, userID)
|
||||||
|
|
||||||
|
return cm.clients[userID]
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig returns the config used to configure clients.
|
||||||
|
func (cm *ClientManager) GetConfig() *ClientConfig {
|
||||||
|
return cm.config
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetToken returns the token for the given userID.
|
||||||
|
func (cm *ClientManager) GetToken(userID string) string {
|
||||||
|
return cm.tokens[userID]
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetToken sets the token for the given userID.
|
||||||
|
func (cm *ClientManager) SetToken(userID, token string) {
|
||||||
|
cm.tokens[userID] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTokenIfUnset sets the token for the given userID if it does not yet have a token.
|
||||||
|
func (cm *ClientManager) SetTokenIfUnset(userID, token string) {
|
||||||
|
if _, ok := cm.tokens[userID]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.tokens[userID] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearToken clears the token of the given userID.
|
||||||
|
func (cm *ClientManager) ClearToken(userID string) {
|
||||||
|
delete(cm.tokens, userID)
|
||||||
|
}
|
||||||
@ -42,7 +42,7 @@ func TestTLSPinValid(t *testing.T) {
|
|||||||
called, _ := newTestDialerWithPinning()
|
called, _ := newTestDialerWithPinning()
|
||||||
|
|
||||||
RootURL = liveAPI
|
RootURL = liveAPI
|
||||||
client := NewClient(testLiveConfig, "pmapi"+t.Name())
|
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
||||||
|
|
||||||
_, err := client.AuthInfo("this.address.is.disabled")
|
_, err := client.AuthInfo("this.address.is.disabled")
|
||||||
Ok(t, err)
|
Ok(t, err)
|
||||||
@ -56,7 +56,7 @@ func TestTLSPinBackup(t *testing.T) {
|
|||||||
p.report.KnownPins[0] = ""
|
p.report.KnownPins[0] = ""
|
||||||
|
|
||||||
RootURL = liveAPI
|
RootURL = liveAPI
|
||||||
client := NewClient(testLiveConfig, "pmapi"+t.Name())
|
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
||||||
|
|
||||||
_, err := client.AuthInfo("this.address.is.disabled")
|
_, err := client.AuthInfo("this.address.is.disabled")
|
||||||
Ok(t, err)
|
Ok(t, err)
|
||||||
@ -71,13 +71,13 @@ func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused]
|
|||||||
}
|
}
|
||||||
|
|
||||||
RootURL = liveAPI
|
RootURL = liveAPI
|
||||||
client := NewClient(testLiveConfig, "pmapi"+t.Name())
|
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
||||||
|
|
||||||
_, err := client.AuthInfo("this.address.is.disabled")
|
_, err := client.AuthInfo("this.address.is.disabled")
|
||||||
Ok(t, err)
|
Ok(t, err)
|
||||||
|
|
||||||
// check that it will be called only once per session
|
// check that it will be called only once per session
|
||||||
client = NewClient(testLiveConfig, "pmapi"+t.Name())
|
client = newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
||||||
_, err = client.AuthInfo("this.address.is.disabled")
|
_, err = client.AuthInfo("this.address.is.disabled")
|
||||||
Ok(t, err)
|
Ok(t, err)
|
||||||
|
|
||||||
@ -92,7 +92,7 @@ func _TestTLSPinInvalid(t *testing.T) { // nolint[unused]
|
|||||||
|
|
||||||
called, _ := newTestDialerWithPinning()
|
called, _ := newTestDialerWithPinning()
|
||||||
|
|
||||||
client := NewClient(testLiveConfig, "pmapi"+t.Name())
|
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
||||||
|
|
||||||
RootURL = liveAPI
|
RootURL = liveAPI
|
||||||
_, err := client.AuthInfo("this.address.is.disabled")
|
_, err := client.AuthInfo("this.address.is.disabled")
|
||||||
|
|||||||
@ -152,8 +152,8 @@ func (c *Client) ReportSentryCrash(reportErr error) (err error) {
|
|||||||
}
|
}
|
||||||
tags := map[string]string{
|
tags := map[string]string{
|
||||||
"OS": runtime.GOOS,
|
"OS": runtime.GOOS,
|
||||||
"Client": c.config.ClientID,
|
"Client": c.cm.GetConfig().ClientID,
|
||||||
"Version": c.config.AppVersion,
|
"Version": c.cm.GetConfig().AppVersion,
|
||||||
"UserAgent": CurrentUserAgent,
|
"UserAgent": CurrentUserAgent,
|
||||||
"UserID": c.userID,
|
"UserID": c.userID,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -25,7 +25,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestSentryCrashReport(t *testing.T) {
|
func TestSentryCrashReport(t *testing.T) {
|
||||||
c := NewClient(testClientConfig, "bridgetest")
|
c := newClient(NewClientManager(testClientConfig), "bridgetest")
|
||||||
if err := c.ReportSentryCrash(errors.New("Testing crash report - api proxy; goroutines with threads, find origin")); err != nil {
|
if err := c.ReportSentryCrash(errors.New("Testing crash report - api proxy; goroutines with threads, find origin")); err != nil {
|
||||||
t.Fatal("Expected no error while report, but have", err)
|
t.Fatal("Expected no error while report, but have", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user