feat(GODT-2500): Add panic handlers everywhere.

This commit is contained in:
Jakub
2023-03-22 17:18:17 +01:00
parent 9f59e61b14
commit ec92c918cd
42 changed files with 283 additions and 130 deletions

2
go.mod
View File

@ -125,6 +125,8 @@ require (
) )
replace ( replace (
github.com/ProtonMail/gluon => /home/dev/gopath18/src/gluon
github.com/ProtonMail/go-proton-api => /home/dev/gopath18/src/go-proton-api
github.com/docker/docker-credential-helpers => github.com/ProtonMail/docker-credential-helpers v1.1.0 github.com/docker/docker-credential-helpers => github.com/ProtonMail/docker-credential-helpers v1.1.0
github.com/emersion/go-message => github.com/ProtonMail/go-message v0.0.0-20210611055058-fabeff2ec753 github.com/emersion/go-message => github.com/ProtonMail/go-message v0.0.0-20210611055058-fabeff2ec753
github.com/keybase/go-keychain => github.com/cuthix/go-keychain v0.0.0-20220405075754-31e7cee908fe github.com/keybase/go-keychain => github.com/cuthix/go-keychain v0.0.0-20220405075754-31e7cee908fe

View File

@ -185,14 +185,14 @@ func run(c *cli.Context) error {
exe = os.Args[0] exe = os.Args[0]
} }
migrationErr := migrateOldVersions() // Restart the app if requested.
return withRestarter(exe, func(restarter *restarter.Restarter) error {
// Handle crashes with various actions.
return withCrashHandler(restarter, reporter, func(crashHandler *crash.Handler, quitCh <-chan struct{}) error {
migrationErr := migrateOldVersions()
// Run with profiling if requested. // Run with profiling if requested.
return withProfiler(c, func() error { return withProfiler(c, func() error {
// Restart the app if requested.
return withRestarter(exe, func(restarter *restarter.Restarter) error {
// Handle crashes with various actions.
return withCrashHandler(restarter, reporter, func(crashHandler *crash.Handler, quitCh <-chan struct{}) error {
// Load the locations where we store our files. // Load the locations where we store our files.
return WithLocations(func(locations *locations.Locations) error { return WithLocations(func(locations *locations.Locations) error {
// Migrate the keychain helper. // Migrate the keychain helper.
@ -215,7 +215,7 @@ func run(c *cli.Context) error {
return withSingleInstance(settings, locations.GetLockFile(), version, func() error { return withSingleInstance(settings, locations.GetLockFile(), version, func() error {
// Unlock the encrypted vault. // Unlock the encrypted vault.
return WithVault(locations, func(v *vault.Vault, insecure, corrupt bool) error { return WithVault(locations, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error {
// Report insecure vault. // Report insecure vault.
if insecure { if insecure {
_ = reporter.ReportMessageWithContext("Vault is insecure", map[string]interface{}{}) _ = reporter.ReportMessageWithContext("Vault is insecure", map[string]interface{}{})

View File

@ -78,7 +78,7 @@ func withBridge(
) )
// Create a proxy dialer which switches to a proxy if the request fails. // Create a proxy dialer which switches to a proxy if the request fails.
proxyDialer := dialer.NewProxyTLSDialer(pinningDialer, constants.APIHost) proxyDialer := dialer.NewProxyTLSDialer(pinningDialer, constants.APIHost, crashHandler)
// Create the autostarter. // Create the autostarter.
autostarter := newAutostarter(exe) autostarter := newAutostarter(exe)

View File

@ -46,7 +46,7 @@ func runFrontend(
switch { switch {
case c.Bool(flagCLI): case c.Bool(flagCLI):
return bridgeCLI.New(bridge, restarter, eventCh).Loop() return bridgeCLI.New(bridge, restarter, eventCh, crashHandler).Loop()
case c.Bool(flagNonInteractive): case c.Bool(flagNonInteractive):
select {} select {}

View File

@ -25,6 +25,7 @@ import (
"runtime" "runtime"
"testing" "testing"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/bridge" "github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/cookies" "github.com/ProtonMail/proton-bridge/v3/internal/cookies"
@ -40,7 +41,7 @@ import (
func TestMigratePrefsToVaultWithKeys(t *testing.T) { func TestMigratePrefsToVaultWithKeys(t *testing.T) {
// Create a new vault. // Create a new vault.
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key")) vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)
@ -61,7 +62,7 @@ func TestMigratePrefsToVaultWithKeys(t *testing.T) {
func TestMigratePrefsToVaultWithoutKeys(t *testing.T) { func TestMigratePrefsToVaultWithoutKeys(t *testing.T) {
// Create a new vault. // Create a new vault.
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key")) vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)
@ -173,7 +174,7 @@ func TestUserMigration(t *testing.T) {
token, err := crypto.RandomToken(32) token, err := crypto.RandomToken(32)
require.NoError(t, err) require.NoError(t, err)
v, corrupt, err := vault.New(settingsFolder, settingsFolder, token) v, corrupt, err := vault.New(settingsFolder, settingsFolder, token, queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"path" "path"
"github.com/ProtonMail/proton-bridge/v3/internal/async"
"github.com/ProtonMail/proton-bridge/v3/internal/certs" "github.com/ProtonMail/proton-bridge/v3/internal/certs"
"github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/locations" "github.com/ProtonMail/proton-bridge/v3/internal/locations"
@ -29,12 +30,12 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func WithVault(locations *locations.Locations, fn func(*vault.Vault, bool, bool) error) error { func WithVault(locations *locations.Locations, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error {
logrus.Debug("Creating vault") logrus.Debug("Creating vault")
defer logrus.Debug("Vault stopped") defer logrus.Debug("Vault stopped")
// Create the encVault. // Create the encVault.
encVault, insecure, corrupt, err := newVault(locations) encVault, insecure, corrupt, err := newVault(locations, panicHandler)
if err != nil { if err != nil {
return fmt.Errorf("could not create vault: %w", err) return fmt.Errorf("could not create vault: %w", err)
} }
@ -66,7 +67,7 @@ func WithVault(locations *locations.Locations, fn func(*vault.Vault, bool, bool)
return fn(encVault, insecure, corrupt) return fn(encVault, insecure, corrupt)
} }
func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error) { func newVault(locations *locations.Locations, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) {
vaultDir, err := locations.ProvideSettingsPath() vaultDir, err := locations.ProvideSettingsPath()
if err != nil { if err != nil {
return nil, false, false, fmt.Errorf("could not get vault dir: %w", err) return nil, false, false, fmt.Errorf("could not get vault dir: %w", err)
@ -93,7 +94,7 @@ func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error)
return nil, false, false, fmt.Errorf("could not provide gluon path: %w", err) return nil, false, false, fmt.Errorf("could not provide gluon path: %w", err)
} }
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey) vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey, panicHandler)
if err != nil { if err != nil {
return nil, false, false, fmt.Errorf("could not create vault: %w", err) return nil, false, false, fmt.Errorf("could not create vault: %w", err)
} }

View File

@ -21,6 +21,7 @@ import (
"net/http" "net/http"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -32,6 +33,7 @@ func defaultAPIOptions(
version *semver.Version, version *semver.Version,
cookieJar http.CookieJar, cookieJar http.CookieJar,
transport http.RoundTripper, transport http.RoundTripper,
panicHandler queue.PanicHandler,
) []proton.Option { ) []proton.Option {
return []proton.Option{ return []proton.Option{
proton.WithHostURL(apiURL), proton.WithHostURL(apiURL),
@ -39,5 +41,6 @@ func defaultAPIOptions(
proton.WithCookieJar(cookieJar), proton.WithCookieJar(cookieJar),
proton.WithTransport(transport), proton.WithTransport(transport),
proton.WithLogger(logrus.StandardLogger()), proton.WithLogger(logrus.StandardLogger()),
proton.WithPanicHandler(panicHandler),
} }
} }

View File

@ -23,6 +23,7 @@ import (
"net/http" "net/http"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
) )
@ -32,6 +33,7 @@ func newAPIOptions(
version *semver.Version, version *semver.Version,
cookieJar http.CookieJar, cookieJar http.CookieJar,
transport http.RoundTripper, transport http.RoundTripper,
panicHandler queue.PanicHandler,
) []proton.Option { ) []proton.Option {
return defaultAPIOptions(apiURL, version, cookieJar, transport) return defaultAPIOptions(apiURL, version, cookieJar, transport, panicHandler)
} }

View File

@ -24,6 +24,7 @@ import (
"os" "os"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
) )
@ -33,8 +34,9 @@ func newAPIOptions(
version *semver.Version, version *semver.Version,
cookieJar http.CookieJar, cookieJar http.CookieJar,
transport http.RoundTripper, transport http.RoundTripper,
panicHandler queue.PanicHandler,
) []proton.Option { ) []proton.Option {
opt := defaultAPIOptions(apiURL, version, cookieJar, transport) opt := defaultAPIOptions(apiURL, version, cookieJar, transport, panicHandler)
if host := os.Getenv("BRIDGE_API_HOST"); host != "" { if host := os.Getenv("BRIDGE_API_HOST"); host != "" {
opt = append(opt, proton.WithHostURL(host)) opt = append(opt, proton.WithHostURL(host))

View File

@ -93,8 +93,8 @@ type Bridge struct {
// locator is the bridge's locator. // locator is the bridge's locator.
locator Locator locator Locator
// crashHandler // panicHandler
crashHandler async.PanicHandler panicHandler async.PanicHandler
// reporter // reporter
reporter reporter.Reporter reporter reporter.Reporter
@ -143,7 +143,7 @@ func New(
tlsReporter TLSReporter, // the TLS reporter to report TLS errors tlsReporter TLSReporter, // the TLS reporter to report TLS errors
roundTripper http.RoundTripper, // the round tripper to use for API requests roundTripper http.RoundTripper, // the round tripper to use for API requests
proxyCtl ProxyController, // the DoH controller proxyCtl ProxyController, // the DoH controller
crashHandler async.PanicHandler, panicHandler async.PanicHandler,
reporter reporter.Reporter, reporter reporter.Reporter,
uidValidityGenerator imap.UIDValidityGenerator, uidValidityGenerator imap.UIDValidityGenerator,
@ -151,10 +151,10 @@ func New(
logSMTP bool, // whether to log SMTP activity logSMTP bool, // whether to log SMTP activity
) (*Bridge, <-chan events.Event, error) { ) (*Bridge, <-chan events.Event, error) {
// api is the user's API manager. // api is the user's API manager.
api := proton.New(newAPIOptions(apiURL, curVersion, cookieJar, roundTripper)...) api := proton.New(newAPIOptions(apiURL, curVersion, cookieJar, roundTripper, panicHandler)...)
// tasks holds all the bridge's background tasks. // tasks holds all the bridge's background tasks.
tasks := async.NewGroup(context.Background(), crashHandler) tasks := async.NewGroup(context.Background(), panicHandler)
// imapEventCh forwards IMAP events from gluon instances to the bridge for processing. // imapEventCh forwards IMAP events from gluon instances to the bridge for processing.
imapEventCh := make(chan imapEvents.Event) imapEventCh := make(chan imapEvents.Event)
@ -169,7 +169,7 @@ func New(
autostarter, autostarter,
updater, updater,
curVersion, curVersion,
crashHandler, panicHandler,
reporter, reporter,
api, api,
@ -202,7 +202,7 @@ func newBridge(
autostarter Autostarter, autostarter Autostarter,
updater Updater, updater Updater,
curVersion *semver.Version, curVersion *semver.Version,
crashHandler async.PanicHandler, panicHandler async.PanicHandler,
reporter reporter.Reporter, reporter reporter.Reporter,
api *proton.Manager, api *proton.Manager,
@ -248,12 +248,13 @@ func newBridge(
imapEventCh, imapEventCh,
tasks, tasks,
uidValidityGenerator, uidValidityGenerator,
panicHandler,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create IMAP server: %w", err) return nil, fmt.Errorf("failed to create IMAP server: %w", err)
} }
focusService, err := focus.NewService(locator, curVersion) focusService, err := focus.NewService(locator, curVersion, panicHandler)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create focus service: %w", err) return nil, fmt.Errorf("failed to create focus service: %w", err)
} }
@ -279,7 +280,7 @@ func newBridge(
newVersion: curVersion, newVersion: curVersion,
newVersionLock: safe.NewRWMutex(), newVersionLock: safe.NewRWMutex(),
crashHandler: crashHandler, panicHandler: panicHandler,
reporter: reporter, reporter: reporter,
focusService: focusService, focusService: focusService,
@ -495,7 +496,7 @@ func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events
bridge.watchersLock.Lock() bridge.watchersLock.Lock()
defer bridge.watchersLock.Unlock() defer bridge.watchersLock.Unlock()
watcher := watcher.New(ofType...) watcher := watcher.New(bridge.panicHandler, ofType...)
bridge.watchers = append(bridge.watchers, watcher) bridge.watchers = append(bridge.watchers, watcher)

View File

@ -31,6 +31,7 @@ import (
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server" "github.com/ProtonMail/go-proton-api/server"
"github.com/ProtonMail/go-proton-api/server/backend" "github.com/ProtonMail/go-proton-api/server/backend"
@ -699,7 +700,7 @@ func withBridgeNoMocks(
require.NoError(t, err) require.NoError(t, err)
// Create the vault. // Create the vault.
vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey) vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey, queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
// Create a new cookie jar. // Create a new cookie jar.

View File

@ -299,6 +299,7 @@ func newIMAPServer(
eventCh chan<- imapEvents.Event, eventCh chan<- imapEvents.Event,
tasks *async.Group, tasks *async.Group,
uidValidityGenerator imap.UIDValidityGenerator, uidValidityGenerator imap.UIDValidityGenerator,
panicHandler async.PanicHandler,
) (*gluon.Server, error) { ) (*gluon.Server, error) {
gluonCacheDir = ApplyGluonCachePathSuffix(gluonCacheDir) gluonCacheDir = ApplyGluonCachePathSuffix(gluonCacheDir)
gluonConfigDir = ApplyGluonConfigPathSuffix(gluonConfigDir) gluonConfigDir = ApplyGluonConfigPathSuffix(gluonConfigDir)
@ -343,6 +344,7 @@ func newIMAPServer(
getGluonVersionInfo(version), getGluonVersionInfo(version),
gluon.WithReporter(reporter), gluon.WithReporter(reporter),
gluon.WithUIDValidityGenerator(uidValidityGenerator), gluon.WithUIDValidityGenerator(uidValidityGenerator),
gluon.WithPanicHandler(panicHandler),
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -157,6 +157,7 @@ func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error
bridge.imapEventCh, bridge.imapEventCh,
bridge.tasks, bridge.tasks,
bridge.uidValidityGenerator, bridge.uidValidityGenerator,
bridge.panicHandler,
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to create new IMAP server: %w", err) return fmt.Errorf("failed to create new IMAP server: %w", err)

View File

@ -28,6 +28,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server" "github.com/ProtonMail/go-proton-api/server"
@ -428,7 +429,7 @@ func createMessages(ctx context.Context, t *testing.T, c *proton.Client, addrID,
keyPass, err := salt.SaltForKey(password, user.Keys.Primary().ID) keyPass, err := salt.SaltForKey(password, user.Keys.Primary().ID)
require.NoError(t, err) require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, keyPass) _, addrKRs, err := proton.Unlock(user, addr, keyPass, queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
_, ok := addrKRs[addrID] _, ok := addrKRs[addrID]

View File

@ -516,7 +516,7 @@ func (bridge *Bridge) addUserWithVault(
client, client,
bridge.reporter, bridge.reporter,
apiUser, apiUser,
bridge.crashHandler, bridge.panicHandler,
bridge.vault.GetShowAllMail(), bridge.vault.GetShowAllMail(),
bridge.vault.GetMaxSyncMemory(), bridge.vault.GetMaxSyncMemory(),
) )

View File

@ -28,6 +28,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server" "github.com/ProtonMail/go-proton-api/server"
@ -474,7 +475,7 @@ func TestBridge_User_UpdateDraftAndCreateOtherMessage(t *testing.T) {
keyPass, err := salts.SaltForKey(password, user.Keys.Primary().ID) keyPass, err := salts.SaltForKey(password, user.Keys.Primary().ID)
require.NoError(t, err) require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addrs, keyPass) _, addrKRs, err := proton.Unlock(user, addrs, keyPass, queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
// Create a draft (generating a "create draft message" event). // Create a draft (generating a "create draft message" event).
@ -556,7 +557,7 @@ func TestBridge_User_SendDraftRemoveDraftFlag(t *testing.T) {
keyPass, err := salts.SaltForKey(password, user.Keys.Primary().ID) keyPass, err := salts.SaltForKey(password, user.Keys.Primary().ID)
require.NoError(t, err) require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addrs, keyPass) _, addrKRs, err := proton.Unlock(user, addrs, keyPass, queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
// Create a draft (generating a "create draft message" event). // Create a draft (generating a "create draft message" event).

View File

@ -98,6 +98,8 @@ func saveConfigTemporarily(mc *mobileconfig.Config) (fname string, err error) {
// Make sure the temporary file is deleted. // Make sure the temporary file is deleted.
go func() { go func() {
defer recover() //nolint:errcheck
<-time.After(10 * time.Minute) <-time.After(10 * time.Minute)
_ = os.RemoveAll(dir) _ = os.RemoveAll(dir)
}() }()

View File

@ -24,6 +24,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/ProtonMail/proton-bridge/v3/internal/async"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -40,17 +41,20 @@ type ProxyTLSDialer struct {
allowProxy bool allowProxy bool
proxyProvider *proxyProvider proxyProvider *proxyProvider
proxyUseDuration time.Duration proxyUseDuration time.Duration
panicHandler async.PanicHandler
} }
// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer. // NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer.
func NewProxyTLSDialer(dialer TLSDialer, hostURL string) *ProxyTLSDialer { func NewProxyTLSDialer(dialer TLSDialer, hostURL string, panicHandler async.PanicHandler) *ProxyTLSDialer {
return &ProxyTLSDialer{ return &ProxyTLSDialer{
dialer: dialer, dialer: dialer,
locker: sync.RWMutex{}, locker: sync.RWMutex{},
directAddress: formatAsAddress(hostURL), directAddress: formatAsAddress(hostURL),
proxyAddress: formatAsAddress(hostURL), proxyAddress: formatAsAddress(hostURL),
proxyProvider: newProxyProvider(dialer, hostURL, DoHProviders), proxyProvider: newProxyProvider(dialer, hostURL, DoHProviders, panicHandler),
proxyUseDuration: proxyUseDuration, proxyUseDuration: proxyUseDuration,
panicHandler: panicHandler,
} }
} }
@ -75,6 +79,12 @@ func formatAsAddress(rawURL string) string {
return net.JoinHostPort(host, port) return net.JoinHostPort(host, port)
} }
func (d *ProxyTLSDialer) handlePanic() {
if d.panicHandler != nil {
d.panicHandler.HandlePanic()
}
}
// DialTLSContext dials the given network/address. If it fails, it retries using a proxy. // DialTLSContext dials the given network/address. If it fails, it retries using a proxy.
func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
d.locker.RLock() d.locker.RLock()
@ -129,6 +139,8 @@ func (d *ProxyTLSDialer) switchToReachableServer() error {
// This means we want to disable it again in 24 hours. // This means we want to disable it again in 24 hours.
if d.proxyAddress == d.directAddress { if d.proxyAddress == d.directAddress {
go func() { go func() {
defer d.handlePanic()
<-time.After(d.proxyUseDuration) <-time.After(d.proxyUseDuration)
d.locker.Lock() d.locker.Lock()

View File

@ -24,6 +24,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/ProtonMail/proton-bridge/v3/internal/async"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -67,11 +68,13 @@ type proxyProvider struct {
canReachTimeout time.Duration canReachTimeout time.Duration
lastLookup time.Time // The time at which we last attempted to find a proxy. lastLookup time.Time // The time at which we last attempted to find a proxy.
panicHandler async.PanicHandler
} }
// newProxyProvider creates a new proxyProvider that queries the given DoH providers // newProxyProvider creates a new proxyProvider that queries the given DoH providers
// to retrieve DNS records for the given query string. // to retrieve DNS records for the given query string.
func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p *proxyProvider) { func newProxyProvider(dialer TLSDialer, hostURL string, providers []string, panicHandler async.PanicHandler) (p *proxyProvider) {
p = &proxyProvider{ p = &proxyProvider{
dialer: dialer, dialer: dialer,
hostURL: hostURL, hostURL: hostURL,
@ -80,6 +83,7 @@ func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p *
cacheRefreshTimeout: proxyCacheRefreshTimeout, cacheRefreshTimeout: proxyCacheRefreshTimeout,
dohTimeout: proxyDoHTimeout, dohTimeout: proxyDoHTimeout,
canReachTimeout: proxyCanReachTimeout, canReachTimeout: proxyCanReachTimeout,
panicHandler: panicHandler,
} }
// Use the default DNS lookup method; this can be overridden if necessary. // Use the default DNS lookup method; this can be overridden if necessary.
@ -88,6 +92,12 @@ func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p *
return return
} }
func (p *proxyProvider) handlePanic() {
if p.panicHandler != nil {
p.panicHandler.HandlePanic()
}
}
// findReachableServer returns a working API server (either proxy or standard API). // findReachableServer returns a working API server (either proxy or standard API).
func (p *proxyProvider) findReachableServer() (proxy string, err error) { func (p *proxyProvider) findReachableServer() (proxy string, err error) {
logrus.Debug("Trying to find a reachable server") logrus.Debug("Trying to find a reachable server")
@ -109,11 +119,13 @@ func (p *proxyProvider) findReachableServer() (proxy string, err error) {
wg.Add(2) wg.Add(2)
go func() { go func() {
defer p.handlePanic()
defer wg.Done() defer wg.Done()
apiReachable = p.canReach(p.hostURL) apiReachable = p.canReach(p.hostURL)
}() }()
go func() { go func() {
defer p.handlePanic()
defer wg.Done() defer wg.Done()
err = p.refreshProxyCache() err = p.refreshProxyCache()
}() }()
@ -150,6 +162,8 @@ func (p *proxyProvider) refreshProxyCache() error {
resultChan := make(chan []string) resultChan := make(chan []string)
go func() { go func() {
defer p.handlePanic()
for _, provider := range p.providers { for _, provider := range p.providers {
if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil { if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil {
resultChan <- proxies resultChan <- proxies
@ -203,6 +217,7 @@ func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider
dataChan, errChan := make(chan []string), make(chan error) dataChan, errChan := make(chan []string), make(chan error)
go func() { go func() {
defer p.handlePanic()
// Build new DNS request in RFC1035 format. // Build new DNS request in RFC1035 format.
dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT) dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT)

View File

@ -23,6 +23,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/useragent"
r "github.com/stretchr/testify/require" r "github.com/stretchr/testify/require"
) )
@ -31,7 +32,7 @@ func TestProxyProvider_FindProxy(t *testing.T) {
proxy := getTrustedServer() proxy := getTrustedServer()
defer closeServer(proxy) defer closeServer(proxy)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}) p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil } p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findReachableServer() url, err := p.findReachableServer()
@ -47,7 +48,7 @@ func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
unreachableProxy := getTrustedServer() unreachableProxy := getTrustedServer()
closeServer(unreachableProxy) closeServer(unreachableProxy)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}) p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{reachableProxy.URL, unreachableProxy.URL}, nil return []string{reachableProxy.URL, unreachableProxy.URL}, nil
} }
@ -68,7 +69,7 @@ func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) {
checker := NewTLSPinChecker(TrustedAPIPins) checker := NewTLSPinChecker(TrustedAPIPins)
dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker) dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker)
p := newProxyProvider(dialer, "", []string{"not used"}) p := newProxyProvider(dialer, "", []string{"not used"}, queue.NoopPanicHandler{})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{untrustedProxy.URL, trustedProxy.URL}, nil return []string{untrustedProxy.URL, trustedProxy.URL}, nil
} }
@ -85,7 +86,7 @@ func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
unreachableProxy2 := getTrustedServer() unreachableProxy2 := getTrustedServer()
closeServer(unreachableProxy2) closeServer(unreachableProxy2)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}) p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil
} }
@ -105,7 +106,7 @@ func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
checker := NewTLSPinChecker(TrustedAPIPins) checker := NewTLSPinChecker(TrustedAPIPins)
dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker) dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker)
p := newProxyProvider(dialer, "", []string{"not used"}) p := newProxyProvider(dialer, "", []string{"not used"}, queue.NoopPanicHandler{})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil
} }
@ -115,7 +116,7 @@ func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
} }
func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) { func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}) p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{})
p.cacheRefreshTimeout = 1 * time.Second p.cacheRefreshTimeout = 1 * time.Second
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil } p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
@ -132,7 +133,7 @@ func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) {
})) }))
defer closeServer(slowProxy) defer closeServer(slowProxy)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}) p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{})
p.canReachTimeout = 1 * time.Second p.canReachTimeout = 1 * time.Second
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil } p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
@ -144,7 +145,7 @@ func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) {
} }
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) { func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}) p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{})
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9Provider) records, err := p.dohLookup(context.Background(), proxyQuery, Quad9Provider)
r.NoError(t, err) r.NoError(t, err)
@ -155,7 +156,7 @@ func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
// port filter. Basic functionality should be covered by other tests. Keeping // port filter. Basic functionality should be covered by other tests. Keeping
// code here to be able to run it locally if needed. // code here to be able to run it locally if needed.
func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) { func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}) p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{})
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9PortProvider) records, err := p.dohLookup(context.Background(), proxyQuery, Quad9PortProvider)
r.NoError(t, err) r.NoError(t, err)
@ -163,7 +164,7 @@ func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) {
} }
func TestProxyProvider_DoHLookup_Google(t *testing.T) { func TestProxyProvider_DoHLookup_Google(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}) p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{})
records, err := p.dohLookup(context.Background(), proxyQuery, GoogleProvider) records, err := p.dohLookup(context.Background(), proxyQuery, GoogleProvider)
r.NoError(t, err) r.NoError(t, err)
@ -173,7 +174,7 @@ func TestProxyProvider_DoHLookup_Google(t *testing.T) {
func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) { func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
skipIfProxyIsSet(t) skipIfProxyIsSet(t)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}) p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{})
url, err := p.findReachableServer() url, err := p.findReachableServer()
r.NoError(t, err) r.NoError(t, err)
@ -183,7 +184,7 @@ func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) { func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
skipIfProxyIsSet(t) skipIfProxyIsSet(t)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"https://unreachable", Quad9Provider, GoogleProvider}) p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"https://unreachable", Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{})
url, err := p.findReachableServer() url, err := p.findReachableServer()
r.NoError(t, err) r.NoError(t, err)

View File

@ -25,6 +25,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/ProtonMail/gluon/queue"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -141,8 +142,8 @@ func TestProxyDialer_UseProxy(t *testing.T) {
trustedProxy := getTrustedServer() trustedProxy := getTrustedServer()
defer closeServer(trustedProxy) defer closeServer(trustedProxy)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders) provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{})
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "") d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{})
d.proxyProvider = provider d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
@ -159,8 +160,8 @@ func TestProxyDialer_UseProxy_MultipleTimes(t *testing.T) {
proxy3 := getTrustedServer() proxy3 := getTrustedServer()
defer closeServer(proxy3) defer closeServer(proxy3)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders) provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{})
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "") d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{})
d.proxyProvider = provider d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil } provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
@ -189,8 +190,8 @@ func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
trustedProxy := getTrustedServer() trustedProxy := getTrustedServer()
defer closeServer(trustedProxy) defer closeServer(trustedProxy)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders) provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{})
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "") d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{})
d.proxyProvider = provider d.proxyProvider = provider
d.proxyUseDuration = time.Second d.proxyUseDuration = time.Second
@ -212,8 +213,8 @@ func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) { func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
trustedProxy := getTrustedServer() trustedProxy := getTrustedServer()
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders) provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{})
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "") d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{})
d.proxyProvider = provider d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
@ -242,8 +243,8 @@ func TestProxyDialer_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBloc
proxy2 := getTrustedServer() proxy2 := getTrustedServer()
defer closeServer(proxy2) defer closeServer(proxy2)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders) provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{})
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "") d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{})
d.proxyProvider = provider d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }

