// Copyright (c) 2020 Proton Technologies AG // // This file is part of ProtonMail Bridge. // // ProtonMail Bridge is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // ProtonMail Bridge is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with ProtonMail Bridge. If not, see . package pmapi import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "io/ioutil" "math/rand" "net/http" "reflect" "strconv" "strings" "sync" "time" pmcrypto "github.com/ProtonMail/gopenpgp/crypto" "github.com/jaytaylor/html2text" "github.com/sirupsen/logrus" ) // Version of the API. const Version = 3 // API return codes. const ( ForceUpgradeBadAPIVersion = 5003 ForceUpgradeInvalidAPI = 5004 ForceUpgradeBadAppVersion = 5005 APIOffline = 7001 ImportMessageTooLong = 36022 BansRequests = 85131 ) // The output errors. var ( ErrInvalidToken = errors.New("refresh token invalid") ErrAPINotReachable = errors.New("cannot reach the server") ErrUpgradeApplication = errors.New("application upgrade required") ErrNoSuchAPIID = errors.New("no such API ID") ) type ErrUnauthorized struct { error } func (err *ErrUnauthorized) Error() string { return fmt.Sprintf("unauthorized access: %+v", err.error.Error()) } // ClientConfig contains Client configuration. type ClientConfig struct { // The client application name and version. AppVersion string // The client ID. ClientID string // Timeout specifies the timeout from request to getting response headers to our API. // Passed to http.Client, empty means no timeout. Timeout time.Duration // FirstReadTimeout specifies the timeout from getting response to the first read of body response. // This timeout is applied only when MinSpeed is used. // Default is 5 minutes. FirstReadTimeout time.Duration // MinSpeed specifies minimum Bytes per second or the request will be canceled. // Zero means no limitation. MinSpeed int64 } // Client defines the interface of a PMAPI client. type Client interface { Auth(username, password string, info *AuthInfo) (*Auth, error) AuthInfo(username string) (*AuthInfo, error) AuthRefresh(token string) (*Auth, error) Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error) Logout() DeleteAuth() error ClearData() CurrentUser() (*User, error) UpdateUser() (*User, error) Unlock(mailboxPassword string) (kr *pmcrypto.KeyRing, err error) UnlockAddresses(passphrase []byte) error GetAddresses() (addresses AddressList, err error) Addresses() AddressList GetEvent(eventID string) (*Event, error) SendMessage(string, *SendMessageReq) (sent, parent *Message, err error) CreateDraft(m *Message, parent string, action int) (created *Message, err error) Import([]*ImportMsgReq) ([]*ImportMsgRes, error) CountMessages(addressID string) ([]*MessagesCount, error) ListMessages(filter *MessagesFilter) ([]*Message, int, error) GetMessage(apiID string) (*Message, error) DeleteMessages(apiIDs []string) error LabelMessages(apiIDs []string, labelID string) error UnlabelMessages(apiIDs []string, labelID string) error MarkMessagesRead(apiIDs []string) error MarkMessagesUnread(apiIDs []string) error ListLabels() ([]*Label, error) CreateLabel(label *Label) (*Label, error) UpdateLabel(label *Label) (*Label, error) DeleteLabel(labelID string) error EmptyFolder(labelID string, addressID string) error ReportBugWithEmailClient(os, osVersion, title, description, username, email, emailClient string) error SendSimpleMetric(category, action, label string) error ReportSentryCrash(reportErr error) (err error) GetMailSettings() (MailSettings, error) GetContactEmailByEmail(string, int, int) ([]ContactEmail, error) GetContactByID(string) (Contact, error) DecryptAndVerifyCards([]Card) ([]Card, error) GetAttachment(id string) (att io.ReadCloser, err error) CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) DeleteAttachment(attID string) (err error) KeyRingForAddressID(string) (kr *pmcrypto.KeyRing) GetPublicKeysForEmail(string) ([]PublicKey, bool, error) } // client is a client of the protonmail API. It implements the Client interface. type client struct { cm *ClientManager hc *http.Client uid string accessToken string userID string requestLocker sync.Locker keyLocker sync.Locker user *User addresses AddressList kr *pmcrypto.KeyRing log *logrus.Entry } // newClient creates a new API client. func newClient(cm *ClientManager, userID string) *client { return &client{ cm: cm, hc: getHTTPClient(cm.GetConfig(), cm.GetRoundTripper()), userID: userID, requestLocker: &sync.Mutex{}, keyLocker: &sync.Mutex{}, log: logrus.WithField("pkg", "pmapi").WithField("userID", userID), } } // getHTTPClient returns a http client configured by the given client config and using the given transport. func getHTTPClient(cfg *ClientConfig, rt http.RoundTripper) (hc *http.Client) { return &http.Client{ Timeout: cfg.Timeout, Transport: rt, } } // Do makes an API request. It does not check for HTTP status code errors. func (c *client) Do(req *http.Request, retryUnauthorized bool) (res *http.Response, err error) { // Copy the request body in case we need to retry it. var bodyBuffer []byte if req.Body != nil { defer req.Body.Close() //nolint[errcheck] bodyBuffer, err = ioutil.ReadAll(req.Body) if err != nil { return nil, err } r := bytes.NewReader(bodyBuffer) req.Body = ioutil.NopCloser(r) } return c.doBuffered(req, bodyBuffer, retryUnauthorized) } // If needed it retries using req and buffered body. func (c *client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen] isAuthReq := strings.Contains(req.URL.Path, "/auth") req.Header.Set("x-pm-appversion", c.cm.GetConfig().AppVersion) req.Header.Set("x-pm-apiversion", strconv.Itoa(Version)) if c.uid != "" { req.Header.Set("x-pm-uid", c.uid) } if c.accessToken != "" { req.Header.Set("Authorization", "Bearer "+c.accessToken) } c.log.Debugln("Requesting ", req.Method, req.URL.RequestURI()) if logrus.GetLevel() == logrus.TraceLevel { head := "" for i, v := range req.Header { head += i + ": " head += strings.Join(v, "") head += "\n" } c.log.Tracef("REQHEAD \n%s", head) c.log.Tracef("REQBODY '%s'", string(bodyBuffer)) } hasBody := len(bodyBuffer) > 0 if res, err = c.hc.Do(req); err != nil { if res == nil { c.log.WithError(err).Error("Cannot get response") err = ErrAPINotReachable } return } resDate := res.Header.Get("Date") if resDate != "" { if serverTime, err := http.ParseTime(resDate); err == nil { pmcrypto.GetGopenPGP().UpdateTime(serverTime.Unix()) } } if res.StatusCode == http.StatusUnauthorized { if hasBody { r := bytes.NewReader(bodyBuffer) req.Body = ioutil.NopCloser(r) } if !isAuthReq { _, _ = io.Copy(ioutil.Discard, res.Body) _ = res.Body.Close() return c.handleStatusUnauthorized(req, bodyBuffer, res, retryUnauthorized) } } // Retry induced by HTTP status code> retryAfter := 10 doRetry := res.StatusCode == http.StatusTooManyRequests if doRetry { if headerAfter, err := strconv.Atoi(res.Header.Get("Retry-After")); err == nil && headerAfter > 0 { retryAfter = headerAfter } // To avoid spikes when all clients retry at the same time, we add some random wait. retryAfter += rand.Intn(10) if hasBody { r := bytes.NewReader(bodyBuffer) req.Body = ioutil.NopCloser(r) } c.log.Warningf("Retrying %s after %ds induced by http code %d", req.URL.Path, retryAfter, res.StatusCode) time.Sleep(time.Duration(retryAfter) * time.Second) _, _ = io.Copy(ioutil.Discard, res.Body) _ = res.Body.Close() return c.doBuffered(req, bodyBuffer, false) } return res, err } // DoJSON performs the request and unmarshals the response as JSON into data. // If the API returns a non-2xx HTTP status code, the error returned will contain status // and response as plaintext. API errors must be checked by the caller. // It is performed buffered, in case we need to retry. func (c *client) DoJSON(req *http.Request, data interface{}) error { // Copy the request body in case we need to retry it var reqBodyBuffer []byte if req.Body != nil { defer req.Body.Close() //nolint[errcheck] var err error if reqBodyBuffer, err = ioutil.ReadAll(req.Body); err != nil { return err } req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer)) } return c.doJSONBuffered(req, reqBodyBuffer, data) } // doJSONBuffered performs a buffered json request (see DoJSON for more information). func (c *client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data interface{}) error { // nolint[funlen] req.Header.Set("Accept", "application/vnd.protonmail.v1+json") var cancelRequest context.CancelFunc if c.cm.GetConfig().MinSpeed > 0 { var ctx context.Context ctx, cancelRequest = context.WithCancel(req.Context()) defer func() { cancelRequest() }() req = req.WithContext(ctx) } res, err := c.doBuffered(req, reqBodyBuffer, false) if err != nil { return err } defer res.Body.Close() //nolint[errcheck] var resBody []byte if c.cm.GetConfig().MinSpeed == 0 { resBody, err = ioutil.ReadAll(res.Body) } else { resBody, err = c.readAllMinSpeed(res.Body, cancelRequest) } // The server response may contain data which we want to have in memory // for as little time as possible (such as keys). Go is garbage collected, // so we are not in charge of when the memory will actually be cleared. // We can at least try to rewrite the original data to mitigate this problem. defer func() { for i := 0; i < len(resBody); i++ { resBody[i] = byte(65) } }() if logrus.GetLevel() == logrus.TraceLevel { head := "" for i, v := range res.Header { head += i + ": " head += strings.Join(v, "") head += "\n" } c.log.Tracef("RESHEAD \n%s", head) c.log.Tracef("RESBODY '%s'", resBody) } if err != nil { return err } // Retry induced by API code. errCode := &Res{} if err := json.Unmarshal(resBody, errCode); err == nil { if errCode.Code == BansRequests { retryAfter := 3 c.log.Warningf("Retrying %s after %ds induced by API code %d", req.URL.Path, retryAfter, errCode.Code) time.Sleep(time.Duration(retryAfter) * time.Second) if len(reqBodyBuffer) > 0 { req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer)) } return c.doJSONBuffered(req, reqBodyBuffer, data) } } if err := json.Unmarshal(resBody, data); err != nil { // Check to see if this is due to a non 2xx HTTP status code. if res.StatusCode != http.StatusOK { r := bytes.NewReader(bytes.ReplaceAll(resBody, []byte("\n"), []byte("\\n"))) plaintext, err := html2text.FromReader(r) if err == nil { return fmt.Errorf("Error: \n\n" + res.Status + "\n\n" + plaintext) } } if errJS, ok := err.(*json.SyntaxError); ok { return fmt.Errorf("invalid json %v (offset:%d) ", errJS.Error(), errJS.Offset) } return fmt.Errorf("unmarshal fail: %v ", err) } // Set StatusCode in case data struct supports that field. // It's safe to set StatusCode, server returns Code. StatusCode should be preferred over Code. dataValue := reflect.ValueOf(data).Elem() statusCodeField := dataValue.FieldByName("StatusCode") if statusCodeField.IsValid() && statusCodeField.CanSet() && statusCodeField.Kind() == reflect.Int { statusCodeField.SetInt(int64(res.StatusCode)) } if res.StatusCode != http.StatusOK { c.log.Warnf("request %s %s NOT OK: %s", req.Method, req.URL.Path, res.Status) } return nil } func (c *client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) { firstReadTimeout := c.cm.GetConfig().FirstReadTimeout if firstReadTimeout == 0 { firstReadTimeout = 5 * time.Minute } timer := time.AfterFunc(firstReadTimeout, func() { cancelRequest() }) var buffer bytes.Buffer for { _, err := io.CopyN(&buffer, data, c.cm.GetConfig().MinSpeed) timer.Stop() timer.Reset(1 * time.Second) if err == io.EOF { break } else if err != nil { return nil, err } } return ioutil.ReadAll(&buffer) } func (c *client) refreshAccessToken() (err error) { c.log.Debug("Refreshing token") refreshToken := c.cm.GetToken(c.userID) if refreshToken == "" { c.sendAuth(nil) return ErrInvalidToken } if _, err := c.AuthRefresh(refreshToken); err != nil { c.sendAuth(nil) return err } return } func (c *client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) { c.log.Info("Handling unauthorized status") // If this is not a retry, then it is the first time handling status unauthorized, // so try again without refreshing the access token. if !retry { c.log.Debug("Handling unauthorized status by retrying") c.requestLocker.Lock() defer c.requestLocker.Unlock() _, _ = io.Copy(ioutil.Discard, res.Body) _ = res.Body.Close() return c.doBuffered(req, reqBodyBuffer, true) } // This is already a retry, so we will try to refresh the access token before trying again. if err = c.refreshAccessToken(); err != nil { c.log.WithError(err).Warn("Cannot refresh token") err = &ErrUnauthorized{err} return } _, err = io.Copy(ioutil.Discard, res.Body) if err != nil { c.log.WithError(err).Warn("Failed to read out response body") } _ = res.Body.Close() return c.doBuffered(req, reqBodyBuffer, true) }