mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
520 lines
14 KiB
Go
520 lines
14 KiB
Go
// Copyright (c) 2021 Proton Technologies AG
|
|
//
|
|
// This file is part of ProtonMail Bridge.
|
|
//
|
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package pmapi
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"net/http"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
|
"github.com/jaytaylor/html2text"
|
|
"github.com/pkg/errors"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// Version of the API.
|
|
const Version = 3
|
|
|
|
// API return codes.
|
|
const (
|
|
ForceUpgradeBadAppVersion = 5003
|
|
APIOffline = 7001
|
|
ImportMessageTooLong = 36022
|
|
BansRequests = 85131
|
|
)
|
|
|
|
// The output errors.
|
|
var (
|
|
ErrInvalidToken = errors.New("refresh token invalid")
|
|
ErrAPINotReachable = errors.New("cannot reach the server")
|
|
ErrUpgradeApplication = errors.New("application upgrade required")
|
|
ErrConnectionSlow = errors.New("request canceled because connection speed was too slow")
|
|
)
|
|
|
|
type ErrUnprocessableEntity struct {
|
|
error
|
|
}
|
|
|
|
func (err *ErrUnprocessableEntity) Error() string {
|
|
return err.error.Error()
|
|
}
|
|
|
|
type ErrUnauthorized struct {
|
|
error
|
|
}
|
|
|
|
func (err *ErrUnauthorized) Error() string {
|
|
return fmt.Sprintf("unauthorized access: %+v", err.error.Error())
|
|
}
|
|
|
|
// ClientConfig contains Client configuration.
|
|
type ClientConfig struct {
|
|
// The client application name and version.
|
|
AppVersion string
|
|
|
|
// The client ID.
|
|
ClientID string
|
|
|
|
// Timeout is the timeout of the full request. It is passed to http.Client.
|
|
// If it is left unset, it means no timeout is applied.
|
|
Timeout time.Duration
|
|
|
|
// FirstReadTimeout specifies the timeout from getting response to the first read of body response.
|
|
// This timeout is applied only when MinBytesPerSecond is used.
|
|
// Default is 5 minutes.
|
|
FirstReadTimeout time.Duration
|
|
|
|
// MinBytesPerSecond specifies minimum Bytes per second or the request will be canceled.
|
|
// Zero means no limitation.
|
|
MinBytesPerSecond int64
|
|
|
|
ConnectionOnHandler func()
|
|
ConnectionOffHandler func()
|
|
UpgradeApplicationHandler func()
|
|
}
|
|
|
|
// client is a client of the protonmail API. It implements the Client interface.
|
|
type client struct {
|
|
cm *ClientManager
|
|
hc *http.Client
|
|
|
|
uid string
|
|
accessToken string
|
|
userID string
|
|
requestLocker sync.Locker
|
|
refreshLocker sync.Locker
|
|
|
|
user *User
|
|
addresses AddressList
|
|
userKeyRing *crypto.KeyRing
|
|
addrKeyRing map[string]*crypto.KeyRing
|
|
keyRingLock sync.Locker
|
|
|
|
log *logrus.Entry
|
|
}
|
|
|
|
// newClient creates a new API client.
|
|
func newClient(cm *ClientManager, userID string) *client {
|
|
return &client{
|
|
cm: cm,
|
|
hc: getHTTPClient(cm.config, cm.roundTripper, cm.cookieJar),
|
|
userID: userID,
|
|
requestLocker: &sync.Mutex{},
|
|
refreshLocker: &sync.Mutex{},
|
|
keyRingLock: &sync.Mutex{},
|
|
addrKeyRing: make(map[string]*crypto.KeyRing),
|
|
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID),
|
|
}
|
|
}
|
|
|
|
// getHTTPClient returns a http client configured by the given client config and using the given transport.
|
|
func getHTTPClient(cfg *ClientConfig, rt http.RoundTripper, jar http.CookieJar) (hc *http.Client) {
|
|
return &http.Client{
|
|
Transport: rt,
|
|
Jar: jar,
|
|
Timeout: cfg.Timeout,
|
|
}
|
|
}
|
|
|
|
func (c *client) IsUnlocked() bool {
|
|
return c.userKeyRing != nil
|
|
}
|
|
|
|
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
|
|
// If the keyrings are already present, they are not recreated.
|
|
func (c *client) Unlock(passphrase []byte) (err error) {
|
|
c.keyRingLock.Lock()
|
|
defer c.keyRingLock.Unlock()
|
|
|
|
return c.unlock(passphrase)
|
|
}
|
|
|
|
// unlock unlocks the user's keys but without locking the keyring lock first.
|
|
// Should only be used internally by methods that first lock the lock.
|
|
func (c *client) unlock(passphrase []byte) (err error) {
|
|
if _, err = c.CurrentUser(); err != nil {
|
|
return
|
|
}
|
|
|
|
if c.userKeyRing == nil {
|
|
if err = c.unlockUser(passphrase); err != nil {
|
|
return errors.Wrap(err, "failed to unlock user")
|
|
}
|
|
}
|
|
|
|
for _, address := range c.addresses {
|
|
if c.addrKeyRing[address.ID] == nil {
|
|
if err = c.unlockAddress(passphrase, address); err != nil {
|
|
return errors.Wrap(err, "failed to unlock address")
|
|
}
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *client) ReloadKeys(passphrase []byte) (err error) {
|
|
c.keyRingLock.Lock()
|
|
defer c.keyRingLock.Unlock()
|
|
|
|
c.clearKeys()
|
|
|
|
return c.unlock(passphrase)
|
|
}
|
|
|
|
func (c *client) clearKeys() {
|
|
if c.userKeyRing != nil {
|
|
c.userKeyRing.ClearPrivateParams()
|
|
c.userKeyRing = nil
|
|
}
|
|
|
|
for id, kr := range c.addrKeyRing {
|
|
if kr != nil {
|
|
kr.ClearPrivateParams()
|
|
}
|
|
delete(c.addrKeyRing, id)
|
|
}
|
|
}
|
|
|
|
func (c *client) CloseConnections() {
|
|
c.hc.CloseIdleConnections()
|
|
}
|
|
|
|
// Do makes an API request. It does not check for HTTP status code errors.
|
|
func (c *client) Do(req *http.Request, retryUnauthorized bool) (res *http.Response, err error) {
|
|
// Copy the request body in case we need to retry it.
|
|
var bodyBuffer []byte
|
|
if req.Body != nil {
|
|
defer req.Body.Close() //nolint[errcheck]
|
|
bodyBuffer, err = ioutil.ReadAll(req.Body)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r := bytes.NewReader(bodyBuffer)
|
|
req.Body = ioutil.NopCloser(r)
|
|
}
|
|
|
|
return c.doBuffered(req, bodyBuffer, retryUnauthorized)
|
|
}
|
|
|
|
// If needed it retries using req and buffered body.
|
|
func (c *client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen]
|
|
isAuthReq := strings.Contains(req.URL.Path, "/auth")
|
|
|
|
req.Header.Set("User-Agent", c.cm.userAgent.String())
|
|
req.Header.Set("x-pm-appversion", c.cm.config.AppVersion)
|
|
|
|
if c.uid != "" {
|
|
req.Header.Set("x-pm-uid", c.uid)
|
|
}
|
|
|
|
if c.accessToken != "" {
|
|
req.Header.Set("Authorization", "Bearer "+c.accessToken)
|
|
}
|
|
|
|
c.log.Debugln("Requesting ", req.Method, req.URL.RequestURI())
|
|
if logrus.GetLevel() == logrus.TraceLevel {
|
|
head := ""
|
|
for i, v := range req.Header {
|
|
head += i + ": "
|
|
head += strings.Join(v, "")
|
|
head += "\n"
|
|
}
|
|
c.log.Tracef("REQHEAD \n%s", head)
|
|
c.log.Tracef("REQBODY '%s'", printBytes(bodyBuffer))
|
|
}
|
|
|
|
hasBody := len(bodyBuffer) > 0
|
|
if res, err = c.hc.Do(req); err != nil {
|
|
if res == nil {
|
|
c.log.WithError(err).Error("Cannot get response")
|
|
err = ErrAPINotReachable
|
|
c.cm.noConnection()
|
|
}
|
|
return
|
|
}
|
|
|
|
// Cookies are returned only after request was sent.
|
|
c.log.Tracef("REQCOOKIES '%v'", req.Cookies())
|
|
|
|
resDate := res.Header.Get("Date")
|
|
if resDate != "" {
|
|
if serverTime, err := http.ParseTime(resDate); err == nil {
|
|
crypto.UpdateTime(serverTime.Unix())
|
|
}
|
|
}
|
|
|
|
if res.StatusCode == http.StatusUnauthorized {
|
|
if hasBody {
|
|
r := bytes.NewReader(bodyBuffer)
|
|
req.Body = ioutil.NopCloser(r)
|
|
}
|
|
|
|
if !isAuthReq {
|
|
_, _ = io.Copy(ioutil.Discard, res.Body)
|
|
_ = res.Body.Close()
|
|
return c.handleStatusUnauthorized(req, bodyBuffer, res, retryUnauthorized)
|
|
}
|
|
}
|
|
|
|
// Retry induced by HTTP status code>
|
|
retryAfter := 10
|
|
doRetry := res.StatusCode == http.StatusTooManyRequests
|
|
if doRetry {
|
|
if headerAfter, err := strconv.Atoi(res.Header.Get("Retry-After")); err == nil && headerAfter > 0 {
|
|
retryAfter = headerAfter
|
|
}
|
|
// To avoid spikes when all clients retry at the same time, we add some random wait.
|
|
retryAfter += rand.Intn(10) //nolint[gosec] It is OK to use weak random number generator here
|
|
|
|
if hasBody {
|
|
r := bytes.NewReader(bodyBuffer)
|
|
req.Body = ioutil.NopCloser(r)
|
|
}
|
|
|
|
c.log.Warningf("Retrying %s after %ds induced by http code %d", req.URL.Path, retryAfter, res.StatusCode)
|
|
time.Sleep(time.Duration(retryAfter) * time.Second)
|
|
_, _ = io.Copy(ioutil.Discard, res.Body)
|
|
_ = res.Body.Close()
|
|
return c.doBuffered(req, bodyBuffer, false)
|
|
}
|
|
|
|
return res, err
|
|
}
|
|
|
|
// DoJSON performs the request and unmarshals the response as JSON into data.
|
|
// If the API returns a non-2xx HTTP status code, the error returned will contain status
|
|
// and response as plaintext. API errors must be checked by the caller.
|
|
// It is performed buffered, in case we need to retry.
|
|
func (c *client) DoJSON(req *http.Request, data interface{}) error {
|
|
// Copy the request body in case we need to retry it
|
|
var reqBodyBuffer []byte
|
|
|
|
if req.Body != nil {
|
|
defer req.Body.Close() //nolint[errcheck]
|
|
var err error
|
|
if reqBodyBuffer, err = ioutil.ReadAll(req.Body); err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer))
|
|
}
|
|
|
|
return c.doJSONBuffered(req, reqBodyBuffer, data)
|
|
}
|
|
|
|
// doJSONBuffered performs a buffered json request (see DoJSON for more information).
|
|
func (c *client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data interface{}) error { // nolint[funlen]
|
|
req.Header.Set("Accept", "application/vnd.protonmail.v1+json")
|
|
|
|
var cancelRequest context.CancelFunc
|
|
if c.cm.config.MinBytesPerSecond > 0 {
|
|
var ctx context.Context
|
|
ctx, cancelRequest = context.WithCancel(req.Context())
|
|
defer func() {
|
|
cancelRequest()
|
|
}()
|
|
req = req.WithContext(ctx)
|
|
}
|
|
|
|
res, err := c.doBuffered(req, reqBodyBuffer, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer res.Body.Close() //nolint[errcheck]
|
|
|
|
var resBody []byte
|
|
if c.cm.config.MinBytesPerSecond == 0 {
|
|
resBody, err = ioutil.ReadAll(res.Body)
|
|
} else {
|
|
resBody, err = c.readAllMinSpeed(res.Body, cancelRequest)
|
|
if err == context.Canceled {
|
|
err = ErrConnectionSlow
|
|
}
|
|
}
|
|
|
|
// The server response may contain data which we want to have in memory
|
|
// for as little time as possible (such as keys). Go is garbage collected,
|
|
// so we are not in charge of when the memory will actually be cleared.
|
|
// We can at least try to rewrite the original data to mitigate this problem.
|
|
defer func() {
|
|
for i := 0; i < len(resBody); i++ {
|
|
resBody[i] = byte(65)
|
|
}
|
|
}()
|
|
|
|
if logrus.GetLevel() == logrus.TraceLevel {
|
|
head := ""
|
|
for i, v := range res.Header {
|
|
head += i + ": "
|
|
head += strings.Join(v, "")
|
|
head += "\n"
|
|
}
|
|
c.log.Tracef("RESHEAD \n%s", head)
|
|
c.log.Tracef("RESBODY '%s'", resBody)
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Retry induced by API code.
|
|
errCode := &Res{}
|
|
if err := json.Unmarshal(resBody, errCode); err == nil {
|
|
if errCode.Code == BansRequests {
|
|
retryAfter := 3
|
|
c.log.Warningf("Retrying %s after %ds induced by API code %d", req.URL.Path, retryAfter, errCode.Code)
|
|
time.Sleep(time.Duration(retryAfter) * time.Second)
|
|
if len(reqBodyBuffer) > 0 {
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer))
|
|
}
|
|
return c.doJSONBuffered(req, reqBodyBuffer, data)
|
|
}
|
|
if errCode.Err() == ErrAPINotReachable {
|
|
c.cm.noConnection()
|
|
}
|
|
if errCode.Err() == ErrUpgradeApplication {
|
|
c.cm.config.UpgradeApplicationHandler()
|
|
}
|
|
}
|
|
|
|
if err := json.Unmarshal(resBody, data); err != nil {
|
|
// Check to see if this is due to a non 2xx HTTP status code.
|
|
if res.StatusCode != http.StatusOK {
|
|
r := bytes.NewReader(bytes.ReplaceAll(resBody, []byte("\n"), []byte("\\n")))
|
|
plaintext, err := html2text.FromReader(r)
|
|
if err == nil {
|
|
return fmt.Errorf("Error: \n\n" + res.Status + "\n\n" + plaintext)
|
|
}
|
|
}
|
|
|
|
if errJS, ok := err.(*json.SyntaxError); ok {
|
|
return fmt.Errorf("invalid json %v (offset:%d) ", errJS.Error(), errJS.Offset)
|
|
}
|
|
|
|
return fmt.Errorf("unmarshal fail: %v ", err)
|
|
}
|
|
|
|
// Set StatusCode in case data struct supports that field.
|
|
// It's safe to set StatusCode, server returns Code. StatusCode should be preferred over Code.
|
|
dataValue := reflect.ValueOf(data).Elem()
|
|
statusCodeField := dataValue.FieldByName("StatusCode")
|
|
if statusCodeField.IsValid() && statusCodeField.CanSet() && statusCodeField.Kind() == reflect.Int {
|
|
statusCodeField.SetInt(int64(res.StatusCode))
|
|
}
|
|
|
|
if res.StatusCode != http.StatusOK {
|
|
c.log.Warnf("request %s %s NOT OK: %s", req.Method, req.URL.Path, res.Status)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) {
|
|
firstReadTimeout := c.cm.config.FirstReadTimeout
|
|
if firstReadTimeout == 0 {
|
|
firstReadTimeout = 5 * time.Minute
|
|
}
|
|
timer := time.AfterFunc(firstReadTimeout, func() {
|
|
cancelRequest()
|
|
})
|
|
|
|
// speedCheckSeconds controls how often we check the transfer speed.
|
|
// Note that connection can be unstable, on average very fast, but can be
|
|
// idle for few seconds; or that API can take its time before sending
|
|
// another data, e.g., API can send some data and take some time before
|
|
// processing and sending the rest of the response.
|
|
const speedCheckSeconds = 30
|
|
|
|
var buffer bytes.Buffer
|
|
for {
|
|
_, err := io.CopyN(&buffer, data, c.cm.config.MinBytesPerSecond*speedCheckSeconds)
|
|
timer.Stop()
|
|
if err == io.EOF {
|
|
break
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
timer.Reset(speedCheckSeconds * time.Second)
|
|
}
|
|
|
|
return ioutil.ReadAll(&buffer)
|
|
}
|
|
|
|
func (c *client) refreshAccessToken() (err error) {
|
|
c.log.Debug("Refreshing token")
|
|
|
|
refreshToken := c.cm.GetToken(c.userID)
|
|
|
|
if refreshToken == "" {
|
|
c.sendAuth(nil)
|
|
return ErrInvalidToken
|
|
}
|
|
|
|
if _, err := c.AuthRefresh(refreshToken); err != nil {
|
|
if err != ErrAPINotReachable {
|
|
c.sendAuth(nil)
|
|
}
|
|
return errors.Wrap(err, "failed to refresh auth")
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) {
|
|
c.log.Info("Handling unauthorized status")
|
|
|
|
// If this is not a retry, then it is the first time handling status unauthorized,
|
|
// so try again without refreshing the access token.
|
|
if !retry {
|
|
c.log.Debug("Handling unauthorized status by retrying")
|
|
c.requestLocker.Lock()
|
|
defer c.requestLocker.Unlock()
|
|
|
|
_, _ = io.Copy(ioutil.Discard, res.Body)
|
|
_ = res.Body.Close()
|
|
return c.doBuffered(req, reqBodyBuffer, true)
|
|
}
|
|
|
|
// This is already a retry, so we will try to refresh the access token before trying again.
|
|
if err = c.refreshAccessToken(); err != nil {
|
|
c.log.WithError(err).Warn("Cannot refresh token")
|
|
err = &ErrUnauthorized{err}
|
|
return
|
|
}
|
|
_, err = io.Copy(ioutil.Discard, res.Body)
|
|
if err != nil {
|
|
c.log.WithError(err).Warn("Failed to read out response body")
|
|
}
|
|
_ = res.Body.Close()
|
|
return c.doBuffered(req, reqBodyBuffer, true)
|
|
}
|