mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-18 08:06:59 +00:00
feat: implement token expiration watcher
This commit is contained in:
@ -181,10 +181,11 @@ func (b *Bridge) watchBridgeOutdated() {
|
|||||||
// watchUserAuths receives auths from the client manager and sends them to the appropriate user.
|
// watchUserAuths receives auths from the client manager and sends them to the appropriate user.
|
||||||
func (b *Bridge) watchUserAuths() {
|
func (b *Bridge) watchUserAuths() {
|
||||||
for auth := range b.clientManager.GetBridgeAuthChannel() {
|
for auth := range b.clientManager.GetBridgeAuthChannel() {
|
||||||
logrus.WithField("token", auth.Auth.GenToken()).WithField("userID", auth.UserID).Info("Received auth from bridge auth channel")
|
logrus.Debug("Bridge received auth from ClientManager")
|
||||||
|
|
||||||
if user, ok := b.hasUser(auth.UserID); ok {
|
if user, ok := b.hasUser(auth.UserID); ok {
|
||||||
user.ReceiveAPIAuth(auth.Auth)
|
logrus.Debug("Bridge is forwarding auth to user")
|
||||||
|
user.AuthorizeWithAPIAuth(auth.Auth)
|
||||||
} else {
|
} else {
|
||||||
logrus.Info("User is not added to bridge yet")
|
logrus.Info("User is not added to bridge yet")
|
||||||
}
|
}
|
||||||
@ -494,11 +495,7 @@ func (b *Bridge) updateCurrentUserAgent() {
|
|||||||
|
|
||||||
// hasUser returns whether the bridge currently has a user with ID `id`.
|
// hasUser returns whether the bridge currently has a user with ID `id`.
|
||||||
func (b *Bridge) hasUser(id string) (user *User, ok bool) {
|
func (b *Bridge) hasUser(id string) (user *User, ok bool) {
|
||||||
logrus.WithField("id", id).Info("Checking whether bridge has given user")
|
|
||||||
|
|
||||||
for _, u := range b.users {
|
for _, u := range b.users {
|
||||||
logrus.WithField("id", u.ID()).Info("Found potential user")
|
|
||||||
|
|
||||||
if u.ID() == id {
|
if u.ID() == id {
|
||||||
user, ok = u, true
|
user, ok = u, true
|
||||||
return
|
return
|
||||||
|
|||||||
@ -243,10 +243,12 @@ func (u *User) authorizeAndUnlock() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) ReceiveAPIAuth(auth *pmapi.Auth) {
|
func (u *User) AuthorizeWithAPIAuth(auth *pmapi.Auth) {
|
||||||
u.lock.Lock()
|
u.lock.Lock()
|
||||||
defer u.lock.Unlock()
|
defer u.lock.Unlock()
|
||||||
|
|
||||||
|
u.log.Debug("User received auth from bridge")
|
||||||
|
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
if err := u.logout(); err != nil {
|
if err := u.logout(); err != nil {
|
||||||
u.log.WithError(err).Error("Failed to logout user after receiving empty auth from API")
|
u.log.WithError(err).Error("Failed to logout user after receiving empty auth from API")
|
||||||
|
|||||||
@ -24,7 +24,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
pmcrypto "github.com/ProtonMail/gopenpgp/crypto"
|
pmcrypto "github.com/ProtonMail/gopenpgp/crypto"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/srp"
|
"github.com/ProtonMail/proton-bridge/pkg/srp"
|
||||||
@ -197,6 +196,9 @@ type AuthRefreshReq struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) sendAuth(auth *Auth) {
|
func (c *Client) sendAuth(auth *Auth) {
|
||||||
|
go func() {
|
||||||
|
c.log.Debug("Client is sending auth to ClientManager")
|
||||||
|
|
||||||
c.cm.getClientAuthChannel() <- ClientAuth{
|
c.cm.getClientAuthChannel() <- ClientAuth{
|
||||||
UserID: c.userID,
|
UserID: c.userID,
|
||||||
Auth: auth,
|
Auth: auth,
|
||||||
@ -206,6 +208,7 @@ func (c *Client) sendAuth(auth *Auth) {
|
|||||||
c.uid = auth.UID()
|
c.uid = auth.UID()
|
||||||
c.accessToken = auth.accessToken
|
c.accessToken = auth.accessToken
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthInfo gets authentication info for a user.
|
// AuthInfo gets authentication info for a user.
|
||||||
@ -324,7 +327,6 @@ func (c *Client) Auth(username, password string, info *AuthInfo) (auth *Auth, er
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
|
|
||||||
return auth, err
|
return auth, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -445,7 +447,6 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
|||||||
|
|
||||||
auth = res.getAuth()
|
auth = res.getAuth()
|
||||||
c.sendAuth(auth)
|
c.sendAuth(auth)
|
||||||
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
|
|
||||||
|
|
||||||
return auth, err
|
return auth, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -81,6 +81,7 @@ type ClientConfig struct {
|
|||||||
|
|
||||||
// Transport specifies the mechanism by which individual HTTP requests are made.
|
// Transport specifies the mechanism by which individual HTTP requests are made.
|
||||||
// If nil, http.DefaultTransport is used.
|
// If nil, http.DefaultTransport is used.
|
||||||
|
// TODO: This could be removed entirely and set in the client manager via SetClientRoundTripper.
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
|
|
||||||
// Timeout specifies the timeout from request to getting response headers to our API.
|
// Timeout specifies the timeout from request to getting response headers to our API.
|
||||||
@ -108,7 +109,6 @@ type Client struct {
|
|||||||
requestLocker sync.Locker
|
requestLocker sync.Locker
|
||||||
keyLocker sync.Locker
|
keyLocker sync.Locker
|
||||||
|
|
||||||
expiresAt time.Time
|
|
||||||
user *User
|
user *User
|
||||||
addresses AddressList
|
addresses AddressList
|
||||||
kr *pmcrypto.KeyRing
|
kr *pmcrypto.KeyRing
|
||||||
|
|||||||
@ -3,12 +3,15 @@ package pmapi
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/getsentry/raven-go"
|
"github.com/getsentry/raven-go"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var proxyUseDuration = 24 * time.Hour
|
||||||
|
|
||||||
// ClientManager is a manager of clients.
|
// ClientManager is a manager of clients.
|
||||||
type ClientManager struct {
|
type ClientManager struct {
|
||||||
config *ClientConfig
|
config *ClientConfig
|
||||||
@ -17,6 +20,7 @@ type ClientManager struct {
|
|||||||
clientsLocker sync.Locker
|
clientsLocker sync.Locker
|
||||||
|
|
||||||
tokens map[string]string
|
tokens map[string]string
|
||||||
|
tokenExpirations map[string]*tokenExpiration
|
||||||
tokensLocker sync.Locker
|
tokensLocker sync.Locker
|
||||||
|
|
||||||
url string
|
url string
|
||||||
@ -34,6 +38,11 @@ type ClientAuth struct {
|
|||||||
Auth *Auth
|
Auth *Auth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type tokenExpiration struct {
|
||||||
|
timer *time.Timer
|
||||||
|
cancel chan (struct{})
|
||||||
|
}
|
||||||
|
|
||||||
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
|
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
|
||||||
func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||||
if err := raven.SetDSN(config.SentryDSN); err != nil {
|
if err := raven.SetDSN(config.SentryDSN); err != nil {
|
||||||
@ -47,6 +56,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
|||||||
clientsLocker: &sync.Mutex{},
|
clientsLocker: &sync.Mutex{},
|
||||||
|
|
||||||
tokens: make(map[string]string),
|
tokens: make(map[string]string),
|
||||||
|
tokenExpirations: make(map[string]*tokenExpiration),
|
||||||
tokensLocker: &sync.Mutex{},
|
tokensLocker: &sync.Mutex{},
|
||||||
|
|
||||||
url: RootURL,
|
url: RootURL,
|
||||||
@ -112,43 +122,56 @@ func (cm *ClientManager) GetRootURL() string {
|
|||||||
return cm.url
|
return cm.url
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRootURL sets the root URL to make requests to.
|
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
||||||
func (cm *ClientManager) SetRootURL(url string) {
|
func (cm *ClientManager) IsProxyAllowed() bool {
|
||||||
cm.urlLocker.Lock()
|
cm.urlLocker.Lock()
|
||||||
defer cm.urlLocker.Unlock()
|
defer cm.urlLocker.Unlock()
|
||||||
|
|
||||||
logrus.WithField("url", url).Info("Changing to a new root URL")
|
|
||||||
|
|
||||||
cm.url = url
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
|
||||||
func (cm *ClientManager) IsProxyAllowed() bool {
|
|
||||||
return cm.allowProxy
|
return cm.allowProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
|
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
|
||||||
func (cm *ClientManager) AllowProxy() {
|
func (cm *ClientManager) AllowProxy() {
|
||||||
|
cm.urlLocker.Lock()
|
||||||
|
defer cm.urlLocker.Unlock()
|
||||||
|
|
||||||
cm.allowProxy = true
|
cm.allowProxy = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
|
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
|
||||||
func (cm *ClientManager) DisallowProxy() {
|
func (cm *ClientManager) DisallowProxy() {
|
||||||
|
cm.urlLocker.Lock()
|
||||||
|
defer cm.urlLocker.Unlock()
|
||||||
|
|
||||||
cm.allowProxy = false
|
cm.allowProxy = false
|
||||||
|
cm.url = RootURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProxyEnabled returns whether we are currently proxying requests.
|
||||||
|
func (cm *ClientManager) IsProxyEnabled() bool {
|
||||||
|
cm.urlLocker.Lock()
|
||||||
|
defer cm.urlLocker.Unlock()
|
||||||
|
|
||||||
|
return cm.url != RootURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindProxy returns a usable proxy server.
|
// FindProxy returns a usable proxy server.
|
||||||
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
|
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
|
||||||
logrus.Info("Attempting gto switch to a proxy")
|
cm.urlLocker.Lock()
|
||||||
|
defer cm.urlLocker.Unlock()
|
||||||
|
|
||||||
|
logrus.Info("Attempting to switch to a proxy")
|
||||||
|
|
||||||
if proxy, err = cm.proxyProvider.findProxy(); err != nil {
|
if proxy, err = cm.proxyProvider.findProxy(); err != nil {
|
||||||
err = errors.Wrap(err, "failed to find usable proxy")
|
err = errors.Wrap(err, "failed to find a usable proxy")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.SetRootURL(proxy)
|
logrus.WithField("proxy", proxy).Info("Switching to a proxy")
|
||||||
|
|
||||||
// TODO: Disable after 24 hours.
|
cm.url = proxy
|
||||||
|
|
||||||
|
// TODO: Disable again after 24 hours.
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -165,7 +188,7 @@ func (cm *ClientManager) GetToken(userID string) string {
|
|||||||
|
|
||||||
// GetBridgeAuthChannel returns a channel on which client auths can be received.
|
// GetBridgeAuthChannel returns a channel on which client auths can be received.
|
||||||
func (cm *ClientManager) GetBridgeAuthChannel() chan ClientAuth {
|
func (cm *ClientManager) GetBridgeAuthChannel() chan ClientAuth {
|
||||||
return cm.clientAuths
|
return cm.bridgeAuths
|
||||||
}
|
}
|
||||||
|
|
||||||
// getClientAuthChannel returns a channel on which clients should send auths.
|
// getClientAuthChannel returns a channel on which clients should send auths.
|
||||||
@ -176,25 +199,46 @@ func (cm *ClientManager) getClientAuthChannel() chan ClientAuth {
|
|||||||
// forwardClientAuths handles all incoming auths from clients before forwarding them on the bridge auth channel.
|
// forwardClientAuths handles all incoming auths from clients before forwarding them on the bridge auth channel.
|
||||||
func (cm *ClientManager) forwardClientAuths() {
|
func (cm *ClientManager) forwardClientAuths() {
|
||||||
for auth := range cm.clientAuths {
|
for auth := range cm.clientAuths {
|
||||||
|
logrus.Debug("ClientManager received auth from client")
|
||||||
cm.handleClientAuth(auth)
|
cm.handleClientAuth(auth)
|
||||||
|
logrus.Debug("ClientManager is forwarding auth to bridge")
|
||||||
cm.bridgeAuths <- auth
|
cm.bridgeAuths <- auth
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ClientManager) setToken(userID, token string) {
|
// setToken sets the token for the given userID with the given expiration time.
|
||||||
|
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) {
|
||||||
cm.tokensLocker.Lock()
|
cm.tokensLocker.Lock()
|
||||||
defer cm.tokensLocker.Unlock()
|
defer cm.tokensLocker.Unlock()
|
||||||
|
|
||||||
logrus.WithField("userID", userID).Info("Updating refresh token")
|
logrus.WithField("userID", userID).Info("Updating token")
|
||||||
|
|
||||||
cm.tokens[userID] = token
|
cm.tokens[userID] = token
|
||||||
|
|
||||||
|
cm.setTokenExpiration(userID, expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setTokenExpiration will ensure the token is refreshed if it expires.
|
||||||
|
// If the token already has an expiration time set, it is replaced.
|
||||||
|
func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) {
|
||||||
|
if exp, ok := cm.tokenExpirations[userID]; ok {
|
||||||
|
exp.timer.Stop()
|
||||||
|
close(exp.cancel)
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.tokenExpirations[userID] = &tokenExpiration{
|
||||||
|
timer: time.NewTimer(expiration),
|
||||||
|
cancel: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
go cm.watchTokenExpiration(userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ClientManager) clearToken(userID string) {
|
func (cm *ClientManager) clearToken(userID string) {
|
||||||
cm.tokensLocker.Lock()
|
cm.tokensLocker.Lock()
|
||||||
defer cm.tokensLocker.Unlock()
|
defer cm.tokensLocker.Unlock()
|
||||||
|
|
||||||
logrus.WithField("userID", userID).Info("Clearing refresh token")
|
logrus.WithField("userID", userID).Info("Clearing token")
|
||||||
|
|
||||||
delete(cm.tokens, userID)
|
delete(cm.tokens, userID)
|
||||||
}
|
}
|
||||||
@ -203,6 +247,7 @@ func (cm *ClientManager) clearToken(userID string) {
|
|||||||
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||||
// If we aren't managing this client, there's nothing to do.
|
// If we aren't managing this client, there's nothing to do.
|
||||||
if _, ok := cm.clients[ca.UserID]; !ok {
|
if _, ok := cm.clients[ca.UserID]; !ok {
|
||||||
|
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,5 +258,18 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.setToken(ca.UserID, ca.Auth.GenToken())
|
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ClientManager) watchTokenExpiration(userID string) {
|
||||||
|
expiration := cm.tokenExpirations[userID]
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-expiration.timer.C:
|
||||||
|
logrus.WithField("userID", userID).Info("Auth token expired! Refreshing")
|
||||||
|
cm.clients[userID].AuthRefresh(cm.tokens[userID])
|
||||||
|
|
||||||
|
case <-expiration.cancel:
|
||||||
|
logrus.WithField("userID", userID).Info("Auth was refreshed before it expired, cancelling this watcher")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -282,13 +282,15 @@ func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (c
|
|||||||
func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn net.Conn, err error) {
|
func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn net.Conn, err error) {
|
||||||
p.log.Info("Dialing with proxy fallback")
|
p.log.Info("Dialing with proxy fallback")
|
||||||
|
|
||||||
var host, port string
|
// Try to dial, and if it succeeds, then just return.
|
||||||
if host, port, err = net.SplitHostPort(address); err != nil {
|
if conn, err = p.dial(network, address); err == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to dial, and if it succeeds, then just return.
|
p.log.WithField("address", address).WithError(err).Error("Dialing failed")
|
||||||
if conn, err = p.dial(network, address); err == nil {
|
|
||||||
|
host, port, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,18 +298,21 @@ func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn
|
|||||||
// (e.g. we dial protonmail.com/... to check for updates), there's also no point in
|
// (e.g. we dial protonmail.com/... to check for updates), there's also no point in
|
||||||
// continuing since a proxy won't help us reach that.
|
// continuing since a proxy won't help us reach that.
|
||||||
if !p.cm.IsProxyAllowed() || host != p.cm.GetRootURL() {
|
if !p.cm.IsProxyAllowed() || host != p.cm.GetRootURL() {
|
||||||
p.log.WithField("useProxy", p.cm.IsProxyAllowed()).Info("Dial failed but not switching to proxy")
|
p.log.WithField("address", address).Debug("Aborting dial, cannot switch to a proxy")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Switch to a proxy and retry the dial.
|
// Switch to a proxy and retry the dial.
|
||||||
var proxy string
|
proxy, err := p.cm.SwitchToProxy()
|
||||||
|
if err != nil {
|
||||||
if proxy, err = p.cm.SwitchToProxy(); err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.dial(network, net.JoinHostPort(proxy, port))
|
proxyAddress := net.JoinHostPort(proxy, port)
|
||||||
|
|
||||||
|
p.log.WithField("address", proxyAddress).Debug("Trying dial again using a proxy")
|
||||||
|
|
||||||
|
return p.dial(network, proxyAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// dial returns a connection to the given address using the given network.
|
// dial returns a connection to the given address using the given network.
|
||||||
|
|||||||
@ -51,7 +51,6 @@ type proxyProvider struct {
|
|||||||
query string // The query string used to find proxies.
|
query string // The query string used to find proxies.
|
||||||
proxyCache []string // All known proxies, cached in case DoH providers are unreachable.
|
proxyCache []string // All known proxies, cached in case DoH providers are unreachable.
|
||||||
|
|
||||||
useDuration time.Duration // How much time to use the proxy before returning to the original API.
|
|
||||||
findTimeout, lookupTimeout time.Duration // Timeouts for DNS query and proxy search.
|
findTimeout, lookupTimeout time.Duration // Timeouts for DNS query and proxy search.
|
||||||
|
|
||||||
lastLookup time.Time // The time at which we last attempted to find a proxy.
|
lastLookup time.Time // The time at which we last attempted to find a proxy.
|
||||||
@ -63,7 +62,6 @@ func newProxyProvider(providers []string, query string) (p *proxyProvider) { //
|
|||||||
p = &proxyProvider{
|
p = &proxyProvider{
|
||||||
providers: providers,
|
providers: providers,
|
||||||
query: query,
|
query: query,
|
||||||
useDuration: proxyRevertTime,
|
|
||||||
findTimeout: proxySearchTimeout,
|
findTimeout: proxySearchTimeout,
|
||||||
lookupTimeout: proxyQueryTimeout,
|
lookupTimeout: proxyQueryTimeout,
|
||||||
}
|
}
|
||||||
@ -148,6 +146,7 @@ func (p *proxyProvider) canReach(url string) bool {
|
|||||||
SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) // nolint[gosec]
|
SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) // nolint[gosec]
|
||||||
|
|
||||||
if _, err := pinger.R().Get("/tests/ping"); err != nil {
|
if _, err := pinger.R().Get("/tests/ping"); err != nil {
|
||||||
|
logrus.WithField("proxy", url).WithError(err).Warn("Failed to ping proxy")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -109,7 +109,6 @@ func (c *Client) UpdateUser() (user *User, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.user = user
|
c.user = user
|
||||||
c.log.Infoln("update user:", user.ID)
|
|
||||||
raven.SetUserContext(&raven.User{ID: user.ID})
|
raven.SetUserContext(&raven.User{ID: user.ID})
|
||||||
|
|
||||||
var tmpList AddressList
|
var tmpList AddressList
|
||||||
@ -117,6 +116,8 @@ func (c *Client) UpdateUser() (user *User, err error) {
|
|||||||
c.addresses = tmpList
|
c.addresses = tmpList
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.log.WithField("userID", user.ID).Info("Updated user")
|
||||||
|
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||||
"github.com/ProtonMail/proton-bridge/internal/preferences"
|
"github.com/ProtonMail/proton-bridge/internal/preferences"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||||
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetBridge returns bridge instance.
|
// GetBridge returns bridge instance.
|
||||||
@ -34,10 +35,7 @@ func (ctx *TestContext) GetBridge() *bridge.Bridge {
|
|||||||
// withBridgeInstance creates a bridge instance for use in the test.
|
// withBridgeInstance creates a bridge instance for use in the test.
|
||||||
// Every TestContext has this by default and thus this doesn't need to be exported.
|
// Every TestContext has this by default and thus this doesn't need to be exported.
|
||||||
func (ctx *TestContext) withBridgeInstance() {
|
func (ctx *TestContext) withBridgeInstance() {
|
||||||
pmapiFactory := func(userID string) bridge.PMAPIProvider {
|
ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, ctx.clientManager)
|
||||||
return ctx.pmapiController.GetClient(userID)
|
|
||||||
}
|
|
||||||
ctx.bridge = newBridgeInstance(ctx.t, ctx.cfg, ctx.credStore, ctx.listener, pmapiFactory)
|
|
||||||
ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data")
|
ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,7 +60,7 @@ func newBridgeInstance(
|
|||||||
cfg *fakeConfig,
|
cfg *fakeConfig,
|
||||||
credStore bridge.CredentialsStorer,
|
credStore bridge.CredentialsStorer,
|
||||||
eventListener listener.Listener,
|
eventListener listener.Listener,
|
||||||
pmapiFactory bridge.PMAPIProviderFactory,
|
clientManager *pmapi.ClientManager,
|
||||||
) *bridge.Bridge {
|
) *bridge.Bridge {
|
||||||
version := os.Getenv("VERSION")
|
version := os.Getenv("VERSION")
|
||||||
bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "")
|
bridge.UpdateCurrentUserAgent(version, runtime.GOOS, "", "")
|
||||||
@ -70,7 +68,7 @@ func newBridgeInstance(
|
|||||||
panicHandler := &panicHandler{t: t}
|
panicHandler := &panicHandler{t: t}
|
||||||
pref := preferences.New(cfg)
|
pref := preferences.New(cfg)
|
||||||
|
|
||||||
return bridge.New(cfg, pref, panicHandler, eventListener, version, pmapiFactory, credStore)
|
return bridge.New(cfg, pref, panicHandler, eventListener, version, clientManager, credStore)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLastBridgeError sets the last error that occurred while executing a bridge action.
|
// SetLastBridgeError sets the last error that occurred while executing a bridge action.
|
||||||
|
|||||||
@ -28,7 +28,6 @@ import (
|
|||||||
|
|
||||||
type fakeConfig struct {
|
type fakeConfig struct {
|
||||||
dir string
|
dir string
|
||||||
tm *pmapi.TokenManager
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newFakeConfig creates a temporary folder for files.
|
// newFakeConfig creates a temporary folder for files.
|
||||||
@ -41,7 +40,6 @@ func newFakeConfig() *fakeConfig {
|
|||||||
|
|
||||||
return &fakeConfig{
|
return &fakeConfig{
|
||||||
dir: dir,
|
dir: dir,
|
||||||
tm: pmapi.NewTokenManager(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,8 +50,7 @@ func (c *fakeConfig) GetAPIConfig() *pmapi.ClientConfig {
|
|||||||
return &pmapi.ClientConfig{
|
return &pmapi.ClientConfig{
|
||||||
AppVersion: "Bridge_" + os.Getenv("VERSION"),
|
AppVersion: "Bridge_" + os.Getenv("VERSION"),
|
||||||
ClientID: "bridge",
|
ClientID: "bridge",
|
||||||
// TokenManager should not be required, but PMAPI still doesn't handle not-set cases everywhere.
|
SentryDSN: "",
|
||||||
TokenManager: c.tm,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (c *fakeConfig) GetDBDir() string {
|
func (c *fakeConfig) GetDBDir() string {
|
||||||
|
|||||||
@ -21,6 +21,7 @@ package context
|
|||||||
import (
|
import (
|
||||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||||
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
"github.com/ProtonMail/proton-bridge/test/accounts"
|
"github.com/ProtonMail/proton-bridge/test/accounts"
|
||||||
"github.com/ProtonMail/proton-bridge/test/mocks"
|
"github.com/ProtonMail/proton-bridge/test/mocks"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@ -57,6 +58,9 @@ type TestContext struct {
|
|||||||
smtpClients map[string]*mocks.SMTPClient
|
smtpClients map[string]*mocks.SMTPClient
|
||||||
smtpLastResponses map[string]*mocks.SMTPResponse
|
smtpLastResponses map[string]*mocks.SMTPResponse
|
||||||
|
|
||||||
|
// PMAPI related variables.
|
||||||
|
clientManager *pmapi.ClientManager
|
||||||
|
|
||||||
// These are the cleanup steps executed when Cleanup() is called.
|
// These are the cleanup steps executed when Cleanup() is called.
|
||||||
cleanupSteps []*Cleaner
|
cleanupSteps []*Cleaner
|
||||||
|
|
||||||
@ -70,17 +74,20 @@ func New() *TestContext {
|
|||||||
|
|
||||||
cfg := newFakeConfig()
|
cfg := newFakeConfig()
|
||||||
|
|
||||||
|
cm := pmapi.NewClientManager(cfg.GetAPIConfig())
|
||||||
|
|
||||||
ctx := &TestContext{
|
ctx := &TestContext{
|
||||||
t: &bddT{},
|
t: &bddT{},
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
listener: listener.New(),
|
listener: listener.New(),
|
||||||
pmapiController: newPMAPIController(),
|
pmapiController: newPMAPIController(cm),
|
||||||
testAccounts: newTestAccounts(),
|
testAccounts: newTestAccounts(),
|
||||||
credStore: newFakeCredStore(),
|
credStore: newFakeCredStore(),
|
||||||
imapClients: make(map[string]*mocks.IMAPClient),
|
imapClients: make(map[string]*mocks.IMAPClient),
|
||||||
imapLastResponses: make(map[string]*mocks.IMAPResponse),
|
imapLastResponses: make(map[string]*mocks.IMAPResponse),
|
||||||
smtpClients: make(map[string]*mocks.SMTPClient),
|
smtpClients: make(map[string]*mocks.SMTPClient),
|
||||||
smtpLastResponses: make(map[string]*mocks.SMTPResponse),
|
smtpLastResponses: make(map[string]*mocks.SMTPResponse),
|
||||||
|
clientManager: cm,
|
||||||
logger: logrus.StandardLogger(),
|
logger: logrus.StandardLogger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -40,12 +40,12 @@ type PMAPIController interface {
|
|||||||
GetCalls(method, path string) [][]byte
|
GetCalls(method, path string) [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPMAPIController() PMAPIController {
|
func newPMAPIController(cm *pmapi.ClientManager) PMAPIController {
|
||||||
switch os.Getenv(EnvName) {
|
switch os.Getenv(EnvName) {
|
||||||
case EnvFake:
|
case EnvFake:
|
||||||
return newFakePMAPIController()
|
return newFakePMAPIController()
|
||||||
case EnvLive:
|
case EnvLive:
|
||||||
return newLivePMAPIController()
|
return newLivePMAPIController(cm)
|
||||||
default:
|
default:
|
||||||
panic("unknown env")
|
panic("unknown env")
|
||||||
}
|
}
|
||||||
@ -67,8 +67,8 @@ func (s *fakePMAPIControllerWrap) GetClient(userID string) bridge.PMAPIProvider
|
|||||||
return s.Controller.GetClient(userID)
|
return s.Controller.GetClient(userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLivePMAPIController() PMAPIController {
|
func newLivePMAPIController(cm *pmapi.ClientManager) PMAPIController {
|
||||||
return newLiveAPIControllerWrap(liveapi.NewController())
|
return newLiveAPIControllerWrap(liveapi.NewController(cm))
|
||||||
}
|
}
|
||||||
|
|
||||||
type liveAPIControllerWrap struct {
|
type liveAPIControllerWrap struct {
|
||||||
|
|||||||
@ -141,13 +141,12 @@ func (api *FakePMAPI) AuthRefresh(token string) (*pmapi.Auth, error) {
|
|||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *FakePMAPI) Logout() error {
|
func (api *FakePMAPI) Logout() {
|
||||||
if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil {
|
if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
// Logout will also emit change to auth channel
|
// Logout will also emit change to auth channel
|
||||||
api.sendAuth(nil)
|
api.sendAuth(nil)
|
||||||
api.controller.deleteSession(api.uid)
|
api.controller.deleteSession(api.uid)
|
||||||
api.unsetUser()
|
api.unsetUser()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,9 +18,7 @@
|
|||||||
package liveapi
|
package liveapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
@ -32,31 +30,31 @@ type Controller struct {
|
|||||||
calls []*fakeCall
|
calls []*fakeCall
|
||||||
pmapiByUsername map[string]*pmapi.Client
|
pmapiByUsername map[string]*pmapi.Client
|
||||||
messageIDsByUsername map[string][]string
|
messageIDsByUsername map[string][]string
|
||||||
|
clientManager *pmapi.ClientManager
|
||||||
|
|
||||||
// State controlled by test.
|
// State controlled by test.
|
||||||
noInternetConnection bool
|
noInternetConnection bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewController() *Controller {
|
func NewController(cm *pmapi.ClientManager) *Controller {
|
||||||
return &Controller{
|
cntrl := &Controller{
|
||||||
lock: &sync.RWMutex{},
|
lock: &sync.RWMutex{},
|
||||||
calls: []*fakeCall{},
|
calls: []*fakeCall{},
|
||||||
pmapiByUsername: map[string]*pmapi.Client{},
|
pmapiByUsername: map[string]*pmapi.Client{},
|
||||||
messageIDsByUsername: map[string][]string{},
|
messageIDsByUsername: map[string][]string{},
|
||||||
|
clientManager: cm,
|
||||||
|
|
||||||
noInternetConnection: false,
|
noInternetConnection: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cntrl.clientManager.SetClientRoundTripper(&fakeTransport{
|
||||||
|
cntrl: cntrl,
|
||||||
|
transport: http.DefaultTransport,
|
||||||
|
})
|
||||||
|
|
||||||
|
return cntrl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cntrl *Controller) GetClient(userID string) *pmapi.Client {
|
func (cntrl *Controller) GetClient(userID string) *pmapi.Client {
|
||||||
cfg := &pmapi.ClientConfig{
|
return cntrl.clientManager.GetClient(userID)
|
||||||
AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")),
|
|
||||||
ClientID: "bridge-test",
|
|
||||||
Transport: &fakeTransport{
|
|
||||||
cntrl: cntrl,
|
|
||||||
transport: http.DefaultTransport,
|
|
||||||
},
|
|
||||||
TokenManager: pmapi.NewTokenManager(),
|
|
||||||
}
|
|
||||||
return pmapi.NewClient(cfg, userID)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,9 +18,6 @@
|
|||||||
package liveapi
|
package liveapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||||
"github.com/cucumber/godog"
|
"github.com/cucumber/godog"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@ -31,11 +28,7 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList,
|
|||||||
return godog.ErrPending
|
return godog.ErrPending
|
||||||
}
|
}
|
||||||
|
|
||||||
client := pmapi.NewClient(&pmapi.ClientConfig{
|
client := cntrl.GetClient(user.ID)
|
||||||
AppVersion: fmt.Sprintf("Bridge_%s", os.Getenv("VERSION")),
|
|
||||||
ClientID: "bridge-cntrl",
|
|
||||||
TokenManager: pmapi.NewTokenManager(),
|
|
||||||
}, user.ID)
|
|
||||||
|
|
||||||
authInfo, err := client.AuthInfo(user.Name)
|
authInfo, err := client.AuthInfo(user.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -62,5 +55,6 @@ func (cntrl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList,
|
|||||||
}
|
}
|
||||||
|
|
||||||
cntrl.pmapiByUsername[user.Name] = client
|
cntrl.pmapiByUsername[user.Name] = client
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user