chore: merge release/Rialto into devel

This commit is contained in:
Jakub
2023-05-23 15:53:29 +02:00
144 changed files with 2947 additions and 1145 deletions

View File

@ -19,14 +19,12 @@ package app
import (
"fmt"
"math/rand"
"net/http"
"net/http/cookiejar"
"net/url"
"os"
"path/filepath"
"runtime"
"time"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/async"
@ -160,9 +158,6 @@ func New() *cli.App {
}
func run(c *cli.Context) error {
// Seed the default RNG from the math/rand package.
rand.Seed(time.Now().UnixNano())
// Get the current bridge version.
version, err := semver.NewVersion(constants.Version)
if err != nil {

View File

@ -29,7 +29,6 @@ import (
"time"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon"
"github.com/ProtonMail/gluon/async"
imapEvents "github.com/ProtonMail/gluon/events"
"github.com/ProtonMail/gluon/imap"
@ -45,7 +44,6 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/user"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-smtp"
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
)
@ -67,13 +65,7 @@ type Bridge struct {
tlsConfig *tls.Config
// imapServer is the bridge's IMAP server.
imapServer *gluon.Server
imapListener net.Listener
imapEventCh chan imapEvents.Event
// smtpServer is the bridge's SMTP server.
smtpServer *smtp.Server
smtpListener net.Listener
imapEventCh chan imapEvents.Event
// updater is the bridge's updater.
updater Updater
@ -134,6 +126,8 @@ type Bridge struct {
goHeartbeat func()
uidValidityGenerator imap.UIDValidityGenerator
serverManager *ServerManager
}
// New creates a new bridge.
@ -224,16 +218,6 @@ func newBridge(
return nil, fmt.Errorf("failed to load TLS config: %w", err)
}
gluonCacheDir, err := getGluonDir(vault)
if err != nil {
return nil, fmt.Errorf("failed to get Gluon directory: %w", err)
}
gluonDataDir, err := locator.ProvideGluonDataPath()
if err != nil {
return nil, fmt.Errorf("failed to get Gluon Database directory: %w", err)
}
firstStart := vault.GetFirstStart()
if err := vault.SetFirstStart(false); err != nil {
return nil, fmt.Errorf("failed to save first start indicator: %w", err)
@ -246,23 +230,6 @@ func newBridge(
identifier.SetClientString(vault.GetLastUserAgent())
imapServer, err := newIMAPServer(
gluonCacheDir,
gluonDataDir,
curVersion,
tlsConfig,
reporter,
logIMAPClient,
logIMAPServer,
imapEventCh,
tasks,
uidValidityGenerator,
panicHandler,
)
if err != nil {
return nil, fmt.Errorf("failed to create IMAP server: %w", err)
}
focusService, err := focus.NewService(locator, curVersion, panicHandler)
if err != nil {
return nil, fmt.Errorf("failed to create focus service: %w", err)
@ -279,7 +246,6 @@ func newBridge(
identifier: identifier,
tlsConfig: tlsConfig,
imapServer: imapServer,
imapEventCh: imapEventCh,
updater: updater,
@ -306,9 +272,13 @@ func newBridge(
tasks: tasks,
uidValidityGenerator: uidValidityGenerator,
serverManager: newServerManager(),
}
bridge.smtpServer = newSMTPServer(bridge, tlsConfig, logSMTP)
if err := bridge.serverManager.Init(bridge); err != nil {
return nil, err
}
return bridge, nil
}
@ -381,10 +351,6 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
})
})
// We need to load users before we can start the IMAP and SMTP servers.
// We must only start the servers once.
var once sync.Once
// Attempt to load users from the vault when triggered.
bridge.goLoad = bridge.tasks.Trigger(func(ctx context.Context) {
if err := bridge.loadUsers(ctx); err != nil {
@ -396,17 +362,6 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
}
bridge.publish(events.AllUsersLoaded{})
// Once all users have been loaded, start the bridge's IMAP and SMTP servers.
once.Do(func() {
if err := bridge.serveIMAP(); err != nil {
logrus.WithError(err).Error("Failed to start IMAP server")
}
if err := bridge.serveSMTP(); err != nil {
logrus.WithError(err).Error("Failed to start SMTP server")
}
})
})
defer bridge.goLoad()
@ -452,18 +407,13 @@ func (bridge *Bridge) GetErrors() []error {
func (bridge *Bridge) Close(ctx context.Context) {
logrus.Info("Closing bridge")
// Close the IMAP server.
if err := bridge.closeIMAP(ctx); err != nil {
logrus.WithError(err).Error("Failed to close IMAP server")
}
// Close the SMTP server.
if err := bridge.closeSMTP(); err != nil {
logrus.WithError(err).Error("Failed to close SMTP server")
// Close the servers
if err := bridge.serverManager.CloseServers(ctx); err != nil {
logrus.WithError(err).Error("Failed to close servers")
}
// Close all users.
safe.RLock(func() {
safe.Lock(func() {
for _, user := range bridge.users {
user.Close()
}

View File

@ -50,7 +50,6 @@ import (
"github.com/ProtonMail/proton-bridge/v3/tests"
"github.com/bradenaw/juniper/xslices"
imapid "github.com/emersion/go-imap-id"
"github.com/emersion/go-imap/client"
"github.com/stretchr/testify/require"
)
@ -173,11 +172,27 @@ func TestBridge_UserAgent(t *testing.T) {
func TestBridge_UserAgent_Persistence(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
otherPassword := []byte("bar")
otherUser := "foo"
_, _, err := s.CreateUser(otherUser, otherPassword)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(b)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(b)
defer smtpWaiter.Done()
require.NoError(t, getErr(b.LoginFull(ctx, otherUser, otherPassword, nil, nil)))
imapWaiter.Wait()
smtpWaiter.Wait()
currentUserAgent := b.GetCurrentUserAgent()
require.Contains(t, currentUserAgent, vault.DefaultUserAgent)
imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
imapClient, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
defer func() { _ = imapClient.Logout() }()
@ -220,8 +235,24 @@ func TestBridge_UserAgentFromIMAPID(t *testing.T) {
calls = append(calls, call)
})
otherPassword := []byte("bar")
otherUser := "foo"
_, _, err := s.CreateUser(otherUser, otherPassword)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
imapWaiter := waitForIMAPServerReady(b)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(b)
defer smtpWaiter.Done()
require.NoError(t, getErr(b.LoginFull(ctx, otherUser, otherPassword, nil, nil)))
imapWaiter.Wait()
smtpWaiter.Wait()
imapClient, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
defer func() { _ = imapClient.Logout() }()
@ -592,10 +623,22 @@ func TestBridge_InitGluonDirectory(t *testing.T) {
func TestBridge_LoginFailed(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
failCh, done := chToType[events.Event, events.IMAPLoginFailed](bridge.GetEvents(events.IMAPLoginFailed{}))
defer done()
imapClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
imapWaiter.Wait()
smtpWaiter.Wait()
imapClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
require.NoError(t, err)
require.Error(t, imapClient.Login("badUser", "badPass"))
@ -622,6 +665,12 @@ func TestBridge_ChangeCacheDirectory(t *testing.T) {
configDir, err := b.GetGluonDataDir()
require.NoError(t, err)
imapWaiter := waitForIMAPServerReady(b)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(b)
defer smtpWaiter.Done()
// Login the user.
syncCh, done := chToType[events.Event, events.SyncFinished](b.GetEvents(events.SyncFinished{}))
defer done()
@ -655,7 +704,10 @@ func TestBridge_ChangeCacheDirectory(t *testing.T) {
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
imapWaiter.Wait()
smtpWaiter.Wait()
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
@ -695,7 +747,7 @@ func TestBridge_ChangeAddressOrder(t *testing.T) {
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
@ -716,7 +768,7 @@ func TestBridge_ChangeAddressOrder(t *testing.T) {
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
@ -778,6 +830,7 @@ func withBridgeNoMocks(
locator bridge.Locator,
vaultKey []byte,
tests func(*bridge.Bridge),
waitOnServers bool,
) {
// Bridge will disable the proxy by default at startup.
mocks.ProxyCtl.EXPECT().DisallowProxy()
@ -828,14 +881,17 @@ func withBridgeNoMocks(
// Wait for bridge to finish loading users.
waitForEvent(t, eventCh, events.AllUsersLoaded{})
// Wait for bridge to start the IMAP server.
waitForEvent(t, eventCh, events.IMAPServerReady{})
// Wait for bridge to start the SMTP server.
waitForEvent(t, eventCh, events.SMTPServerReady{})
// Set random IMAP and SMTP ports for the tests.
require.NoError(t, bridge.SetIMAPPort(0))
require.NoError(t, bridge.SetSMTPPort(0))
require.NoError(t, bridge.SetIMAPPort(ctx, 0))
require.NoError(t, bridge.SetSMTPPort(ctx, 0))
if waitOnServers {
// Wait for bridge to start the IMAP server.
waitForEvent(t, eventCh, events.IMAPServerReady{})
// Wait for bridge to start the SMTP server.
waitForEvent(t, eventCh, events.SMTPServerReady{})
}
// Close the bridge when done.
defer bridge.Close(ctx)
@ -857,7 +913,24 @@ func withBridge(
withMocks(t, func(mocks *bridge.Mocks) {
withBridgeNoMocks(ctx, t, mocks, apiURL, netCtl, locator, vaultKey, func(bridge *bridge.Bridge) {
tests(bridge, mocks)
})
}, false)
})
}
// withBridgeWaitForServers is the same as withBridge, but it will wait until IMAP & SMTP servers are ready.
func withBridgeWaitForServers(
ctx context.Context,
t *testing.T,
apiURL string,
netCtl *proton.NetCtl,
locator bridge.Locator,
vaultKey []byte,
tests func(*bridge.Bridge, *bridge.Mocks),
) {
withMocks(t, func(mocks *bridge.Mocks) {
withBridgeNoMocks(ctx, t, mocks, apiURL, netCtl, locator, vaultKey, func(bridge *bridge.Bridge) {
tests(bridge, mocks)
}, true)
})
}
@ -910,3 +983,48 @@ func chToType[In, Out any](inCh <-chan In, done func()) (<-chan Out, func()) {
return outCh, done
}
type eventWaiter struct {
evtCh <-chan events.Event
cancel func()
}
func (e *eventWaiter) Done() {
e.cancel()
}
func (e *eventWaiter) Wait() {
<-e.evtCh
}
func waitForSMTPServerReady(b *bridge.Bridge) *eventWaiter {
evtCh, cancel := b.GetEvents(events.SMTPServerReady{})
return &eventWaiter{
evtCh: evtCh,
cancel: cancel,
}
}
func waitForSMTPServerStopped(b *bridge.Bridge) *eventWaiter {
evtCh, cancel := b.GetEvents(events.SMTPServerStopped{})
return &eventWaiter{
evtCh: evtCh,
cancel: cancel,
}
}
func waitForIMAPServerReady(b *bridge.Bridge) *eventWaiter {
evtCh, cancel := b.GetEvents(events.IMAPServerReady{})
return &eventWaiter{
evtCh: evtCh,
cancel: cancel,
}
}
func waitForIMAPServerStopped(b *bridge.Bridge) *eventWaiter {
evtCh, cancel := b.GetEvents(events.IMAPServerStopped{})
return &eventWaiter{
evtCh: evtCh,
cancel: cancel,
}
}

View File

@ -18,6 +18,7 @@
package bridge
import (
"context"
"strings"
"github.com/ProtonMail/proton-bridge/v3/internal/clientconfig"
@ -31,7 +32,7 @@ import (
// ConfigureAppleMail configures apple mail for the given userID and address.
// If configuring apple mail for Catalina or newer, it ensures Bridge is using SSL.
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
func (bridge *Bridge) ConfigureAppleMail(ctx context.Context, userID, address string) error {
logrus.WithFields(logrus.Fields{
"userID": userID,
"address": logging.Sensitive(address),
@ -56,7 +57,7 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
}
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
if err := bridge.SetSMTPSSL(true); err != nil {
if err := bridge.SetSMTPSSL(ctx, true); err != nil {
return err
}
}

View File

@ -58,11 +58,7 @@ func moveFile(from, to string) error {
return err
}
if err := os.Rename(from, to); err != nil {
return err
}
return nil
return os.Rename(from, to)
}
func copyDir(from, to string) error {

View File

@ -20,7 +20,6 @@ package bridge
import (
"context"
"crypto/tls"
"fmt"
"io"
"os"
"path/filepath"
@ -37,203 +36,21 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/user"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/sirupsen/logrus"
)
func (bridge *Bridge) serveIMAP() error {
port, err := func() (int, error) {
if bridge.imapServer == nil {
return 0, fmt.Errorf("no IMAP server instance running")
}
logrus.WithFields(logrus.Fields{
"port": bridge.vault.GetIMAPPort(),
"ssl": bridge.vault.GetIMAPSSL(),
}).Info("Starting IMAP server")
imapListener, err := newListener(bridge.vault.GetIMAPPort(), bridge.vault.GetIMAPSSL(), bridge.tlsConfig)
if err != nil {
return 0, fmt.Errorf("failed to create IMAP listener: %w", err)
}
bridge.imapListener = imapListener
if err := bridge.imapServer.Serve(context.Background(), bridge.imapListener); err != nil {
return 0, fmt.Errorf("failed to serve IMAP: %w", err)
}
if err := bridge.vault.SetIMAPPort(getPort(imapListener.Addr())); err != nil {
return 0, fmt.Errorf("failed to store IMAP port in vault: %w", err)
}
return getPort(imapListener.Addr()), nil
}()
if err != nil {
bridge.publish(events.IMAPServerError{
Error: err,
})
return err
}
bridge.publish(events.IMAPServerReady{
Port: port,
})
return nil
}
func (bridge *Bridge) restartIMAP() error {
logrus.Info("Restarting IMAP server")
if bridge.imapListener != nil {
if err := bridge.imapListener.Close(); err != nil {
return fmt.Errorf("failed to close IMAP listener: %w", err)
}
bridge.publish(events.IMAPServerStopped{})
}
return bridge.serveIMAP()
}
func (bridge *Bridge) closeIMAP(ctx context.Context) error {
logrus.Info("Closing IMAP server")
if bridge.imapServer != nil {
if err := bridge.imapServer.Close(ctx); err != nil {
return fmt.Errorf("failed to close IMAP server: %w", err)
}
bridge.imapServer = nil
}
if bridge.imapListener != nil {
if err := bridge.imapListener.Close(); err != nil {
return fmt.Errorf("failed to close IMAP listener: %w", err)
}
}
bridge.publish(events.IMAPServerStopped{})
return nil
func (bridge *Bridge) restartIMAP(ctx context.Context) error {
return bridge.serverManager.RestartIMAP(ctx)
}
// addIMAPUser connects the given user to gluon.
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
if bridge.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
imapConn, err := user.NewIMAPConnectors()
if err != nil {
return fmt.Errorf("failed to create IMAP connectors: %w", err)
}
for addrID, imapConn := range imapConn {
log := logrus.WithFields(logrus.Fields{
"userID": user.ID(),
"addrID": addrID,
})
if gluonID, ok := user.GetGluonID(addrID); ok {
log.WithField("gluonID", gluonID).Info("Loading existing IMAP user")
// Load the user, checking whether the DB was newly created.
isNew, err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to load IMAP user: %w", err)
}
if isNew {
// If the DB was newly created, clear the sync status; gluon's DB was not found.
logrus.Warn("IMAP user DB was newly created, clearing sync status")
// Remove the user from IMAP so we can clear the sync status.
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
// Clear the sync status -- we need to resync all messages.
if err := user.ClearSyncStatus(); err != nil {
return fmt.Errorf("failed to clear sync status: %w", err)
}
// Add the user back to the IMAP server.
if isNew, err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
} else if isNew {
panic("IMAP user should already have a database")
}
} else if status := user.GetSyncStatus(); !status.HasLabels {
// Otherwise, the DB already exists -- if the labels are not yet synced, we need to re-create the DB.
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
return fmt.Errorf("failed to remove old IMAP user: %w", err)
}
if err := user.RemoveGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to remove old IMAP user ID: %w", err)
}
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
log.WithField("gluonID", gluonID).Info("Re-created IMAP user")
}
} else {
log.Info("Creating new IMAP user")
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
log.WithField("gluonID", gluonID).Info("Created new IMAP user")
}
}
// Trigger a sync for the user, if needed.
user.TriggerSync()
return nil
return bridge.serverManager.AddIMAPUser(ctx, user)
}
// removeIMAPUser disconnects the given user from gluon, optionally also removing its files.
func (bridge *Bridge) removeIMAPUser(ctx context.Context, user *user.User, withData bool) error {
if bridge.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
logrus.WithFields(logrus.Fields{
"userID": user.ID(),
"withData": withData,
}).Debug("Removing IMAP user")
for addrID, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withData); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
if withData {
if err := user.RemoveGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to remove IMAP user ID: %w", err)
}
}
}
return nil
return bridge.serverManager.RemoveIMAPUser(ctx, user, withData)
}
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
@ -262,19 +79,12 @@ func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
logrus.WithFields(logrus.Fields{
"sessionID": event.SessionID,
"username": event.Username,
}).Info("Received IMAP login failure notification")
"pkg": "imap",
}).Error("Incorrect login credentials.")
bridge.publish(events.IMAPLoginFailed{Username: event.Username})
}
}
func getGluonDir(encVault *vault.Vault) (string, error) {
if err := os.MkdirAll(encVault.GetGluonCacheDir(), 0o700); err != nil {
return "", fmt.Errorf("failed to create gluon dir: %w", err)
}
return encVault.GetGluonCacheDir(), nil
}
func ApplyGluonCachePathSuffix(basePath string) string {
return filepath.Join(basePath, "backend", "store")
}

View File

@ -144,13 +144,13 @@ func (testUpdater *TestUpdater) SetLatestVersion(version, minAuto *semver.Versio
}
}
func (testUpdater *TestUpdater) GetVersionInfo(ctx context.Context, downloader updater.Downloader, channel updater.Channel) (updater.VersionInfo, error) {
func (testUpdater *TestUpdater) GetVersionInfo(_ context.Context, _ updater.Downloader, _ updater.Channel) (updater.VersionInfo, error) {
testUpdater.lock.RLock()
defer testUpdater.lock.RUnlock()
return testUpdater.latest, nil
}
func (testUpdater *TestUpdater) InstallUpdate(ctx context.Context, downloader updater.Downloader, update updater.VersionInfo) error {
func (testUpdater *TestUpdater) InstallUpdate(_ context.Context, _ updater.Downloader, _ updater.VersionInfo) error {
return nil
}

View File

@ -28,7 +28,6 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/bradenaw/juniper/iterator"
"github.com/emersion/go-imap/client"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
@ -66,7 +65,7 @@ func TestBridge_Refresh(t *testing.T) {
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
@ -99,7 +98,7 @@ func TestBridge_Refresh(t *testing.T) {
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()

View File

@ -34,7 +34,6 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/emersion/go-imap"
"github.com/emersion/go-imap/client"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/stretchr/testify/require"
@ -46,12 +45,17 @@ func TestBridge_Send(t *testing.T) {
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
senderUserID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
recipientUserID, err := bridge.LoginFull(ctx, "recipient", password, nil, nil)
require.NoError(t, err)
smtpWaiter.Wait()
senderInfo, err := bridge.GetUserInfo(senderUserID)
require.NoError(t, err)
@ -91,13 +95,13 @@ func TestBridge_Send(t *testing.T) {
}
// Connect the sender IMAP client.
senderIMAPClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
senderIMAPClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
require.NoError(t, err)
require.NoError(t, senderIMAPClient.Login(senderInfo.Addresses[0], string(senderInfo.BridgePass)))
defer senderIMAPClient.Logout() //nolint:errcheck
// Connect the recipient IMAP client.
recipientIMAPClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
recipientIMAPClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
require.NoError(t, err)
require.NoError(t, recipientIMAPClient.Login(recipientInfo.Addresses[0], string(recipientInfo.BridgePass)))
defer recipientIMAPClient.Logout() //nolint:errcheck
@ -135,13 +139,13 @@ func TestBridge_SendDraftFlags(t *testing.T) {
})
// Start the bridge.
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
// Get the sender user info.
userInfo, err := bridge.QueryUserInfo(username)
require.NoError(t, err)
// Connect the sender IMAP client.
imapClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
imapClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
require.NoError(t, err)
require.NoError(t, imapClient.Login(userInfo.Addresses[0], string(userInfo.BridgePass)))
defer imapClient.Logout() //nolint:errcheck
@ -245,13 +249,13 @@ func TestBridge_SendInvite(t *testing.T) {
})
// Start the bridge.
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
// Get the sender user info.
userInfo, err := bridge.QueryUserInfo(username)
require.NoError(t, err)
// Connect the sender IMAP client.
imapClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
imapClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
require.NoError(t, err)
require.NoError(t, imapClient.Login(userInfo.Addresses[0], string(userInfo.BridgePass)))
defer imapClient.Logout() //nolint:errcheck
@ -401,6 +405,9 @@ SGVsbG8gd29ybGQK
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
senderUserID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
@ -420,6 +427,8 @@ SGVsbG8gd29ybGQK
messageMultipartWithoutTextWithTextAttachment,
}
smtpWaiter.Wait()
for _, m := range messages {
// Dial the server.
client, err := smtp.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetSMTPPort())))
@ -444,13 +453,13 @@ SGVsbG8gd29ybGQK
}
// Connect the sender IMAP client.
senderIMAPClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
senderIMAPClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
require.NoError(t, err)
require.NoError(t, senderIMAPClient.Login(senderInfo.Addresses[0], string(senderInfo.BridgePass)))
defer senderIMAPClient.Logout() //nolint:errcheck
// Connect the recipient IMAP client.
recipientIMAPClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
recipientIMAPClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
require.NoError(t, err)
require.NoError(t, recipientIMAPClient.Login(recipientInfo.Addresses[0], string(recipientInfo.BridgePass)))
defer recipientIMAPClient.Logout() //nolint:errcheck

View File

@ -36,6 +36,9 @@ import (
func TestBridge_Report(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(b)
defer imapWaiter.Done()
syncCh, done := chToType[events.Event, events.SyncFinished](b.GetEvents(events.SyncFinished{}))
defer done()
@ -51,6 +54,8 @@ func TestBridge_Report(t *testing.T) {
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
imapWaiter.Wait()
// Dial the IMAP port.
conn, err := net.Dial("tcp", fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)

View File

@ -0,0 +1,696 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import (
"context"
"fmt"
"net"
"path/filepath"
"github.com/ProtonMail/gluon"
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/user"
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
"github.com/emersion/go-smtp"
"github.com/sirupsen/logrus"
)
// ServerManager manages the IMAP & SMTP servers and their listeners.
type ServerManager struct {
requests *cpc.CPC
imapServer *gluon.Server
imapListener net.Listener
smtpServer *smtp.Server
smtpListener net.Listener
loadedUserCount int
}
func newServerManager() *ServerManager {
return &ServerManager{
requests: cpc.NewCPC(),
}
}
func (sm *ServerManager) Init(bridge *Bridge) error {
imapServer, err := createIMAPServer(bridge)
if err != nil {
return err
}
smtpServer := createSMTPServer(bridge)
sm.imapServer = imapServer
sm.smtpServer = smtpServer
bridge.tasks.Once(func(ctx context.Context) {
logging.DoAnnotated(ctx, func(ctx context.Context) {
sm.run(ctx, bridge)
}, logging.Labels{
"service": "server-manager",
})
})
return nil
}
func (sm *ServerManager) CloseServers(ctx context.Context) error {
defer sm.requests.Close()
_, err := sm.requests.Send(ctx, &smRequestClose{})
return err
}
func (sm *ServerManager) RestartIMAP(ctx context.Context) error {
_, err := sm.requests.Send(ctx, &smRequestRestartIMAP{})
return err
}
func (sm *ServerManager) RestartSMTP(ctx context.Context) error {
_, err := sm.requests.Send(ctx, &smRequestRestartSMTP{})
return err
}
func (sm *ServerManager) AddIMAPUser(ctx context.Context, user *user.User) error {
_, err := sm.requests.Send(ctx, &smRequestAddIMAPUser{user: user})
return err
}
func (sm *ServerManager) RemoveIMAPUser(ctx context.Context, user *user.User, withData bool) error {
_, err := sm.requests.Send(ctx, &smRequestRemoveIMAPUser{
user: user,
withData: withData,
})
return err
}
func (sm *ServerManager) SetGluonDir(ctx context.Context, gluonDir string) error {
_, err := sm.requests.Send(ctx, &smRequestSetGluonDir{
dir: gluonDir,
})
return err
}
func (sm *ServerManager) AddGluonUser(ctx context.Context, conn connector.Connector, passphrase []byte) (string, error) {
reply, err := cpc.SendTyped[string](ctx, sm.requests, &smRequestAddGluonUser{
conn: conn,
passphrase: passphrase,
})
return reply, err
}
func (sm *ServerManager) RemoveGluonUser(ctx context.Context, gluonID string) error {
_, err := sm.requests.Send(ctx, &smRequestRemoveGluonUser{
userID: gluonID,
})
return err
}
func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) {
eventCh, cancel := bridge.GetEvents()
defer cancel()
for {
select {
case <-ctx.Done():
sm.handleClose(ctx, bridge)
return
case evt := <-eventCh:
switch evt.(type) {
case events.ConnStatusDown:
logrus.Info("Server Manager, network down stopping listeners")
if err := sm.closeSMTPServer(bridge); err != nil {
logrus.WithError(err).Error("Failed to close SMTP server")
}
if err := sm.stopIMAPListener(bridge); err != nil {
logrus.WithError(err)
}
case events.ConnStatusUp:
logrus.Info("Server Manager, network up starting listeners")
sm.handleLoadedUserCountChange(ctx, bridge)
}
case request, ok := <-sm.requests.ReceiveCh():
if !ok {
return
}
switch r := request.Value().(type) {
case *smRequestClose:
sm.handleClose(ctx, bridge)
request.Reply(ctx, nil, nil)
return
case *smRequestRestartSMTP:
err := sm.restartSMTP(bridge)
request.Reply(ctx, nil, err)
case *smRequestRestartIMAP:
err := sm.restartIMAP(ctx, bridge)
request.Reply(ctx, nil, err)
case *smRequestAddIMAPUser:
err := sm.handleAddIMAPUser(ctx, r.user)
request.Reply(ctx, nil, err)
if err == nil {
sm.loadedUserCount++
sm.handleLoadedUserCountChange(ctx, bridge)
}
case *smRequestRemoveIMAPUser:
err := sm.handleRemoveIMAPUser(ctx, r.user, r.withData)
request.Reply(ctx, nil, err)
if err == nil {
sm.loadedUserCount--
sm.handleLoadedUserCountChange(ctx, bridge)
}
case *smRequestSetGluonDir:
err := sm.handleSetGluonDir(ctx, bridge, r.dir)
request.Reply(ctx, nil, err)
case *smRequestAddGluonUser:
id, err := sm.handleAddGluonUser(ctx, r.conn, r.passphrase)
request.Reply(ctx, id, err)
case *smRequestRemoveGluonUser:
err := sm.handleRemoveGluonUser(ctx, r.userID)
request.Reply(ctx, nil, err)
}
}
}
}
func (sm *ServerManager) handleLoadedUserCountChange(ctx context.Context, bridge *Bridge) {
logrus.Infof("Validating Listener State %v", sm.loadedUserCount)
if sm.shouldStartServers() {
if sm.imapListener == nil {
if err := sm.serveIMAP(ctx, bridge); err != nil {
logrus.WithError(err).Error("Failed to start IMAP server")
}
}
if sm.smtpListener == nil {
if err := sm.restartSMTP(bridge); err != nil {
logrus.WithError(err).Error("Failed to start SMTP server")
}
}
} else {
if sm.imapListener != nil {
if err := sm.stopIMAPListener(bridge); err != nil {
logrus.WithError(err).Error("Failed to stop IMAP server")
}
}
if sm.smtpListener != nil {
if err := sm.closeSMTPServer(bridge); err != nil {
logrus.WithError(err).Error("Failed to stop SMTP server")
}
}
}
}
func (sm *ServerManager) handleClose(ctx context.Context, bridge *Bridge) {
// Close the IMAP server.
if err := sm.closeIMAPServer(ctx, bridge); err != nil {
logrus.WithError(err).Error("Failed to close IMAP server")
}
// Close the SMTP server.
if err := sm.closeSMTPServer(bridge); err != nil {
logrus.WithError(err).Error("Failed to close SMTP server")
}
}
func (sm *ServerManager) handleAddIMAPUser(ctx context.Context, user *user.User) error {
if sm.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
imapConn, err := user.NewIMAPConnectors()
if err != nil {
return fmt.Errorf("failed to create IMAP connectors: %w", err)
}
for addrID, imapConn := range imapConn {
log := logrus.WithFields(logrus.Fields{
"userID": user.ID(),
"addrID": addrID,
})
if gluonID, ok := user.GetGluonID(addrID); ok {
log.WithField("gluonID", gluonID).Info("Loading existing IMAP user")
// Load the user, checking whether the DB was newly created.
isNew, err := sm.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to load IMAP user: %w", err)
}
if isNew {
// If the DB was newly created, clear the sync status; gluon's DB was not found.
logrus.Warn("IMAP user DB was newly created, clearing sync status")
// Remove the user from IMAP so we can clear the sync status.
if err := sm.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
// Clear the sync status -- we need to resync all messages.
if err := user.ClearSyncStatus(); err != nil {
return fmt.Errorf("failed to clear sync status: %w", err)
}
// Add the user back to the IMAP server.
if isNew, err := sm.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
} else if isNew {
panic("IMAP user should already have a database")
}
} else if status := user.GetSyncStatus(); !status.HasLabels {
// Otherwise, the DB already exists -- if the labels are not yet synced, we need to re-create the DB.
if err := sm.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
return fmt.Errorf("failed to remove old IMAP user: %w", err)
}
if err := user.RemoveGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to remove old IMAP user ID: %w", err)
}
gluonID, err := sm.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
log.WithField("gluonID", gluonID).Info("Re-created IMAP user")
}
} else {
log.Info("Creating new IMAP user")
gluonID, err := sm.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
log.WithField("gluonID", gluonID).Info("Created new IMAP user")
}
}
// Trigger a sync for the user, if needed.
user.TriggerSync()
return nil
}
func (sm *ServerManager) handleRemoveIMAPUser(ctx context.Context, user *user.User, withData bool) error {
if sm.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
logrus.WithFields(logrus.Fields{
"userID": user.ID(),
"withData": withData,
}).Debug("Removing IMAP user")
for addrID, gluonID := range user.GetGluonIDs() {
if err := sm.imapServer.RemoveUser(ctx, gluonID, withData); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
if withData {
if err := user.RemoveGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to remove IMAP user ID: %w", err)
}
}
}
return nil
}
func createIMAPServer(bridge *Bridge) (*gluon.Server, error) {
gluonDataDir, err := bridge.GetGluonDataDir()
if err != nil {
return nil, fmt.Errorf("failed to get Gluon Database directory: %w", err)
}
return newIMAPServer(
bridge.vault.GetGluonCacheDir(),
gluonDataDir,
bridge.curVersion,
bridge.tlsConfig,
bridge.reporter,
bridge.logIMAPClient,
bridge.logIMAPServer,
bridge.imapEventCh,
bridge.tasks,
bridge.uidValidityGenerator,
bridge.panicHandler,
)
}
func createSMTPServer(bridge *Bridge) *smtp.Server {
return newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
}
func (sm *ServerManager) closeSMTPServer(bridge *Bridge) error {
// We close the listener ourselves even though it's also closed by smtpServer.Close().
// This is because smtpServer.Serve() is called in a separate goroutine and might be executed
// after we've already closed the server. However, go-smtp has a bug; it blocks on the listener
// even after the server has been closed. So we close the listener ourselves to unblock it.
if sm.smtpListener != nil {
logrus.Info("Closing SMTP Listener")
if err := sm.smtpListener.Close(); err != nil {
return fmt.Errorf("failed to close SMTP listener: %w", err)
}
sm.smtpListener = nil
}
if sm.smtpServer != nil {
logrus.Info("Closing SMTP server")
if err := sm.smtpServer.Close(); err != nil {
logrus.WithError(err).Debug("Failed to close SMTP server (expected -- we close the listener ourselves)")
}
sm.smtpServer = nil
bridge.publish(events.SMTPServerStopped{})
}
return nil
}
func (sm *ServerManager) closeIMAPServer(ctx context.Context, bridge *Bridge) error {
if sm.imapListener != nil {
logrus.Info("Closing IMAP Listener")
if err := sm.imapListener.Close(); err != nil {
return fmt.Errorf("failed to close IMAP listener: %w", err)
}
sm.imapListener = nil
bridge.publish(events.IMAPServerStopped{})
}
if sm.imapServer != nil {
logrus.Info("Closing IMAP server")
if err := sm.imapServer.Close(ctx); err != nil {
return fmt.Errorf("failed to close IMAP server: %w", err)
}
sm.imapServer = nil
}
return nil
}
func (sm *ServerManager) restartIMAP(ctx context.Context, bridge *Bridge) error {
logrus.Info("Restarting IMAP server")
if sm.imapListener != nil {
if err := sm.imapListener.Close(); err != nil {
return fmt.Errorf("failed to close IMAP listener: %w", err)
}
sm.imapListener = nil
bridge.publish(events.IMAPServerStopped{})
}
if sm.shouldStartServers() {
return sm.serveIMAP(ctx, bridge)
}
return nil
}
func (sm *ServerManager) restartSMTP(bridge *Bridge) error {
logrus.Info("Restarting SMTP server")
if err := sm.closeSMTPServer(bridge); err != nil {
return fmt.Errorf("failed to close SMTP: %w", err)
}
bridge.publish(events.SMTPServerStopped{})
sm.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
if sm.shouldStartServers() {
return sm.serveSMTP(bridge)
}
return nil
}
func (sm *ServerManager) serveSMTP(bridge *Bridge) error {
port, err := func() (int, error) {
logrus.WithFields(logrus.Fields{
"port": bridge.vault.GetSMTPPort(),
"ssl": bridge.vault.GetSMTPSSL(),
}).Info("Starting SMTP server")
smtpListener, err := newListener(bridge.vault.GetSMTPPort(), bridge.vault.GetSMTPSSL(), bridge.tlsConfig)
if err != nil {
return 0, fmt.Errorf("failed to create SMTP listener: %w", err)
}
sm.smtpListener = smtpListener
bridge.tasks.Once(func(context.Context) {
if err := sm.smtpServer.Serve(smtpListener); err != nil {
logrus.WithError(err).Info("SMTP server stopped")
}
})
if err := bridge.vault.SetSMTPPort(getPort(smtpListener.Addr())); err != nil {
return 0, fmt.Errorf("failed to store SMTP port in vault: %w", err)
}
return getPort(smtpListener.Addr()), nil
}()
if err != nil {
bridge.publish(events.SMTPServerError{
Error: err,
})
return err
}
bridge.publish(events.SMTPServerReady{
Port: port,
})
return nil
}
func (sm *ServerManager) serveIMAP(ctx context.Context, bridge *Bridge) error {
port, err := func() (int, error) {
if sm.imapServer == nil {
return 0, fmt.Errorf("no IMAP server instance running")
}
logrus.WithFields(logrus.Fields{
"port": bridge.vault.GetIMAPPort(),
"ssl": bridge.vault.GetIMAPSSL(),
}).Info("Starting IMAP server")
imapListener, err := newListener(bridge.vault.GetIMAPPort(), bridge.vault.GetIMAPSSL(), bridge.tlsConfig)
if err != nil {
return 0, fmt.Errorf("failed to create IMAP listener: %w", err)
}
sm.imapListener = imapListener
if err := sm.imapServer.Serve(ctx, sm.imapListener); err != nil {
return 0, fmt.Errorf("failed to serve IMAP: %w", err)
}
if err := bridge.vault.SetIMAPPort(getPort(imapListener.Addr())); err != nil {
return 0, fmt.Errorf("failed to store IMAP port in vault: %w", err)
}
return getPort(imapListener.Addr()), nil
}()
if err != nil {
bridge.publish(events.IMAPServerError{
Error: err,
})
return err
}
bridge.publish(events.IMAPServerReady{
Port: port,
})
return nil
}
func (sm *ServerManager) stopIMAPListener(bridge *Bridge) error {
logrus.Info("Stopping IMAP listener")
if sm.imapListener != nil {
if err := sm.imapListener.Close(); err != nil {
return err
}
sm.imapListener = nil
bridge.publish(events.IMAPServerStopped{})
}
return nil
}
func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge, newGluonDir string) error {
return safe.RLockRet(func() error {
currentGluonDir := bridge.GetGluonCacheDir()
newGluonDir = filepath.Join(newGluonDir, "gluon")
if newGluonDir == currentGluonDir {
return fmt.Errorf("new gluon dir is the same as the old one")
}
if err := sm.closeIMAPServer(context.Background(), bridge); err != nil {
return fmt.Errorf("failed to close IMAP: %w", err)
}
sm.loadedUserCount = 0
if err := bridge.moveGluonCacheDir(currentGluonDir, newGluonDir); err != nil {
logrus.WithError(err).Error("failed to move GluonCacheDir")
if err := bridge.vault.SetGluonDir(currentGluonDir); err != nil {
return fmt.Errorf("failed to revert GluonCacheDir: %w", err)
}
}
bridge.heartbeat.SetCacheLocation(newGluonDir)
gluonDataDir, err := bridge.GetGluonDataDir()
if err != nil {
return fmt.Errorf("failed to get Gluon Database directory: %w", err)
}
imapServer, err := newIMAPServer(
bridge.vault.GetGluonCacheDir(),
gluonDataDir,
bridge.curVersion,
bridge.tlsConfig,
bridge.reporter,
bridge.logIMAPClient,
bridge.logIMAPServer,
bridge.imapEventCh,
bridge.tasks,
bridge.uidValidityGenerator,
bridge.panicHandler,
)
if err != nil {
return fmt.Errorf("failed to create new IMAP server: %w", err)
}
sm.imapServer = imapServer
for _, bridgeUser := range bridge.users {
if err := sm.handleAddIMAPUser(ctx, bridgeUser); err != nil {
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
}
sm.loadedUserCount++
}
if sm.shouldStartServers() {
if err := sm.serveIMAP(ctx, bridge); err != nil {
return fmt.Errorf("failed to serve IMAP: %w", err)
}
}
return nil
}, bridge.usersLock)
}
func (sm *ServerManager) handleAddGluonUser(ctx context.Context, conn connector.Connector, passphrase []byte) (string, error) {
if sm.imapServer == nil {
return "", fmt.Errorf("no imap server instance running")
}
return sm.imapServer.AddUser(ctx, conn, passphrase)
}
func (sm *ServerManager) handleRemoveGluonUser(ctx context.Context, userID string) error {
if sm.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
return sm.imapServer.RemoveUser(ctx, userID, true)
}
func (sm *ServerManager) shouldStartServers() bool {
return sm.loadedUserCount >= 1
}
type smRequestClose struct{}
type smRequestRestartIMAP struct{}
type smRequestRestartSMTP struct{}
type smRequestAddIMAPUser struct {
user *user.User
}
type smRequestRemoveIMAPUser struct {
user *user.User
withData bool
}
type smRequestSetGluonDir struct {
dir string
}
type smRequestAddGluonUser struct {
conn connector.Connector
passphrase []byte
}
type smRequestRemoveGluonUser struct {
userID string
}

View File

@ -0,0 +1,179 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge_test
import (
"context"
"fmt"
"testing"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/stretchr/testify/require"
)
func TestServerManager_NoLoadedUsersNoServers(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.Error(t, err)
})
})
}
func TestServerManager_ServersStartAfterFirstConnectedUser(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
imapWaiter.Wait()
smtpWaiter.Wait()
})
})
}
func TestServerManager_ServersStopsAfterUserLogsOut(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
imapWaiter.Wait()
smtpWaiter.Wait()
imapWaiterStopped := waitForIMAPServerStopped(bridge)
defer imapWaiterStopped.Done()
require.NoError(t, bridge.LogoutUser(ctx, userID))
imapWaiterStopped.Wait()
})
})
}
func TestServerManager_ServersDoNotStopWhenThereIsStillOneActiveUser(t *testing.T) {
otherPassword := []byte("bar")
otherUser := "foo"
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
_, _, err := s.CreateUser(otherUser, otherPassword)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
userIDOther, err := bridge.LoginFull(ctx, otherUser, otherPassword, nil, nil)
require.NoError(t, err)
imapWaiter.Wait()
smtpWaiter.Wait()
evtCh, cancel := bridge.GetEvents(events.UserDeauth{})
defer cancel()
require.NoError(t, s.RevokeUser(userIDOther))
waitForEvent(t, evtCh, events.UserDeauth{})
imapClient, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, imapClient.Logout())
})
})
}
func TestServerManager_ServersStartIfAtLeastOneUserIsLoggedIn(t *testing.T) {
otherPassword := []byte("bar")
otherUser := "foo"
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
userIDOther, _, err := s.CreateUser(otherUser, otherPassword)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
_, err = bridge.LoginFull(ctx, otherUser, otherPassword, nil, nil)
require.NoError(t, err)
})
require.NoError(t, s.RevokeUser(userIDOther))
withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapClient, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, imapClient.Logout())
})
})
}
func TestServerManager_NetworkLossStopsServers(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
imapWaiterStop := waitForIMAPServerStopped(bridge)
defer imapWaiterStop.Done()
smtpWaiterStop := waitForSMTPServerStopped(bridge)
defer smtpWaiterStop.Done()
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
require.NoError(t, err)
imapWaiter.Wait()
smtpWaiter.Wait()
netCtl.Disable()
imapWaiterStop.Wait()
smtpWaiterStop.Wait()
netCtl.Enable()
imapWaiter.Wait()
smtpWaiter.Wait()
})
})
}

