forked from Silverfish/proton-bridge
Launcher, app/base, sentry, update service
This commit is contained in:
@ -1,319 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ProtonMail/go-appdir"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
log = logrus.WithField("pkg", "config") //nolint[gochecknoglobals]
|
||||
)
|
||||
|
||||
type appDirProvider interface {
|
||||
UserConfig() string
|
||||
UserCache() string
|
||||
UserLogs() string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
appName string
|
||||
version string
|
||||
revision string
|
||||
cacheVersion string
|
||||
appDirs appDirProvider
|
||||
appDirsVersion appDirProvider
|
||||
}
|
||||
|
||||
// New returns fully initialized config struct.
|
||||
// `appName` should be in camelCase format for folder or file names. It's also used in API
|
||||
// as `AppVersion` which is converted to CamelCase.
|
||||
// `version` is the version of the app (e.g. v1.2.3).
|
||||
// `cacheVersion` is the version of the cache files (setting a different number will remove the old ones).
|
||||
func New(appName, version, revision, cacheVersion string) *Config {
|
||||
appDirs := appdir.New(filepath.Join("protonmail", appName))
|
||||
appDirsVersion := appdir.New(filepath.Join("protonmail", appName, cacheVersion))
|
||||
return newConfig(appName, version, revision, cacheVersion, appDirs, appDirsVersion)
|
||||
}
|
||||
|
||||
func newConfig(appName, version, revision, cacheVersion string, appDirs, appDirsVersion appDirProvider) *Config {
|
||||
return &Config{
|
||||
appName: appName,
|
||||
version: version,
|
||||
revision: revision,
|
||||
cacheVersion: cacheVersion,
|
||||
appDirs: appDirs,
|
||||
appDirsVersion: appDirsVersion,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDirs creates all folders that are necessary for bridge to properly function.
|
||||
func (c *Config) CreateDirs() error {
|
||||
// Log files.
|
||||
if err := os.MkdirAll(c.appDirs.UserLogs(), 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
// TLS files.
|
||||
if err := os.MkdirAll(c.appDirs.UserConfig(), 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
// Lock, events, preferences, user_info, db files.
|
||||
if err := os.MkdirAll(c.appDirsVersion.UserCache(), 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearData removes all files except the lock file.
|
||||
// The lock file will be removed when the Bridge stops.
|
||||
func (c *Config) ClearData() error {
|
||||
dirs := []string{
|
||||
c.appDirs.UserLogs(),
|
||||
c.appDirs.UserConfig(),
|
||||
c.appDirs.UserCache(),
|
||||
}
|
||||
shouldRemove := func(filePath string) bool {
|
||||
return filePath != c.GetLockPath()
|
||||
}
|
||||
return c.removeAllExcept(dirs, shouldRemove)
|
||||
}
|
||||
|
||||
// ClearOldData removes all old files, such as old log files or old versions of cache and so on.
|
||||
func (c *Config) ClearOldData() error {
|
||||
// `appDirs` is parent for `appDirsVersion`.
|
||||
// `dir` then contains all subfolders and only `cacheVersion` should stay.
|
||||
// But on Windows all files (dirs) are in the same one - we cannot remove log, lock or tls files.
|
||||
dir := c.appDirs.UserCache()
|
||||
|
||||
return c.removeExcept(dir, func(filePath string) bool {
|
||||
fileName := filepath.Base(filePath)
|
||||
return (fileName != c.cacheVersion &&
|
||||
!logFileRgx.MatchString(fileName) &&
|
||||
filePath != c.GetLogDir() &&
|
||||
filePath != c.GetTLSCertPath() &&
|
||||
filePath != c.GetTLSKeyPath() &&
|
||||
filePath != c.GetEventsPath() &&
|
||||
filePath != c.GetIMAPCachePath() &&
|
||||
filePath != c.GetLockPath() &&
|
||||
filePath != c.GetPreferencesPath())
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) removeAllExcept(dirs []string, shouldRemove func(string) bool) error {
|
||||
var result *multierror.Error
|
||||
for _, dir := range dirs {
|
||||
if err := c.removeExcept(dir, shouldRemove); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (c *Config) removeExcept(dir string, shouldRemove func(string) bool) error {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, file := range files {
|
||||
filePath := filepath.Join(dir, file.Name())
|
||||
if !shouldRemove(filePath) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !file.IsDir() {
|
||||
if err := os.RemoveAll(filePath); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
subDir := filepath.Join(dir, file.Name())
|
||||
if err := c.removeExcept(subDir, shouldRemove); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
// Remove dir itself only if it's empty.
|
||||
subFiles, err := ioutil.ReadDir(subDir)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else if len(subFiles) == 0 {
|
||||
if err := os.RemoveAll(subDir); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// IsDevMode should be used for development conditions such us whether to send sentry reports.
|
||||
func (c *Config) IsDevMode() bool {
|
||||
return os.Getenv("PROTONMAIL_ENV") == "dev"
|
||||
}
|
||||
|
||||
// GetVersion returns the version.
|
||||
func (c *Config) GetVersion() string {
|
||||
return c.version
|
||||
}
|
||||
|
||||
// GetLogDir returns folder for log files.
|
||||
func (c *Config) GetLogDir() string {
|
||||
return c.appDirs.UserLogs()
|
||||
}
|
||||
|
||||
// GetLogPrefix returns prefix for log files. Bridge uses format vVERSION.
|
||||
func (c *Config) GetLogPrefix() string {
|
||||
return "v" + c.version + "_" + c.revision
|
||||
}
|
||||
|
||||
// GetLicenseFilePath returns path to liense file.
|
||||
func (c *Config) GetLicenseFilePath() string {
|
||||
path := c.getLicenseFilePath()
|
||||
log.WithField("path", path).Info("License file path")
|
||||
return path
|
||||
}
|
||||
|
||||
func (c *Config) getLicenseFilePath() string {
|
||||
// User can install app to different location, or user can run it
|
||||
// directly from the package without installation, or it could be
|
||||
// automatically updated (app started from differenet location).
|
||||
// For all those cases, first let's check LICENSE next to the binary.
|
||||
path := filepath.Join(filepath.Dir(os.Args[0]), "LICENSE")
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
appName := c.appName
|
||||
if c.appName == "importExport" {
|
||||
appName = "import-export"
|
||||
}
|
||||
// Most Linux distributions.
|
||||
path := "/usr/share/doc/protonmail/" + appName + "/LICENSE"
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
// Arch distributions.
|
||||
return "/usr/share/licenses/protonmail-" + appName + "/LICENSE"
|
||||
case "darwin": //nolint[goconst]
|
||||
path := filepath.Join(filepath.Dir(os.Args[0]), "..", "Resources", "LICENSE")
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
|
||||
appName := "ProtonMail Bridge.app"
|
||||
if c.appName == "importExport" {
|
||||
appName = "ProtonMail Import-Export.app"
|
||||
}
|
||||
return "/Applications/" + appName + "/Contents/Resources/LICENSE"
|
||||
case "windows":
|
||||
path := filepath.Join(filepath.Dir(os.Args[0]), "LICENSE.txt")
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
// This should not happen, Windows should be handled by relative
|
||||
// location to the binary above. This is just fallback which may
|
||||
// or may not work, depends where user installed the app and how
|
||||
// user started the app.
|
||||
return filepath.FromSlash("C:/Program Files/Proton Technologies AG/ProtonMail Bridge/LICENSE.txt")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetTLSCertPath returns path to certificate; used for TLS servers (IMAP, SMTP and API).
|
||||
func (c *Config) GetTLSCertPath() string {
|
||||
return filepath.Join(c.appDirs.UserConfig(), "cert.pem")
|
||||
}
|
||||
|
||||
// GetTLSKeyPath returns path to private key; used for TLS servers (IMAP, SMTP and API).
|
||||
func (c *Config) GetTLSKeyPath() string {
|
||||
return filepath.Join(c.appDirs.UserConfig(), "key.pem")
|
||||
}
|
||||
|
||||
// GetDBDir returns folder for db files.
|
||||
func (c *Config) GetDBDir() string {
|
||||
return c.appDirsVersion.UserCache()
|
||||
}
|
||||
|
||||
// GetEventsPath returns path to events file containing the last processed event IDs.
|
||||
func (c *Config) GetEventsPath() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache(), "events.json")
|
||||
}
|
||||
|
||||
// GetIMAPCachePath returns path to file with IMAP status.
|
||||
func (c *Config) GetIMAPCachePath() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache(), "user_info.json")
|
||||
}
|
||||
|
||||
// GetLockPath returns path to lock file to check if bridge is already running.
|
||||
func (c *Config) GetLockPath() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache(), c.appName+".lock")
|
||||
}
|
||||
|
||||
// GetUpdateDir returns folder for update files; such as new binary.
|
||||
func (c *Config) GetUpdateDir() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache(), "updates")
|
||||
}
|
||||
|
||||
// GetPreferencesPath returns path to preference file.
|
||||
func (c *Config) GetPreferencesPath() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache(), "prefs.json")
|
||||
}
|
||||
|
||||
// GetTransferDir returns folder for import-export rules files.
|
||||
func (c *Config) GetTransferDir() string {
|
||||
return c.appDirsVersion.UserCache()
|
||||
}
|
||||
|
||||
// GetDefaultAPIPort returns default Bridge local API port.
|
||||
func (c *Config) GetDefaultAPIPort() int {
|
||||
return 1042
|
||||
}
|
||||
|
||||
// GetDefaultIMAPPort returns default Bridge IMAP port.
|
||||
func (c *Config) GetDefaultIMAPPort() int {
|
||||
return 1143
|
||||
}
|
||||
|
||||
// GetDefaultSMTPPort returns default Bridge SMTP port.
|
||||
func (c *Config) GetDefaultSMTPPort() int {
|
||||
return 1025
|
||||
}
|
||||
|
||||
// getAPIOS returns actual operating system.
|
||||
func (c *Config) getAPIOS() string {
|
||||
switch os := runtime.GOOS; os {
|
||||
case "darwin": // nolint: goconst
|
||||
return "macOS"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
}
|
||||
|
||||
return "Linux"
|
||||
}
|
||||
@ -1,238 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testAppName = "bridge-test"
|
||||
|
||||
var testConfigDir string //nolint[gochecknoglobals]
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
setupTestConfig()
|
||||
setupTestLogs()
|
||||
code := m.Run()
|
||||
shutdownTestConfig()
|
||||
shutdownTestLogs()
|
||||
shutdownTestPreferences()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func setupTestConfig() {
|
||||
var err error
|
||||
testConfigDir, err = ioutil.TempDir("", "config")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownTestConfig() {
|
||||
_ = os.RemoveAll(testConfigDir)
|
||||
}
|
||||
|
||||
type mocks struct {
|
||||
t *testing.T
|
||||
|
||||
ctrl *gomock.Controller
|
||||
appDir *MockappDirer
|
||||
appDirVersion *MockappDirer
|
||||
}
|
||||
|
||||
func initMocks(t *testing.T) mocks {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
return mocks{
|
||||
t: t,
|
||||
|
||||
ctrl: mockCtrl,
|
||||
appDir: NewMockappDirer(mockCtrl),
|
||||
appDirVersion: NewMockappDirer(mockCtrl),
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearDataLinux(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
createTestStructureLinux(m, testConfigDir)
|
||||
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
|
||||
require.NoError(t, cfg.ClearData())
|
||||
checkFileNames(t, testConfigDir, []string{
|
||||
"cache",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"config",
|
||||
"logs",
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearDataWindows(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
createTestStructureWindows(m, testConfigDir)
|
||||
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
|
||||
require.NoError(t, cfg.ClearData())
|
||||
checkFileNames(t, testConfigDir, []string{
|
||||
"cache",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"config",
|
||||
})
|
||||
}
|
||||
|
||||
// OldData touches only cache folder.
|
||||
// Removes only c1 folder as nothing else is part of cache folder on Linux/Mac.
|
||||
func TestClearOldDataLinux(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
createTestStructureLinux(m, testConfigDir)
|
||||
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
|
||||
require.NoError(t, cfg.ClearOldData())
|
||||
checkFileNames(t, testConfigDir, []string{
|
||||
"cache",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"cache/c2/events.json",
|
||||
"cache/c2/mailbox-user@pm.me.db",
|
||||
"cache/c2/prefs.json",
|
||||
"cache/c2/updates",
|
||||
"cache/c2/user_info.json",
|
||||
"config",
|
||||
"config/cert.pem",
|
||||
"config/key.pem",
|
||||
"logs",
|
||||
"logs/other.log",
|
||||
"logs/v1_10.log",
|
||||
"logs/v1_11.log",
|
||||
"logs/v2_12.log",
|
||||
"logs/v2_13.log",
|
||||
})
|
||||
}
|
||||
|
||||
// OldData touches only cache folder. Removes everything except c2 folder
|
||||
// and bridge log files which are part of cache folder on Windows.
|
||||
func TestClearOldDataWindows(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
createTestStructureWindows(m, testConfigDir)
|
||||
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
|
||||
require.NoError(t, cfg.ClearOldData())
|
||||
checkFileNames(t, testConfigDir, []string{
|
||||
"cache",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"cache/c2/events.json",
|
||||
"cache/c2/mailbox-user@pm.me.db",
|
||||
"cache/c2/prefs.json",
|
||||
"cache/c2/updates",
|
||||
"cache/c2/user_info.json",
|
||||
"cache/v1_10.log",
|
||||
"cache/v1_11.log",
|
||||
"cache/v2_12.log",
|
||||
"cache/v2_13.log",
|
||||
"config",
|
||||
"config/cert.pem",
|
||||
"config/key.pem",
|
||||
})
|
||||
}
|
||||
|
||||
func createTestStructureLinux(m mocks, baseDir string) {
|
||||
logsDir := filepath.Join(baseDir, "logs")
|
||||
configDir := filepath.Join(baseDir, "config")
|
||||
cacheDir := filepath.Join(baseDir, "cache")
|
||||
versionedOldCacheDir := filepath.Join(baseDir, "cache", "c1")
|
||||
versionedCacheDir := filepath.Join(baseDir, "cache", "c2")
|
||||
createTestStructure(m, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir)
|
||||
}
|
||||
|
||||
func createTestStructureWindows(m mocks, baseDir string) {
|
||||
logsDir := filepath.Join(baseDir, "cache")
|
||||
configDir := filepath.Join(baseDir, "config")
|
||||
cacheDir := filepath.Join(baseDir, "cache")
|
||||
versionedOldCacheDir := filepath.Join(baseDir, "cache", "c1")
|
||||
versionedCacheDir := filepath.Join(baseDir, "cache", "c2")
|
||||
createTestStructure(m, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir)
|
||||
}
|
||||
|
||||
func createTestStructure(m mocks, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir string) {
|
||||
m.appDir.EXPECT().UserLogs().Return(logsDir).AnyTimes()
|
||||
m.appDir.EXPECT().UserConfig().Return(configDir).AnyTimes()
|
||||
m.appDir.EXPECT().UserCache().Return(cacheDir).AnyTimes()
|
||||
m.appDirVersion.EXPECT().UserCache().Return(versionedCacheDir).AnyTimes()
|
||||
|
||||
require.NoError(m.t, os.RemoveAll(baseDir))
|
||||
require.NoError(m.t, os.MkdirAll(baseDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(logsDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(configDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(cacheDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(versionedOldCacheDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(versionedCacheDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(filepath.Join(versionedCacheDir, "updates"), 0700))
|
||||
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "other.log"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v1_10.log"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v1_11.log"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v2_12.log"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v2_13.log"), []byte("Hello"), 0755))
|
||||
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(configDir, "cert.pem"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(configDir, "key.pem"), []byte("Hello"), 0755))
|
||||
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "prefs.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "events.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "user_info.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "mailbox-user@pm.me.db"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "prefs.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "events.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "user_info.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, testAppName+".lock"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "mailbox-user@pm.me.db"), []byte("Hello"), 0755))
|
||||
}
|
||||
|
||||
func checkFileNames(t *testing.T, dir string, expectedFileNames []string) {
|
||||
fileNames := getFileNames(t, dir)
|
||||
require.Equal(t, expectedFileNames, fileNames)
|
||||
}
|
||||
|
||||
func getFileNames(t *testing.T, dir string) []string {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
require.NoError(t, err)
|
||||
|
||||
fileNames := []string{}
|
||||
for _, file := range files {
|
||||
fileNames = append(fileNames, file.Name())
|
||||
if file.IsDir() {
|
||||
subDir := filepath.Join(dir, file.Name())
|
||||
subFileNames := getFileNames(t, subDir)
|
||||
for _, subFileName := range subFileNames {
|
||||
fileNames = append(fileNames, file.Name()+"/"+subFileName)
|
||||
}
|
||||
}
|
||||
}
|
||||
return fileNames
|
||||
}
|
||||
@ -1,251 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/sentry"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type logConfiger interface {
|
||||
GetLogDir() string
|
||||
GetLogPrefix() string
|
||||
}
|
||||
|
||||
const (
|
||||
// Zendesk now has a file size limit of 20MB. When the last N log files
|
||||
// are zipped, it should fit under 20MB. Value in MB (average file has
|
||||
// few hundreds kB).
|
||||
maxLogFileSize = 10 * 1024 * 1024 //nolint[gochecknoglobals]
|
||||
// Including the current logfile.
|
||||
maxNumberLogFiles = 3 //nolint[gochecknoglobals]
|
||||
)
|
||||
|
||||
// logFile is pointer to currently open file used by logrus.
|
||||
var logFile *os.File //nolint[gochecknoglobals]
|
||||
|
||||
var logFileRgx = regexp.MustCompile("^v.*\\.log$") //nolint[gochecknoglobals]
|
||||
var logCrashRgx = regexp.MustCompile("^v.*_crash_.*\\.log$") //nolint[gochecknoglobals]
|
||||
|
||||
// HandlePanic reports the crash to sentry or local file when sentry fails.
|
||||
func HandlePanic(cfg *Config, output string) {
|
||||
sentry.SkipDuringUnwind()
|
||||
|
||||
if !cfg.IsDevMode() {
|
||||
apiCfg := cfg.GetAPIConfig()
|
||||
if err := sentry.ReportSentryCrash(apiCfg.ClientID, apiCfg.AppVersion, apiCfg.UserAgent, errors.New(output)); err != nil {
|
||||
log.Error("Sentry crash report failed: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
filename := getLogFilename(cfg.GetLogPrefix() + "_crash_")
|
||||
filepath := filepath.Join(cfg.GetLogDir(), filename)
|
||||
f, err := os.OpenFile(filepath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
|
||||
if err != nil {
|
||||
log.Error("Cannot open file to write crash report: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = f.WriteString(output)
|
||||
_ = pprof.Lookup("goroutine").WriteTo(f, 2)
|
||||
|
||||
log.Warn("Crash report saved to ", filepath)
|
||||
}
|
||||
|
||||
// GetGID returns goroutine number which can be used to distiguish logs from
|
||||
// the concurent processes. Keep in mind that it returns the number of routine
|
||||
// which executes the function.
|
||||
func GetGID() uint64 {
|
||||
b := make([]byte, 64)
|
||||
b = b[:runtime.Stack(b, false)]
|
||||
b = bytes.TrimPrefix(b, []byte("goroutine "))
|
||||
b = b[:bytes.IndexByte(b, ' ')]
|
||||
n, _ := strconv.ParseUint(string(b), 10, 64)
|
||||
return n
|
||||
}
|
||||
|
||||
// SetupLog set up log level, formatter and output (file or stdout).
|
||||
// Returns whether should be used debug for IMAP and SMTP servers.
|
||||
func SetupLog(cfg logConfiger, levelFlag string) (debugClient, debugServer bool) {
|
||||
level, useFile := getLogLevelAndFile(levelFlag)
|
||||
|
||||
logrus.SetLevel(level)
|
||||
|
||||
if useFile {
|
||||
logrus.SetFormatter(&logrus.JSONFormatter{})
|
||||
setLogFile(cfg.GetLogDir(), cfg.GetLogPrefix())
|
||||
watchLogFileSize(cfg.GetLogDir(), cfg.GetLogPrefix())
|
||||
} else {
|
||||
logrus.SetFormatter(&logrus.TextFormatter{
|
||||
ForceColors: true,
|
||||
FullTimestamp: true,
|
||||
TimestampFormat: time.StampMilli,
|
||||
})
|
||||
logrus.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
switch levelFlag {
|
||||
case "debug-client", "debug-client-json":
|
||||
debugClient = true
|
||||
case "debug-server", "debug-server-json", "trace":
|
||||
fmt.Println("THE LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
|
||||
log.Warning("================================================")
|
||||
log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
|
||||
log.Warning("================================================")
|
||||
debugClient = true
|
||||
debugServer = true
|
||||
}
|
||||
|
||||
return debugClient, debugServer
|
||||
}
|
||||
|
||||
func setLogFile(logDir, logPrefix string) {
|
||||
if logFile != nil {
|
||||
return
|
||||
}
|
||||
|
||||
filename := getLogFilename(logPrefix)
|
||||
var err error
|
||||
logFile, err = os.OpenFile(filepath.Join(logDir, filename), os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
logrus.SetOutput(logFile)
|
||||
|
||||
// Users sometimes change the name of the log file. We want to always log
|
||||
// information about bridge version (included in log prefix) and OS.
|
||||
log.Warn("Bridge version: ", logPrefix, " ", runtime.GOOS)
|
||||
}
|
||||
|
||||
func getLogFilename(logPrefix string) string {
|
||||
currentTime := strconv.Itoa(int(time.Now().Unix()))
|
||||
return logPrefix + "_" + currentTime + ".log"
|
||||
}
|
||||
|
||||
func watchLogFileSize(logDir, logPrefix string) {
|
||||
go func() {
|
||||
for {
|
||||
// Some rare bug can cause log file spamming a lot. Checking file
|
||||
// size too often is not good, and at the same time postpone next
|
||||
// check for too long is the same thing. 30 seconds seems as good
|
||||
// compromise; average computer can generates ~500MB in 30 seconds.
|
||||
time.Sleep(30 * time.Second)
|
||||
checkLogFileSize(logDir, logPrefix)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func checkLogFileSize(logDir, logPrefix string) {
|
||||
if logFile == nil {
|
||||
return
|
||||
}
|
||||
|
||||
stat, err := logFile.Stat()
|
||||
if err != nil {
|
||||
log.Error("Log file size check failed: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if stat.Size() >= maxLogFileSize {
|
||||
log.Warn("Current log file ", logFile.Name(), " is too big, opening new file")
|
||||
closeLogFile()
|
||||
setLogFile(logDir, logPrefix)
|
||||
}
|
||||
|
||||
if err := clearLogs(logDir); err != nil {
|
||||
log.Error("Cannot clear logs ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func closeLogFile() {
|
||||
if logFile != nil {
|
||||
_ = logFile.Close()
|
||||
logFile = nil
|
||||
}
|
||||
}
|
||||
|
||||
func clearLogs(logDir string) error {
|
||||
files, err := ioutil.ReadDir(logDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var logsWithPrefix []string
|
||||
var crashesWithPrefix []string
|
||||
|
||||
for _, file := range files {
|
||||
if logFileRgx.MatchString(file.Name()) {
|
||||
if logCrashRgx.MatchString(file.Name()) {
|
||||
crashesWithPrefix = append(crashesWithPrefix, file.Name())
|
||||
} else {
|
||||
logsWithPrefix = append(logsWithPrefix, file.Name())
|
||||
}
|
||||
} else {
|
||||
// Older versions of Bridge stored logs in subfolders for each version.
|
||||
// That also has to be cleared and the functionality can be removed after some time.
|
||||
if file.IsDir() {
|
||||
if err := clearLogs(filepath.Join(logDir, file.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
removeLog(logDir, file.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
removeOldLogs(logDir, logsWithPrefix)
|
||||
removeOldLogs(logDir, crashesWithPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeOldLogs(logDir string, filenames []string) {
|
||||
count := len(filenames)
|
||||
if count <= maxNumberLogFiles {
|
||||
return
|
||||
}
|
||||
|
||||
sort.Strings(filenames) // Sorted by timestamp: oldest first.
|
||||
for _, filename := range filenames[:count-maxNumberLogFiles] {
|
||||
removeLog(logDir, filename)
|
||||
}
|
||||
}
|
||||
|
||||
func removeLog(logDir, filename string) {
|
||||
// We need to be sure to delete only log files.
|
||||
// Directory with logs can also contain other files.
|
||||
if !logFileRgx.MatchString(filename) {
|
||||
return
|
||||
}
|
||||
if err := os.RemoveAll(filepath.Join(logDir, filename)); err != nil {
|
||||
log.Error("Cannot remove old logs ", err)
|
||||
}
|
||||
}
|
||||
@ -1,49 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// +build !build_qa
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func getLogLevelAndFile(levelFlag string) (level logrus.Level, useFile bool) {
|
||||
useFile = true
|
||||
switch levelFlag {
|
||||
case "panic":
|
||||
level = logrus.PanicLevel
|
||||
case "fatal":
|
||||
level = logrus.FatalLevel
|
||||
case "error":
|
||||
level = logrus.ErrorLevel
|
||||
case "warn":
|
||||
level = logrus.WarnLevel
|
||||
case "info":
|
||||
level = logrus.InfoLevel
|
||||
case "debug", "debug-client", "debug-server", "debug-client-json", "debug-server-json":
|
||||
level = logrus.DebugLevel
|
||||
useFile = false
|
||||
case "trace":
|
||||
level = logrus.TraceLevel
|
||||
useFile = false
|
||||
default:
|
||||
level = logrus.InfoLevel
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -1,50 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// +build build_qa
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// getLogLevelAndFile for QA build is altered in a way even decrypted data are stored
|
||||
// in the log file when forced with `debug-client-json` or `debug-server-json`.
|
||||
func getLogLevelAndFile(levelFlag string) (level logrus.Level, useFile bool) {
|
||||
useFile = true
|
||||
switch levelFlag {
|
||||
case "panic":
|
||||
level = logrus.PanicLevel
|
||||
case "fatal":
|
||||
level = logrus.FatalLevel
|
||||
case "error":
|
||||
level = logrus.ErrorLevel
|
||||
case "warn":
|
||||
level = logrus.WarnLevel
|
||||
case "info":
|
||||
level = logrus.InfoLevel
|
||||
case "debug-client-json", "debug-server-json":
|
||||
level = logrus.DebugLevel
|
||||
case "debug", "debug-client", "debug-server":
|
||||
level = logrus.DebugLevel
|
||||
useFile = false
|
||||
default:
|
||||
level = logrus.InfoLevel
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -1,225 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testLogConfig struct{ logDir, logPrefix string }
|
||||
|
||||
func (c *testLogConfig) GetLogDir() string { return c.logDir }
|
||||
func (c *testLogConfig) GetLogPrefix() string { return c.logPrefix }
|
||||
|
||||
var testLogDir string //nolint[gochecknoglobals]
|
||||
|
||||
func setupTestLogs() {
|
||||
var err error
|
||||
testLogDir, err = ioutil.TempDir("", "log")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownTestLogs() {
|
||||
_ = os.RemoveAll(testLogDir)
|
||||
}
|
||||
|
||||
func TestLogNameLength(t *testing.T) {
|
||||
cfg := New("bridge-test", "longVersion123", "longRevision1234567890", "c2")
|
||||
name := getLogFilename(cfg.GetLogPrefix())
|
||||
if len(name) > 128 {
|
||||
t.Fatal("Name of the log is too long - limit for encrypted linux is 128 characters")
|
||||
}
|
||||
}
|
||||
|
||||
// Info and higher levels writes to the file.
|
||||
func TestSetupLogInfo(t *testing.T) {
|
||||
dir := beforeEachCreateTestDir(t, "setupInfo")
|
||||
|
||||
SetupLog(&testLogConfig{dir, "v"}, "info")
|
||||
require.Equal(t, "info", logrus.GetLevel().String())
|
||||
|
||||
logrus.Info("test message")
|
||||
files := checkLogFiles(t, dir, 1)
|
||||
checkLogContains(t, dir, files[0].Name(), "test message")
|
||||
}
|
||||
|
||||
// Debug levels writes to stdout.
|
||||
func TestSetupLogDebug(t *testing.T) {
|
||||
dir := beforeEachCreateTestDir(t, "setupDebug")
|
||||
|
||||
SetupLog(&testLogConfig{dir, "v"}, "debug")
|
||||
require.Equal(t, "debug", logrus.GetLevel().String())
|
||||
|
||||
logrus.Info("test message")
|
||||
checkLogFiles(t, dir, 0)
|
||||
}
|
||||
|
||||
func TestReopenLogFile(t *testing.T) {
|
||||
dir := beforeEachCreateTestDir(t, "reopenLogFile")
|
||||
|
||||
setLogFile(dir, "v1")
|
||||
|
||||
done := make(chan interface{})
|
||||
|
||||
log.Info("first message")
|
||||
|
||||
go func() {
|
||||
<-done // Wait for closing file and opening new one.
|
||||
log.Info("second message")
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
closeLogFile()
|
||||
setLogFile(dir, "v2")
|
||||
|
||||
done <- nil
|
||||
<-done // Wait for second log message.
|
||||
|
||||
files := checkLogFiles(t, dir, 2)
|
||||
checkLogContains(t, dir, files[0].Name(), "first message")
|
||||
checkLogContains(t, dir, files[1].Name(), "second message")
|
||||
}
|
||||
|
||||
func TestCheckLogFileSizeSmall(t *testing.T) {
|
||||
dir := beforeEachCreateTestDir(t, "logFileSizeSmall")
|
||||
|
||||
setLogFile(dir, "v1")
|
||||
originalFileName := logFile.Name()
|
||||
|
||||
_, _ = logFile.WriteString("small file")
|
||||
checkLogFileSize(dir, "v2")
|
||||
|
||||
require.Equal(t, originalFileName, logFile.Name())
|
||||
}
|
||||
|
||||
func TestCheckLogFileSizeBig(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
dir := beforeEachCreateTestDir(t, "logFileSizeBig")
|
||||
|
||||
setLogFile(dir, "v1")
|
||||
originalFileName := logFile.Name()
|
||||
|
||||
// The limit for big file is 10*1024*1024 - keep the string 10 letters long.
|
||||
for i := 0; i < 1024*1024; i++ {
|
||||
_, _ = logFile.WriteString("big file!\n")
|
||||
}
|
||||
checkLogFileSize(dir, "v2")
|
||||
|
||||
require.NotEqual(t, originalFileName, logFile.Name())
|
||||
}
|
||||
|
||||
// ClearLogs removes only bridge old log files keeping last three of them.
|
||||
func TestClearLogsLinux(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
dir := beforeEachCreateTestDir(t, "clearLogs")
|
||||
|
||||
createTestStructureLinux(m, dir)
|
||||
require.NoError(t, clearLogs(dir))
|
||||
checkFileNames(t, dir, []string{
|
||||
"cache",
|
||||
"cache/c1",
|
||||
"cache/c1/events.json",
|
||||
"cache/c1/mailbox-user@pm.me.db",
|
||||
"cache/c1/prefs.json",
|
||||
"cache/c1/user_info.json",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"cache/c2/events.json",
|
||||
"cache/c2/mailbox-user@pm.me.db",
|
||||
"cache/c2/prefs.json",
|
||||
"cache/c2/updates",
|
||||
"cache/c2/user_info.json",
|
||||
"config",
|
||||
"config/cert.pem",
|
||||
"config/key.pem",
|
||||
"logs",
|
||||
"logs/other.log",
|
||||
"logs/v1_11.log",
|
||||
"logs/v2_12.log",
|
||||
"logs/v2_13.log",
|
||||
})
|
||||
}
|
||||
|
||||
// ClearLogs removes only bridge old log files even when log folder
|
||||
// is shared with other files on Windows.
|
||||
func TestClearLogsWindows(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
dir := beforeEachCreateTestDir(t, "clearLogs")
|
||||
|
||||
createTestStructureWindows(m, dir)
|
||||
require.NoError(t, clearLogs(dir))
|
||||
checkFileNames(t, dir, []string{
|
||||
"cache",
|
||||
"cache/c1",
|
||||
"cache/c1/events.json",
|
||||
"cache/c1/mailbox-user@pm.me.db",
|
||||
"cache/c1/prefs.json",
|
||||
"cache/c1/user_info.json",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"cache/c2/events.json",
|
||||
"cache/c2/mailbox-user@pm.me.db",
|
||||
"cache/c2/prefs.json",
|
||||
"cache/c2/updates",
|
||||
"cache/c2/user_info.json",
|
||||
"cache/other.log",
|
||||
"cache/v1_11.log",
|
||||
"cache/v2_12.log",
|
||||
"cache/v2_13.log",
|
||||
"config",
|
||||
"config/cert.pem",
|
||||
"config/key.pem",
|
||||
})
|
||||
}
|
||||
|
||||
func beforeEachCreateTestDir(t *testing.T, dir string) string {
|
||||
// Make sure opened file (from the previous test) is cleared.
|
||||
closeLogFile()
|
||||
|
||||
dir = filepath.Join(testLogDir, dir)
|
||||
require.NoError(t, os.MkdirAll(dir, 0700))
|
||||
return dir
|
||||
}
|
||||
|
||||
func checkLogFiles(t *testing.T, dir string, expectedCount int) []os.FileInfo {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedCount, len(files))
|
||||
return files
|
||||
}
|
||||
|
||||
func checkLogContains(t *testing.T, dir, fileName, expectedSubstr string) {
|
||||
data, err := ioutil.ReadFile(filepath.Join(dir, fileName)) //nolint[gosec]
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(data), expectedSubstr)
|
||||
}
|
||||
@ -1,76 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: config/config.go
|
||||
|
||||
// Package config is a generated GoMock package.
|
||||
package config
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockappDirer is a mock of appDirer interface
|
||||
type MockappDirer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockappDirerMockRecorder
|
||||
}
|
||||
|
||||
// MockappDirerMockRecorder is the mock recorder for MockappDirer
|
||||
type MockappDirerMockRecorder struct {
|
||||
mock *MockappDirer
|
||||
}
|
||||
|
||||
// NewMockappDirer creates a new mock instance
|
||||
func NewMockappDirer(ctrl *gomock.Controller) *MockappDirer {
|
||||
mock := &MockappDirer{ctrl: ctrl}
|
||||
mock.recorder = &MockappDirerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockappDirer) EXPECT() *MockappDirerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// UserConfig mocks base method
|
||||
func (m *MockappDirer) UserConfig() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserConfig")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UserConfig indicates an expected call of UserConfig
|
||||
func (mr *MockappDirerMockRecorder) UserConfig() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserConfig", reflect.TypeOf((*MockappDirer)(nil).UserConfig))
|
||||
}
|
||||
|
||||
// UserCache mocks base method
|
||||
func (m *MockappDirer) UserCache() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserCache")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UserCache indicates an expected call of UserCache
|
||||
func (mr *MockappDirerMockRecorder) UserCache() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCache", reflect.TypeOf((*MockappDirer)(nil).UserCache))
|
||||
}
|
||||
|
||||
// UserLogs mocks base method
|
||||
func (m *MockappDirer) UserLogs() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserLogs")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UserLogs indicates an expected call of UserLogs
|
||||
func (mr *MockappDirerMockRecorder) UserLogs() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserLogs", reflect.TypeOf((*MockappDirer)(nil).UserLogs))
|
||||
}
|
||||
@ -1,127 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Preferences struct {
|
||||
cache map[string]string
|
||||
path string
|
||||
lock *sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPreferences returns loaded preferences.
|
||||
func NewPreferences(preferencesPath string) *Preferences {
|
||||
p := &Preferences{
|
||||
path: preferencesPath,
|
||||
lock: &sync.RWMutex{},
|
||||
}
|
||||
if err := p.load(); err != nil {
|
||||
log.Warn("Cannot load preferences: ", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Preferences) load() error {
|
||||
if p.cache != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
p.cache = map[string]string{}
|
||||
|
||||
f, err := os.Open(p.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close() //nolint[errcheck]
|
||||
|
||||
return json.NewDecoder(f).Decode(&p.cache)
|
||||
}
|
||||
|
||||
func (p *Preferences) save() error {
|
||||
if p.cache == nil {
|
||||
return errors.New("cannot save preferences: cache is nil")
|
||||
}
|
||||
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
f, err := os.Create(p.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close() //nolint[errcheck]
|
||||
|
||||
return json.NewEncoder(f).Encode(p.cache)
|
||||
}
|
||||
|
||||
func (p *Preferences) SetDefault(key, value string) {
|
||||
if p.Get(key) == "" {
|
||||
p.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Preferences) Get(key string) string {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
|
||||
return p.cache[key]
|
||||
}
|
||||
|
||||
func (p *Preferences) GetBool(key string) bool {
|
||||
return p.Get(key) == "true"
|
||||
}
|
||||
|
||||
func (p *Preferences) GetInt(key string) int {
|
||||
value, err := strconv.Atoi(p.Get(key))
|
||||
if err != nil {
|
||||
log.Error("Cannot parse int: ", err)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (p *Preferences) Set(key, value string) {
|
||||
p.lock.Lock()
|
||||
p.cache[key] = value
|
||||
p.lock.Unlock()
|
||||
|
||||
if err := p.save(); err != nil {
|
||||
log.Warn("Cannot save preferences: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Preferences) SetBool(key string, value bool) {
|
||||
if value {
|
||||
p.Set(key, "true")
|
||||
} else {
|
||||
p.Set(key, "false")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Preferences) SetInt(key string, value int) {
|
||||
p.Set(key, strconv.Itoa(value))
|
||||
}
|
||||
@ -1,109 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testPrefFilePath = "/tmp/pref.json"
|
||||
|
||||
func shutdownTestPreferences() {
|
||||
_ = os.RemoveAll(testPrefFilePath)
|
||||
}
|
||||
|
||||
func TestLoadNoPreferences(t *testing.T) {
|
||||
pref := newTestEmptyPreferences(t)
|
||||
require.Equal(t, "", pref.Get("key"))
|
||||
}
|
||||
|
||||
func TestLoadBadPreferences(t *testing.T) {
|
||||
require.NoError(t, ioutil.WriteFile(testPrefFilePath, []byte("{\"key\":\"value"), 0700))
|
||||
pref := NewPreferences(testPrefFilePath)
|
||||
require.Equal(t, "", pref.Get("key"))
|
||||
}
|
||||
|
||||
func TestPreferencesGet(t *testing.T) {
|
||||
pref := newTestPreferences(t)
|
||||
require.Equal(t, "value", pref.Get("str"))
|
||||
require.Equal(t, "42", pref.Get("int"))
|
||||
require.Equal(t, "true", pref.Get("bool"))
|
||||
require.Equal(t, "t", pref.Get("falseBool"))
|
||||
}
|
||||
|
||||
func TestPreferencesGetInt(t *testing.T) {
|
||||
pref := newTestPreferences(t)
|
||||
require.Equal(t, 0, pref.GetInt("str"))
|
||||
require.Equal(t, 42, pref.GetInt("int"))
|
||||
require.Equal(t, 0, pref.GetInt("bool"))
|
||||
require.Equal(t, 0, pref.GetInt("falseBool"))
|
||||
}
|
||||
|
||||
func TestPreferencesGetBool(t *testing.T) {
|
||||
pref := newTestPreferences(t)
|
||||
require.Equal(t, false, pref.GetBool("str"))
|
||||
require.Equal(t, false, pref.GetBool("int"))
|
||||
require.Equal(t, true, pref.GetBool("bool"))
|
||||
require.Equal(t, false, pref.GetBool("falseBool"))
|
||||
}
|
||||
|
||||
func TestPreferencesSetDefault(t *testing.T) {
|
||||
pref := newTestEmptyPreferences(t)
|
||||
pref.SetDefault("key", "value")
|
||||
pref.SetDefault("key", "othervalue")
|
||||
require.Equal(t, "value", pref.Get("key"))
|
||||
}
|
||||
|
||||
func TestPreferencesSet(t *testing.T) {
|
||||
pref := newTestEmptyPreferences(t)
|
||||
pref.Set("str", "value")
|
||||
checkSavedPreferences(t, "{\"str\":\"value\"}")
|
||||
}
|
||||
|
||||
func TestPreferencesSetInt(t *testing.T) {
|
||||
pref := newTestEmptyPreferences(t)
|
||||
pref.SetInt("int", 42)
|
||||
checkSavedPreferences(t, "{\"int\":\"42\"}")
|
||||
}
|
||||
|
||||
func TestPreferencesSetBool(t *testing.T) {
|
||||
pref := newTestEmptyPreferences(t)
|
||||
pref.SetBool("trueBool", true)
|
||||
pref.SetBool("falseBool", false)
|
||||
checkSavedPreferences(t, "{\"falseBool\":\"false\",\"trueBool\":\"true\"}")
|
||||
}
|
||||
|
||||
func newTestEmptyPreferences(t *testing.T) *Preferences {
|
||||
require.NoError(t, os.RemoveAll(testPrefFilePath))
|
||||
return NewPreferences(testPrefFilePath)
|
||||
}
|
||||
|
||||
func newTestPreferences(t *testing.T) *Preferences {
|
||||
require.NoError(t, ioutil.WriteFile(testPrefFilePath, []byte("{\"str\":\"value\",\"int\":\"42\",\"bool\":\"true\",\"falseBool\":\"t\"}"), 0700))
|
||||
return NewPreferences(testPrefFilePath)
|
||||
}
|
||||
|
||||
func checkSavedPreferences(t *testing.T, expected string) {
|
||||
data, err := ioutil.ReadFile(testPrefFilePath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected+"\n", string(data))
|
||||
}
|
||||
@ -1,173 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
type tlsConfiger interface {
|
||||
GetTLSCertPath() string
|
||||
GetTLSKeyPath() string
|
||||
}
|
||||
|
||||
var tlsTemplate = x509.Certificate{ //nolint[gochecknoglobals]
|
||||
SerialNumber: big.NewInt(-1),
|
||||
Subject: pkix.Name{
|
||||
Country: []string{"CH"},
|
||||
Organization: []string{"Proton Technologies AG"},
|
||||
OrganizationalUnit: []string{"ProtonMail"},
|
||||
CommonName: "127.0.0.1",
|
||||
},
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(20 * 365 * 24 * time.Hour),
|
||||
}
|
||||
|
||||
var ErrTLSCertExpireSoon = fmt.Errorf("TLS certificate will expire soon")
|
||||
|
||||
// GetTLSConfig tries to load TLS config or generate new one which is then returned.
|
||||
func GetTLSConfig(cfg tlsConfiger) (tlsConfig *tls.Config, err error) {
|
||||
certPath := cfg.GetTLSCertPath()
|
||||
keyPath := cfg.GetTLSKeyPath()
|
||||
tlsConfig, err = loadTLSConfig(certPath, keyPath)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Cannot load cert, generating a new one")
|
||||
tlsConfig, err = GenerateTLSConfig(certPath, keyPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
if err := exec.Command( // nolint[gosec]
|
||||
"/usr/bin/security",
|
||||
"execute-with-privileges",
|
||||
"/usr/bin/security",
|
||||
"add-trusted-cert",
|
||||
"-d",
|
||||
"-r", "trustRoot",
|
||||
"-p", "ssl",
|
||||
"-k", "/Library/Keychains/System.keychain",
|
||||
certPath,
|
||||
).Run(); err != nil {
|
||||
log.WithError(err).Error("Failed to add cert to system keychain")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig.ServerName = "127.0.0.1"
|
||||
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(tlsConfig.Certificates[0].Leaf)
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
|
||||
/* This is deprecated:
|
||||
* SA1019: tlsConfig.BuildNameToCertificate is deprecated:
|
||||
* NameToCertificate only allows associating a single certificate with a given name.
|
||||
* Leave that field nil to let the library select the first compatible chain from Certificates.
|
||||
*/
|
||||
tlsConfig.BuildNameToCertificate() // nolint[staticcheck]
|
||||
|
||||
return tlsConfig, err
|
||||
}
|
||||
|
||||
func loadTLSConfig(certPath, keyPath string) (tlsConfig *tls.Config, err error) {
|
||||
c, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.Leaf, err = x509.ParseCertificate(c.Certificate[0])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tlsConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{c},
|
||||
}
|
||||
|
||||
if time.Now().Add(31 * 24 * time.Hour).After(c.Leaf.NotAfter) {
|
||||
err = ErrTLSCertExpireSoon
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GenerateTLSConfig generates certs and keys at the given filepaths and returns a TLS Config which holds them.
|
||||
// See https://golang.org/src/crypto/tls/generate_cert.go
|
||||
func GenerateTLSConfig(certPath, keyPath string) (tlsConfig *tls.Config, err error) {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to generate private key: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to generate serial number: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
tlsTemplate.SerialNumber = serialNumber
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &tlsTemplate, &tlsTemplate, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to create certificate: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
certOut, err := os.Create(certPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer certOut.Close() //nolint[errcheck]
|
||||
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer keyOut.Close() //nolint[errcheck]
|
||||
err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return loadTLSConfig(certPath, keyPath)
|
||||
}
|
||||
@ -1,63 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testTLSConfig struct{ certPath, keyPath string }
|
||||
|
||||
func (c *testTLSConfig) GetTLSCertPath() string { return c.certPath }
|
||||
func (c *testTLSConfig) GetTLSKeyPath() string { return c.keyPath }
|
||||
|
||||
func TestTLSKeyRenewal(t *testing.T) {
|
||||
// Remove keys.
|
||||
configPath := "/tmp"
|
||||
certPath := filepath.Join(configPath, "cert.pem")
|
||||
keyPath := filepath.Join(configPath, "key.pem")
|
||||
_ = os.Remove(certPath)
|
||||
_ = os.Remove(keyPath)
|
||||
|
||||
// Put old key there.
|
||||
tlsTemplate.NotBefore = time.Now().Add(-365 * 24 * time.Hour)
|
||||
tlsTemplate.NotAfter = time.Now()
|
||||
cert, err := GenerateTLSConfig(certPath, keyPath)
|
||||
require.Equal(t, err, ErrTLSCertExpireSoon)
|
||||
require.Equal(t, len(cert.Certificates), 1)
|
||||
time.Sleep(time.Second)
|
||||
now, notValidAfter := time.Now(), cert.Certificates[0].Leaf.NotAfter
|
||||
require.True(t, now.After(notValidAfter), "old certificate expected to not be valid at %v but have valid until %v", now, notValidAfter)
|
||||
|
||||
// Renew key.
|
||||
tlsTemplate.NotBefore = time.Now()
|
||||
tlsTemplate.NotAfter = time.Now().Add(2 * 365 * 24 * time.Hour)
|
||||
cert, err = GetTLSConfig(&testTLSConfig{certPath, keyPath})
|
||||
if runtime.GOOS != "darwin" { // Darwin is not supported.
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, len(cert.Certificates), 1)
|
||||
now, notValidAfter = time.Now(), cert.Certificates[0].Leaf.NotAfter
|
||||
require.False(t, now.After(notValidAfter), "new certificate expected to be valid at %v but have valid until %v", now, notValidAfter)
|
||||
}
|
||||
@ -1,40 +0,0 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package constants contains variables that are set via ldflags during build.
|
||||
package constants
|
||||
|
||||
// nolint[gochecknoglobals]
|
||||
var (
|
||||
// Version of the build.
|
||||
Version = ""
|
||||
|
||||
// Revision is current hash of the build.
|
||||
Revision = ""
|
||||
|
||||
// BuildTime stamp of the build.
|
||||
BuildTime = ""
|
||||
|
||||
// DSNSentry client keys to be able to report crashes to Sentry.
|
||||
DSNSentry = ""
|
||||
|
||||
// LongVersion is derived from Version and Revision.
|
||||
LongVersion = Version + " (" + Revision + ")"
|
||||
|
||||
// BuildVersion is derived from LongVersion and BuildTime.
|
||||
BuildVersion = LongVersion + " " + BuildTime
|
||||
)
|
||||
82
pkg/files/removal.go
Normal file
82
pkg/files/removal.go
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package files provides standard filesystem operations.
|
||||
package files
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
type OpRemove struct {
|
||||
targets []string
|
||||
exceptions []string
|
||||
}
|
||||
|
||||
func Remove(targets ...string) *OpRemove {
|
||||
return &OpRemove{targets: targets}
|
||||
}
|
||||
|
||||
func (op *OpRemove) Except(exceptions ...string) *OpRemove {
|
||||
op.exceptions = exceptions
|
||||
return op
|
||||
}
|
||||
|
||||
func (op *OpRemove) Do() error {
|
||||
var multiErr error
|
||||
|
||||
for _, target := range op.targets {
|
||||
if err := remove(target, op.exceptions...); err != nil {
|
||||
multiErr = multierror.Append(multiErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
return multiErr
|
||||
}
|
||||
|
||||
func remove(dir string, except ...string) error {
|
||||
var toRemove []string
|
||||
|
||||
if err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
for _, exception := range except {
|
||||
if path == exception || strings.HasPrefix(exception, path) || strings.HasPrefix(path, exception) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
toRemove = append(toRemove, path)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(toRemove)))
|
||||
|
||||
for _, target := range toRemove {
|
||||
if err := os.RemoveAll(target); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
116
pkg/files/removal_test.go
Normal file
116
pkg/files/removal_test.go
Normal file
@ -0,0 +1,116 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package files
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
dir := newTestDir(t,
|
||||
"subdir1",
|
||||
"subdir2/subdir3",
|
||||
)
|
||||
defer delTestDir(t, dir)
|
||||
|
||||
createTestFiles(t, dir,
|
||||
"subdir1/file1",
|
||||
"subdir1/file2",
|
||||
"subdir2/file3",
|
||||
"subdir2/file4",
|
||||
"subdir2/subdir3/file5",
|
||||
"subdir2/subdir3/file6",
|
||||
)
|
||||
|
||||
require.NoError(t, Remove(
|
||||
filepath.Join(dir, "subdir1"),
|
||||
filepath.Join(dir, "subdir2", "file3"),
|
||||
filepath.Join(dir, "subdir2", "subdir3", "file5"),
|
||||
).Do())
|
||||
|
||||
assert.NoFileExists(t, filepath.Join(dir, "subdir1", "file1"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "subdir1", "file2"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "subdir2", "file3"))
|
||||
assert.FileExists(t, filepath.Join(dir, "subdir2", "file4"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "subdir2", "subdir3", "file5"))
|
||||
assert.FileExists(t, filepath.Join(dir, "subdir2", "subdir3", "file6"))
|
||||
}
|
||||
|
||||
func TestRemoveWithExceptions(t *testing.T) {
|
||||
dir := newTestDir(t,
|
||||
"subdir1",
|
||||
"subdir2/subdir3",
|
||||
"subdir4",
|
||||
)
|
||||
defer delTestDir(t, dir)
|
||||
|
||||
createTestFiles(t, dir,
|
||||
"subdir1/file1",
|
||||
"subdir1/file2",
|
||||
"subdir2/file3",
|
||||
"subdir2/file4",
|
||||
"subdir2/subdir3/file5",
|
||||
"subdir2/subdir3/file6",
|
||||
"subdir4/file7",
|
||||
"subdir4/file8",
|
||||
)
|
||||
|
||||
require.NoError(t, Remove(dir).Except(
|
||||
filepath.Join(dir, "subdir2", "file4"),
|
||||
filepath.Join(dir, "subdir2", "subdir3", "file6"),
|
||||
filepath.Join(dir, "subdir4"),
|
||||
).Do())
|
||||
|
||||
assert.NoFileExists(t, filepath.Join(dir, "subdir1", "file1"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "subdir1", "file2"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "subdir2", "file3"))
|
||||
assert.FileExists(t, filepath.Join(dir, "subdir2", "file4"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "subdir2", "subdir3", "file5"))
|
||||
assert.FileExists(t, filepath.Join(dir, "subdir2", "subdir3", "file6"))
|
||||
assert.FileExists(t, filepath.Join(dir, "subdir4", "file7"))
|
||||
assert.FileExists(t, filepath.Join(dir, "subdir4", "file8"))
|
||||
}
|
||||
|
||||
func newTestDir(t *testing.T, subdirs ...string) string {
|
||||
dir, err := ioutil.TempDir("", "test-files-dir")
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, target := range subdirs {
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(dir, target), 0700))
|
||||
}
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func createTestFiles(t *testing.T, dir string, files ...string) {
|
||||
for _, target := range files {
|
||||
f, err := os.Create(filepath.Join(dir, target))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, f.Close())
|
||||
}
|
||||
}
|
||||
|
||||
func delTestDir(t *testing.T, dir string) {
|
||||
require.NoError(t, os.RemoveAll(dir))
|
||||
}
|
||||
@ -38,6 +38,15 @@ import (
|
||||
|
||||
// Parse parses RAW message.
|
||||
func Parse(r io.Reader) (m *pmapi.Message, mimeBody, plainBody string, attReaders []io.Reader, err error) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = fmt.Errorf("panic while parsing message: %v", r)
|
||||
}()
|
||||
|
||||
p, err := parser.New(r)
|
||||
if err != nil {
|
||||
return nil, "", "", nil, errors.Wrap(err, "failed to create new parser")
|
||||
|
||||
@ -545,6 +545,20 @@ func TestParseEncodedContentTypeBad(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
type panicReader struct{}
|
||||
|
||||
func (panicReader) Read(p []byte) (int, error) {
|
||||
panic("lol")
|
||||
}
|
||||
|
||||
func TestParsePanic(t *testing.T) {
|
||||
var err error
|
||||
require.NotPanics(t, func() {
|
||||
_, _, _, _, err = Parse(&panicReader{})
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func getFileReader(filename string) io.Reader {
|
||||
f, err := os.Open(filepath.Join("testdata", filename))
|
||||
if err != nil {
|
||||
|
||||
@ -81,4 +81,6 @@ type Client interface {
|
||||
|
||||
KeyRingForAddressID(string) (kr *crypto.KeyRing, err error)
|
||||
GetPublicKeysForEmail(string) ([]PublicKey, bool, error)
|
||||
|
||||
DownloadAndVerify(string, string, *crypto.KeyRing) (io.Reader, error)
|
||||
}
|
||||
|
||||
@ -137,10 +137,18 @@ func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
|
||||
cm.roundTripper = rt
|
||||
}
|
||||
|
||||
func (cm *ClientManager) GetClientConfig() *ClientConfig {
|
||||
return cm.config
|
||||
}
|
||||
|
||||
func (cm *ClientManager) SetUserAgent(clientName, clientVersion, os string) {
|
||||
cm.config.UserAgent = formatUserAgent(clientName, clientVersion, os)
|
||||
}
|
||||
|
||||
func (cm *ClientManager) GetUserAgent() string {
|
||||
return cm.config.UserAgent
|
||||
}
|
||||
|
||||
// GetClient returns a client for the given userID.
|
||||
// If the client does not exist already, it is created.
|
||||
func (cm *ClientManager) GetClient(userID string) Client {
|
||||
@ -366,7 +374,7 @@ func (cm *ClientManager) clearToken(userID string) {
|
||||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).Info("Clearing token")
|
||||
logrus.WithField("userID", userID).Debug("Clearing token")
|
||||
|
||||
delete(cm.tokens, userID)
|
||||
}
|
||||
|
||||
@ -18,23 +18,23 @@
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// rootURL is the API root URL.
|
||||
//
|
||||
// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
|
||||
// Default is pmapi_prod.
|
||||
//
|
||||
// It must not contain the protocol! The protocol should be in rootScheme.
|
||||
var rootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
|
||||
var rootScheme = "https" //nolint[gochecknoglobals]
|
||||
|
||||
// The HTTP transport to use by default.
|
||||
var defaultTransport = &http.Transport{ //nolint[gochecknoglobals]
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
// rootScheme is the scheme to use for connections to the root URL.
|
||||
var rootScheme = "https" //nolint[gochecknoglobals]
|
||||
|
||||
func GetAPIConfig(configName, appVersion string) *ClientConfig {
|
||||
return &ClientConfig{
|
||||
AppVersion: strings.Title(configName) + "_" + appVersion,
|
||||
ClientID: configName,
|
||||
Timeout: 25 * time.Minute, // Overall request timeout (~25MB / 25 mins => ~16kB/s, should be reasonable).
|
||||
FirstReadTimeout: 30 * time.Second, // 30s to match 30s response header timeout.
|
||||
MinBytesPerSecond: 1 << 10, // Enforce minimum download speed of 1kB/s.
|
||||
}
|
||||
}
|
||||
|
||||
// checkTLSCerts controls whether TLS certs are checked against known fingerprints.
|
||||
// The default is for this to always be done.
|
||||
var checkTLSCerts = true //nolint[gochecknoglobals]
|
||||
|
||||
@ -15,43 +15,30 @@
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// +build pmapi_prod
|
||||
// +build !pmapi_qa
|
||||
|
||||
package config
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
func (c *Config) GetAPIConfig() *pmapi.ClientConfig {
|
||||
return &pmapi.ClientConfig{
|
||||
AppVersion: c.getAPIOS() + strings.Title(c.appName) + "_" + c.version,
|
||||
ClientID: c.appName,
|
||||
Timeout: 25 * time.Minute, // Overall request timeout (~25MB / 25 mins => ~16kB/s, should be reasonable).
|
||||
FirstReadTimeout: 30 * time.Second, // 30s to match 30s response header timeout.
|
||||
MinBytesPerSecond: 1 << 10, // Enforce minimum download speed of 1kB/s.
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) GetRoundTripper(cm *pmapi.ClientManager, listener listener.Listener) http.RoundTripper {
|
||||
func GetRoundTripper(cm *ClientManager, listener listener.Listener) http.RoundTripper {
|
||||
// We use a TLS dialer.
|
||||
basicDialer := pmapi.NewBasicTLSDialer()
|
||||
basicDialer := NewBasicTLSDialer()
|
||||
|
||||
// We wrap the TLS dialer in a layer which enforces connections to trusted servers.
|
||||
pinningDialer := pmapi.NewPinningTLSDialer(basicDialer)
|
||||
pinningDialer := NewPinningTLSDialer(basicDialer)
|
||||
|
||||
// We want any pin mismatches to be communicated back to bridge GUI and reported.
|
||||
pinningDialer.SetTLSIssueNotifier(func() { listener.Emit(events.TLSCertIssue, "") })
|
||||
pinningDialer.EnableRemoteTLSIssueReporting(c.GetAPIConfig().AppVersion, c.GetAPIConfig().UserAgent)
|
||||
pinningDialer.EnableRemoteTLSIssueReporting(cm)
|
||||
|
||||
// We wrap the pinning dialer in a layer which adds "alternative routing" feature.
|
||||
proxyDialer := pmapi.NewProxyTLSDialer(pinningDialer, cm)
|
||||
proxyDialer := NewProxyTLSDialer(pinningDialer, cm)
|
||||
|
||||
return pmapi.CreateTransportWithDialer(proxyDialer)
|
||||
return CreateTransportWithDialer(proxyDialer)
|
||||
}
|
||||
@ -24,6 +24,8 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -37,13 +39,13 @@ func init() {
|
||||
rootURL = fullRootURL
|
||||
rootScheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
func GetRoundTripper(_ *ClientManager, _ listener.Listener) http.RoundTripper {
|
||||
transport := CreateTransportWithDialer(NewBasicTLSDialer())
|
||||
|
||||
// TLS certificate of testing environment might be self-signed.
|
||||
defaultTransport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
// This config disables TLS cert checking.
|
||||
checkTLSCerts = false
|
||||
return transport
|
||||
}
|
||||
|
||||
@ -30,19 +30,12 @@ type PinningTLSDialer struct {
|
||||
dialer TLSDialer
|
||||
|
||||
// pinChecker is used to check TLS keys of connections.
|
||||
pinChecker pinChecker
|
||||
pinChecker *pinChecker
|
||||
|
||||
// tlsIssueNotifier is used to notify something when there is a TLS issue.
|
||||
tlsIssueNotifier func()
|
||||
|
||||
// appVersion is needed to report TLS mismatches.
|
||||
appVersion string
|
||||
|
||||
// userAgent is needed to report TLS mismatches.
|
||||
userAgent string
|
||||
|
||||
// enableRemoteReporting instructs the dialer to report TLS mismatches.
|
||||
enableRemoteReporting bool
|
||||
reporter *tlsReporter
|
||||
|
||||
// A logger for logging messages.
|
||||
log logrus.FieldLogger
|
||||
@ -63,41 +56,38 @@ func (p *PinningTLSDialer) SetTLSIssueNotifier(notifier func()) {
|
||||
p.tlsIssueNotifier = notifier
|
||||
}
|
||||
|
||||
func (p *PinningTLSDialer) EnableRemoteTLSIssueReporting(appVersion, userAgent string) {
|
||||
p.enableRemoteReporting = true
|
||||
p.appVersion = appVersion
|
||||
p.userAgent = userAgent
|
||||
func (p *PinningTLSDialer) EnableRemoteTLSIssueReporting(cm *ClientManager) {
|
||||
p.reporter = newTLSReporter(p.pinChecker, cm)
|
||||
}
|
||||
|
||||
// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
|
||||
func (p *PinningTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
|
||||
if conn, err = p.dialer.DialTLS(network, address); err != nil {
|
||||
return
|
||||
func (p *PinningTLSDialer) DialTLS(network, address string) (net.Conn, error) {
|
||||
conn, err := p.dialer.DialTLS(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = p.pinChecker.checkCertificate(conn); err != nil {
|
||||
if err := p.pinChecker.checkCertificate(conn); err != nil {
|
||||
if p.tlsIssueNotifier != nil {
|
||||
go p.tlsIssueNotifier()
|
||||
}
|
||||
|
||||
if tlsConn, ok := conn.(*tls.Conn); ok && p.enableRemoteReporting {
|
||||
p.pinChecker.reportCertIssue(
|
||||
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
|
||||
p.reporter.reportCertIssue(
|
||||
TLSReportURI,
|
||||
host,
|
||||
port,
|
||||
tlsConn.ConnectionState(),
|
||||
p.appVersion,
|
||||
p.userAgent,
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
74
pkg/pmapi/download.go
Normal file
74
pkg/pmapi/download.go
Normal file
@ -0,0 +1,74 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
)
|
||||
|
||||
// DownloadAndVerify downloads a file and its signature from the given locations `file` and `sig`.
|
||||
// The file and its signature are verified using the given keyring `kr`.
|
||||
// If the file is verified successfully, it can be read from the returned reader.
|
||||
// TLS fingerprinting is used to verify that connections are only made to known servers.
|
||||
func (c *client) DownloadAndVerify(file, sig string, kr *crypto.KeyRing) (io.Reader, error) {
|
||||
var fb, sb []byte
|
||||
|
||||
if err := c.fetchFile(file, func(r io.Reader) (err error) {
|
||||
fb, err = ioutil.ReadAll(r)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.fetchFile(sig, func(r io.Reader) (err error) {
|
||||
sb, err = ioutil.ReadAll(r)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := kr.VerifyDetached(
|
||||
crypto.NewPlainMessage(fb),
|
||||
crypto.NewPGPSignature(sb),
|
||||
crypto.GetUnixTime(),
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bytes.NewReader(fb), nil
|
||||
}
|
||||
|
||||
func (c *client) fetchFile(file string, fn func(io.Reader) error) error {
|
||||
res, err := c.hc.Get(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = res.Body.Close() }()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to get file: http error %v", res.StatusCode)
|
||||
}
|
||||
|
||||
return fn(res.Body)
|
||||
}
|
||||
@ -294,6 +294,21 @@ func (mr *MockClientMockRecorder) DeleteMessages(arg0 interface{}) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMessages", reflect.TypeOf((*MockClient)(nil).DeleteMessages), arg0)
|
||||
}
|
||||
|
||||
// DownloadAndVerify mocks base method
|
||||
func (m *MockClient) DownloadAndVerify(arg0, arg1 string, arg2 *crypto.KeyRing) (io.Reader, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DownloadAndVerify", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(io.Reader)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DownloadAndVerify indicates an expected call of DownloadAndVerify
|
||||
func (mr *MockClientMockRecorder) DownloadAndVerify(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadAndVerify", reflect.TypeOf((*MockClient)(nil).DownloadAndVerify), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// EmptyFolder mocks base method
|
||||
func (m *MockClient) EmptyFolder(arg0, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -35,7 +35,6 @@ import (
|
||||
|
||||
type pinChecker struct {
|
||||
trustedPins []string
|
||||
sentReports []sentReport
|
||||
}
|
||||
|
||||
type sentReport struct {
|
||||
@ -43,8 +42,8 @@ type sentReport struct {
|
||||
t time.Time
|
||||
}
|
||||
|
||||
func newPinChecker(trustedPins []string) pinChecker {
|
||||
return pinChecker{
|
||||
func newPinChecker(trustedPins []string) *pinChecker {
|
||||
return &pinChecker{
|
||||
trustedPins: trustedPins,
|
||||
}
|
||||
}
|
||||
@ -76,8 +75,25 @@ func certFingerprint(cert *x509.Certificate) string {
|
||||
return fmt.Sprintf(`pin-sha256=%q`, base64.StdEncoding.EncodeToString(hash[:]))
|
||||
}
|
||||
|
||||
type clientConfigProvider interface {
|
||||
GetClientConfig() *ClientConfig
|
||||
}
|
||||
|
||||
type tlsReporter struct {
|
||||
cm clientConfigProvider
|
||||
p *pinChecker
|
||||
sentReports []sentReport
|
||||
}
|
||||
|
||||
func newTLSReporter(p *pinChecker, cm clientConfigProvider) *tlsReporter {
|
||||
return &tlsReporter{
|
||||
cm: cm,
|
||||
p: p,
|
||||
}
|
||||
}
|
||||
|
||||
// reportCertIssue reports a TLS key mismatch.
|
||||
func (p *pinChecker) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState, appVersion, userAgent string) {
|
||||
func (r *tlsReporter) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState) {
|
||||
var certChain []string
|
||||
|
||||
if len(connState.VerifiedChains) > 0 {
|
||||
@ -86,27 +102,29 @@ func (p *pinChecker) reportCertIssue(remoteURI, host, port string, connState tls
|
||||
certChain = marshalCert7468(connState.PeerCertificates)
|
||||
}
|
||||
|
||||
r := newTLSReport(host, port, connState.ServerName, certChain, p.trustedPins, appVersion)
|
||||
cfg := r.cm.GetClientConfig()
|
||||
|
||||
if !p.hasRecentlySentReport(r) {
|
||||
p.recordReport(r)
|
||||
go r.sendReport(remoteURI, userAgent)
|
||||
report := newTLSReport(host, port, connState.ServerName, certChain, r.p.trustedPins, cfg.AppVersion)
|
||||
|
||||
if !r.hasRecentlySentReport(report) {
|
||||
r.recordReport(report)
|
||||
go report.sendReport(remoteURI, cfg.UserAgent)
|
||||
}
|
||||
}
|
||||
|
||||
// hasRecentlySentReport returns whether the report was already sent within the last 24 hours.
|
||||
func (p *pinChecker) hasRecentlySentReport(report tlsReport) bool {
|
||||
func (r *tlsReporter) hasRecentlySentReport(report tlsReport) bool {
|
||||
var validReports []sentReport
|
||||
|
||||
for _, r := range p.sentReports {
|
||||
for _, r := range r.sentReports {
|
||||
if time.Since(r.t) < 24*time.Hour {
|
||||
validReports = append(validReports, r)
|
||||
}
|
||||
}
|
||||
|
||||
p.sentReports = validReports
|
||||
r.sentReports = validReports
|
||||
|
||||
for _, r := range p.sentReports {
|
||||
for _, r := range r.sentReports {
|
||||
if cmp.Equal(report, r.r) {
|
||||
return true
|
||||
}
|
||||
@ -116,8 +134,8 @@ func (p *pinChecker) hasRecentlySentReport(report tlsReport) bool {
|
||||
}
|
||||
|
||||
// recordReport records the given report and the current time so we can check whether we recently sent this report.
|
||||
func (p *pinChecker) recordReport(r tlsReport) {
|
||||
p.sentReports = append(p.sentReports, sentReport{r: r, t: time.Now()})
|
||||
func (r *tlsReporter) recordReport(report tlsReport) {
|
||||
r.sentReports = append(r.sentReports, sentReport{r: report, t: time.Now()})
|
||||
}
|
||||
|
||||
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
|
||||
|
||||
@ -27,6 +27,14 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type fakeClientConfigProvider struct {
|
||||
version, useragent string
|
||||
}
|
||||
|
||||
func (c *fakeClientConfigProvider) GetClientConfig() *ClientConfig {
|
||||
return &ClientConfig{AppVersion: c.version, UserAgent: c.useragent}
|
||||
}
|
||||
|
||||
func TestPinCheckerDoubleReport(t *testing.T) {
|
||||
reportCounter := 0
|
||||
|
||||
@ -34,11 +42,11 @@ func TestPinCheckerDoubleReport(t *testing.T) {
|
||||
reportCounter++
|
||||
}))
|
||||
|
||||
pc := newPinChecker(TrustedAPIPins)
|
||||
r := newTLSReporter(newPinChecker(TrustedAPIPins), &fakeClientConfigProvider{version: "3", useragent: "useragent"})
|
||||
|
||||
// Report the same issue many times.
|
||||
for i := 0; i < 10; i++ {
|
||||
pc.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{}, "3", "useragent")
|
||||
r.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{})
|
||||
}
|
||||
|
||||
// We should only report once.
|
||||
@ -48,7 +56,7 @@ func TestPinCheckerDoubleReport(t *testing.T) {
|
||||
|
||||
// If we then report something else many times.
|
||||
for i := 0; i < 10; i++ {
|
||||
pc.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{}, "3", "useragent")
|
||||
r.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{})
|
||||
}
|
||||
|
||||
// We should get a second report.
|
||||
|
||||
@ -120,9 +120,7 @@ func (c *client) UpdateUser() (user *User, err error) {
|
||||
|
||||
c.user = user
|
||||
sentry.ConfigureScope(func(scope *sentry.Scope) {
|
||||
scope.SetUser(sentry.User{
|
||||
ID: user.ID,
|
||||
})
|
||||
scope.SetUser(sentry.User{ID: user.ID})
|
||||
})
|
||||
|
||||
var tmpList AddressList
|
||||
|
||||
@ -18,8 +18,8 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -31,9 +31,12 @@ func IsPortFree(port int) bool {
|
||||
if !(0 < port && port < maxPortNumber) {
|
||||
return false
|
||||
}
|
||||
stringPort := ":" + strconv.Itoa(port)
|
||||
isFree := !isOccupied(stringPort)
|
||||
return isFree
|
||||
// First, check localhost only.
|
||||
if isOccupied(fmt.Sprintf("127.0.0.1:%d", port)) {
|
||||
return false
|
||||
}
|
||||
// Second, check also ports opened to public.
|
||||
return !isOccupied(fmt.Sprintf(":%d", port))
|
||||
}
|
||||
|
||||
func isOccupied(port string) bool {
|
||||
|
||||
@ -19,28 +19,76 @@ package sentry
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/constants"
|
||||
"github.com/getsentry/sentry-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
skippedFunctions = []string{} //nolint[gochecknoglobals]
|
||||
)
|
||||
var skippedFunctions = []string{} //nolint[gochecknoglobals]
|
||||
|
||||
// ReportSentryCrash reports a sentry crash.
|
||||
func ReportSentryCrash(clientID, appVersion, userAgent string, reportErr error) error {
|
||||
func init() { // nolint[noinit]
|
||||
if err := sentry.Init(sentry.ClientOptions{
|
||||
Dsn: constants.DSNSentry,
|
||||
Release: constants.Revision,
|
||||
BeforeSend: EnhanceSentryEvent,
|
||||
}); err != nil {
|
||||
logrus.WithError(err).Error("Failed to initialize sentry options")
|
||||
}
|
||||
|
||||
sentry.ConfigureScope(func(scope *sentry.Scope) {
|
||||
scope.SetFingerprint([]string{"{{ default }}"})
|
||||
})
|
||||
}
|
||||
|
||||
type userAgentProvider interface {
|
||||
GetUserAgent() string
|
||||
}
|
||||
|
||||
type Reporter struct {
|
||||
appName string
|
||||
appVersion string
|
||||
uap userAgentProvider
|
||||
}
|
||||
|
||||
// NewReporter creates new sentry reporter with appName and appVersion to report.
|
||||
func NewReporter(appName, appVersion string) *Reporter {
|
||||
return &Reporter{
|
||||
appName: appName,
|
||||
appVersion: appVersion,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reporter) SetUserAgentProvider(uap userAgentProvider) {
|
||||
r.uap = uap
|
||||
}
|
||||
|
||||
// Report reports a sentry crash with stacktrace from all goroutines.
|
||||
func (r *Reporter) Report(i interface{}) (err error) {
|
||||
SkipDuringUnwind()
|
||||
if reportErr == nil {
|
||||
|
||||
if os.Getenv("PROTONMAIL_ENV") == "dev" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// In case clientManager is not yet created we can get at least OS string.
|
||||
var userAgent string
|
||||
if r.uap != nil {
|
||||
userAgent = r.uap.GetUserAgent()
|
||||
} else {
|
||||
userAgent = runtime.GOOS
|
||||
}
|
||||
|
||||
reportErr := fmt.Errorf("recover: %v", i)
|
||||
|
||||
tags := map[string]string{
|
||||
"OS": runtime.GOOS,
|
||||
"Client": clientID,
|
||||
"Version": appVersion,
|
||||
"Client": r.appName,
|
||||
"Version": r.appVersion,
|
||||
"UserAgent": userAgent,
|
||||
"UserID": "",
|
||||
}
|
||||
@ -49,18 +97,17 @@ func ReportSentryCrash(clientID, appVersion, userAgent string, reportErr error)
|
||||
sentry.WithScope(func(scope *sentry.Scope) {
|
||||
SkipDuringUnwind()
|
||||
scope.SetTags(tags)
|
||||
eventID := sentry.CaptureException(reportErr)
|
||||
if eventID != nil {
|
||||
if eventID := sentry.CaptureException(reportErr); eventID != nil {
|
||||
reportID = string(*eventID)
|
||||
}
|
||||
})
|
||||
|
||||
if !sentry.Flush(time.Second * 10) {
|
||||
log.WithField("error", reportErr).Error("Failed to report sentry error")
|
||||
return errors.New("failed to report sentry error")
|
||||
}
|
||||
|
||||
log.WithField("error", reportErr).WithField("id", reportID).Warn("Sentry error reported")
|
||||
logrus.WithField("error", reportErr).WithField("id", reportID).Warn("Sentry error reported")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -15,29 +15,29 @@
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// +build !pmapi_prod
|
||||
|
||||
package config
|
||||
// Package signature implements functions to verify files by their detached signatures.
|
||||
package signature
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func (c *Config) GetAPIConfig() *pmapi.ClientConfig {
|
||||
return &pmapi.ClientConfig{
|
||||
AppVersion: c.getAPIOS() + strings.Title(c.appName) + "_" + c.version,
|
||||
ClientID: c.appName,
|
||||
// Verify verifies the given file by its signature using the given armored public key.
|
||||
func Verify(fileBytes, sigBytes []byte, pubKey string) error {
|
||||
key, err := crypto.NewKeyFromArmored(pubKey)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to load key")
|
||||
}
|
||||
}
|
||||
|
||||
func SetClientRoundTripper(_ *pmapi.ClientManager, _ *pmapi.ClientConfig, _ listener.Listener) {
|
||||
// Use the default roundtripper; do nothing.
|
||||
}
|
||||
kr, err := crypto.NewKeyRing(key)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create keyring")
|
||||
}
|
||||
|
||||
func (c *Config) GetRoundTripper(_ *pmapi.ClientManager, _ listener.Listener) http.RoundTripper {
|
||||
return http.DefaultTransport
|
||||
return kr.VerifyDetached(
|
||||
crypto.NewPlainMessage(fileBytes),
|
||||
crypto.NewPGPSignature(sigBytes),
|
||||
crypto.GetUnixTime(),
|
||||
)
|
||||
}
|
||||
76
pkg/tar/tar.go
Normal file
76
pkg/tar/tar.go
Normal file
@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package tar
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func UntarToDir(r io.Reader, dir string) error {
|
||||
tr := tar.NewReader(r)
|
||||
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if header == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
target := filepath.Join(dir, header.Name)
|
||||
|
||||
switch {
|
||||
case header.Typeflag == tar.TypeSymlink:
|
||||
if err := os.Symlink(header.Linkname, target); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case header.FileInfo().IsDir():
|
||||
if err := os.MkdirAll(target, header.FileInfo().Mode()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
f, err := os.Create(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(f, tr); err != nil { // nolint[gosec]
|
||||
return err
|
||||
}
|
||||
if runtime.GOOS != "windows" {
|
||||
if err := f.Chmod(header.FileInfo().Mode()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close file")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user