We build too many walls and not enough bridges

This commit is contained in:
Jakub
2020-04-08 12:59:16 +02:00
commit 17f4d6097a
494 changed files with 62753 additions and 0 deletions

259
pkg/config/config.go Normal file
View File

@ -0,0 +1,259 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/ProtonMail/go-appdir"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/hashicorp/go-multierror"
)
var (
log = GetLogEntry("config") //nolint[gochecknoglobals]
)
type appDirProvider interface {
UserConfig() string
UserCache() string
UserLogs() string
}
type Config struct {
appName string
version string
revision string
cacheVersion string
appDirs appDirProvider
appDirsVersion appDirProvider
apiConfig *pmapi.ClientConfig
}
// New returns fully initialized config struct.
// `appName` should be in camelCase format for folder or file names. It's also used in API
// as `AppVersion` which is converted to CamelCase.
// `version` is the version of the app (e.g. v1.2.3).
// `cacheVersion` is the version of the cache files (setting a different number will remove the old ones).
func New(appName, version, revision, cacheVersion string) *Config {
appDirs := appdir.New(filepath.Join("protonmail", appName))
appDirsVersion := appdir.New(filepath.Join("protonmail", appName, cacheVersion))
return newConfig(appName, version, revision, cacheVersion, appDirs, appDirsVersion)
}
func newConfig(appName, version, revision, cacheVersion string, appDirs, appDirsVersion appDirProvider) *Config {
return &Config{
appName: appName,
version: version,
revision: revision,
cacheVersion: cacheVersion,
appDirs: appDirs,
appDirsVersion: appDirsVersion,
apiConfig: &pmapi.ClientConfig{
AppVersion: strings.Title(appName) + "_" + version,
ClientID: appName,
Transport: &http.Transport{
DialContext: (&net.Dialer{Timeout: 3 * time.Second}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
},
// TokenManager should not be required, but PMAPI still doesn't handle not-set cases everywhere.
TokenManager: pmapi.NewTokenManager(),
},
}
}
// CreateDirs creates all folders that are necessary for bridge to properly function.
func (c *Config) CreateDirs() error {
// Log files.
if err := os.MkdirAll(c.appDirs.UserLogs(), 0700); err != nil {
return err
}
// TLS files.
if err := os.MkdirAll(c.appDirs.UserConfig(), 0750); err != nil {
return err
}
// Lock, events, preferences, user_info, db files.
if err := os.MkdirAll(c.appDirsVersion.UserCache(), 0750); err != nil {
return err
}
return nil
}
// ClearData removes all files except the lock file.
// The lock file will be removed when the Bridge stops.
func (c *Config) ClearData() error {
dirs := []string{
c.appDirs.UserLogs(),
c.appDirs.UserConfig(),
c.appDirs.UserCache(),
}
shouldRemove := func(filePath string) bool {
return filePath != c.GetLockPath()
}
return c.removeAllExcept(dirs, shouldRemove)
}
// ClearOldData removes all old files, such as old log files or old versions of cache and so on.
func (c *Config) ClearOldData() error {
// `appDirs` is parent for `appDirsVersion`.
// `dir` then contains all subfolders and only `cacheVersion` should stay.
// But on Windows all files (dirs) are in the same one - we cannot remove log, lock or tls files.
dir := c.appDirs.UserCache()
return c.removeExcept(dir, func(filePath string) bool {
fileName := filepath.Base(filePath)
return (fileName != c.cacheVersion &&
!logFileRgx.MatchString(fileName) &&
filePath != c.GetTLSCertPath() &&
filePath != c.GetTLSKeyPath() &&
filePath != c.GetEventsPath() &&
filePath != c.GetIMAPCachePath() &&
filePath != c.GetLockPath() &&
filePath != c.GetPreferencesPath())
})
}
func (c *Config) removeAllExcept(dirs []string, shouldRemove func(string) bool) error {
var result *multierror.Error
for _, dir := range dirs {
if err := c.removeExcept(dir, shouldRemove); err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}
func (c *Config) removeExcept(dir string, shouldRemove func(string) bool) error {
files, err := ioutil.ReadDir(dir)
if err != nil {
return err
}
var result *multierror.Error
for _, file := range files {
filePath := filepath.Join(dir, file.Name())
if !shouldRemove(filePath) {
continue
}
if !file.IsDir() {
if err := os.RemoveAll(filePath); err != nil {
result = multierror.Append(result, err)
}
continue
}
subDir := filepath.Join(dir, file.Name())
if err := c.removeExcept(subDir, shouldRemove); err != nil {
result = multierror.Append(result, err)
} else {
// Remove dir itself only if it's empty.
subFiles, err := ioutil.ReadDir(subDir)
if err != nil {
result = multierror.Append(result, err)
} else if len(subFiles) == 0 {
if err := os.RemoveAll(subDir); err != nil {
result = multierror.Append(result, err)
}
}
}
}
return result.ErrorOrNil()
}
// IsDevMode should be used for development conditions such us whether to send sentry reports.
func (c *Config) IsDevMode() bool {
return os.Getenv("PROTONMAIL_ENV") == "dev"
}
// GetLogDir returns folder for log files.
func (c *Config) GetLogDir() string {
return c.appDirs.UserLogs()
}
// GetLogPrefix returns prefix for log files. Bridge uses format vVERSION.
func (c *Config) GetLogPrefix() string {
return "v" + c.version + "_" + c.revision
}
// GetTLSCertPath returns path to certificate; used for TLS servers (IMAP, SMTP and API).
func (c *Config) GetTLSCertPath() string {
return filepath.Join(c.appDirs.UserConfig(), "cert.pem")
}
// GetTLSKeyPath returns path to private key; used for TLS servers (IMAP, SMTP and API).
func (c *Config) GetTLSKeyPath() string {
return filepath.Join(c.appDirs.UserConfig(), "key.pem")
}
// GetDBDir returns folder for db files.
func (c *Config) GetDBDir() string {
return filepath.Join(c.appDirsVersion.UserCache())
}
// GetEventsPath returns path to events file containing the last processed event IDs.
func (c *Config) GetEventsPath() string {
return filepath.Join(c.appDirsVersion.UserCache(), "events.json")
}
// GetIMAPCachePath returns path to file with IMAP status.
func (c *Config) GetIMAPCachePath() string {
return filepath.Join(c.appDirsVersion.UserCache(), "user_info.json")
}
// GetLockPath returns path to lock file to check if bridge is already running.
func (c *Config) GetLockPath() string {
return filepath.Join(c.appDirsVersion.UserCache(), c.appName+".lock")
}
// GetUpdateDir returns folder for update files; such as new binary.
func (c *Config) GetUpdateDir() string {
return filepath.Join(c.appDirsVersion.UserCache(), "updates")
}
// GetPreferencesPath returns path to preference file.
func (c *Config) GetPreferencesPath() string {
return filepath.Join(c.appDirsVersion.UserCache(), "prefs.json")
}
// GetAPIConfig returns config for ProtonMail API.
func (c *Config) GetAPIConfig() *pmapi.ClientConfig {
return c.apiConfig
}
// GetDefaultAPIPort returns default Bridge local API port.
func (c *Config) GetDefaultAPIPort() int {
return 1042
}
// GetDefaultIMAPPort returns default Bridge IMAP port.
func (c *Config) GetDefaultIMAPPort() int {
return 1143
}
// GetDefaultSMTPPort returns default Bridge SMTP port.
func (c *Config) GetDefaultSMTPPort() int {
return 1025
}

238
pkg/config/config_test.go Normal file
View File

@ -0,0 +1,238 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
const testAppName = "bridge-test"
var testConfigDir string //nolint[gochecknoglobals]
func TestMain(m *testing.M) {
setupTestConfig()
setupTestLogs()
code := m.Run()
shutdownTestConfig()
shutdownTestLogs()
shutdownTestPreferences()
os.Exit(code)
}
func setupTestConfig() {
var err error
testConfigDir, err = ioutil.TempDir("", "config")
if err != nil {
panic(err)
}
}
func shutdownTestConfig() {
_ = os.RemoveAll(testConfigDir)
}
type mocks struct {
t *testing.T
ctrl *gomock.Controller
appDir *MockappDirer
appDirVersion *MockappDirer
}
func initMocks(t *testing.T) mocks {
mockCtrl := gomock.NewController(t)
return mocks{
t: t,
ctrl: mockCtrl,
appDir: NewMockappDirer(mockCtrl),
appDirVersion: NewMockappDirer(mockCtrl),
}
}
func TestClearDataLinux(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
createTestStructureLinux(m, testConfigDir)
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
require.NoError(t, cfg.ClearData())
checkFileNames(t, testConfigDir, []string{
"cache",
"cache/c2",
"cache/c2/bridge-test.lock",
"config",
"logs",
})
}
func TestClearDataWindows(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
createTestStructureWindows(m, testConfigDir)
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
require.NoError(t, cfg.ClearData())
checkFileNames(t, testConfigDir, []string{
"cache",
"cache/c2",
"cache/c2/bridge-test.lock",
"config",
})
}
// OldData touches only cache folder.
// Removes only c1 folder as nothing else is part of cache folder on Linux/Mac.
func TestClearOldDataLinux(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
createTestStructureLinux(m, testConfigDir)
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
require.NoError(t, cfg.ClearOldData())
checkFileNames(t, testConfigDir, []string{
"cache",
"cache/c2",
"cache/c2/bridge-test.lock",
"cache/c2/events.json",
"cache/c2/mailbox-user@pm.me.db",
"cache/c2/prefs.json",
"cache/c2/updates",
"cache/c2/user_info.json",
"config",
"config/cert.pem",
"config/key.pem",
"logs",
"logs/other.log",
"logs/v1_10.log",
"logs/v1_11.log",
"logs/v2_12.log",
"logs/v2_13.log",
})
}
// OldData touches only cache folder. Removes everything except c2 folder
// and bridge log files which are part of cache folder on Windows.
func TestClearOldDataWindows(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
createTestStructureWindows(m, testConfigDir)
cfg := newConfig(testAppName, "v1", "rev123", "c2", m.appDir, m.appDirVersion)
require.NoError(t, cfg.ClearOldData())
checkFileNames(t, testConfigDir, []string{
"cache",
"cache/c2",
"cache/c2/bridge-test.lock",
"cache/c2/events.json",
"cache/c2/mailbox-user@pm.me.db",
"cache/c2/prefs.json",
"cache/c2/updates",
"cache/c2/user_info.json",
"cache/v1_10.log",
"cache/v1_11.log",
"cache/v2_12.log",
"cache/v2_13.log",
"config",
"config/cert.pem",
"config/key.pem",
})
}
func createTestStructureLinux(m mocks, baseDir string) {
logsDir := filepath.Join(baseDir, "logs")
configDir := filepath.Join(baseDir, "config")
cacheDir := filepath.Join(baseDir, "cache")
versionedOldCacheDir := filepath.Join(baseDir, "cache", "c1")
versionedCacheDir := filepath.Join(baseDir, "cache", "c2")
createTestStructure(m, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir)
}
func createTestStructureWindows(m mocks, baseDir string) {
logsDir := filepath.Join(baseDir, "cache")
configDir := filepath.Join(baseDir, "config")
cacheDir := filepath.Join(baseDir, "cache")
versionedOldCacheDir := filepath.Join(baseDir, "cache", "c1")
versionedCacheDir := filepath.Join(baseDir, "cache", "c2")
createTestStructure(m, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir)
}
func createTestStructure(m mocks, baseDir, logsDir, configDir, cacheDir, versionedOldCacheDir, versionedCacheDir string) {
m.appDir.EXPECT().UserLogs().Return(logsDir).AnyTimes()
m.appDir.EXPECT().UserConfig().Return(configDir).AnyTimes()
m.appDir.EXPECT().UserCache().Return(cacheDir).AnyTimes()
m.appDirVersion.EXPECT().UserCache().Return(versionedCacheDir).AnyTimes()
require.NoError(m.t, os.RemoveAll(baseDir))
require.NoError(m.t, os.MkdirAll(baseDir, 0700))
require.NoError(m.t, os.MkdirAll(logsDir, 0700))
require.NoError(m.t, os.MkdirAll(configDir, 0700))
require.NoError(m.t, os.MkdirAll(cacheDir, 0700))
require.NoError(m.t, os.MkdirAll(versionedOldCacheDir, 0700))
require.NoError(m.t, os.MkdirAll(versionedCacheDir, 0700))
require.NoError(m.t, os.MkdirAll(filepath.Join(versionedCacheDir, "updates"), 0700))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "other.log"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v1_10.log"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v1_11.log"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v2_12.log"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(logsDir, "v2_13.log"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(configDir, "cert.pem"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(configDir, "key.pem"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "prefs.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "events.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "user_info.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedOldCacheDir, "mailbox-user@pm.me.db"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "prefs.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "events.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "user_info.json"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, testAppName+".lock"), []byte("Hello"), 0755))
require.NoError(m.t, ioutil.WriteFile(filepath.Join(versionedCacheDir, "mailbox-user@pm.me.db"), []byte("Hello"), 0755))
}
func checkFileNames(t *testing.T, dir string, expectedFileNames []string) {
fileNames := getFileNames(t, dir)
require.Equal(t, expectedFileNames, fileNames)
}
func getFileNames(t *testing.T, dir string) []string {
files, err := ioutil.ReadDir(dir)
require.NoError(t, err)
fileNames := []string{}
for _, file := range files {
fileNames = append(fileNames, file.Name())
if file.IsDir() {
subDir := filepath.Join(dir, file.Name())
subFileNames := getFileNames(t, subDir)
for _, subFileName := range subFileNames {
fileNames = append(fileNames, file.Name()+"/"+subFileName)
}
}
}
return fileNames
}

