Revert "GODT-1932: frontend is instantiated before bridge."

This reverts commit db2379e2fd.
This commit is contained in:
James Houlahan
2022-10-12 22:21:33 +02:00
parent cc9ad17ea5
commit 3b0bc1ca15
7 changed files with 67 additions and 145 deletions

View File

@ -44,9 +44,9 @@ const (
flagLogSMTP = "log-smtp" flagLogSMTP = "log-smtp"
flagNonInteractive = "noninteractive" flagNonInteractive = "noninteractive"
// Memory cache was estimated by empirical usage in the past, and it was set to 100MB. // Memory cache was estimated by empirical usage in past and it was set to 100MB.
// NOTE: This value must not be less than maximal size of one email (~30MB). // NOTE: This value must not be less than maximal size of one email (~30MB).
inMemoryCacheLimit = 100 * (1 << 20) inMemoryCacheLimnit = 100 * (1 << 20)
) )
func New(base *base.Base) *cli.App { func New(base *base.Base) *cli.App {
@ -63,7 +63,7 @@ func New(base *base.Base) *cli.App {
}, },
&cli.BoolFlag{ &cli.BoolFlag{
Name: flagNonInteractive, Name: flagNonInteractive,
Usage: "Start Bridge entirely non-interactively", Usage: "Start Bridge entirely noninteractively",
}, },
}...) }...)
@ -71,17 +71,6 @@ func New(base *base.Base) *cli.App {
} }
func main(b *base.Base, c *cli.Context) error { //nolint:funlen func main(b *base.Base, c *cli.Context) error { //nolint:funlen
frontendType := getFrontendTypeFromCLIParams(c)
f := frontend.New(
frontendType,
!c.Bool(base.FlagNoWindow),
b.CrashHandler,
b.Listener,
b.Updater,
b,
b.Locations,
)
cache, cacheErr := loadMessageCache(b) cache, cacheErr := loadMessageCache(b)
if cacheErr != nil { if cacheErr != nil {
logrus.WithError(cacheErr).Error("Could not load local cache.") logrus.WithError(cacheErr).Error("Could not load local cache.")
@ -152,10 +141,28 @@ func main(b *base.Base, c *cli.Context) error { //nolint:funlen
// We want cookies to be saved to disk so they are loaded the next time. // We want cookies to be saved to disk so they are loaded the next time.
b.AddTeardownAction(b.CookieJar.PersistCookies) b.AddTeardownAction(b.CookieJar.PersistCookies)
if frontendType == frontend.NonInteractive { var frontendMode string
return <-(make(chan error))
switch {
case c.Bool(base.FlagCLI):
frontendMode = "cli"
case c.Bool(flagNonInteractive):
return <-(make(chan error)) // Block forever.
default:
frontendMode = "grpc"
} }
f := frontend.New(
frontendMode,
!c.Bool(base.FlagNoWindow),
b.CrashHandler,
b.Listener,
b.Updater,
bridge,
b,
b.Locations,
)
// Watch for updates routine // Watch for updates routine
go func() { go func() {
ticker := time.NewTicker(constants.UpdateCheckInterval) ticker := time.NewTicker(constants.UpdateCheckInterval)
@ -166,18 +173,7 @@ func main(b *base.Base, c *cli.Context) error { //nolint:funlen
} }
}() }()
return f.Loop(bridge) return f.Loop()
}
func getFrontendTypeFromCLIParams(c *cli.Context) frontend.Type {
switch {
case c.Bool(base.FlagCLI):
return frontend.CLI
case c.Bool(flagNonInteractive):
return frontend.NonInteractive
default:
return frontend.GRPC
}
} }
func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) { func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) {
@ -230,7 +226,7 @@ func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool)
// local cache is enabled but unavailable (in-memory cache will be returned nevertheless). // local cache is enabled but unavailable (in-memory cache will be returned nevertheless).
func loadMessageCache(b *base.Base) (cache.Cache, error) { func loadMessageCache(b *base.Base) (cache.Cache, error) {
if !b.Settings.GetBool(settings.CacheEnabledKey) { if !b.Settings.GetBool(settings.CacheEnabledKey) {
return cache.NewInMemoryCache(inMemoryCacheLimit), nil return cache.NewInMemoryCache(inMemoryCacheLimnit), nil
} }
var compressor cache.Compressor var compressor cache.Compressor
@ -250,12 +246,12 @@ func loadMessageCache(b *base.Base) (cache.Cache, error) {
path = customPath path = customPath
} else { } else {
path = b.Cache.GetDefaultMessageCacheDir() path = b.Cache.GetDefaultMessageCacheDir()
// Store path so it will always persist if default location // Store path so it will allways persist if default location
// will be changed in new version. // will be changed in new version.
b.Settings.Set(settings.CacheLocationKey, path) b.Settings.Set(settings.CacheLocationKey, path)
} }
// To prevent memory peaks we set maximal write concurrency for store // To prevent memory peaks we set maximal write concurency for store
// build jobs. // build jobs.
store.SetBuildAndCacheJobLimit(b.Settings.GetInt(settings.CacheConcurrencyWrite)) store.SetBuildAndCacheJobLimit(b.Settings.GetInt(settings.CacheConcurrencyWrite))
@ -266,7 +262,7 @@ func loadMessageCache(b *base.Base) (cache.Cache, error) {
ConcurrentWrite: b.Settings.GetInt(settings.CacheConcurrencyWrite), ConcurrentWrite: b.Settings.GetInt(settings.CacheConcurrencyWrite),
}) })
if err != nil { if err != nil {
return cache.NewInMemoryCache(inMemoryCacheLimit), err return cache.NewInMemoryCache(inMemoryCacheLimnit), err
} }
return messageCache, nil return messageCache, nil

