Launcher, app/base, sentry, update service

This commit is contained in:
James Houlahan
2020-11-23 11:56:57 +01:00
parent 6fffb460b8
commit dc3f61acee
164 changed files with 5368 additions and 4039 deletions

View File

@ -1,319 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"io/ioutil"
"os"
"path/filepath"
"runtime"
"github.com/ProtonMail/go-appdir"
"github.com/hashicorp/go-multierror"
"github.com/sirupsen/logrus"
)
var (
log = logrus.WithField("pkg", "config") //nolint[gochecknoglobals]
)
type appDirProvider interface {
UserConfig() string
UserCache() string
UserLogs() string
}
type Config struct {
appName string
version string
revision string
cacheVersion string
appDirs appDirProvider
appDirsVersion appDirProvider
}
// New returns fully initialized config struct.
// `appName` should be in camelCase format for folder or file names. It's also used in API
// as `AppVersion` which is converted to CamelCase.
// `version` is the version of the app (e.g. v1.2.3).
// `cacheVersion` is the version of the cache files (setting a different number will remove the old ones).
func New(appName, version, revision, cacheVersion string) *Config {
appDirs := appdir.New(filepath.Join("protonmail", appName))
appDirsVersion := appdir.New(filepath.Join("protonmail", appName, cacheVersion))
return newConfig(appName, version, revision, cacheVersion, appDirs, appDirsVersion)
}
func newConfig(appName, version, revision, cacheVersion string, appDirs, appDirsVersion appDirProvider) *Config {
return &Config{
appName: appName,
version: version,
revision: revision,
cacheVersion: cacheVersion,
appDirs: appDirs,
appDirsVersion: appDirsVersion,
}
}
// CreateDirs creates all folders that are necessary for bridge to properly function.
func (c *Config) CreateDirs() error {
// Log files.
if err := os.MkdirAll(c.appDirs.UserLogs(), 0700); err != nil {
return err
}
// TLS files.
if err := os.MkdirAll(c.appDirs.UserConfig(), 0750); err != nil {
return err
}
// Lock, events, preferences, user_info, db files.
if err := os.MkdirAll(c.appDirsVersion.UserCache(), 0750); err != nil {
return err
}
return nil
}
// ClearData removes all files except the lock file.
// The lock file will be removed when the Bridge stops.
func (c *Config) ClearData() error {
dirs := []string{
c.appDirs.UserLogs(),
c.appDirs.UserConfig(),
c.appDirs.UserCache(),
}
shouldRemove := func(filePath string) bool {
return filePath != c.GetLockPath()
}
return c.removeAllExcept(dirs, shouldRemove)
}
// ClearOldData removes all old files, such as old log files or old versions of cache and so on.
func (c *Config) ClearOldData() error {
// `appDirs` is parent for `appDirsVersion`.
// `dir` then contains all subfolders and only `cacheVersion` should stay.
// But on Windows all files (dirs) are in the same one - we cannot remove log, lock or tls files.
dir := c.appDirs.UserCache()
return c.removeExcept(dir, func(filePath string) bool {
fileName := filepath.Base(filePath)
return (fileName != c.cacheVersion &&
!logFileRgx.MatchString(fileName) &&
filePath != c.GetLogDir() &&
filePath != c.GetTLSCertPath() &&
filePath != c.GetTLSKeyPath() &&
filePath != c.GetEventsPath() &&
filePath != c.GetIMAPCachePath() &&
filePath != c.GetLockPath() &&
filePath != c.GetPreferencesPath())
})
}
func (c *Config) removeAllExcept(dirs []string, shouldRemove func(string) bool) error {
var result *multierror.Error
for _, dir := range dirs {
if err := c.removeExcept(dir, shouldRemove); err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}
func (c *Config) removeExcept(dir string, shouldRemove func(string) bool) error {
files, err := ioutil.ReadDir(dir)
if err != nil {
return err
}
var result *multierror.Error
for _, file := range files {
filePath := filepath.Join(dir, file.Name())
if !shouldRemove(filePath) {
continue
}
if !file.IsDir() {
if err := os.RemoveAll(filePath); err != nil {
result = multierror.Append(result, err)
}
continue
}
subDir := filepath.Join(dir, file.Name())
if err := c.removeExcept(subDir, shouldRemove); err != nil {
result = multierror.Append(result, err)
} else {
// Remove dir itself only if it's empty.
subFiles, err := ioutil.ReadDir(subDir)
if err != nil {
result = multierror.Append(result, err)
} else if len(subFiles) == 0 {
if err := os.RemoveAll(subDir); err != nil {
result = multierror.Append(result, err)
}
}
}
}
return result.ErrorOrNil()
}
// IsDevMode should be used for development conditions such us whether to send sentry reports.
func (c *Config) IsDevMode() bool {
return os.Getenv("PROTONMAIL_ENV") == "dev"
}
// GetVersion returns the version.
func (c *Config) GetVersion() string {
return c.version
}
// GetLogDir returns folder for log files.
func (c *Config) GetLogDir() string {
return c.appDirs.UserLogs()
}
// GetLogPrefix returns prefix for log files. Bridge uses format vVERSION.
func (c *Config) GetLogPrefix() string {
return "v" + c.version + "_" + c.revision
}
// GetLicenseFilePath returns path to liense file.
func (c *Config) GetLicenseFilePath() string {
path := c.getLicenseFilePath()
log.WithField("path", path).Info("License file path")
return path
}
func (c *Config) getLicenseFilePath() string {
// User can install app to different location, or user can run it
// directly from the package without installation, or it could be
// automatically updated (app started from differenet location).
// For all those cases, first let's check LICENSE next to the binary.
path := filepath.Join(filepath.Dir(os.Args[0]), "LICENSE")
if _, err := os.Stat(path); err == nil {
return path
}
switch runtime.GOOS {
case "linux":
appName := c.appName
if c.appName == "importExport" {
appName = "import-export"
}
// Most Linux distributions.
path := "/usr/share/doc/protonmail/" + appName + "/LICENSE"
if _, err := os.Stat(path); err == nil {
return path
}
// Arch distributions.
return "/usr/share/licenses/protonmail-" + appName + "/LICENSE"
case "darwin": //nolint[goconst]
path := filepath.Join(filepath.Dir(os.Args[0]), "..", "Resources", "LICENSE")
if _, err := os.Stat(path); err == nil {
return path
}
appName := "ProtonMail Bridge.app"
if c.appName == "importExport" {
appName = "ProtonMail Import-Export.app"
}
return "/Applications/" + appName + "/Contents/Resources/LICENSE"
case "windows":
path := filepath.Join(filepath.Dir(os.Args[0]), "LICENSE.txt")
if _, err := os.Stat(path); err == nil {
return path
}
// This should not happen, Windows should be handled by relative
// location to the binary above. This is just fallback which may
// or may not work, depends where user installed the app and how
// user started the app.
return filepath.FromSlash("C:/Program Files/Proton Technologies AG/ProtonMail Bridge/LICENSE.txt")
}
return ""
}
// GetTLSCertPath returns path to certificate; used for TLS servers (IMAP, SMTP and API).
func (c *Config) GetTLSCertPath() string {
return filepath.Join(c.appDirs.UserConfig(), "cert.pem")
}
// GetTLSKeyPath returns path to private key; used for TLS servers (IMAP, SMTP and API).
func (c *Config) GetTLSKeyPath() string {
return filepath.Join(c.appDirs.UserConfig(), "key.pem")
}
// GetDBDir returns folder for db files.
func (c *Config) GetDBDir() string {
return c.appDirsVersion.UserCache()
}
// GetEventsPath returns path to events file containing the last processed event IDs.
func (c *Config) GetEventsPath() string {
return filepath.Join(c.appDirsVersion.UserCache(), "events.json")
}
// GetIMAPCachePath returns path to file with IMAP status.
func (c *Config) GetIMAPCachePath() string {
return filepath.Join(c.appDirsVersion.UserCache(), "user_info.json")
}
// GetLockPath returns path to lock file to check if bridge is already running.
func (c *Config) GetLockPath() string {
return filepath.Join(c.appDirsVersion.UserCache(), c.appName+".lock")
}
// GetUpdateDir returns folder for update files; such as new binary.
func (c *Config) GetUpdateDir() string {
return filepath.Join(c.appDirsVersion.UserCache(), "updates")
}
// GetPreferencesPath returns path to preference file.
func (c *Config) GetPreferencesPath() string {
return filepath.Join(c.appDirsVersion.UserCache(), "prefs.json")
}
// GetTransferDir returns folder for import-export rules files.
func (c *Config) GetTransferDir() string {
return c.appDirsVersion.UserCache()
}
// GetDefaultAPIPort returns default Bridge local API port.
func (c *Config) GetDefaultAPIPort() int {
return 1042
}
// GetDefaultIMAPPort returns default Bridge IMAP port.
func (c *Config) GetDefaultIMAPPort() int {
return 1143
}
// GetDefaultSMTPPort returns default Bridge SMTP port.
func (c *Config) GetDefaultSMTPPort() int {
return 1025
}
// getAPIOS returns actual operating system.
func (c *Config) getAPIOS() string {
switch os := runtime.GOOS; os {
case "darwin": // nolint: goconst
return "macOS"
case "linux":
return "Linux"
case "windows":
return "Windows"
}
return "Linux"
}

View File

@ -1,238 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
const testAppName = "bridge-test"
var testConfigDir string //nolint[gochecknoglobals]
func TestMain(m *testing.M) {
setupTestConfig()
setupTestLogs()
code := m.Run()
shutdownTestConfig()
shutdownTestLogs()
shutdownTestPreferences()
os.Exit(code)
}
func setupTestConfig() {
var err error
testConfigDir, err = ioutil.TempDir("", "config")
if err != nil {
panic(err)
}
}
func shutdownTestConfig() {
_ = os.RemoveAll(testConfigDir)
}
type mocks struct {
t *testing.T
ctrl *gomock.Controller
appDir *MockappDirer
appDirVersion *MockappDirer
}
func initMocks(t *testing.T) mocks {
mockCtrl := gomock.NewController(t)
return mocks{
t: t,
ctrl: mockCtrl,
appDir: NewMockappDirer(mockCtrl),
appDirVersion: NewMockappDirer(mockCtrl),
}
}
func TestClearDataLinux(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
createTestStructureLinux(m, testConfigDir)
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
require.NoError(t, cfg.ClearData())
checkFileNames(t, testConfigDir, []string{
"cache",
"cache/c2",
"cache/c2/bridge-test.lock",
"config",
"logs",
})
}
func TestClearDataWindows(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
createTestStructureWindows(m, testConfigDir)
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
require.NoError(t, cfg.ClearData())
checkFileNames(t, testConfigDir, []string{
"cache",
"cache/c2",
"cache/c2/bridge-test.lock",
"config",
})
}
// OldData touches only cache folder.
// Removes only c1 folder as nothing else is part of cache folder on Linux/Mac.
func TestClearOldDataLinux(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
createTestStructureLinux(m, testConfigDir)
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
require.NoError(t, cfg.ClearOldData())
checkFileNames(t, testConfigDir, []string{
"cache",
"cache/c2",
"cache/c2/bridge-test.lock",
"cache/c2/events.json",
"cache/c2/mailbox-user@pm.me.db",
"cache/c2/prefs.json",
"cache/c2/updates",
"cache/c2/user_info.json",
"config",
"config/cert.pem",
"config/key.pem",
"logs",
"logs/other.log",
"logs/v1_10.log",
"logs/v1_11.log",
"logs/v2_12.log",
"logs/v2_13.log",
})
}
// OldData touches only cache folder. Removes everything except c2 folder
// and bridge log files which are part of cache folder on Windows.
func TestClearOldDataWindows(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
createTestStructureWindows(m, testConfigDir)
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
require.NoError(t, cfg.ClearOldData())
checkFileNames(t, testConfigDir, []string{
"cache",
"cache/c2",
"cache/c2/bridge-test.lock",
"cache/c2/events.json",
"cache/c2/mailbox-user@pm.me.db",
"cache/c2/prefs.json",
"cache/c2/updates",
"cache/c2/user_info.json",
"cache/v1_10.log",
"cache/v1_11.log",
"cache/v2_12.log",
"cache/v2_13.log",
"config",
"config/cert.pem",
"config/key.pem",
})
}
func createTestStructureLinux(m mocks, baseDir string) {
logsDir := filepath.Join(baseDir, "logs")
configDir := filepath.Join(baseDir, "config")
cacheDir := filepath.Join(baseDir, "cache")
versionedOldCacheDir := filepath.Join(baseDir, "cache", "c1")
versionedCacheDir := filepath.Join(baseDir, "cache", "c2")
createTestStructure(m, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir)
}
func createTestStructureWindows(m mocks, baseDir string) {
logsDir := filepath.Join(baseDir, "cache")
configDir := filepath.Join(baseDir, "config")
cacheDir := filepath.Join(baseDir, "cache")
versionedOldCacheDir := filepath.Join(baseDir, "cache", "c1")
versionedCacheDir := filepath.Join(baseDir, "cache", "c2")
createTestStructure(m, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir)
}
func createTestStructure(m mocks, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir string) {
m.appDir.EXPECT().UserLogs().Return(logsDir).AnyTimes()
m.appDir.EXPECT().UserConfig().Return(configDir).AnyTimes()
m.appDir.EXPECT().UserCache().Return(cacheDir).AnyTimes()
m.appDirVersion.EXPECT().UserCache().Return(versionedCacheDir).AnyTimes()
require.NoError(m.t, os.RemoveAll(baseDir))
require.NoError(m.t, os.MkdirAll(baseDir, 0700))
require.NoError(m.t, os.MkdirAll(logsDir, 0700))
require.NoError(m.t, os.MkdirAll(configDir, 0700))
require.NoError(m.t, os.MkdirAll(cacheDir, 0700))
require.NoError(m.t, os.MkdirAll(versionedOldCacheDir, 0700))
require.NoError(m.t, os.MkdirAll(versionedCacheDir, 0700))
require.NoError(m.t, os.MkdirAll(filepath.Join(versionedCacheDir, "updates"), 0700))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "other.log"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v1_10.log"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v1_11.log"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v2_12.log"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v2_13.log"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(configDir, "cert.pem"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(configDir, "key.pem"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "prefs.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "events.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "user_info.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "mailbox-user@pm.me.db"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "prefs.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "events.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "user_info.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, testAppName+".lock"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "mailbox-user@pm.me.db"), []byte("Hello"), 0755))
}
func checkFileNames(t *testing.T, dir string, expectedFileNames []string) {
fileNames := getFileNames(t, dir)
require.Equal(t, expectedFileNames, fileNames)
}
func getFileNames(t *testing.T, dir string) []string {
files, err := ioutil.ReadDir(dir)
require.NoError(t, err)
fileNames := []string{}
for _, file := range files {
fileNames = append(fileNames, file.Name())
if file.IsDir() {
subDir := filepath.Join(dir, file.Name())
subFileNames := getFileNames(t, subDir)
for _, subFileName := range subFileNames {
fileNames = append(fileNames, file.Name()+"/"+subFileName)
}
}
}
return fileNames
}

