mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
feat(GODT-2500): Add panic handlers everywhere.
This commit is contained in:
2
go.mod
2
go.mod
@ -125,6 +125,8 @@ require (
|
||||
)
|
||||
|
||||
replace (
|
||||
github.com/ProtonMail/gluon => /home/dev/gopath18/src/gluon
|
||||
github.com/ProtonMail/go-proton-api => /home/dev/gopath18/src/go-proton-api
|
||||
github.com/docker/docker-credential-helpers => github.com/ProtonMail/docker-credential-helpers v1.1.0
|
||||
github.com/emersion/go-message => github.com/ProtonMail/go-message v0.0.0-20210611055058-fabeff2ec753
|
||||
github.com/keybase/go-keychain => github.com/cuthix/go-keychain v0.0.0-20220405075754-31e7cee908fe
|
||||
|
||||
@ -185,14 +185,14 @@ func run(c *cli.Context) error {
|
||||
exe = os.Args[0]
|
||||
}
|
||||
|
||||
migrationErr := migrateOldVersions()
|
||||
// Restart the app if requested.
|
||||
return withRestarter(exe, func(restarter *restarter.Restarter) error {
|
||||
// Handle crashes with various actions.
|
||||
return withCrashHandler(restarter, reporter, func(crashHandler *crash.Handler, quitCh <-chan struct{}) error {
|
||||
migrationErr := migrateOldVersions()
|
||||
|
||||
// Run with profiling if requested.
|
||||
return withProfiler(c, func() error {
|
||||
// Restart the app if requested.
|
||||
return withRestarter(exe, func(restarter *restarter.Restarter) error {
|
||||
// Handle crashes with various actions.
|
||||
return withCrashHandler(restarter, reporter, func(crashHandler *crash.Handler, quitCh <-chan struct{}) error {
|
||||
// Run with profiling if requested.
|
||||
return withProfiler(c, func() error {
|
||||
// Load the locations where we store our files.
|
||||
return WithLocations(func(locations *locations.Locations) error {
|
||||
// Migrate the keychain helper.
|
||||
@ -215,7 +215,7 @@ func run(c *cli.Context) error {
|
||||
|
||||
return withSingleInstance(settings, locations.GetLockFile(), version, func() error {
|
||||
// Unlock the encrypted vault.
|
||||
return WithVault(locations, func(v *vault.Vault, insecure, corrupt bool) error {
|
||||
return WithVault(locations, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error {
|
||||
// Report insecure vault.
|
||||
if insecure {
|
||||
_ = reporter.ReportMessageWithContext("Vault is insecure", map[string]interface{}{})
|
||||
|
||||
@ -78,7 +78,7 @@ func withBridge(
|
||||
)
|
||||
|
||||
// Create a proxy dialer which switches to a proxy if the request fails.
|
||||
proxyDialer := dialer.NewProxyTLSDialer(pinningDialer, constants.APIHost)
|
||||
proxyDialer := dialer.NewProxyTLSDialer(pinningDialer, constants.APIHost, crashHandler)
|
||||
|
||||
// Create the autostarter.
|
||||
autostarter := newAutostarter(exe)
|
||||
|
||||
@ -46,7 +46,7 @@ func runFrontend(
|
||||
|
||||
switch {
|
||||
case c.Bool(flagCLI):
|
||||
return bridgeCLI.New(bridge, restarter, eventCh).Loop()
|
||||
return bridgeCLI.New(bridge, restarter, eventCh, crashHandler).Loop()
|
||||
|
||||
case c.Bool(flagNonInteractive):
|
||||
select {}
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/cookies"
|
||||
@ -40,7 +41,7 @@ import (
|
||||
|
||||
func TestMigratePrefsToVaultWithKeys(t *testing.T) {
|
||||
// Create a new vault.
|
||||
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"))
|
||||
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
|
||||
@ -61,7 +62,7 @@ func TestMigratePrefsToVaultWithKeys(t *testing.T) {
|
||||
|
||||
func TestMigratePrefsToVaultWithoutKeys(t *testing.T) {
|
||||
// Create a new vault.
|
||||
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"))
|
||||
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
|
||||
@ -173,7 +174,7 @@ func TestUserMigration(t *testing.T) {
|
||||
token, err := crypto.RandomToken(32)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, corrupt, err := vault.New(settingsFolder, settingsFolder, token)
|
||||
v, corrupt, err := vault.New(settingsFolder, settingsFolder, token, queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"fmt"
|
||||
"path"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/async"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/certs"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/locations"
|
||||
@ -29,12 +30,12 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func WithVault(locations *locations.Locations, fn func(*vault.Vault, bool, bool) error) error {
|
||||
func WithVault(locations *locations.Locations, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error {
|
||||
logrus.Debug("Creating vault")
|
||||
defer logrus.Debug("Vault stopped")
|
||||
|
||||
// Create the encVault.
|
||||
encVault, insecure, corrupt, err := newVault(locations)
|
||||
encVault, insecure, corrupt, err := newVault(locations, panicHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create vault: %w", err)
|
||||
}
|
||||
@ -66,7 +67,7 @@ func WithVault(locations *locations.Locations, fn func(*vault.Vault, bool, bool)
|
||||
return fn(encVault, insecure, corrupt)
|
||||
}
|
||||
|
||||
func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error) {
|
||||
func newVault(locations *locations.Locations, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) {
|
||||
vaultDir, err := locations.ProvideSettingsPath()
|
||||
if err != nil {
|
||||
return nil, false, false, fmt.Errorf("could not get vault dir: %w", err)
|
||||
@ -93,7 +94,7 @@ func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error)
|
||||
return nil, false, false, fmt.Errorf("could not provide gluon path: %w", err)
|
||||
}
|
||||
|
||||
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey)
|
||||
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey, panicHandler)
|
||||
if err != nil {
|
||||
return nil, false, false, fmt.Errorf("could not create vault: %w", err)
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -32,6 +33,7 @@ func defaultAPIOptions(
|
||||
version *semver.Version,
|
||||
cookieJar http.CookieJar,
|
||||
transport http.RoundTripper,
|
||||
panicHandler queue.PanicHandler,
|
||||
) []proton.Option {
|
||||
return []proton.Option{
|
||||
proton.WithHostURL(apiURL),
|
||||
@ -39,5 +41,6 @@ func defaultAPIOptions(
|
||||
proton.WithCookieJar(cookieJar),
|
||||
proton.WithTransport(transport),
|
||||
proton.WithLogger(logrus.StandardLogger()),
|
||||
proton.WithPanicHandler(panicHandler),
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
)
|
||||
|
||||
@ -32,6 +33,7 @@ func newAPIOptions(
|
||||
version *semver.Version,
|
||||
cookieJar http.CookieJar,
|
||||
transport http.RoundTripper,
|
||||
panicHandler queue.PanicHandler,
|
||||
) []proton.Option {
|
||||
return defaultAPIOptions(apiURL, version, cookieJar, transport)
|
||||
return defaultAPIOptions(apiURL, version, cookieJar, transport, panicHandler)
|
||||
}
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
)
|
||||
|
||||
@ -33,8 +34,9 @@ func newAPIOptions(
|
||||
version *semver.Version,
|
||||
cookieJar http.CookieJar,
|
||||
transport http.RoundTripper,
|
||||
panicHandler queue.PanicHandler,
|
||||
) []proton.Option {
|
||||
opt := defaultAPIOptions(apiURL, version, cookieJar, transport)
|
||||
opt := defaultAPIOptions(apiURL, version, cookieJar, transport, panicHandler)
|
||||
|
||||
if host := os.Getenv("BRIDGE_API_HOST"); host != "" {
|
||||
opt = append(opt, proton.WithHostURL(host))
|
||||
|
||||
@ -93,8 +93,8 @@ type Bridge struct {
|
||||
// locator is the bridge's locator.
|
||||
locator Locator
|
||||
|
||||
// crashHandler
|
||||
crashHandler async.PanicHandler
|
||||
// panicHandler
|
||||
panicHandler async.PanicHandler
|
||||
|
||||
// reporter
|
||||
reporter reporter.Reporter
|
||||
@ -143,7 +143,7 @@ func New(
|
||||
tlsReporter TLSReporter, // the TLS reporter to report TLS errors
|
||||
roundTripper http.RoundTripper, // the round tripper to use for API requests
|
||||
proxyCtl ProxyController, // the DoH controller
|
||||
crashHandler async.PanicHandler,
|
||||
panicHandler async.PanicHandler,
|
||||
reporter reporter.Reporter,
|
||||
uidValidityGenerator imap.UIDValidityGenerator,
|
||||
|
||||
@ -151,10 +151,10 @@ func New(
|
||||
logSMTP bool, // whether to log SMTP activity
|
||||
) (*Bridge, <-chan events.Event, error) {
|
||||
// api is the user's API manager.
|
||||
api := proton.New(newAPIOptions(apiURL, curVersion, cookieJar, roundTripper)...)
|
||||
api := proton.New(newAPIOptions(apiURL, curVersion, cookieJar, roundTripper, panicHandler)...)
|
||||
|
||||
// tasks holds all the bridge's background tasks.
|
||||
tasks := async.NewGroup(context.Background(), crashHandler)
|
||||
tasks := async.NewGroup(context.Background(), panicHandler)
|
||||
|
||||
// imapEventCh forwards IMAP events from gluon instances to the bridge for processing.
|
||||
imapEventCh := make(chan imapEvents.Event)
|
||||
@ -169,7 +169,7 @@ func New(
|
||||
autostarter,
|
||||
updater,
|
||||
curVersion,
|
||||
crashHandler,
|
||||
panicHandler,
|
||||
reporter,
|
||||
|
||||
api,
|
||||
@ -202,7 +202,7 @@ func newBridge(
|
||||
autostarter Autostarter,
|
||||
updater Updater,
|
||||
curVersion *semver.Version,
|
||||
crashHandler async.PanicHandler,
|
||||
panicHandler async.PanicHandler,
|
||||
reporter reporter.Reporter,
|
||||
|
||||
api *proton.Manager,
|
||||
@ -248,12 +248,13 @@ func newBridge(
|
||||
imapEventCh,
|
||||
tasks,
|
||||
uidValidityGenerator,
|
||||
panicHandler,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IMAP server: %w", err)
|
||||
}
|
||||
|
||||
focusService, err := focus.NewService(locator, curVersion)
|
||||
focusService, err := focus.NewService(locator, curVersion, panicHandler)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create focus service: %w", err)
|
||||
}
|
||||
@ -279,7 +280,7 @@ func newBridge(
|
||||
newVersion: curVersion,
|
||||
newVersionLock: safe.NewRWMutex(),
|
||||
|
||||
crashHandler: crashHandler,
|
||||
panicHandler: panicHandler,
|
||||
reporter: reporter,
|
||||
|
||||
focusService: focusService,
|
||||
@ -495,7 +496,7 @@ func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events
|
||||
bridge.watchersLock.Lock()
|
||||
defer bridge.watchersLock.Unlock()
|
||||
|
||||
watcher := watcher.New(ofType...)
|
||||
watcher := watcher.New(bridge.panicHandler, ofType...)
|
||||
|
||||
bridge.watchers = append(bridge.watchers, watcher)
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ import (
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/go-proton-api/server"
|
||||
"github.com/ProtonMail/go-proton-api/server/backend"
|
||||
@ -699,7 +700,7 @@ func withBridgeNoMocks(
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create the vault.
|
||||
vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey)
|
||||
vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey, queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a new cookie jar.
|
||||
|
||||
@ -299,6 +299,7 @@ func newIMAPServer(
|
||||
eventCh chan<- imapEvents.Event,
|
||||
tasks *async.Group,
|
||||
uidValidityGenerator imap.UIDValidityGenerator,
|
||||
panicHandler async.PanicHandler,
|
||||
) (*gluon.Server, error) {
|
||||
gluonCacheDir = ApplyGluonCachePathSuffix(gluonCacheDir)
|
||||
gluonConfigDir = ApplyGluonConfigPathSuffix(gluonConfigDir)
|
||||
@ -343,6 +344,7 @@ func newIMAPServer(
|
||||
getGluonVersionInfo(version),
|
||||
gluon.WithReporter(reporter),
|
||||
gluon.WithUIDValidityGenerator(uidValidityGenerator),
|
||||
gluon.WithPanicHandler(panicHandler),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -157,6 +157,7 @@ func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error
|
||||
bridge.imapEventCh,
|
||||
bridge.tasks,
|
||||
bridge.uidValidityGenerator,
|
||||
bridge.panicHandler,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new IMAP server: %w", err)
|
||||
|
||||
@ -28,6 +28,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/go-proton-api/server"
|
||||
@ -428,7 +429,7 @@ func createMessages(ctx context.Context, t *testing.T, c *proton.Client, addrID,
|
||||
keyPass, err := salt.SaltForKey(password, user.Keys.Primary().ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, addrKRs, err := proton.Unlock(user, addr, keyPass)
|
||||
_, addrKRs, err := proton.Unlock(user, addr, keyPass, queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, ok := addrKRs[addrID]
|
||||
|
||||
@ -516,7 +516,7 @@ func (bridge *Bridge) addUserWithVault(
|
||||
client,
|
||||
bridge.reporter,
|
||||
apiUser,
|
||||
bridge.crashHandler,
|
||||
bridge.panicHandler,
|
||||
bridge.vault.GetShowAllMail(),
|
||||
bridge.vault.GetMaxSyncMemory(),
|
||||
)
|
||||
|
||||
@ -28,6 +28,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/go-proton-api/server"
|
||||
@ -474,7 +475,7 @@ func TestBridge_User_UpdateDraftAndCreateOtherMessage(t *testing.T) {
|
||||
keyPass, err := salts.SaltForKey(password, user.Keys.Primary().ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, addrKRs, err := proton.Unlock(user, addrs, keyPass)
|
||||
_, addrKRs, err := proton.Unlock(user, addrs, keyPass, queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a draft (generating a "create draft message" event).
|
||||
@ -556,7 +557,7 @@ func TestBridge_User_SendDraftRemoveDraftFlag(t *testing.T) {
|
||||
keyPass, err := salts.SaltForKey(password, user.Keys.Primary().ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, addrKRs, err := proton.Unlock(user, addrs, keyPass)
|
||||
_, addrKRs, err := proton.Unlock(user, addrs, keyPass, queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a draft (generating a "create draft message" event).
|
||||
|
||||
@ -98,6 +98,8 @@ func saveConfigTemporarily(mc *mobileconfig.Config) (fname string, err error) {
|
||||
|
||||
// Make sure the temporary file is deleted.
|
||||
go func() {
|
||||
defer recover() //nolint:errcheck
|
||||
|
||||
<-time.After(10 * time.Minute)
|
||||
_ = os.RemoveAll(dir)
|
||||
}()
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/async"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
@ -40,17 +41,20 @@ type ProxyTLSDialer struct {
|
||||
allowProxy bool
|
||||
proxyProvider *proxyProvider
|
||||
proxyUseDuration time.Duration
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer.
|
||||
func NewProxyTLSDialer(dialer TLSDialer, hostURL string) *ProxyTLSDialer {
|
||||
func NewProxyTLSDialer(dialer TLSDialer, hostURL string, panicHandler async.PanicHandler) *ProxyTLSDialer {
|
||||
return &ProxyTLSDialer{
|
||||
dialer: dialer,
|
||||
locker: sync.RWMutex{},
|
||||
directAddress: formatAsAddress(hostURL),
|
||||
proxyAddress: formatAsAddress(hostURL),
|
||||
proxyProvider: newProxyProvider(dialer, hostURL, DoHProviders),
|
||||
proxyProvider: newProxyProvider(dialer, hostURL, DoHProviders, panicHandler),
|
||||
proxyUseDuration: proxyUseDuration,
|
||||
panicHandler: panicHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@ -75,6 +79,12 @@ func formatAsAddress(rawURL string) string {
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
func (d *ProxyTLSDialer) handlePanic() {
|
||||
if d.panicHandler != nil {
|
||||
d.panicHandler.HandlePanic()
|
||||
}
|
||||
}
|
||||
|
||||
// DialTLSContext dials the given network/address. If it fails, it retries using a proxy.
|
||||
func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d.locker.RLock()
|
||||
@ -129,6 +139,8 @@ func (d *ProxyTLSDialer) switchToReachableServer() error {
|
||||
// This means we want to disable it again in 24 hours.
|
||||
if d.proxyAddress == d.directAddress {
|
||||
go func() {
|
||||
defer d.handlePanic()
|
||||
|
||||
<-time.After(d.proxyUseDuration)
|
||||
|
||||
d.locker.Lock()
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/async"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
@ -67,11 +68,13 @@ type proxyProvider struct {
|
||||
canReachTimeout time.Duration
|
||||
|
||||
lastLookup time.Time // The time at which we last attempted to find a proxy.
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
// newProxyProvider creates a new proxyProvider that queries the given DoH providers
|
||||
// to retrieve DNS records for the given query string.
|
||||
func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p *proxyProvider) {
|
||||
func newProxyProvider(dialer TLSDialer, hostURL string, providers []string, panicHandler async.PanicHandler) (p *proxyProvider) {
|
||||
p = &proxyProvider{
|
||||
dialer: dialer,
|
||||
hostURL: hostURL,
|
||||
@ -80,6 +83,7 @@ func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p *
|
||||
cacheRefreshTimeout: proxyCacheRefreshTimeout,
|
||||
dohTimeout: proxyDoHTimeout,
|
||||
canReachTimeout: proxyCanReachTimeout,
|
||||
panicHandler: panicHandler,
|
||||
}
|
||||
|
||||
// Use the default DNS lookup method; this can be overridden if necessary.
|
||||
@ -88,6 +92,12 @@ func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p *
|
||||
return
|
||||
}
|
||||
|
||||
func (p *proxyProvider) handlePanic() {
|
||||
if p.panicHandler != nil {
|
||||
p.panicHandler.HandlePanic()
|
||||
}
|
||||
}
|
||||
|
||||
// findReachableServer returns a working API server (either proxy or standard API).
|
||||
func (p *proxyProvider) findReachableServer() (proxy string, err error) {
|
||||
logrus.Debug("Trying to find a reachable server")
|
||||
@ -109,11 +119,13 @@ func (p *proxyProvider) findReachableServer() (proxy string, err error) {
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer p.handlePanic()
|
||||
defer wg.Done()
|
||||
apiReachable = p.canReach(p.hostURL)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer p.handlePanic()
|
||||
defer wg.Done()
|
||||
err = p.refreshProxyCache()
|
||||
}()
|
||||
@ -150,6 +162,8 @@ func (p *proxyProvider) refreshProxyCache() error {
|
||||
resultChan := make(chan []string)
|
||||
|
||||
go func() {
|
||||
defer p.handlePanic()
|
||||
|
||||
for _, provider := range p.providers {
|
||||
if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil {
|
||||
resultChan <- proxies
|
||||
@ -203,6 +217,7 @@ func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider
|
||||
dataChan, errChan := make(chan []string), make(chan error)
|
||||
|
||||
go func() {
|
||||
defer p.handlePanic()
|
||||
// Build new DNS request in RFC1035 format.
|
||||
dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT)
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/useragent"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -31,7 +32,7 @@ func TestProxyProvider_FindProxy(t *testing.T) {
|
||||
proxy := getTrustedServer()
|
||||
defer closeServer(proxy)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{})
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
@ -47,7 +48,7 @@ func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
|
||||
unreachableProxy := getTrustedServer()
|
||||
closeServer(unreachableProxy)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{})
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{reachableProxy.URL, unreachableProxy.URL}, nil
|
||||
}
|
||||
@ -68,7 +69,7 @@ func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) {
|
||||
checker := NewTLSPinChecker(TrustedAPIPins)
|
||||
dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker)
|
||||
|
||||
p := newProxyProvider(dialer, "", []string{"not used"})
|
||||
p := newProxyProvider(dialer, "", []string{"not used"}, queue.NoopPanicHandler{})
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{untrustedProxy.URL, trustedProxy.URL}, nil
|
||||
}
|
||||
@ -85,7 +86,7 @@ func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
|
||||
unreachableProxy2 := getTrustedServer()
|
||||
closeServer(unreachableProxy2)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{})
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil
|
||||
}
|
||||
@ -105,7 +106,7 @@ func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
|
||||
checker := NewTLSPinChecker(TrustedAPIPins)
|
||||
dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker)
|
||||
|
||||
p := newProxyProvider(dialer, "", []string{"not used"})
|
||||
p := newProxyProvider(dialer, "", []string{"not used"}, queue.NoopPanicHandler{})
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil
|
||||
}
|
||||
@ -115,7 +116,7 @@ func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) {
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{})
|
||||
p.cacheRefreshTimeout = 1 * time.Second
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
|
||||
|
||||
@ -132,7 +133,7 @@ func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) {
|
||||
}))
|
||||
defer closeServer(slowProxy)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"}, queue.NoopPanicHandler{})
|
||||
p.canReachTimeout = 1 * time.Second
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
|
||||
|
||||
@ -144,7 +145,7 @@ func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{})
|
||||
|
||||
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9Provider)
|
||||
r.NoError(t, err)
|
||||
@ -155,7 +156,7 @@ func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
|
||||
// port filter. Basic functionality should be covered by other tests. Keeping
|
||||
// code here to be able to run it locally if needed.
|
||||
func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) {
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{})
|
||||
|
||||
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9PortProvider)
|
||||
r.NoError(t, err)
|
||||
@ -163,7 +164,7 @@ func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyProvider_DoHLookup_Google(t *testing.T) {
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{})
|
||||
|
||||
records, err := p.dohLookup(context.Background(), proxyQuery, GoogleProvider)
|
||||
r.NoError(t, err)
|
||||
@ -173,7 +174,7 @@ func TestProxyProvider_DoHLookup_Google(t *testing.T) {
|
||||
func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{})
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
@ -183,7 +184,7 @@ func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
|
||||
func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"https://unreachable", Quad9Provider, GoogleProvider})
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"https://unreachable", Quad9Provider, GoogleProvider}, queue.NoopPanicHandler{})
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -141,8 +142,8 @@ func TestProxyDialer_UseProxy(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
defer closeServer(trustedProxy)
|
||||
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{})
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{})
|
||||
d.proxyProvider = provider
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
|
||||
@ -159,8 +160,8 @@ func TestProxyDialer_UseProxy_MultipleTimes(t *testing.T) {
|
||||
proxy3 := getTrustedServer()
|
||||
defer closeServer(proxy3)
|
||||
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{})
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{})
|
||||
d.proxyProvider = provider
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
|
||||
|
||||
@ -189,8 +190,8 @@ func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
defer closeServer(trustedProxy)
|
||||
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{})
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{})
|
||||
d.proxyProvider = provider
|
||||
d.proxyUseDuration = time.Second
|
||||
|
||||
@ -212,8 +213,8 @@ func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
|
||||
func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{})
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{})
|
||||
d.proxyProvider = provider
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
|
||||
@ -242,8 +243,8 @@ func TestProxyDialer_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBloc
|
||||
proxy2 := getTrustedServer()
|
||||
defer closeServer(proxy2)
|
||||
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders, queue.NoopPanicHandler{})
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "", queue.NoopPanicHandler{})
|
||||
d.proxyProvider = provider
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ func TestFocus_Raise(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
locations := locations.New(newTestLocationsProvider(tmpDir), "config-name")
|
||||
// Start the focus service.
|
||||
service, err := NewService(locations, semver.MustParse("1.2.3"))
|
||||
service, err := NewService(locations, semver.MustParse("1.2.3"), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsFolder, err := locations.ProvideSettingsPath()
|
||||
@ -52,7 +52,7 @@ func TestFocus_Version(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
locations := locations.New(newTestLocationsProvider(tmpDir), "config-name")
|
||||
// Start the focus service.
|
||||
_, err := NewService(locations, semver.MustParse("1.2.3"))
|
||||
_, err := NewService(locations, semver.MustParse("1.2.3"), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsFolder, err := locations.ProvideSettingsPath()
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"net"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/async"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/focus/proto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/service"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -43,15 +44,18 @@ type Service struct {
|
||||
server *grpc.Server
|
||||
raiseCh chan struct{}
|
||||
version *semver.Version
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
// NewService creates a new focus service.
|
||||
// It listens on the local host and port 1042 (by default).
|
||||
func NewService(locator service.Locator, version *semver.Version) (*Service, error) {
|
||||
func NewService(locator service.Locator, version *semver.Version, panicHandler async.PanicHandler) (*Service, error) {
|
||||
serv := &Service{
|
||||
server: grpc.NewServer(),
|
||||
raiseCh: make(chan struct{}, 1),
|
||||
version: version,
|
||||
server: grpc.NewServer(),
|
||||
raiseCh: make(chan struct{}, 1),
|
||||
version: version,
|
||||
panicHandler: panicHandler,
|
||||
}
|
||||
|
||||
proto.RegisterFocusServer(serv.server, serv)
|
||||
@ -73,6 +77,8 @@ func NewService(locator service.Locator, version *semver.Version) (*Service, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer serv.handlePanic()
|
||||
|
||||
if err := serv.server.Serve(listener); err != nil {
|
||||
fmt.Printf("failed to serve: %v", err)
|
||||
}
|
||||
@ -82,6 +88,12 @@ func NewService(locator service.Locator, version *semver.Version) (*Service, err
|
||||
return serv, nil
|
||||
}
|
||||
|
||||
func (service *Service) handlePanic() {
|
||||
if service.panicHandler != nil {
|
||||
service.panicHandler.HandlePanic()
|
||||
}
|
||||
}
|
||||
|
||||
// Raise implements the gRPC FocusService interface; it raises the application.
|
||||
func (service *Service) Raise(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
service.raiseCh <- struct{}{}
|
||||
@ -103,6 +115,8 @@ func (service *Service) GetRaiseCh() <-chan struct{} {
|
||||
// Close closes the service.
|
||||
func (service *Service) Close() {
|
||||
go func() {
|
||||
defer service.handlePanic()
|
||||
|
||||
// we do this in a goroutine, as on Windows, the gRPC shutdown may take minutes if something tries to
|
||||
// interact with it in an invalid way (e.g. HTTP GET request from a Qt QNetworkManager instance).
|
||||
service.server.Stop()
|
||||
|
||||
@ -21,6 +21,7 @@ package cli
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/async"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
@ -39,15 +40,18 @@ type frontendCLI struct {
|
||||
restarter *restarter.Restarter
|
||||
|
||||
badUserID string
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
// New returns a new CLI frontend configured with the given options.
|
||||
func New(bridge *bridge.Bridge, restarter *restarter.Restarter, eventCh <-chan events.Event) *frontendCLI { //nolint:revive
|
||||
func New(bridge *bridge.Bridge, restarter *restarter.Restarter, eventCh <-chan events.Event, panicHandler async.PanicHandler) *frontendCLI { //nolint:revive
|
||||
fe := &frontendCLI{
|
||||
Shell: ishell.New(),
|
||||
bridge: bridge,
|
||||
restarter: restarter,
|
||||
badUserID: "",
|
||||
Shell: ishell.New(),
|
||||
bridge: bridge,
|
||||
restarter: restarter,
|
||||
badUserID: "",
|
||||
panicHandler: panicHandler,
|
||||
}
|
||||
|
||||
// Clear commands.
|
||||
@ -285,6 +289,8 @@ func New(bridge *bridge.Bridge, restarter *restarter.Restarter, eventCh <-chan e
|
||||
}
|
||||
|
||||
func (f *frontendCLI) watchEvents(eventCh <-chan events.Event) { // nolint:gocyclo
|
||||
defer f.handlePanic()
|
||||
|
||||
// GODT-1949: Better error events.
|
||||
for _, err := range f.bridge.GetErrors() {
|
||||
switch {
|
||||
@ -445,6 +451,12 @@ func (f *frontendCLI) watchEvents(eventCh <-chan events.Event) { // nolint:gocyc
|
||||
*/
|
||||
}
|
||||
|
||||
func (f *frontendCLI) handlePanic() {
|
||||
if f.panicHandler != nil {
|
||||
f.panicHandler.HandlePanic()
|
||||
}
|
||||
}
|
||||
|
||||
// Loop starts the frontend loop with an interactive shell.
|
||||
func (f *frontendCLI) Loop() error {
|
||||
f.Printf(`
|
||||
|
||||
@ -191,6 +191,12 @@ func NewService(
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Service) handlePanic() {
|
||||
if s.panicHandler != nil {
|
||||
s.panicHandler.HandlePanic()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) initAutostart() {
|
||||
s.firstTimeAutostart.Do(func() {
|
||||
shouldAutostartBeOn := s.bridge.GetAutostart()
|
||||
@ -207,11 +213,14 @@ func (s *Service) Loop() error {
|
||||
if s.parentPID < 0 {
|
||||
s.log.Info("Not monitoring parent PID")
|
||||
} else {
|
||||
go s.monitorParentPID()
|
||||
go func() {
|
||||
defer s.handlePanic()
|
||||
s.monitorParentPID()
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
defer s.handlePanic()
|
||||
s.watchEvents()
|
||||
}()
|
||||
|
||||
@ -221,6 +230,8 @@ func (s *Service) Loop() error {
|
||||
defer close(doneCh)
|
||||
|
||||
go func() {
|
||||
defer s.handlePanic()
|
||||
|
||||
select {
|
||||
case <-s.quitCh:
|
||||
s.log.Info("Stopping gRPC server")
|
||||
@ -564,6 +575,8 @@ func (s *Service) monitorParentPID() {
|
||||
s.log.Info("Parent process does not exist anymore. Initiating shutdown")
|
||||
// quit will write to the parentPIDDoneCh, so we launch a goroutine.
|
||||
go func() {
|
||||
defer s.handlePanic()
|
||||
|
||||
if err := s.quit(); err != nil {
|
||||
logrus.WithError(err).Error("Error on quit")
|
||||
}
|
||||
|
||||
@ -114,6 +114,8 @@ func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empt
|
||||
func (s *Service) quit() error {
|
||||
// Windows is notably slow at Quitting. We do it in a goroutine to speed things up a bit.
|
||||
go func() {
|
||||
defer s.handlePanic()
|
||||
|
||||
if s.parentPID >= 0 {
|
||||
s.parentPIDDoneCh <- struct{}{}
|
||||
}
|
||||
@ -221,7 +223,8 @@ func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb.
|
||||
s.log.Debug("TriggerReset")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
defer s.handlePanic()
|
||||
|
||||
s.triggerReset()
|
||||
}()
|
||||
return &emptypb.Empty{}, nil
|
||||
@ -316,6 +319,8 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp
|
||||
}).Debug("ReportBug")
|
||||
|
||||
go func() {
|
||||
defer s.handlePanic()
|
||||
|
||||
defer func() { _ = s.SendEvent(NewReportBugFinishedEvent()) }()
|
||||
|
||||
if err := s.bridge.ReportBug(
|
||||
@ -343,7 +348,7 @@ func (s *Service) ExportTLSCertificates(_ context.Context, folderPath *wrappersp
|
||||
s.log.WithField("folderPath", folderPath).Info("ExportTLSCertificates")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
defer s.handlePanic()
|
||||
|
||||
cert, key := s.bridge.GetBridgeTLSCert()
|
||||
|
||||
@ -379,7 +384,7 @@ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empt
|
||||
s.log.WithField("username", login.Username).Debug("Login")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
defer s.handlePanic()
|
||||
|
||||
password, err := base64Decode(login.Password)
|
||||
if err != nil {
|
||||
@ -435,7 +440,7 @@ func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.E
|
||||
s.log.WithField("username", login.Username).Debug("Login2FA")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
defer s.handlePanic()
|
||||
|
||||
if s.auth.UID == "" || s.authClient == nil {
|
||||
s.log.Errorf("Login 2FA: authethication incomplete %s %p", s.auth.UID, s.authClient)
|
||||
@ -480,7 +485,7 @@ func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*em
|
||||
s.log.WithField("username", login.Username).Debug("Login2Passwords")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
defer s.handlePanic()
|
||||
|
||||
password, err := base64Decode(login.Password)
|
||||
if err != nil {
|
||||
@ -502,7 +507,7 @@ func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest)
|
||||
s.log.WithField("username", loginAbort.Username).Debug("LoginAbort")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
defer s.handlePanic()
|
||||
|
||||
s.loginAbort()
|
||||
}()
|
||||
@ -514,7 +519,7 @@ func (s *Service) CheckUpdate(context.Context, *emptypb.Empty) (*emptypb.Empty,
|
||||
s.log.Debug("CheckUpdate")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
defer s.handlePanic()
|
||||
|
||||
updateCh, done := s.bridge.GetEvents(
|
||||
events.UpdateAvailable{},
|
||||
@ -546,7 +551,7 @@ func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb
|
||||
s.log.Debug("InstallUpdate")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
defer s.handlePanic()
|
||||
|
||||
safe.RLock(func() {
|
||||
s.bridge.InstallUpdate(s.target)
|
||||
@ -587,6 +592,8 @@ func (s *Service) SetDiskCachePath(ctx context.Context, newPath *wrapperspb.Stri
|
||||
s.log.WithField("path", newPath.Value).Debug("setDiskCachePath")
|
||||
|
||||
go func() {
|
||||
defer s.handlePanic()
|
||||
|
||||
defer func() {
|
||||
_ = s.SendEvent(NewDiskCachePathChangeFinishedEvent())
|
||||
}()
|
||||
@ -652,7 +659,7 @@ func (s *Service) SetMailServerSettings(_ context.Context, settings *ImapSmtpSet
|
||||
Debug("SetConnectionMode")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
defer s.handlePanic()
|
||||
|
||||
defer func() { _ = s.SendEvent(NewChangeMailServerSettingsFinishedEvent()) }()
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/async"
|
||||
"github.com/bradenaw/juniper/xerrors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
@ -68,17 +69,32 @@ func catch(handlers ...func() error) {
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
mu sync.Mutex
|
||||
mu sync.Mutex
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
func (wg *Group) SetPanicHandler(panicHandler async.PanicHandler) {
|
||||
wg.panicHandler = panicHandler
|
||||
}
|
||||
|
||||
func (wg *Group) handlePanic() {
|
||||
if wg.panicHandler != nil {
|
||||
wg.panicHandler.HandlePanic()
|
||||
}
|
||||
}
|
||||
|
||||
func (wg *Group) GoTry(fn func(bool)) {
|
||||
if wg.mu.TryLock() {
|
||||
go func() {
|
||||
defer wg.handlePanic()
|
||||
defer wg.mu.Unlock()
|
||||
fn(true)
|
||||
}()
|
||||
} else {
|
||||
go fn(false)
|
||||
go func() {
|
||||
defer wg.handlePanic()
|
||||
fn(false)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -225,7 +225,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
|
||||
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
|
||||
|
||||
case vault.SplitMode:
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressCreated{
|
||||
@ -284,7 +284,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
|
||||
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
|
||||
|
||||
case vault.SplitMode:
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressEnabled{
|
||||
@ -594,7 +594,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
|
||||
"subject": logging.Sensitive(message.Subject),
|
||||
}).Info("Handling message created event")
|
||||
|
||||
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
@ -686,7 +686,7 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
|
||||
"subject": logging.Sensitive(event.Message.Subject),
|
||||
}).Info("Handling draft updated event")
|
||||
|
||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
|
||||
@ -290,7 +290,7 @@ func (conn *imapConnector) CreateMessage(
|
||||
conn.log.WithField("messageID", messageID).Warn("Message already sent")
|
||||
|
||||
// Query the server-side message.
|
||||
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
@ -354,7 +354,7 @@ func (conn *imapConnector) CreateMessage(
|
||||
}
|
||||
|
||||
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
|
||||
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -572,7 +572,7 @@ func (conn *imapConnector) importMessage(
|
||||
|
||||
var err error
|
||||
|
||||
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()); err != nil {
|
||||
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
|
||||
return fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@ -48,6 +48,8 @@ import (
|
||||
|
||||
// sendMail sends an email from the given address to the given recipients.
|
||||
func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error {
|
||||
defer user.handlePanic()
|
||||
|
||||
return safe.RLockRet(func() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@ -143,7 +145,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
||||
}
|
||||
|
||||
// Send the message using the correct key.
|
||||
sent, err := sendWithKey(
|
||||
sent, err := user.sendWithKey(
|
||||
ctx,
|
||||
user.client,
|
||||
user.reporter,
|
||||
@ -167,7 +169,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
||||
}
|
||||
|
||||
// sendWithKey sends the message with the given address key.
|
||||
func sendWithKey(
|
||||
func (user *User) sendWithKey(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
sentry reporter.Reporter,
|
||||
@ -226,12 +228,12 @@ func sendWithKey(
|
||||
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
||||
}
|
||||
|
||||
attKeys, err := createAttachments(ctx, client, addrKR, draft.ID, message.Attachments)
|
||||
attKeys, err := user.createAttachments(ctx, client, addrKR, draft.ID, message.Attachments)
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
||||
}
|
||||
|
||||
recipients, err := getRecipients(ctx, client, userKR, settings, draft)
|
||||
recipients, err := user.getRecipients(ctx, client, userKR, settings, draft)
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err)
|
||||
}
|
||||
@ -377,7 +379,7 @@ func createDraft(
|
||||
})
|
||||
}
|
||||
|
||||
func createAttachments(
|
||||
func (user *User) createAttachments(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
addrKR *crypto.KeyRing,
|
||||
@ -390,6 +392,8 @@ func createAttachments(
|
||||
}
|
||||
|
||||
keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) {
|
||||
defer user.handlePanic()
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"name": logging.Sensitive(att.Name),
|
||||
"contentID": att.ContentID,
|
||||
@ -455,7 +459,7 @@ func createAttachments(
|
||||
return attKeys, nil
|
||||
}
|
||||
|
||||
func getRecipients(
|
||||
func (user *User) getRecipients(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
userKR *crypto.KeyRing,
|
||||
@ -467,6 +471,8 @@ func getRecipients(
|
||||
})
|
||||
|
||||
prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) {
|
||||
defer user.handlePanic()
|
||||
|
||||
pubKeys, recType, err := client.GetPublicKeys(ctx, recipient)
|
||||
if err != nil {
|
||||
return proton.SendPreferences{}, fmt.Errorf("failed to get public key for %v: %w", recipient, err)
|
||||
|
||||
@ -153,7 +153,7 @@ func (user *User) sync(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Sync the messages.
|
||||
if err := syncMessages(
|
||||
if err := user.syncMessages(
|
||||
ctx,
|
||||
user.ID(),
|
||||
messageIDs,
|
||||
@ -242,7 +242,7 @@ func toMB(v uint64) float64 {
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func syncMessages(
|
||||
func (user *User) syncMessages(
|
||||
ctx context.Context,
|
||||
userID string,
|
||||
messageIDs []string,
|
||||
@ -370,7 +370,7 @@ func syncMessages(
|
||||
errorCh := make(chan error, maxParallelDownloads*4)
|
||||
|
||||
// Go routine in charge of downloading message metadata
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(downloadCh)
|
||||
const MetadataDataPageSize = 150
|
||||
|
||||
@ -433,14 +433,14 @@ func syncMessages(
|
||||
}, logging.Labels{"sync-stage": "meta-data"})
|
||||
|
||||
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(buildCh)
|
||||
defer close(errorCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync downloader exit")
|
||||
}()
|
||||
|
||||
attachmentDownloader := newAttachmentDownloader(ctx, client, maxParallelDownloads)
|
||||
attachmentDownloader := user.newAttachmentDownloader(ctx, client, maxParallelDownloads)
|
||||
defer attachmentDownloader.close()
|
||||
|
||||
for request := range downloadCh {
|
||||
@ -456,6 +456,8 @@ func syncMessages(
|
||||
}
|
||||
|
||||
result, err := parallel.MapContext(ctx, maxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
|
||||
defer user.handlePanic()
|
||||
|
||||
var result proton.FullMessage
|
||||
|
||||
msg, err := client.GetMessage(ctx, id)
|
||||
@ -490,7 +492,7 @@ func syncMessages(
|
||||
}, logging.Labels{"sync-stage": "download"})
|
||||
|
||||
// Goroutine which builds messages after they have been downloaded
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(flushCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync builder exit")
|
||||
@ -509,6 +511,8 @@ func syncMessages(
|
||||
logrus.Debugf("Build request: %v of %v count=%v", index, len(chunks), len(chunk))
|
||||
|
||||
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
|
||||
defer user.handlePanic()
|
||||
|
||||
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil
|
||||
})
|
||||
if err != nil {
|
||||
@ -526,7 +530,7 @@ func syncMessages(
|
||||
}, logging.Labels{"sync-stage": "builder"})
|
||||
|
||||
// Goroutine which converts the messages into updates and builds a waitable structure for progress tracking.
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(flushUpdateCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync flush exit")
|
||||
@ -771,12 +775,12 @@ func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan at
|
||||
}
|
||||
}
|
||||
|
||||
func newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
|
||||
func (user *User) newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
|
||||
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
for i := 0; i < workerCount; i++ {
|
||||
workerCh = make(chan attachmentJob)
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
|
||||
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
|
||||
"sync": fmt.Sprintf("att-downloader %v", i),
|
||||
})
|
||||
}
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/async"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
@ -93,6 +94,6 @@ func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
|
||||
return sorted
|
||||
}
|
||||
|
||||
func newProtonAPIScheduler() proton.Scheduler {
|
||||
return proton.NewParallelScheduler(runtime.NumCPU() / 2)
|
||||
func newProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler {
|
||||
return proton.NewParallelScheduler(runtime.NumCPU()/2, panicHandler)
|
||||
}
|
||||
|
||||
@ -91,6 +91,8 @@ type User struct {
|
||||
showAllMail uint32
|
||||
|
||||
maxSyncMemory uint64
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
// New returns a new user.
|
||||
@ -127,7 +129,7 @@ func New(
|
||||
reporter: reporter,
|
||||
sendHash: newSendRecorder(sendEntryExpiry),
|
||||
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0, crashHandler),
|
||||
eventLock: safe.NewRWMutex(),
|
||||
|
||||
apiUser: apiUser,
|
||||
@ -148,6 +150,8 @@ func New(
|
||||
showAllMail: b32(showAllMail),
|
||||
|
||||
maxSyncMemory: maxSyncMemory,
|
||||
|
||||
panicHandler: crashHandler,
|
||||
}
|
||||
|
||||
// Initialize the user's update channels for its current address mode.
|
||||
@ -179,7 +183,10 @@ func New(
|
||||
user.goPollAPIEvents = func(wait bool) {
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
go func() { user.pollAPIEventsCh <- doneCh }()
|
||||
go func() {
|
||||
defer user.handlePanic()
|
||||
user.pollAPIEventsCh <- doneCh
|
||||
}()
|
||||
|
||||
if wait {
|
||||
<-doneCh
|
||||
@ -230,6 +237,12 @@ func New(
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (user *User) handlePanic() {
|
||||
if user.panicHandler != nil {
|
||||
user.panicHandler.HandlePanic()
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) TriggerSync() {
|
||||
user.goSync()
|
||||
}
|
||||
@ -596,7 +609,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
|
||||
|
||||
switch mode {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
|
||||
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh[addrID] = primaryUpdateCh
|
||||
@ -604,7 +617,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
|
||||
|
||||
case vault.SplitMode:
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -614,7 +627,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
func (user *User) startEvents(ctx context.Context) {
|
||||
ticker := proton.NewTicker(EventPeriod, EventJitter)
|
||||
ticker := proton.NewTicker(EventPeriod, EventJitter, user.panicHandler)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
||||
@ -119,7 +119,7 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
|
||||
saltedKeyPass, err := salts.SaltForKey([]byte(password), apiUser.Keys.Primary().ID)
|
||||
require.NoError(tb, err)
|
||||
|
||||
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"))
|
||||
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"), nil)
|
||||
require.NoError(tb, err)
|
||||
require.False(tb, corrupt)
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
@ -52,7 +53,7 @@ func TestMigrate(t *testing.T) {
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "vault.enc"), b, 0o600))
|
||||
|
||||
// Migrate the vault.
|
||||
s, corrupt, err := New(dir, "default-gluon-dir", []byte("my secret key"))
|
||||
s, corrupt, err := New(dir, "default-gluon-dir", []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -63,7 +64,7 @@ func TestVault_Settings_SMTP(t *testing.T) {
|
||||
|
||||
func TestVault_Settings_GluonDir(t *testing.T) {
|
||||
// create a new test vault.
|
||||
s, corrupt, err := vault.New(t.TempDir(), "/path/to/gluon", []byte("my secret key"))
|
||||
s, corrupt, err := vault.New(t.TempDir(), "/path/to/gluon", []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ import (
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/async"
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -44,10 +45,12 @@ type Vault struct {
|
||||
|
||||
ref map[string]int
|
||||
refLock sync.Mutex
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
// New constructs a new encrypted data vault at the given filepath using the given encryption key.
|
||||
func New(vaultDir, gluonCacheDir string, key []byte) (*Vault, bool, error) {
|
||||
func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHandler) (*Vault, bool, error) {
|
||||
if err := os.MkdirAll(vaultDir, 0o700); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
@ -69,9 +72,17 @@ func New(vaultDir, gluonCacheDir string, key []byte) (*Vault, bool, error) {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
vault.panicHandler = panicHandler
|
||||
|
||||
return vault, corrupt, nil
|
||||
}
|
||||
|
||||
func (vault *Vault) handlePanic() {
|
||||
if vault.panicHandler != nil {
|
||||
vault.panicHandler.HandlePanic()
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserIDs returns the user IDs and usernames of all users in the vault.
|
||||
func (vault *Vault) GetUserIDs() []string {
|
||||
return xslices.Map(vault.get().Users, func(user UserData) string {
|
||||
@ -115,6 +126,8 @@ func (vault *Vault) ForUser(parallelism int, fn func(*User) error) error {
|
||||
userIDs := vault.GetUserIDs()
|
||||
|
||||
return parallel.DoContext(context.Background(), parallelism, len(userIDs), func(_ context.Context, idx int) error {
|
||||
defer vault.handlePanic()
|
||||
|
||||
user, err := vault.NewUser(userIDs[idx])
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -31,7 +32,7 @@ func BenchmarkVault(b *testing.B) {
|
||||
vaultDir, gluonDir := b.TempDir(), b.TempDir()
|
||||
|
||||
// Create a new vault.
|
||||
s, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"))
|
||||
s, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(b, err)
|
||||
require.False(b, corrupt)
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -30,19 +31,19 @@ func TestVault_Corrupt(t *testing.T) {
|
||||
vaultDir, gluonDir := t.TempDir(), t.TempDir()
|
||||
|
||||
{
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"))
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
}
|
||||
|
||||
{
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"))
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
}
|
||||
|
||||
{
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("bad key"))
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("bad key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.True(t, corrupt)
|
||||
}
|
||||
@ -52,13 +53,13 @@ func TestVault_Corrupt_JunkData(t *testing.T) {
|
||||
vaultDir, gluonDir := t.TempDir(), t.TempDir()
|
||||
|
||||
{
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"))
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
}
|
||||
|
||||
{
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"))
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
}
|
||||
@ -71,7 +72,7 @@ func TestVault_Corrupt_JunkData(t *testing.T) {
|
||||
_, err = f.Write([]byte("junk data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"))
|
||||
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.True(t, corrupt)
|
||||
}
|
||||
@ -99,7 +100,7 @@ func TestVault_Reset(t *testing.T) {
|
||||
func newVault(t *testing.T) *vault.Vault {
|
||||
t.Helper()
|
||||
|
||||
s, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"))
|
||||
s, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), queue.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ func (c *eventCollector) collectFrom(eventCh <-chan events.Event) <-chan events.
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
fwdCh := queue.NewQueuedChannel[events.Event](0, 0)
|
||||
fwdCh := queue.NewQueuedChannel[events.Event](0, 0, queue.NoopPanicHandler{})
|
||||
|
||||
c.fwdCh = append(c.fwdCh, fwdCh)
|
||||
|
||||
@ -87,7 +87,7 @@ func (c *eventCollector) push(event events.Event) {
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if _, ok := c.events[reflect.TypeOf(event)]; !ok {
|
||||
c.events[reflect.TypeOf(event)] = queue.NewQueuedChannel[events.Event](0, 0)
|
||||
c.events[reflect.TypeOf(event)] = queue.NewQueuedChannel[events.Event](0, 0, queue.NoopPanicHandler{})
|
||||
}
|
||||
|
||||
c.events[reflect.TypeOf(event)].Enqueue(event)
|
||||
@ -102,7 +102,7 @@ func (c *eventCollector) getEventCh(ofType events.Event) <-chan events.Event {
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if _, ok := c.events[reflect.TypeOf(ofType)]; !ok {
|
||||
c.events[reflect.TypeOf(ofType)] = queue.NewQueuedChannel[events.Event](0, 0)
|
||||
c.events[reflect.TypeOf(ofType)] = queue.NewQueuedChannel[events.Event](0, 0, queue.NoopPanicHandler{})
|
||||
}
|
||||
|
||||
return c.events[reflect.TypeOf(ofType)].GetChannel()
|
||||
|
||||
@ -108,7 +108,7 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) {
|
||||
}
|
||||
|
||||
// Create the vault.
|
||||
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, t.storeKey)
|
||||
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, t.storeKey, queue.NoopPanicHandler{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create vault: %w", err)
|
||||
} else if corrupt {
|
||||
@ -301,7 +301,7 @@ func (t *testCtx) initFrontendClient() error {
|
||||
return fmt.Errorf("could not start event stream: %w", err)
|
||||
}
|
||||
|
||||
eventCh := queue.NewQueuedChannel[*frontend.StreamEvent](0, 0)
|
||||
eventCh := queue.NewQueuedChannel[*frontend.StreamEvent](0, 0, queue.NoopPanicHandler{})
|
||||
|
||||
go func() {
|
||||
defer eventCh.CloseAndDiscardQueued()
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/bradenaw/juniper/stream"
|
||||
@ -113,7 +114,7 @@ func (t *testCtx) withAddrKR(
|
||||
return err
|
||||
}
|
||||
|
||||
_, addrKRs, err := proton.Unlock(user, addr, keyPass)
|
||||
_, addrKRs, err := proton.Unlock(user, addr, keyPass, queue.NoopPanicHandler{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user