mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-23 18:36:46 +00:00
We build too many walls and not enough bridges
This commit is contained in:
259
pkg/config/config.go
Normal file
259
pkg/config/config.go
Normal file
@ -0,0 +1,259 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/go-appdir"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
var (
|
||||
log = GetLogEntry("config") //nolint[gochecknoglobals]
|
||||
)
|
||||
|
||||
type appDirProvider interface {
|
||||
UserConfig() string
|
||||
UserCache() string
|
||||
UserLogs() string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
appName string
|
||||
version string
|
||||
revision string
|
||||
cacheVersion string
|
||||
appDirs appDirProvider
|
||||
appDirsVersion appDirProvider
|
||||
apiConfig *pmapi.ClientConfig
|
||||
}
|
||||
|
||||
// New returns fully initialized config struct.
|
||||
// `appName` should be in camelCase format for folder or file names. It's also used in API
|
||||
// as `AppVersion` which is converted to CamelCase.
|
||||
// `version` is the version of the app (e.g. v1.2.3).
|
||||
// `cacheVersion` is the version of the cache files (setting a different number will remove the old ones).
|
||||
func New(appName, version, revision, cacheVersion string) *Config {
|
||||
appDirs := appdir.New(filepath.Join("protonmail", appName))
|
||||
appDirsVersion := appdir.New(filepath.Join("protonmail", appName, cacheVersion))
|
||||
return newConfig(appName, version, revision, cacheVersion, appDirs, appDirsVersion)
|
||||
}
|
||||
|
||||
func newConfig(appName, version, revision, cacheVersion string, appDirs, appDirsVersion appDirProvider) *Config {
|
||||
return &Config{
|
||||
appName: appName,
|
||||
version: version,
|
||||
revision: revision,
|
||||
cacheVersion: cacheVersion,
|
||||
appDirs: appDirs,
|
||||
appDirsVersion: appDirsVersion,
|
||||
apiConfig: &pmapi.ClientConfig{
|
||||
AppVersion: strings.Title(appName) + "_" + version,
|
||||
ClientID: appName,
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{Timeout: 3 * time.Second}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
},
|
||||
// TokenManager should not be required, but PMAPI still doesn't handle not-set cases everywhere.
|
||||
TokenManager: pmapi.NewTokenManager(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDirs creates all folders that are necessary for bridge to properly function.
|
||||
func (c *Config) CreateDirs() error {
|
||||
// Log files.
|
||||
if err := os.MkdirAll(c.appDirs.UserLogs(), 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
// TLS files.
|
||||
if err := os.MkdirAll(c.appDirs.UserConfig(), 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
// Lock, events, preferences, user_info, db files.
|
||||
if err := os.MkdirAll(c.appDirsVersion.UserCache(), 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearData removes all files except the lock file.
|
||||
// The lock file will be removed when the Bridge stops.
|
||||
func (c *Config) ClearData() error {
|
||||
dirs := []string{
|
||||
c.appDirs.UserLogs(),
|
||||
c.appDirs.UserConfig(),
|
||||
c.appDirs.UserCache(),
|
||||
}
|
||||
shouldRemove := func(filePath string) bool {
|
||||
return filePath != c.GetLockPath()
|
||||
}
|
||||
return c.removeAllExcept(dirs, shouldRemove)
|
||||
}
|
||||
|
||||
// ClearOldData removes all old files, such as old log files or old versions of cache and so on.
|
||||
func (c *Config) ClearOldData() error {
|
||||
// `appDirs` is parent for `appDirsVersion`.
|
||||
// `dir` then contains all subfolders and only `cacheVersion` should stay.
|
||||
// But on Windows all files (dirs) are in the same one - we cannot remove log, lock or tls files.
|
||||
dir := c.appDirs.UserCache()
|
||||
|
||||
return c.removeExcept(dir, func(filePath string) bool {
|
||||
fileName := filepath.Base(filePath)
|
||||
return (fileName != c.cacheVersion &&
|
||||
!logFileRgx.MatchString(fileName) &&
|
||||
filePath != c.GetTLSCertPath() &&
|
||||
filePath != c.GetTLSKeyPath() &&
|
||||
filePath != c.GetEventsPath() &&
|
||||
filePath != c.GetIMAPCachePath() &&
|
||||
filePath != c.GetLockPath() &&
|
||||
filePath != c.GetPreferencesPath())
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) removeAllExcept(dirs []string, shouldRemove func(string) bool) error {
|
||||
var result *multierror.Error
|
||||
for _, dir := range dirs {
|
||||
if err := c.removeExcept(dir, shouldRemove); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (c *Config) removeExcept(dir string, shouldRemove func(string) bool) error {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, file := range files {
|
||||
filePath := filepath.Join(dir, file.Name())
|
||||
if !shouldRemove(filePath) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !file.IsDir() {
|
||||
if err := os.RemoveAll(filePath); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
subDir := filepath.Join(dir, file.Name())
|
||||
if err := c.removeExcept(subDir, shouldRemove); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
// Remove dir itself only if it's empty.
|
||||
subFiles, err := ioutil.ReadDir(subDir)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else if len(subFiles) == 0 {
|
||||
if err := os.RemoveAll(subDir); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// IsDevMode should be used for development conditions such us whether to send sentry reports.
|
||||
func (c *Config) IsDevMode() bool {
|
||||
return os.Getenv("PROTONMAIL_ENV") == "dev"
|
||||
}
|
||||
|
||||
// GetLogDir returns folder for log files.
|
||||
func (c *Config) GetLogDir() string {
|
||||
return c.appDirs.UserLogs()
|
||||
}
|
||||
|
||||
// GetLogPrefix returns prefix for log files. Bridge uses format vVERSION.
|
||||
func (c *Config) GetLogPrefix() string {
|
||||
return "v" + c.version + "_" + c.revision
|
||||
}
|
||||
|
||||
// GetTLSCertPath returns path to certificate; used for TLS servers (IMAP, SMTP and API).
|
||||
func (c *Config) GetTLSCertPath() string {
|
||||
return filepath.Join(c.appDirs.UserConfig(), "cert.pem")
|
||||
}
|
||||
|
||||
// GetTLSKeyPath returns path to private key; used for TLS servers (IMAP, SMTP and API).
|
||||
func (c *Config) GetTLSKeyPath() string {
|
||||
return filepath.Join(c.appDirs.UserConfig(), "key.pem")
|
||||
}
|
||||
|
||||
// GetDBDir returns folder for db files.
|
||||
func (c *Config) GetDBDir() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache())
|
||||
}
|
||||
|
||||
// GetEventsPath returns path to events file containing the last processed event IDs.
|
||||
func (c *Config) GetEventsPath() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache(), "events.json")
|
||||
}
|
||||
|
||||
// GetIMAPCachePath returns path to file with IMAP status.
|
||||
func (c *Config) GetIMAPCachePath() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache(), "user_info.json")
|
||||
}
|
||||
|
||||
// GetLockPath returns path to lock file to check if bridge is already running.
|
||||
func (c *Config) GetLockPath() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache(), c.appName+".lock")
|
||||
}
|
||||
|
||||
// GetUpdateDir returns folder for update files; such as new binary.
|
||||
func (c *Config) GetUpdateDir() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache(), "updates")
|
||||
}
|
||||
|
||||
// GetPreferencesPath returns path to preference file.
|
||||
func (c *Config) GetPreferencesPath() string {
|
||||
return filepath.Join(c.appDirsVersion.UserCache(), "prefs.json")
|
||||
}
|
||||
|
||||
// GetAPIConfig returns config for ProtonMail API.
|
||||
func (c *Config) GetAPIConfig() *pmapi.ClientConfig {
|
||||
return c.apiConfig
|
||||
}
|
||||
|
||||
// GetDefaultAPIPort returns default Bridge local API port.
|
||||
func (c *Config) GetDefaultAPIPort() int {
|
||||
return 1042
|
||||
}
|
||||
|
||||
// GetDefaultIMAPPort returns default Bridge IMAP port.
|
||||
func (c *Config) GetDefaultIMAPPort() int {
|
||||
return 1143
|
||||
}
|
||||
|
||||
// GetDefaultSMTPPort returns default Bridge SMTP port.
|
||||
func (c *Config) GetDefaultSMTPPort() int {
|
||||
return 1025
|
||||
}
|
||||
238
pkg/config/config_test.go
Normal file
238
pkg/config/config_test.go
Normal file
@ -0,0 +1,238 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testAppName = "bridge-test"
|
||||
|
||||
var testConfigDir string //nolint[gochecknoglobals]
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
setupTestConfig()
|
||||
setupTestLogs()
|
||||
code := m.Run()
|
||||
shutdownTestConfig()
|
||||
shutdownTestLogs()
|
||||
shutdownTestPreferences()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func setupTestConfig() {
|
||||
var err error
|
||||
testConfigDir, err = ioutil.TempDir("", "config")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownTestConfig() {
|
||||
_ = os.RemoveAll(testConfigDir)
|
||||
}
|
||||
|
||||
type mocks struct {
|
||||
t *testing.T
|
||||
|
||||
ctrl *gomock.Controller
|
||||
appDir *MockappDirer
|
||||
appDirVersion *MockappDirer
|
||||
}
|
||||
|
||||
func initMocks(t *testing.T) mocks {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
return mocks{
|
||||
t: t,
|
||||
|
||||
ctrl: mockCtrl,
|
||||
appDir: NewMockappDirer(mockCtrl),
|
||||
appDirVersion: NewMockappDirer(mockCtrl),
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearDataLinux(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
createTestStructureLinux(m, testConfigDir)
|
||||
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
|
||||
require.NoError(t, cfg.ClearData())
|
||||
checkFileNames(t, testConfigDir, []string{
|
||||
"cache",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"config",
|
||||
"logs",
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearDataWindows(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
createTestStructureWindows(m, testConfigDir)
|
||||
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
|
||||
require.NoError(t, cfg.ClearData())
|
||||
checkFileNames(t, testConfigDir, []string{
|
||||
"cache",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"config",
|
||||
})
|
||||
}
|
||||
|
||||
// OldData touches only cache folder.
|
||||
// Removes only c1 folder as nothing else is part of cache folder on Linux/Mac.
|
||||
func TestClearOldDataLinux(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
createTestStructureLinux(m, testConfigDir)
|
||||
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
|
||||
require.NoError(t, cfg.ClearOldData())
|
||||
checkFileNames(t, testConfigDir, []string{
|
||||
"cache",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"cache/c2/events.json",
|
||||
"cache/c2/mailbox-user@pm.me.db",
|
||||
"cache/c2/prefs.json",
|
||||
"cache/c2/updates",
|
||||
"cache/c2/user_info.json",
|
||||
"config",
|
||||
"config/cert.pem",
|
||||
"config/key.pem",
|
||||
"logs",
|
||||
"logs/other.log",
|
||||
"logs/v1_10.log",
|
||||
"logs/v1_11.log",
|
||||
"logs/v2_12.log",
|
||||
"logs/v2_13.log",
|
||||
})
|
||||
}
|
||||
|
||||
// OldData touches only cache folder. Removes everything except c2 folder
|
||||
// and bridge log files which are part of cache folder on Windows.
|
||||
func TestClearOldDataWindows(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
createTestStructureWindows(m, testConfigDir)
|
||||
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
|
||||
require.NoError(t, cfg.ClearOldData())
|
||||
checkFileNames(t, testConfigDir, []string{
|
||||
"cache",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"cache/c2/events.json",
|
||||
"cache/c2/mailbox-user@pm.me.db",
|
||||
"cache/c2/prefs.json",
|
||||
"cache/c2/updates",
|
||||
"cache/c2/user_info.json",
|
||||
"cache/v1_10.log",
|
||||
"cache/v1_11.log",
|
||||
"cache/v2_12.log",
|
||||
"cache/v2_13.log",
|
||||
"config",
|
||||
"config/cert.pem",
|
||||
"config/key.pem",
|
||||
})
|
||||
}
|
||||
|
||||
func createTestStructureLinux(m mocks, baseDir string) {
|
||||
logsDir := filepath.Join(baseDir, "logs")
|
||||
configDir := filepath.Join(baseDir, "config")
|
||||
cacheDir := filepath.Join(baseDir, "cache")
|
||||
versionedOldCacheDir := filepath.Join(baseDir, "cache", "c1")
|
||||
versionedCacheDir := filepath.Join(baseDir, "cache", "c2")
|
||||
createTestStructure(m, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir)
|
||||
}
|
||||
|
||||
func createTestStructureWindows(m mocks, baseDir string) {
|
||||
logsDir := filepath.Join(baseDir, "cache")
|
||||
configDir := filepath.Join(baseDir, "config")
|
||||
cacheDir := filepath.Join(baseDir, "cache")
|
||||
versionedOldCacheDir := filepath.Join(baseDir, "cache", "c1")
|
||||
versionedCacheDir := filepath.Join(baseDir, "cache", "c2")
|
||||
createTestStructure(m, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir)
|
||||
}
|
||||
|
||||
func createTestStructure(m mocks, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir string) {
|
||||
m.appDir.EXPECT().UserLogs().Return(logsDir).AnyTimes()
|
||||
m.appDir.EXPECT().UserConfig().Return(configDir).AnyTimes()
|
||||
m.appDir.EXPECT().UserCache().Return(cacheDir).AnyTimes()
|
||||
m.appDirVersion.EXPECT().UserCache().Return(versionedCacheDir).AnyTimes()
|
||||
|
||||
require.NoError(m.t, os.RemoveAll(baseDir))
|
||||
require.NoError(m.t, os.MkdirAll(baseDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(logsDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(configDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(cacheDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(versionedOldCacheDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(versionedCacheDir, 0700))
|
||||
require.NoError(m.t, os.MkdirAll(filepath.Join(versionedCacheDir, "updates"), 0700))
|
||||
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "other.log"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v1_10.log"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v1_11.log"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v2_12.log"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v2_13.log"), []byte("Hello"), 0755))
|
||||
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(configDir, "cert.pem"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(configDir, "key.pem"), []byte("Hello"), 0755))
|
||||
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "prefs.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "events.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "user_info.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "mailbox-user@pm.me.db"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "prefs.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "events.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "user_info.json"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, testAppName+".lock"), []byte("Hello"), 0755))
|
||||
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "mailbox-user@pm.me.db"), []byte("Hello"), 0755))
|
||||
}
|
||||
|
||||
func checkFileNames(t *testing.T, dir string, expectedFileNames []string) {
|
||||
fileNames := getFileNames(t, dir)
|
||||
require.Equal(t, expectedFileNames, fileNames)
|
||||
}
|
||||
|
||||
func getFileNames(t *testing.T, dir string) []string {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
require.NoError(t, err)
|
||||
|
||||
fileNames := []string{}
|
||||
for _, file := range files {
|
||||
fileNames = append(fileNames, file.Name())
|
||||
if file.IsDir() {
|
||||
subDir := filepath.Join(dir, file.Name())
|
||||
subFileNames := getFileNames(t, subDir)
|
||||
for _, subFileName := range subFileNames {
|
||||
fileNames = append(fileNames, file.Name()+"/"+subFileName)
|
||||
}
|
||||
}
|
||||
}
|
||||
return fileNames
|
||||
}
|
||||
252
pkg/config/logs.go
Normal file
252
pkg/config/logs.go
Normal file
@ -0,0 +1,252 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type logConfiger interface {
|
||||
GetLogDir() string
|
||||
GetLogPrefix() string
|
||||
}
|
||||
|
||||
const (
|
||||
// Zendesk now has a file size limit of 20MB. When the last N log files
|
||||
// are zipped, it should fit under 20MB. Value in MB (average file has
|
||||
// few hundreds kB).
|
||||
maxLogFileSize = 10 * 1024 * 1024 //nolint[gochecknoglobals]
|
||||
// Including the current logfile.
|
||||
maxNumberLogFiles = 3 //nolint[gochecknoglobals]
|
||||
)
|
||||
|
||||
// logFile is pointer to currently open file used by logrus.
|
||||
var logFile *os.File //nolint[gochecknoglobals]
|
||||
|
||||
var logFileRgx = regexp.MustCompile("^v.*\\.log$") //nolint[gochecknoglobals]
|
||||
var logCrashRgx = regexp.MustCompile("^v.*_crash_.*\\.log$") //nolint[gochecknoglobals]
|
||||
|
||||
// GetLogEntry returns logrus.Entry with PID and `packageName`.
|
||||
func GetLogEntry(packageName string) *logrus.Entry {
|
||||
return logrus.WithFields(logrus.Fields{
|
||||
"pkg": packageName,
|
||||
})
|
||||
}
|
||||
|
||||
// HandlePanic reports the crash to sentry or local file when sentry fails.
|
||||
func HandlePanic(cfg *Config, output string) {
|
||||
if !cfg.IsDevMode() {
|
||||
c := pmapi.NewClient(cfg.GetAPIConfig(), "no-user-id")
|
||||
err := c.ReportSentryCrash(fmt.Errorf(output))
|
||||
if err != nil {
|
||||
log.Error("Sentry crash report failed: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
filename := getLogFilename(cfg.GetLogPrefix() + "_crash_")
|
||||
filepath := filepath.Join(cfg.GetLogDir(), filename)
|
||||
f, err := os.OpenFile(filepath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
|
||||
if err != nil {
|
||||
log.Error("Cannot open file to write crash report: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = f.WriteString(output)
|
||||
_ = pprof.Lookup("goroutine").WriteTo(f, 2)
|
||||
|
||||
log.Warn("Crash report saved to ", filepath)
|
||||
}
|
||||
|
||||
// GetGID returns goroutine number which can be used to distiguish logs from
|
||||
// the concurent processes. Keep in mind that it returns the number of routine
|
||||
// which executes the function.
|
||||
func GetGID() uint64 {
|
||||
b := make([]byte, 64)
|
||||
b = b[:runtime.Stack(b, false)]
|
||||
b = bytes.TrimPrefix(b, []byte("goroutine "))
|
||||
b = b[:bytes.IndexByte(b, ' ')]
|
||||
n, _ := strconv.ParseUint(string(b), 10, 64)
|
||||
return n
|
||||
}
|
||||
|
||||
// SetupLog set up log level, formatter and output (file or stdout).
|
||||
// Returns whether should be used debug for IMAP and SMTP servers.
|
||||
func SetupLog(cfg logConfiger, levelFlag string) (debugClient, debugServer bool) {
|
||||
level, useFile := getLogLevelAndFile(levelFlag)
|
||||
|
||||
logrus.SetLevel(level)
|
||||
|
||||
if useFile {
|
||||
logrus.SetFormatter(&logrus.JSONFormatter{})
|
||||
setLogFile(cfg.GetLogDir(), cfg.GetLogPrefix())
|
||||
watchLogFileSize(cfg.GetLogDir(), cfg.GetLogPrefix())
|
||||
} else {
|
||||
logrus.SetFormatter(&logrus.TextFormatter{
|
||||
ForceColors: true,
|
||||
FullTimestamp: true,
|
||||
TimestampFormat: time.StampMilli,
|
||||
})
|
||||
logrus.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
switch levelFlag {
|
||||
case "debug-client", "debug-client-json":
|
||||
debugClient = true
|
||||
case "debug-server", "debug-server-json", "trace":
|
||||
fmt.Println("THE LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
|
||||
log.Warning("================================================")
|
||||
log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
|
||||
log.Warning("================================================")
|
||||
debugClient = true
|
||||
debugServer = true
|
||||
}
|
||||
|
||||
return debugClient, debugServer
|
||||
}
|
||||
|
||||
func setLogFile(logDir, logPrefix string) {
|
||||
if logFile != nil {
|
||||
return
|
||||
}
|
||||
|
||||
filename := getLogFilename(logPrefix)
|
||||
var err error
|
||||
logFile, err = os.OpenFile(filepath.Join(logDir, filename), os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
logrus.SetOutput(logFile)
|
||||
|
||||
// Users sometimes change the name of the log file. We want to always log
|
||||
// information about bridge version (included in log prefix) and OS.
|
||||
log.Warn("Bridge version: ", logPrefix, " ", runtime.GOOS)
|
||||
}
|
||||
|
||||
func getLogFilename(logPrefix string) string {
|
||||
currentTime := strconv.Itoa(int(time.Now().Unix()))
|
||||
return logPrefix + "_" + currentTime + ".log"
|
||||
}
|
||||
|
||||
func watchLogFileSize(logDir, logPrefix string) {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(60 * time.Second)
|
||||
checkLogFileSize(logDir, logPrefix)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func checkLogFileSize(logDir, logPrefix string) {
|
||||
if logFile == nil {
|
||||
return
|
||||
}
|
||||
|
||||
stat, err := logFile.Stat()
|
||||
if err != nil {
|
||||
log.Error("Log file size check failed: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if stat.Size() >= maxLogFileSize {
|
||||
log.Warn("Current log file ", logFile.Name(), " is too big, opening new file")
|
||||
closeLogFile()
|
||||
setLogFile(logDir, logPrefix)
|
||||
}
|
||||
|
||||
if err := clearLogs(logDir); err != nil {
|
||||
log.Error("Cannot clear logs ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func closeLogFile() {
|
||||
if logFile != nil {
|
||||
_ = logFile.Close()
|
||||
logFile = nil
|
||||
}
|
||||
}
|
||||
|
||||
func clearLogs(logDir string) error {
|
||||
files, err := ioutil.ReadDir(logDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var logsWithPrefix []string
|
||||
var crashesWithPrefix []string
|
||||
|
||||
for _, file := range files {
|
||||
if logFileRgx.MatchString(file.Name()) {
|
||||
if logCrashRgx.MatchString(file.Name()) {
|
||||
crashesWithPrefix = append(crashesWithPrefix, file.Name())
|
||||
} else {
|
||||
logsWithPrefix = append(logsWithPrefix, file.Name())
|
||||
}
|
||||
} else {
|
||||
// Older versions of Bridge stored logs in subfolders for each version.
|
||||
// That also has to be cleared and the functionality can be removed after some time.
|
||||
if file.IsDir() {
|
||||
if err := clearLogs(filepath.Join(logDir, file.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
removeLog(logDir, file.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
removeOldLogs(logDir, logsWithPrefix)
|
||||
removeOldLogs(logDir, crashesWithPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeOldLogs(logDir string, filenames []string) {
|
||||
count := len(filenames)
|
||||
if count <= maxNumberLogFiles {
|
||||
return
|
||||
}
|
||||
|
||||
sort.Strings(filenames) // Sorted by timestamp: oldest first.
|
||||
for _, filename := range filenames[:count-maxNumberLogFiles] {
|
||||
removeLog(logDir, filename)
|
||||
}
|
||||
}
|
||||
|
||||
func removeLog(logDir, filename string) {
|
||||
// We need to be sure to delete only log files.
|
||||
// Directory with logs can also contain other files.
|
||||
if !logFileRgx.MatchString(filename) {
|
||||
return
|
||||
}
|
||||
if err := os.RemoveAll(filepath.Join(logDir, filename)); err != nil {
|
||||
log.Error("Cannot remove old logs ", err)
|
||||
}
|
||||
}
|
||||
49
pkg/config/logs_all.go
Normal file
49
pkg/config/logs_all.go
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// +build !build_qa
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func getLogLevelAndFile(levelFlag string) (level logrus.Level, useFile bool) {
|
||||
useFile = true
|
||||
switch levelFlag {
|
||||
case "panic":
|
||||
level = logrus.PanicLevel
|
||||
case "fatal":
|
||||
level = logrus.FatalLevel
|
||||
case "error":
|
||||
level = logrus.ErrorLevel
|
||||
case "warn":
|
||||
level = logrus.WarnLevel
|
||||
case "info":
|
||||
level = logrus.InfoLevel
|
||||
case "debug", "debug-client", "debug-server", "debug-client-json", "debug-server-json":
|
||||
level = logrus.DebugLevel
|
||||
useFile = false
|
||||
case "trace":
|
||||
level = logrus.TraceLevel
|
||||
useFile = false
|
||||
default:
|
||||
level = logrus.InfoLevel
|
||||
}
|
||||
return
|
||||
}
|
||||
50
pkg/config/logs_qa.go
Normal file
50
pkg/config/logs_qa.go
Normal file
@ -0,0 +1,50 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// +build build_qa
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// getLogLevelAndFile for QA build is altered in a way even decrypted data are stored
|
||||
// in the log file when forced with `debug-client-json` or `debug-server-json`.
|
||||
func getLogLevelAndFile(levelFlag string) (level logrus.Level, useFile bool) {
|
||||
useFile = true
|
||||
switch levelFlag {
|
||||
case "panic":
|
||||
level = logrus.PanicLevel
|
||||
case "fatal":
|
||||
level = logrus.FatalLevel
|
||||
case "error":
|
||||
level = logrus.ErrorLevel
|
||||
case "warn":
|
||||
level = logrus.WarnLevel
|
||||
case "info":
|
||||
level = logrus.InfoLevel
|
||||
case "debug-client-json", "debug-server-json":
|
||||
level = logrus.DebugLevel
|
||||
case "debug", "debug-client", "debug-server":
|
||||
level = logrus.DebugLevel
|
||||
useFile = false
|
||||
default:
|
||||
level = logrus.InfoLevel
|
||||
}
|
||||
return
|
||||
}
|
||||
225
pkg/config/logs_test.go
Normal file
225
pkg/config/logs_test.go
Normal file
@ -0,0 +1,225 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testLogConfig struct{ logDir, logPrefix string }
|
||||
|
||||
func (c *testLogConfig) GetLogDir() string { return c.logDir }
|
||||
func (c *testLogConfig) GetLogPrefix() string { return c.logPrefix }
|
||||
|
||||
var testLogDir string //nolint[gochecknoglobals]
|
||||
|
||||
func setupTestLogs() {
|
||||
var err error
|
||||
testLogDir, err = ioutil.TempDir("", "log")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownTestLogs() {
|
||||
_ = os.RemoveAll(testLogDir)
|
||||
}
|
||||
|
||||
func TestLogNameLength(t *testing.T) {
|
||||
cfg := New("bridge-test", "longVersion123", "longRevision1234567890", "c2")
|
||||
name := getLogFilename(cfg.GetLogPrefix())
|
||||
if len(name) > 128 {
|
||||
t.Fatal("Name of the log is too long - limit for encrypted linux is 128 characters")
|
||||
}
|
||||
}
|
||||
|
||||
// Info and higher levels writes to the file.
|
||||
func TestSetupLogInfo(t *testing.T) {
|
||||
dir := beforeEachCreateTestDir(t, "setupInfo")
|
||||
|
||||
SetupLog(&testLogConfig{dir, "v"}, "info")
|
||||
require.Equal(t, "info", logrus.GetLevel().String())
|
||||
|
||||
logrus.Info("test message")
|
||||
files := checkLogFiles(t, dir, 1)
|
||||
checkLogContains(t, dir, files[0].Name(), "test message")
|
||||
}
|
||||
|
||||
// Debug levels writes to stdout.
|
||||
func TestSetupLogDebug(t *testing.T) {
|
||||
dir := beforeEachCreateTestDir(t, "setupDebug")
|
||||
|
||||
SetupLog(&testLogConfig{dir, "v"}, "debug")
|
||||
require.Equal(t, "debug", logrus.GetLevel().String())
|
||||
|
||||
logrus.Info("test message")
|
||||
checkLogFiles(t, dir, 0)
|
||||
}
|
||||
|
||||
func TestReopenLogFile(t *testing.T) {
|
||||
dir := beforeEachCreateTestDir(t, "reopenLogFile")
|
||||
|
||||
setLogFile(dir, "v1")
|
||||
|
||||
done := make(chan interface{})
|
||||
|
||||
log.Info("first message")
|
||||
|
||||
go func() {
|
||||
<-done // Wait for closing file and opening new one.
|
||||
log.Info("second message")
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
closeLogFile()
|
||||
setLogFile(dir, "v2")
|
||||
|
||||
done <- nil
|
||||
<-done // Wait for second log message.
|
||||
|
||||
files := checkLogFiles(t, dir, 2)
|
||||
checkLogContains(t, dir, files[0].Name(), "first message")
|
||||
checkLogContains(t, dir, files[1].Name(), "second message")
|
||||
}
|
||||
|
||||
func TestCheckLogFileSizeSmall(t *testing.T) {
|
||||
dir := beforeEachCreateTestDir(t, "logFileSizeSmall")
|
||||
|
||||
setLogFile(dir, "v1")
|
||||
originalFileName := logFile.Name()
|
||||
|
||||
_, _ = logFile.WriteString("small file")
|
||||
checkLogFileSize(dir, "v2")
|
||||
|
||||
require.Equal(t, originalFileName, logFile.Name())
|
||||
}
|
||||
|
||||
func TestCheckLogFileSizeBig(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
dir := beforeEachCreateTestDir(t, "logFileSizeBig")
|
||||
|
||||
setLogFile(dir, "v1")
|
||||
originalFileName := logFile.Name()
|
||||
|
||||
// The limit for big file is 10*1024*1024 - keep the string 10 letters long.
|
||||
for i := 0; i < 1024*1024; i++ {
|
||||
_, _ = logFile.WriteString("big file!\n")
|
||||
}
|
||||
checkLogFileSize(dir, "v2")
|
||||
|
||||
require.NotEqual(t, originalFileName, logFile.Name())
|
||||
}
|
||||
|
||||
// ClearLogs removes only bridge old log files keeping last three of them.
|
||||
func TestClearLogsLinux(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
dir := beforeEachCreateTestDir(t, "clearLogs")
|
||||
|
||||
createTestStructureLinux(m, dir)
|
||||
require.NoError(t, clearLogs(dir))
|
||||
checkFileNames(t, dir, []string{
|
||||
"cache",
|
||||
"cache/c1",
|
||||
"cache/c1/events.json",
|
||||
"cache/c1/mailbox-user@pm.me.db",
|
||||
"cache/c1/prefs.json",
|
||||
"cache/c1/user_info.json",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"cache/c2/events.json",
|
||||
"cache/c2/mailbox-user@pm.me.db",
|
||||
"cache/c2/prefs.json",
|
||||
"cache/c2/updates",
|
||||
"cache/c2/user_info.json",
|
||||
"config",
|
||||
"config/cert.pem",
|
||||
"config/key.pem",
|
||||
"logs",
|
||||
"logs/other.log",
|
||||
"logs/v1_11.log",
|
||||
"logs/v2_12.log",
|
||||
"logs/v2_13.log",
|
||||
})
|
||||
}
|
||||
|
||||
// ClearLogs removes only bridge old log files even when log folder
|
||||
// is shared with other files on Windows.
|
||||
func TestClearLogsWindows(t *testing.T) {
|
||||
m := initMocks(t)
|
||||
defer m.ctrl.Finish()
|
||||
|
||||
dir := beforeEachCreateTestDir(t, "clearLogs")
|
||||
|
||||
createTestStructureWindows(m, dir)
|
||||
require.NoError(t, clearLogs(dir))
|
||||
checkFileNames(t, dir, []string{
|
||||
"cache",
|
||||
"cache/c1",
|
||||
"cache/c1/events.json",
|
||||
"cache/c1/mailbox-user@pm.me.db",
|
||||
"cache/c1/prefs.json",
|
||||
"cache/c1/user_info.json",
|
||||
"cache/c2",
|
||||
"cache/c2/bridge-test.lock",
|
||||
"cache/c2/events.json",
|
||||
"cache/c2/mailbox-user@pm.me.db",
|
||||
"cache/c2/prefs.json",
|
||||
"cache/c2/updates",
|
||||
"cache/c2/user_info.json",
|
||||
"cache/other.log",
|
||||
"cache/v1_11.log",
|
||||
"cache/v2_12.log",
|
||||
"cache/v2_13.log",
|
||||
"config",
|
||||
"config/cert.pem",
|
||||
"config/key.pem",
|
||||
})
|
||||
}
|
||||
|
||||
func beforeEachCreateTestDir(t *testing.T, dir string) string {
|
||||
// Make sure opened file (from the previous test) is cleared.
|
||||
closeLogFile()
|
||||
|
||||
dir = filepath.Join(testLogDir, dir)
|
||||
require.NoError(t, os.MkdirAll(dir, 0700))
|
||||
return dir
|
||||
}
|
||||
|
||||
func checkLogFiles(t *testing.T, dir string, expectedCount int) []os.FileInfo {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedCount, len(files))
|
||||
return files
|
||||
}
|
||||
|
||||
func checkLogContains(t *testing.T, dir, fileName, expectedSubstr string) {
|
||||
data, err := ioutil.ReadFile(filepath.Join(dir, fileName)) //nolint[gosec]
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(data), expectedSubstr)
|
||||
}
|
||||
76
pkg/config/mock_config.go
Normal file
76
pkg/config/mock_config.go
Normal file
@ -0,0 +1,76 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: config/config.go
|
||||
|
||||
// Package config is a generated GoMock package.
|
||||
package config
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockappDirer is a mock of appDirer interface
|
||||
type MockappDirer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockappDirerMockRecorder
|
||||
}
|
||||
|
||||
// MockappDirerMockRecorder is the mock recorder for MockappDirer
|
||||
type MockappDirerMockRecorder struct {
|
||||
mock *MockappDirer
|
||||
}
|
||||
|
||||
// NewMockappDirer creates a new mock instance
|
||||
func NewMockappDirer(ctrl *gomock.Controller) *MockappDirer {
|
||||
mock := &MockappDirer{ctrl: ctrl}
|
||||
mock.recorder = &MockappDirerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockappDirer) EXPECT() *MockappDirerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// UserConfig mocks base method
|
||||
func (m *MockappDirer) UserConfig() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserConfig")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UserConfig indicates an expected call of UserConfig
|
||||
func (mr *MockappDirerMockRecorder) UserConfig() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserConfig", reflect.TypeOf((*MockappDirer)(nil).UserConfig))
|
||||
}
|
||||
|
||||
// UserCache mocks base method
|
||||
func (m *MockappDirer) UserCache() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserCache")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UserCache indicates an expected call of UserCache
|
||||
func (mr *MockappDirerMockRecorder) UserCache() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCache", reflect.TypeOf((*MockappDirer)(nil).UserCache))
|
||||
}
|
||||
|
||||
// UserLogs mocks base method
|
||||
func (m *MockappDirer) UserLogs() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserLogs")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UserLogs indicates an expected call of UserLogs
|
||||
func (mr *MockappDirerMockRecorder) UserLogs() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserLogs", reflect.TypeOf((*MockappDirer)(nil).UserLogs))
|
||||
}
|
||||
127
pkg/config/preferences.go
Normal file
127
pkg/config/preferences.go
Normal file
@ -0,0 +1,127 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Preferences struct {
|
||||
cache map[string]string
|
||||
path string
|
||||
lock *sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPreferences returns loaded preferences.
|
||||
func NewPreferences(preferencesPath string) *Preferences {
|
||||
p := &Preferences{
|
||||
path: preferencesPath,
|
||||
lock: &sync.RWMutex{},
|
||||
}
|
||||
if err := p.load(); err != nil {
|
||||
log.Warn("Cannot load preferences: ", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Preferences) load() error {
|
||||
if p.cache != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
p.cache = map[string]string{}
|
||||
|
||||
f, err := os.Open(p.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close() //nolint[errcheck]
|
||||
|
||||
return json.NewDecoder(f).Decode(&p.cache)
|
||||
}
|
||||
|
||||
func (p *Preferences) save() error {
|
||||
if p.cache == nil {
|
||||
return errors.New("cannot save preferences: cache is nil")
|
||||
}
|
||||
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
f, err := os.Create(p.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close() //nolint[errcheck]
|
||||
|
||||
return json.NewEncoder(f).Encode(p.cache)
|
||||
}
|
||||
|
||||
func (p *Preferences) SetDefault(key, value string) {
|
||||
if p.Get(key) == "" {
|
||||
p.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Preferences) Get(key string) string {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
|
||||
return p.cache[key]
|
||||
}
|
||||
|
||||
func (p *Preferences) GetBool(key string) bool {
|
||||
return p.Get(key) == "true"
|
||||
}
|
||||
|
||||
func (p *Preferences) GetInt(key string) int {
|
||||
value, err := strconv.Atoi(p.Get(key))
|
||||
if err != nil {
|
||||
log.Error("Cannot parse int: ", err)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (p *Preferences) Set(key, value string) {
|
||||
p.lock.Lock()
|
||||
p.cache[key] = value
|
||||
p.lock.Unlock()
|
||||
|
||||
if err := p.save(); err != nil {
|
||||
log.Warn("Cannot save preferences: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Preferences) SetBool(key string, value bool) {
|
||||
if value {
|
||||
p.Set(key, "true")
|
||||
} else {
|
||||
p.Set(key, "false")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Preferences) SetInt(key string, value int) {
|
||||
p.Set(key, strconv.Itoa(value))
|
||||
}
|
||||
109
pkg/config/preferences_test.go
Normal file
109
pkg/config/preferences_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testPrefFilePath = "/tmp/pref.json"
|
||||
|
||||
func shutdownTestPreferences() {
|
||||
_ = os.RemoveAll(testPrefFilePath)
|
||||
}
|
||||
|
||||
func TestLoadNoPreferences(t *testing.T) {
|
||||
pref := newTestEmptyPreferences(t)
|
||||
require.Equal(t, "", pref.Get("key"))
|
||||
}
|
||||
|
||||
func TestLoadBadPreferences(t *testing.T) {
|
||||
require.NoError(t, ioutil.WriteFile(testPrefFilePath, []byte("{\"key\":\"value"), 0700))
|
||||
pref := NewPreferences(testPrefFilePath)
|
||||
require.Equal(t, "", pref.Get("key"))
|
||||
}
|
||||
|
||||
func TestPreferencesGet(t *testing.T) {
|
||||
pref := newTestPreferences(t)
|
||||
require.Equal(t, "value", pref.Get("str"))
|
||||
require.Equal(t, "42", pref.Get("int"))
|
||||
require.Equal(t, "true", pref.Get("bool"))
|
||||
require.Equal(t, "t", pref.Get("falseBool"))
|
||||
}
|
||||
|
||||
func TestPreferencesGetInt(t *testing.T) {
|
||||
pref := newTestPreferences(t)
|
||||
require.Equal(t, 0, pref.GetInt("str"))
|
||||
require.Equal(t, 42, pref.GetInt("int"))
|
||||
require.Equal(t, 0, pref.GetInt("bool"))
|
||||
require.Equal(t, 0, pref.GetInt("falseBool"))
|
||||
}
|
||||
|
||||
func TestPreferencesGetBool(t *testing.T) {
|
||||
pref := newTestPreferences(t)
|
||||
require.Equal(t, false, pref.GetBool("str"))
|
||||
require.Equal(t, false, pref.GetBool("int"))
|
||||
require.Equal(t, true, pref.GetBool("bool"))
|
||||
require.Equal(t, false, pref.GetBool("falseBool"))
|
||||
}
|
||||
|
||||
func TestPreferencesSetDefault(t *testing.T) {
|
||||
pref := newTestEmptyPreferences(t)
|
||||
pref.SetDefault("key", "value")
|
||||
pref.SetDefault("key", "othervalue")
|
||||
require.Equal(t, "value", pref.Get("key"))
|
||||
}
|
||||
|
||||
func TestPreferencesSet(t *testing.T) {
|
||||
pref := newTestEmptyPreferences(t)
|
||||
pref.Set("str", "value")
|
||||
checkSavedPreferences(t, "{\"str\":\"value\"}")
|
||||
}
|
||||
|
||||
func TestPreferencesSetInt(t *testing.T) {
|
||||
pref := newTestEmptyPreferences(t)
|
||||
pref.SetInt("int", 42)
|
||||
checkSavedPreferences(t, "{\"int\":\"42\"}")
|
||||
}
|
||||
|
||||
func TestPreferencesSetBool(t *testing.T) {
|
||||
pref := newTestEmptyPreferences(t)
|
||||
pref.SetBool("trueBool", true)
|
||||
pref.SetBool("falseBool", false)
|
||||
checkSavedPreferences(t, "{\"falseBool\":\"false\",\"trueBool\":\"true\"}")
|
||||
}
|
||||
|
||||
func newTestEmptyPreferences(t *testing.T) *Preferences {
|
||||
require.NoError(t, os.RemoveAll(testPrefFilePath))
|
||||
return NewPreferences(testPrefFilePath)
|
||||
}
|
||||
|
||||
func newTestPreferences(t *testing.T) *Preferences {
|
||||
require.NoError(t, ioutil.WriteFile(testPrefFilePath, []byte("{\"str\":\"value\",\"int\":\"42\",\"bool\":\"true\",\"falseBool\":\"t\"}"), 0700))
|
||||
return NewPreferences(testPrefFilePath)
|
||||
}
|
||||
|
||||
func checkSavedPreferences(t *testing.T, expected string) {
|
||||
data, err := ioutil.ReadFile(testPrefFilePath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected+"\n", string(data))
|
||||
}
|
||||
170
pkg/config/tls.go
Normal file
170
pkg/config/tls.go
Normal file
@ -0,0 +1,170 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/osext"
|
||||
)
|
||||
|
||||
type tlsConfiger interface {
|
||||
GetTLSCertPath() string
|
||||
GetTLSKeyPath() string
|
||||
}
|
||||
|
||||
var tlsTemplate = x509.Certificate{ //nolint[gochecknoglobals]
|
||||
SerialNumber: big.NewInt(-1),
|
||||
Subject: pkix.Name{
|
||||
Country: []string{"CH"},
|
||||
Organization: []string{"Proton Technologies AG"},
|
||||
OrganizationalUnit: []string{"ProtonMail"},
|
||||
CommonName: "127.0.0.1",
|
||||
},
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(20 * 365 * 24 * time.Hour),
|
||||
}
|
||||
|
||||
var ErrTLSCertExpireSoon = fmt.Errorf("TLS certificate will expire soon")
|
||||
|
||||
// GetTLSConfig tries to load TLS config or generate new one which is then returned.
|
||||
func GetTLSConfig(cfg tlsConfiger) (tlsConfig *tls.Config, err error) {
|
||||
certPath := cfg.GetTLSCertPath()
|
||||
keyPath := cfg.GetTLSKeyPath()
|
||||
tlsConfig, err = loadTLSConfig(certPath, keyPath)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Cannot load cert, generating a new one")
|
||||
tlsConfig, err = generateTLSConfig(certPath, keyPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
// If this fails, log the error but continue to load.
|
||||
if p, err := osext.Executable(); err == nil {
|
||||
p = strings.TrimSuffix(p, "MacOS/Desktop-Bridge") // This needs to match the executable name.
|
||||
p += "Resources/addcert.scpt"
|
||||
if err := exec.Command("/usr/bin/osascript", p).Run(); err != nil { // nolint[gosec]
|
||||
log.WithError(err).Error("Failed to add cert to system keychain")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig.ServerName = "127.0.0.1"
|
||||
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(tlsConfig.Certificates[0].Leaf)
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
|
||||
/* This is deprecated:
|
||||
* SA1019: tlsConfig.BuildNameToCertificate is deprecated:
|
||||
* NameToCertificate only allows associating a single certificate with a given name.
|
||||
* Leave that field nil to let the library select the first compatible chain from Certificates.
|
||||
*/
|
||||
tlsConfig.BuildNameToCertificate() // nolint[staticcheck]
|
||||
|
||||
return tlsConfig, err
|
||||
}
|
||||
|
||||
func loadTLSConfig(certPath, keyPath string) (tlsConfig *tls.Config, err error) {
|
||||
c, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.Leaf, err = x509.ParseCertificate(c.Certificate[0])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tlsConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{c},
|
||||
}
|
||||
|
||||
if time.Now().Add(31 * 24 * time.Hour).After(c.Leaf.NotAfter) {
|
||||
err = ErrTLSCertExpireSoon
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// See https://golang.org/src/crypto/tls/generate_cert.go
|
||||
func generateTLSConfig(certPath, keyPath string) (tlsConfig *tls.Config, err error) {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to generate private key: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to generate serial number: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
tlsTemplate.SerialNumber = serialNumber
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &tlsTemplate, &tlsTemplate, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to create certificate: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
certOut, err := os.Create(certPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer certOut.Close() //nolint[errcheck]
|
||||
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer keyOut.Close() //nolint[errcheck]
|
||||
err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return loadTLSConfig(certPath, keyPath)
|
||||
}
|
||||
63
pkg/config/tls_test.go
Normal file
63
pkg/config/tls_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testTLSConfig struct{ certPath, keyPath string }
|
||||
|
||||
func (c *testTLSConfig) GetTLSCertPath() string { return c.certPath }
|
||||
func (c *testTLSConfig) GetTLSKeyPath() string { return c.keyPath }
|
||||
|
||||
func TestTLSKeyRenewal(t *testing.T) {
|
||||
// Remove keys.
|
||||
configPath := "/tmp"
|
||||
certPath := filepath.Join(configPath, "cert.pem")
|
||||
keyPath := filepath.Join(configPath, "key.pem")
|
||||
_ = os.Remove(certPath)
|
||||
_ = os.Remove(keyPath)
|
||||
|
||||
// Put old key there.
|
||||
tlsTemplate.NotBefore = time.Now().Add(-365 * 24 * time.Hour)
|
||||
tlsTemplate.NotAfter = time.Now()
|
||||
cert, err := generateTLSConfig(certPath, keyPath)
|
||||
require.Equal(t, err, ErrTLSCertExpireSoon)
|
||||
require.Equal(t, len(cert.Certificates), 1)
|
||||
time.Sleep(time.Second)
|
||||
now, notValidAfter := time.Now(), cert.Certificates[0].Leaf.NotAfter
|
||||
require.True(t, now.After(notValidAfter), "old certificate expected to not be valid at %v but have valid until %v", now, notValidAfter)
|
||||
|
||||
// Renew key.
|
||||
tlsTemplate.NotBefore = time.Now()
|
||||
tlsTemplate.NotAfter = time.Now().Add(2 * 365 * 24 * time.Hour)
|
||||
cert, err = GetTLSConfig(&testTLSConfig{certPath, keyPath})
|
||||
if runtime.GOOS != "darwin" { // Darwin is not supported.
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, len(cert.Certificates), 1)
|
||||
now, notValidAfter = time.Now(), cert.Certificates[0].Leaf.NotAfter
|
||||
require.False(t, now.After(notValidAfter), "new certificate expected to be valid at %v but have valid until %v", now, notValidAfter)
|
||||
}
|
||||
Reference in New Issue
Block a user