// Copyright (c) 2020 Proton Technologies AG // // This file is part of ProtonMail Bridge. // // ProtonMail Bridge is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // ProtonMail Bridge is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with ProtonMail Bridge. If not, see . package pmapi import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "io/ioutil" "math/rand" "net/http" "reflect" "strconv" "strings" "sync" "time" pmcrypto "github.com/ProtonMail/gopenpgp/crypto" "github.com/jaytaylor/html2text" "github.com/sirupsen/logrus" ) // Version of the API. const Version = 3 // API return codes. const ( ForceUpgradeBadAPIVersion = 5003 ForceUpgradeInvalidAPI = 5004 ForceUpgradeBadAppVersion = 5005 APIOffline = 7001 ImportMessageTooLong = 36022 BansRequests = 85131 ) // The output errors. var ( ErrInvalidToken = errors.New("refresh token invalid") ErrAPINotReachable = errors.New("cannot reach the server") ErrUpgradeApplication = errors.New("application upgrade required") ErrNoSuchAPIID = errors.New("no such API ID") ) type ErrUnauthorized struct { error } func (err *ErrUnauthorized) Error() string { return fmt.Sprintf("unauthorized access: %+v", err.error.Error()) } // ClientConfig contains Client configuration. type ClientConfig struct { // The client application name and version. AppVersion string // The client ID. ClientID string // The sentry DSN. SentryDSN string // Timeout specifies the timeout from request to getting response headers to our API. // Passed to http.Client, empty means no timeout. Timeout time.Duration // FirstReadTimeout specifies the timeout from getting response to the first read of body response. // This timeout is applied only when MinSpeed is used. // Default is 5 minutes. FirstReadTimeout time.Duration // MinSpeed specifies minimum Bytes per second or the request will be canceled. // Zero means no limitation. MinSpeed int64 } // Client defines the interface of a PMAPI client. type Client interface { Auth(username, password string, info *AuthInfo) (*Auth, error) AuthInfo(username string) (*AuthInfo, error) AuthRefresh(token string) (*Auth, error) Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error) UnlockAddresses(passphrase []byte) error CurrentUser() (*User, error) UpdateUser() (*User, error) Addresses() AddressList Logout() GetEvent(eventID string) (*Event, error) CountMessages(addressID string) ([]*MessagesCount, error) ListMessages(filter *MessagesFilter) ([]*Message, int, error) GetMessage(apiID string) (*Message, error) Import([]*ImportMsgReq) ([]*ImportMsgRes, error) DeleteMessages(apiIDs []string) error LabelMessages(apiIDs []string, labelID string) error UnlabelMessages(apiIDs []string, labelID string) error MarkMessagesRead(apiIDs []string) error MarkMessagesUnread(apiIDs []string) error ListLabels() ([]*Label, error) CreateLabel(label *Label) (*Label, error) UpdateLabel(label *Label) (*Label, error) DeleteLabel(labelID string) error EmptyFolder(labelID string, addressID string) error ReportBugWithEmailClient(os, osVersion, title, description, username, email, emailClient string) error SendSimpleMetric(category, action, label string) error ReportSentryCrash(reportErr error) (err error) Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error) GetMailSettings() (MailSettings, error) GetContactEmailByEmail(string, int, int) ([]ContactEmail, error) GetContactByID(string) (Contact, error) DecryptAndVerifyCards([]Card) ([]Card, error) GetPublicKeysForEmail(string) ([]PublicKey, bool, error) SendMessage(string, *SendMessageReq) (sent, parent *Message, err error) CreateDraft(m *Message, parent string, action int) (created *Message, err error) CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) DeleteAttachment(attID string) (err error) KeyRingForAddressID(string) (kr *pmcrypto.KeyRing) GetAttachment(id string) (att io.ReadCloser, err error) } // client is a client of the protonmail API. It implements the Client interface. type client struct { cm *ClientManager hc *http.Client uid string accessToken string userID string requestLocker sync.Locker keyLocker sync.Locker user *User addresses AddressList kr *pmcrypto.KeyRing log *logrus.Entry } // newClient creates a new API client. func newClient(cm *ClientManager, userID string) *client { return &client{ cm: cm, hc: getHTTPClient(cm.GetConfig(), cm.GetRoundTripper()), userID: userID, requestLocker: &sync.Mutex{}, keyLocker: &sync.Mutex{}, log: logrus.WithField("pkg", "pmapi").WithField("userID", userID), } } // getHTTPClient returns a http client configured by the given client config and using the given transport. func getHTTPClient(cfg *ClientConfig, rt http.RoundTripper) (hc *http.Client) { return &http.Client{ Timeout: cfg.Timeout, Transport: rt, } } // Do makes an API request. It does not check for HTTP status code errors. func (c *client) Do(req *http.Request, retryUnauthorized bool) (res *http.Response, err error) { // Copy the request body in case we need to retry it. var bodyBuffer []byte if req.Body != nil { defer req.Body.Close() //nolint[errcheck] bodyBuffer, err = ioutil.ReadAll(req.Body) if err != nil { return nil, err } r := bytes.NewReader(bodyBuffer) req.Body = ioutil.NopCloser(r) } return c.doBuffered(req, bodyBuffer, retryUnauthorized) } // If needed it retries using req and buffered body. func (c *client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen] isAuthReq := strings.Contains(req.URL.Path, "/auth") req.Header.Set("x-pm-appversion", c.cm.GetConfig().AppVersion) req.Header.Set("x-pm-apiversion", strconv.Itoa(Version)) if c.uid != "" { req.Header.Set("x-pm-uid", c.uid) } if c.accessToken != "" { req.Header.Set("Authorization", "Bearer "+c.accessToken) } c.log.Debugln("Requesting ", req.Method, req.URL.RequestURI()) if logrus.GetLevel() == logrus.TraceLevel { head := "" for i, v := range req.Header { head += i + ": " head += strings.Join(v, "") head += "\n" } c.log.Tracef("REQHEAD \n%s", head) c.log.Tracef("REQBODY '%s'", string(bodyBuffer)) } hasBody := len(bodyBuffer) > 0 if res, err = c.hc.Do(req); err != nil { if res == nil { c.log.WithError(err).Error("Cannot get response") err = ErrAPINotReachable } return } resDate := res.Header.Get("Date") if resDate != "" { if serverTime, err := http.ParseTime(resDate); err == nil { pmcrypto.GetGopenPGP().UpdateTime(serverTime.Unix()) } } if res.StatusCode == http.StatusUnauthorized { if hasBody { r := bytes.NewReader(bodyBuffer) req.Body = ioutil.NopCloser(r) } if !isAuthReq { _, _ = io.Copy(ioutil.Discard, res.Body) _ = res.Body.Close() return c.handleStatusUnauthorized(req, bodyBuffer, res, retryUnauthorized) } } // Retry induced by HTTP status code> retryAfter := 10 doRetry := res.StatusCode == http.StatusTooManyRequests if doRetry { if headerAfter, err := strconv.Atoi(res.Header.Get("Retry-After")); err == nil && headerAfter > 0 { retryAfter = headerAfter } // To avoid spikes when all clients retry at the same time, we add some random wait. retryAfter += rand.Intn(10) if hasBody { r := bytes.NewReader(bodyBuffer) req.Body = ioutil.NopCloser(r) } c.log.Warningf("Retrying %s after %ds induced by http code %d", req.URL.Path, retryAfter, res.StatusCode) time.Sleep(time.Duration(retryAfter) * time.Second) _, _ = io.Copy(ioutil.Discard, res.Body) _ = res.Body.Close() return c.doBuffered(req, bodyBuffer, false) } return res, err } // DoJSON performs the request and unmarshals the response as JSON into data. // If the API returns a non-2xx HTTP status code, the error returned will contain status // and response as plaintext. API errors must be checked by the caller. // It is performed buffered, in case we need to retry. func (c *client) DoJSON(req *http.Request, data interface{}) error { // Copy the request body in case we need to retry it var reqBodyBuffer []byte if req.Body != nil { defer req.Body.Close() //nolint[errcheck] var err error if reqBodyBuffer, err = ioutil.ReadAll(req.Body); err != nil { return err } req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer)) } return c.doJSONBuffered(req, reqBodyBuffer, data) } // doJSONBuffered performs a buffered json request (see DoJSON for more information). func (c *client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data interface{}) error { // nolint[funlen] req.Header.Set("Accept", "application/vnd.protonmail.v1+json") var cancelRequest context.CancelFunc if c.cm.GetConfig().MinSpeed > 0 { var ctx context.Context ctx, cancelRequest = context.WithCancel(req.Context()) defer func() { cancelRequest() }() req = req.WithContext(ctx) } res, err := c.doBuffered(req, reqBodyBuffer, false) if err != nil { return err } defer res.Body.Close() //nolint[errcheck] var resBody []byte if c.cm.GetConfig().MinSpeed == 0 { resBody, err = ioutil.ReadAll(res.Body) } else { resBody, err = c.readAllMinSpeed(res.Body, cancelRequest) } // The server response may contain data which we want to have in memory // for as little time as possible (such as keys). Go is garbage collected, // so we are not in charge of when the memory will actually be cleared. // We can at least try to rewrite the original data to mitigate this problem. defer func() { for i := 0; i < len(resBody); i++ { resBody[i] = byte(65) } }() if logrus.GetLevel() == logrus.TraceLevel { head := "" for i, v := range res.Header { head += i + ": " head += strings.Join(v, "") head += "\n" } c.log.Tracef("RESHEAD \n%s", head) c.log.Tracef("RESBODY '%s'", resBody) } if err != nil { return err } // Retry induced by API code. errCode := &Res{} if err := json.Unmarshal(resBody, errCode); err == nil { if errCode.Code == BansRequests { retryAfter := 3 c.log.Warningf("Retrying %s after %ds induced by API code %d", req.URL.Path, retryAfter, errCode.Code) time.Sleep(time.Duration(retryAfter) * time.Second) if len(reqBodyBuffer) > 0 { req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer)) } return c.doJSONBuffered(req, reqBodyBuffer, data) } } if err := json.Unmarshal(resBody, data); err != nil { // Check to see if this is due to a non 2xx HTTP status code. if res.StatusCode != http.StatusOK { r := bytes.NewReader(bytes.ReplaceAll(resBody, []byte("\n"), []byte("\\n"))) plaintext, err := html2text.FromReader(r) if err == nil { return fmt.Errorf("Error: \n\n" + res.Status + "\n\n" + plaintext) } } if errJS, ok := err.(*json.SyntaxError); ok { return fmt.Errorf("invalid json %v (offset:%d) ", errJS.Error(), errJS.Offset) } return fmt.Errorf("unmarshal fail: %v ", err) } // Set StatusCode in case data struct supports that field. // It's safe to set StatusCode, server returns Code. StatusCode should be preferred over Code. dataValue := reflect.ValueOf(data).Elem() statusCodeField := dataValue.FieldByName("StatusCode") if statusCodeField.IsValid() && statusCodeField.CanSet() && statusCodeField.Kind() == reflect.Int { statusCodeField.SetInt(int64(res.StatusCode)) } if res.StatusCode != http.StatusOK { c.log.Warnf("request %s %s NOT OK: %s", req.Method, req.URL.Path, res.Status) } return nil } func (c *client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) { firstReadTimeout := c.cm.GetConfig().FirstReadTimeout if firstReadTimeout == 0 { firstReadTimeout = 5 * time.Minute } timer := time.AfterFunc(firstReadTimeout, func() { cancelRequest() }) var buffer bytes.Buffer for { _, err := io.CopyN(&buffer, data, c.cm.GetConfig().MinSpeed) timer.Stop() timer.Reset(1 * time.Second) if err == io.EOF { break } else if err != nil { return nil, err } } return ioutil.ReadAll(&buffer) } func (c *client) refreshAccessToken() (err error) { c.log.Debug("Refreshing token") refreshToken := c.cm.GetToken(c.userID) if refreshToken == "" { c.sendAuth(nil) return ErrInvalidToken } if _, err := c.AuthRefresh(refreshToken); err != nil { c.sendAuth(nil) return err } return } func (c *client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) { c.log.Info("Handling unauthorized status") // If this is not a retry, then it is the first time handling status unauthorized, // so try again without refreshing the access token. if !retry { c.log.Debug("Handling unauthorized status by retrying") c.requestLocker.Lock() defer c.requestLocker.Unlock() _, _ = io.Copy(ioutil.Discard, res.Body) _ = res.Body.Close() return c.doBuffered(req, reqBodyBuffer, true) } // This is already a retry, so we will try to refresh the access token before trying again. if err = c.refreshAccessToken(); err != nil { c.log.WithError(err).Warn("Cannot refresh token") err = &ErrUnauthorized{err} return } _, err = io.Copy(ioutil.Discard, res.Body) if err != nil { c.log.WithError(err).Warn("Failed to read out response body") } _ = res.Body.Close() return c.doBuffered(req, reqBodyBuffer, true) }