mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-17 15:46:44 +00:00
GODT-35: New pmapi client and manager using resty
This commit is contained in:
@ -18,97 +18,23 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/jaytaylor/html2text"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Version of the API.
|
||||
const Version = 3
|
||||
|
||||
// API return codes.
|
||||
const (
|
||||
ForceUpgradeBadAppVersion = 5003
|
||||
APIOffline = 7001
|
||||
ImportMessageTooLong = 36022
|
||||
BansRequests = 85131
|
||||
)
|
||||
|
||||
// The output errors.
|
||||
var (
|
||||
ErrInvalidToken = errors.New("refresh token invalid")
|
||||
ErrAPINotReachable = errors.New("cannot reach the server")
|
||||
ErrUpgradeApplication = errors.New("application upgrade required")
|
||||
ErrConnectionSlow = errors.New("request canceled because connection speed was too slow")
|
||||
)
|
||||
|
||||
type ErrUnprocessableEntity struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (err *ErrUnprocessableEntity) Error() string {
|
||||
return err.error.Error()
|
||||
}
|
||||
|
||||
type ErrUnauthorized struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (err *ErrUnauthorized) Error() string {
|
||||
return fmt.Sprintf("unauthorized access: %+v", err.error.Error())
|
||||
}
|
||||
|
||||
// ClientConfig contains Client configuration.
|
||||
type ClientConfig struct {
|
||||
// The client application name and version.
|
||||
AppVersion string
|
||||
|
||||
// The client ID.
|
||||
ClientID string
|
||||
|
||||
// Timeout is the timeout of the full request. It is passed to http.Client.
|
||||
// If it is left unset, it means no timeout is applied.
|
||||
Timeout time.Duration
|
||||
|
||||
// FirstReadTimeout specifies the timeout from getting response to the first read of body response.
|
||||
// This timeout is applied only when MinBytesPerSecond is used.
|
||||
// Default is 5 minutes.
|
||||
FirstReadTimeout time.Duration
|
||||
|
||||
// MinBytesPerSecond specifies minimum Bytes per second or the request will be canceled.
|
||||
// Zero means no limitation.
|
||||
MinBytesPerSecond int64
|
||||
|
||||
ConnectionOnHandler func()
|
||||
ConnectionOffHandler func()
|
||||
UpgradeApplicationHandler func()
|
||||
}
|
||||
|
||||
// client is a client of the protonmail API. It implements the Client interface.
|
||||
type client struct {
|
||||
cm *ClientManager
|
||||
hc *http.Client
|
||||
req requester
|
||||
|
||||
uid string
|
||||
accessToken string
|
||||
userID string
|
||||
requestLocker sync.Locker
|
||||
refreshLocker sync.Locker
|
||||
uid, acc, ref string
|
||||
authHandlers []AuthHandler
|
||||
authLocker sync.RWMutex
|
||||
|
||||
user *User
|
||||
addresses AddressList
|
||||
@ -116,404 +42,79 @@ type client struct {
|
||||
addrKeyRing map[string]*crypto.KeyRing
|
||||
keyRingLock sync.Locker
|
||||
|
||||
log *logrus.Entry
|
||||
exp time.Time
|
||||
}
|
||||
|
||||
// newClient creates a new API client.
|
||||
func newClient(cm *ClientManager, userID string) *client {
|
||||
func newClient(req requester, uid string) *client {
|
||||
return &client{
|
||||
cm: cm,
|
||||
hc: getHTTPClient(cm.config, cm.roundTripper, cm.cookieJar),
|
||||
userID: userID,
|
||||
requestLocker: &sync.Mutex{},
|
||||
refreshLocker: &sync.Mutex{},
|
||||
keyRingLock: &sync.Mutex{},
|
||||
addrKeyRing: make(map[string]*crypto.KeyRing),
|
||||
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
|
||||
req: req,
|
||||
uid: uid,
|
||||
addrKeyRing: make(map[string]*crypto.KeyRing),
|
||||
keyRingLock: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// getHTTPClient returns a http client configured by the given client config and using the given transport.
|
||||
func getHTTPClient(cfg *ClientConfig, rt http.RoundTripper, jar http.CookieJar) (hc *http.Client) {
|
||||
return &http.Client{
|
||||
Transport: rt,
|
||||
Jar: jar,
|
||||
Timeout: cfg.Timeout,
|
||||
}
|
||||
func (c *client) withAuth(acc, ref string, exp time.Time) *client {
|
||||
c.acc = acc
|
||||
c.ref = ref
|
||||
c.exp = exp
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *client) IsUnlocked() bool {
|
||||
return c.userKeyRing != nil
|
||||
}
|
||||
|
||||
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
|
||||
// If the keyrings are already present, they are not recreated.
|
||||
func (c *client) Unlock(passphrase []byte) (err error) {
|
||||
c.keyRingLock.Lock()
|
||||
defer c.keyRingLock.Unlock()
|
||||
|
||||
return c.unlock(passphrase)
|
||||
}
|
||||
|
||||
// unlock unlocks the user's keys but without locking the keyring lock first.
|
||||
// Should only be used internally by methods that first lock the lock.
|
||||
func (c *client) unlock(passphrase []byte) (err error) {
|
||||
if _, err = c.CurrentUser(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.userKeyRing == nil {
|
||||
if err = c.unlockUser(passphrase); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
}
|
||||
|
||||
for _, address := range c.addresses {
|
||||
if c.addrKeyRing[address.ID] == nil {
|
||||
if err = c.unlockAddress(passphrase, address); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock address")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *client) ReloadKeys(passphrase []byte) (err error) {
|
||||
c.keyRingLock.Lock()
|
||||
defer c.keyRingLock.Unlock()
|
||||
|
||||
c.clearKeys()
|
||||
|
||||
return c.unlock(passphrase)
|
||||
}
|
||||
|
||||
func (c *client) clearKeys() {
|
||||
if c.userKeyRing != nil {
|
||||
c.userKeyRing.ClearPrivateParams()
|
||||
c.userKeyRing = nil
|
||||
}
|
||||
|
||||
for id, kr := range c.addrKeyRing {
|
||||
if kr != nil {
|
||||
kr.ClearPrivateParams()
|
||||
}
|
||||
delete(c.addrKeyRing, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) CloseConnections() {
|
||||
c.hc.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// Do makes an API request. It does not check for HTTP status code errors.
|
||||
func (c *client) Do(req *http.Request, retryUnauthorized bool) (res *http.Response, err error) {
|
||||
// Copy the request body in case we need to retry it.
|
||||
var bodyBuffer []byte
|
||||
if req.Body != nil {
|
||||
defer req.Body.Close() //nolint[errcheck]
|
||||
bodyBuffer, err = ioutil.ReadAll(req.Body)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := bytes.NewReader(bodyBuffer)
|
||||
req.Body = ioutil.NopCloser(r)
|
||||
}
|
||||
|
||||
return c.doBuffered(req, bodyBuffer, retryUnauthorized)
|
||||
}
|
||||
|
||||
// If needed it retries using req and buffered body.
|
||||
func (c *client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen]
|
||||
isAuthReq := strings.Contains(req.URL.Path, "/auth")
|
||||
|
||||
req.Header.Set("User-Agent", c.cm.userAgent.String())
|
||||
req.Header.Set("x-pm-appversion", c.cm.config.AppVersion)
|
||||
func (c *client) r(ctx context.Context) (*resty.Request, error) {
|
||||
r := c.req.r(ctx)
|
||||
|
||||
if c.uid != "" {
|
||||
req.Header.Set("x-pm-uid", c.uid)
|
||||
r.SetHeader("x-pm-uid", c.uid)
|
||||
}
|
||||
|
||||
if c.accessToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.accessToken)
|
||||
}
|
||||
|
||||
c.log.Debugln("Requesting ", req.Method, req.URL.RequestURI())
|
||||
if logrus.GetLevel() == logrus.TraceLevel {
|
||||
head := ""
|
||||
for i, v := range req.Header {
|
||||
head += i + ": "
|
||||
head += strings.Join(v, "")
|
||||
head += "\n"
|
||||
}
|
||||
c.log.Tracef("REQHEAD \n%s", head)
|
||||
c.log.Tracef("REQBODY '%s'", printBytes(bodyBuffer))
|
||||
}
|
||||
|
||||
hasBody := len(bodyBuffer) > 0
|
||||
if res, err = c.hc.Do(req); err != nil {
|
||||
if res == nil {
|
||||
c.log.WithError(err).Error("Cannot get response")
|
||||
err = ErrAPINotReachable
|
||||
c.cm.noConnection()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cookies are returned only after request was sent.
|
||||
c.log.Tracef("REQCOOKIES '%v'", req.Cookies())
|
||||
|
||||
resDate := res.Header.Get("Date")
|
||||
if resDate != "" {
|
||||
if serverTime, err := http.ParseTime(resDate); err == nil {
|
||||
crypto.UpdateTime(serverTime.Unix())
|
||||
}
|
||||
}
|
||||
|
||||
if res.StatusCode == http.StatusUnauthorized {
|
||||
if hasBody {
|
||||
r := bytes.NewReader(bodyBuffer)
|
||||
req.Body = ioutil.NopCloser(r)
|
||||
}
|
||||
|
||||
if !isAuthReq {
|
||||
_, _ = io.Copy(ioutil.Discard, res.Body)
|
||||
_ = res.Body.Close()
|
||||
return c.handleStatusUnauthorized(req, bodyBuffer, res, retryUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
// Retry induced by HTTP status code>
|
||||
retryAfter := 10
|
||||
doRetry := res.StatusCode == http.StatusTooManyRequests
|
||||
if doRetry {
|
||||
if headerAfter, err := strconv.Atoi(res.Header.Get("Retry-After")); err == nil && headerAfter > 0 {
|
||||
retryAfter = headerAfter
|
||||
}
|
||||
// To avoid spikes when all clients retry at the same time, we add some random wait.
|
||||
retryAfter += rand.Intn(10) //nolint[gosec] It is OK to use weak random number generator here
|
||||
|
||||
if hasBody {
|
||||
r := bytes.NewReader(bodyBuffer)
|
||||
req.Body = ioutil.NopCloser(r)
|
||||
}
|
||||
|
||||
c.log.Warningf("Retrying %s after %ds induced by http code %d", req.URL.Path, retryAfter, res.StatusCode)
|
||||
time.Sleep(time.Duration(retryAfter) * time.Second)
|
||||
_, _ = io.Copy(ioutil.Discard, res.Body)
|
||||
_ = res.Body.Close()
|
||||
return c.doBuffered(req, bodyBuffer, false)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// DoJSON performs the request and unmarshals the response as JSON into data.
|
||||
// If the API returns a non-2xx HTTP status code, the error returned will contain status
|
||||
// and response as plaintext. API errors must be checked by the caller.
|
||||
// It is performed buffered, in case we need to retry.
|
||||
func (c *client) DoJSON(req *http.Request, data interface{}) error {
|
||||
// Copy the request body in case we need to retry it
|
||||
var reqBodyBuffer []byte
|
||||
|
||||
if req.Body != nil {
|
||||
defer req.Body.Close() //nolint[errcheck]
|
||||
var err error
|
||||
if reqBodyBuffer, err = ioutil.ReadAll(req.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer))
|
||||
}
|
||||
|
||||
return c.doJSONBuffered(req, reqBodyBuffer, data)
|
||||
}
|
||||
|
||||
// doJSONBuffered performs a buffered json request (see DoJSON for more information).
|
||||
func (c *client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data interface{}) error { // nolint[funlen]
|
||||
req.Header.Set("Accept", "application/vnd.protonmail.v1+json")
|
||||
|
||||
var cancelRequest context.CancelFunc
|
||||
if c.cm.config.MinBytesPerSecond > 0 {
|
||||
var ctx context.Context
|
||||
ctx, cancelRequest = context.WithCancel(req.Context())
|
||||
defer func() {
|
||||
cancelRequest()
|
||||
}()
|
||||
req = req.WithContext(ctx)
|
||||
}
|
||||
|
||||
res, err := c.doBuffered(req, reqBodyBuffer, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close() //nolint[errcheck]
|
||||
|
||||
var resBody []byte
|
||||
if c.cm.config.MinBytesPerSecond == 0 {
|
||||
resBody, err = ioutil.ReadAll(res.Body)
|
||||
} else {
|
||||
resBody, err = c.readAllMinSpeed(res.Body, cancelRequest)
|
||||
if err == context.Canceled {
|
||||
err = ErrConnectionSlow
|
||||
}
|
||||
}
|
||||
|
||||
// The server response may contain data which we want to have in memory
|
||||
// for as little time as possible (such as keys). Go is garbage collected,
|
||||
// so we are not in charge of when the memory will actually be cleared.
|
||||
// We can at least try to rewrite the original data to mitigate this problem.
|
||||
defer func() {
|
||||
for i := 0; i < len(resBody); i++ {
|
||||
resBody[i] = byte(65)
|
||||
}
|
||||
}()
|
||||
|
||||
if logrus.GetLevel() == logrus.TraceLevel {
|
||||
head := ""
|
||||
for i, v := range res.Header {
|
||||
head += i + ": "
|
||||
head += strings.Join(v, "")
|
||||
head += "\n"
|
||||
}
|
||||
c.log.Tracef("RESHEAD \n%s", head)
|
||||
c.log.Tracef("RESBODY '%s'", resBody)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Retry induced by API code.
|
||||
errCode := &Res{}
|
||||
if err := json.Unmarshal(resBody, errCode); err == nil {
|
||||
if errCode.Code == BansRequests {
|
||||
retryAfter := 3
|
||||
c.log.Warningf("Retrying %s after %ds induced by API code %d", req.URL.Path, retryAfter, errCode.Code)
|
||||
time.Sleep(time.Duration(retryAfter) * time.Second)
|
||||
if len(reqBodyBuffer) > 0 {
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer))
|
||||
}
|
||||
return c.doJSONBuffered(req, reqBodyBuffer, data)
|
||||
}
|
||||
if errCode.Err() == ErrAPINotReachable {
|
||||
c.cm.noConnection()
|
||||
}
|
||||
if errCode.Err() == ErrUpgradeApplication {
|
||||
c.cm.config.UpgradeApplicationHandler()
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resBody, data); err != nil {
|
||||
// Check to see if this is due to a non 2xx HTTP status code.
|
||||
if res.StatusCode != http.StatusOK {
|
||||
r := bytes.NewReader(bytes.ReplaceAll(resBody, []byte("\n"), []byte("\\n")))
|
||||
plaintext, err := html2text.FromReader(r)
|
||||
if err == nil {
|
||||
return fmt.Errorf("Error: \n\n" + res.Status + "\n\n" + plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
if errJS, ok := err.(*json.SyntaxError); ok {
|
||||
return fmt.Errorf("invalid json %v (offset:%d) ", errJS.Error(), errJS.Offset)
|
||||
}
|
||||
|
||||
return fmt.Errorf("unmarshal fail: %v ", err)
|
||||
}
|
||||
|
||||
// Set StatusCode in case data struct supports that field.
|
||||
// It's safe to set StatusCode, server returns Code. StatusCode should be preferred over Code.
|
||||
dataValue := reflect.ValueOf(data).Elem()
|
||||
statusCodeField := dataValue.FieldByName("StatusCode")
|
||||
if statusCodeField.IsValid() && statusCodeField.CanSet() && statusCodeField.Kind() == reflect.Int {
|
||||
statusCodeField.SetInt(int64(res.StatusCode))
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
c.log.Warnf("request %s %s NOT OK: %s", req.Method, req.URL.Path, res.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) {
|
||||
firstReadTimeout := c.cm.config.FirstReadTimeout
|
||||
if firstReadTimeout == 0 {
|
||||
firstReadTimeout = 5 * time.Minute
|
||||
}
|
||||
timer := time.AfterFunc(firstReadTimeout, func() {
|
||||
cancelRequest()
|
||||
})
|
||||
|
||||
// speedCheckSeconds controls how often we check the transfer speed.
|
||||
// Note that connection can be unstable, on average very fast, but can be
|
||||
// idle for few seconds; or that API can take its time before sending
|
||||
// another data, e.g., API can send some data and take some time before
|
||||
// processing and sending the rest of the response.
|
||||
const speedCheckSeconds = 30
|
||||
|
||||
var buffer bytes.Buffer
|
||||
for {
|
||||
_, err := io.CopyN(&buffer, data, c.cm.config.MinBytesPerSecond*speedCheckSeconds)
|
||||
timer.Stop()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
if time.Now().After(c.exp) {
|
||||
if err := c.authRefresh(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
timer.Reset(speedCheckSeconds * time.Second)
|
||||
}
|
||||
|
||||
return ioutil.ReadAll(&buffer)
|
||||
c.authLocker.RLock()
|
||||
defer c.authLocker.RUnlock()
|
||||
|
||||
if c.acc != "" {
|
||||
r.SetAuthToken(c.acc)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *client) refreshAccessToken() (err error) {
|
||||
c.log.Debug("Refreshing token")
|
||||
|
||||
refreshToken := c.cm.GetToken(c.userID)
|
||||
|
||||
if refreshToken == "" {
|
||||
c.sendAuth(nil)
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
if _, err := c.AuthRefresh(refreshToken); err != nil {
|
||||
if err != ErrAPINotReachable {
|
||||
c.sendAuth(nil)
|
||||
}
|
||||
return errors.Wrap(err, "failed to refresh auth")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) {
|
||||
c.log.Info("Handling unauthorized status")
|
||||
|
||||
// If this is not a retry, then it is the first time handling status unauthorized,
|
||||
// so try again without refreshing the access token.
|
||||
if !retry {
|
||||
c.log.Debug("Handling unauthorized status by retrying")
|
||||
c.requestLocker.Lock()
|
||||
defer c.requestLocker.Unlock()
|
||||
|
||||
_, _ = io.Copy(ioutil.Discard, res.Body)
|
||||
_ = res.Body.Close()
|
||||
return c.doBuffered(req, reqBodyBuffer, true)
|
||||
}
|
||||
|
||||
// This is already a retry, so we will try to refresh the access token before trying again.
|
||||
if err = c.refreshAccessToken(); err != nil {
|
||||
c.log.WithError(err).Warn("Cannot refresh token")
|
||||
err = &ErrUnauthorized{err}
|
||||
return
|
||||
}
|
||||
_, err = io.Copy(ioutil.Discard, res.Body)
|
||||
func (c *client) do(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) (*resty.Response, error) {
|
||||
r, err := c.r(ctx)
|
||||
if err != nil {
|
||||
c.log.WithError(err).Warn("Failed to read out response body")
|
||||
return nil, err
|
||||
}
|
||||
_ = res.Body.Close()
|
||||
return c.doBuffered(req, reqBodyBuffer, true)
|
||||
|
||||
res, err := wrapRestyError(fn(r))
|
||||
if err != nil {
|
||||
if res.StatusCode() != http.StatusUnauthorized {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.authRefresh(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapRestyError(fn(r))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func wrapRestyError(res *resty.Response, err error) (*resty.Response, error) {
|
||||
if err, ok := err.(*resty.ResponseError); ok {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if res.RawResponse != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return res, errors.Wrap(ErrNoConnection, err.Error())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user