diff --git a/internal/app/bridge/bridge.go b/internal/app/bridge/bridge.go index 0f58d5bf..c20f9536 100644 --- a/internal/app/bridge/bridge.go +++ b/internal/app/bridge/bridge.go @@ -44,9 +44,9 @@ const ( flagLogSMTP = "log-smtp" flagNonInteractive = "noninteractive" - // Memory cache was estimated by empirical usage in the past, and it was set to 100MB. + // Memory cache was estimated by empirical usage in past and it was set to 100MB. // NOTE: This value must not be less than maximal size of one email (~30MB). - inMemoryCacheLimit = 100 * (1 << 20) + inMemoryCacheLimnit = 100 * (1 << 20) ) func New(base *base.Base) *cli.App { @@ -63,7 +63,7 @@ func New(base *base.Base) *cli.App { }, &cli.BoolFlag{ Name: flagNonInteractive, - Usage: "Start Bridge entirely non-interactively", + Usage: "Start Bridge entirely noninteractively", }, }...) @@ -71,17 +71,6 @@ func New(base *base.Base) *cli.App { } func main(b *base.Base, c *cli.Context) error { //nolint:funlen - frontendType := getFrontendTypeFromCLIParams(c) - f := frontend.New( - frontendType, - !c.Bool(base.FlagNoWindow), - b.CrashHandler, - b.Listener, - b.Updater, - b, - b.Locations, - ) - cache, cacheErr := loadMessageCache(b) if cacheErr != nil { logrus.WithError(cacheErr).Error("Could not load local cache.") @@ -152,10 +141,28 @@ func main(b *base.Base, c *cli.Context) error { //nolint:funlen // We want cookies to be saved to disk so they are loaded the next time. b.AddTeardownAction(b.CookieJar.PersistCookies) - if frontendType == frontend.NonInteractive { - return <-(make(chan error)) + var frontendMode string + + switch { + case c.Bool(base.FlagCLI): + frontendMode = "cli" + case c.Bool(flagNonInteractive): + return <-(make(chan error)) // Block forever. + default: + frontendMode = "grpc" } + f := frontend.New( + frontendMode, + !c.Bool(base.FlagNoWindow), + b.CrashHandler, + b.Listener, + b.Updater, + bridge, + b, + b.Locations, + ) + // Watch for updates routine go func() { ticker := time.NewTicker(constants.UpdateCheckInterval) @@ -166,18 +173,7 @@ func main(b *base.Base, c *cli.Context) error { //nolint:funlen } }() - return f.Loop(bridge) -} - -func getFrontendTypeFromCLIParams(c *cli.Context) frontend.Type { - switch { - case c.Bool(base.FlagCLI): - return frontend.CLI - case c.Bool(flagNonInteractive): - return frontend.NonInteractive - default: - return frontend.GRPC - } + return f.Loop() } func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) { @@ -230,7 +226,7 @@ func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) // local cache is enabled but unavailable (in-memory cache will be returned nevertheless). func loadMessageCache(b *base.Base) (cache.Cache, error) { if !b.Settings.GetBool(settings.CacheEnabledKey) { - return cache.NewInMemoryCache(inMemoryCacheLimit), nil + return cache.NewInMemoryCache(inMemoryCacheLimnit), nil } var compressor cache.Compressor @@ -250,12 +246,12 @@ func loadMessageCache(b *base.Base) (cache.Cache, error) { path = customPath } else { path = b.Cache.GetDefaultMessageCacheDir() - // Store path so it will always persist if default location + // Store path so it will allways persist if default location // will be changed in new version. b.Settings.Set(settings.CacheLocationKey, path) } - // To prevent memory peaks we set maximal write concurrency for store + // To prevent memory peaks we set maximal write concurency for store // build jobs. store.SetBuildAndCacheJobLimit(b.Settings.GetInt(settings.CacheConcurrencyWrite)) @@ -266,7 +262,7 @@ func loadMessageCache(b *base.Base) (cache.Cache, error) { ConcurrentWrite: b.Settings.GetInt(settings.CacheConcurrencyWrite), }) if err != nil { - return cache.NewInMemoryCache(inMemoryCacheLimit), err + return cache.NewInMemoryCache(inMemoryCacheLimnit), err } return messageCache, nil diff --git a/internal/frontend/cli/frontend.go b/internal/frontend/cli/frontend.go index 1597ea80..10ad573c 100644 --- a/internal/frontend/cli/frontend.go +++ b/internal/frontend/cli/frontend.go @@ -47,6 +47,7 @@ func New( //nolint:funlen eventListener listener.Listener, updater types.Updater, + bridge types.Bridger, restarter types.Restarter, ) *frontendCLI { //nolint:revive fe := &frontendCLI{ @@ -54,6 +55,7 @@ func New( //nolint:funlen eventListener: eventListener, updater: updater, + bridge: bridge, restarter: restarter, } @@ -317,8 +319,7 @@ func (f *frontendCLI) watchEvents() { } // Loop starts the frontend loop with an interactive shell. -func (f *frontendCLI) Loop(b types.Bridger) error { - f.bridge = b +func (f *frontendCLI) Loop() error { f.Printf(` Welcome to %s interactive shell ___....___ diff --git a/internal/frontend/frontend.go b/internal/frontend/frontend.go index b701787f..fc3e407b 100644 --- a/internal/frontend/frontend.go +++ b/internal/frontend/frontend.go @@ -19,6 +19,7 @@ package frontend import ( + "github.com/ProtonMail/proton-bridge/v2/internal/bridge" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/types" @@ -27,17 +28,8 @@ import ( "github.com/ProtonMail/proton-bridge/v2/pkg/listener" ) -// Type describes the available types of frontend. -type Type int - -const ( - CLI Type = iota - GRPC - NonInteractive -) - type Frontend interface { - Loop(b types.Bridger) error + Loop() error NotifyManualUpdate(update updater.VersionInfo, canInstall bool) SetVersion(update updater.VersionInfo) NotifySilentUpdateInstalled() @@ -45,38 +37,38 @@ type Frontend interface { WaitUntilFrontendIsReady() } -// New returns initialized frontend based on `frontendType`, which can be `CLI` or `GRPC`. +// New returns initialized frontend based on `frontendType`, which can be `cli` or `grpc`. func New( - frontendType Type, + frontendType string, showWindowOnStart bool, panicHandler types.PanicHandler, eventListener listener.Listener, updater types.Updater, + bridge *bridge.Bridge, restarter types.Restarter, locations *locations.Locations, ) Frontend { switch frontendType { - case GRPC: + case "grpc": return grpc.NewService( showWindowOnStart, panicHandler, eventListener, updater, + bridge, restarter, locations, ) - case CLI: + case "cli": return cli.New( panicHandler, eventListener, updater, + bridge, restarter, ) - case NonInteractive: - fallthrough - default: return nil } diff --git a/internal/frontend/grpc/grpc_test.go b/internal/frontend/grpc/config_test.go similarity index 57% rename from internal/frontend/grpc/grpc_test.go rename to internal/frontend/grpc/config_test.go index 37d31ebe..f54e6f88 100644 --- a/internal/frontend/grpc/grpc_test.go +++ b/internal/frontend/grpc/config_test.go @@ -53,25 +53,3 @@ func TestConfig(t *testing.T) { // failure to save require.Error(t, conf2.save(filepath.Join(tempDir, "non/existing/folder", tempFileName))) } - -func TestIsInternetStatus(t *testing.T) { - require.True(t, NewInternetStatusEvent(true).isInternetStatus()) - require.True(t, NewInternetStatusEvent(false).isInternetStatus()) - require.False(t, NewKeychainHasNoKeychainEvent().isInternetStatus()) - require.False(t, NewLoginAlreadyLoggedInEvent("").isInternetStatus()) -} - -func TestFilterOutInternetStatusEvents(t *testing.T) { - require.Zero(t, len(filterOutInternetStatusEvents([]*StreamEvent{}))) - - off := NewInternetStatusEvent(false) - on := NewInternetStatusEvent(true) - show := NewShowMainWindowEvent() - finished := NewLoginFinishedEvent("id") - - require.Zero(t, len(filterOutInternetStatusEvents([]*StreamEvent{}))) - require.Zero(t, len(filterOutInternetStatusEvents([]*StreamEvent{off, on, off}))) - require.Equal(t, filterOutInternetStatusEvents([]*StreamEvent{off, show, on}), []*StreamEvent{show}) - require.Equal(t, filterOutInternetStatusEvents([]*StreamEvent{finished, off, show, on}), []*StreamEvent{finished, show}) - require.Equal(t, filterOutInternetStatusEvents([]*StreamEvent{finished, show}), []*StreamEvent{finished, show}) -} diff --git a/internal/frontend/grpc/event_utils.go b/internal/frontend/grpc/event_utils.go deleted file mode 100644 index cbbab163..00000000 --- a/internal/frontend/grpc/event_utils.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2022 Proton AG -// -// This file is part of Proton Mail Bridge. -// -// Proton Mail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// Proton Mail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with Proton Mail Bridge. If not, see . - -package grpc - -import "github.com/bradenaw/juniper/xslices" - -// isInternetStatus returns true iff the event is InternetStatus. -func (x *StreamEvent) isInternetStatus() bool { - appEvent := x.GetApp() - - return (appEvent != nil) && (appEvent.GetInternetStatus() != nil) -} - -// filterOutInternetStatusEvents return a copy of the events list where all internet connection events have been removed. -func filterOutInternetStatusEvents(events []*StreamEvent) []*StreamEvent { - return xslices.Filter(events, func(event *StreamEvent) bool { return !event.isInternetStatus() }) -} diff --git a/internal/frontend/grpc/service.go b/internal/frontend/grpc/service.go index 00d7aebc..a03c773c 100644 --- a/internal/frontend/grpc/service.go +++ b/internal/frontend/grpc/service.go @@ -92,6 +92,7 @@ func NewService( panicHandler types.PanicHandler, eventListener listener.Listener, updater types.Updater, + bridge types.Bridger, restarter types.Restarter, locations *locations.Locations, ) *Service { @@ -100,6 +101,7 @@ func NewService( panicHandler: panicHandler, eventListener: eventListener, updater: updater, + bridge: bridge, restarter: restarter, showOnStartup: showOnStartup, @@ -115,16 +117,6 @@ func NewService( // set to 1 s.initializing.Add(1) - go func() { - defer s.panicHandler.HandlePanic() - s.watchEvents() - }() - - return &s -} - -func (s *Service) startGRPCServer() { - s.log.Info("Starting gRPC server") tlsConfig, pemCert, err := s.generateTLSConfig() if err != nil { s.log.WithError(err).Panic("Could not generate gRPC TLS config") @@ -132,13 +124,14 @@ func (s *Service) startGRPCServer() { s.pemCert = string(pemCert) + s.initAutostart() s.grpcServer = grpc.NewServer( grpc.Creds(credentials.NewTLS(tlsConfig)), grpc.UnaryInterceptor(s.validateUnaryServerToken), grpc.StreamInterceptor(s.validateStreamServerToken), ) - RegisterBridgeServer(s.grpcServer, s) + RegisterBridgeServer(s.grpcServer, &s) s.listener, err = net.Listen("tcp", "127.0.0.1:0") // Port 0 means that the port is randomly picked by the system. if err != nil { @@ -151,7 +144,9 @@ func (s *Service) startGRPCServer() { s.log.WithField("path", path).Info("Successfully saved gRPC service config file") } - s.log.Info("gRPC server listening at ", s.listener.Addr()) + s.log.Info("gRPC server listening on ", s.listener.Addr()) + + return &s } func (s *Service) initAutostart() { @@ -171,24 +166,23 @@ func (s *Service) initAutostart() { }) } -func (s *Service) Loop(b types.Bridger) error { - s.bridge = b - s.initAutostart() - s.startGRPCServer() - +func (s *Service) Loop() error { defer func() { s.bridge.SetBool(settings.FirstStartGUIKey, false) }() - if s.bridge.HasError(bridge.ErrLocalCacheUnavailable) { - _ = s.SendEvent(NewCacheErrorEvent(CacheErrorType_CACHE_UNAVAILABLE_ERROR)) - } + go func() { + defer s.panicHandler.HandlePanic() + s.watchEvents() + }() - err := s.grpcServer.Serve(s.listener) - if err != nil { - s.log.WithError(err).Error("error serving RPC") + s.log.Info("Starting gRPC server") + + if err := s.grpcServer.Serve(s.listener); err != nil { + s.log.WithError(err).Error("Error serving gRPC") return err } + return nil } @@ -219,6 +213,10 @@ func (s *Service) WaitUntilFrontendIsReady() { } func (s *Service) watchEvents() { // nolint:funlen + if s.bridge.HasError(bridge.ErrLocalCacheUnavailable) { + _ = s.SendEvent(NewCacheErrorEvent(CacheErrorType_CACHE_UNAVAILABLE_ERROR)) + } + errorCh := s.eventListener.ProvideChannel(events.ErrorEvent) credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent) noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent) @@ -270,10 +268,6 @@ func (s *Service) watchEvents() { // nolint:funlen case address := <-addressChangedLogoutCh: _ = s.SendEvent(NewMailAddressChangeLogoutEvent(address)) case userID := <-logoutCh: - if s.bridge == nil { - logrus.Error("Received a logout event but bridge is not yet instantiated.") - break - } user, err := s.bridge.GetUserInfo(userID) if err != nil { return diff --git a/internal/frontend/grpc/service_stream.go b/internal/frontend/grpc/service_stream.go index f32511fd..6ab46b0f 100644 --- a/internal/frontend/grpc/service_stream.go +++ b/internal/frontend/grpc/service_stream.go @@ -87,8 +87,12 @@ func (s *Service) StopEventStream(ctx context.Context, _ *emptypb.Empty) (*empty // SendEvent sends an event to the via the gRPC event stream. func (s *Service) SendEvent(event *StreamEvent) error { - if s.eventStreamCh == nil { // nobody is connected to the event stream, we queue events - s.queueEvent(event) + s.eventQueueMutex.Lock() + defer s.eventQueueMutex.Unlock() + + if s.eventStreamCh == nil { + // nobody is connected to the event stream, we queue events + s.eventQueue = append(s.eventQueue, event) return nil } @@ -163,14 +167,3 @@ func (s *Service) StartEventTest() error { //nolint:funlen return nil } - -func (s *Service) queueEvent(event *StreamEvent) { - s.eventQueueMutex.Lock() - defer s.eventQueueMutex.Unlock() - - if event.isInternetStatus() { - s.eventQueue = append(filterOutInternetStatusEvents(s.eventQueue), event) - } else { - s.eventQueue = append(s.eventQueue, event) - } -}