mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-11 13:16:53 +00:00
GODT-1779: Remove go-imap
This commit is contained in:
@ -21,34 +21,29 @@ package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
cryptotls "crypto/tls"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/users"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/restarter"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -59,6 +54,7 @@ const (
|
||||
// Service is the RPC service struct.
|
||||
type Service struct { // nolint:structcheck
|
||||
UnimplementedBridgeServer
|
||||
|
||||
grpcServer *grpc.Server // the gGRPC server
|
||||
listener net.Listener
|
||||
eventStreamCh chan *StreamEvent
|
||||
@ -66,99 +62,87 @@ type Service struct { // nolint:structcheck
|
||||
eventQueue []*StreamEvent
|
||||
eventQueueMutex sync.Mutex
|
||||
|
||||
panicHandler types.PanicHandler
|
||||
eventListener listener.Listener
|
||||
updater types.Updater
|
||||
updateCheckMutex sync.Mutex
|
||||
bridge types.Bridger
|
||||
restarter types.Restarter
|
||||
showOnStartup bool
|
||||
authClient pmapi.Client
|
||||
auth *pmapi.Auth
|
||||
password []byte
|
||||
newVersionInfo updater.VersionInfo
|
||||
panicHandler *crash.Handler
|
||||
restarter *restarter.Restarter
|
||||
bridge *bridge.Bridge
|
||||
newVersionInfo updater.VersionInfo
|
||||
|
||||
log *logrus.Entry
|
||||
initializing sync.WaitGroup
|
||||
initializationDone sync.Once
|
||||
firstTimeAutostart sync.Once
|
||||
locations *locations.Locations
|
||||
token string
|
||||
pemCert string
|
||||
|
||||
showOnStartup bool
|
||||
}
|
||||
|
||||
// NewService returns a new instance of the service.
|
||||
func NewService(
|
||||
showOnStartup bool,
|
||||
panicHandler types.PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
updater types.Updater,
|
||||
bridge types.Bridger,
|
||||
restarter types.Restarter,
|
||||
panicHandler *crash.Handler,
|
||||
restarter *restarter.Restarter,
|
||||
locations *locations.Locations,
|
||||
) *Service {
|
||||
s := Service{
|
||||
UnimplementedBridgeServer: UnimplementedBridgeServer{},
|
||||
panicHandler: panicHandler,
|
||||
eventListener: eventListener,
|
||||
updater: updater,
|
||||
bridge: bridge,
|
||||
restarter: restarter,
|
||||
showOnStartup: showOnStartup,
|
||||
bridge *bridge.Bridge,
|
||||
showOnStartup bool,
|
||||
) (*Service, error) {
|
||||
tlsConfig, certPEM, err := newTLSConfig()
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panic("Could not generate gRPC TLS config")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0") // Port should be provided by the OS.
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panic("Could not create gRPC listener")
|
||||
}
|
||||
|
||||
token := uuid.NewString()
|
||||
|
||||
if path, err := saveGRPCServerConfigFile(locations, listener, token, certPEM); err != nil {
|
||||
logrus.WithError(err).WithField("path", path).Panic("Could not write gRPC service config file")
|
||||
} else {
|
||||
logrus.WithField("path", path).Info("Successfully saved gRPC service config file")
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
grpcServer: grpc.NewServer(
|
||||
grpc.Creds(credentials.NewTLS(tlsConfig)),
|
||||
grpc.UnaryInterceptor(newUnaryTokenValidator(token)),
|
||||
grpc.StreamInterceptor(newStreamTokenValidator(token)),
|
||||
),
|
||||
listener: listener,
|
||||
|
||||
panicHandler: panicHandler,
|
||||
restarter: restarter,
|
||||
bridge: bridge,
|
||||
|
||||
log: logrus.WithField("pkg", "grpc"),
|
||||
initializing: sync.WaitGroup{},
|
||||
initializationDone: sync.Once{},
|
||||
firstTimeAutostart: sync.Once{},
|
||||
locations: locations,
|
||||
token: uuid.NewString(),
|
||||
|
||||
showOnStartup: showOnStartup,
|
||||
}
|
||||
|
||||
// Initializing.Done is only called sync.Once. Please keep the increment
|
||||
// set to 1
|
||||
// Initializing.Done is only called sync.Once. Please keep the increment set to 1
|
||||
s.initializing.Add(1)
|
||||
|
||||
tlsConfig, pemCert, err := s.generateTLSConfig()
|
||||
if err != nil {
|
||||
s.log.WithError(err).Panic("Could not generate gRPC TLS config")
|
||||
}
|
||||
|
||||
s.pemCert = string(pemCert)
|
||||
|
||||
// Initialize the autostart.
|
||||
s.initAutostart()
|
||||
s.grpcServer = grpc.NewServer(
|
||||
grpc.Creds(credentials.NewTLS(tlsConfig)),
|
||||
grpc.UnaryInterceptor(s.validateUnaryServerToken),
|
||||
grpc.StreamInterceptor(s.validateStreamServerToken),
|
||||
)
|
||||
|
||||
RegisterBridgeServer(s.grpcServer, &s)
|
||||
|
||||
s.listener, err = net.Listen("tcp", "127.0.0.1:0") // Port 0 means that the port is randomly picked by the system.
|
||||
if err != nil {
|
||||
s.log.WithError(err).Panic("Could not create gRPC listener")
|
||||
}
|
||||
|
||||
if path, err := s.saveGRPCServerConfigFile(); err != nil {
|
||||
s.log.WithError(err).WithField("path", path).Panic("Could not write gRPC service config file")
|
||||
} else {
|
||||
s.log.WithField("path", path).Info("Successfully saved gRPC service config file")
|
||||
}
|
||||
// Register the gRPC service implementation.
|
||||
RegisterBridgeServer(s.grpcServer, s)
|
||||
|
||||
s.log.Info("gRPC server listening on ", s.listener.Addr())
|
||||
|
||||
return &s
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GODT-1507 Windows: autostart needs to be created after Qt is initialized.
|
||||
// GODT-1206: if preferences file says it should be on enable it here.
|
||||
// TO-DO GODT-1681 Autostart needs to be properly implement for gRPC approach.
|
||||
func (s *Service) initAutostart() {
|
||||
// GODT-1507 Windows: autostart needs to be created after Qt is initialized.
|
||||
// GODT-1206: if preferences file says it should be on enable it here.
|
||||
|
||||
// TO-DO GODT-1681 Autostart needs to be properly implement for gRPC approach.
|
||||
|
||||
s.firstTimeAutostart.Do(func() {
|
||||
shouldAutostartBeOn := s.bridge.GetBool(settings.AutostartKey)
|
||||
if s.bridge.IsFirstStart() || shouldAutostartBeOn {
|
||||
if err := s.bridge.EnableAutostart(); err != nil {
|
||||
shouldAutostartBeOn := s.bridge.GetAutostart()
|
||||
if s.bridge.GetFirstStart() || shouldAutostartBeOn {
|
||||
if err := s.bridge.SetAutostart(true); err != nil {
|
||||
s.log.WithField("prefs", shouldAutostartBeOn).WithError(err).Error("Failed to enable first autostart")
|
||||
}
|
||||
return
|
||||
@ -168,7 +152,7 @@ func (s *Service) initAutostart() {
|
||||
|
||||
func (s *Service) Loop() error {
|
||||
defer func() {
|
||||
s.bridge.SetBool(settings.FirstStartGUIKey, false)
|
||||
_ = s.bridge.SetFirstStartGUI(false)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
@ -179,7 +163,7 @@ func (s *Service) Loop() error {
|
||||
s.log.Info("Starting gRPC server")
|
||||
|
||||
if err := s.grpcServer.Serve(s.listener); err != nil {
|
||||
s.log.WithError(err).Error("Error serving gRPC")
|
||||
s.log.WithError(err).Error("Failed to serve gRPC")
|
||||
return err
|
||||
}
|
||||
|
||||
@ -212,140 +196,59 @@ func (s *Service) WaitUntilFrontendIsReady() {
|
||||
s.initializing.Wait()
|
||||
}
|
||||
|
||||
func (s *Service) watchEvents() { // nolint:funlen
|
||||
if s.bridge.HasError(bridge.ErrLocalCacheUnavailable) {
|
||||
_ = s.SendEvent(NewCacheErrorEvent(CacheErrorType_CACHE_UNAVAILABLE_ERROR))
|
||||
}
|
||||
func (s *Service) watchEvents() {
|
||||
eventCh, done := s.bridge.GetEvents()
|
||||
defer done()
|
||||
|
||||
errorCh := s.eventListener.ProvideChannel(events.ErrorEvent)
|
||||
credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||
noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent)
|
||||
internetConnChangedCh := s.eventListener.ProvideChannel(events.InternetConnChangedEvent)
|
||||
secondInstanceCh := s.eventListener.ProvideChannel(events.SecondInstanceEvent)
|
||||
restartBridgeCh := s.eventListener.ProvideChannel(events.RestartBridgeEvent)
|
||||
addressChangedCh := s.eventListener.ProvideChannel(events.AddressChangedEvent)
|
||||
addressChangedLogoutCh := s.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||
logoutCh := s.eventListener.ProvideChannel(events.LogoutEvent)
|
||||
updateApplicationCh := s.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
|
||||
userChangedCh := s.eventListener.ProvideChannel(events.UserRefreshEvent)
|
||||
certIssue := s.eventListener.ProvideChannel(events.TLSCertIssue)
|
||||
|
||||
// we forward events to the GUI/frontend via the gRPC event stream.
|
||||
for {
|
||||
select {
|
||||
case errorDetails := <-errorCh:
|
||||
if strings.Contains(errorDetails, "IMAP failed") {
|
||||
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_IMAP_PORT_ISSUE))
|
||||
}
|
||||
if strings.Contains(errorDetails, "SMTP failed") {
|
||||
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_SMTP_PORT_ISSUE))
|
||||
}
|
||||
case reason := <-credentialsErrorCh:
|
||||
if reason == keychain.ErrMacKeychainRebuild.Error() {
|
||||
_ = s.SendEvent(NewKeychainRebuildKeychainEvent())
|
||||
continue
|
||||
}
|
||||
// TODO: Better error events.
|
||||
for _, err := range s.bridge.GetErrors() {
|
||||
switch {
|
||||
case errors.Is(err, vault.ErrCorrupt):
|
||||
_ = s.SendEvent(NewKeychainHasNoKeychainEvent())
|
||||
case email := <-noActiveKeyForRecipientCh:
|
||||
_ = s.SendEvent(NewMailNoActiveKeyForRecipientEvent(email))
|
||||
case stat := <-internetConnChangedCh:
|
||||
if stat == events.InternetOff {
|
||||
_ = s.SendEvent(NewInternetStatusEvent(false))
|
||||
}
|
||||
if stat == events.InternetOn {
|
||||
_ = s.SendEvent(NewInternetStatusEvent(true))
|
||||
}
|
||||
|
||||
case <-secondInstanceCh:
|
||||
_ = s.SendEvent(NewShowMainWindowEvent())
|
||||
case <-restartBridgeCh:
|
||||
_, _ = s.Restart(
|
||||
metadata.AppendToOutgoingContext(context.Background(), serverTokenMetadataKey, s.token),
|
||||
&emptypb.Empty{},
|
||||
)
|
||||
case address := <-addressChangedCh:
|
||||
_ = s.SendEvent(NewMailAddressChangeEvent(address))
|
||||
case address := <-addressChangedLogoutCh:
|
||||
_ = s.SendEvent(NewMailAddressChangeLogoutEvent(address))
|
||||
case userID := <-logoutCh:
|
||||
user, err := s.bridge.GetUserInfo(userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = s.SendEvent(NewUserDisconnectedEvent(user.Username))
|
||||
case <-updateApplicationCh:
|
||||
s.updateForce()
|
||||
case userID := <-userChangedCh:
|
||||
_ = s.SendEvent(NewUserChangedEvent(userID))
|
||||
case <-certIssue:
|
||||
_ = s.SendEvent(NewMailApiCertIssue())
|
||||
case errors.Is(err, vault.ErrInsecure):
|
||||
_ = s.SendEvent(NewKeychainHasNoKeychainEvent())
|
||||
|
||||
case errors.Is(err, bridge.ErrServeIMAP):
|
||||
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_IMAP_PORT_ISSUE))
|
||||
|
||||
case errors.Is(err, bridge.ErrServeSMTP):
|
||||
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_SMTP_PORT_ISSUE))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) loginAbort() {
|
||||
s.loginClean()
|
||||
}
|
||||
for event := range eventCh {
|
||||
switch event := event.(type) {
|
||||
case events.ConnStatus:
|
||||
_ = s.SendEvent(NewInternetStatusEvent(event.Status == liteapi.StatusUp))
|
||||
|
||||
func (s *Service) loginClean() {
|
||||
s.auth = nil
|
||||
s.authClient = nil
|
||||
for i := range s.password {
|
||||
s.password[i] = '\x00'
|
||||
}
|
||||
s.password = s.password[0:0]
|
||||
}
|
||||
case events.Raise:
|
||||
_ = s.SendEvent(NewShowMainWindowEvent())
|
||||
|
||||
func (s *Service) finishLogin() {
|
||||
defer s.loginClean()
|
||||
case events.UserAddressCreated:
|
||||
_ = s.SendEvent(NewMailAddressChangeEvent(event.Address))
|
||||
|
||||
if len(s.password) == 0 || s.auth == nil || s.authClient == nil {
|
||||
s.log.
|
||||
WithField("hasPass", len(s.password) != 0).
|
||||
WithField("hasAuth", s.auth != nil).
|
||||
WithField("hasClient", s.authClient != nil).
|
||||
Error("Finish login: authentication incomplete")
|
||||
case events.UserAddressChanged:
|
||||
_ = s.SendEvent(NewMailAddressChangeEvent(event.Address))
|
||||
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_TWO_PASSWORDS_ABORT, "Missing authentication, try again."))
|
||||
return
|
||||
}
|
||||
case events.UserAddressDeleted:
|
||||
_ = s.SendEvent(NewMailAddressChangeLogoutEvent(event.Address))
|
||||
|
||||
done := make(chan string)
|
||||
s.eventListener.Add(events.UserChangeDone, done)
|
||||
defer s.eventListener.Remove(events.UserChangeDone, done)
|
||||
case events.UserChanged:
|
||||
_ = s.SendEvent(NewUserChangedEvent(event.UserID))
|
||||
|
||||
userID, err := s.bridge.FinishLogin(s.authClient, s.auth, s.password)
|
||||
|
||||
if err != nil && err != users.ErrUserAlreadyConnected {
|
||||
s.log.WithError(err).Errorf("Finish login failed")
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_TWO_PASSWORDS_ABORT, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// The user changed should be triggered by FinishLogin, but it is not
|
||||
// guaranteed when this is going to happen. Therefor we should wait
|
||||
// until we receive the signal from userChanged function.
|
||||
s.waitForUserChangeDone(done, userID)
|
||||
|
||||
s.log.WithField("userID", userID).Debug("Login finished")
|
||||
_ = s.SendEvent(NewLoginFinishedEvent(userID))
|
||||
|
||||
if err == users.ErrUserAlreadyConnected {
|
||||
s.log.WithError(err).Error("User already logged in")
|
||||
_ = s.SendEvent(NewLoginAlreadyLoggedInEvent(userID))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) waitForUserChangeDone(done <-chan string, userID string) {
|
||||
for {
|
||||
select {
|
||||
case changedID := <-done:
|
||||
if changedID == userID {
|
||||
return
|
||||
case events.UserDeauth:
|
||||
if user, err := s.bridge.GetUserInfo(event.UserID); err != nil {
|
||||
s.log.WithError(err).Error("Failed to get user info")
|
||||
} else {
|
||||
_ = s.SendEvent(NewUserDisconnectedEvent(user.Username))
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
s.log.WithField("ID", userID).Warning("Login finished but user not added within 2 seconds")
|
||||
return
|
||||
|
||||
case events.TLSIssue:
|
||||
_ = s.SendEvent(NewMailApiCertIssue())
|
||||
|
||||
case events.UpdateForced:
|
||||
panic("TODO")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -354,103 +257,46 @@ func (s *Service) triggerReset() {
|
||||
defer func() {
|
||||
_ = s.SendEvent(NewResetFinishedEvent())
|
||||
}()
|
||||
s.bridge.FactoryReset()
|
||||
if err := s.bridge.FactoryReset(context.Background()); err != nil {
|
||||
s.log.WithError(err).Error("Failed to reset")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) checkUpdate() {
|
||||
version, err := s.updater.Check()
|
||||
func newTLSConfig() (*tls.Config, []byte, error) {
|
||||
template, err := certs.NewTLSTemplate()
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("An error occurred while checking for updates")
|
||||
s.SetVersion(updater.VersionInfo{})
|
||||
return
|
||||
}
|
||||
s.SetVersion(version)
|
||||
}
|
||||
|
||||
func (s *Service) updateForce() {
|
||||
s.updateCheckMutex.Lock()
|
||||
defer s.updateCheckMutex.Unlock()
|
||||
s.checkUpdate()
|
||||
_ = s.SendEvent(NewUpdateForceEvent(s.newVersionInfo.Version.String()))
|
||||
}
|
||||
|
||||
func (s *Service) checkUpdateAndNotify(isReqFromUser bool) {
|
||||
s.updateCheckMutex.Lock()
|
||||
defer func() {
|
||||
s.updateCheckMutex.Unlock()
|
||||
_ = s.SendEvent(NewUpdateCheckFinishedEvent())
|
||||
}()
|
||||
|
||||
s.checkUpdate()
|
||||
version := s.newVersionInfo
|
||||
if (version.Version == nil) || (version.Version.String() == "") {
|
||||
if isReqFromUser {
|
||||
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
|
||||
}
|
||||
return
|
||||
}
|
||||
if !s.updater.IsUpdateApplicable(s.newVersionInfo) {
|
||||
s.log.Info("No need to update")
|
||||
if isReqFromUser {
|
||||
_ = s.SendEvent(NewUpdateIsLatestVersionEvent())
|
||||
}
|
||||
} else if isReqFromUser {
|
||||
s.NotifyManualUpdate(s.newVersionInfo, s.updater.CanInstall(s.newVersionInfo))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) installUpdate() {
|
||||
s.updateCheckMutex.Lock()
|
||||
defer s.updateCheckMutex.Unlock()
|
||||
|
||||
if !s.updater.CanInstall(s.newVersionInfo) {
|
||||
s.log.Warning("Skipping update installation, current version too old")
|
||||
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
|
||||
return
|
||||
return nil, nil, fmt.Errorf("failed to create TLS template: %w", err)
|
||||
}
|
||||
|
||||
if err := s.updater.InstallUpdate(s.newVersionInfo); err != nil {
|
||||
if errors.Cause(err) == updater.ErrDownloadVerify {
|
||||
s.log.WithError(err).Warning("Skipping update installation due to temporary error")
|
||||
} else {
|
||||
s.log.WithError(err).Error("The update couldn't be installed")
|
||||
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_ = s.SendEvent(NewUpdateSilentRestartNeededEvent())
|
||||
}
|
||||
|
||||
func (s *Service) generateTLSConfig() (tlsConfig *cryptotls.Config, pemCert []byte, err error) {
|
||||
pemCert, pemKey, err := tls.NewPEMKeyPair()
|
||||
certPEM, keyPEM, err := certs.GenerateCert(template)
|
||||
if err != nil {
|
||||
return nil, nil, errors.New("Could not get TLS config")
|
||||
return nil, nil, fmt.Errorf("failed to generate cert: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = tls.GetConfigFromPEMKeyPair(pemCert, pemKey)
|
||||
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, nil, errors.New("Could not get TLS config")
|
||||
return nil, nil, fmt.Errorf("failed to load cert: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig.ClientAuth = cryptotls.NoClientCert // skip client auth if the certificate allow it.
|
||||
|
||||
return tlsConfig, pemCert, nil
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ClientAuth: tls.NoClientCert,
|
||||
}, certPEM, nil
|
||||
}
|
||||
|
||||
func (s *Service) saveGRPCServerConfigFile() (string, error) {
|
||||
address, ok := s.listener.Addr().(*net.TCPAddr)
|
||||
func saveGRPCServerConfigFile(locations *locations.Locations, listener net.Listener, token string, certPEM []byte) (string, error) {
|
||||
address, ok := listener.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("could not retrieve gRPC service listener address")
|
||||
}
|
||||
|
||||
sc := config{
|
||||
Port: address.Port,
|
||||
Cert: s.pemCert,
|
||||
Token: s.token,
|
||||
Cert: string(certPEM),
|
||||
Token: token,
|
||||
}
|
||||
|
||||
settingsPath, err := s.locations.ProvideSettingsPath()
|
||||
settingsPath, err := locations.ProvideSettingsPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -461,7 +307,7 @@ func (s *Service) saveGRPCServerConfigFile() (string, error) {
|
||||
}
|
||||
|
||||
// validateServerToken verify that the server token provided by the client is valid.
|
||||
func (s *Service) validateServerToken(ctx context.Context) error {
|
||||
func validateServerToken(ctx context.Context, wantToken string) error {
|
||||
values, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, "missing server token")
|
||||
@ -476,40 +322,31 @@ func (s *Service) validateServerToken(ctx context.Context) error {
|
||||
return status.Error(codes.Unauthenticated, "more than one server token was provided")
|
||||
}
|
||||
|
||||
if token[0] != s.token {
|
||||
if token[0] != wantToken {
|
||||
return status.Error(codes.Unauthenticated, "invalid server token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUnaryServerToken check the server token for every unary gRPC call.
|
||||
func (s *Service) validateUnaryServerToken(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (resp interface{}, err error) {
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// newUnaryTokenValidator checks the server token for every unary gRPC call.
|
||||
func newUnaryTokenValidator(wantToken string) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if err := validateServerToken(ctx, wantToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return handler(ctx, req)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// validateStreamServerToken check the server token for every gRPC stream request.
|
||||
func (s *Service) validateStreamServerToken(
|
||||
srv interface{},
|
||||
ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler,
|
||||
) error {
|
||||
logEntry := s.log.WithField("FullMethod", info.FullMethod)
|
||||
// newStreamTokenValidator checks the server token for every gRPC stream request.
|
||||
func newStreamTokenValidator(wantToken string) grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if err := validateServerToken(stream.Context(), wantToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.validateServerToken(ss.Context()); err != nil {
|
||||
logEntry.WithError(err).Error("Stream validator failed")
|
||||
return err
|
||||
return handler(srv, stream)
|
||||
}
|
||||
|
||||
return handler(srv, ss)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user