From ff288145df977ed7bc6e3c5a3468bc3469847799 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Fri, 10 Nov 2023 12:57:07 +0100 Subject: [PATCH] fix(GODT-1623): Throttle SMTP failed requests If a SMPT client keeps hammering bridge and triggers multiple successive errors in quick succession, force that client to wait 20 seconds before trying again. --- internal/services/smtp/accounts.go | 73 ++++++++++++++++++++++--- internal/services/smtp/accounts_test.go | 46 ++++++++++++++++ internal/services/smtp/errors.go | 1 + 3 files changed, 111 insertions(+), 9 deletions(-) create mode 100644 internal/services/smtp/accounts_test.go diff --git a/internal/services/smtp/accounts.go b/internal/services/smtp/accounts.go index 4bff6965..d31f93f7 100644 --- a/internal/services/smtp/accounts.go +++ b/internal/services/smtp/accounts.go @@ -21,16 +21,21 @@ import ( "context" "io" "sync" + "time" ) type Accounts struct { accountsLock sync.RWMutex - accounts map[string]*Service + accounts map[string]*smtpAccountState } +const maxFailedCommands = 3 +const defaultErrTimeout = 20 * time.Second +const successiveErrInterval = time.Second + func NewAccounts() *Accounts { return &Accounts{ - accounts: make(map[string]*Service), + accounts: make(map[string]*smtpAccountState), } } @@ -38,7 +43,10 @@ func (s *Accounts) AddAccount(account *Service) { s.accountsLock.Lock() defer s.accountsLock.Unlock() - s.accounts[account.UserID()] = account + s.accounts[account.UserID()] = &smtpAccountState{ + service: account, + errTimeout: defaultErrTimeout, + } } func (s *Accounts) RemoveAccount(account *Service) { @@ -52,18 +60,18 @@ func (s *Accounts) CheckAuth(user string, password []byte) (string, string, erro s.accountsLock.RLock() defer s.accountsLock.RUnlock() - for id, service := range s.accounts { - addrID, err := service.checkAuth(context.Background(), user, password) + for id, account := range s.accounts { + addrID, err := account.service.checkAuth(context.Background(), user, password) if err != nil { continue } - service.telemetry.ReportSMTPAuthSuccess(context.Background()) + account.service.telemetry.ReportSMTPAuthSuccess(context.Background()) return id, addrID, nil } for _, service := range s.accounts { - service.telemetry.ReportSMTPAuthFailed(user) + service.service.telemetry.ReportSMTPAuthFailed(user) } return "", "", ErrNoSuchUser @@ -77,10 +85,57 @@ func (s *Accounts) SendMail(ctx context.Context, userID, addrID, from string, to s.accountsLock.RLock() defer s.accountsLock.RUnlock() - service, ok := s.accounts[userID] + requestTime := time.Now() + + account, ok := s.accounts[userID] if !ok { return ErrNoSuchUser } - return service.SendMail(ctx, addrID, from, to, r) + if err := account.canMakeRequest(requestTime); err != nil { + return err + } + + err := account.service.SendMail(ctx, addrID, from, to, r) + account.handleSMTPErr(requestTime, err) + + return err +} + +type smtpAccountState struct { + service *Service + errTimeout time.Duration + + errLock sync.Mutex + errCounter int + lastRequest time.Time +} + +func (s *smtpAccountState) canMakeRequest(requestTime time.Time) error { + s.errLock.Lock() + defer s.errLock.Unlock() + + if s.errCounter >= maxFailedCommands { + if requestTime.Sub(s.lastRequest) >= s.errTimeout { + s.errCounter = 0 + return nil + } + + return ErrTooManyErrors + } + + return nil +} + +func (s *smtpAccountState) handleSMTPErr(requestTime time.Time, err error) { + s.errLock.Lock() + defer s.errLock.Unlock() + + if err == nil || requestTime.Sub(s.lastRequest) > successiveErrInterval { + s.errCounter = 0 + } else { + s.errCounter++ + } + + s.lastRequest = requestTime } diff --git a/internal/services/smtp/accounts_test.go b/internal/services/smtp/accounts_test.go new file mode 100644 index 00000000..0593331c --- /dev/null +++ b/internal/services/smtp/accounts_test.go @@ -0,0 +1,46 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package smtp + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAccountTimeout(t *testing.T) { + account := smtpAccountState{errTimeout: 5 * time.Second} + err := errors.New("fail") + + for i := 0; i <= maxFailedCommands; i++ { + requestTime := time.Now() + assert.Nil(t, account.canMakeRequest(requestTime)) + account.handleSMTPErr(requestTime, err) + } + { + requestTime := time.Now() + assert.ErrorIs(t, account.canMakeRequest(requestTime), ErrTooManyErrors) + } + + assert.Eventually(t, func() bool { + requestTime := time.Now() + return account.canMakeRequest(requestTime) == nil + }, 10*time.Second, time.Second) +} diff --git a/internal/services/smtp/errors.go b/internal/services/smtp/errors.go index 7526e1c7..b0221773 100644 --- a/internal/services/smtp/errors.go +++ b/internal/services/smtp/errors.go @@ -22,3 +22,4 @@ import "errors" var ErrInvalidRecipient = errors.New("invalid recipient") var ErrInvalidReturnPath = errors.New("invalid return path") var ErrNoSuchUser = errors.New("no such user") +var ErrTooManyErrors = errors.New("too many failed requests, please try again later")