forked from Silverfish/proton-bridge
feat: central auth channel for clients
This commit is contained in:
@ -66,7 +66,7 @@ func HandlePanic(cfg *Config, output string) {
|
||||
if !cfg.IsDevMode() {
|
||||
// TODO: Is it okay to just create a throwaway client like this?
|
||||
c := pmapi.NewClientManager(cfg.GetAPIConfig()).GetClient("no-user-id")
|
||||
defer func() { _ = c.Logout() }()
|
||||
defer c.Logout()
|
||||
|
||||
if err := c.ReportSentryCrash(fmt.Errorf(output)); err != nil {
|
||||
log.Error("Sentry crash report failed: ", err)
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@ -122,6 +123,10 @@ func (s *Auth) UID() string {
|
||||
return s.uid
|
||||
}
|
||||
|
||||
func (s *Auth) GenToken() string {
|
||||
return fmt.Sprintf("%v:%v", s.UID(), s.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *Auth) HasTwoFactor() bool {
|
||||
if s.TwoFA == nil {
|
||||
return false
|
||||
@ -191,9 +196,16 @@ type AuthRefreshReq struct {
|
||||
State string
|
||||
}
|
||||
|
||||
// SetAuths sets auths channel.
|
||||
func (c *Client) SetAuths(auths chan<- *Auth) {
|
||||
c.auths = auths
|
||||
func (c *Client) sendAuth(auth *Auth) {
|
||||
c.cm.getClientAuthChannel() <- ClientAuth{
|
||||
UserID: c.userID,
|
||||
Auth: auth,
|
||||
}
|
||||
|
||||
if auth != nil {
|
||||
c.uid = auth.UID()
|
||||
c.accessToken = auth.accessToken
|
||||
}
|
||||
}
|
||||
|
||||
// AuthInfo gets authentication info for a user.
|
||||
@ -301,13 +313,7 @@ func (c *Client) Auth(username, password string, info *AuthInfo) (auth *Auth, er
|
||||
}
|
||||
|
||||
auth = authRes.getAuth()
|
||||
c.uid = auth.UID()
|
||||
c.accessToken = auth.accessToken
|
||||
|
||||
if c.auths != nil {
|
||||
c.auths <- auth
|
||||
}
|
||||
c.cm.SetToken(c.userID, c.uid+":"+auth.RefreshToken)
|
||||
c.sendAuth(auth)
|
||||
|
||||
// Auth has to be fully unlocked to get key salt. During `Auth` it can happen
|
||||
// only to accounts without 2FA. For 2FA accounts, it's done in `Auth2FA`.
|
||||
@ -403,7 +409,8 @@ func (c *Client) Unlock(password string) (kr *pmcrypto.KeyRing, err error) {
|
||||
func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
|
||||
// If we don't yet have a saved access token, save this one in case the refresh fails!
|
||||
// That way we can try again later (see handleUnauthorizedStatus).
|
||||
c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
|
||||
// TODO:
|
||||
// c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
|
||||
|
||||
split := strings.Split(uidAndRefreshToken, ":")
|
||||
if len(split) != 2 {
|
||||
@ -437,22 +444,18 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
||||
}
|
||||
|
||||
auth = res.getAuth()
|
||||
// UID should never change after auth, see backend-communication#11
|
||||
auth.uid = c.uid
|
||||
if c.auths != nil {
|
||||
c.auths <- auth
|
||||
}
|
||||
|
||||
c.uid = auth.UID()
|
||||
c.accessToken = auth.accessToken
|
||||
c.cm.SetToken(c.userID, c.uid+":"+res.RefreshToken)
|
||||
|
||||
c.sendAuth(auth)
|
||||
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
|
||||
|
||||
return auth, err
|
||||
}
|
||||
|
||||
// Logout logs the current user out.
|
||||
func (c *Client) Logout() (err error) {
|
||||
func (c *Client) Logout() {
|
||||
c.cm.LogoutClient(c.userID)
|
||||
}
|
||||
|
||||
// logout logs the current user out.
|
||||
func (c *Client) logout() (err error) {
|
||||
req, err := NewRequest("DELETE", "/auth", nil)
|
||||
if err != nil {
|
||||
return
|
||||
@ -467,23 +470,13 @@ func (c *Client) Logout() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// This can trigger a deadlock! We don't want to do it if the above requests failed (GODT-154).
|
||||
// That's why it's not in the deferred statement above.
|
||||
if c.auths != nil {
|
||||
c.auths <- nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// This should ideally be deferred at the top of this method so that it is executed
|
||||
// regardless of what happens, but we currently don't have a way to prevent ourselves
|
||||
// from using a logged out client. So for now, it's down here, as it was in Charles release.
|
||||
// defer func() {
|
||||
func (c *Client) clearSensitiveData() {
|
||||
c.uid = ""
|
||||
c.accessToken = ""
|
||||
c.kr = nil
|
||||
// c.addresses = nil
|
||||
c.addresses = nil
|
||||
c.user = nil
|
||||
c.cm.ClearToken(c.userID)
|
||||
// }()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -99,10 +99,8 @@ type ClientConfig struct {
|
||||
|
||||
// Client to communicate with API.
|
||||
type Client struct {
|
||||
auths chan<- *Auth // Channel that sends Auth responses back to the bridge.
|
||||
|
||||
client *http.Client
|
||||
cm *ClientManager
|
||||
client *http.Client
|
||||
|
||||
uid string
|
||||
accessToken string
|
||||
@ -121,39 +119,42 @@ type Client struct {
|
||||
// newClient creates a new API client.
|
||||
func newClient(cm *ClientManager, userID string) *Client {
|
||||
return &Client{
|
||||
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
|
||||
client: getHTTPClient(cm.GetConfig()),
|
||||
cm: cm,
|
||||
client: getHTTPClient(cm.GetConfig()),
|
||||
userID: userID,
|
||||
requestLocker: &sync.Mutex{},
|
||||
keyLocker: &sync.Mutex{},
|
||||
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
|
||||
}
|
||||
}
|
||||
|
||||
// getHTTPClient returns a http client configured by the given client config.
|
||||
func getHTTPClient(cfg *ClientConfig) *http.Client {
|
||||
hc := &http.Client{
|
||||
Timeout: cfg.Timeout,
|
||||
func getHTTPClient(cfg *ClientConfig) (hc *http.Client) {
|
||||
hc = &http.Client{Timeout: cfg.Timeout}
|
||||
|
||||
if cfg.Transport == nil && defaultTransport == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.Transport != nil {
|
||||
cfgTransport, ok := cfg.Transport.(*http.Transport)
|
||||
if ok {
|
||||
// In future use Clone here.
|
||||
// https://go-review.googlesource.com/c/go/+/174597/
|
||||
transport := &http.Transport{}
|
||||
*transport = *cfgTransport //nolint
|
||||
if transport.Proxy == nil {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
hc.Transport = transport
|
||||
} else {
|
||||
hc.Transport = cfg.Transport
|
||||
}
|
||||
} else if defaultTransport != nil {
|
||||
if defaultTransport != nil {
|
||||
hc.Transport = defaultTransport
|
||||
return
|
||||
}
|
||||
|
||||
// In future use Clone here.
|
||||
// https://go-review.googlesource.com/c/go/+/174597/
|
||||
if cfgTransport, ok := cfg.Transport.(*http.Transport); ok {
|
||||
transport := &http.Transport{}
|
||||
*transport = *cfgTransport //nolint
|
||||
if transport.Proxy == nil {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
hc.Transport = transport
|
||||
return
|
||||
}
|
||||
|
||||
hc.Transport = cfg.Transport
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
@ -400,30 +401,20 @@ func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFun
|
||||
|
||||
func (c *Client) refreshAccessToken() (err error) {
|
||||
c.log.Debug("Refreshing token")
|
||||
|
||||
refreshToken := c.cm.GetToken(c.userID)
|
||||
c.log.WithField("token", refreshToken).Info("Current refresh token")
|
||||
|
||||
if refreshToken == "" {
|
||||
if c.auths != nil {
|
||||
c.auths <- nil
|
||||
}
|
||||
c.cm.ClearToken(c.userID)
|
||||
c.sendAuth(nil)
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
auth, err := c.AuthRefresh(refreshToken)
|
||||
if err != nil {
|
||||
c.log.WithError(err).WithField("auths", c.auths).Debug("Token refreshing failed")
|
||||
// The refresh failed, so we should log the user out.
|
||||
// A nil value in the Auths channel will trigger this.
|
||||
if c.auths != nil {
|
||||
c.auths <- nil
|
||||
}
|
||||
c.cm.ClearToken(c.userID)
|
||||
return
|
||||
if _, err := c.AuthRefresh(refreshToken); err != nil {
|
||||
c.sendAuth(nil)
|
||||
return err
|
||||
}
|
||||
c.uid = auth.UID()
|
||||
c.accessToken = auth.accessToken
|
||||
return err
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) {
|
||||
|
||||
@ -1,30 +1,50 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/getsentry/raven-go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ClientManager is a manager of clients.
|
||||
type ClientManager struct {
|
||||
// TODO: Lockers.
|
||||
clients map[string]*Client
|
||||
clientsLocker sync.Locker
|
||||
|
||||
clients map[string]*Client
|
||||
tokens map[string]string
|
||||
config *ClientConfig
|
||||
tokens map[string]string
|
||||
tokensLocker sync.Locker
|
||||
|
||||
config *ClientConfig
|
||||
|
||||
bridgeAuths chan ClientAuth
|
||||
clientAuths chan ClientAuth
|
||||
}
|
||||
|
||||
type ClientAuth struct {
|
||||
UserID string
|
||||
Auth *Auth
|
||||
}
|
||||
|
||||
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
|
||||
func NewClientManager(config *ClientConfig) *ClientManager {
|
||||
func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
||||
if err := raven.SetDSN(config.SentryDSN); err != nil {
|
||||
logrus.WithError(err).Error("Could not set up sentry DSN")
|
||||
}
|
||||
|
||||
return &ClientManager{
|
||||
clients: make(map[string]*Client),
|
||||
tokens: make(map[string]string),
|
||||
config: config,
|
||||
cm = &ClientManager{
|
||||
clients: make(map[string]*Client),
|
||||
clientsLocker: &sync.Mutex{},
|
||||
tokens: make(map[string]string),
|
||||
tokensLocker: &sync.Mutex{},
|
||||
config: config,
|
||||
bridgeAuths: make(chan ClientAuth),
|
||||
clientAuths: make(chan ClientAuth),
|
||||
}
|
||||
|
||||
go cm.forwardClientAuths()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetClient returns a client for the given userID.
|
||||
@ -39,6 +59,28 @@ func (cm *ClientManager) GetClient(userID string) *Client {
|
||||
return cm.clients[userID]
|
||||
}
|
||||
|
||||
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
||||
func (cm *ClientManager) LogoutClient(userID string) {
|
||||
client, ok := cm.clients[userID]
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(cm.clients, userID)
|
||||
|
||||
go func() {
|
||||
if err := client.logout(); err != nil {
|
||||
// TODO: Try again!
|
||||
logrus.WithError(err).Error("Client logout failed, not trying again")
|
||||
}
|
||||
client.clearSensitiveData()
|
||||
cm.clearToken(userID)
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetConfig returns the config used to configure clients.
|
||||
func (cm *ClientManager) GetConfig() *ClientConfig {
|
||||
return cm.config
|
||||
@ -49,21 +91,52 @@ func (cm *ClientManager) GetToken(userID string) string {
|
||||
return cm.tokens[userID]
|
||||
}
|
||||
|
||||
// SetToken sets the token for the given userID.
|
||||
func (cm *ClientManager) SetToken(userID, token string) {
|
||||
// GetBridgeAuthChannel returns a channel on which client auths can be received.
|
||||
func (cm *ClientManager) GetBridgeAuthChannel() chan ClientAuth {
|
||||
return cm.clientAuths
|
||||
}
|
||||
|
||||
// getClientAuthChannel returns a channel on which clients should send auths.
|
||||
func (cm *ClientManager) getClientAuthChannel() chan ClientAuth {
|
||||
return cm.clientAuths
|
||||
}
|
||||
|
||||
// forwardClientAuths handles all incoming auths from clients before forwarding them on the bridge auth channel.
|
||||
func (cm *ClientManager) forwardClientAuths() {
|
||||
for auth := range cm.clientAuths {
|
||||
cm.handleClientAuth(auth)
|
||||
cm.bridgeAuths <- auth
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ClientManager) setToken(userID, token string) {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).WithField("token", token).Info("Updating refresh token")
|
||||
|
||||
cm.tokens[userID] = token
|
||||
}
|
||||
|
||||
// SetTokenIfUnset sets the token for the given userID if it does not yet have a token.
|
||||
func (cm *ClientManager) SetTokenIfUnset(userID, token string) {
|
||||
if _, ok := cm.tokens[userID]; ok {
|
||||
func (cm *ClientManager) clearToken(userID string) {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).Info("Clearing refresh token")
|
||||
|
||||
delete(cm.tokens, userID)
|
||||
}
|
||||
|
||||
// handleClientAuth
|
||||
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||
// TODO: Maybe want to logout the client in case of nil auth.
|
||||
if _, ok := cm.clients[ca.UserID]; !ok {
|
||||
return
|
||||
}
|
||||
|
||||
cm.tokens[userID] = token
|
||||
}
|
||||
|
||||
// ClearToken clears the token of the given userID.
|
||||
func (cm *ClientManager) ClearToken(userID string) {
|
||||
delete(cm.tokens, userID)
|
||||
if ca.Auth == nil {
|
||||
cm.clearToken(ca.UserID)
|
||||
} else {
|
||||
cm.setToken(ca.UserID, ca.Auth.GenToken())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user