feat: switch to proxy when need be

This commit is contained in:
James Houlahan
2020-04-01 17:20:03 +02:00
parent f239e8f3bf
commit ce29d4d74e
36 changed files with 311 additions and 320 deletions

View File

@ -274,8 +274,11 @@ func run(context *cli.Context) (contextError error) { // nolint[funlen]
log.Error("Could not get credentials store: ", credentialsError) log.Error("Could not get credentials store: ", credentialsError)
} }
clientman := pmapi.NewClientManager(pmapifactory.GetClientConfig(cfg, eventListener)) clientConfig := pmapifactory.GetClientConfig(cfg.GetAPIConfig())
bridgeInstance := bridge.New(cfg, pref, panicHandler, eventListener, Version, clientman, credentialsStore) cm := pmapi.NewClientManager(clientConfig)
pmapifactory.SetClientRoundTripper(cm, clientConfig, eventListener)
bridgeInstance := bridge.New(cfg, pref, panicHandler, eventListener, Version, cm, credentialsStore)
imapBackend := imap.NewIMAPBackend(panicHandler, eventListener, cfg, bridgeInstance) imapBackend := imap.NewIMAPBackend(panicHandler, eventListener, cfg, bridgeInstance)
smtpBackend := smtp.NewSMTPBackend(panicHandler, eventListener, pref, bridgeInstance) smtpBackend := smtp.NewSMTPBackend(panicHandler, eventListener, pref, bridgeInstance)

View File

