From 3ddd88e1278d92adbe667154bcd1bf3d3639fd30 Mon Sep 17 00:00:00 2001 From: Xavier Michelon Date: Thu, 6 Apr 2023 15:08:41 +0200 Subject: [PATCH] feat(GODT-2538): implement smart picking of default IMAP/SMTP ports --- internal/vault/types_settings.go | 7 +- pkg/ports/ports.go | 11 +-- pkg/ports/ports_test.go | 18 ++++- tests/bdd_test.go | 5 +- tests/bridge_test.go | 36 +++++++++ tests/ctx_test.go | 10 +++ tests/features/bridge/default_ports.feature | 24 ++++++ tests/features/{ => bridge}/updates.feature | 0 utils/port-blocker/port-blocker.go | 83 +++++++++++++++++++++ 9 files changed, 182 insertions(+), 12 deletions(-) create mode 100644 tests/features/bridge/default_ports.feature rename tests/features/{ => bridge}/updates.feature (100%) create mode 100644 utils/port-blocker/port-blocker.go diff --git a/internal/vault/types_settings.go b/internal/vault/types_settings.go index 2c385ff3..68c3e6f1 100644 --- a/internal/vault/types_settings.go +++ b/internal/vault/types_settings.go @@ -22,6 +22,7 @@ import ( "runtime" "github.com/ProtonMail/proton-bridge/v3/internal/updater" + "github.com/ProtonMail/proton-bridge/v3/pkg/ports" ) type Settings struct { @@ -70,12 +71,14 @@ func GetDefaultSyncWorkerCount() int { func newDefaultSettings(gluonDir string) Settings { syncWorkers := GetDefaultSyncWorkerCount() + imapPort := ports.FindFreePortFrom(1143) + smtpPort := ports.FindFreePortFrom(1025, imapPort) return Settings{ GluonDir: gluonDir, - IMAPPort: 1143, - SMTPPort: 1025, + IMAPPort: imapPort, + SMTPPort: smtpPort, IMAPSSL: false, SMTPSSL: false, diff --git a/pkg/ports/ports.go b/pkg/ports/ports.go index d5cfbaff..212a3d42 100644 --- a/pkg/ports/ports.go +++ b/pkg/ports/ports.go @@ -22,6 +22,7 @@ import ( "net" "github.com/ProtonMail/proton-bridge/v3/internal/constants" + "golang.org/x/exp/slices" ) const ( @@ -43,19 +44,19 @@ func IsPortFree(port int) bool { func isOccupied(port string) bool { // Try to create server at port. - dummyserver, err := net.Listen("tcp", port) + dummyServer, err := net.Listen("tcp", port) if err != nil { return true } - _ = dummyserver.Close() + _ = dummyServer.Close() return false } -// FindFreePortFrom finds first empty port, starting with `startPort`. -func FindFreePortFrom(startPort int) int { +// FindFreePortFrom finds first empty port, starting with `startPort`, and excluding ports listed in exclude. +func FindFreePortFrom(startPort int, exclude ...int) int { loopedOnce := false freePort := startPort - for !IsPortFree(freePort) { + for slices.Contains(exclude, freePort) || !IsPortFree(freePort) { freePort++ if freePort >= maxPortNumber { freePort = 1 diff --git a/pkg/ports/ports_test.go b/pkg/ports/ports_test.go index a88bbce2..53276be7 100644 --- a/pkg/ports/ports_test.go +++ b/pkg/ports/ports_test.go @@ -32,12 +32,12 @@ func TestFreePort(t *testing.T) { } func TestOccupiedPort(t *testing.T) { - dummyserver, err := net.Listen("tcp", ":"+strconv.Itoa(testPort)) + dummyServer, err := net.Listen("tcp", ":"+strconv.Itoa(testPort)) require.NoError(t, err) require.True(t, !IsPortFree(testPort), "port should be occupied") - _ = dummyserver.Close() + _ = dummyServer.Close() } func TestFindFreePortFromDirectly(t *testing.T) { @@ -46,11 +46,21 @@ func TestFindFreePortFromDirectly(t *testing.T) { } func TestFindFreePortFromNextOne(t *testing.T) { - dummyserver, err := net.Listen("tcp", ":"+strconv.Itoa(testPort)) + dummyServer, err := net.Listen("tcp", ":"+strconv.Itoa(testPort)) require.NoError(t, err) foundPort := FindFreePortFrom(testPort) require.Equal(t, testPort+1, foundPort) - _ = dummyserver.Close() + _ = dummyServer.Close() +} + +func TestFindFreePortExcluding(t *testing.T) { + dummyServer, err := net.Listen("tcp", ":"+strconv.Itoa(testPort)) + require.NoError(t, err) + + foundPort := FindFreePortFrom(testPort, testPort+1, testPort+2) + require.Equal(t, testPort+3, foundPort) + + _ = dummyServer.Close() } diff --git a/tests/bdd_test.go b/tests/bdd_test.go index d854fb12..90c8f817 100644 --- a/tests/bdd_test.go +++ b/tests/bdd_test.go @@ -107,7 +107,10 @@ func TestFeatures(testingT *testing.T) { ctx.Step(`^the header in the "([^"]*)" request to "([^"]*)" has "([^"]*)" set to "([^"]*)"$`, s.theHeaderInTheRequestToHasSetTo) ctx.Step(`^the body in the "([^"]*)" request to "([^"]*)" is:$`, s.theBodyInTheRequestToIs) ctx.Step(`^the API requires bridge version at least "([^"]*)"$`, s.theAPIRequiresBridgeVersion) - + ctx.Step(`^the network port (\d+) is busy$`, s.networkPortIsBusy) + ctx.Step(`^the network port range (\d+)-(\d+) is busy$`, s.networkPortRangeIsBusy) + ctx.Step(`^bridge IMAP port is (\d+)`, s.bridgeIMAPPortIs) + ctx.Step(`^bridge SMTP port is (\d+)`, s.bridgeSMTPPortIs) // ==== SETUP ==== ctx.Step(`^there exists an account with username "([^"]*)" and password "([^"]*)"$`, s.thereExistsAnAccountWithUsernameAndPassword) ctx.Step(`^there exists a disabled account with username "([^"]*)" and password "([^"]*)"$`, s.thereExistsAnAccountWithUsernameAndPasswordWithDisablePrimary) diff --git a/tests/bridge_test.go b/tests/bridge_test.go index 14632013..40a4667e 100644 --- a/tests/bridge_test.go +++ b/tests/bridge_test.go @@ -21,7 +21,9 @@ import ( "context" "errors" "fmt" + "net" "os" + "strconv" "time" "github.com/Masterminds/semver/v3" @@ -307,3 +309,37 @@ func (s *scenario) theUserHidesAllMail() error { func (s *scenario) theUserShowsAllMail() error { return s.t.bridge.SetShowAllMail(true) } + +func (s *scenario) networkPortIsBusy(port int) { + if listener, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(port)); err == nil { // we ignore errors. Most likely port is already busy. + s.t.dummyListeners = append(s.t.dummyListeners, listener) + } +} + +func (s *scenario) networkPortRangeIsBusy(startPort, endPort int) { + if startPort > endPort { + startPort, endPort = endPort, startPort + } + + for port := startPort; port <= endPort; port++ { + s.networkPortIsBusy(port) + } +} + +func (s *scenario) bridgeIMAPPortIs(expectedPort int) error { + actualPort := s.t.bridge.GetIMAPPort() + if actualPort != expectedPort { + return fmt.Errorf("expected IMAP port to be %v but got %v", expectedPort, actualPort) + } + + return nil +} + +func (s *scenario) bridgeSMTPPortIs(expectedPort int) error { + actualPort := s.t.bridge.GetSMTPPort() + if actualPort != expectedPort { + return fmt.Errorf("expected SMTP port to be %v but got %v", expectedPort, actualPort) + } + + return nil +} diff --git a/tests/ctx_test.go b/tests/ctx_test.go index 13bff820..004ab501 100644 --- a/tests/ctx_test.go +++ b/tests/ctx_test.go @@ -20,6 +20,7 @@ package tests import ( "context" "fmt" + "net" "net/smtp" "net/url" "regexp" @@ -160,6 +161,9 @@ type testCtx struct { // errors holds test-related errors encountered while running test steps. errors [][]error errorsLock sync.RWMutex + + // This slice contains the dummy listeners that are intended to block network ports. + dummyListeners []net.Listener } type imapClient struct { @@ -437,6 +441,12 @@ func (t *testCtx) close(ctx context.Context) { } } + for _, listener := range t.dummyListeners { + if err := listener.Close(); err != nil { + logrus.WithError(err).Errorf("Failed to close dummy listener %v", listener.Addr()) + } + } + t.api.Close() t.events.close() t.reporter.close() diff --git a/tests/features/bridge/default_ports.feature b/tests/features/bridge/default_ports.feature new file mode 100644 index 00000000..4c7f3703 --- /dev/null +++ b/tests/features/bridge/default_ports.feature @@ -0,0 +1,24 @@ +Feature: Bridge picks default ports wisely + + Scenario: bridge picks ports for IMAP and SMTP using default values. + When bridge starts + Then bridge IMAP port is 1143 + Then bridge SMTP port is 1025 + + Scenario: bridge picks ports for IMAP wisely when default port is busy. + When the network port 1143 is busy + And bridge starts + Then bridge IMAP port is 1144 + Then bridge SMTP port is 1025 + + Scenario: bridge picks ports for SMTP wisely when default port is busy. + When the network port range 1025-1030 is busy + And bridge starts + Then bridge IMAP port is 1143 + Then bridge SMTP port is 1031 + + Scenario: bridge picks ports for IMAP SMTP wisely when default ports are busy. + When the network port range 1025-1200 is busy + And bridge starts + Then bridge IMAP port is 1201 + Then bridge SMTP port is 1202 diff --git a/tests/features/updates.feature b/tests/features/bridge/updates.feature similarity index 100% rename from tests/features/updates.feature rename to tests/features/bridge/updates.feature diff --git a/utils/port-blocker/port-blocker.go b/utils/port-blocker/port-blocker.go new file mode 100644 index 00000000..9a358f44 --- /dev/null +++ b/utils/port-blocker/port-blocker.go @@ -0,0 +1,83 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +// port-blocker is a command-line that ensure a port or range of ports is occupied by creating listeners. +package main + +import ( + "fmt" + "net" + "os" + "strconv" +) + +func main() { + argCount := len(os.Args) + if (len(os.Args) < 2) || (argCount > 3) { + exitWithUsage("Invalid number of arguments.") + } + + startPort := parsePort(os.Args[1]) + endPort := startPort + if argCount == 3 { + endPort = parsePort(os.Args[2]) + } + + runBlocker(startPort, endPort) +} + +func parsePort(portString string) int { + result, err := strconv.Atoi(portString) + if err != nil { + exitWithUsage(fmt.Sprintf("Invalid port '%v'.", portString)) + } + + if (result < 1024) || (result > 65535) { // ports below 1024 are reserved. + exitWithUsage("Ports must be in the range [1024-65535].") + } + + return result +} + +func exitWithUsage(message string) { + fmt.Printf("Usage: port-blocker []\n") + if len(message) > 0 { + fmt.Println(message) + } + os.Exit(1) +} + +func runBlocker(startPort, endPort int) { + if endPort < startPort { + exitWithUsage("startPort must be less than or equal to endPort.") + } + + for port := startPort; port <= endPort; port++ { + listener, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(port)) + if err != nil { + fmt.Printf("Port %v is already blocked. Skipping.\n", port) + } else { + //goland:noinspection GoDeferInLoop + defer func() { + _ = listener.Close() + }() + } + } + + fmt.Println("Blocking requested ports. Press enter to exit.") + _, _ = fmt.Scanln() +}