diff --git a/go.mod b/go.mod index 8e93ee44..aa13f862 100644 --- a/go.mod +++ b/go.mod @@ -125,6 +125,8 @@ require ( ) replace ( + github.com/ProtonMail/gluon => /home/dev/gopath18/src/gluon + github.com/ProtonMail/go-proton-api => /home/dev/gopath18/src/go-proton-api github.com/docker/docker-credential-helpers => github.com/ProtonMail/docker-credential-helpers v1.1.0 github.com/emersion/go-message => github.com/ProtonMail/go-message v0.0.0-20210611055058-fabeff2ec753 github.com/keybase/go-keychain => github.com/cuthix/go-keychain v0.0.0-20220405075754-31e7cee908fe diff --git a/internal/app/app.go b/internal/app/app.go index 36240e74..4cdbd187 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -185,14 +185,14 @@ func run(c *cli.Context) error { exe = os.Args[0] } - migrationErr := migrateOldVersions() + // Restart the app if requested. + return withRestarter(exe, func(restarter *restarter.Restarter) error { + // Handle crashes with various actions. + return withCrashHandler(restarter, reporter, func(crashHandler *crash.Handler, quitCh <-chan struct{}) error { + migrationErr := migrateOldVersions() - // Run with profiling if requested. - return withProfiler(c, func() error { - // Restart the app if requested. - return withRestarter(exe, func(restarter *restarter.Restarter) error { - // Handle crashes with various actions. - return withCrashHandler(restarter, reporter, func(crashHandler *crash.Handler, quitCh <-chan struct{}) error { + // Run with profiling if requested. + return withProfiler(c, func() error { // Load the locations where we store our files. return WithLocations(func(locations *locations.Locations) error { // Migrate the keychain helper. @@ -215,7 +215,7 @@ func run(c *cli.Context) error { return withSingleInstance(settings, locations.GetLockFile(), version, func() error { // Unlock the encrypted vault. - return WithVault(locations, func(v *vault.Vault, insecure, corrupt bool) error { + return WithVault(locations, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error { // Report insecure vault. if insecure { _ = reporter.ReportMessageWithContext("Vault is insecure", map[string]interface{}{}) diff --git a/internal/app/bridge.go b/internal/app/bridge.go index 6705227a..280d444f 100644 --- a/internal/app/bridge.go +++ b/internal/app/bridge.go @@ -78,7 +78,7 @@ func withBridge( ) // Create a proxy dialer which switches to a proxy if the request fails. - proxyDialer := dialer.NewProxyTLSDialer(pinningDialer, constants.APIHost) + proxyDialer := dialer.NewProxyTLSDialer(pinningDialer, constants.APIHost, crashHandler) // Create the autostarter. autostarter := newAutostarter(exe) diff --git a/internal/app/frontend.go b/internal/app/frontend.go index 4c56d968..6e1bd8fa 100644 --- a/internal/app/frontend.go +++ b/internal/app/frontend.go @@ -46,7 +46,7 @@ func runFrontend( switch { case c.Bool(flagCLI): - return bridgeCLI.New(bridge, restarter, eventCh).Loop() + return bridgeCLI.New(bridge, restarter, eventCh, crashHandler).Loop() case c.Bool(flagNonInteractive): select {} diff --git a/internal/app/migration_test.go b/internal/app/migration_test.go index 02dff7e1..c994ed63 100644 --- a/internal/app/migration_test.go +++ b/internal/app/migration_test.go @@ -25,6 +25,7 @@ import ( "runtime" "testing" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v3/internal/bridge" "github.com/ProtonMail/proton-bridge/v3/internal/cookies" @@ -40,7 +41,7 @@ import ( func TestMigratePrefsToVaultWithKeys(t *testing.T) { // Create a new vault. - vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key")) + vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.False(t, corrupt) @@ -61,7 +62,7 @@ func TestMigratePrefsToVaultWithKeys(t *testing.T) { func TestMigratePrefsToVaultWithoutKeys(t *testing.T) { // Create a new vault. - vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key")) + vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.False(t, corrupt) @@ -173,7 +174,7 @@ func TestUserMigration(t *testing.T) { token, err := crypto.RandomToken(32) require.NoError(t, err) - v, corrupt, err := vault.New(settingsFolder, settingsFolder, token) + v, corrupt, err := vault.New(settingsFolder, settingsFolder, token, queue.NoopPanicHandler{}) require.NoError(t, err) require.False(t, corrupt) diff --git a/internal/app/vault.go b/internal/app/vault.go index 55b583b1..013653a5 100644 --- a/internal/app/vault.go +++ b/internal/app/vault.go @@ -21,6 +21,7 @@ import ( "fmt" "path" + "github.com/ProtonMail/proton-bridge/v3/internal/async" "github.com/ProtonMail/proton-bridge/v3/internal/certs" "github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/locations" @@ -29,12 +30,12 @@ import ( "github.com/sirupsen/logrus" ) -func WithVault(locations *locations.Locations, fn func(*vault.Vault, bool, bool) error) error { +func WithVault(locations *locations.Locations, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error { logrus.Debug("Creating vault") defer logrus.Debug("Vault stopped") // Create the encVault. - encVault, insecure, corrupt, err := newVault(locations) + encVault, insecure, corrupt, err := newVault(locations, panicHandler) if err != nil { return fmt.Errorf("could not create vault: %w", err) } @@ -66,7 +67,7 @@ func WithVault(locations *locations.Locations, fn func(*vault.Vault, bool, bool) return fn(encVault, insecure, corrupt) } -func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error) { +func newVault(locations *locations.Locations, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) { vaultDir, err := locations.ProvideSettingsPath() if err != nil { return nil, false, false, fmt.Errorf("could not get vault dir: %w", err) @@ -93,7 +94,7 @@ func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error) return nil, false, false, fmt.Errorf("could not provide gluon path: %w", err) } - vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey) + vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey, panicHandler) if err != nil { return nil, false, false, fmt.Errorf("could not create vault: %w", err) } diff --git a/internal/bridge/api.go b/internal/bridge/api.go index cbe2f2be..f7c3d41a 100644 --- a/internal/bridge/api.go +++ b/internal/bridge/api.go @@ -21,6 +21,7 @@ import ( "net/http" "github.com/Masterminds/semver/v3" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/sirupsen/logrus" @@ -32,6 +33,7 @@ func defaultAPIOptions( version *semver.Version, cookieJar http.CookieJar, transport http.RoundTripper, + panicHandler queue.PanicHandler, ) []proton.Option { return []proton.Option{ proton.WithHostURL(apiURL), @@ -39,5 +41,6 @@ func defaultAPIOptions( proton.WithCookieJar(cookieJar), proton.WithTransport(transport), proton.WithLogger(logrus.StandardLogger()), + proton.WithPanicHandler(panicHandler), } } diff --git a/internal/bridge/api_default.go b/internal/bridge/api_default.go index 6f800d77..cc89fe16 100644 --- a/internal/bridge/api_default.go +++ b/internal/bridge/api_default.go @@ -23,6 +23,7 @@ import ( "net/http" "github.com/Masterminds/semver/v3" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/go-proton-api" ) @@ -32,6 +33,7 @@ func newAPIOptions( version *semver.Version, cookieJar http.CookieJar, transport http.RoundTripper, + panicHandler queue.PanicHandler, ) []proton.Option { - return defaultAPIOptions(apiURL, version, cookieJar, transport) + return defaultAPIOptions(apiURL, version, cookieJar, transport, panicHandler) } diff --git a/internal/bridge/api_qa.go b/internal/bridge/api_qa.go index 618ce42e..a8c5a71a 100644 --- a/internal/bridge/api_qa.go +++ b/internal/bridge/api_qa.go @@ -24,6 +24,7 @@ import ( "os" "github.com/Masterminds/semver/v3" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/go-proton-api" ) @@ -33,8 +34,9 @@ func newAPIOptions( version *semver.Version, cookieJar http.CookieJar, transport http.RoundTripper, + panicHandler queue.PanicHandler, ) []proton.Option { - opt := defaultAPIOptions(apiURL, version, cookieJar, transport) + opt := defaultAPIOptions(apiURL, version, cookieJar, transport, panicHandler) if host := os.Getenv("BRIDGE_API_HOST"); host != "" { opt = append(opt, proton.WithHostURL(host)) diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 27cde2a4..0eba71fc 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -93,8 +93,8 @@ type Bridge struct { // locator is the bridge's locator. locator Locator - // crashHandler - crashHandler async.PanicHandler + // panicHandler + panicHandler async.PanicHandler // reporter reporter reporter.Reporter @@ -143,7 +143,7 @@ func New( tlsReporter TLSReporter, // the TLS reporter to report TLS errors roundTripper http.RoundTripper, // the round tripper to use for API requests proxyCtl ProxyController, // the DoH controller - crashHandler async.PanicHandler, + panicHandler async.PanicHandler, reporter reporter.Reporter, uidValidityGenerator imap.UIDValidityGenerator, @@ -151,10 +151,10 @@ func New( logSMTP bool, // whether to log SMTP activity ) (*Bridge, <-chan events.Event, error) { // api is the user's API manager. - api := proton.New(newAPIOptions(apiURL, curVersion, cookieJar, roundTripper)...) + api := proton.New(newAPIOptions(apiURL, curVersion, cookieJar, roundTripper, panicHandler)...) // tasks holds all the bridge's background tasks. - tasks := async.NewGroup(context.Background(), crashHandler) + tasks := async.NewGroup(context.Background(), panicHandler) // imapEventCh forwards IMAP events from gluon instances to the bridge for processing. imapEventCh := make(chan imapEvents.Event) @@ -169,7 +169,7 @@ func New( autostarter, updater, curVersion, - crashHandler, + panicHandler, reporter, api, @@ -202,7 +202,7 @@ func newBridge( autostarter Autostarter, updater Updater, curVersion *semver.Version, - crashHandler async.PanicHandler, + panicHandler async.PanicHandler, reporter reporter.Reporter, api *proton.Manager, @@ -248,12 +248,13 @@ func newBridge( imapEventCh, tasks, uidValidityGenerator, + panicHandler, ) if err != nil { return nil, fmt.Errorf("failed to create IMAP server: %w", err) } - focusService, err := focus.NewService(locator, curVersion) + focusService, err := focus.NewService(locator, curVersion, panicHandler) if err != nil { return nil, fmt.Errorf("failed to create focus service: %w", err) } @@ -279,7 +280,7 @@ func newBridge( newVersion: curVersion, newVersionLock: safe.NewRWMutex(), - crashHandler: crashHandler, + panicHandler: panicHandler, reporter: reporter, focusService: focusService, @@ -495,7 +496,7 @@ func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events bridge.watchersLock.Lock() defer bridge.watchersLock.Unlock() - watcher := watcher.New(ofType...) + watcher := watcher.New(bridge.panicHandler, ofType...) bridge.watchers = append(bridge.watchers, watcher) diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 992d241f..28778cd5 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -31,6 +31,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api/server" "github.com/ProtonMail/go-proton-api/server/backend" @@ -699,7 +700,7 @@ func withBridgeNoMocks( require.NoError(t, err) // Create the vault. - vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey) + vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey, queue.NoopPanicHandler{}) require.NoError(t, err) // Create a new cookie jar. diff --git a/internal/bridge/imap.go b/internal/bridge/imap.go index d6efff44..752155dd 100644 --- a/internal/bridge/imap.go +++ b/internal/bridge/imap.go @@ -299,6 +299,7 @@ func newIMAPServer( eventCh chan<- imapEvents.Event, tasks *async.Group, uidValidityGenerator imap.UIDValidityGenerator, + panicHandler async.PanicHandler, ) (*gluon.Server, error) { gluonCacheDir = ApplyGluonCachePathSuffix(gluonCacheDir) gluonConfigDir = ApplyGluonConfigPathSuffix(gluonConfigDir) @@ -343,6 +344,7 @@ func newIMAPServer( getGluonVersionInfo(version), gluon.WithReporter(reporter), gluon.WithUIDValidityGenerator(uidValidityGenerator), + gluon.WithPanicHandler(panicHandler), ) if err != nil { return nil, err diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index 0ab4c5a0..8c262cc9 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -157,6 +157,7 @@ func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error bridge.imapEventCh, bridge.tasks, bridge.uidValidityGenerator, + bridge.panicHandler, ) if err != nil { return fmt.Errorf("failed to create new IMAP server: %w", err) diff --git a/internal/bridge/sync_test.go b/internal/bridge/sync_test.go index 182986c2..e0ffe1c8 100644 --- a/internal/bridge/sync_test.go +++ b/internal/bridge/sync_test.go @@ -28,6 +28,7 @@ import ( "testing" "time" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api/server" @@ -428,7 +429,7 @@ func createMessages(ctx context.Context, t *testing.T, c *proton.Client, addrID, keyPass, err := salt.SaltForKey(password, user.Keys.Primary().ID) require.NoError(t, err) - _, addrKRs, err := proton.Unlock(user, addr, keyPass) + _, addrKRs, err := proton.Unlock(user, addr, keyPass, queue.NoopPanicHandler{}) require.NoError(t, err) _, ok := addrKRs[addrID] diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 59f8f00c..4bdb6152 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -516,7 +516,7 @@ func (bridge *Bridge) addUserWithVault( client, bridge.reporter, apiUser, - bridge.crashHandler, + bridge.panicHandler, bridge.vault.GetShowAllMail(), bridge.vault.GetMaxSyncMemory(), ) diff --git a/internal/bridge/user_event_test.go b/internal/bridge/user_event_test.go index d48094c7..616f492c 100644 --- a/internal/bridge/user_event_test.go +++ b/internal/bridge/user_event_test.go @@ -28,6 +28,7 @@ import ( "testing" "time" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api/server" @@ -474,7 +475,7 @@ func TestBridge_User_UpdateDraftAndCreateOtherMessage(t *testing.T) { keyPass, err := salts.SaltForKey(password, user.Keys.Primary().ID) require.NoError(t, err) - _, addrKRs, err := proton.Unlock(user, addrs, keyPass) + _, addrKRs, err := proton.Unlock(user, addrs, keyPass, queue.NoopPanicHandler{}) require.NoError(t, err) // Create a draft (generating a "create draft message" event). @@ -556,7 +557,7 @@ func TestBridge_User_SendDraftRemoveDraftFlag(t *testing.T) { keyPass, err := salts.SaltForKey(password, user.Keys.Primary().ID) require.NoError(t, err) - _, addrKRs, err := proton.Unlock(user, addrs, keyPass) + _, addrKRs, err := proton.Unlock(user, addrs, keyPass, queue.NoopPanicHandler{}) require.NoError(t, err) // Create a draft (generating a "create draft message" event). diff --git a/internal/clientconfig/applemail.go b/internal/clientconfig/applemail.go index 7eb56922..1a16fdb5 100644 --- a/internal/clientconfig/applemail.go +++ b/internal/clientconfig/applemail.go @@ -98,6 +98,8 @@ func saveConfigTemporarily(mc *mobileconfig.Config) (fname string, err error) { // Make sure the temporary file is deleted. go func() { + defer recover() //nolint:errcheck + <-time.After(10 * time.Minute) _ = os.RemoveAll(dir) }() diff --git a/internal/dialer/dialer_proxy.go b/internal/dialer/dialer_proxy.go index 0ad1474e..0ce2945e 100644 --- a/internal/dialer/dialer_proxy.go +++ b/internal/dialer/dialer_proxy.go @@ -24,6 +24,7 @@ import ( "sync" "time" + "github.com/ProtonMail/proton-bridge/v3/internal/async" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -40,17 +41,20 @@ type ProxyTLSDialer struct { allowProxy bool proxyProvider *proxyProvider proxyUseDuration time.Duration + + panicHandler async.PanicHandler } // NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer. -func NewProxyTLSDialer(dialer TLSDialer, hostURL string) *ProxyTLSDialer { +func NewProxyTLSDialer(dialer TLSDialer, hostURL string, panicHandler async.PanicHandler) *ProxyTLSDialer { return &ProxyTLSDialer{ dialer: dialer, locker: sync.RWMutex{}, directAddress: formatAsAddress(hostURL), proxyAddress: formatAsAddress(hostURL), - proxyProvider: newProxyProvider(dialer, hostURL, DoHProviders), + proxyProvider: newProxyProvider(dialer, hostURL, DoHProviders, panicHandler), proxyUseDuration: proxyUseDuration, + panicHandler: panicHandler, } } @@ -75,6 +79,12 @@ func formatAsAddress(rawURL string) string { return net.JoinHostPort(host, port) } +func (d *ProxyTLSDialer) handlePanic() { + if d.panicHandler != nil { + d.panicHandler.HandlePanic() + } +} + // DialTLSContext dials the given network/address. If it fails, it retries using a proxy. func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { d.locker.RLock() @@ -129,6 +139,8 @@ func (d *ProxyTLSDialer) switchToReachableServer() error { // This means we want to disable it again in 24 hours. if d.proxyAddress == d.directAddress { go func() { + defer d.handlePanic() + <-time.After(d.proxyUseDuration) d.locker.Lock() diff --git a/internal/dialer/dialer_proxy_provider.go b/internal/dialer/dialer_proxy_provider.go index b9cfd2fd..e4006891 100644 --- a/internal/dialer/dialer_proxy_provider.go +++ b/internal/dialer/dialer_proxy_provider.go @@ -24,6 +24,7 @@ import ( "sync" "time" + "github.com/ProtonMail/proton-bridge/v3/internal/async" "github.com/go-resty/resty/v2" "github.com/miekg/dns" "github.com/pkg/errors" @@ -67,11 +68,13 @@ type proxyProvider struct { canReachTimeout time.Duration lastLookup time.Time // The time at which we last attempted to find a proxy. + + panicHandler async.PanicHandler } // newProxyProvider creates a new proxyProvider that queries the given DoH providers // to retrieve DNS records for the given query string. -func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p *proxyProvider) { +func newProxyProvider(dialer TLSDialer, hostURL string, providers []string, panicHandler async.PanicHandler) (p *proxyProvider) { p = &proxyProvider{ dialer: dialer, hostURL: hostURL, @@ -80,6 +83,7 @@ func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p * cacheRefreshTimeout: proxyCacheRefreshTimeout, dohTimeout: proxyDoHTimeout, canReachTimeout: proxyCanReachTimeout, + panicHandler: panicHandler, } // Use the default DNS lookup method; this can be overridden if necessary. @@ -88,6 +92,12 @@ func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p * return } +func (p *proxyProvider) handlePanic() { + if p.panicHandler != nil { + p.panicHandler.HandlePanic() + } +} + // findReachableServer returns a working API server (either proxy or standard API). func (p *proxyProvider) findReachableServer() (proxy string, err error) { logrus.Debug("Trying to find a reachable server") @@ -109,11 +119,13 @@ func (p *proxyProvider) findReachableServer() (proxy string, err error) { wg.Add(2) go func() { + defer p.handlePanic() defer wg.Done() apiReachable = p.canReach(p.hostURL) }() go func() { + defer p.handlePanic() defer wg.Done() err = p.refreshProxyCache() }() @@ -150,6 +162,8 @@ func (p *proxyProvider) refreshProxyCache() error { resultChan := make(chan []string) go func() { + defer p.handlePanic() + for _, provider := range p.providers { if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil { resultChan <- proxies @@ -203,6 +217,7 @@ func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider dataChan, errChan := make(chan []string), make(chan error) go func() { + defer p.handlePanic() // Build new DNS request in RFC1035 format. dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT) diff --git a/internal/dialer/dialer_proxy_provider_test.go b/internal/dialer/dialer_proxy_provider_test.go index f322d462..ddd7aed2 100644 --- a/internal/dialer/dialer_proxy_provider_test.go +++ b/internal/dialer/dialer_proxy_provider_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/proton-bridge/v3/internal/useragent" r "github.com/stretchr/testify/require" ) @@ -31,7 +32,7 @@ func TestProxyProvider_FindProxy(t *testing.T) { proxy := getTrustedServer() defer closeServer(proxy) - p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}) + p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{}) p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil } url, err := p.findReachableServer() @@ -47,7 +48,7 @@ func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) { unreachableProxy := getTrustedServer() closeServer(unreachableProxy) - p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}) + p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{}) p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{reachableProxy.URL, unreachableProxy.URL}, nil } @@ -68,7 +69,7 @@ func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) { checker := NewTLSPinChecker(TrustedAPIPins) dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker) - p := newProxyProvider(dialer, "", []string{"not used"}) + p := newProxyProvider(dialer, "", []string{"not used"}, queue.NoopPanicHandler{}) p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{untrustedProxy.URL, trustedProxy.URL}, nil } @@ -85,7 +86,7 @@ func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) { unreachableProxy2 := getTrustedServer() closeServer(unreachableProxy2) - p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}) + p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{}) p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil } @@ -105,7 +106,7 @@ func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) { checker := NewTLSPinChecker(TrustedAPIPins) dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker) - p := newProxyProvider(dialer, "", []string{"not used"}) + p := newProxyProvider(dialer, "", []string{"not used"}, queue.NoopPanicHandler{}) p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil } @@ -115,7 +116,7 @@ func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) { } func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) { - p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}) + p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{}) p.cacheRefreshTimeout = 1 * time.Second p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil } @@ -132,7 +133,7 @@ func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) { })) defer closeServer(slowProxy) - p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}) + p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{}) p.canReachTimeout = 1 * time.Second p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil } @@ -144,7 +145,7 @@ func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) { } func TestProxyProvider_DoHLookup_Quad9(t *testing.T) { - p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}) + p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{}) records, err := p.dohLookup(context.Background(), proxyQuery, Quad9Provider) r.NoError(t, err) @@ -155,7 +156,7 @@ func TestProxyProvider_DoHLookup_Quad9(t *testing.T) { // port filter. Basic functionality should be covered by other tests. Keeping // code here to be able to run it locally if needed. func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) { - p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}) + p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{}) records, err := p.dohLookup(context.Background(), proxyQuery, Quad9PortProvider) r.NoError(t, err) @@ -163,7 +164,7 @@ func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) { } func TestProxyProvider_DoHLookup_Google(t *testing.T) { - p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}) + p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{}) records, err := p.dohLookup(context.Background(), proxyQuery, GoogleProvider) r.NoError(t, err) @@ -173,7 +174,7 @@ func TestProxyProvider_DoHLookup_Google(t *testing.T) { func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) { skipIfProxyIsSet(t) - p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}) + p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{}) url, err := p.findReachableServer() r.NoError(t, err) @@ -183,7 +184,7 @@ func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) { func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) { skipIfProxyIsSet(t) - p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"https://unreachable", Quad9Provider, GoogleProvider}) + p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"https://unreachable", Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{}) url, err := p.findReachableServer() r.NoError(t, err) diff --git a/internal/dialer/dialer_proxy_test.go b/internal/dialer/dialer_proxy_test.go index 28290aeb..a78b3055 100644 --- a/internal/dialer/dialer_proxy_test.go +++ b/internal/dialer/dialer_proxy_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + "github.com/ProtonMail/gluon/queue" "github.com/stretchr/testify/require" ) @@ -141,8 +142,8 @@ func TestProxyDialer_UseProxy(t *testing.T) { trustedProxy := getTrustedServer() defer closeServer(trustedProxy) - provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders) - d := NewProxyTLSDialer(NewBasicTLSDialer(""), "") + provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{}) + d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{}) d.proxyProvider = provider provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } @@ -159,8 +160,8 @@ func TestProxyDialer_UseProxy_MultipleTimes(t *testing.T) { proxy3 := getTrustedServer() defer closeServer(proxy3) - provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders) - d := NewProxyTLSDialer(NewBasicTLSDialer(""), "") + provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{}) + d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{}) d.proxyProvider = provider provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil } @@ -189,8 +190,8 @@ func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) { trustedProxy := getTrustedServer() defer closeServer(trustedProxy) - provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders) - d := NewProxyTLSDialer(NewBasicTLSDialer(""), "") + provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{}) + d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{}) d.proxyProvider = provider d.proxyUseDuration = time.Second @@ -212,8 +213,8 @@ func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) { func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) { trustedProxy := getTrustedServer() - provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders) - d := NewProxyTLSDialer(NewBasicTLSDialer(""), "") + provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{}) + d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{}) d.proxyProvider = provider provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil } @@ -242,8 +243,8 @@ func TestProxyDialer_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBloc proxy2 := getTrustedServer() defer closeServer(proxy2) - provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders) - d := NewProxyTLSDialer(NewBasicTLSDialer(""), "") + provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{}) + d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{}) d.proxyProvider = provider provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } diff --git a/internal/focus/focus_test.go b/internal/focus/focus_test.go index 059d243d..101ed389 100644 --- a/internal/focus/focus_test.go +++ b/internal/focus/focus_test.go @@ -30,7 +30,7 @@ func TestFocus_Raise(t *testing.T) { tmpDir := t.TempDir() locations := locations.New(newTestLocationsProvider(tmpDir), "config-name") // Start the focus service. - service, err := NewService(locations, semver.MustParse("1.2.3")) + service, err := NewService(locations, semver.MustParse("1.2.3"), nil) require.NoError(t, err) settingsFolder, err := locations.ProvideSettingsPath() @@ -52,7 +52,7 @@ func TestFocus_Version(t *testing.T) { tmpDir := t.TempDir() locations := locations.New(newTestLocationsProvider(tmpDir), "config-name") // Start the focus service. - _, err := NewService(locations, semver.MustParse("1.2.3")) + _, err := NewService(locations, semver.MustParse("1.2.3"), nil) require.NoError(t, err) settingsFolder, err := locations.ProvideSettingsPath() diff --git a/internal/focus/service.go b/internal/focus/service.go index ab2d0127..ad0395d1 100644 --- a/internal/focus/service.go +++ b/internal/focus/service.go @@ -24,6 +24,7 @@ import ( "net" "github.com/Masterminds/semver/v3" + "github.com/ProtonMail/proton-bridge/v3/internal/async" "github.com/ProtonMail/proton-bridge/v3/internal/focus/proto" "github.com/ProtonMail/proton-bridge/v3/internal/service" "github.com/sirupsen/logrus" @@ -43,15 +44,18 @@ type Service struct { server *grpc.Server raiseCh chan struct{} version *semver.Version + + panicHandler async.PanicHandler } // NewService creates a new focus service. // It listens on the local host and port 1042 (by default). -func NewService(locator service.Locator, version *semver.Version) (*Service, error) { +func NewService(locator service.Locator, version *semver.Version, panicHandler async.PanicHandler) (*Service, error) { serv := &Service{ - server: grpc.NewServer(), - raiseCh: make(chan struct{}, 1), - version: version, + server: grpc.NewServer(), + raiseCh: make(chan struct{}, 1), + version: version, + panicHandler: panicHandler, } proto.RegisterFocusServer(serv.server, serv) @@ -73,6 +77,8 @@ func NewService(locator service.Locator, version *semver.Version) (*Service, err } go func() { + defer serv.handlePanic() + if err := serv.server.Serve(listener); err != nil { fmt.Printf("failed to serve: %v", err) } @@ -82,6 +88,12 @@ func NewService(locator service.Locator, version *semver.Version) (*Service, err return serv, nil } +func (service *Service) handlePanic() { + if service.panicHandler != nil { + service.panicHandler.HandlePanic() + } +} + // Raise implements the gRPC FocusService interface; it raises the application. func (service *Service) Raise(context.Context, *emptypb.Empty) (*emptypb.Empty, error) { service.raiseCh <- struct{}{} @@ -103,6 +115,8 @@ func (service *Service) GetRaiseCh() <-chan struct{} { // Close closes the service. func (service *Service) Close() { go func() { + defer service.handlePanic() + // we do this in a goroutine, as on Windows, the gRPC shutdown may take minutes if something tries to // interact with it in an invalid way (e.g. HTTP GET request from a Qt QNetworkManager instance). service.server.Stop() diff --git a/internal/frontend/cli/frontend.go b/internal/frontend/cli/frontend.go index 4063f416..94b5414a 100644 --- a/internal/frontend/cli/frontend.go +++ b/internal/frontend/cli/frontend.go @@ -21,6 +21,7 @@ package cli import ( "errors" + "github.com/ProtonMail/proton-bridge/v3/internal/async" "github.com/ProtonMail/proton-bridge/v3/internal/bridge" "github.com/ProtonMail/proton-bridge/v3/internal/constants" "github.com/ProtonMail/proton-bridge/v3/internal/events" @@ -39,15 +40,18 @@ type frontendCLI struct { restarter *restarter.Restarter badUserID string + + panicHandler async.PanicHandler } // New returns a new CLI frontend configured with the given options. -func New(bridge *bridge.Bridge, restarter *restarter.Restarter, eventCh <-chan events.Event) *frontendCLI { //nolint:revive +func New(bridge *bridge.Bridge, restarter *restarter.Restarter, eventCh <-chan events.Event, panicHandler async.PanicHandler) *frontendCLI { //nolint:revive fe := &frontendCLI{ - Shell: ishell.New(), - bridge: bridge, - restarter: restarter, - badUserID: "", + Shell: ishell.New(), + bridge: bridge, + restarter: restarter, + badUserID: "", + panicHandler: panicHandler, } // Clear commands. @@ -285,6 +289,8 @@ func New(bridge *bridge.Bridge, restarter *restarter.Restarter, eventCh <-chan e } func (f *frontendCLI) watchEvents(eventCh <-chan events.Event) { // nolint:gocyclo + defer f.handlePanic() + // GODT-1949: Better error events. for _, err := range f.bridge.GetErrors() { switch { @@ -445,6 +451,12 @@ func (f *frontendCLI) watchEvents(eventCh <-chan events.Event) { // nolint:gocyc */ } +func (f *frontendCLI) handlePanic() { + if f.panicHandler != nil { + f.panicHandler.HandlePanic() + } +} + // Loop starts the frontend loop with an interactive shell. func (f *frontendCLI) Loop() error { f.Printf(` diff --git a/internal/frontend/grpc/service.go b/internal/frontend/grpc/service.go index 89c6b585..4f308ef3 100644 --- a/internal/frontend/grpc/service.go +++ b/internal/frontend/grpc/service.go @@ -191,6 +191,12 @@ func NewService( return s, nil } +func (s *Service) handlePanic() { + if s.panicHandler != nil { + s.panicHandler.HandlePanic() + } +} + func (s *Service) initAutostart() { s.firstTimeAutostart.Do(func() { shouldAutostartBeOn := s.bridge.GetAutostart() @@ -207,11 +213,14 @@ func (s *Service) Loop() error { if s.parentPID < 0 { s.log.Info("Not monitoring parent PID") } else { - go s.monitorParentPID() + go func() { + defer s.handlePanic() + s.monitorParentPID() + }() } go func() { - defer s.panicHandler.HandlePanic() + defer s.handlePanic() s.watchEvents() }() @@ -221,6 +230,8 @@ func (s *Service) Loop() error { defer close(doneCh) go func() { + defer s.handlePanic() + select { case <-s.quitCh: s.log.Info("Stopping gRPC server") @@ -564,6 +575,8 @@ func (s *Service) monitorParentPID() { s.log.Info("Parent process does not exist anymore. Initiating shutdown") // quit will write to the parentPIDDoneCh, so we launch a goroutine. go func() { + defer s.handlePanic() + if err := s.quit(); err != nil { logrus.WithError(err).Error("Error on quit") } diff --git a/internal/frontend/grpc/service_methods.go b/internal/frontend/grpc/service_methods.go index 251a2d84..be5d7276 100644 --- a/internal/frontend/grpc/service_methods.go +++ b/internal/frontend/grpc/service_methods.go @@ -114,6 +114,8 @@ func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empt func (s *Service) quit() error { // Windows is notably slow at Quitting. We do it in a goroutine to speed things up a bit. go func() { + defer s.handlePanic() + if s.parentPID >= 0 { s.parentPIDDoneCh <- struct{}{} } @@ -221,7 +223,8 @@ func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb. s.log.Debug("TriggerReset") go func() { - defer s.panicHandler.HandlePanic() + defer s.handlePanic() + s.triggerReset() }() return &emptypb.Empty{}, nil @@ -316,6 +319,8 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp }).Debug("ReportBug") go func() { + defer s.handlePanic() + defer func() { _ = s.SendEvent(NewReportBugFinishedEvent()) }() if err := s.bridge.ReportBug( @@ -343,7 +348,7 @@ func (s *Service) ExportTLSCertificates(_ context.Context, folderPath *wrappersp s.log.WithField("folderPath", folderPath).Info("ExportTLSCertificates") go func() { - defer s.panicHandler.HandlePanic() + defer s.handlePanic() cert, key := s.bridge.GetBridgeTLSCert() @@ -379,7 +384,7 @@ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empt s.log.WithField("username", login.Username).Debug("Login") go func() { - defer s.panicHandler.HandlePanic() + defer s.handlePanic() password, err := base64Decode(login.Password) if err != nil { @@ -435,7 +440,7 @@ func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.E s.log.WithField("username", login.Username).Debug("Login2FA") go func() { - defer s.panicHandler.HandlePanic() + defer s.handlePanic() if s.auth.UID == "" || s.authClient == nil { s.log.Errorf("Login 2FA: authethication incomplete %s %p", s.auth.UID, s.authClient) @@ -480,7 +485,7 @@ func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*em s.log.WithField("username", login.Username).Debug("Login2Passwords") go func() { - defer s.panicHandler.HandlePanic() + defer s.handlePanic() password, err := base64Decode(login.Password) if err != nil { @@ -502,7 +507,7 @@ func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest) s.log.WithField("username", loginAbort.Username).Debug("LoginAbort") go func() { - defer s.panicHandler.HandlePanic() + defer s.handlePanic() s.loginAbort() }() @@ -514,7 +519,7 @@ func (s *Service) CheckUpdate(context.Context, *emptypb.Empty) (*emptypb.Empty, s.log.Debug("CheckUpdate") go func() { - defer s.panicHandler.HandlePanic() + defer s.handlePanic() updateCh, done := s.bridge.GetEvents( events.UpdateAvailable{}, @@ -546,7 +551,7 @@ func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb s.log.Debug("InstallUpdate") go func() { - defer s.panicHandler.HandlePanic() + defer s.handlePanic() safe.RLock(func() { s.bridge.InstallUpdate(s.target) @@ -587,6 +592,8 @@ func (s *Service) SetDiskCachePath(ctx context.Context, newPath *wrapperspb.Stri s.log.WithField("path", newPath.Value).Debug("setDiskCachePath") go func() { + defer s.handlePanic() + defer func() { _ = s.SendEvent(NewDiskCachePathChangeFinishedEvent()) }() @@ -652,7 +659,7 @@ func (s *Service) SetMailServerSettings(_ context.Context, settings *ImapSmtpSet Debug("SetConnectionMode") go func() { - defer s.panicHandler.HandlePanic() + defer s.handlePanic() defer func() { _ = s.SendEvent(NewChangeMailServerSettingsFinishedEvent()) }() diff --git a/internal/try/try.go b/internal/try/try.go index 1485c683..c134d5ab 100644 --- a/internal/try/try.go +++ b/internal/try/try.go @@ -21,6 +21,7 @@ import ( "fmt" "sync" + "github.com/ProtonMail/proton-bridge/v3/internal/async" "github.com/bradenaw/juniper/xerrors" "github.com/sirupsen/logrus" ) @@ -68,17 +69,32 @@ func catch(handlers ...func() error) { } type Group struct { - mu sync.Mutex + mu sync.Mutex + panicHandler async.PanicHandler +} + +func (wg *Group) SetPanicHandler(panicHandler async.PanicHandler) { + wg.panicHandler = panicHandler +} + +func (wg *Group) handlePanic() { + if wg.panicHandler != nil { + wg.panicHandler.HandlePanic() + } } func (wg *Group) GoTry(fn func(bool)) { if wg.mu.TryLock() { go func() { + defer wg.handlePanic() defer wg.mu.Unlock() fn(true) }() } else { - go fn(false) + go func() { + defer wg.handlePanic() + fn(false) + }() } } diff --git a/internal/user/events.go b/internal/user/events.go index f980454a..8f978ed0 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -225,7 +225,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID] case vault.SplitMode: - user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0) + user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler) } user.eventCh.Enqueue(events.UserAddressCreated{ @@ -284,7 +284,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID] case vault.SplitMode: - user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0) + user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler) } user.eventCh.Enqueue(events.UserAddressEnabled{ @@ -594,7 +594,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M "subject": logging.Sensitive(message.Subject), }).Info("Handling message created event") - full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()) + full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator()) if err != nil { // If the message is not found, it means that it has been deleted before we could fetch it. if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity { @@ -686,7 +686,7 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa "subject": logging.Sensitive(event.Message.Subject), }).Info("Handling draft updated event") - full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()) + full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator()) if err != nil { // If the message is not found, it means that it has been deleted before we could fetch it. if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity { diff --git a/internal/user/imap.go b/internal/user/imap.go index fb5cdaa3..6e607ddb 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -290,7 +290,7 @@ func (conn *imapConnector) CreateMessage( conn.log.WithField("messageID", messageID).Warn("Message already sent") // Query the server-side message. - full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()) + full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()) if err != nil { return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err) } @@ -354,7 +354,7 @@ func (conn *imapConnector) CreateMessage( } func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) { - msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()) + msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()) if err != nil { return nil, err } @@ -572,7 +572,7 @@ func (conn *imapConnector) importMessage( var err error - if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()); err != nil { + if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil { return fmt.Errorf("failed to fetch message: %w", err) } diff --git a/internal/user/smtp.go b/internal/user/smtp.go index 22cc3b66..bcc093f1 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -48,6 +48,8 @@ import ( // sendMail sends an email from the given address to the given recipients. func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error { + defer user.handlePanic() + return safe.RLockRet(func() error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -143,7 +145,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader) } // Send the message using the correct key. - sent, err := sendWithKey( + sent, err := user.sendWithKey( ctx, user.client, user.reporter, @@ -167,7 +169,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader) } // sendWithKey sends the message with the given address key. -func sendWithKey( +func (user *User) sendWithKey( ctx context.Context, client *proton.Client, sentry reporter.Reporter, @@ -226,12 +228,12 @@ func sendWithKey( return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err) } - attKeys, err := createAttachments(ctx, client, addrKR, draft.ID, message.Attachments) + attKeys, err := user.createAttachments(ctx, client, addrKR, draft.ID, message.Attachments) if err != nil { return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err) } - recipients, err := getRecipients(ctx, client, userKR, settings, draft) + recipients, err := user.getRecipients(ctx, client, userKR, settings, draft) if err != nil { return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err) } @@ -377,7 +379,7 @@ func createDraft( }) } -func createAttachments( +func (user *User) createAttachments( ctx context.Context, client *proton.Client, addrKR *crypto.KeyRing, @@ -390,6 +392,8 @@ func createAttachments( } keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) { + defer user.handlePanic() + logrus.WithFields(logrus.Fields{ "name": logging.Sensitive(att.Name), "contentID": att.ContentID, @@ -455,7 +459,7 @@ func createAttachments( return attKeys, nil } -func getRecipients( +func (user *User) getRecipients( ctx context.Context, client *proton.Client, userKR *crypto.KeyRing, @@ -467,6 +471,8 @@ func getRecipients( }) prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) { + defer user.handlePanic() + pubKeys, recType, err := client.GetPublicKeys(ctx, recipient) if err != nil { return proton.SendPreferences{}, fmt.Errorf("failed to get public key for %v: %w", recipient, err) diff --git a/internal/user/sync.go b/internal/user/sync.go index 215c3b82..6b281386 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -153,7 +153,7 @@ func (user *User) sync(ctx context.Context) error { } // Sync the messages. - if err := syncMessages( + if err := user.syncMessages( ctx, user.ID(), messageIDs, @@ -242,7 +242,7 @@ func toMB(v uint64) float64 { } // nolint:gocyclo -func syncMessages( +func (user *User) syncMessages( ctx context.Context, userID string, messageIDs []string, @@ -370,7 +370,7 @@ func syncMessages( errorCh := make(chan error, maxParallelDownloads*4) // Go routine in charge of downloading message metadata - logging.GoAnnotated(ctx, func(ctx context.Context) { + logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { defer close(downloadCh) const MetadataDataPageSize = 150 @@ -433,14 +433,14 @@ func syncMessages( }, logging.Labels{"sync-stage": "meta-data"}) // Goroutine in charge of downloading and building messages in maxBatchSize batches. - logging.GoAnnotated(ctx, func(ctx context.Context) { + logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { defer close(buildCh) defer close(errorCh) defer func() { logrus.Debugf("sync downloader exit") }() - attachmentDownloader := newAttachmentDownloader(ctx, client, maxParallelDownloads) + attachmentDownloader := user.newAttachmentDownloader(ctx, client, maxParallelDownloads) defer attachmentDownloader.close() for request := range downloadCh { @@ -456,6 +456,8 @@ func syncMessages( } result, err := parallel.MapContext(ctx, maxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) { + defer user.handlePanic() + var result proton.FullMessage msg, err := client.GetMessage(ctx, id) @@ -490,7 +492,7 @@ func syncMessages( }, logging.Labels{"sync-stage": "download"}) // Goroutine which builds messages after they have been downloaded - logging.GoAnnotated(ctx, func(ctx context.Context) { + logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { defer close(flushCh) defer func() { logrus.Debugf("sync builder exit") @@ -509,6 +511,8 @@ func syncMessages( logrus.Debugf("Build request: %v of %v count=%v", index, len(chunks), len(chunk)) result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) { + defer user.handlePanic() + return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil }) if err != nil { @@ -526,7 +530,7 @@ func syncMessages( }, logging.Labels{"sync-stage": "builder"}) // Goroutine which converts the messages into updates and builds a waitable structure for progress tracking. - logging.GoAnnotated(ctx, func(ctx context.Context) { + logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { defer close(flushUpdateCh) defer func() { logrus.Debugf("sync flush exit") @@ -771,12 +775,12 @@ func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan at } } -func newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader { +func (user *User) newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader { workerCh := make(chan attachmentJob, (workerCount+2)*workerCount) ctx, cancel := context.WithCancel(ctx) for i := 0; i < workerCount; i++ { workerCh = make(chan attachmentJob) - logging.GoAnnotated(ctx, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{ + logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{ "sync": fmt.Sprintf("att-downloader %v", i), }) } diff --git a/internal/user/types.go b/internal/user/types.go index 09478f5e..3da287fb 100644 --- a/internal/user/types.go +++ b/internal/user/types.go @@ -24,6 +24,7 @@ import ( "strings" "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/proton-bridge/v3/internal/async" "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) @@ -93,6 +94,6 @@ func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item { return sorted } -func newProtonAPIScheduler() proton.Scheduler { - return proton.NewParallelScheduler(runtime.NumCPU() / 2) +func newProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler { + return proton.NewParallelScheduler(runtime.NumCPU()/2, panicHandler) } diff --git a/internal/user/user.go b/internal/user/user.go index 2d0769cb..cd4b14c7 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -91,6 +91,8 @@ type User struct { showAllMail uint32 maxSyncMemory uint64 + + panicHandler async.PanicHandler } // New returns a new user. @@ -127,7 +129,7 @@ func New( reporter: reporter, sendHash: newSendRecorder(sendEntryExpiry), - eventCh: queue.NewQueuedChannel[events.Event](0, 0), + eventCh: queue.NewQueuedChannel[events.Event](0, 0, crashHandler), eventLock: safe.NewRWMutex(), apiUser: apiUser, @@ -148,6 +150,8 @@ func New( showAllMail: b32(showAllMail), maxSyncMemory: maxSyncMemory, + + panicHandler: crashHandler, } // Initialize the user's update channels for its current address mode. @@ -179,7 +183,10 @@ func New( user.goPollAPIEvents = func(wait bool) { doneCh := make(chan struct{}) - go func() { user.pollAPIEventsCh <- doneCh }() + go func() { + defer user.handlePanic() + user.pollAPIEventsCh <- doneCh + }() if wait { <-doneCh @@ -230,6 +237,12 @@ func New( return user, nil } +func (user *User) handlePanic() { + if user.panicHandler != nil { + user.panicHandler.HandlePanic() + } +} + func (user *User) TriggerSync() { user.goSync() } @@ -596,7 +609,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) { switch mode { case vault.CombinedMode: - primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0) + primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler) for addrID := range user.apiAddrs { user.updateCh[addrID] = primaryUpdateCh @@ -604,7 +617,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) { case vault.SplitMode: for addrID := range user.apiAddrs { - user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0) + user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler) } } } @@ -614,7 +627,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) { // When we receive an API event, we attempt to handle it. // If successful, we update the event ID in the vault. func (user *User) startEvents(ctx context.Context) { - ticker := proton.NewTicker(EventPeriod, EventJitter) + ticker := proton.NewTicker(EventPeriod, EventJitter, user.panicHandler) defer ticker.Stop() for { diff --git a/internal/user/user_test.go b/internal/user/user_test.go index a9ac6e04..c0888147 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -119,7 +119,7 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma saltedKeyPass, err := salts.SaltForKey([]byte(password), apiUser.Keys.Primary().ID) require.NoError(tb, err) - v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key")) + v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"), nil) require.NoError(tb, err) require.False(tb, corrupt) diff --git a/internal/vault/migrate_test.go b/internal/vault/migrate_test.go index 32af3a19..2d3662e1 100644 --- a/internal/vault/migrate_test.go +++ b/internal/vault/migrate_test.go @@ -25,6 +25,7 @@ import ( "path/filepath" "testing" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/stretchr/testify/require" "github.com/vmihailenco/msgpack/v5" @@ -52,7 +53,7 @@ func TestMigrate(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(dir, "vault.enc"), b, 0o600)) // Migrate the vault. - s, corrupt, err := New(dir, "default-gluon-dir", []byte("my secret key")) + s, corrupt, err := New(dir, "default-gluon-dir", []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.False(t, corrupt) diff --git a/internal/vault/settings_test.go b/internal/vault/settings_test.go index 33528fae..e0265798 100644 --- a/internal/vault/settings_test.go +++ b/internal/vault/settings_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/Masterminds/semver/v3" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/proton-bridge/v3/internal/updater" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/stretchr/testify/require" @@ -63,7 +64,7 @@ func TestVault_Settings_SMTP(t *testing.T) { func TestVault_Settings_GluonDir(t *testing.T) { // create a new test vault. - s, corrupt, err := vault.New(t.TempDir(), "/path/to/gluon", []byte("my secret key")) + s, corrupt, err := vault.New(t.TempDir(), "/path/to/gluon", []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.False(t, corrupt) diff --git a/internal/vault/vault.go b/internal/vault/vault.go index 4e605e3a..757f5707 100644 --- a/internal/vault/vault.go +++ b/internal/vault/vault.go @@ -29,6 +29,7 @@ import ( "path/filepath" "sync" + "github.com/ProtonMail/proton-bridge/v3/internal/async" "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/xslices" "github.com/sirupsen/logrus" @@ -44,10 +45,12 @@ type Vault struct { ref map[string]int refLock sync.Mutex + + panicHandler async.PanicHandler } // New constructs a new encrypted data vault at the given filepath using the given encryption key. -func New(vaultDir, gluonCacheDir string, key []byte) (*Vault, bool, error) { +func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHandler) (*Vault, bool, error) { if err := os.MkdirAll(vaultDir, 0o700); err != nil { return nil, false, err } @@ -69,9 +72,17 @@ func New(vaultDir, gluonCacheDir string, key []byte) (*Vault, bool, error) { return nil, false, err } + vault.panicHandler = panicHandler + return vault, corrupt, nil } +func (vault *Vault) handlePanic() { + if vault.panicHandler != nil { + vault.panicHandler.HandlePanic() + } +} + // GetUserIDs returns the user IDs and usernames of all users in the vault. func (vault *Vault) GetUserIDs() []string { return xslices.Map(vault.get().Users, func(user UserData) string { @@ -115,6 +126,8 @@ func (vault *Vault) ForUser(parallelism int, fn func(*User) error) error { userIDs := vault.GetUserIDs() return parallel.DoContext(context.Background(), parallelism, len(userIDs), func(_ context.Context, idx int) error { + defer vault.handlePanic() + user, err := vault.NewUser(userIDs[idx]) if err != nil { return err diff --git a/internal/vault/vault_bench_test.go b/internal/vault/vault_bench_test.go index d4410bc5..c5903ceb 100644 --- a/internal/vault/vault_bench_test.go +++ b/internal/vault/vault_bench_test.go @@ -22,6 +22,7 @@ import ( "runtime" "testing" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -31,7 +32,7 @@ func BenchmarkVault(b *testing.B) { vaultDir, gluonDir := b.TempDir(), b.TempDir() // Create a new vault. - s, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) + s, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(b, err) require.False(b, corrupt) diff --git a/internal/vault/vault_test.go b/internal/vault/vault_test.go index 07b2b329..12fea53d 100644 --- a/internal/vault/vault_test.go +++ b/internal/vault/vault_test.go @@ -22,6 +22,7 @@ import ( "path/filepath" "testing" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/stretchr/testify/require" ) @@ -30,19 +31,19 @@ func TestVault_Corrupt(t *testing.T) { vaultDir, gluonDir := t.TempDir(), t.TempDir() { - _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) + _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.False(t, corrupt) } { - _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) + _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.False(t, corrupt) } { - _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("bad key")) + _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("bad key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.True(t, corrupt) } @@ -52,13 +53,13 @@ func TestVault_Corrupt_JunkData(t *testing.T) { vaultDir, gluonDir := t.TempDir(), t.TempDir() { - _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) + _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.False(t, corrupt) } { - _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) + _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.False(t, corrupt) } @@ -71,7 +72,7 @@ func TestVault_Corrupt_JunkData(t *testing.T) { _, err = f.Write([]byte("junk data")) require.NoError(t, err) - _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key")) + _, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.True(t, corrupt) } @@ -99,7 +100,7 @@ func TestVault_Reset(t *testing.T) { func newVault(t *testing.T) *vault.Vault { t.Helper() - s, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key")) + s, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), queue.NoopPanicHandler{}) require.NoError(t, err) require.False(t, corrupt) diff --git a/tests/collector_test.go b/tests/collector_test.go index 7df7d7b8..cfcd62f6 100644 --- a/tests/collector_test.go +++ b/tests/collector_test.go @@ -44,7 +44,7 @@ func (c *eventCollector) collectFrom(eventCh <-chan events.Event) <-chan events. c.lock.Lock() defer c.lock.Unlock() - fwdCh := queue.NewQueuedChannel[events.Event](0, 0) + fwdCh := queue.NewQueuedChannel[events.Event](0, 0, queue.NoopPanicHandler{}) c.fwdCh = append(c.fwdCh, fwdCh) @@ -87,7 +87,7 @@ func (c *eventCollector) push(event events.Event) { defer c.lock.Unlock() if _, ok := c.events[reflect.TypeOf(event)]; !ok { - c.events[reflect.TypeOf(event)] = queue.NewQueuedChannel[events.Event](0, 0) + c.events[reflect.TypeOf(event)] = queue.NewQueuedChannel[events.Event](0, 0, queue.NoopPanicHandler{}) } c.events[reflect.TypeOf(event)].Enqueue(event) @@ -102,7 +102,7 @@ func (c *eventCollector) getEventCh(ofType events.Event) <-chan events.Event { defer c.lock.Unlock() if _, ok := c.events[reflect.TypeOf(ofType)]; !ok { - c.events[reflect.TypeOf(ofType)] = queue.NewQueuedChannel[events.Event](0, 0) + c.events[reflect.TypeOf(ofType)] = queue.NewQueuedChannel[events.Event](0, 0, queue.NoopPanicHandler{}) } return c.events[reflect.TypeOf(ofType)].GetChannel() diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index cb0a2446..7a6edf66 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -108,7 +108,7 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) { } // Create the vault. - vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, t.storeKey) + vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, t.storeKey, queue.NoopPanicHandler{}) if err != nil { return nil, fmt.Errorf("could not create vault: %w", err) } else if corrupt { @@ -301,7 +301,7 @@ func (t *testCtx) initFrontendClient() error { return fmt.Errorf("could not start event stream: %w", err) } - eventCh := queue.NewQueuedChannel[*frontend.StreamEvent](0, 0) + eventCh := queue.NewQueuedChannel[*frontend.StreamEvent](0, 0, queue.NoopPanicHandler{}) go func() { defer eventCh.CloseAndDiscardQueued() diff --git a/tests/ctx_helper_test.go b/tests/ctx_helper_test.go index 04c4c51b..05c5b0a4 100644 --- a/tests/ctx_helper_test.go +++ b/tests/ctx_helper_test.go @@ -23,6 +23,7 @@ import ( "os" "runtime" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/stream" @@ -113,7 +114,7 @@ func (t *testCtx) withAddrKR( return err } - _, addrKRs, err := proton.Unlock(user, addr, keyPass) + _, addrKRs, err := proton.Unlock(user, addr, keyPass, queue.NoopPanicHandler{}) if err != nil { return err }