252
pkg/config/logs.go Normal file
View File

@ -0,0 +1,252 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"runtime"
"runtime/pprof"
"sort"
"strconv"
"time"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/sirupsen/logrus"
)
type logConfiger interface {
GetLogDir() string
GetLogPrefix() string
}
const (
// Zendesk now has a file size limit of 20MB. When the last N log files
// are zipped, it should fit under 20MB. Value in MB (average file has
// few hundreds kB).
maxLogFileSize = 10 * 1024 * 1024 //nolint[gochecknoglobals]
// Including the current logfile.
maxNumberLogFiles = 3 //nolint[gochecknoglobals]
)
// logFile is pointer to currently open file used by logrus.
var logFile *os.File //nolint[gochecknoglobals]
var logFileRgx = regexp.MustCompile("^v.*\\.log$") //nolint[gochecknoglobals]
var logCrashRgx = regexp.MustCompile("^v.*_crash_.*\\.log$") //nolint[gochecknoglobals]
// GetLogEntry returns logrus.Entry with PID and `packageName`.
func GetLogEntry(packageName string) *logrus.Entry {
return logrus.WithFields(logrus.Fields{
"pkg": packageName,
})
}
// HandlePanic reports the crash to sentry or local file when sentry fails.
func HandlePanic(cfg *Config, output string) {
if !cfg.IsDevMode() {
c := pmapi.NewClient(cfg.GetAPIConfig(), "no-user-id")
err := c.ReportSentryCrash(fmt.Errorf(output))
if err != nil {
log.Error("Sentry crash report failed: ", err)
}
}
filename := getLogFilename(cfg.GetLogPrefix() + "_crash_")
filepath := filepath.Join(cfg.GetLogDir(), filename)
f, err := os.OpenFile(filepath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
if err != nil {
log.Error("Cannot open file to write crash report: ", err)
return
}
_, _ = f.WriteString(output)
_ = pprof.Lookup("goroutine").WriteTo(f, 2)
log.Warn("Crash report saved to ", filepath)
}
// GetGID returns goroutine number which can be used to distiguish logs from
// the concurent processes. Keep in mind that it returns the number of routine
// which executes the function.
func GetGID() uint64 {
b := make([]byte, 64)
b = b[:runtime.Stack(b, false)]
b = bytes.TrimPrefix(b, []byte("goroutine "))
b = b[:bytes.IndexByte(b, ' ')]
n, _ := strconv.ParseUint(string(b), 10, 64)
return n
}
// SetupLog set up log level, formatter and output (file or stdout).
// Returns whether should be used debug for IMAP and SMTP servers.
func SetupLog(cfg logConfiger, levelFlag string) (debugClient, debugServer bool) {
level, useFile := getLogLevelAndFile(levelFlag)
logrus.SetLevel(level)
if useFile {
logrus.SetFormatter(&logrus.JSONFormatter{})
setLogFile(cfg.GetLogDir(), cfg.GetLogPrefix())
watchLogFileSize(cfg.GetLogDir(), cfg.GetLogPrefix())
} else {
logrus.SetFormatter(&logrus.TextFormatter{
ForceColors: true,
FullTimestamp: true,
TimestampFormat: time.StampMilli,
})
logrus.SetOutput(os.Stdout)
}
switch levelFlag {
case "debug-client", "debug-client-json":
debugClient = true
case "debug-server", "debug-server-json", "trace":
fmt.Println("THE LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
log.Warning("================================================")
log.Warning("THIS LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
log.Warning("================================================")
debugClient = true
debugServer = true
}
return debugClient, debugServer
}
func setLogFile(logDir, logPrefix string) {
if logFile != nil {
return
}
filename := getLogFilename(logPrefix)
var err error
logFile, err = os.OpenFile(filepath.Join(logDir, filename), os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
if err != nil {
panic(err)
}
logrus.SetOutput(logFile)
// Users sometimes change the name of the log file. We want to always log
// information about bridge version (included in log prefix) and OS.
log.Warn("Bridge version: ", logPrefix, " ", runtime.GOOS)
}
func getLogFilename(logPrefix string) string {
currentTime := strconv.Itoa(int(time.Now().Unix()))
return logPrefix + "_" + currentTime + ".log"
}
func watchLogFileSize(logDir, logPrefix string) {
go func() {
for {
time.Sleep(60 * time.Second)
checkLogFileSize(logDir, logPrefix)
}
}()
}
func checkLogFileSize(logDir, logPrefix string) {
if logFile == nil {
return
}
stat, err := logFile.Stat()
if err != nil {
log.Error("Log file size check failed: ", err)
return
}
if stat.Size() >= maxLogFileSize {
log.Warn("Current log file ", logFile.Name(), " is too big, opening new file")
closeLogFile()
setLogFile(logDir, logPrefix)
}
if err := clearLogs(logDir); err != nil {
log.Error("Cannot clear logs ", err)
}
}
func closeLogFile() {
if logFile != nil {
_ = logFile.Close()
logFile = nil
}
}
func clearLogs(logDir string) error {
files, err := ioutil.ReadDir(logDir)
if err != nil {
return err
}
var logsWithPrefix []string
var crashesWithPrefix []string
for _, file := range files {
if logFileRgx.MatchString(file.Name()) {
if logCrashRgx.MatchString(file.Name()) {
crashesWithPrefix = append(crashesWithPrefix, file.Name())
} else {
logsWithPrefix = append(logsWithPrefix, file.Name())
}
} else {
// Older versions of Bridge stored logs in subfolders for each version.
// That also has to be cleared and the functionality can be removed after some time.
if file.IsDir() {
if err := clearLogs(filepath.Join(logDir, file.Name())); err != nil {
return err
}
} else {
removeLog(logDir, file.Name())
}
}
}
removeOldLogs(logDir, logsWithPrefix)
removeOldLogs(logDir, crashesWithPrefix)
return nil
}
func removeOldLogs(logDir string, filenames []string) {
count := len(filenames)
if count <= maxNumberLogFiles {
return
}
sort.Strings(filenames) // Sorted by timestamp: oldest first.
for _, filename := range filenames[:count-maxNumberLogFiles] {
removeLog(logDir, filename)
}
}
func removeLog(logDir, filename string) {
// We need to be sure to delete only log files.
// Directory with logs can also contain other files.
if !logFileRgx.MatchString(filename) {
return
}
if err := os.RemoveAll(filepath.Join(logDir, filename)); err != nil {
log.Error("Cannot remove old logs ", err)
}
}

49
pkg/config/logs_all.go Normal file
View File

@ -0,0 +1,49 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// +build !build_qa
package config
import (
"github.com/sirupsen/logrus"
)
func getLogLevelAndFile(levelFlag string) (level logrus.Level, useFile bool) {
useFile = true
switch levelFlag {
case "panic":
level = logrus.PanicLevel
case "fatal":
level = logrus.FatalLevel
case "error":
level = logrus.ErrorLevel
case "warn":
level = logrus.WarnLevel
case "info":
level = logrus.InfoLevel
case "debug", "debug-client", "debug-server", "debug-client-json", "debug-server-json":
level = logrus.DebugLevel
useFile = false
case "trace":
level = logrus.TraceLevel
useFile = false
default:
level = logrus.InfoLevel
}
return
}

50
pkg/config/logs_qa.go Normal file
View File

@ -0,0 +1,50 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// +build build_qa
package config
import (
"github.com/sirupsen/logrus"
)
// getLogLevelAndFile for QA build is altered in a way even decrypted data are stored
// in the log file when forced with `debug-client-json` or `debug-server-json`.
func getLogLevelAndFile(levelFlag string) (level logrus.Level, useFile bool) {
useFile = true
switch levelFlag {
case "panic":
level = logrus.PanicLevel
case "fatal":
level = logrus.FatalLevel
case "error":
level = logrus.ErrorLevel
case "warn":
level = logrus.WarnLevel
case "info":
level = logrus.InfoLevel
case "debug-client-json", "debug-server-json":
level = logrus.DebugLevel
case "debug", "debug-client", "debug-server":
level = logrus.DebugLevel
useFile = false
default:
level = logrus.InfoLevel
}
return
}

225
pkg/config/logs_test.go Normal file
View File

@ -0,0 +1,225 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
type testLogConfig struct{ logDir, logPrefix string }
func (c *testLogConfig) GetLogDir() string { return c.logDir }
func (c *testLogConfig) GetLogPrefix() string { return c.logPrefix }
var testLogDir string //nolint[gochecknoglobals]
func setupTestLogs() {
var err error
testLogDir, err = ioutil.TempDir("", "log")
if err != nil {
panic(err)
}
}
func shutdownTestLogs() {
_ = os.RemoveAll(testLogDir)
}
func TestLogNameLength(t *testing.T) {
cfg := New("bridge-test", "longVersion123", "longRevision1234567890", "c2")
name := getLogFilename(cfg.GetLogPrefix())
if len(name) > 128 {
t.Fatal("Name of the log is too long - limit for encrypted linux is 128 characters")
}
}
// Info and higher levels writes to the file.
func TestSetupLogInfo(t *testing.T) {
dir := beforeEachCreateTestDir(t, "setupInfo")
SetupLog(&testLogConfig{dir, "v"}, "info")
require.Equal(t, "info", logrus.GetLevel().String())
logrus.Info("test message")
files := checkLogFiles(t, dir, 1)
checkLogContains(t, dir, files[0].Name(), "test message")
}
// Debug levels writes to stdout.
func TestSetupLogDebug(t *testing.T) {
dir := beforeEachCreateTestDir(t, "setupDebug")
SetupLog(&testLogConfig{dir, "v"}, "debug")
require.Equal(t, "debug", logrus.GetLevel().String())
logrus.Info("test message")
checkLogFiles(t, dir, 0)
}
func TestReopenLogFile(t *testing.T) {
dir := beforeEachCreateTestDir(t, "reopenLogFile")
setLogFile(dir, "v1")
done := make(chan interface{})
log.Info("first message")
go func() {
<-done // Wait for closing file and opening new one.
log.Info("second message")
done <- nil
}()
closeLogFile()
setLogFile(dir, "v2")
done <- nil
<-done // Wait for second log message.
files := checkLogFiles(t, dir, 2)
checkLogContains(t, dir, files[0].Name(), "first message")
checkLogContains(t, dir, files[1].Name(), "second message")
}
func TestCheckLogFileSizeSmall(t *testing.T) {
dir := beforeEachCreateTestDir(t, "logFileSizeSmall")
setLogFile(dir, "v1")
originalFileName := logFile.Name()
_, _ = logFile.WriteString("small file")
checkLogFileSize(dir, "v2")
require.Equal(t, originalFileName, logFile.Name())
}
func TestCheckLogFileSizeBig(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
dir := beforeEachCreateTestDir(t, "logFileSizeBig")
setLogFile(dir, "v1")
originalFileName := logFile.Name()
// The limit for big file is 10*1024*1024 - keep the string 10 letters long.
for i := 0; i < 1024*1024; i++ {
_, _ = logFile.WriteString("big file!\n")
}
checkLogFileSize(dir, "v2")
require.NotEqual(t, originalFileName, logFile.Name())
}
// ClearLogs removes only bridge old log files keeping last three of them.
func TestClearLogsLinux(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
dir := beforeEachCreateTestDir(t, "clearLogs")
createTestStructureLinux(m, dir)
require.NoError(t, clearLogs(dir))
checkFileNames(t, dir, []string{
"cache",
"cache/c1",
"cache/c1/events.json",
"cache/c1/mailbox-user@pm.me.db",
"cache/c1/prefs.json",
"cache/c1/user_info.json",
"cache/c2",
"cache/c2/bridge-test.lock",
"cache/c2/events.json",
"cache/c2/mailbox-user@pm.me.db",
"cache/c2/prefs.json",
"cache/c2/updates",
"cache/c2/user_info.json",
"config",
"config/cert.pem",
"config/key.pem",
"logs",
"logs/other.log",
"logs/v1_11.log",
"logs/v2_12.log",
"logs/v2_13.log",
})
}
// ClearLogs removes only bridge old log files even when log folder
// is shared with other files on Windows.
func TestClearLogsWindows(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
dir := beforeEachCreateTestDir(t, "clearLogs")
createTestStructureWindows(m, dir)
require.NoError(t, clearLogs(dir))
checkFileNames(t, dir, []string{
"cache",
"cache/c1",
"cache/c1/events.json",
"cache/c1/mailbox-user@pm.me.db",
"cache/c1/prefs.json",
"cache/c1/user_info.json",
"cache/c2",
"cache/c2/bridge-test.lock",
"cache/c2/events.json",
"cache/c2/mailbox-user@pm.me.db",
"cache/c2/prefs.json",
"cache/c2/updates",
"cache/c2/user_info.json",
"cache/other.log",
"cache/v1_11.log",
"cache/v2_12.log",
"cache/v2_13.log",
"config",
"config/cert.pem",
"config/key.pem",
})
}
func beforeEachCreateTestDir(t *testing.T, dir string) string {
// Make sure opened file (from the previous test) is cleared.
closeLogFile()
dir = filepath.Join(testLogDir, dir)
require.NoError(t, os.MkdirAll(dir, 0700))
return dir
}
func checkLogFiles(t *testing.T, dir string, expectedCount int) []os.FileInfo {
files, err := ioutil.ReadDir(dir)
require.NoError(t, err)
require.Equal(t, expectedCount, len(files))
return files
}
func checkLogContains(t *testing.T, dir, fileName, expectedSubstr string) {
data, err := ioutil.ReadFile(filepath.Join(dir, fileName)) //nolint[gosec]
require.NoError(t, err)
require.Contains(t, string(data), expectedSubstr)
}

76
pkg/config/mock_config.go Normal file
View File

@ -0,0 +1,76 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: config/config.go
// Package config is a generated GoMock package.
package config
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockappDirer is a mock of appDirer interface
type MockappDirer struct {
ctrl *gomock.Controller
recorder *MockappDirerMockRecorder
}
// MockappDirerMockRecorder is the mock recorder for MockappDirer
type MockappDirerMockRecorder struct {
mock *MockappDirer
}
// NewMockappDirer creates a new mock instance
func NewMockappDirer(ctrl *gomock.Controller) *MockappDirer {
mock := &MockappDirer{ctrl: ctrl}
mock.recorder = &MockappDirerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockappDirer) EXPECT() *MockappDirerMockRecorder {
return m.recorder
}
// UserConfig mocks base method
func (m *MockappDirer) UserConfig() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserConfig")
ret0, _ := ret[0].(string)
return ret0
}
// UserConfig indicates an expected call of UserConfig
func (mr *MockappDirerMockRecorder) UserConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserConfig", reflect.TypeOf((*MockappDirer)(nil).UserConfig))
}
// UserCache mocks base method
func (m *MockappDirer) UserCache() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserCache")
ret0, _ := ret[0].(string)
return ret0
}
// UserCache indicates an expected call of UserCache
func (mr *MockappDirerMockRecorder) UserCache() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCache", reflect.TypeOf((*MockappDirer)(nil).UserCache))
}
// UserLogs mocks base method
func (m *MockappDirer) UserLogs() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserLogs")
ret0, _ := ret[0].(string)
return ret0
}
// UserLogs indicates an expected call of UserLogs
func (mr *MockappDirerMockRecorder) UserLogs() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserLogs", reflect.TypeOf((*MockappDirer)(nil).UserLogs))
}

