diff --git a/go.mod b/go.mod index 5a040601..66f832ab 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.0 github.com/urfave/cli/v2 v2.16.3 - gitlab.protontech.ch/go/liteapi v0.33.2-0.20221011164043-97f5d601ba2b + gitlab.protontech.ch/go/liteapi v0.33.2-0.20221011193656-705963f7a7d9 golang.org/x/exp v0.0.0-20220921164117-439092de6870 golang.org/x/net v0.1.0 golang.org/x/sys v0.1.0 diff --git a/go.sum b/go.sum index 41b990f2..480806fe 100644 --- a/go.sum +++ b/go.sum @@ -397,8 +397,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0= github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= -gitlab.protontech.ch/go/liteapi v0.33.2-0.20221011164043-97f5d601ba2b h1:9bTndevIV9WTSbRsoLXmLj8bycla6O3KU7fFzEV09n0= -gitlab.protontech.ch/go/liteapi v0.33.2-0.20221011164043-97f5d601ba2b/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4= +gitlab.protontech.ch/go/liteapi v0.33.2-0.20221011193656-705963f7a7d9 h1:WErqL7DdcsQFNNy2Zkj8MT83HbSUbc17qptrEuVcbGA= +gitlab.protontech.ch/go/liteapi v0.33.2-0.20221011193656-705963f7a7d9/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= diff --git a/internal/app/app.go b/internal/app/app.go index d0e3f12b..c2a73f0b 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -2,9 +2,13 @@ package app import ( "fmt" + "net/http" + "net/http/cookiejar" "path/filepath" + "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/constants" + "github.com/ProtonMail/proton-bridge/v2/internal/cookies" "github.com/ProtonMail/proton-bridge/v2/internal/crash" "github.com/ProtonMail/proton-bridge/v2/internal/focus" bridgeCLI "github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli" @@ -12,8 +16,10 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/locations" "github.com/ProtonMail/proton-bridge/v2/internal/sentry" "github.com/ProtonMail/proton-bridge/v2/internal/useragent" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/pkg/restarter" "github.com/pkg/profile" + "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" ) @@ -93,33 +99,66 @@ func run(c *cli.Context) error { return nil } - // Start CPU profile if requested. - if c.Bool(flagCPUProfile) { - p := profile.Start(profile.CPUProfile, profile.ProfilePath(".")) - defer p.Stop() - } - - // Start memory profile if requested. - if c.Bool(flagMemProfile) { - p := profile.Start(profile.MemProfile, profile.MemProfileAllocs, profile.ProfilePath(".")) - defer p.Stop() - } - - // Create the restarter. - restarter := restarter.New() - defer restarter.Restart() - // Create a user agent that will be used for all requests. identifier := useragent.New() - // Create a crash handler that will send crash reports to sentry. - crashHandler := crash.NewHandler( - sentry.NewReporter(constants.FullAppName, constants.Version, identifier).ReportException, - crash.ShowErrorNotification(constants.FullAppName), - func(r interface{}) error { restarter.Set(true, true); return nil }, - ) - defer crashHandler.HandlePanic() + // Create a new Sentry client that will be used to report crashes etc. + reporter := sentry.NewReporter(constants.FullAppName, constants.Version, identifier) + // Run with profiling if requested. + return withProfiler(c, func() error { + // Restart the app if requested. + return withRestarter(func(restarter *restarter.Restarter) error { + // Handle crashes with various actions. + return withCrashHandler(restarter, reporter, func(crashHandler *crash.Handler) error { + // Load the locations where we store our files. + return withLocations(func(locations *locations.Locations) error { + // Initialize the logging. + if err := initLogging(c, locations, crashHandler); err != nil { + return fmt.Errorf("could not initialize logging: %w", err) + } + + // Unlock the encrypted vault. + return withVault(locations, func(vault *vault.Vault, insecure, corrupt bool) error { + // Load the cookies from the vault. + return withCookieJar(vault, func(cookieJar http.CookieJar) error { + // Create a new bridge instance. + return withBridge(c, locations, identifier, reporter, vault, cookieJar, func(b *bridge.Bridge) error { + if insecure { + logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted") + b.PushError(bridge.ErrVaultInsecure) + } + + if corrupt { + logrus.Warn("The vault is corrupt and has been wiped") + b.PushError(bridge.ErrVaultCorrupt) + } + + switch { + case c.Bool(flagCLI): + return bridgeCLI.New(b).Loop() + + case c.Bool(flagNonInteractive): + select {} + + default: + service, err := grpc.NewService(crashHandler, restarter, locations, b, !c.Bool(flagNoWindow)) + if err != nil { + return fmt.Errorf("could not create service: %w", err) + } + + return service.Loop() + } + }) + }) + }) + }) + }) + }) + }) +} + +func withLocations(fn func(*locations.Locations) error) error { // Create a locations provider to determine where to store our files. provider, err := locations.NewDefaultProvider(filepath.Join(constants.VendorName, constants.ConfigName)) if err != nil { @@ -129,32 +168,67 @@ func run(c *cli.Context) error { // Create a new locations object that will be used to provide paths to store files. locations := locations.New(provider, constants.ConfigName) - // Initialize the logging. - if err := initLogging(c, locations, crashHandler); err != nil { - return fmt.Errorf("could not initialize logging: %w", err) - } + // TODO: Add teardown actions (removing the lock file, etc.) - // Create the bridge. - bridge, err := newBridge(c, locations, identifier) - if err != nil { - return fmt.Errorf("could not create bridge: %w", err) - } - defer bridge.Close(c.Context) - - // Start the frontend. - switch { - case c.Bool(flagCLI): - return bridgeCLI.New(bridge).Loop() - - case c.Bool(flagNonInteractive): - select {} - - default: - service, err := grpc.NewService(crashHandler, restarter, locations, bridge, !c.Bool(flagNoWindow)) - if err != nil { - return fmt.Errorf("could not create service: %w", err) - } - - return service.Loop() - } + return fn(locations) +} + +func withProfiler(c *cli.Context, fn func() error) error { + // Start CPU profile if requested. + if c.Bool(flagCPUProfile) { + defer profile.Start(profile.CPUProfile, profile.ProfilePath(".")).Stop() + } + + // Start memory profile if requested. + if c.Bool(flagMemProfile) { + defer profile.Start(profile.MemProfile, profile.MemProfileAllocs, profile.ProfilePath(".")).Stop() + } + + return fn() +} + +func withRestarter(fn func(*restarter.Restarter) error) error { + restarter := restarter.New() + defer restarter.Restart() + + return fn(restarter) +} + +func withCrashHandler(restarter *restarter.Restarter, reporter *sentry.Reporter, fn func(*crash.Handler) error) error { + crashHandler := crash.NewHandler(crash.ShowErrorNotification(constants.FullAppName)) + defer crashHandler.HandlePanic() + + // On crash, send crash report to Sentry. + crashHandler.AddRecoveryAction(reporter.ReportException) + + // On crash, notify the user and restart the app. + crashHandler.AddRecoveryAction(crash.ShowErrorNotification(constants.FullAppName)) + + // On crash, restart the app. + crashHandler.AddRecoveryAction(func(r any) error { restarter.Set(true, true); return nil }) + + return fn(crashHandler) +} + +func withCookieJar(vault *vault.Vault, fn func(http.CookieJar) error) error { + // Create the underlying cookie jar. + jar, err := cookiejar.New(nil) + if err != nil { + return fmt.Errorf("could not create cookie jar: %w", err) + } + + // Create the cookie jar which persists to the vault. + persister, err := cookies.NewCookieJar(jar, vault) + if err != nil { + return fmt.Errorf("could not create cookie jar: %w", err) + } + + // Persist the cookies to the vault when we close. + defer func() { + if err := persister.PersistCookies(); err != nil { + logrus.WithError(err).Error("Failed to persist cookies") + } + }() + + return fn(persister) } diff --git a/internal/app/bridge.go b/internal/app/bridge.go index e4c75b18..ae09cbbb 100644 --- a/internal/app/bridge.go +++ b/internal/app/bridge.go @@ -1,9 +1,8 @@ package app import ( - "encoding/base64" "fmt" - "github.com/urfave/cli/v2" + "net/http" "os" "runtime" @@ -11,22 +10,30 @@ import ( "github.com/ProtonMail/go-autostart" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/bridge" - "github.com/ProtonMail/proton-bridge/v2/internal/certs" "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/dialer" "github.com/ProtonMail/proton-bridge/v2/internal/locations" + "github.com/ProtonMail/proton-bridge/v2/internal/sentry" "github.com/ProtonMail/proton-bridge/v2/internal/updater" "github.com/ProtonMail/proton-bridge/v2/internal/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/internal/versioner" - "github.com/ProtonMail/proton-bridge/v2/pkg/keychain" "github.com/sirupsen/logrus" - "golang.org/x/exp/slices" + "github.com/urfave/cli/v2" ) const vaultSecretName = "bridge-vault-key" -func newBridge(c *cli.Context, locations *locations.Locations, identifier *useragent.UserAgent) (*bridge.Bridge, error) { +// withBridge creates creates and tears down the bridge. +func withBridge( + c *cli.Context, + locations *locations.Locations, + identifier *useragent.UserAgent, + reporter *sentry.Reporter, + vault *vault.Vault, + cookieJar http.CookieJar, + fn func(*bridge.Bridge) error, +) error { // Create the underlying dialer used by the bridge. // It only connects to trusted servers and reports any untrusted servers it finds. pinningDialer := dialer.NewPinningTLSDialer( @@ -41,145 +48,55 @@ func newBridge(c *cli.Context, locations *locations.Locations, identifier *usera // Create the autostarter. autostarter, err := newAutostarter() if err != nil { - return nil, fmt.Errorf("could not create autostarter: %w", err) + return fmt.Errorf("could not create autostarter: %w", err) } // Create the update installer. updater, err := newUpdater(locations) if err != nil { - return nil, fmt.Errorf("could not create updater: %w", err) + return fmt.Errorf("could not create updater: %w", err) } // Get the current bridge version. version, err := semver.NewVersion(constants.Version) if err != nil { - return nil, fmt.Errorf("could not create version: %w", err) - } - - // Create the encVault. - encVault, insecure, corrupt, err := newVault(locations) - if err != nil { - return nil, fmt.Errorf("could not create vault: %w", err) - } else if insecure { - logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted") - } else if corrupt { - logrus.Warn("The vault is corrupt and has been wiped") - } - - // Install the certificates if needed. - if installed := encVault.GetCertsInstalled(); !installed { - if err := certs.NewInstaller().InstallCert(encVault.GetBridgeTLSCert()); err != nil { - return nil, fmt.Errorf("failed to install certs: %w", err) - } - - if err := encVault.SetCertsInstalled(true); err != nil { - return nil, fmt.Errorf("failed to set certs installed: %w", err) - } - - if err := encVault.SetCertsInstalled(true); err != nil { - return nil, fmt.Errorf("could not set certs installed: %w", err) - } + return fmt.Errorf("could not create version: %w", err) } // Create a new bridge. bridge, err := bridge.New( - constants.APIHost, + // The app stuff. locations, - encVault, + vault, + autostarter, + updater, + version, + + // The API stuff. + constants.APIHost, + cookieJar, identifier, pinningDialer, dialer.CreateTransportWithDialer(proxyDialer), proxyDialer, - autostarter, - updater, - version, + + // The logging stuff. c.String(flagLogIMAP) == "client" || c.String(flagLogIMAP) == "all", c.String(flagLogIMAP) == "server" || c.String(flagLogIMAP) == "all", c.Bool(flagLogSMTP), ) if err != nil { - return nil, fmt.Errorf("could not create bridge: %w", err) + return fmt.Errorf("could not create bridge: %w", err) } - // If the vault could not be loaded properly, push errors to the bridge. - switch { - case insecure: - bridge.PushError(vault.ErrInsecure) - - case corrupt: - bridge.PushError(vault.ErrCorrupt) - } - - return bridge, nil -} - -func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error) { - var insecure bool - - vaultDir, err := locations.ProvideSettingsPath() - if err != nil { - return nil, false, false, fmt.Errorf("could not get vault dir: %w", err) - } - - var vaultKey []byte - - if key, err := getVaultKey(vaultDir); err != nil { - insecure = true - } else { - vaultKey = key - } - - gluonDir, err := locations.ProvideGluonPath() - if err != nil { - return nil, false, false, fmt.Errorf("could not provide gluon path: %w", err) - } - - vault, corrupt, err := vault.New(vaultDir, gluonDir, vaultKey) - if err != nil { - return nil, false, false, fmt.Errorf("could not create vault: %w", err) - } - - return vault, insecure, corrupt, nil -} - -func getVaultKey(vaultDir string) ([]byte, error) { - helper, err := vault.GetHelper(vaultDir) - if err != nil { - return nil, fmt.Errorf("could not get keychain helper: %w", err) - } - - keychain, err := keychain.NewKeychain(helper, constants.KeyChainName) - if err != nil { - return nil, fmt.Errorf("could not create keychain: %w", err) - } - - secrets, err := keychain.List() - if err != nil { - return nil, fmt.Errorf("could not list keychain: %w", err) - } - - if !slices.Contains(secrets, vaultSecretName) { - tok, err := crypto.RandomToken(32) - if err != nil { - return nil, fmt.Errorf("could not generate random token: %w", err) + // Close the bridge when we exit. + defer func() { + if err := bridge.Close(c.Context); err != nil { + logrus.WithError(err).Error("Failed to close bridge") } + }() - if err := keychain.Put(vaultSecretName, base64.StdEncoding.EncodeToString(tok)); err != nil { - return nil, fmt.Errorf("could not put keychain item: %w", err) - } - } - - _, keyEnc, err := keychain.Get(vaultSecretName) - if err != nil { - return nil, fmt.Errorf("could not get keychain item: %w", err) - } - - keyDec, err := base64.StdEncoding.DecodeString(keyEnc) - if err != nil { - return nil, fmt.Errorf("could not decode keychain item: %w", err) - } - - return keyDec, nil + return fn(bridge) } func newAutostarter() (*autostart.App, error) { diff --git a/internal/app/vault.go b/internal/app/vault.go new file mode 100644 index 00000000..7022e4ae --- /dev/null +++ b/internal/app/vault.go @@ -0,0 +1,110 @@ +package app + +import ( + "encoding/base64" + "fmt" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/proton-bridge/v2/internal/certs" + "github.com/ProtonMail/proton-bridge/v2/internal/constants" + "github.com/ProtonMail/proton-bridge/v2/internal/locations" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" + "github.com/ProtonMail/proton-bridge/v2/pkg/keychain" + "golang.org/x/exp/slices" +) + +func withVault(locations *locations.Locations, fn func(*vault.Vault, bool, bool) error) error { + // Create the encVault. + encVault, insecure, corrupt, err := newVault(locations) + if err != nil { + return fmt.Errorf("could not create vault: %w", err) + } + + // Install the certificates if needed. + if installed := encVault.GetCertsInstalled(); !installed { + if err := certs.NewInstaller().InstallCert(encVault.GetBridgeTLSCert()); err != nil { + return fmt.Errorf("failed to install certs: %w", err) + } + + if err := encVault.SetCertsInstalled(true); err != nil { + return fmt.Errorf("failed to set certs installed: %w", err) + } + + if err := encVault.SetCertsInstalled(true); err != nil { + return fmt.Errorf("could not set certs installed: %w", err) + } + } + + // TODO: Add teardown actions (e.g. to close the vault). + + return fn(encVault, insecure, corrupt) +} + +func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error) { + var insecure bool + + vaultDir, err := locations.ProvideSettingsPath() + if err != nil { + return nil, false, false, fmt.Errorf("could not get vault dir: %w", err) + } + + var vaultKey []byte + + if key, err := getVaultKey(vaultDir); err != nil { + insecure = true + } else { + vaultKey = key + } + + gluonDir, err := locations.ProvideGluonPath() + if err != nil { + return nil, false, false, fmt.Errorf("could not provide gluon path: %w", err) + } + + vault, corrupt, err := vault.New(vaultDir, gluonDir, vaultKey) + if err != nil { + return nil, false, false, fmt.Errorf("could not create vault: %w", err) + } + + return vault, insecure, corrupt, nil +} + +func getVaultKey(vaultDir string) ([]byte, error) { + helper, err := vault.GetHelper(vaultDir) + if err != nil { + return nil, fmt.Errorf("could not get keychain helper: %w", err) + } + + keychain, err := keychain.NewKeychain(helper, constants.KeyChainName) + if err != nil { + return nil, fmt.Errorf("could not create keychain: %w", err) + } + + secrets, err := keychain.List() + if err != nil { + return nil, fmt.Errorf("could not list keychain: %w", err) + } + + if !slices.Contains(secrets, vaultSecretName) { + tok, err := crypto.RandomToken(32) + if err != nil { + return nil, fmt.Errorf("could not generate random token: %w", err) + } + + if err := keychain.Put(vaultSecretName, base64.StdEncoding.EncodeToString(tok)); err != nil { + return nil, fmt.Errorf("could not put keychain item: %w", err) + } + } + + _, keyEnc, err := keychain.Get(vaultSecretName) + if err != nil { + return nil, fmt.Errorf("could not get keychain item: %w", err) + } + + keyDec, err := base64.StdEncoding.DecodeString(keyEnc) + if err != nil { + return nil, fmt.Errorf("could not decode keychain item: %w", err) + } + + return keyDec, nil +} diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index c0c3b7d1..46ac32d5 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -14,7 +14,6 @@ import ( "github.com/ProtonMail/gluon" "github.com/ProtonMail/gluon/watcher" "github.com/ProtonMail/proton-bridge/v2/internal/constants" - "github.com/ProtonMail/proton-bridge/v2/internal/cookies" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/focus" "github.com/ProtonMail/proton-bridge/v2/internal/user" @@ -35,7 +34,6 @@ type Bridge struct { // api manages user API clients. api *liteapi.Manager - cookieJar *cookies.Jar proxyCtl ProxyController identifier Identifier @@ -82,24 +80,22 @@ type Bridge struct { // New creates a new bridge. func New( - apiURL string, // the URL of the API to use locator Locator, // the locator to provide paths to store data vault *vault.Vault, // the bridge's encrypted data store + autostarter Autostarter, // the autostarter to manage autostart settings + updater Updater, // the updater to fetch and install updates + curVersion *semver.Version, // the current version of the bridge + + apiURL string, // the URL of the API to use + cookieJar http.CookieJar, // the cookie jar to use identifier Identifier, // the identifier to keep track of the user agent tlsReporter TLSReporter, // the TLS reporter to report TLS errors roundTripper http.RoundTripper, // the round tripper to use for API requests proxyCtl ProxyController, // the DoH controller - autostarter Autostarter, // the autostarter to manage autostart settings - updater Updater, // the updater to fetch and install updates - curVersion *semver.Version, // the current version of the bridge + logIMAPClient, logIMAPServer bool, // whether to log IMAP client/server activity logSMTP bool, // whether to log SMTP activity ) (*Bridge, error) { - cookieJar, err := cookies.NewCookieJar(vault) - if err != nil { - return nil, fmt.Errorf("failed to create cookie jar: %w", err) - } - api := liteapi.New( liteapi.WithHostURL(apiURL), liteapi.WithAppVersion(constants.AppVersion), @@ -133,19 +129,20 @@ func New( } bridge := newBridge( + locator, vault, + autostarter, + updater, + curVersion, + api, - cookieJar, - proxyCtl, identifier, + proxyCtl, + tlsConfig, imapServer, smtpBackend, - updater, - curVersion, focusService, - autostarter, - locator, logIMAPClient, logIMAPServer, logSMTP, @@ -159,19 +156,20 @@ func New( } func newBridge( + locator Locator, vault *vault.Vault, + autostarter Autostarter, + updater Updater, + curVersion *semver.Version, + api *liteapi.Manager, - cookieJar *cookies.Jar, - proxyCtl ProxyController, identifier Identifier, + proxyCtl ProxyController, + tlsConfig *tls.Config, imapServer *gluon.Server, smtpBackend *smtpBackend, - updater Updater, - curVersion *semver.Version, focusService *focus.Service, - autostarter Autostarter, - locator Locator, logIMAPClient, logIMAPServer, logSMTP bool, ) *Bridge { return &Bridge{ @@ -179,7 +177,6 @@ func newBridge( users: make(map[string]*user.User), api: api, - cookieJar: cookieJar, proxyCtl: proxyCtl, identifier: identifier, @@ -308,11 +305,6 @@ func (bridge *Bridge) Close(ctx context.Context) error { } } - // Persist the cookies. - if err := bridge.cookieJar.PersistCookies(); err != nil { - logrus.WithError(err).Error("Failed to persist cookies") - } - // Close the focus service. bridge.focusService.Close() diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index b13c8976..68b96adc 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -12,6 +12,7 @@ import ( "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/certs" + "github.com/ProtonMail/proton-bridge/v2/internal/cookies" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/focus" "github.com/ProtonMail/proton-bridge/v2/internal/locations" @@ -128,7 +129,7 @@ func TestBridge_UserAgent(t *testing.T) { require.NoError(t, err) // Assert that the user agent was sent to the API. - require.Contains(t, calls[len(calls)-1].Header.Get("User-Agent"), bridge.GetCurrentUserAgent()) + require.Contains(t, calls[len(calls)-1].RequestHeader.Get("User-Agent"), bridge.GetCurrentUserAgent()) }) }) } @@ -137,9 +138,9 @@ func TestBridge_Cookies(t *testing.T) { withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) { sessionIDs := safe.NewSet[string]() - // Save any session IDs the API returns. + // Save any session IDs we use. s.AddCallWatcher(func(call server.Call) { - cookie, err := (&http.Request{Header: call.Header}).Cookie("Session-Id") + cookie, err := (&http.Request{Header: call.RequestHeader}).Cookie("Session-Id") if err != nil { return } @@ -398,18 +399,29 @@ func withBridge( require.NoError(t, vault.SetIMAPPort(0)) require.NoError(t, vault.SetSMTPPort(0)) + // Create a new cookie jar. + cookieJar, err := cookies.NewCookieJar(bridge.NewTestCookieJar(), vault) + require.NoError(t, err) + defer func() { require.NoError(t, cookieJar.PersistCookies()) }() + // Create a new bridge. bridge, err := bridge.New( - apiURL, + // The app stuff. locator, vault, + mocks.Autostarter, + mocks.Updater, + v2_3_0, + + // The API stuff. + apiURL, + cookieJar, useragent.New(), mocks.TLSReporter, liteapi.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(), mocks.ProxyCtl, - mocks.Autostarter, - mocks.Updater, - v2_3_0, + + // The logging stuff. false, false, false, diff --git a/internal/bridge/errors.go b/internal/bridge/errors.go index d7d1e25c..6135b366 100644 --- a/internal/bridge/errors.go +++ b/internal/bridge/errors.go @@ -3,6 +3,9 @@ package bridge import "errors" var ( + ErrVaultInsecure = errors.New("the vault is insecure") + ErrVaultCorrupt = errors.New("the vault is corrupt") + ErrServeIMAP = errors.New("failed to serve IMAP") ErrServeSMTP = errors.New("failed to serve SMTP") ErrWatchUpdates = errors.New("failed to watch for updates") diff --git a/internal/bridge/mocks.go b/internal/bridge/mocks.go index ae6108e7..e985ffbf 100644 --- a/internal/bridge/mocks.go +++ b/internal/bridge/mocks.go @@ -1,6 +1,8 @@ package bridge import ( + "net/http" + "net/url" "os" "testing" @@ -41,6 +43,24 @@ func (mocks *Mocks) Close() { close(mocks.TLSIssueCh) } +type TestCookieJar struct { + cookies map[string][]*http.Cookie +} + +func NewTestCookieJar() *TestCookieJar { + return &TestCookieJar{ + cookies: make(map[string][]*http.Cookie), + } +} + +func (j *TestCookieJar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.cookies[u.Host] = cookies +} + +func (j *TestCookieJar) Cookies(u *url.URL) []*http.Cookie { + return j.cookies[u.Host] +} + type TestLocationsProvider struct { config, cache string } diff --git a/internal/cookies/jar.go b/internal/cookies/jar.go index acff7091..eca9691a 100644 --- a/internal/cookies/jar.go +++ b/internal/cookies/jar.go @@ -22,7 +22,6 @@ import ( "encoding/json" "fmt" "net/http" - "net/http/cookiejar" "net/url" "sync" "time" @@ -38,18 +37,14 @@ type Persister interface { // Jar implements http.CookieJar by wrapping the standard library's cookiejar.Jar. // The jar uses a pantry to load cookies at startup and save cookies when set. type Jar struct { - jar *cookiejar.Jar + jar http.CookieJar + persister Persister cookies cookiesByHost - locker sync.Locker + locker sync.RWMutex } -func NewCookieJar(persister Persister) (*Jar, error) { - jar, err := cookiejar.New(nil) - if err != nil { - return nil, err - } - +func NewCookieJar(jar http.CookieJar, persister Persister) (*Jar, error) { cookiesByHost, err := loadCookies(persister) if err != nil { return nil, err @@ -65,10 +60,10 @@ func NewCookieJar(persister Persister) (*Jar, error) { } return &Jar{ - jar: jar, + jar: jar, + persister: persister, cookies: cookiesByHost, - locker: &sync.Mutex{}, }, nil } @@ -88,16 +83,16 @@ func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { } func (j *Jar) Cookies(u *url.URL) []*http.Cookie { - j.locker.Lock() - defer j.locker.Unlock() + j.locker.RLock() + defer j.locker.RUnlock() return j.jar.Cookies(u) } // PersistCookies persists the cookies to disk. func (j *Jar) PersistCookies() error { - j.locker.Lock() - defer j.locker.Unlock() + j.locker.RLock() + defer j.locker.RUnlock() rawCookies, err := json.Marshal(j.cookies) if err != nil { diff --git a/internal/cookies/jar_test.go b/internal/cookies/jar_test.go index e41488b3..a7594b32 100644 --- a/internal/cookies/jar_test.go +++ b/internal/cookies/jar_test.go @@ -21,6 +21,7 @@ import ( "errors" "io/fs" "net/http" + "net/http/cookiejar" "net/http/httptest" "os" "path/filepath" @@ -138,10 +139,13 @@ type testCookie struct { } func getClientWithJar(t *testing.T, persister Persister) (*http.Client, *Jar) { - jar, err := NewCookieJar(persister) + jar, err := cookiejar.New(nil) require.NoError(t, err) - return &http.Client{Jar: jar}, jar + wrapper, err := NewCookieJar(jar, persister) + require.NoError(t, err) + + return &http.Client{Jar: wrapper}, wrapper } func getTestServer(t *testing.T, wantCookies []testCookie) *httptest.Server { diff --git a/internal/frontend/cli/frontend.go b/internal/frontend/cli/frontend.go index 471c267b..a2503188 100644 --- a/internal/frontend/cli/frontend.go +++ b/internal/frontend/cli/frontend.go @@ -24,7 +24,6 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/constants" "github.com/ProtonMail/proton-bridge/v2/internal/events" - "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/abiosoft/ishell" "github.com/sirupsen/logrus" @@ -266,10 +265,10 @@ func (f *frontendCLI) watchEvents() { // TODO: Better error events. for _, err := range f.bridge.GetErrors() { switch { - case errors.Is(err, vault.ErrCorrupt): + case errors.Is(err, bridge.ErrVaultCorrupt): f.notifyCredentialsError() - case errors.Is(err, vault.ErrInsecure): + case errors.Is(err, bridge.ErrVaultInsecure): f.notifyCredentialsError() case errors.Is(err, bridge.ErrServeIMAP): diff --git a/internal/frontend/grpc/service.go b/internal/frontend/grpc/service.go index dfe92f99..d81e41c4 100644 --- a/internal/frontend/grpc/service.go +++ b/internal/frontend/grpc/service.go @@ -34,7 +34,6 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/locations" "github.com/ProtonMail/proton-bridge/v2/internal/updater" - "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/pkg/restarter" "github.com/google/uuid" "github.com/sirupsen/logrus" @@ -202,10 +201,10 @@ func (s *Service) watchEvents() { // TODO: Better error events. for _, err := range s.bridge.GetErrors() { switch { - case errors.Is(err, vault.ErrCorrupt): + case errors.Is(err, bridge.ErrVaultCorrupt): _ = s.SendEvent(NewKeychainHasNoKeychainEvent()) - case errors.Is(err, vault.ErrInsecure): + case errors.Is(err, bridge.ErrVaultInsecure): _ = s.SendEvent(NewKeychainHasNoKeychainEvent()) case errors.Is(err, bridge.ErrServeIMAP): diff --git a/internal/vault/vault.go b/internal/vault/vault.go index 85f438ca..468f9556 100644 --- a/internal/vault/vault.go +++ b/internal/vault/vault.go @@ -15,11 +15,6 @@ import ( "github.com/bradenaw/juniper/xslices" ) -var ( - ErrInsecure = errors.New("the vault is insecure") - ErrCorrupt = errors.New("the vault is corrupt") -) - type Vault struct { path string enc []byte @@ -122,6 +117,10 @@ func (vault *Vault) DeleteUser(userID string) error { }) } +func (vault *Vault) Close() error { + return nil +} + func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) { if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) { if _, err := initVault(path, gluonDir, gcm); err != nil { diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index e9489aae..dcb0ef16 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -4,8 +4,10 @@ import ( "context" "crypto/tls" "fmt" + "net/http/cookiejar" "github.com/ProtonMail/proton-bridge/v2/internal/bridge" + "github.com/ProtonMail/proton-bridge/v2/internal/cookies" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/useragent" "github.com/ProtonMail/proton-bridge/v2/internal/vault" @@ -36,18 +38,31 @@ func (t *testCtx) startBridge() error { return fmt.Errorf("vault is corrupt") } + jar, err := cookiejar.New(nil) + if err != nil { + return err + } + + persister, err := cookies.NewCookieJar(jar, vault) + if err != nil { + return err + } + // Create the bridge. bridge, err := bridge.New( - t.api.GetHostURL(), t.locator, vault, + t.mocks.Autostarter, + t.mocks.Updater, + t.version, + + t.api.GetHostURL(), + persister, useragent.New(), t.mocks.TLSReporter, liteapi.NewDialer(t.netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(), t.mocks.ProxyCtl, - t.mocks.Autostarter, - t.mocks.Updater, - t.version, + false, false, false, diff --git a/tests/environment_test.go b/tests/environment_test.go index e57f72db..c5146aba 100644 --- a/tests/environment_test.go +++ b/tests/environment_test.go @@ -61,7 +61,7 @@ func (s *scenario) theHeaderInTheRequestToHasSetTo(method, path, key, value stri return err } - if haveKey := call.Header.Get(key); haveKey != value { + if haveKey := call.RequestHeader.Get(key); haveKey != value { return fmt.Errorf("have header %q, want %q", haveKey, value) } @@ -76,7 +76,7 @@ func (s *scenario) theBodyInTheRequestToIs(method, path string, value *godog.Doc var body, want map[string]any - if err := json.Unmarshal(call.Body, &body); err != nil { + if err := json.Unmarshal(call.RequestBody, &body); err != nil { return err }