View File

@ -1,251 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"runtime"
"runtime/pprof"
"sort"
"strconv"
"time"
"github.com/ProtonMail/proton-bridge/pkg/sentry"
"github.com/sirupsen/logrus"
)
type logConfiger interface {
GetLogDir() string
GetLogPrefix() string
}
const (
// Zendesk now has a file size limit of 20MB. When the last N log files
// are zipped, it should fit under 20MB. Value in MB (average file has
// few hundreds kB).
maxLogFileSize = 10 * 1024 * 1024 //nolint[gochecknoglobals]
// Including the current logfile.
maxNumberLogFiles = 3 //nolint[gochecknoglobals]
)
// logFile is pointer to currently open file used by logrus.
var logFile *os.File //nolint[gochecknoglobals]
var logFileRgx = regexp.MustCompile("^v.*\\.log$") //nolint[gochecknoglobals]
var logCrashRgx = regexp.MustCompile("^v.*_crash_.*\\.log$") //nolint[gochecknoglobals]
// HandlePanic reports the crash to sentry or local file when sentry fails.
func HandlePanic(cfg *Config, output string) {
sentry.SkipDuringUnwind()
if !cfg.IsDevMode() {
apiCfg := cfg.GetAPIConfig()
if err := sentry.ReportSentryCrash(apiCfg.ClientID, apiCfg.AppVersion, apiCfg.UserAgent, errors.New(output)); err != nil {
log.Error("Sentry crash report failed: ", err)
}
}
filename := getLogFilename(cfg.GetLogPrefix() + "_crash_")
filepath := filepath.Join(cfg.GetLogDir(), filename)
f, err := os.OpenFile(filepath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
if err != nil {
log.Error("Cannot open file to write crash report: ", err)
return
}
_, _ = f.WriteString(output)
_ = pprof.Lookup("goroutine").WriteTo(f, 2)
log.Warn("Crash report saved to ", filepath)
}
// GetGID returns goroutine number which can be used to distiguish logs from
// the concurent processes. Keep in mind that it returns the number of routine
// which executes the function.
func GetGID() uint64 {
b := make([]byte, 64)
b = b[:runtime.Stack(b, false)]
b = bytes.TrimPrefix(b, []byte("goroutine "))
b = b[:bytes.IndexByte(b, ' ')]
n, _ := strconv.ParseUint(string(b), 10, 64)
return n
}
// SetupLog set up log level, formatter and output (file or stdout).
// Returns whether should be used debug for IMAP and SMTP servers.
func SetupLog(cfg logConfiger, levelFlag string) (debugClient, debugServer bool) {
level, useFile := getLogLevelAndFile(levelFlag)
logrus.SetLevel(level)
if useFile {
logrus.SetFormatter(&logrus.JSONFormatter{})
setLogFile(cfg.GetLogDir(), cfg.GetLogPrefix())
watchLogFileSize(cfg.GetLogDir(), cfg.GetLogPrefix())
} else {
logrus.SetFormatter(&logrus.TextFormatter{
ForceColors: true,
FullTimestamp: true,
TimestampFormat: time.StampMilli,
})
logrus.SetOutput(os.Stdout)
}
switch levelFlag {
case "debug-client", "debug-client-json":
debugClient = true
case "debug-server", "debug-server-json", "trace":
fmt.Println("THE LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
log.Warning("================================================")
log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
log.Warning("================================================")
debugClient = true
debugServer = true
}
return debugClient, debugServer
}
func setLogFile(logDir, logPrefix string) {
if logFile != nil {
return
}
filename := getLogFilename(logPrefix)
var err error
logFile, err = os.OpenFile(filepath.Join(logDir, filename), os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
if err != nil {
panic(err)
}
logrus.SetOutput(logFile)
// Users sometimes change the name of the log file. We want to always log
// information about bridge version (included in log prefix) and OS.
log.Warn("Bridge version: ", logPrefix, " ", runtime.GOOS)
}
func getLogFilename(logPrefix string) string {
currentTime := strconv.Itoa(int(time.Now().Unix()))
return logPrefix + "_" + currentTime + ".log"
}
func watchLogFileSize(logDir, logPrefix string) {
go func() {
for {
// Some rare bug can cause log file spamming a lot. Checking file
// size too often is not good, and at the same time postpone next
// check for too long is the same thing. 30 seconds seems as good
// compromise; average computer can generates ~500MB in 30 seconds.
time.Sleep(30 * time.Second)
checkLogFileSize(logDir, logPrefix)
}
}()
}
func checkLogFileSize(logDir, logPrefix string) {
if logFile == nil {
return
}
stat, err := logFile.Stat()
if err != nil {
log.Error("Log file size check failed: ", err)
return
}
if stat.Size() >= maxLogFileSize {
log.Warn("Current log file ", logFile.Name(), " is too big, opening new file")
closeLogFile()
setLogFile(logDir, logPrefix)
}
if err := clearLogs(logDir); err != nil {
log.Error("Cannot clear logs ", err)
}
}
func closeLogFile() {
if logFile != nil {
_ = logFile.Close()
logFile = nil
}
}
func clearLogs(logDir string) error {
files, err := ioutil.ReadDir(logDir)
if err != nil {
return err
}
var logsWithPrefix []string
var crashesWithPrefix []string
for _, file := range files {
if logFileRgx.MatchString(file.Name()) {
if logCrashRgx.MatchString(file.Name()) {
crashesWithPrefix = append(crashesWithPrefix, file.Name())
} else {
logsWithPrefix = append(logsWithPrefix, file.Name())
}
} else {
// Older versions of Bridge stored logs in subfolders for each version.
// That also has to be cleared and the functionality can be removed after some time.
if file.IsDir() {
if err := clearLogs(filepath.Join(logDir, file.Name())); err != nil {
return err
}
} else {
removeLog(logDir, file.Name())
}
}
}
removeOldLogs(logDir, logsWithPrefix)
removeOldLogs(logDir, crashesWithPrefix)
return nil
}
func removeOldLogs(logDir string, filenames []string) {
count := len(filenames)
if count <= maxNumberLogFiles {
return
}
sort.Strings(filenames) // Sorted by timestamp: oldest first.
for _, filename := range filenames[:count-maxNumberLogFiles] {
removeLog(logDir, filename)
}
}
func removeLog(logDir, filename string) {
// We need to be sure to delete only log files.
// Directory with logs can also contain other files.
if !logFileRgx.MatchString(filename) {
return
}
if err := os.RemoveAll(filepath.Join(logDir, filename)); err != nil {
log.Error("Cannot remove old logs ", err)
}
}

View File

@ -1,49 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// +build !build_qa
package config
import (
"github.com/sirupsen/logrus"
)
func getLogLevelAndFile(levelFlag string) (level logrus.Level, useFile bool) {
useFile = true
switch levelFlag {
case "panic":
level = logrus.PanicLevel
case "fatal":
level = logrus.FatalLevel
case "error":
level = logrus.ErrorLevel
case "warn":
level = logrus.WarnLevel
case "info":
level = logrus.InfoLevel
case "debug", "debug-client", "debug-server", "debug-client-json", "debug-server-json":
level = logrus.DebugLevel
useFile = false
case "trace":
level = logrus.TraceLevel
useFile = false
default:
level = logrus.InfoLevel
}
return
}

View File

@ -1,50 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// +build build_qa
package config
import (
"github.com/sirupsen/logrus"
)
// getLogLevelAndFile for QA build is altered in a way even decrypted data are stored
// in the log file when forced with `debug-client-json` or `debug-server-json`.
func getLogLevelAndFile(levelFlag string) (level logrus.Level, useFile bool) {
useFile = true
switch levelFlag {
case "panic":
level = logrus.PanicLevel
case "fatal":
level = logrus.FatalLevel
case "error":
level = logrus.ErrorLevel
case "warn":
level = logrus.WarnLevel
case "info":
level = logrus.InfoLevel
case "debug-client-json", "debug-server-json":
level = logrus.DebugLevel
case "debug", "debug-client", "debug-server":
level = logrus.DebugLevel
useFile = false
default:
level = logrus.InfoLevel
}
return
}

View File

@ -1,225 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
type testLogConfig struct{ logDir, logPrefix string }
func (c *testLogConfig) GetLogDir() string { return c.logDir }
func (c *testLogConfig) GetLogPrefix() string { return c.logPrefix }
var testLogDir string //nolint[gochecknoglobals]
func setupTestLogs() {
var err error
testLogDir, err = ioutil.TempDir("", "log")
if err != nil {
panic(err)
}
}
func shutdownTestLogs() {
_ = os.RemoveAll(testLogDir)
}
func TestLogNameLength(t *testing.T) {
cfg := New("bridge-test", "longVersion123", "longRevision1234567890", "c2")
name := getLogFilename(cfg.GetLogPrefix())
if len(name) > 128 {
t.Fatal("Name of the log is too long - limit for encrypted linux is 128 characters")
}
}
// Info and higher levels writes to the file.
func TestSetupLogInfo(t *testing.T) {
dir := beforeEachCreateTestDir(t, "setupInfo")
SetupLog(&testLogConfig{dir, "v"}, "info")
require.Equal(t, "info", logrus.GetLevel().String())
logrus.Info("test message")
files := checkLogFiles(t, dir, 1)
checkLogContains(t, dir, files[0].Name(), "test message")
}
// Debug levels writes to stdout.
func TestSetupLogDebug(t *testing.T) {
dir := beforeEachCreateTestDir(t, "setupDebug")
SetupLog(&testLogConfig{dir, "v"}, "debug")
require.Equal(t, "debug", logrus.GetLevel().String())
logrus.Info("test message")
checkLogFiles(t, dir, 0)
}
func TestReopenLogFile(t *testing.T) {
dir := beforeEachCreateTestDir(t, "reopenLogFile")
setLogFile(dir, "v1")
done := make(chan interface{})
log.Info("first message")
go func() {
<-done // Wait for closing file and opening new one.
log.Info("second message")
done <- nil
}()
closeLogFile()
setLogFile(dir, "v2")
done <- nil
<-done // Wait for second log message.
files := checkLogFiles(t, dir, 2)
checkLogContains(t, dir, files[0].Name(), "first message")
checkLogContains(t, dir, files[1].Name(), "second message")
}
func TestCheckLogFileSizeSmall(t *testing.T) {
dir := beforeEachCreateTestDir(t, "logFileSizeSmall")
setLogFile(dir, "v1")
originalFileName := logFile.Name()
_, _ = logFile.WriteString("small file")
checkLogFileSize(dir, "v2")
require.Equal(t, originalFileName, logFile.Name())
}
func TestCheckLogFileSizeBig(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
dir := beforeEachCreateTestDir(t, "logFileSizeBig")
setLogFile(dir, "v1")
originalFileName := logFile.Name()
// The limit for big file is 10*1024*1024 - keep the string 10 letters long.
for i := 0; i < 1024*1024; i++ {
_, _ = logFile.WriteString("big file!\n")
}
checkLogFileSize(dir, "v2")
require.NotEqual(t, originalFileName, logFile.Name())
}
// ClearLogs removes only bridge old log files keeping last three of them.
func TestClearLogsLinux(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
dir := beforeEachCreateTestDir(t, "clearLogs")
createTestStructureLinux(m, dir)
require.NoError(t, clearLogs(dir))
checkFileNames(t, dir, []string{
"cache",
"cache/c1",
"cache/c1/events.json",
"cache/c1/mailbox-user@pm.me.db",
"cache/c1/prefs.json",
"cache/c1/user_info.json",
"cache/c2",
"cache/c2/bridge-test.lock",
"cache/c2/events.json",
"cache/c2/mailbox-user@pm.me.db",
"cache/c2/prefs.json",
"cache/c2/updates",
"cache/c2/user_info.json",
"config",
"config/cert.pem",
"config/key.pem",
"logs",
"logs/other.log",
"logs/v1_11.log",
"logs/v2_12.log",
"logs/v2_13.log",
})
}
// ClearLogs removes only bridge old log files even when log folder
// is shared with other files on Windows.
func TestClearLogsWindows(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
dir := beforeEachCreateTestDir(t, "clearLogs")
createTestStructureWindows(m, dir)
require.NoError(t, clearLogs(dir))
checkFileNames(t, dir, []string{
"cache",
"cache/c1",
"cache/c1/events.json",
"cache/c1/mailbox-user@pm.me.db",
"cache/c1/prefs.json",
"cache/c1/user_info.json",
"cache/c2",
"cache/c2/bridge-test.lock",
"cache/c2/events.json",
"cache/c2/mailbox-user@pm.me.db",
"cache/c2/prefs.json",
"cache/c2/updates",
"cache/c2/user_info.json",
"cache/other.log",
"cache/v1_11.log",
"cache/v2_12.log",
"cache/v2_13.log",
"config",
"config/cert.pem",
"config/key.pem",
})
}
func beforeEachCreateTestDir(t *testing.T, dir string) string {
// Make sure opened file (from the previous test) is cleared.
closeLogFile()
dir = filepath.Join(testLogDir, dir)
require.NoError(t, os.MkdirAll(dir, 0700))
return dir
}
func checkLogFiles(t *testing.T, dir string, expectedCount int) []os.FileInfo {
files, err := ioutil.ReadDir(dir)
require.NoError(t, err)
require.Equal(t, expectedCount, len(files))
return files
}
func checkLogContains(t *testing.T, dir, fileName, expectedSubstr string) {
data, err := ioutil.ReadFile(filepath.Join(dir, fileName)) //nolint[gosec]
require.NoError(t, err)
require.Contains(t, string(data), expectedSubstr)
}

View File

@ -1,76 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: config/config.go
// Package config is a generated GoMock package.
package config
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockappDirer is a mock of appDirer interface
type MockappDirer struct {
ctrl *gomock.Controller
recorder *MockappDirerMockRecorder
}
// MockappDirerMockRecorder is the mock recorder for MockappDirer
type MockappDirerMockRecorder struct {
mock *MockappDirer
}
// NewMockappDirer creates a new mock instance
func NewMockappDirer(ctrl *gomock.Controller) *MockappDirer {
mock := &MockappDirer{ctrl: ctrl}
mock.recorder = &MockappDirerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockappDirer) EXPECT() *MockappDirerMockRecorder {
return m.recorder
}
// UserConfig mocks base method
func (m *MockappDirer) UserConfig() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserConfig")
ret0, _ := ret[0].(string)
return ret0
}
// UserConfig indicates an expected call of UserConfig
func (mr *MockappDirerMockRecorder) UserConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserConfig", reflect.TypeOf((*MockappDirer)(nil).UserConfig))
}
// UserCache mocks base method
func (m *MockappDirer) UserCache() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserCache")
ret0, _ := ret[0].(string)
return ret0
}
// UserCache indicates an expected call of UserCache
func (mr *MockappDirerMockRecorder) UserCache() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCache", reflect.TypeOf((*MockappDirer)(nil).UserCache))
}
// UserLogs mocks base method
func (m *MockappDirer) UserLogs() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserLogs")
ret0, _ := ret[0].(string)
return ret0
}
// UserLogs indicates an expected call of UserLogs
func (mr *MockappDirerMockRecorder) UserLogs() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserLogs", reflect.TypeOf((*MockappDirer)(nil).UserLogs))
}

