diff --git a/internal/frontend/grpc/service.go b/internal/frontend/grpc/service.go index 4318ab95..0312f18e 100644 --- a/internal/frontend/grpc/service.go +++ b/internal/frontend/grpc/service.go @@ -44,8 +44,10 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" ) @@ -129,7 +131,11 @@ func (s *Service) startGRPCServer() { s.pemCert = string(pemCert) - s.grpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + s.grpcServer = grpc.NewServer( + grpc.Creds(credentials.NewTLS(tlsConfig)), + grpc.UnaryInterceptor(s.validateUnaryServerToken), + grpc.StreamInterceptor(s.validateStreamServerToken), + ) RegisterBridgeServer(s.grpcServer, s) @@ -458,3 +464,57 @@ func (s *Service) saveGRPCServerConfigFile() (string, error) { return configPath, sc.save(configPath) } + +// validateServerToken verify that the server token provided by the client is valid. +func (s *Service) validateServerToken(ctx context.Context) error { + values, ok := metadata.FromIncomingContext(ctx) + if !ok { + return status.Error(codes.Unauthenticated, "missing server token") + } + + token := values.Get(serverTokenMetadataKey) + if len(token) == 0 { + return status.Error(codes.Unauthenticated, "missing server token") + } + + if len(token) > 1 { + return status.Error(codes.Unauthenticated, "more than one server token was provided") + } + + if token[0] != s.token { + return status.Error(codes.Unauthenticated, "invalid server token") + } + + return nil +} + +// validateUnaryServerToken check the server token for every unary gRPC call. +func (s *Service) validateUnaryServerToken( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, +) (resp interface{}, err error) { + if err := s.validateServerToken(ctx); err != nil { + return nil, err + } + + return handler(ctx, req) +} + +// validateStreamServerToken check the server token for every gRPC stream request. +func (s *Service) validateStreamServerToken( + _ interface{}, + ss grpc.ServerStream, + info *grpc.StreamServerInfo, + _ grpc.StreamHandler, +) error { + logEntry := s.log.WithField("FullMethod", info.FullMethod) + + if err := s.validateServerToken(ss.Context()); err != nil { + logEntry.WithError(err).Error("Stream validator failed") + return err + } + + return nil +} diff --git a/internal/frontend/grpc/service_methods.go b/internal/frontend/grpc/service_methods.go index 287aa612..fde708a1 100644 --- a/internal/frontend/grpc/service_methods.go +++ b/internal/frontend/grpc/service_methods.go @@ -33,7 +33,6 @@ import ( "github.com/ProtonMail/proton-bridge/v2/pkg/ports" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/wrapperspb" @@ -43,10 +42,6 @@ import ( func (s *Service) CheckTokens(ctx context.Context, clientConfigPath *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { s.log.Debug("CheckTokens") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - path := clientConfigPath.Value logEntry := s.log.WithField("path", path) @@ -65,10 +60,6 @@ func (s *Service) CheckTokens(ctx context.Context, clientConfigPath *wrapperspb. func (s *Service) AddLogEntry(ctx context.Context, request *AddLogEntryRequest) (*emptypb.Empty, error) { entry := s.log - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - if len(request.Package) > 0 { entry = entry.WithField("pkg", request.Package) } @@ -97,10 +88,6 @@ func (s *Service) AddLogEntry(ctx context.Context, request *AddLogEntryRequest) func (s *Service) GuiReady(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) { s.log.Debug("GuiReady") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - s.initializationDone.Do(s.initializing.Done) return &emptypb.Empty{}, nil } @@ -109,10 +96,6 @@ func (s *Service) GuiReady(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empt func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empty, error) { s.log.Debug("Quit") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - // Windows is notably slow at Quitting. We do it in a goroutine to speed things up a bit. go func() { var err error @@ -133,10 +116,6 @@ func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empt func (s *Service) Restart(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empty, error) { s.log.Debug("Restart") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - s.restarter.SetToRestart() return s.Quit(ctx, empty) } @@ -144,20 +123,12 @@ func (s *Service) Restart(ctx context.Context, empty *emptypb.Empty) (*emptypb.E func (s *Service) ShowOnStartup(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { s.log.Debug("ShowOnStartup") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Bool(s.showOnStartup), nil } func (s *Service) ShowSplashScreen(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { s.log.Debug("ShowSplashScreen") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - if s.bridge.IsFirstStart() { return wrapperspb.Bool(false), nil } @@ -176,20 +147,12 @@ func (s *Service) ShowSplashScreen(ctx context.Context, _ *emptypb.Empty) (*wrap func (s *Service) IsFirstGuiStart(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { s.log.Debug("IsFirstGuiStart") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Bool(s.bridge.GetBool(settings.FirstStartGUIKey)), nil } func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) { s.log.WithField("show", isOn.Value).Debug("SetIsAutostartOn") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - defer func() { _ = s.SendEvent(NewToggleAutostartFinishedEvent()) }() if isOn.Value == s.bridge.IsAutostartEnabled() { @@ -216,20 +179,12 @@ func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolVal func (s *Service) IsAutostartOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { s.log.Debug("IsAutostartOn") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Bool(s.bridge.IsAutostartEnabled()), nil } func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) { s.log.WithField("isEnabled", isEnabled.Value).Debug("SetIsBetaEnabled") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - channel := updater.StableChannel if isEnabled.Value { channel = updater.EarlyChannel @@ -244,20 +199,12 @@ func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.Bo func (s *Service) IsBetaEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { s.log.Debug("IsBetaEnabled") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Bool(s.bridge.GetUpdateChannel() == updater.EarlyChannel), nil } func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb.BoolValue) (*emptypb.Empty, error) { s.log.WithField("isVisible", isVisible.Value).Debug("SetIsAllMailVisible") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - s.bridge.SetIsAllMailVisible(isVisible.Value) return &emptypb.Empty{}, nil @@ -266,30 +213,18 @@ func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb func (s *Service) IsAllMailVisible(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { s.log.Debug("IsAllMailVisible") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Bool(s.bridge.IsAllMailVisible()), nil } func (s *Service) GoOs(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { s.log.Debug("GoOs") // TO-DO We can probably get rid of this and use QSysInfo::product name - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.String(runtime.GOOS), nil } func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) { s.log.Debug("TriggerReset") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer s.panicHandler.HandlePanic() s.triggerReset() @@ -300,20 +235,12 @@ func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb. func (s *Service) Version(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { s.log.Debug("Version") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.String(constants.Version), nil } func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { s.log.Debug("LogsPath") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - path, err := s.bridge.ProvideLogsPath() if err != nil { s.log.WithError(err).Error("Cannot determine logs path") @@ -325,44 +252,24 @@ func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.S func (s *Service) LicensePath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { s.log.Debug("LicensePath") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.String(s.bridge.GetLicenseFilePath()), nil } func (s *Service) DependencyLicensesLink(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.String(s.bridge.GetDependencyLicensesLink()), nil } func (s *Service) ReleaseNotesPageLink(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.String(s.newVersionInfo.ReleaseNotesPage), nil } func (s *Service) LandingPageLink(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.String(s.newVersionInfo.LandingPage), nil } func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.StringValue) (*emptypb.Empty, error) { s.log.WithField("ColorSchemeName", name.Value).Debug("SetColorSchemeName") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - if !theme.IsAvailable(theme.Theme(name.Value)) { s.log.WithField("scheme", name.Value).Warn("Color scheme not available") return nil, status.Error(codes.NotFound, "Color scheme not available") @@ -376,10 +283,6 @@ func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.Strin func (s *Service) ColorSchemeName(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { s.log.Debug("ColorSchemeName") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - current := s.bridge.Get(settings.ColorScheme) if !theme.IsAvailable(theme.Theme(current)) { current = string(theme.DefaultTheme()) @@ -392,10 +295,6 @@ func (s *Service) ColorSchemeName(ctx context.Context, _ *emptypb.Empty) (*wrapp func (s *Service) CurrentEmailClient(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { s.log.Debug("CurrentEmailClient") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.String(s.bridge.GetCurrentUserAgent()), nil } @@ -409,10 +308,6 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp "includeLogs": report.IncludeLogs, }).Debug("ReportBug") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer func() { _ = s.SendEvent(NewReportBugFinishedEvent()) }() @@ -439,10 +334,6 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.StringValue) (*emptypb.Empty, error) { s.log.WithField("launcher", launcher.Value).Debug("ForceLauncher") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer s.panicHandler.HandlePanic() s.restarter.ForceLauncher(launcher.Value) @@ -453,10 +344,6 @@ func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.String func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringValue) (*emptypb.Empty, error) { s.log.WithField("executable", exe.Value).Debug("SetMainExecutable") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer s.panicHandler.HandlePanic() s.restarter.SetMainExecutable(exe.Value) @@ -467,10 +354,6 @@ func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringV func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) { s.log.WithField("username", login.Username).Debug("Login") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer s.panicHandler.HandlePanic() @@ -518,10 +401,6 @@ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empt func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) { s.log.WithField("username", login.Username).Debug("Login2FA") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer s.panicHandler.HandlePanic() @@ -575,10 +454,6 @@ func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.E func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) { s.log.WithField("username", login.Username).Debug("Login2Passwords") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer s.panicHandler.HandlePanic() @@ -601,10 +476,6 @@ func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*em func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest) (*emptypb.Empty, error) { s.log.WithField("username", loginAbort.Username).Debug("LoginAbort") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer s.panicHandler.HandlePanic() @@ -617,10 +488,6 @@ func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest) func (s *Service) CheckUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) { s.log.Debug("CheckUpdate") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer s.panicHandler.HandlePanic() @@ -632,10 +499,6 @@ func (s *Service) CheckUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb.E func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) { s.log.Debug("InstallUpdate") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer s.panicHandler.HandlePanic() @@ -648,10 +511,6 @@ func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) { s.log.WithField("isOn", isOn.Value).Debug("SetIsAutomaticUpdateOn") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - currentlyOn := s.bridge.GetBool(settings.AutoUpdateKey) if currentlyOn == isOn.Value { return &emptypb.Empty{}, nil @@ -670,30 +529,18 @@ func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.B func (s *Service) IsAutomaticUpdateOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { s.log.Debug("IsAutomaticUpdateOn") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Bool(s.bridge.GetBool(settings.AutoUpdateKey)), nil } func (s *Service) IsCacheOnDiskEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { s.log.Debug("IsCacheOnDiskEnabled") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Bool(s.bridge.GetBool(settings.CacheEnabledKey)), nil } func (s *Service) DiskCachePath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { s.log.Debug("DiskCachePath") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.String(s.bridge.Get(settings.CacheLocationKey)), nil } @@ -702,10 +549,6 @@ func (s *Service) ChangeLocalCache(ctx context.Context, change *ChangeLocalCache WithField("diskCachePath", change.DiskCachePath). Debug("DiskCachePath") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - restart := false defer func(willRestart *bool) { _ = s.SendEvent(NewCacheChangeLocalCacheFinishedEvent(*willRestart)) @@ -758,10 +601,6 @@ func (s *Service) ChangeLocalCache(ctx context.Context, change *ChangeLocalCache func (s *Service) SetIsDoHEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) { s.log.WithField("isEnabled", isEnabled.Value).Debug("SetIsDohEnabled") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - s.bridge.SetProxyAllowed(isEnabled.Value) return &emptypb.Empty{}, nil @@ -770,20 +609,12 @@ func (s *Service) SetIsDoHEnabled(ctx context.Context, isEnabled *wrapperspb.Boo func (s *Service) IsDoHEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { s.log.Debug("IsDohEnabled") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Bool(s.bridge.GetProxyAllowed()), nil } func (s *Service) SetUseSslForSmtp(ctx context.Context, useSsl *wrapperspb.BoolValue) (*emptypb.Empty, error) { //nolint:revive,stylecheck s.log.WithField("useSsl", useSsl.Value).Debug("SetUseSslForSmtp") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - if s.bridge.GetBool(settings.SMTPSSLKey) == useSsl.Value { return &emptypb.Empty{}, nil } @@ -798,50 +629,30 @@ func (s *Service) SetUseSslForSmtp(ctx context.Context, useSsl *wrapperspb.BoolV func (s *Service) UseSslForSmtp(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { //nolint:revive,stylecheck s.log.Debug("UseSslForSmtp") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Bool(s.bridge.GetBool(settings.SMTPSSLKey)), nil } func (s *Service) Hostname(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { s.log.Debug("Hostname") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.String(bridge.Host), nil } func (s *Service) ImapPort(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.Int32Value, error) { s.log.Debug("ImapPort") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Int32(int32(s.bridge.GetInt(settings.IMAPPortKey))), nil } func (s *Service) SmtpPort(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.Int32Value, error) { //nolint:revive,stylecheck s.log.Debug("SmtpPort") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.Int32(int32(s.bridge.GetInt(settings.SMTPPortKey))), nil } func (s *Service) ChangePorts(ctx context.Context, ports *ChangePortsRequest) (*emptypb.Empty, error) { s.log.WithField("imapPort", ports.ImapPort).WithField("smtpPort", ports.SmtpPort).Debug("ChangePorts") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - s.bridge.SetInt(settings.IMAPPortKey, int(ports.ImapPort)) s.bridge.SetInt(settings.SMTPPortKey, int(ports.SmtpPort)) @@ -852,9 +663,6 @@ func (s *Service) ChangePorts(ctx context.Context, ports *ChangePortsRequest) (* func (s *Service) IsPortFree(ctx context.Context, port *wrapperspb.Int32Value) (*wrapperspb.BoolValue, error) { s.log.Debug("IsPortFree") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } return wrapperspb.Bool(ports.IsPortFree(int(port.Value))), nil } @@ -862,10 +670,6 @@ func (s *Service) IsPortFree(ctx context.Context, port *wrapperspb.Int32Value) ( func (s *Service) AvailableKeychains(ctx context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) { s.log.Debug("AvailableKeychains") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - keychains := make([]string, 0, len(keychain.Helpers)) for chain := range keychain.Helpers { keychains = append(keychains, chain) @@ -877,10 +681,6 @@ func (s *Service) AvailableKeychains(ctx context.Context, _ *emptypb.Empty) (*Av func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.StringValue) (*emptypb.Empty, error) { s.log.WithField("keychain", keychain.Value).Debug("SetCurrentKeyChain") // we do not check validity. - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - defer func() { _, _ = s.Restart(ctx, &emptypb.Empty{}) }() defer func() { _ = s.SendEvent(NewKeychainChangeKeychainFinishedEvent()) }() @@ -896,32 +696,5 @@ func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.S func (s *Service) CurrentKeychain(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) { s.log.Debug("CurrentKeychain") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - return wrapperspb.String(s.bridge.GetKeychainApp()), nil } - -// validateServerToken verify that the server token provided by the client is valid. -func (s *Service) validateServerToken(ctx context.Context) error { - values, ok := metadata.FromIncomingContext(ctx) - if !ok { - return status.Error(codes.Unauthenticated, "missing server token") - } - - token := values.Get(serverTokenMetadataKey) - if len(token) == 0 { - return status.Error(codes.Unauthenticated, "missing server token") - } - - if len(token) > 1 { - return status.Error(codes.Unauthenticated, "more than one server token was provided") - } - - if token[0] != s.token { - return status.Error(codes.Unauthenticated, "invalid server token") - } - - return nil -} diff --git a/internal/frontend/grpc/service_stream.go b/internal/frontend/grpc/service_stream.go index f671a92d..f32511fd 100644 --- a/internal/frontend/grpc/service_stream.go +++ b/internal/frontend/grpc/service_stream.go @@ -29,10 +29,6 @@ import ( func (s *Service) RunEventStream(request *EventStreamRequest, server Bridge_RunEventStreamServer) error { s.log.Debug("Starting Event stream") - if err := s.validateServerToken(server.Context()); err != nil { - return err - } - if s.eventStreamCh != nil { return status.Errorf(codes.AlreadyExists, "the service is already streaming") // TO-DO GODT-1667 decide if we want to kill the existing stream. } @@ -80,10 +76,6 @@ func (s *Service) RunEventStream(request *EventStreamRequest, server Bridge_RunE // StopEventStream stops the event stream. func (s *Service) StopEventStream(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) { - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - if s.eventStreamCh == nil { return nil, status.Errorf(codes.NotFound, "The service is not streaming") } diff --git a/internal/frontend/grpc/service_user.go b/internal/frontend/grpc/service_user.go index df47783c..9ccd1f61 100644 --- a/internal/frontend/grpc/service_user.go +++ b/internal/frontend/grpc/service_user.go @@ -32,10 +32,6 @@ import ( func (s *Service) GetUserList(ctx context.Context, _ *emptypb.Empty) (*UserListResponse, error) { s.log.Debug("GetUserList") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - userIDs := s.bridge.GetUserIDs() userList := make([]*User, len(userIDs)) @@ -59,10 +55,6 @@ func (s *Service) GetUserList(ctx context.Context, _ *emptypb.Empty) (*UserListR func (s *Service) GetUser(ctx context.Context, userID *wrapperspb.StringValue) (*User, error) { s.log.WithField("userID", userID).Debug("GetUser") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - user, err := s.bridge.GetUserInfo(userID.Value) if err != nil { return nil, status.Errorf(codes.NotFound, "user not found %v", userID.Value) @@ -74,10 +66,6 @@ func (s *Service) GetUser(ctx context.Context, userID *wrapperspb.StringValue) ( func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitModeRequest) (*emptypb.Empty, error) { s.log.WithField("UserID", splitMode.UserID).WithField("Active", splitMode.Active).Debug("SetUserSplitMode") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - user, err := s.bridge.GetUserInfo(splitMode.UserID) if err != nil { return nil, status.Errorf(codes.NotFound, "user not found %v", splitMode.UserID) @@ -106,10 +94,6 @@ func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitMode func (s *Service) LogoutUser(ctx context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) { s.log.WithField("UserID", userID.Value).Debug("LogoutUser") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - if _, err := s.bridge.GetUserInfo(userID.Value); err != nil { return nil, status.Errorf(codes.NotFound, "user not found %v", userID.Value) } @@ -128,10 +112,6 @@ func (s *Service) LogoutUser(ctx context.Context, userID *wrapperspb.StringValue func (s *Service) RemoveUser(ctx context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) { s.log.WithField("UserID", userID.Value).Debug("RemoveUser") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - go func() { defer s.panicHandler.HandlePanic() @@ -147,10 +127,6 @@ func (s *Service) RemoveUser(ctx context.Context, userID *wrapperspb.StringValue func (s *Service) ConfigureUserAppleMail(ctx context.Context, request *ConfigureAppleMailRequest) (*emptypb.Empty, error) { s.log.WithField("UserID", request.UserID).WithField("Address", request.Address).Debug("ConfigureUserAppleMail") - if err := s.validateServerToken(ctx); err != nil { - return nil, err - } - restart, err := s.bridge.ConfigureAppleMail(request.UserID, request.Address) if err != nil { s.log.WithField("userID", request.UserID).Error("Cannot configure AppleMail for user")