Launcher, app/base, sentry, update service

This commit is contained in:
James Houlahan
2020-11-23 11:56:57 +01:00
parent 6fffb460b8
commit dc3f61acee
164 changed files with 5368 additions and 4039 deletions

View File

@ -81,4 +81,6 @@ type Client interface {
KeyRingForAddressID(string) (kr *crypto.KeyRing, err error)
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
DownloadAndVerify(string, string, *crypto.KeyRing) (io.Reader, error)
}

View File

@ -137,10 +137,18 @@ func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
cm.roundTripper = rt
}
func (cm *ClientManager) GetClientConfig() *ClientConfig {
return cm.config
}
func (cm *ClientManager) SetUserAgent(clientName, clientVersion, os string) {
cm.config.UserAgent = formatUserAgent(clientName, clientVersion, os)
}
func (cm *ClientManager) GetUserAgent() string {
return cm.config.UserAgent
}
// GetClient returns a client for the given userID.
// If the client does not exist already, it is created.
func (cm *ClientManager) GetClient(userID string) Client {
@ -366,7 +374,7 @@ func (cm *ClientManager) clearToken(userID string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).Info("Clearing token")
logrus.WithField("userID", userID).Debug("Clearing token")
delete(cm.tokens, userID)
}

View File

@ -18,23 +18,23 @@
package pmapi
import (
"net/http"
"strings"
"time"
)
// rootURL is the API root URL.
//
// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
// Default is pmapi_prod.
//
// It must not contain the protocol! The protocol should be in rootScheme.
var rootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
var rootScheme = "https" //nolint[gochecknoglobals]
// The HTTP transport to use by default.
var defaultTransport = &http.Transport{ //nolint[gochecknoglobals]
Proxy: http.ProxyFromEnvironment,
// rootScheme is the scheme to use for connections to the root URL.
var rootScheme = "https" //nolint[gochecknoglobals]
func GetAPIConfig(configName, appVersion string) *ClientConfig {
return &ClientConfig{
AppVersion: strings.Title(configName) + "_" + appVersion,
ClientID: configName,
Timeout: 25 * time.Minute, // Overall request timeout (~25MB / 25 mins => ~16kB/s, should be reasonable).
FirstReadTimeout: 30 * time.Second, // 30s to match 30s response header timeout.
MinBytesPerSecond: 1 << 10, // Enforce minimum download speed of 1kB/s.
}
}
// checkTLSCerts controls whether TLS certs are checked against known fingerprints.
// The default is for this to always be done.
var checkTLSCerts = true //nolint[gochecknoglobals]

View File

@ -0,0 +1,44 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// +build !pmapi_qa
package pmapi
import (
"net/http"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
)
func GetRoundTripper(cm *ClientManager, listener listener.Listener) http.RoundTripper {
// We use a TLS dialer.
basicDialer := NewBasicTLSDialer()
// We wrap the TLS dialer in a layer which enforces connections to trusted servers.
pinningDialer := NewPinningTLSDialer(basicDialer)
// We want any pin mismatches to be communicated back to bridge GUI and reported.
pinningDialer.SetTLSIssueNotifier(func() { listener.Emit(events.TLSCertIssue, "") })
pinningDialer.EnableRemoteTLSIssueReporting(cm)
// We wrap the pinning dialer in a layer which adds "alternative routing" feature.
proxyDialer := NewProxyTLSDialer(pinningDialer, cm)
return CreateTransportWithDialer(proxyDialer)
}

View File

@ -24,6 +24,8 @@ import (
"net/http"
"os"
"strings"
"github.com/ProtonMail/proton-bridge/pkg/listener"
)
func init() {
@ -37,13 +39,13 @@ func init() {
rootURL = fullRootURL
rootScheme = "https"
}
}
func GetRoundTripper(_ *ClientManager, _ listener.Listener) http.RoundTripper {
transport := CreateTransportWithDialer(NewBasicTLSDialer())
// TLS certificate of testing environment might be self-signed.
defaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
// This config disables TLS cert checking.
checkTLSCerts = false
return transport
}

View File

@ -30,19 +30,12 @@ type PinningTLSDialer struct {
dialer TLSDialer
// pinChecker is used to check TLS keys of connections.
pinChecker pinChecker
pinChecker *pinChecker
// tlsIssueNotifier is used to notify something when there is a TLS issue.
tlsIssueNotifier func()
// appVersion is needed to report TLS mismatches.
appVersion string
// userAgent is needed to report TLS mismatches.
userAgent string
// enableRemoteReporting instructs the dialer to report TLS mismatches.
enableRemoteReporting bool
reporter *tlsReporter
// A logger for logging messages.
log logrus.FieldLogger
@ -63,41 +56,38 @@ func (p *PinningTLSDialer) SetTLSIssueNotifier(notifier func()) {
p.tlsIssueNotifier = notifier
}
func (p *PinningTLSDialer) EnableRemoteTLSIssueReporting(appVersion, userAgent string) {
p.enableRemoteReporting = true
p.appVersion = appVersion
p.userAgent = userAgent
func (p *PinningTLSDialer) EnableRemoteTLSIssueReporting(cm *ClientManager) {
p.reporter = newTLSReporter(p.pinChecker, cm)
}
// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
func (p *PinningTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
if conn, err = p.dialer.DialTLS(network, address); err != nil {
return
func (p *PinningTLSDialer) DialTLS(network, address string) (net.Conn, error) {
conn, err := p.dialer.DialTLS(network, address)
if err != nil {
return nil, err
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return
return nil, err
}
if err = p.pinChecker.checkCertificate(conn); err != nil {
if err := p.pinChecker.checkCertificate(conn); err != nil {
if p.tlsIssueNotifier != nil {
go p.tlsIssueNotifier()
}
if tlsConn, ok := conn.(*tls.Conn); ok && p.enableRemoteReporting {
p.pinChecker.reportCertIssue(
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
p.reporter.reportCertIssue(
TLSReportURI,
host,
port,
tlsConn.ConnectionState(),
p.appVersion,
p.userAgent,
)
}
return
return nil, err
}
return
return conn, nil
}

74
pkg/pmapi/download.go Normal file
View File

@ -0,0 +1,74 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
// DownloadAndVerify downloads a file and its signature from the given locations `file` and `sig`.
// The file and its signature are verified using the given keyring `kr`.
// If the file is verified successfully, it can be read from the returned reader.
// TLS fingerprinting is used to verify that connections are only made to known servers.
func (c *client) DownloadAndVerify(file, sig string, kr *crypto.KeyRing) (io.Reader, error) {
var fb, sb []byte
if err := c.fetchFile(file, func(r io.Reader) (err error) {
fb, err = ioutil.ReadAll(r)
return err
}); err != nil {
return nil, err
}
if err := c.fetchFile(sig, func(r io.Reader) (err error) {
sb, err = ioutil.ReadAll(r)
return err
}); err != nil {
return nil, err
}
if err := kr.VerifyDetached(
crypto.NewPlainMessage(fb),
crypto.NewPGPSignature(sb),
crypto.GetUnixTime(),
); err != nil {
return nil, err
}
return bytes.NewReader(fb), nil
}
func (c *client) fetchFile(file string, fn func(io.Reader) error) error {
res, err := c.hc.Get(file)
if err != nil {
return err
}
defer func() { _ = res.Body.Close() }()
if res.StatusCode != http.StatusOK {
return fmt.Errorf("failed to get file: http error %v", res.StatusCode)
}
return fn(res.Body)
}

View File

@ -294,6 +294,21 @@ func (mr *MockClientMockRecorder) DeleteMessages(arg0 interface{}) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMessages", reflect.TypeOf((*MockClient)(nil).DeleteMessages), arg0)
}
// DownloadAndVerify mocks base method
func (m *MockClient) DownloadAndVerify(arg0, arg1 string, arg2 *crypto.KeyRing) (io.Reader, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DownloadAndVerify", arg0, arg1, arg2)
ret0, _ := ret[0].(io.Reader)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DownloadAndVerify indicates an expected call of DownloadAndVerify
func (mr *MockClientMockRecorder) DownloadAndVerify(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadAndVerify", reflect.TypeOf((*MockClient)(nil).DownloadAndVerify), arg0, arg1, arg2)
}
// EmptyFolder mocks base method
func (m *MockClient) EmptyFolder(arg0, arg1 string) error {
m.ctrl.T.Helper()

View File

@ -35,7 +35,6 @@ import (
type pinChecker struct {
trustedPins []string
sentReports []sentReport
}
type sentReport struct {
@ -43,8 +42,8 @@ type sentReport struct {
t time.Time
}
func newPinChecker(trustedPins []string) pinChecker {
return pinChecker{
func newPinChecker(trustedPins []string) *pinChecker {
return &pinChecker{
trustedPins: trustedPins,
}
}
@ -76,8 +75,25 @@ func certFingerprint(cert *x509.Certificate) string {
return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:]))
}
type clientConfigProvider interface {
GetClientConfig() *ClientConfig
}
type tlsReporter struct {
cm clientConfigProvider
p *pinChecker
sentReports []sentReport
}
func newTLSReporter(p *pinChecker, cm clientConfigProvider) *tlsReporter {
return &tlsReporter{
cm: cm,
p: p,
}
}
// reportCertIssue reports a TLS key mismatch.
func (p *pinChecker) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState, appVersion, userAgent string) {
func (r *tlsReporter) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState) {
var certChain []string
if len(connState.VerifiedChains) > 0 {
@ -86,27 +102,29 @@ func (p *pinChecker) reportCertIssue(remoteURI, host, port string, connState tls
certChain = marshalCert7468(connState.PeerCertificates)
}
r := newTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion)
cfg := r.cm.GetClientConfig()
if !p.hasRecentlySentReport(r) {
p.recordReport(r)
go r.sendReport(remoteURI, userAgent)
report := newTLSReport(host, port, connState.ServerName, certChain, r.p.trustedPins, cfg.AppVersion)
if !r.hasRecentlySentReport(report) {
r.recordReport(report)
go report.sendReport(remoteURI, cfg.UserAgent)
}
}
// hasRecentlySentReport returns whether the report was already sent within the last 24 hours.
func (p *pinChecker) hasRecentlySentReport(report tlsReport) bool {
func (r *tlsReporter) hasRecentlySentReport(report tlsReport) bool {
var validReports []sentReport
for _, r := range p.sentReports {
for _, r := range r.sentReports {
if time.Since(r.t) < 24*time.Hour {
validReports = append(validReports, r)
}
}
p.sentReports = validReports
r.sentReports = validReports
for _, r := range p.sentReports {
for _, r := range r.sentReports {
if cmp.Equal(report, r.r) {
return true
}
@ -116,8 +134,8 @@ func (p *pinChecker) hasRecentlySentReport(report tlsReport) bool {
}
// recordReport records the given report and the current time so we can check whether we recently sent this report.
func (p *pinChecker) recordReport(r tlsReport) {
p.sentReports = append(p.sentReports, sentReport{r: r, t: time.Now()})
func (r *tlsReporter) recordReport(report tlsReport) {
r.sentReports = append(r.sentReports, sentReport{r: report, t: time.Now()})
}
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {

View File

@ -27,6 +27,14 @@ import (
"github.com/stretchr/testify/assert"
)
type fakeClientConfigProvider struct {
version, useragent string
}
func (c *fakeClientConfigProvider) GetClientConfig() *ClientConfig {
return &ClientConfig{AppVersion: c.version, UserAgent: c.useragent}
}
func TestPinCheckerDoubleReport(t *testing.T) {
reportCounter := 0
@ -34,11 +42,11 @@ func TestPinCheckerDoubleReport(t *testing.T) {
reportCounter++
}))
pc := newPinChecker(TrustedAPIPins)
r := newTLSReporter(newPinChecker(TrustedAPIPins), &fakeClientConfigProvider{version: "3", useragent: "useragent"})
// Report the same issue many times.
for i := 0; i < 10; i++ {
pc.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{}, "3", "useragent")
r.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{})
}
// We should only report once.
@ -48,7 +56,7 @@ func TestPinCheckerDoubleReport(t *testing.T) {
// If we then report something else many times.
for i := 0; i < 10; i++ {
pc.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{}, "3", "useragent")
r.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{})
}
// We should get a second report.

View File

@ -120,9 +120,7 @@ func (c *client) UpdateUser() (user *User, err error) {
c.user = user
sentry.ConfigureScope(func(scope *sentry.Scope) {
scope.SetUser(sentry.User{
ID: user.ID,
})
scope.SetUser(sentry.User{ID: user.ID})
})
var tmpList AddressList