forked from Silverfish/proton-bridge
feat: simple client manager
This commit is contained in:
@ -78,8 +78,6 @@ func newConfig(appName, version, revision, cacheVersion string, appDirs, appDirs
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
},
|
||||
// TokenManager should not be required, but PMAPI still doesn't handle not-set cases everywhere.
|
||||
TokenManager: pmapi.NewTokenManager(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,9 +64,11 @@ func GetLogEntry(packageName string) *logrus.Entry {
|
||||
// HandlePanic reports the crash to sentry or local file when sentry fails.
|
||||
func HandlePanic(cfg *Config, output string) {
|
||||
if !cfg.IsDevMode() {
|
||||
c := pmapi.NewClient(cfg.GetAPIConfig(), "no-user-id")
|
||||
err := c.ReportSentryCrash(fmt.Errorf(output))
|
||||
if err != nil {
|
||||
// TODO: Is it okay to just create a throwaway client like this?
|
||||
c := pmapi.NewClientManager(cfg.GetAPIConfig()).GetClient("no-user-id")
|
||||
defer func() { _ = c.Logout() }()
|
||||
|
||||
if err := c.ReportSentryCrash(fmt.Errorf(output)); err != nil {
|
||||
log.Error("Sentry crash report failed: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -307,11 +307,7 @@ func (c *Client) Auth(username, password string, info *AuthInfo) (auth *Auth, er
|
||||
if c.auths != nil {
|
||||
c.auths <- auth
|
||||
}
|
||||
|
||||
if c.tokenManager != nil {
|
||||
c.tokenManager.SetToken(c.userID, c.uid+":"+auth.RefreshToken)
|
||||
c.log.Info("Set token from auth " + c.uid + ":" + auth.RefreshToken)
|
||||
}
|
||||
c.cm.SetToken(c.userID, c.uid+":"+auth.RefreshToken)
|
||||
|
||||
// Auth has to be fully unlocked to get key salt. During `Auth` it can happen
|
||||
// only to accounts without 2FA. For 2FA accounts, it's done in `Auth2FA`.
|
||||
@ -407,14 +403,7 @@ func (c *Client) Unlock(password string) (kr *pmcrypto.KeyRing, err error) {
|
||||
func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) {
|
||||
// If we don't yet have a saved access token, save this one in case the refresh fails!
|
||||
// That way we can try again later (see handleUnauthorizedStatus).
|
||||
if c.tokenManager != nil {
|
||||
currentAccessToken := c.tokenManager.GetToken(c.userID)
|
||||
if currentAccessToken == "" {
|
||||
c.log.WithField("token", uidAndRefreshToken).
|
||||
Info("Currently have no access token, setting given one")
|
||||
c.tokenManager.SetToken(c.userID, uidAndRefreshToken)
|
||||
}
|
||||
}
|
||||
c.cm.SetTokenIfUnset(c.userID, uidAndRefreshToken)
|
||||
|
||||
split := strings.Split(uidAndRefreshToken, ":")
|
||||
if len(split) != 2 {
|
||||
@ -456,11 +445,7 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
||||
|
||||
c.uid = auth.UID()
|
||||
c.accessToken = auth.accessToken
|
||||
|
||||
if c.tokenManager != nil {
|
||||
c.tokenManager.SetToken(c.userID, c.uid+":"+res.RefreshToken)
|
||||
c.log.Info("Set token from auth refresh " + c.uid + ":" + res.RefreshToken)
|
||||
}
|
||||
c.cm.SetToken(c.userID, c.uid+":"+res.RefreshToken)
|
||||
|
||||
c.expiresAt = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
|
||||
return auth, err
|
||||
@ -497,9 +482,7 @@ func (c *Client) Logout() (err error) {
|
||||
c.kr = nil
|
||||
// c.addresses = nil
|
||||
c.user = nil
|
||||
if c.tokenManager != nil {
|
||||
c.tokenManager.SetToken(c.userID, "")
|
||||
}
|
||||
c.cm.ClearToken(c.userID)
|
||||
// }()
|
||||
|
||||
return err
|
||||
|
||||
@ -353,8 +353,7 @@ func TestClient_DoUnauthorized(t *testing.T) {
|
||||
c.uid = testUID
|
||||
c.accessToken = testAccessTokenOld
|
||||
c.expiresAt = aLongTimeAgo
|
||||
c.tokenManager = NewTokenManager()
|
||||
c.tokenManager.tokenMap[c.userID] = testUID + ":" + testRefreshToken
|
||||
c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken
|
||||
|
||||
req, err := NewRequest("GET", "/", nil)
|
||||
Ok(t, err)
|
||||
|
||||
@ -132,8 +132,8 @@ func writeMultipartReport(w *multipart.Writer, rep *ReportReq) error { // nolint
|
||||
|
||||
// Report sends request as json or multipart (if has attachment).
|
||||
func (c *Client) Report(rep ReportReq) (err error) {
|
||||
rep.Client = c.config.ClientID
|
||||
rep.ClientVersion = c.config.AppVersion
|
||||
rep.Client = c.cm.GetConfig().ClientID
|
||||
rep.ClientVersion = c.cm.GetConfig().AppVersion
|
||||
rep.ClientType = EmailClientType
|
||||
|
||||
var req *http.Request
|
||||
@ -196,8 +196,8 @@ func (c *Client) ReportBugWithEmailClient(os, osVersion, title, description, use
|
||||
// ReportCrash is old. Use sentry instead.
|
||||
func (c *Client) ReportCrash(stacktrace string) (err error) {
|
||||
crashReq := ReportReq{
|
||||
Client: c.config.ClientID,
|
||||
ClientVersion: c.config.AppVersion,
|
||||
Client: c.cm.GetConfig().ClientID,
|
||||
ClientVersion: c.cm.GetConfig().AppVersion,
|
||||
ClientType: EmailClientType,
|
||||
OS: runtime.GOOS,
|
||||
Debug: stacktrace,
|
||||
|
||||
@ -68,33 +68,6 @@ func (err *ErrUnauthorized) Error() string {
|
||||
return fmt.Sprintf("unauthorized access: %+v", err.error.Error())
|
||||
}
|
||||
|
||||
type TokenManager struct {
|
||||
tokensLocker sync.Locker
|
||||
tokenMap map[string]string
|
||||
}
|
||||
|
||||
func NewTokenManager() *TokenManager {
|
||||
tm := &TokenManager{
|
||||
tokensLocker: &sync.Mutex{},
|
||||
tokenMap: map[string]string{},
|
||||
}
|
||||
return tm
|
||||
}
|
||||
|
||||
func (tm *TokenManager) GetToken(userID string) string {
|
||||
tm.tokensLocker.Lock()
|
||||
defer tm.tokensLocker.Unlock()
|
||||
|
||||
return tm.tokenMap[userID]
|
||||
}
|
||||
|
||||
func (tm *TokenManager) SetToken(userID, token string) {
|
||||
tm.tokensLocker.Lock()
|
||||
defer tm.tokensLocker.Unlock()
|
||||
|
||||
tm.tokenMap[userID] = token
|
||||
}
|
||||
|
||||
// ClientConfig contains Client configuration.
|
||||
type ClientConfig struct {
|
||||
// The client application name and version.
|
||||
@ -103,7 +76,8 @@ type ClientConfig struct {
|
||||
// The client ID.
|
||||
ClientID string
|
||||
|
||||
TokenManager *TokenManager
|
||||
// The sentry DSN.
|
||||
SentryDSN string
|
||||
|
||||
// Transport specifies the mechanism by which individual HTTP requests are made.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
@ -127,10 +101,8 @@ type ClientConfig struct {
|
||||
type Client struct {
|
||||
auths chan<- *Auth // Channel that sends Auth responses back to the bridge.
|
||||
|
||||
log *logrus.Entry
|
||||
config *ClientConfig
|
||||
client *http.Client
|
||||
conrep ConnectionReporter
|
||||
cm *ClientManager
|
||||
|
||||
uid string
|
||||
accessToken string
|
||||
@ -138,18 +110,32 @@ type Client struct {
|
||||
requestLocker sync.Locker
|
||||
keyLocker sync.Locker
|
||||
|
||||
tokenManager *TokenManager
|
||||
expiresAt time.Time
|
||||
user *User
|
||||
addresses AddressList
|
||||
kr *pmcrypto.KeyRing
|
||||
expiresAt time.Time
|
||||
user *User
|
||||
addresses AddressList
|
||||
kr *pmcrypto.KeyRing
|
||||
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
// NewClient creates a new API client.
|
||||
func NewClient(cfg *ClientConfig, userID string) *Client {
|
||||
// newClient creates a new API client.
|
||||
func newClient(cm *ClientManager, userID string) *Client {
|
||||
return &Client{
|
||||
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
|
||||
client: getHTTPClient(cm.GetConfig()),
|
||||
cm: cm,
|
||||
userID: userID,
|
||||
requestLocker: &sync.Mutex{},
|
||||
keyLocker: &sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// getHTTPClient returns a http client configured by the given client config.
|
||||
func getHTTPClient(cfg *ClientConfig) *http.Client {
|
||||
hc := &http.Client{
|
||||
Timeout: cfg.Timeout,
|
||||
}
|
||||
|
||||
if cfg.Transport != nil {
|
||||
cfgTransport, ok := cfg.Transport.(*http.Transport)
|
||||
if ok {
|
||||
@ -168,37 +154,7 @@ func NewClient(cfg *ClientConfig, userID string) *Client {
|
||||
hc.Transport = defaultTransport
|
||||
}
|
||||
|
||||
log := logrus.WithFields(logrus.Fields{
|
||||
"pkg": "pmapi",
|
||||
"userID": userID,
|
||||
})
|
||||
|
||||
return &Client{
|
||||
log: log,
|
||||
config: cfg,
|
||||
client: hc,
|
||||
tokenManager: cfg.TokenManager,
|
||||
userID: userID,
|
||||
requestLocker: &sync.Mutex{},
|
||||
keyLocker: &sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// SetConnectionReporter sets the connection reporter used by the client to report when
|
||||
// internet connection is lost.
|
||||
func (c *Client) SetConnectionReporter(conrep ConnectionReporter) {
|
||||
c.conrep = conrep
|
||||
}
|
||||
|
||||
// reportLostConnection reports that the internet connection has been lost using the connection reporter.
|
||||
// If the connection reporter has not been set, this does nothing.
|
||||
func (c *Client) reportLostConnection() {
|
||||
if c.conrep != nil {
|
||||
err := c.conrep.NotifyConnectionLost()
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("Failed to notify of lost connection")
|
||||
}
|
||||
}
|
||||
return hc
|
||||
}
|
||||
|
||||
// Do makes an API request. It does not check for HTTP status code errors.
|
||||
@ -224,7 +180,7 @@ func (c *Client) Do(req *http.Request, retryUnauthorized bool) (res *http.Respon
|
||||
func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen]
|
||||
isAuthReq := strings.Contains(req.URL.Path, "/auth")
|
||||
|
||||
req.Header.Set("x-pm-appversion", c.config.AppVersion)
|
||||
req.Header.Set("x-pm-appversion", c.cm.GetConfig().AppVersion)
|
||||
req.Header.Set("x-pm-apiversion", strconv.Itoa(Version))
|
||||
|
||||
if c.uid != "" {
|
||||
@ -252,7 +208,6 @@ func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthori
|
||||
if res == nil {
|
||||
c.log.WithError(err).Error("Cannot get response")
|
||||
err = ErrAPINotReachable
|
||||
c.reportLostConnection()
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -328,7 +283,7 @@ func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data in
|
||||
req.Header.Set("Accept", "application/vnd.protonmail.v1+json")
|
||||
|
||||
var cancelRequest context.CancelFunc
|
||||
if c.config.MinSpeed > 0 {
|
||||
if c.cm.GetConfig().MinSpeed > 0 {
|
||||
var ctx context.Context
|
||||
ctx, cancelRequest = context.WithCancel(req.Context())
|
||||
defer func() {
|
||||
@ -344,7 +299,7 @@ func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data in
|
||||
defer res.Body.Close() //nolint[errcheck]
|
||||
|
||||
var resBody []byte
|
||||
if c.config.MinSpeed == 0 {
|
||||
if c.cm.GetConfig().MinSpeed == 0 {
|
||||
resBody, err = ioutil.ReadAll(res.Body)
|
||||
} else {
|
||||
resBody, err = c.readAllMinSpeed(res.Body, cancelRequest)
|
||||
@ -422,7 +377,7 @@ func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data in
|
||||
}
|
||||
|
||||
func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) {
|
||||
firstReadTimeout := c.config.FirstReadTimeout
|
||||
firstReadTimeout := c.cm.GetConfig().FirstReadTimeout
|
||||
if firstReadTimeout == 0 {
|
||||
firstReadTimeout = 5 * time.Minute
|
||||
}
|
||||
@ -431,7 +386,7 @@ func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFun
|
||||
})
|
||||
var buffer bytes.Buffer
|
||||
for {
|
||||
_, err := io.CopyN(&buffer, data, c.config.MinSpeed)
|
||||
_, err := io.CopyN(&buffer, data, c.cm.GetConfig().MinSpeed)
|
||||
timer.Stop()
|
||||
timer.Reset(1 * time.Second)
|
||||
if err == io.EOF {
|
||||
@ -445,15 +400,13 @@ func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFun
|
||||
|
||||
func (c *Client) refreshAccessToken() (err error) {
|
||||
c.log.Debug("Refreshing token")
|
||||
refreshToken := c.tokenManager.GetToken(c.userID)
|
||||
refreshToken := c.cm.GetToken(c.userID)
|
||||
c.log.WithField("token", refreshToken).Info("Current refresh token")
|
||||
if refreshToken == "" {
|
||||
if c.auths != nil {
|
||||
c.auths <- nil
|
||||
}
|
||||
if c.tokenManager != nil {
|
||||
c.tokenManager.SetToken(c.userID, "")
|
||||
}
|
||||
c.cm.ClearToken(c.userID)
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
@ -465,9 +418,7 @@ func (c *Client) refreshAccessToken() (err error) {
|
||||
if c.auths != nil {
|
||||
c.auths <- nil
|
||||
}
|
||||
if c.tokenManager != nil {
|
||||
c.tokenManager.SetToken(c.userID, "")
|
||||
}
|
||||
c.cm.ClearToken(c.userID)
|
||||
return
|
||||
}
|
||||
c.uid = auth.UID()
|
||||
|
||||
@ -37,8 +37,7 @@ var testClientConfig = &ClientConfig{
|
||||
}
|
||||
|
||||
func newTestClient() *Client {
|
||||
c := NewClient(testClientConfig, "tester")
|
||||
c.tokenManager = NewTokenManager()
|
||||
c := newClient(NewClientManager(testClientConfig), "tester")
|
||||
return c
|
||||
}
|
||||
|
||||
|
||||
69
pkg/pmapi/clientmanager.go
Normal file
69
pkg/pmapi/clientmanager.go
Normal file
@ -0,0 +1,69 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"github.com/getsentry/raven-go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ClientManager is a manager of clients.
|
||||
type ClientManager struct {
|
||||
// TODO: Lockers.
|
||||
|
||||
clients map[string]*Client
|
||||
tokens map[string]string
|
||||
config *ClientConfig
|
||||
}
|
||||
|
||||
// NewClientManager creates a new ClientMan which manages clients configured with the given client config.
|
||||
func NewClientManager(config *ClientConfig) *ClientManager {
|
||||
if err := raven.SetDSN(config.SentryDSN); err != nil {
|
||||
logrus.WithError(err).Error("Could not set up sentry DSN")
|
||||
}
|
||||
|
||||
return &ClientManager{
|
||||
clients: make(map[string]*Client),
|
||||
tokens: make(map[string]string),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// GetClient returns a client for the given userID.
|
||||
// If the client does not exist already, it is created.
|
||||
func (cm *ClientManager) GetClient(userID string) *Client {
|
||||
if client, ok := cm.clients[userID]; ok {
|
||||
return client
|
||||
}
|
||||
|
||||
cm.clients[userID] = newClient(cm, userID)
|
||||
|
||||
return cm.clients[userID]
|
||||
}
|
||||
|
||||
// GetConfig returns the config used to configure clients.
|
||||
func (cm *ClientManager) GetConfig() *ClientConfig {
|
||||
return cm.config
|
||||
}
|
||||
|
||||
// GetToken returns the token for the given userID.
|
||||
func (cm *ClientManager) GetToken(userID string) string {
|
||||
return cm.tokens[userID]
|
||||
}
|
||||
|
||||
// SetToken sets the token for the given userID.
|
||||
func (cm *ClientManager) SetToken(userID, token string) {
|
||||
cm.tokens[userID] = token
|
||||
}
|
||||
|
||||
// SetTokenIfUnset sets the token for the given userID if it does not yet have a token.
|
||||
func (cm *ClientManager) SetTokenIfUnset(userID, token string) {
|
||||
if _, ok := cm.tokens[userID]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
cm.tokens[userID] = token
|
||||
}
|
||||
|
||||
// ClearToken clears the token of the given userID.
|
||||
func (cm *ClientManager) ClearToken(userID string) {
|
||||
delete(cm.tokens, userID)
|
||||
}
|
||||
@ -42,7 +42,7 @@ func TestTLSPinValid(t *testing.T) {
|
||||
called, _ := newTestDialerWithPinning()
|
||||
|
||||
RootURL = liveAPI
|
||||
client := NewClient(testLiveConfig, "pmapi"+t.Name())
|
||||
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
||||
|
||||
_, err := client.AuthInfo("this.address.is.disabled")
|
||||
Ok(t, err)
|
||||
@ -56,7 +56,7 @@ func TestTLSPinBackup(t *testing.T) {
|
||||
p.report.KnownPins[0] = ""
|
||||
|
||||
RootURL = liveAPI
|
||||
client := NewClient(testLiveConfig, "pmapi"+t.Name())
|
||||
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
||||
|
||||
_, err := client.AuthInfo("this.address.is.disabled")
|
||||
Ok(t, err)
|
||||
@ -71,13 +71,13 @@ func _TestTLSPinNoMatch(t *testing.T) { // nolint[unused]
|
||||
}
|
||||
|
||||
RootURL = liveAPI
|
||||
client := NewClient(testLiveConfig, "pmapi"+t.Name())
|
||||
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
||||
|
||||
_, err := client.AuthInfo("this.address.is.disabled")
|
||||
Ok(t, err)
|
||||
|
||||
// check that it will be called only once per session
|
||||
client = NewClient(testLiveConfig, "pmapi"+t.Name())
|
||||
client = newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
||||
_, err = client.AuthInfo("this.address.is.disabled")
|
||||
Ok(t, err)
|
||||
|
||||
@ -92,7 +92,7 @@ func _TestTLSPinInvalid(t *testing.T) { // nolint[unused]
|
||||
|
||||
called, _ := newTestDialerWithPinning()
|
||||
|
||||
client := NewClient(testLiveConfig, "pmapi"+t.Name())
|
||||
client := newClient(NewClientManager(testLiveConfig), "pmapi"+t.Name())
|
||||
|
||||
RootURL = liveAPI
|
||||
_, err := client.AuthInfo("this.address.is.disabled")
|
||||
|
||||
@ -152,8 +152,8 @@ func (c *Client) ReportSentryCrash(reportErr error) (err error) {
|
||||
}
|
||||
tags := map[string]string{
|
||||
"OS": runtime.GOOS,
|
||||
"Client": c.config.ClientID,
|
||||
"Version": c.config.AppVersion,
|
||||
"Client": c.cm.GetConfig().ClientID,
|
||||
"Version": c.cm.GetConfig().AppVersion,
|
||||
"UserAgent": CurrentUserAgent,
|
||||
"UserID": c.userID,
|
||||
}
|
||||
|
||||
@ -25,7 +25,7 @@ import (
|
||||
)
|
||||
|
||||
func TestSentryCrashReport(t *testing.T) {
|
||||
c := NewClient(testClientConfig, "bridgetest")
|
||||
c := newClient(NewClientManager(testClientConfig), "bridgetest")
|
||||
if err := c.ReportSentryCrash(errors.New("Testing crash report - api proxy; goroutines with threads, find origin")); err != nil {
|
||||
t.Fatal("Expected no error while report, but have", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user