diff --git a/cmd/Desktop-Bridge/main.go b/cmd/Desktop-Bridge/main.go index 48e7ec6a..3ec567bc 100644 --- a/cmd/Desktop-Bridge/main.go +++ b/cmd/Desktop-Bridge/main.go @@ -274,8 +274,11 @@ func run(context *cli.Context) (contextError error) { // nolint[funlen] log.Error("Could not get credentials store: ", credentialsError) } - clientman := pmapi.NewClientManager(pmapifactory.GetClientConfig(cfg, eventListener)) - bridgeInstance := bridge.New(cfg, pref, panicHandler, eventListener, Version, clientman, credentialsStore) + clientConfig := pmapifactory.GetClientConfig(cfg.GetAPIConfig()) + cm := pmapi.NewClientManager(clientConfig) + pmapifactory.SetClientRoundTripper(cm, clientConfig, eventListener) + + bridgeInstance := bridge.New(cfg, pref, panicHandler, eventListener, Version, cm, credentialsStore) imapBackend := imap.NewIMAPBackend(panicHandler, eventListener, cfg, bridgeInstance) smtpBackend := smtp.NewSMTPBackend(panicHandler, eventListener, pref, bridgeInstance) diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 7c1dcf33..6a19ae6b 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -97,7 +97,7 @@ func New( // Allow DoH before starting bridge if the user has previously set this setting. // This allows us to start even if protonmail is blocked. if pref.GetBool(preferences.AllowProxyKey) { - AllowDoH() + b.AllowProxy() } go func() { @@ -178,15 +178,16 @@ func (b *Bridge) watchBridgeOutdated() { } } +// watchUserAuths receives auths from the client manager and sends them to the appropriate user. func (b *Bridge) watchUserAuths() { for auth := range b.clientManager.GetBridgeAuthChannel() { - user, ok := b.hasUser(auth.UserID) + logrus.WithField("token", auth.Auth.GenToken()).WithField("userID", auth.UserID).Info("Received auth from bridge auth channel") - if !ok { - continue + if user, ok := b.hasUser(auth.UserID); ok { + user.ReceiveAPIAuth(auth.Auth) + } else { + logrus.Info("User is not added to bridge yet") } - - user.ReceiveAPIAuth(auth.Auth) } } @@ -274,7 +275,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass apiClient := b.clientManager.GetClient(apiUser.ID) auth, err = apiClient.AuthRefresh(auth.GenToken()) if err != nil { - log.WithError(err).Error("Could refresh token in new client") + log.WithError(err).Error("Could not refresh token in new client") return } @@ -298,6 +299,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user") return } + b.users = append(b.users, user) } // Set up the user auth and store (which we do for both new and existing users). @@ -307,7 +309,6 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass } if !hasUser { - b.users = append(b.users, user) b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel)) } @@ -475,16 +476,16 @@ func (b *Bridge) GetIMAPUpdatesChannel() chan interface{} { return b.idleUpdates } -// AllowDoH instructs bridge to use DoH to access an API proxy if necessary. +// AllowProxy instructs bridge to use DoH to access an API proxy if necessary. // It also needs to work before bridge is initialised (because we may need to use the proxy at startup). -func AllowDoH() { - pmapi.GlobalAllowDoH() +func (b *Bridge) AllowProxy() { + b.clientManager.AllowProxy() } -// DisallowDoH instructs bridge to not use DoH to access an API proxy if necessary. +// DisallowProxy instructs bridge to not use DoH to access an API proxy if necessary. // It also needs to work before bridge is initialised (because we may need to use the proxy at startup). -func DisallowDoH() { - pmapi.GlobalDisallowDoH() +func (b *Bridge) DisallowProxy() { + b.clientManager.DisallowProxy() } func (b *Bridge) updateCurrentUserAgent() { @@ -493,7 +494,11 @@ func (b *Bridge) updateCurrentUserAgent() { // hasUser returns whether the bridge currently has a user with ID `id`. func (b *Bridge) hasUser(id string) (user *User, ok bool) { + logrus.WithField("id", id).Info("Checking whether bridge has given user") + for _, u := range b.users { + logrus.WithField("id", u.ID()).Info("Found potential user") + if u.ID() == id { user, ok = u, true return diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 1903fbca..f2edeb79 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -107,6 +107,8 @@ func (u *User) init(idleUpdates chan interface{}) (err error) { u.wasKeyringUnlocked = false u.unlockingKeyringLock.Unlock() + u.log.Info("Initialising user") + // Reload the user's credentials (if they log out and back in we need the new // version with the apitoken and mailbox password). creds, err := u.credStorer.Get(u.userID) @@ -242,27 +244,19 @@ func (u *User) authorizeAndUnlock() (err error) { } func (u *User) ReceiveAPIAuth(auth *pmapi.Auth) { + u.lock.Lock() + defer u.lock.Unlock() + if auth == nil { if err := u.logout(); err != nil { - u.log.WithError(err).Error("Cannot logout user after receiving empty auth from API") + u.log.WithError(err).Error("Failed to logout user after receiving empty auth from API") } u.isAuthorized = false return } - u.updateAPIToken(auth.GenToken()) -} - -// updateAPIToken is helper for updating the token in keychain. It's not supposed to be -// called directly from other parts of the code, only from `ReceiveAPIAuth`. -func (u *User) updateAPIToken(newRefreshToken string) { - u.lock.Lock() - defer u.lock.Unlock() - - u.log.WithField("token", newRefreshToken).Info("Saving token to credentials store") - - if err := u.credStorer.UpdateToken(u.userID, newRefreshToken); err != nil { - u.log.WithError(err).Error("Cannot update refresh token in credentials store") + if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil { + u.log.WithError(err).Error("Failed to update refresh token in credentials store") return } diff --git a/internal/frontend/cli/system.go b/internal/frontend/cli/system.go index 4fe14702..7ac619c4 100644 --- a/internal/frontend/cli/system.go +++ b/internal/frontend/cli/system.go @@ -22,7 +22,6 @@ import ( "strconv" "strings" - "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/pkg/connection" "github.com/ProtonMail/proton-bridge/pkg/ports" @@ -135,13 +134,13 @@ func (f *frontendCLI) toggleAllowProxy(c *ishell.Context) { f.Println("Bridge is currently set to use alternative routing to connect to Proton if it is being blocked.") if f.yesNoQuestion("Are you sure you want to stop bridge from doing this") { f.preferences.SetBool(preferences.AllowProxyKey, false) - bridge.DisallowDoH() + f.bridge.DisallowProxy() } } else { f.Println("Bridge is currently set to NOT use alternative routing to connect to Proton if it is being blocked.") if f.yesNoQuestion("Are you sure you want to allow bridge to do this") { f.preferences.SetBool(preferences.AllowProxyKey, true) - bridge.AllowDoH() + f.bridge.AllowProxy() } } } diff --git a/internal/frontend/types/types.go b/internal/frontend/types/types.go index 78385d86..b610860a 100644 --- a/internal/frontend/types/types.go +++ b/internal/frontend/types/types.go @@ -52,6 +52,8 @@ type Bridger interface { DeleteUser(userID string, clearCache bool) error ReportBug(osType, osVersion, description, accountName, address, emailClient string) error ClearData() error + AllowProxy() + DisallowProxy() } // BridgeUser is an interface of user needed by frontend. diff --git a/internal/pmapifactory/pmapi_noprod.go b/internal/pmapifactory/pmapi_noprod.go index a8b6a721..2ed31155 100644 --- a/internal/pmapifactory/pmapi_noprod.go +++ b/internal/pmapifactory/pmapi_noprod.go @@ -21,11 +21,14 @@ package pmapifactory import ( - "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/pmapi" ) -func GetClientConfig(config bridge.Configer, _ listener.Listener) *pmapi.ClientConfig { - return config.GetAPIConfig() +func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig { + return clientConfig +} + +func SetClientRoundTripper(_ *pmapi.ClientManager, _ *pmapi.ClientConfig, _ listener.Listener) { + // Use the default roundtripper; do nothing. } diff --git a/internal/pmapifactory/pmapi_prod.go b/internal/pmapifactory/pmapi_prod.go index 7d7f3e0f..f4db4c11 100644 --- a/internal/pmapifactory/pmapi_prod.go +++ b/internal/pmapifactory/pmapi_prod.go @@ -23,26 +23,13 @@ package pmapifactory import ( "time" - "github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/pmapi" + "github.com/sirupsen/logrus" ) -func GetClientConfig(config bridge.Configer, listener listener.Listener) *pmapi.ClientConfig { - clientConfig := config.GetAPIConfig() - - pin := pmapi.NewPMAPIPinning(clientConfig.AppVersion) - pin.ReportCertIssueLocal = func() { - listener.Emit(events.TLSCertIssue, "") - } - - // This transport already has timeouts set governing the roundtrip: - // - IdleConnTimeout: 5 * time.Minute, - // - ExpectContinueTimeout: 500 * time.Millisecond, - // - ResponseHeaderTimeout: 30 * time.Second, - clientConfig.Transport = pin.TransportWithPinning() - +func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig { // We set additional timeouts/thresholds for the request as a whole: clientConfig.Timeout = 10 * time.Minute // Overall request timeout (~25MB / 10 mins => ~40kB/s, should be reasonable). clientConfig.FirstReadTimeout = 30 * time.Second // 30s to match 30s response header timeout. @@ -50,3 +37,15 @@ func GetClientConfig(config bridge.Configer, listener listener.Listener) *pmapi. return clientConfig } + +func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, listener listener.Listener) { + logrus.Info("Setting dialer with pinning") + + pin := pmapi.NewDialerWithPinning(cm, cfg.AppVersion) + + pin.ReportCertIssueLocal = func() { + listener.Emit(events.TLSCertIssue, "") + } + + cm.SetClientRoundTripper(pin.TransportWithPinning()) +} diff --git a/pkg/connection/check_connection.go b/pkg/connection/check_connection.go index c25ed92b..1182649a 100644 --- a/pkg/connection/check_connection.go +++ b/pkg/connection/check_connection.go @@ -39,7 +39,8 @@ var ( // Two errors can be returned, ErrNoInternetConnection or ErrCanNotReachAPI. func CheckInternetConnection() error { client := &http.Client{ - Transport: pmapi.NewPMAPIPinning(pmapi.CurrentUserAgent).TransportWithPinning(), + // TODO: Set transport properly! (Need access to ClientManager somehow) + // Transport: pmapi.NewDialerWithPinning(pmapi.CurrentUserAgent).TransportWithPinning(), } // Do not cumulate timeouts, use goroutines. @@ -51,7 +52,8 @@ func CheckInternetConnection() error { go checkConnection(client, "http://protonstatus.com/vpn_status", retStatus) // Check of API reachability also uses a fast endpoint. - go checkConnection(client, pmapi.GlobalGetRootURL()+"/tests/ping", retAPI) + // TODO: This should check active proxy, not the RootURL + go checkConnection(client, pmapi.RootURL+"/tests/ping", retAPI) errStatus := <-retStatus errAPI := <-retAPI diff --git a/pkg/keychain/keychain_darwin.go b/pkg/keychain/keychain_darwin.go index d93ba5ab..43723fe0 100644 --- a/pkg/keychain/keychain_darwin.go +++ b/pkg/keychain/keychain_darwin.go @@ -36,7 +36,7 @@ type osxkeychain struct { } func newKeychain() (credentials.Helper, error) { - log.Debug("creating osckeychain") + log.Debug("Creating osckeychain") return &osxkeychain{}, nil } diff --git a/pkg/keychain/keychain_linux.go b/pkg/keychain/keychain_linux.go index 86927a8f..9376d9a9 100644 --- a/pkg/keychain/keychain_linux.go +++ b/pkg/keychain/keychain_linux.go @@ -24,14 +24,14 @@ import ( ) func newKeychain() (credentials.Helper, error) { - log.Debug("creating pass") + log.Debug("Creating pass") passHelper := &pass.Pass{} passErr := checkPassIsUsable(passHelper) if passErr == nil { return passHelper, nil } - log.Debug("creating secretservice") + log.Debug("Creating secretservice") sserviceHelper := &secretservice.Secretservice{} _, sserviceErr := sserviceHelper.List() if sserviceErr == nil { diff --git a/pkg/keychain/keychain_windows.go b/pkg/keychain/keychain_windows.go index b77c0ce1..3f038c6a 100644 --- a/pkg/keychain/keychain_windows.go +++ b/pkg/keychain/keychain_windows.go @@ -23,7 +23,7 @@ import ( ) func newKeychain() (credentials.Helper, error) { - log.Debug("creating wincred") + log.Debug("Creating wincred") return &wincred.Wincred{}, nil } diff --git a/pkg/pmapi/addresses.go b/pkg/pmapi/addresses.go index dba5d9ee..1c69664b 100644 --- a/pkg/pmapi/addresses.go +++ b/pkg/pmapi/addresses.go @@ -161,7 +161,7 @@ func ConstructAddress(headerEmail string, addressEmail string) string { // GetAddresses requests all of current user addresses (without pagination). func (c *Client) GetAddresses() (addresses AddressList, err error) { - req, err := NewRequest("GET", "/addresses", nil) + req, err := c.NewRequest("GET", "/addresses", nil) if err != nil { return } diff --git a/pkg/pmapi/attachments.go b/pkg/pmapi/attachments.go index b63be384..86d1d917 100644 --- a/pkg/pmapi/attachments.go +++ b/pkg/pmapi/attachments.go @@ -179,7 +179,7 @@ func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.R // // The returned created attachment contains the new attachment ID and its size. func (c *Client) CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) { - req, w, err := NewMultipartRequest("POST", "/attachments") + req, w, err := c.NewMultipartRequest("POST", "/attachments") if err != nil { return } @@ -213,7 +213,7 @@ type UpdateAttachmentSignatureReq struct { func (c *Client) UpdateAttachmentSignature(attachmentID, signature string) (err error) { updateReq := &UpdateAttachmentSignatureReq{signature} - req, err := NewJSONRequest("PUT", "/attachments/"+attachmentID+"/signature", updateReq) + req, err := c.NewJSONRequest("PUT", "/attachments/"+attachmentID+"/signature", updateReq) if err != nil { return } @@ -228,7 +228,7 @@ func (c *Client) UpdateAttachmentSignature(attachmentID, signature string) (err // DeleteAttachment removes an attachment. message is the message ID, att is the attachment ID. func (c *Client) DeleteAttachment(attID string) (err error) { - req, err := NewRequest("DELETE", "/attachments/"+attID, nil) + req, err := c.NewRequest("DELETE", "/attachments/"+attID, nil) if err != nil { return } @@ -249,7 +249,7 @@ func (c *Client) GetAttachment(id string) (att io.ReadCloser, err error) { return } - req, err := NewRequest("GET", "/attachments/"+id, nil) + req, err := c.NewRequest("GET", "/attachments/"+id, nil) if err != nil { return } diff --git a/pkg/pmapi/auth.go b/pkg/pmapi/auth.go index dfd00421..4f14da77 100644 --- a/pkg/pmapi/auth.go +++ b/pkg/pmapi/auth.go @@ -214,7 +214,7 @@ func (c *Client) AuthInfo(username string) (info *AuthInfo, err error) { Username: username, } - req, err := NewJSONRequest("POST", "/auth/info", infoReq) + req, err := c.NewJSONRequest("POST", "/auth/info", infoReq) if err != nil { return } @@ -257,7 +257,7 @@ func (c *Client) tryAuth(username, password string, info *AuthInfo, fallbackVers SRPSession: info.srpSession, } - req, err := NewJSONRequest("POST", "/auth", authReq) + req, err := c.NewJSONRequest("POST", "/auth", authReq) if err != nil { return } @@ -335,7 +335,7 @@ func (c *Client) Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error) { TwoFactorCode: twoFactorCode, } - req, err := NewJSONRequest("POST", "/auth/2fa", auth2FAReq) + req, err := c.NewJSONRequest("POST", "/auth/2fa", auth2FAReq) if err != nil { return nil, err } @@ -430,7 +430,7 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) // UID must be set for `x-pm-uid` header field, see backend-communication#11 c.uid = split[0] - req, err := NewJSONRequest("POST", "/auth/refresh", refreshReq) + req, err := c.NewJSONRequest("POST", "/auth/refresh", refreshReq) if err != nil { return } @@ -450,13 +450,14 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error) return auth, err } +// Logout instructs the client manager to log out this client. func (c *Client) Logout() { c.cm.LogoutClient(c.userID) } // logout logs the current user out. func (c *Client) logout() (err error) { - req, err := NewRequest("DELETE", "/auth", nil) + req, err := c.NewRequest("DELETE", "/auth", nil) if err != nil { return } diff --git a/pkg/pmapi/auth_test.go b/pkg/pmapi/auth_test.go index 7be954cb..05ad840c 100644 --- a/pkg/pmapi/auth_test.go +++ b/pkg/pmapi/auth_test.go @@ -332,7 +332,9 @@ func TestClient_Logout(t *testing.T) { c.uid = testUID c.accessToken = testAccessToken - Ok(t, c.Logout()) + c.Logout() + + // TODO: Check that the client is logged out and sensitive data is cleared eventually. } func TestClient_DoUnauthorized(t *testing.T) { @@ -355,7 +357,7 @@ func TestClient_DoUnauthorized(t *testing.T) { c.expiresAt = aLongTimeAgo c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken - req, err := NewRequest("GET", "/", nil) + req, err := c.NewRequest("GET", "/", nil) Ok(t, err) res, err := c.Do(req, true) diff --git a/pkg/pmapi/bugs.go b/pkg/pmapi/bugs.go index 77a95977..7cf20c31 100644 --- a/pkg/pmapi/bugs.go +++ b/pkg/pmapi/bugs.go @@ -139,9 +139,9 @@ func (c *Client) Report(rep ReportReq) (err error) { var req *http.Request var w *MultipartWriter if len(rep.Attachments) > 0 { - req, w, err = NewMultipartRequest("POST", "/reports/bug") + req, w, err = c.NewMultipartRequest("POST", "/reports/bug") } else { - req, err = NewJSONRequest("POST", "/reports/bug", rep) + req, err = c.NewJSONRequest("POST", "/reports/bug", rep) } if err != nil { return @@ -202,7 +202,7 @@ func (c *Client) ReportCrash(stacktrace string) (err error) { OS: runtime.GOOS, Debug: stacktrace, } - req, err := NewJSONRequest("POST", "/reports/crash", crashReq) + req, err := c.NewJSONRequest("POST", "/reports/crash", crashReq) if err != nil { return } diff --git a/pkg/pmapi/client.go b/pkg/pmapi/client.go index e26ba92b..98249e40 100644 --- a/pkg/pmapi/client.go +++ b/pkg/pmapi/client.go @@ -99,12 +99,12 @@ type ClientConfig struct { // Client to communicate with API. type Client struct { - cm *ClientManager - client *http.Client + cm *ClientManager + hc *http.Client uid string accessToken string - userID string // Twice here because Username is not unique. + userID string requestLocker sync.Locker keyLocker sync.Locker @@ -120,7 +120,7 @@ type Client struct { func newClient(cm *ClientManager, userID string) *Client { return &Client{ cm: cm, - client: getHTTPClient(cm.GetConfig()), + hc: getHTTPClient(cm.GetConfig()), userID: userID, requestLocker: &sync.Mutex{}, keyLocker: &sync.Mutex{}, @@ -132,12 +132,10 @@ func newClient(cm *ClientManager, userID string) *Client { func getHTTPClient(cfg *ClientConfig) (hc *http.Client) { hc = &http.Client{Timeout: cfg.Timeout} - if cfg.Transport == nil && defaultTransport == nil { - return - } - - if defaultTransport != nil { - hc.Transport = defaultTransport + if cfg.Transport == nil { + if defaultTransport != nil { + hc.Transport = defaultTransport + } return } @@ -205,7 +203,7 @@ func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthori } hasBody := len(bodyBuffer) > 0 - if res, err = c.client.Do(req); err != nil { + if res, err = c.hc.Do(req); err != nil { if res == nil { c.log.WithError(err).Error("Cannot get response") err = ErrAPINotReachable diff --git a/pkg/pmapi/client_test.go b/pkg/pmapi/client_test.go index fdb55641..b110ba39 100644 --- a/pkg/pmapi/client_test.go +++ b/pkg/pmapi/client_test.go @@ -51,7 +51,7 @@ func TestClient_Do(t *testing.T) { })) defer s.Close() - req, err := NewRequest("GET", "/", nil) + req, err := c.NewRequest("GET", "/", nil) if err != nil { t.Fatal("Expected no error while creating request, got:", err) } @@ -163,8 +163,8 @@ func TestClient_FirstReadTimeout(t *testing.T) { ) defer finish() - c.client.Transport = &slowTransport{ - transport: c.client.Transport, + c.hc.Transport = &slowTransport{ + transport: c.hc.Transport, firstBodySleep: requestTimeout, } diff --git a/pkg/pmapi/clientmanager.go b/pkg/pmapi/clientmanager.go index 83aec19e..1157b063 100644 --- a/pkg/pmapi/clientmanager.go +++ b/pkg/pmapi/clientmanager.go @@ -1,24 +1,32 @@ package pmapi import ( + "net/http" "sync" "github.com/getsentry/raven-go" + "github.com/pkg/errors" "github.com/sirupsen/logrus" ) // ClientManager is a manager of clients. type ClientManager struct { + config *ClientConfig + clients map[string]*Client clientsLocker sync.Locker tokens map[string]string tokensLocker sync.Locker - config *ClientConfig + url string + urlLocker sync.Locker bridgeAuths chan ClientAuth clientAuths chan ClientAuth + + allowProxy bool + proxyProvider *proxyProvider } type ClientAuth struct { @@ -33,13 +41,21 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { } cm = &ClientManager{ + config: config, + clients: make(map[string]*Client), clientsLocker: &sync.Mutex{}, - tokens: make(map[string]string), - tokensLocker: &sync.Mutex{}, - config: config, - bridgeAuths: make(chan ClientAuth), - clientAuths: make(chan ClientAuth), + + tokens: make(map[string]string), + tokensLocker: &sync.Mutex{}, + + url: RootURL, + urlLocker: &sync.Mutex{}, + + bridgeAuths: make(chan ClientAuth), + clientAuths: make(chan ClientAuth), + + proxyProvider: newProxyProvider(dohProviders, proxyQuery), } go cm.forwardClientAuths() @@ -47,6 +63,12 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) { return } +// SetClientRoundTripper sets the roundtripper used by clients created by this client manager. +func (cm *ClientManager) SetClientRoundTripper(rt http.RoundTripper) { + logrus.Info("Setting client roundtripper") + cm.config.Transport = rt +} + // GetClient returns a client for the given userID. // If the client does not exist already, it is created. func (cm *ClientManager) GetClient(userID string) *Client { @@ -71,7 +93,7 @@ func (cm *ClientManager) LogoutClient(userID string) { go func() { if err := client.logout(); err != nil { - // TODO: Try again! + // TODO: Try again! This should loop until it succeeds (might fail the first time due to internet). logrus.WithError(err).Error("Client logout failed, not trying again") } client.clearSensitiveData() @@ -81,6 +103,56 @@ func (cm *ClientManager) LogoutClient(userID string) { return } +// GetRootURL returns the root URL to make requests to. +// It does not include the protocol i.e. no "https://". +func (cm *ClientManager) GetRootURL() string { + cm.urlLocker.Lock() + defer cm.urlLocker.Unlock() + + return cm.url +} + +// SetRootURL sets the root URL to make requests to. +func (cm *ClientManager) SetRootURL(url string) { + cm.urlLocker.Lock() + defer cm.urlLocker.Unlock() + + logrus.WithField("url", url).Info("Changing to a new root URL") + + cm.url = url +} + +// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be. +func (cm *ClientManager) IsProxyAllowed() bool { + return cm.allowProxy +} + +// AllowProxy allows the client manager to switch clients over to a proxy if need be. +func (cm *ClientManager) AllowProxy() { + cm.allowProxy = true +} + +// DisallowProxy prevents the client manager from switching clients over to a proxy if need be. +func (cm *ClientManager) DisallowProxy() { + cm.allowProxy = false +} + +// FindProxy returns a usable proxy server. +func (cm *ClientManager) SwitchToProxy() (proxy string, err error) { + logrus.Info("Attempting gto switch to a proxy") + + if proxy, err = cm.proxyProvider.findProxy(); err != nil { + err = errors.Wrap(err, "failed to find usable proxy") + return + } + + cm.SetRootURL(proxy) + + // TODO: Disable after 24 hours. + + return +} + // GetConfig returns the config used to configure clients. func (cm *ClientManager) GetConfig() *ClientConfig { return cm.config @@ -113,7 +185,7 @@ func (cm *ClientManager) setToken(userID, token string) { cm.tokensLocker.Lock() defer cm.tokensLocker.Unlock() - logrus.WithField("userID", userID).WithField("token", token).Info("Updating refresh token") + logrus.WithField("userID", userID).Info("Updating refresh token") cm.tokens[userID] = token } @@ -127,16 +199,19 @@ func (cm *ClientManager) clearToken(userID string) { delete(cm.tokens, userID) } -// handleClientAuth +// handleClientAuth updates or clears client authorisation based on auths received. func (cm *ClientManager) handleClientAuth(ca ClientAuth) { - // TODO: Maybe want to logout the client in case of nil auth. + // If we aren't managing this client, there's nothing to do. if _, ok := cm.clients[ca.UserID]; !ok { return } + // If the auth is nil, we should clear the token. + // TODO: Maybe we should trigger a client logout here? Then we don't have to remember to log it out ourself. if ca.Auth == nil { cm.clearToken(ca.UserID) - } else { - cm.setToken(ca.UserID, ca.Auth.GenToken()) + return } + + cm.setToken(ca.UserID, ca.Auth.GenToken()) } diff --git a/pkg/pmapi/config.go b/pkg/pmapi/config.go index 37ee5811..7e1bb421 100644 --- a/pkg/pmapi/config.go +++ b/pkg/pmapi/config.go @@ -24,9 +24,9 @@ import ( // RootURL is the API root URL. // -// This can be changed using build flags: pmapi_local for "http://localhost/api", -// pmapi_dev or pmapi_prod. Default is pmapi_prod. -var RootURL = "https://api.protonmail.ch" //nolint[gochecknoglobals] +// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod. +// Default is pmapi_prod. +var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals] // CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program // version and email client. diff --git a/pkg/pmapi/config_dev.go b/pkg/pmapi/config_dev.go index a456c4af..77a2e6c6 100644 --- a/pkg/pmapi/config_dev.go +++ b/pkg/pmapi/config_dev.go @@ -20,5 +20,5 @@ package pmapi func init() { - RootURL = "https://dev.protonmail.com/api" + RootURL = "dev.protonmail.com/api" } diff --git a/pkg/pmapi/config_local.go b/pkg/pmapi/config_local.go index f534f250..9544e4f2 100644 --- a/pkg/pmapi/config_local.go +++ b/pkg/pmapi/config_local.go @@ -27,7 +27,7 @@ import ( func init() { // Use port above 1000 which doesn't need root access to start anything on it. // Now the port is rounded pi. :-) - RootURL = "http://127.0.0.1:3142/api" + RootURL = "127.0.0.1:3142/api" // TLS certificate is self-signed defaultTransport = &http.Transport{ diff --git a/pkg/pmapi/contacts.go b/pkg/pmapi/contacts.go index 26959569..9cf31b60 100644 --- a/pkg/pmapi/contacts.go +++ b/pkg/pmapi/contacts.go @@ -119,7 +119,7 @@ func (c *Client) GetContacts(page int, pageSize int) (contacts []*Contact, err e if pageSize > 0 { v.Set("PageSize", strconv.Itoa(pageSize)) } - req, err := NewRequest("GET", "/contacts?"+v.Encode(), nil) + req, err := c.NewRequest("GET", "/contacts?"+v.Encode(), nil) if err != nil { return @@ -136,7 +136,7 @@ func (c *Client) GetContacts(page int, pageSize int) (contacts []*Contact, err e // GetContactByID gets contact details specified by contact ID. func (c *Client) GetContactByID(id string) (contactDetail Contact, err error) { - req, err := NewRequest("GET", "/contacts/"+id, nil) + req, err := c.NewRequest("GET", "/contacts/"+id, nil) if err != nil { return @@ -164,7 +164,7 @@ func (c *Client) GetContactsForExport(page int, pageSize int) (contacts []Contac v.Set("PageSize", strconv.Itoa(pageSize)) } - req, err := NewRequest("GET", "/contacts/export?"+v.Encode(), nil) + req, err := c.NewRequest("GET", "/contacts/export?"+v.Encode(), nil) if err != nil { return @@ -198,7 +198,7 @@ func (c *Client) GetAllContactsEmails(page int, pageSize int) (contactsEmails [] v.Set("PageSize", strconv.Itoa(pageSize)) } - req, err := NewRequest("GET", "/contacts/emails?"+v.Encode(), nil) + req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil) if err != nil { return } @@ -221,7 +221,7 @@ func (c *Client) GetContactEmailByEmail(email string, page int, pageSize int) (c } v.Set("Email", email) - req, err := NewRequest("GET", "/contacts/emails?"+v.Encode(), nil) + req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil) if err != nil { return } @@ -276,7 +276,7 @@ func (c *Client) AddContacts(cards ContactsCards, overwrite int, groups int, lab Labels: labels, } - req, err := NewJSONRequest("POST", "/contacts", reqBody) + req, err := c.NewJSONRequest("POST", "/contacts", reqBody) if err != nil { return } @@ -306,7 +306,7 @@ func (c *Client) UpdateContact(id string, cards []Card) (res *UpdateContactRespo reqBody := UpdateContactReq{ Cards: cards, } - req, err := NewJSONRequest("PUT", "/contacts/"+id, reqBody) + req, err := c.NewJSONRequest("PUT", "/contacts/"+id, reqBody) if err != nil { return } @@ -354,7 +354,7 @@ func (c *Client) modifyContactGroups(groupID string, modifyContactGroupsAction i Action: modifyContactGroupsAction, ContactEmailIDs: contactEmailIDs, } - req, err := NewJSONRequest("PUT", "/contacts/group", reqBody) + req, err := c.NewJSONRequest("PUT", "/contacts/group", reqBody) if err != nil { return } @@ -377,7 +377,7 @@ func (c *Client) DeleteContacts(ids []string) (err error) { IDs: ids, } - req, err := NewJSONRequest("PUT", "/contacts/delete", deleteReq) + req, err := c.NewJSONRequest("PUT", "/contacts/delete", deleteReq) if err != nil { return } @@ -402,7 +402,7 @@ func (c *Client) DeleteContacts(ids []string) (err error) { // DeleteAllContacts deletes all contacts. func (c *Client) DeleteAllContacts() (err error) { - req, err := NewRequest("DELETE", "/contacts", nil) + req, err := c.NewRequest("DELETE", "/contacts", nil) if err != nil { return } diff --git a/pkg/pmapi/conversations.go b/pkg/pmapi/conversations.go index 9401016b..728f09d7 100644 --- a/pkg/pmapi/conversations.go +++ b/pkg/pmapi/conversations.go @@ -36,7 +36,7 @@ func (c *Client) CountConversations(addressID string) (counts []*ConversationsCo if addressID != "" { reqURL += ("?AddressID=" + addressID) } - req, err := NewRequest("GET", reqURL, nil) + req, err := c.NewRequest("GET", reqURL, nil) if err != nil { return } diff --git a/pkg/pmapi/dialer_with_proxy.go b/pkg/pmapi/dialer_with_proxy.go index f2db0234..7c9ca96b 100644 --- a/pkg/pmapi/dialer_with_proxy.go +++ b/pkg/pmapi/dialer_with_proxy.go @@ -112,51 +112,47 @@ type DialerWithPinning struct { // It is used only if set. ReportCertIssueLocal func() - // proxyManager manages API proxies. - proxyManager *proxyManager + // cm is used to find and switch to a proxy if necessary. + cm *ClientManager // A logger for logging messages. log logrus.FieldLogger } -func NewDialerWithPinning(reportURI string, report TLSReport) *DialerWithPinning { +// NewDialerWithPinning constructs a new dialer with pinned certs. +func NewDialerWithPinning(cm *ClientManager, appVersion string) *DialerWithPinning { + reportURI := "https://reports.protonmail.ch/reports/tls" + + report := TLSReport{ + EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339), + IncludeSubdomains: false, + ValidatedCertificateChain: []string{}, + ServedCertificateChain: []string{}, + AppVersion: appVersion, + + // NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;) + KnownPins: []string{ + `pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current + `pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot + `pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold + `pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main + `pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1 + `pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2 + `pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3 + }, + } + log := logrus.WithField("pkg", "pmapi/tls-pinning") - proxyManager := newProxyManager(dohProviders, proxyQuery) - return &DialerWithPinning{ - isReported: false, - reportURI: reportURI, - report: report, - proxyManager: proxyManager, - log: log, + cm: cm, + isReported: false, + reportURI: reportURI, + report: report, + log: log, } } -func NewPMAPIPinning(appVersion string) *DialerWithPinning { - return NewDialerWithPinning( - "https://reports.protonmail.ch/reports/tls", - TLSReport{ - EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339), - IncludeSubdomains: false, - ValidatedCertificateChain: []string{}, - ServedCertificateChain: []string{}, - AppVersion: appVersion, - - // NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;) - KnownPins: []string{ - `pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current - `pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot - `pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold - `pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main - `pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1 - `pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2 - `pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3 - }, - }, - ) -} - func (p *DialerWithPinning) reportCertIssue(connState tls.ConnectionState) { p.isReported = true @@ -231,6 +227,7 @@ func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) { return pemCerts } +// TransportWithPinning creates an http.Transport that checks fingerprints when dialing. func (p *DialerWithPinning) TransportWithPinning() *http.Transport { return &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -258,7 +255,7 @@ func (p *DialerWithPinning) TransportWithPinning() *http.Transport { // p.ReportCertIssueLocal() and p.reportCertIssueRemote() if they are not nil. func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (conn net.Conn, err error) { // If DoH is enabled, we hardfail on fingerprint mismatches. - if globalIsDoHAllowed() && p.isReported { + if p.cm.IsProxyAllowed() && p.isReported { return nil, ErrTLSMatch } @@ -283,6 +280,8 @@ func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (c // dialWithProxyFallback tries to dial the given address but falls back to alternative proxies if need be. func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn net.Conn, err error) { + p.log.Info("Dialing with proxy fallback") + var host, port string if host, port, err = net.SplitHostPort(address); err != nil { return @@ -296,21 +295,18 @@ func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn // If DoH is not allowed, give up. Or, if we are dialing something other than the API // (e.g. we dial protonmail.com/... to check for updates), there's also no point in // continuing since a proxy won't help us reach that. - if !globalIsDoHAllowed() || host != stripProtocol(GlobalGetRootURL()) { + if !p.cm.IsProxyAllowed() || host != p.cm.GetRootURL() { + p.log.WithField("useProxy", p.cm.IsProxyAllowed()).Info("Dial failed but not switching to proxy") return } - // Find a new proxy. + // Switch to a proxy and retry the dial. var proxy string - if proxy, err = p.proxyManager.findProxy(); err != nil { + + if proxy, err = p.cm.SwitchToProxy(); err != nil { return } - // Switch to the proxy. - p.log.WithField("proxy", proxy).Debug("Switching to proxy") - p.proxyManager.useProxy(proxy) - - // Retry dial with proxy. return p.dial(network, net.JoinHostPort(proxy, port)) } @@ -329,7 +325,7 @@ func (p *DialerWithPinning) dial(network, address string) (conn net.Conn, err er // If we are not dialing the standard API then we should skip cert verification checks. var tlsConfig *tls.Config = nil - if address != stripProtocol(globalOriginalURL) { + if address != RootURL { tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec] } diff --git a/pkg/pmapi/events.go b/pkg/pmapi/events.go index 98dd9989..50662891 100644 --- a/pkg/pmapi/events.go +++ b/pkg/pmapi/events.go @@ -179,7 +179,7 @@ func (c *Client) GetEvent(last string) (event *Event, err error) { func (c *Client) getEvent(last string, numberOfMergedEvents int) (event *Event, err error) { var req *http.Request if last == "" { - req, err = NewRequest("GET", "/events/latest", nil) + req, err = c.NewRequest("GET", "/events/latest", nil) if err != nil { return } @@ -191,7 +191,7 @@ func (c *Client) getEvent(last string, numberOfMergedEvents int) (event *Event, event, err = res.Event, res.Err() } else { - req, err = NewRequest("GET", "/events/"+last, nil) + req, err = c.NewRequest("GET", "/events/"+last, nil) if err != nil { return } diff --git a/pkg/pmapi/import.go b/pkg/pmapi/import.go index 83f25ee7..3c8c3bb0 100644 --- a/pkg/pmapi/import.go +++ b/pkg/pmapi/import.go @@ -120,7 +120,7 @@ type ImportMsgRes struct { func (c *Client) Import(reqs []*ImportMsgReq) (resps []*ImportMsgRes, err error) { importReq := &ImportReq{Messages: reqs} - req, w, err := NewMultipartRequest("POST", "/import") + req, w, err := c.NewMultipartRequest("POST", "/import") if err != nil { return } diff --git a/pkg/pmapi/key.go b/pkg/pmapi/key.go index c8941e02..4994b4f4 100644 --- a/pkg/pmapi/key.go +++ b/pkg/pmapi/key.go @@ -57,7 +57,7 @@ func (c *Client) PublicKeys(emails []string) (keys map[string]*pmcrypto.KeyRing, email = url.QueryEscape(email) var req *http.Request - if req, err = NewRequest("GET", "/keys?Email="+email, nil); err != nil { + if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil { return } @@ -90,7 +90,7 @@ func (c *Client) GetPublicKeysForEmail(email string) (keys []PublicKey, internal email = url.QueryEscape(email) var req *http.Request - if req, err = NewRequest("GET", "/keys?Email="+email, nil); err != nil { + if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil { return } @@ -123,7 +123,7 @@ type KeySaltRes struct { // GetKeySalts sends request to get list of key salts (n.b. locked route). func (c *Client) GetKeySalts() (keySalts []KeySalt, err error) { var req *http.Request - if req, err = NewRequest("GET", "/keys/salts", nil); err != nil { + if req, err = c.NewRequest("GET", "/keys/salts", nil); err != nil { return } diff --git a/pkg/pmapi/labels.go b/pkg/pmapi/labels.go index 60294f98..cdb7637d 100644 --- a/pkg/pmapi/labels.go +++ b/pkg/pmapi/labels.go @@ -103,7 +103,7 @@ func (c *Client) ListContactGroups() (labels []*Label, err error) { // ListLabelType lists all labels created by the user. func (c *Client) ListLabelType(labelType int) (labels []*Label, err error) { - req, err := NewRequest("GET", fmt.Sprintf("/labels?%d", labelType), nil) + req, err := c.NewRequest("GET", fmt.Sprintf("/labels?%d", labelType), nil) if err != nil { return } @@ -129,7 +129,7 @@ type LabelRes struct { // CreateLabel creates a new label. func (c *Client) CreateLabel(label *Label) (created *Label, err error) { labelReq := &LabelReq{label} - req, err := NewJSONRequest("POST", "/labels", labelReq) + req, err := c.NewJSONRequest("POST", "/labels", labelReq) if err != nil { return } @@ -146,7 +146,7 @@ func (c *Client) CreateLabel(label *Label) (created *Label, err error) { // UpdateLabel updates a label. func (c *Client) UpdateLabel(label *Label) (updated *Label, err error) { labelReq := &LabelReq{label} - req, err := NewJSONRequest("PUT", "/labels/"+label.ID, labelReq) + req, err := c.NewJSONRequest("PUT", "/labels/"+label.ID, labelReq) if err != nil { return } @@ -162,7 +162,7 @@ func (c *Client) UpdateLabel(label *Label) (updated *Label, err error) { // DeleteLabel deletes a label. func (c *Client) DeleteLabel(id string) (err error) { - req, err := NewRequest("DELETE", "/labels/"+id, nil) + req, err := c.NewRequest("DELETE", "/labels/"+id, nil) if err != nil { return } diff --git a/pkg/pmapi/messages.go b/pkg/pmapi/messages.go index 5a74888f..81caf23f 100644 --- a/pkg/pmapi/messages.go +++ b/pkg/pmapi/messages.go @@ -468,7 +468,7 @@ type MessagesListRes struct { // ListMessages gets message metadata. func (c *Client) ListMessages(filter *MessagesFilter) (msgs []*Message, total int, err error) { - req, err := NewRequest("GET", "/messages", nil) + req, err := c.NewRequest("GET", "/messages", nil) if err != nil { return } @@ -500,7 +500,7 @@ func (c *Client) CountMessages(addressID string) (counts []*MessagesCount, err e if addressID != "" { reqURL += ("?AddressID=" + addressID) } - req, err := NewRequest("GET", reqURL, nil) + req, err := c.NewRequest("GET", reqURL, nil) if err != nil { return } @@ -522,7 +522,7 @@ type MessageRes struct { // GetMessage retrieves a message. func (c *Client) GetMessage(id string) (msg *Message, err error) { - req, err := NewRequest("GET", "/messages/"+id, nil) + req, err := c.NewRequest("GET", "/messages/"+id, nil) if err != nil { return } @@ -599,7 +599,7 @@ func (c *Client) SendMessage(id string, sendReq *SendMessageReq) (sent, parent * sendReq.Packages = []*MessagePackage{} } - req, err := NewJSONRequest("POST", "/messages/"+id, sendReq) + req, err := c.NewJSONRequest("POST", "/messages/"+id, sendReq) if err != nil { return } @@ -629,7 +629,7 @@ type DraftReq struct { func (c *Client) CreateDraft(m *Message, parent string, action int) (created *Message, err error) { createReq := &DraftReq{Message: m, ParentID: parent, Action: action, AttachmentKeyPackets: []string{}} - req, err := NewJSONRequest("POST", "/messages", createReq) + req, err := c.NewJSONRequest("POST", "/messages", createReq) if err != nil { return } @@ -688,7 +688,7 @@ func (c *Client) doMessagesAction(action string, ids []string) (err error) { // You should not call this directly unless you know what you are doing (it can overload the server). func (c *Client) doMessagesActionInner(action string, ids []string) (err error) { actionReq := &MessagesActionReq{IDs: ids} - req, err := NewJSONRequest("PUT", "/messages/"+action, actionReq) + req, err := c.NewJSONRequest("PUT", "/messages/"+action, actionReq) if err != nil { return } @@ -740,7 +740,7 @@ func (c *Client) LabelMessages(ids []string, label string) (err error) { func (c *Client) labelMessages(ids []string, label string) (err error) { labelReq := &LabelMessagesReq{LabelID: label, IDs: ids} - req, err := NewJSONRequest("PUT", "/messages/label", labelReq) + req, err := c.NewJSONRequest("PUT", "/messages/label", labelReq) if err != nil { return } @@ -770,7 +770,7 @@ func (c *Client) UnlabelMessages(ids []string, label string) (err error) { func (c *Client) unlabelMessages(ids []string, label string) (err error) { labelReq := &LabelMessagesReq{LabelID: label, IDs: ids} - req, err := NewJSONRequest("PUT", "/messages/unlabel", labelReq) + req, err := c.NewJSONRequest("PUT", "/messages/unlabel", labelReq) if err != nil { return } @@ -793,7 +793,7 @@ func (c *Client) EmptyFolder(labelID, addressID string) (err error) { reqURL += ("&AddressID=" + addressID) } - req, err := NewRequest("DELETE", reqURL, nil) + req, err := c.NewRequest("DELETE", reqURL, nil) if err != nil { return diff --git a/pkg/pmapi/metrics.go b/pkg/pmapi/metrics.go index b0603d86..4ce17176 100644 --- a/pkg/pmapi/metrics.go +++ b/pkg/pmapi/metrics.go @@ -28,7 +28,7 @@ func (c *Client) SendSimpleMetric(category, action, label string) (err error) { v.Set("Action", action) v.Set("Label", label) - req, err := NewRequest("GET", "/metrics?"+v.Encode(), nil) + req, err := c.NewRequest("GET", "/metrics?"+v.Encode(), nil) if err != nil { return } diff --git a/pkg/pmapi/proxy.go b/pkg/pmapi/proxy.go index 2ba711be..fe36c6be 100644 --- a/pkg/pmapi/proxy.go +++ b/pkg/pmapi/proxy.go @@ -21,7 +21,6 @@ import ( "crypto/tls" "encoding/base64" "strings" - "sync" "time" "github.com/go-resty/resty/v2" @@ -43,63 +42,8 @@ var dohProviders = []string{ //nolint[gochecknoglobals] "https://dns.google/dns-query", } -// globalAllowDoH controls whether or not to enable use of DoH/Proxy in pmapi. -var globalAllowDoH = false // nolint[golint] - -// globalProxyMutex allows threadsafe modification of proxy state. -var globalProxyMutex = sync.RWMutex{} // nolint[golint] - -// globalOriginalURL backs up the original API url so it can be restored later. -var globalOriginalURL = RootURL // nolint[golint] - -// globalIsDoHAllowed returns whether or not to use DoH. -func globalIsDoHAllowed() bool { // nolint[golint] - globalProxyMutex.RLock() - defer globalProxyMutex.RUnlock() - - return globalAllowDoH -} - -// GlobalAllowDoH enables DoH. -func GlobalAllowDoH() { // nolint[golint] - globalProxyMutex.Lock() - defer globalProxyMutex.Unlock() - - globalAllowDoH = true -} - -// GlobalDisallowDoH disables DoH and sets the RootURL back to what it was. -func GlobalDisallowDoH() { // nolint[golint] - globalProxyMutex.Lock() - defer globalProxyMutex.Unlock() - - globalAllowDoH = false - RootURL = globalOriginalURL -} - -// globalSetRootURL sets the global RootURL. -func globalSetRootURL(url string) { // nolint[golint] - globalProxyMutex.Lock() - defer globalProxyMutex.Unlock() - - RootURL = url -} - -// GlobalGetRootURL returns the global RootURL. -func GlobalGetRootURL() (url string) { // nolint[golint] - globalProxyMutex.RLock() - defer globalProxyMutex.RUnlock() - - return RootURL -} - -// isProxyEnabled returns whether or not we are currently using a proxy. -func isProxyEnabled() bool { // nolint[golint] - return globalOriginalURL != GlobalGetRootURL() -} - -// proxyManager manages known proxies. -type proxyManager struct { +// proxyProvider manages known proxies. +type proxyProvider struct { // dohLookup is used to look up the given query at the given DoH provider, returning the TXT records> dohLookup func(query, provider string) (urls []string, err error) @@ -113,10 +57,10 @@ type proxyManager struct { lastLookup time.Time // The time at which we last attempted to find a proxy. } -// newProxyManager creates a new proxyManager that queries the given DoH providers +// newProxyProvider creates a new proxyProvider that queries the given DoH providers // to retrieve DNS records for the given query string. -func newProxyManager(providers []string, query string) (p *proxyManager) { // nolint[unparam] - p = &proxyManager{ +func newProxyProvider(providers []string, query string) (p *proxyProvider) { // nolint[unparam] + p = &proxyProvider{ providers: providers, query: query, useDuration: proxyRevertTime, @@ -132,7 +76,7 @@ func newProxyManager(providers []string, query string) (p *proxyManager) { // no // findProxy returns a new proxy domain which is not equal to the current RootURL. // It returns an error if the process takes longer than ProxySearchTime. -func (p *proxyManager) findProxy() (proxy string, err error) { +func (p *proxyProvider) findProxy() (proxy string, err error) { if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) { return "", errors.New("not looking for a proxy, too soon") } @@ -147,7 +91,7 @@ func (p *proxyManager) findProxy() (proxy string, err error) { } for _, proxy := range p.proxyCache { - if proxy != stripProtocol(GlobalGetRootURL()) && p.canReach(proxy) { + if p.canReach(proxy) { proxyResult <- proxy return } @@ -171,25 +115,8 @@ func (p *proxyManager) findProxy() (proxy string, err error) { } } -// useProxy sets the proxy server to use. It returns to the original RootURL after 24 hours. -func (p *proxyManager) useProxy(proxy string) { - if !isProxyEnabled() { - p.disableProxyAfter(p.useDuration) - } - - globalSetRootURL(https(proxy)) -} - -// disableProxyAfter disables the proxy after the given amount of time. -func (p *proxyManager) disableProxyAfter(d time.Duration) { - go func() { - <-time.After(d) - globalSetRootURL(globalOriginalURL) - }() -} - // refreshProxyCache loads the latest proxies from the known providers. -func (p *proxyManager) refreshProxyCache() error { +func (p *proxyProvider) refreshProxyCache() error { logrus.Info("Refreshing proxy cache") for _, provider := range p.providers { @@ -197,7 +124,7 @@ func (p *proxyManager) refreshProxyCache() error { p.proxyCache = proxies // We also want to allow bridge to switch back to the standard API at any time. - p.proxyCache = append(p.proxyCache, globalOriginalURL) + p.proxyCache = append(p.proxyCache, RootURL) logrus.WithField("proxies", proxies).Info("Available proxies") @@ -210,9 +137,13 @@ func (p *proxyManager) refreshProxyCache() error { // canReach returns whether we can reach the given url. // NOTE: we skip cert verification to stop it complaining that cert name doesn't match hostname. -func (p *proxyManager) canReach(url string) bool { +func (p *proxyProvider) canReach(url string) bool { + if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") { + url = "https://" + url + } + pinger := resty.New(). - SetHostURL(https(url)). + SetHostURL(url). SetTimeout(p.lookupTimeout). SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) // nolint[gosec] @@ -227,7 +158,7 @@ func (p *proxyManager) canReach(url string) bool { // It looks up DNS TXT records for the given query URL using the given DoH provider. // It returns a list of all found TXT records. // If the whole process takes more than ProxyQueryTime then an error is returned. -func (p *proxyManager) defaultDoHLookup(query, dohProvider string) (data []string, err error) { +func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []string, err error) { dataResult := make(chan []string) errResult := make(chan error) go func() { @@ -282,23 +213,3 @@ func (p *proxyManager) defaultDoHLookup(query, dohProvider string) (data []strin return } } - -func stripProtocol(url string) string { - if strings.HasPrefix(url, "https://") { - return strings.TrimPrefix(url, "https://") - } - - if strings.HasPrefix(url, "http://") { - return strings.TrimPrefix(url, "http://") - } - - return url -} - -func https(url string) string { - if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") { - url = "https://" + url - } - - return url -} diff --git a/pkg/pmapi/proxy_test.go b/pkg/pmapi/proxy_test.go index 15cf3257..d14718c4 100644 --- a/pkg/pmapi/proxy_test.go +++ b/pkg/pmapi/proxy_test.go @@ -32,14 +32,14 @@ const ( TestGoogleProvider = "https://dns.google/dns-query" ) -func TestProxyManager_FindProxy(t *testing.T) { +func TestProxyProvider_FindProxy(t *testing.T) { blockAPI() defer unblockAPI() proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy.Close() - p := newProxyManager([]string{"not used"}, "not used") + p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } url, err := p.findProxy() @@ -47,7 +47,7 @@ func TestProxyManager_FindProxy(t *testing.T) { require.Equal(t, proxy.URL, url) } -func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) { +func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) { blockAPI() defer unblockAPI() @@ -58,7 +58,7 @@ func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) { badProxy.Close() defer goodProxy.Close() - p := newProxyManager([]string{"not used"}, "not used") + p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, goodProxy.URL}, nil } url, err := p.findProxy() @@ -66,7 +66,7 @@ func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) { require.Equal(t, goodProxy.URL, url) } -func TestProxyManager_FindProxy_FailIfNoneReachable(t *testing.T) { +func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) { blockAPI() defer unblockAPI() @@ -77,21 +77,21 @@ func TestProxyManager_FindProxy_FailIfNoneReachable(t *testing.T) { badProxy.Close() anotherBadProxy.Close() - p := newProxyManager([]string{"not used"}, "not used") + p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, anotherBadProxy.URL}, nil } _, err := p.findProxy() require.Error(t, err) } -func TestProxyManager_FindProxy_LookupTimeout(t *testing.T) { +func TestProxyProvider_FindProxy_LookupTimeout(t *testing.T) { blockAPI() defer unblockAPI() proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy.Close() - p := newProxyManager([]string{"not used"}, "not used") + p := newProxyProvider([]string{"not used"}, "not used") p.lookupTimeout = time.Second p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil } @@ -100,7 +100,7 @@ func TestProxyManager_FindProxy_LookupTimeout(t *testing.T) { require.Error(t, err) } -func TestProxyManager_FindProxy_FindTimeout(t *testing.T) { +func TestProxyProvider_FindProxy_FindTimeout(t *testing.T) { blockAPI() defer unblockAPI() @@ -109,7 +109,7 @@ func TestProxyManager_FindProxy_FindTimeout(t *testing.T) { })) defer slowProxy.Close() - p := newProxyManager([]string{"not used"}, "not used") + p := newProxyProvider([]string{"not used"}, "not used") p.findTimeout = time.Second p.dohLookup = func(q, p string) ([]string, error) { return []string{slowProxy.URL}, nil } @@ -118,14 +118,14 @@ func TestProxyManager_FindProxy_FindTimeout(t *testing.T) { require.Error(t, err) } -func TestProxyManager_UseProxy(t *testing.T) { +func TestProxyProvider_UseProxy(t *testing.T) { blockAPI() defer unblockAPI() proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy.Close() - p := newProxyManager([]string{"not used"}, "not used") + p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } url, err := p.findProxy() @@ -135,7 +135,7 @@ func TestProxyManager_UseProxy(t *testing.T) { require.Equal(t, proxy.URL, GlobalGetRootURL()) } -func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) { +func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { blockAPI() defer unblockAPI() @@ -146,7 +146,7 @@ func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) { proxy3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy3.Close() - p := newProxyManager([]string{"not used"}, "not used") + p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil } url, err := p.findProxy() @@ -173,14 +173,14 @@ func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) { require.Equal(t, proxy3.URL, GlobalGetRootURL()) } -func TestProxyManager_UseProxy_RevertAfterTime(t *testing.T) { +func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) { blockAPI() defer unblockAPI() proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy.Close() - p := newProxyManager([]string{"not used"}, "not used") + p := newProxyProvider([]string{"not used"}, "not used") p.useDuration = time.Second p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } @@ -195,14 +195,14 @@ func TestProxyManager_UseProxy_RevertAfterTime(t *testing.T) { require.Equal(t, globalOriginalURL, GlobalGetRootURL()) } -func TestProxyManager_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) { +func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) { // Don't block the API here because we want it to be working so the test can find it. defer unblockAPI() proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy.Close() - p := newProxyManager([]string{"not used"}, "not used") + p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } url, err := p.findProxy() @@ -225,7 +225,7 @@ func TestProxyManager_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachabl require.Equal(t, globalOriginalURL, GlobalGetRootURL()) } -func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) { +func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) { blockAPI() defer unblockAPI() @@ -234,7 +234,7 @@ func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlo proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer proxy2.Close() - p := newProxyManager([]string{"not used"}, "not used") + p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } // Find a proxy. @@ -256,32 +256,32 @@ func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlo require.Equal(t, proxy2.URL, GlobalGetRootURL()) } -func TestProxyManager_DoHLookup_Quad9(t *testing.T) { - p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) +func TestProxyProvider_DoHLookup_Quad9(t *testing.T) { + p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) records, err := p.dohLookup(TestDoHQuery, TestQuad9Provider) require.NoError(t, err) require.NotEmpty(t, records) } -func TestProxyManager_DoHLookup_Google(t *testing.T) { - p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) +func TestProxyProvider_DoHLookup_Google(t *testing.T) { + p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) records, err := p.dohLookup(TestDoHQuery, TestGoogleProvider) require.NoError(t, err) require.NotEmpty(t, records) } -func TestProxyManager_DoHLookup_FindProxy(t *testing.T) { - p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) +func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) { + p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) url, err := p.findProxy() require.NoError(t, err) require.NotEmpty(t, url) } -func TestProxyManager_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) { - p := newProxyManager([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery) +func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) { + p := newProxyProvider([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery) url, err := p.findProxy() require.NoError(t, err) diff --git a/pkg/pmapi/req.go b/pkg/pmapi/req.go index 66a989b7..b8d8a4f3 100644 --- a/pkg/pmapi/req.go +++ b/pkg/pmapi/req.go @@ -26,8 +26,9 @@ import ( ) // NewRequest creates a new request. -func NewRequest(method, path string, body io.Reader) (req *http.Request, err error) { - req, err = http.NewRequest(method, GlobalGetRootURL()+path, body) +func (c *Client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) { + // TODO: Support other protocols (localhost needs http not https). + req, err = http.NewRequest(method, "https://"+c.cm.GetRootURL()+path, body) if req != nil { req.Header.Set("User-Agent", CurrentUserAgent) } @@ -35,13 +36,13 @@ func NewRequest(method, path string, body io.Reader) (req *http.Request, err err } // NewJSONRequest create a new JSON request. -func NewJSONRequest(method, path string, body interface{}) (*http.Request, error) { +func (c *Client) NewJSONRequest(method, path string, body interface{}) (*http.Request, error) { b, err := json.Marshal(body) if err != nil { panic(err) } - req, err := NewRequest(method, path, bytes.NewReader(b)) + req, err := c.NewRequest(method, path, bytes.NewReader(b)) if err != nil { return nil, err } @@ -70,7 +71,7 @@ func (w *MultipartWriter) Close() error { // that writing the request and sending it MUST be done in parallel. If the // request fails, subsequent writes to the multipart writer will fail with an // io.ErrClosedPipe error. -func NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWriter, err error) { +func (c *Client) NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWriter, err error) { // The pipe will connect the multipart writer and the HTTP request body. pr, pw := io.Pipe() @@ -80,7 +81,7 @@ func NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWr pw, } - req, err = NewRequest(method, path, pr) + req, err = c.NewRequest(method, path, pr) if err != nil { return } diff --git a/pkg/pmapi/settings.go b/pkg/pmapi/settings.go index a0f89125..20247f58 100644 --- a/pkg/pmapi/settings.go +++ b/pkg/pmapi/settings.go @@ -45,7 +45,7 @@ type UserSettings struct { // GetUserSettings gets general settings. func (c *Client) GetUserSettings() (settings UserSettings, err error) { - req, err := NewRequest("GET", "/settings", nil) + req, err := c.NewRequest("GET", "/settings", nil) if err != nil { return @@ -99,7 +99,7 @@ type MailSettings struct { // GetMailSettings gets contact details specified by contact ID. func (c *Client) GetMailSettings() (settings MailSettings, err error) { - req, err := NewRequest("GET", "/settings/mail", nil) + req, err := c.NewRequest("GET", "/settings/mail", nil) if err != nil { return diff --git a/pkg/pmapi/users.go b/pkg/pmapi/users.go index 1815f1db..0f62cf93 100644 --- a/pkg/pmapi/users.go +++ b/pkg/pmapi/users.go @@ -93,7 +93,7 @@ func (u *User) KeyRing() *pmcrypto.KeyRing { // UpdateUser retrieves details about user and loads its addresses. func (c *Client) UpdateUser() (user *User, err error) { - req, err := NewRequest("GET", "/users", nil) + req, err := c.NewRequest("GET", "/users", nil) if err != nil { return }