diff --git a/cmd/launcher/main.go b/cmd/launcher/main.go index fa379fb9..bf01e836 100644 --- a/cmd/launcher/main.go +++ b/cmd/launcher/main.go @@ -40,6 +40,7 @@ import ( "github.com/elastic/go-sysinfo/types" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" "golang.org/x/sys/execabs" ) @@ -53,9 +54,12 @@ const ( FlagCLIShort = "c" FlagNonInteractive = "noninteractive" FlagNonInteractiveShort = "n" - FlagLauncher = "--launcher" - FlagWait = "--wait" - FlagSessionID = "--session-id" + FlagLauncher = "launcher" + FlagWait = "wait" + FlagSessionID = "session-id" + HyphenatedFlagLauncher = "--" + FlagLauncher + HyphenatedFlagWait = "--" + FlagWait + HyphenatedFlagSessionID = "--" + FlagSessionID ) func main() { //nolint:funlen @@ -151,7 +155,7 @@ func main() { //nolint:funlen } } - cmd := execabs.Command(exe, appendLauncherPath(launcher, append(args, FlagSessionID, string(sessionID)))...) //nolint:gosec + cmd := execabs.Command(exe, appendLauncherPath(launcher, appendOrModifySessionID(args, string(sessionID)))...) //nolint:gosec cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -173,19 +177,14 @@ func main() { //nolint:funlen // appendLauncherPath add launcher path if missing. func appendLauncherPath(path string, args []string) []string { - if !sliceContains(args, FlagLauncher) { + if !slices.Contains(args, HyphenatedFlagLauncher) { res := append([]string{}, args...) - res = append(res, FlagLauncher, path) + res = append(res, HyphenatedFlagLauncher, path) return res } return args } -// sliceContains checks if a value is present in a list. -func sliceContains[T comparable](list []T, s T) bool { - return xslices.Any(list, func(arg T) bool { return arg == s }) -} - // inCLIMode detect if CLI mode is asked. func inCLIMode(args []string) bool { return hasFlag(args, FlagCLI) || hasFlag(args, FlagCLIShort) || hasFlag(args, FlagNonInteractive) || hasFlag(args, FlagNonInteractiveShort) @@ -193,7 +192,12 @@ func inCLIMode(args []string) bool { // hasFlag checks if a flag is present in a list. func hasFlag(args []string, flag string) bool { - return xslices.Any(args, func(arg string) bool { return (arg == "-"+flag) || (arg == "--"+flag) }) + return flagIndex(args, flag) >= 0 +} + +// flagIndex returns the position of the first occurrence of a flag int args, or -1 if the flag is not present. +func flagIndex(args []string, flag string) int { + return slices.IndexFunc(args, func(arg string) bool { return (arg == "-"+flag) || (arg == "--"+flag) }) } // findAndStrip check if a value is present in s list and remove all occurrences of the value from this list. @@ -211,7 +215,7 @@ func findAndStripWait(args []string) ([]string, bool, []string) { hasFlag := false values := make([]string, 0) for k, v := range res { - if v != FlagWait { + if v != HyphenatedFlagWait { continue } if k+1 >= len(res) { @@ -222,7 +226,7 @@ func findAndStripWait(args []string) ([]string, bool, []string) { } if hasFlag { - res, _ = findAndStrip(res, FlagWait) + res, _ = findAndStrip(res, HyphenatedFlagWait) for _, v := range values { res, _ = findAndStrip(res, v) } @@ -230,6 +234,23 @@ func findAndStripWait(args []string) ([]string, bool, []string) { return res, hasFlag, values } +// return args with the sessionID flag and value added or modified. The original slice is not modified. +func appendOrModifySessionID(args []string, sessionID string) []string { + index := flagIndex(args, FlagSessionID) + if index < 0 { + return append(args, HyphenatedFlagSessionID, sessionID) + } + + if index == len(args)-1 { + return append(args, sessionID) + } + + res := slices.Clone(args) + res[index+1] = sessionID + + return res +} + func getPathToUpdatedExecutable( name string, ver *versioner.Versioner, diff --git a/cmd/launcher/main_test.go b/cmd/launcher/main_test.go index c055c82b..762f5c66 100644 --- a/cmd/launcher/main_test.go +++ b/cmd/launcher/main_test.go @@ -20,19 +20,12 @@ package main import ( "testing" + "github.com/ProtonMail/proton-bridge/v3/internal/logging" "github.com/bradenaw/juniper/xslices" "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" ) -func TestSliceContains(t *testing.T) { - assert.True(t, sliceContains([]string{"a", "b", "c"}, "a")) - assert.True(t, sliceContains([]int{1, 2, 3}, 2)) - assert.False(t, sliceContains([]string{"a", "b", "c"}, "A")) - assert.False(t, sliceContains([]int{1, 2, 3}, 4)) - assert.False(t, sliceContains([]string{}, "a")) - assert.True(t, sliceContains([]string{"a", "a"}, "a")) -} - func TestFindAndStrip(t *testing.T) { list := []string{"a", "b", "c", "c", "b", "c"} @@ -78,3 +71,13 @@ func TestFindAndStripWait(t *testing.T) { assert.True(t, xslices.Equal(result, []string{"a"})) assert.True(t, xslices.Equal(values, []string{"b", "c", "d"})) } + +func TestAppendOrModifySessionID(t *testing.T) { + sessionID := string(logging.NewSessionID()) + assert.True(t, slices.Equal(appendOrModifySessionID(nil, sessionID), []string{"--session-id", sessionID})) + assert.True(t, slices.Equal(appendOrModifySessionID([]string{}, sessionID), []string{"--session-id", sessionID})) + assert.True(t, slices.Equal(appendOrModifySessionID([]string{"--cli"}, sessionID), []string{"--cli", "--session-id", sessionID})) + assert.True(t, slices.Equal(appendOrModifySessionID([]string{"--cli", "--session-id"}, sessionID), []string{"--cli", "--session-id", sessionID})) + assert.True(t, slices.Equal(appendOrModifySessionID([]string{"--cli", "--session-id"}, sessionID), []string{"--cli", "--session-id", sessionID})) + assert.True(t, slices.Equal(appendOrModifySessionID([]string{"--session-id", "", "--cli"}, sessionID), []string{"--session-id", sessionID, "--cli"})) +}