View File

@ -1,127 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"encoding/json"
"errors"
"os"
"strconv"
"sync"
)
type Preferences struct {
cache map[string]string
path string
lock *sync.RWMutex
}
// NewPreferences returns loaded preferences.
func NewPreferences(preferencesPath string) *Preferences {
p := &Preferences{
path: preferencesPath,
lock: &sync.RWMutex{},
}
if err := p.load(); err != nil {
log.Warn("Cannot load preferences: ", err)
}
return p
}
func (p *Preferences) load() error {
if p.cache != nil {
return nil
}
p.lock.Lock()
defer p.lock.Unlock()
p.cache = map[string]string{}
f, err := os.Open(p.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewDecoder(f).Decode(&p.cache)
}
func (p *Preferences) save() error {
if p.cache == nil {
return errors.New("cannot save preferences: cache is nil")
}
p.lock.Lock()
defer p.lock.Unlock()
f, err := os.Create(p.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewEncoder(f).Encode(p.cache)
}
func (p *Preferences) SetDefault(key, value string) {
if p.Get(key) == "" {
p.Set(key, value)
}
}
func (p *Preferences) Get(key string) string {
p.lock.RLock()
defer p.lock.RUnlock()
return p.cache[key]
}
func (p *Preferences) GetBool(key string) bool {
return p.Get(key) == "true"
}
func (p *Preferences) GetInt(key string) int {
value, err := strconv.Atoi(p.Get(key))
if err != nil {
log.Error("Cannot parse int: ", err)
}
return value
}
func (p *Preferences) Set(key, value string) {
p.lock.Lock()
p.cache[key] = value
p.lock.Unlock()
if err := p.save(); err != nil {
log.Warn("Cannot save preferences: ", err)
}
}
func (p *Preferences) SetBool(key string, value bool) {
if value {
p.Set(key, "true")
} else {
p.Set(key, "false")
}
}
func (p *Preferences) SetInt(key string, value int) {
p.Set(key, strconv.Itoa(value))
}

View File

@ -1,109 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"io/ioutil"
"os"
"testing"
"github.com/stretchr/testify/require"
)
const testPrefFilePath = "/tmp/pref.json"
func shutdownTestPreferences() {
_ = os.RemoveAll(testPrefFilePath)
}
func TestLoadNoPreferences(t *testing.T) {
pref := newTestEmptyPreferences(t)
require.Equal(t, "", pref.Get("key"))
}
func TestLoadBadPreferences(t *testing.T) {
require.NoError(t, ioutil.WriteFile(testPrefFilePath, []byte("{\"key\":\"value"), 0700))
pref := NewPreferences(testPrefFilePath)
require.Equal(t, "", pref.Get("key"))
}
func TestPreferencesGet(t *testing.T) {
pref := newTestPreferences(t)
require.Equal(t, "value", pref.Get("str"))
require.Equal(t, "42", pref.Get("int"))
require.Equal(t, "true", pref.Get("bool"))
require.Equal(t, "t", pref.Get("falseBool"))
}
func TestPreferencesGetInt(t *testing.T) {
pref := newTestPreferences(t)
require.Equal(t, 0, pref.GetInt("str"))
require.Equal(t, 42, pref.GetInt("int"))
require.Equal(t, 0, pref.GetInt("bool"))
require.Equal(t, 0, pref.GetInt("falseBool"))
}
func TestPreferencesGetBool(t *testing.T) {
pref := newTestPreferences(t)
require.Equal(t, false, pref.GetBool("str"))
require.Equal(t, false, pref.GetBool("int"))
require.Equal(t, true, pref.GetBool("bool"))
require.Equal(t, false, pref.GetBool("falseBool"))
}
func TestPreferencesSetDefault(t *testing.T) {
pref := newTestEmptyPreferences(t)
pref.SetDefault("key", "value")
pref.SetDefault("key", "othervalue")
require.Equal(t, "value", pref.Get("key"))
}
func TestPreferencesSet(t *testing.T) {
pref := newTestEmptyPreferences(t)
pref.Set("str", "value")
checkSavedPreferences(t, "{\"str\":\"value\"}")
}
func TestPreferencesSetInt(t *testing.T) {
pref := newTestEmptyPreferences(t)
pref.SetInt("int", 42)
checkSavedPreferences(t, "{\"int\":\"42\"}")
}
func TestPreferencesSetBool(t *testing.T) {
pref := newTestEmptyPreferences(t)
pref.SetBool("trueBool", true)
pref.SetBool("falseBool", false)
checkSavedPreferences(t, "{\"falseBool\":\"false\",\"trueBool\":\"true\"}")
}
func newTestEmptyPreferences(t *testing.T) *Preferences {
require.NoError(t, os.RemoveAll(testPrefFilePath))
return NewPreferences(testPrefFilePath)
}
func newTestPreferences(t *testing.T) *Preferences {
require.NoError(t, ioutil.WriteFile(testPrefFilePath, []byte("{\"str\":\"value\",\"int\":\"42\",\"bool\":\"true\",\"falseBool\":\"t\"}"), 0700))
return NewPreferences(testPrefFilePath)
}
func checkSavedPreferences(t *testing.T, expected string) {
data, err := ioutil.ReadFile(testPrefFilePath)
require.NoError(t, err)
require.Equal(t, expected+"\n", string(data))
}

View File

@ -1,173 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"os/exec"
"runtime"
"time"
)
type tlsConfiger interface {
GetTLSCertPath() string
GetTLSKeyPath() string
}
var tlsTemplate = x509.Certificate{ //nolint[gochecknoglobals]
SerialNumber: big.NewInt(-1),
Subject: pkix.Name{
Country: []string{"CH"},
Organization: []string{"Proton Technologies AG"},
OrganizationalUnit: []string{"ProtonMail"},
CommonName: "127.0.0.1",
},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
IsCA: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
NotBefore: time.Now(),
NotAfter: time.Now().Add(20 * 365 * 24 * time.Hour),
}
var ErrTLSCertExpireSoon = fmt.Errorf("TLS certificate will expire soon")
// GetTLSConfig tries to load TLS config or generate new one which is then returned.
func GetTLSConfig(cfg tlsConfiger) (tlsConfig *tls.Config, err error) {
certPath := cfg.GetTLSCertPath()
keyPath := cfg.GetTLSKeyPath()
tlsConfig, err = loadTLSConfig(certPath, keyPath)
if err != nil {
log.WithError(err).Warn("Cannot load cert, generating a new one")
tlsConfig, err = GenerateTLSConfig(certPath, keyPath)
if err != nil {
return
}
if runtime.GOOS == "darwin" {
if err := exec.Command( // nolint[gosec]
"/usr/bin/security",
"execute-with-privileges",
"/usr/bin/security",
"add-trusted-cert",
"-d",
"-r", "trustRoot",
"-p", "ssl",
"-k", "/Library/Keychains/System.keychain",
certPath,
).Run(); err != nil {
log.WithError(err).Error("Failed to add cert to system keychain")
}
}
}
tlsConfig.ServerName = "127.0.0.1"
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
caCertPool := x509.NewCertPool()
caCertPool.AddCert(tlsConfig.Certificates[0].Leaf)
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
/* This is deprecated:
* SA1019: tlsConfig.BuildNameToCertificate is deprecated:
* NameToCertificate only allows associating a single certificate with a given name.
* Leave that field nil to let the library select the first compatible chain from Certificates.
*/
tlsConfig.BuildNameToCertificate() // nolint[staticcheck]
return tlsConfig, err
}
func loadTLSConfig(certPath, keyPath string) (tlsConfig *tls.Config, err error) {
c, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return
}
c.Leaf, err = x509.ParseCertificate(c.Certificate[0])
if err != nil {
return
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{c},
}
if time.Now().Add(31 * 24 * time.Hour).After(c.Leaf.NotAfter) {
err = ErrTLSCertExpireSoon
return
}
return
}
// GenerateTLSConfig generates certs and keys at the given filepaths and returns a TLS Config which holds them.
// See https://golang.org/src/crypto/tls/generate_cert.go
func GenerateTLSConfig(certPath, keyPath string) (tlsConfig *tls.Config, err error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
err = fmt.Errorf("failed to generate private key: %s", err)
return
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
err = fmt.Errorf("failed to generate serial number: %s", err)
return
}
tlsTemplate.SerialNumber = serialNumber
derBytes, err := x509.CreateCertificate(rand.Reader, &tlsTemplate, &tlsTemplate, &priv.PublicKey, priv)
if err != nil {
err = fmt.Errorf("failed to create certificate: %s", err)
return
}
certOut, err := os.Create(certPath)
if err != nil {
return
}
defer certOut.Close() //nolint[errcheck]
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
if err != nil {
return
}
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return
}
defer keyOut.Close() //nolint[errcheck]
err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
if err != nil {
return
}
return loadTLSConfig(certPath, keyPath)
}

