forked from Silverfish/proton-bridge
GODT-1936: check gRPC server token via interceptors.
This commit is contained in:
@ -44,8 +44,10 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
@ -129,7 +131,11 @@ func (s *Service) startGRPCServer() {
|
||||
|
||||
s.pemCert = string(pemCert)
|
||||
|
||||
s.grpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
s.grpcServer = grpc.NewServer(
|
||||
grpc.Creds(credentials.NewTLS(tlsConfig)),
|
||||
grpc.UnaryInterceptor(s.validateUnaryServerToken),
|
||||
grpc.StreamInterceptor(s.validateStreamServerToken),
|
||||
)
|
||||
|
||||
RegisterBridgeServer(s.grpcServer, s)
|
||||
|
||||
@ -458,3 +464,57 @@ func (s *Service) saveGRPCServerConfigFile() (string, error) {
|
||||
|
||||
return configPath, sc.save(configPath)
|
||||
}
|
||||
|
||||
// validateServerToken verify that the server token provided by the client is valid.
|
||||
func (s *Service) validateServerToken(ctx context.Context) error {
|
||||
values, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, "missing server token")
|
||||
}
|
||||
|
||||
token := values.Get(serverTokenMetadataKey)
|
||||
if len(token) == 0 {
|
||||
return status.Error(codes.Unauthenticated, "missing server token")
|
||||
}
|
||||
|
||||
if len(token) > 1 {
|
||||
return status.Error(codes.Unauthenticated, "more than one server token was provided")
|
||||
}
|
||||
|
||||
if token[0] != s.token {
|
||||
return status.Error(codes.Unauthenticated, "invalid server token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUnaryServerToken check the server token for every unary gRPC call.
|
||||
func (s *Service) validateUnaryServerToken(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (resp interface{}, err error) {
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// validateStreamServerToken check the server token for every gRPC stream request.
|
||||
func (s *Service) validateStreamServerToken(
|
||||
_ interface{},
|
||||
ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo,
|
||||
_ grpc.StreamHandler,
|
||||
) error {
|
||||
logEntry := s.log.WithField("FullMethod", info.FullMethod)
|
||||
|
||||
if err := s.validateServerToken(ss.Context()); err != nil {
|
||||
logEntry.WithError(err).Error("Stream validator failed")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -33,7 +33,6 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/ports"
|
||||
"github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
@ -43,10 +42,6 @@ import (
|
||||
func (s *Service) CheckTokens(ctx context.Context, clientConfigPath *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("CheckTokens")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path := clientConfigPath.Value
|
||||
logEntry := s.log.WithField("path", path)
|
||||
|
||||
@ -65,10 +60,6 @@ func (s *Service) CheckTokens(ctx context.Context, clientConfigPath *wrapperspb.
|
||||
func (s *Service) AddLogEntry(ctx context.Context, request *AddLogEntryRequest) (*emptypb.Empty, error) {
|
||||
entry := s.log
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(request.Package) > 0 {
|
||||
entry = entry.WithField("pkg", request.Package)
|
||||
}
|
||||
@ -97,10 +88,6 @@ func (s *Service) AddLogEntry(ctx context.Context, request *AddLogEntryRequest)
|
||||
func (s *Service) GuiReady(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("GuiReady")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.initializationDone.Do(s.initializing.Done)
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@ -109,10 +96,6 @@ func (s *Service) GuiReady(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empt
|
||||
func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("Quit")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Windows is notably slow at Quitting. We do it in a goroutine to speed things up a bit.
|
||||
go func() {
|
||||
var err error
|
||||
@ -133,10 +116,6 @@ func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empt
|
||||
func (s *Service) Restart(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("Restart")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.restarter.SetToRestart()
|
||||
return s.Quit(ctx, empty)
|
||||
}
|
||||
@ -144,20 +123,12 @@ func (s *Service) Restart(ctx context.Context, empty *emptypb.Empty) (*emptypb.E
|
||||
func (s *Service) ShowOnStartup(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("ShowOnStartup")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Bool(s.showOnStartup), nil
|
||||
}
|
||||
|
||||
func (s *Service) ShowSplashScreen(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("ShowSplashScreen")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.bridge.IsFirstStart() {
|
||||
return wrapperspb.Bool(false), nil
|
||||
}
|
||||
@ -176,20 +147,12 @@ func (s *Service) ShowSplashScreen(ctx context.Context, _ *emptypb.Empty) (*wrap
|
||||
func (s *Service) IsFirstGuiStart(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsFirstGuiStart")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetBool(settings.FirstStartGUIKey)), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("show", isOn.Value).Debug("SetIsAutostartOn")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() { _ = s.SendEvent(NewToggleAutostartFinishedEvent()) }()
|
||||
|
||||
if isOn.Value == s.bridge.IsAutostartEnabled() {
|
||||
@ -216,20 +179,12 @@ func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolVal
|
||||
func (s *Service) IsAutostartOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsAutostartOn")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Bool(s.bridge.IsAutostartEnabled()), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isEnabled", isEnabled.Value).Debug("SetIsBetaEnabled")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channel := updater.StableChannel
|
||||
if isEnabled.Value {
|
||||
channel = updater.EarlyChannel
|
||||
@ -244,20 +199,12 @@ func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.Bo
|
||||
func (s *Service) IsBetaEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsBetaEnabled")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetUpdateChannel() == updater.EarlyChannel), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isVisible", isVisible.Value).Debug("SetIsAllMailVisible")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.bridge.SetIsAllMailVisible(isVisible.Value)
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
@ -266,30 +213,18 @@ func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb
|
||||
func (s *Service) IsAllMailVisible(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsAllMailVisible")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Bool(s.bridge.IsAllMailVisible()), nil
|
||||
}
|
||||
|
||||
func (s *Service) GoOs(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("GoOs") // TO-DO We can probably get rid of this and use QSysInfo::product name
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.String(runtime.GOOS), nil
|
||||
}
|
||||
|
||||
func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("TriggerReset")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
s.triggerReset()
|
||||
@ -300,20 +235,12 @@ func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb.
|
||||
func (s *Service) Version(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("Version")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.String(constants.Version), nil
|
||||
}
|
||||
|
||||
func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("LogsPath")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path, err := s.bridge.ProvideLogsPath()
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("Cannot determine logs path")
|
||||
@ -325,44 +252,24 @@ func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.S
|
||||
func (s *Service) LicensePath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("LicensePath")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.String(s.bridge.GetLicenseFilePath()), nil
|
||||
}
|
||||
|
||||
func (s *Service) DependencyLicensesLink(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.String(s.bridge.GetDependencyLicensesLink()), nil
|
||||
}
|
||||
|
||||
func (s *Service) ReleaseNotesPageLink(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.String(s.newVersionInfo.ReleaseNotesPage), nil
|
||||
}
|
||||
|
||||
func (s *Service) LandingPageLink(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.String(s.newVersionInfo.LandingPage), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("ColorSchemeName", name.Value).Debug("SetColorSchemeName")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !theme.IsAvailable(theme.Theme(name.Value)) {
|
||||
s.log.WithField("scheme", name.Value).Warn("Color scheme not available")
|
||||
return nil, status.Error(codes.NotFound, "Color scheme not available")
|
||||
@ -376,10 +283,6 @@ func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.Strin
|
||||
func (s *Service) ColorSchemeName(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("ColorSchemeName")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
current := s.bridge.Get(settings.ColorScheme)
|
||||
if !theme.IsAvailable(theme.Theme(current)) {
|
||||
current = string(theme.DefaultTheme())
|
||||
@ -392,10 +295,6 @@ func (s *Service) ColorSchemeName(ctx context.Context, _ *emptypb.Empty) (*wrapp
|
||||
func (s *Service) CurrentEmailClient(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("CurrentEmailClient")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.String(s.bridge.GetCurrentUserAgent()), nil
|
||||
}
|
||||
|
||||
@ -409,10 +308,6 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp
|
||||
"includeLogs": report.IncludeLogs,
|
||||
}).Debug("ReportBug")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = s.SendEvent(NewReportBugFinishedEvent()) }()
|
||||
|
||||
@ -439,10 +334,6 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp
|
||||
func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("launcher", launcher.Value).Debug("ForceLauncher")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
s.restarter.ForceLauncher(launcher.Value)
|
||||
@ -453,10 +344,6 @@ func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.String
|
||||
func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("executable", exe.Value).Debug("SetMainExecutable")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
s.restarter.SetMainExecutable(exe.Value)
|
||||
@ -467,10 +354,6 @@ func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringV
|
||||
func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", login.Username).Debug("Login")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
@ -518,10 +401,6 @@ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empt
|
||||
func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", login.Username).Debug("Login2FA")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
@ -575,10 +454,6 @@ func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.E
|
||||
func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", login.Username).Debug("Login2Passwords")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
@ -601,10 +476,6 @@ func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*em
|
||||
func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", loginAbort.Username).Debug("LoginAbort")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
@ -617,10 +488,6 @@ func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest)
|
||||
func (s *Service) CheckUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("CheckUpdate")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
@ -632,10 +499,6 @@ func (s *Service) CheckUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb.E
|
||||
func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("InstallUpdate")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
@ -648,10 +511,6 @@ func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb
|
||||
func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isOn", isOn.Value).Debug("SetIsAutomaticUpdateOn")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentlyOn := s.bridge.GetBool(settings.AutoUpdateKey)
|
||||
if currentlyOn == isOn.Value {
|
||||
return &emptypb.Empty{}, nil
|
||||
@ -670,30 +529,18 @@ func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.B
|
||||
func (s *Service) IsAutomaticUpdateOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsAutomaticUpdateOn")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetBool(settings.AutoUpdateKey)), nil
|
||||
}
|
||||
|
||||
func (s *Service) IsCacheOnDiskEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsCacheOnDiskEnabled")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetBool(settings.CacheEnabledKey)), nil
|
||||
}
|
||||
|
||||
func (s *Service) DiskCachePath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("DiskCachePath")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.String(s.bridge.Get(settings.CacheLocationKey)), nil
|
||||
}
|
||||
|
||||
@ -702,10 +549,6 @@ func (s *Service) ChangeLocalCache(ctx context.Context, change *ChangeLocalCache
|
||||
WithField("diskCachePath", change.DiskCachePath).
|
||||
Debug("DiskCachePath")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restart := false
|
||||
defer func(willRestart *bool) {
|
||||
_ = s.SendEvent(NewCacheChangeLocalCacheFinishedEvent(*willRestart))
|
||||
@ -758,10 +601,6 @@ func (s *Service) ChangeLocalCache(ctx context.Context, change *ChangeLocalCache
|
||||
func (s *Service) SetIsDoHEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isEnabled", isEnabled.Value).Debug("SetIsDohEnabled")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.bridge.SetProxyAllowed(isEnabled.Value)
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
@ -770,20 +609,12 @@ func (s *Service) SetIsDoHEnabled(ctx context.Context, isEnabled *wrapperspb.Boo
|
||||
func (s *Service) IsDoHEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsDohEnabled")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetProxyAllowed()), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetUseSslForSmtp(ctx context.Context, useSsl *wrapperspb.BoolValue) (*emptypb.Empty, error) { //nolint:revive,stylecheck
|
||||
s.log.WithField("useSsl", useSsl.Value).Debug("SetUseSslForSmtp")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.bridge.GetBool(settings.SMTPSSLKey) == useSsl.Value {
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@ -798,50 +629,30 @@ func (s *Service) SetUseSslForSmtp(ctx context.Context, useSsl *wrapperspb.BoolV
|
||||
func (s *Service) UseSslForSmtp(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { //nolint:revive,stylecheck
|
||||
s.log.Debug("UseSslForSmtp")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetBool(settings.SMTPSSLKey)), nil
|
||||
}
|
||||
|
||||
func (s *Service) Hostname(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("Hostname")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.String(bridge.Host), nil
|
||||
}
|
||||
|
||||
func (s *Service) ImapPort(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.Int32Value, error) {
|
||||
s.log.Debug("ImapPort")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Int32(int32(s.bridge.GetInt(settings.IMAPPortKey))), nil
|
||||
}
|
||||
|
||||
func (s *Service) SmtpPort(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.Int32Value, error) { //nolint:revive,stylecheck
|
||||
s.log.Debug("SmtpPort")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Int32(int32(s.bridge.GetInt(settings.SMTPPortKey))), nil
|
||||
}
|
||||
|
||||
func (s *Service) ChangePorts(ctx context.Context, ports *ChangePortsRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("imapPort", ports.ImapPort).WithField("smtpPort", ports.SmtpPort).Debug("ChangePorts")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.bridge.SetInt(settings.IMAPPortKey, int(ports.ImapPort))
|
||||
s.bridge.SetInt(settings.SMTPPortKey, int(ports.SmtpPort))
|
||||
|
||||
@ -852,9 +663,6 @@ func (s *Service) ChangePorts(ctx context.Context, ports *ChangePortsRequest) (*
|
||||
|
||||
func (s *Service) IsPortFree(ctx context.Context, port *wrapperspb.Int32Value) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsPortFree")
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.Bool(ports.IsPortFree(int(port.Value))), nil
|
||||
}
|
||||
@ -862,10 +670,6 @@ func (s *Service) IsPortFree(ctx context.Context, port *wrapperspb.Int32Value) (
|
||||
func (s *Service) AvailableKeychains(ctx context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) {
|
||||
s.log.Debug("AvailableKeychains")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keychains := make([]string, 0, len(keychain.Helpers))
|
||||
for chain := range keychain.Helpers {
|
||||
keychains = append(keychains, chain)
|
||||
@ -877,10 +681,6 @@ func (s *Service) AvailableKeychains(ctx context.Context, _ *emptypb.Empty) (*Av
|
||||
func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("keychain", keychain.Value).Debug("SetCurrentKeyChain") // we do not check validity.
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() { _, _ = s.Restart(ctx, &emptypb.Empty{}) }()
|
||||
defer func() { _ = s.SendEvent(NewKeychainChangeKeychainFinishedEvent()) }()
|
||||
|
||||
@ -896,32 +696,5 @@ func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.S
|
||||
func (s *Service) CurrentKeychain(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("CurrentKeychain")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrapperspb.String(s.bridge.GetKeychainApp()), nil
|
||||
}
|
||||
|
||||
// validateServerToken verify that the server token provided by the client is valid.
|
||||
func (s *Service) validateServerToken(ctx context.Context) error {
|
||||
values, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, "missing server token")
|
||||
}
|
||||
|
||||
token := values.Get(serverTokenMetadataKey)
|
||||
if len(token) == 0 {
|
||||
return status.Error(codes.Unauthenticated, "missing server token")
|
||||
}
|
||||
|
||||
if len(token) > 1 {
|
||||
return status.Error(codes.Unauthenticated, "more than one server token was provided")
|
||||
}
|
||||
|
||||
if token[0] != s.token {
|
||||
return status.Error(codes.Unauthenticated, "invalid server token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -29,10 +29,6 @@ import (
|
||||
func (s *Service) RunEventStream(request *EventStreamRequest, server Bridge_RunEventStreamServer) error {
|
||||
s.log.Debug("Starting Event stream")
|
||||
|
||||
if err := s.validateServerToken(server.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.eventStreamCh != nil {
|
||||
return status.Errorf(codes.AlreadyExists, "the service is already streaming") // TO-DO GODT-1667 decide if we want to kill the existing stream.
|
||||
}
|
||||
@ -80,10 +76,6 @@ func (s *Service) RunEventStream(request *EventStreamRequest, server Bridge_RunE
|
||||
|
||||
// StopEventStream stops the event stream.
|
||||
func (s *Service) StopEventStream(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.eventStreamCh == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "The service is not streaming")
|
||||
}
|
||||
|
||||
@ -32,10 +32,6 @@ import (
|
||||
func (s *Service) GetUserList(ctx context.Context, _ *emptypb.Empty) (*UserListResponse, error) {
|
||||
s.log.Debug("GetUserList")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userIDs := s.bridge.GetUserIDs()
|
||||
userList := make([]*User, len(userIDs))
|
||||
|
||||
@ -59,10 +55,6 @@ func (s *Service) GetUserList(ctx context.Context, _ *emptypb.Empty) (*UserListR
|
||||
func (s *Service) GetUser(ctx context.Context, userID *wrapperspb.StringValue) (*User, error) {
|
||||
s.log.WithField("userID", userID).Debug("GetUser")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := s.bridge.GetUserInfo(userID.Value)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found %v", userID.Value)
|
||||
@ -74,10 +66,6 @@ func (s *Service) GetUser(ctx context.Context, userID *wrapperspb.StringValue) (
|
||||
func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitModeRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("UserID", splitMode.UserID).WithField("Active", splitMode.Active).Debug("SetUserSplitMode")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := s.bridge.GetUserInfo(splitMode.UserID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found %v", splitMode.UserID)
|
||||
@ -106,10 +94,6 @@ func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitMode
|
||||
func (s *Service) LogoutUser(ctx context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("UserID", userID.Value).Debug("LogoutUser")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := s.bridge.GetUserInfo(userID.Value); err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found %v", userID.Value)
|
||||
}
|
||||
@ -128,10 +112,6 @@ func (s *Service) LogoutUser(ctx context.Context, userID *wrapperspb.StringValue
|
||||
func (s *Service) RemoveUser(ctx context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("UserID", userID.Value).Debug("RemoveUser")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
@ -147,10 +127,6 @@ func (s *Service) RemoveUser(ctx context.Context, userID *wrapperspb.StringValue
|
||||
func (s *Service) ConfigureUserAppleMail(ctx context.Context, request *ConfigureAppleMailRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("UserID", request.UserID).WithField("Address", request.Address).Debug("ConfigureUserAppleMail")
|
||||
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restart, err := s.bridge.ConfigureAppleMail(request.UserID, request.Address)
|
||||
if err != nil {
|
||||
s.log.WithField("userID", request.UserID).Error("Cannot configure AppleMail for user")
|
||||
|
||||
Reference in New Issue
Block a user