View File

@ -22,7 +22,6 @@ import (
"fmt"
"net"
"os"
"path/filepath"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
@ -55,7 +54,7 @@ func (bridge *Bridge) GetIMAPPort() int {
return bridge.vault.GetIMAPPort()
}
func (bridge *Bridge) SetIMAPPort(newPort int) error {
func (bridge *Bridge) SetIMAPPort(ctx context.Context, newPort int) error {
if newPort == bridge.vault.GetIMAPPort() {
return nil
}
@ -66,14 +65,14 @@ func (bridge *Bridge) SetIMAPPort(newPort int) error {
bridge.heartbeat.SetIMAPPort(newPort)
return bridge.restartIMAP()
return bridge.restartIMAP(ctx)
}
func (bridge *Bridge) GetIMAPSSL() bool {
return bridge.vault.GetIMAPSSL()
}
func (bridge *Bridge) SetIMAPSSL(newSSL bool) error {
func (bridge *Bridge) SetIMAPSSL(ctx context.Context, newSSL bool) error {
if newSSL == bridge.vault.GetIMAPSSL() {
return nil
}
@ -84,14 +83,14 @@ func (bridge *Bridge) SetIMAPSSL(newSSL bool) error {
bridge.heartbeat.SetIMAPConnectionMode(newSSL)
return bridge.restartIMAP()
return bridge.restartIMAP(ctx)
}
func (bridge *Bridge) GetSMTPPort() int {
return bridge.vault.GetSMTPPort()
}
func (bridge *Bridge) SetSMTPPort(newPort int) error {
func (bridge *Bridge) SetSMTPPort(ctx context.Context, newPort int) error {
if newPort == bridge.vault.GetSMTPPort() {
return nil
}
@ -102,14 +101,14 @@ func (bridge *Bridge) SetSMTPPort(newPort int) error {
bridge.heartbeat.SetSMTPPort(newPort)
return bridge.restartSMTP()
return bridge.restartSMTP(ctx)
}
func (bridge *Bridge) GetSMTPSSL() bool {
return bridge.vault.GetSMTPSSL()
}
func (bridge *Bridge) SetSMTPSSL(newSSL bool) error {
func (bridge *Bridge) SetSMTPSSL(ctx context.Context, newSSL bool) error {
if newSSL == bridge.vault.GetSMTPSSL() {
return nil
}
@ -120,7 +119,7 @@ func (bridge *Bridge) SetSMTPSSL(newSSL bool) error {
bridge.heartbeat.SetSMTPConnectionMode(newSSL)
return bridge.restartSMTP()
return bridge.restartSMTP(ctx)
}
func (bridge *Bridge) GetGluonCacheDir() string {
@ -132,63 +131,7 @@ func (bridge *Bridge) GetGluonDataDir() (string, error) {
}
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
return safe.RLockRet(func() error {
currentGluonDir := bridge.GetGluonCacheDir()
newGluonDir = filepath.Join(newGluonDir, "gluon")
if newGluonDir == currentGluonDir {
return fmt.Errorf("new gluon dir is the same as the old one")
}
if err := bridge.closeIMAP(context.Background()); err != nil {
return fmt.Errorf("failed to close IMAP: %w", err)
}
if err := bridge.moveGluonCacheDir(currentGluonDir, newGluonDir); err != nil {
logrus.WithError(err).Error("failed to move GluonCacheDir")
if err := bridge.vault.SetGluonDir(currentGluonDir); err != nil {
return fmt.Errorf("failed to revert GluonCacheDir: %w", err)
}
}
bridge.heartbeat.SetCacheLocation(newGluonDir)
gluonDataDir, err := bridge.GetGluonDataDir()
if err != nil {
return fmt.Errorf("failed to get Gluon Database directory: %w", err)
}
imapServer, err := newIMAPServer(
bridge.vault.GetGluonCacheDir(),
gluonDataDir,
bridge.curVersion,
bridge.tlsConfig,
bridge.reporter,
bridge.logIMAPClient,
bridge.logIMAPServer,
bridge.imapEventCh,
bridge.tasks,
bridge.uidValidityGenerator,
bridge.panicHandler,
)
if err != nil {
return fmt.Errorf("failed to create new IMAP server: %w", err)
}
bridge.imapServer = imapServer
for _, user := range bridge.users {
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
}
}
if err := bridge.serveIMAP(); err != nil {
return fmt.Errorf("failed to serve IMAP: %w", err)
}
return nil
}, bridge.usersLock)
return bridge.serverManager.SetGluonDir(ctx, newGluonDir)
}
func (bridge *Bridge) moveGluonCacheDir(oldGluonDir, newGluonDir string) error {

View File

@ -57,7 +57,7 @@ func TestBridge_Settings_IMAPPort(t *testing.T) {
curPort := bridge.GetIMAPPort()
// Set the port to 1144.
require.NoError(t, bridge.SetIMAPPort(1144))
require.NoError(t, bridge.SetIMAPPort(ctx, 1144))
// Get the new setting.
require.Equal(t, 1144, bridge.GetIMAPPort())
@ -75,7 +75,7 @@ func TestBridge_Settings_IMAPSSL(t *testing.T) {
require.False(t, bridge.GetIMAPSSL())
// Enable IMAP SSL.
require.NoError(t, bridge.SetIMAPSSL(true))
require.NoError(t, bridge.SetIMAPSSL(ctx, true))
// Get the new setting.
require.True(t, bridge.GetIMAPSSL())
@ -89,7 +89,7 @@ func TestBridge_Settings_SMTPPort(t *testing.T) {
curPort := bridge.GetSMTPPort()
// Set the port to 1024.
require.NoError(t, bridge.SetSMTPPort(1024))
require.NoError(t, bridge.SetSMTPPort(ctx, 1024))
// Get the new setting.
require.Equal(t, 1024, bridge.GetSMTPPort())
@ -107,7 +107,7 @@ func TestBridge_Settings_SMTPSSL(t *testing.T) {
require.False(t, bridge.GetSMTPSSL())
// Enable SMTP SSL.
require.NoError(t, bridge.SetSMTPSSL(true))
require.NoError(t, bridge.SetSMTPSSL(ctx, true))
// Get the new setting.
require.True(t, bridge.GetSMTPSSL())

View File

@ -20,93 +20,16 @@ package bridge
import (
"context"
"crypto/tls"
"fmt"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/sirupsen/logrus"
)
func (bridge *Bridge) serveSMTP() error {
port, err := func() (int, error) {
logrus.WithFields(logrus.Fields{
"port": bridge.vault.GetSMTPPort(),
"ssl": bridge.vault.GetSMTPSSL(),
}).Info("Starting SMTP server")
smtpListener, err := newListener(bridge.vault.GetSMTPPort(), bridge.vault.GetSMTPSSL(), bridge.tlsConfig)
if err != nil {
return 0, fmt.Errorf("failed to create SMTP listener: %w", err)
}
bridge.smtpListener = smtpListener
bridge.tasks.Once(func(context.Context) {
if err := bridge.smtpServer.Serve(smtpListener); err != nil {
logrus.WithError(err).Info("SMTP server stopped")
}
})
if err := bridge.vault.SetSMTPPort(getPort(smtpListener.Addr())); err != nil {
return 0, fmt.Errorf("failed to store SMTP port in vault: %w", err)
}
return getPort(smtpListener.Addr()), nil
}()
if err != nil {
bridge.publish(events.SMTPServerError{
Error: err,
})
return err
}
bridge.publish(events.SMTPServerReady{
Port: port,
})
return nil
}
func (bridge *Bridge) restartSMTP() error {
logrus.Info("Restarting SMTP server")
if err := bridge.closeSMTP(); err != nil {
return fmt.Errorf("failed to close SMTP: %w", err)
}
bridge.publish(events.SMTPServerStopped{})
bridge.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
return bridge.serveSMTP()
}
// We close the listener ourselves even though it's also closed by smtpServer.Close().
// This is because smtpServer.Serve() is called in a separate goroutine and might be executed
// after we've already closed the server. However, go-smtp has a bug; it blocks on the listener
// even after the server has been closed. So we close the listener ourselves to unblock it.
func (bridge *Bridge) closeSMTP() error {
logrus.Info("Closing SMTP server")
if bridge.smtpListener != nil {
if err := bridge.smtpListener.Close(); err != nil {
return fmt.Errorf("failed to close SMTP listener: %w", err)
}
}
if err := bridge.smtpServer.Close(); err != nil {
logrus.WithError(err).Debug("Failed to close SMTP server (expected -- we close the listener ourselves)")
}
bridge.publish(events.SMTPServerStopped{})
return nil
func (bridge *Bridge) restartSMTP(ctx context.Context) error {
return bridge.serverManager.RestartSMTP(ctx)
}
func newSMTPServer(bridge *Bridge, tlsConfig *tls.Config, logSMTP bool) *smtp.Server {

View File

@ -58,6 +58,11 @@ func (s *smtpSession) AuthPlain(username, password string) error {
return nil
}
logrus.WithFields(logrus.Fields{
"username": username,
"pkg": "smtp",
}).Error("Incorrect login credentials.")
return fmt.Errorf("invalid username or password")
}, s.usersLock)
}
@ -72,7 +77,7 @@ func (s *smtpSession) Logout() error {
return nil
}
func (s *smtpSession) Mail(from string, opts *smtp.MailOptions) error {
func (s *smtpSession) Mail(from string, _ *smtp.MailOptions) error {
s.from = from
return nil
}

View File

@ -80,7 +80,7 @@ func TestBridge_Sync(t *testing.T) {
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
@ -112,15 +112,6 @@ func TestBridge_Sync(t *testing.T) {
info, err := b.GetUserInfo(userID)
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
status, err := client.Select(`Folders/folder`, false)
require.NoError(t, err)
require.Less(t, status.Messages, uint32(numMsg))
}
// Remove the network limit, allowing the sync to finish.
@ -136,7 +127,7 @@ func TestBridge_Sync(t *testing.T) {
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
@ -187,7 +178,7 @@ func _TestBridge_Sync_BadMessage(t *testing.T) { //nolint:unused,deadcode
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
@ -273,15 +264,6 @@ func TestBridge_SyncWithOngoingEvents(t *testing.T) {
info, err := b.GetUserInfo(userID)
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
status, err := client.Select(`Folders/folder`, false)
require.NoError(t, err)
require.Less(t, status.Messages, uint32(numMsg))
}
// Create a new mailbox and move that last 1/3 of the messages into it to simulate user
@ -311,7 +293,7 @@ func TestBridge_SyncWithOngoingEvents(t *testing.T) {
require.NoError(t, err)
require.True(t, info.State == bridge.Connected)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()

View File

@ -0,0 +1,82 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
//go:build !windows
package bridge_test
import (
"context"
"syscall"
"testing"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/stretchr/testify/require"
)
// Disabled due to flakyness.
func _TestBridge_SyncExistsWithErrorWhenTooManyFilesAreOpen(t *testing.T) { //nolint:unused
var rlimitCurrent syscall.Rlimit
require.NoError(t, syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimitCurrent))
// Restore RLimit for Process at the end of this test
defer func() {
require.NoError(t, syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimitCurrent))
}()
rlimit := syscall.Rlimit{
Max: 100,
Cur: 100,
}
require.NoError(t, syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit))
numMsg := 1 << 8
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
userID, addrID, err := s.CreateUser("imap", password)
require.NoError(t, err)
labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder)
require.NoError(t, err)
withClient(ctx, t, s, "imap", password, func(ctx context.Context, c *proton.Client) {
createNumMessages(ctx, t, c, addrID, labelID, numMsg)
})
// The initial user should be fully synced.
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
syncCh, done := bridge.GetEvents(events.SyncFailed{})
defer done()
userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil)
require.NoError(t, err)
evt := <-syncCh
switch e := evt.(type) {
case events.SyncFailed:
require.Equal(t, userID, e.UserID)
default:
require.Fail(t, "Expected events.SyncFailed{}")
}
})
}, server.WithTLS(false))
}

View File

@ -584,29 +584,7 @@ func (bridge *Bridge) newVaultUser(
authUID, authRef string,
saltedKeyPass []byte,
) (*vault.User, bool, error) {
if !bridge.vault.HasUser(apiUser.ID) {
user, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, apiUser.Email, authUID, authRef, saltedKeyPass)
if err != nil {
return nil, false, fmt.Errorf("failed to add user to vault: %w", err)
}
return user, true, nil
}
user, err := bridge.vault.NewUser(apiUser.ID)
if err != nil {
return nil, false, err
}
if err := user.SetAuth(authUID, authRef); err != nil {
return nil, false, err
}
if err := user.SetKeyPass(saltedKeyPass); err != nil {
return nil, false, err
}
return user, false, nil
return bridge.vault.GetOrAddUser(apiUser.ID, apiUser.Name, apiUser.Email, authUID, authRef, saltedKeyPass)
}
// logout logs out the given user, optionally logging them out from the API too.

View File

@ -141,6 +141,9 @@ func test_badMessage_badEvent(userFeedback func(t *testing.T, ctx context.Contex
})
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
userLoginAndSync(ctx, t, bridge, "user", password)
var messageIDs []string
@ -176,6 +179,8 @@ func test_badMessage_badEvent(userFeedback func(t *testing.T, ctx context.Contex
userFeedback(t, ctx, bridge, badUserID)
smtpWaiter.Wait()
userContinueEventProcess(ctx, t, s, bridge)
})
})
@ -194,6 +199,9 @@ func TestBridge_User_BadMessage_NoBadEvent(t *testing.T) {
})
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
userLoginAndSync(ctx, t, bridge, "user", password)
var messageIDs []string
@ -217,6 +225,7 @@ func TestBridge_User_BadMessage_NoBadEvent(t *testing.T) {
require.NoError(t, c.DeleteMessage(ctx, messageIDs...))
})
smtpWaiter.Wait()
userContinueEventProcess(ctx, t, s, bridge)
})
})
@ -412,6 +421,17 @@ func TestBridge_User_DropConn_NoBadEvent(t *testing.T) {
})
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
var count int32
// The first 10 times bridge attempts to sync any of the messages, drop the connection.
s.AddStatusHook(func(req *http.Request) (int, bool) {
if strings.Contains(req.URL.Path, "/mail/v4/messages") {
if atomic.AddInt32(&count, 1) < 10 {
dropListener.DropAll()
}
}
return 0, false
})
userLoginAndSync(ctx, t, bridge, "user", password)
mocks.Reporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Any()).AnyTimes()
@ -421,30 +441,17 @@ func TestBridge_User_DropConn_NoBadEvent(t *testing.T) {
createNumMessages(ctx, t, c, addrID, proton.InboxLabel, 10)
})
var count int
// The first 10 times bridge attempts to sync any of the messages, drop the connection.
s.AddStatusHook(func(req *http.Request) (int, bool) {
if strings.Contains(req.URL.Path, "/mail/v4/messages") {
if count++; count < 10 {
dropListener.DropAll()
}
}
return 0, false
})
info, err := bridge.QueryUserInfo("user")
require.NoError(t, err)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
cli, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
require.NoError(t, cli.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = cli.Logout() }()
// The IMAP client will eventually see 20 messages.
require.Eventually(t, func() bool {
status, err := client.Status("INBOX", []imap.StatusItem{imap.StatusMessages})
status, err := cli.Status("INBOX", []imap.StatusItem{imap.StatusMessages})
return err == nil && status.Messages == 20
}, 10*time.Second, 100*time.Millisecond)
})
@ -638,12 +645,12 @@ func TestBridge_User_SendDraftRemoveDraftFlag(t *testing.T) {
info, err := bridge.QueryUserInfo("user")
require.NoError(t, err)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
cli, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
require.NoError(t, cli.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = cli.Logout() }()
messages, err := clientFetch(client, "Drafts")
messages, err := clientFetch(cli, "Drafts")
require.NoError(t, err)
require.Len(t, messages, 1)
require.Contains(t, messages[0].Flags, imap.DraftFlag)
@ -677,12 +684,12 @@ func TestBridge_User_SendDraftRemoveDraftFlag(t *testing.T) {
info, err := bridge.QueryUserInfo("user")
require.NoError(t, err)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
cli, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
require.NoError(t, cli.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = cli.Logout() }()
messages, err := clientFetch(client, "Sent")
messages, err := clientFetch(cli, "Sent")
require.NoError(t, err)
require.Len(t, messages, 1)
require.NotContains(t, messages[0].Flags, imap.DraftFlag)
@ -771,15 +778,24 @@ func TestBridge_User_CreateDisabledAddress(t *testing.T) {
func TestBridge_User_HandleParentLabelRename(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
imapWaiter := waitForIMAPServerReady(bridge)
defer imapWaiter.Done()
smtpWaiter := waitForSMTPServerReady(bridge)
defer smtpWaiter.Done()
require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
info, err := bridge.QueryUserInfo(username)
require.NoError(t, err)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
imapWaiter.Wait()
smtpWaiter.Wait()
cli, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
require.NoError(t, cli.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = cli.Logout() }()
withClient(ctx, t, s, username, password, func(ctx context.Context, c *proton.Client) {
parentName := uuid.NewString()
@ -795,7 +811,7 @@ func TestBridge_User_HandleParentLabelRename(t *testing.T) {
// Wait for the parent folder to be created.
require.Eventually(t, func() bool {
return xslices.IndexFunc(clientList(client), func(mailbox *imap.MailboxInfo) bool {
return xslices.IndexFunc(clientList(cli), func(mailbox *imap.MailboxInfo) bool {
return mailbox.Name == fmt.Sprintf("Folders/%v", parentName)
}) >= 0
}, 100*user.EventPeriod, user.EventPeriod)
@ -812,7 +828,7 @@ func TestBridge_User_HandleParentLabelRename(t *testing.T) {
// Wait for the parent folder to be created.
require.Eventually(t, func() bool {
return xslices.IndexFunc(clientList(client), func(mailbox *imap.MailboxInfo) bool {
return xslices.IndexFunc(clientList(cli), func(mailbox *imap.MailboxInfo) bool {
return mailbox.Name == fmt.Sprintf("Folders/%v/%v", parentName, childName)
}) >= 0
}, 100*user.EventPeriod, user.EventPeriod)
@ -827,14 +843,14 @@ func TestBridge_User_HandleParentLabelRename(t *testing.T) {
// Wait for the parent folder to be renamed.
require.Eventually(t, func() bool {
return xslices.IndexFunc(clientList(client), func(mailbox *imap.MailboxInfo) bool {
return xslices.IndexFunc(clientList(cli), func(mailbox *imap.MailboxInfo) bool {
return mailbox.Name == fmt.Sprintf("Folders/%v", newParentName)
}) >= 0
}, 100*user.EventPeriod, user.EventPeriod)
// Wait for the child folder to be renamed.
require.Eventually(t, func() bool {
return xslices.IndexFunc(clientList(client), func(mailbox *imap.MailboxInfo) bool {
return xslices.IndexFunc(clientList(cli), func(mailbox *imap.MailboxInfo) bool {
return mailbox.Name == fmt.Sprintf("Folders/%v/%v", newParentName, childName)
}) >= 0
}, 100*user.EventPeriod, user.EventPeriod)
@ -843,48 +859,6 @@ func TestBridge_User_HandleParentLabelRename(t *testing.T) {
})
}
// TBD: GODT-2527.
func _TestBridge503DuringEventDoesNotCauseBadEvent(t *testing.T) { //nolint:unused,deadcode
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
// Create a user.
userID, addrID, err := s.CreateUser("user", password)
require.NoError(t, err)
labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder)
require.NoError(t, err)
// Create 10 messages for the user.
withClient(ctx, t, s, "user", password, func(ctx context.Context, c *proton.Client) {
createNumMessages(ctx, t, c, addrID, labelID, 10)
})
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
userLoginAndSync(ctx, t, bridge, "user", password)
var messageIDs []string
// Create 10 more messages for the user, generating events.
withClient(ctx, t, s, "user", password, func(ctx context.Context, c *proton.Client) {
messageIDs = createNumMessages(ctx, t, c, addrID, labelID, 10)
})
mocks.Reporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Any()).MinTimes(1)
s.AddStatusHook(func(req *http.Request) (int, bool) {
if xslices.Index(xslices.Map(messageIDs[0:5], func(messageID string) string {
return "/mail/v4/messages/" + messageID
}), req.URL.Path) < 0 {
return 0, false
}
return http.StatusServiceUnavailable, true
})
userContinueEventProcess(ctx, t, s, bridge)
})
})
}
// userLoginAndSync logs in user and waits until user is fully synced.
func userLoginAndSync(
ctx context.Context,
@ -928,10 +902,10 @@ func userContinueEventProcess(
info, err := bridge.QueryUserInfo("user")
require.NoError(t, err)
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
cli, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
require.NoError(t, err)
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = client.Logout() }()
require.NoError(t, cli.Login(info.Addresses[0], string(info.BridgePass)))
defer func() { _ = cli.Logout() }()
randomLabel := uuid.NewString()
@ -946,8 +920,21 @@ func userContinueEventProcess(
// Wait for the label to be created.
require.Eventually(t, func() bool {
return xslices.IndexFunc(clientList(client), func(mailbox *imap.MailboxInfo) bool {
return xslices.IndexFunc(clientList(cli), func(mailbox *imap.MailboxInfo) bool {
return mailbox.Name == "Labels/"+randomLabel
}) >= 0
}, 100*user.EventPeriod, user.EventPeriod)
}
func eventuallyDial(addr string) (cli *client.Client, err error) {
var sleep = 1 * time.Second
for i := 0; i < 5; i++ {
cli, err := client.Dial(addr)
if err == nil {
return cli, nil
}
time.Sleep(sleep)
sleep *= 2
}
return nil, fmt.Errorf("after 5 attempts, last error: %s", err)
}

View File

@ -75,11 +75,7 @@ func (bridge *Bridge) handleUserAddressCreated(ctx context.Context, user *user.U
return nil
}
if bridge.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
gluonID, err := bridge.imapServer.AddUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
gluonID, err := bridge.serverManager.AddGluonUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add user to IMAP server: %w", err)
}
@ -96,7 +92,7 @@ func (bridge *Bridge) handleUserAddressEnabled(ctx context.Context, user *user.U
return nil
}
gluonID, err := bridge.imapServer.AddUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
gluonID, err := bridge.serverManager.AddGluonUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add user to IMAP server: %w", err)
}
@ -118,7 +114,7 @@ func (bridge *Bridge) handleUserAddressDisabled(ctx context.Context, user *user.
return fmt.Errorf("gluon ID not found for address %s", event.AddressID)
}
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
if err := bridge.serverManager.RemoveGluonUser(ctx, gluonID); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
@ -134,16 +130,12 @@ func (bridge *Bridge) handleUserAddressDeleted(ctx context.Context, user *user.U
return nil
}
if bridge.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
gluonID, ok := user.GetGluonID(event.AddressID)
if !ok {
return fmt.Errorf("gluon ID not found for address %s", event.AddressID)
}
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
if err := bridge.serverManager.handleRemoveGluonUser(ctx, gluonID); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}

View File

@ -708,7 +708,26 @@ func TestBridge_User_Refresh(t *testing.T) {
})
}
func TestBridge_User_GetAddresses(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
// Create a user.
userID, _, err := s.CreateUser("user", password)
require.NoError(t, err)
addrID2, err := s.CreateAddress(userID, "user@external.com", []byte("password"))
require.NoError(t, err)
require.NoError(t, s.ChangeAddressType(userID, addrID2, proton.AddressTypeExternal))
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
userLoginAndSync(ctx, t, bridge, "user", password)
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
require.Equal(t, 1, len(info.Addresses))
require.Equal(t, info.Addresses[0], "user@proton.local")
})
})
}
// getErr returns the error that was passed to it.
func getErr[T any](val T, err error) error {
func getErr[T any](_ T, err error) error {
return err
}

View File

@ -50,7 +50,7 @@ int installTrustedCert(char const *bytes, unsigned long long length) {
(id)kSecTrustSettingsResult: [NSNumber numberWithInt:kSecTrustSettingsResultTrustRoot],
(id)kSecTrustSettingsPolicy: (__bridge id) policy,
};
status = SecTrustSettingsSetTrustSettings(cert, kSecTrustSettingsDomainAdmin, (__bridge CFTypeRef)(trustSettings));
status = SecTrustSettingsSetTrustSettings(cert, kSecTrustSettingsDomainUser, (__bridge CFTypeRef)(trustSettings));
CFRelease(policy);
CFRelease(cert);
@ -72,7 +72,7 @@ int removeTrustedCert(char const *bytes, unsigned long long length) {
(id)kSecTrustSettingsResult: [NSNumber numberWithInt:kSecTrustSettingsResultUnspecified],
(id)kSecTrustSettingsPolicy: (__bridge id) policy,
};
OSStatus status = SecTrustSettingsSetTrustSettings(cert, kSecTrustSettingsDomainAdmin, (__bridge CFTypeRef)(trustSettings));
OSStatus status = SecTrustSettingsSetTrustSettings(cert, kSecTrustSettingsDomainUser, (__bridge CFTypeRef)(trustSettings));
CFRelease(policy);
if (errSecSuccess != status) {
CFRelease(cert);
@ -107,7 +107,6 @@ const (
// certPEMToDER converts a certificate in PEM format to DER format, which is the format required by Apple's Security framework.
func certPEMToDER(certPEM []byte) ([]byte, error) {
block, left := pem.Decode(certPEM)
if block == nil {
return []byte{}, errors.New("invalid PEM certificate")
@ -127,7 +126,7 @@ func installCert(certPEM []byte) error {
}
p := C.CBytes(certDER)
defer C.free(unsafe.Pointer(p))
defer C.free(unsafe.Pointer(p)) //nolint:unconvert
errCode := C.installTrustedCert((*C.char)(p), (C.ulonglong)(len(certDER)))
switch errCode {
@ -147,7 +146,7 @@ func uninstallCert(certPEM []byte) error {
}
p := C.CBytes(certDER)
defer C.free(unsafe.Pointer(p))
defer C.free(unsafe.Pointer(p)) //nolint:unconvert
if errCode := C.removeTrustedCert((*C.char)(p), (C.ulonglong)(len(certDER))); errCode != 0 {
return fmt.Errorf("could not install certificate from keychain (error %v)", errCode)

View File

@ -26,7 +26,7 @@ import (
)
// This test implies human interactions to enter password and is disabled by default.
func _TestTrustedCertsDarwin(t *testing.T) {
func _TestTrustedCertsDarwin(t *testing.T) { //nolint:unused
template, err := NewTLSTemplate()
require.NoError(t, err)

View File

@ -75,7 +75,7 @@ if(NOT UNIX)
set(CMAKE_INSTALL_BINDIR ".")
endif(NOT UNIX)
find_package(Qt6 COMPONENTS Core Quick Qml QuickControls2 Widgets REQUIRED)
find_package(Qt6 COMPONENTS Core Quick Qml QuickControls2 Widgets Svg REQUIRED)
qt_standard_project_setup()
set(CMAKE_AUTORCC ON)
message(STATUS "Using Qt ${Qt6_VERSION}")
@ -147,6 +147,7 @@ target_link_libraries(bridge-gui
Qt6::Quick
Qt6::Qml
Qt6::QuickControls2
Qt6::Svg
sentry::sentry
bridgepp
)

View File

@ -25,6 +25,7 @@
#include <QtQml>
#include <QtWidgets>
#include <QtQuickControls2>
#include <QtSvg>
#include <AppController.h>

View File

@ -994,15 +994,44 @@ void QMLBackend::onUserBadEvent(QString const &userID, QString const &) {
void QMLBackend::onIMAPLoginFailed(QString const &username) {
HANDLE_EXCEPTION(
SPUser const user = users_->getUserWithUsernameOrEmail(username);
if ((!user) || (user->state() != UserState::SignedOut)) { // We want to pop-up only if a signed-out user has been detected
if (!user) {
return;
}
if (user->isInIMAPLoginFailureCooldown()) {
return;
qint64 const cooldownDurationMs = 10 * 60 * 1000; // 10 minutes cooldown period for notifications
switch (user->state()) {
case UserState::SignedOut:
if (user->isNotificationInCooldown(User::ENotification::IMAPLoginWhileSignedOut)) {
return;
}
user->startNotificationCooldownPeriod(User::ENotification::IMAPLoginWhileSignedOut, cooldownDurationMs);
emit selectUser(user->id(), true);
emit imapLoginWhileSignedOut(username);
break;
case UserState::Connected:
if (user->isNotificationInCooldown(User::ENotification::IMAPPasswordFailure)) {
return;
}
user->startNotificationCooldownPeriod(User::ENotification::IMAPPasswordFailure, cooldownDurationMs);
emit selectUser(user->id(), false);
trayIcon_->showErrorPopupNotification(tr("Incorrect password"),
tr("Your email client can't connect to Proton Bridge. Make sure you are using the local Bridge password shown in Bridge."));
break;
case UserState::Locked:
if (user->isNotificationInCooldown(User::ENotification::IMAPLoginWhileLocked)) {
return;
}
user->startNotificationCooldownPeriod(User::ENotification::IMAPLoginWhileLocked, cooldownDurationMs);
emit selectUser(user->id(), false);
trayIcon_->showErrorPopupNotification(tr("Connection in progress"),
tr("Your Proton account in Bridge is being connected. Please wait or restart Bridge."));
break;
default:
break;
}
user->startImapLoginFailureCooldown(60 * 60 * 1000); // 1 hour cooldown during which we will not display this notification to this user again.
emit selectUser(user->id());
emit imapLoginWhileSignedOut(username);
)
}
@ -1134,7 +1163,7 @@ void QMLBackend::displayBadEventDialog(QString const &userID) {
emit userBadEvent(userID,
tr("Bridge ran into an internal error and it is not able to proceed with the account %1. Synchronize your local database now or logout"
" to do it later. Synchronization time depends on the size of your mailbox.").arg(elideLongString(user->primaryEmailOrUsername(), 30)));
emit selectUser(userID);
emit selectUser(userID, true);
emit showMainWindow();
)
}

View File

@ -180,6 +180,8 @@ public slots: // slot for signals received from QML -> To be forwarded to Bridge
void onVersionChanged(); ///< Slot for the version change signal.
void setMailServerSettings(int imapPort, int smtpPort, bool useSSLForIMAP, bool useSSLForSMTP) const; ///< Forwards a connection mode change request from QML to gRPC
void sendBadEventUserFeedback(QString const &userID, bool doResync); ///< Slot the providing user feedback for a bad event.
public slots: // slots for functions that need to be processed locally.
void setNormalTrayIcon(); ///< Set the tray icon to normal.
void setErrorTrayIcon(QString const& stateString, QString const &statusIcon); ///< Set the tray icon to 'error' state.
void setWarnTrayIcon(QString const& stateString, QString const &statusIcon); ///< Set the tray icon to 'warn' state.
@ -245,7 +247,7 @@ signals: // Signals received from the Go backend, to be forwarded to QML
void hideMainWindow(); ///< Signal for the 'hideMainWindow' gRPC stream event.
void showHelp(); ///< Signal for the 'showHelp' event (from the context menu).
void showSettings(); ///< Signal for the 'showHelp' event (from the context menu).
void selectUser(QString const& userID); ///< Signal emitted in order to selected a user with a given ID in the list.
void selectUser(QString const& userID, bool forceShowWindow); ///< Signal emitted in order to selected a user with a given ID in the list.
void genericError(QString const &title, QString const &description); ///< Signal for the 'genericError' gRPC stream event.
void imapLoginWhileSignedOut(QString const& username); ///< Signal for the notification of IMAP login attempt on a signed out account.

View File

@ -49,7 +49,7 @@ QString sentryAttachmentFilePath() {
//****************************************************************************************************************************************************
QByteArray getProtectedHostname() {
QByteArray hostname = QCryptographicHash::hash(QSysInfo::machineHostName().toUtf8(), QCryptographicHash::Sha256);
return hostname.toHex();
return hostname.toBase64();
}
//****************************************************************************************************************************************************

View File

@ -22,6 +22,7 @@
#include <sentry.h>
void initSentry();
QByteArray getProtectedHostname();
void setSentryReportScope();
sentry_options_t* newSentryOptions(const char * sentryDNS, const char * cacheDir);
sentry_uuid_t reportSentryEvent(sentry_level_t level, const char *message);

View File

@ -43,12 +43,75 @@ qint64 const iconRefreshDurationSecs = 10; ///< The total number of seconds duri
QIcon loadIconFromImage(QString const &path) {
QPixmap const pixmap(path);
if (pixmap.isNull()) {
throw Exception(QString("Could create icon from image '%1'.").arg(path));
throw Exception(QString("Could not create an icon from an image '%1'.").arg(path));
}
return QIcon(pixmap);
}
//****************************************************************************************************************************************************
/// \brief Generate an icon from a SVG renderer (a.k.a. path).
///
/// \param[in] renderer The SVG renderer.
/// \param[in] color The color to use in case the SVG path is to be used as a mask.
/// \return The icon.
//****************************************************************************************************************************************************
QIcon loadIconFromSVGRenderer(QSvgRenderer &renderer, QColor const &color = QColor()) {
if (!renderer.isValid()) {
return QIcon();
}
QIcon icon;
qint32 size = 256;
while (size >= 16) {
QPixmap pixmap(size, size);
pixmap.fill(QColor(0, 0, 0, 0));
QPainter painter(&pixmap);
renderer.render(&painter);
if (color.isValid()) {
painter.setCompositionMode(QPainter::CompositionMode_SourceIn);
painter.fillRect(pixmap.rect(), color);
}
painter.end();
icon.addPixmap(pixmap);
size /= 2;
}
return icon;
}
//****************************************************************************************************************************************************
/// \brief Load a multi-resolution icon from a SVG file. The image is assumed to be square. SVG is rasterized in 256, 128, 64, 32 and 16px.
///
/// Note: QPixmap can load SVG files directly, but our SVG file are defined in small shape size and QPixmap will rasterize them a very low resolution
/// by default (eg. 16x16), which is insufficient for some uses. As a consequence, we manually generate a multi-resolution icon that render smoothly
/// at any acceptable resolution for an icon.
///
/// \param[in] path The path of the SVG file.
/// \return The icon.
//****************************************************************************************************************************************************
QIcon loadIconFromSVG(QString const &path, QColor const &color = QColor()) {
QSvgRenderer renderer(path);
QIcon const icon = loadIconFromSVGRenderer(renderer, color);
if (icon.isNull()) {
Exception(QString("Could not create an icon from a vector image '%1'.").arg(path));
}
return icon;
}
//****************************************************************************************************************************************************
//
//****************************************************************************************************************************************************
QIcon loadIcon(QString const &path) {
if (path.endsWith(".svg", Qt::CaseInsensitive)) {
return loadIconFromSVG(path);
}
return loadIconFromImage(path);
}
//****************************************************************************************************************************************************
/// \brief Retrieve the color associated with a tray icon state.
///
@ -95,6 +158,18 @@ QString stateText(TrayIcon::State state) {
}
//****************************************************************************************************************************************************
/// \brief converts a QML resource path to Qt resource path.
/// QML resource paths are a bit different from qt resource paths
/// \param[in] path The resource path.
/// \return
//****************************************************************************************************************************************************
QString qmlResourcePathToQt(QString const &path) {
QString result = path;
result.replace(QRegularExpression(R"(^\.\/)"), ":/qml/");
return result;
}
} // anonymous namespace
@ -103,17 +178,17 @@ QString stateText(TrayIcon::State state) {
//****************************************************************************************************************************************************
TrayIcon::TrayIcon()
: QSystemTrayIcon()
, menu_(new QMenu) {
, menu_(new QMenu)
, notificationErrorIcon_(loadIconFromSVG(":/qml/icons/ic-alert.svg")) {
this->generateDotIcons();
this->setContextMenu(menu_.get());
connect(menu_.get(), &QMenu::aboutToShow, this, &TrayIcon::onMenuAboutToShow);
connect(this, &TrayIcon::selectUser, &app().backend(), &QMLBackend::selectUser);
connect(this, &TrayIcon::activated, this, &TrayIcon::onActivated);
// some OSes/Desktop managers will automatically show main window when clicked, but not all, so we do it manually.
connect(this, &TrayIcon::messageClicked, &app().backend(), &QMLBackend::showMainWindow);
this->show();
this->setState(State::Normal, QString(), QString());
// TrayIcon does not expose its screen, so we connect relevant screen events to our DPI change handler.
for (QScreen *screen: QGuiApplication::screens()) {
@ -151,7 +226,7 @@ void TrayIcon::onUserClicked() {
throw Exception("Could not retrieve context menu's selected user.");
}
emit selectUser(userID);
emit selectUser(userID, true);
} catch (Exception const &e) {
app().log().error(e.qwhat());
}
@ -212,18 +287,17 @@ void TrayIcon::onIconRefreshTimer() {
//
//****************************************************************************************************************************************************
void TrayIcon::generateDotIcons() {
QPixmap dotSVG(":/qml/icons/ic-dot.svg");
QSvgRenderer dotSVG(QString(":/qml/icons/ic-dot.svg"));
struct IconColor {
QIcon &icon;
QColor color;
};
for (auto pair: QList<IconColor> {{ greenDot_, normalColor }, { greyDot_, greyColor }, { orangeDot_, warnColor }}) {
QPixmap p = dotSVG;
QPainter painter(&p);
painter.setCompositionMode(QPainter::CompositionMode_SourceIn);
painter.fillRect(p.rect(), pair.color);
painter.end();
pair.icon = QIcon(p);
pair.icon = loadIconFromSVGRenderer(dotSVG, pair.color);
if (pair.icon.isNull()) {
throw Exception("Could not generate dot icon from vector file.");
}
}
}
@ -242,26 +316,28 @@ void TrayIcon::setState(TrayIcon::State state, QString const &stateString, QStri
}
//****************************************************************************************************************************************************
/// \param[in] title The title.
/// \param[in] message The message.
//****************************************************************************************************************************************************
void TrayIcon::showErrorPopupNotification(QString const &title, QString const &message) {
this->showMessage(title, message, notificationErrorIcon_);
}
//****************************************************************************************************************************************************
/// \param[in] svgPath The path of the SVG file for the icon.
/// \param[in] color The color to apply to the icon.
//****************************************************************************************************************************************************
void TrayIcon::generateStatusIcon(QString const &svgPath, QColor const &color) {
// We use the SVG path as pixmap mask and fill it with the appropriate color
QString resourcePath = svgPath;
resourcePath.replace(QRegularExpression(R"(^\.\/)"), ":/qml/"); // QML resource path are a bit different from the Qt resources path.
QPixmap pixmap(resourcePath);
QPainter painter(&pixmap);
painter.setCompositionMode(QPainter::CompositionMode_SourceIn);
painter.fillRect(pixmap.rect(), color);
painter.end();
statusIcon_ = QIcon(pixmap);
statusIcon_ = loadIconFromSVG(qmlResourcePathToQt(svgPath), color);
}
//**********************************************************************************************************************
//****************************************************************************************************************************************************
//
//**********************************************************************************************************************
//****************************************************************************************************************************************************
void TrayIcon::refreshContextMenu() {
if (!menu_) {
app().log().error("Native tray icon context menu is null.");
@ -297,3 +373,5 @@ void TrayIcon::refreshContextMenu() {
menu_->addSeparator();
menu_->addAction(tr("&Quit Bridge"), onMac ? QKeySequence("Ctrl+Q") : noShortcut, &app().backend(), &QMLBackend::quit);
}

View File

@ -41,10 +41,10 @@ public: // data members
TrayIcon& operator=(TrayIcon const&) = delete; ///< Disabled assignment operator.
TrayIcon& operator=(TrayIcon&&) = delete; ///< Disabled move assignment operator.
void setState(State state, QString const& stateString, QString const &statusIconPath); ///< Set the state of the icon
void showNotificationPopup(QString const& title, QString const &message, QString const& iconPath); ///< Display a pop up notification.
void showErrorPopupNotification(QString const& title, QString const &message); ///< Display a pop up notification.
signals:
void selectUser(QString const& userID); ///< Signal for selecting a user with a given userID
void selectUser(QString const& userID, bool forceShowWindow); ///< Signal for selecting a user with a given userID
private slots:
void onMenuAboutToShow(); ///< Slot called before the context menu is shown.
@ -67,6 +67,7 @@ private: // data members
QIcon greenDot_; ///< The green dot icon.
QIcon greyDot_; ///< The grey dot icon.
QIcon orangeDot_; ///< The orange dot icon.
QIcon const notificationErrorIcon_; ///< The error icon used for notifications.
QTimer iconRefreshTimer_; ///< The timer used to periodically refresh the icon when DPI changes.
QDateTime iconRefreshDeadline_; ///< The deadline for refreshing the icon

View File

@ -305,6 +305,8 @@ int main(int argc, char *argv[]) {
// When not in attached mode, log entries are forwarded to bridge, which output it on stdout/stderr. bridge-gui's process monitor intercept
// these outputs and output them on the command-line.
log.setLevel(cliOptions.logLevel);
log.info(QString("New Sentry reporter - id: %1.").arg(getProtectedHostname()));
QString bridgeexec;
if (!cliOptions.attach) {
if (isBridgeRunning()) {

View File

@ -95,9 +95,11 @@ ApplicationWindow {
root.showAndRise()
}
function onSelectUser(userID) {
function onSelectUser(userID, forceShowWindow) {
contentWrapper.selectUser(userID)
root.showAndRise()
if (forceShowWindow) {
root.showAndRise()
}
}
}

View File

@ -535,11 +535,12 @@ QtObject {
}
property Notification onlyPaidUsers: Notification {
description: qsTr("Bridge is exclusive to our paid plans. Upgrade your account to use Bridge.")
description: qsTr("Bridge is exclusive to our mail paid plans. Upgrade your account to use Bridge.")
brief: qsTr("Upgrade your account")
icon: "./icons/ic-exclamation-circle-filled.svg"
type: Notification.NotificationType.Danger
group: Notifications.Group.Configuration
property var pricingLink: "https://proton.me/mail/pricing"
Connections {
target: Backend
@ -550,8 +551,9 @@ QtObject {
action: [
Action {
text: qsTr("OK")
text: qsTr("Upgrade")
onTriggered: {
Qt.openUrlExternally(root.onlyPaidUsers.pricingLink)
root.onlyPaidUsers.active = false
}
}

View File

@ -344,7 +344,12 @@ FocusScope {
if (str.length === 0) {
return qsTr("Enter the 6-digit code")
}
return
}
onTextChanged: {
if (text.length >= 6) {
twoFAButton.onClicked()
}
}
onAccepted: {

View File

@ -6,19 +6,19 @@
#include "focus.grpc.pb.h"
#include <functional>
#include <grpcpp/impl/codegen/async_stream.h>
#include <grpcpp/impl/codegen/async_unary_call.h>
#include <grpcpp/impl/codegen/channel_interface.h>
#include <grpcpp/impl/codegen/client_unary_call.h>
#include <grpcpp/impl/codegen/client_callback.h>
#include <grpcpp/impl/codegen/message_allocator.h>
#include <grpcpp/impl/codegen/method_handler.h>
#include <grpcpp/impl/codegen/rpc_service_method.h>
#include <grpcpp/impl/codegen/server_callback.h>
#include <grpcpp/support/async_stream.h>
#include <grpcpp/support/async_unary_call.h>
#include <grpcpp/impl/channel_interface.h>
#include <grpcpp/impl/client_unary_call.h>
#include <grpcpp/support/client_callback.h>
#include <grpcpp/support/message_allocator.h>
#include <grpcpp/support/method_handler.h>
#include <grpcpp/impl/rpc_service_method.h>
#include <grpcpp/support/server_callback.h>
#include <grpcpp/impl/codegen/server_callback_handlers.h>
#include <grpcpp/impl/codegen/server_context.h>
#include <grpcpp/impl/codegen/service_type.h>
#include <grpcpp/impl/codegen/sync_stream.h>
#include <grpcpp/server_context.h>
#include <grpcpp/impl/service_type.h>
#include <grpcpp/support/sync_stream.h>
namespace focus {
static const char* Focus_method_names[] = {

View File

@ -25,23 +25,23 @@
#include "focus.pb.h"
#include <functional>
#include <grpcpp/impl/codegen/async_generic_service.h>
#include <grpcpp/impl/codegen/async_stream.h>
#include <grpcpp/impl/codegen/async_unary_call.h>
#include <grpcpp/impl/codegen/client_callback.h>
#include <grpcpp/impl/codegen/client_context.h>
#include <grpcpp/impl/codegen/completion_queue.h>
#include <grpcpp/impl/codegen/message_allocator.h>
#include <grpcpp/impl/codegen/method_handler.h>
#include <grpcpp/generic/async_generic_service.h>
#include <grpcpp/support/async_stream.h>
#include <grpcpp/support/async_unary_call.h>
#include <grpcpp/support/client_callback.h>
#include <grpcpp/client_context.h>
#include <grpcpp/completion_queue.h>
#include <grpcpp/support/message_allocator.h>
#include <grpcpp/support/method_handler.h>
#include <grpcpp/impl/codegen/proto_utils.h>
#include <grpcpp/impl/codegen/rpc_method.h>
#include <grpcpp/impl/codegen/server_callback.h>
#include <grpcpp/impl/rpc_method.h>
#include <grpcpp/support/server_callback.h>
#include <grpcpp/impl/codegen/server_callback_handlers.h>
#include <grpcpp/impl/codegen/server_context.h>
#include <grpcpp/impl/codegen/service_type.h>
#include <grpcpp/server_context.h>
#include <grpcpp/impl/service_type.h>
#include <grpcpp/impl/codegen/status.h>
#include <grpcpp/impl/codegen/stub_options.h>
#include <grpcpp/impl/codegen/sync_stream.h>
#include <grpcpp/support/stub_options.h>
#include <grpcpp/support/sync_stream.h>
namespace focus {

View File

@ -13,7 +13,7 @@
#error incompatible with your Protocol Buffer headers. Please update
#error your headers.
#endif
#if 3021003 < PROTOBUF_MIN_PROTOC_VERSION
#if 3021012 < PROTOBUF_MIN_PROTOC_VERSION
#error This file was generated by an older version of protoc which is
#error incompatible with your Protocol Buffer headers. Please
#error regenerate this file with a newer version of protoc.

View File

@ -6,19 +6,19 @@
#include "bridge.grpc.pb.h"
#include <functional>
#include <grpcpp/impl/codegen/async_stream.h>
#include <grpcpp/impl/codegen/async_unary_call.h>
#include <grpcpp/impl/codegen/channel_interface.h>
#include <grpcpp/impl/codegen/client_unary_call.h>
#include <grpcpp/impl/codegen/client_callback.h>
#include <grpcpp/impl/codegen/message_allocator.h>
#include <grpcpp/impl/codegen/method_handler.h>
#include <grpcpp/impl/codegen/rpc_service_method.h>
#include <grpcpp/impl/codegen/server_callback.h>
#include <grpcpp/support/async_stream.h>
#include <grpcpp/support/async_unary_call.h>
#include <grpcpp/impl/channel_interface.h>
#include <grpcpp/impl/client_unary_call.h>
#include <grpcpp/support/client_callback.h>
#include <grpcpp/support/message_allocator.h>
#include <grpcpp/support/method_handler.h>
#include <grpcpp/impl/rpc_service_method.h>
#include <grpcpp/support/server_callback.h>
#include <grpcpp/impl/codegen/server_callback_handlers.h>
#include <grpcpp/impl/codegen/server_context.h>
#include <grpcpp/impl/codegen/service_type.h>
#include <grpcpp/impl/codegen/sync_stream.h>
#include <grpcpp/server_context.h>
#include <grpcpp/impl/service_type.h>
#include <grpcpp/support/sync_stream.h>
namespace grpc {
static const char* Bridge_method_names[] = {

View File

@ -25,23 +25,23 @@
#include "bridge.pb.h"
#include <functional>
#include <grpcpp/impl/codegen/async_generic_service.h>
#include <grpcpp/impl/codegen/async_stream.h>
#include <grpcpp/impl/codegen/async_unary_call.h>
#include <grpcpp/impl/codegen/client_callback.h>
#include <grpcpp/impl/codegen/client_context.h>
#include <grpcpp/impl/codegen/completion_queue.h>
#include <grpcpp/impl/codegen/message_allocator.h>
#include <grpcpp/impl/codegen/method_handler.h>
#include <grpcpp/generic/async_generic_service.h>
#include <grpcpp/support/async_stream.h>
#include <grpcpp/support/async_unary_call.h>
#include <grpcpp/support/client_callback.h>
#include <grpcpp/client_context.h>
#include <grpcpp/completion_queue.h>
#include <grpcpp/support/message_allocator.h>
#include <grpcpp/support/method_handler.h>
#include <grpcpp/impl/codegen/proto_utils.h>
#include <grpcpp/impl/codegen/rpc_method.h>
#include <grpcpp/impl/codegen/server_callback.h>
#include <grpcpp/impl/rpc_method.h>
#include <grpcpp/support/server_callback.h>
#include <grpcpp/impl/codegen/server_callback_handlers.h>
#include <grpcpp/impl/codegen/server_context.h>
#include <grpcpp/impl/codegen/service_type.h>
#include <grpcpp/server_context.h>
#include <grpcpp/impl/service_type.h>
#include <grpcpp/impl/codegen/status.h>
#include <grpcpp/impl/codegen/stub_options.h>
#include <grpcpp/impl/codegen/sync_stream.h>
#include <grpcpp/support/stub_options.h>
#include <grpcpp/support/sync_stream.h>
namespace grpc {

View File

@ -13,7 +13,7 @@
#error incompatible with your Protocol Buffer headers. Please update
#error your headers.
#endif
#if 3021003 < PROTOBUF_MIN_PROTOC_VERSION
#if 3021012 < PROTOBUF_MIN_PROTOC_VERSION
#error This file was generated by an older version of protoc which is
#error incompatible with your Protocol Buffer headers. Please
#error regenerate this file with a newer version of protoc.

View File

@ -34,9 +34,7 @@ SPUser User::newUser(QObject *parent) {
/// \param[in] parent The parent object.
//****************************************************************************************************************************************************
User::User(QObject *parent)
: QObject(parent)
, imapFailureCooldownEndTime_(QDateTime::currentDateTime()) {
: QObject(parent) {
}
@ -355,22 +353,18 @@ QString User::stateToString(UserState state) {
//****************************************************************************************************************************************************
/// We display a notification and pop the application window if an IMAP client tries to connect to a signed out account, but we do not want to
/// do it repeatedly, as it's an intrusive action. This function let's you define a period of time during which the notification should not be
/// displayed.
///
/// \param durationMSecs The duration of the period in milliseconds.
/// \param[in] durationMSecs The duration of the period in milliseconds.
//****************************************************************************************************************************************************
void User::startImapLoginFailureCooldown(qint64 durationMSecs) {
imapFailureCooldownEndTime_ = QDateTime::currentDateTime().addMSecs(durationMSecs);
void User::startNotificationCooldownPeriod(User::ENotification notification, qint64 durationMSecs) {
notificationCooldownList_[notification] = QDateTime::currentDateTime().addMSecs(durationMSecs);
}
//****************************************************************************************************************************************************
/// \return true if we currently are in a cooldown period for the notification
/// \return true iff the notification is currently in a cooldown period.
//****************************************************************************************************************************************************
bool User::isInIMAPLoginFailureCooldown() const {
return QDateTime::currentDateTime() < imapFailureCooldownEndTime_;
bool User::isNotificationInCooldown(User::ENotification notification) const {
return notificationCooldownList_.contains(notification) && (QDateTime::currentDateTime() < notificationCooldownList_[notification]);
}

View File

@ -62,6 +62,13 @@ typedef std::shared_ptr<class User> SPUser; ///< Type definition for shared poin
class User : public QObject {
Q_OBJECT
public: // data types
enum class ENotification {
IMAPLoginWhileSignedOut, ///< An IMAP client tried to login while the user is signed out.
IMAPPasswordFailure, ///< An IMAP client provided an invalid password for the user.
IMAPLoginWhileLocked, ///< An IMAP client tried to connect while the user is locked.
};
public: // static member function
static SPUser newUser(QObject *parent); ///< Create a new user
static QString stateToString(UserState state); ///< Return a string describing a user state.
@ -74,8 +81,8 @@ public: // member functions.
User &operator=(User &&) = delete; ///< Disabled move assignment operator.
void update(User const &user); ///< Update the user.
Q_INVOKABLE QString primaryEmailOrUsername() const; ///< Return the user primary email, or, if unknown its username.
void startImapLoginFailureCooldown(qint64 durationMSecs); ///< Start the user cooldown period for the IMAP login attempt while signed-out notification.
bool isInIMAPLoginFailureCooldown() const; ///< Check if the user in a IMAP login failure notification.
void startNotificationCooldownPeriod(ENotification notification, qint64 durationMSecs); ///< Start the user cooldown period for a notification.
bool isNotificationInCooldown(ENotification notification) const; ///< Return true iff the notification is in a cooldown period.
public slots:
// slots for QML generated calls
@ -147,7 +154,7 @@ private: // member functions.
User(QObject *parent); ///< Default constructor.
private: // data members.
QDateTime imapFailureCooldownEndTime_; ///< The end date/time for the IMAP login failure notification cooldown period.
QMap<ENotification, QDateTime> notificationCooldownList_; ///< A list of cooldown period end time for notifications.
QString id_; ///< The userID.
QString username_; ///< The username
QString password_; ///< The IMAP password of the user.

View File

@ -297,7 +297,7 @@ func (f *frontendCLI) configureAppleMail(c *ishell.Context) {
return
}
if err := f.bridge.ConfigureAppleMail(user.UserID, user.Addresses[0]); err != nil {
if err := f.bridge.ConfigureAppleMail(context.Background(), user.UserID, user.Addresses[0]); err != nil {
f.printAndLogError(err)
return
}
@ -305,11 +305,11 @@ func (f *frontendCLI) configureAppleMail(c *ishell.Context) {
f.Printf("Apple Mail configured for %v with address %v\n", user.Username, user.Addresses[0])
}
func (f *frontendCLI) badEventSynchronize(c *ishell.Context) {
func (f *frontendCLI) badEventSynchronize(_ *ishell.Context) {
f.badEventFeedback(true)
}
func (f *frontendCLI) badEventLogout(c *ishell.Context) {
func (f *frontendCLI) badEventLogout(_ *ishell.Context) {
f.badEventFeedback(false)
}

View File

@ -20,6 +20,7 @@ package cli
import (
"errors"
"os"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
@ -60,6 +61,11 @@ func New(
panicHandler: panicHandler,
}
// We want to exit at the first Ctrl+C. By default, ishell requires two.
fe.Interrupt(func(_ *ishell.Context, _ int, _ string) {
os.Exit(1)
})
// Clear commands.
clearCmd := &ishell.Cmd{
Name: "clear",

View File

@ -31,7 +31,7 @@ import (
"github.com/abiosoft/ishell"
)
func (f *frontendCLI) printLogDir(c *ishell.Context) {
func (f *frontendCLI) printLogDir(_ *ishell.Context) {
if path, err := f.bridge.GetLogsPath(); err != nil {
f.Println("Failed to determine location of log files")
} else {
@ -39,17 +39,17 @@ func (f *frontendCLI) printLogDir(c *ishell.Context) {
}
}
func (f *frontendCLI) printManual(c *ishell.Context) {
func (f *frontendCLI) printManual(_ *ishell.Context) {
f.Println("More instructions about the Bridge can be found at\n\n https://proton.me/mail/bridge")
}
func (f *frontendCLI) printCredits(c *ishell.Context) {
func (f *frontendCLI) printCredits(_ *ishell.Context) {
for _, pkg := range strings.Split(bridge.Credits, ";") {
f.Println(pkg)
}
}
func (f *frontendCLI) changeIMAPSecurity(c *ishell.Context) {
func (f *frontendCLI) changeIMAPSecurity(_ *ishell.Context) {
f.ShowPrompt(false)
defer f.ShowPrompt(true)
@ -61,14 +61,14 @@ func (f *frontendCLI) changeIMAPSecurity(c *ishell.Context) {
msg := fmt.Sprintf("Are you sure you want to change IMAP setting to %q", newSecurity)
if f.yesNoQuestion(msg) {
if err := f.bridge.SetIMAPSSL(!f.bridge.GetIMAPSSL()); err != nil {
if err := f.bridge.SetIMAPSSL(context.Background(), !f.bridge.GetIMAPSSL()); err != nil {
f.printAndLogError(err)
return
}
}
}
func (f *frontendCLI) changeSMTPSecurity(c *ishell.Context) {
func (f *frontendCLI) changeSMTPSecurity(_ *ishell.Context) {
f.ShowPrompt(false)
defer f.ShowPrompt(true)
@ -80,7 +80,7 @@ func (f *frontendCLI) changeSMTPSecurity(c *ishell.Context) {
msg := fmt.Sprintf("Are you sure you want to change SMTP setting to %q", newSecurity)
if f.yesNoQuestion(msg) {
if err := f.bridge.SetSMTPSSL(!f.bridge.GetSMTPSSL()); err != nil {
if err := f.bridge.SetSMTPSSL(context.Background(), !f.bridge.GetSMTPSSL()); err != nil {
f.printAndLogError(err)
return
}
@ -103,7 +103,7 @@ func (f *frontendCLI) changeIMAPPort(c *ishell.Context) {
return
}
if err := f.bridge.SetIMAPPort(newIMAPPortInt); err != nil {
if err := f.bridge.SetIMAPPort(context.Background(), newIMAPPortInt); err != nil {
f.printAndLogError(err)
return
}
@ -125,13 +125,13 @@ func (f *frontendCLI) changeSMTPPort(c *ishell.Context) {
return
}
if err := f.bridge.SetSMTPPort(newSMTPPortInt); err != nil {
if err := f.bridge.SetSMTPPort(context.Background(), newSMTPPortInt); err != nil {
f.printAndLogError(err)
return
}
}
func (f *frontendCLI) allowProxy(c *ishell.Context) {
func (f *frontendCLI) allowProxy(_ *ishell.Context) {
if f.bridge.GetProxyAllowed() {
f.Println("Bridge is already set to use alternative routing to connect to Proton if it is being blocked.")
return
@ -147,7 +147,7 @@ func (f *frontendCLI) allowProxy(c *ishell.Context) {
}
}
func (f *frontendCLI) disallowProxy(c *ishell.Context) {
func (f *frontendCLI) disallowProxy(_ *ishell.Context) {
if !f.bridge.GetProxyAllowed() {
f.Println("Bridge is already set to NOT use alternative routing to connect to Proton if it is being blocked.")
return
@ -163,7 +163,7 @@ func (f *frontendCLI) disallowProxy(c *ishell.Context) {
}
}
func (f *frontendCLI) hideAllMail(c *ishell.Context) {
func (f *frontendCLI) hideAllMail(_ *ishell.Context) {
if !f.bridge.GetShowAllMail() {
f.Println("All Mail folder is not listed in your local client.")
return
@ -179,7 +179,7 @@ func (f *frontendCLI) hideAllMail(c *ishell.Context) {
}
}
func (f *frontendCLI) showAllMail(c *ishell.Context) {
func (f *frontendCLI) showAllMail(_ *ishell.Context) {
if f.bridge.GetShowAllMail() {
f.Println("All Mail folder is listed in your local client.")
return

View File

@ -23,7 +23,7 @@ import (
"github.com/abiosoft/ishell"
)
func (f *frontendCLI) checkUpdates(c *ishell.Context) {
func (f *frontendCLI) checkUpdates(_ *ishell.Context) {
updateCh, done := f.bridge.GetEvents(events.UpdateAvailable{}, events.UpdateNotAvailable{})
defer done()
@ -38,7 +38,7 @@ func (f *frontendCLI) checkUpdates(c *ishell.Context) {
}
}
func (f *frontendCLI) enableAutoUpdates(c *ishell.Context) {
func (f *frontendCLI) enableAutoUpdates(_ *ishell.Context) {
if f.bridge.GetAutoUpdate() {
f.Println("Bridge is already set to automatically install updates.")
return
@ -54,7 +54,7 @@ func (f *frontendCLI) enableAutoUpdates(c *ishell.Context) {
}
}
func (f *frontendCLI) disableAutoUpdates(c *ishell.Context) {
func (f *frontendCLI) disableAutoUpdates(_ *ishell.Context) {
if !f.bridge.GetAutoUpdate() {
f.Println("Bridge is already set to NOT automatically install updates.")
return
@ -70,7 +70,7 @@ func (f *frontendCLI) disableAutoUpdates(c *ishell.Context) {
}
}
func (f *frontendCLI) selectEarlyChannel(c *ishell.Context) {
func (f *frontendCLI) selectEarlyChannel(_ *ishell.Context) {
if f.bridge.GetUpdateChannel() == updater.EarlyChannel {
f.Println("Bridge is already on the early-access update channel.")
return
@ -86,7 +86,7 @@ func (f *frontendCLI) selectEarlyChannel(c *ishell.Context) {
}
}
func (f *frontendCLI) selectStableChannel(c *ishell.Context) {
func (f *frontendCLI) selectStableChannel(_ *ishell.Context) {
if f.bridge.GetUpdateChannel() == updater.StableChannel {
f.Println("Bridge is already on the stable update channel.")
return

View File

@ -47,7 +47,7 @@ import (
)
// CheckTokens implements the CheckToken gRPC service call.
func (s *Service) CheckTokens(ctx context.Context, clientConfigPath *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
func (s *Service) CheckTokens(_ context.Context, clientConfigPath *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
s.log.Debug("CheckTokens")
path := clientConfigPath.Value
@ -65,7 +65,7 @@ func (s *Service) CheckTokens(ctx context.Context, clientConfigPath *wrapperspb.
return &wrapperspb.StringValue{Value: clientConfig.Token}, nil
}
func (s *Service) AddLogEntry(ctx context.Context, request *AddLogEntryRequest) (*emptypb.Empty, error) {
func (s *Service) AddLogEntry(_ context.Context, request *AddLogEntryRequest) (*emptypb.Empty, error) {
entry := s.log
if len(request.Package) > 0 {
@ -93,7 +93,7 @@ func (s *Service) AddLogEntry(ctx context.Context, request *AddLogEntryRequest)
}
// GuiReady implement the GuiReady gRPC service call.
func (s *Service) GuiReady(ctx context.Context, _ *emptypb.Empty) (*GuiReadyResponse, error) {
func (s *Service) GuiReady(_ context.Context, _ *emptypb.Empty) (*GuiReadyResponse, error) {
s.log.Debug("GuiReady")
s.initializationDone.Do(s.initializing.Done)
@ -107,7 +107,7 @@ func (s *Service) GuiReady(ctx context.Context, _ *emptypb.Empty) (*GuiReadyResp
}
// Quit implement the Quit gRPC service call.
func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empty, error) {
func (s *Service) Quit(_ context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
s.log.Debug("Quit")
return &emptypb.Empty{}, s.quit()
}
@ -143,13 +143,13 @@ func (s *Service) Restart(ctx context.Context, empty *emptypb.Empty) (*emptypb.E
return s.Quit(ctx, empty)
}
func (s *Service) ShowOnStartup(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
func (s *Service) ShowOnStartup(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("ShowOnStartup")
return wrapperspb.Bool(s.showOnStartup), nil
}
func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
func (s *Service) SetIsAutostartOn(_ context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
s.log.WithField("show", isOn.Value).Debug("SetIsAutostartOn")
defer func() { _ = s.SendEvent(NewToggleAutostartFinishedEvent()) }()
@ -169,13 +169,13 @@ func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolVal
return &emptypb.Empty{}, nil
}
func (s *Service) IsAutostartOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
func (s *Service) IsAutostartOn(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsAutostartOn")
return wrapperspb.Bool(s.bridge.GetAutostart()), nil
}
func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
func (s *Service) SetIsBetaEnabled(_ context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
s.log.WithField("isEnabled", isEnabled.Value).Debug("SetIsBetaEnabled")
channel := updater.StableChannel
@ -191,13 +191,13 @@ func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.Bo
return &emptypb.Empty{}, nil
}
func (s *Service) IsBetaEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
func (s *Service) IsBetaEnabled(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsBetaEnabled")
return wrapperspb.Bool(s.bridge.GetUpdateChannel() == updater.EarlyChannel), nil
}
func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb.BoolValue) (*emptypb.Empty, error) {
func (s *Service) SetIsAllMailVisible(_ context.Context, isVisible *wrapperspb.BoolValue) (*emptypb.Empty, error) {
s.log.WithField("isVisible", isVisible.Value).Debug("SetIsAllMailVisible")
if err := s.bridge.SetShowAllMail(isVisible.Value); err != nil {
@ -208,7 +208,7 @@ func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb
return &emptypb.Empty{}, nil
}
func (s *Service) IsAllMailVisible(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
func (s *Service) IsAllMailVisible(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsAllMailVisible")
return wrapperspb.Bool(s.bridge.GetShowAllMail()), nil
@ -231,13 +231,13 @@ func (s *Service) IsTelemetryDisabled(_ context.Context, _ *emptypb.Empty) (*wra
return wrapperspb.Bool(s.bridge.GetTelemetryDisabled()), nil
}
func (s *Service) GoOs(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
func (s *Service) GoOs(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("GoOs") // TO-DO We can probably get rid of this and use QSysInfo::product name
return wrapperspb.String(runtime.GOOS), nil
}
func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
func (s *Service) TriggerReset(_ context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
s.log.Debug("TriggerReset")
go func() {
@ -248,13 +248,13 @@ func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb.
return &emptypb.Empty{}, nil
}
func (s *Service) Version(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
func (s *Service) Version(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("Version")
return wrapperspb.String(s.bridge.GetCurrentVersion().Original()), nil
}
func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
func (s *Service) LogsPath(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("LogsPath")
path, err := s.bridge.GetLogsPath()
@ -265,7 +265,7 @@ func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.S
return wrapperspb.String(path), nil
}
func (s *Service) LicensePath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
func (s *Service) LicensePath(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("LicensePath")
return wrapperspb.String(s.bridge.GetLicenseFilePath()), nil
@ -275,7 +275,7 @@ func (s *Service) DependencyLicensesLink(_ context.Context, _ *emptypb.Empty) (*
return wrapperspb.String(s.bridge.GetDependencyLicensesLink()), nil
}
func (s *Service) ReleaseNotesPageLink(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
func (s *Service) ReleaseNotesPageLink(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.latestLock.RLock()
defer s.latestLock.RUnlock()
@ -289,7 +289,7 @@ func (s *Service) LandingPageLink(_ context.Context, _ *emptypb.Empty) (*wrapper
return wrapperspb.String(s.latest.LandingPage), nil
}
func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.StringValue) (*emptypb.Empty, error) {
func (s *Service) SetColorSchemeName(_ context.Context, name *wrapperspb.StringValue) (*emptypb.Empty, error) {
s.log.WithField("ColorSchemeName", name.Value).Debug("SetColorSchemeName")
if !theme.IsAvailable(theme.Theme(name.Value)) {
@ -305,7 +305,7 @@ func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.Strin
return &emptypb.Empty{}, nil
}
func (s *Service) ColorSchemeName(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
func (s *Service) ColorSchemeName(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("ColorSchemeName")
current := s.bridge.GetColorScheme()
@ -320,13 +320,13 @@ func (s *Service) ColorSchemeName(ctx context.Context, _ *emptypb.Empty) (*wrapp
return wrapperspb.String(current), nil
}
func (s *Service) CurrentEmailClient(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
func (s *Service) CurrentEmailClient(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("CurrentEmailClient")
return wrapperspb.String(s.bridge.GetCurrentUserAgent()), nil
}
func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emptypb.Empty, error) {
func (s *Service) ReportBug(_ context.Context, report *ReportBugRequest) (*emptypb.Empty, error) {
s.log.WithFields(logrus.Fields{
"osType": report.OsType,
"osVersion": report.OsVersion,
@ -382,7 +382,7 @@ func (s *Service) ExportTLSCertificates(_ context.Context, folderPath *wrappersp
return &emptypb.Empty{}, nil
}
func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.StringValue) (*emptypb.Empty, error) {
func (s *Service) ForceLauncher(_ context.Context, launcher *wrapperspb.StringValue) (*emptypb.Empty, error) {
s.log.WithField("launcher", launcher.Value).Debug("ForceLauncher")
s.restarter.Override(launcher.Value)
@ -390,7 +390,7 @@ func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.String
return &emptypb.Empty{}, nil
}
func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringValue) (*emptypb.Empty, error) {
func (s *Service) SetMainExecutable(_ context.Context, exe *wrapperspb.StringValue) (*emptypb.Empty, error) {
s.log.WithField("executable", exe.Value).Debug("SetMainExecutable")
s.restarter.AddFlags("--wait", exe.Value)
@ -398,7 +398,7 @@ func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringV
return &emptypb.Empty{}, nil
}
func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
func (s *Service) Login(_ context.Context, login *LoginRequest) (*emptypb.Empty, error) {
s.log.WithField("username", login.Username).Debug("Login")
go func() {
@ -454,7 +454,7 @@ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empt
return &emptypb.Empty{}, nil
}
func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
func (s *Service) Login2FA(_ context.Context, login *LoginRequest) (*emptypb.Empty, error) {
s.log.WithField("username", login.Username).Debug("Login2FA")
go func() {
@ -499,7 +499,7 @@ func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.E
return &emptypb.Empty{}, nil
}
func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
func (s *Service) Login2Passwords(_ context.Context, login *LoginRequest) (*emptypb.Empty, error) {
s.log.WithField("username", login.Username).Debug("Login2Passwords")
go func() {
@ -521,7 +521,7 @@ func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*em
return &emptypb.Empty{}, nil
}
func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest) (*emptypb.Empty, error) {
func (s *Service) LoginAbort(_ context.Context, loginAbort *LoginAbortRequest) (*emptypb.Empty, error) {
s.log.WithField("username", loginAbort.Username).Debug("LoginAbort")
go func() {
@ -565,7 +565,7 @@ func (s *Service) CheckUpdate(context.Context, *emptypb.Empty) (*emptypb.Empty,
return &emptypb.Empty{}, nil
}
func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
func (s *Service) InstallUpdate(_ context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
s.log.Debug("InstallUpdate")
go func() {
@ -579,7 +579,7 @@ func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb
return &emptypb.Empty{}, nil
}
func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
func (s *Service) SetIsAutomaticUpdateOn(_ context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
s.log.WithField("isOn", isOn.Value).Debug("SetIsAutomaticUpdateOn")
if currentlyOn := s.bridge.GetAutoUpdate(); currentlyOn == isOn.Value {
@ -594,19 +594,19 @@ func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.B
return &emptypb.Empty{}, nil
}
func (s *Service) IsAutomaticUpdateOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
func (s *Service) IsAutomaticUpdateOn(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsAutomaticUpdateOn")
return wrapperspb.Bool(s.bridge.GetAutoUpdate()), nil
}
func (s *Service) DiskCachePath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
func (s *Service) DiskCachePath(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("DiskCachePath")
return wrapperspb.String(s.bridge.GetGluonCacheDir()), nil
}
func (s *Service) SetDiskCachePath(ctx context.Context, newPath *wrapperspb.StringValue) (*emptypb.Empty, error) {
func (s *Service) SetDiskCachePath(_ context.Context, newPath *wrapperspb.StringValue) (*emptypb.Empty, error) {
s.log.WithField("path", newPath.Value).Debug("setDiskCachePath")
go func() {
@ -637,7 +637,7 @@ func (s *Service) SetDiskCachePath(ctx context.Context, newPath *wrapperspb.Stri
return &emptypb.Empty{}, nil
}
func (s *Service) SetIsDoHEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
func (s *Service) SetIsDoHEnabled(_ context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
s.log.WithField("isEnabled", isEnabled.Value).Debug("SetIsDohEnabled")
if err := s.bridge.SetProxyAllowed(isEnabled.Value); err != nil {
@ -648,7 +648,7 @@ func (s *Service) SetIsDoHEnabled(ctx context.Context, isEnabled *wrapperspb.Boo
return &emptypb.Empty{}, nil
}
func (s *Service) IsDoHEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
func (s *Service) IsDoHEnabled(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsDohEnabled")
return wrapperspb.Bool(s.bridge.GetProxyAllowed()), nil
@ -668,7 +668,7 @@ func (s *Service) MailServerSettings(_ context.Context, _ *emptypb.Empty) (*Imap
}, nil
}
func (s *Service) SetMailServerSettings(_ context.Context, settings *ImapSmtpSettings) (*emptypb.Empty, error) {
func (s *Service) SetMailServerSettings(ctx context.Context, settings *ImapSmtpSettings) (*emptypb.Empty, error) {
s.log.
WithField("ImapPort", settings.ImapPort).
WithField("SmtpPort", settings.SmtpPort).
@ -682,28 +682,28 @@ func (s *Service) SetMailServerSettings(_ context.Context, settings *ImapSmtpSet
defer func() { _ = s.SendEvent(NewChangeMailServerSettingsFinishedEvent()) }()
if s.bridge.GetIMAPSSL() != settings.UseSSLForImap {
if err := s.bridge.SetIMAPSSL(settings.UseSSLForImap); err != nil {
if err := s.bridge.SetIMAPSSL(ctx, settings.UseSSLForImap); err != nil {
s.log.WithError(err).Error("Failed to set IMAP SSL")
_ = s.SendEvent(NewMailServerSettingsErrorEvent(MailServerSettingsErrorType_IMAP_CONNECTION_MODE_CHANGE_ERROR))
}
}
if s.bridge.GetSMTPSSL() != settings.UseSSLForSmtp {
if err := s.bridge.SetSMTPSSL(settings.UseSSLForSmtp); err != nil {
if err := s.bridge.SetSMTPSSL(ctx, settings.UseSSLForSmtp); err != nil {
s.log.WithError(err).Error("Failed to set SMTP SSL")
_ = s.SendEvent(NewMailServerSettingsErrorEvent(MailServerSettingsErrorType_SMTP_CONNECTION_MODE_CHANGE_ERROR))
}
}
if s.bridge.GetIMAPPort() != int(settings.ImapPort) {
if err := s.bridge.SetIMAPPort(int(settings.ImapPort)); err != nil {
if err := s.bridge.SetIMAPPort(ctx, int(settings.ImapPort)); err != nil {
s.log.WithError(err).Error("Failed to set IMAP port")
_ = s.SendEvent(NewMailServerSettingsErrorEvent(MailServerSettingsErrorType_IMAP_PORT_CHANGE_ERROR))
}
}
if s.bridge.GetSMTPPort() != int(settings.SmtpPort) {
if err := s.bridge.SetSMTPPort(int(settings.SmtpPort)); err != nil {
if err := s.bridge.SetSMTPPort(ctx, int(settings.SmtpPort)); err != nil {
s.log.WithError(err).Error("Failed to set SMTP port")
_ = s.SendEvent(NewMailServerSettingsErrorEvent(MailServerSettingsErrorType_SMTP_PORT_CHANGE_ERROR))
}
@ -715,19 +715,19 @@ func (s *Service) SetMailServerSettings(_ context.Context, settings *ImapSmtpSet
return &emptypb.Empty{}, nil
}
func (s *Service) Hostname(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
func (s *Service) Hostname(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("Hostname")
return wrapperspb.String(constants.Host), nil
}
func (s *Service) IsPortFree(ctx context.Context, port *wrapperspb.Int32Value) (*wrapperspb.BoolValue, error) {
func (s *Service) IsPortFree(_ context.Context, port *wrapperspb.Int32Value) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsPortFree")
return wrapperspb.Bool(ports.IsPortFree(int(port.Value))), nil
}
func (s *Service) AvailableKeychains(ctx context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) {
func (s *Service) AvailableKeychains(_ context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) {
s.log.Debug("AvailableKeychains")
return &AvailableKeychainsResponse{Keychains: maps.Keys(keychain.Helpers)}, nil
@ -757,7 +757,7 @@ func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.S
return &emptypb.Empty{}, nil
}
func (s *Service) CurrentKeychain(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
func (s *Service) CurrentKeychain(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("CurrentKeychain")
helper, err := s.bridge.GetKeychainApp()

View File

@ -28,7 +28,7 @@ import (
"google.golang.org/protobuf/types/known/wrapperspb"
)
func (s *Service) GetUserList(ctx context.Context, _ *emptypb.Empty) (*UserListResponse, error) {
func (s *Service) GetUserList(_ context.Context, _ *emptypb.Empty) (*UserListResponse, error) {
s.log.Debug("GetUserList")
userIDs := s.bridge.GetUserIDs()
@ -51,7 +51,7 @@ func (s *Service) GetUserList(ctx context.Context, _ *emptypb.Empty) (*UserListR
return &UserListResponse{Users: userList}, nil
}
func (s *Service) GetUser(ctx context.Context, userID *wrapperspb.StringValue) (*User, error) {
func (s *Service) GetUser(_ context.Context, userID *wrapperspb.StringValue) (*User, error) {
s.log.WithField("userID", userID).Debug("GetUser")
user, err := s.bridge.GetUserInfo(userID.Value)
@ -62,7 +62,7 @@ func (s *Service) GetUser(ctx context.Context, userID *wrapperspb.StringValue) (
return grpcUserFromInfo(user), nil
}
func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitModeRequest) (*emptypb.Empty, error) {
func (s *Service) SetUserSplitMode(_ context.Context, splitMode *UserSplitModeRequest) (*emptypb.Empty, error) {
s.log.WithField("UserID", splitMode.UserID).WithField("Active", splitMode.Active).Debug("SetUserSplitMode")
user, err := s.bridge.GetUserInfo(splitMode.UserID)
@ -96,7 +96,7 @@ func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitMode
return &emptypb.Empty{}, nil
}
func (s *Service) SendBadEventUserFeedback(ctx context.Context, feedback *UserBadEventFeedbackRequest) (*emptypb.Empty, error) {
func (s *Service) SendBadEventUserFeedback(_ context.Context, feedback *UserBadEventFeedbackRequest) (*emptypb.Empty, error) {
l := s.log.WithField("UserID", feedback.UserID).WithField("doResync", feedback.DoResync)
l.Debug("SendBadEventUserFeedback")
@ -114,7 +114,7 @@ func (s *Service) SendBadEventUserFeedback(ctx context.Context, feedback *UserBa
return &emptypb.Empty{}, nil
}
func (s *Service) LogoutUser(ctx context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) {
func (s *Service) LogoutUser(_ context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) {
s.log.WithField("UserID", userID.Value).Debug("LogoutUser")
if _, err := s.bridge.GetUserInfo(userID.Value); err != nil {
@ -132,7 +132,7 @@ func (s *Service) LogoutUser(ctx context.Context, userID *wrapperspb.StringValue
return &emptypb.Empty{}, nil
}
func (s *Service) RemoveUser(ctx context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) {
func (s *Service) RemoveUser(_ context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) {
s.log.WithField("UserID", userID.Value).Debug("RemoveUser")
go func() {
@ -152,7 +152,7 @@ func (s *Service) ConfigureUserAppleMail(ctx context.Context, request *Configure
sslWasEnabled := s.bridge.GetSMTPSSL()
if err := s.bridge.ConfigureAppleMail(request.UserID, request.Address); err != nil {
if err := s.bridge.ConfigureAppleMail(ctx, request.UserID, request.Address); err != nil {
s.log.WithField("userID", request.UserID).Error("Cannot configure AppleMail for user")
return nil, status.Error(codes.Internal, "Apple Mail config failed")
}

View File

@ -113,7 +113,7 @@ func Init(logsPath, level string) error {
// Debug or Trace.
func setLevel(level string) error {
if level == "" {
return nil
level = "debug"
}
logLevel, err := logrus.ParseLevel(level)

View File

@ -96,6 +96,7 @@ func GetTimeZone() string {
// NewReporter creates new sentry reporter with appName and appVersion to report.
func NewReporter(appName string, identifier Identifier) *Reporter {
logrus.WithField("id", GetProtectedHostname()).Info("New sentry reporter")
return &Reporter{
appName: appName,
appVersion: constants.Revision,
@ -203,7 +204,7 @@ func SkipDuringUnwind() {
}
// EnhanceSentryEvent swaps type with value and removes panic handlers from the stacktrace.
func EnhanceSentryEvent(event *sentry.Event, hint *sentry.EventHint) *sentry.Event {
func EnhanceSentryEvent(event *sentry.Event, _ *sentry.EventHint) *sentry.Event {
for idx, exception := range event.Exception {
exception.Type, exception.Value = exception.Value, exception.Type
if exception.Stacktrace != nil {

View File

@ -62,6 +62,6 @@ func (i *InstallerDarwin) InstallUpdate(_ *semver.Version, r io.Reader) error {
return syncFolders(oldBundle, newBundle)
}
func (i *InstallerDarwin) IsAlreadyInstalled(version *semver.Version) bool {
func (i *InstallerDarwin) IsAlreadyInstalled(_ *semver.Version) bool {
return false
}

View File

@ -217,7 +217,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
// If the address is enabled, we need to hook it up to the update channels.
switch user.vault.AddressMode() {
case vault.CombinedMode:
primAddr, err := getAddrIdx(user.apiAddrs, 0)
primAddr, err := getPrimaryAddr(user.apiAddrs)
if err != nil {
return fmt.Errorf("failed to get primary address: %w", err)
}
@ -276,7 +276,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
switch user.vault.AddressMode() {
case vault.CombinedMode:
primAddr, err := getAddrIdx(user.apiAddrs, 0)
primAddr, err := getPrimaryAddr(user.apiAddrs)
if err != nil {
return fmt.Errorf("failed to get primary address: %w", err)
}
@ -394,7 +394,7 @@ func (user *User) handleLabelEvents(ctx context.Context, labelEvents []proton.La
return nil
}
func (user *User) handleCreateLabelEvent(ctx context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
func (user *User) handleCreateLabelEvent(_ context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
return safe.LockRetErr(func() ([]imap.Update, error) {
var updates []imap.Update
@ -480,7 +480,7 @@ func (user *User) handleUpdateLabelEvent(ctx context.Context, event proton.Label
}, user.apiLabelsLock, user.updateChLock)
}
func (user *User) handleDeleteLabelEvent(ctx context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
func (user *User) handleDeleteLabelEvent(_ context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
return safe.LockRetErr(func() ([]imap.Update, error) {
var updates []imap.Update
@ -628,7 +628,14 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
}
update = imap.NewMessagesCreated(false, res.update)
user.updateCh[full.AddressID].Enqueue(update)
didPublish, err := safePublishMessageUpdate(user, full.AddressID, update)
if err != nil {
return err
}
if !didPublish {
update = nil
}
return nil
}); err != nil {
@ -643,7 +650,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
}
func (user *User) handleUpdateMessageEvent(ctx context.Context, message proton.MessageMetadata) ([]imap.Update, error) { //nolint:unparam
func (user *User) handleUpdateMessageEvent(_ context.Context, message proton.MessageMetadata) ([]imap.Update, error) { //nolint:unparam
return safe.RLockRetErr(func() ([]imap.Update, error) {
user.log.WithFields(logrus.Fields{
"messageID": message.ID,
@ -674,13 +681,20 @@ func (user *User) handleUpdateMessageEvent(ctx context.Context, message proton.M
flags,
)
user.updateCh[message.AddressID].Enqueue(update)
didPublish, err := safePublishMessageUpdate(user, message.AddressID, update)
if err != nil {
return nil, err
}
if !didPublish {
return nil, nil
}
return []imap.Update{update}, nil
}, user.apiLabelsLock, user.updateChLock)
}
func (user *User) handleDeleteMessageEvent(ctx context.Context, event proton.MessageEvent) ([]imap.Update, error) { //nolint:unparam
func (user *User) handleDeleteMessageEvent(_ context.Context, event proton.MessageEvent) ([]imap.Update, error) {
return safe.RLockRetErr(func() ([]imap.Update, error) {
user.log.WithField("messageID", event.ID).Info("Handling message deleted event")
@ -696,7 +710,7 @@ func (user *User) handleDeleteMessageEvent(ctx context.Context, event proton.Mes
}, user.updateChLock)
}
func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.MessageEvent) ([]imap.Update, error) { //nolint:unparam
func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.MessageEvent) ([]imap.Update, error) {
return safe.RLockRetErr(func() ([]imap.Update, error) {
user.log.WithFields(logrus.Fields{
"messageID": event.ID,
@ -743,13 +757,24 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
true, // Is the message doesn't exist, silently create it.
)
user.updateCh[full.AddressID].Enqueue(update)
didPublish, err := safePublishMessageUpdate(user, full.AddressID, update)
if err != nil {
return err
}
if !didPublish {
update = nil
}
return nil
}); err != nil {
return nil, err
}
if update == nil {
return nil, nil
}
return []imap.Update{update}, nil
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
}
@ -816,3 +841,37 @@ func (user *User) reportErrorNoContextCancel(title string, err error, reportCont
}
}
}
// safePublishMessageUpdate handles the rare case where the address' update channel may have been deleted in the same
// event. This rare case can take place if in the same event fetch request there is an update for delete address and
// create/update message.
// If the user is in combined mode, we simply push the update to the primary address. If the user is in split mode
// we do not publish the update as the address no longer exists.
func safePublishMessageUpdate(user *User, addressID string, update imap.Update) (bool, error) {
v, ok := user.updateCh[addressID]
if !ok {
if user.GetAddressMode() == vault.CombinedMode {
primAddr, err := getPrimaryAddr(user.apiAddrs)
if err != nil {
return false, fmt.Errorf("failed to get primary address: %w", err)
}
primaryCh, ok := user.updateCh[primAddr.ID]
if !ok {
return false, fmt.Errorf("primary address channel is not available")
}
primaryCh.Enqueue(update)
return true, nil
}
logrus.Warnf("Update channel not found for address %v, it may have been already deleted", addressID)
_ = user.reporter.ReportMessage("Message Update channel does not exist")
return false, nil
}
v.Enqueue(update)
return true, nil
}

View File

@ -270,7 +270,7 @@ func (conn *imapConnector) CreateMessage(
mailboxID imap.MailboxID,
literal []byte,
flags imap.FlagSet,
date time.Time,
_ time.Time,
) (imap.Message, []byte, error) {
defer conn.goPollAPIEvents(false)
@ -459,11 +459,11 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
var result bool
if v, ok := conn.apiLabels[string(labelFromID)]; ok && v.Type == proton.LabelTypeLabel {
result = result || true
result = true
}
if v, ok := conn.apiLabels[string(labelToID)]; ok && (v.Type == proton.LabelTypeFolder || v.Type == proton.LabelTypeSystem) {
result = result || true
result = true
}
return result
@ -529,7 +529,7 @@ func (conn *imapConnector) GetMailboxVisibility(_ context.Context, mailboxID ima
}
// Close the connector will no longer be used and all resources should be closed/released.
func (conn *imapConnector) Close(ctx context.Context) error {
func (conn *imapConnector) Close(_ context.Context) error {
return nil
}
@ -544,7 +544,7 @@ func (conn *imapConnector) importMessage(
if err := safe.RLockRet(func() error {
return withAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
messageID := ""
var messageID string
if slices.Contains(labelIDs, proton.DraftsLabel) {
msg, err := conn.createDraft(ctx, literal, addrKR, conn.apiAddrs[conn.addrID])

View File

@ -25,6 +25,7 @@ import (
"time"
"github.com/ProtonMail/gluon/rfc822"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
)
@ -34,22 +35,30 @@ const sendEntryExpiry = 30 * time.Minute
type sendRecorder struct {
expiry time.Duration
entries map[string]*sendEntry
entries map[string][]*sendEntry
entriesLock sync.Mutex
}
func newSendRecorder(expiry time.Duration) *sendRecorder {
return &sendRecorder{
expiry: expiry,
entries: make(map[string]*sendEntry),
entries: make(map[string][]*sendEntry),
}
}
type sendEntry struct {
msgID string
toList []string
exp time.Time
waitCh chan struct{}
msgID string
toList []string
exp time.Time
waitCh chan struct{}
waitChClosed bool
}
func (s *sendEntry) closeWaitChannel() {
if !s.waitChClosed {
close(s.waitCh)
s.waitChClosed = true
}
}
// tryInsertWait tries to insert the given message into the send recorder.
@ -102,25 +111,40 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context, hash string, deadline t
return h.hasEntryWait(ctx, hash, deadline)
}
func (h *sendRecorder) removeExpiredUnsafe() {
for hash, entry := range h.entries {
remaining := xslices.Filter(entry, func(t *sendEntry) bool {
return !t.exp.Before(time.Now())
})
if len(remaining) == 0 {
delete(h.entries, hash)
} else {
h.entries[hash] = remaining
}
}
}
func (h *sendRecorder) tryInsert(hash string, toList []string) bool {
h.entriesLock.Lock()
defer h.entriesLock.Unlock()
for hash, entry := range h.entries {
if entry.exp.Before(time.Now()) {
delete(h.entries, hash)
h.removeExpiredUnsafe()
entries, ok := h.entries[hash]
if ok {
for _, entry := range entries {
if matchToList(entry.toList, toList) {
return false
}
}
}
if _, ok := h.entries[hash]; ok && matchToList(h.entries[hash].toList, toList) {
return false
}
h.entries[hash] = &sendEntry{
h.entries[hash] = append(entries, &sendEntry{
exp: time.Now().Add(h.expiry),
toList: toList,
waitCh: make(chan struct{}),
}
})
return true
}
@ -129,11 +153,7 @@ func (h *sendRecorder) hasEntry(hash string) bool {
h.entriesLock.Lock()
defer h.entriesLock.Unlock()
for hash, entry := range h.entries {
if entry.exp.Before(time.Now()) {
delete(h.entries, hash)
}
}
h.removeExpiredUnsafe()
if _, ok := h.entries[hash]; ok {
return true
@ -142,32 +162,46 @@ func (h *sendRecorder) hasEntry(hash string) bool {
return false
}
func (h *sendRecorder) addMessageID(hash, msgID string) {
// signalMessageSent should be called after a message has been successfully sent.
func (h *sendRecorder) signalMessageSent(hash, msgID string, toList []string) {
h.entriesLock.Lock()
defer h.entriesLock.Unlock()
entry, ok := h.entries[hash]
entries, ok := h.entries[hash]
if ok {
entry.msgID = msgID
} else {
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
for _, entry := range entries {
if matchToList(entry.toList, toList) {
entry.msgID = msgID
entry.closeWaitChannel()
return
}
}
}
close(entry.waitCh)
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
}
func (h *sendRecorder) removeOnFail(hash string) {
func (h *sendRecorder) removeOnFail(hash string, toList []string) {
h.entriesLock.Lock()
defer h.entriesLock.Unlock()
entry, ok := h.entries[hash]
if !ok || entry.msgID != "" {
entries, ok := h.entries[hash]
if !ok {
return
}
close(entry.waitCh)
for idx, entry := range entries {
if entry.msgID == "" && matchToList(entry.toList, toList) {
entry.closeWaitChannel()
delete(h.entries, hash)
remaining := xslices.Remove(entries, idx, 1)
if len(remaining) != 0 {
h.entries[hash] = remaining
} else {
delete(h.entries, hash)
}
}
}
}
func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time) (string, bool, error) {
@ -191,7 +225,7 @@ func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time
defer h.entriesLock.Unlock()
if entry, ok := h.entries[hash]; ok {
return entry.msgID, true, nil
return entry[0].msgID, true, nil
}
return "", false, nil
@ -202,7 +236,7 @@ func (h *sendRecorder) getWaitCh(hash string) (<-chan struct{}, bool) {
defer h.entriesLock.Unlock()
if entry, ok := h.entries[hash]; ok {
return entry.waitCh, true
return entry[0].waitCh, true
}
return nil, false

View File

@ -35,7 +35,7 @@ func TestSendHasher_Insert(t *testing.T) {
require.NotEmpty(t, hash1)
// Simulate successfully sending the message.
h.addMessageID(hash1, "abc")
h.signalMessageSent(hash1, "abc", nil)
// Inserting a message with the same hash should return false.
_, ok, err = testTryInsert(h, literal1, time.Now().Add(time.Second))
@ -59,7 +59,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) {
require.NotEmpty(t, hash1)
// Simulate successfully sending the message.
h.addMessageID(hash1, "abc")
h.signalMessageSent(hash1, "abc", nil)
// Wait for the entry to expire.
time.Sleep(time.Second)
@ -106,7 +106,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) {
// Simulate successfully sending the message after half a second.
go func() {
time.Sleep(time.Millisecond * 500)
h.addMessageID(hash, "abc")
h.signalMessageSent(hash, "abc", nil)
}()
// Inserting a message with the same hash should fail.
@ -127,7 +127,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) {
// Simulate failing to send the message after half a second.
go func() {
time.Sleep(time.Millisecond * 500)
h.removeOnFail(hash)
h.removeOnFail(hash, nil)
}()
// Inserting a message with the same hash should succeed because the first message failed to send.
@ -163,7 +163,7 @@ func TestSendHasher_HasEntry(t *testing.T) {
require.NotEmpty(t, hash)
// Simulate successfully sending the message.
h.addMessageID(hash, "abc")
h.signalMessageSent(hash, "abc", nil)
// The message was already sent; we should find it in the hasher.
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
@ -184,7 +184,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
// Simulate successfully sending the message after half a second.
go func() {
time.Sleep(time.Millisecond * 500)
h.addMessageID(hash, "abc")
h.signalMessageSent(hash, "abc", nil)
}()
// The message was already sent; we should find it in the hasher.
@ -194,6 +194,47 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
require.Equal(t, "abc", messageID)
}
func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
// There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message
// is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be
// inserted as a new entry. The two clients end up sending the message twice and calling the `signalMessageSent` x2,
// resulting in a crash.
h := newSendRecorder(sendEntryExpiry)
// Insert a message into the hasher.
hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
require.NoError(t, err)
require.True(t, ok)
require.NotEmpty(t, hash)
// Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections
// to attempt to send the same message.
h.signalMessageSent(hash, "abc", nil)
h.signalMessageSent(hash, "abc", nil)
// The message was already sent; we should find it in the hasher.
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
require.NoError(t, err)
require.True(t, ok)
require.Equal(t, "abc", messageID)
}
func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testing.T) {
h := newSendRecorder(sendEntryExpiry)
// Insert a message into the hasher.
hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver <receiver@pm.me>")
require.NoError(t, err)
require.True(t, ok)
require.NotEmpty(t, hash)
hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver <receiver@pm.me>", "Receiver2 <receiver2@pm.me>")
require.NoError(t, err)
require.True(t, ok)
require.NotEmpty(t, hash2)
require.Equal(t, hash, hash2)
}
func TestSendHasher_HasEntry_SendFail(t *testing.T) {
h := newSendRecorder(sendEntryExpiry)
@ -206,7 +247,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) {
// Simulate failing to send the message after half a second.
go func() {
time.Sleep(time.Millisecond * 500)
h.removeOnFail(hash)
h.removeOnFail(hash, nil)
}()
// The message failed to send; we should not find it in the hasher.
@ -240,7 +281,7 @@ func TestSendHasher_HasEntry_Expired(t *testing.T) {
require.NotEmpty(t, hash)
// Simulate successfully sending the message.
h.addMessageID(hash, "abc")
h.signalMessageSent(hash, "abc", nil)
// Wait for the entry to expire.
time.Sleep(time.Second)
@ -264,7 +305,6 @@ Content-Disposition: attachment; filename="attname.txt"
attachment
--longrandomstring--
`
const literal2 = `From: Sender <sender@pm.me>
To: Receiver <receiver@pm.me>
Content-Type: multipart/mixed; boundary=longrandomstring

View File

@ -89,7 +89,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
}
// If we fail to send this message, we should remove the hash from the send recorder.
defer user.sendHash.removeOnFail(hash)
defer user.sendHash.removeOnFail(hash, to)
// Create a new message parser from the reader.
parser, err := parser.New(bytes.NewReader(b))
@ -162,7 +162,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
}
// If the message was successfully sent, we can update the message ID in the record.
user.sendHash.addMessageID(hash, sent.ID)
user.sendHash.signalMessageSent(hash, sent.ID, to)
return nil
})
@ -438,6 +438,10 @@ func (user *User) createAttachments(
}
}
// Exclude name from params since this is already provided using Filename.
delete(att.MIMEParams, "name")
delete(att.MIMEParams, "filename")
attachment, err := client.UploadAttachment(ctx, addrKR, proton.CreateAttachmentReq{
Filename: att.Name,
MessageID: draftID,

View File

@ -19,6 +19,6 @@
package user
func debugDumpToDisk(b []byte) error {
func debugDumpToDisk(_ []byte) error {
return nil
}

View File

@ -513,7 +513,16 @@ func (user *User) syncMessages(
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
defer async.HandlePanic(user.panicHandler)
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil
kr, ok := addrKRs[msg.AddressID]
if !ok {
return &buildRes{
messageID: msg.ID,
addressID: msg.AddressID,
err: fmt.Errorf("address does not have an unlocked keyring"),
}, nil
}
return buildRFC822(apiLabels, msg, kr, new(bytes.Buffer)), nil
})
if err != nil {
return
@ -572,10 +581,10 @@ func (user *User) syncMessages(
// We could sync a placeholder message here, but for now we skip it entirely.
continue
} else {
if err := vault.RemFailedMessageID(res.messageID); err != nil {
logrus.WithError(err).Error("Failed to remove failed message ID")
}
}
if err := vault.RemFailedMessageID(res.messageID); err != nil {
logrus.WithError(err).Error("Failed to remove failed message ID")
}
targetInfo := addressToIndex[res.addressID]

View File

@ -83,6 +83,18 @@ func getAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, er
return sorted[idx], nil
}
func getPrimaryAddr(apiAddrs map[string]proton.Address) (proton.Address, error) {
sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool {
return a.Order < b.Order
})
if len(sorted) == 0 {
return proton.Address{}, fmt.Errorf("no addresses available")
}
return sorted[0], nil
}
// sortSlice returns the given slice sorted by the given comparator.
func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
sorted := make([]Item, len(items))

View File

@ -282,7 +282,7 @@ func (user *User) Match(query string) bool {
func (user *User) Emails() []string {
return safe.RLockRet(func() []string {
addresses := xslices.Filter(maps.Values(user.apiAddrs), func(addr proton.Address) bool {
return addr.Status == proton.AddressStatusEnabled
return addr.Status == proton.AddressStatusEnabled && addr.Type != proton.AddressTypeExternal
})
slices.SortFunc(addresses, func(a, b proton.Address) bool {
@ -586,6 +586,8 @@ func (user *User) Close() {
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
updateCh.CloseAndDiscardQueued()
}
user.updateCh = make(map[string]*async.QueuedChannel[imap.Update])
}, user.updateChLock)
// Close the user's notify channel.
@ -690,87 +692,89 @@ func (user *User) doEventPoll(ctx context.Context) error {
user.eventLock.Lock()
defer user.eventLock.Unlock()
event, more, err := user.client.GetEvent(ctx, user.vault.EventID())
gpaEvents, more, err := user.client.GetEvent(ctx, user.vault.EventID())
if err != nil {
return fmt.Errorf("failed to get event (caused by %T): %w", internal.ErrCause(err), err)
}
// If the event ID hasn't changed, there are no new events.
if event.EventID == user.vault.EventID() {
if gpaEvents[len(gpaEvents)-1].EventID == user.vault.EventID() {
user.log.Debug("No new API events")
return nil
}
user.log.WithFields(logrus.Fields{
"old": user.vault.EventID(),
"new": event,
}).Info("Received new API event")
for _, event := range gpaEvents {
user.log.WithFields(logrus.Fields{
"old": user.vault.EventID(),
"new": event,
}).Info("Received new API event")
// Handle the event.
if err := user.handleAPIEvent(ctx, event); err != nil {
// If the error is a context cancellation, return error to retry later.
if errors.Is(err, context.Canceled) {
return fmt.Errorf("failed to handle event due to context cancellation: %w", err)
// Handle the event.
if err := user.handleAPIEvent(ctx, event); err != nil {
// If the error is a context cancellation, return error to retry later.
if errors.Is(err, context.Canceled) {
return fmt.Errorf("failed to handle event due to context cancellation: %w", err)
}
// If the error is a network error, return error to retry later.
if netErr := new(proton.NetError); errors.As(err, &netErr) {
return fmt.Errorf("failed to handle event due to network issue: %w", err)
}
// Catch all for uncategorized net errors that may slip through.
if netErr := new(net.OpError); errors.As(err, &netErr) {
return fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err)
}
// In case a json decode error slips through.
if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) {
user.eventCh.Enqueue(events.UncategorizedEventError{
UserID: user.ID(),
Error: err,
})
return fmt.Errorf("failed to handle event due to JSON issue: %w", err)
}
// If the error is an unexpected EOF, return error to retry later.
if errors.Is(err, io.ErrUnexpectedEOF) {
return fmt.Errorf("failed to handle event due to EOF: %w", err)
}
// If the error is a server-side issue, return error to retry later.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
return fmt.Errorf("failed to handle event due to server error: %w", err)
}
// Otherwise, the error is a client-side issue; notify bridge to handle it.
user.log.WithField("event", event).Warn("Failed to handle API event")
user.eventCh.Enqueue(events.UserBadEvent{
UserID: user.ID(),
OldEventID: user.vault.EventID(),
NewEventID: event.EventID,
EventInfo: event.String(),
Error: err,
})
return fmt.Errorf("failed to handle event due to client error: %w", err)
}
// If the error is a network error, return error to retry later.
if netErr := new(proton.NetError); errors.As(err, &netErr) {
return fmt.Errorf("failed to handle event due to network issue: %w", err)
}
user.log.WithField("event", event).Debug("Handled API event")
// Catch all for uncategorized net errors that may slip through.
if netErr := new(net.OpError); errors.As(err, &netErr) {
return fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err)
}
// In case a json decode error slips through.
if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) {
user.eventCh.Enqueue(events.UncategorizedEventError{
// Update the event ID in the vault. If this fails, notify bridge to handle it.
if err := user.vault.SetEventID(event.EventID); err != nil {
user.eventCh.Enqueue(events.UserBadEvent{
UserID: user.ID(),
Error: err,
})
return fmt.Errorf("failed to handle event due to JSON issue: %w", err)
return fmt.Errorf("failed to update event ID: %w", err)
}
// If the error is an unexpected EOF, return error to retry later.
if errors.Is(err, io.ErrUnexpectedEOF) {
return fmt.Errorf("failed to handle event due to EOF: %w", err)
}
// If the error is a server-side issue, return error to retry later.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
return fmt.Errorf("failed to handle event due to server error: %w", err)
}
// Otherwise, the error is a client-side issue; notify bridge to handle it.
user.log.WithField("event", event).Warn("Failed to handle API event")
user.eventCh.Enqueue(events.UserBadEvent{
UserID: user.ID(),
OldEventID: user.vault.EventID(),
NewEventID: event.EventID,
EventInfo: event.String(),
Error: err,
})
return fmt.Errorf("failed to handle event due to client error: %w", err)
user.log.WithField("eventID", event.EventID).Debug("Updated event ID in vault")
}
user.log.WithField("event", event).Debug("Handled API event")
// Update the event ID in the vault. If this fails, notify bridge to handle it.
if err := user.vault.SetEventID(event.EventID); err != nil {
user.eventCh.Enqueue(events.UserBadEvent{
UserID: user.ID(),
Error: err,
})
return fmt.Errorf("failed to update event ID: %w", err)
}
user.log.WithField("eventID", event.EventID).Debug("Updated event ID in vault")
if more {
user.goPollAPIEvents(false)
}

View File

@ -30,7 +30,12 @@ import (
// If CertPEMPath is set, it will attempt to read the certificate from the file.
// Otherwise, or on read/validation failure, it will return the certificate from the vault.
func (vault *Vault) GetBridgeTLSCert() ([]byte, []byte) {
if certPath, keyPath := vault.get().Certs.CustomCertPath, vault.get().Certs.CustomKeyPath; certPath != "" && keyPath != "" {
vault.lock.RLock()
defer vault.lock.RUnlock()
certs := vault.getUnsafe().Certs
if certPath, keyPath := certs.CustomCertPath, certs.CustomKeyPath; certPath != "" && keyPath != "" {
if certPEM, keyPEM, err := readPEMCert(certPath, keyPath); err == nil {
return certPEM, keyPEM
}
@ -38,7 +43,7 @@ func (vault *Vault) GetBridgeTLSCert() ([]byte, []byte) {
logrus.Error("Failed to read certificate from file, using default")
}
return vault.get().Certs.Bridge.Cert, vault.get().Certs.Bridge.Key
return certs.Bridge.Cert, certs.Bridge.Key
}
// SetBridgeTLSCertPath sets the path to PEM-encoded certificates for the bridge.
@ -47,7 +52,7 @@ func (vault *Vault) SetBridgeTLSCertPath(certPath, keyPath string) error {
return fmt.Errorf("invalid certificate: %w", err)
}
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Certs.CustomCertPath = certPath
data.Certs.CustomKeyPath = keyPath
})
@ -55,18 +60,18 @@ func (vault *Vault) SetBridgeTLSCertPath(certPath, keyPath string) error {
// SetBridgeTLSCertKey sets the path to PEM-encoded certificates for the bridge.
func (vault *Vault) SetBridgeTLSCertKey(cert, key []byte) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Certs.Bridge.Cert = cert
data.Certs.Bridge.Key = key
})
}
func (vault *Vault) GetCertsInstalled() bool {
return vault.get().Certs.Installed
return vault.getSafe().Certs.Installed
}
func (vault *Vault) SetCertsInstalled(installed bool) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Certs.Installed = installed
})
}

View File

@ -18,11 +18,11 @@
package vault
func (vault *Vault) GetCookies() ([]byte, error) {
return vault.get().Cookies, nil
return vault.getSafe().Cookies, nil
}
func (vault *Vault) SetCookies(cookies []byte) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Cookies = cookies
})
}

View File

@ -0,0 +1,46 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package vault
import (
"crypto/sha256"
"fmt"
)
// set archives the password for an email address, overwriting any existing archived value.
func (p *PasswordArchive) set(emailAddress string, password []byte) {
if p.Archive == nil {
p.Archive = make(map[string][]byte)
}
p.Archive[emailHashString(emailAddress)] = password
}
// get retrieves the archived password for an email address, or nil if not found.
func (p *PasswordArchive) get(emailAddress string) []byte {
if p.Archive == nil {
return nil
}
return p.Archive[emailHashString(emailAddress)]
}
// emailHashString returns a hash string for an email address as a hexadecimal string.
func emailHashString(emailAddress string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(emailAddress)))
}

View File

@ -33,72 +33,72 @@ const (
// GetIMAPPort sets the port that the IMAP server should listen on.
func (vault *Vault) GetIMAPPort() int {
return vault.get().Settings.IMAPPort
return vault.getSafe().Settings.IMAPPort
}
// SetIMAPPort sets the port that the IMAP server should listen on.
func (vault *Vault) SetIMAPPort(port int) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.IMAPPort = port
})
}
// GetSMTPPort sets the port that the SMTP server should listen on.
func (vault *Vault) GetSMTPPort() int {
return vault.get().Settings.SMTPPort
return vault.getSafe().Settings.SMTPPort
}
// SetSMTPPort sets the port that the SMTP server should listen on.
func (vault *Vault) SetSMTPPort(port int) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.SMTPPort = port
})
}
// GetIMAPSSL sets whether the IMAP server should use SSL.
func (vault *Vault) GetIMAPSSL() bool {
return vault.get().Settings.IMAPSSL
return vault.getSafe().Settings.IMAPSSL
}
// SetIMAPSSL sets whether the IMAP server should use SSL.
func (vault *Vault) SetIMAPSSL(ssl bool) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.IMAPSSL = ssl
})
}
// GetSMTPSSL sets whether the SMTP server should use SSL.
func (vault *Vault) GetSMTPSSL() bool {
return vault.get().Settings.SMTPSSL
return vault.getSafe().Settings.SMTPSSL
}
// SetSMTPSSL sets whether the SMTP server should use SSL.
func (vault *Vault) SetSMTPSSL(ssl bool) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.SMTPSSL = ssl
})
}
// GetGluonCacheDir sets the directory where the gluon should store its data.
func (vault *Vault) GetGluonCacheDir() string {
return vault.get().Settings.GluonDir
return vault.getSafe().Settings.GluonDir
}
// SetGluonDir sets the directory where the gluon should store its data.
func (vault *Vault) SetGluonDir(dir string) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.GluonDir = dir
})
}
// GetUpdateChannel sets the update channel.
func (vault *Vault) GetUpdateChannel() updater.Channel {
return vault.get().Settings.UpdateChannel
return vault.getSafe().Settings.UpdateChannel
}
// SetUpdateChannel sets the update channel.
func (vault *Vault) SetUpdateChannel(channel updater.Channel) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.UpdateChannel = channel
})
}
@ -106,7 +106,7 @@ func (vault *Vault) SetUpdateChannel(channel updater.Channel) error {
// GetUpdateRollout sets the update rollout.
func (vault *Vault) GetUpdateRollout() float64 {
// The rollout value 0.6046602879796196 is forbidden. The RNG was not seeded when it was picked (GODT-2319).
rollout := vault.get().Settings.UpdateRollout
rollout := vault.getSafe().Settings.UpdateRollout
if math.Abs(rollout-ForbiddenRollout) >= 0.00000001 {
return rollout
}
@ -120,110 +120,110 @@ func (vault *Vault) GetUpdateRollout() float64 {
// SetUpdateRollout sets the update rollout.
func (vault *Vault) SetUpdateRollout(rollout float64) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.UpdateRollout = rollout
})
}
// GetColorScheme sets the color scheme to be used by the bridge GUI.
func (vault *Vault) GetColorScheme() string {
return vault.get().Settings.ColorScheme
return vault.getSafe().Settings.ColorScheme
}
// SetColorScheme sets the color scheme to be used by the bridge GUI.
func (vault *Vault) SetColorScheme(colorScheme string) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.ColorScheme = colorScheme
})
}
// GetProxyAllowed sets whether the bridge is allowed to use alternative routing.
func (vault *Vault) GetProxyAllowed() bool {
return vault.get().Settings.ProxyAllowed
return vault.getSafe().Settings.ProxyAllowed
}
// SetProxyAllowed sets whether the bridge is allowed to use alternative routing.
func (vault *Vault) SetProxyAllowed(allowed bool) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.ProxyAllowed = allowed
})
}
// GetShowAllMail sets whether the bridge should show the All Mail folder.
func (vault *Vault) GetShowAllMail() bool {
return vault.get().Settings.ShowAllMail
return vault.getSafe().Settings.ShowAllMail
}
// SetShowAllMail sets whether the bridge should show the All Mail folder.
func (vault *Vault) SetShowAllMail(showAllMail bool) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.ShowAllMail = showAllMail
})
}
// GetAutostart sets whether the bridge should autostart.
func (vault *Vault) GetAutostart() bool {
return vault.get().Settings.Autostart
return vault.getSafe().Settings.Autostart
}
// SetAutostart sets whether the bridge should autostart.
func (vault *Vault) SetAutostart(autostart bool) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.Autostart = autostart
})
}
// GetAutoUpdate sets whether the bridge should automatically update.
func (vault *Vault) GetAutoUpdate() bool {
return vault.get().Settings.AutoUpdate
return vault.getSafe().Settings.AutoUpdate
}
// SetAutoUpdate sets whether the bridge should automatically update.
func (vault *Vault) SetAutoUpdate(autoUpdate bool) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.AutoUpdate = autoUpdate
})
}
// GetTelemetryDisabled checks whether telemetry is disabled.
func (vault *Vault) GetTelemetryDisabled() bool {
return vault.get().Settings.TelemetryDisabled
return vault.getSafe().Settings.TelemetryDisabled
}
// SetTelemetryDisabled sets whether telemetry is disabled.
func (vault *Vault) SetTelemetryDisabled(telemetryDisabled bool) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.TelemetryDisabled = telemetryDisabled
})
}
// GetLastVersion returns the last version of the bridge that was run.
func (vault *Vault) GetLastVersion() *semver.Version {
return semver.MustParse(vault.get().Settings.LastVersion)
return semver.MustParse(vault.getSafe().Settings.LastVersion)
}
// SetLastVersion sets the last version of the bridge that was run.
func (vault *Vault) SetLastVersion(version *semver.Version) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.LastVersion = version.String()
})
}
// GetFirstStart returns whether this is the first time the bridge has been started.
func (vault *Vault) GetFirstStart() bool {
return vault.get().Settings.FirstStart
return vault.getSafe().Settings.FirstStart
}
// SetFirstStart sets whether this is the first time the bridge has been started.
func (vault *Vault) SetFirstStart(firstStart bool) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.FirstStart = firstStart
})
}
// GetMaxSyncMemory returns the maximum amount of memory the sync process should use.
func (vault *Vault) GetMaxSyncMemory() uint64 {
v := vault.get().Settings.MaxSyncMemory
v := vault.getSafe().Settings.MaxSyncMemory
// can be zero if never written to vault before.
if v == 0 {
return DefaultMaxSyncMemory
@ -234,14 +234,14 @@ func (vault *Vault) GetMaxSyncMemory() uint64 {
// SetMaxSyncMemory sets the maximum amount of memory the sync process should use.
func (vault *Vault) SetMaxSyncMemory(maxMemory uint64) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.MaxSyncMemory = maxMemory
})
}
// GetLastUserAgent returns the last user agent recorded by bridge.
func (vault *Vault) GetLastUserAgent() string {
v := vault.get().Settings.LastUserAgent
v := vault.getSafe().Settings.LastUserAgent
// Handle case where there may be no value.
if len(v) == 0 {
@ -253,19 +253,19 @@ func (vault *Vault) GetLastUserAgent() string {
// SetLastUserAgent store the last user agent recorded by bridge.
func (vault *Vault) SetLastUserAgent(userAgent string) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.LastUserAgent = userAgent
})
}
// GetLastHeartbeatSent returns the last time heartbeat was sent.
func (vault *Vault) GetLastHeartbeatSent() time.Time {
return vault.get().Settings.LastHeartbeatSent
return vault.getSafe().Settings.LastHeartbeatSent
}
// SetLastHeartbeatSent store the last time heartbeat was sent.
func (vault *Vault) SetLastHeartbeatSent(timestamp time.Time) error {
return vault.mod(func(data *Data) {
return vault.modSafe(func(data *Data) {
data.Settings.LastHeartbeatSent = timestamp
})
}

View File

@ -238,3 +238,30 @@ func TestVault_Settings_LastUserAgent(t *testing.T) {
// Check the default first start value.
require.Equal(t, vault.DefaultUserAgent, s.GetLastUserAgent())
}
func Test_Settings_PasswordArchive(t *testing.T) {
// Create a new test vault.
s := newVault(t)
// The store should have no users.
require.Empty(t, s.GetUserIDs())
// Create a new user.
user, err := s.AddUser("userID1", "username1", "username1@pm.me", "authUID1", "authRef1", []byte("keyPass1"))
require.NoError(t, err)
bridgePass := user.BridgePass()
// Remove the user.
require.NoError(t, user.Close())
require.NoError(t, s.DeleteUser("userID1"))
// Add a different user. Another password is generated.
user, err = s.AddUser("userID2", "username2", "username2@pm.me", "authUID2", "authRef2", []byte("keyPass2"))
require.NoError(t, err)
require.NotEqual(t, user.BridgePass(), bridgePass)
// Add the first user again. The password is restored.
user, err = s.AddUser("userID1", "username1", "username1@pm.me", "authUID1", "authRef1", []byte("keyPass1"))
require.NoError(t, err)
require.Equal(t, user.BridgePass(), bridgePass)
}

View File

@ -48,11 +48,7 @@ func unmarshalFile[T any](gcm cipher.AEAD, b []byte, data *T) error {
}
}
if err := msgpack.Unmarshal(dec, data); err != nil {
return err
}
return nil
return msgpack.Unmarshal(dec, data)
}
func marshalFile[T any](gcm cipher.AEAD, t T) ([]byte, error) {

View File

@ -0,0 +1,25 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package vault
// PasswordArchive maps a list email address hashes to passwords.
// The type is not defined as a map alias to prevent having to handle nil default values when vault was created by an older version of the application.
type PasswordArchive struct {
// we store the SHA-256 sum as string for readability and JSON marshalling of map[[32]byte][]byte will not be allowed, thus breaking vault-editor.
Archive map[string][]byte
}

View File

@ -53,6 +53,8 @@ type Settings struct {
LastHeartbeatSent time.Time
PasswordArchive PasswordArchive
// **WARNING**: These entry can't be removed until they vault has proper migration support.
SyncWorkers int
SyncAttPool int
@ -105,5 +107,7 @@ func newDefaultSettings(gluonDir string) Settings {
LastUserAgent: DefaultUserAgent,
LastHeartbeatSent: time.Time{},
PasswordArchive: PasswordArchive{},
}
}

View File

@ -73,7 +73,7 @@ func (status SyncStatus) IsComplete() bool {
return status.HasLabels && status.HasMessages
}
func newDefaultUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) UserData {
func newDefaultUser(userID, username, primaryEmail, authUID, authRef string, keyPass, bridgePass []byte) UserData {
return UserData{
UserID: userID,
Username: username,
@ -82,7 +82,7 @@ func newDefaultUser(userID, username, primaryEmail, authUID, authRef string, key
GluonKey: newRandomToken(32),
GluonIDs: make(map[string]string),
UIDValidity: make(map[string]imap.UID),
BridgePass: newRandomToken(16),
BridgePass: bridgePass,
AddressMode: CombinedMode,
AuthUID: authUID,

View File

@ -122,6 +122,14 @@ func (user *User) SetAuth(authUID, authRef string) error {
})
}
func (user *User) setAuthAndKeyPassUnsafe(authUID, authRef string, keyPass []byte) error {
return user.vault.modUserUnsafe(user.userID, func(userData *UserData) {
userData.AuthRef = authRef
userData.AuthUID = authUID
userData.KeyPass = keyPass
})
}
// KeyPass returns the user's (salted) key password.
func (user *User) KeyPass() []byte {
return user.vault.getUser(user.userID).KeyPass

View File

@ -40,11 +40,11 @@ type Vault struct {
path string
gcm cipher.AEAD
enc []byte
encLock sync.RWMutex
enc []byte
ref map[string]int
refLock sync.Mutex
ref map[string]int
lock sync.RWMutex
panicHandler async.PanicHandler
}
@ -79,14 +79,46 @@ func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHan
// GetUserIDs returns the user IDs and usernames of all users in the vault.
func (vault *Vault) GetUserIDs() []string {
return xslices.Map(vault.get().Users, func(user UserData) string {
vault.lock.RLock()
defer vault.lock.RUnlock()
return xslices.Map(vault.getUnsafe().Users, func(user UserData) string {
return user.UserID
})
}
func (vault *Vault) getUsers() ([]*User, error) {
vault.lock.Lock()
defer vault.lock.Unlock()
users := vault.getUnsafe().Users
result := make([]*User, 0, len(users))
for _, user := range users {
u, err := vault.newUserUnsafe(user.UserID)
if err != nil {
for _, v := range result {
if err := v.Close(); err != nil {
logrus.WithError(err).Error("Fait to close user after failed get")
}
}
return nil, err
}
result = append(result, u)
}
return result, nil
}
// HasUser returns true if the vault contains a user with the given ID.
func (vault *Vault) HasUser(userID string) bool {
return xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
vault.lock.RLock()
defer vault.lock.RUnlock()
return xslices.IndexFunc(vault.getUnsafe().Users, func(user UserData) bool {
return user.UserID == userID
}) >= 0
}
@ -106,46 +138,72 @@ func (vault *Vault) GetUser(userID string, fn func(*User)) error {
// NewUser returns a new vault user. It must be closed before it can be deleted.
func (vault *Vault) NewUser(userID string) (*User, error) {
if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
vault.lock.Lock()
defer vault.lock.Unlock()
return vault.newUserUnsafe(userID)
}
func (vault *Vault) newUserUnsafe(userID string) (*User, error) {
if idx := xslices.IndexFunc(vault.getUnsafe().Users, func(user UserData) bool {
return user.UserID == userID
}); idx < 0 {
return nil, errors.New("no such user")
}
return vault.attachUser(userID), nil
return vault.attachUserUnsafe(userID), nil
}
// ForUser executes a callback for each user in the vault.
func (vault *Vault) ForUser(parallelism int, fn func(*User) error) error {
userIDs := vault.GetUserIDs()
users, err := vault.getUsers()
if err != nil {
return err
}
return parallel.DoContext(context.Background(), parallelism, len(userIDs), func(_ context.Context, idx int) error {
r := parallel.DoContext(context.Background(), parallelism, len(users), func(_ context.Context, idx int) error {
defer async.HandlePanic(vault.panicHandler)
user, err := vault.NewUser(userIDs[idx])
if err != nil {
return err
}
defer func() { _ = user.Close() }()
user := users[idx]
return fn(user)
})
for _, u := range users {
if err := u.Close(); err != nil {
logrus.WithError(err).Error("Failed to close user after ForUser")
}
}
return r
}
// AddUser creates a new user in the vault with the given ID, username and password.
// A bridge password and gluon key are generated using the package's token generator.
// A gluon key is generated using the package's token generator. If a password is found in the password archive for this user,
// it is restored, otherwise a new bridge password is generated using the package's token generator.
func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) {
vault.lock.Lock()
defer vault.lock.Unlock()
return vault.addUserUnsafe(userID, username, primaryEmail, authUID, authRef, keyPass)
}
func (vault *Vault) addUserUnsafe(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) {
logrus.WithField("userID", userID).Info("Adding vault user")
var exists bool
if err := vault.mod(func(data *Data) {
if err := vault.modUnsafe(func(data *Data) {
if idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
return user.UserID == userID
}); idx >= 0 {
exists = true
} else {
data.Users = append(data.Users, newDefaultUser(userID, username, primaryEmail, authUID, authRef, keyPass))
bridgePass := data.Settings.PasswordArchive.get(primaryEmail)
if len(bridgePass) == 0 {
bridgePass = newRandomToken(16)
}
data.Users = append(data.Users, newDefaultUser(userID, username, primaryEmail, authUID, authRef, keyPass, bridgePass))
}
}); err != nil {
return nil, err
@ -155,13 +213,42 @@ func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef str
return nil, errors.New("user already exists")
}
return vault.NewUser(userID)
return vault.attachUserUnsafe(userID), nil
}
// GetOrAddUser retrieves an existing user and updates the authRef and keyPass or creates a new user. Returns
// the user and whether the user did not exist before.
func (vault *Vault) GetOrAddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, bool, error) {
vault.lock.Lock()
defer vault.lock.Unlock()
{
users := vault.getUnsafe().Users
idx := xslices.IndexFunc(users, func(user UserData) bool {
return user.UserID == userID
})
if idx >= 0 {
user := vault.attachUserUnsafe(userID)
if err := user.setAuthAndKeyPassUnsafe(authUID, authRef, keyPass); err != nil {
return nil, false, err
}
return user, false, nil
}
}
u, err := vault.addUserUnsafe(userID, username, primaryEmail, authUID, authRef, keyPass)
return u, true, err
}
// DeleteUser removes the given user from the vault.
func (vault *Vault) DeleteUser(userID string) error {
vault.refLock.Lock()
defer vault.refLock.Unlock()
vault.lock.Lock()
defer vault.lock.Unlock()
logrus.WithField("userID", userID).Info("Deleting vault user")
@ -169,7 +256,7 @@ func (vault *Vault) DeleteUser(userID string) error {
return fmt.Errorf("user %s is currently in use", userID)
}
return vault.mod(func(data *Data) {
return vault.modUnsafe(func(data *Data) {
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
return user.UserID == userID
})
@ -177,23 +264,32 @@ func (vault *Vault) DeleteUser(userID string) error {
if idx < 0 {
return
}
data.Settings.PasswordArchive.set(data.Users[idx].PrimaryEmail, data.Users[idx].BridgePass)
data.Users = append(data.Users[:idx], data.Users[idx+1:]...)
})
}
func (vault *Vault) Migrated() bool {
return vault.get().Migrated
vault.lock.RLock()
defer vault.lock.RUnlock()
return vault.getUnsafe().Migrated
}
func (vault *Vault) SetMigrated() error {
return vault.mod(func(data *Data) {
vault.lock.Lock()
defer vault.lock.Unlock()
return vault.modUnsafe(func(data *Data) {
data.Migrated = true
})
}
func (vault *Vault) Reset(gluonDir string) error {
return vault.mod(func(data *Data) {
vault.lock.Lock()
defer vault.lock.Unlock()
return vault.modUnsafe(func(data *Data) {
*data = newDefaultData(gluonDir)
})
}
@ -203,8 +299,8 @@ func (vault *Vault) Path() string {
}
func (vault *Vault) Close() error {
vault.refLock.Lock()
defer vault.refLock.Unlock()
vault.lock.Lock()
defer vault.lock.Unlock()
if len(vault.ref) > 0 {
return errors.New("vault is still in use")
@ -215,10 +311,7 @@ func (vault *Vault) Close() error {
return nil
}
func (vault *Vault) attachUser(userID string) *User {
vault.refLock.Lock()
defer vault.refLock.Unlock()
func (vault *Vault) attachUserUnsafe(userID string) *User {
logrus.WithField("userID", userID).Trace("Attaching vault user")
vault.ref[userID]++
@ -230,8 +323,8 @@ func (vault *Vault) attachUser(userID string) *User {
}
func (vault *Vault) detachUser(userID string) error {
vault.refLock.Lock()
defer vault.refLock.Unlock()
vault.lock.Lock()
defer vault.lock.Unlock()
logrus.WithField("userID", userID).Trace("Detaching vault user")
@ -283,10 +376,14 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
}, corrupt, nil
}
func (vault *Vault) get() Data {
vault.encLock.RLock()
defer vault.encLock.RUnlock()
func (vault *Vault) getSafe() Data {
vault.lock.RLock()
defer vault.lock.RUnlock()
return vault.getUnsafe()
}
func (vault *Vault) getUnsafe() Data {
var data Data
if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil {
@ -296,10 +393,14 @@ func (vault *Vault) get() Data {
return data
}
func (vault *Vault) mod(fn func(data *Data)) error {
vault.encLock.Lock()
defer vault.encLock.Unlock()
func (vault *Vault) modSafe(fn func(data *Data)) error {
vault.lock.Lock()
defer vault.lock.Unlock()
return vault.modUnsafe(fn)
}
func (vault *Vault) modUnsafe(fn func(data *Data)) error {
var data Data
if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil {
@ -319,13 +420,31 @@ func (vault *Vault) mod(fn func(data *Data)) error {
}
func (vault *Vault) getUser(userID string) UserData {
return vault.get().Users[xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
vault.lock.RLock()
defer vault.lock.RUnlock()
users := vault.getUnsafe().Users
idx := xslices.IndexFunc(users, func(user UserData) bool {
return user.UserID == userID
})]
})
if idx < 0 {
panic("Unknown user")
}
return users[idx]
}
func (vault *Vault) modUser(userID string, fn func(userData *UserData)) error {
return vault.mod(func(data *Data) {
vault.lock.Lock()
defer vault.lock.Unlock()
return vault.modUserUnsafe(userID, fn)
}
func (vault *Vault) modUserUnsafe(userID string, fn func(userData *UserData)) error {
return vault.modUnsafe(func(data *Data) {
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
return user.UserID == userID
})

View File

@ -24,7 +24,7 @@ import (
)
func (vault *Vault) ImportJSON(dec []byte) {
vault.mod(func(data *Data) {
vault.modSafe(func(data *Data) {
if err := json.Unmarshal(dec, data); err != nil {
panic(err)
}
@ -32,7 +32,7 @@ func (vault *Vault) ImportJSON(dec []byte) {
}
func (vault *Vault) ExportJSON() []byte {
enc, err := json.MarshalIndent(vault.get(), "", " ")
enc, err := json.MarshalIndent(vault.getSafe(), "", " ")
if err != nil {
panic(err)
}

View File

@ -29,7 +29,7 @@ func (v *Versioner) RemoveOldVersions() error {
}
// RemoveOtherVersions removes all but the specific provided app version.
func (v *Versioner) RemoveOtherVersions(versionToKeep *semver.Version) error {
func (v *Versioner) RemoveOtherVersions(_ *semver.Version) error {
// darwin does not use the versioner; removal is a noop.
return nil
}