View File

@ -1,63 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type testTLSConfig struct{ certPath, keyPath string }
func (c *testTLSConfig) GetTLSCertPath() string { return c.certPath }
func (c *testTLSConfig) GetTLSKeyPath() string { return c.keyPath }
func TestTLSKeyRenewal(t *testing.T) {
// Remove keys.
configPath := "/tmp"
certPath := filepath.Join(configPath, "cert.pem")
keyPath := filepath.Join(configPath, "key.pem")
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
// Put old key there.
tlsTemplate.NotBefore = time.Now().Add(-365 * 24 * time.Hour)
tlsTemplate.NotAfter = time.Now()
cert, err := GenerateTLSConfig(certPath, keyPath)
require.Equal(t, err, ErrTLSCertExpireSoon)
require.Equal(t, len(cert.Certificates), 1)
time.Sleep(time.Second)
now, notValidAfter := time.Now(), cert.Certificates[0].Leaf.NotAfter
require.True(t, now.After(notValidAfter), "old certificate expected to not be valid at %v but have valid until %v", now, notValidAfter)
// Renew key.
tlsTemplate.NotBefore = time.Now()
tlsTemplate.NotAfter = time.Now().Add(2 * 365 * 24 * time.Hour)
cert, err = GetTLSConfig(&testTLSConfig{certPath, keyPath})
if runtime.GOOS != "darwin" { // Darwin is not supported.
require.NoError(t, err)
}
require.Equal(t, len(cert.Certificates), 1)
now, notValidAfter = time.Now(), cert.Certificates[0].Leaf.NotAfter
require.False(t, now.After(notValidAfter), "new certificate expected to be valid at %v but have valid until %v", now, notValidAfter)
}