View File

@ -30,7 +30,7 @@ func TestFocus_Raise(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
locations := locations.New(newTestLocationsProvider(tmpDir), "config-name") locations := locations.New(newTestLocationsProvider(tmpDir), "config-name")
// Start the focus service. // Start the focus service.
service, err := NewService(locations, semver.MustParse("1.2.3")) service, err := NewService(locations, semver.MustParse("1.2.3"), nil)
require.NoError(t, err) require.NoError(t, err)
settingsFolder, err := locations.ProvideSettingsPath() settingsFolder, err := locations.ProvideSettingsPath()
@ -52,7 +52,7 @@ func TestFocus_Version(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
locations := locations.New(newTestLocationsProvider(tmpDir), "config-name") locations := locations.New(newTestLocationsProvider(tmpDir), "config-name")
// Start the focus service. // Start the focus service.
_, err := NewService(locations, semver.MustParse("1.2.3")) _, err := NewService(locations, semver.MustParse("1.2.3"), nil)
require.NoError(t, err) require.NoError(t, err)
settingsFolder, err := locations.ProvideSettingsPath() settingsFolder, err := locations.ProvideSettingsPath()

View File

@ -24,6 +24,7 @@ import (
"net" "net"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v3/internal/async"
"github.com/ProtonMail/proton-bridge/v3/internal/focus/proto" "github.com/ProtonMail/proton-bridge/v3/internal/focus/proto"
"github.com/ProtonMail/proton-bridge/v3/internal/service" "github.com/ProtonMail/proton-bridge/v3/internal/service"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -43,15 +44,18 @@ type Service struct {
server *grpc.Server server *grpc.Server
raiseCh chan struct{} raiseCh chan struct{}
version *semver.Version version *semver.Version
panicHandler async.PanicHandler
} }
// NewService creates a new focus service. // NewService creates a new focus service.
// It listens on the local host and port 1042 (by default). // It listens on the local host and port 1042 (by default).
func NewService(locator service.Locator, version *semver.Version) (*Service, error) { func NewService(locator service.Locator, version *semver.Version, panicHandler async.PanicHandler) (*Service, error) {
serv := &Service{ serv := &Service{
server: grpc.NewServer(), server: grpc.NewServer(),
raiseCh: make(chan struct{}, 1), raiseCh: make(chan struct{}, 1),
version: version, version: version,
panicHandler: panicHandler,
} }
proto.RegisterFocusServer(serv.server, serv) proto.RegisterFocusServer(serv.server, serv)
@ -73,6 +77,8 @@ func NewService(locator service.Locator, version *semver.Version) (*Service, err
} }
go func() { go func() {
defer serv.handlePanic()
if err := serv.server.Serve(listener); err != nil { if err := serv.server.Serve(listener); err != nil {
fmt.Printf("failed to serve: %v", err) fmt.Printf("failed to serve: %v", err)
} }
@ -82,6 +88,12 @@ func NewService(locator service.Locator, version *semver.Version) (*Service, err
return serv, nil return serv, nil
} }
func (service *Service) handlePanic() {
if service.panicHandler != nil {
service.panicHandler.HandlePanic()
}
}
// Raise implements the gRPC FocusService interface; it raises the application. // Raise implements the gRPC FocusService interface; it raises the application.
func (service *Service) Raise(context.Context, *emptypb.Empty) (*emptypb.Empty, error) { func (service *Service) Raise(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
service.raiseCh <- struct{}{} service.raiseCh <- struct{}{}
@ -103,6 +115,8 @@ func (service *Service) GetRaiseCh() <-chan struct{} {
// Close closes the service. // Close closes the service.
func (service *Service) Close() { func (service *Service) Close() {
go func() { go func() {
defer service.handlePanic()
// we do this in a goroutine, as on Windows, the gRPC shutdown may take minutes if something tries to // we do this in a goroutine, as on Windows, the gRPC shutdown may take minutes if something tries to
// interact with it in an invalid way (e.g. HTTP GET request from a Qt QNetworkManager instance). // interact with it in an invalid way (e.g. HTTP GET request from a Qt QNetworkManager instance).
service.server.Stop() service.server.Stop()

View File

@ -21,6 +21,7 @@ package cli
import ( import (
"errors" "errors"
"github.com/ProtonMail/proton-bridge/v3/internal/async"
"github.com/ProtonMail/proton-bridge/v3/internal/bridge" "github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/events"
@ -39,15 +40,18 @@ type frontendCLI struct {
restarter *restarter.Restarter restarter *restarter.Restarter
badUserID string badUserID string
panicHandler async.PanicHandler
} }
// New returns a new CLI frontend configured with the given options. // New returns a new CLI frontend configured with the given options.
func New(bridge *bridge.Bridge, restarter *restarter.Restarter, eventCh <-chan events.Event) *frontendCLI { //nolint:revive func New(bridge *bridge.Bridge, restarter *restarter.Restarter, eventCh <-chan events.Event, panicHandler async.PanicHandler) *frontendCLI { //nolint:revive
fe := &frontendCLI{ fe := &frontendCLI{
Shell: ishell.New(), Shell: ishell.New(),
bridge: bridge, bridge: bridge,
restarter: restarter, restarter: restarter,
badUserID: "", badUserID: "",
panicHandler: panicHandler,
} }
// Clear commands. // Clear commands.
@ -285,6 +289,8 @@ func New(bridge *bridge.Bridge, restarter *restarter.Restarter, eventCh <-chan e
} }
func (f *frontendCLI) watchEvents(eventCh <-chan events.Event) { // nolint:gocyclo func (f *frontendCLI) watchEvents(eventCh <-chan events.Event) { // nolint:gocyclo
defer f.handlePanic()
// GODT-1949: Better error events. // GODT-1949: Better error events.
for _, err := range f.bridge.GetErrors() { for _, err := range f.bridge.GetErrors() {
switch { switch {
@ -445,6 +451,12 @@ func (f *frontendCLI) watchEvents(eventCh <-chan events.Event) { // nolint:gocyc
*/ */
} }
func (f *frontendCLI) handlePanic() {
if f.panicHandler != nil {
f.panicHandler.HandlePanic()
}
}
// Loop starts the frontend loop with an interactive shell. // Loop starts the frontend loop with an interactive shell.
func (f *frontendCLI) Loop() error { func (f *frontendCLI) Loop() error {
f.Printf(` f.Printf(`

View File

@ -191,6 +191,12 @@ func NewService(
return s, nil return s, nil
} }
func (s *Service) handlePanic() {
if s.panicHandler != nil {
s.panicHandler.HandlePanic()
}
}
func (s *Service) initAutostart() { func (s *Service) initAutostart() {
s.firstTimeAutostart.Do(func() { s.firstTimeAutostart.Do(func() {
shouldAutostartBeOn := s.bridge.GetAutostart() shouldAutostartBeOn := s.bridge.GetAutostart()
@ -207,11 +213,14 @@ func (s *Service) Loop() error {
if s.parentPID < 0 { if s.parentPID < 0 {
s.log.Info("Not monitoring parent PID") s.log.Info("Not monitoring parent PID")
} else { } else {
go s.monitorParentPID() go func() {
defer s.handlePanic()
s.monitorParentPID()
}()
} }
go func() { go func() {
defer s.panicHandler.HandlePanic() defer s.handlePanic()
s.watchEvents() s.watchEvents()
}() }()
@ -221,6 +230,8 @@ func (s *Service) Loop() error {
defer close(doneCh) defer close(doneCh)
go func() { go func() {
defer s.handlePanic()
select { select {
case <-s.quitCh: case <-s.quitCh:
s.log.Info("Stopping gRPC server") s.log.Info("Stopping gRPC server")
@ -564,6 +575,8 @@ func (s *Service) monitorParentPID() {
s.log.Info("Parent process does not exist anymore. Initiating shutdown") s.log.Info("Parent process does not exist anymore. Initiating shutdown")
// quit will write to the parentPIDDoneCh, so we launch a goroutine. // quit will write to the parentPIDDoneCh, so we launch a goroutine.
go func() { go func() {
defer s.handlePanic()
if err := s.quit(); err != nil { if err := s.quit(); err != nil {
logrus.WithError(err).Error("Error on quit") logrus.WithError(err).Error("Error on quit")
} }

View File

@ -114,6 +114,8 @@ func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empt
func (s *Service) quit() error { func (s *Service) quit() error {
// Windows is notably slow at Quitting. We do it in a goroutine to speed things up a bit. // Windows is notably slow at Quitting. We do it in a goroutine to speed things up a bit.
go func() { go func() {
defer s.handlePanic()
if s.parentPID >= 0 { if s.parentPID >= 0 {
s.parentPIDDoneCh <- struct{}{} s.parentPIDDoneCh <- struct{}{}
} }
@ -221,7 +223,8 @@ func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb.
s.log.Debug("TriggerReset") s.log.Debug("TriggerReset")
go func() { go func() {
defer s.panicHandler.HandlePanic() defer s.handlePanic()
s.triggerReset() s.triggerReset()
}() }()
return &emptypb.Empty{}, nil return &emptypb.Empty{}, nil
@ -316,6 +319,8 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp
}).Debug("ReportBug") }).Debug("ReportBug")
go func() { go func() {
defer s.handlePanic()
defer func() { _ = s.SendEvent(NewReportBugFinishedEvent()) }() defer func() { _ = s.SendEvent(NewReportBugFinishedEvent()) }()
if err := s.bridge.ReportBug( if err := s.bridge.ReportBug(
@ -343,7 +348,7 @@ func (s *Service) ExportTLSCertificates(_ context.Context, folderPath *wrappersp
s.log.WithField("folderPath", folderPath).Info("ExportTLSCertificates") s.log.WithField("folderPath", folderPath).Info("ExportTLSCertificates")
go func() { go func() {
defer s.panicHandler.HandlePanic() defer s.handlePanic()
cert, key := s.bridge.GetBridgeTLSCert() cert, key := s.bridge.GetBridgeTLSCert()
@ -379,7 +384,7 @@ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empt
s.log.WithField("username", login.Username).Debug("Login") s.log.WithField("username", login.Username).Debug("Login")
go func() { go func() {
defer s.panicHandler.HandlePanic() defer s.handlePanic()
password, err := base64Decode(login.Password) password, err := base64Decode(login.Password)
if err != nil { if err != nil {
@ -435,7 +440,7 @@ func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.E
s.log.WithField("username", login.Username).Debug("Login2FA") s.log.WithField("username", login.Username).Debug("Login2FA")
go func() { go func() {
defer s.panicHandler.HandlePanic() defer s.handlePanic()
if s.auth.UID == "" || s.authClient == nil { if s.auth.UID == "" || s.authClient == nil {
s.log.Errorf("Login 2FA: authethication incomplete %s %p", s.auth.UID, s.authClient) s.log.Errorf("Login 2FA: authethication incomplete %s %p", s.auth.UID, s.authClient)
@ -480,7 +485,7 @@ func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*em
s.log.WithField("username", login.Username).Debug("Login2Passwords") s.log.WithField("username", login.Username).Debug("Login2Passwords")
go func() { go func() {
defer s.panicHandler.HandlePanic() defer s.handlePanic()
password, err := base64Decode(login.Password) password, err := base64Decode(login.Password)
if err != nil { if err != nil {
@ -502,7 +507,7 @@ func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest)
s.log.WithField("username", loginAbort.Username).Debug("LoginAbort") s.log.WithField("username", loginAbort.Username).Debug("LoginAbort")
go func() { go func() {
defer s.panicHandler.HandlePanic() defer s.handlePanic()
s.loginAbort() s.loginAbort()
}() }()
@ -514,7 +519,7 @@ func (s *Service) CheckUpdate(context.Context, *emptypb.Empty) (*emptypb.Empty,
s.log.Debug("CheckUpdate") s.log.Debug("CheckUpdate")
go func() { go func() {
defer s.panicHandler.HandlePanic() defer s.handlePanic()
updateCh, done := s.bridge.GetEvents( updateCh, done := s.bridge.GetEvents(
events.UpdateAvailable{}, events.UpdateAvailable{},
@ -546,7 +551,7 @@ func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb
s.log.Debug("InstallUpdate") s.log.Debug("InstallUpdate")
go func() { go func() {
defer s.panicHandler.HandlePanic() defer s.handlePanic()
safe.RLock(func() { safe.RLock(func() {
s.bridge.InstallUpdate(s.target) s.bridge.InstallUpdate(s.target)
@ -587,6 +592,8 @@ func (s *Service) SetDiskCachePath(ctx context.Context, newPath *wrapperspb.Stri
s.log.WithField("path", newPath.Value).Debug("setDiskCachePath") s.log.WithField("path", newPath.Value).Debug("setDiskCachePath")
go func() { go func() {
defer s.handlePanic()
defer func() { defer func() {
_ = s.SendEvent(NewDiskCachePathChangeFinishedEvent()) _ = s.SendEvent(NewDiskCachePathChangeFinishedEvent())
}() }()
@ -652,7 +659,7 @@ func (s *Service) SetMailServerSettings(_ context.Context, settings *ImapSmtpSet
Debug("SetConnectionMode") Debug("SetConnectionMode")
go func() { go func() {
defer s.panicHandler.HandlePanic() defer s.handlePanic()
defer func() { _ = s.SendEvent(NewChangeMailServerSettingsFinishedEvent()) }() defer func() { _ = s.SendEvent(NewChangeMailServerSettingsFinishedEvent()) }()

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/ProtonMail/proton-bridge/v3/internal/async"
"github.com/bradenaw/juniper/xerrors" "github.com/bradenaw/juniper/xerrors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -68,17 +69,32 @@ func catch(handlers ...func() error) {
} }
type Group struct { type Group struct {
mu sync.Mutex mu sync.Mutex
panicHandler async.PanicHandler
}
func (wg *Group) SetPanicHandler(panicHandler async.PanicHandler) {
wg.panicHandler = panicHandler
}
func (wg *Group) handlePanic() {
if wg.panicHandler != nil {
wg.panicHandler.HandlePanic()
}
} }
func (wg *Group) GoTry(fn func(bool)) { func (wg *Group) GoTry(fn func(bool)) {
if wg.mu.TryLock() { if wg.mu.TryLock() {
go func() { go func() {
defer wg.handlePanic()
defer wg.mu.Unlock() defer wg.mu.Unlock()
fn(true) fn(true)
}() }()
} else { } else {
go fn(false) go func() {
defer wg.handlePanic()
fn(false)
}()
} }
} }

View File

@ -225,7 +225,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID] user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
case vault.SplitMode: case vault.SplitMode:
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0) user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
} }
user.eventCh.Enqueue(events.UserAddressCreated{ user.eventCh.Enqueue(events.UserAddressCreated{
@ -284,7 +284,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID] user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
case vault.SplitMode: case vault.SplitMode:
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0) user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
} }
user.eventCh.Enqueue(events.UserAddressEnabled{ user.eventCh.Enqueue(events.UserAddressEnabled{
@ -594,7 +594,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
"subject": logging.Sensitive(message.Subject), "subject": logging.Sensitive(message.Subject),
}).Info("Handling message created event") }).Info("Handling message created event")
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()) full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil { if err != nil {
// If the message is not found, it means that it has been deleted before we could fetch it. // If the message is not found, it means that it has been deleted before we could fetch it.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity { if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
@ -686,7 +686,7 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
"subject": logging.Sensitive(event.Message.Subject), "subject": logging.Sensitive(event.Message.Subject),
}).Info("Handling draft updated event") }).Info("Handling draft updated event")
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()) full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil { if err != nil {
// If the message is not found, it means that it has been deleted before we could fetch it. // If the message is not found, it means that it has been deleted before we could fetch it.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity { if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {

View File

@ -290,7 +290,7 @@ func (conn *imapConnector) CreateMessage(
conn.log.WithField("messageID", messageID).Warn("Message already sent") conn.log.WithField("messageID", messageID).Warn("Message already sent")
// Query the server-side message. // Query the server-side message.
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()) full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil { if err != nil {
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err) return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
} }
@ -354,7 +354,7 @@ func (conn *imapConnector) CreateMessage(
} }
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) { func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()) msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -572,7 +572,7 @@ func (conn *imapConnector) importMessage(
var err error var err error
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()); err != nil { if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
return fmt.Errorf("failed to fetch message: %w", err) return fmt.Errorf("failed to fetch message: %w", err)
} }

View File

@ -48,6 +48,8 @@ import (
// sendMail sends an email from the given address to the given recipients. // sendMail sends an email from the given address to the given recipients.
func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error { func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error {
defer user.handlePanic()
return safe.RLockRet(func() error { return safe.RLockRet(func() error {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -143,7 +145,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
} }
// Send the message using the correct key. // Send the message using the correct key.
sent, err := sendWithKey( sent, err := user.sendWithKey(
ctx, ctx,
user.client, user.client,
user.reporter, user.reporter,
@ -167,7 +169,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
} }
// sendWithKey sends the message with the given address key. // sendWithKey sends the message with the given address key.
func sendWithKey( func (user *User) sendWithKey(
ctx context.Context, ctx context.Context,
client *proton.Client, client *proton.Client,
sentry reporter.Reporter, sentry reporter.Reporter,
@ -226,12 +228,12 @@ func sendWithKey(
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err) return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
} }
attKeys, err := createAttachments(ctx, client, addrKR, draft.ID, message.Attachments) attKeys, err := user.createAttachments(ctx, client, addrKR, draft.ID, message.Attachments)
if err != nil { if err != nil {
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err) return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
} }
recipients, err := getRecipients(ctx, client, userKR, settings, draft) recipients, err := user.getRecipients(ctx, client, userKR, settings, draft)
if err != nil { if err != nil {
return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err) return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err)
} }
@ -377,7 +379,7 @@ func createDraft(
}) })
} }
func createAttachments( func (user *User) createAttachments(
ctx context.Context, ctx context.Context,
client *proton.Client, client *proton.Client,
addrKR *crypto.KeyRing, addrKR *crypto.KeyRing,
@ -390,6 +392,8 @@ func createAttachments(
} }
keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) { keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) {
defer user.handlePanic()
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"name": logging.Sensitive(att.Name), "name": logging.Sensitive(att.Name),
"contentID": att.ContentID, "contentID": att.ContentID,
@ -455,7 +459,7 @@ func createAttachments(
return attKeys, nil return attKeys, nil
} }
func getRecipients( func (user *User) getRecipients(
ctx context.Context, ctx context.Context,
client *proton.Client, client *proton.Client,
userKR *crypto.KeyRing, userKR *crypto.KeyRing,
@ -467,6 +471,8 @@ func getRecipients(
}) })
prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) { prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) {
defer user.handlePanic()
pubKeys, recType, err := client.GetPublicKeys(ctx, recipient) pubKeys, recType, err := client.GetPublicKeys(ctx, recipient)
if err != nil { if err != nil {
return proton.SendPreferences{}, fmt.Errorf("failed to get public key for %v: %w", recipient, err) return proton.SendPreferences{}, fmt.Errorf("failed to get public key for %v: %w", recipient, err)

View File

@ -153,7 +153,7 @@ func (user *User) sync(ctx context.Context) error {
} }
// Sync the messages. // Sync the messages.
if err := syncMessages( if err := user.syncMessages(
ctx, ctx,
user.ID(), user.ID(),
messageIDs, messageIDs,
@ -242,7 +242,7 @@ func toMB(v uint64) float64 {
} }
// nolint:gocyclo // nolint:gocyclo
func syncMessages( func (user *User) syncMessages(
ctx context.Context, ctx context.Context,
userID string, userID string,
messageIDs []string, messageIDs []string,
@ -370,7 +370,7 @@ func syncMessages(
errorCh := make(chan error, maxParallelDownloads*4) errorCh := make(chan error, maxParallelDownloads*4)
// Go routine in charge of downloading message metadata // Go routine in charge of downloading message metadata
logging.GoAnnotated(ctx, func(ctx context.Context) { logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(downloadCh) defer close(downloadCh)
const MetadataDataPageSize = 150 const MetadataDataPageSize = 150
@ -433,14 +433,14 @@ func syncMessages(
}, logging.Labels{"sync-stage": "meta-data"}) }, logging.Labels{"sync-stage": "meta-data"})
// Goroutine in charge of downloading and building messages in maxBatchSize batches. // Goroutine in charge of downloading and building messages in maxBatchSize batches.
logging.GoAnnotated(ctx, func(ctx context.Context) { logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(buildCh) defer close(buildCh)
defer close(errorCh) defer close(errorCh)
defer func() { defer func() {
logrus.Debugf("sync downloader exit") logrus.Debugf("sync downloader exit")
}() }()
attachmentDownloader := newAttachmentDownloader(ctx, client, maxParallelDownloads) attachmentDownloader := user.newAttachmentDownloader(ctx, client, maxParallelDownloads)
defer attachmentDownloader.close() defer attachmentDownloader.close()
for request := range downloadCh { for request := range downloadCh {
@ -456,6 +456,8 @@ func syncMessages(
} }
result, err := parallel.MapContext(ctx, maxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) { result, err := parallel.MapContext(ctx, maxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
defer user.handlePanic()
var result proton.FullMessage var result proton.FullMessage
msg, err := client.GetMessage(ctx, id) msg, err := client.GetMessage(ctx, id)
@ -490,7 +492,7 @@ func syncMessages(
}, logging.Labels{"sync-stage": "download"}) }, logging.Labels{"sync-stage": "download"})
// Goroutine which builds messages after they have been downloaded // Goroutine which builds messages after they have been downloaded
logging.GoAnnotated(ctx, func(ctx context.Context) { logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(flushCh) defer close(flushCh)
defer func() { defer func() {
logrus.Debugf("sync builder exit") logrus.Debugf("sync builder exit")
@ -509,6 +511,8 @@ func syncMessages(
logrus.Debugf("Build request: %v of %v count=%v", index, len(chunks), len(chunk)) logrus.Debugf("Build request: %v of %v count=%v", index, len(chunks), len(chunk))
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) { result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
defer user.handlePanic()
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil
}) })
if err != nil { if err != nil {
@ -526,7 +530,7 @@ func syncMessages(
}, logging.Labels{"sync-stage": "builder"}) }, logging.Labels{"sync-stage": "builder"})
// Goroutine which converts the messages into updates and builds a waitable structure for progress tracking. // Goroutine which converts the messages into updates and builds a waitable structure for progress tracking.
logging.GoAnnotated(ctx, func(ctx context.Context) { logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(flushUpdateCh) defer close(flushUpdateCh)
defer func() { defer func() {
logrus.Debugf("sync flush exit") logrus.Debugf("sync flush exit")
@ -771,12 +775,12 @@ func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan at
} }
} }
func newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader { func (user *User) newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount) workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
for i := 0; i < workerCount; i++ { for i := 0; i < workerCount; i++ {
workerCh = make(chan attachmentJob) workerCh = make(chan attachmentJob)
logging.GoAnnotated(ctx, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{ logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
"sync": fmt.Sprintf("att-downloader %v", i), "sync": fmt.Sprintf("att-downloader %v", i),
}) })
} }

View File

@ -24,6 +24,7 @@ import (
"strings" "strings"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/async"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
@ -93,6 +94,6 @@ func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
return sorted return sorted
} }
func newProtonAPIScheduler() proton.Scheduler { func newProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler {
return proton.NewParallelScheduler(runtime.NumCPU() / 2) return proton.NewParallelScheduler(runtime.NumCPU()/2, panicHandler)
} }

View File

@ -91,6 +91,8 @@ type User struct {
showAllMail uint32 showAllMail uint32
maxSyncMemory uint64 maxSyncMemory uint64
panicHandler async.PanicHandler
} }
// New returns a new user. // New returns a new user.
@ -127,7 +129,7 @@ func New(
reporter: reporter, reporter: reporter,
sendHash: newSendRecorder(sendEntryExpiry), sendHash: newSendRecorder(sendEntryExpiry),
eventCh: queue.NewQueuedChannel[events.Event](0, 0), eventCh: queue.NewQueuedChannel[events.Event](0, 0, crashHandler),
eventLock: safe.NewRWMutex(), eventLock: safe.NewRWMutex(),
apiUser: apiUser, apiUser: apiUser,
@ -148,6 +150,8 @@ func New(
showAllMail: b32(showAllMail), showAllMail: b32(showAllMail),
maxSyncMemory: maxSyncMemory, maxSyncMemory: maxSyncMemory,
panicHandler: crashHandler,
} }
// Initialize the user's update channels for its current address mode. // Initialize the user's update channels for its current address mode.
@ -179,7 +183,10 @@ func New(
user.goPollAPIEvents = func(wait bool) { user.goPollAPIEvents = func(wait bool) {
doneCh := make(chan struct{}) doneCh := make(chan struct{})
go func() { user.pollAPIEventsCh <- doneCh }() go func() {
defer user.handlePanic()
user.pollAPIEventsCh <- doneCh
}()
if wait { if wait {
<-doneCh <-doneCh
@ -230,6 +237,12 @@ func New(
return user, nil return user, nil
} }
func (user *User) handlePanic() {
if user.panicHandler != nil {
user.panicHandler.HandlePanic()
}
}
func (user *User) TriggerSync() { func (user *User) TriggerSync() {
user.goSync() user.goSync()
} }
@ -596,7 +609,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
switch mode { switch mode {
case vault.CombinedMode: case vault.CombinedMode:
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0) primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
for addrID := range user.apiAddrs { for addrID := range user.apiAddrs {
user.updateCh[addrID] = primaryUpdateCh user.updateCh[addrID] = primaryUpdateCh
@ -604,7 +617,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
case vault.SplitMode: case vault.SplitMode:
for addrID := range user.apiAddrs { for addrID := range user.apiAddrs {
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0) user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
} }
} }
} }
@ -614,7 +627,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
// When we receive an API event, we attempt to handle it. // When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault. // If successful, we update the event ID in the vault.
func (user *User) startEvents(ctx context.Context) { func (user *User) startEvents(ctx context.Context) {
ticker := proton.NewTicker(EventPeriod, EventJitter) ticker := proton.NewTicker(EventPeriod, EventJitter, user.panicHandler)
defer ticker.Stop() defer ticker.Stop()
for { for {

View File

@ -119,7 +119,7 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
saltedKeyPass, err := salts.SaltForKey([]byte(password), apiUser.Keys.Primary().ID) saltedKeyPass, err := salts.SaltForKey([]byte(password), apiUser.Keys.Primary().ID)
require.NoError(tb, err) require.NoError(tb, err)
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key")) v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"), nil)
require.NoError(tb, err) require.NoError(tb, err)
require.False(tb, corrupt) require.False(tb, corrupt)

View File

@ -25,6 +25,7 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
@ -52,7 +53,7 @@ func TestMigrate(t *testing.T) {
require.NoError(t, os.WriteFile(filepath.Join(dir, "vault.enc"), b, 0o600)) require.NoError(t, os.WriteFile(filepath.Join(dir, "vault.enc"), b, 0o600))
// Migrate the vault. // Migrate the vault.
s, corrupt, err := New(dir, "default-gluon-dir", []byte("my secret key")) s, corrupt, err := New(dir, "default-gluon-dir", []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)

View File

@ -22,6 +22,7 @@ import (
"testing" "testing"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v3/internal/updater" "github.com/ProtonMail/proton-bridge/v3/internal/updater"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -63,7 +64,7 @@ func TestVault_Settings_SMTP(t *testing.T) {
func TestVault_Settings_GluonDir(t *testing.T) { func TestVault_Settings_GluonDir(t *testing.T) {
// create a new test vault. // create a new test vault.
s, corrupt, err := vault.New(t.TempDir(), "/path/to/gluon", []byte("my secret key")) s, corrupt, err := vault.New(t.TempDir(), "/path/to/gluon", []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)

View File

@ -29,6 +29,7 @@ import (
"path/filepath" "path/filepath"
"sync" "sync"
"github.com/ProtonMail/proton-bridge/v3/internal/async"
"github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -44,10 +45,12 @@ type Vault struct {
ref map[string]int ref map[string]int
refLock sync.Mutex refLock sync.Mutex
panicHandler async.PanicHandler
} }
// New constructs a new encrypted data vault at the given filepath using the given encryption key. // New constructs a new encrypted data vault at the given filepath using the given encryption key.
func New(vaultDir, gluonCacheDir string, key []byte) (*Vault, bool, error) { func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHandler) (*Vault, bool, error) {
if err := os.MkdirAll(vaultDir, 0o700); err != nil { if err := os.MkdirAll(vaultDir, 0o700); err != nil {
return nil, false, err return nil, false, err
} }
@ -69,9 +72,17 @@ func New(vaultDir, gluonCacheDir string, key []byte) (*Vault, bool, error) {
return nil, false, err return nil, false, err
} }
vault.panicHandler = panicHandler
return vault, corrupt, nil return vault, corrupt, nil
} }
func (vault *Vault) handlePanic() {
if vault.panicHandler != nil {
vault.panicHandler.HandlePanic()
}
}
// GetUserIDs returns the user IDs and usernames of all users in the vault. // GetUserIDs returns the user IDs and usernames of all users in the vault.
func (vault *Vault) GetUserIDs() []string { func (vault *Vault) GetUserIDs() []string {
return xslices.Map(vault.get().Users, func(user UserData) string { return xslices.Map(vault.get().Users, func(user UserData) string {
@ -115,6 +126,8 @@ func (vault *Vault) ForUser(parallelism int, fn func(*User) error) error {
userIDs := vault.GetUserIDs() userIDs := vault.GetUserIDs()
return parallel.DoContext(context.Background(), parallelism, len(userIDs), func(_ context.Context, idx int) error { return parallel.DoContext(context.Background(), parallelism, len(userIDs), func(_ context.Context, idx int) error {
defer vault.handlePanic()
user, err := vault.NewUser(userIDs[idx]) user, err := vault.NewUser(userIDs[idx])
if err != nil { if err != nil {
return err return err

View File

@ -22,6 +22,7 @@ import (
"runtime" "runtime"
"testing" "testing"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -31,7 +32,7 @@ func BenchmarkVault(b *testing.B) {
vaultDir, gluonDir := b.TempDir(), b.TempDir() vaultDir, gluonDir := b.TempDir(), b.TempDir()
// Create a new vault. // Create a new vault.
s, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) s, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(b, err) require.NoError(b, err)
require.False(b, corrupt) require.False(b, corrupt)

View File

@ -22,6 +22,7 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -30,19 +31,19 @@ func TestVault_Corrupt(t *testing.T) {
vaultDir, gluonDir := t.TempDir(), t.TempDir() vaultDir, gluonDir := t.TempDir(), t.TempDir()
{ {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)
} }
{ {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)
} }
{ {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("bad key")) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("bad key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.True(t, corrupt) require.True(t, corrupt)
} }
@ -52,13 +53,13 @@ func TestVault_Corrupt_JunkData(t *testing.T) {
vaultDir, gluonDir := t.TempDir(), t.TempDir() vaultDir, gluonDir := t.TempDir(), t.TempDir()
{ {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)
} }
{ {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)
} }
@ -71,7 +72,7 @@ func TestVault_Corrupt_JunkData(t *testing.T) {
_, err = f.Write([]byte("junk data")) _, err = f.Write([]byte("junk data"))
require.NoError(t, err) require.NoError(t, err)
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.True(t, corrupt) require.True(t, corrupt)
} }
@ -99,7 +100,7 @@ func TestVault_Reset(t *testing.T) {
func newVault(t *testing.T) *vault.Vault { func newVault(t *testing.T) *vault.Vault {
t.Helper() t.Helper()
s, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key")) s, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), queue.NoopPanicHandler{})
require.NoError(t, err) require.NoError(t, err)
require.False(t, corrupt) require.False(t, corrupt)

View File

@ -44,7 +44,7 @@ func (c *eventCollector) collectFrom(eventCh <-chan events.Event) <-chan events.
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
fwdCh := queue.NewQueuedChannel[events.Event](0, 0) fwdCh := queue.NewQueuedChannel[events.Event](0, 0, queue.NoopPanicHandler{})
c.fwdCh = append(c.fwdCh, fwdCh) c.fwdCh = append(c.fwdCh, fwdCh)
@ -87,7 +87,7 @@ func (c *eventCollector) push(event events.Event) {
defer c.lock.Unlock() defer c.lock.Unlock()
if _, ok := c.events[reflect.TypeOf(event)]; !ok { if _, ok := c.events[reflect.TypeOf(event)]; !ok {
c.events[reflect.TypeOf(event)] = queue.NewQueuedChannel[events.Event](0, 0) c.events[reflect.TypeOf(event)] = queue.NewQueuedChannel[events.Event](0, 0, queue.NoopPanicHandler{})
} }
c.events[reflect.TypeOf(event)].Enqueue(event) c.events[reflect.TypeOf(event)].Enqueue(event)
@ -102,7 +102,7 @@ func (c *eventCollector) getEventCh(ofType events.Event) <-chan events.Event {
defer c.lock.Unlock() defer c.lock.Unlock()
if _, ok := c.events[reflect.TypeOf(ofType)]; !ok { if _, ok := c.events[reflect.TypeOf(ofType)]; !ok {
c.events[reflect.TypeOf(ofType)] = queue.NewQueuedChannel[events.Event](0, 0) c.events[reflect.TypeOf(ofType)] = queue.NewQueuedChannel[events.Event](0, 0, queue.NoopPanicHandler{})
} }
return c.events[reflect.TypeOf(ofType)].GetChannel() return c.events[reflect.TypeOf(ofType)].GetChannel()

View File

@ -108,7 +108,7 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) {
} }
// Create the vault. // Create the vault.
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, t.storeKey) vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, t.storeKey, queue.NoopPanicHandler{})
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create vault: %w", err) return nil, fmt.Errorf("could not create vault: %w", err)
} else if corrupt { } else if corrupt {
@ -301,7 +301,7 @@ func (t *testCtx) initFrontendClient() error {
return fmt.Errorf("could not start event stream: %w", err) return fmt.Errorf("could not start event stream: %w", err)
} }
eventCh := queue.NewQueuedChannel[*frontend.StreamEvent](0, 0) eventCh := queue.NewQueuedChannel[*frontend.StreamEvent](0, 0, queue.NoopPanicHandler{})
go func() { go func() {
defer eventCh.CloseAndDiscardQueued() defer eventCh.CloseAndDiscardQueued()

View File

@ -23,6 +23,7 @@ import (
"os" "os"
"runtime" "runtime"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/stream"
@ -113,7 +114,7 @@ func (t *testCtx) withAddrKR(
return err return err
} }
_, addrKRs, err := proton.Unlock(user, addr, keyPass) _, addrKRs, err := proton.Unlock(user, addr, keyPass, queue.NoopPanicHandler{})
if err != nil { if err != nil {
return err return err
} }