mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2026-02-04 08:18:34 +00:00
fix(GODT-1623): Throttle SMTP failed requests
If a SMPT client keeps hammering bridge and triggers multiple successive errors in quick succession, force that client to wait 20 seconds before trying again.
This commit is contained in:
@ -21,16 +21,21 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Accounts struct {
|
type Accounts struct {
|
||||||
accountsLock sync.RWMutex
|
accountsLock sync.RWMutex
|
||||||
accounts map[string]*Service
|
accounts map[string]*smtpAccountState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const maxFailedCommands = 3
|
||||||
|
const defaultErrTimeout = 20 * time.Second
|
||||||
|
const successiveErrInterval = time.Second
|
||||||
|
|
||||||
func NewAccounts() *Accounts {
|
func NewAccounts() *Accounts {
|
||||||
return &Accounts{
|
return &Accounts{
|
||||||
accounts: make(map[string]*Service),
|
accounts: make(map[string]*smtpAccountState),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,7 +43,10 @@ func (s *Accounts) AddAccount(account *Service) {
|
|||||||
s.accountsLock.Lock()
|
s.accountsLock.Lock()
|
||||||
defer s.accountsLock.Unlock()
|
defer s.accountsLock.Unlock()
|
||||||
|
|
||||||
s.accounts[account.UserID()] = account
|
s.accounts[account.UserID()] = &smtpAccountState{
|
||||||
|
service: account,
|
||||||
|
errTimeout: defaultErrTimeout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Accounts) RemoveAccount(account *Service) {
|
func (s *Accounts) RemoveAccount(account *Service) {
|
||||||
@ -52,18 +60,18 @@ func (s *Accounts) CheckAuth(user string, password []byte) (string, string, erro
|
|||||||
s.accountsLock.RLock()
|
s.accountsLock.RLock()
|
||||||
defer s.accountsLock.RUnlock()
|
defer s.accountsLock.RUnlock()
|
||||||
|
|
||||||
for id, service := range s.accounts {
|
for id, account := range s.accounts {
|
||||||
addrID, err := service.checkAuth(context.Background(), user, password)
|
addrID, err := account.service.checkAuth(context.Background(), user, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
service.telemetry.ReportSMTPAuthSuccess(context.Background())
|
account.service.telemetry.ReportSMTPAuthSuccess(context.Background())
|
||||||
return id, addrID, nil
|
return id, addrID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, service := range s.accounts {
|
for _, service := range s.accounts {
|
||||||
service.telemetry.ReportSMTPAuthFailed(user)
|
service.service.telemetry.ReportSMTPAuthFailed(user)
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", "", ErrNoSuchUser
|
return "", "", ErrNoSuchUser
|
||||||
@ -77,10 +85,57 @@ func (s *Accounts) SendMail(ctx context.Context, userID, addrID, from string, to
|
|||||||
s.accountsLock.RLock()
|
s.accountsLock.RLock()
|
||||||
defer s.accountsLock.RUnlock()
|
defer s.accountsLock.RUnlock()
|
||||||
|
|
||||||
service, ok := s.accounts[userID]
|
requestTime := time.Now()
|
||||||
|
|
||||||
|
account, ok := s.accounts[userID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrNoSuchUser
|
return ErrNoSuchUser
|
||||||
}
|
}
|
||||||
|
|
||||||
return service.SendMail(ctx, addrID, from, to, r)
|
if err := account.canMakeRequest(requestTime); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := account.service.SendMail(ctx, addrID, from, to, r)
|
||||||
|
account.handleSMTPErr(requestTime, err)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type smtpAccountState struct {
|
||||||
|
service *Service
|
||||||
|
errTimeout time.Duration
|
||||||
|
|
||||||
|
errLock sync.Mutex
|
||||||
|
errCounter int
|
||||||
|
lastRequest time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *smtpAccountState) canMakeRequest(requestTime time.Time) error {
|
||||||
|
s.errLock.Lock()
|
||||||
|
defer s.errLock.Unlock()
|
||||||
|
|
||||||
|
if s.errCounter >= maxFailedCommands {
|
||||||
|
if requestTime.Sub(s.lastRequest) >= s.errTimeout {
|
||||||
|
s.errCounter = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ErrTooManyErrors
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *smtpAccountState) handleSMTPErr(requestTime time.Time, err error) {
|
||||||
|
s.errLock.Lock()
|
||||||
|
defer s.errLock.Unlock()
|
||||||
|
|
||||||
|
if err == nil || requestTime.Sub(s.lastRequest) > successiveErrInterval {
|
||||||
|
s.errCounter = 0
|
||||||
|
} else {
|
||||||
|
s.errCounter++
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lastRequest = requestTime
|
||||||
}
|
}
|
||||||
|
|||||||
46
internal/services/smtp/accounts_test.go
Normal file
46
internal/services/smtp/accounts_test.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
// Copyright (c) 2023 Proton AG
|
||||||
|
//
|
||||||
|
// This file is part of Proton Mail Bridge.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package smtp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccountTimeout(t *testing.T) {
|
||||||
|
account := smtpAccountState{errTimeout: 5 * time.Second}
|
||||||
|
err := errors.New("fail")
|
||||||
|
|
||||||
|
for i := 0; i <= maxFailedCommands; i++ {
|
||||||
|
requestTime := time.Now()
|
||||||
|
assert.Nil(t, account.canMakeRequest(requestTime))
|
||||||
|
account.handleSMTPErr(requestTime, err)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
requestTime := time.Now()
|
||||||
|
assert.ErrorIs(t, account.canMakeRequest(requestTime), ErrTooManyErrors)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
requestTime := time.Now()
|
||||||
|
return account.canMakeRequest(requestTime) == nil
|
||||||
|
}, 10*time.Second, time.Second)
|
||||||
|
}
|
||||||
@ -22,3 +22,4 @@ import "errors"
|
|||||||
var ErrInvalidRecipient = errors.New("invalid recipient")
|
var ErrInvalidRecipient = errors.New("invalid recipient")
|
||||||
var ErrInvalidReturnPath = errors.New("invalid return path")
|
var ErrInvalidReturnPath = errors.New("invalid return path")
|
||||||
var ErrNoSuchUser = errors.New("no such user")
|
var ErrNoSuchUser = errors.New("no such user")
|
||||||
|
var ErrTooManyErrors = errors.New("too many failed requests, please try again later")
|
||||||
|
|||||||
Reference in New Issue
Block a user