View File

@ -47,6 +47,7 @@ func New( //nolint:funlen
eventListener listener.Listener, eventListener listener.Listener,
updater types.Updater, updater types.Updater,
bridge types.Bridger,
restarter types.Restarter, restarter types.Restarter,
) *frontendCLI { //nolint:revive ) *frontendCLI { //nolint:revive
fe := &frontendCLI{ fe := &frontendCLI{
@ -54,6 +55,7 @@ func New( //nolint:funlen
eventListener: eventListener, eventListener: eventListener,
updater: updater, updater: updater,
bridge: bridge,
restarter: restarter, restarter: restarter,
} }
@ -317,8 +319,7 @@ func (f *frontendCLI) watchEvents() {
} }
// Loop starts the frontend loop with an interactive shell. // Loop starts the frontend loop with an interactive shell.
func (f *frontendCLI) Loop(b types.Bridger) error { func (f *frontendCLI) Loop() error {
f.bridge = b
f.Printf(` f.Printf(`
Welcome to %s interactive shell Welcome to %s interactive shell
___....___ ___....___

View File

@ -19,6 +19,7 @@
package frontend package frontend
import ( import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types" "github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
@ -27,17 +28,8 @@ import (
"github.com/ProtonMail/proton-bridge/v2/pkg/listener" "github.com/ProtonMail/proton-bridge/v2/pkg/listener"
) )
// Type describes the available types of frontend.
type Type int
const (
CLI Type = iota
GRPC
NonInteractive
)
type Frontend interface { type Frontend interface {
Loop(b types.Bridger) error Loop() error
NotifyManualUpdate(update updater.VersionInfo, canInstall bool) NotifyManualUpdate(update updater.VersionInfo, canInstall bool)
SetVersion(update updater.VersionInfo) SetVersion(update updater.VersionInfo)
NotifySilentUpdateInstalled() NotifySilentUpdateInstalled()
@ -45,38 +37,38 @@ type Frontend interface {
WaitUntilFrontendIsReady() WaitUntilFrontendIsReady()
} }
// New returns initialized frontend based on `frontendType`, which can be `CLI` or `GRPC`. // New returns initialized frontend based on `frontendType`, which can be `cli` or `grpc`.
func New( func New(
frontendType Type, frontendType string,
showWindowOnStart bool, showWindowOnStart bool,
panicHandler types.PanicHandler, panicHandler types.PanicHandler,
eventListener listener.Listener, eventListener listener.Listener,
updater types.Updater, updater types.Updater,
bridge *bridge.Bridge,
restarter types.Restarter, restarter types.Restarter,
locations *locations.Locations, locations *locations.Locations,
) Frontend { ) Frontend {
switch frontendType { switch frontendType {
case GRPC: case "grpc":
return grpc.NewService( return grpc.NewService(
showWindowOnStart, showWindowOnStart,
panicHandler, panicHandler,
eventListener, eventListener,
updater, updater,
bridge,
restarter, restarter,
locations, locations,
) )
case CLI: case "cli":
return cli.New( return cli.New(
panicHandler, panicHandler,
eventListener, eventListener,
updater, updater,
bridge,
restarter, restarter,
) )
case NonInteractive:
fallthrough
default: default:
return nil return nil
} }

View File

@ -53,25 +53,3 @@ func TestConfig(t *testing.T) {
// failure to save // failure to save
require.Error(t, conf2.save(filepath.Join(tempDir, "non/existing/folder", tempFileName))) require.Error(t, conf2.save(filepath.Join(tempDir, "non/existing/folder", tempFileName)))
} }
func TestIsInternetStatus(t *testing.T) {
require.True(t, NewInternetStatusEvent(true).isInternetStatus())
require.True(t, NewInternetStatusEvent(false).isInternetStatus())
require.False(t, NewKeychainHasNoKeychainEvent().isInternetStatus())
require.False(t, NewLoginAlreadyLoggedInEvent("").isInternetStatus())
}
func TestFilterOutInternetStatusEvents(t *testing.T) {
require.Zero(t, len(filterOutInternetStatusEvents([]*StreamEvent{})))
off := NewInternetStatusEvent(false)
on := NewInternetStatusEvent(true)
show := NewShowMainWindowEvent()
finished := NewLoginFinishedEvent("id")
require.Zero(t, len(filterOutInternetStatusEvents([]*StreamEvent{})))
require.Zero(t, len(filterOutInternetStatusEvents([]*StreamEvent{off, on, off})))
require.Equal(t, filterOutInternetStatusEvents([]*StreamEvent{off, show, on}), []*StreamEvent{show})
require.Equal(t, filterOutInternetStatusEvents([]*StreamEvent{finished, off, show, on}), []*StreamEvent{finished, show})
require.Equal(t, filterOutInternetStatusEvents([]*StreamEvent{finished, show}), []*StreamEvent{finished, show})
}

View File

@ -1,32 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package grpc
import "github.com/bradenaw/juniper/xslices"
// isInternetStatus returns true iff the event is InternetStatus.
func (x *StreamEvent) isInternetStatus() bool {
appEvent := x.GetApp()
return (appEvent != nil) && (appEvent.GetInternetStatus() != nil)
}
// filterOutInternetStatusEvents return a copy of the events list where all internet connection events have been removed.
func filterOutInternetStatusEvents(events []*StreamEvent) []*StreamEvent {
return xslices.Filter(events, func(event *StreamEvent) bool { return !event.isInternetStatus() })
}

View File

@ -92,6 +92,7 @@ func NewService(
panicHandler types.PanicHandler, panicHandler types.PanicHandler,
eventListener listener.Listener, eventListener listener.Listener,
updater types.Updater, updater types.Updater,
bridge types.Bridger,
restarter types.Restarter, restarter types.Restarter,
locations *locations.Locations, locations *locations.Locations,
) *Service { ) *Service {
@ -100,6 +101,7 @@ func NewService(
panicHandler: panicHandler, panicHandler: panicHandler,
eventListener: eventListener, eventListener: eventListener,
updater: updater, updater: updater,
bridge: bridge,
restarter: restarter, restarter: restarter,
showOnStartup: showOnStartup, showOnStartup: showOnStartup,
@ -115,16 +117,6 @@ func NewService(
// set to 1 // set to 1
s.initializing.Add(1) s.initializing.Add(1)
go func() {
defer s.panicHandler.HandlePanic()
s.watchEvents()
}()
return &s
}
func (s *Service) startGRPCServer() {
s.log.Info("Starting gRPC server")
tlsConfig, pemCert, err := s.generateTLSConfig() tlsConfig, pemCert, err := s.generateTLSConfig()
if err != nil { if err != nil {
s.log.WithError(err).Panic("Could not generate gRPC TLS config") s.log.WithError(err).Panic("Could not generate gRPC TLS config")
@ -132,13 +124,14 @@ func (s *Service) startGRPCServer() {
s.pemCert = string(pemCert) s.pemCert = string(pemCert)
s.initAutostart()
s.grpcServer = grpc.NewServer( s.grpcServer = grpc.NewServer(
grpc.Creds(credentials.NewTLS(tlsConfig)), grpc.Creds(credentials.NewTLS(tlsConfig)),
grpc.UnaryInterceptor(s.validateUnaryServerToken), grpc.UnaryInterceptor(s.validateUnaryServerToken),
grpc.StreamInterceptor(s.validateStreamServerToken), grpc.StreamInterceptor(s.validateStreamServerToken),
) )
RegisterBridgeServer(s.grpcServer, s) RegisterBridgeServer(s.grpcServer, &s)
s.listener, err = net.Listen("tcp", "127.0.0.1:0") // Port 0 means that the port is randomly picked by the system. s.listener, err = net.Listen("tcp", "127.0.0.1:0") // Port 0 means that the port is randomly picked by the system.
if err != nil { if err != nil {
@ -151,7 +144,9 @@ func (s *Service) startGRPCServer() {
s.log.WithField("path", path).Info("Successfully saved gRPC service config file") s.log.WithField("path", path).Info("Successfully saved gRPC service config file")
} }
s.log.Info("gRPC server listening at ", s.listener.Addr()) s.log.Info("gRPC server listening on ", s.listener.Addr())
return &s
} }
func (s *Service) initAutostart() { func (s *Service) initAutostart() {
@ -171,24 +166,23 @@ func (s *Service) initAutostart() {
}) })
} }
func (s *Service) Loop(b types.Bridger) error { func (s *Service) Loop() error {
s.bridge = b
s.initAutostart()
s.startGRPCServer()
defer func() { defer func() {
s.bridge.SetBool(settings.FirstStartGUIKey, false) s.bridge.SetBool(settings.FirstStartGUIKey, false)
}() }()
if s.bridge.HasError(bridge.ErrLocalCacheUnavailable) { go func() {
_ = s.SendEvent(NewCacheErrorEvent(CacheErrorType_CACHE_UNAVAILABLE_ERROR)) defer s.panicHandler.HandlePanic()
} s.watchEvents()
}()
err := s.grpcServer.Serve(s.listener) s.log.Info("Starting gRPC server")
if err != nil {
s.log.WithError(err).Error("error serving RPC") if err := s.grpcServer.Serve(s.listener); err != nil {
s.log.WithError(err).Error("Error serving gRPC")
return err return err
} }
return nil return nil
} }
@ -219,6 +213,10 @@ func (s *Service) WaitUntilFrontendIsReady() {
} }
func (s *Service) watchEvents() { // nolint:funlen func (s *Service) watchEvents() { // nolint:funlen
if s.bridge.HasError(bridge.ErrLocalCacheUnavailable) {
_ = s.SendEvent(NewCacheErrorEvent(CacheErrorType_CACHE_UNAVAILABLE_ERROR))
}
errorCh := s.eventListener.ProvideChannel(events.ErrorEvent) errorCh := s.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent) credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent)
noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent) noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent)
@ -270,10 +268,6 @@ func (s *Service) watchEvents() { // nolint:funlen
case address := <-addressChangedLogoutCh: case address := <-addressChangedLogoutCh:
_ = s.SendEvent(NewMailAddressChangeLogoutEvent(address)) _ = s.SendEvent(NewMailAddressChangeLogoutEvent(address))
case userID := <-logoutCh: case userID := <-logoutCh:
if s.bridge == nil {
logrus.Error("Received a logout event but bridge is not yet instantiated.")
break
}
user, err := s.bridge.GetUserInfo(userID) user, err := s.bridge.GetUserInfo(userID)
if err != nil { if err != nil {
return return

View File

@ -87,8 +87,12 @@ func (s *Service) StopEventStream(ctx context.Context, _ *emptypb.Empty) (*empty
// SendEvent sends an event to the via the gRPC event stream. // SendEvent sends an event to the via the gRPC event stream.
func (s *Service) SendEvent(event *StreamEvent) error { func (s *Service) SendEvent(event *StreamEvent) error {
if s.eventStreamCh == nil { // nobody is connected to the event stream, we queue events s.eventQueueMutex.Lock()
s.queueEvent(event) defer s.eventQueueMutex.Unlock()
if s.eventStreamCh == nil {
// nobody is connected to the event stream, we queue events
s.eventQueue = append(s.eventQueue, event)
return nil return nil
} }
@ -163,14 +167,3 @@ func (s *Service) StartEventTest() error { //nolint:funlen
return nil return nil
} }
func (s *Service) queueEvent(event *StreamEvent) {
s.eventQueueMutex.Lock()
defer s.eventQueueMutex.Unlock()
if event.isInternetStatus() {
s.eventQueue = append(filterOutInternetStatusEvents(s.eventQueue), event)
} else {
s.eventQueue = append(s.eventQueue, event)
}
}