From 4a8a7ef0939983ef5cc6acf19889a29d0cf5b5d6 Mon Sep 17 00:00:00 2001 From: Atanas Janeshliev Date: Thu, 21 Mar 2024 16:32:12 +0000 Subject: [PATCH] fix(BRIDGE-4): logs not being created when invalid flag is passed --- cmd/Desktop-Bridge/main.go | 76 ++++++++++++++++++++++++++++++++- cmd/Desktop-Bridge/main_test.go | 47 ++++++++++++++++++++ internal/app/app.go | 6 +-- 3 files changed, 125 insertions(+), 4 deletions(-) create mode 100644 cmd/Desktop-Bridge/main_test.go diff --git a/cmd/Desktop-Bridge/main.go b/cmd/Desktop-Bridge/main.go index b7f697d7..0dc977b1 100644 --- a/cmd/Desktop-Bridge/main.go +++ b/cmd/Desktop-Bridge/main.go @@ -19,8 +19,15 @@ package main import ( "os" + "runtime" "strings" + "github.com/ProtonMail/proton-bridge/v3/internal/constants" + "github.com/ProtonMail/proton-bridge/v3/internal/locations" + "github.com/ProtonMail/proton-bridge/v3/internal/logging" + "github.com/ProtonMail/proton-bridge/v3/internal/sentry" + "github.com/sirupsen/logrus" + "github.com/ProtonMail/proton-bridge/v3/internal/app" "github.com/bradenaw/juniper/xslices" ) @@ -43,5 +50,72 @@ import ( */ func main() { - _ = app.New().Run(xslices.Filter(os.Args, func(arg string) bool { return !strings.Contains(arg, "-psn_") })) + appErr := app.New().Run(xslices.Filter(os.Args, func(arg string) bool { return !strings.Contains(arg, "-psn_") })) + if appErr != nil { + _ = app.WithLocations(func(l *locations.Locations) error { + logsPath, err := l.ProvideLogsPath() + if err != nil { + return err + } + + // Get the session ID if its specified + var sessionID logging.SessionID + if flagVal, found := getFlagValue(os.Args, app.FlagSessionID); found { + sessionID = logging.SessionID(flagVal) + } else { + sessionID = logging.NewSessionID() + } + + closer, err := logging.Init( + logsPath, + sessionID, + logging.BridgeShortAppName, + logging.DefaultMaxLogFileSize, + logging.DefaultPruningSize, + "", + ) + if err != nil { + return err + } + + defer func() { + _ = logging.Close(closer) + }() + + logrus. + WithField("appName", constants.FullAppName). + WithField("version", constants.Version). + WithField("revision", constants.Revision). + WithField("tag", constants.Tag). + WithField("build", constants.BuildTime). + WithField("runtime", runtime.GOOS). + WithField("args", os.Args). + WithField("SentryID", sentry.GetProtectedHostname()).WithError(appErr).Error("Failed to initialize bridge") + return nil + }) + } +} + +// getFlagValue - obtains the value of a specified tag +// The flag can be of the following form `-flag value`, `--flag value`, `-flag=value` or `--flags=value`. +func getFlagValue(argList []string, flag string) (string, bool) { + eqPrefix1 := "-" + flag + "=" + eqPrefix2 := "--" + flag + "=" + + for i := 0; i < len(argList); i++ { + arg := argList[i] + if strings.HasPrefix(arg, eqPrefix1) { + val := strings.TrimPrefix(arg, eqPrefix1) + return val, len(val) > 0 + } + if strings.HasPrefix(arg, eqPrefix2) { + val := strings.TrimPrefix(arg, eqPrefix2) + return val, len(val) > 0 + } + if (arg == "-"+flag || arg == "--"+flag) && i+1 < len(argList) { + return argList[i+1], true + } + } + + return "", false } diff --git a/cmd/Desktop-Bridge/main_test.go b/cmd/Desktop-Bridge/main_test.go new file mode 100644 index 00000000..3de451cc --- /dev/null +++ b/cmd/Desktop-Bridge/main_test.go @@ -0,0 +1,47 @@ +// Copyright (c) 2024 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetFlagValue(t *testing.T) { + tests := []struct { + args []string + flag string + expected string + }{ + {[]string{"session-id", ""}, "session-id", ""}, + {[]string{"-session-id", ""}, "session-id", ""}, + {[]string{"--session-id", ""}, "session-id", ""}, + {[]string{"session-id", "test"}, "session-id", ""}, + {[]string{"-session-id", "test"}, "session-id", "test"}, + {[]string{"--session-id", "test"}, "session-id", "test"}, + {[]string{"session-id=test"}, "session-id", ""}, + {[]string{"-session-id=test"}, "session-id", "test"}, + {[]string{"--session-id=test"}, "session-id", "test"}, + } + + for _, tt := range tests { + val, _ := getFlagValue(tt.args, tt.flag) + require.Equal(t, val, tt.expected) + } +} diff --git a/internal/app/app.go b/internal/app/app.go index 49fb83e0..54ffd1be 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -83,7 +83,7 @@ const ( flagNoWindow = "no-window" flagParentPID = "parent-pid" flagSoftwareRenderer = "software-renderer" - flagSessionID = "session-id" + FlagSessionID = "session-id" ) const ( @@ -165,7 +165,7 @@ func New() *cli.App { Value: false, }, &cli.StringFlag{ - Name: flagSessionID, + Name: FlagSessionID, Hidden: true, }, } @@ -346,7 +346,7 @@ func withLogging(c *cli.Context, crashHandler *crash.Handler, locations *locatio logrus.WithField("path", logsPath).Debug("Received logs path") // Initialize logging. - sessionID := logging.NewSessionIDFromString(c.String(flagSessionID)) + sessionID := logging.NewSessionIDFromString(c.String(FlagSessionID)) var closer io.Closer if closer, err = logging.Init( logsPath,