127
pkg/config/preferences.go Normal file
View File

@ -0,0 +1,127 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"encoding/json"
"errors"
"os"
"strconv"
"sync"
)
type Preferences struct {
cache map[string]string
path string
lock *sync.RWMutex
}
// NewPreferences returns loaded preferences.
func NewPreferences(preferencesPath string) *Preferences {
p := &Preferences{
path: preferencesPath,
lock: &sync.RWMutex{},
}
if err := p.load(); err != nil {
log.Warn("Cannot load preferences: ", err)
}
return p
}
func (p *Preferences) load() error {
if p.cache != nil {
return nil
}
p.lock.Lock()
defer p.lock.Unlock()
p.cache = map[string]string{}
f, err := os.Open(p.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewDecoder(f).Decode(&p.cache)
}
func (p *Preferences) save() error {
if p.cache == nil {
return errors.New("cannot save preferences: cache is nil")
}
p.lock.Lock()
defer p.lock.Unlock()
f, err := os.Create(p.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewEncoder(f).Encode(p.cache)
}
func (p *Preferences) SetDefault(key, value string) {
if p.Get(key) == "" {
p.Set(key, value)
}
}
func (p *Preferences) Get(key string) string {
p.lock.RLock()
defer p.lock.RUnlock()
return p.cache[key]
}
func (p *Preferences) GetBool(key string) bool {
return p.Get(key) == "true"
}
func (p *Preferences) GetInt(key string) int {
value, err := strconv.Atoi(p.Get(key))
if err != nil {
log.Error("Cannot parse int: ", err)
}
return value
}
func (p *Preferences) Set(key, value string) {
p.lock.Lock()
p.cache[key] = value
p.lock.Unlock()
if err := p.save(); err != nil {
log.Warn("Cannot save preferences: ", err)
}
}
func (p *Preferences) SetBool(key string, value bool) {
if value {
p.Set(key, "true")
} else {
p.Set(key, "false")
}
}
func (p *Preferences) SetInt(key string, value int) {
p.Set(key, strconv.Itoa(value))
}

View File

@ -0,0 +1,109 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"io/ioutil"
"os"
"testing"
"github.com/stretchr/testify/require"
)
const testPrefFilePath = "/tmp/pref.json"
func shutdownTestPreferences() {
_ = os.RemoveAll(testPrefFilePath)
}
func TestLoadNoPreferences(t *testing.T) {
pref := newTestEmptyPreferences(t)
require.Equal(t, "", pref.Get("key"))
}
func TestLoadBadPreferences(t *testing.T) {
require.NoError(t, ioutil.WriteFile(testPrefFilePath, []byte("{\"key\":\"value"), 0700))
pref := NewPreferences(testPrefFilePath)
require.Equal(t, "", pref.Get("key"))
}
func TestPreferencesGet(t *testing.T) {
pref := newTestPreferences(t)
require.Equal(t, "value", pref.Get("str"))
require.Equal(t, "42", pref.Get("int"))
require.Equal(t, "true", pref.Get("bool"))
require.Equal(t, "t", pref.Get("falseBool"))
}
func TestPreferencesGetInt(t *testing.T) {
pref := newTestPreferences(t)
require.Equal(t, 0, pref.GetInt("str"))
require.Equal(t, 42, pref.GetInt("int"))
require.Equal(t, 0, pref.GetInt("bool"))
require.Equal(t, 0, pref.GetInt("falseBool"))
}
func TestPreferencesGetBool(t *testing.T) {
pref := newTestPreferences(t)
require.Equal(t, false, pref.GetBool("str"))
require.Equal(t, false, pref.GetBool("int"))
require.Equal(t, true, pref.GetBool("bool"))
require.Equal(t, false, pref.GetBool("falseBool"))
}
func TestPreferencesSetDefault(t *testing.T) {
pref := newTestEmptyPreferences(t)
pref.SetDefault("key", "value")
pref.SetDefault("key", "othervalue")
require.Equal(t, "value", pref.Get("key"))
}
func TestPreferencesSet(t *testing.T) {
pref := newTestEmptyPreferences(t)
pref.Set("str", "value")
checkSavedPreferences(t, "{\"str\":\"value\"}")
}
func TestPreferencesSetInt(t *testing.T) {
pref := newTestEmptyPreferences(t)
pref.SetInt("int", 42)
checkSavedPreferences(t, "{\"int\":\"42\"}")
}
func TestPreferencesSetBool(t *testing.T) {
pref := newTestEmptyPreferences(t)
pref.SetBool("trueBool", true)
pref.SetBool("falseBool", false)
checkSavedPreferences(t, "{\"falseBool\":\"false\",\"trueBool\":\"true\"}")
}
func newTestEmptyPreferences(t *testing.T) *Preferences {
require.NoError(t, os.RemoveAll(testPrefFilePath))
return NewPreferences(testPrefFilePath)
}
func newTestPreferences(t *testing.T) *Preferences {
require.NoError(t, ioutil.WriteFile(testPrefFilePath, []byte("{\"str\":\"value\",\"int\":\"42\",\"bool\":\"true\",\"falseBool\":\"t\"}"), 0700))
return NewPreferences(testPrefFilePath)
}
func checkSavedPreferences(t *testing.T, expected string) {
data, err := ioutil.ReadFile(testPrefFilePath)
require.NoError(t, err)
require.Equal(t, expected+"\n", string(data))
}

170
pkg/config/tls.go Normal file
View File

@ -0,0 +1,170 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"os/exec"
"runtime"
"strings"
"time"
"github.com/kardianos/osext"
)
type tlsConfiger interface {
GetTLSCertPath() string
GetTLSKeyPath() string
}
var tlsTemplate = x509.Certificate{ //nolint[gochecknoglobals]
SerialNumber: big.NewInt(-1),
Subject: pkix.Name{
Country: []string{"CH"},
Organization: []string{"Proton Technologies AG"},
OrganizationalUnit: []string{"ProtonMail"},
CommonName: "127.0.0.1",
},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
IsCA: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
NotBefore: time.Now(),
NotAfter: time.Now().Add(20 * 365 * 24 * time.Hour),
}
var ErrTLSCertExpireSoon = fmt.Errorf("TLS certificate will expire soon")
// GetTLSConfig tries to load TLS config or generate new one which is then returned.
func GetTLSConfig(cfg tlsConfiger) (tlsConfig *tls.Config, err error) {
certPath := cfg.GetTLSCertPath()
keyPath := cfg.GetTLSKeyPath()
tlsConfig, err = loadTLSConfig(certPath, keyPath)
if err != nil {
log.WithError(err).Warn("Cannot load cert, generating a new one")
tlsConfig, err = generateTLSConfig(certPath, keyPath)
if err != nil {
return
}
if runtime.GOOS == "darwin" {
// If this fails, log the error but continue to load.
if p, err := osext.Executable(); err == nil {
p = strings.TrimSuffix(p, "MacOS/Desktop-Bridge") // This needs to match the executable name.
p += "Resources/addcert.scpt"
if err := exec.Command("/usr/bin/osascript", p).Run(); err != nil { // nolint[gosec]
log.WithError(err).Error("Failed to add cert to system keychain")
}
}
}
}
tlsConfig.ServerName = "127.0.0.1"
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
caCertPool := x509.NewCertPool()
caCertPool.AddCert(tlsConfig.Certificates[0].Leaf)
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
/* This is deprecated:
* SA1019: tlsConfig.BuildNameToCertificate is deprecated:
* NameToCertificate only allows associating a single certificate with a given name.
* Leave that field nil to let the library select the first compatible chain from Certificates.
*/
tlsConfig.BuildNameToCertificate() // nolint[staticcheck]
return tlsConfig, err
}
func loadTLSConfig(certPath, keyPath string) (tlsConfig *tls.Config, err error) {
c, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return
}
c.Leaf, err = x509.ParseCertificate(c.Certificate[0])
if err != nil {
return
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{c},
}
if time.Now().Add(31 * 24 * time.Hour).After(c.Leaf.NotAfter) {
err = ErrTLSCertExpireSoon
return
}
return
}
// See https://golang.org/src/crypto/tls/generate_cert.go
func generateTLSConfig(certPath, keyPath string) (tlsConfig *tls.Config, err error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
err = fmt.Errorf("failed to generate private key: %s", err)
return
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
err = fmt.Errorf("failed to generate serial number: %s", err)
return
}
tlsTemplate.SerialNumber = serialNumber
derBytes, err := x509.CreateCertificate(rand.Reader, &tlsTemplate, &tlsTemplate, &priv.PublicKey, priv)
if err != nil {
err = fmt.Errorf("failed to create certificate: %s", err)
return
}
certOut, err := os.Create(certPath)
if err != nil {
return
}
defer certOut.Close() //nolint[errcheck]
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
if err != nil {
return
}
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return
}
defer keyOut.Close() //nolint[errcheck]
err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
if err != nil {
return
}
return loadTLSConfig(certPath, keyPath)
}