@ -97,7 +97,7 @@ func New(
// Allow DoH before starting bridge if the user has previously set this setting. // Allow DoH before starting bridge if the user has previously set this setting.
// This allows us to start even if protonmail is blocked. // This allows us to start even if protonmail is blocked.
if pref.GetBool(preferences.AllowProxyKey) { if pref.GetBool(preferences.AllowProxyKey) {
AllowDoH() b.AllowProxy()
} }
go func() { go func() {
@ -178,15 +178,16 @@ func (b *Bridge) watchBridgeOutdated() {
} }
} }
// watchUserAuths receives auths from the client manager and sends them to the appropriate user.
func (b *Bridge) watchUserAuths() { func (b *Bridge) watchUserAuths() {
for auth := range b.clientManager.GetBridgeAuthChannel() { for auth := range b.clientManager.GetBridgeAuthChannel() {
user, ok := b.hasUser(auth.UserID) logrus.WithField("token", auth.Auth.GenToken()).WithField("userID", auth.UserID).Info("Received auth from bridge auth channel")
if !ok { if user, ok := b.hasUser(auth.UserID); ok {
continue user.ReceiveAPIAuth(auth.Auth)
} else {
logrus.Info("User is not added to bridge yet")
} }
user.ReceiveAPIAuth(auth.Auth)
} }
} }
@ -274,7 +275,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
apiClient := b.clientManager.GetClient(apiUser.ID) apiClient := b.clientManager.GetClient(apiUser.ID)
auth, err = apiClient.AuthRefresh(auth.GenToken()) auth, err = apiClient.AuthRefresh(auth.GenToken())
if err != nil { if err != nil {
log.WithError(err).Error("Could refresh token in new client") log.WithError(err).Error("Could not refresh token in new client")
return return
} }
@ -298,6 +299,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user") log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user")
return return
} }
b.users = append(b.users, user)
} }
// Set up the user auth and store (which we do for both new and existing users). // Set up the user auth and store (which we do for both new and existing users).
@ -307,7 +309,6 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
} }
if !hasUser { if !hasUser {
b.users = append(b.users, user)
b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel)) b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel))
} }
@ -475,16 +476,16 @@ func (b *Bridge) GetIMAPUpdatesChannel() chan interface{} {
return b.idleUpdates return b.idleUpdates
} }
// AllowDoH instructs bridge to use DoH to access an API proxy if necessary. // AllowProxy instructs bridge to use DoH to access an API proxy if necessary.
// It also needs to work before bridge is initialised (because we may need to use the proxy at startup). // It also needs to work before bridge is initialised (because we may need to use the proxy at startup).
func AllowDoH() { func (b *Bridge) AllowProxy() {
pmapi.GlobalAllowDoH() b.clientManager.AllowProxy()
} }
// DisallowDoH instructs bridge to not use DoH to access an API proxy if necessary. // DisallowProxy instructs bridge to not use DoH to access an API proxy if necessary.
// It also needs to work before bridge is initialised (because we may need to use the proxy at startup). // It also needs to work before bridge is initialised (because we may need to use the proxy at startup).
func DisallowDoH() { func (b *Bridge) DisallowProxy() {
pmapi.GlobalDisallowDoH() b.clientManager.DisallowProxy()
} }
func (b *Bridge) updateCurrentUserAgent() { func (b *Bridge) updateCurrentUserAgent() {
@ -493,7 +494,11 @@ func (b *Bridge) updateCurrentUserAgent() {
// hasUser returns whether the bridge currently has a user with ID `id`. // hasUser returns whether the bridge currently has a user with ID `id`.
func (b *Bridge) hasUser(id string) (user *User, ok bool) { func (b *Bridge) hasUser(id string) (user *User, ok bool) {
logrus.WithField("id", id).Info("Checking whether bridge has given user")
for _, u := range b.users { for _, u := range b.users {
logrus.WithField("id", u.ID()).Info("Found potential user")
if u.ID() == id { if u.ID() == id {
user, ok = u, true user, ok = u, true
return return

View File

@ -107,6 +107,8 @@ func (u *User) init(idleUpdates chan interface{}) (err error) {
u.wasKeyringUnlocked = false u.wasKeyringUnlocked = false
u.unlockingKeyringLock.Unlock() u.unlockingKeyringLock.Unlock()
u.log.Info("Initialising user")
// Reload the user's credentials (if they log out and back in we need the new // Reload the user's credentials (if they log out and back in we need the new
// version with the apitoken and mailbox password). // version with the apitoken and mailbox password).
creds, err := u.credStorer.Get(u.userID) creds, err := u.credStorer.Get(u.userID)
@ -242,27 +244,19 @@ func (u *User) authorizeAndUnlock() (err error) {
} }
func (u *User) ReceiveAPIAuth(auth *pmapi.Auth) { func (u *User) ReceiveAPIAuth(auth *pmapi.Auth) {
u.lock.Lock()
defer u.lock.Unlock()
if auth == nil { if auth == nil {
if err := u.logout(); err != nil { if err := u.logout(); err != nil {
u.log.WithError(err).Error("Cannot logout user after receiving empty auth from API") u.log.WithError(err).Error("Failed to logout user after receiving empty auth from API")
} }
u.isAuthorized = false u.isAuthorized = false
return return
} }
u.updateAPIToken(auth.GenToken()) if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
} u.log.WithError(err).Error("Failed to update refresh token in credentials store")
// updateAPIToken is helper for updating the token in keychain. It's not supposed to be
// called directly from other parts of the code, only from `ReceiveAPIAuth`.
func (u *User) updateAPIToken(newRefreshToken string) {
u.lock.Lock()
defer u.lock.Unlock()
u.log.WithField("token", newRefreshToken).Info("Saving token to credentials store")
if err := u.credStorer.UpdateToken(u.userID, newRefreshToken); err != nil {
u.log.WithError(err).Error("Cannot update refresh token in credentials store")
return return
} }

View File

@ -22,7 +22,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/internal/preferences"
"github.com/ProtonMail/proton-bridge/pkg/connection" "github.com/ProtonMail/proton-bridge/pkg/connection"
"github.com/ProtonMail/proton-bridge/pkg/ports" "github.com/ProtonMail/proton-bridge/pkg/ports"
@ -135,13 +134,13 @@ func (f *frontendCLI) toggleAllowProxy(c *ishell.Context) {
f.Println("Bridge is currently set to use alternative routing to connect to Proton if it is being blocked.") f.Println("Bridge is currently set to use alternative routing to connect to Proton if it is being blocked.")
if f.yesNoQuestion("Are you sure you want to stop bridge from doing this") { if f.yesNoQuestion("Are you sure you want to stop bridge from doing this") {
f.preferences.SetBool(preferences.AllowProxyKey, false) f.preferences.SetBool(preferences.AllowProxyKey, false)
bridge.DisallowDoH() f.bridge.DisallowProxy()
} }
} else { } else {
f.Println("Bridge is currently set to NOT use alternative routing to connect to Proton if it is being blocked.") f.Println("Bridge is currently set to NOT use alternative routing to connect to Proton if it is being blocked.")
if f.yesNoQuestion("Are you sure you want to allow bridge to do this") { if f.yesNoQuestion("Are you sure you want to allow bridge to do this") {
f.preferences.SetBool(preferences.AllowProxyKey, true) f.preferences.SetBool(preferences.AllowProxyKey, true)
bridge.AllowDoH() f.bridge.AllowProxy()
} }
} }
} }

View File

@ -52,6 +52,8 @@ type Bridger interface {
DeleteUser(userID string, clearCache bool) error DeleteUser(userID string, clearCache bool) error
ReportBug(osType, osVersion, description, accountName, address, emailClient string) error ReportBug(osType, osVersion, description, accountName, address, emailClient string) error
ClearData() error ClearData() error
AllowProxy()
DisallowProxy()
} }
// BridgeUser is an interface of user needed by frontend. // BridgeUser is an interface of user needed by frontend.

View File

@ -21,11 +21,14 @@
package pmapifactory package pmapifactory
import ( import (
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
) )
func GetClientConfig(config bridge.Configer, _ listener.Listener) *pmapi.ClientConfig { func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig {
return config.GetAPIConfig() return clientConfig
}
func SetClientRoundTripper(_ *pmapi.ClientManager, _ *pmapi.ClientConfig, _ listener.Listener) {
// Use the default roundtripper; do nothing.
} }

View File

@ -23,26 +23,13 @@ package pmapifactory
import ( import (
"time" "time"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/sirupsen/logrus"
) )
func GetClientConfig(config bridge.Configer, listener listener.Listener) *pmapi.ClientConfig { func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig {
clientConfig := config.GetAPIConfig()
pin := pmapi.NewPMAPIPinning(clientConfig.AppVersion)
pin.ReportCertIssueLocal = func() {
listener.Emit(events.TLSCertIssue, "")
}
// This transport already has timeouts set governing the roundtrip:
// - IdleConnTimeout: 5 * time.Minute,
// - ExpectContinueTimeout: 500 * time.Millisecond,
// - ResponseHeaderTimeout: 30 * time.Second,
clientConfig.Transport = pin.TransportWithPinning()
// We set additional timeouts/thresholds for the request as a whole: // We set additional timeouts/thresholds for the request as a whole:
clientConfig.Timeout = 10 * time.Minute // Overall request timeout (~25MB / 10 mins => ~40kB/s, should be reasonable). clientConfig.Timeout = 10 * time.Minute // Overall request timeout (~25MB / 10 mins => ~40kB/s, should be reasonable).
clientConfig.FirstReadTimeout = 30 * time.Second // 30s to match 30s response header timeout. clientConfig.FirstReadTimeout = 30 * time.Second // 30s to match 30s response header timeout.
@ -50,3 +37,15 @@ func GetClientConfig(config bridge.Configer, listener listener.Listener) *pmapi.
return clientConfig return clientConfig
} }
func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, listener listener.Listener) {
logrus.Info("Setting dialer with pinning")
pin := pmapi.NewDialerWithPinning(cm, cfg.AppVersion)
pin.ReportCertIssueLocal = func() {
listener.Emit(events.TLSCertIssue, "")
}
cm.SetClientRoundTripper(pin.TransportWithPinning())
}

View File

@ -39,7 +39,8 @@ var (
// Two errors can be returned, ErrNoInternetConnection or ErrCanNotReachAPI. // Two errors can be returned, ErrNoInternetConnection or ErrCanNotReachAPI.
func CheckInternetConnection() error { func CheckInternetConnection() error {
client := &http.Client{ client := &http.Client{
Transport: pmapi.NewPMAPIPinning(pmapi.CurrentUserAgent).TransportWithPinning(), // TODO: Set transport properly! (Need access to ClientManager somehow)
// Transport: pmapi.NewDialerWithPinning(pmapi.CurrentUserAgent).TransportWithPinning(),
} }
// Do not cumulate timeouts, use goroutines. // Do not cumulate timeouts, use goroutines.
@ -51,7 +52,8 @@ func CheckInternetConnection() error {
go checkConnection(client, "http://protonstatus.com/vpn_status", retStatus) go checkConnection(client, "http://protonstatus.com/vpn_status", retStatus)
// Check of API reachability also uses a fast endpoint. // Check of API reachability also uses a fast endpoint.
go checkConnection(client, pmapi.GlobalGetRootURL()+"/tests/ping", retAPI) // TODO: This should check active proxy, not the RootURL
go checkConnection(client, pmapi.RootURL+"/tests/ping", retAPI)
errStatus := <-retStatus errStatus := <-retStatus
errAPI := <-retAPI errAPI := <-retAPI

View File

@ -36,7 +36,7 @@ type osxkeychain struct {
} }
func newKeychain() (credentials.Helper, error) { func newKeychain() (credentials.Helper, error) {
log.Debug("creating osckeychain") log.Debug("Creating osckeychain")
return &osxkeychain{}, nil return &osxkeychain{}, nil
} }

View File

@ -24,14 +24,14 @@ import (
) )
func newKeychain() (credentials.Helper, error) { func newKeychain() (credentials.Helper, error) {
log.Debug("creating pass") log.Debug("Creating pass")
passHelper := &pass.Pass{} passHelper := &pass.Pass{}
passErr := checkPassIsUsable(passHelper) passErr := checkPassIsUsable(passHelper)
if passErr == nil { if passErr == nil {
return passHelper, nil return passHelper, nil
} }
log.Debug("creating secretservice") log.Debug("Creating secretservice")
sserviceHelper := &secretservice.Secretservice{} sserviceHelper := &secretservice.Secretservice{}
_, sserviceErr := sserviceHelper.List() _, sserviceErr := sserviceHelper.List()
if sserviceErr == nil { if sserviceErr == nil {

View File

@ -23,7 +23,7 @@ import (
) )
func newKeychain() (credentials.Helper, error) { func newKeychain() (credentials.Helper, error) {
log.Debug("creating wincred") log.Debug("Creating wincred")
return &wincred.Wincred{}, nil return &wincred.Wincred{}, nil
} }

View File

@ -161,7 +161,7 @@ func ConstructAddress(headerEmail string, addressEmail string) string {
// GetAddresses requests all of current user addresses (without pagination). // GetAddresses requests all of current user addresses (without pagination).
func (c *Client) GetAddresses() (addresses AddressList, err error) { func (c *Client) GetAddresses() (addresses AddressList, err error) {
req, err := NewRequest("GET", "/addresses", nil) req, err := c.NewRequest("GET", "/addresses", nil)
if err != nil { if err != nil {
return return
} }

View File

@ -179,7 +179,7 @@ func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.R
// //
// The returned created attachment contains the new attachment ID and its size. // The returned created attachment contains the new attachment ID and its size.
func (c *Client) CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) { func (c *Client) CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) {
req, w, err := NewMultipartRequest("POST", "/attachments") req, w, err := c.NewMultipartRequest("POST", "/attachments")
if err != nil { if err != nil {
return return
} }
@ -213,7 +213,7 @@ type UpdateAttachmentSignatureReq struct {
func (c *Client) UpdateAttachmentSignature(attachmentID, signature string) (err error) { func (c *Client) UpdateAttachmentSignature(attachmentID, signature string) (err error) {
updateReq := &UpdateAttachmentSignatureReq{signature} updateReq := &UpdateAttachmentSignatureReq{signature}
req, err := NewJSONRequest("PUT", "/attachments/"+attachmentID+"/signature", updateReq) req, err := c.NewJSONRequest("PUT", "/attachments/"+attachmentID+"/signature", updateReq)
if err != nil { if err != nil {
return return
} }
@ -228,7 +228,7 @@ func (c *Client) UpdateAttachmentSignature(attachmentID, signature string) (err
// DeleteAttachment removes an attachment. message is the message ID, att is the attachment ID. // DeleteAttachment removes an attachment. message is the message ID, att is the attachment ID.
func (c *Client) DeleteAttachment(attID string) (err error) { func (c *Client) DeleteAttachment(attID string) (err error) {
req, err := NewRequest("DELETE", "/attachments/"+attID, nil) req, err := c.NewRequest("DELETE", "/attachments/"+attID, nil)
if err != nil { if err != nil {
return return
} }
@ -249,7 +249,7 @@ func (c *Client) GetAttachment(id string) (att io.ReadCloser, err error) {
return return
} }
req, err := NewRequest("GET", "/attachments/"+id, nil) req, err := c.NewRequest("GET", "/attachments/"+id, nil)
if err != nil { if err != nil {
return return
} }

View File

@ -214,7 +214,7 @@ func (c *Client) AuthInfo(username string) (info *AuthInfo, err error) {
Username: username, Username: username,
} }
req, err := NewJSONRequest("POST", "/auth/info", infoReq) req, err := c.NewJSONRequest("POST", "/auth/info", infoReq)
if err != nil { if err != nil {
return return
} }
@ -257,7 +257,7 @@ func (c *Client) tryAuth(username, password string, info *AuthInfo, fallbackVers
SRPSession: info.srpSession, SRPSession: info.srpSession,
} }
req, err := NewJSONRequest("POST", "/auth", authReq) req, err := c.NewJSONRequest("POST", "/auth", authReq)
if err != nil { if err != nil {
return return
} }
@ -335,7 +335,7 @@ func (c *Client) Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error) {
TwoFactorCode: twoFactorCode, TwoFactorCode: twoFactorCode,
} }
req, err := NewJSONRequest("POST", "/auth/2fa", auth2FAReq) req, err := c.NewJSONRequest("POST", "/auth/2fa", auth2FAReq)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -430,7 +430,7 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
// UID must be set for `x-pm-uid` header field, see backend-communication#11 // UID must be set for `x-pm-uid` header field, see backend-communication#11
c.uid = split[0] c.uid = split[0]
req, err := NewJSONRequest("POST", "/auth/refresh", refreshReq) req, err := c.NewJSONRequest("POST", "/auth/refresh", refreshReq)
if err != nil { if err != nil {
return return
} }
@ -450,13 +450,14 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
return auth, err return auth, err
} }
// Logout instructs the client manager to log out this client.
func (c *Client) Logout() { func (c *Client) Logout() {
c.cm.LogoutClient(c.userID) c.cm.LogoutClient(c.userID)
} }
// logout logs the current user out. // logout logs the current user out.
func (c *Client) logout() (err error) { func (c *Client) logout() (err error) {
req, err := NewRequest("DELETE", "/auth", nil) req, err := c.NewRequest("DELETE", "/auth", nil)
if err != nil { if err != nil {
return return
} }

View File

@ -332,7 +332,9 @@ func TestClient_Logout(t *testing.T) {
c.uid = testUID c.uid = testUID
c.accessToken = testAccessToken c.accessToken = testAccessToken
Ok(t, c.Logout()) c.Logout()
// TODO: Check that the client is logged out and sensitive data is cleared eventually.
} }
func TestClient_DoUnauthorized(t *testing.T) { func TestClient_DoUnauthorized(t *testing.T) {
@ -355,7 +357,7 @@ func TestClient_DoUnauthorized(t *testing.T) {
c.expiresAt = aLongTimeAgo c.expiresAt = aLongTimeAgo
c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken
req, err := NewRequest("GET", "/", nil) req, err := c.NewRequest("GET", "/", nil)
Ok(t, err) Ok(t, err)
res, err := c.Do(req, true) res, err := c.Do(req, true)

View File

@ -139,9 +139,9 @@ func (c *Client) Report(rep ReportReq) (err error) {
var req *http.Request var req *http.Request
var w *MultipartWriter var w *MultipartWriter
if len(rep.Attachments) > 0 { if len(rep.Attachments) > 0 {
req, w, err = NewMultipartRequest("POST", "/reports/bug") req, w, err = c.NewMultipartRequest("POST", "/reports/bug")
} else { } else {
req, err = NewJSONRequest("POST", "/reports/bug", rep) req, err = c.NewJSONRequest("POST", "/reports/bug", rep)
} }
if err != nil { if err != nil {
return return
@ -202,7 +202,7 @@ func (c *Client) ReportCrash(stacktrace string) (err error) {
OS: runtime.GOOS, OS: runtime.GOOS,
Debug: stacktrace, Debug: stacktrace,
} }
req, err := NewJSONRequest("POST", "/reports/crash", crashReq) req, err := c.NewJSONRequest("POST", "/reports/crash", crashReq)
if err != nil { if err != nil {
return return
} }

View File

@ -99,12 +99,12 @@ type ClientConfig struct {
// Client to communicate with API. // Client to communicate with API.
type Client struct { type Client struct {
cm *ClientManager cm *ClientManager
client *http.Client hc *http.Client
uid string uid string
accessToken string accessToken string
userID string // Twice here because Username is not unique. userID string
requestLocker sync.Locker requestLocker sync.Locker
keyLocker sync.Locker keyLocker sync.Locker
@ -120,7 +120,7 @@ type Client struct {
func newClient(cm *ClientManager, userID string) *Client { func newClient(cm *ClientManager, userID string) *Client {
return &Client{ return &Client{
cm: cm, cm: cm,
client: getHTTPClient(cm.GetConfig()), hc: getHTTPClient(cm.GetConfig()),
userID: userID, userID: userID,
requestLocker: &sync.Mutex{}, requestLocker: &sync.Mutex{},
keyLocker: &sync.Mutex{}, keyLocker: &sync.Mutex{},
@ -132,12 +132,10 @@ func newClient(cm *ClientManager, userID string) *Client {
func getHTTPClient(cfg *ClientConfig) (hc *http.Client) { func getHTTPClient(cfg *ClientConfig) (hc *http.Client) {
hc = &http.Client{Timeout: cfg.Timeout} hc = &http.Client{Timeout: cfg.Timeout}
if cfg.Transport == nil && defaultTransport == nil { if cfg.Transport == nil {
return if defaultTransport != nil {
} hc.Transport = defaultTransport
}
if defaultTransport != nil {
hc.Transport = defaultTransport
return return
} }
@ -205,7 +203,7 @@ func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthori
} }
hasBody := len(bodyBuffer) > 0 hasBody := len(bodyBuffer) > 0
if res, err = c.client.Do(req); err != nil { if res, err = c.hc.Do(req); err != nil {
if res == nil { if res == nil {
c.log.WithError(err).Error("Cannot get response") c.log.WithError(err).Error("Cannot get response")
err = ErrAPINotReachable err = ErrAPINotReachable

View File

@ -51,7 +51,7 @@ func TestClient_Do(t *testing.T) {
})) }))
defer s.Close() defer s.Close()
req, err := NewRequest("GET", "/", nil) req, err := c.NewRequest("GET", "/", nil)
if err != nil { if err != nil {
t.Fatal("Expected no error while creating request, got:", err) t.Fatal("Expected no error while creating request, got:", err)
} }
@ -163,8 +163,8 @@ func TestClient_FirstReadTimeout(t *testing.T) {
) )
defer finish() defer finish()
c.client.Transport = &slowTransport{ c.hc.Transport = &slowTransport{
transport: c.client.Transport, transport: c.hc.Transport,
firstBodySleep: requestTimeout, firstBodySleep: requestTimeout,
} }

View File

@ -1,24 +1,32 @@
package pmapi package pmapi
import ( import (
"net/http"
"sync" "sync"
"github.com/getsentry/raven-go" "github.com/getsentry/raven-go"
"github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// ClientManager is a manager of clients. // ClientManager is a manager of clients.
type ClientManager struct { type ClientManager struct {
config *ClientConfig
clients map[string]*Client clients map[string]*Client
clientsLocker sync.Locker clientsLocker sync.Locker
tokens map[string]string tokens map[string]string
tokensLocker sync.Locker tokensLocker sync.Locker
config *ClientConfig url string
urlLocker sync.Locker
bridgeAuths chan ClientAuth bridgeAuths chan ClientAuth
clientAuths chan ClientAuth clientAuths chan ClientAuth
allowProxy bool
proxyProvider *proxyProvider
} }
type ClientAuth struct { type ClientAuth struct {
@ -33,13 +41,21 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
} }
cm = &ClientManager{ cm = &ClientManager{
config: config,
clients: make(map[string]*Client), clients: make(map[string]*Client),
clientsLocker: &sync.Mutex{}, clientsLocker: &sync.Mutex{},
tokens: make(map[string]string),
tokensLocker: &sync.Mutex{}, tokens: make(map[string]string),
config: config, tokensLocker: &sync.Mutex{},
bridgeAuths: make(chan ClientAuth),
clientAuths: make(chan ClientAuth), url: RootURL,
urlLocker: &sync.Mutex{},
bridgeAuths: make(chan ClientAuth),
clientAuths: make(chan ClientAuth),
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
} }
go cm.forwardClientAuths() go cm.forwardClientAuths()
@ -47,6 +63,12 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
return return
} }
// SetClientRoundTripper sets the roundtripper used by clients created by this client manager.
func (cm *ClientManager) SetClientRoundTripper(rt http.RoundTripper) {
logrus.Info("Setting client roundtripper")
cm.config.Transport = rt
}
// GetClient returns a client for the given userID. // GetClient returns a client for the given userID.
// If the client does not exist already, it is created. // If the client does not exist already, it is created.
func (cm *ClientManager) GetClient(userID string) *Client { func (cm *ClientManager) GetClient(userID string) *Client {
@ -71,7 +93,7 @@ func (cm *ClientManager) LogoutClient(userID string) {
go func() { go func() {
if err := client.logout(); err != nil { if err := client.logout(); err != nil {
// TODO: Try again! // TODO: Try again! This should loop until it succeeds (might fail the first time due to internet).
logrus.WithError(err).Error("Client logout failed, not trying again") logrus.WithError(err).Error("Client logout failed, not trying again")
} }
client.clearSensitiveData() client.clearSensitiveData()
@ -81,6 +103,56 @@ func (cm *ClientManager) LogoutClient(userID string) {
return return
} }
// GetRootURL returns the root URL to make requests to.
// It does not include the protocol i.e. no "https://".
func (cm *ClientManager) GetRootURL() string {
cm.urlLocker.Lock()
defer cm.urlLocker.Unlock()
return cm.url
}
// SetRootURL sets the root URL to make requests to.
func (cm *ClientManager) SetRootURL(url string) {
cm.urlLocker.Lock()
defer cm.urlLocker.Unlock()
logrus.WithField("url", url).Info("Changing to a new root URL")
cm.url = url
}
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
func (cm *ClientManager) IsProxyAllowed() bool {
return cm.allowProxy
}
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
func (cm *ClientManager) AllowProxy() {
cm.allowProxy = true
}
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
func (cm *ClientManager) DisallowProxy() {
cm.allowProxy = false
}
// FindProxy returns a usable proxy server.
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
logrus.Info("Attempting gto switch to a proxy")
if proxy, err = cm.proxyProvider.findProxy(); err != nil {
err = errors.Wrap(err, "failed to find usable proxy")
return
}
cm.SetRootURL(proxy)
// TODO: Disable after 24 hours.
return
}
// GetConfig returns the config used to configure clients. // GetConfig returns the config used to configure clients.
func (cm *ClientManager) GetConfig() *ClientConfig { func (cm *ClientManager) GetConfig() *ClientConfig {
return cm.config return cm.config
@ -113,7 +185,7 @@ func (cm *ClientManager) setToken(userID, token string) {
cm.tokensLocker.Lock() cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock() defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).WithField("token", token).Info("Updating refresh token") logrus.WithField("userID", userID).Info("Updating refresh token")
cm.tokens[userID] = token cm.tokens[userID] = token
} }
@ -127,16 +199,19 @@ func (cm *ClientManager) clearToken(userID string) {
delete(cm.tokens, userID) delete(cm.tokens, userID)
} }
// handleClientAuth // handleClientAuth updates or clears client authorisation based on auths received.
func (cm *ClientManager) handleClientAuth(ca ClientAuth) { func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
// TODO: Maybe want to logout the client in case of nil auth. // If we aren't managing this client, there's nothing to do.
if _, ok := cm.clients[ca.UserID]; !ok { if _, ok := cm.clients[ca.UserID]; !ok {
return return
} }
// If the auth is nil, we should clear the token.
// TODO: Maybe we should trigger a client logout here? Then we don't have to remember to log it out ourself.
if ca.Auth == nil { if ca.Auth == nil {
cm.clearToken(ca.UserID) cm.clearToken(ca.UserID)
} else { return
cm.setToken(ca.UserID, ca.Auth.GenToken())
} }
cm.setToken(ca.UserID, ca.Auth.GenToken())
} }

View File

@ -24,9 +24,9 @@ import (
// RootURL is the API root URL. // RootURL is the API root URL.
// //
// This can be changed using build flags: pmapi_local for "http://localhost/api", // This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
// pmapi_dev or pmapi_prod. Default is pmapi_prod. // Default is pmapi_prod.
var RootURL = "https://api.protonmail.ch" //nolint[gochecknoglobals] var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
// CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program // CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program
// version and email client. // version and email client.

View File

@ -20,5 +20,5 @@
package pmapi package pmapi
func init() { func init() {
RootURL = "https://dev.protonmail.com/api" RootURL = "dev.protonmail.com/api"
} }

View File

@ -27,7 +27,7 @@ import (
func init() { func init() {
// Use port above 1000 which doesn't need root access to start anything on it. // Use port above 1000 which doesn't need root access to start anything on it.
// Now the port is rounded pi. :-) // Now the port is rounded pi. :-)
RootURL = "http://127.0.0.1:3142/api" RootURL = "127.0.0.1:3142/api"
// TLS certificate is self-signed // TLS certificate is self-signed
defaultTransport = &http.Transport{ defaultTransport = &http.Transport{

View File

@ -119,7 +119,7 @@ func (c *Client) GetContacts(page int, pageSize int) (contacts []*Contact, err e
if pageSize > 0 { if pageSize > 0 {
v.Set("PageSize", strconv.Itoa(pageSize)) v.Set("PageSize", strconv.Itoa(pageSize))
} }
req, err := NewRequest("GET", "/contacts?"+v.Encode(), nil) req, err := c.NewRequest("GET", "/contacts?"+v.Encode(), nil)
if err != nil { if err != nil {
return return
@ -136,7 +136,7 @@ func (c *Client) GetContacts(page int, pageSize int) (contacts []*Contact, err e
// GetContactByID gets contact details specified by contact ID. // GetContactByID gets contact details specified by contact ID.
func (c *Client) GetContactByID(id string) (contactDetail Contact, err error) { func (c *Client) GetContactByID(id string) (contactDetail Contact, err error) {
req, err := NewRequest("GET", "/contacts/"+id, nil) req, err := c.NewRequest("GET", "/contacts/"+id, nil)
if err != nil { if err != nil {
return return
@ -164,7 +164,7 @@ func (c *Client) GetContactsForExport(page int, pageSize int) (contacts []Contac
v.Set("PageSize", strconv.Itoa(pageSize)) v.Set("PageSize", strconv.Itoa(pageSize))
} }
req, err := NewRequest("GET", "/contacts/export?"+v.Encode(), nil) req, err := c.NewRequest("GET", "/contacts/export?"+v.Encode(), nil)
if err != nil { if err != nil {
return return
@ -198,7 +198,7 @@ func (c *Client) GetAllContactsEmails(page int, pageSize int) (contactsEmails []
v.Set("PageSize", strconv.Itoa(pageSize)) v.Set("PageSize", strconv.Itoa(pageSize))
} }
req, err := NewRequest("GET", "/contacts/emails?"+v.Encode(), nil) req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
if err != nil { if err != nil {
return return
} }
@ -221,7 +221,7 @@ func (c *Client) GetContactEmailByEmail(email string, page int, pageSize int) (c
} }
v.Set("Email", email) v.Set("Email", email)
req, err := NewRequest("GET", "/contacts/emails?"+v.Encode(), nil) req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
if err != nil { if err != nil {
return return
} }
@ -276,7 +276,7 @@ func (c *Client) AddContacts(cards ContactsCards, overwrite int, groups int, lab
Labels: labels, Labels: labels,
} }
req, err := NewJSONRequest("POST", "/contacts", reqBody) req, err := c.NewJSONRequest("POST", "/contacts", reqBody)
if err != nil { if err != nil {
return return
} }
@ -306,7 +306,7 @@ func (c *Client) UpdateContact(id string, cards []Card) (res *UpdateContactRespo
reqBody := UpdateContactReq{ reqBody := UpdateContactReq{
Cards: cards, Cards: cards,
} }
req, err := NewJSONRequest("PUT", "/contacts/"+id, reqBody) req, err := c.NewJSONRequest("PUT", "/contacts/"+id, reqBody)
if err != nil { if err != nil {
return return
} }
@ -354,7 +354,7 @@ func (c *Client) modifyContactGroups(groupID string, modifyContactGroupsAction i
Action: modifyContactGroupsAction, Action: modifyContactGroupsAction,
ContactEmailIDs: contactEmailIDs, ContactEmailIDs: contactEmailIDs,
} }
req, err := NewJSONRequest("PUT", "/contacts/group", reqBody) req, err := c.NewJSONRequest("PUT", "/contacts/group", reqBody)
if err != nil { if err != nil {
return return
} }
@ -377,7 +377,7 @@ func (c *Client) DeleteContacts(ids []string) (err error) {
IDs: ids, IDs: ids,
} }
req, err := NewJSONRequest("PUT", "/contacts/delete", deleteReq) req, err := c.NewJSONRequest("PUT", "/contacts/delete", deleteReq)
if err != nil { if err != nil {
return return
} }
@ -402,7 +402,7 @@ func (c *Client) DeleteContacts(ids []string) (err error) {
// DeleteAllContacts deletes all contacts. // DeleteAllContacts deletes all contacts.
func (c *Client) DeleteAllContacts() (err error) { func (c *Client) DeleteAllContacts() (err error) {
req, err := NewRequest("DELETE", "/contacts", nil) req, err := c.NewRequest("DELETE", "/contacts", nil)
if err != nil { if err != nil {
return return
} }

View File

@ -36,7 +36,7 @@ func (c *Client) CountConversations(addressID string) (counts []*ConversationsCo
if addressID != "" { if addressID != "" {
reqURL += ("?AddressID=" + addressID) reqURL += ("?AddressID=" + addressID)
} }
req, err := NewRequest("GET", reqURL, nil) req, err := c.NewRequest("GET", reqURL, nil)
if err != nil { if err != nil {
return return
} }

View File

@ -112,51 +112,47 @@ type DialerWithPinning struct {
// It is used only if set. // It is used only if set.
ReportCertIssueLocal func() ReportCertIssueLocal func()
// proxyManager manages API proxies. // cm is used to find and switch to a proxy if necessary.
proxyManager *proxyManager cm *ClientManager
// A logger for logging messages. // A logger for logging messages.
log logrus.FieldLogger log logrus.FieldLogger
} }
func NewDialerWithPinning(reportURI string, report TLSReport) *DialerWithPinning { // NewDialerWithPinning constructs a new dialer with pinned certs.
func NewDialerWithPinning(cm *ClientManager, appVersion string) *DialerWithPinning {
reportURI := "https://reports.protonmail.ch/reports/tls"
report := TLSReport{
EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339),
IncludeSubdomains: false,
ValidatedCertificateChain: []string{},
ServedCertificateChain: []string{},
AppVersion: appVersion,
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;)
KnownPins: []string{
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3
},
}
log := logrus.WithField("pkg", "pmapi/tls-pinning") log := logrus.WithField("pkg", "pmapi/tls-pinning")
proxyManager := newProxyManager(dohProviders, proxyQuery)
return &DialerWithPinning{ return &DialerWithPinning{
isReported: false, cm: cm,
reportURI: reportURI, isReported: false,
report: report, reportURI: reportURI,
proxyManager: proxyManager, report: report,
log: log, log: log,
} }
} }
func NewPMAPIPinning(appVersion string) *DialerWithPinning {
return NewDialerWithPinning(
"https://reports.protonmail.ch/reports/tls",
TLSReport{
EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339),
IncludeSubdomains: false,
ValidatedCertificateChain: []string{},
ServedCertificateChain: []string{},
AppVersion: appVersion,
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;)
KnownPins: []string{
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3
},
},
)
}
func (p *DialerWithPinning) reportCertIssue(connState tls.ConnectionState) { func (p *DialerWithPinning) reportCertIssue(connState tls.ConnectionState) {
p.isReported = true p.isReported = true
@ -231,6 +227,7 @@ func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
return pemCerts return pemCerts
} }
// TransportWithPinning creates an http.Transport that checks fingerprints when dialing.
func (p *DialerWithPinning) TransportWithPinning() *http.Transport { func (p *DialerWithPinning) TransportWithPinning() *http.Transport {
return &http.Transport{ return &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
@ -258,7 +255,7 @@ func (p *DialerWithPinning) TransportWithPinning() *http.Transport {
// p.ReportCertIssueLocal() and p.reportCertIssueRemote() if they are not nil. // p.ReportCertIssueLocal() and p.reportCertIssueRemote() if they are not nil.
func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (conn net.Conn, err error) { func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (conn net.Conn, err error) {
// If DoH is enabled, we hardfail on fingerprint mismatches. // If DoH is enabled, we hardfail on fingerprint mismatches.
if globalIsDoHAllowed() && p.isReported { if p.cm.IsProxyAllowed() && p.isReported {
return nil, ErrTLSMatch return nil, ErrTLSMatch
} }
@ -283,6 +280,8 @@ func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (c
// dialWithProxyFallback tries to dial the given address but falls back to alternative proxies if need be. // dialWithProxyFallback tries to dial the given address but falls back to alternative proxies if need be.
func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn net.Conn, err error) { func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn net.Conn, err error) {
p.log.Info("Dialing with proxy fallback")
var host, port string var host, port string
if host, port, err = net.SplitHostPort(address); err != nil { if host, port, err = net.SplitHostPort(address); err != nil {
return return
@ -296,21 +295,18 @@ func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn
// If DoH is not allowed, give up. Or, if we are dialing something other than the API // If DoH is not allowed, give up. Or, if we are dialing something other than the API
// (e.g. we dial protonmail.com/... to check for updates), there's also no point in // (e.g. we dial protonmail.com/... to check for updates), there's also no point in
// continuing since a proxy won't help us reach that. // continuing since a proxy won't help us reach that.
if !globalIsDoHAllowed() || host != stripProtocol(GlobalGetRootURL()) { if !p.cm.IsProxyAllowed() || host != p.cm.GetRootURL() {
p.log.WithField("useProxy", p.cm.IsProxyAllowed()).Info("Dial failed but not switching to proxy")
return return
} }
// Find a new proxy. // Switch to a proxy and retry the dial.
var proxy string var proxy string
if proxy, err = p.proxyManager.findProxy(); err != nil {
if proxy, err = p.cm.SwitchToProxy(); err != nil {
return return
} }
// Switch to the proxy.
p.log.WithField("proxy", proxy).Debug("Switching to proxy")
p.proxyManager.useProxy(proxy)
// Retry dial with proxy.
return p.dial(network, net.JoinHostPort(proxy, port)) return p.dial(network, net.JoinHostPort(proxy, port))
} }
@ -329,7 +325,7 @@ func (p *DialerWithPinning) dial(network, address string) (conn net.Conn, err er
// If we are not dialing the standard API then we should skip cert verification checks. // If we are not dialing the standard API then we should skip cert verification checks.
var tlsConfig *tls.Config = nil var tlsConfig *tls.Config = nil
if address != stripProtocol(globalOriginalURL) { if address != RootURL {
tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec] tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec]
} }

View File

@ -179,7 +179,7 @@ func (c *Client) GetEvent(last string) (event *Event, err error) {
func (c *Client) getEvent(last string, numberOfMergedEvents int) (event *Event, err error) { func (c *Client) getEvent(last string, numberOfMergedEvents int) (event *Event, err error) {
var req *http.Request var req *http.Request
if last == "" { if last == "" {
req, err = NewRequest("GET", "/events/latest", nil) req, err = c.NewRequest("GET", "/events/latest", nil)
if err != nil { if err != nil {
return return
} }
@ -191,7 +191,7 @@ func (c *Client) getEvent(last string, numberOfMergedEvents int) (event *Event,
event, err = res.Event, res.Err() event, err = res.Event, res.Err()
} else { } else {
req, err = NewRequest("GET", "/events/"+last, nil) req, err = c.NewRequest("GET", "/events/"+last, nil)
if err != nil { if err != nil {
return return
} }

View File

@ -120,7 +120,7 @@ type ImportMsgRes struct {
func (c *Client) Import(reqs []*ImportMsgReq) (resps []*ImportMsgRes, err error) { func (c *Client) Import(reqs []*ImportMsgReq) (resps []*ImportMsgRes, err error) {
importReq := &ImportReq{Messages: reqs} importReq := &ImportReq{Messages: reqs}
req, w, err := NewMultipartRequest("POST", "/import") req, w, err := c.NewMultipartRequest("POST", "/import")
if err != nil { if err != nil {
return return
} }

View File

@ -57,7 +57,7 @@ func (c *Client) PublicKeys(emails []string) (keys map[string]*pmcrypto.KeyRing,
email = url.QueryEscape(email) email = url.QueryEscape(email)
var req *http.Request var req *http.Request
if req, err = NewRequest("GET", "/keys?Email="+email, nil); err != nil { if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil {
return return
} }
@ -90,7 +90,7 @@ func (c *Client) GetPublicKeysForEmail(email string) (keys []PublicKey, internal
email = url.QueryEscape(email) email = url.QueryEscape(email)
var req *http.Request var req *http.Request
if req, err = NewRequest("GET", "/keys?Email="+email, nil); err != nil { if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil {
return return
} }
@ -123,7 +123,7 @@ type KeySaltRes struct {
// GetKeySalts sends request to get list of key salts (n.b. locked route). // GetKeySalts sends request to get list of key salts (n.b. locked route).
func (c *Client) GetKeySalts() (keySalts []KeySalt, err error) { func (c *Client) GetKeySalts() (keySalts []KeySalt, err error) {
var req *http.Request var req *http.Request
if req, err = NewRequest("GET", "/keys/salts", nil); err != nil { if req, err = c.NewRequest("GET", "/keys/salts", nil); err != nil {
return return
} }

View File

@ -103,7 +103,7 @@ func (c *Client) ListContactGroups() (labels []*Label, err error) {
// ListLabelType lists all labels created by the user. // ListLabelType lists all labels created by the user.
func (c *Client) ListLabelType(labelType int) (labels []*Label, err error) { func (c *Client) ListLabelType(labelType int) (labels []*Label, err error) {
req, err := NewRequest("GET", fmt.Sprintf("/labels?%d", labelType), nil) req, err := c.NewRequest("GET", fmt.Sprintf("/labels?%d", labelType), nil)
if err != nil { if err != nil {
return return
} }
@ -129,7 +129,7 @@ type LabelRes struct {
// CreateLabel creates a new label. // CreateLabel creates a new label.
func (c *Client) CreateLabel(label *Label) (created *Label, err error) { func (c *Client) CreateLabel(label *Label) (created *Label, err error) {
labelReq := &LabelReq{label} labelReq := &LabelReq{label}
req, err := NewJSONRequest("POST", "/labels", labelReq) req, err := c.NewJSONRequest("POST", "/labels", labelReq)
if err != nil { if err != nil {
return return
} }
@ -146,7 +146,7 @@ func (c *Client) CreateLabel(label *Label) (created *Label, err error) {
// UpdateLabel updates a label. // UpdateLabel updates a label.
func (c *Client) UpdateLabel(label *Label) (updated *Label, err error) { func (c *Client) UpdateLabel(label *Label) (updated *Label, err error) {
labelReq := &LabelReq{label} labelReq := &LabelReq{label}
req, err := NewJSONRequest("PUT", "/labels/"+label.ID, labelReq) req, err := c.NewJSONRequest("PUT", "/labels/"+label.ID, labelReq)
if err != nil { if err != nil {
return return
} }
@ -162,7 +162,7 @@ func (c *Client) UpdateLabel(label *Label) (updated *Label, err error) {
// DeleteLabel deletes a label. // DeleteLabel deletes a label.
func (c *Client) DeleteLabel(id string) (err error) { func (c *Client) DeleteLabel(id string) (err error) {
req, err := NewRequest("DELETE", "/labels/"+id, nil) req, err := c.NewRequest("DELETE", "/labels/"+id, nil)
if err != nil { if err != nil {
return return
} }

View File

@ -468,7 +468,7 @@ type MessagesListRes struct {
// ListMessages gets message metadata. // ListMessages gets message metadata.
func (c *Client) ListMessages(filter *MessagesFilter) (msgs []*Message, total int, err error) { func (c *Client) ListMessages(filter *MessagesFilter) (msgs []*Message, total int, err error) {
req, err := NewRequest("GET", "/messages", nil) req, err := c.NewRequest("GET", "/messages", nil)
if err != nil { if err != nil {
return return
} }
@ -500,7 +500,7 @@ func (c *Client) CountMessages(addressID string) (counts []*MessagesCount, err e
if addressID != "" { if addressID != "" {
reqURL += ("?AddressID=" + addressID) reqURL += ("?AddressID=" + addressID)
} }
req, err := NewRequest("GET", reqURL, nil) req, err := c.NewRequest("GET", reqURL, nil)
if err != nil { if err != nil {
return return
} }
@ -522,7 +522,7 @@ type MessageRes struct {
// GetMessage retrieves a message. // GetMessage retrieves a message.
func (c *Client) GetMessage(id string) (msg *Message, err error) { func (c *Client) GetMessage(id string) (msg *Message, err error) {
req, err := NewRequest("GET", "/messages/"+id, nil) req, err := c.NewRequest("GET", "/messages/"+id, nil)
if err != nil { if err != nil {
return return
} }
@ -599,7 +599,7 @@ func (c *Client) SendMessage(id string, sendReq *SendMessageReq) (sent, parent *
sendReq.Packages = []*MessagePackage{} sendReq.Packages = []*MessagePackage{}
} }
req, err := NewJSONRequest("POST", "/messages/"+id, sendReq) req, err := c.NewJSONRequest("POST", "/messages/"+id, sendReq)
if err != nil { if err != nil {
return return
} }
@ -629,7 +629,7 @@ type DraftReq struct {
func (c *Client) CreateDraft(m *Message, parent string, action int) (created *Message, err error) { func (c *Client) CreateDraft(m *Message, parent string, action int) (created *Message, err error) {
createReq := &DraftReq{Message: m, ParentID: parent, Action: action, AttachmentKeyPackets: []string{}} createReq := &DraftReq{Message: m, ParentID: parent, Action: action, AttachmentKeyPackets: []string{}}
req, err := NewJSONRequest("POST", "/messages", createReq) req, err := c.NewJSONRequest("POST", "/messages", createReq)
if err != nil { if err != nil {
return return
} }
@ -688,7 +688,7 @@ func (c *Client) doMessagesAction(action string, ids []string) (err error) {
// You should not call this directly unless you know what you are doing (it can overload the server). // You should not call this directly unless you know what you are doing (it can overload the server).
func (c *Client) doMessagesActionInner(action string, ids []string) (err error) { func (c *Client) doMessagesActionInner(action string, ids []string) (err error) {
actionReq := &MessagesActionReq{IDs: ids} actionReq := &MessagesActionReq{IDs: ids}
req, err := NewJSONRequest("PUT", "/messages/"+action, actionReq) req, err := c.NewJSONRequest("PUT", "/messages/"+action, actionReq)
if err != nil { if err != nil {
return return
} }
@ -740,7 +740,7 @@ func (c *Client) LabelMessages(ids []string, label string) (err error) {
func (c *Client) labelMessages(ids []string, label string) (err error) { func (c *Client) labelMessages(ids []string, label string) (err error) {
labelReq := &LabelMessagesReq{LabelID: label, IDs: ids} labelReq := &LabelMessagesReq{LabelID: label, IDs: ids}
req, err := NewJSONRequest("PUT", "/messages/label", labelReq) req, err := c.NewJSONRequest("PUT", "/messages/label", labelReq)
if err != nil { if err != nil {
return return
} }
@ -770,7 +770,7 @@ func (c *Client) UnlabelMessages(ids []string, label string) (err error) {
func (c *Client) unlabelMessages(ids []string, label string) (err error) { func (c *Client) unlabelMessages(ids []string, label string) (err error) {
labelReq := &LabelMessagesReq{LabelID: label, IDs: ids} labelReq := &LabelMessagesReq{LabelID: label, IDs: ids}
req, err := NewJSONRequest("PUT", "/messages/unlabel", labelReq) req, err := c.NewJSONRequest("PUT", "/messages/unlabel", labelReq)
if err != nil { if err != nil {
return return
} }
@ -793,7 +793,7 @@ func (c *Client) EmptyFolder(labelID, addressID string) (err error) {
reqURL += ("&AddressID=" + addressID) reqURL += ("&AddressID=" + addressID)
} }
req, err := NewRequest("DELETE", reqURL, nil) req, err := c.NewRequest("DELETE", reqURL, nil)
if err != nil { if err != nil {
return return

View File

@ -28,7 +28,7 @@ func (c *Client) SendSimpleMetric(category, action, label string) (err error) {
v.Set("Action", action) v.Set("Action", action)
v.Set("Label", label) v.Set("Label", label)
req, err := NewRequest("GET", "/metrics?"+v.Encode(), nil) req, err := c.NewRequest("GET", "/metrics?"+v.Encode(), nil)
if err != nil { if err != nil {
return return
} }

View File

@ -21,7 +21,6 @@ import (
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"strings" "strings"
"sync"
"time" "time"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
@ -43,63 +42,8 @@ var dohProviders = []string{ //nolint[gochecknoglobals]
"https://dns.google/dns-query", "https://dns.google/dns-query",
} }
// globalAllowDoH controls whether or not to enable use of DoH/Proxy in pmapi. // proxyProvider manages known proxies.
var globalAllowDoH = false // nolint[golint] type proxyProvider struct {
// globalProxyMutex allows threadsafe modification of proxy state.
var globalProxyMutex = sync.RWMutex{} // nolint[golint]
// globalOriginalURL backs up the original API url so it can be restored later.
var globalOriginalURL = RootURL // nolint[golint]
// globalIsDoHAllowed returns whether or not to use DoH.
func globalIsDoHAllowed() bool { // nolint[golint]
globalProxyMutex.RLock()
defer globalProxyMutex.RUnlock()
return globalAllowDoH
}
// GlobalAllowDoH enables DoH.
func GlobalAllowDoH() { // nolint[golint]
globalProxyMutex.Lock()
defer globalProxyMutex.Unlock()
globalAllowDoH = true
}
// GlobalDisallowDoH disables DoH and sets the RootURL back to what it was.
func GlobalDisallowDoH() { // nolint[golint]
globalProxyMutex.Lock()
defer globalProxyMutex.Unlock()
globalAllowDoH = false
RootURL = globalOriginalURL
}
// globalSetRootURL sets the global RootURL.
func globalSetRootURL(url string) { // nolint[golint]
globalProxyMutex.Lock()
defer globalProxyMutex.Unlock()
RootURL = url
}
// GlobalGetRootURL returns the global RootURL.
func GlobalGetRootURL() (url string) { // nolint[golint]
globalProxyMutex.RLock()
defer globalProxyMutex.RUnlock()
return RootURL
}
// isProxyEnabled returns whether or not we are currently using a proxy.
func isProxyEnabled() bool { // nolint[golint]
return globalOriginalURL != GlobalGetRootURL()
}
// proxyManager manages known proxies.
type proxyManager struct {
// dohLookup is used to look up the given query at the given DoH provider, returning the TXT records> // dohLookup is used to look up the given query at the given DoH provider, returning the TXT records>
dohLookup func(query, provider string) (urls []string, err error) dohLookup func(query, provider string) (urls []string, err error)
@ -113,10 +57,10 @@ type proxyManager struct {
lastLookup time.Time // The time at which we last attempted to find a proxy. lastLookup time.Time // The time at which we last attempted to find a proxy.
} }
// newProxyManager creates a new proxyManager that queries the given DoH providers // newProxyProvider creates a new proxyProvider that queries the given DoH providers
// to retrieve DNS records for the given query string. // to retrieve DNS records for the given query string.
func newProxyManager(providers []string, query string) (p *proxyManager) { // nolint[unparam] func newProxyProvider(providers []string, query string) (p *proxyProvider) { // nolint[unparam]
p = &proxyManager{ p = &proxyProvider{
providers: providers, providers: providers,
query: query, query: query,
useDuration: proxyRevertTime, useDuration: proxyRevertTime,
@ -132,7 +76,7 @@ func newProxyManager(providers []string, query string) (p *proxyManager) { // no
// findProxy returns a new proxy domain which is not equal to the current RootURL. // findProxy returns a new proxy domain which is not equal to the current RootURL.
// It returns an error if the process takes longer than ProxySearchTime. // It returns an error if the process takes longer than ProxySearchTime.
func (p *proxyManager) findProxy() (proxy string, err error) { func (p *proxyProvider) findProxy() (proxy string, err error) {
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) { if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
return "", errors.New("not looking for a proxy, too soon") return "", errors.New("not looking for a proxy, too soon")
} }
@ -147,7 +91,7 @@ func (p *proxyManager) findProxy() (proxy string, err error) {
} }
for _, proxy := range p.proxyCache { for _, proxy := range p.proxyCache {
if proxy != stripProtocol(GlobalGetRootURL()) && p.canReach(proxy) { if p.canReach(proxy) {
proxyResult <- proxy proxyResult <- proxy
return return
} }
@ -171,25 +115,8 @@ func (p *proxyManager) findProxy() (proxy string, err error) {
} }
} }
// useProxy sets the proxy server to use. It returns to the original RootURL after 24 hours.
func (p *proxyManager) useProxy(proxy string) {
if !isProxyEnabled() {
p.disableProxyAfter(p.useDuration)
}
globalSetRootURL(https(proxy))
}
// disableProxyAfter disables the proxy after the given amount of time.
func (p *proxyManager) disableProxyAfter(d time.Duration) {
go func() {
<-time.After(d)
globalSetRootURL(globalOriginalURL)
}()
}
// refreshProxyCache loads the latest proxies from the known providers. // refreshProxyCache loads the latest proxies from the known providers.
func (p *proxyManager) refreshProxyCache() error { func (p *proxyProvider) refreshProxyCache() error {
logrus.Info("Refreshing proxy cache") logrus.Info("Refreshing proxy cache")
for _, provider := range p.providers { for _, provider := range p.providers {
@ -197,7 +124,7 @@ func (p *proxyManager) refreshProxyCache() error {
p.proxyCache = proxies p.proxyCache = proxies
// We also want to allow bridge to switch back to the standard API at any time. // We also want to allow bridge to switch back to the standard API at any time.
p.proxyCache = append(p.proxyCache, globalOriginalURL) p.proxyCache = append(p.proxyCache, RootURL)
logrus.WithField("proxies", proxies).Info("Available proxies") logrus.WithField("proxies", proxies).Info("Available proxies")
@ -210,9 +137,13 @@ func (p *proxyManager) refreshProxyCache() error {
// canReach returns whether we can reach the given url. // canReach returns whether we can reach the given url.
// NOTE: we skip cert verification to stop it complaining that cert name doesn't match hostname. // NOTE: we skip cert verification to stop it complaining that cert name doesn't match hostname.
func (p *proxyManager) canReach(url string) bool { func (p *proxyProvider) canReach(url string) bool {
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
url = "https://" + url
}
pinger := resty.New(). pinger := resty.New().
SetHostURL(https(url)). SetHostURL(url).
SetTimeout(p.lookupTimeout). SetTimeout(p.lookupTimeout).
SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) // nolint[gosec] SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) // nolint[gosec]
@ -227,7 +158,7 @@ func (p *proxyManager) canReach(url string) bool {
// It looks up DNS TXT records for the given query URL using the given DoH provider. // It looks up DNS TXT records for the given query URL using the given DoH provider.
// It returns a list of all found TXT records. // It returns a list of all found TXT records.
// If the whole process takes more than ProxyQueryTime then an error is returned. // If the whole process takes more than ProxyQueryTime then an error is returned.
func (p *proxyManager) defaultDoHLookup(query, dohProvider string) (data []string, err error) { func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []string, err error) {
dataResult := make(chan []string) dataResult := make(chan []string)
errResult := make(chan error) errResult := make(chan error)
go func() { go func() {
@ -282,23 +213,3 @@ func (p *proxyManager) defaultDoHLookup(query, dohProvider string) (data []strin
return return
} }
} }
func stripProtocol(url string) string {
if strings.HasPrefix(url, "https://") {
return strings.TrimPrefix(url, "https://")
}
if strings.HasPrefix(url, "http://") {
return strings.TrimPrefix(url, "http://")
}
return url
}
func https(url string) string {
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
url = "https://" + url
}
return url
}

View File

@ -32,14 +32,14 @@ const (
TestGoogleProvider = "https://dns.google/dns-query" TestGoogleProvider = "https://dns.google/dns-query"
) )
func TestProxyManager_FindProxy(t *testing.T) { func TestProxyProvider_FindProxy(t *testing.T) {
blockAPI() blockAPI()
defer unblockAPI() defer unblockAPI()
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy.Close() defer proxy.Close()
p := newProxyManager([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findProxy() url, err := p.findProxy()
@ -47,7 +47,7 @@ func TestProxyManager_FindProxy(t *testing.T) {
require.Equal(t, proxy.URL, url) require.Equal(t, proxy.URL, url)
} }
func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) { func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
blockAPI() blockAPI()
defer unblockAPI() defer unblockAPI()
@ -58,7 +58,7 @@ func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) {
badProxy.Close() badProxy.Close()
defer goodProxy.Close() defer goodProxy.Close()
p := newProxyManager([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, goodProxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, goodProxy.URL}, nil }
url, err := p.findProxy() url, err := p.findProxy()
@ -66,7 +66,7 @@ func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) {
require.Equal(t, goodProxy.URL, url) require.Equal(t, goodProxy.URL, url)
} }
func TestProxyManager_FindProxy_FailIfNoneReachable(t *testing.T) { func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
blockAPI() blockAPI()
defer unblockAPI() defer unblockAPI()
@ -77,21 +77,21 @@ func TestProxyManager_FindProxy_FailIfNoneReachable(t *testing.T) {
badProxy.Close() badProxy.Close()
anotherBadProxy.Close() anotherBadProxy.Close()
p := newProxyManager([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, anotherBadProxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, anotherBadProxy.URL}, nil }
_, err := p.findProxy() _, err := p.findProxy()
require.Error(t, err) require.Error(t, err)
} }
func TestProxyManager_FindProxy_LookupTimeout(t *testing.T) { func TestProxyProvider_FindProxy_LookupTimeout(t *testing.T) {
blockAPI() blockAPI()
defer unblockAPI() defer unblockAPI()
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy.Close() defer proxy.Close()
p := newProxyManager([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.lookupTimeout = time.Second p.lookupTimeout = time.Second
p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil } p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
@ -100,7 +100,7 @@ func TestProxyManager_FindProxy_LookupTimeout(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestProxyManager_FindProxy_FindTimeout(t *testing.T) { func TestProxyProvider_FindProxy_FindTimeout(t *testing.T) {
blockAPI() blockAPI()
defer unblockAPI() defer unblockAPI()
@ -109,7 +109,7 @@ func TestProxyManager_FindProxy_FindTimeout(t *testing.T) {
})) }))
defer slowProxy.Close() defer slowProxy.Close()
p := newProxyManager([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.findTimeout = time.Second p.findTimeout = time.Second
p.dohLookup = func(q, p string) ([]string, error) { return []string{slowProxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
@ -118,14 +118,14 @@ func TestProxyManager_FindProxy_FindTimeout(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestProxyManager_UseProxy(t *testing.T) { func TestProxyProvider_UseProxy(t *testing.T) {
blockAPI() blockAPI()
defer unblockAPI() defer unblockAPI()
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy.Close() defer proxy.Close()
p := newProxyManager([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findProxy() url, err := p.findProxy()
@ -135,7 +135,7 @@ func TestProxyManager_UseProxy(t *testing.T) {
require.Equal(t, proxy.URL, GlobalGetRootURL()) require.Equal(t, proxy.URL, GlobalGetRootURL())
} }
func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) { func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
blockAPI() blockAPI()
defer unblockAPI() defer unblockAPI()
@ -146,7 +146,7 @@ func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) {
proxy3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) proxy3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy3.Close() defer proxy3.Close()
p := newProxyManager([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
url, err := p.findProxy() url, err := p.findProxy()
@ -173,14 +173,14 @@ func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) {
require.Equal(t, proxy3.URL, GlobalGetRootURL()) require.Equal(t, proxy3.URL, GlobalGetRootURL())
} }
func TestProxyManager_UseProxy_RevertAfterTime(t *testing.T) { func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) {
blockAPI() blockAPI()
defer unblockAPI() defer unblockAPI()
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy.Close() defer proxy.Close()
p := newProxyManager([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.useDuration = time.Second p.useDuration = time.Second
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
@ -195,14 +195,14 @@ func TestProxyManager_UseProxy_RevertAfterTime(t *testing.T) {
require.Equal(t, globalOriginalURL, GlobalGetRootURL()) require.Equal(t, globalOriginalURL, GlobalGetRootURL())
} }
func TestProxyManager_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) { func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
// Don't block the API here because we want it to be working so the test can find it. // Don't block the API here because we want it to be working so the test can find it.
defer unblockAPI() defer unblockAPI()
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy.Close() defer proxy.Close()
p := newProxyManager([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findProxy() url, err := p.findProxy()
@ -225,7 +225,7 @@ func TestProxyManager_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachabl
require.Equal(t, globalOriginalURL, GlobalGetRootURL()) require.Equal(t, globalOriginalURL, GlobalGetRootURL())
} }
func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) { func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
blockAPI() blockAPI()
defer unblockAPI() defer unblockAPI()
@ -234,7 +234,7 @@ func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlo
proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy2.Close() defer proxy2.Close()
p := newProxyManager([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
// Find a proxy. // Find a proxy.
@ -256,32 +256,32 @@ func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlo
require.Equal(t, proxy2.URL, GlobalGetRootURL()) require.Equal(t, proxy2.URL, GlobalGetRootURL())
} }
func TestProxyManager_DoHLookup_Quad9(t *testing.T) { func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
records, err := p.dohLookup(TestDoHQuery, TestQuad9Provider) records, err := p.dohLookup(TestDoHQuery, TestQuad9Provider)
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, records) require.NotEmpty(t, records)
} }
func TestProxyManager_DoHLookup_Google(t *testing.T) { func TestProxyProvider_DoHLookup_Google(t *testing.T) {
p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
records, err := p.dohLookup(TestDoHQuery, TestGoogleProvider) records, err := p.dohLookup(TestDoHQuery, TestGoogleProvider)
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, records) require.NotEmpty(t, records)
} }
func TestProxyManager_DoHLookup_FindProxy(t *testing.T) { func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
url, err := p.findProxy() url, err := p.findProxy()
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, url) require.NotEmpty(t, url)
} }
func TestProxyManager_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) { func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
p := newProxyManager([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery) p := newProxyProvider([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery)
url, err := p.findProxy() url, err := p.findProxy()
require.NoError(t, err) require.NoError(t, err)

View File

@ -26,8 +26,9 @@ import (
) )
// NewRequest creates a new request. // NewRequest creates a new request.
func NewRequest(method, path string, body io.Reader) (req *http.Request, err error) { func (c *Client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) {
req, err = http.NewRequest(method, GlobalGetRootURL()+path, body) // TODO: Support other protocols (localhost needs http not https).
req, err = http.NewRequest(method, "https://"+c.cm.GetRootURL()+path, body)
if req != nil { if req != nil {
req.Header.Set("User-Agent", CurrentUserAgent) req.Header.Set("User-Agent", CurrentUserAgent)
} }
@ -35,13 +36,13 @@ func NewRequest(method, path string, body io.Reader) (req *http.Request, err err
} }
// NewJSONRequest create a new JSON request. // NewJSONRequest create a new JSON request.
func NewJSONRequest(method, path string, body interface{}) (*http.Request, error) { func (c *Client) NewJSONRequest(method, path string, body interface{}) (*http.Request, error) {
b, err := json.Marshal(body) b, err := json.Marshal(body)
if err != nil { if err != nil {
panic(err) panic(err)
} }
req, err := NewRequest(method, path, bytes.NewReader(b)) req, err := c.NewRequest(method, path, bytes.NewReader(b))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,7 +71,7 @@ func (w *MultipartWriter) Close() error {
// that writing the request and sending it MUST be done in parallel. If the // that writing the request and sending it MUST be done in parallel. If the
// request fails, subsequent writes to the multipart writer will fail with an // request fails, subsequent writes to the multipart writer will fail with an
// io.ErrClosedPipe error. // io.ErrClosedPipe error.
func NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWriter, err error) { func (c *Client) NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWriter, err error) {
// The pipe will connect the multipart writer and the HTTP request body. // The pipe will connect the multipart writer and the HTTP request body.
pr, pw := io.Pipe() pr, pw := io.Pipe()
@ -80,7 +81,7 @@ func NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWr
pw, pw,
} }
req, err = NewRequest(method, path, pr) req, err = c.NewRequest(method, path, pr)
if err != nil { if err != nil {
return return
} }

View File

@ -45,7 +45,7 @@ type UserSettings struct {
// GetUserSettings gets general settings. // GetUserSettings gets general settings.
func (c *Client) GetUserSettings() (settings UserSettings, err error) { func (c *Client) GetUserSettings() (settings UserSettings, err error) {
req, err := NewRequest("GET", "/settings", nil) req, err := c.NewRequest("GET", "/settings", nil)
if err != nil { if err != nil {
return return
@ -99,7 +99,7 @@ type MailSettings struct {
// GetMailSettings gets contact details specified by contact ID. // GetMailSettings gets contact details specified by contact ID.
func (c *Client) GetMailSettings() (settings MailSettings, err error) { func (c *Client) GetMailSettings() (settings MailSettings, err error) {
req, err := NewRequest("GET", "/settings/mail", nil) req, err := c.NewRequest("GET", "/settings/mail", nil)
if err != nil { if err != nil {
return return

View File

@ -93,7 +93,7 @@ func (u *User) KeyRing() *pmcrypto.KeyRing {
// UpdateUser retrieves details about user and loads its addresses. // UpdateUser retrieves details about user and loads its addresses.
func (c *Client) UpdateUser() (user *User, err error) { func (c *Client) UpdateUser() (user *User, err error) {
req, err := NewRequest("GET", "/users", nil) req, err := c.NewRequest("GET", "/users", nil)
if err != nil { if err != nil {
return return
} }