Files
proton-bridge/pkg/pmapi/client.go
2020-04-21 08:36:38 +00:00

448 lines
12 KiB
Go

// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"reflect"
"strconv"
"strings"
"sync"
"time"
pmcrypto "github.com/ProtonMail/gopenpgp/crypto"
"github.com/jaytaylor/html2text"
"github.com/sirupsen/logrus"
)
// Version of the API.
const Version = 3
// API return codes.
const (
ForceUpgradeBadAPIVersion = 5003
ForceUpgradeInvalidAPI = 5004
ForceUpgradeBadAppVersion = 5005
APIOffline = 7001
ImportMessageTooLong = 36022
BansRequests = 85131
)
// The output errors.
var (
ErrInvalidToken = errors.New("refresh token invalid")
ErrAPINotReachable = errors.New("cannot reach the server")
ErrUpgradeApplication = errors.New("application upgrade required")
ErrNoSuchAPIID = errors.New("no such API ID")
)
type ErrUnauthorized struct {
error
}
func (err *ErrUnauthorized) Error() string {
return fmt.Sprintf("unauthorized access: %+v", err.error.Error())
}
// ClientConfig contains Client configuration.
type ClientConfig struct {
// The client application name and version.
AppVersion string
// The client ID.
ClientID string
// The sentry DSN.
SentryDSN string
// Transport specifies the mechanism by which individual HTTP requests are made.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// Timeout specifies the timeout from request to getting response headers to our API.
// Passed to http.Client, empty means no timeout.
Timeout time.Duration
// FirstReadTimeout specifies the timeout from getting response to the first read of body response.
// This timeout is applied only when MinSpeed is used.
// Default is 5 minutes.
FirstReadTimeout time.Duration
// MinSpeed specifies minimum Bytes per second or the request will be canceled.
// Zero means no limitation.
MinSpeed int64
}
// Client to communicate with API.
type Client struct {
cm *ClientManager
client *http.Client
uid string
accessToken string
userID string // Twice here because Username is not unique.
requestLocker sync.Locker
keyLocker sync.Locker
expiresAt time.Time
user *User
addresses AddressList
kr *pmcrypto.KeyRing
log *logrus.Entry
}
// newClient creates a new API client.
func newClient(cm *ClientManager, userID string) *Client {
return &Client{
cm: cm,
client: getHTTPClient(cm.GetConfig()),
userID: userID,
requestLocker: &sync.Mutex{},
keyLocker: &sync.Mutex{},
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
}
}
// getHTTPClient returns a http client configured by the given client config.
func getHTTPClient(cfg *ClientConfig) (hc *http.Client) {
hc = &http.Client{Timeout: cfg.Timeout}
if cfg.Transport == nil && defaultTransport == nil {
return
}
if defaultTransport != nil {
hc.Transport = defaultTransport
return
}
// In future use Clone here.
// https://go-review.googlesource.com/c/go/+/174597/
if cfgTransport, ok := cfg.Transport.(*http.Transport); ok {
transport := &http.Transport{}
*transport = *cfgTransport //nolint
if transport.Proxy == nil {
transport.Proxy = http.ProxyFromEnvironment
}
hc.Transport = transport
return
}
hc.Transport = cfg.Transport
return hc
}
// Do makes an API request. It does not check for HTTP status code errors.
func (c *Client) Do(req *http.Request, retryUnauthorized bool) (res *http.Response, err error) {
// Copy the request body in case we need to retry it.
var bodyBuffer []byte
if req.Body != nil {
defer req.Body.Close() //nolint[errcheck]
bodyBuffer, err = ioutil.ReadAll(req.Body)
if err != nil {
return nil, err
}
r := bytes.NewReader(bodyBuffer)
req.Body = ioutil.NopCloser(r)
}
return c.doBuffered(req, bodyBuffer, retryUnauthorized)
}
// If needed it retries using req and buffered body.
func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen]
isAuthReq := strings.Contains(req.URL.Path, "/auth")
req.Header.Set("x-pm-appversion", c.cm.GetConfig().AppVersion)
req.Header.Set("x-pm-apiversion", strconv.Itoa(Version))
if c.uid != "" {
req.Header.Set("x-pm-uid", c.uid)
}
if c.accessToken != "" {
req.Header.Set("Authorization", "Bearer "+c.accessToken)
}
c.log.Debugln("Requesting ", req.Method, req.URL.RequestURI())
if logrus.GetLevel() == logrus.TraceLevel {
head := ""
for i, v := range req.Header {
head += i + ": "
head += strings.Join(v, "")
head += "\n"
}
c.log.Tracef("REQHEAD \n%s", head)
c.log.Tracef("REQBODY '%s'", string(bodyBuffer))
}
hasBody := len(bodyBuffer) > 0
if res, err = c.client.Do(req); err != nil {
if res == nil {
c.log.WithError(err).Error("Cannot get response")
err = ErrAPINotReachable
}
return
}
resDate := res.Header.Get("Date")
if resDate != "" {
if serverTime, err := http.ParseTime(resDate); err == nil {
pmcrypto.GetGopenPGP().UpdateTime(serverTime.Unix())
}
}
if res.StatusCode == http.StatusUnauthorized {
if hasBody {
r := bytes.NewReader(bodyBuffer)
req.Body = ioutil.NopCloser(r)
}
if !isAuthReq {
_, _ = io.Copy(ioutil.Discard, res.Body)
_ = res.Body.Close()
return c.handleStatusUnauthorized(req, bodyBuffer, res, retryUnauthorized)
}
}
// Retry induced by HTTP status code>
retryAfter := 10
doRetry := res.StatusCode == http.StatusTooManyRequests
if doRetry {
if headerAfter, err := strconv.Atoi(res.Header.Get("Retry-After")); err == nil && headerAfter > 0 {
retryAfter = headerAfter
}
// To avoid spikes when all clients retry at the same time, we add some random wait.
retryAfter += rand.Intn(10)
if hasBody {
r := bytes.NewReader(bodyBuffer)
req.Body = ioutil.NopCloser(r)
}
c.log.Warningf("Retrying %s after %ds induced by http code %d", req.URL.Path, retryAfter, res.StatusCode)
time.Sleep(time.Duration(retryAfter) * time.Second)
_, _ = io.Copy(ioutil.Discard, res.Body)
_ = res.Body.Close()
return c.doBuffered(req, bodyBuffer, false)
}
return res, err
}
// DoJSON performs the request and unmarshals the response as JSON into data.
// If the API returns a non-2xx HTTP status code, the error returned will contain status
// and response as plaintext. API errors must be checked by the caller.
// It is performed buffered, in case we need to retry.
func (c *Client) DoJSON(req *http.Request, data interface{}) error {
// Copy the request body in case we need to retry it
var reqBodyBuffer []byte
if req.Body != nil {
defer req.Body.Close() //nolint[errcheck]
var err error
if reqBodyBuffer, err = ioutil.ReadAll(req.Body); err != nil {
return err
}
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer))
}
return c.doJSONBuffered(req, reqBodyBuffer, data)
}
// doJSONBuffered performs a buffered json request (see DoJSON for more information).
func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data interface{}) error { // nolint[funlen]
req.Header.Set("Accept", "application/vnd.protonmail.v1+json")
var cancelRequest context.CancelFunc
if c.cm.GetConfig().MinSpeed > 0 {
var ctx context.Context
ctx, cancelRequest = context.WithCancel(req.Context())
defer func() {
cancelRequest()
}()
req = req.WithContext(ctx)
}
res, err := c.doBuffered(req, reqBodyBuffer, false)
if err != nil {
return err
}
defer res.Body.Close() //nolint[errcheck]
var resBody []byte
if c.cm.GetConfig().MinSpeed == 0 {
resBody, err = ioutil.ReadAll(res.Body)
} else {
resBody, err = c.readAllMinSpeed(res.Body, cancelRequest)
}
// The server response may contain data which we want to have in memory
// for as little time as possible (such as keys). Go is garbage collected,
// so we are not in charge of when the memory will actually be cleared.
// We can at least try to rewrite the original data to mitigate this problem.
defer func() {
for i := 0; i < len(resBody); i++ {
resBody[i] = byte(65)
}
}()
if logrus.GetLevel() == logrus.TraceLevel {
head := ""
for i, v := range res.Header {
head += i + ": "
head += strings.Join(v, "")
head += "\n"
}
c.log.Tracef("RESHEAD \n%s", head)
c.log.Tracef("RESBODY '%s'", resBody)
}
if err != nil {
return err
}
// Retry induced by API code.
errCode := &Res{}
if err := json.Unmarshal(resBody, errCode); err == nil {
if errCode.Code == BansRequests {
retryAfter := 3
c.log.Warningf("Retrying %s after %ds induced by API code %d", req.URL.Path, retryAfter, errCode.Code)
time.Sleep(time.Duration(retryAfter) * time.Second)
if len(reqBodyBuffer) > 0 {
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer))
}
return c.doJSONBuffered(req, reqBodyBuffer, data)
}
}
if err := json.Unmarshal(resBody, data); err != nil {
// Check to see if this is due to a non 2xx HTTP status code.
if res.StatusCode != http.StatusOK {
r := bytes.NewReader(bytes.ReplaceAll(resBody, []byte("\n"), []byte("\\n")))
plaintext, err := html2text.FromReader(r)
if err == nil {
return fmt.Errorf("Error: \n\n" + res.Status + "\n\n" + plaintext)
}
}
if errJS, ok := err.(*json.SyntaxError); ok {
return fmt.Errorf("invalid json %v (offset:%d) ", errJS.Error(), errJS.Offset)
}
return fmt.Errorf("unmarshal fail: %v ", err)
}
// Set StatusCode in case data struct supports that field.
// It's safe to set StatusCode, server returns Code. StatusCode should be preferred over Code.
dataValue := reflect.ValueOf(data).Elem()
statusCodeField := dataValue.FieldByName("StatusCode")
if statusCodeField.IsValid() && statusCodeField.CanSet() && statusCodeField.Kind() == reflect.Int {
statusCodeField.SetInt(int64(res.StatusCode))
}
if res.StatusCode != http.StatusOK {
c.log.Warnf("request %s %s NOT OK: %s", req.Method, req.URL.Path, res.Status)
}
return nil
}
func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) {
firstReadTimeout := c.cm.GetConfig().FirstReadTimeout
if firstReadTimeout == 0 {
firstReadTimeout = 5 * time.Minute
}
timer := time.AfterFunc(firstReadTimeout, func() {
cancelRequest()
})
var buffer bytes.Buffer
for {
_, err := io.CopyN(&buffer, data, c.cm.GetConfig().MinSpeed)
timer.Stop()
timer.Reset(1 * time.Second)
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
}
return ioutil.ReadAll(&buffer)
}
func (c *Client) refreshAccessToken() (err error) {
c.log.Debug("Refreshing token")
refreshToken := c.cm.GetToken(c.userID)
if refreshToken == "" {
c.sendAuth(nil)
return ErrInvalidToken
}
if _, err := c.AuthRefresh(refreshToken); err != nil {
c.sendAuth(nil)
return err
}
return
}
func (c *Client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) {
c.log.Info("Handling unauthorized status")
// If this is not a retry, then it is the first time handling status unauthorized,
// so try again without refreshing the access token.
if !retry {
c.log.Debug("Handling unauthorized status by retrying")
c.requestLocker.Lock()
defer c.requestLocker.Unlock()
_, _ = io.Copy(ioutil.Discard, res.Body)
_ = res.Body.Close()
return c.doBuffered(req, reqBodyBuffer, true)
}
// This is already a retry, so we will try to refresh the access token before trying again.
if err = c.refreshAccessToken(); err != nil {
c.log.WithError(err).Warn("Cannot refresh token")
err = &ErrUnauthorized{err}
return
}
_, err = io.Copy(ioutil.Discard, res.Body)
if err != nil {
c.log.WithError(err).Warn("Failed to read out response body")
}
_ = res.Body.Close()
return c.doBuffered(req, reqBodyBuffer, true)
}