View File

@ -1,40 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package constants contains variables that are set via ldflags during build.
package constants
// nolint[gochecknoglobals]
var (
// Version of the build.
Version = ""
// Revision is current hash of the build.
Revision = ""
// BuildTime stamp of the build.
BuildTime = ""
// DSNSentry client keys to be able to report crashes to Sentry.
DSNSentry = ""
// LongVersion is derived from Version and Revision.
LongVersion = Version + " (" + Revision + ")"
// BuildVersion is derived from LongVersion and BuildTime.
BuildVersion = LongVersion + " " + BuildTime
)

82
pkg/files/removal.go Normal file
View File

@ -0,0 +1,82 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package files provides standard filesystem operations.
package files
import (
"os"
"path/filepath"
"sort"
"strings"
"github.com/hashicorp/go-multierror"
)
type OpRemove struct {
targets []string
exceptions []string
}
func Remove(targets ...string) *OpRemove {
return &OpRemove{targets: targets}
}
func (op *OpRemove) Except(exceptions ...string) *OpRemove {
op.exceptions = exceptions
return op
}
func (op *OpRemove) Do() error {
var multiErr error
for _, target := range op.targets {
if err := remove(target, op.exceptions...); err != nil {
multiErr = multierror.Append(multiErr, err)
}
}
return multiErr
}
func remove(dir string, except ...string) error {
var toRemove []string
if err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
for _, exception := range except {
if path == exception || strings.HasPrefix(exception, path) || strings.HasPrefix(path, exception) {
return nil
}
}
toRemove = append(toRemove, path)
return nil
}); err != nil {
return err
}
sort.Sort(sort.Reverse(sort.StringSlice(toRemove)))
for _, target := range toRemove {
if err := os.RemoveAll(target); err != nil {
return err
}
}
return nil
}

