diff --git a/internal/app/app.go b/internal/app/app.go index 4a1fbfba..55d7346d 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "net/http/cookiejar" + "os" "path/filepath" "github.com/Masterminds/semver/v3" @@ -123,10 +124,23 @@ func run(c *cli.Context) error { // Create a new Sentry client that will be used to report crashes etc. reporter := sentry.NewReporter(constants.FullAppName, constants.Version, identifier) + // Determine the exe that should be used to restart/autostart the app. + // By default, this is the launcher, if used. Otherwise, we try to get + // the current exe, and fall back to os.Args[0] if that fails. + var exe string + + if launcher := c.String(flagLauncher); launcher != "" { + exe = launcher + } else if executable, err := os.Executable(); err == nil { + exe = executable + } else { + exe = os.Args[0] + } + // Run with profiling if requested. return withProfiler(c, func() error { // Restart the app if requested. - return withRestarter(func(restarter *restarter.Restarter) error { + return withRestarter(exe, func(restarter *restarter.Restarter) error { // Handle crashes with various actions. return withCrashHandler(restarter, reporter, func(crashHandler *crash.Handler) error { // Load the locations where we store our files. @@ -140,7 +154,7 @@ func run(c *cli.Context) error { // Load the cookies from the vault. return withCookieJar(vault, func(cookieJar http.CookieJar) error { // Create a new bridge instance. - return withBridge(c, locations, version, identifier, reporter, vault, cookieJar, func(b *bridge.Bridge, eventCh <-chan events.Event) error { + return withBridge(c, exe, locations, version, identifier, reporter, vault, cookieJar, func(b *bridge.Bridge, eventCh <-chan events.Event) error { if insecure { logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted") b.PushError(bridge.ErrVaultInsecure) @@ -223,8 +237,8 @@ func withProfiler(c *cli.Context, fn func() error) error { } // Restart the app if necessary. -func withRestarter(fn func(*restarter.Restarter) error) error { - restarter := restarter.New() +func withRestarter(exe string, fn func(*restarter.Restarter) error) error { + restarter := restarter.New(exe) defer restarter.Restart() return fn(restarter) diff --git a/internal/app/bridge.go b/internal/app/bridge.go index dc0518a5..0a7ea032 100644 --- a/internal/app/bridge.go +++ b/internal/app/bridge.go @@ -3,7 +3,6 @@ package app import ( "fmt" "net/http" - "os" "runtime" "github.com/Masterminds/semver/v3" @@ -28,6 +27,7 @@ const vaultSecretName = "bridge-vault-key" // withBridge creates creates and tears down the bridge. func withBridge( c *cli.Context, + exe string, locations *locations.Locations, version *semver.Version, identifier *useragent.UserAgent, @@ -48,7 +48,7 @@ func withBridge( proxyDialer := dialer.NewProxyTLSDialer(pinningDialer, constants.APIHost) // Create the autostarter. - autostarter, err := newAutostarter() + autostarter, err := newAutostarter(exe) if err != nil { return fmt.Errorf("could not create autostarter: %w", err) } @@ -95,12 +95,7 @@ func withBridge( return fn(bridge, eventCh) } -func newAutostarter() (*autostart.App, error) { - exe, err := os.Executable() - if err != nil { - return nil, err - } - +func newAutostarter(exe string) (*autostart.App, error) { return &autostart.App{ Name: constants.FullAppName, DisplayName: constants.FullAppName, diff --git a/internal/frontend/grpc/service_methods.go b/internal/frontend/grpc/service_methods.go index b068b68f..5a987b16 100644 --- a/internal/frontend/grpc/service_methods.go +++ b/internal/frontend/grpc/service_methods.go @@ -332,27 +332,21 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp return &emptypb.Empty{}, nil } -/* func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.StringValue) (*emptypb.Empty, error) { s.log.WithField("launcher", launcher.Value).Debug("ForceLauncher") - go func() { - defer s.panicHandler.HandlePanic() - s.restarter.ForceLauncher(launcher.Value) - }() + s.restarter.Override(launcher.Value) + return &emptypb.Empty{}, nil } func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringValue) (*emptypb.Empty, error) { s.log.WithField("executable", exe.Value).Debug("SetMainExecutable") - go func() { - defer s.panicHandler.HandlePanic() - s.restarter.SetMainExecutable(exe.Value) - }() + s.restarter.AddFlags("--wait", exe.Value) + return &emptypb.Empty{}, nil } -*/ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) { s.log.WithField("username", login.Username).Debug("Login") diff --git a/pkg/restarter/restarter.go b/pkg/restarter/restarter.go index 0745bccd..841eecf9 100644 --- a/pkg/restarter/restarter.go +++ b/pkg/restarter/restarter.go @@ -5,30 +5,22 @@ import ( "strconv" "strings" + "github.com/bradenaw/juniper/xslices" "github.com/sirupsen/logrus" "golang.org/x/sys/execabs" ) -const ( - BridgeCrashCount = "BRIDGE_CRASH_COUNT" - BridgeLauncher = "BRIDGE_LAUNCHER" -) +const BridgeCrashCount = "BRIDGE_CRASH_COUNT" type Restarter struct { restart bool crash bool - exe string + + exe string + flags []string } -func New() *Restarter { - var exe string - - if osExe, err := os.Executable(); err == nil { - exe = osExe - } else { - logrus.WithError(err).Error("Failed to get executable path, the app will not be able to restart") - } - +func New(exe string) *Restarter { return &Restarter{exe: exe} } @@ -37,6 +29,14 @@ func (restarter *Restarter) Set(restart, crash bool) { restarter.crash = crash } +func (restarter *Restarter) Override(exe string) { + restarter.exe = exe +} + +func (restarter *Restarter) AddFlags(flags ...string) { + restarter.flags = append(restarter.flags, flags...) +} + func (restarter *Restarter) Restart() { if !restarter.restart { return @@ -49,12 +49,12 @@ func (restarter *Restarter) Restart() { env := getEnvMap() if restarter.crash { - env[BridgeCrashCount] = increment(env[BridgeLauncher]) + env[BridgeCrashCount] = increment(env[BridgeCrashCount]) } else { delete(env, BridgeCrashCount) } - cmd := execabs.Command(restarter.exe, os.Args[1:]...) + cmd := execabs.Command(restarter.exe, xslices.Join(os.Args[1:], restarter.flags)...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout