fix(BRIDGE-8): launcher replace session-id if provided instead of adding another one.

This commit is contained in:
Xavier Michelon
2024-04-09 18:01:07 +02:00
parent e94d3be12d
commit bb15efa711
2 changed files with 47 additions and 23 deletions

View File

@ -40,6 +40,7 @@ import (
"github.com/elastic/go-sysinfo/types" "github.com/elastic/go-sysinfo/types"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
"golang.org/x/sys/execabs" "golang.org/x/sys/execabs"
) )
@ -53,9 +54,12 @@ const (
FlagCLIShort = "c" FlagCLIShort = "c"
FlagNonInteractive = "noninteractive" FlagNonInteractive = "noninteractive"
FlagNonInteractiveShort = "n" FlagNonInteractiveShort = "n"
FlagLauncher = "--launcher" FlagLauncher = "launcher"
FlagWait = "--wait" FlagWait = "wait"
FlagSessionID = "--session-id" FlagSessionID = "session-id"
HyphenatedFlagLauncher = "--" + FlagLauncher
HyphenatedFlagWait = "--" + FlagWait
HyphenatedFlagSessionID = "--" + FlagSessionID
) )
func main() { //nolint:funlen func main() { //nolint:funlen
@ -151,7 +155,7 @@ func main() { //nolint:funlen
} }
} }
cmd := execabs.Command(exe, appendLauncherPath(launcher, append(args, FlagSessionID, string(sessionID)))...) //nolint:gosec cmd := execabs.Command(exe, appendLauncherPath(launcher, appendOrModifySessionID(args, string(sessionID)))...) //nolint:gosec
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
@ -173,19 +177,14 @@ func main() { //nolint:funlen
// appendLauncherPath add launcher path if missing. // appendLauncherPath add launcher path if missing.
func appendLauncherPath(path string, args []string) []string { func appendLauncherPath(path string, args []string) []string {
if !sliceContains(args, FlagLauncher) { if !slices.Contains(args, HyphenatedFlagLauncher) {
res := append([]string{}, args...) res := append([]string{}, args...)
res = append(res, FlagLauncher, path) res = append(res, HyphenatedFlagLauncher, path)
return res return res
} }
return args return args
} }
// sliceContains checks if a value is present in a list.
func sliceContains[T comparable](list []T, s T) bool {
return xslices.Any(list, func(arg T) bool { return arg == s })
}
// inCLIMode detect if CLI mode is asked. // inCLIMode detect if CLI mode is asked.
func inCLIMode(args []string) bool { func inCLIMode(args []string) bool {
return hasFlag(args, FlagCLI) || hasFlag(args, FlagCLIShort) || hasFlag(args, FlagNonInteractive) || hasFlag(args, FlagNonInteractiveShort) return hasFlag(args, FlagCLI) || hasFlag(args, FlagCLIShort) || hasFlag(args, FlagNonInteractive) || hasFlag(args, FlagNonInteractiveShort)
@ -193,7 +192,12 @@ func inCLIMode(args []string) bool {
// hasFlag checks if a flag is present in a list. // hasFlag checks if a flag is present in a list.
func hasFlag(args []string, flag string) bool { func hasFlag(args []string, flag string) bool {
return xslices.Any(args, func(arg string) bool { return (arg == "-"+flag) || (arg == "--"+flag) }) return flagIndex(args, flag) >= 0
}
// flagIndex returns the position of the first occurrence of a flag int args, or -1 if the flag is not present.
func flagIndex(args []string, flag string) int {
return slices.IndexFunc(args, func(arg string) bool { return (arg == "-"+flag) || (arg == "--"+flag) })
} }
// findAndStrip check if a value is present in s list and remove all occurrences of the value from this list. // findAndStrip check if a value is present in s list and remove all occurrences of the value from this list.
@ -211,7 +215,7 @@ func findAndStripWait(args []string) ([]string, bool, []string) {
hasFlag := false hasFlag := false
values := make([]string, 0) values := make([]string, 0)
for k, v := range res { for k, v := range res {
if v != FlagWait { if v != HyphenatedFlagWait {
continue continue
} }
if k+1 >= len(res) { if k+1 >= len(res) {
@ -222,7 +226,7 @@ func findAndStripWait(args []string) ([]string, bool, []string) {
} }
if hasFlag { if hasFlag {
res, _ = findAndStrip(res, FlagWait) res, _ = findAndStrip(res, HyphenatedFlagWait)
for _, v := range values { for _, v := range values {
res, _ = findAndStrip(res, v) res, _ = findAndStrip(res, v)
} }
@ -230,6 +234,23 @@ func findAndStripWait(args []string) ([]string, bool, []string) {
return res, hasFlag, values return res, hasFlag, values
} }
// return args with the sessionID flag and value added or modified. The original slice is not modified.
func appendOrModifySessionID(args []string, sessionID string) []string {
index := flagIndex(args, FlagSessionID)
if index < 0 {
return append(args, HyphenatedFlagSessionID, sessionID)
}
if index == len(args)-1 {
return append(args, sessionID)
}
res := slices.Clone(args)
res[index+1] = sessionID
return res
}
func getPathToUpdatedExecutable( func getPathToUpdatedExecutable(
name string, name string,
ver *versioner.Versioner, ver *versioner.Versioner,

View File

@ -20,19 +20,12 @@ package main
import ( import (
"testing" "testing"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/bradenaw/juniper/xslices" "github.com/bradenaw/juniper/xslices"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
) )
func TestSliceContains(t *testing.T) {
assert.True(t, sliceContains([]string{"a", "b", "c"}, "a"))
assert.True(t, sliceContains([]int{1, 2, 3}, 2))
assert.False(t, sliceContains([]string{"a", "b", "c"}, "A"))
assert.False(t, sliceContains([]int{1, 2, 3}, 4))
assert.False(t, sliceContains([]string{}, "a"))
assert.True(t, sliceContains([]string{"a", "a"}, "a"))
}
func TestFindAndStrip(t *testing.T) { func TestFindAndStrip(t *testing.T) {
list := []string{"a", "b", "c", "c", "b", "c"} list := []string{"a", "b", "c", "c", "b", "c"}
@ -78,3 +71,13 @@ func TestFindAndStripWait(t *testing.T) {
assert.True(t, xslices.Equal(result, []string{"a"})) assert.True(t, xslices.Equal(result, []string{"a"}))
assert.True(t, xslices.Equal(values, []string{"b", "c", "d"})) assert.True(t, xslices.Equal(values, []string{"b", "c", "d"}))
} }
func TestAppendOrModifySessionID(t *testing.T) {
sessionID := string(logging.NewSessionID())
assert.True(t, slices.Equal(appendOrModifySessionID(nil, sessionID), []string{"--session-id", sessionID}))
assert.True(t, slices.Equal(appendOrModifySessionID([]string{}, sessionID), []string{"--session-id", sessionID}))
assert.True(t, slices.Equal(appendOrModifySessionID([]string{"--cli"}, sessionID), []string{"--cli", "--session-id", sessionID}))
assert.True(t, slices.Equal(appendOrModifySessionID([]string{"--cli", "--session-id"}, sessionID), []string{"--cli", "--session-id", sessionID}))
assert.True(t, slices.Equal(appendOrModifySessionID([]string{"--cli", "--session-id"}, sessionID), []string{"--cli", "--session-id", sessionID}))
assert.True(t, slices.Equal(appendOrModifySessionID([]string{"--session-id", "<oldID>", "--cli"}, sessionID), []string{"--session-id", sessionID, "--cli"}))
}