116
pkg/files/removal_test.go Normal file
View File

@ -0,0 +1,116 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package files
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRemove(t *testing.T) {
dir := newTestDir(t,
"subdir1",
"subdir2/subdir3",
)
defer delTestDir(t, dir)
createTestFiles(t, dir,
"subdir1/file1",
"subdir1/file2",
"subdir2/file3",
"subdir2/file4",
"subdir2/subdir3/file5",
"subdir2/subdir3/file6",
)
require.NoError(t, Remove(
filepath.Join(dir, "subdir1"),
filepath.Join(dir, "subdir2", "file3"),
filepath.Join(dir, "subdir2", "subdir3", "file5"),
).Do())
assert.NoFileExists(t, filepath.Join(dir, "subdir1", "file1"))
assert.NoFileExists(t, filepath.Join(dir, "subdir1", "file2"))
assert.NoFileExists(t, filepath.Join(dir, "subdir2", "file3"))
assert.FileExists(t, filepath.Join(dir, "subdir2", "file4"))
assert.NoFileExists(t, filepath.Join(dir, "subdir2", "subdir3", "file5"))
assert.FileExists(t, filepath.Join(dir, "subdir2", "subdir3", "file6"))
}
func TestRemoveWithExceptions(t *testing.T) {
dir := newTestDir(t,
"subdir1",
"subdir2/subdir3",
"subdir4",
)
defer delTestDir(t, dir)
createTestFiles(t, dir,
"subdir1/file1",
"subdir1/file2",
"subdir2/file3",
"subdir2/file4",
"subdir2/subdir3/file5",
"subdir2/subdir3/file6",
"subdir4/file7",
"subdir4/file8",
)
require.NoError(t, Remove(dir).Except(
filepath.Join(dir, "subdir2", "file4"),
filepath.Join(dir, "subdir2", "subdir3", "file6"),
filepath.Join(dir, "subdir4"),
).Do())
assert.NoFileExists(t, filepath.Join(dir, "subdir1", "file1"))
assert.NoFileExists(t, filepath.Join(dir, "subdir1", "file2"))
assert.NoFileExists(t, filepath.Join(dir, "subdir2", "file3"))
assert.FileExists(t, filepath.Join(dir, "subdir2", "file4"))
assert.NoFileExists(t, filepath.Join(dir, "subdir2", "subdir3", "file5"))
assert.FileExists(t, filepath.Join(dir, "subdir2", "subdir3", "file6"))
assert.FileExists(t, filepath.Join(dir, "subdir4", "file7"))
assert.FileExists(t, filepath.Join(dir, "subdir4", "file8"))
}
func newTestDir(t *testing.T, subdirs ...string) string {
dir, err := ioutil.TempDir("", "test-files-dir")
require.NoError(t, err)
for _, target := range subdirs {
require.NoError(t, os.MkdirAll(filepath.Join(dir, target), 0700))
}
return dir
}
func createTestFiles(t *testing.T, dir string, files ...string) {
for _, target := range files {
f, err := os.Create(filepath.Join(dir, target))
require.NoError(t, err)
require.NoError(t, f.Close())
}
}
func delTestDir(t *testing.T, dir string) {
require.NoError(t, os.RemoveAll(dir))
}

View File

