feat: simple client manager

This commit is contained in:
James Houlahan
2020-04-01 13:42:25 +02:00
parent fb263e84a9
commit 0a55fac29a
17 changed files with 199 additions and 198 deletions

View File

@ -68,33 +68,6 @@ func (err *ErrUnauthorized) Error() string {
return fmt.Sprintf("unauthorized access: %+v", err.error.Error())
}
type TokenManager struct {
tokensLocker sync.Locker
tokenMap map[string]string
}
func NewTokenManager() *TokenManager {
tm := &TokenManager{
tokensLocker: &sync.Mutex{},
tokenMap: map[string]string{},
}
return tm
}
func (tm *TokenManager) GetToken(userID string) string {
tm.tokensLocker.Lock()
defer tm.tokensLocker.Unlock()
return tm.tokenMap[userID]
}
func (tm *TokenManager) SetToken(userID, token string) {
tm.tokensLocker.Lock()
defer tm.tokensLocker.Unlock()
tm.tokenMap[userID] = token
}
// ClientConfig contains Client configuration.
type ClientConfig struct {
// The client application name and version.
@ -103,7 +76,8 @@ type ClientConfig struct {
// The client ID.
ClientID string
TokenManager *TokenManager
// The sentry DSN.
SentryDSN string
// Transport specifies the mechanism by which individual HTTP requests are made.
// If nil, http.DefaultTransport is used.
@ -127,10 +101,8 @@ type ClientConfig struct {
type Client struct {
auths chan<- *Auth // Channel that sends Auth responses back to the bridge.
log *logrus.Entry
config *ClientConfig
client *http.Client
conrep ConnectionReporter
cm *ClientManager
uid string
accessToken string
@ -138,18 +110,32 @@ type Client struct {
requestLocker sync.Locker
keyLocker sync.Locker
tokenManager *TokenManager
expiresAt time.Time
user *User
addresses AddressList
kr *pmcrypto.KeyRing
expiresAt time.Time
user *User
addresses AddressList
kr *pmcrypto.KeyRing
log *logrus.Entry
}
// NewClient creates a new API client.
func NewClient(cfg *ClientConfig, userID string) *Client {
// newClient creates a new API client.
func newClient(cm *ClientManager, userID string) *Client {
return &Client{
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
client: getHTTPClient(cm.GetConfig()),
cm: cm,
userID: userID,
requestLocker: &sync.Mutex{},
keyLocker: &sync.Mutex{},
}
}
// getHTTPClient returns a http client configured by the given client config.
func getHTTPClient(cfg *ClientConfig) *http.Client {
hc := &http.Client{
Timeout: cfg.Timeout,
}
if cfg.Transport != nil {
cfgTransport, ok := cfg.Transport.(*http.Transport)
if ok {
@ -168,37 +154,7 @@ func NewClient(cfg *ClientConfig, userID string) *Client {
hc.Transport = defaultTransport
}
log := logrus.WithFields(logrus.Fields{
"pkg": "pmapi",
"userID": userID,
})
return &Client{
log: log,
config: cfg,
client: hc,
tokenManager: cfg.TokenManager,
userID: userID,
requestLocker: &sync.Mutex{},
keyLocker: &sync.Mutex{},
}
}
// SetConnectionReporter sets the connection reporter used by the client to report when
// internet connection is lost.
func (c *Client) SetConnectionReporter(conrep ConnectionReporter) {
c.conrep = conrep
}
// reportLostConnection reports that the internet connection has been lost using the connection reporter.
// If the connection reporter has not been set, this does nothing.
func (c *Client) reportLostConnection() {
if c.conrep != nil {
err := c.conrep.NotifyConnectionLost()
if err != nil {
logrus.WithError(err).Error("Failed to notify of lost connection")
}
}
return hc
}
// Do makes an API request. It does not check for HTTP status code errors.
@ -224,7 +180,7 @@ func (c *Client) Do(req *http.Request, retryUnauthorized bool) (res *http.Respon
func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen]
isAuthReq := strings.Contains(req.URL.Path, "/auth")
req.Header.Set("x-pm-appversion", c.config.AppVersion)
req.Header.Set("x-pm-appversion", c.cm.GetConfig().AppVersion)
req.Header.Set("x-pm-apiversion", strconv.Itoa(Version))
if c.uid != "" {
@ -252,7 +208,6 @@ func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthori
if res == nil {
c.log.WithError(err).Error("Cannot get response")
err = ErrAPINotReachable
c.reportLostConnection()
}
return
}
@ -328,7 +283,7 @@ func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data in
req.Header.Set("Accept", "application/vnd.protonmail.v1+json")
var cancelRequest context.CancelFunc
if c.config.MinSpeed > 0 {
if c.cm.GetConfig().MinSpeed > 0 {
var ctx context.Context
ctx, cancelRequest = context.WithCancel(req.Context())
defer func() {
@ -344,7 +299,7 @@ func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data in
defer res.Body.Close() //nolint[errcheck]
var resBody []byte
if c.config.MinSpeed == 0 {
if c.cm.GetConfig().MinSpeed == 0 {
resBody, err = ioutil.ReadAll(res.Body)
} else {
resBody, err = c.readAllMinSpeed(res.Body, cancelRequest)
@ -422,7 +377,7 @@ func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data in
}
func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) {
firstReadTimeout := c.config.FirstReadTimeout
firstReadTimeout := c.cm.GetConfig().FirstReadTimeout
if firstReadTimeout == 0 {
firstReadTimeout = 5 * time.Minute
}
@ -431,7 +386,7 @@ func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFun
})
var buffer bytes.Buffer
for {
_, err := io.CopyN(&buffer, data, c.config.MinSpeed)
_, err := io.CopyN(&buffer, data, c.cm.GetConfig().MinSpeed)
timer.Stop()
timer.Reset(1 * time.Second)
if err == io.EOF {
@ -445,15 +400,13 @@ func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFun
func (c *Client) refreshAccessToken() (err error) {
c.log.Debug("Refreshing token")
refreshToken := c.tokenManager.GetToken(c.userID)
refreshToken := c.cm.GetToken(c.userID)
c.log.WithField("token", refreshToken).Info("Current refresh token")
if refreshToken == "" {
if c.auths != nil {
c.auths <- nil
}
if c.tokenManager != nil {
c.tokenManager.SetToken(c.userID, "")
}
c.cm.ClearToken(c.userID)
return ErrInvalidToken
}
@ -465,9 +418,7 @@ func (c *Client) refreshAccessToken() (err error) {
if c.auths != nil {
c.auths <- nil
}
if c.tokenManager != nil {
c.tokenManager.SetToken(c.userID, "")
}
c.cm.ClearToken(c.userID)
return
}
c.uid = auth.UID()