mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-15 14:56:42 +00:00
GODT-35: Finish all details and make tests pass
This commit is contained in:
@ -32,12 +32,6 @@ const (
|
||||
EnabledAddress
|
||||
)
|
||||
|
||||
// Address receive values.
|
||||
const (
|
||||
CannotReceive = iota
|
||||
CanReceive
|
||||
)
|
||||
|
||||
// Address HasKeys values.
|
||||
const (
|
||||
MissingKeys = iota
|
||||
@ -66,7 +60,7 @@ type Address struct {
|
||||
DomainID string
|
||||
Email string
|
||||
Send int
|
||||
Receive int
|
||||
Receive Boolean
|
||||
Status int
|
||||
Order int `json:",omitempty"`
|
||||
Type int
|
||||
@ -103,7 +97,7 @@ func (l AddressList) AllEmails() (addresses []string) {
|
||||
// ActiveEmails returns only active emails.
|
||||
func (l AddressList) ActiveEmails() (addresses []string) {
|
||||
for _, a := range l {
|
||||
if a.Receive == CanReceive {
|
||||
if a.Receive {
|
||||
addresses = append(addresses, a.Email)
|
||||
}
|
||||
}
|
||||
@ -175,8 +169,19 @@ func (c *client) GetAddresses(ctx context.Context) (addresses AddressList, err e
|
||||
return res.Addresses, nil
|
||||
}
|
||||
|
||||
func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) (err error) {
|
||||
panic("TODO")
|
||||
func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) error {
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(&struct {
|
||||
AddressIDs []string
|
||||
}{
|
||||
AddressIDs: addressIDs,
|
||||
}).Put("/addresses/order")
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := c.UpdateUser(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Addresses returns the addresses stored in the client object itself rather than fetching from the API.
|
||||
@ -185,24 +190,22 @@ func (c *client) Addresses() AddressList {
|
||||
}
|
||||
|
||||
// unlockAddresses unlocks all keys for all addresses of current user.
|
||||
func (c *client) unlockAddress(passphrase []byte, address *Address) (err error) {
|
||||
func (c *client) unlockAddress(passphrase []byte, address *Address) error {
|
||||
if address == nil {
|
||||
return errors.New("address data is missing")
|
||||
}
|
||||
|
||||
if address.HasKeys == MissingKeys {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
var kr *crypto.KeyRing
|
||||
|
||||
if kr, err = address.Keys.UnlockAll(passphrase, c.userKeyRing); err != nil {
|
||||
return
|
||||
kr, err := address.Keys.UnlockAll(passphrase, c.userKeyRing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.addrKeyRing[address.ID] = kr
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) {
|
||||
|
||||
@ -20,6 +20,8 @@ package pmapi
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testAddressList = AddressList{
|
||||
@ -46,39 +48,29 @@ var testAddressList = AddressList{
|
||||
},
|
||||
}
|
||||
|
||||
func routeGetAddresses(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(tb, checkMethodAndPath(r, "GET", "/addresses"))
|
||||
Ok(tb, isAuthReq(r, testUID, testAccessToken))
|
||||
func routeGetAddresses(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(tb, checkMethodAndPath(req, "GET", "/addresses"))
|
||||
r.NoError(tb, isAuthReq(req, testUID, testAccessToken))
|
||||
return "addresses/get_response.json"
|
||||
}
|
||||
|
||||
func TestAddressList(t *testing.T) {
|
||||
input := "1"
|
||||
addr := testAddressList.ByID(input)
|
||||
if addr != testAddressList[0] {
|
||||
t.Errorf("ById(%s) expected:\n%v\n but have:\n%v\n", input, testAddressList[0], addr)
|
||||
}
|
||||
r.Equal(t, testAddressList[0], addr)
|
||||
|
||||
input = "42"
|
||||
addr = testAddressList.ByID(input)
|
||||
if addr != nil {
|
||||
t.Errorf("ById expected nil for %s but have : %v\n", input, addr)
|
||||
}
|
||||
r.Nil(t, addr)
|
||||
|
||||
input = "root@protonmail.com"
|
||||
addr = testAddressList.ByEmail(input)
|
||||
if addr != testAddressList[2] {
|
||||
t.Errorf("ByEmail(%s) expected:\n%v\n but have:\n%v\n", input, testAddressList[2], addr)
|
||||
}
|
||||
r.Equal(t, testAddressList[2], addr)
|
||||
|
||||
input = "idontexist@protonmail.com"
|
||||
addr = testAddressList.ByEmail(input)
|
||||
if addr != nil {
|
||||
t.Errorf("ByEmail expected nil for %s but have : %v\n", input, addr)
|
||||
}
|
||||
r.Nil(t, addr)
|
||||
|
||||
addr = testAddressList.Main()
|
||||
if addr != testAddressList[1] {
|
||||
t.Errorf("Main() expected:\n%v\n but have:\n%v\n", testAddressList[1], addr)
|
||||
}
|
||||
r.Equal(t, testAddressList[1], addr)
|
||||
}
|
||||
|
||||
@ -23,7 +23,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
@ -138,44 +137,6 @@ func (a *Attachment) DetachedSign(kr *crypto.KeyRing, att io.Reader) (signed io.
|
||||
return signAttachment(kr, att)
|
||||
}
|
||||
|
||||
func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.Reader) (err error) {
|
||||
// Create metadata fields.
|
||||
if err = w.WriteField("Filename", att.Name); err != nil {
|
||||
return
|
||||
}
|
||||
if err = w.WriteField("MessageID", att.MessageID); err != nil {
|
||||
return
|
||||
}
|
||||
if err = w.WriteField("MIMEType", att.MIMEType); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = w.WriteField("ContentID", att.ContentID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// And send attachment data.
|
||||
ff, err := w.CreateFormFile("DataPacket", "DataPacket.pgp")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = io.Copy(ff, r); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// And send attachment data.
|
||||
sigff, err := w.CreateFormFile("Signature", "Signature.pgp")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = io.Copy(sigff, sig); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID.
|
||||
//
|
||||
// The returned created attachment contains the new attachment ID and its size.
|
||||
|
||||
@ -28,13 +28,13 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pmmime "github.com/ProtonMail/proton-bridge/pkg/mime"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testAttachment = &Attachment{
|
||||
@ -77,65 +77,40 @@ const testCreateAttachmentBody = `{
|
||||
"Attachment": {"ID": "y6uKIlc2HdoHPAwPSrvf7dXoZNMYvBgxshYUN67cY5DJjL2O8NYewuvGHcYvCfd8LpEoAI_GdymO0Jr0mHlsEw=="}
|
||||
}`
|
||||
|
||||
const testDeleteAttachmentBody = `{
|
||||
"Code": 1000
|
||||
}`
|
||||
|
||||
func TestAttachment_UnmarshalJSON(t *testing.T) {
|
||||
att := new(Attachment)
|
||||
if err := json.Unmarshal([]byte(testAttachmentJSON), att); err != nil {
|
||||
t.Fatal("Expected no error while unmarshaling JSON, got:", err)
|
||||
}
|
||||
err := json.Unmarshal([]byte(testAttachmentJSON), att)
|
||||
r.NoError(t, err)
|
||||
|
||||
att.MessageID = testAttachment.MessageID // This isn't in the JSON object
|
||||
|
||||
if !reflect.DeepEqual(testAttachment, att) {
|
||||
t.Errorf("Invalid attachment: expected %+v but got %+v", testAttachment, att)
|
||||
}
|
||||
r.Equal(t, testAttachment, att)
|
||||
}
|
||||
|
||||
func TestClient_CreateAttachment(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/attachments"))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/mail/v4/attachments"))
|
||||
|
||||
contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
t.Error("Expected no error while parsing request content type, got:", err)
|
||||
}
|
||||
if contentType != "multipart/form-data" {
|
||||
t.Errorf("Invalid request content type: expected %v but got %v", "multipart/form-data", contentType)
|
||||
}
|
||||
contentType, params, err := pmmime.ParseMediaType(req.Header.Get("Content-Type"))
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, "multipart/form-data", contentType)
|
||||
|
||||
mr := multipart.NewReader(r.Body, params["boundary"])
|
||||
mr := multipart.NewReader(req.Body, params["boundary"])
|
||||
form, err := mr.ReadForm(10 * 1024)
|
||||
if err != nil {
|
||||
t.Error("Expected no error while parsing request form, got:", err)
|
||||
}
|
||||
defer Ok(t, form.RemoveAll())
|
||||
r.NoError(t, err)
|
||||
defer r.NoError(t, form.RemoveAll())
|
||||
|
||||
if form.Value["Filename"][0] != testAttachment.Name {
|
||||
t.Errorf("Invalid attachment filename: expected %v but got %v", testAttachment.Name, form.Value["Filename"][0])
|
||||
}
|
||||
if form.Value["MessageID"][0] != testAttachment.MessageID {
|
||||
t.Errorf("Invalid attachment message id: expected %v but got %v", testAttachment.MessageID, form.Value["MessageID"][0])
|
||||
}
|
||||
if form.Value["MIMEType"][0] != testAttachment.MIMEType {
|
||||
t.Errorf("Invalid attachment message id: expected %v but got %v", testAttachment.MIMEType, form.Value["MIMEType"][0])
|
||||
}
|
||||
r.Equal(t, testAttachment.Name, form.Value["Filename"][0])
|
||||
r.Equal(t, testAttachment.MessageID, form.Value["MessageID"][0])
|
||||
r.Equal(t, testAttachment.MIMEType, form.Value["MIMEType"][0])
|
||||
|
||||
dataFile, err := form.File["DataPacket"][0].Open()
|
||||
if err != nil {
|
||||
t.Error("Expected no error while opening packets file, got:", err)
|
||||
}
|
||||
defer Ok(t, dataFile.Close())
|
||||
r.NoError(t, err)
|
||||
defer r.NoError(t, dataFile.Close())
|
||||
|
||||
b, err := ioutil.ReadAll(dataFile)
|
||||
if err != nil {
|
||||
t.Error("Expected no error while reading packets file, got:", err)
|
||||
}
|
||||
if string(b) != testAttachmentCleartext {
|
||||
t.Errorf("Invalid attachment packets: expected %v but got %v", testAttachment.KeyPackets, string(b))
|
||||
}
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, testAttachmentCleartext, string(b))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@ -143,50 +118,39 @@ func TestClient_CreateAttachment(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
r := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted
|
||||
created, err := c.CreateAttachment(context.TODO(), testAttachment, r, strings.NewReader(""))
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while creating attachment, got:", err)
|
||||
}
|
||||
reader := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted
|
||||
created, err := c.CreateAttachment(context.Background(), testAttachment, reader, strings.NewReader(""))
|
||||
r.NoError(t, err)
|
||||
|
||||
if created.ID != testAttachment.ID {
|
||||
t.Errorf("Invalid attachment id: expected %v but got %v", testAttachment.ID, created.ID)
|
||||
}
|
||||
r.Equal(t, testAttachment.ID, created.ID)
|
||||
}
|
||||
|
||||
func TestClient_GetAttachment(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/mail/v4/attachments/"+testAttachment.ID))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "GET", "/mail/v4/attachments/"+testAttachment.ID))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testAttachmentCleartext)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
r, err := c.GetAttachment(context.TODO(), testAttachment.ID)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while getting attachment, got:", err)
|
||||
}
|
||||
defer r.Close() //nolint[errcheck]
|
||||
att, err := c.GetAttachment(context.Background(), testAttachment.ID)
|
||||
r.NoError(t, err)
|
||||
defer att.Close() //nolint[errcheck]
|
||||
|
||||
// In reality, r contains encrypted data
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while reading attachment, got:", err)
|
||||
}
|
||||
b, err := ioutil.ReadAll(att)
|
||||
r.NoError(t, err)
|
||||
|
||||
if string(b) != testAttachmentCleartext {
|
||||
t.Errorf("Invalid attachment data: expected %q but got %q", testAttachmentCleartext, string(b))
|
||||
}
|
||||
r.Equal(t, testAttachmentCleartext, string(b))
|
||||
}
|
||||
|
||||
func TestAttachment_Encrypt(t *testing.T) {
|
||||
data := bytes.NewBufferString(testAttachmentCleartext)
|
||||
r, err := testAttachment.Encrypt(testPublicKeyRing, data)
|
||||
assert.Nil(t, err)
|
||||
a.Nil(t, err)
|
||||
b, err := ioutil.ReadAll(r)
|
||||
assert.Nil(t, err)
|
||||
a.Nil(t, err)
|
||||
|
||||
// Result is always different, so the best way is to test it by decrypting again.
|
||||
// Another test for decrypting will help us to be sure it's working.
|
||||
@ -202,8 +166,8 @@ func TestAttachment_Decrypt(t *testing.T) {
|
||||
|
||||
func decryptAndCheck(t *testing.T, data io.Reader) {
|
||||
r, err := testAttachment.Decrypt(data, testPrivateKeyRing)
|
||||
assert.Nil(t, err)
|
||||
a.Nil(t, err)
|
||||
b, err := ioutil.ReadAll(r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testAttachmentCleartext, string(b))
|
||||
a.Nil(t, err)
|
||||
a.Equal(t, testAttachmentCleartext, string(b))
|
||||
}
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
@ -6,15 +23,117 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
func (c *client) Auth2FA(ctx context.Context, req Auth2FAReq) error {
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(req).Post("/auth/2fa")
|
||||
type AuthModulus struct {
|
||||
Modulus string
|
||||
ModulusID string
|
||||
}
|
||||
|
||||
type GetAuthInfoReq struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
type AuthInfo struct {
|
||||
Version int
|
||||
Modulus string
|
||||
ServerEphemeral string
|
||||
Salt string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
type TwoFAInfo struct {
|
||||
Enabled TwoFAStatus
|
||||
}
|
||||
|
||||
func (twoFAInfo TwoFAInfo) hasTwoFactor() bool {
|
||||
return twoFAInfo.Enabled > 0
|
||||
}
|
||||
|
||||
type TwoFAStatus int
|
||||
|
||||
const (
|
||||
TwoFADisabled TwoFAStatus = iota
|
||||
TOTPEnabled
|
||||
U2FEnabled
|
||||
TOTPAndU2FEnabled
|
||||
)
|
||||
|
||||
type PasswordMode int
|
||||
|
||||
const (
|
||||
OnePasswordMode PasswordMode = iota + 1
|
||||
TwoPasswordMode
|
||||
)
|
||||
|
||||
type AuthReq struct {
|
||||
Username string
|
||||
ClientProof string
|
||||
ClientEphemeral string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
type AuthRefresh struct {
|
||||
UID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresIn int64
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
AuthRefresh
|
||||
|
||||
UserID string
|
||||
ServerProof string
|
||||
PasswordMode PasswordMode
|
||||
TwoFA *TwoFAInfo `json:"2FA,omitempty"`
|
||||
}
|
||||
|
||||
func (a Auth) HasTwoFactor() bool {
|
||||
if a.TwoFA == nil {
|
||||
return false
|
||||
}
|
||||
return a.TwoFA.hasTwoFactor()
|
||||
}
|
||||
|
||||
func (a Auth) HasMailboxPassword() bool {
|
||||
return a.PasswordMode == TwoPasswordMode
|
||||
}
|
||||
|
||||
type auth2FAReq struct {
|
||||
TwoFactorCode string
|
||||
}
|
||||
|
||||
type authRefreshReq struct {
|
||||
UID string
|
||||
RefreshToken string
|
||||
ResponseType string
|
||||
GrantType string
|
||||
RedirectURI string
|
||||
State string
|
||||
}
|
||||
|
||||
func (c *client) Auth2FA(ctx context.Context, twoFactorCode string) error {
|
||||
// 2FA is called during login procedure during which refresh token should
|
||||
// be valid, therefore, no refresh is needed if there is an error.
|
||||
ctx = ContextWithoutAuthRefresh(ctx)
|
||||
|
||||
if res, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(auth2FAReq{TwoFactorCode: twoFactorCode}).Post("/auth/2fa")
|
||||
}); err != nil {
|
||||
if res != nil {
|
||||
switch res.StatusCode() {
|
||||
case http.StatusUnauthorized:
|
||||
return ErrBad2FACode
|
||||
case http.StatusUnprocessableEntity:
|
||||
return ErrBad2FACodeTryAgain
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@ -29,9 +148,7 @@ func (c *client) AuthDelete(ctx context.Context) error {
|
||||
}
|
||||
|
||||
c.uid, c.acc, c.ref, c.exp = "", "", "", time.Time{}
|
||||
|
||||
// FIXME(conman): should we perhaps signal via AuthHandler that the auth was deleted?
|
||||
|
||||
c.sendAuthRefresh(nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -54,7 +171,7 @@ func (c *client) AuthSalt(ctx context.Context) (string, error) {
|
||||
return "", errors.New("no matching salt found")
|
||||
}
|
||||
|
||||
func (c *client) AddAuthHandler(handler AuthHandler) {
|
||||
func (c *client) AddAuthRefreshHandler(handler AuthRefreshHandler) {
|
||||
c.authHandlers = append(c.authHandlers, handler)
|
||||
}
|
||||
|
||||
@ -62,23 +179,35 @@ func (c *client) authRefresh(ctx context.Context) error {
|
||||
c.authLocker.Lock()
|
||||
defer c.authLocker.Unlock()
|
||||
|
||||
auth, err := c.req.authRefresh(ctx, c.uid, c.ref)
|
||||
if c.ref == "" {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
|
||||
auth, err := c.manager.authRefresh(ctx, c.uid, c.ref)
|
||||
if err != nil {
|
||||
if err != ErrNoConnection {
|
||||
c.sendAuthRefresh(nil)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c.acc = auth.AccessToken
|
||||
c.ref = auth.RefreshToken
|
||||
c.exp = expiresIn(auth.ExpiresIn)
|
||||
|
||||
for _, handler := range c.authHandlers {
|
||||
if err := handler(auth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.sendAuthRefresh(auth)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) sendAuthRefresh(auth *AuthRefresh) {
|
||||
for _, handler := range c.authHandlers {
|
||||
go handler(auth)
|
||||
}
|
||||
if auth == nil {
|
||||
c.authHandlers = []AuthRefreshHandler{}
|
||||
}
|
||||
}
|
||||
|
||||
func randomString(length int) string {
|
||||
noise := make([]byte, length)
|
||||
|
||||
|
||||
@ -1,22 +1,40 @@
|
||||
package pmapi_test
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
var wantAuth = &pmapi.Auth{
|
||||
var wantAuthRefresh = &AuthRefresh{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
ExpiresIn: 100,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
@ -24,7 +42,7 @@ func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
@ -35,28 +53,28 @@ func TestAutomaticAuthRefresh(t *testing.T) {
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuth *pmapi.Auth
|
||||
var gotAuthRefresh *AuthRefresh
|
||||
|
||||
// Create a new client.
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
|
||||
|
||||
// Register an auth handler.
|
||||
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
|
||||
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||
|
||||
// Make a request with an access token that already expired one second ago.
|
||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
|
||||
// The auth callback should have been called.
|
||||
if *gotAuth != *wantAuth {
|
||||
t.Fatal("got unexpected auth", gotAuth)
|
||||
}
|
||||
a.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||
|
||||
cl := c.(*client) //nolint[forcetypeassert] we want to panic here
|
||||
a.Equal(t, wantAuthRefresh.AccessToken, cl.acc)
|
||||
a.Equal(t, wantAuthRefresh.RefreshToken, cl.ref)
|
||||
a.WithinDuration(t, expiresIn(100), cl.exp, time.Second)
|
||||
}
|
||||
|
||||
func Test401AuthRefresh(t *testing.T) {
|
||||
var wantAuth = &pmapi.Auth{
|
||||
var wantAuthRefresh = &AuthRefresh{
|
||||
UID: "testUID",
|
||||
AccessToken: "testAcc",
|
||||
RefreshToken: "testRef",
|
||||
@ -67,7 +85,7 @@ func Test401AuthRefresh(t *testing.T) {
|
||||
mux.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(wantAuthRefresh); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
@ -86,24 +104,21 @@ func Test401AuthRefresh(t *testing.T) {
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
var gotAuth *pmapi.Auth
|
||||
var gotAuthRefresh *AuthRefresh
|
||||
|
||||
// Create a new client.
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
|
||||
// Register an auth handler.
|
||||
c.AddAuthHandler(func(auth *pmapi.Auth) error { gotAuth = auth; return nil })
|
||||
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
|
||||
|
||||
// The first request will fail with 401, triggering a refresh and retry.
|
||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
|
||||
// The auth callback should have been called.
|
||||
if *gotAuth != *wantAuth {
|
||||
t.Fatal("got unexpected auth", gotAuth)
|
||||
}
|
||||
r.Equal(t, *wantAuthRefresh, *gotAuthRefresh)
|
||||
}
|
||||
|
||||
func Test401RevokedAuth(t *testing.T) {
|
||||
@ -119,17 +134,57 @@ func Test401RevokedAuth(t *testing.T) {
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
|
||||
c := pmapi.New(pmapi.Config{HostURL: ts.URL}).
|
||||
c := New(Config{HostURL: ts.URL}).
|
||||
NewClient("uid", "acc", "ref", time.Now().Add(time.Hour))
|
||||
|
||||
// The request will fail with 401, triggering a refresh.
|
||||
// The retry will also fail with 401, returning an error.
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
if !errors.Is(err, pmapi.ErrUnauthorized) {
|
||||
t.Fatal("expected error to be ErrUnauthorized, instead got", err)
|
||||
}
|
||||
r.EqualError(t, err, ErrUnauthorized.Error())
|
||||
}
|
||||
|
||||
func TestAuth2FA(t *testing.T) {
|
||||
twoFACode := "code"
|
||||
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
|
||||
var twoFAreq auth2FAReq
|
||||
r.NoError(t, json.NewDecoder(req.Body).Decode(&twoFAreq))
|
||||
r.Equal(t, twoFAreq.TwoFactorCode, twoFACode)
|
||||
|
||||
return "/auth/2fa/post_response.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), twoFACode)
|
||||
r.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuth2FA_Fail(t *testing.T) {
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
return "/auth/2fa/post_401_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), "code")
|
||||
r.Equal(t, ErrBad2FACode, err)
|
||||
}
|
||||
|
||||
func TestAuth2FA_Retry(t *testing.T) {
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/auth/2fa"))
|
||||
return "/auth/2fa/post_422_bad_password.json"
|
||||
},
|
||||
)
|
||||
defer finish()
|
||||
|
||||
err := c.Auth2FA(context.Background(), "code")
|
||||
r.Equal(t, ErrBad2FACodeTryAgain, err)
|
||||
}
|
||||
|
||||
@ -1,72 +0,0 @@
|
||||
package pmapi
|
||||
|
||||
type AuthModulus struct {
|
||||
Modulus string
|
||||
ModulusID string
|
||||
}
|
||||
|
||||
type GetAuthInfoReq struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
type AuthInfo struct {
|
||||
Version int
|
||||
Modulus string
|
||||
ServerEphemeral string
|
||||
Salt string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
type TwoFAInfo struct {
|
||||
Enabled TwoFAStatus
|
||||
}
|
||||
|
||||
type TwoFAStatus int
|
||||
|
||||
const (
|
||||
TwoFADisabled TwoFAStatus = iota
|
||||
TOTPEnabled
|
||||
// TODO: Support UTF
|
||||
)
|
||||
|
||||
type PasswordMode int
|
||||
|
||||
const (
|
||||
OnePasswordMode PasswordMode = iota + 1
|
||||
TwoPasswordMode
|
||||
)
|
||||
|
||||
type AuthReq struct {
|
||||
Username string
|
||||
ClientProof string
|
||||
ClientEphemeral string
|
||||
SRPSession string
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
UserID string
|
||||
|
||||
UID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresIn int64
|
||||
|
||||
Scope string
|
||||
ServerProof string
|
||||
|
||||
TwoFA TwoFAInfo `json:"2FA"`
|
||||
PasswordMode PasswordMode
|
||||
}
|
||||
|
||||
type Auth2FAReq struct {
|
||||
TwoFactorCode string
|
||||
}
|
||||
|
||||
type AuthRefreshReq struct {
|
||||
UID string
|
||||
RefreshToken string
|
||||
ResponseType string
|
||||
GrantType string
|
||||
RedirectURI string
|
||||
State string
|
||||
}
|
||||
41
pkg/pmapi/boolean.go
Normal file
41
pkg/pmapi/boolean.go
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Boolean bool
|
||||
|
||||
func (boolean *Boolean) UnmarshalJSON(b []byte) error {
|
||||
var value int
|
||||
err := json.Unmarshal(b, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*boolean = Boolean(value == 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (boolean Boolean) MarshalJSON() ([]byte, error) {
|
||||
var value int
|
||||
if boolean {
|
||||
value = 1
|
||||
}
|
||||
return json.Marshal(value)
|
||||
}
|
||||
@ -25,15 +25,14 @@ import (
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// client is a client of the protonmail API. It implements the Client interface.
|
||||
type client struct {
|
||||
req requester
|
||||
manager clientManager
|
||||
|
||||
uid, acc, ref string
|
||||
authHandlers []AuthHandler
|
||||
authHandlers []AuthRefreshHandler
|
||||
authLocker sync.RWMutex
|
||||
|
||||
user *User
|
||||
@ -45,9 +44,9 @@ type client struct {
|
||||
exp time.Time
|
||||
}
|
||||
|
||||
func newClient(req requester, uid string) *client {
|
||||
func newClient(manager clientManager, uid string) *client {
|
||||
return &client{
|
||||
req: req,
|
||||
manager: manager,
|
||||
uid: uid,
|
||||
addrKeyRing: make(map[string]*crypto.KeyRing),
|
||||
keyRingLock: &sync.RWMutex{},
|
||||
@ -63,7 +62,7 @@ func (c *client) withAuth(acc, ref string, exp time.Time) *client {
|
||||
}
|
||||
|
||||
func (c *client) r(ctx context.Context) (*resty.Request, error) {
|
||||
r := c.req.r(ctx)
|
||||
r := c.manager.r(ctx)
|
||||
|
||||
if c.uid != "" {
|
||||
r.SetHeader("x-pm-uid", c.uid)
|
||||
@ -91,30 +90,23 @@ func (c *client) do(ctx context.Context, fn func(*resty.Request) (*resty.Respons
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := wrapRestyError(fn(r))
|
||||
res, err := wrapNoConnection(fn(r))
|
||||
if err != nil {
|
||||
if res.StatusCode() != http.StatusUnauthorized {
|
||||
return nil, err
|
||||
// Return also response so caller has more options to decide what to do.
|
||||
return res, err
|
||||
}
|
||||
|
||||
if err := c.authRefresh(ctx); err != nil {
|
||||
return nil, err
|
||||
if !isAuthRefreshDisabled(ctx) {
|
||||
if err := c.authRefresh(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapNoConnection(fn(r))
|
||||
}
|
||||
|
||||
return wrapRestyError(fn(r))
|
||||
return res, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func wrapRestyError(res *resty.Response, err error) (*resty.Response, error) {
|
||||
if err, ok := err.(*resty.ResponseError); ok {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if res.RawResponse != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return res, errors.Wrap(ErrNoConnection, err.Error())
|
||||
}
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
@ -12,8 +29,6 @@ func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) {
|
||||
c.keyRingLock.Lock()
|
||||
defer c.keyRingLock.Unlock()
|
||||
|
||||
// FIXME(conman): Should this be done as part of NewClient somehow?
|
||||
|
||||
return c.unlock(ctx, passphrase)
|
||||
}
|
||||
|
||||
@ -65,6 +80,15 @@ func (c *client) clearKeys() {
|
||||
}
|
||||
|
||||
func (c *client) IsUnlocked() bool {
|
||||
// FIXME(conman): Better way to check? we don't currently check address keys.
|
||||
return c.userKeyRing != nil
|
||||
if c.userKeyRing == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, address := range c.addresses {
|
||||
if address.HasKeys != MissingKeys && c.addrKeyRing[address.ID] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@ -27,10 +27,10 @@ import (
|
||||
|
||||
// Client defines the interface of a PMAPI client.
|
||||
type Client interface {
|
||||
Auth2FA(context.Context, Auth2FAReq) error
|
||||
Auth2FA(context.Context, string) error
|
||||
AuthSalt(ctx context.Context) (string, error)
|
||||
AuthDelete(context.Context) error
|
||||
AddAuthHandler(AuthHandler)
|
||||
AddAuthRefreshHandler(AuthRefreshHandler)
|
||||
|
||||
CurrentUser(ctx context.Context) (*User, error)
|
||||
UpdateUser(ctx context.Context) (*User, error)
|
||||
@ -75,9 +75,9 @@ type Client interface {
|
||||
GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
|
||||
}
|
||||
|
||||
type AuthHandler func(*Auth) error
|
||||
type AuthRefreshHandler func(*AuthRefresh)
|
||||
|
||||
type requester interface {
|
||||
type clientManager interface {
|
||||
r(context.Context) *resty.Request
|
||||
authRefresh(context.Context, string, string) (*Auth, error)
|
||||
authRefresh(context.Context, string, string) (*AuthRefresh, error)
|
||||
}
|
||||
|
||||
@ -1,11 +1,72 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
HostURL string
|
||||
// HostURL is the base URL of API.
|
||||
HostURL string
|
||||
|
||||
// AppVersion sets version to headers of each request.
|
||||
AppVersion string
|
||||
|
||||
// UserAgent sets user agent to headers of each request.
|
||||
// Used only if GetUserAgent is not set.
|
||||
UserAgent string
|
||||
|
||||
// GetUserAgent is dynamic version of UserAgent.
|
||||
// Overrides UserAgent.
|
||||
GetUserAgent func() string
|
||||
|
||||
// UpgradeApplicationHandler is used to notify when there is a force upgrade.
|
||||
UpgradeApplicationHandler func()
|
||||
|
||||
// TLSIssueHandler is used to notify when there is a TLS issue.
|
||||
TLSIssueHandler func()
|
||||
}
|
||||
|
||||
var DefaultConfig = Config{
|
||||
HostURL: "https://api.protonmail.ch",
|
||||
AppVersion: "Other",
|
||||
func NewConfig(appVersionName, appVersion string) Config {
|
||||
return Config{
|
||||
HostURL: getRootURL(),
|
||||
AppVersion: getAPIOS() + strings.Title(appVersionName) + "_" + appVersion,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) getUserAgent() string {
|
||||
if c.GetUserAgent == nil {
|
||||
return c.UserAgent
|
||||
}
|
||||
return c.GetUserAgent()
|
||||
}
|
||||
|
||||
// getAPIOS returns actual operating system.
|
||||
func getAPIOS() string {
|
||||
switch os := runtime.GOOS; os {
|
||||
case "darwin": // nolint: goconst
|
||||
return "macOS"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
}
|
||||
return "Linux"
|
||||
}
|
||||
|
||||
35
pkg/pmapi/config_default.go
Normal file
35
pkg/pmapi/config_default.go
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// +build !build_qa
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func getRootURL() string {
|
||||
return "https://api.protonmail.ch"
|
||||
}
|
||||
|
||||
func newProxyDialerAndTransport(cfg Config) (*ProxyTLSDialer, http.RoundTripper) {
|
||||
basicDialer := NewBasicTLSDialer(cfg)
|
||||
pinningDialer := NewPinningTLSDialer(cfg, basicDialer)
|
||||
proxyDialer := NewProxyTLSDialer(cfg, pinningDialer)
|
||||
return proxyDialer, CreateTransportWithDialer(proxyDialer)
|
||||
}
|
||||
48
pkg/pmapi/config_qa.go
Normal file
48
pkg/pmapi/config_qa.go
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// +build build_qa
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getRootURL() string {
|
||||
// This config allows to dynamically change ROOT URL.
|
||||
url := os.Getenv("PMAPI_ROOT_URL")
|
||||
if strings.HasPrefix(url, "http") {
|
||||
return url
|
||||
}
|
||||
if url != "" {
|
||||
return "https://" + url
|
||||
}
|
||||
return "https://api.protonmail.ch"
|
||||
}
|
||||
|
||||
func newProxyDialerAndTransport(cfg Config) (*ProxyTLSDialer, http.RoundTripper) {
|
||||
transport := CreateTransportWithDialer(NewBasicTLSDialer(cfg))
|
||||
|
||||
// TLS certificate of testing environment might be self-signed.
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
return nil, transport
|
||||
}
|
||||
@ -129,11 +129,14 @@ func (c *client) GetContactEmailByEmail(ctx context.Context, email string, page
|
||||
}
|
||||
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetQueryParams(map[string]string{
|
||||
"Email": email,
|
||||
"Page": strconv.Itoa(page),
|
||||
"PageSize": strconv.Itoa(pageSize),
|
||||
}).SetResult(&res).Get("/contacts/v4")
|
||||
r = r.SetQueryParams(map[string]string{
|
||||
"Email": email,
|
||||
"Page": strconv.Itoa(page),
|
||||
})
|
||||
if pageSize != 0 {
|
||||
r.SetQueryParam("PageSize", strconv.Itoa(pageSize))
|
||||
}
|
||||
return r.SetResult(&res).Get("/contacts/v4")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -24,7 +24,7 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -106,19 +106,16 @@ var testGetContactByID = Contact{
|
||||
}
|
||||
|
||||
func TestContact_GetContactById(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/contacts/v4/s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg=="))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "GET", "/contacts/v4/s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg=="))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testGetContactByIDResponseBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
contact, err := c.GetContactByID(context.TODO(), "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==")
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while getting contacts, got:", err)
|
||||
}
|
||||
contact, err := c.GetContactByID(context.Background(), "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==")
|
||||
r.NoError(t, err)
|
||||
|
||||
if !reflect.DeepEqual(contact, testGetContactByID) {
|
||||
t.Fatalf("Invalid got contact: expected %+v, got %+v", testGetContactByID, contact)
|
||||
@ -160,24 +157,24 @@ var testCardsCleartext = []Card{
|
||||
}
|
||||
|
||||
func TestClient_Encrypt(t *testing.T) {
|
||||
c := newClient(newManager(DefaultConfig), "")
|
||||
c := newClient(newManager(Config{}), "")
|
||||
c.userKeyRing = testPrivateKeyRing
|
||||
|
||||
cardEncrypted, err := c.EncryptAndSignCards(testCardsCleartext)
|
||||
assert.Nil(t, err)
|
||||
r.Nil(t, err)
|
||||
|
||||
// Result is always different, so the best way is to test it by decrypting again.
|
||||
// Another test for decrypting will help us to be sure it's working.
|
||||
cardCleartext, err := c.DecryptAndVerifyCards(cardEncrypted)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testCardsCleartext[0].Data, cardCleartext[0].Data)
|
||||
r.Nil(t, err)
|
||||
r.Equal(t, testCardsCleartext[0].Data, cardCleartext[0].Data)
|
||||
}
|
||||
|
||||
func TestClient_Decrypt(t *testing.T) {
|
||||
c := newClient(newManager(DefaultConfig), "")
|
||||
c := newClient(newManager(Config{}), "")
|
||||
c.userKeyRing = testPrivateKeyRing
|
||||
|
||||
cardCleartext, err := c.DecryptAndVerifyCards(testCardsEncrypted)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testCardsCleartext[0].Data, cardCleartext[0].Data)
|
||||
r.Nil(t, err)
|
||||
r.Equal(t, testCardsCleartext[0].Data, cardCleartext[0].Data)
|
||||
}
|
||||
|
||||
54
pkg/pmapi/context.go
Normal file
54
pkg/pmapi/context.go
Normal file
@ -0,0 +1,54 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type pmapiContextKey string
|
||||
|
||||
const (
|
||||
retryContextKey = pmapiContextKey("retry")
|
||||
retryDisabled = "disabled"
|
||||
|
||||
authRefreshContextKey = pmapiContextKey("authRefresh")
|
||||
authRefreshDisabled = "disabled"
|
||||
)
|
||||
|
||||
func ContextWithoutRetry(parent context.Context) context.Context {
|
||||
return context.WithValue(parent, retryContextKey, retryDisabled)
|
||||
}
|
||||
|
||||
func isRetryDisabled(ctx context.Context) bool {
|
||||
if v := ctx.Value(retryContextKey); v != nil {
|
||||
return v == retryDisabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ContextWithoutAuthRefresh(parent context.Context) context.Context {
|
||||
return context.WithValue(parent, authRefreshContextKey, authRefreshDisabled)
|
||||
}
|
||||
|
||||
func isAuthRefreshDisabled(ctx context.Context) bool {
|
||||
if v := ctx.Value(authRefreshContextKey); v != nil {
|
||||
return v == authRefreshDisabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -1,3 +1,20 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import "github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
@ -8,9 +25,6 @@ var testIdentity = &crypto.Identity{
|
||||
}
|
||||
|
||||
const (
|
||||
testUsername = "jason"
|
||||
testAPIPassword = "apple"
|
||||
|
||||
testUID = "729ad6012421d67ad26950dc898bebe3a6e3caa2" //nolint[gosec]
|
||||
testAccessToken = "de0423049b44243afeec7d9c1d99be7b46da1e8a" //nolint[gosec]
|
||||
testAccessTokenOld = "feb3159ac63fb05119bcf4480d939278aa746926" //nolint[gosec]
|
||||
|
||||
76
pkg/pmapi/dialer_basic.go
Normal file
76
pkg/pmapi/dialer_basic.go
Normal file
@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TLSDialer interface {
|
||||
DialTLS(network, address string) (conn net.Conn, err error)
|
||||
}
|
||||
|
||||
// CreateTransportWithDialer creates an http.Transport that uses the given dialer to make TLS connections.
|
||||
func CreateTransportWithDialer(dialer TLSDialer) *http.Transport {
|
||||
return &http.Transport{
|
||||
DialTLS: dialer.DialTLS,
|
||||
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 5 * time.Minute,
|
||||
ExpectContinueTimeout: 500 * time.Millisecond,
|
||||
|
||||
// GODT-126: this was initially 10s but logs from users showed a significant number
|
||||
// were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect.
|
||||
// Bumping to 30s for now to avoid this problem.
|
||||
ResponseHeaderTimeout: 30 * time.Second,
|
||||
|
||||
// If we allow up to 30 seconds for response headers, it is reasonable to allow up
|
||||
// to 30 seconds for the TLS handshake to take place.
|
||||
TLSHandshakeTimeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// BasicTLSDialer implements TLSDialer.
|
||||
type BasicTLSDialer struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
// NewBasicTLSDialer returns a new BasicTLSDialer.
|
||||
func NewBasicTLSDialer(cfg Config) *BasicTLSDialer {
|
||||
return &BasicTLSDialer{
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// DialTLS returns a connection to the given address using the given network.
|
||||
func (d *BasicTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
|
||||
dialer := &net.Dialer{Timeout: 30 * time.Second} // Alternative Routes spec says this should be a 30s timeout.
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
|
||||
// If we are not dialing the standard API then we should skip cert verification checks.
|
||||
if address != d.cfg.HostURL {
|
||||
tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec]
|
||||
}
|
||||
|
||||
return tls.DialWithDialer(dialer, network, address, tlsConfig)
|
||||
}
|
||||
110
pkg/pmapi/dialer_pinning.go
Normal file
110
pkg/pmapi/dialer_pinning.go
Normal file
@ -0,0 +1,110 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TrustedAPIPins contains trusted public keys of the protonmail API and proxies.
|
||||
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;).
|
||||
var TrustedAPIPins = []string{ // nolint[gochecknoglobals]
|
||||
// api.protonmail.ch
|
||||
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
|
||||
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot backup
|
||||
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold backup
|
||||
|
||||
// protonmail.com
|
||||
`pin-sha256="8joiNBdqaYiQpKskgtkJsqRxF7zN0C0aqfi8DacknnI="`, // current
|
||||
`pin-sha256="JMI8yrbc6jB1FYGyyWRLFTmDNgIszrNEMGlgy972e7w="`, // hot backup
|
||||
`pin-sha256="Iu44zU84EOCZ9vx/vz67/MRVrxF1IO4i4NIa8ETwiIY="`, // cold backup
|
||||
|
||||
// proxies
|
||||
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // main
|
||||
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // backup 1
|
||||
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // backup 2
|
||||
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // backup 3
|
||||
}
|
||||
|
||||
// TLSReportURI is the address where TLS reports should be sent.
|
||||
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
|
||||
|
||||
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
|
||||
// to report errors if the fingerprint check fails.
|
||||
type PinningTLSDialer struct {
|
||||
dialer TLSDialer
|
||||
|
||||
// pinChecker is used to check TLS keys of connections.
|
||||
pinChecker *pinChecker
|
||||
|
||||
reporter *tlsReporter
|
||||
|
||||
// tlsIssueNotifier is used to notify something when there is a TLS issue.
|
||||
tlsIssueNotifier func()
|
||||
|
||||
// A logger for logging messages.
|
||||
log logrus.FieldLogger
|
||||
}
|
||||
|
||||
// NewPinningTLSDialer constructs a new dialer which only returns tcp connections to servers
|
||||
// which present known certificates.
|
||||
// If enabled, it reports any invalid certificates it finds.
|
||||
func NewPinningTLSDialer(cfg Config, dialer TLSDialer) *PinningTLSDialer {
|
||||
return &PinningTLSDialer{
|
||||
dialer: dialer,
|
||||
pinChecker: newPinChecker(TrustedAPIPins),
|
||||
reporter: newTLSReporter(cfg, TrustedAPIPins),
|
||||
tlsIssueNotifier: cfg.TLSIssueHandler,
|
||||
log: logrus.WithField("pkg", "pmapi/tls-pinning"),
|
||||
}
|
||||
}
|
||||
|
||||
// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
|
||||
func (p *PinningTLSDialer) DialTLS(network, address string) (net.Conn, error) {
|
||||
conn, err := p.dialer.DialTLS(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.pinChecker.checkCertificate(conn); err != nil {
|
||||
if p.tlsIssueNotifier != nil {
|
||||
go p.tlsIssueNotifier()
|
||||
}
|
||||
|
||||
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
|
||||
p.reporter.reportCertIssue(
|
||||
TLSReportURI,
|
||||
host,
|
||||
port,
|
||||
tlsConn.ConnectionState(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
68
pkg/pmapi/dialer_pinning_checker.go
Normal file
68
pkg/pmapi/dialer_pinning_checker.go
Normal file
@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// ErrTLSMismatch indicates that no TLS fingerprint match could be found.
|
||||
var ErrTLSMismatch = errors.New("no TLS fingerprint match found")
|
||||
|
||||
type pinChecker struct {
|
||||
trustedPins []string
|
||||
}
|
||||
|
||||
func newPinChecker(trustedPins []string) *pinChecker {
|
||||
return &pinChecker{
|
||||
trustedPins: trustedPins,
|
||||
}
|
||||
}
|
||||
|
||||
// checkCertificate returns whether the connection presents a known TLS certificate.
|
||||
func (p *pinChecker) checkCertificate(conn net.Conn) error {
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if !ok {
|
||||
return errors.New("connection is not a TLS connection")
|
||||
}
|
||||
|
||||
connState := tlsConn.ConnectionState()
|
||||
|
||||
for _, peerCert := range connState.PeerCertificates {
|
||||
fingerprint := certFingerprint(peerCert)
|
||||
|
||||
for _, pin := range p.trustedPins {
|
||||
if pin == fingerprint {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ErrTLSMismatch
|
||||
}
|
||||
|
||||
func certFingerprint(cert *x509.Certificate) string {
|
||||
hash := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
|
||||
return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:]))
|
||||
}
|
||||
144
pkg/pmapi/dialer_pinning_report.go
Normal file
144
pkg/pmapi/dialer_pinning_report.go
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// tlsReport is inspired by https://tools.ietf.org/html/rfc7469#section-3.
|
||||
// When a TLS key mismatch is detected, a tlsReport is posted to TLSReportURI.
|
||||
type tlsReport struct {
|
||||
// DateTime of observed pin validation in time.RFC3339 format.
|
||||
DateTime string `json:"date-time"`
|
||||
|
||||
// Hostname to which the UA made original request that failed pin validation.
|
||||
Hostname string `json:"hostname"`
|
||||
|
||||
// Port to which the UA made original request that failed pin validation.
|
||||
Port int `json:"port"`
|
||||
|
||||
// EffectiveExpirationDate for noted pins in time.RFC3339 format.
|
||||
EffectiveExpirationDate string `json:"effective-expiration-date"`
|
||||
|
||||
// IncludeSubdomains indicates whether or not the UA has noted the
|
||||
// includeSubDomains directive for the Known Pinned Host.
|
||||
IncludeSubdomains bool `json:"include-subdomains"`
|
||||
|
||||
// NotedHostname indicates the hostname that the UA noted when it noted
|
||||
// the Known Pinned Host. This field allows operators to understand why
|
||||
// Pin Validation was performed for, e.g., foo.example.com when the
|
||||
// noted Known Pinned Host was example.com with includeSubDomains set.
|
||||
NotedHostname string `json:"noted-hostname"`
|
||||
|
||||
// ServedCertificateChain is the certificate chain, as served by
|
||||
// the Known Pinned Host during TLS session setup. It is provided as an
|
||||
// array of strings; each string pem1, ... pemN is the Privacy-Enhanced
|
||||
// Mail (PEM) representation of each X.509 certificate as described in
|
||||
// [RFC7468].
|
||||
ServedCertificateChain []string `json:"served-certificate-chain"`
|
||||
|
||||
// ValidatedCertificateChain is the certificate chain, as
|
||||
// constructed by the UA during certificate chain verification. (This
|
||||
// may differ from the served-certificate-chain.) It is provided as an
|
||||
// array of strings; each string pem1, ... pemN is the PEM
|
||||
// representation of each X.509 certificate as described in [RFC7468].
|
||||
// UAs that build certificate chains in more than one way during the
|
||||
// validation process SHOULD send the last chain built. In this way,
|
||||
// they can avoid keeping too much state during the validation process.
|
||||
ValidatedCertificateChain []string `json:"validated-certificate-chain"`
|
||||
|
||||
// The known-pins are the Pins that the UA has noted for the Known
|
||||
// Pinned Host. They are provided as an array of strings with the
|
||||
// syntax: known-pin = token "=" quoted-string
|
||||
// e.g.:
|
||||
// ```
|
||||
// "known-pins": [
|
||||
// 'pin-sha256="d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="',
|
||||
// "pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\""
|
||||
// ]
|
||||
// ```
|
||||
KnownPins []string `json:"known-pins"`
|
||||
|
||||
// AppVersion is used to set `x-pm-appversion` json format from datatheorem/TrustKit.
|
||||
AppVersion string `json:"app-version"`
|
||||
}
|
||||
|
||||
// newTLSReport constructs a new tlsReport configured with the given app version and known pinned public keys.
|
||||
// Temporal things (current date/time) are not set yet -- they are set when sendReport is called.
|
||||
func newTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report tlsReport) {
|
||||
// If we can't parse the port for whatever reason, it doesn't really matter; we should report anyway.
|
||||
intPort, _ := strconv.Atoi(port)
|
||||
|
||||
report = tlsReport{
|
||||
Hostname: host,
|
||||
Port: intPort,
|
||||
NotedHostname: server,
|
||||
ServedCertificateChain: certChain,
|
||||
KnownPins: knownPins,
|
||||
AppVersion: appVersion,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// sendReport posts the given TLS report to the standard TLS Report URI.
|
||||
func (r tlsReport) sendReport(cfg Config, uri string) {
|
||||
now := time.Now()
|
||||
r.DateTime = now.Format(time.RFC3339)
|
||||
r.EffectiveExpirationDate = now.Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339)
|
||||
|
||||
b, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("Failed to marshal TLS report")
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", uri, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("Failed to create http request")
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", cfg.getUserAgent())
|
||||
req.Header.Set("x-pm-appversion", r.AppVersion)
|
||||
|
||||
logrus.WithField("request", req).Warn("Reporting TLS mismatch")
|
||||
res, err := (&http.Client{Transport: CreateTransportWithDialer(NewBasicTLSDialer(cfg))}).Do(req)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("Failed to report TLS mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
logrus.WithField("response", res).Error("Reported TLS mismatch")
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
logrus.WithField("status", http.StatusOK).Error("StatusCode was not OK")
|
||||
}
|
||||
|
||||
_, _ = ioutil.ReadAll(res.Body)
|
||||
_ = res.Body.Close()
|
||||
}
|
||||
107
pkg/pmapi/dialer_pinning_reporter.go
Normal file
107
pkg/pmapi/dialer_pinning_reporter.go
Normal file
@ -0,0 +1,107 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type sentReport struct {
|
||||
r tlsReport
|
||||
t time.Time
|
||||
}
|
||||
|
||||
type tlsReporter struct {
|
||||
cfg Config
|
||||
trustedPins []string
|
||||
sentReports []sentReport
|
||||
}
|
||||
|
||||
func newTLSReporter(cfg Config, trustedPins []string) *tlsReporter {
|
||||
return &tlsReporter{
|
||||
cfg: cfg,
|
||||
trustedPins: trustedPins,
|
||||
}
|
||||
}
|
||||
|
||||
// reportCertIssue reports a TLS key mismatch.
|
||||
func (r *tlsReporter) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState) {
|
||||
var certChain []string
|
||||
|
||||
if len(connState.VerifiedChains) > 0 {
|
||||
certChain = marshalCert7468(connState.VerifiedChains[len(connState.VerifiedChains)-1])
|
||||
} else {
|
||||
certChain = marshalCert7468(connState.PeerCertificates)
|
||||
}
|
||||
|
||||
report := newTLSReport(host, port, connState.ServerName, certChain, r.trustedPins, r.cfg.AppVersion)
|
||||
|
||||
if !r.hasRecentlySentReport(report) {
|
||||
r.recordReport(report)
|
||||
go report.sendReport(r.cfg, remoteURI)
|
||||
}
|
||||
}
|
||||
|
||||
// hasRecentlySentReport returns whether the report was already sent within the last 24 hours.
|
||||
func (r *tlsReporter) hasRecentlySentReport(report tlsReport) bool {
|
||||
var validReports []sentReport
|
||||
|
||||
for _, r := range r.sentReports {
|
||||
if time.Since(r.t) < 24*time.Hour {
|
||||
validReports = append(validReports, r)
|
||||
}
|
||||
}
|
||||
|
||||
r.sentReports = validReports
|
||||
|
||||
for _, r := range r.sentReports {
|
||||
if cmp.Equal(report, r.r) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// recordReport records the given report and the current time so we can check whether we recently sent this report.
|
||||
func (r *tlsReporter) recordReport(report tlsReport) {
|
||||
r.sentReports = append(r.sentReports, sentReport{r: report, t: time.Now()})
|
||||
}
|
||||
|
||||
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
|
||||
var buffer bytes.Buffer
|
||||
for _, cert := range certs {
|
||||
if err := pem.Encode(&buffer, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}); err != nil {
|
||||
logrus.WithField("pkg", "pmapi/tls-pinning").WithError(err).Error("Failed to encode TLS certificate")
|
||||
}
|
||||
pemCerts = append(pemCerts, buffer.String())
|
||||
buffer.Reset()
|
||||
}
|
||||
|
||||
return pemCerts
|
||||
}
|
||||
62
pkg/pmapi/dialer_pinning_reporter_test.go
Normal file
62
pkg/pmapi/dialer_pinning_reporter_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTLSReporter_DoubleReport(t *testing.T) {
|
||||
reportCounter := 0
|
||||
|
||||
reportServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reportCounter++
|
||||
}))
|
||||
|
||||
cfg := Config{
|
||||
AppVersion: "3",
|
||||
UserAgent: "useragent",
|
||||
}
|
||||
r := newTLSReporter(cfg, TrustedAPIPins)
|
||||
|
||||
// Report the same issue many times.
|
||||
for i := 0; i < 10; i++ {
|
||||
r.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{})
|
||||
}
|
||||
|
||||
// We should only report once.
|
||||
assert.Eventually(t, func() bool {
|
||||
return reportCounter == 1
|
||||
}, time.Second, time.Millisecond)
|
||||
|
||||
// If we then report something else many times.
|
||||
for i := 0; i < 10; i++ {
|
||||
r.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{})
|
||||
}
|
||||
|
||||
// We should get a second report.
|
||||
assert.Eventually(t, func() bool {
|
||||
return reportCounter == 2
|
||||
}, time.Second, time.Millisecond)
|
||||
}
|
||||
149
pkg/pmapi/dialer_pinning_test.go
Normal file
149
pkg/pmapi/dialer_pinning_test.go
Normal file
@ -0,0 +1,149 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTLSPinValid(t *testing.T) {
|
||||
called, _, cm := createClientWithPinningDialer(getRootURL())
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint
|
||||
checkTLSIssueHandler(t, 0, called)
|
||||
}
|
||||
|
||||
func TestTLSPinBackup(t *testing.T) {
|
||||
called, dialer, cm := createClientWithPinningDialer(getRootURL())
|
||||
copyTrustedPins(dialer.pinChecker)
|
||||
dialer.pinChecker.trustedPins[1] = dialer.pinChecker.trustedPins[0]
|
||||
dialer.pinChecker.trustedPins[0] = ""
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint
|
||||
checkTLSIssueHandler(t, 0, called)
|
||||
}
|
||||
|
||||
func TestTLSPinInvalid(t *testing.T) {
|
||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSONResponsefromFile(t, w, "/auth/info/post_response.json", 0)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
called, _, cm := createClientWithPinningDialer(ts.URL)
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint
|
||||
checkTLSIssueHandler(t, 1, called)
|
||||
}
|
||||
|
||||
func TestTLSPinNoMatch(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
called, dialer, cm := createClientWithPinningDialer(getRootURL())
|
||||
|
||||
copyTrustedPins(dialer.pinChecker)
|
||||
for i := 0; i < len(dialer.pinChecker.trustedPins); i++ {
|
||||
dialer.pinChecker.trustedPins[i] = "testing"
|
||||
}
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint
|
||||
|
||||
// Check that it will be reported only once per session, but notified every time.
|
||||
r.Equal(t, 1, len(dialer.reporter.sentReports))
|
||||
checkTLSIssueHandler(t, 2, called)
|
||||
}
|
||||
|
||||
func TestTLSSignedCertWrongPublicKey(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
_, dialer, _ := createClientWithPinningDialer("")
|
||||
_, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
|
||||
r.Error(t, err, "expected dial to fail because of wrong public key")
|
||||
}
|
||||
|
||||
func TestTLSSignedCertTrustedPublicKey(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
_, dialer, _ := createClientWithPinningDialer("")
|
||||
copyTrustedPins(dialer.pinChecker)
|
||||
dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`)
|
||||
_, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
|
||||
r.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA")
|
||||
}
|
||||
|
||||
func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
_, dialer, _ := createClientWithPinningDialer("")
|
||||
copyTrustedPins(dialer.pinChecker)
|
||||
dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
|
||||
_, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443")
|
||||
r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
|
||||
}
|
||||
|
||||
func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *manager) {
|
||||
called := 0
|
||||
|
||||
cfg := Config{
|
||||
AppVersion: "Bridge_1.2.4-test",
|
||||
HostURL: hostURL,
|
||||
TLSIssueHandler: func() { called++ },
|
||||
}
|
||||
|
||||
dialer := NewPinningTLSDialer(cfg, NewBasicTLSDialer(cfg))
|
||||
|
||||
cm := newManager(cfg)
|
||||
cm.SetTransport(CreateTransportWithDialer(dialer))
|
||||
|
||||
return &called, dialer, cm
|
||||
}
|
||||
|
||||
func copyTrustedPins(pinChecker *pinChecker) {
|
||||
copiedPins := make([]string, len(pinChecker.trustedPins))
|
||||
copy(copiedPins, pinChecker.trustedPins)
|
||||
pinChecker.trustedPins = copiedPins
|
||||
}
|
||||
|
||||
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) {
|
||||
// TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called.
|
||||
a.Eventually(
|
||||
t,
|
||||
func() bool {
|
||||
if wantCalledAtLeast == 0 {
|
||||
return *called == 0
|
||||
}
|
||||
// Dialer can do more attempts resulting in more calls.
|
||||
return *called >= wantCalledAtLeast
|
||||
},
|
||||
time.Second,
|
||||
10*time.Millisecond,
|
||||
)
|
||||
// Repeated again so it generates nice message.
|
||||
if wantCalledAtLeast == 0 {
|
||||
r.Equal(t, 0, *called)
|
||||
} else {
|
||||
r.GreaterOrEqual(t, *called, wantCalledAtLeast)
|
||||
}
|
||||
}
|
||||
144
pkg/pmapi/dialer_proxy.go
Normal file
144
pkg/pmapi/dialer_proxy.go
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ProxyTLSDialer wraps a TLSDialer to switch to a proxy if the initial dial fails.
|
||||
type ProxyTLSDialer struct {
|
||||
dialer TLSDialer
|
||||
|
||||
locker sync.RWMutex
|
||||
directAddress string
|
||||
proxyAddress string
|
||||
allowProxy bool
|
||||
proxyProvider *proxyProvider
|
||||
proxyUseDuration time.Duration
|
||||
}
|
||||
|
||||
// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer.
|
||||
func NewProxyTLSDialer(cfg Config, dialer TLSDialer) *ProxyTLSDialer {
|
||||
return &ProxyTLSDialer{
|
||||
dialer: dialer,
|
||||
locker: sync.RWMutex{},
|
||||
directAddress: formatAsAddress(cfg.HostURL),
|
||||
proxyAddress: formatAsAddress(cfg.HostURL),
|
||||
proxyProvider: newProxyProvider(cfg, dohProviders, proxyQuery),
|
||||
proxyUseDuration: proxyUseDuration,
|
||||
}
|
||||
}
|
||||
|
||||
// formatAsAddress returns URL as `host:port` for easy comparison in DialTLS.
|
||||
func formatAsAddress(rawURL string) string {
|
||||
url, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
// This means wrong configuration.
|
||||
// Developer should get feedback right away.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
port := "443"
|
||||
if url.Scheme == "http" {
|
||||
port = "80"
|
||||
}
|
||||
return net.JoinHostPort(url.Host, port)
|
||||
}
|
||||
|
||||
// DialTLS dials the given network/address. If it fails, it retries using a proxy.
|
||||
func (d *ProxyTLSDialer) DialTLS(network, address string) (net.Conn, error) {
|
||||
if address == d.directAddress {
|
||||
address = d.proxyAddress
|
||||
}
|
||||
|
||||
conn, err := d.dialer.DialTLS(network, address)
|
||||
if err == nil || !d.allowProxy {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
err = d.switchToReachableServer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d.dialer.DialTLS(network, d.proxyAddress)
|
||||
}
|
||||
|
||||
// switchToReachableServer switches to using a reachable server (either proxy or standard API).
|
||||
func (d *ProxyTLSDialer) switchToReachableServer() error {
|
||||
d.locker.Lock()
|
||||
defer d.locker.Unlock()
|
||||
|
||||
logrus.Info("Attempting to switch to a proxy")
|
||||
|
||||
proxy, err := d.proxyProvider.findReachableServer()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to find a usable proxy")
|
||||
}
|
||||
|
||||
proxyAddress := formatAsAddress(proxy)
|
||||
|
||||
// If the chosen proxy is the standard API, we want to use it but still show the troubleshooting screen.
|
||||
if proxyAddress == d.directAddress {
|
||||
logrus.Info("The standard API is reachable again; connection drop was only intermittent")
|
||||
d.proxyAddress = proxyAddress
|
||||
return ErrNoConnection
|
||||
}
|
||||
|
||||
logrus.WithField("proxy", proxyAddress).Info("Switching to a proxy")
|
||||
|
||||
// If the host is currently the rootURL, it's the first time we are enabling a proxy.
|
||||
// This means we want to disable it again in 24 hours.
|
||||
if d.proxyAddress == d.directAddress {
|
||||
go func() {
|
||||
<-time.After(d.proxyUseDuration)
|
||||
|
||||
d.locker.Lock()
|
||||
defer d.locker.Unlock()
|
||||
|
||||
d.proxyAddress = d.directAddress
|
||||
}()
|
||||
}
|
||||
|
||||
d.proxyAddress = proxyAddress
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowProxy allows the dialer to switch to a proxy if need be.
|
||||
func (d *ProxyTLSDialer) AllowProxy() {
|
||||
d.locker.Lock()
|
||||
defer d.locker.Unlock()
|
||||
|
||||
d.allowProxy = true
|
||||
}
|
||||
|
||||
// DisallowProxy prevents the dialer from switching to a proxy if need be.
|
||||
func (d *ProxyTLSDialer) DisallowProxy() {
|
||||
d.locker.Lock()
|
||||
defer d.locker.Unlock()
|
||||
|
||||
d.allowProxy = false
|
||||
d.proxyAddress = d.directAddress
|
||||
}
|
||||
249
pkg/pmapi/dialer_proxy_provider.go
Normal file
249
pkg/pmapi/dialer_proxy_provider.go
Normal file
@ -0,0 +1,249 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
proxyUseDuration = 24 * time.Hour
|
||||
proxyLookupWait = 5 * time.Second
|
||||
proxyCacheRefreshTimeout = 20 * time.Second
|
||||
proxyDoHTimeout = 20 * time.Second
|
||||
proxyCanReachTimeout = 20 * time.Second
|
||||
proxyQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz"
|
||||
)
|
||||
|
||||
var dohProviders = []string{ //nolint[gochecknoglobals]
|
||||
"https://dns11.quad9.net/dns-query",
|
||||
"https://dns.google/dns-query",
|
||||
}
|
||||
|
||||
// proxyProvider manages known proxies.
|
||||
type proxyProvider struct {
|
||||
cfg Config
|
||||
|
||||
// dohLookup is used to look up the given query at the given DoH provider, returning the TXT records>
|
||||
dohLookup func(ctx context.Context, query, provider string) (urls []string, err error)
|
||||
|
||||
providers []string // List of known doh providers.
|
||||
query string // The query string used to find proxies.
|
||||
proxyCache []string // All known proxies, cached in case DoH providers are unreachable.
|
||||
|
||||
cacheRefreshTimeout time.Duration
|
||||
dohTimeout time.Duration
|
||||
canReachTimeout time.Duration
|
||||
|
||||
lastLookup time.Time // The time at which we last attempted to find a proxy.
|
||||
}
|
||||
|
||||
// newProxyProvider creates a new proxyProvider that queries the given DoH providers
|
||||
// to retrieve DNS records for the given query string.
|
||||
func newProxyProvider(cfg Config, providers []string, query string) (p *proxyProvider) { // nolint[unparam]
|
||||
p = &proxyProvider{
|
||||
cfg: cfg,
|
||||
providers: providers,
|
||||
query: query,
|
||||
cacheRefreshTimeout: proxyCacheRefreshTimeout,
|
||||
dohTimeout: proxyDoHTimeout,
|
||||
canReachTimeout: proxyCanReachTimeout,
|
||||
}
|
||||
|
||||
// Use the default DNS lookup method; this can be overridden if necessary.
|
||||
p.dohLookup = p.defaultDoHLookup
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// findReachableServer returns a working API server (either proxy or standard API).
|
||||
func (p *proxyProvider) findReachableServer() (proxy string, err error) {
|
||||
log.Debug("Trying to find a reachable server")
|
||||
|
||||
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
|
||||
return "", errors.New("not looking for a proxy, too soon")
|
||||
}
|
||||
|
||||
p.lastLookup = time.Now()
|
||||
|
||||
// We use a waitgroup to wait for both
|
||||
// a) the check whether the API is reachable, and
|
||||
// b) the DoH queries.
|
||||
// This is because the Alternative Routes v2 spec says:
|
||||
// Call the GET /test/ping route on normal API domain (same time as DoH requests and wait until all have finished)
|
||||
var wg sync.WaitGroup
|
||||
var apiReachable bool
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
apiReachable = p.canReach(p.cfg.HostURL)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err = p.refreshProxyCache()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if apiReachable {
|
||||
proxy = p.cfg.HostURL
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, url := range p.proxyCache {
|
||||
if p.canReach(url) {
|
||||
proxy = url
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("no reachable server could be found")
|
||||
}
|
||||
|
||||
// refreshProxyCache loads the latest proxies from the known providers.
|
||||
// If the process takes longer than proxyCacheRefreshTimeout, an error is returned.
|
||||
func (p *proxyProvider) refreshProxyCache() error {
|
||||
log.Info("Refreshing proxy cache")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.cacheRefreshTimeout)
|
||||
defer cancel()
|
||||
|
||||
resultChan := make(chan []string)
|
||||
|
||||
go func() {
|
||||
for _, provider := range p.providers {
|
||||
if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil {
|
||||
resultChan <- proxies
|
||||
return
|
||||
}
|
||||
}
|
||||
// If no dohLoopkup worked, cancel right after it's done to not
|
||||
// block refreshing for the whole cacheRefreshTimeout.
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
p.proxyCache = result
|
||||
return nil
|
||||
|
||||
case <-ctx.Done():
|
||||
return errors.New("timed out while refreshing proxy cache")
|
||||
}
|
||||
}
|
||||
|
||||
// canReach returns whether we can reach the given url.
|
||||
func (p *proxyProvider) canReach(url string) bool {
|
||||
log.WithField("url", url).Debug("Trying to ping proxy")
|
||||
|
||||
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
|
||||
url = "https://" + url
|
||||
}
|
||||
|
||||
dialer := NewPinningTLSDialer(p.cfg, NewBasicTLSDialer(p.cfg))
|
||||
|
||||
pinger := resty.New().
|
||||
SetHostURL(url).
|
||||
SetTimeout(p.canReachTimeout).
|
||||
SetTransport(CreateTransportWithDialer(dialer))
|
||||
|
||||
if _, err := pinger.R().Get("/tests/ping"); err != nil {
|
||||
log.WithField("proxy", url).WithError(err).Warn("Failed to ping proxy")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// defaultDoHLookup is the default implementation of the proxy manager's DoH lookup.
|
||||
// It looks up DNS TXT records for the given query URL using the given DoH provider.
|
||||
// It returns a list of all found TXT records.
|
||||
// If the whole process takes more than proxyDoHTimeout then an error is returned.
|
||||
func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider string) (data []string, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, p.dohTimeout)
|
||||
defer cancel()
|
||||
|
||||
dataChan, errChan := make(chan []string), make(chan error)
|
||||
|
||||
go func() {
|
||||
// Build new DNS request in RFC1035 format.
|
||||
dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT)
|
||||
|
||||
// Pack the DNS request message into wire format.
|
||||
rawRequest, err := dnsRequest.Pack()
|
||||
if err != nil {
|
||||
errChan <- errors.Wrap(err, "failed to pack DNS request")
|
||||
return
|
||||
}
|
||||
|
||||
// Encode wire-format DNS request message as base64url (RFC4648) without padding chars.
|
||||
encodedRequest := base64.RawURLEncoding.EncodeToString(rawRequest)
|
||||
|
||||
// Make DoH request to the given DoH provider.
|
||||
rawResponse, err := resty.New().R().SetContext(ctx).SetQueryParam("dns", encodedRequest).Get(dohProvider)
|
||||
if err != nil {
|
||||
errChan <- errors.Wrap(err, "failed to make DoH request")
|
||||
return
|
||||
}
|
||||
|
||||
// Unpack the DNS response.
|
||||
dnsResponse := new(dns.Msg)
|
||||
if err = dnsResponse.Unpack(rawResponse.Body()); err != nil {
|
||||
errChan <- errors.Wrap(err, "failed to unpack DNS response")
|
||||
return
|
||||
}
|
||||
|
||||
// Pick out the TXT answers.
|
||||
for _, answer := range dnsResponse.Answer {
|
||||
if t, ok := answer.(*dns.TXT); ok {
|
||||
data = append(data, t.Txt...)
|
||||
}
|
||||
}
|
||||
|
||||
dataChan <- data
|
||||
}()
|
||||
|
||||
select {
|
||||
case data = <-dataChan:
|
||||
log.WithField("data", data).Info("Received TXT records")
|
||||
return
|
||||
|
||||
case err = <-errChan:
|
||||
log.WithField("provider", dohProvider).WithError(err).Error("Failed to query DNS records")
|
||||
return
|
||||
|
||||
case <-ctx.Done():
|
||||
log.WithField("provider", dohProvider).Error("Timed out querying DNS records")
|
||||
return []string{}, errors.New("timed out querying DNS records")
|
||||
}
|
||||
}
|
||||
187
pkg/pmapi/dialer_proxy_provider_test.go
Normal file
187
pkg/pmapi/dialer_proxy_provider_test.go
Normal file
@ -0,0 +1,187 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
r "github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/http/httpproxy"
|
||||
)
|
||||
|
||||
const (
|
||||
TestDoHQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz"
|
||||
TestQuad9Provider = "https://dns11.quad9.net/dns-query"
|
||||
TestGoogleProvider = "https://dns.google/dns-query"
|
||||
)
|
||||
|
||||
func TestProxyProvider_FindProxy(t *testing.T) {
|
||||
proxy := getTrustedServer()
|
||||
defer closeServer(proxy)
|
||||
|
||||
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, proxy.URL, url)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
|
||||
reachableProxy := getTrustedServer()
|
||||
defer closeServer(reachableProxy)
|
||||
|
||||
// We actually close the unreachable proxy straight away rather than deferring the closure.
|
||||
unreachableProxy := getTrustedServer()
|
||||
closeServer(unreachableProxy)
|
||||
|
||||
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{reachableProxy.URL, unreachableProxy.URL}, nil
|
||||
}
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, reachableProxy.URL, url)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
defer closeServer(trustedProxy)
|
||||
|
||||
untrustedProxy := getUntrustedServer()
|
||||
defer closeServer(untrustedProxy)
|
||||
|
||||
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{untrustedProxy.URL, trustedProxy.URL}, nil
|
||||
}
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, trustedProxy.URL, url)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
|
||||
unreachableProxy1 := getTrustedServer()
|
||||
closeServer(unreachableProxy1)
|
||||
|
||||
unreachableProxy2 := getTrustedServer()
|
||||
closeServer(unreachableProxy2)
|
||||
|
||||
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil
|
||||
}
|
||||
|
||||
_, err := p.findReachableServer()
|
||||
r.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
|
||||
untrustedProxy1 := getUntrustedServer()
|
||||
defer closeServer(untrustedProxy1)
|
||||
|
||||
untrustedProxy2 := getUntrustedServer()
|
||||
defer closeServer(untrustedProxy2)
|
||||
|
||||
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil
|
||||
}
|
||||
|
||||
_, err := p.findReachableServer()
|
||||
r.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) {
|
||||
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
|
||||
p.cacheRefreshTimeout = 1 * time.Second
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
|
||||
|
||||
// We should fail to refresh the proxy cache because the doh provider
|
||||
// takes 2 seconds to respond but we timeout after just 1 second.
|
||||
_, err := p.findReachableServer()
|
||||
|
||||
r.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) {
|
||||
slowProxy := getTrustedServerWithHandler(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||
time.Sleep(2 * time.Second)
|
||||
}))
|
||||
defer closeServer(slowProxy)
|
||||
|
||||
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
|
||||
p.canReachTimeout = 1 * time.Second
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
|
||||
|
||||
// We should fail to reach the returned proxy because it takes 2 seconds
|
||||
// to reach it and we only allow 1.
|
||||
_, err := p.findReachableServer()
|
||||
|
||||
r.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
|
||||
p := newProxyProvider(Config{}, []string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
|
||||
records, err := p.dohLookup(context.Background(), TestDoHQuery, TestQuad9Provider)
|
||||
r.NoError(t, err)
|
||||
r.NotEmpty(t, records)
|
||||
}
|
||||
|
||||
func TestProxyProvider_DoHLookup_Google(t *testing.T) {
|
||||
p := newProxyProvider(Config{}, []string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
|
||||
records, err := p.dohLookup(context.Background(), TestDoHQuery, TestGoogleProvider)
|
||||
r.NoError(t, err)
|
||||
r.NotEmpty(t, records)
|
||||
}
|
||||
|
||||
func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
p := newProxyProvider(Config{}, []string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
r.NotEmpty(t, url)
|
||||
}
|
||||
|
||||
func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
p := newProxyProvider(Config{}, []string{"https://unreachable", TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
r.NotEmpty(t, url)
|
||||
}
|
||||
|
||||
// skipIfProxyIsSet skips the tests if HTTPS proxy is set.
|
||||
// Should be used for tests depending on proper certificate checks which
|
||||
// is not possible under our CI setup.
|
||||
func skipIfProxyIsSet(t *testing.T) {
|
||||
if httpproxy.FromEnvironment().HTTPSProxy != "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
}
|
||||
253
pkg/pmapi/dialer_proxy_test.go
Normal file
253
pkg/pmapi/dialer_proxy_test.go
Normal file
@ -0,0 +1,253 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// getTrustedServer returns a server and sets its public key as one of the pinned ones.
|
||||
func getTrustedServer() *httptest.Server {
|
||||
return getTrustedServerWithHandler(
|
||||
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||
// Do nothing.
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
func getTrustedServerWithHandler(handler http.HandlerFunc) *httptest.Server {
|
||||
proxy := httptest.NewTLSServer(handler)
|
||||
|
||||
pin := certFingerprint(proxy.Certificate())
|
||||
TrustedAPIPins = append(TrustedAPIPins, pin)
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
||||
const servercrt = `
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIE5TCCA82gAwIBAgIJAKsmhcMFGfGcMA0GCSqGSIb3DQEBCwUAMIGsMQswCQYD
|
||||
VQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzARBgNVBAcMClJhbmRvbUNp
|
||||
dHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEfMB0GA1UECwwWUmFuZG9t
|
||||
T3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYRaGVsbG9AZXhhbXBsZS5j
|
||||
b20xEjAQBgNVBAMMCTEyNy4wLjAuMTAeFw0yMDA0MjQxMzI3MzdaFw0yMTA5MDYx
|
||||
MzI3MzdaMIGsMQswCQYDVQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzAR
|
||||
BgNVBAcMClJhbmRvbUNpdHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEf
|
||||
MB0GA1UECwwWUmFuZG9tT3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYR
|
||||
aGVsbG9AZXhhbXBsZS5jb20xEjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZI
|
||||
hvcNAQEBBQADggEPADCCAQoCggEBANAnYyqhosWwNzGjBwSwmDUINOaPs4TSTgKt
|
||||
r6CE01atxAWzWUCyYqnQ4fPe5q2tx5t/VrmnTNpzycammKJszGLlmj9DFxSiYVw2
|
||||
pTTK3DBWFkfTwxq98mM7wMnCWy1T2L2pmuYjnd7Pa6pQa9OHYoJwRzlIl2Q3YVdM
|
||||
GIBDbkW728A1dcelkIdFpv3r3ayTZv01vU8JMXd4PLHwXU0x0hHlH52+kx+9Ndru
|
||||
rdqqV6LqVfNlSR1jFZkwLBBqvh3XrJRD9Q01EAX6m+ufZ0yq8mK9ifMRtwQet10c
|
||||
kKMnx63MwvxDFmqrBj4HMtIRUpK+LBDs1ke7DvS0eLqaojWl28ECAwEAAaOCAQYw
|
||||
ggECMIHLBgNVHSMEgcMwgcChgbKkga8wgawxCzAJBgNVBAYTAlVTMRQwEgYDVQQI
|
||||
DAtSYW5kb21TdGF0ZTETMBEGA1UEBwwKUmFuZG9tQ2l0eTEbMBkGA1UECgwSUmFu
|
||||
ZG9tT3JnYW5pemF0aW9uMR8wHQYDVQQLDBZSYW5kb21Pcmdhbml6YXRpb25Vbml0
|
||||
MSAwHgYJKoZIhvcNAQkBFhFoZWxsb0BleGFtcGxlLmNvbTESMBAGA1UEAwwJMTI3
|
||||
LjAuMC4xggkAvCxbs152YckwCQYDVR0TBAIwADALBgNVHQ8EBAMCBPAwGgYDVR0R
|
||||
BBMwEYIJMTI3LjAuMC4xhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAC7ZycZMZ5
|
||||
L+cjIpwSj0cemLkVD+kcFUCkI7ket5gbX1PmavmnpuFl9Sru0eJ5wyJ+97MQElPA
|
||||
CNFgXoX7DbJWkcd/LSksvZoJnpc1sTqFKMWFmOUxmUD62lCacuhqE27ZTThQ/53P
|
||||
3doLa74rKzUqlPI8OL4R34FY2deL7t5l2KSnpf7CKNeF5bkinAsn6NBqyZs2KPmg
|
||||
yT1/POdlRewzGSqBTMdktNQ4vKSfdFjcfVeo8PSHBgbGXZ5KoHZ6R6DNJehEh27l
|
||||
z3OteROLGoii+w3OllLq6JATif2MDIbH0s/KjGjbXSSGbM/rZu5eBZm5/vksGAzc
|
||||
u53wgIhCJGuX
|
||||
-----END CERTIFICATE-----
|
||||
`
|
||||
|
||||
const serverkey = `
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDQJ2MqoaLFsDcx
|
||||
owcEsJg1CDTmj7OE0k4Cra+ghNNWrcQFs1lAsmKp0OHz3uatrcebf1a5p0zac8nG
|
||||
ppiibMxi5Zo/QxcUomFcNqU0ytwwVhZH08MavfJjO8DJwlstU9i9qZrmI53ez2uq
|
||||
UGvTh2KCcEc5SJdkN2FXTBiAQ25Fu9vANXXHpZCHRab9692sk2b9Nb1PCTF3eDyx
|
||||
8F1NMdIR5R+dvpMfvTXa7q3aqlei6lXzZUkdYxWZMCwQar4d16yUQ/UNNRAF+pvr
|
||||
n2dMqvJivYnzEbcEHrddHJCjJ8etzML8QxZqqwY+BzLSEVKSviwQ7NZHuw70tHi6
|
||||
mqI1pdvBAgMBAAECggEAOqqPOYm63arPs462QK0hCPlaJ41i1FGNqRWYxU4KXoi1
|
||||
EcI9qo1cX24+8MPnEhZDhuD56XNsprkxqmpz5Htzk4AQ3DmlfKxTcnD4WQu/yWPJ
|
||||
/c6CU7wrX6qMqJC9r+XM1Y/C15A8Q3sEZkkqSsECk67fdBawjI9LQRZyZVwb7U0F
|
||||
qtvbKM7VQA6hrgdSmXWJ+spp5yymVFF22Ssz31SSbCI93bnp3mukRCKWdRmA9pmT
|
||||
VXa0HzJ5p70WC+Se9nA/1riWGKt4HCmjVeEtZuiwaUTlXDSeYpu2e4QrX1OnUXBu
|
||||
Z7yfviTqA8o7KfiA6urumFbAMJcibxkWJoWacc5tTQKBgQD39ZdtNz8B6XJy7f5h
|
||||
bo9Ag9OrkVX+HITQyWKpcCDba9SuIX3/F++2AK4oeJ3aHKMJWiP19hQvGS1xE67X
|
||||
TKejOsQxORn6nAYQpFd3AOBOtKAC+VQITBqlfq2ukGmvcQ1O31hMOFbZagFA5cpU
|
||||
LYb9VVDsZzhM7CccIn/EGEZjgwKBgQDW51rUA2S9naV/iEGhw1tuhoQ5OADD/n8f
|
||||
pPIkbGxmACDaX/7jt+UwlDU0EsI+aBlJUDqGiEZ5z3UPmaSJUdfRCeJEdKIe1GLm
|
||||
nqF3sF6Aq+S/79v/wKYn+MHcoiWog5n3McLzZ3+0rwrhMREjE2eWPwVHz/jJIFP3
|
||||
Pp3+UZVsawKBgB4Az5PdjXgzwS968L7lW9wYl3I5Iciftsp0s8WA1dj3EUMItnA5
|
||||
ez3wkyI+hgswT+H/0D4gyoxwZXk7Qnq2wcoUgEzcdfJHEszMtfCmYH3liT8S4EIo
|
||||
w0inLWjj/IXIDi4vBEYkww2HsCMkKvlIkP7yZdpVGxDjuk/DNOaLcWj1AoGAXuyK
|
||||
PiPRl7/Onmp9MwqrlEJunSeTjv8W/89H9ba+mr9rw4mreMJ9xdtxNLMkgZRRtwRt
|
||||
FYeUObHdLyradp1kCr2m6D3sblm55cwj3k5VL9i9jdpQ/sMFoZpLZz1oDOs0Uu/0
|
||||
ALeyvQikcZvOygOEOeVUW8gNSCmzbP6HoxI+QkkCgYBCI6oL4GPcPPqzd+2djbOD
|
||||
z3rVUyHzYc1KUcBixK/uaRQKM886k4CL8/GvbHHI/yoZ7xWJGnBi59DtpqnGTZJ2
|
||||
FDJwYIlQKhZmsyVcZu/4smsaejGnHn/liksVlgesSwCtOrsd2AC8fBXSyrTWJx8o
|
||||
vwRMog6lPhlRhHh/FZ43Cg==
|
||||
-----END PRIVATE KEY-----
|
||||
`
|
||||
|
||||
// getUntrustedServer returns a server but it doesn't add its public key to the list of pinned ones.
|
||||
func getUntrustedServer() *httptest.Server {
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
cert, err := tls.X509KeyPair([]byte(servercrt), []byte(serverkey))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
|
||||
server.StartTLS()
|
||||
return server
|
||||
}
|
||||
|
||||
// closeServer closes the given server. If it is a trusted server, its cert is removed from the trusted public keys.
|
||||
func closeServer(server *httptest.Server) {
|
||||
pin := certFingerprint(server.Certificate())
|
||||
|
||||
for i := range TrustedAPIPins {
|
||||
if TrustedAPIPins[i] == pin {
|
||||
TrustedAPIPins = append(TrustedAPIPins[:i], TrustedAPIPins[i:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
server.Close()
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
defer closeServer(trustedProxy)
|
||||
|
||||
cfg := Config{HostURL: ""}
|
||||
d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg))
|
||||
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy_MultipleTimes(t *testing.T) {
|
||||
proxy1 := getTrustedServer()
|
||||
defer closeServer(proxy1)
|
||||
proxy2 := getTrustedServer()
|
||||
defer closeServer(proxy2)
|
||||
proxy3 := getTrustedServer()
|
||||
defer closeServer(proxy3)
|
||||
|
||||
cfg := Config{HostURL: ""}
|
||||
d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg))
|
||||
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
|
||||
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress)
|
||||
|
||||
// Have to wait so as to not get rejected.
|
||||
time.Sleep(proxyLookupWait)
|
||||
|
||||
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy2.URL}, nil }
|
||||
err = d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress)
|
||||
|
||||
// Have to wait so as to not get rejected.
|
||||
time.Sleep(proxyLookupWait)
|
||||
|
||||
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy3.URL}, nil }
|
||||
err = d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(proxy3.URL), d.proxyAddress)
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
defer closeServer(trustedProxy)
|
||||
|
||||
cfg := Config{HostURL: ""}
|
||||
d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg))
|
||||
d.proxyUseDuration = time.Second
|
||||
|
||||
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
require.Equal(t, ":443", d.proxyAddress)
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
|
||||
cfg := Config{HostURL: ""}
|
||||
d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg))
|
||||
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
|
||||
|
||||
// Simulate that the proxy stops working and that the standard api is reachable again.
|
||||
closeServer(trustedProxy)
|
||||
d.directAddress = formatAsAddress(getRootURL())
|
||||
d.proxyProvider.cfg.HostURL = getRootURL()
|
||||
time.Sleep(proxyLookupWait)
|
||||
|
||||
// We should now find the original API URL if it is working again.
|
||||
// The error should be ErrAPINotReachable because the connection dropped intermittently but
|
||||
// the original API is now reachable (see Alternative-Routing-v2 spec for details).
|
||||
err = d.switchToReachableServer()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, formatAsAddress(getRootURL()), d.proxyAddress)
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
|
||||
// proxy1 is closed later in this test so we don't defer it here.
|
||||
proxy1 := getTrustedServer()
|
||||
|
||||
proxy2 := getTrustedServer()
|
||||
defer closeServer(proxy2)
|
||||
|
||||
cfg := Config{HostURL: ""}
|
||||
d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg))
|
||||
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
|
||||
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress)
|
||||
|
||||
// Have to wait so as to not get rejected.
|
||||
time.Sleep(proxyLookupWait)
|
||||
|
||||
// The proxy stops working and the protonmail API is still blocked.
|
||||
closeServer(proxy1)
|
||||
|
||||
// Should switch to the second proxy because both the first proxy and the protonmail API are blocked.
|
||||
err = d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress)
|
||||
}
|
||||
@ -1,9 +1,37 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNoConnection = errors.New("no internet connection")
|
||||
ErrAPIFailure = errors.New("API returned an error")
|
||||
ErrUnauthorized = errors.New("API client is unauthorized")
|
||||
ErrNoConnection = errors.New("no internet connection")
|
||||
ErrUnauthorized = errors.New("API client is unauthorized")
|
||||
ErrUpgradeApplication = errors.New("application upgrade required")
|
||||
|
||||
ErrBad2FACode = errors.New("incorrect 2FA code")
|
||||
ErrBad2FACodeTryAgain = errors.New("incorrect 2FA code: please try again")
|
||||
)
|
||||
|
||||
type ErrUnprocessableEntity struct {
|
||||
originalError error
|
||||
}
|
||||
|
||||
func (err ErrUnprocessableEntity) Error() string {
|
||||
return err.originalError.Error()
|
||||
}
|
||||
|
||||
@ -32,7 +32,7 @@ type Event struct {
|
||||
// If set to one, all cached data must be fetched again.
|
||||
Refresh int
|
||||
// If set to one, fetch more events.
|
||||
More int
|
||||
More Boolean
|
||||
// Changes applied to messages.
|
||||
Messages []*EventMessage
|
||||
// Counts of messages per labels.
|
||||
@ -167,26 +167,32 @@ type EventAddress struct {
|
||||
|
||||
// GetEvent returns a summary of events that occurred since last. To get the latest event,
|
||||
// provide an empty last value. The latest event is always empty.
|
||||
func (c *client) GetEvent(ctx context.Context, eventID string) (event *Event, err error) {
|
||||
func (c *client) GetEvent(ctx context.Context, eventID string) (*Event, error) {
|
||||
return c.getEvent(ctx, eventID, 1)
|
||||
}
|
||||
|
||||
func (c *client) getEvent(ctx context.Context, eventID string, numberOfMergedEvents int) (*Event, error) {
|
||||
if eventID == "" {
|
||||
eventID = "latest"
|
||||
}
|
||||
|
||||
var res struct {
|
||||
*Event
|
||||
|
||||
More int
|
||||
}
|
||||
var event *Event
|
||||
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetResult(&res).Get("/events/" + eventID)
|
||||
return r.SetResult(&event).Get("/events/" + eventID)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// FIXME(conman): use mergeEvents() function.
|
||||
if event.More && numberOfMergedEvents < maxNumberOfMergedEvents {
|
||||
nextEvent, err := c.getEvent(ctx, event.EventID, numberOfMergedEvents+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
event = mergeEvents(event, nextEvent)
|
||||
}
|
||||
|
||||
return res.Event, nil
|
||||
return event, nil
|
||||
}
|
||||
|
||||
// mergeEvents combines an old events and a new events object.
|
||||
|
||||
@ -27,13 +27,12 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClient_GetEvent(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.NoError(t, checkMethodAndPath(r, "GET", "/events/latest"))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "GET", "/events/latest"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@ -41,14 +40,14 @@ func TestClient_GetEvent(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
event, err := c.GetEvent(context.TODO(), "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testEvent, event)
|
||||
event, err := c.GetEvent(context.Background(), "")
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, testEvent, event)
|
||||
}
|
||||
|
||||
func TestClient_GetEvent_withID(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.NoError(t, checkMethodAndPath(r, "GET", "/events/"+testEvent.EventID))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "GET", "/events/"+testEvent.EventID))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@ -56,23 +55,22 @@ func TestClient_GetEvent_withID(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
event, err := c.GetEvent(context.TODO(), testEvent.EventID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testEvent, event)
|
||||
event, err := c.GetEvent(context.Background(), testEvent.EventID)
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, testEvent, event)
|
||||
}
|
||||
|
||||
// We first call GetEvent with id of eventID1, which returns More=1 so we fetch with id eventID2.
|
||||
// FIXME(conman): Merging is currently not supported. Implement it and then enable this test again!
|
||||
func _TestClient_GetEvent_mergeEvents(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
func TestClient_GetEvent_mergeEvents(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
switch r.URL.RequestURI() {
|
||||
switch req.URL.RequestURI() {
|
||||
case "/events/eventID1":
|
||||
assert.NoError(t, checkMethodAndPath(r, "GET", "/events/eventID1"))
|
||||
r.NoError(t, checkMethodAndPath(req, "GET", "/events/eventID1"))
|
||||
fmt.Fprint(w, testEventBodyMore1)
|
||||
case "/events/eventID2":
|
||||
assert.NoError(t, checkMethodAndPath(r, "GET", "/events/eventID2"))
|
||||
r.NoError(t, checkMethodAndPath(req, "GET", "/events/eventID2"))
|
||||
fmt.Fprint(w, testEventBodyMore2)
|
||||
default:
|
||||
t.Fail()
|
||||
@ -80,29 +78,26 @@ func _TestClient_GetEvent_mergeEvents(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
event, err := c.GetEvent(context.TODO(), "eventID1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testEventMerged, event)
|
||||
event, err := c.GetEvent(context.Background(), "eventID1")
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, testEventMerged, event)
|
||||
}
|
||||
|
||||
// FIXME(conman): Merging is currently not supported. Implement it and then enable this test again!
|
||||
func _TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) {
|
||||
func TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) {
|
||||
numberOfCalls := 0
|
||||
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
numberOfCalls++
|
||||
|
||||
re := regexp.MustCompile(`/eventID([0-9]+)`)
|
||||
eventIDString := re.FindStringSubmatch(r.URL.RequestURI())[1]
|
||||
eventIDString := re.FindStringSubmatch(req.URL.RequestURI())[1]
|
||||
eventID, err := strconv.Atoi(eventIDString)
|
||||
require.NoError(t, err)
|
||||
r.NoError(t, err)
|
||||
|
||||
if numberOfCalls > maxNumberOfMergedEvents*2 {
|
||||
require.Fail(t, "Too many calls!")
|
||||
r.Fail(t, "Too many calls!")
|
||||
}
|
||||
|
||||
fmt.Println("")
|
||||
|
||||
body := strings.ReplaceAll(testEventBodyMore1, "eventID2", "eventID"+strconv.Itoa(eventID+1))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@ -110,14 +105,14 @@ func _TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
event, err := c.GetEvent(context.TODO(), "eventID1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, maxNumberOfMergedEvents, numberOfCalls)
|
||||
require.Equal(t, 1, event.More)
|
||||
event, err := c.GetEvent(context.Background(), "eventID1")
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, maxNumberOfMergedEvents, numberOfCalls)
|
||||
r.True(t, bool(event.More))
|
||||
}
|
||||
|
||||
var (
|
||||
testEventMessageUpdateUnread = False
|
||||
testEventMessageUpdateUnread = Boolean(false)
|
||||
|
||||
testEvent = &Event{
|
||||
EventID: "eventID1",
|
||||
|
||||
@ -37,9 +37,8 @@ type ImportMsgReq struct {
|
||||
type ImportMsgReqs []*ImportMsgReq
|
||||
|
||||
func (reqs ImportMsgReqs) buildMultipartFormData() ([]*resty.MultipartField, error) {
|
||||
var fields []*resty.MultipartField
|
||||
|
||||
metadata := make(map[string]*ImportMetadata)
|
||||
metadata := make(map[string]*ImportMetadata, len(reqs))
|
||||
fields := make([]*resty.MultipartField, 0, len(reqs))
|
||||
|
||||
for i, req := range reqs {
|
||||
name := strconv.Itoa(i)
|
||||
@ -68,7 +67,6 @@ func (reqs ImportMsgReqs) buildMultipartFormData() ([]*resty.MultipartField, err
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
// TODO: Add other metadata.
|
||||
type ImportMetadata struct {
|
||||
AddressID string
|
||||
Unread Boolean // 0: read, 1: unread.
|
||||
@ -114,7 +112,7 @@ func (c *client) Import(ctx context.Context, reqs ImportMsgReqs) ([]*ImportMsgRe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resps []*ImportMsgRes
|
||||
resps := make([]*ImportMsgRes, 0, len(res.Responses))
|
||||
|
||||
for _, resp := range res.Responses {
|
||||
var err error
|
||||
|
||||
@ -25,17 +25,17 @@ import (
|
||||
"io/ioutil"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
pmmime "github.com/ProtonMail/proton-bridge/pkg/mime"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testImportReqs = []*ImportMsgReq{
|
||||
{
|
||||
Metadata: &ImportMetadata{
|
||||
AddressID: "QMJs2dzTx7uqpH5PNgIzjULywU4gO9uMBhEMVFOAVJOoUml54gC0CCHtW9qYwzH-zYbZwMv3MFYncPjW1Usq7Q==",
|
||||
Unread: 0,
|
||||
Unread: Boolean(false),
|
||||
Flags: FlagReceived | FlagImported,
|
||||
LabelIDs: []string{ArchiveLabel},
|
||||
},
|
||||
@ -57,86 +57,52 @@ var testImportRes = &ImportMsgRes{
|
||||
}
|
||||
|
||||
func TestClient_Import(t *testing.T) { // nolint[funlen]
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/mail/v4/messages/import"))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/mail/v4/messages/import"))
|
||||
|
||||
contentType, params, err := pmmime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
t.Error("Expected no error while parsing request content type, got:", err)
|
||||
}
|
||||
if contentType != "multipart/form-data" {
|
||||
t.Errorf("Invalid request content type: expected %v but got %v", "multipart/form-data", contentType)
|
||||
}
|
||||
contentType, params, err := pmmime.ParseMediaType(req.Header.Get("Content-Type"))
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, "multipart/form-data", contentType)
|
||||
|
||||
mr := multipart.NewReader(r.Body, params["boundary"])
|
||||
mr := multipart.NewReader(req.Body, params["boundary"])
|
||||
|
||||
// First part is message body.
|
||||
p, err := mr.NextPart()
|
||||
if err != nil {
|
||||
t.Error("Expected no error while reading second part of request body, got:", err)
|
||||
}
|
||||
r.NoError(t, err)
|
||||
|
||||
contentDisp, params, err := pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
|
||||
if err != nil {
|
||||
t.Error("Expected no error while parsing part content disposition, got:", err)
|
||||
}
|
||||
if contentDisp != "form-data" {
|
||||
t.Errorf("Invalid part content disposition: expected %v but got %v", "form-data", contentType)
|
||||
}
|
||||
if params["name"] != "0" {
|
||||
t.Errorf("Invalid part name: expected %v but got %v", "0", params["name"])
|
||||
}
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, "form-data", contentDisp)
|
||||
r.Equal(t, "0", params["name"])
|
||||
|
||||
b, err := ioutil.ReadAll(p)
|
||||
if err != nil {
|
||||
t.Error("Expected no error while reading second part body, got:", err)
|
||||
}
|
||||
|
||||
if string(b) != string(testImportReqs[0].Message) {
|
||||
t.Errorf("Invalid message body: expected %v but got %v", string(testImportReqs[0].Message), string(b))
|
||||
}
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, string(testImportReqs[0].Message), string(b))
|
||||
|
||||
// Second part is metadata.
|
||||
p, err = mr.NextPart()
|
||||
if err != nil {
|
||||
t.Error("Expected no error while reading first part of request body, got:", err)
|
||||
}
|
||||
r.NoError(t, err)
|
||||
|
||||
contentDisp, params, err = pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
|
||||
if err != nil {
|
||||
t.Error("Expected no error while parsing part content disposition, got:", err)
|
||||
}
|
||||
if contentDisp != "form-data" {
|
||||
t.Errorf("Invalid part content disposition: expected %v but got %v", "form-data", contentType)
|
||||
}
|
||||
if params["name"] != "Metadata" {
|
||||
t.Errorf("Invalid part name: expected %v but got %v", "Metadata", params["name"])
|
||||
}
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, "form-data", contentDisp)
|
||||
r.Equal(t, "Metadata", params["name"])
|
||||
|
||||
metadata := map[string]*ImportMetadata{}
|
||||
if err := json.NewDecoder(p).Decode(&metadata); err != nil {
|
||||
t.Error("Expected no error while parsing metadata json, got:", err)
|
||||
}
|
||||
err = json.NewDecoder(p).Decode(&metadata)
|
||||
r.NoError(t, err)
|
||||
|
||||
if len(metadata) != 1 {
|
||||
t.Errorf("Expected metadata to contain exactly one item, got %v", metadata)
|
||||
}
|
||||
r.Equal(t, 1, len(metadata))
|
||||
|
||||
req := metadata["0"]
|
||||
if metadata["0"] == nil {
|
||||
t.Errorf("Expected metadata to contain one item indexed by 0, got %v", metadata)
|
||||
}
|
||||
importReq := metadata["0"]
|
||||
r.NotNil(t, req)
|
||||
|
||||
expected := *testImportReqs[0].Metadata
|
||||
if !reflect.DeepEqual(&expected, req) {
|
||||
t.Errorf("Invalid message metadata: expected %v, got %v", &expected, req)
|
||||
}
|
||||
r.Equal(t, &expected, importReq)
|
||||
|
||||
// No more parts.
|
||||
_, err = mr.NextPart()
|
||||
if err != io.EOF {
|
||||
t.Error("Expected no more parts but error was not EOF, got:", err)
|
||||
}
|
||||
r.EqualError(t, err, io.EOF.Error())
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@ -144,16 +110,8 @@ func TestClient_Import(t *testing.T) { // nolint[funlen]
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
imported, err := c.Import(context.TODO(), testImportReqs)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while importing, got:", err)
|
||||
}
|
||||
|
||||
if len(imported) != 1 {
|
||||
t.Fatalf("Expected exactly one imported message, got %v", len(imported))
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(testImportRes, imported[0]) {
|
||||
t.Errorf("Invalid response for imported message: expected %+v but got %+v", testImportRes, imported[0])
|
||||
}
|
||||
imported, err := c.Import(context.Background(), testImportReqs)
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, 1, len(imported))
|
||||
r.Equal(t, testImportRes, imported[0])
|
||||
}
|
||||
|
||||
@ -19,7 +19,6 @@ package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
@ -44,8 +43,6 @@ const (
|
||||
|
||||
// GetPublicKeysForEmail returns all sending public keys for the given email address.
|
||||
func (c *client) GetPublicKeysForEmail(ctx context.Context, email string) (keys []PublicKey, internal bool, err error) {
|
||||
email = url.QueryEscape(email)
|
||||
|
||||
var res struct {
|
||||
Keys []PublicKey
|
||||
RecipientType RecipientType
|
||||
|
||||
@ -75,42 +75,35 @@ var LabelColors = []string{ //nolint[gochecknoglobals]
|
||||
"#dfb286",
|
||||
}
|
||||
|
||||
type LabelAction int
|
||||
|
||||
const (
|
||||
RemoveLabel LabelAction = iota
|
||||
AddLabel
|
||||
)
|
||||
|
||||
// Label for message.
|
||||
type Label struct {
|
||||
type Label struct { //nolint[maligned]
|
||||
ID string
|
||||
Name string
|
||||
Path string
|
||||
Color string
|
||||
Order int `json:",omitempty"`
|
||||
Display int // Not used for now, leave it empty.
|
||||
Exclusive int
|
||||
Exclusive Boolean
|
||||
Type int
|
||||
Notify int
|
||||
Notify Boolean
|
||||
}
|
||||
|
||||
func (c *client) ListLabels(ctx context.Context) (labels []*Label, err error) {
|
||||
return c.ListLabelType(ctx, LabelTypeMailbox)
|
||||
return c.listLabelType(ctx, LabelTypeMailbox)
|
||||
}
|
||||
|
||||
func (c *client) ListContactGroups(ctx context.Context) (labels []*Label, err error) {
|
||||
return c.ListLabelType(ctx, LabelTypeContactGroup)
|
||||
return c.listLabelType(ctx, LabelTypeContactGroup)
|
||||
}
|
||||
|
||||
// ListLabelType lists all labels created by the user.
|
||||
func (c *client) ListLabelType(ctx context.Context, labelType int) (labels []*Label, err error) {
|
||||
// listLabelType lists all labels created by the user.
|
||||
func (c *client) listLabelType(ctx context.Context, labelType int) (labels []*Label, err error) {
|
||||
var res struct {
|
||||
Labels []*Label
|
||||
}
|
||||
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetQueryParam("Type", strconv.Itoa(labelType)).SetResult(&res).Get("/v4/labels")
|
||||
return r.SetQueryParam("Type", strconv.Itoa(labelType)).SetResult(&res).Get("/labels")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -135,7 +128,7 @@ func (c *client) CreateLabel(ctx context.Context, label *Label) (created *Label,
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(&LabelReq{
|
||||
Label: label,
|
||||
}).SetResult(&res).Post("/v4/labels")
|
||||
}).SetResult(&res).Post("/labels")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -156,7 +149,7 @@ func (c *client) UpdateLabel(ctx context.Context, label *Label) (updated *Label,
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.SetBody(&LabelReq{
|
||||
Label: label,
|
||||
}).SetResult(&res).Put("/v4/labels/" + label.ID)
|
||||
}).SetResult(&res).Put("/labels/" + label.ID)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -167,7 +160,7 @@ func (c *client) UpdateLabel(ctx context.Context, label *Label) (updated *Label,
|
||||
// DeleteLabel deletes a label.
|
||||
func (c *client) DeleteLabel(ctx context.Context, labelID string) error {
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
return r.Delete("/v4/labels/" + labelID)
|
||||
return r.Delete("/labels/" + labelID)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -91,61 +91,43 @@ const testDeleteLabelBody = `{
|
||||
`
|
||||
|
||||
func TestClient_ListLabels(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/v4/labels?Type=1"))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "GET", "/labels?Type=1"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testLabelsBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
labels, err := c.ListLabels(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while listing labels, got:", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(labels, testLabels) {
|
||||
for i, l := range testLabels {
|
||||
t.Errorf("expected %d: %#v\n", i, l)
|
||||
}
|
||||
for i, l := range labels {
|
||||
t.Errorf("got %d: %#v\n", i, l)
|
||||
}
|
||||
t.Fatalf("Not same")
|
||||
}
|
||||
labels, err := c.ListLabels(context.Background())
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, testLabels, labels)
|
||||
}
|
||||
|
||||
func TestClient_CreateLabel(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/v4/labels"))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/labels"))
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
_, err := body.ReadFrom(r.Body)
|
||||
Ok(t, err)
|
||||
_, err := body.ReadFrom(req.Body)
|
||||
r.NoError(t, err)
|
||||
|
||||
if bytes.Contains(body.Bytes(), []byte("Order")) {
|
||||
t.Fatal("Body contains `Order`: ", body.String())
|
||||
}
|
||||
|
||||
var labelReq LabelReq
|
||||
if err := json.NewDecoder(body).Decode(&labelReq); err != nil {
|
||||
t.Error("Expecting no error while reading request body, got:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(testLabelReq.Label, labelReq.Label) {
|
||||
t.Errorf("Invalid label request: expected %+v but got %+v", testLabelReq.Label, labelReq.Label)
|
||||
}
|
||||
err = json.NewDecoder(body).Decode(&labelReq)
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, testLabelReq.Label, labelReq.Label)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testCreateLabelBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
created, err := c.CreateLabel(context.TODO(), testLabelReq.Label)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while creating label, got:", err)
|
||||
}
|
||||
created, err := c.CreateLabel(context.Background(), testLabelReq.Label)
|
||||
r.NoError(t, err)
|
||||
|
||||
if !reflect.DeepEqual(created, testLabelCreated) {
|
||||
t.Fatalf("Invalid created label: expected %+v, got %+v", testLabelCreated, created)
|
||||
@ -158,32 +140,26 @@ func TestClient_CreateEmptyLabel(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
_, err := c.CreateLabel(context.TODO(), &Label{})
|
||||
_, err := c.CreateLabel(context.Background(), &Label{})
|
||||
r.EqualError(t, err, "name is required")
|
||||
}
|
||||
|
||||
func TestClient_UpdateLabel(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "PUT", "/v4/labels/"+testLabelCreated.ID))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "PUT", "/labels/"+testLabelCreated.ID))
|
||||
|
||||
var labelReq LabelReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&labelReq); err != nil {
|
||||
t.Error("Expecting no error while reading request body, got:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(testLabelCreated, labelReq.Label) {
|
||||
t.Errorf("Invalid label request: expected %+v but got %+v", testLabelCreated, labelReq.Label)
|
||||
}
|
||||
err := json.NewDecoder(req.Body).Decode(&labelReq)
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, testLabelCreated, labelReq.Label)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testCreateLabelBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
updated, err := c.UpdateLabel(context.TODO(), testLabelCreated)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while updating label, got:", err)
|
||||
}
|
||||
updated, err := c.UpdateLabel(context.Background(), testLabelCreated)
|
||||
r.NoError(t, err)
|
||||
|
||||
if !reflect.DeepEqual(updated, testLabelCreated) {
|
||||
t.Fatalf("Invalid updated label: expected %+v, got %+v", testLabelCreated, updated)
|
||||
@ -196,24 +172,21 @@ func TestClient_UpdateLabelToEmptyName(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
_, err := c.UpdateLabel(context.TODO(), &Label{ID: "label"})
|
||||
_, err := c.UpdateLabel(context.Background(), &Label{ID: "label"})
|
||||
r.EqualError(t, err, "name is required")
|
||||
}
|
||||
|
||||
func TestClient_DeleteLabel(t *testing.T) {
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "DELETE", "/v4/labels/"+testLabelCreated.ID))
|
||||
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "DELETE", "/labels/"+testLabelCreated.ID))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testDeleteLabelBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
err := c.DeleteLabel(context.TODO(), testLabelCreated.ID)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while deleting label, got:", err)
|
||||
}
|
||||
err := c.DeleteLabel(context.Background(), testLabelCreated.ID)
|
||||
r.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLeastUsedColor(t *testing.T) {
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
@ -10,56 +27,53 @@ import (
|
||||
)
|
||||
|
||||
type manager struct {
|
||||
rc *resty.Client
|
||||
cfg Config
|
||||
rc *resty.Client
|
||||
|
||||
isDown bool
|
||||
locker sync.Locker
|
||||
observers []ConnectionObserver
|
||||
}
|
||||
|
||||
func newManager(cfg Config) *manager {
|
||||
m := &manager{
|
||||
rc: resty.New(),
|
||||
locker: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// Set the API host.
|
||||
m.rc.SetHostURL(cfg.HostURL)
|
||||
|
||||
// Set static header values.
|
||||
m.rc.SetHeader("x-pm-appversion", cfg.AppVersion)
|
||||
|
||||
// Set middleware.
|
||||
m.rc.OnAfterResponse(catchAPIError)
|
||||
|
||||
// Configure retry mechanism.
|
||||
m.rc.SetRetryMaxWaitTime(time.Minute)
|
||||
m.rc.SetRetryAfter(catchRetryAfter)
|
||||
m.rc.AddRetryCondition(catchTooManyRequests)
|
||||
m.rc.AddRetryCondition(catchNoResponse)
|
||||
m.rc.AddRetryCondition(catchProxyAvailable)
|
||||
|
||||
// Determine what happens when requests succeed/fail.
|
||||
m.rc.OnAfterResponse(m.handleRequestSuccess)
|
||||
m.rc.OnError(m.handleRequestFailure)
|
||||
|
||||
// Set the data type of API errors.
|
||||
m.rc.SetError(&Error{})
|
||||
|
||||
return m
|
||||
isDown bool
|
||||
locker sync.Locker
|
||||
connectionObservers []ConnectionObserver
|
||||
proxyDialer *ProxyTLSDialer
|
||||
}
|
||||
|
||||
func New(cfg Config) Manager {
|
||||
return newManager(cfg)
|
||||
}
|
||||
|
||||
func (m *manager) SetLogger(logger resty.Logger) {
|
||||
m.rc.SetLogger(logger)
|
||||
m.rc.SetDebug(true)
|
||||
func newManager(cfg Config) *manager {
|
||||
m := &manager{
|
||||
cfg: cfg,
|
||||
rc: resty.New(),
|
||||
locker: &sync.Mutex{},
|
||||
}
|
||||
|
||||
proxyDialer, transport := newProxyDialerAndTransport(cfg)
|
||||
m.proxyDialer = proxyDialer
|
||||
m.rc.SetTransport(transport)
|
||||
|
||||
m.rc.SetHostURL(cfg.HostURL)
|
||||
m.rc.OnBeforeRequest(m.setHeaderValues)
|
||||
|
||||
// Any HTTP status code higher than 399 with JSON inside (and proper header)
|
||||
// is converted to Error. `catchAPIError` then processes API custom errors
|
||||
// wrapped in JSON. If error is returned, `handleRequestFailure` is called,
|
||||
// otherwise `handleRequestSuccess` is called.
|
||||
m.rc.SetError(&Error{})
|
||||
m.rc.OnAfterResponse(m.catchAPIError)
|
||||
m.rc.OnAfterResponse(m.handleRequestSuccess)
|
||||
m.rc.OnError(m.handleRequestFailure)
|
||||
|
||||
// Configure retry mechanism.
|
||||
m.rc.SetRetryMaxWaitTime(time.Minute)
|
||||
m.rc.SetRetryAfter(catchRetryAfter)
|
||||
m.rc.AddRetryCondition(shouldRetry)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *manager) SetTransport(transport http.RoundTripper) {
|
||||
m.rc.SetTransport(transport)
|
||||
m.proxyDialer = nil
|
||||
}
|
||||
|
||||
func (m *manager) SetCookieJar(jar http.CookieJar) {
|
||||
@ -71,7 +85,15 @@ func (m *manager) SetRetryCount(count int) {
|
||||
}
|
||||
|
||||
func (m *manager) AddConnectionObserver(observer ConnectionObserver) {
|
||||
m.observers = append(m.observers, observer)
|
||||
m.connectionObservers = append(m.connectionObservers, observer)
|
||||
}
|
||||
|
||||
func (m *manager) setHeaderValues(_ *resty.Client, req *resty.Request) error {
|
||||
req.SetHeaders(map[string]string{
|
||||
"x-pm-appversion": m.cfg.AppVersion,
|
||||
"User-Agent": m.cfg.getUserAgent(),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) r(ctx context.Context) *resty.Request {
|
||||
@ -90,7 +112,7 @@ func (m *manager) handleRequestSuccess(_ *resty.Client, res *resty.Response) err
|
||||
|
||||
m.isDown = false
|
||||
|
||||
for _, observer := range m.observers {
|
||||
for _, observer := range m.connectionObservers {
|
||||
observer.OnUp()
|
||||
}
|
||||
|
||||
@ -113,15 +135,9 @@ func (m *manager) handleRequestFailure(req *resty.Request, err error) {
|
||||
|
||||
m.isDown = true
|
||||
|
||||
for _, observer := range m.observers {
|
||||
for _, observer := range m.connectionObservers {
|
||||
observer.OnDown()
|
||||
}
|
||||
|
||||
go m.pingUntilSuccess()
|
||||
}
|
||||
|
||||
func (m *manager) pingUntilSuccess() {
|
||||
for m.testPing(context.Background()) != nil {
|
||||
time.Sleep(time.Second) // TODO: How long to sleep here?
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
@ -9,10 +26,14 @@ import (
|
||||
)
|
||||
|
||||
func (m *manager) NewClient(uid, acc, ref string, exp time.Time) Client {
|
||||
log.Trace("New client")
|
||||
|
||||
return newClient(m, uid).withAuth(acc, ref, exp)
|
||||
}
|
||||
|
||||
func (m *manager) NewClientWithRefresh(ctx context.Context, uid, ref string) (Client, *Auth, error) {
|
||||
func (m *manager) NewClientWithRefresh(ctx context.Context, uid, ref string) (Client, *AuthRefresh, error) {
|
||||
log.Trace("New client with refresh")
|
||||
|
||||
c := newClient(m, uid)
|
||||
|
||||
auth, err := m.authRefresh(ctx, uid, ref)
|
||||
@ -24,6 +45,8 @@ func (m *manager) NewClientWithRefresh(ctx context.Context, uid, ref string) (Cl
|
||||
}
|
||||
|
||||
func (m *manager) NewClientWithLogin(ctx context.Context, username, password string) (Client, *Auth, error) {
|
||||
log.Trace("New client with login")
|
||||
|
||||
info, err := m.getAuthInfo(ctx, GetAuthInfoReq{Username: username})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@ -52,24 +75,13 @@ func (m *manager) NewClientWithLogin(ctx context.Context, username, password str
|
||||
return newClient(m, auth.UID).withAuth(auth.AccessToken, auth.RefreshToken, expiresIn(auth.ExpiresIn)), auth, nil
|
||||
}
|
||||
|
||||
func (m *manager) getAuthModulus(ctx context.Context) (AuthModulus, error) {
|
||||
var res struct {
|
||||
AuthModulus
|
||||
}
|
||||
|
||||
if _, err := m.r(ctx).SetResult(&res).Get("/auth/modulus"); err != nil {
|
||||
return AuthModulus{}, err
|
||||
}
|
||||
|
||||
return res.AuthModulus, nil
|
||||
}
|
||||
|
||||
func (m *manager) getAuthInfo(ctx context.Context, req GetAuthInfoReq) (*AuthInfo, error) {
|
||||
var res struct {
|
||||
*AuthInfo
|
||||
}
|
||||
|
||||
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/info"); err != nil {
|
||||
_, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/info"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -81,15 +93,16 @@ func (m *manager) auth(ctx context.Context, req AuthReq) (*Auth, error) {
|
||||
*Auth
|
||||
}
|
||||
|
||||
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth"); err != nil {
|
||||
_, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Auth, nil
|
||||
}
|
||||
|
||||
func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*Auth, error) {
|
||||
var req = AuthRefreshReq{
|
||||
func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefresh, error) {
|
||||
var req = authRefreshReq{
|
||||
UID: uid,
|
||||
RefreshToken: ref,
|
||||
ResponseType: "token",
|
||||
@ -99,14 +112,15 @@ func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*Auth, erro
|
||||
}
|
||||
|
||||
var res struct {
|
||||
*Auth
|
||||
*AuthRefresh
|
||||
}
|
||||
|
||||
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh"); err != nil {
|
||||
_, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Auth, nil
|
||||
return res.AuthRefresh, nil
|
||||
}
|
||||
|
||||
func expiresIn(seconds int64) time.Time {
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
|
||||
71
pkg/pmapi/manager_log.go
Normal file
71
pkg/pmapi/manager_log.go
Normal file
@ -0,0 +1,71 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// restyLogger decreases debug level to trace level so resty logs
|
||||
// are not logged as debug but trace instead. Resty logging is too
|
||||
// verbose which we don't want to have in debug level.
|
||||
type restyLogger struct {
|
||||
logrus *logrus.Entry
|
||||
}
|
||||
|
||||
func (l *restyLogger) Errorf(format string, v ...interface{}) {
|
||||
l.logrus.Errorf(format, v...)
|
||||
}
|
||||
|
||||
func (l *restyLogger) Warnf(format string, v ...interface{}) {
|
||||
l.logrus.Warnf(format, v...)
|
||||
}
|
||||
|
||||
func (l *restyLogger) Debugf(format string, v ...interface{}) {
|
||||
l.logrus.Tracef(format, v...)
|
||||
}
|
||||
|
||||
func (m *manager) SetLogging(logger *logrus.Entry, verbose bool) {
|
||||
if verbose {
|
||||
m.rc.SetLogger(&restyLogger{logrus: logger})
|
||||
m.rc.SetDebug(true)
|
||||
return
|
||||
}
|
||||
|
||||
m.rc.OnBeforeRequest(func(_ *resty.Client, req *resty.Request) error {
|
||||
logger.Infof("Requesting %s %s", req.Method, req.URL)
|
||||
return nil
|
||||
})
|
||||
m.rc.OnAfterResponse(func(_ *resty.Client, res *resty.Response) error {
|
||||
log := logger.WithFields(logrus.Fields{
|
||||
"error": res.Error(),
|
||||
"status": res.StatusCode(),
|
||||
"duration": res.Time(),
|
||||
})
|
||||
if res.Request == nil {
|
||||
log.Warn("Requested unknown request")
|
||||
return nil
|
||||
}
|
||||
log.Debugf("Requested %s %s", res.Request.Method, res.Request.URL)
|
||||
return nil
|
||||
})
|
||||
m.rc.OnError(func(req *resty.Request, err error) {
|
||||
logger.WithError(err).Warnf("Failed request %s %s", req.Method, req.URL)
|
||||
})
|
||||
}
|
||||
@ -1,11 +1,34 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
func (m *manager) SendSimpleMetric(context.Context, string, string, string) error {
|
||||
// FIXME(conman): Implement.
|
||||
return errors.New("not implemented")
|
||||
func (m *manager) SendSimpleMetric(ctx context.Context, category, action, label string) error {
|
||||
r := m.r(ctx).SetQueryParams(map[string]string{
|
||||
"Category": category,
|
||||
"Action": action,
|
||||
"Label": label,
|
||||
})
|
||||
if _, err := wrapNoConnection(r.Get("/metrics")); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -23,6 +23,8 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testSendSimpleMetricsBody = `{
|
||||
@ -30,21 +32,17 @@ const testSendSimpleMetricsBody = `{
|
||||
}
|
||||
`
|
||||
|
||||
// FIXME(conman): Implement metrics then enable this test.
|
||||
func _TestClient_SendSimpleMetric(t *testing.T) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "GET", "/metrics?Action=some_action&Category=some_category&Label=some_label"))
|
||||
|
||||
func TestClient_SendSimpleMetric(t *testing.T) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "GET", "/metrics?Action=some_action&Category=some_category&Label=some_label"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, testSendSimpleMetricsBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
m := newManager(Config{HostURL: s.URL})
|
||||
m := newManager(newTestConfig(s.URL))
|
||||
|
||||
err := m.SendSimpleMetric(context.TODO(), "some_category", "some_action", "some_label")
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while sending simple metric, got:", err)
|
||||
}
|
||||
err := m.SendSimpleMetric(context.Background(), "some_category", "some_action", "some_label")
|
||||
r.NoError(t, err)
|
||||
}
|
||||
@ -1,11 +1,60 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
// retryConnectionSleeps defines a smooth cool down in seconds.
|
||||
retryConnectionSleeps = []int{2, 5, 10, 30, 60} // nolint[gochecknoglobals]
|
||||
)
|
||||
|
||||
func (m *manager) pingUntilSuccess() {
|
||||
attempt := 0
|
||||
for {
|
||||
err := m.testPing(context.Background())
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
waitTime := getRetryConnectionSleep(attempt)
|
||||
attempt++
|
||||
logrus.WithError(err).WithField("attempt", attempt).WithField("wait", waitTime).Debug("Connection not available")
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
}
|
||||
|
||||
func getRetryConnectionSleep(idx int) time.Duration {
|
||||
if idx >= len(retryConnectionSleeps) {
|
||||
idx = len(retryConnectionSleeps) - 1
|
||||
}
|
||||
sec := retryConnectionSleeps[idx]
|
||||
return time.Duration(sec) * time.Second
|
||||
}
|
||||
|
||||
func (m *manager) testPing(ctx context.Context) error {
|
||||
if _, err := m.r(ctx).Get("/tests/ping"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
32
pkg/pmapi/manager_proxy.go
Normal file
32
pkg/pmapi/manager_proxy.go
Normal file
@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
|
||||
func (m *manager) AllowProxy() {
|
||||
if m.proxyDialer != nil {
|
||||
m.proxyDialer.AllowProxy()
|
||||
}
|
||||
}
|
||||
|
||||
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
|
||||
func (m *manager) DisallowProxy() {
|
||||
if m.proxyDialer != nil {
|
||||
m.proxyDialer.DisallowProxy()
|
||||
}
|
||||
}
|
||||
@ -1,12 +1,43 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Report sends request as json or multipart (if has attachment).
|
||||
func (m *manager) ReportBug(context.Context, ReportBugReq) error {
|
||||
// FIXME(conman): Implement.
|
||||
return errors.New("not implemented")
|
||||
func (m *manager) ReportBug(ctx context.Context, rep ReportBugReq) error {
|
||||
if rep.ClientType == 0 {
|
||||
rep.ClientType = EmailClientType
|
||||
}
|
||||
|
||||
r := m.r(ctx)
|
||||
if len(rep.Attachments) == 0 {
|
||||
r = r.SetBody(rep)
|
||||
} else {
|
||||
r = r.SetMultipartFormData(rep.GetMultipartFormData())
|
||||
for _, att := range rep.Attachments {
|
||||
r = r.SetMultipartField(att.name, att.filename, "application/octet-stream", att.body)
|
||||
}
|
||||
}
|
||||
if _, err := wrapNoConnection(r.Post("/reports/bug")); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -24,9 +24,10 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testBugReportReq = ReportBugReq{
|
||||
@ -42,28 +43,15 @@ var testBugReportReq = ReportBugReq{
|
||||
Email: "apple@gmail.com",
|
||||
}
|
||||
|
||||
var testBugsCrashReq = ReportBugReq{
|
||||
OS: runtime.GOOS,
|
||||
Client: "demoapp",
|
||||
ClientVersion: "GoPMAPI_1.0.14",
|
||||
ClientType: 1,
|
||||
Debug: "main.func·001()\n/Users/sunny/Code/Go/src/scratch/stack.go:21 +0xabruntime.panic(0x80b80, 0x2101fb150)\n/usr/local/Cellar/go/1.2/libexec/src/pkg/runtime/panic.c:248 +0x106\nmain.inner()/Users/sunny/Code/Go/src/scratch/stack.go:27 +0x68\nmain.outer()\n/Users/sunny/Code/Go/src/scratch/stack.go:13 +0x1a\nmain.main()\n/Users/sunny/Code/Go/src/scratch/stack.go:9 +0x1a",
|
||||
}
|
||||
|
||||
const testBugsBody = `{
|
||||
"Code": 1000
|
||||
}
|
||||
`
|
||||
|
||||
const testAttachmentJSONZipped = "PK\x03\x04\x14\x00\b\x00\b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\b\x00\x00\x00last.log\\Rَ\xaaH\x00}ﯨ\xf8r\x1f\xeeܖED;\xe9\ap\x03\x11\x11\x97\x0e8\x99L\xb0(\xa1\xa0\x16\x85b\x91I\xff\xfbD{\x99\xc9}\xab:K\x9d\xa4\xce\xf9\xe7\t\x00\x00z\xf6\xb4\xf7\x02z\xb7a\xe5\xd8\x04*V̭\x8d\xd1lvE}\xd6\xe3\x80\x1f\xd7nX\x9bI[\xa6\xe1a=\xd4a\xa8M\x97\xd9J\xf1F\xeb\x105U\xbd\xb0`XO\xce\xf1hu\x99q\xc3\xfe{\x11ߨ'-\v\x89Z\xa4\x9c5\xaf\xaf\xbd?>R\xd6\x11E\xf7\x1cX\xf0JpF#L\x9eE+\xbe\xe8\x1d\xee\ued2e\u007f\xde]\u06dd\xedo\x97\x87E\xa0V\xf4/$\xc2\xecK\xed\xa0\xdb&\x829\x12\xe5\x9do\xa0\xe9\x1a\xd2\x19\x1e\xf5`\x95гb\xf8\x89\x81\xb7\xa5G\x18\x95\xf3\x9d9\xe8\x93B\x17!\x1a^\xccr\xbb`\xb2\xb4\xb86\x87\xb4h\x0e\xda\xc6u<+\x9e$̓\x95\xccSo\xea\xa4\xdbH!\xe9g\x8b\xd4\b\xb3hܬ\xa6Wk\x14He\xae\x8aPU\xaa\xc1\xee$\xfbH\xb3\xab.I\f<\x89\x06q\xe3-3-\x99\xcdݽ\xe5v\x99\xedn\xac\xadn\xe8Rp=\xb4nJ\xed\xd5\r\x8d\xde\x06Ζ\xf6\xb3\x01\x94\xcb\xf6\xd4\x19r\xe1\xaa$4+\xeaW\xa6F\xfa0\x97\x9cD\f\x8e\xd7\xd6z\v,G\xf3e2\xd4\xe6V\xba\v\xb6\xd9\xe8\xca*\x16\x95V\xa4J\xfbp\xddmF\x8c\x9a\xc6\xc8Č-\xdb\v\xf6\xf5\xf9\x02*\x15e\x874\xc9\xe7\"\xa3\x1an\xabq}ˊq\x957\xd3\xfd\xa91\x82\xe0Lß\\\x17\x8e\x9e_\xed`\t\xe9~5̕\x03\x9a\f\xddN6\xa2\xc4\x17\xdb\xc9V\x1c~\x9e\xea\xbe\xda-xv\xed\x8b\xe2\xc8DŽS\x95E6\xf2\xc3H\x1d:HPx\xc9\x14\xbfɒ\xff\xea\xb4P\x14\xa3\xe2\xfe\xfd\x1f+z\x80\x903\x81\x98\xf8\x15\xa3\x12\x16\xf8\"0g\xf7~B^\xfd \x040T\xa3\x02\x9c\x10\xc1\xa8F\xa0I#\xf1\xa3\x04\x98\x01\x91\xe2\x12\xdc;\x06gL\xd0g\xc0\xe3\xbd\xf6\xd7}&\xa8轀?\xbfяy`X\xf0\x92\x9f\x05\xf0*A8ρ\xac=K\xff\xf3\xfe\xa6Z\xe1\x1a\x017\xc2\x04\f\x94g\xa9\xf7-\xfb\xebqz\u007fz\u007f\xfa7\x00\x00\xff\xffPK\a\b\xf5\\\v\xe5I\x02\x00\x00\r\x03\x00\x00PK\x01\x02\x14\x00\x14\x00\b\x00\b\x00\x00\x00\x00\x00\xf5\\\v\xe5I\x02\x00\x00\r\x03\x00\x00\b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00last.logPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x006\x00\x00\x00\u007f\x02\x00\x00\x00\x00" //nolint[misspell]
|
||||
|
||||
// FIXME(conman): Implement bug reports then enable this test.
|
||||
func _TestClient_BugReportWithAttachment(t *testing.T) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/reports/bug"))
|
||||
Ok(t, isAuthReq(r, testUID, testAccessToken))
|
||||
|
||||
Ok(t, r.ParseMultipartForm(10*1024))
|
||||
func TestClient_BugReportWithAttachment(t *testing.T) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/reports/bug"))
|
||||
r.NoError(t, req.ParseMultipartForm(10*1024))
|
||||
|
||||
for field, expected := range map[string]string{
|
||||
"OS": testBugReportReq.OS,
|
||||
@ -76,60 +64,43 @@ func _TestClient_BugReportWithAttachment(t *testing.T) {
|
||||
"Username": testBugReportReq.Username,
|
||||
"Email": testBugReportReq.Email,
|
||||
} {
|
||||
if r.PostFormValue(field) != expected {
|
||||
t.Errorf("Field %q has %q but expected %q", field, r.PostFormValue(field), expected)
|
||||
}
|
||||
r.Equal(t, expected, req.PostFormValue(field))
|
||||
}
|
||||
|
||||
attReader, err := r.MultipartForm.File["log"][0].Open()
|
||||
Ok(t, err)
|
||||
|
||||
log, err := ioutil.ReadAll(attReader)
|
||||
Ok(t, err)
|
||||
|
||||
Equals(t, []byte(testAttachmentJSONZipped), log)
|
||||
attReader, err := req.MultipartForm.File["log"][0].Open()
|
||||
r.NoError(t, err)
|
||||
_, err = ioutil.ReadAll(attReader)
|
||||
r.NoError(t, err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testBugsBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
cm := newManager(Config{HostURL: s.URL})
|
||||
cm := newManager(newTestConfig(s.URL))
|
||||
|
||||
rep := testBugReportReq
|
||||
rep.AddAttachment("log", "last.log", strings.NewReader(testAttachmentJSON))
|
||||
|
||||
Ok(t, cm.ReportBug(context.TODO(), rep))
|
||||
err := cm.ReportBug(context.Background(), rep)
|
||||
r.NoError(t, err)
|
||||
}
|
||||
|
||||
// FIXME(conman): Implement bug reports then enable this test.
|
||||
func _TestClient_BugReport(t *testing.T) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Ok(t, checkMethodAndPath(r, "POST", "/reports/bug"))
|
||||
Ok(t, isAuthReq(r, testUID, testAccessToken))
|
||||
func TestClient_BugReport(t *testing.T) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
r.NoError(t, checkMethodAndPath(req, "POST", "/reports/bug"))
|
||||
|
||||
var bugsReportReq ReportBugReq
|
||||
Ok(t, json.NewDecoder(r.Body).Decode(&bugsReportReq))
|
||||
Equals(t, testBugReportReq, bugsReportReq)
|
||||
r.NoError(t, json.NewDecoder(req.Body).Decode(&bugsReportReq))
|
||||
r.Equal(t, testBugReportReq, bugsReportReq)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Fprint(w, testBugsBody)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
cm := newManager(Config{HostURL: s.URL})
|
||||
cm := newManager(newTestConfig(s.URL))
|
||||
|
||||
r := ReportBugReq{
|
||||
OS: testBugReportReq.OS,
|
||||
OSVersion: testBugReportReq.OSVersion,
|
||||
Browser: testBugReportReq.Browser,
|
||||
Title: testBugReportReq.Title,
|
||||
Description: testBugReportReq.Description,
|
||||
Username: testBugReportReq.Username,
|
||||
Email: testBugReportReq.Email,
|
||||
}
|
||||
|
||||
Ok(t, cm.ReportBug(context.TODO(), r))
|
||||
err := cm.ReportBug(context.Background(), testBugReportReq)
|
||||
r.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -1,12 +1,25 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ClientType is required by API.
|
||||
@ -47,8 +60,8 @@ func (rep *ReportBugReq) AddAttachment(name, filename string, r io.Reader) {
|
||||
rep.Attachments = append(rep.Attachments, reportAtt{name: name, filename: filename, body: r})
|
||||
}
|
||||
|
||||
func writeMultipartReport(w *multipart.Writer, rep *ReportBugReq) error { // nolint[funlen]
|
||||
fieldData := map[string]string{
|
||||
func (rep *ReportBugReq) GetMultipartFormData() map[string]string {
|
||||
return map[string]string{
|
||||
"OS": rep.OS,
|
||||
"OSVersion": rep.OSVersion,
|
||||
"Browser": rep.Browser,
|
||||
@ -58,7 +71,7 @@ func writeMultipartReport(w *multipart.Writer, rep *ReportBugReq) error { // nol
|
||||
"DisplayMode": rep.DisplayMode,
|
||||
"Client": rep.Client,
|
||||
"ClientVersion": rep.ClientVersion,
|
||||
"ClientType": "1",
|
||||
"ClientType": fmt.Sprintf("%d", rep.ClientType),
|
||||
"Title": rep.Title,
|
||||
"Description": rep.Description,
|
||||
"Username": rep.Username,
|
||||
@ -67,46 +80,4 @@ func writeMultipartReport(w *multipart.Writer, rep *ReportBugReq) error { // nol
|
||||
"ISP": rep.ISP,
|
||||
"Debug": rep.Debug,
|
||||
}
|
||||
|
||||
for field, data := range fieldData {
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
if err := w.WriteField(field, data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
quoteEscaper := strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
|
||||
|
||||
for _, att := range rep.Attachments {
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition",
|
||||
fmt.Sprintf(`form-data; name="%s"; filename="%s"`,
|
||||
quoteEscaper.Replace(att.name), quoteEscaper.Replace(att.filename+".zip")))
|
||||
h.Set("Content-Type", "application/octet-stream")
|
||||
// h.Set("Content-Transfer-Encoding", "base64")
|
||||
attWr, err := w.CreatePart(h)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
zipArch := zip.NewWriter(attWr)
|
||||
zipWr, err := zipArch.Create(att.filename)
|
||||
// b64 := base64.NewEncoder(base64.StdEncoding, zipWr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = io.Copy(zipWr, att.body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = zipArch.Close()
|
||||
// err = b64.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1,16 +1,39 @@
|
||||
package pmapi_test
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testForceUpgradeBody = `{
|
||||
"Code":5003,
|
||||
"Error":"Upgrade!"
|
||||
}`
|
||||
|
||||
func TestHandleTooManyRequests(t *testing.T) {
|
||||
var numCalls int
|
||||
|
||||
@ -24,21 +47,17 @@ func TestHandleTooManyRequests(t *testing.T) {
|
||||
}
|
||||
}))
|
||||
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
|
||||
// Set the retry count to 5.
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// The call should succeed because the 5th retry should succeed (429s are retried).
|
||||
if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
|
||||
// The server should be called 5 times.
|
||||
// The first four calls should return 429 and the last call should return 200.
|
||||
if numCalls != 5 {
|
||||
t.Fatal("expected numCalls to be 5, instead got", numCalls)
|
||||
}
|
||||
r.Equal(t, 5, numCalls)
|
||||
}
|
||||
|
||||
func TestHandleUnprocessableEntity(t *testing.T) {
|
||||
@ -49,27 +68,16 @@ func TestHandleUnprocessableEntity(t *testing.T) {
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
}))
|
||||
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
|
||||
// Set the retry count to 5.
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// The call should fail because the first call should fail (422s are not retried).
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
// API-side errors get ErrAPIFailure
|
||||
if !errors.Is(err, pmapi.ErrAPIFailure) {
|
||||
t.Fatal("expected error to be ErrAPIFailure, instead got", err)
|
||||
}
|
||||
|
||||
r.EqualError(t, err, "422 Unprocessable Entity")
|
||||
// The server should be called 1 time.
|
||||
// The first call should return 422.
|
||||
if numCalls != 1 {
|
||||
t.Fatal("expected numCalls to be 1, instead got", numCalls)
|
||||
}
|
||||
r.Equal(t, 1, numCalls)
|
||||
}
|
||||
|
||||
func TestHandleDialFailure(t *testing.T) {
|
||||
@ -81,24 +89,17 @@ func TestHandleDialFailure(t *testing.T) {
|
||||
}))
|
||||
|
||||
// The failingRoundTripper will fail the first 5 times it is used.
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
|
||||
// Set a custom transport.
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
m.SetTransport(newFailingRoundTripper(5))
|
||||
|
||||
// Set the retry count to 5.
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// The call should succeed because the last retry should succeed (dial errors are retried).
|
||||
if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
|
||||
// The server should be called 1 time.
|
||||
// The first 4 attempts don't reach the server.
|
||||
if numCalls != 1 {
|
||||
t.Fatal("expected numCalls to be 1, instead got", numCalls)
|
||||
}
|
||||
r.Equal(t, 1, numCalls)
|
||||
}
|
||||
|
||||
func TestHandleTooManyDialFailures(t *testing.T) {
|
||||
@ -112,28 +113,15 @@ func TestHandleTooManyDialFailures(t *testing.T) {
|
||||
// The failingRoundTripper will fail the first 10 times it is used.
|
||||
// This is more than the number of retries we permit.
|
||||
// Thus, dials will fail.
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
|
||||
// Set a custom transport.
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
m.SetTransport(newFailingRoundTripper(10))
|
||||
|
||||
// Set the retry count to 5.
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// The call should fail because every dial will fail and we'll run out of retries.
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
if !errors.Is(err, pmapi.ErrNoConnection) {
|
||||
t.Fatal("expected error to be ErrNoConnection, instead got", err)
|
||||
}
|
||||
|
||||
r.EqualError(t, err, "no internet connection")
|
||||
// The server should never be called.
|
||||
if numCalls != 0 {
|
||||
t.Fatal("expected numCalls to be 0, instead got", numCalls)
|
||||
}
|
||||
r.Equal(t, 0, numCalls)
|
||||
}
|
||||
|
||||
func TestRetriesWithContextTimeout(t *testing.T) {
|
||||
@ -150,24 +138,16 @@ func TestRetriesWithContextTimeout(t *testing.T) {
|
||||
}))
|
||||
|
||||
// Theoretically, this should succeed; on the fifth retry, we'll get StatusOK.
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
|
||||
// Set the retry count to 5.
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// However, that will take ~5s, and we only allow 1s in the context.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
// However, that will take ~0.5s, and we only allow 10ms in the context.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Thus, it will fail.
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatal("expected error to be DeadlineExceeded, instead got", err)
|
||||
}
|
||||
r.EqualError(t, err, context.DeadlineExceeded.Error())
|
||||
}
|
||||
|
||||
func TestObserveConnectionStatus(t *testing.T) {
|
||||
@ -177,36 +157,24 @@ func TestObserveConnectionStatus(t *testing.T) {
|
||||
|
||||
var onDown, onUp bool
|
||||
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
|
||||
// Set a custom transport.
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
m.SetTransport(newFailingRoundTripper(10))
|
||||
|
||||
// Set the retry count to 5.
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// Add a connection observer.
|
||||
m.AddConnectionObserver(pmapi.NewConnectionObserver(func() { onDown = true }, func() { onUp = true }))
|
||||
m.AddConnectionObserver(NewConnectionObserver(func() { onDown = true }, func() { onUp = true }))
|
||||
|
||||
// The call should fail because every dial will fail and we'll run out of retries.
|
||||
if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
|
||||
if onDown != true || onUp == true {
|
||||
t.Fatal("expected onDown to have been called and onUp to not have been called")
|
||||
}
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
r.Error(t, err)
|
||||
r.False(t, onUp)
|
||||
r.True(t, onDown)
|
||||
|
||||
onDown, onUp = false, false
|
||||
|
||||
// The call should succeed because the last dial attempt will succeed.
|
||||
if _, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
}
|
||||
|
||||
if onDown == true || onUp != true {
|
||||
t.Fatal("expected onUp to have been called and onDown to not have been called")
|
||||
}
|
||||
_, err = m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
r.NoError(t, err)
|
||||
r.True(t, onUp)
|
||||
r.False(t, onDown)
|
||||
}
|
||||
|
||||
func TestReturnErrNoConnection(t *testing.T) {
|
||||
@ -215,19 +183,27 @@ func TestReturnErrNoConnection(t *testing.T) {
|
||||
}))
|
||||
|
||||
// We will fail more times than we retry, so requests should fail with ErrNoConnection.
|
||||
m := pmapi.New(pmapi.Config{HostURL: ts.URL})
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
m.SetTransport(newFailingRoundTripper(10))
|
||||
m.SetRetryCount(5)
|
||||
|
||||
// The call should fail because every dial will fail and we'll run out of retries.
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error, instead got", err)
|
||||
}
|
||||
r.EqualError(t, err, "no internet connection")
|
||||
}
|
||||
|
||||
if !errors.Is(err, pmapi.ErrNoConnection) {
|
||||
t.Fatal("expected error to be ErrNoConnection, instead got", err)
|
||||
}
|
||||
func TestReturnErrUpgradeApplication(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("content-type", "application/json")
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
fmt.Fprint(w, testForceUpgradeBody)
|
||||
}))
|
||||
|
||||
m := New(Config{HostURL: ts.URL})
|
||||
|
||||
// The call should fail because every call return force upgrade error.
|
||||
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
|
||||
r.EqualError(t, err, ErrUpgradeApplication.Error())
|
||||
}
|
||||
|
||||
type failingRoundTripper struct {
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
@ -6,21 +23,24 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
NewClient(string, string, string, time.Time) Client
|
||||
NewClientWithRefresh(context.Context, string, string) (Client, *Auth, error)
|
||||
NewClientWithRefresh(context.Context, string, string) (Client, *AuthRefresh, error)
|
||||
NewClientWithLogin(context.Context, string, string) (Client, *Auth, error)
|
||||
|
||||
DownloadAndVerify(kr *crypto.KeyRing, url, sig string) ([]byte, error)
|
||||
ReportBug(context.Context, ReportBugReq) error
|
||||
SendSimpleMetric(context.Context, string, string, string) error
|
||||
|
||||
SetLogger(resty.Logger)
|
||||
SetLogging(logger *logrus.Entry, verbose bool)
|
||||
SetTransport(http.RoundTripper)
|
||||
SetCookieJar(http.CookieJar)
|
||||
SetRetryCount(int)
|
||||
AddConnectionObserver(ConnectionObserver)
|
||||
|
||||
AllowProxy()
|
||||
DisallowProxy()
|
||||
}
|
||||
|
||||
@ -518,7 +518,20 @@ func (c *client) ListMessages(ctx context.Context, filter *MessagesFilter) ([]*M
|
||||
|
||||
// CountMessages counts messages by label.
|
||||
func (c *client) CountMessages(ctx context.Context, addressID string) (counts []*MessagesCount, err error) {
|
||||
panic("TODO")
|
||||
var res struct {
|
||||
Counts []*MessagesCount
|
||||
}
|
||||
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
if addressID != "" {
|
||||
r = r.SetQueryParam("AddressID", addressID)
|
||||
}
|
||||
return r.SetResult(&res).Get("/mail/v4/messages/count")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Counts, nil
|
||||
}
|
||||
|
||||
// GetMessage retrieves a message.
|
||||
@ -640,6 +653,10 @@ func (c *client) UnlabelMessages(ctx context.Context, messageIDs []string, label
|
||||
}
|
||||
|
||||
func (c *client) EmptyFolder(ctx context.Context, labelID, addressID string) error {
|
||||
if labelID == "" {
|
||||
return errors.New("labelID parameter is empty string")
|
||||
}
|
||||
|
||||
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
if addressID != "" {
|
||||
r.SetQueryParam("AddressID", addressID)
|
||||
|
||||
@ -24,7 +24,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testMessageCleartext = `<div>jeej saas<br></div><div><br></div><div class="protonmail_signature_block"><div>Sent from <a href="https://protonmail.ch">ProtonMail</a>, encrypted email based in Switzerland.<br></div><div><br></div></div>`
|
||||
@ -127,66 +128,65 @@ ClW54lp9eeOfYTsdTSbn9VaSO0E6m2/Q4Tk=
|
||||
|
||||
func TestMessage_IsBodyEncrypted(t *testing.T) {
|
||||
msg := &Message{Body: testMessageEncrypted}
|
||||
Assert(t, msg.IsBodyEncrypted(), "the body should be encrypted")
|
||||
r.True(t, msg.IsBodyEncrypted(), "the body should be encrypted")
|
||||
|
||||
msg.Body = testMessageCleartext
|
||||
Assert(t, !msg.IsBodyEncrypted(), "the body should not be encrypted")
|
||||
r.True(t, !msg.IsBodyEncrypted(), "the body should not be encrypted")
|
||||
}
|
||||
|
||||
func TestMessage_Decrypt(t *testing.T) {
|
||||
msg := &Message{Body: testMessageEncrypted}
|
||||
dec, err := msg.Decrypt(testPrivateKeyRing)
|
||||
Ok(t, err)
|
||||
Equals(t, testMessageCleartext, string(dec))
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, testMessageCleartext, string(dec))
|
||||
}
|
||||
|
||||
func TestMessage_Decrypt_Legacy(t *testing.T) {
|
||||
testPrivateKeyLegacy := readTestFile("testPrivateKeyLegacy", false)
|
||||
|
||||
key, err := crypto.NewKeyFromArmored(testPrivateKeyLegacy)
|
||||
Ok(t, err)
|
||||
r.NoError(t, err)
|
||||
|
||||
unlockedKey, err := key.Unlock([]byte(testMailboxPasswordLegacy))
|
||||
Ok(t, err)
|
||||
r.NoError(t, err)
|
||||
|
||||
testPrivateKeyRingLegacy, err := crypto.NewKeyRing(unlockedKey)
|
||||
Ok(t, err)
|
||||
r.NoError(t, err)
|
||||
|
||||
msg := &Message{Body: testMessageEncryptedLegacy}
|
||||
|
||||
dec, err := msg.Decrypt(testPrivateKeyRingLegacy)
|
||||
Ok(t, err)
|
||||
r.NoError(t, err)
|
||||
|
||||
Equals(t, testMessageCleartextLegacy, string(dec))
|
||||
r.Equal(t, testMessageCleartextLegacy, string(dec))
|
||||
}
|
||||
|
||||
func TestMessage_Decrypt_signed(t *testing.T) {
|
||||
msg := &Message{Body: testMessageSigned}
|
||||
dec, err := msg.Decrypt(testPrivateKeyRing)
|
||||
Ok(t, err)
|
||||
Equals(t, testMessageCleartext, string(dec))
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, testMessageCleartext, string(dec))
|
||||
}
|
||||
|
||||
func TestMessage_Encrypt(t *testing.T) {
|
||||
key, err := crypto.NewKeyFromArmored(testMessageSigner)
|
||||
Ok(t, err)
|
||||
r.NoError(t, err)
|
||||
|
||||
signer, err := crypto.NewKeyRing(key)
|
||||
Ok(t, err)
|
||||
r.NoError(t, err)
|
||||
|
||||
msg := &Message{Body: testMessageCleartext}
|
||||
Ok(t, msg.Encrypt(testPrivateKeyRing, testPrivateKeyRing))
|
||||
r.NoError(t, msg.Encrypt(testPrivateKeyRing, testPrivateKeyRing))
|
||||
|
||||
dec, err := msg.Decrypt(testPrivateKeyRing)
|
||||
Ok(t, err)
|
||||
r.NoError(t, err)
|
||||
|
||||
Equals(t, testMessageCleartext, string(dec))
|
||||
Equals(t, testIdentity, signer.GetIdentities()[0])
|
||||
r.Equal(t, testMessageCleartext, string(dec))
|
||||
r.Equal(t, testIdentity, signer.GetIdentities()[0])
|
||||
}
|
||||
|
||||
func routeLabelMessages(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(tb, checkMethodAndPath(r, "PUT", "/mail/v4/messages/label"))
|
||||
|
||||
func routeLabelMessages(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(tb, checkMethodAndPath(req, "PUT", "/mail/v4/messages/label"))
|
||||
return "messages/label/put_response.json"
|
||||
}
|
||||
|
||||
@ -203,7 +203,7 @@ func TestMessage_LabelMessages_NoPaging(t *testing.T) {
|
||||
)
|
||||
defer finish()
|
||||
|
||||
assert.NoError(t, c.LabelMessages(context.TODO(), testIDs, "mylabel"))
|
||||
a.NoError(t, c.LabelMessages(context.Background(), testIDs, "mylabel"))
|
||||
}
|
||||
|
||||
func TestMessage_LabelMessages_Paging(t *testing.T) {
|
||||
@ -221,5 +221,5 @@ func TestMessage_LabelMessages_Paging(t *testing.T) {
|
||||
)
|
||||
defer finish()
|
||||
|
||||
assert.NoError(t, c.LabelMessages(context.TODO(), testIDs, "mylabel"))
|
||||
a.NoError(t, c.LabelMessages(context.Background(), testIDs, "mylabel"))
|
||||
}
|
||||
|
||||
@ -13,8 +13,8 @@ import (
|
||||
|
||||
crypto "github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
resty "github.com/go-resty/resty/v2"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// MockClient is a mock of Client interface
|
||||
@ -40,16 +40,16 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddAuthHandler mocks base method
|
||||
func (m *MockClient) AddAuthHandler(arg0 pmapi.AuthHandler) {
|
||||
// AddAuthRefreshHandler mocks base method
|
||||
func (m *MockClient) AddAuthRefreshHandler(arg0 pmapi.AuthRefreshHandler) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AddAuthHandler", arg0)
|
||||
m.ctrl.Call(m, "AddAuthRefreshHandler", arg0)
|
||||
}
|
||||
|
||||
// AddAuthHandler indicates an expected call of AddAuthHandler
|
||||
func (mr *MockClientMockRecorder) AddAuthHandler(arg0 interface{}) *gomock.Call {
|
||||
// AddAuthRefreshHandler indicates an expected call of AddAuthRefreshHandler
|
||||
func (mr *MockClientMockRecorder) AddAuthRefreshHandler(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddAuthHandler", reflect.TypeOf((*MockClient)(nil).AddAuthHandler), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddAuthRefreshHandler", reflect.TypeOf((*MockClient)(nil).AddAuthRefreshHandler), arg0)
|
||||
}
|
||||
|
||||
// Addresses mocks base method
|
||||
@ -67,7 +67,7 @@ func (mr *MockClientMockRecorder) Addresses() *gomock.Call {
|
||||
}
|
||||
|
||||
// Auth2FA mocks base method
|
||||
func (m *MockClient) Auth2FA(arg0 context.Context, arg1 pmapi.Auth2FAReq) error {
|
||||
func (m *MockClient) Auth2FA(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Auth2FA", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
@ -616,6 +616,30 @@ func (mr *MockManagerMockRecorder) AddConnectionObserver(arg0 interface{}) *gomo
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConnectionObserver", reflect.TypeOf((*MockManager)(nil).AddConnectionObserver), arg0)
|
||||
}
|
||||
|
||||
// AllowProxy mocks base method
|
||||
func (m *MockManager) AllowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AllowProxy")
|
||||
}
|
||||
|
||||
// AllowProxy indicates an expected call of AllowProxy
|
||||
func (mr *MockManagerMockRecorder) AllowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockManager)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// DisallowProxy mocks base method
|
||||
func (m *MockManager) DisallowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisallowProxy")
|
||||
}
|
||||
|
||||
// DisallowProxy indicates an expected call of DisallowProxy
|
||||
func (mr *MockManagerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockManager)(nil).DisallowProxy))
|
||||
}
|
||||
|
||||
// DownloadAndVerify mocks base method
|
||||
func (m *MockManager) DownloadAndVerify(arg0 *crypto.KeyRing, arg1, arg2 string) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@ -662,11 +686,11 @@ func (mr *MockManagerMockRecorder) NewClientWithLogin(arg0, arg1, arg2 interface
|
||||
}
|
||||
|
||||
// NewClientWithRefresh mocks base method
|
||||
func (m *MockManager) NewClientWithRefresh(arg0 context.Context, arg1, arg2 string) (pmapi.Client, *pmapi.Auth, error) {
|
||||
func (m *MockManager) NewClientWithRefresh(arg0 context.Context, arg1, arg2 string) (pmapi.Client, *pmapi.AuthRefresh, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "NewClientWithRefresh", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(pmapi.Client)
|
||||
ret1, _ := ret[1].(*pmapi.Auth)
|
||||
ret1, _ := ret[1].(*pmapi.AuthRefresh)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
@ -717,16 +741,16 @@ func (mr *MockManagerMockRecorder) SetCookieJar(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCookieJar", reflect.TypeOf((*MockManager)(nil).SetCookieJar), arg0)
|
||||
}
|
||||
|
||||
// SetLogger mocks base method
|
||||
func (m *MockManager) SetLogger(arg0 resty.Logger) {
|
||||
// SetLogging mocks base method
|
||||
func (m *MockManager) SetLogging(arg0 *logrus.Entry, arg1 bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetLogger", arg0)
|
||||
m.ctrl.Call(m, "SetLogging", arg0, arg1)
|
||||
}
|
||||
|
||||
// SetLogger indicates an expected call of SetLogger
|
||||
func (mr *MockManagerMockRecorder) SetLogger(arg0 interface{}) *gomock.Call {
|
||||
// SetLogging indicates an expected call of SetLogging
|
||||
func (mr *MockManagerMockRecorder) SetLogging(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLogger", reflect.TypeOf((*MockManager)(nil).SetLogger), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLogging", reflect.TypeOf((*MockManager)(nil).SetLogging), arg0, arg1)
|
||||
}
|
||||
|
||||
// SetRetryCount mocks base method
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
type ConnectionObserver interface {
|
||||
|
||||
@ -1,25 +0,0 @@
|
||||
-- addresses.go
|
||||
-- attachments.go
|
||||
-- auth.go
|
||||
-- contacts.go
|
||||
-- events.go
|
||||
-- import.go
|
||||
-- key.go
|
||||
-- keyring.go
|
||||
-- labels.go
|
||||
-- manager_auth.go
|
||||
-- manager_download.go
|
||||
-- manager.go
|
||||
-- manager_metrics.go
|
||||
-- manager_ping.go
|
||||
-- manager_report.go
|
||||
-- manager_report_types.go
|
||||
-- manager_types.go
|
||||
-- message_send.go
|
||||
-- messages.go
|
||||
-- metrics.go
|
||||
-- observer.go
|
||||
-- passwords.go
|
||||
-- settings.go
|
||||
-- users.go
|
||||
-- utils.go
|
||||
@ -1,8 +1,25 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
const defaultPageSize = 100
|
||||
|
||||
func doPaged(elements []string, pageSize int, fn func([]string) error) error {
|
||||
func doPaged(elements []string, pageSize int, fn func([]string) error) error { //nolint[unparam]
|
||||
for len(elements) > pageSize {
|
||||
if err := fn(elements[:pageSize]); err != nil {
|
||||
return err
|
||||
|
||||
24
pkg/pmapi/pmapi.go
Normal file
24
pkg/pmapi/pmapi.go
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var log = logrus.WithField("pkg", "pmapi") //nolint[gochecknoglobals]
|
||||
@ -1,12 +1,35 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
errCodeUpgradeApplication = 5003
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
@ -18,26 +41,34 @@ func (err Error) Error() string {
|
||||
return err.Message
|
||||
}
|
||||
|
||||
func catchAPIError(_ *resty.Client, res *resty.Response) error {
|
||||
func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error {
|
||||
if !res.IsError() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if res.StatusCode() == http.StatusUnauthorized {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
if apiErr, ok := res.Error().(*Error); ok {
|
||||
err = apiErr
|
||||
switch {
|
||||
case apiErr.Code == errCodeUpgradeApplication:
|
||||
err = ErrUpgradeApplication
|
||||
if m.cfg.UpgradeApplicationHandler != nil {
|
||||
m.cfg.UpgradeApplicationHandler()
|
||||
}
|
||||
case res.StatusCode() == http.StatusUnprocessableEntity:
|
||||
err = ErrUnprocessableEntity{apiErr}
|
||||
default:
|
||||
err = apiErr
|
||||
}
|
||||
} else {
|
||||
err = errors.New(res.Status())
|
||||
}
|
||||
|
||||
switch res.StatusCode() {
|
||||
case http.StatusUnauthorized:
|
||||
return errors.Wrap(ErrUnauthorized, err.Error())
|
||||
|
||||
default:
|
||||
return errors.Wrap(ErrAPIFailure, err.Error())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error) {
|
||||
@ -45,38 +76,47 @@ func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error
|
||||
if after := res.Header().Get("Retry-After"); after != "" {
|
||||
seconds, err := strconv.Atoi(after)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
log.WithError(err).Warning("Cannot convert Retry-After to number")
|
||||
seconds = 10
|
||||
}
|
||||
|
||||
// To avoid spikes when all clients retry at the same time, we add some random wait.
|
||||
seconds += rand.Intn(10) //nolint[gosec] It is OK to use weak random number generator here.
|
||||
|
||||
log.Warningf("Retrying %s after %ds induced by http code %d", res.Request.URL, seconds, res.StatusCode())
|
||||
return time.Duration(seconds) * time.Second, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 0 and no error means default behaviour which is exponential backoff with jitter.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func catchTooManyRequests(res *resty.Response, _ error) bool {
|
||||
func shouldRetry(res *resty.Response, err error) bool {
|
||||
if isRetryDisabled(res.Request.Context()) {
|
||||
return false
|
||||
}
|
||||
return isTooManyRequest(res) || isNoResponse(res, err)
|
||||
}
|
||||
|
||||
func isTooManyRequest(res *resty.Response) bool {
|
||||
return res.StatusCode() == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
func catchNoResponse(res *resty.Response, err error) bool {
|
||||
func isNoResponse(res *resty.Response, err error) bool {
|
||||
return res.RawResponse == nil && err != nil
|
||||
}
|
||||
|
||||
func catchProxyAvailable(res *resty.Response, err error) bool {
|
||||
/*
|
||||
if res.Request.Attempt < ... {
|
||||
return false
|
||||
}
|
||||
func wrapNoConnection(res *resty.Response, err error) (*resty.Response, error) {
|
||||
if err, ok := err.(*resty.ResponseError); ok {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if response is not empty {
|
||||
return false
|
||||
}
|
||||
if res.RawResponse != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if proxy is available {
|
||||
return true
|
||||
}
|
||||
*/
|
||||
|
||||
return false
|
||||
// Log useful information and return back nicer and clear error message.
|
||||
logrus.WithError(err).WithField("url", res.Request.URL).Warn("No internet connection")
|
||||
return res, ErrNoConnection
|
||||
}
|
||||
|
||||
@ -24,7 +24,6 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@ -32,6 +31,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -40,36 +40,6 @@ var (
|
||||
reHTTPCode = regexp.MustCompile(`(HTTP|get|post|put|delete)_(\d{3}).*.json`)
|
||||
)
|
||||
|
||||
// Assert fails the test if the condition is false.
|
||||
func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) {
|
||||
if !condition {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
vv := []interface{}{filepath.Base(file), line, colRed}
|
||||
vv = append(vv, v...)
|
||||
vv = append(vv, colNon)
|
||||
fmt.Printf("%s:%d: %s"+msg+"%s\n\n", vv...)
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Ok fails the test if an err is not nil.
|
||||
func Ok(tb testing.TB, err error) {
|
||||
if err != nil {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
fmt.Printf("%s:%d: %sunexpected error: %s%s\n\n", filepath.Base(file), line, colRed, err.Error(), colNon)
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Equals fails the test if exp is not equal to act.
|
||||
func Equals(tb testing.TB, exp, act interface{}) {
|
||||
if !reflect.DeepEqual(exp, act) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
fmt.Printf("%s:%d:\n\n%s\texp: %#v\n\n\tgot: %#v%s\n\n", filepath.Base(file), line, colRed, exp, act, colNon)
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func newTestConfig(url string) Config {
|
||||
return Config{
|
||||
HostURL: url,
|
||||
@ -77,7 +47,7 @@ func newTestConfig(url string) Config {
|
||||
}
|
||||
}
|
||||
|
||||
// newTestClient is old function and should be replaced everywhere by newTestServerCallbacks.
|
||||
// newTestClient is old function and should be replaced everywhere by newTestClientCallbacks.
|
||||
func newTestClient(h http.Handler) (*httptest.Server, Client) {
|
||||
s := httptest.NewServer(h)
|
||||
|
||||
@ -93,7 +63,7 @@ func newTestClientCallbacks(tb testing.TB, callbacks ...func(testing.TB, http.Re
|
||||
reqNum++
|
||||
if reqNum > len(callbacks) {
|
||||
fmt.Printf(
|
||||
"%s:%d: %sServer was requeted %d times which is more requests than expected %d%s\n\n",
|
||||
"%s:%d: %sServer was requested %d times which is more requests than expected %d times%s\n\n",
|
||||
file, line, colRed, reqNum, len(callbacks), colNon,
|
||||
)
|
||||
tb.FailNow()
|
||||
@ -134,22 +104,18 @@ func checkMethodAndPath(r *http.Request, method, path string) error {
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
func httpResponse(code int) string {
|
||||
return fmt.Sprintf("HTTP_%d.json", code)
|
||||
}
|
||||
|
||||
func writeJSONResponsefromFile(tb testing.TB, w http.ResponseWriter, response string, reqNum int) {
|
||||
if match := reHTTPCode.FindAllSubmatch([]byte(response), -1); len(match) != 0 {
|
||||
httpCode, err := strconv.Atoi(string(match[0][len(match[0])-1]))
|
||||
Ok(tb, err)
|
||||
r.NoError(tb, err)
|
||||
w.WriteHeader(httpCode)
|
||||
}
|
||||
f, err := os.Open("./testdata/routes/" + response)
|
||||
Ok(tb, err)
|
||||
r.NoError(tb, err)
|
||||
w.Header().Set("content-type", "application/json;charset=utf-8")
|
||||
w.Header().Set("x-test-pmapi-response", fmt.Sprintf("%s:%d", tb.Name(), reqNum))
|
||||
_, err = io.Copy(w, f)
|
||||
Ok(tb, err)
|
||||
r.NoError(tb, err)
|
||||
}
|
||||
|
||||
func checkHeader(h http.Header, field, exp string) error {
|
||||
|
||||
@ -1,8 +0,0 @@
|
||||
package pmapi
|
||||
|
||||
type Boolean int
|
||||
|
||||
const (
|
||||
False Boolean = iota
|
||||
True
|
||||
)
|
||||
@ -42,22 +42,12 @@ var testCurrentUser = &User{
|
||||
Keys: *loadPMKeys(readTestFile("keyring_userKey_JSON", false)),
|
||||
}
|
||||
|
||||
func routeGetUsers(tb testing.TB, w http.ResponseWriter, r *http.Request) string {
|
||||
Ok(tb, checkMethodAndPath(r, "GET", "/users"))
|
||||
Ok(tb, isAuthReq(r, testUID, testAccessToken))
|
||||
|
||||
func routeGetUsers(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
|
||||
r.NoError(tb, checkMethodAndPath(req, "GET", "/users"))
|
||||
r.NoError(tb, isAuthReq(req, testUID, testAccessToken))
|
||||
return "users/get_response.json"
|
||||
}
|
||||
|
||||
const testPublicKeysBody = `{
|
||||
"Code": 1000,
|
||||
"RecipientType": 1,
|
||||
"MIMEType": "text/html",
|
||||
"Keys": [
|
||||
{ "Flags": 3, "PublicKey": "-----BEGIN PGP PUBLIC KEY BLOCK-----\nVersion: OpenPGP.js v0.7.1\nComment: http://openpgpjs.org\n\nxsBNBFSI0BMBB/9td6B5RDzVSFTlFzYOS4JxIb5agtNW1rbA4FeLoC47bGLR\n8E42IA6aKcO4H0vOZ1lFms0URiKk1DjCMXn3AUErbxqiV5IATRZLwliH6vwy\nPI6j5rtGF8dyxYfwmLtoUNkDcPdcFEb4NCdowsN7e8tKU0bcpouZcQhAqawC\n9nEdaG/gS5w+2k4hZX2lOKS1EF5SvP48UadlspEK2PLAIp5wB9XsFS9ey2wu\nelzkSfDh7KUAlteqFGSMqIgYH62/gaKm+TcckfZeyiMHWFw6sfrcFQ3QOZPq\nahWt0Rn9XM5xBAxx5vW0oceuQ1vpvdfFlM5ix4gn/9w6MhmStaCee8/fABEB\nAAHNBlVzZXJJRMLAcgQQAQgAJgUCVIjQHQYLCQgHAwIJEASDR1Fk7GNTBBUI\nAgoDFgIBAhsDAh4BAADmhAf/Yt0mCfWqQ25NNGUN14pKKgnPm68zwj1SmMGa\npU7+7ItRpoFNaDwV5QYiQSLC1SvSb1ZeKoY928GPKfqYyJlBpTPL9zC1OHQj\n9+2yYauHjYW9JWQM7hst2S2LBcdiQPOs3ybWPaO9yaccV4thxKOCPvyClaS5\nb9T4Iv9GEVZQIUvArkwI8hyzIi6skRgxflGheq1O+S1W4Gzt2VtYvo8g8r6W\nGzAGMw2nrs2h0+vUr+dLDgIbFCTc5QU99d5jE/e5Hw8iqBxv9tqB1hVATf8T\nwC8aU5MTtxtabOiBgG0PsBs6oIwjFqEjpOIza2/AflPZfo7stp6IiwbwvTHo\n1NlHoM7ATQRUiNAdAQf/eOLJYxX4lUQUzrNQgASDNE8gJPj7ywcGzySyqr0Y\n5rbG57EjtKMIgZrpzJRpSCuRbBjfsltqJ5Q9TBAbPO+oR3rue0LqPKMnmr/q\nKsHswBJRfsb/dbktUNmv/f7R9IVyOuvyP6RgdGeloxdGNeWiZSA6AZYI+WGc\nxaOvVDPz8thtnML4G4MUhXxxNZ7JzQ0Lfz6mN8CCkblIP5xpcJsyRU7lUsGD\nEJGZX0JH/I8bRVN1Xu08uFinIkZyiXRJ5ZGgF3Dns6VbIWmbttY54tBELtk+\n5g9pNSl9qiYwiCdwuZrA//NmD3xlZIN8sG4eM7ZUibZ23vEq+bUt1++6Mpba\nGQARAQABwsBfBBgBCAATBQJUiNAfCRAEg0dRZOxjUwIbDAAAlpMH/085qZdO\nmGRAlbvViUNhF2rtHvCletC48WHGO1ueSh9VTxalkP21YAYLJ4JgJzArJ7tH\nlEeiKiHm8YU9KhLe11Yv/o3AiKIAQjJiQluvk+mWdMcddB4fBjL6ttMTRAXe\ngHnjtMoamHbSZdeUTUadv05Fl6ivWtpXlODG4V02YvDiGBUbDosdGXEqDtpT\ng6MYlj3QMvUiUNQvt7YGMJS8A9iQ9qBNzErgRW8L6CON2RmpQ/wgwP5nwUHz\nJjY51d82Vj8bZeI8LdsX41SPoUhyC7kmNYpw9ZRy7NlrCt8dBIOB4/BKEJ2G\nClW54lp9eeOfYTsdTSbn9VaSO0E6m2/Q4Tk=\n=WFtr\n-----END PGP PUBLIC KEY BLOCK-----"},
|
||||
{ "Flags": 1, "PublicKey": "-----BEGIN PGP PUBLIC KEY BLOCK-----\nVersion: OpenPGP.js v0.7.1\nComment: http://openpgpjs.org\n\nxsBNBFSI0BMBB/9td6B5RDzVSFTlFzYOS4JxIb5agtNW1rbA4FeLoC47bGLR\n8E42IA6aKcO4H0vOZ1lFms0URiKk1DjCMXn3AUErbxqiV5IATRZLwliH6vwy\nPI6j5rtGF8dyxYfwmLtoUNkDcPdcFEb4NCdowsN7e8tKU0bcpouZcQhAqawC\n9nEdaG/gS5w+2k4hZX2lOKS1EF5SvP48UadlspEK2PLAIp5wB9XsFS9ey2wu\nelzkSfDh7KUAlteqFGSMqIgYH62/gaKm+TcckfZeyiMHWFw6sfrcFQ3QOZPq\nahWt0Rn9XM5xBAxx5vW0oceuQ1vpvdfFlM5ix4gn/9w6MhmStaCee8/fABEB\nAAHNBlVzZXJJRMLAcgQQAQgAJgUCVIjQHQYLCQgHAwIJEASDR1Fk7GNTBBUI\nAgoDFgIBAhsDAh4BAADmhAf/Yt0mCfWqQ25NNGUN14pKKgnPm68zwj1SmMGa\npU7+7ItRpoFNaDwV5QYiQSLC1SvSb1ZeKoY928GPKfqYyJlBpTPL9zC1OHQj\n9+2yYauHjYW9JWQM7hst2S2LBcdiQPOs3ybWPaO9yaccV4thxKOCPvyClaS5\nb9T4Iv9GEVZQIUvArkwI8hyzIi6skRgxflGheq1O+S1W4Gzt2VtYvo8g8r6W\nGzAGMw2nrs2h0+vUr+dLDgIbFCTc5QU99d5jE/e5Hw8iqBxv9tqB1hVATf8T\nwC8aU5MTtxtabOiBgG0PsBs6oIwjFqEjpOIza2/AflPZfo7stp6IiwbwvTHo\n1NlHoM7ATQRUiNAdAQf/eOLJYxX4lUQUzrNQgASDNE8gJPj7ywcGzySyqr0Y\n5rbG57EjtKMIgZrpzJRpSCuRbBjfsltqJ5Q9TBAbPO+oR3rue0LqPKMnmr/q\nKsHswBJRfsb/dbktUNmv/f7R9IVyOuvyP6RgdGeloxdGNeWiZSA6AZYI+WGc\nxaOvVDPz8thtnML4G4MUhXxxNZ7JzQ0Lfz6mN8CCkblIP5xpcJsyRU7lUsGD\nEJGZX0JH/I8bRVN1Xu08uFinIkZyiXRJ5ZGgF3Dns6VbIWmbttY54tBELtk+\n5g9pNSl9qiYwiCdwuZrA//NmD3xlZIN8sG4eM7ZUibZ23vEq+bUt1++6Mpba\nGQARAQABwsBfBBgBCAATBQJUiNAfCRAEg0dRZOxjUwIbDAAAlpMH/085qZdO\nmGRAlbvViUNhF2rtHvCletC48WHGO1ueSh9VTxalkP21YAYLJ4JgJzArJ7tH\nlEeiKiHm8YU9KhLe11Yv/o3AiKIAQjJiQluvk+mWdMcddB4fBjL6ttMTRAXe\ngHnjtMoamHbSZdeUTUadv05Fl6ivWtpXlODG4V02YvDiGBUbDosdGXEqDtpT\ng6MYlj3QMvUiUNQvt7YGMJS8A9iQ9qBNzErgRW8L6CON2RmpQ/wgwP5nwUHz\nJjY51d82Vj8bZeI8LdsX41SPoUhyC7kmNYpw9ZRy7NlrCt8dBIOB4/BKEJ2G\nClW54lp9eeOfYTsdTSbn9VaSO0E6m2/Q4Tk=\n=WFtr\n-----END PGP PUBLIC KEY BLOCK-----"}
|
||||
]}`
|
||||
|
||||
func TestClient_CurrentUser(t *testing.T) {
|
||||
finish, c := newTestClientCallbacks(t,
|
||||
routeGetUsers,
|
||||
@ -65,11 +55,11 @@ func TestClient_CurrentUser(t *testing.T) {
|
||||
)
|
||||
defer finish()
|
||||
|
||||
user, err := c.CurrentUser(context.TODO())
|
||||
user, err := c.CurrentUser(context.Background())
|
||||
r.Nil(t, err)
|
||||
|
||||
// Ignore KeyRings during the check because they have unexported fields and cannot be compared
|
||||
r.True(t, cmp.Equal(user, testCurrentUser, cmpopts.IgnoreTypes(&crypto.Key{})))
|
||||
|
||||
r.Nil(t, c.Unlock(context.TODO(), []byte(testMailboxPassword)))
|
||||
r.Nil(t, c.Unlock(context.Background(), []byte(testMailboxPassword)))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user