@ -38,6 +38,15 @@ import (
// Parse parses RAW message.
func Parse(r io.Reader) (m *pmapi.Message, mimeBody, plainBody string, attReaders []io.Reader, err error) {
defer func() {
r := recover()
if r == nil {
return
}
err = fmt.Errorf("panic while parsing message: %v", r)
}()
p, err := parser.New(r)
if err != nil {
return nil, "", "", nil, errors.Wrap(err, "failed to create new parser")

View File

@ -545,6 +545,20 @@ func TestParseEncodedContentTypeBad(t *testing.T) {
require.Error(t, err)
}
type panicReader struct{}
func (panicReader) Read(p []byte) (int, error) {
panic("lol")
}
func TestParsePanic(t *testing.T) {
var err error
require.NotPanics(t, func() {
_, _, _, _, err = Parse(&panicReader{})
})
require.Error(t, err)
}
func getFileReader(filename string) io.Reader {
f, err := os.Open(filepath.Join("testdata", filename))
if err != nil {

View File

@ -81,4 +81,6 @@ type Client interface {
KeyRingForAddressID(string) (kr *crypto.KeyRing, err error)
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
DownloadAndVerify(string, string, *crypto.KeyRing) (io.Reader, error)
}

View File

@ -137,10 +137,18 @@ func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
cm.roundTripper = rt
}
func (cm *ClientManager) GetClientConfig() *ClientConfig {
return cm.config
}
func (cm *ClientManager) SetUserAgent(clientName, clientVersion, os string) {
cm.config.UserAgent = formatUserAgent(clientName, clientVersion, os)
}
func (cm *ClientManager) GetUserAgent() string {
return cm.config.UserAgent
}
// GetClient returns a client for the given userID.
// If the client does not exist already, it is created.
func (cm *ClientManager) GetClient(userID string) Client {
@ -366,7 +374,7 @@ func (cm *ClientManager) clearToken(userID string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).Info("Clearing token")
logrus.WithField("userID", userID).Debug("Clearing token")
delete(cm.tokens, userID)
}

View File

@ -18,23 +18,23 @@
package pmapi
import (
"net/http"
"strings"
"time"
)
// rootURL is the API root URL.
//
// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
// Default is pmapi_prod.
//
// It must not contain the protocol! The protocol should be in rootScheme.
var rootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
var rootScheme = "https" //nolint[gochecknoglobals]
// The HTTP transport to use by default.
var defaultTransport = &http.Transport{ //nolint[gochecknoglobals]
Proxy: http.ProxyFromEnvironment,
// rootScheme is the scheme to use for connections to the root URL.
var rootScheme = "https" //nolint[gochecknoglobals]
func GetAPIConfig(configName, appVersion string) *ClientConfig {
return &ClientConfig{
AppVersion: strings.Title(configName) + "_" + appVersion,
ClientID: configName,
Timeout: 25 * time.Minute, // Overall request timeout (~25MB / 25 mins => ~16kB/s, should be reasonable).
FirstReadTimeout: 30 * time.Second, // 30s to match 30s response header timeout.
MinBytesPerSecond: 1 << 10, // Enforce minimum download speed of 1kB/s.
}
}
// checkTLSCerts controls whether TLS certs are checked against known fingerprints.
// The default is for this to always be done.
var checkTLSCerts = true //nolint[gochecknoglobals]

View File

@ -15,43 +15,30 @@
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// +build pmapi_prod
// +build !pmapi_qa
package config
package pmapi
import (
"net/http"
"strings"
"time"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
func (c *Config) GetAPIConfig() *pmapi.ClientConfig {
return &pmapi.ClientConfig{
AppVersion: c.getAPIOS() + strings.Title(c.appName) + "_" + c.version,
ClientID: c.appName,
Timeout: 25 * time.Minute, // Overall request timeout (~25MB / 25 mins => ~16kB/s, should be reasonable).
FirstReadTimeout: 30 * time.Second, // 30s to match 30s response header timeout.
MinBytesPerSecond: 1 << 10, // Enforce minimum download speed of 1kB/s.
}
}
func (c *Config) GetRoundTripper(cm *pmapi.ClientManager, listener listener.Listener) http.RoundTripper {
func GetRoundTripper(cm *ClientManager, listener listener.Listener) http.RoundTripper {
// We use a TLS dialer.
basicDialer := pmapi.NewBasicTLSDialer()
basicDialer := NewBasicTLSDialer()
// We wrap the TLS dialer in a layer which enforces connections to trusted servers.
pinningDialer := pmapi.NewPinningTLSDialer(basicDialer)
pinningDialer := NewPinningTLSDialer(basicDialer)
// We want any pin mismatches to be communicated back to bridge GUI and reported.
pinningDialer.SetTLSIssueNotifier(func() { listener.Emit(events.TLSCertIssue, "") })
pinningDialer.EnableRemoteTLSIssueReporting(c.GetAPIConfig().AppVersion, c.GetAPIConfig().UserAgent)
pinningDialer.EnableRemoteTLSIssueReporting(cm)
// We wrap the pinning dialer in a layer which adds "alternative routing" feature.
proxyDialer := pmapi.NewProxyTLSDialer(pinningDialer, cm)
proxyDialer := NewProxyTLSDialer(pinningDialer, cm)
return pmapi.CreateTransportWithDialer(proxyDialer)
return CreateTransportWithDialer(proxyDialer)
}

View File

@ -24,6 +24,8 @@ import (
"net/http"
"os"
"strings"
"github.com/ProtonMail/proton-bridge/pkg/listener"
)
func init() {
@ -37,13 +39,13 @@ func init() {
rootURL = fullRootURL
rootScheme = "https"
}
}
func GetRoundTripper(_ *ClientManager, _ listener.Listener) http.RoundTripper {
transport := CreateTransportWithDialer(NewBasicTLSDialer())
// TLS certificate of testing environment might be self-signed.
defaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
// This config disables TLS cert checking.
checkTLSCerts = false
return transport
}

View File

@ -30,19 +30,12 @@ type PinningTLSDialer struct {
dialer TLSDialer
// pinChecker is used to check TLS keys of connections.
pinChecker pinChecker
pinChecker *pinChecker
// tlsIssueNotifier is used to notify something when there is a TLS issue.
tlsIssueNotifier func()
// appVersion is needed to report TLS mismatches.
appVersion string
// userAgent is needed to report TLS mismatches.
userAgent string
// enableRemoteReporting instructs the dialer to report TLS mismatches.
enableRemoteReporting bool
reporter *tlsReporter
// A logger for logging messages.
log logrus.FieldLogger
@ -63,41 +56,38 @@ func (p *PinningTLSDialer) SetTLSIssueNotifier(notifier func()) {
p.tlsIssueNotifier = notifier
}
func (p *PinningTLSDialer) EnableRemoteTLSIssueReporting(appVersion, userAgent string) {
p.enableRemoteReporting = true
p.appVersion = appVersion
p.userAgent = userAgent
func (p *PinningTLSDialer) EnableRemoteTLSIssueReporting(cm *ClientManager) {
p.reporter = newTLSReporter(p.pinChecker, cm)
}
// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
func (p *PinningTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
if conn, err = p.dialer.DialTLS(network, address); err != nil {
return
func (p *PinningTLSDialer) DialTLS(network, address string) (net.Conn, error) {
conn, err := p.dialer.DialTLS(network, address)
if err != nil {
return nil, err
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return
return nil, err
}
if err = p.pinChecker.checkCertificate(conn); err != nil {
if err := p.pinChecker.checkCertificate(conn); err != nil {
if p.tlsIssueNotifier != nil {
go p.tlsIssueNotifier()
}
if tlsConn, ok := conn.(*tls.Conn); ok && p.enableRemoteReporting {
p.pinChecker.reportCertIssue(
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
p.reporter.reportCertIssue(
TLSReportURI,
host,
port,
tlsConn.ConnectionState(),
p.appVersion,
p.userAgent,
)
}
return
return nil, err
}
return
return conn, nil
}

74
pkg/pmapi/download.go Normal file
View File

@ -0,0 +1,74 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
// DownloadAndVerify downloads a file and its signature from the given locations `file` and `sig`.
// The file and its signature are verified using the given keyring `kr`.
// If the file is verified successfully, it can be read from the returned reader.
// TLS fingerprinting is used to verify that connections are only made to known servers.
func (c *client) DownloadAndVerify(file, sig string, kr *crypto.KeyRing) (io.Reader, error) {
var fb, sb []byte
if err := c.fetchFile(file, func(r io.Reader) (err error) {
fb, err = ioutil.ReadAll(r)
return err
}); err != nil {
return nil, err
}
if err := c.fetchFile(sig, func(r io.Reader) (err error) {
sb, err = ioutil.ReadAll(r)
return err
}); err != nil {
return nil, err
}
if err := kr.VerifyDetached(
crypto.NewPlainMessage(fb),
crypto.NewPGPSignature(sb),
crypto.GetUnixTime(),
); err != nil {
return nil, err
}
return bytes.NewReader(fb), nil
}
func (c *client) fetchFile(file string, fn func(io.Reader) error) error {
res, err := c.hc.Get(file)
if err != nil {
return err
}
defer func() { _ = res.Body.Close() }()
if res.StatusCode != http.StatusOK {
return fmt.Errorf("failed to get file: http error %v", res.StatusCode)
}
return fn(res.Body)
}

View File

@ -294,6 +294,21 @@ func (mr *MockClientMockRecorder) DeleteMessages(arg0 interface{}) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMessages", reflect.TypeOf((*MockClient)(nil).DeleteMessages), arg0)
}
// DownloadAndVerify mocks base method
func (m *MockClient) DownloadAndVerify(arg0, arg1 string, arg2 *crypto.KeyRing) (io.Reader, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DownloadAndVerify", arg0, arg1, arg2)
ret0, _ := ret[0].(io.Reader)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DownloadAndVerify indicates an expected call of DownloadAndVerify
func (mr *MockClientMockRecorder) DownloadAndVerify(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadAndVerify", reflect.TypeOf((*MockClient)(nil).DownloadAndVerify), arg0, arg1, arg2)
}
// EmptyFolder mocks base method
func (m *MockClient) EmptyFolder(arg0, arg1 string) error {
m.ctrl.T.Helper()

View File

@ -35,7 +35,6 @@ import (
type pinChecker struct {
trustedPins []string
sentReports []sentReport
}
type sentReport struct {
@ -43,8 +42,8 @@ type sentReport struct {
t time.Time
}
func newPinChecker(trustedPins []string) pinChecker {
return pinChecker{
func newPinChecker(trustedPins []string) *pinChecker {
return &pinChecker{
trustedPins: trustedPins,
}
}
@ -76,8 +75,25 @@ func certFingerprint(cert *x509.Certificate) string {
return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:]))
}
type clientConfigProvider interface {
GetClientConfig() *ClientConfig
}
type tlsReporter struct {
cm clientConfigProvider
p *pinChecker
sentReports []sentReport
}
func newTLSReporter(p *pinChecker, cm clientConfigProvider) *tlsReporter {
return &tlsReporter{
cm: cm,
p: p,
}
}
// reportCertIssue reports a TLS key mismatch.
func (p *pinChecker) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState, appVersion, userAgent string) {
func (r *tlsReporter) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState) {
var certChain []string
if len(connState.VerifiedChains) > 0 {
@ -86,27 +102,29 @@ func (p *pinChecker) reportCertIssue(remoteURI, host, port string, connState tls
certChain = marshalCert7468(connState.PeerCertificates)
}
r := newTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion)
cfg := r.cm.GetClientConfig()
if !p.hasRecentlySentReport(r) {
p.recordReport(r)
go r.sendReport(remoteURI, userAgent)
report := newTLSReport(host, port, connState.ServerName, certChain, r.p.trustedPins, cfg.AppVersion)
if !r.hasRecentlySentReport(report) {
r.recordReport(report)
go report.sendReport(remoteURI, cfg.UserAgent)
}
}
// hasRecentlySentReport returns whether the report was already sent within the last 24 hours.
func (p *pinChecker) hasRecentlySentReport(report tlsReport) bool {
func (r *tlsReporter) hasRecentlySentReport(report tlsReport) bool {
var validReports []sentReport
for _, r := range p.sentReports {
for _, r := range r.sentReports {
if time.Since(r.t) < 24*time.Hour {
validReports = append(validReports, r)
}
}
p.sentReports = validReports
r.sentReports = validReports
for _, r := range p.sentReports {
for _, r := range r.sentReports {
if cmp.Equal(report, r.r) {
return true
}
@ -116,8 +134,8 @@ func (p *pinChecker) hasRecentlySentReport(report tlsReport) bool {
}
// recordReport records the given report and the current time so we can check whether we recently sent this report.
func (p *pinChecker) recordReport(r tlsReport) {
p.sentReports = append(p.sentReports, sentReport{r: r, t: time.Now()})
func (r *tlsReporter) recordReport(report tlsReport) {
r.sentReports = append(r.sentReports, sentReport{r: report, t: time.Now()})
}
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {

View File

@ -27,6 +27,14 @@ import (
"github.com/stretchr/testify/assert"
)
type fakeClientConfigProvider struct {
version, useragent string
}
func (c *fakeClientConfigProvider) GetClientConfig() *ClientConfig {
return &ClientConfig{AppVersion: c.version, UserAgent: c.useragent}
}
func TestPinCheckerDoubleReport(t *testing.T) {
reportCounter := 0
@ -34,11 +42,11 @@ func TestPinCheckerDoubleReport(t *testing.T) {
reportCounter++
}))
pc := newPinChecker(TrustedAPIPins)
r := newTLSReporter(newPinChecker(TrustedAPIPins), &fakeClientConfigProvider{version: "3", useragent: "useragent"})
// Report the same issue many times.
for i := 0; i < 10; i++ {
pc.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{}, "3", "useragent")
r.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{})
}
// We should only report once.
@ -48,7 +56,7 @@ func TestPinCheckerDoubleReport(t *testing.T) {
// If we then report something else many times.
for i := 0; i < 10; i++ {
pc.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{}, "3", "useragent")
r.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{})
}
// We should get a second report.

View File

@ -120,9 +120,7 @@ func (c *client) UpdateUser() (user *User, err error) {
c.user = user
sentry.ConfigureScope(func(scope *sentry.Scope) {
scope.SetUser(sentry.User{
ID: user.ID,
})
scope.SetUser(sentry.User{ID: user.ID})
})
var tmpList AddressList

View File

@ -18,8 +18,8 @@
package ports
import (
"fmt"
"net"
"strconv"
)
const (
@ -31,9 +31,12 @@ func IsPortFree(port int) bool {
if !(0 < port && port < maxPortNumber) {
return false
}
stringPort := ":" + strconv.Itoa(port)
isFree := !isOccupied(stringPort)
return isFree
// First, check localhost only.
if isOccupied(fmt.Sprintf("127.0.0.1:%d", port)) {
return false
}
// Second, check also ports opened to public.
return !isOccupied(fmt.Sprintf(":%d", port))
}
func isOccupied(port string) bool {

View File

@ -19,28 +19,76 @@ package sentry
import (
"errors"
"fmt"
"os"
"runtime"
"time"
"github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/getsentry/sentry-go"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
)
var (
skippedFunctions = []string{} //nolint[gochecknoglobals]
)
var skippedFunctions = []string{} //nolint[gochecknoglobals]
// ReportSentryCrash reports a sentry crash.
func ReportSentryCrash(clientID, appVersion, userAgent string, reportErr error) error {
func init() { // nolint[noinit]
if err := sentry.Init(sentry.ClientOptions{
Dsn: constants.DSNSentry,
Release: constants.Revision,
BeforeSend: EnhanceSentryEvent,
}); err != nil {
logrus.WithError(err).Error("Failed to initialize sentry options")
}
sentry.ConfigureScope(func(scope *sentry.Scope) {
scope.SetFingerprint([]string{"{{ default }}"})
})
}
type userAgentProvider interface {
GetUserAgent() string
}
type Reporter struct {
appName string
appVersion string
uap userAgentProvider
}
// NewReporter creates new sentry reporter with appName and appVersion to report.
func NewReporter(appName, appVersion string) *Reporter {
return &Reporter{
appName: appName,
appVersion: appVersion,
}
}
func (r *Reporter) SetUserAgentProvider(uap userAgentProvider) {
r.uap = uap
}
// Report reports a sentry crash with stacktrace from all goroutines.
func (r *Reporter) Report(i interface{}) (err error) {
SkipDuringUnwind()
if reportErr == nil {
if os.Getenv("PROTONMAIL_ENV") == "dev" {
return nil
}
// In case clientManager is not yet created we can get at least OS string.
var userAgent string
if r.uap != nil {
userAgent = r.uap.GetUserAgent()
} else {
userAgent = runtime.GOOS
}
reportErr := fmt.Errorf("recover: %v", i)
tags := map[string]string{
"OS": runtime.GOOS,
"Client": clientID,
"Version": appVersion,
"Client": r.appName,
"Version": r.appVersion,
"UserAgent": userAgent,
"UserID": "",
}
@ -49,18 +97,17 @@ func ReportSentryCrash(clientID, appVersion, userAgent string, reportErr error)
sentry.WithScope(func(scope *sentry.Scope) {
SkipDuringUnwind()
scope.SetTags(tags)
eventID := sentry.CaptureException(reportErr)
if eventID != nil {
if eventID := sentry.CaptureException(reportErr); eventID != nil {
reportID = string(*eventID)
}
})
if !sentry.Flush(time.Second * 10) {
log.WithField("error", reportErr).Error("Failed to report sentry error")
return errors.New("failed to report sentry error")
}
log.WithField("error", reportErr).WithField("id", reportID).Warn("Sentry error reported")
logrus.WithField("error", reportErr).WithField("id", reportID).Warn("Sentry error reported")
return nil
}

View File

@ -15,29 +15,29 @@
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// +build !pmapi_prod
package config
// Package signature implements functions to verify files by their detached signatures.
package signature
import (
"net/http"
"strings"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/pkg/errors"
)
func (c *Config) GetAPIConfig() *pmapi.ClientConfig {
return &pmapi.ClientConfig{
AppVersion: c.getAPIOS() + strings.Title(c.appName) + "_" + c.version,
ClientID: c.appName,
// Verify verifies the given file by its signature using the given armored public key.
func Verify(fileBytes, sigBytes []byte, pubKey string) error {
key, err := crypto.NewKeyFromArmored(pubKey)
if err != nil {
return errors.Wrap(err, "failed to load key")
}
}
func SetClientRoundTripper(_ *pmapi.ClientManager, _ *pmapi.ClientConfig, _ listener.Listener) {
// Use the default roundtripper; do nothing.
}
kr, err := crypto.NewKeyRing(key)
if err != nil {
return errors.Wrap(err, "failed to create keyring")
}
func (c *Config) GetRoundTripper(_ *pmapi.ClientManager, _ listener.Listener) http.RoundTripper {
return http.DefaultTransport
return kr.VerifyDetached(
crypto.NewPlainMessage(fileBytes),
crypto.NewPGPSignature(sigBytes),
crypto.GetUnixTime(),
)
}

76
pkg/tar/tar.go Normal file
View File

@ -0,0 +1,76 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package tar
import (
"archive/tar"
"io"
"os"
"path/filepath"
"runtime"
"github.com/sirupsen/logrus"
)
func UntarToDir(r io.Reader, dir string) error {
tr := tar.NewReader(r)
for {
header, err := tr.Next()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
if header == nil {
continue
}
target := filepath.Join(dir, header.Name)
switch {
case header.Typeflag == tar.TypeSymlink:
if err := os.Symlink(header.Linkname, target); err != nil {
return err
}
case header.FileInfo().IsDir():
if err := os.MkdirAll(target, header.FileInfo().Mode()); err != nil {
return err
}
default:
f, err := os.Create(target)
if err != nil {
return err
}
if _, err := io.Copy(f, tr); err != nil { // nolint[gosec]
return err
}
if runtime.GOOS != "windows" {
if err := f.Chmod(header.FileInfo().Mode()); err != nil {
return err
}
}
if err := f.Close(); err != nil {
logrus.WithError(err).Error("Failed to close file")
}
}
}
}