Other(refactor): Move TLS to Bridge

This commit is contained in:
James Houlahan
2022-08-18 15:34:27 +02:00
committed by Jakub
parent 743a2f8dac
commit 2aaec3b6bd
10 changed files with 101 additions and 68 deletions

View File

@ -19,14 +19,12 @@
package bridge package bridge
import ( import (
"crypto/tls"
"time" "time"
"github.com/ProtonMail/proton-bridge/v2/internal/api" "github.com/ProtonMail/proton-bridge/v2/internal/api"
"github.com/ProtonMail/proton-bridge/v2/internal/app/base" "github.com/ProtonMail/proton-bridge/v2/internal/app/base"
pkgBridge "github.com/ProtonMail/proton-bridge/v2/internal/bridge" pkgBridge "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings" "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
pkgTLS "github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend" "github.com/ProtonMail/proton-bridge/v2/internal/frontend"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
@ -52,7 +50,7 @@ const (
) )
func New(base *base.Base) *cli.App { func New(base *base.Base) *cli.App {
app := base.NewApp(mailLoop) app := base.NewApp(main)
app.Flags = append(app.Flags, []cli.Flag{ app.Flags = append(app.Flags, []cli.Flag{
&cli.StringFlag{ &cli.StringFlag{
@ -72,12 +70,7 @@ func New(base *base.Base) *cli.App {
return app return app
} }
func mailLoop(b *base.Base, c *cli.Context) error { //nolint:funlen func main(b *base.Base, c *cli.Context) error { //nolint:funlen
tlsConfig, err := loadTLSConfig(b)
if err != nil {
return err
}
// GODT-1481: Always turn off reporting of unencrypted recipient in v2. // GODT-1481: Always turn off reporting of unencrypted recipient in v2.
b.Settings.SetBool(settings.ReportOutgoingNoEncKey, false) b.Settings.SetBool(settings.ReportOutgoingNoEncKey, false)
@ -98,6 +91,7 @@ func mailLoop(b *base.Base, c *cli.Context) error { //nolint:funlen
b.SentryReporter, b.SentryReporter,
b.CrashHandler, b.CrashHandler,
b.Listener, b.Listener,
b.TLS,
cache, cache,
builder, builder,
b.CM, b.CM,
@ -109,6 +103,11 @@ func mailLoop(b *base.Base, c *cli.Context) error { //nolint:funlen
imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, b.Settings, bridge) imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, b.Settings, bridge)
smtpBackend := smtp.NewSMTPBackend(b.CrashHandler, b.Listener, b.Settings, bridge) smtpBackend := smtp.NewSMTPBackend(b.CrashHandler, b.Listener, b.Settings, bridge)
tlsConfig, err := bridge.GetTLSConfig()
if err != nil {
return err
}
if cacheErr != nil { if cacheErr != nil {
bridge.AddError(pkgBridge.ErrLocalCacheUnavailable) bridge.AddError(pkgBridge.ErrLocalCacheUnavailable)
} }
@ -159,7 +158,6 @@ func mailLoop(b *base.Base, c *cli.Context) error { //nolint:funlen
frontendMode, frontendMode,
!c.Bool(base.FlagNoWindow), !c.Bool(base.FlagNoWindow),
b.CrashHandler, b.CrashHandler,
b.TLS,
b.Locations, b.Locations,
b.Settings, b.Settings,
b.Listener, b.Listener,
@ -183,44 +181,6 @@ func mailLoop(b *base.Base, c *cli.Context) error { //nolint:funlen
return f.Loop() return f.Loop()
} }
func loadTLSConfig(b *base.Base) (*tls.Config, error) {
if !b.TLS.HasCerts() {
if err := generateTLSCerts(b); err != nil {
return nil, err
}
}
tlsConfig, err := b.TLS.GetConfig()
if err == nil {
return tlsConfig, nil
}
logrus.WithError(err).Error("Failed to load TLS config, regenerating certificates")
if err := generateTLSCerts(b); err != nil {
return nil, err
}
return b.TLS.GetConfig()
}
func generateTLSCerts(b *base.Base) error {
template, err := pkgTLS.NewTLSTemplate()
if err != nil {
return errors.Wrap(err, "failed to generate TLS template")
}
if err := b.TLS.GenerateCerts(template); err != nil {
return errors.Wrap(err, "failed to generate TLS certs")
}
if err := b.TLS.InstallCerts(); err != nil {
return errors.Wrap(err, "failed to install TLS certs")
}
return nil
}
func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) { func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) {
log := logrus.WithField("pkg", "app/bridge") log := logrus.WithField("pkg", "app/bridge")
version, err := u.Check() version, err := u.Check()

View File

@ -27,6 +27,7 @@ import (
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/go-autostart" "github.com/ProtonMail/go-autostart"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings" "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/metrics" "github.com/ProtonMail/proton-bridge/v2/internal/metrics"
"github.com/ProtonMail/proton-bridge/v2/internal/sentry" "github.com/ProtonMail/proton-bridge/v2/internal/sentry"
@ -52,6 +53,7 @@ type Bridge struct {
clientManager pmapi.Manager clientManager pmapi.Manager
updater Updater updater Updater
versioner Versioner versioner Versioner
tls *tls.TLS
cacheProvider CacheProvider cacheProvider CacheProvider
autostart *autostart.App autostart *autostart.App
// Bridge's global errors list. // Bridge's global errors list.
@ -69,6 +71,7 @@ func New(
sentryReporter *sentry.Reporter, sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler, panicHandler users.PanicHandler,
eventListener listener.Listener, eventListener listener.Listener,
tls *tls.TLS,
cache cache.Cache, cache cache.Cache,
builder *message.Builder, builder *message.Builder,
clientManager pmapi.Manager, clientManager pmapi.Manager,
@ -99,6 +102,7 @@ func New(
clientManager: clientManager, clientManager: clientManager,
updater: updater, updater: updater,
versioner: versioner, versioner: versioner,
tls: tls,
cacheProvider: cacheProvider, cacheProvider: cacheProvider,
autostart: autostart, autostart: autostart,
isFirstStart: false, isFirstStart: false,

View File

@ -0,0 +1,64 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import (
"crypto/tls"
pkgTLS "github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/pkg/errors"
logrus "github.com/sirupsen/logrus"
)
func (b *Bridge) GetTLSConfig() (*tls.Config, error) {
if !b.tls.HasCerts() {
if err := b.generateTLSCerts(); err != nil {
return nil, err
}
}
tlsConfig, err := b.tls.GetConfig()
if err == nil {
return tlsConfig, nil
}
logrus.WithError(err).Error("Failed to load TLS config, regenerating certificates")
if err := b.generateTLSCerts(); err != nil {
return nil, err
}
return b.tls.GetConfig()
}
func (b *Bridge) generateTLSCerts() error {
template, err := pkgTLS.NewTLSTemplate()
if err != nil {
return errors.Wrap(err, "failed to generate TLS template")
}
if err := b.tls.GenerateCerts(template); err != nil {
return errors.Wrap(err, "failed to generate TLS certs")
}
if err := b.tls.InstallCerts(); err != nil {
return errors.Wrap(err, "failed to install TLS certs")
}
return nil
}

View File

@ -21,7 +21,6 @@ package frontend
import ( import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings" "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc"
@ -45,7 +44,6 @@ func New(
frontendType string, frontendType string,
showWindowOnStart bool, showWindowOnStart bool,
panicHandler types.PanicHandler, panicHandler types.PanicHandler,
tls *tls.TLS,
locations *locations.Locations, locations *locations.Locations,
settings *settings.Settings, settings *settings.Settings,
eventListener listener.Listener, eventListener listener.Listener,
@ -61,7 +59,6 @@ func New(
return grpc.NewService( return grpc.NewService(
showWindowOnStart, showWindowOnStart,
panicHandler, panicHandler,
tls,
locations, locations,
settings, settings,
eventListener, eventListener,

View File

@ -28,7 +28,6 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings" "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
bridgetls "github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
@ -53,7 +52,6 @@ type Service struct { // nolint:structcheck
eventStreamDoneCh chan struct{} eventStreamDoneCh chan struct{}
panicHandler types.PanicHandler panicHandler types.PanicHandler
tls *bridgetls.TLS
locations *locations.Locations locations *locations.Locations
settings *settings.Settings settings *settings.Settings
eventListener listener.Listener eventListener listener.Listener
@ -77,7 +75,6 @@ type Service struct { // nolint:structcheck
func NewService( func NewService(
showOnStartup bool, showOnStartup bool,
panicHandler types.PanicHandler, panicHandler types.PanicHandler,
tls *bridgetls.TLS,
locations *locations.Locations, locations *locations.Locations,
settings *settings.Settings, settings *settings.Settings,
eventListener listener.Listener, eventListener listener.Listener,
@ -91,7 +88,6 @@ func NewService(
s := Service{ s := Service{
UnimplementedBridgeServer: UnimplementedBridgeServer{}, UnimplementedBridgeServer: UnimplementedBridgeServer{},
panicHandler: panicHandler, panicHandler: panicHandler,
tls: tls,
locations: locations, locations: locations,
settings: settings, settings: settings,
eventListener: eventListener, eventListener: eventListener,
@ -111,7 +107,7 @@ func NewService(
// set to 1 // set to 1
s.initializing.Add(1) s.initializing.Add(1)
config, err := tls.GetConfig() config, err := bridge.GetTLSConfig()
config.ClientAuth = cryptotls.NoClientCert // skip client auth if the certificate allow it. config.ClientAuth = cryptotls.NoClientCert // skip client auth if the certificate allow it.
if err != nil { if err != nil {
s.log.WithError(err).Error("could not get TLS config") s.log.WithError(err).Error("could not get TLS config")

View File

@ -19,6 +19,8 @@
package types package types
import ( import (
"crypto/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/updater" "github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi" "github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
@ -77,6 +79,10 @@ type User interface {
type Bridger interface { type Bridger interface {
UserManager UserManager
GetTLSConfig() (*tls.Config, error)
// -- old --
ReportBug(osType, osVersion, description, accountName, address, emailClient string, attachLogs bool) error ReportBug(osType, osVersion, description, accountName, address, emailClient string, attachLogs bool) error
SetProxyAllowed(bool) SetProxyAllowed(bool)
GetProxyAllowed() bool GetProxyAllowed() bool

View File

@ -23,6 +23,7 @@ import (
"github.com/ProtonMail/go-autostart" "github.com/ProtonMail/go-autostart"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings" "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/sentry" "github.com/ProtonMail/proton-bridge/v2/internal/sentry"
@ -41,17 +42,18 @@ func (ctx *TestContext) GetBridge() *bridge.Bridge {
// withBridgeInstance creates a bridge instance for use in the test. // withBridgeInstance creates a bridge instance for use in the test.
// TestContext has this by default once called with env variable TEST_APP=bridge. // TestContext has this by default once called with env variable TEST_APP=bridge.
func (ctx *TestContext) withBridgeInstance() { func (ctx *TestContext) withBridgeInstance() {
ctx.bridge = newBridgeInstance(ctx.t, ctx.locations, ctx.cache, ctx.settings, ctx.credStore, ctx.listener, ctx.clientManager) ctx.bridge = newBridgeInstance(ctx.t, ctx.locations, ctx.cache, ctx.settings, ctx.tls, ctx.credStore, ctx.listener, ctx.clientManager)
ctx.users = ctx.bridge.Users ctx.users = ctx.bridge.Users
ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data") ctx.addCleanupChecked(ctx.bridge.ClearData, "Cleaning bridge data")
} }
// RestartBridge closes store for each user and recreates a bridge instance the same way as `withBridgeInstance`. // RestartBridge closes store for each user and recreates a bridge instance the same way as `withBridgeInstance`.
// NOTE: This is a very problematic method. It doesn't stop the goroutines doing the event loop and the sync. // NOTE: This is a very problematic method. It doesn't stop the goroutines doing the event loop and the sync.
// These goroutines can continue to run and can cause problems or unexpected behaviour (especially //
// regarding authorization, because if an auth fails, it will log out the user). // These goroutines can continue to run and can cause problems or unexpected behaviour (especially
// To truly emulate bridge restart, we need a way to immediately stop those goroutines. // regarding authorization, because if an auth fails, it will log out the user).
// I have added a channel that waits up to one second for the event loop to stop, but that isn't great. // To truly emulate bridge restart, we need a way to immediately stop those goroutines.
// I have added a channel that waits up to one second for the event loop to stop, but that isn't great.
func (ctx *TestContext) RestartBridge() error { func (ctx *TestContext) RestartBridge() error {
for _, user := range ctx.bridge.GetUsers() { for _, user := range ctx.bridge.GetUsers() {
_ = user.GetStore().Close() _ = user.GetStore().Close()
@ -71,6 +73,7 @@ func newBridgeInstance(
locations bridge.Locator, locations bridge.Locator,
cacheProvider bridge.CacheProvider, cacheProvider bridge.CacheProvider,
fakeSettings *fakeSettings, fakeSettings *fakeSettings,
tls *tls.TLS,
credStore users.CredentialsStorer, credStore users.CredentialsStorer,
eventListener listener.Listener, eventListener listener.Listener,
clientManager pmapi.Manager, clientManager pmapi.Manager,
@ -82,6 +85,7 @@ func newBridgeInstance(
sentry.NewReporter("bridge", constants.Version, useragent.New()), sentry.NewReporter("bridge", constants.Version, useragent.New()),
&panicHandler{t: t}, &panicHandler{t: t},
eventListener, eventListener,
tls,
cache.NewInMemoryCache(100*(1<<20)), cache.NewInMemoryCache(100*(1<<20)),
message.NewBuilder(fakeSettings.GetInt(settings.FetchWorkers), fakeSettings.GetInt(settings.AttachmentWorkers)), message.NewBuilder(fakeSettings.GetInt(settings.FetchWorkers), fakeSettings.GetInt(settings.AttachmentWorkers)),
clientManager, clientManager,

View File

@ -22,6 +22,7 @@ import (
"sync" "sync"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/users" "github.com/ProtonMail/proton-bridge/v2/internal/users"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener" "github.com/ProtonMail/proton-bridge/v2/pkg/listener"
@ -43,6 +44,7 @@ type TestContext struct {
cache *fakeCache cache *fakeCache
locations *fakeLocations locations *fakeLocations
settings *fakeSettings settings *fakeSettings
tls *tls.TLS
listener listener.Listener listener listener.Listener
userAgent *useragent.UserAgent userAgent *useragent.UserAgent
testAccounts *accounts.TestAccounts testAccounts *accounts.TestAccounts
@ -89,11 +91,15 @@ func New() *TestContext {
listener := listener.New() listener := listener.New()
pmapiController, clientManager := newPMAPIController(listener) pmapiController, clientManager := newPMAPIController(listener)
locations := newFakeLocations()
settingsPath, _ := locations.ProvideSettingsPath()
ctx := &TestContext{ ctx := &TestContext{
t: &bddT{}, t: &bddT{},
cache: newFakeCache(), cache: newFakeCache(),
locations: newFakeLocations(), locations: locations,
settings: newFakeSettings(), settings: newFakeSettings(),
tls: tls.New(settingsPath),
listener: listener, listener: listener,
userAgent: useragent.New(), userAgent: useragent.New(),
pmapiController: pmapiController, pmapiController: pmapiController,

View File

@ -23,7 +23,6 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings" "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/imap" "github.com/ProtonMail/proton-bridge/v2/internal/imap"
"github.com/ProtonMail/proton-bridge/v2/test/mocks" "github.com/ProtonMail/proton-bridge/v2/test/mocks"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -53,10 +52,9 @@ func (ctx *TestContext) withIMAPServer() {
return return
} }
settingsPath, _ := ctx.locations.ProvideSettingsPath()
ph := newPanicHandler(ctx.t) ph := newPanicHandler(ctx.t)
port := ctx.settings.GetInt(settings.IMAPPortKey) port := ctx.settings.GetInt(settings.IMAPPortKey)
tls, _ := tls.New(settingsPath).GetConfig() tls, _ := ctx.tls.GetConfig()
backend := imap.NewIMAPBackend(ph, ctx.listener, ctx.cache, ctx.settings, ctx.bridge) backend := imap.NewIMAPBackend(ph, ctx.listener, ctx.cache, ctx.settings, ctx.bridge)
server := imap.NewIMAPServer(ph, true, true, port, tls, backend, ctx.userAgent, ctx.listener) server := imap.NewIMAPServer(ph, true, true, port, tls, backend, ctx.userAgent, ctx.listener)

View File

@ -23,7 +23,6 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings" "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/smtp" "github.com/ProtonMail/proton-bridge/v2/internal/smtp"
"github.com/ProtonMail/proton-bridge/v2/test/mocks" "github.com/ProtonMail/proton-bridge/v2/test/mocks"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -53,9 +52,8 @@ func (ctx *TestContext) withSMTPServer() {
return return
} }
settingsPath, _ := ctx.locations.ProvideSettingsPath()
ph := newPanicHandler(ctx.t) ph := newPanicHandler(ctx.t)
tls, _ := tls.New(settingsPath).GetConfig() tls, _ := ctx.tls.GetConfig()
port := ctx.settings.GetInt(settings.SMTPPortKey) port := ctx.settings.GetInt(settings.SMTPPortKey)
useSSL := ctx.settings.GetBool(settings.SMTPSSLKey) useSSL := ctx.settings.GetBool(settings.SMTPSSLKey)