63
pkg/config/tls_test.go Normal file
View File

@ -0,0 +1,63 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type testTLSConfig struct{ certPath, keyPath string }
func (c *testTLSConfig) GetTLSCertPath() string { return c.certPath }
func (c *testTLSConfig) GetTLSKeyPath() string { return c.keyPath }
func TestTLSKeyRenewal(t *testing.T) {
// Remove keys.
configPath := "/tmp"
certPath := filepath.Join(configPath, "cert.pem")
keyPath := filepath.Join(configPath, "key.pem")
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
// Put old key there.
tlsTemplate.NotBefore = time.Now().Add(-365 * 24 * time.Hour)
tlsTemplate.NotAfter = time.Now()
cert, err := generateTLSConfig(certPath, keyPath)
require.Equal(t, err, ErrTLSCertExpireSoon)
require.Equal(t, len(cert.Certificates), 1)
time.Sleep(time.Second)
now, notValidAfter := time.Now(), cert.Certificates[0].Leaf.NotAfter
require.True(t, now.After(notValidAfter), "old certificate expected to not be valid at %v but have valid until %v", now, notValidAfter)
// Renew key.
tlsTemplate.NotBefore = time.Now()
tlsTemplate.NotAfter = time.Now().Add(2 * 365 * 24 * time.Hour)
cert, err = GetTLSConfig(&testTLSConfig{certPath, keyPath})
if runtime.GOOS != "darwin" { // Darwin is not supported.
require.NoError(t, err)
}
require.Equal(t, len(cert.Certificates), 1)
now, notValidAfter = time.Now(), cert.Certificates[0].Leaf.NotAfter
require.False(t, now.After(notValidAfter), "new certificate expected to be valid at %v but have valid until %v", now, notValidAfter)
}