forked from Silverfish/proton-bridge
chore: merge release/Rialto into devel
This commit is contained in:
@ -19,14 +19,12 @@ package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
@ -160,9 +158,6 @@ func New() *cli.App {
|
||||
}
|
||||
|
||||
func run(c *cli.Context) error {
|
||||
// Seed the default RNG from the math/rand package.
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
// Get the current bridge version.
|
||||
version, err := semver.NewVersion(constants.Version)
|
||||
if err != nil {
|
||||
|
||||
@ -29,7 +29,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon"
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
imapEvents "github.com/ProtonMail/gluon/events"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
@ -45,7 +44,6 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
@ -67,13 +65,7 @@ type Bridge struct {
|
||||
tlsConfig *tls.Config
|
||||
|
||||
// imapServer is the bridge's IMAP server.
|
||||
imapServer *gluon.Server
|
||||
imapListener net.Listener
|
||||
imapEventCh chan imapEvents.Event
|
||||
|
||||
// smtpServer is the bridge's SMTP server.
|
||||
smtpServer *smtp.Server
|
||||
smtpListener net.Listener
|
||||
imapEventCh chan imapEvents.Event
|
||||
|
||||
// updater is the bridge's updater.
|
||||
updater Updater
|
||||
@ -134,6 +126,8 @@ type Bridge struct {
|
||||
goHeartbeat func()
|
||||
|
||||
uidValidityGenerator imap.UIDValidityGenerator
|
||||
|
||||
serverManager *ServerManager
|
||||
}
|
||||
|
||||
// New creates a new bridge.
|
||||
@ -224,16 +218,6 @@ func newBridge(
|
||||
return nil, fmt.Errorf("failed to load TLS config: %w", err)
|
||||
}
|
||||
|
||||
gluonCacheDir, err := getGluonDir(vault)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get Gluon directory: %w", err)
|
||||
}
|
||||
|
||||
gluonDataDir, err := locator.ProvideGluonDataPath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get Gluon Database directory: %w", err)
|
||||
}
|
||||
|
||||
firstStart := vault.GetFirstStart()
|
||||
if err := vault.SetFirstStart(false); err != nil {
|
||||
return nil, fmt.Errorf("failed to save first start indicator: %w", err)
|
||||
@ -246,23 +230,6 @@ func newBridge(
|
||||
|
||||
identifier.SetClientString(vault.GetLastUserAgent())
|
||||
|
||||
imapServer, err := newIMAPServer(
|
||||
gluonCacheDir,
|
||||
gluonDataDir,
|
||||
curVersion,
|
||||
tlsConfig,
|
||||
reporter,
|
||||
logIMAPClient,
|
||||
logIMAPServer,
|
||||
imapEventCh,
|
||||
tasks,
|
||||
uidValidityGenerator,
|
||||
panicHandler,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IMAP server: %w", err)
|
||||
}
|
||||
|
||||
focusService, err := focus.NewService(locator, curVersion, panicHandler)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create focus service: %w", err)
|
||||
@ -279,7 +246,6 @@ func newBridge(
|
||||
identifier: identifier,
|
||||
|
||||
tlsConfig: tlsConfig,
|
||||
imapServer: imapServer,
|
||||
imapEventCh: imapEventCh,
|
||||
|
||||
updater: updater,
|
||||
@ -306,9 +272,13 @@ func newBridge(
|
||||
tasks: tasks,
|
||||
|
||||
uidValidityGenerator: uidValidityGenerator,
|
||||
|
||||
serverManager: newServerManager(),
|
||||
}
|
||||
|
||||
bridge.smtpServer = newSMTPServer(bridge, tlsConfig, logSMTP)
|
||||
if err := bridge.serverManager.Init(bridge); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bridge, nil
|
||||
}
|
||||
@ -381,10 +351,6 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
|
||||
})
|
||||
})
|
||||
|
||||
// We need to load users before we can start the IMAP and SMTP servers.
|
||||
// We must only start the servers once.
|
||||
var once sync.Once
|
||||
|
||||
// Attempt to load users from the vault when triggered.
|
||||
bridge.goLoad = bridge.tasks.Trigger(func(ctx context.Context) {
|
||||
if err := bridge.loadUsers(ctx); err != nil {
|
||||
@ -396,17 +362,6 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
|
||||
}
|
||||
|
||||
bridge.publish(events.AllUsersLoaded{})
|
||||
|
||||
// Once all users have been loaded, start the bridge's IMAP and SMTP servers.
|
||||
once.Do(func() {
|
||||
if err := bridge.serveIMAP(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to start IMAP server")
|
||||
}
|
||||
|
||||
if err := bridge.serveSMTP(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to start SMTP server")
|
||||
}
|
||||
})
|
||||
})
|
||||
defer bridge.goLoad()
|
||||
|
||||
@ -452,18 +407,13 @@ func (bridge *Bridge) GetErrors() []error {
|
||||
func (bridge *Bridge) Close(ctx context.Context) {
|
||||
logrus.Info("Closing bridge")
|
||||
|
||||
// Close the IMAP server.
|
||||
if err := bridge.closeIMAP(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close IMAP server")
|
||||
}
|
||||
|
||||
// Close the SMTP server.
|
||||
if err := bridge.closeSMTP(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close SMTP server")
|
||||
// Close the servers
|
||||
if err := bridge.serverManager.CloseServers(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close servers")
|
||||
}
|
||||
|
||||
// Close all users.
|
||||
safe.RLock(func() {
|
||||
safe.Lock(func() {
|
||||
for _, user := range bridge.users {
|
||||
user.Close()
|
||||
}
|
||||
|
||||
@ -50,7 +50,6 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/tests"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
imapid "github.com/emersion/go-imap-id"
|
||||
"github.com/emersion/go-imap/client"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -173,11 +172,27 @@ func TestBridge_UserAgent(t *testing.T) {
|
||||
|
||||
func TestBridge_UserAgent_Persistence(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
otherPassword := []byte("bar")
|
||||
otherUser := "foo"
|
||||
_, _, err := s.CreateUser(otherUser, otherPassword)
|
||||
require.NoError(t, err)
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
imapWaiter := waitForIMAPServerReady(b)
|
||||
defer imapWaiter.Done()
|
||||
|
||||
smtpWaiter := waitForSMTPServerReady(b)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
require.NoError(t, getErr(b.LoginFull(ctx, otherUser, otherPassword, nil, nil)))
|
||||
|
||||
imapWaiter.Wait()
|
||||
smtpWaiter.Wait()
|
||||
|
||||
currentUserAgent := b.GetCurrentUserAgent()
|
||||
require.Contains(t, currentUserAgent, vault.DefaultUserAgent)
|
||||
|
||||
imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
imapClient, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = imapClient.Logout() }()
|
||||
|
||||
@ -220,8 +235,24 @@ func TestBridge_UserAgentFromIMAPID(t *testing.T) {
|
||||
calls = append(calls, call)
|
||||
})
|
||||
|
||||
otherPassword := []byte("bar")
|
||||
otherUser := "foo"
|
||||
_, _, err := s.CreateUser(otherUser, otherPassword)
|
||||
require.NoError(t, err)
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
imapClient, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
imapWaiter := waitForIMAPServerReady(b)
|
||||
defer imapWaiter.Done()
|
||||
|
||||
smtpWaiter := waitForSMTPServerReady(b)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
require.NoError(t, getErr(b.LoginFull(ctx, otherUser, otherPassword, nil, nil)))
|
||||
|
||||
imapWaiter.Wait()
|
||||
smtpWaiter.Wait()
|
||||
|
||||
imapClient, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = imapClient.Logout() }()
|
||||
|
||||
@ -592,10 +623,22 @@ func TestBridge_InitGluonDirectory(t *testing.T) {
|
||||
func TestBridge_LoginFailed(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
imapWaiter := waitForIMAPServerReady(bridge)
|
||||
defer imapWaiter.Done()
|
||||
|
||||
smtpWaiter := waitForSMTPServerReady(bridge)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
failCh, done := chToType[events.Event, events.IMAPLoginFailed](bridge.GetEvents(events.IMAPLoginFailed{}))
|
||||
defer done()
|
||||
|
||||
imapClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
imapWaiter.Wait()
|
||||
smtpWaiter.Wait()
|
||||
|
||||
imapClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Error(t, imapClient.Login("badUser", "badPass"))
|
||||
@ -622,6 +665,12 @@ func TestBridge_ChangeCacheDirectory(t *testing.T) {
|
||||
configDir, err := b.GetGluonDataDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
imapWaiter := waitForIMAPServerReady(b)
|
||||
defer imapWaiter.Done()
|
||||
|
||||
smtpWaiter := waitForSMTPServerReady(b)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
// Login the user.
|
||||
syncCh, done := chToType[events.Event, events.SyncFinished](b.GetEvents(events.SyncFinished{}))
|
||||
defer done()
|
||||
@ -655,7 +704,10 @@ func TestBridge_ChangeCacheDirectory(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
imapWaiter.Wait()
|
||||
smtpWaiter.Wait()
|
||||
|
||||
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
@ -695,7 +747,7 @@ func TestBridge_ChangeAddressOrder(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
@ -716,7 +768,7 @@ func TestBridge_ChangeAddressOrder(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
@ -778,6 +830,7 @@ func withBridgeNoMocks(
|
||||
locator bridge.Locator,
|
||||
vaultKey []byte,
|
||||
tests func(*bridge.Bridge),
|
||||
waitOnServers bool,
|
||||
) {
|
||||
// Bridge will disable the proxy by default at startup.
|
||||
mocks.ProxyCtl.EXPECT().DisallowProxy()
|
||||
@ -828,14 +881,17 @@ func withBridgeNoMocks(
|
||||
|
||||
// Wait for bridge to finish loading users.
|
||||
waitForEvent(t, eventCh, events.AllUsersLoaded{})
|
||||
// Wait for bridge to start the IMAP server.
|
||||
waitForEvent(t, eventCh, events.IMAPServerReady{})
|
||||
// Wait for bridge to start the SMTP server.
|
||||
waitForEvent(t, eventCh, events.SMTPServerReady{})
|
||||
|
||||
// Set random IMAP and SMTP ports for the tests.
|
||||
require.NoError(t, bridge.SetIMAPPort(0))
|
||||
require.NoError(t, bridge.SetSMTPPort(0))
|
||||
require.NoError(t, bridge.SetIMAPPort(ctx, 0))
|
||||
require.NoError(t, bridge.SetSMTPPort(ctx, 0))
|
||||
|
||||
if waitOnServers {
|
||||
// Wait for bridge to start the IMAP server.
|
||||
waitForEvent(t, eventCh, events.IMAPServerReady{})
|
||||
// Wait for bridge to start the SMTP server.
|
||||
waitForEvent(t, eventCh, events.SMTPServerReady{})
|
||||
}
|
||||
|
||||
// Close the bridge when done.
|
||||
defer bridge.Close(ctx)
|
||||
@ -857,7 +913,24 @@ func withBridge(
|
||||
withMocks(t, func(mocks *bridge.Mocks) {
|
||||
withBridgeNoMocks(ctx, t, mocks, apiURL, netCtl, locator, vaultKey, func(bridge *bridge.Bridge) {
|
||||
tests(bridge, mocks)
|
||||
})
|
||||
}, false)
|
||||
})
|
||||
}
|
||||
|
||||
// withBridgeWaitForServers is the same as withBridge, but it will wait until IMAP & SMTP servers are ready.
|
||||
func withBridgeWaitForServers(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
apiURL string,
|
||||
netCtl *proton.NetCtl,
|
||||
locator bridge.Locator,
|
||||
vaultKey []byte,
|
||||
tests func(*bridge.Bridge, *bridge.Mocks),
|
||||
) {
|
||||
withMocks(t, func(mocks *bridge.Mocks) {
|
||||
withBridgeNoMocks(ctx, t, mocks, apiURL, netCtl, locator, vaultKey, func(bridge *bridge.Bridge) {
|
||||
tests(bridge, mocks)
|
||||
}, true)
|
||||
})
|
||||
}
|
||||
|
||||
@ -910,3 +983,48 @@ func chToType[In, Out any](inCh <-chan In, done func()) (<-chan Out, func()) {
|
||||
|
||||
return outCh, done
|
||||
}
|
||||
|
||||
type eventWaiter struct {
|
||||
evtCh <-chan events.Event
|
||||
cancel func()
|
||||
}
|
||||
|
||||
func (e *eventWaiter) Done() {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
func (e *eventWaiter) Wait() {
|
||||
<-e.evtCh
|
||||
}
|
||||
|
||||
func waitForSMTPServerReady(b *bridge.Bridge) *eventWaiter {
|
||||
evtCh, cancel := b.GetEvents(events.SMTPServerReady{})
|
||||
return &eventWaiter{
|
||||
evtCh: evtCh,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func waitForSMTPServerStopped(b *bridge.Bridge) *eventWaiter {
|
||||
evtCh, cancel := b.GetEvents(events.SMTPServerStopped{})
|
||||
return &eventWaiter{
|
||||
evtCh: evtCh,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func waitForIMAPServerReady(b *bridge.Bridge) *eventWaiter {
|
||||
evtCh, cancel := b.GetEvents(events.IMAPServerReady{})
|
||||
return &eventWaiter{
|
||||
evtCh: evtCh,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func waitForIMAPServerStopped(b *bridge.Bridge) *eventWaiter {
|
||||
evtCh, cancel := b.GetEvents(events.IMAPServerStopped{})
|
||||
return &eventWaiter{
|
||||
evtCh: evtCh,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/clientconfig"
|
||||
@ -31,7 +32,7 @@ import (
|
||||
|
||||
// ConfigureAppleMail configures apple mail for the given userID and address.
|
||||
// If configuring apple mail for Catalina or newer, it ensures Bridge is using SSL.
|
||||
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
||||
func (bridge *Bridge) ConfigureAppleMail(ctx context.Context, userID, address string) error {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"userID": userID,
|
||||
"address": logging.Sensitive(address),
|
||||
@ -56,7 +57,7 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
||||
}
|
||||
|
||||
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
|
||||
if err := bridge.SetSMTPSSL(true); err != nil {
|
||||
if err := bridge.SetSMTPSSL(ctx, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,11 +58,7 @@ func moveFile(from, to string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(from, to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return os.Rename(from, to)
|
||||
}
|
||||
|
||||
func copyDir(from, to string) error {
|
||||
|
||||
@ -20,7 +20,6 @@ package bridge
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -37,203 +36,21 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (bridge *Bridge) serveIMAP() error {
|
||||
port, err := func() (int, error) {
|
||||
if bridge.imapServer == nil {
|
||||
return 0, fmt.Errorf("no IMAP server instance running")
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"port": bridge.vault.GetIMAPPort(),
|
||||
"ssl": bridge.vault.GetIMAPSSL(),
|
||||
}).Info("Starting IMAP server")
|
||||
|
||||
imapListener, err := newListener(bridge.vault.GetIMAPPort(), bridge.vault.GetIMAPSSL(), bridge.tlsConfig)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create IMAP listener: %w", err)
|
||||
}
|
||||
|
||||
bridge.imapListener = imapListener
|
||||
|
||||
if err := bridge.imapServer.Serve(context.Background(), bridge.imapListener); err != nil {
|
||||
return 0, fmt.Errorf("failed to serve IMAP: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.vault.SetIMAPPort(getPort(imapListener.Addr())); err != nil {
|
||||
return 0, fmt.Errorf("failed to store IMAP port in vault: %w", err)
|
||||
}
|
||||
|
||||
return getPort(imapListener.Addr()), nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
bridge.publish(events.IMAPServerError{
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
bridge.publish(events.IMAPServerReady{
|
||||
Port: port,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) restartIMAP() error {
|
||||
logrus.Info("Restarting IMAP server")
|
||||
|
||||
if bridge.imapListener != nil {
|
||||
if err := bridge.imapListener.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close IMAP listener: %w", err)
|
||||
}
|
||||
|
||||
bridge.publish(events.IMAPServerStopped{})
|
||||
}
|
||||
|
||||
return bridge.serveIMAP()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) closeIMAP(ctx context.Context) error {
|
||||
logrus.Info("Closing IMAP server")
|
||||
|
||||
if bridge.imapServer != nil {
|
||||
if err := bridge.imapServer.Close(ctx); err != nil {
|
||||
return fmt.Errorf("failed to close IMAP server: %w", err)
|
||||
}
|
||||
|
||||
bridge.imapServer = nil
|
||||
}
|
||||
|
||||
if bridge.imapListener != nil {
|
||||
if err := bridge.imapListener.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close IMAP listener: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
bridge.publish(events.IMAPServerStopped{})
|
||||
|
||||
return nil
|
||||
func (bridge *Bridge) restartIMAP(ctx context.Context) error {
|
||||
return bridge.serverManager.RestartIMAP(ctx)
|
||||
}
|
||||
|
||||
// addIMAPUser connects the given user to gluon.
|
||||
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
||||
if bridge.imapServer == nil {
|
||||
return fmt.Errorf("no imap server instance running")
|
||||
}
|
||||
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create IMAP connectors: %w", err)
|
||||
}
|
||||
|
||||
for addrID, imapConn := range imapConn {
|
||||
log := logrus.WithFields(logrus.Fields{
|
||||
"userID": user.ID(),
|
||||
"addrID": addrID,
|
||||
})
|
||||
|
||||
if gluonID, ok := user.GetGluonID(addrID); ok {
|
||||
log.WithField("gluonID", gluonID).Info("Loading existing IMAP user")
|
||||
|
||||
// Load the user, checking whether the DB was newly created.
|
||||
isNew, err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if isNew {
|
||||
// If the DB was newly created, clear the sync status; gluon's DB was not found.
|
||||
logrus.Warn("IMAP user DB was newly created, clearing sync status")
|
||||
|
||||
// Remove the user from IMAP so we can clear the sync status.
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
|
||||
return fmt.Errorf("failed to remove IMAP user: %w", err)
|
||||
}
|
||||
|
||||
// Clear the sync status -- we need to resync all messages.
|
||||
if err := user.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
// Add the user back to the IMAP server.
|
||||
if isNew, err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
} else if isNew {
|
||||
panic("IMAP user should already have a database")
|
||||
}
|
||||
} else if status := user.GetSyncStatus(); !status.HasLabels {
|
||||
// Otherwise, the DB already exists -- if the labels are not yet synced, we need to re-create the DB.
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||
return fmt.Errorf("failed to remove old IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if err := user.RemoveGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to remove old IMAP user ID: %w", err)
|
||||
}
|
||||
|
||||
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if err := user.SetGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to set IMAP user ID: %w", err)
|
||||
}
|
||||
|
||||
log.WithField("gluonID", gluonID).Info("Re-created IMAP user")
|
||||
}
|
||||
} else {
|
||||
log.Info("Creating new IMAP user")
|
||||
|
||||
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if err := user.SetGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to set IMAP user ID: %w", err)
|
||||
}
|
||||
|
||||
log.WithField("gluonID", gluonID).Info("Created new IMAP user")
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger a sync for the user, if needed.
|
||||
user.TriggerSync()
|
||||
|
||||
return nil
|
||||
return bridge.serverManager.AddIMAPUser(ctx, user)
|
||||
}
|
||||
|
||||
// removeIMAPUser disconnects the given user from gluon, optionally also removing its files.
|
||||
func (bridge *Bridge) removeIMAPUser(ctx context.Context, user *user.User, withData bool) error {
|
||||
if bridge.imapServer == nil {
|
||||
return fmt.Errorf("no imap server instance running")
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"userID": user.ID(),
|
||||
"withData": withData,
|
||||
}).Debug("Removing IMAP user")
|
||||
|
||||
for addrID, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withData); err != nil {
|
||||
return fmt.Errorf("failed to remove IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if withData {
|
||||
if err := user.RemoveGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to remove IMAP user ID: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return bridge.serverManager.RemoveIMAPUser(ctx, user, withData)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
|
||||
@ -262,19 +79,12 @@ func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"sessionID": event.SessionID,
|
||||
"username": event.Username,
|
||||
}).Info("Received IMAP login failure notification")
|
||||
"pkg": "imap",
|
||||
}).Error("Incorrect login credentials.")
|
||||
bridge.publish(events.IMAPLoginFailed{Username: event.Username})
|
||||
}
|
||||
}
|
||||
|
||||
func getGluonDir(encVault *vault.Vault) (string, error) {
|
||||
if err := os.MkdirAll(encVault.GetGluonCacheDir(), 0o700); err != nil {
|
||||
return "", fmt.Errorf("failed to create gluon dir: %w", err)
|
||||
}
|
||||
|
||||
return encVault.GetGluonCacheDir(), nil
|
||||
}
|
||||
|
||||
func ApplyGluonCachePathSuffix(basePath string) string {
|
||||
return filepath.Join(basePath, "backend", "store")
|
||||
}
|
||||
|
||||
@ -144,13 +144,13 @@ func (testUpdater *TestUpdater) SetLatestVersion(version, minAuto *semver.Versio
|
||||
}
|
||||
}
|
||||
|
||||
func (testUpdater *TestUpdater) GetVersionInfo(ctx context.Context, downloader updater.Downloader, channel updater.Channel) (updater.VersionInfo, error) {
|
||||
func (testUpdater *TestUpdater) GetVersionInfo(_ context.Context, _ updater.Downloader, _ updater.Channel) (updater.VersionInfo, error) {
|
||||
testUpdater.lock.RLock()
|
||||
defer testUpdater.lock.RUnlock()
|
||||
|
||||
return testUpdater.latest, nil
|
||||
}
|
||||
|
||||
func (testUpdater *TestUpdater) InstallUpdate(ctx context.Context, downloader updater.Downloader, update updater.VersionInfo) error {
|
||||
func (testUpdater *TestUpdater) InstallUpdate(_ context.Context, _ updater.Downloader, _ updater.VersionInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -28,7 +28,6 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/bradenaw/juniper/iterator"
|
||||
"github.com/emersion/go-imap/client"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -66,7 +65,7 @@ func TestBridge_Refresh(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
@ -99,7 +98,7 @@ func TestBridge_Refresh(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
|
||||
@ -34,7 +34,6 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/emersion/go-imap"
|
||||
"github.com/emersion/go-imap/client"
|
||||
"github.com/emersion/go-sasl"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -46,12 +45,17 @@ func TestBridge_Send(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
|
||||
smtpWaiter := waitForSMTPServerReady(bridge)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
senderUserID, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
recipientUserID, err := bridge.LoginFull(ctx, "recipient", password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
smtpWaiter.Wait()
|
||||
|
||||
senderInfo, err := bridge.GetUserInfo(senderUserID)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -91,13 +95,13 @@ func TestBridge_Send(t *testing.T) {
|
||||
}
|
||||
|
||||
// Connect the sender IMAP client.
|
||||
senderIMAPClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
senderIMAPClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, senderIMAPClient.Login(senderInfo.Addresses[0], string(senderInfo.BridgePass)))
|
||||
defer senderIMAPClient.Logout() //nolint:errcheck
|
||||
|
||||
// Connect the recipient IMAP client.
|
||||
recipientIMAPClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
recipientIMAPClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, recipientIMAPClient.Login(recipientInfo.Addresses[0], string(recipientInfo.BridgePass)))
|
||||
defer recipientIMAPClient.Logout() //nolint:errcheck
|
||||
@ -135,13 +139,13 @@ func TestBridge_SendDraftFlags(t *testing.T) {
|
||||
})
|
||||
|
||||
// Start the bridge.
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
|
||||
withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
|
||||
// Get the sender user info.
|
||||
userInfo, err := bridge.QueryUserInfo(username)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect the sender IMAP client.
|
||||
imapClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
imapClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, imapClient.Login(userInfo.Addresses[0], string(userInfo.BridgePass)))
|
||||
defer imapClient.Logout() //nolint:errcheck
|
||||
@ -245,13 +249,13 @@ func TestBridge_SendInvite(t *testing.T) {
|
||||
})
|
||||
|
||||
// Start the bridge.
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
|
||||
withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
|
||||
// Get the sender user info.
|
||||
userInfo, err := bridge.QueryUserInfo(username)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect the sender IMAP client.
|
||||
imapClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
imapClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, imapClient.Login(userInfo.Addresses[0], string(userInfo.BridgePass)))
|
||||
defer imapClient.Logout() //nolint:errcheck
|
||||
@ -401,6 +405,9 @@ SGVsbG8gd29ybGQK
|
||||
require.NoError(t, err)
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
|
||||
smtpWaiter := waitForSMTPServerReady(bridge)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
senderUserID, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -420,6 +427,8 @@ SGVsbG8gd29ybGQK
|
||||
messageMultipartWithoutTextWithTextAttachment,
|
||||
}
|
||||
|
||||
smtpWaiter.Wait()
|
||||
|
||||
for _, m := range messages {
|
||||
// Dial the server.
|
||||
client, err := smtp.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetSMTPPort())))
|
||||
@ -444,13 +453,13 @@ SGVsbG8gd29ybGQK
|
||||
}
|
||||
|
||||
// Connect the sender IMAP client.
|
||||
senderIMAPClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
senderIMAPClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, senderIMAPClient.Login(senderInfo.Addresses[0], string(senderInfo.BridgePass)))
|
||||
defer senderIMAPClient.Logout() //nolint:errcheck
|
||||
|
||||
// Connect the recipient IMAP client.
|
||||
recipientIMAPClient, err := client.Dial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
recipientIMAPClient, err := eventuallyDial(net.JoinHostPort(constants.Host, fmt.Sprint(bridge.GetIMAPPort())))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, recipientIMAPClient.Login(recipientInfo.Addresses[0], string(recipientInfo.BridgePass)))
|
||||
defer recipientIMAPClient.Logout() //nolint:errcheck
|
||||
|
||||
@ -36,6 +36,9 @@ import (
|
||||
func TestBridge_Report(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
imapWaiter := waitForIMAPServerReady(b)
|
||||
defer imapWaiter.Done()
|
||||
|
||||
syncCh, done := chToType[events.Event, events.SyncFinished](b.GetEvents(events.SyncFinished{}))
|
||||
defer done()
|
||||
|
||||
@ -51,6 +54,8 @@ func TestBridge_Report(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
imapWaiter.Wait()
|
||||
|
||||
// Dial the IMAP port.
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
|
||||
696
internal/bridge/server_manager.go
Normal file
696
internal/bridge/server_manager.go
Normal file
@ -0,0 +1,696 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ProtonMail/gluon"
|
||||
"github.com/ProtonMail/gluon/connector"
|
||||
"github.com/ProtonMail/gluon/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ServerManager manages the IMAP & SMTP servers and their listeners.
|
||||
type ServerManager struct {
|
||||
requests *cpc.CPC
|
||||
|
||||
imapServer *gluon.Server
|
||||
imapListener net.Listener
|
||||
|
||||
smtpServer *smtp.Server
|
||||
smtpListener net.Listener
|
||||
|
||||
loadedUserCount int
|
||||
}
|
||||
|
||||
func newServerManager() *ServerManager {
|
||||
return &ServerManager{
|
||||
requests: cpc.NewCPC(),
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ServerManager) Init(bridge *Bridge) error {
|
||||
imapServer, err := createIMAPServer(bridge)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
smtpServer := createSMTPServer(bridge)
|
||||
|
||||
sm.imapServer = imapServer
|
||||
sm.smtpServer = smtpServer
|
||||
|
||||
bridge.tasks.Once(func(ctx context.Context) {
|
||||
logging.DoAnnotated(ctx, func(ctx context.Context) {
|
||||
sm.run(ctx, bridge)
|
||||
}, logging.Labels{
|
||||
"service": "server-manager",
|
||||
})
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ServerManager) CloseServers(ctx context.Context) error {
|
||||
defer sm.requests.Close()
|
||||
_, err := sm.requests.Send(ctx, &smRequestClose{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (sm *ServerManager) RestartIMAP(ctx context.Context) error {
|
||||
_, err := sm.requests.Send(ctx, &smRequestRestartIMAP{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (sm *ServerManager) RestartSMTP(ctx context.Context) error {
|
||||
_, err := sm.requests.Send(ctx, &smRequestRestartSMTP{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (sm *ServerManager) AddIMAPUser(ctx context.Context, user *user.User) error {
|
||||
_, err := sm.requests.Send(ctx, &smRequestAddIMAPUser{user: user})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (sm *ServerManager) RemoveIMAPUser(ctx context.Context, user *user.User, withData bool) error {
|
||||
_, err := sm.requests.Send(ctx, &smRequestRemoveIMAPUser{
|
||||
user: user,
|
||||
withData: withData,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (sm *ServerManager) SetGluonDir(ctx context.Context, gluonDir string) error {
|
||||
_, err := sm.requests.Send(ctx, &smRequestSetGluonDir{
|
||||
dir: gluonDir,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (sm *ServerManager) AddGluonUser(ctx context.Context, conn connector.Connector, passphrase []byte) (string, error) {
|
||||
reply, err := cpc.SendTyped[string](ctx, sm.requests, &smRequestAddGluonUser{
|
||||
conn: conn,
|
||||
passphrase: passphrase,
|
||||
})
|
||||
|
||||
return reply, err
|
||||
}
|
||||
|
||||
func (sm *ServerManager) RemoveGluonUser(ctx context.Context, gluonID string) error {
|
||||
_, err := sm.requests.Send(ctx, &smRequestRemoveGluonUser{
|
||||
userID: gluonID,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) {
|
||||
eventCh, cancel := bridge.GetEvents()
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
sm.handleClose(ctx, bridge)
|
||||
return
|
||||
|
||||
case evt := <-eventCh:
|
||||
switch evt.(type) {
|
||||
case events.ConnStatusDown:
|
||||
logrus.Info("Server Manager, network down stopping listeners")
|
||||
if err := sm.closeSMTPServer(bridge); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close SMTP server")
|
||||
}
|
||||
|
||||
if err := sm.stopIMAPListener(bridge); err != nil {
|
||||
logrus.WithError(err)
|
||||
}
|
||||
case events.ConnStatusUp:
|
||||
logrus.Info("Server Manager, network up starting listeners")
|
||||
sm.handleLoadedUserCountChange(ctx, bridge)
|
||||
}
|
||||
|
||||
case request, ok := <-sm.requests.ReceiveCh():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
switch r := request.Value().(type) {
|
||||
case *smRequestClose:
|
||||
sm.handleClose(ctx, bridge)
|
||||
request.Reply(ctx, nil, nil)
|
||||
return
|
||||
|
||||
case *smRequestRestartSMTP:
|
||||
err := sm.restartSMTP(bridge)
|
||||
request.Reply(ctx, nil, err)
|
||||
|
||||
case *smRequestRestartIMAP:
|
||||
err := sm.restartIMAP(ctx, bridge)
|
||||
request.Reply(ctx, nil, err)
|
||||
|
||||
case *smRequestAddIMAPUser:
|
||||
err := sm.handleAddIMAPUser(ctx, r.user)
|
||||
request.Reply(ctx, nil, err)
|
||||
if err == nil {
|
||||
sm.loadedUserCount++
|
||||
sm.handleLoadedUserCountChange(ctx, bridge)
|
||||
}
|
||||
|
||||
case *smRequestRemoveIMAPUser:
|
||||
err := sm.handleRemoveIMAPUser(ctx, r.user, r.withData)
|
||||
request.Reply(ctx, nil, err)
|
||||
if err == nil {
|
||||
sm.loadedUserCount--
|
||||
sm.handleLoadedUserCountChange(ctx, bridge)
|
||||
}
|
||||
|
||||
case *smRequestSetGluonDir:
|
||||
err := sm.handleSetGluonDir(ctx, bridge, r.dir)
|
||||
request.Reply(ctx, nil, err)
|
||||
|
||||
case *smRequestAddGluonUser:
|
||||
id, err := sm.handleAddGluonUser(ctx, r.conn, r.passphrase)
|
||||
request.Reply(ctx, id, err)
|
||||
|
||||
case *smRequestRemoveGluonUser:
|
||||
err := sm.handleRemoveGluonUser(ctx, r.userID)
|
||||
request.Reply(ctx, nil, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ServerManager) handleLoadedUserCountChange(ctx context.Context, bridge *Bridge) {
|
||||
logrus.Infof("Validating Listener State %v", sm.loadedUserCount)
|
||||
if sm.shouldStartServers() {
|
||||
if sm.imapListener == nil {
|
||||
if err := sm.serveIMAP(ctx, bridge); err != nil {
|
||||
logrus.WithError(err).Error("Failed to start IMAP server")
|
||||
}
|
||||
}
|
||||
|
||||
if sm.smtpListener == nil {
|
||||
if err := sm.restartSMTP(bridge); err != nil {
|
||||
logrus.WithError(err).Error("Failed to start SMTP server")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if sm.imapListener != nil {
|
||||
if err := sm.stopIMAPListener(bridge); err != nil {
|
||||
logrus.WithError(err).Error("Failed to stop IMAP server")
|
||||
}
|
||||
}
|
||||
|
||||
if sm.smtpListener != nil {
|
||||
if err := sm.closeSMTPServer(bridge); err != nil {
|
||||
logrus.WithError(err).Error("Failed to stop SMTP server")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ServerManager) handleClose(ctx context.Context, bridge *Bridge) {
|
||||
// Close the IMAP server.
|
||||
if err := sm.closeIMAPServer(ctx, bridge); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close IMAP server")
|
||||
}
|
||||
|
||||
// Close the SMTP server.
|
||||
if err := sm.closeSMTPServer(bridge); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close SMTP server")
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ServerManager) handleAddIMAPUser(ctx context.Context, user *user.User) error {
|
||||
if sm.imapServer == nil {
|
||||
return fmt.Errorf("no imap server instance running")
|
||||
}
|
||||
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create IMAP connectors: %w", err)
|
||||
}
|
||||
|
||||
for addrID, imapConn := range imapConn {
|
||||
log := logrus.WithFields(logrus.Fields{
|
||||
"userID": user.ID(),
|
||||
"addrID": addrID,
|
||||
})
|
||||
|
||||
if gluonID, ok := user.GetGluonID(addrID); ok {
|
||||
log.WithField("gluonID", gluonID).Info("Loading existing IMAP user")
|
||||
|
||||
// Load the user, checking whether the DB was newly created.
|
||||
isNew, err := sm.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if isNew {
|
||||
// If the DB was newly created, clear the sync status; gluon's DB was not found.
|
||||
logrus.Warn("IMAP user DB was newly created, clearing sync status")
|
||||
|
||||
// Remove the user from IMAP so we can clear the sync status.
|
||||
if err := sm.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
|
||||
return fmt.Errorf("failed to remove IMAP user: %w", err)
|
||||
}
|
||||
|
||||
// Clear the sync status -- we need to resync all messages.
|
||||
if err := user.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
// Add the user back to the IMAP server.
|
||||
if isNew, err := sm.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
} else if isNew {
|
||||
panic("IMAP user should already have a database")
|
||||
}
|
||||
} else if status := user.GetSyncStatus(); !status.HasLabels {
|
||||
// Otherwise, the DB already exists -- if the labels are not yet synced, we need to re-create the DB.
|
||||
if err := sm.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||
return fmt.Errorf("failed to remove old IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if err := user.RemoveGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to remove old IMAP user ID: %w", err)
|
||||
}
|
||||
|
||||
gluonID, err := sm.imapServer.AddUser(ctx, imapConn, user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if err := user.SetGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to set IMAP user ID: %w", err)
|
||||
}
|
||||
|
||||
log.WithField("gluonID", gluonID).Info("Re-created IMAP user")
|
||||
}
|
||||
} else {
|
||||
log.Info("Creating new IMAP user")
|
||||
|
||||
gluonID, err := sm.imapServer.AddUser(ctx, imapConn, user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if err := user.SetGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to set IMAP user ID: %w", err)
|
||||
}
|
||||
|
||||
log.WithField("gluonID", gluonID).Info("Created new IMAP user")
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger a sync for the user, if needed.
|
||||
user.TriggerSync()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ServerManager) handleRemoveIMAPUser(ctx context.Context, user *user.User, withData bool) error {
|
||||
if sm.imapServer == nil {
|
||||
return fmt.Errorf("no imap server instance running")
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"userID": user.ID(),
|
||||
"withData": withData,
|
||||
}).Debug("Removing IMAP user")
|
||||
|
||||
for addrID, gluonID := range user.GetGluonIDs() {
|
||||
if err := sm.imapServer.RemoveUser(ctx, gluonID, withData); err != nil {
|
||||
return fmt.Errorf("failed to remove IMAP user: %w", err)
|
||||
}
|
||||
|
||||
if withData {
|
||||
if err := user.RemoveGluonID(addrID, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to remove IMAP user ID: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createIMAPServer(bridge *Bridge) (*gluon.Server, error) {
|
||||
gluonDataDir, err := bridge.GetGluonDataDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get Gluon Database directory: %w", err)
|
||||
}
|
||||
|
||||
return newIMAPServer(
|
||||
bridge.vault.GetGluonCacheDir(),
|
||||
gluonDataDir,
|
||||
bridge.curVersion,
|
||||
bridge.tlsConfig,
|
||||
bridge.reporter,
|
||||
bridge.logIMAPClient,
|
||||
bridge.logIMAPServer,
|
||||
bridge.imapEventCh,
|
||||
bridge.tasks,
|
||||
bridge.uidValidityGenerator,
|
||||
bridge.panicHandler,
|
||||
)
|
||||
}
|
||||
|
||||
func createSMTPServer(bridge *Bridge) *smtp.Server {
|
||||
return newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
|
||||
}
|
||||
|
||||
func (sm *ServerManager) closeSMTPServer(bridge *Bridge) error {
|
||||
// We close the listener ourselves even though it's also closed by smtpServer.Close().
|
||||
// This is because smtpServer.Serve() is called in a separate goroutine and might be executed
|
||||
// after we've already closed the server. However, go-smtp has a bug; it blocks on the listener
|
||||
// even after the server has been closed. So we close the listener ourselves to unblock it.
|
||||
|
||||
if sm.smtpListener != nil {
|
||||
logrus.Info("Closing SMTP Listener")
|
||||
if err := sm.smtpListener.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close SMTP listener: %w", err)
|
||||
}
|
||||
|
||||
sm.smtpListener = nil
|
||||
}
|
||||
|
||||
if sm.smtpServer != nil {
|
||||
logrus.Info("Closing SMTP server")
|
||||
if err := sm.smtpServer.Close(); err != nil {
|
||||
logrus.WithError(err).Debug("Failed to close SMTP server (expected -- we close the listener ourselves)")
|
||||
}
|
||||
|
||||
sm.smtpServer = nil
|
||||
|
||||
bridge.publish(events.SMTPServerStopped{})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ServerManager) closeIMAPServer(ctx context.Context, bridge *Bridge) error {
|
||||
if sm.imapListener != nil {
|
||||
logrus.Info("Closing IMAP Listener")
|
||||
|
||||
if err := sm.imapListener.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close IMAP listener: %w", err)
|
||||
}
|
||||
|
||||
sm.imapListener = nil
|
||||
|
||||
bridge.publish(events.IMAPServerStopped{})
|
||||
}
|
||||
|
||||
if sm.imapServer != nil {
|
||||
logrus.Info("Closing IMAP server")
|
||||
if err := sm.imapServer.Close(ctx); err != nil {
|
||||
return fmt.Errorf("failed to close IMAP server: %w", err)
|
||||
}
|
||||
|
||||
sm.imapServer = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ServerManager) restartIMAP(ctx context.Context, bridge *Bridge) error {
|
||||
logrus.Info("Restarting IMAP server")
|
||||
|
||||
if sm.imapListener != nil {
|
||||
if err := sm.imapListener.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close IMAP listener: %w", err)
|
||||
}
|
||||
|
||||
sm.imapListener = nil
|
||||
|
||||
bridge.publish(events.IMAPServerStopped{})
|
||||
}
|
||||
|
||||
if sm.shouldStartServers() {
|
||||
return sm.serveIMAP(ctx, bridge)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ServerManager) restartSMTP(bridge *Bridge) error {
|
||||
logrus.Info("Restarting SMTP server")
|
||||
|
||||
if err := sm.closeSMTPServer(bridge); err != nil {
|
||||
return fmt.Errorf("failed to close SMTP: %w", err)
|
||||
}
|
||||
|
||||
bridge.publish(events.SMTPServerStopped{})
|
||||
|
||||
sm.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
|
||||
|
||||
if sm.shouldStartServers() {
|
||||
return sm.serveSMTP(bridge)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ServerManager) serveSMTP(bridge *Bridge) error {
|
||||
port, err := func() (int, error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"port": bridge.vault.GetSMTPPort(),
|
||||
"ssl": bridge.vault.GetSMTPSSL(),
|
||||
}).Info("Starting SMTP server")
|
||||
|
||||
smtpListener, err := newListener(bridge.vault.GetSMTPPort(), bridge.vault.GetSMTPSSL(), bridge.tlsConfig)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create SMTP listener: %w", err)
|
||||
}
|
||||
|
||||
sm.smtpListener = smtpListener
|
||||
|
||||
bridge.tasks.Once(func(context.Context) {
|
||||
if err := sm.smtpServer.Serve(smtpListener); err != nil {
|
||||
logrus.WithError(err).Info("SMTP server stopped")
|
||||
}
|
||||
})
|
||||
|
||||
if err := bridge.vault.SetSMTPPort(getPort(smtpListener.Addr())); err != nil {
|
||||
return 0, fmt.Errorf("failed to store SMTP port in vault: %w", err)
|
||||
}
|
||||
|
||||
return getPort(smtpListener.Addr()), nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
bridge.publish(events.SMTPServerError{
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
bridge.publish(events.SMTPServerReady{
|
||||
Port: port,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ServerManager) serveIMAP(ctx context.Context, bridge *Bridge) error {
|
||||
port, err := func() (int, error) {
|
||||
if sm.imapServer == nil {
|
||||
return 0, fmt.Errorf("no IMAP server instance running")
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"port": bridge.vault.GetIMAPPort(),
|
||||
"ssl": bridge.vault.GetIMAPSSL(),
|
||||
}).Info("Starting IMAP server")
|
||||
|
||||
imapListener, err := newListener(bridge.vault.GetIMAPPort(), bridge.vault.GetIMAPSSL(), bridge.tlsConfig)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create IMAP listener: %w", err)
|
||||
}
|
||||
|
||||
sm.imapListener = imapListener
|
||||
|
||||
if err := sm.imapServer.Serve(ctx, sm.imapListener); err != nil {
|
||||
return 0, fmt.Errorf("failed to serve IMAP: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.vault.SetIMAPPort(getPort(imapListener.Addr())); err != nil {
|
||||
return 0, fmt.Errorf("failed to store IMAP port in vault: %w", err)
|
||||
}
|
||||
|
||||
return getPort(imapListener.Addr()), nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
bridge.publish(events.IMAPServerError{
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
bridge.publish(events.IMAPServerReady{
|
||||
Port: port,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ServerManager) stopIMAPListener(bridge *Bridge) error {
|
||||
logrus.Info("Stopping IMAP listener")
|
||||
if sm.imapListener != nil {
|
||||
if err := sm.imapListener.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sm.imapListener = nil
|
||||
|
||||
bridge.publish(events.IMAPServerStopped{})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge, newGluonDir string) error {
|
||||
return safe.RLockRet(func() error {
|
||||
currentGluonDir := bridge.GetGluonCacheDir()
|
||||
newGluonDir = filepath.Join(newGluonDir, "gluon")
|
||||
if newGluonDir == currentGluonDir {
|
||||
return fmt.Errorf("new gluon dir is the same as the old one")
|
||||
}
|
||||
|
||||
if err := sm.closeIMAPServer(context.Background(), bridge); err != nil {
|
||||
return fmt.Errorf("failed to close IMAP: %w", err)
|
||||
}
|
||||
|
||||
sm.loadedUserCount = 0
|
||||
|
||||
if err := bridge.moveGluonCacheDir(currentGluonDir, newGluonDir); err != nil {
|
||||
logrus.WithError(err).Error("failed to move GluonCacheDir")
|
||||
|
||||
if err := bridge.vault.SetGluonDir(currentGluonDir); err != nil {
|
||||
return fmt.Errorf("failed to revert GluonCacheDir: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
bridge.heartbeat.SetCacheLocation(newGluonDir)
|
||||
|
||||
gluonDataDir, err := bridge.GetGluonDataDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get Gluon Database directory: %w", err)
|
||||
}
|
||||
|
||||
imapServer, err := newIMAPServer(
|
||||
bridge.vault.GetGluonCacheDir(),
|
||||
gluonDataDir,
|
||||
bridge.curVersion,
|
||||
bridge.tlsConfig,
|
||||
bridge.reporter,
|
||||
bridge.logIMAPClient,
|
||||
bridge.logIMAPServer,
|
||||
bridge.imapEventCh,
|
||||
bridge.tasks,
|
||||
bridge.uidValidityGenerator,
|
||||
bridge.panicHandler,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new IMAP server: %w", err)
|
||||
}
|
||||
|
||||
sm.imapServer = imapServer
|
||||
for _, bridgeUser := range bridge.users {
|
||||
if err := sm.handleAddIMAPUser(ctx, bridgeUser); err != nil {
|
||||
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
|
||||
}
|
||||
sm.loadedUserCount++
|
||||
}
|
||||
|
||||
if sm.shouldStartServers() {
|
||||
if err := sm.serveIMAP(ctx, bridge); err != nil {
|
||||
return fmt.Errorf("failed to serve IMAP: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}, bridge.usersLock)
|
||||
}
|
||||
|
||||
func (sm *ServerManager) handleAddGluonUser(ctx context.Context, conn connector.Connector, passphrase []byte) (string, error) {
|
||||
if sm.imapServer == nil {
|
||||
return "", fmt.Errorf("no imap server instance running")
|
||||
}
|
||||
|
||||
return sm.imapServer.AddUser(ctx, conn, passphrase)
|
||||
}
|
||||
|
||||
func (sm *ServerManager) handleRemoveGluonUser(ctx context.Context, userID string) error {
|
||||
if sm.imapServer == nil {
|
||||
return fmt.Errorf("no imap server instance running")
|
||||
}
|
||||
|
||||
return sm.imapServer.RemoveUser(ctx, userID, true)
|
||||
}
|
||||
|
||||
func (sm *ServerManager) shouldStartServers() bool {
|
||||
return sm.loadedUserCount >= 1
|
||||
}
|
||||
|
||||
type smRequestClose struct{}
|
||||
|
||||
type smRequestRestartIMAP struct{}
|
||||
|
||||
type smRequestRestartSMTP struct{}
|
||||
|
||||
type smRequestAddIMAPUser struct {
|
||||
user *user.User
|
||||
}
|
||||
|
||||
type smRequestRemoveIMAPUser struct {
|
||||
user *user.User
|
||||
withData bool
|
||||
}
|
||||
|
||||
type smRequestSetGluonDir struct {
|
||||
dir string
|
||||
}
|
||||
|
||||
type smRequestAddGluonUser struct {
|
||||
conn connector.Connector
|
||||
passphrase []byte
|
||||
}
|
||||
|
||||
type smRequestRemoveGluonUser struct {
|
||||
userID string
|
||||
}
|
||||
179
internal/bridge/server_manager_test.go
Normal file
179
internal/bridge/server_manager_test.go
Normal file
@ -0,0 +1,179 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/go-proton-api/server"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServerManager_NoLoadedUsersNoServers(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
_, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
require.Error(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerManager_ServersStartAfterFirstConnectedUser(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
imapWaiter := waitForIMAPServerReady(bridge)
|
||||
defer imapWaiter.Done()
|
||||
|
||||
smtpWaiter := waitForSMTPServerReady(bridge)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
imapWaiter.Wait()
|
||||
smtpWaiter.Wait()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerManager_ServersStopsAfterUserLogsOut(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
imapWaiter := waitForIMAPServerReady(bridge)
|
||||
defer imapWaiter.Done()
|
||||
|
||||
smtpWaiter := waitForSMTPServerReady(bridge)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
userID, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
imapWaiter.Wait()
|
||||
smtpWaiter.Wait()
|
||||
|
||||
imapWaiterStopped := waitForIMAPServerStopped(bridge)
|
||||
defer imapWaiterStopped.Done()
|
||||
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
|
||||
imapWaiterStopped.Wait()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerManager_ServersDoNotStopWhenThereIsStillOneActiveUser(t *testing.T) {
|
||||
otherPassword := []byte("bar")
|
||||
otherUser := "foo"
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
_, _, err := s.CreateUser(otherUser, otherPassword)
|
||||
require.NoError(t, err)
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
imapWaiter := waitForIMAPServerReady(bridge)
|
||||
defer imapWaiter.Done()
|
||||
|
||||
smtpWaiter := waitForSMTPServerReady(bridge)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userIDOther, err := bridge.LoginFull(ctx, otherUser, otherPassword, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
imapWaiter.Wait()
|
||||
smtpWaiter.Wait()
|
||||
|
||||
evtCh, cancel := bridge.GetEvents(events.UserDeauth{})
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, s.RevokeUser(userIDOther))
|
||||
|
||||
waitForEvent(t, evtCh, events.UserDeauth{})
|
||||
|
||||
imapClient, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, imapClient.Logout())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerManager_ServersStartIfAtLeastOneUserIsLoggedIn(t *testing.T) {
|
||||
otherPassword := []byte("bar")
|
||||
otherUser := "foo"
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
userIDOther, _, err := s.CreateUser(otherUser, otherPassword)
|
||||
require.NoError(t, err)
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = bridge.LoginFull(ctx, otherUser, otherPassword, nil, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
require.NoError(t, s.RevokeUser(userIDOther))
|
||||
|
||||
withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
imapClient, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, imapClient.Logout())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerManager_NetworkLossStopsServers(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
imapWaiter := waitForIMAPServerReady(bridge)
|
||||
defer imapWaiter.Done()
|
||||
|
||||
smtpWaiter := waitForSMTPServerReady(bridge)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
imapWaiterStop := waitForIMAPServerStopped(bridge)
|
||||
defer imapWaiterStop.Done()
|
||||
|
||||
smtpWaiterStop := waitForSMTPServerStopped(bridge)
|
||||
defer smtpWaiterStop.Done()
|
||||
|
||||
_, err := bridge.LoginFull(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
imapWaiter.Wait()
|
||||
smtpWaiter.Wait()
|
||||
|
||||
netCtl.Disable()
|
||||
|
||||
imapWaiterStop.Wait()
|
||||
smtpWaiterStop.Wait()
|
||||
|
||||
netCtl.Enable()
|
||||
|
||||
imapWaiter.Wait()
|
||||
smtpWaiter.Wait()
|
||||
})
|
||||
})
|
||||
}
|
||||
@ -22,7 +22,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
@ -55,7 +54,7 @@ func (bridge *Bridge) GetIMAPPort() int {
|
||||
return bridge.vault.GetIMAPPort()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetIMAPPort(newPort int) error {
|
||||
func (bridge *Bridge) SetIMAPPort(ctx context.Context, newPort int) error {
|
||||
if newPort == bridge.vault.GetIMAPPort() {
|
||||
return nil
|
||||
}
|
||||
@ -66,14 +65,14 @@ func (bridge *Bridge) SetIMAPPort(newPort int) error {
|
||||
|
||||
bridge.heartbeat.SetIMAPPort(newPort)
|
||||
|
||||
return bridge.restartIMAP()
|
||||
return bridge.restartIMAP(ctx)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetIMAPSSL() bool {
|
||||
return bridge.vault.GetIMAPSSL()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetIMAPSSL(newSSL bool) error {
|
||||
func (bridge *Bridge) SetIMAPSSL(ctx context.Context, newSSL bool) error {
|
||||
if newSSL == bridge.vault.GetIMAPSSL() {
|
||||
return nil
|
||||
}
|
||||
@ -84,14 +83,14 @@ func (bridge *Bridge) SetIMAPSSL(newSSL bool) error {
|
||||
|
||||
bridge.heartbeat.SetIMAPConnectionMode(newSSL)
|
||||
|
||||
return bridge.restartIMAP()
|
||||
return bridge.restartIMAP(ctx)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetSMTPPort() int {
|
||||
return bridge.vault.GetSMTPPort()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetSMTPPort(newPort int) error {
|
||||
func (bridge *Bridge) SetSMTPPort(ctx context.Context, newPort int) error {
|
||||
if newPort == bridge.vault.GetSMTPPort() {
|
||||
return nil
|
||||
}
|
||||
@ -102,14 +101,14 @@ func (bridge *Bridge) SetSMTPPort(newPort int) error {
|
||||
|
||||
bridge.heartbeat.SetSMTPPort(newPort)
|
||||
|
||||
return bridge.restartSMTP()
|
||||
return bridge.restartSMTP(ctx)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetSMTPSSL() bool {
|
||||
return bridge.vault.GetSMTPSSL()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetSMTPSSL(newSSL bool) error {
|
||||
func (bridge *Bridge) SetSMTPSSL(ctx context.Context, newSSL bool) error {
|
||||
if newSSL == bridge.vault.GetSMTPSSL() {
|
||||
return nil
|
||||
}
|
||||
@ -120,7 +119,7 @@ func (bridge *Bridge) SetSMTPSSL(newSSL bool) error {
|
||||
|
||||
bridge.heartbeat.SetSMTPConnectionMode(newSSL)
|
||||
|
||||
return bridge.restartSMTP()
|
||||
return bridge.restartSMTP(ctx)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetGluonCacheDir() string {
|
||||
@ -132,63 +131,7 @@ func (bridge *Bridge) GetGluonDataDir() (string, error) {
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
|
||||
return safe.RLockRet(func() error {
|
||||
currentGluonDir := bridge.GetGluonCacheDir()
|
||||
newGluonDir = filepath.Join(newGluonDir, "gluon")
|
||||
if newGluonDir == currentGluonDir {
|
||||
return fmt.Errorf("new gluon dir is the same as the old one")
|
||||
}
|
||||
|
||||
if err := bridge.closeIMAP(context.Background()); err != nil {
|
||||
return fmt.Errorf("failed to close IMAP: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.moveGluonCacheDir(currentGluonDir, newGluonDir); err != nil {
|
||||
logrus.WithError(err).Error("failed to move GluonCacheDir")
|
||||
|
||||
if err := bridge.vault.SetGluonDir(currentGluonDir); err != nil {
|
||||
return fmt.Errorf("failed to revert GluonCacheDir: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
bridge.heartbeat.SetCacheLocation(newGluonDir)
|
||||
|
||||
gluonDataDir, err := bridge.GetGluonDataDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get Gluon Database directory: %w", err)
|
||||
}
|
||||
|
||||
imapServer, err := newIMAPServer(
|
||||
bridge.vault.GetGluonCacheDir(),
|
||||
gluonDataDir,
|
||||
bridge.curVersion,
|
||||
bridge.tlsConfig,
|
||||
bridge.reporter,
|
||||
bridge.logIMAPClient,
|
||||
bridge.logIMAPServer,
|
||||
bridge.imapEventCh,
|
||||
bridge.tasks,
|
||||
bridge.uidValidityGenerator,
|
||||
bridge.panicHandler,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new IMAP server: %w", err)
|
||||
}
|
||||
|
||||
bridge.imapServer = imapServer
|
||||
|
||||
for _, user := range bridge.users {
|
||||
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
||||
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := bridge.serveIMAP(); err != nil {
|
||||
return fmt.Errorf("failed to serve IMAP: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, bridge.usersLock)
|
||||
return bridge.serverManager.SetGluonDir(ctx, newGluonDir)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) moveGluonCacheDir(oldGluonDir, newGluonDir string) error {
|
||||
|
||||
@ -57,7 +57,7 @@ func TestBridge_Settings_IMAPPort(t *testing.T) {
|
||||
curPort := bridge.GetIMAPPort()
|
||||
|
||||
// Set the port to 1144.
|
||||
require.NoError(t, bridge.SetIMAPPort(1144))
|
||||
require.NoError(t, bridge.SetIMAPPort(ctx, 1144))
|
||||
|
||||
// Get the new setting.
|
||||
require.Equal(t, 1144, bridge.GetIMAPPort())
|
||||
@ -75,7 +75,7 @@ func TestBridge_Settings_IMAPSSL(t *testing.T) {
|
||||
require.False(t, bridge.GetIMAPSSL())
|
||||
|
||||
// Enable IMAP SSL.
|
||||
require.NoError(t, bridge.SetIMAPSSL(true))
|
||||
require.NoError(t, bridge.SetIMAPSSL(ctx, true))
|
||||
|
||||
// Get the new setting.
|
||||
require.True(t, bridge.GetIMAPSSL())
|
||||
@ -89,7 +89,7 @@ func TestBridge_Settings_SMTPPort(t *testing.T) {
|
||||
curPort := bridge.GetSMTPPort()
|
||||
|
||||
// Set the port to 1024.
|
||||
require.NoError(t, bridge.SetSMTPPort(1024))
|
||||
require.NoError(t, bridge.SetSMTPPort(ctx, 1024))
|
||||
|
||||
// Get the new setting.
|
||||
require.Equal(t, 1024, bridge.GetSMTPPort())
|
||||
@ -107,7 +107,7 @@ func TestBridge_Settings_SMTPSSL(t *testing.T) {
|
||||
require.False(t, bridge.GetSMTPSSL())
|
||||
|
||||
// Enable SMTP SSL.
|
||||
require.NoError(t, bridge.SetSMTPSSL(true))
|
||||
require.NoError(t, bridge.SetSMTPSSL(ctx, true))
|
||||
|
||||
// Get the new setting.
|
||||
require.True(t, bridge.GetSMTPSSL())
|
||||
|
||||
@ -20,93 +20,16 @@ package bridge
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||
"github.com/emersion/go-sasl"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (bridge *Bridge) serveSMTP() error {
|
||||
port, err := func() (int, error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"port": bridge.vault.GetSMTPPort(),
|
||||
"ssl": bridge.vault.GetSMTPSSL(),
|
||||
}).Info("Starting SMTP server")
|
||||
|
||||
smtpListener, err := newListener(bridge.vault.GetSMTPPort(), bridge.vault.GetSMTPSSL(), bridge.tlsConfig)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create SMTP listener: %w", err)
|
||||
}
|
||||
|
||||
bridge.smtpListener = smtpListener
|
||||
|
||||
bridge.tasks.Once(func(context.Context) {
|
||||
if err := bridge.smtpServer.Serve(smtpListener); err != nil {
|
||||
logrus.WithError(err).Info("SMTP server stopped")
|
||||
}
|
||||
})
|
||||
|
||||
if err := bridge.vault.SetSMTPPort(getPort(smtpListener.Addr())); err != nil {
|
||||
return 0, fmt.Errorf("failed to store SMTP port in vault: %w", err)
|
||||
}
|
||||
|
||||
return getPort(smtpListener.Addr()), nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
bridge.publish(events.SMTPServerError{
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
bridge.publish(events.SMTPServerReady{
|
||||
Port: port,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) restartSMTP() error {
|
||||
logrus.Info("Restarting SMTP server")
|
||||
|
||||
if err := bridge.closeSMTP(); err != nil {
|
||||
return fmt.Errorf("failed to close SMTP: %w", err)
|
||||
}
|
||||
|
||||
bridge.publish(events.SMTPServerStopped{})
|
||||
|
||||
bridge.smtpServer = newSMTPServer(bridge, bridge.tlsConfig, bridge.logSMTP)
|
||||
|
||||
return bridge.serveSMTP()
|
||||
}
|
||||
|
||||
// We close the listener ourselves even though it's also closed by smtpServer.Close().
|
||||
// This is because smtpServer.Serve() is called in a separate goroutine and might be executed
|
||||
// after we've already closed the server. However, go-smtp has a bug; it blocks on the listener
|
||||
// even after the server has been closed. So we close the listener ourselves to unblock it.
|
||||
func (bridge *Bridge) closeSMTP() error {
|
||||
logrus.Info("Closing SMTP server")
|
||||
|
||||
if bridge.smtpListener != nil {
|
||||
if err := bridge.smtpListener.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close SMTP listener: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := bridge.smtpServer.Close(); err != nil {
|
||||
logrus.WithError(err).Debug("Failed to close SMTP server (expected -- we close the listener ourselves)")
|
||||
}
|
||||
|
||||
bridge.publish(events.SMTPServerStopped{})
|
||||
|
||||
return nil
|
||||
func (bridge *Bridge) restartSMTP(ctx context.Context) error {
|
||||
return bridge.serverManager.RestartSMTP(ctx)
|
||||
}
|
||||
|
||||
func newSMTPServer(bridge *Bridge, tlsConfig *tls.Config, logSMTP bool) *smtp.Server {
|
||||
|
||||
@ -58,6 +58,11 @@ func (s *smtpSession) AuthPlain(username, password string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"username": username,
|
||||
"pkg": "smtp",
|
||||
}).Error("Incorrect login credentials.")
|
||||
|
||||
return fmt.Errorf("invalid username or password")
|
||||
}, s.usersLock)
|
||||
}
|
||||
@ -72,7 +77,7 @@ func (s *smtpSession) Logout() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *smtpSession) Mail(from string, opts *smtp.MailOptions) error {
|
||||
func (s *smtpSession) Mail(from string, _ *smtp.MailOptions) error {
|
||||
s.from = from
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -80,7 +80,7 @@ func TestBridge_Sync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
@ -112,15 +112,6 @@ func TestBridge_Sync(t *testing.T) {
|
||||
info, err := b.GetUserInfo(userID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
|
||||
status, err := client.Select(`Folders/folder`, false)
|
||||
require.NoError(t, err)
|
||||
require.Less(t, status.Messages, uint32(numMsg))
|
||||
}
|
||||
|
||||
// Remove the network limit, allowing the sync to finish.
|
||||
@ -136,7 +127,7 @@ func TestBridge_Sync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
@ -187,7 +178,7 @@ func _TestBridge_Sync_BadMessage(t *testing.T) { //nolint:unused,deadcode
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
@ -273,15 +264,6 @@ func TestBridge_SyncWithOngoingEvents(t *testing.T) {
|
||||
info, err := b.GetUserInfo(userID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
|
||||
status, err := client.Select(`Folders/folder`, false)
|
||||
require.NoError(t, err)
|
||||
require.Less(t, status.Messages, uint32(numMsg))
|
||||
}
|
||||
|
||||
// Create a new mailbox and move that last 1/3 of the messages into it to simulate user
|
||||
@ -311,7 +293,7 @@ func TestBridge_SyncWithOngoingEvents(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, info.State == bridge.Connected)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
client, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
|
||||
82
internal/bridge/sync_unix_test.go
Normal file
82
internal/bridge/sync_unix_test.go
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
//go:build !windows
|
||||
|
||||
package bridge_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/go-proton-api/server"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Disabled due to flakyness.
|
||||
func _TestBridge_SyncExistsWithErrorWhenTooManyFilesAreOpen(t *testing.T) { //nolint:unused
|
||||
var rlimitCurrent syscall.Rlimit
|
||||
|
||||
require.NoError(t, syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimitCurrent))
|
||||
|
||||
// Restore RLimit for Process at the end of this test
|
||||
defer func() {
|
||||
require.NoError(t, syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimitCurrent))
|
||||
}()
|
||||
|
||||
rlimit := syscall.Rlimit{
|
||||
Max: 100,
|
||||
Cur: 100,
|
||||
}
|
||||
|
||||
require.NoError(t, syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit))
|
||||
|
||||
numMsg := 1 << 8
|
||||
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
userID, addrID, err := s.CreateUser("imap", password)
|
||||
require.NoError(t, err)
|
||||
|
||||
labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder)
|
||||
require.NoError(t, err)
|
||||
|
||||
withClient(ctx, t, s, "imap", password, func(ctx context.Context, c *proton.Client) {
|
||||
createNumMessages(ctx, t, c, addrID, labelID, numMsg)
|
||||
})
|
||||
|
||||
// The initial user should be fully synced.
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
|
||||
syncCh, done := bridge.GetEvents(events.SyncFailed{})
|
||||
defer done()
|
||||
|
||||
userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
evt := <-syncCh
|
||||
switch e := evt.(type) {
|
||||
case events.SyncFailed:
|
||||
require.Equal(t, userID, e.UserID)
|
||||
default:
|
||||
require.Fail(t, "Expected events.SyncFailed{}")
|
||||
}
|
||||
})
|
||||
}, server.WithTLS(false))
|
||||
}
|
||||
@ -584,29 +584,7 @@ func (bridge *Bridge) newVaultUser(
|
||||
authUID, authRef string,
|
||||
saltedKeyPass []byte,
|
||||
) (*vault.User, bool, error) {
|
||||
if !bridge.vault.HasUser(apiUser.ID) {
|
||||
user, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, apiUser.Email, authUID, authRef, saltedKeyPass)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to add user to vault: %w", err)
|
||||
}
|
||||
|
||||
return user, true, nil
|
||||
}
|
||||
|
||||
user, err := bridge.vault.NewUser(apiUser.ID)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if err := user.SetAuth(authUID, authRef); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if err := user.SetKeyPass(saltedKeyPass); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return user, false, nil
|
||||
return bridge.vault.GetOrAddUser(apiUser.ID, apiUser.Name, apiUser.Email, authUID, authRef, saltedKeyPass)
|
||||
}
|
||||
|
||||
// logout logs out the given user, optionally logging them out from the API too.
|
||||
|
||||
@ -141,6 +141,9 @@ func test_badMessage_badEvent(userFeedback func(t *testing.T, ctx context.Contex
|
||||
})
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
smtpWaiter := waitForSMTPServerReady(bridge)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
userLoginAndSync(ctx, t, bridge, "user", password)
|
||||
|
||||
var messageIDs []string
|
||||
@ -176,6 +179,8 @@ func test_badMessage_badEvent(userFeedback func(t *testing.T, ctx context.Contex
|
||||
|
||||
userFeedback(t, ctx, bridge, badUserID)
|
||||
|
||||
smtpWaiter.Wait()
|
||||
|
||||
userContinueEventProcess(ctx, t, s, bridge)
|
||||
})
|
||||
})
|
||||
@ -194,6 +199,9 @@ func TestBridge_User_BadMessage_NoBadEvent(t *testing.T) {
|
||||
})
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
smtpWaiter := waitForSMTPServerReady(bridge)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
userLoginAndSync(ctx, t, bridge, "user", password)
|
||||
|
||||
var messageIDs []string
|
||||
@ -217,6 +225,7 @@ func TestBridge_User_BadMessage_NoBadEvent(t *testing.T) {
|
||||
require.NoError(t, c.DeleteMessage(ctx, messageIDs...))
|
||||
})
|
||||
|
||||
smtpWaiter.Wait()
|
||||
userContinueEventProcess(ctx, t, s, bridge)
|
||||
})
|
||||
})
|
||||
@ -412,6 +421,17 @@ func TestBridge_User_DropConn_NoBadEvent(t *testing.T) {
|
||||
})
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
var count int32
|
||||
// The first 10 times bridge attempts to sync any of the messages, drop the connection.
|
||||
s.AddStatusHook(func(req *http.Request) (int, bool) {
|
||||
if strings.Contains(req.URL.Path, "/mail/v4/messages") {
|
||||
if atomic.AddInt32(&count, 1) < 10 {
|
||||
dropListener.DropAll()
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
})
|
||||
userLoginAndSync(ctx, t, bridge, "user", password)
|
||||
|
||||
mocks.Reporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
@ -421,30 +441,17 @@ func TestBridge_User_DropConn_NoBadEvent(t *testing.T) {
|
||||
createNumMessages(ctx, t, c, addrID, proton.InboxLabel, 10)
|
||||
})
|
||||
|
||||
var count int
|
||||
|
||||
// The first 10 times bridge attempts to sync any of the messages, drop the connection.
|
||||
s.AddStatusHook(func(req *http.Request) (int, bool) {
|
||||
if strings.Contains(req.URL.Path, "/mail/v4/messages") {
|
||||
if count++; count < 10 {
|
||||
dropListener.DropAll()
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
})
|
||||
|
||||
info, err := bridge.QueryUserInfo("user")
|
||||
require.NoError(t, err)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
cli, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
require.NoError(t, cli.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = cli.Logout() }()
|
||||
|
||||
// The IMAP client will eventually see 20 messages.
|
||||
require.Eventually(t, func() bool {
|
||||
status, err := client.Status("INBOX", []imap.StatusItem{imap.StatusMessages})
|
||||
status, err := cli.Status("INBOX", []imap.StatusItem{imap.StatusMessages})
|
||||
return err == nil && status.Messages == 20
|
||||
}, 10*time.Second, 100*time.Millisecond)
|
||||
})
|
||||
@ -638,12 +645,12 @@ func TestBridge_User_SendDraftRemoveDraftFlag(t *testing.T) {
|
||||
info, err := bridge.QueryUserInfo("user")
|
||||
require.NoError(t, err)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
cli, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
require.NoError(t, cli.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = cli.Logout() }()
|
||||
|
||||
messages, err := clientFetch(client, "Drafts")
|
||||
messages, err := clientFetch(cli, "Drafts")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
require.Contains(t, messages[0].Flags, imap.DraftFlag)
|
||||
@ -677,12 +684,12 @@ func TestBridge_User_SendDraftRemoveDraftFlag(t *testing.T) {
|
||||
info, err := bridge.QueryUserInfo("user")
|
||||
require.NoError(t, err)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
cli, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
require.NoError(t, cli.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = cli.Logout() }()
|
||||
|
||||
messages, err := clientFetch(client, "Sent")
|
||||
messages, err := clientFetch(cli, "Sent")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
require.NotContains(t, messages[0].Flags, imap.DraftFlag)
|
||||
@ -771,15 +778,24 @@ func TestBridge_User_CreateDisabledAddress(t *testing.T) {
|
||||
func TestBridge_User_HandleParentLabelRename(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
imapWaiter := waitForIMAPServerReady(bridge)
|
||||
defer imapWaiter.Done()
|
||||
|
||||
smtpWaiter := waitForSMTPServerReady(bridge)
|
||||
defer smtpWaiter.Done()
|
||||
|
||||
require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
|
||||
|
||||
info, err := bridge.QueryUserInfo(username)
|
||||
require.NoError(t, err)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
imapWaiter.Wait()
|
||||
smtpWaiter.Wait()
|
||||
|
||||
cli, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
require.NoError(t, cli.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = cli.Logout() }()
|
||||
|
||||
withClient(ctx, t, s, username, password, func(ctx context.Context, c *proton.Client) {
|
||||
parentName := uuid.NewString()
|
||||
@ -795,7 +811,7 @@ func TestBridge_User_HandleParentLabelRename(t *testing.T) {
|
||||
|
||||
// Wait for the parent folder to be created.
|
||||
require.Eventually(t, func() bool {
|
||||
return xslices.IndexFunc(clientList(client), func(mailbox *imap.MailboxInfo) bool {
|
||||
return xslices.IndexFunc(clientList(cli), func(mailbox *imap.MailboxInfo) bool {
|
||||
return mailbox.Name == fmt.Sprintf("Folders/%v", parentName)
|
||||
}) >= 0
|
||||
}, 100*user.EventPeriod, user.EventPeriod)
|
||||
@ -812,7 +828,7 @@ func TestBridge_User_HandleParentLabelRename(t *testing.T) {
|
||||
|
||||
// Wait for the parent folder to be created.
|
||||
require.Eventually(t, func() bool {
|
||||
return xslices.IndexFunc(clientList(client), func(mailbox *imap.MailboxInfo) bool {
|
||||
return xslices.IndexFunc(clientList(cli), func(mailbox *imap.MailboxInfo) bool {
|
||||
return mailbox.Name == fmt.Sprintf("Folders/%v/%v", parentName, childName)
|
||||
}) >= 0
|
||||
}, 100*user.EventPeriod, user.EventPeriod)
|
||||
@ -827,14 +843,14 @@ func TestBridge_User_HandleParentLabelRename(t *testing.T) {
|
||||
|
||||
// Wait for the parent folder to be renamed.
|
||||
require.Eventually(t, func() bool {
|
||||
return xslices.IndexFunc(clientList(client), func(mailbox *imap.MailboxInfo) bool {
|
||||
return xslices.IndexFunc(clientList(cli), func(mailbox *imap.MailboxInfo) bool {
|
||||
return mailbox.Name == fmt.Sprintf("Folders/%v", newParentName)
|
||||
}) >= 0
|
||||
}, 100*user.EventPeriod, user.EventPeriod)
|
||||
|
||||
// Wait for the child folder to be renamed.
|
||||
require.Eventually(t, func() bool {
|
||||
return xslices.IndexFunc(clientList(client), func(mailbox *imap.MailboxInfo) bool {
|
||||
return xslices.IndexFunc(clientList(cli), func(mailbox *imap.MailboxInfo) bool {
|
||||
return mailbox.Name == fmt.Sprintf("Folders/%v/%v", newParentName, childName)
|
||||
}) >= 0
|
||||
}, 100*user.EventPeriod, user.EventPeriod)
|
||||
@ -843,48 +859,6 @@ func TestBridge_User_HandleParentLabelRename(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TBD: GODT-2527.
|
||||
func _TestBridge503DuringEventDoesNotCauseBadEvent(t *testing.T) { //nolint:unused,deadcode
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
// Create a user.
|
||||
userID, addrID, err := s.CreateUser("user", password)
|
||||
require.NoError(t, err)
|
||||
|
||||
labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create 10 messages for the user.
|
||||
withClient(ctx, t, s, "user", password, func(ctx context.Context, c *proton.Client) {
|
||||
createNumMessages(ctx, t, c, addrID, labelID, 10)
|
||||
})
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
userLoginAndSync(ctx, t, bridge, "user", password)
|
||||
|
||||
var messageIDs []string
|
||||
|
||||
// Create 10 more messages for the user, generating events.
|
||||
withClient(ctx, t, s, "user", password, func(ctx context.Context, c *proton.Client) {
|
||||
messageIDs = createNumMessages(ctx, t, c, addrID, labelID, 10)
|
||||
})
|
||||
|
||||
mocks.Reporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Any()).MinTimes(1)
|
||||
|
||||
s.AddStatusHook(func(req *http.Request) (int, bool) {
|
||||
if xslices.Index(xslices.Map(messageIDs[0:5], func(messageID string) string {
|
||||
return "/mail/v4/messages/" + messageID
|
||||
}), req.URL.Path) < 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return http.StatusServiceUnavailable, true
|
||||
})
|
||||
|
||||
userContinueEventProcess(ctx, t, s, bridge)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// userLoginAndSync logs in user and waits until user is fully synced.
|
||||
func userLoginAndSync(
|
||||
ctx context.Context,
|
||||
@ -928,10 +902,10 @@ func userContinueEventProcess(
|
||||
info, err := bridge.QueryUserInfo("user")
|
||||
require.NoError(t, err)
|
||||
|
||||
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
cli, err := eventuallyDial(fmt.Sprintf("%v:%v", constants.Host, bridge.GetIMAPPort()))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = client.Logout() }()
|
||||
require.NoError(t, cli.Login(info.Addresses[0], string(info.BridgePass)))
|
||||
defer func() { _ = cli.Logout() }()
|
||||
|
||||
randomLabel := uuid.NewString()
|
||||
|
||||
@ -946,8 +920,21 @@ func userContinueEventProcess(
|
||||
|
||||
// Wait for the label to be created.
|
||||
require.Eventually(t, func() bool {
|
||||
return xslices.IndexFunc(clientList(client), func(mailbox *imap.MailboxInfo) bool {
|
||||
return xslices.IndexFunc(clientList(cli), func(mailbox *imap.MailboxInfo) bool {
|
||||
return mailbox.Name == "Labels/"+randomLabel
|
||||
}) >= 0
|
||||
}, 100*user.EventPeriod, user.EventPeriod)
|
||||
}
|
||||
|
||||
func eventuallyDial(addr string) (cli *client.Client, err error) {
|
||||
var sleep = 1 * time.Second
|
||||
for i := 0; i < 5; i++ {
|
||||
cli, err := client.Dial(addr)
|
||||
if err == nil {
|
||||
return cli, nil
|
||||
}
|
||||
time.Sleep(sleep)
|
||||
sleep *= 2
|
||||
}
|
||||
return nil, fmt.Errorf("after 5 attempts, last error: %s", err)
|
||||
}
|
||||
|
||||
@ -75,11 +75,7 @@ func (bridge *Bridge) handleUserAddressCreated(ctx context.Context, user *user.U
|
||||
return nil
|
||||
}
|
||||
|
||||
if bridge.imapServer == nil {
|
||||
return fmt.Errorf("no imap server instance running")
|
||||
}
|
||||
|
||||
gluonID, err := bridge.imapServer.AddUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
|
||||
gluonID, err := bridge.serverManager.AddGluonUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add user to IMAP server: %w", err)
|
||||
}
|
||||
@ -96,7 +92,7 @@ func (bridge *Bridge) handleUserAddressEnabled(ctx context.Context, user *user.U
|
||||
return nil
|
||||
}
|
||||
|
||||
gluonID, err := bridge.imapServer.AddUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
|
||||
gluonID, err := bridge.serverManager.AddGluonUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add user to IMAP server: %w", err)
|
||||
}
|
||||
@ -118,7 +114,7 @@ func (bridge *Bridge) handleUserAddressDisabled(ctx context.Context, user *user.
|
||||
return fmt.Errorf("gluon ID not found for address %s", event.AddressID)
|
||||
}
|
||||
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||
if err := bridge.serverManager.RemoveGluonUser(ctx, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
|
||||
}
|
||||
|
||||
@ -134,16 +130,12 @@ func (bridge *Bridge) handleUserAddressDeleted(ctx context.Context, user *user.U
|
||||
return nil
|
||||
}
|
||||
|
||||
if bridge.imapServer == nil {
|
||||
return fmt.Errorf("no imap server instance running")
|
||||
}
|
||||
|
||||
gluonID, ok := user.GetGluonID(event.AddressID)
|
||||
if !ok {
|
||||
return fmt.Errorf("gluon ID not found for address %s", event.AddressID)
|
||||
}
|
||||
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||
if err := bridge.serverManager.handleRemoveGluonUser(ctx, gluonID); err != nil {
|
||||
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@ -708,7 +708,26 @@ func TestBridge_User_Refresh(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_User_GetAddresses(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
// Create a user.
|
||||
userID, _, err := s.CreateUser("user", password)
|
||||
require.NoError(t, err)
|
||||
addrID2, err := s.CreateAddress(userID, "user@external.com", []byte("password"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.ChangeAddressType(userID, addrID2, proton.AddressTypeExternal))
|
||||
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
|
||||
userLoginAndSync(ctx, t, bridge, "user", password)
|
||||
info, err := bridge.GetUserInfo(userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(info.Addresses))
|
||||
require.Equal(t, info.Addresses[0], "user@proton.local")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// getErr returns the error that was passed to it.
|
||||
func getErr[T any](val T, err error) error {
|
||||
func getErr[T any](_ T, err error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -50,7 +50,7 @@ int installTrustedCert(char const *bytes, unsigned long long length) {
|
||||
(id)kSecTrustSettingsResult: [NSNumber numberWithInt:kSecTrustSettingsResultTrustRoot],
|
||||
(id)kSecTrustSettingsPolicy: (__bridge id) policy,
|
||||
};
|
||||
status = SecTrustSettingsSetTrustSettings(cert, kSecTrustSettingsDomainAdmin, (__bridge CFTypeRef)(trustSettings));
|
||||
status = SecTrustSettingsSetTrustSettings(cert, kSecTrustSettingsDomainUser, (__bridge CFTypeRef)(trustSettings));
|
||||
CFRelease(policy);
|
||||
CFRelease(cert);
|
||||
|
||||
@ -72,7 +72,7 @@ int removeTrustedCert(char const *bytes, unsigned long long length) {
|
||||
(id)kSecTrustSettingsResult: [NSNumber numberWithInt:kSecTrustSettingsResultUnspecified],
|
||||
(id)kSecTrustSettingsPolicy: (__bridge id) policy,
|
||||
};
|
||||
OSStatus status = SecTrustSettingsSetTrustSettings(cert, kSecTrustSettingsDomainAdmin, (__bridge CFTypeRef)(trustSettings));
|
||||
OSStatus status = SecTrustSettingsSetTrustSettings(cert, kSecTrustSettingsDomainUser, (__bridge CFTypeRef)(trustSettings));
|
||||
CFRelease(policy);
|
||||
if (errSecSuccess != status) {
|
||||
CFRelease(cert);
|
||||
@ -107,7 +107,6 @@ const (
|
||||
|
||||
// certPEMToDER converts a certificate in PEM format to DER format, which is the format required by Apple's Security framework.
|
||||
func certPEMToDER(certPEM []byte) ([]byte, error) {
|
||||
|
||||
block, left := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
return []byte{}, errors.New("invalid PEM certificate")
|
||||
@ -127,7 +126,7 @@ func installCert(certPEM []byte) error {
|
||||
}
|
||||
|
||||
p := C.CBytes(certDER)
|
||||
defer C.free(unsafe.Pointer(p))
|
||||
defer C.free(unsafe.Pointer(p)) //nolint:unconvert
|
||||
|
||||
errCode := C.installTrustedCert((*C.char)(p), (C.ulonglong)(len(certDER)))
|
||||
switch errCode {
|
||||
@ -147,7 +146,7 @@ func uninstallCert(certPEM []byte) error {
|
||||
}
|
||||
|
||||
p := C.CBytes(certDER)
|
||||
defer C.free(unsafe.Pointer(p))
|
||||
defer C.free(unsafe.Pointer(p)) //nolint:unconvert
|
||||
|
||||
if errCode := C.removeTrustedCert((*C.char)(p), (C.ulonglong)(len(certDER))); errCode != 0 {
|
||||
return fmt.Errorf("could not install certificate from keychain (error %v)", errCode)
|
||||
|
||||
@ -26,7 +26,7 @@ import (
|
||||
)
|
||||
|
||||
// This test implies human interactions to enter password and is disabled by default.
|
||||
func _TestTrustedCertsDarwin(t *testing.T) {
|
||||
func _TestTrustedCertsDarwin(t *testing.T) { //nolint:unused
|
||||
template, err := NewTLSTemplate()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ if(NOT UNIX)
|
||||
set(CMAKE_INSTALL_BINDIR ".")
|
||||
endif(NOT UNIX)
|
||||
|
||||
find_package(Qt6 COMPONENTS Core Quick Qml QuickControls2 Widgets REQUIRED)
|
||||
find_package(Qt6 COMPONENTS Core Quick Qml QuickControls2 Widgets Svg REQUIRED)
|
||||
qt_standard_project_setup()
|
||||
set(CMAKE_AUTORCC ON)
|
||||
message(STATUS "Using Qt ${Qt6_VERSION}")
|
||||
@ -147,6 +147,7 @@ target_link_libraries(bridge-gui
|
||||
Qt6::Quick
|
||||
Qt6::Qml
|
||||
Qt6::QuickControls2
|
||||
Qt6::Svg
|
||||
sentry::sentry
|
||||
bridgepp
|
||||
)
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include <QtQml>
|
||||
#include <QtWidgets>
|
||||
#include <QtQuickControls2>
|
||||
#include <QtSvg>
|
||||
#include <AppController.h>
|
||||
|
||||
|
||||
|
||||
@ -994,15 +994,44 @@ void QMLBackend::onUserBadEvent(QString const &userID, QString const &) {
|
||||
void QMLBackend::onIMAPLoginFailed(QString const &username) {
|
||||
HANDLE_EXCEPTION(
|
||||
SPUser const user = users_->getUserWithUsernameOrEmail(username);
|
||||
if ((!user) || (user->state() != UserState::SignedOut)) { // We want to pop-up only if a signed-out user has been detected
|
||||
if (!user) {
|
||||
return;
|
||||
}
|
||||
if (user->isInIMAPLoginFailureCooldown()) {
|
||||
return;
|
||||
|
||||
qint64 const cooldownDurationMs = 10 * 60 * 1000; // 10 minutes cooldown period for notifications
|
||||
switch (user->state()) {
|
||||
case UserState::SignedOut:
|
||||
if (user->isNotificationInCooldown(User::ENotification::IMAPLoginWhileSignedOut)) {
|
||||
return;
|
||||
}
|
||||
user->startNotificationCooldownPeriod(User::ENotification::IMAPLoginWhileSignedOut, cooldownDurationMs);
|
||||
emit selectUser(user->id(), true);
|
||||
emit imapLoginWhileSignedOut(username);
|
||||
break;
|
||||
|
||||
case UserState::Connected:
|
||||
if (user->isNotificationInCooldown(User::ENotification::IMAPPasswordFailure)) {
|
||||
return;
|
||||
}
|
||||
user->startNotificationCooldownPeriod(User::ENotification::IMAPPasswordFailure, cooldownDurationMs);
|
||||
emit selectUser(user->id(), false);
|
||||
trayIcon_->showErrorPopupNotification(tr("Incorrect password"),
|
||||
tr("Your email client can't connect to Proton Bridge. Make sure you are using the local Bridge password shown in Bridge."));
|
||||
break;
|
||||
|
||||
case UserState::Locked:
|
||||
if (user->isNotificationInCooldown(User::ENotification::IMAPLoginWhileLocked)) {
|
||||
return;
|
||||
}
|
||||
user->startNotificationCooldownPeriod(User::ENotification::IMAPLoginWhileLocked, cooldownDurationMs);
|
||||
emit selectUser(user->id(), false);
|
||||
trayIcon_->showErrorPopupNotification(tr("Connection in progress"),
|
||||
tr("Your Proton account in Bridge is being connected. Please wait or restart Bridge."));
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
user->startImapLoginFailureCooldown(60 * 60 * 1000); // 1 hour cooldown during which we will not display this notification to this user again.
|
||||
emit selectUser(user->id());
|
||||
emit imapLoginWhileSignedOut(username);
|
||||
)
|
||||
}
|
||||
|
||||
@ -1134,7 +1163,7 @@ void QMLBackend::displayBadEventDialog(QString const &userID) {
|
||||
emit userBadEvent(userID,
|
||||
tr("Bridge ran into an internal error and it is not able to proceed with the account %1. Synchronize your local database now or logout"
|
||||
" to do it later. Synchronization time depends on the size of your mailbox.").arg(elideLongString(user->primaryEmailOrUsername(), 30)));
|
||||
emit selectUser(userID);
|
||||
emit selectUser(userID, true);
|
||||
emit showMainWindow();
|
||||
)
|
||||
}
|
||||
|
||||
@ -180,6 +180,8 @@ public slots: // slot for signals received from QML -> To be forwarded to Bridge
|
||||
void onVersionChanged(); ///< Slot for the version change signal.
|
||||
void setMailServerSettings(int imapPort, int smtpPort, bool useSSLForIMAP, bool useSSLForSMTP) const; ///< Forwards a connection mode change request from QML to gRPC
|
||||
void sendBadEventUserFeedback(QString const &userID, bool doResync); ///< Slot the providing user feedback for a bad event.
|
||||
|
||||
public slots: // slots for functions that need to be processed locally.
|
||||
void setNormalTrayIcon(); ///< Set the tray icon to normal.
|
||||
void setErrorTrayIcon(QString const& stateString, QString const &statusIcon); ///< Set the tray icon to 'error' state.
|
||||
void setWarnTrayIcon(QString const& stateString, QString const &statusIcon); ///< Set the tray icon to 'warn' state.
|
||||
@ -245,7 +247,7 @@ signals: // Signals received from the Go backend, to be forwarded to QML
|
||||
void hideMainWindow(); ///< Signal for the 'hideMainWindow' gRPC stream event.
|
||||
void showHelp(); ///< Signal for the 'showHelp' event (from the context menu).
|
||||
void showSettings(); ///< Signal for the 'showHelp' event (from the context menu).
|
||||
void selectUser(QString const& userID); ///< Signal emitted in order to selected a user with a given ID in the list.
|
||||
void selectUser(QString const& userID, bool forceShowWindow); ///< Signal emitted in order to selected a user with a given ID in the list.
|
||||
void genericError(QString const &title, QString const &description); ///< Signal for the 'genericError' gRPC stream event.
|
||||
void imapLoginWhileSignedOut(QString const& username); ///< Signal for the notification of IMAP login attempt on a signed out account.
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ QString sentryAttachmentFilePath() {
|
||||
//****************************************************************************************************************************************************
|
||||
QByteArray getProtectedHostname() {
|
||||
QByteArray hostname = QCryptographicHash::hash(QSysInfo::machineHostName().toUtf8(), QCryptographicHash::Sha256);
|
||||
return hostname.toHex();
|
||||
return hostname.toBase64();
|
||||
}
|
||||
|
||||
//****************************************************************************************************************************************************
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include <sentry.h>
|
||||
|
||||
void initSentry();
|
||||
QByteArray getProtectedHostname();
|
||||
void setSentryReportScope();
|
||||
sentry_options_t* newSentryOptions(const char * sentryDNS, const char * cacheDir);
|
||||
sentry_uuid_t reportSentryEvent(sentry_level_t level, const char *message);
|
||||
|
||||
@ -43,12 +43,75 @@ qint64 const iconRefreshDurationSecs = 10; ///< The total number of seconds duri
|
||||
QIcon loadIconFromImage(QString const &path) {
|
||||
QPixmap const pixmap(path);
|
||||
if (pixmap.isNull()) {
|
||||
throw Exception(QString("Could create icon from image '%1'.").arg(path));
|
||||
throw Exception(QString("Could not create an icon from an image '%1'.").arg(path));
|
||||
}
|
||||
return QIcon(pixmap);
|
||||
}
|
||||
|
||||
|
||||
//****************************************************************************************************************************************************
|
||||
/// \brief Generate an icon from a SVG renderer (a.k.a. path).
|
||||
///
|
||||
/// \param[in] renderer The SVG renderer.
|
||||
/// \param[in] color The color to use in case the SVG path is to be used as a mask.
|
||||
/// \return The icon.
|
||||
//****************************************************************************************************************************************************
|
||||
QIcon loadIconFromSVGRenderer(QSvgRenderer &renderer, QColor const &color = QColor()) {
|
||||
if (!renderer.isValid()) {
|
||||
return QIcon();
|
||||
}
|
||||
QIcon icon;
|
||||
qint32 size = 256;
|
||||
|
||||
while (size >= 16) {
|
||||
QPixmap pixmap(size, size);
|
||||
pixmap.fill(QColor(0, 0, 0, 0));
|
||||
QPainter painter(&pixmap);
|
||||
renderer.render(&painter);
|
||||
if (color.isValid()) {
|
||||
painter.setCompositionMode(QPainter::CompositionMode_SourceIn);
|
||||
painter.fillRect(pixmap.rect(), color);
|
||||
}
|
||||
painter.end();
|
||||
icon.addPixmap(pixmap);
|
||||
size /= 2;
|
||||
}
|
||||
|
||||
return icon;
|
||||
}
|
||||
|
||||
|
||||
//****************************************************************************************************************************************************
|
||||
/// \brief Load a multi-resolution icon from a SVG file. The image is assumed to be square. SVG is rasterized in 256, 128, 64, 32 and 16px.
|
||||
///
|
||||
/// Note: QPixmap can load SVG files directly, but our SVG file are defined in small shape size and QPixmap will rasterize them a very low resolution
|
||||
/// by default (eg. 16x16), which is insufficient for some uses. As a consequence, we manually generate a multi-resolution icon that render smoothly
|
||||
/// at any acceptable resolution for an icon.
|
||||
///
|
||||
/// \param[in] path The path of the SVG file.
|
||||
/// \return The icon.
|
||||
//****************************************************************************************************************************************************
|
||||
QIcon loadIconFromSVG(QString const &path, QColor const &color = QColor()) {
|
||||
QSvgRenderer renderer(path);
|
||||
QIcon const icon = loadIconFromSVGRenderer(renderer, color);
|
||||
if (icon.isNull()) {
|
||||
Exception(QString("Could not create an icon from a vector image '%1'.").arg(path));
|
||||
}
|
||||
return icon;
|
||||
}
|
||||
|
||||
|
||||
//****************************************************************************************************************************************************
|
||||
//
|
||||
//****************************************************************************************************************************************************
|
||||
QIcon loadIcon(QString const &path) {
|
||||
if (path.endsWith(".svg", Qt::CaseInsensitive)) {
|
||||
return loadIconFromSVG(path);
|
||||
}
|
||||
return loadIconFromImage(path);
|
||||
}
|
||||
|
||||
|
||||
//****************************************************************************************************************************************************
|
||||
/// \brief Retrieve the color associated with a tray icon state.
|
||||
///
|
||||
@ -95,6 +158,18 @@ QString stateText(TrayIcon::State state) {
|
||||
}
|
||||
|
||||
|
||||
//****************************************************************************************************************************************************
|
||||
/// \brief converts a QML resource path to Qt resource path.
|
||||
/// QML resource paths are a bit different from qt resource paths
|
||||
/// \param[in] path The resource path.
|
||||
/// \return
|
||||
//****************************************************************************************************************************************************
|
||||
QString qmlResourcePathToQt(QString const &path) {
|
||||
QString result = path;
|
||||
result.replace(QRegularExpression(R"(^\.\/)"), ":/qml/");
|
||||
return result;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
@ -103,17 +178,17 @@ QString stateText(TrayIcon::State state) {
|
||||
//****************************************************************************************************************************************************
|
||||
TrayIcon::TrayIcon()
|
||||
: QSystemTrayIcon()
|
||||
, menu_(new QMenu) {
|
||||
|
||||
, menu_(new QMenu)
|
||||
, notificationErrorIcon_(loadIconFromSVG(":/qml/icons/ic-alert.svg")) {
|
||||
this->generateDotIcons();
|
||||
this->setContextMenu(menu_.get());
|
||||
|
||||
connect(menu_.get(), &QMenu::aboutToShow, this, &TrayIcon::onMenuAboutToShow);
|
||||
connect(this, &TrayIcon::selectUser, &app().backend(), &QMLBackend::selectUser);
|
||||
connect(this, &TrayIcon::activated, this, &TrayIcon::onActivated);
|
||||
|
||||
// some OSes/Desktop managers will automatically show main window when clicked, but not all, so we do it manually.
|
||||
connect(this, &TrayIcon::messageClicked, &app().backend(), &QMLBackend::showMainWindow);
|
||||
this->show();
|
||||
this->setState(State::Normal, QString(), QString());
|
||||
|
||||
// TrayIcon does not expose its screen, so we connect relevant screen events to our DPI change handler.
|
||||
for (QScreen *screen: QGuiApplication::screens()) {
|
||||
@ -151,7 +226,7 @@ void TrayIcon::onUserClicked() {
|
||||
throw Exception("Could not retrieve context menu's selected user.");
|
||||
}
|
||||
|
||||
emit selectUser(userID);
|
||||
emit selectUser(userID, true);
|
||||
} catch (Exception const &e) {
|
||||
app().log().error(e.qwhat());
|
||||
}
|
||||
@ -212,18 +287,17 @@ void TrayIcon::onIconRefreshTimer() {
|
||||
//
|
||||
//****************************************************************************************************************************************************
|
||||
void TrayIcon::generateDotIcons() {
|
||||
QPixmap dotSVG(":/qml/icons/ic-dot.svg");
|
||||
QSvgRenderer dotSVG(QString(":/qml/icons/ic-dot.svg"));
|
||||
|
||||
struct IconColor {
|
||||
QIcon &icon;
|
||||
QColor color;
|
||||
};
|
||||
for (auto pair: QList<IconColor> {{ greenDot_, normalColor }, { greyDot_, greyColor }, { orangeDot_, warnColor }}) {
|
||||
QPixmap p = dotSVG;
|
||||
QPainter painter(&p);
|
||||
painter.setCompositionMode(QPainter::CompositionMode_SourceIn);
|
||||
painter.fillRect(p.rect(), pair.color);
|
||||
painter.end();
|
||||
pair.icon = QIcon(p);
|
||||
pair.icon = loadIconFromSVGRenderer(dotSVG, pair.color);
|
||||
if (pair.icon.isNull()) {
|
||||
throw Exception("Could not generate dot icon from vector file.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -242,26 +316,28 @@ void TrayIcon::setState(TrayIcon::State state, QString const &stateString, QStri
|
||||
}
|
||||
|
||||
|
||||
//****************************************************************************************************************************************************
|
||||
/// \param[in] title The title.
|
||||
/// \param[in] message The message.
|
||||
//****************************************************************************************************************************************************
|
||||
void TrayIcon::showErrorPopupNotification(QString const &title, QString const &message) {
|
||||
this->showMessage(title, message, notificationErrorIcon_);
|
||||
}
|
||||
|
||||
|
||||
//****************************************************************************************************************************************************
|
||||
/// \param[in] svgPath The path of the SVG file for the icon.
|
||||
/// \param[in] color The color to apply to the icon.
|
||||
//****************************************************************************************************************************************************
|
||||
void TrayIcon::generateStatusIcon(QString const &svgPath, QColor const &color) {
|
||||
// We use the SVG path as pixmap mask and fill it with the appropriate color
|
||||
QString resourcePath = svgPath;
|
||||
resourcePath.replace(QRegularExpression(R"(^\.\/)"), ":/qml/"); // QML resource path are a bit different from the Qt resources path.
|
||||
QPixmap pixmap(resourcePath);
|
||||
QPainter painter(&pixmap);
|
||||
painter.setCompositionMode(QPainter::CompositionMode_SourceIn);
|
||||
painter.fillRect(pixmap.rect(), color);
|
||||
painter.end();
|
||||
statusIcon_ = QIcon(pixmap);
|
||||
statusIcon_ = loadIconFromSVG(qmlResourcePathToQt(svgPath), color);
|
||||
}
|
||||
|
||||
|
||||
//**********************************************************************************************************************
|
||||
//****************************************************************************************************************************************************
|
||||
//
|
||||
//**********************************************************************************************************************
|
||||
//****************************************************************************************************************************************************
|
||||
void TrayIcon::refreshContextMenu() {
|
||||
if (!menu_) {
|
||||
app().log().error("Native tray icon context menu is null.");
|
||||
@ -297,3 +373,5 @@ void TrayIcon::refreshContextMenu() {
|
||||
menu_->addSeparator();
|
||||
menu_->addAction(tr("&Quit Bridge"), onMac ? QKeySequence("Ctrl+Q") : noShortcut, &app().backend(), &QMLBackend::quit);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -41,10 +41,10 @@ public: // data members
|
||||
TrayIcon& operator=(TrayIcon const&) = delete; ///< Disabled assignment operator.
|
||||
TrayIcon& operator=(TrayIcon&&) = delete; ///< Disabled move assignment operator.
|
||||
void setState(State state, QString const& stateString, QString const &statusIconPath); ///< Set the state of the icon
|
||||
void showNotificationPopup(QString const& title, QString const &message, QString const& iconPath); ///< Display a pop up notification.
|
||||
void showErrorPopupNotification(QString const& title, QString const &message); ///< Display a pop up notification.
|
||||
|
||||
signals:
|
||||
void selectUser(QString const& userID); ///< Signal for selecting a user with a given userID
|
||||
void selectUser(QString const& userID, bool forceShowWindow); ///< Signal for selecting a user with a given userID
|
||||
|
||||
private slots:
|
||||
void onMenuAboutToShow(); ///< Slot called before the context menu is shown.
|
||||
@ -67,6 +67,7 @@ private: // data members
|
||||
QIcon greenDot_; ///< The green dot icon.
|
||||
QIcon greyDot_; ///< The grey dot icon.
|
||||
QIcon orangeDot_; ///< The orange dot icon.
|
||||
QIcon const notificationErrorIcon_; ///< The error icon used for notifications.
|
||||
|
||||
QTimer iconRefreshTimer_; ///< The timer used to periodically refresh the icon when DPI changes.
|
||||
QDateTime iconRefreshDeadline_; ///< The deadline for refreshing the icon
|
||||
|
||||
@ -305,6 +305,8 @@ int main(int argc, char *argv[]) {
|
||||
// When not in attached mode, log entries are forwarded to bridge, which output it on stdout/stderr. bridge-gui's process monitor intercept
|
||||
// these outputs and output them on the command-line.
|
||||
log.setLevel(cliOptions.logLevel);
|
||||
log.info(QString("New Sentry reporter - id: %1.").arg(getProtectedHostname()));
|
||||
|
||||
QString bridgeexec;
|
||||
if (!cliOptions.attach) {
|
||||
if (isBridgeRunning()) {
|
||||
|
||||
@ -95,9 +95,11 @@ ApplicationWindow {
|
||||
root.showAndRise()
|
||||
}
|
||||
|
||||
function onSelectUser(userID) {
|
||||
function onSelectUser(userID, forceShowWindow) {
|
||||
contentWrapper.selectUser(userID)
|
||||
root.showAndRise()
|
||||
if (forceShowWindow) {
|
||||
root.showAndRise()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -535,11 +535,12 @@ QtObject {
|
||||
}
|
||||
|
||||
property Notification onlyPaidUsers: Notification {
|
||||
description: qsTr("Bridge is exclusive to our paid plans. Upgrade your account to use Bridge.")
|
||||
description: qsTr("Bridge is exclusive to our mail paid plans. Upgrade your account to use Bridge.")
|
||||
brief: qsTr("Upgrade your account")
|
||||
icon: "./icons/ic-exclamation-circle-filled.svg"
|
||||
type: Notification.NotificationType.Danger
|
||||
group: Notifications.Group.Configuration
|
||||
property var pricingLink: "https://proton.me/mail/pricing"
|
||||
|
||||
Connections {
|
||||
target: Backend
|
||||
@ -550,8 +551,9 @@ QtObject {
|
||||
|
||||
action: [
|
||||
Action {
|
||||
text: qsTr("OK")
|
||||
text: qsTr("Upgrade")
|
||||
onTriggered: {
|
||||
Qt.openUrlExternally(root.onlyPaidUsers.pricingLink)
|
||||
root.onlyPaidUsers.active = false
|
||||
}
|
||||
}
|
||||
|
||||
@ -344,7 +344,12 @@ FocusScope {
|
||||
if (str.length === 0) {
|
||||
return qsTr("Enter the 6-digit code")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
onTextChanged: {
|
||||
if (text.length >= 6) {
|
||||
twoFAButton.onClicked()
|
||||
}
|
||||
}
|
||||
|
||||
onAccepted: {
|
||||
|
||||
@ -6,19 +6,19 @@
|
||||
#include "focus.grpc.pb.h"
|
||||
|
||||
#include <functional>
|
||||
#include <grpcpp/impl/codegen/async_stream.h>
|
||||
#include <grpcpp/impl/codegen/async_unary_call.h>
|
||||
#include <grpcpp/impl/codegen/channel_interface.h>
|
||||
#include <grpcpp/impl/codegen/client_unary_call.h>
|
||||
#include <grpcpp/impl/codegen/client_callback.h>
|
||||
#include <grpcpp/impl/codegen/message_allocator.h>
|
||||
#include <grpcpp/impl/codegen/method_handler.h>
|
||||
#include <grpcpp/impl/codegen/rpc_service_method.h>
|
||||
#include <grpcpp/impl/codegen/server_callback.h>
|
||||
#include <grpcpp/support/async_stream.h>
|
||||
#include <grpcpp/support/async_unary_call.h>
|
||||
#include <grpcpp/impl/channel_interface.h>
|
||||
#include <grpcpp/impl/client_unary_call.h>
|
||||
#include <grpcpp/support/client_callback.h>
|
||||
#include <grpcpp/support/message_allocator.h>
|
||||
#include <grpcpp/support/method_handler.h>
|
||||
#include <grpcpp/impl/rpc_service_method.h>
|
||||
#include <grpcpp/support/server_callback.h>
|
||||
#include <grpcpp/impl/codegen/server_callback_handlers.h>
|
||||
#include <grpcpp/impl/codegen/server_context.h>
|
||||
#include <grpcpp/impl/codegen/service_type.h>
|
||||
#include <grpcpp/impl/codegen/sync_stream.h>
|
||||
#include <grpcpp/server_context.h>
|
||||
#include <grpcpp/impl/service_type.h>
|
||||
#include <grpcpp/support/sync_stream.h>
|
||||
namespace focus {
|
||||
|
||||
static const char* Focus_method_names[] = {
|
||||
|
||||
@ -25,23 +25,23 @@
|
||||
#include "focus.pb.h"
|
||||
|
||||
#include <functional>
|
||||
#include <grpcpp/impl/codegen/async_generic_service.h>
|
||||
#include <grpcpp/impl/codegen/async_stream.h>
|
||||
#include <grpcpp/impl/codegen/async_unary_call.h>
|
||||
#include <grpcpp/impl/codegen/client_callback.h>
|
||||
#include <grpcpp/impl/codegen/client_context.h>
|
||||
#include <grpcpp/impl/codegen/completion_queue.h>
|
||||
#include <grpcpp/impl/codegen/message_allocator.h>
|
||||
#include <grpcpp/impl/codegen/method_handler.h>
|
||||
#include <grpcpp/generic/async_generic_service.h>
|
||||
#include <grpcpp/support/async_stream.h>
|
||||
#include <grpcpp/support/async_unary_call.h>
|
||||
#include <grpcpp/support/client_callback.h>
|
||||
#include <grpcpp/client_context.h>
|
||||
#include <grpcpp/completion_queue.h>
|
||||
#include <grpcpp/support/message_allocator.h>
|
||||
#include <grpcpp/support/method_handler.h>
|
||||
#include <grpcpp/impl/codegen/proto_utils.h>
|
||||
#include <grpcpp/impl/codegen/rpc_method.h>
|
||||
#include <grpcpp/impl/codegen/server_callback.h>
|
||||
#include <grpcpp/impl/rpc_method.h>
|
||||
#include <grpcpp/support/server_callback.h>
|
||||
#include <grpcpp/impl/codegen/server_callback_handlers.h>
|
||||
#include <grpcpp/impl/codegen/server_context.h>
|
||||
#include <grpcpp/impl/codegen/service_type.h>
|
||||
#include <grpcpp/server_context.h>
|
||||
#include <grpcpp/impl/service_type.h>
|
||||
#include <grpcpp/impl/codegen/status.h>
|
||||
#include <grpcpp/impl/codegen/stub_options.h>
|
||||
#include <grpcpp/impl/codegen/sync_stream.h>
|
||||
#include <grpcpp/support/stub_options.h>
|
||||
#include <grpcpp/support/sync_stream.h>
|
||||
|
||||
namespace focus {
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
#error incompatible with your Protocol Buffer headers. Please update
|
||||
#error your headers.
|
||||
#endif
|
||||
#if 3021003 < PROTOBUF_MIN_PROTOC_VERSION
|
||||
#if 3021012 < PROTOBUF_MIN_PROTOC_VERSION
|
||||
#error This file was generated by an older version of protoc which is
|
||||
#error incompatible with your Protocol Buffer headers. Please
|
||||
#error regenerate this file with a newer version of protoc.
|
||||
|
||||
@ -6,19 +6,19 @@
|
||||
#include "bridge.grpc.pb.h"
|
||||
|
||||
#include <functional>
|
||||
#include <grpcpp/impl/codegen/async_stream.h>
|
||||
#include <grpcpp/impl/codegen/async_unary_call.h>
|
||||
#include <grpcpp/impl/codegen/channel_interface.h>
|
||||
#include <grpcpp/impl/codegen/client_unary_call.h>
|
||||
#include <grpcpp/impl/codegen/client_callback.h>
|
||||
#include <grpcpp/impl/codegen/message_allocator.h>
|
||||
#include <grpcpp/impl/codegen/method_handler.h>
|
||||
#include <grpcpp/impl/codegen/rpc_service_method.h>
|
||||
#include <grpcpp/impl/codegen/server_callback.h>
|
||||
#include <grpcpp/support/async_stream.h>
|
||||
#include <grpcpp/support/async_unary_call.h>
|
||||
#include <grpcpp/impl/channel_interface.h>
|
||||
#include <grpcpp/impl/client_unary_call.h>
|
||||
#include <grpcpp/support/client_callback.h>
|
||||
#include <grpcpp/support/message_allocator.h>
|
||||
#include <grpcpp/support/method_handler.h>
|
||||
#include <grpcpp/impl/rpc_service_method.h>
|
||||
#include <grpcpp/support/server_callback.h>
|
||||
#include <grpcpp/impl/codegen/server_callback_handlers.h>
|
||||
#include <grpcpp/impl/codegen/server_context.h>
|
||||
#include <grpcpp/impl/codegen/service_type.h>
|
||||
#include <grpcpp/impl/codegen/sync_stream.h>
|
||||
#include <grpcpp/server_context.h>
|
||||
#include <grpcpp/impl/service_type.h>
|
||||
#include <grpcpp/support/sync_stream.h>
|
||||
namespace grpc {
|
||||
|
||||
static const char* Bridge_method_names[] = {
|
||||
|
||||
@ -25,23 +25,23 @@
|
||||
#include "bridge.pb.h"
|
||||
|
||||
#include <functional>
|
||||
#include <grpcpp/impl/codegen/async_generic_service.h>
|
||||
#include <grpcpp/impl/codegen/async_stream.h>
|
||||
#include <grpcpp/impl/codegen/async_unary_call.h>
|
||||
#include <grpcpp/impl/codegen/client_callback.h>
|
||||
#include <grpcpp/impl/codegen/client_context.h>
|
||||
#include <grpcpp/impl/codegen/completion_queue.h>
|
||||
#include <grpcpp/impl/codegen/message_allocator.h>
|
||||
#include <grpcpp/impl/codegen/method_handler.h>
|
||||
#include <grpcpp/generic/async_generic_service.h>
|
||||
#include <grpcpp/support/async_stream.h>
|
||||
#include <grpcpp/support/async_unary_call.h>
|
||||
#include <grpcpp/support/client_callback.h>
|
||||
#include <grpcpp/client_context.h>
|
||||
#include <grpcpp/completion_queue.h>
|
||||
#include <grpcpp/support/message_allocator.h>
|
||||
#include <grpcpp/support/method_handler.h>
|
||||
#include <grpcpp/impl/codegen/proto_utils.h>
|
||||
#include <grpcpp/impl/codegen/rpc_method.h>
|
||||
#include <grpcpp/impl/codegen/server_callback.h>
|
||||
#include <grpcpp/impl/rpc_method.h>
|
||||
#include <grpcpp/support/server_callback.h>
|
||||
#include <grpcpp/impl/codegen/server_callback_handlers.h>
|
||||
#include <grpcpp/impl/codegen/server_context.h>
|
||||
#include <grpcpp/impl/codegen/service_type.h>
|
||||
#include <grpcpp/server_context.h>
|
||||
#include <grpcpp/impl/service_type.h>
|
||||
#include <grpcpp/impl/codegen/status.h>
|
||||
#include <grpcpp/impl/codegen/stub_options.h>
|
||||
#include <grpcpp/impl/codegen/sync_stream.h>
|
||||
#include <grpcpp/support/stub_options.h>
|
||||
#include <grpcpp/support/sync_stream.h>
|
||||
|
||||
namespace grpc {
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
#error incompatible with your Protocol Buffer headers. Please update
|
||||
#error your headers.
|
||||
#endif
|
||||
#if 3021003 < PROTOBUF_MIN_PROTOC_VERSION
|
||||
#if 3021012 < PROTOBUF_MIN_PROTOC_VERSION
|
||||
#error This file was generated by an older version of protoc which is
|
||||
#error incompatible with your Protocol Buffer headers. Please
|
||||
#error regenerate this file with a newer version of protoc.
|
||||
|
||||
@ -34,9 +34,7 @@ SPUser User::newUser(QObject *parent) {
|
||||
/// \param[in] parent The parent object.
|
||||
//****************************************************************************************************************************************************
|
||||
User::User(QObject *parent)
|
||||
: QObject(parent)
|
||||
, imapFailureCooldownEndTime_(QDateTime::currentDateTime()) {
|
||||
|
||||
: QObject(parent) {
|
||||
}
|
||||
|
||||
|
||||
@ -355,22 +353,18 @@ QString User::stateToString(UserState state) {
|
||||
|
||||
|
||||
//****************************************************************************************************************************************************
|
||||
/// We display a notification and pop the application window if an IMAP client tries to connect to a signed out account, but we do not want to
|
||||
/// do it repeatedly, as it's an intrusive action. This function let's you define a period of time during which the notification should not be
|
||||
/// displayed.
|
||||
///
|
||||
/// \param durationMSecs The duration of the period in milliseconds.
|
||||
/// \param[in] durationMSecs The duration of the period in milliseconds.
|
||||
//****************************************************************************************************************************************************
|
||||
void User::startImapLoginFailureCooldown(qint64 durationMSecs) {
|
||||
imapFailureCooldownEndTime_ = QDateTime::currentDateTime().addMSecs(durationMSecs);
|
||||
void User::startNotificationCooldownPeriod(User::ENotification notification, qint64 durationMSecs) {
|
||||
notificationCooldownList_[notification] = QDateTime::currentDateTime().addMSecs(durationMSecs);
|
||||
}
|
||||
|
||||
|
||||
//****************************************************************************************************************************************************
|
||||
/// \return true if we currently are in a cooldown period for the notification
|
||||
/// \return true iff the notification is currently in a cooldown period.
|
||||
//****************************************************************************************************************************************************
|
||||
bool User::isInIMAPLoginFailureCooldown() const {
|
||||
return QDateTime::currentDateTime() < imapFailureCooldownEndTime_;
|
||||
bool User::isNotificationInCooldown(User::ENotification notification) const {
|
||||
return notificationCooldownList_.contains(notification) && (QDateTime::currentDateTime() < notificationCooldownList_[notification]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -62,6 +62,13 @@ typedef std::shared_ptr<class User> SPUser; ///< Type definition for shared poin
|
||||
class User : public QObject {
|
||||
|
||||
Q_OBJECT
|
||||
public: // data types
|
||||
enum class ENotification {
|
||||
IMAPLoginWhileSignedOut, ///< An IMAP client tried to login while the user is signed out.
|
||||
IMAPPasswordFailure, ///< An IMAP client provided an invalid password for the user.
|
||||
IMAPLoginWhileLocked, ///< An IMAP client tried to connect while the user is locked.
|
||||
};
|
||||
|
||||
public: // static member function
|
||||
static SPUser newUser(QObject *parent); ///< Create a new user
|
||||
static QString stateToString(UserState state); ///< Return a string describing a user state.
|
||||
@ -74,8 +81,8 @@ public: // member functions.
|
||||
User &operator=(User &&) = delete; ///< Disabled move assignment operator.
|
||||
void update(User const &user); ///< Update the user.
|
||||
Q_INVOKABLE QString primaryEmailOrUsername() const; ///< Return the user primary email, or, if unknown its username.
|
||||
void startImapLoginFailureCooldown(qint64 durationMSecs); ///< Start the user cooldown period for the IMAP login attempt while signed-out notification.
|
||||
bool isInIMAPLoginFailureCooldown() const; ///< Check if the user in a IMAP login failure notification.
|
||||
void startNotificationCooldownPeriod(ENotification notification, qint64 durationMSecs); ///< Start the user cooldown period for a notification.
|
||||
bool isNotificationInCooldown(ENotification notification) const; ///< Return true iff the notification is in a cooldown period.
|
||||
|
||||
public slots:
|
||||
// slots for QML generated calls
|
||||
@ -147,7 +154,7 @@ private: // member functions.
|
||||
User(QObject *parent); ///< Default constructor.
|
||||
|
||||
private: // data members.
|
||||
QDateTime imapFailureCooldownEndTime_; ///< The end date/time for the IMAP login failure notification cooldown period.
|
||||
QMap<ENotification, QDateTime> notificationCooldownList_; ///< A list of cooldown period end time for notifications.
|
||||
QString id_; ///< The userID.
|
||||
QString username_; ///< The username
|
||||
QString password_; ///< The IMAP password of the user.
|
||||
|
||||
@ -297,7 +297,7 @@ func (f *frontendCLI) configureAppleMail(c *ishell.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := f.bridge.ConfigureAppleMail(user.UserID, user.Addresses[0]); err != nil {
|
||||
if err := f.bridge.ConfigureAppleMail(context.Background(), user.UserID, user.Addresses[0]); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
@ -305,11 +305,11 @@ func (f *frontendCLI) configureAppleMail(c *ishell.Context) {
|
||||
f.Printf("Apple Mail configured for %v with address %v\n", user.Username, user.Addresses[0])
|
||||
}
|
||||
|
||||
func (f *frontendCLI) badEventSynchronize(c *ishell.Context) {
|
||||
func (f *frontendCLI) badEventSynchronize(_ *ishell.Context) {
|
||||
f.badEventFeedback(true)
|
||||
}
|
||||
|
||||
func (f *frontendCLI) badEventLogout(c *ishell.Context) {
|
||||
func (f *frontendCLI) badEventLogout(_ *ishell.Context) {
|
||||
f.badEventFeedback(false)
|
||||
}
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
|
||||
@ -60,6 +61,11 @@ func New(
|
||||
panicHandler: panicHandler,
|
||||
}
|
||||
|
||||
// We want to exit at the first Ctrl+C. By default, ishell requires two.
|
||||
fe.Interrupt(func(_ *ishell.Context, _ int, _ string) {
|
||||
os.Exit(1)
|
||||
})
|
||||
|
||||
// Clear commands.
|
||||
clearCmd := &ishell.Cmd{
|
||||
Name: "clear",
|
||||
|
||||
@ -31,7 +31,7 @@ import (
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
func (f *frontendCLI) printLogDir(c *ishell.Context) {
|
||||
func (f *frontendCLI) printLogDir(_ *ishell.Context) {
|
||||
if path, err := f.bridge.GetLogsPath(); err != nil {
|
||||
f.Println("Failed to determine location of log files")
|
||||
} else {
|
||||
@ -39,17 +39,17 @@ func (f *frontendCLI) printLogDir(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) printManual(c *ishell.Context) {
|
||||
func (f *frontendCLI) printManual(_ *ishell.Context) {
|
||||
f.Println("More instructions about the Bridge can be found at\n\n https://proton.me/mail/bridge")
|
||||
}
|
||||
|
||||
func (f *frontendCLI) printCredits(c *ishell.Context) {
|
||||
func (f *frontendCLI) printCredits(_ *ishell.Context) {
|
||||
for _, pkg := range strings.Split(bridge.Credits, ";") {
|
||||
f.Println(pkg)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) changeIMAPSecurity(c *ishell.Context) {
|
||||
func (f *frontendCLI) changeIMAPSecurity(_ *ishell.Context) {
|
||||
f.ShowPrompt(false)
|
||||
defer f.ShowPrompt(true)
|
||||
|
||||
@ -61,14 +61,14 @@ func (f *frontendCLI) changeIMAPSecurity(c *ishell.Context) {
|
||||
msg := fmt.Sprintf("Are you sure you want to change IMAP setting to %q", newSecurity)
|
||||
|
||||
if f.yesNoQuestion(msg) {
|
||||
if err := f.bridge.SetIMAPSSL(!f.bridge.GetIMAPSSL()); err != nil {
|
||||
if err := f.bridge.SetIMAPSSL(context.Background(), !f.bridge.GetIMAPSSL()); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) changeSMTPSecurity(c *ishell.Context) {
|
||||
func (f *frontendCLI) changeSMTPSecurity(_ *ishell.Context) {
|
||||
f.ShowPrompt(false)
|
||||
defer f.ShowPrompt(true)
|
||||
|
||||
@ -80,7 +80,7 @@ func (f *frontendCLI) changeSMTPSecurity(c *ishell.Context) {
|
||||
msg := fmt.Sprintf("Are you sure you want to change SMTP setting to %q", newSecurity)
|
||||
|
||||
if f.yesNoQuestion(msg) {
|
||||
if err := f.bridge.SetSMTPSSL(!f.bridge.GetSMTPSSL()); err != nil {
|
||||
if err := f.bridge.SetSMTPSSL(context.Background(), !f.bridge.GetSMTPSSL()); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
@ -103,7 +103,7 @@ func (f *frontendCLI) changeIMAPPort(c *ishell.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := f.bridge.SetIMAPPort(newIMAPPortInt); err != nil {
|
||||
if err := f.bridge.SetIMAPPort(context.Background(), newIMAPPortInt); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
@ -125,13 +125,13 @@ func (f *frontendCLI) changeSMTPPort(c *ishell.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := f.bridge.SetSMTPPort(newSMTPPortInt); err != nil {
|
||||
if err := f.bridge.SetSMTPPort(context.Background(), newSMTPPortInt); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) allowProxy(c *ishell.Context) {
|
||||
func (f *frontendCLI) allowProxy(_ *ishell.Context) {
|
||||
if f.bridge.GetProxyAllowed() {
|
||||
f.Println("Bridge is already set to use alternative routing to connect to Proton if it is being blocked.")
|
||||
return
|
||||
@ -147,7 +147,7 @@ func (f *frontendCLI) allowProxy(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) disallowProxy(c *ishell.Context) {
|
||||
func (f *frontendCLI) disallowProxy(_ *ishell.Context) {
|
||||
if !f.bridge.GetProxyAllowed() {
|
||||
f.Println("Bridge is already set to NOT use alternative routing to connect to Proton if it is being blocked.")
|
||||
return
|
||||
@ -163,7 +163,7 @@ func (f *frontendCLI) disallowProxy(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) hideAllMail(c *ishell.Context) {
|
||||
func (f *frontendCLI) hideAllMail(_ *ishell.Context) {
|
||||
if !f.bridge.GetShowAllMail() {
|
||||
f.Println("All Mail folder is not listed in your local client.")
|
||||
return
|
||||
@ -179,7 +179,7 @@ func (f *frontendCLI) hideAllMail(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) showAllMail(c *ishell.Context) {
|
||||
func (f *frontendCLI) showAllMail(_ *ishell.Context) {
|
||||
if f.bridge.GetShowAllMail() {
|
||||
f.Println("All Mail folder is listed in your local client.")
|
||||
return
|
||||
|
||||
@ -23,7 +23,7 @@ import (
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
func (f *frontendCLI) checkUpdates(c *ishell.Context) {
|
||||
func (f *frontendCLI) checkUpdates(_ *ishell.Context) {
|
||||
updateCh, done := f.bridge.GetEvents(events.UpdateAvailable{}, events.UpdateNotAvailable{})
|
||||
defer done()
|
||||
|
||||
@ -38,7 +38,7 @@ func (f *frontendCLI) checkUpdates(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) enableAutoUpdates(c *ishell.Context) {
|
||||
func (f *frontendCLI) enableAutoUpdates(_ *ishell.Context) {
|
||||
if f.bridge.GetAutoUpdate() {
|
||||
f.Println("Bridge is already set to automatically install updates.")
|
||||
return
|
||||
@ -54,7 +54,7 @@ func (f *frontendCLI) enableAutoUpdates(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) disableAutoUpdates(c *ishell.Context) {
|
||||
func (f *frontendCLI) disableAutoUpdates(_ *ishell.Context) {
|
||||
if !f.bridge.GetAutoUpdate() {
|
||||
f.Println("Bridge is already set to NOT automatically install updates.")
|
||||
return
|
||||
@ -70,7 +70,7 @@ func (f *frontendCLI) disableAutoUpdates(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) selectEarlyChannel(c *ishell.Context) {
|
||||
func (f *frontendCLI) selectEarlyChannel(_ *ishell.Context) {
|
||||
if f.bridge.GetUpdateChannel() == updater.EarlyChannel {
|
||||
f.Println("Bridge is already on the early-access update channel.")
|
||||
return
|
||||
@ -86,7 +86,7 @@ func (f *frontendCLI) selectEarlyChannel(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) selectStableChannel(c *ishell.Context) {
|
||||
func (f *frontendCLI) selectStableChannel(_ *ishell.Context) {
|
||||
if f.bridge.GetUpdateChannel() == updater.StableChannel {
|
||||
f.Println("Bridge is already on the stable update channel.")
|
||||
return
|
||||
|
||||
@ -47,7 +47,7 @@ import (
|
||||
)
|
||||
|
||||
// CheckTokens implements the CheckToken gRPC service call.
|
||||
func (s *Service) CheckTokens(ctx context.Context, clientConfigPath *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) CheckTokens(_ context.Context, clientConfigPath *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("CheckTokens")
|
||||
|
||||
path := clientConfigPath.Value
|
||||
@ -65,7 +65,7 @@ func (s *Service) CheckTokens(ctx context.Context, clientConfigPath *wrapperspb.
|
||||
return &wrapperspb.StringValue{Value: clientConfig.Token}, nil
|
||||
}
|
||||
|
||||
func (s *Service) AddLogEntry(ctx context.Context, request *AddLogEntryRequest) (*emptypb.Empty, error) {
|
||||
func (s *Service) AddLogEntry(_ context.Context, request *AddLogEntryRequest) (*emptypb.Empty, error) {
|
||||
entry := s.log
|
||||
|
||||
if len(request.Package) > 0 {
|
||||
@ -93,7 +93,7 @@ func (s *Service) AddLogEntry(ctx context.Context, request *AddLogEntryRequest)
|
||||
}
|
||||
|
||||
// GuiReady implement the GuiReady gRPC service call.
|
||||
func (s *Service) GuiReady(ctx context.Context, _ *emptypb.Empty) (*GuiReadyResponse, error) {
|
||||
func (s *Service) GuiReady(_ context.Context, _ *emptypb.Empty) (*GuiReadyResponse, error) {
|
||||
s.log.Debug("GuiReady")
|
||||
|
||||
s.initializationDone.Do(s.initializing.Done)
|
||||
@ -107,7 +107,7 @@ func (s *Service) GuiReady(ctx context.Context, _ *emptypb.Empty) (*GuiReadyResp
|
||||
}
|
||||
|
||||
// Quit implement the Quit gRPC service call.
|
||||
func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
func (s *Service) Quit(_ context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("Quit")
|
||||
return &emptypb.Empty{}, s.quit()
|
||||
}
|
||||
@ -143,13 +143,13 @@ func (s *Service) Restart(ctx context.Context, empty *emptypb.Empty) (*emptypb.E
|
||||
return s.Quit(ctx, empty)
|
||||
}
|
||||
|
||||
func (s *Service) ShowOnStartup(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
func (s *Service) ShowOnStartup(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("ShowOnStartup")
|
||||
|
||||
return wrapperspb.Bool(s.showOnStartup), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) SetIsAutostartOn(_ context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("show", isOn.Value).Debug("SetIsAutostartOn")
|
||||
|
||||
defer func() { _ = s.SendEvent(NewToggleAutostartFinishedEvent()) }()
|
||||
@ -169,13 +169,13 @@ func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolVal
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) IsAutostartOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
func (s *Service) IsAutostartOn(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsAutostartOn")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetAutostart()), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) SetIsBetaEnabled(_ context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isEnabled", isEnabled.Value).Debug("SetIsBetaEnabled")
|
||||
|
||||
channel := updater.StableChannel
|
||||
@ -191,13 +191,13 @@ func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.Bo
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) IsBetaEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
func (s *Service) IsBetaEnabled(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsBetaEnabled")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetUpdateChannel() == updater.EarlyChannel), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) SetIsAllMailVisible(_ context.Context, isVisible *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isVisible", isVisible.Value).Debug("SetIsAllMailVisible")
|
||||
|
||||
if err := s.bridge.SetShowAllMail(isVisible.Value); err != nil {
|
||||
@ -208,7 +208,7 @@ func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) IsAllMailVisible(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
func (s *Service) IsAllMailVisible(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsAllMailVisible")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetShowAllMail()), nil
|
||||
@ -231,13 +231,13 @@ func (s *Service) IsTelemetryDisabled(_ context.Context, _ *emptypb.Empty) (*wra
|
||||
return wrapperspb.Bool(s.bridge.GetTelemetryDisabled()), nil
|
||||
}
|
||||
|
||||
func (s *Service) GoOs(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) GoOs(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("GoOs") // TO-DO We can probably get rid of this and use QSysInfo::product name
|
||||
|
||||
return wrapperspb.String(runtime.GOOS), nil
|
||||
}
|
||||
|
||||
func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
func (s *Service) TriggerReset(_ context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("TriggerReset")
|
||||
|
||||
go func() {
|
||||
@ -248,13 +248,13 @@ func (s *Service) TriggerReset(ctx context.Context, _ *emptypb.Empty) (*emptypb.
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Version(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) Version(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("Version")
|
||||
|
||||
return wrapperspb.String(s.bridge.GetCurrentVersion().Original()), nil
|
||||
}
|
||||
|
||||
func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) LogsPath(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("LogsPath")
|
||||
|
||||
path, err := s.bridge.GetLogsPath()
|
||||
@ -265,7 +265,7 @@ func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.S
|
||||
return wrapperspb.String(path), nil
|
||||
}
|
||||
|
||||
func (s *Service) LicensePath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) LicensePath(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("LicensePath")
|
||||
|
||||
return wrapperspb.String(s.bridge.GetLicenseFilePath()), nil
|
||||
@ -275,7 +275,7 @@ func (s *Service) DependencyLicensesLink(_ context.Context, _ *emptypb.Empty) (*
|
||||
return wrapperspb.String(s.bridge.GetDependencyLicensesLink()), nil
|
||||
}
|
||||
|
||||
func (s *Service) ReleaseNotesPageLink(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) ReleaseNotesPageLink(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.latestLock.RLock()
|
||||
defer s.latestLock.RUnlock()
|
||||
|
||||
@ -289,7 +289,7 @@ func (s *Service) LandingPageLink(_ context.Context, _ *emptypb.Empty) (*wrapper
|
||||
return wrapperspb.String(s.latest.LandingPage), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) SetColorSchemeName(_ context.Context, name *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("ColorSchemeName", name.Value).Debug("SetColorSchemeName")
|
||||
|
||||
if !theme.IsAvailable(theme.Theme(name.Value)) {
|
||||
@ -305,7 +305,7 @@ func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.Strin
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) ColorSchemeName(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) ColorSchemeName(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("ColorSchemeName")
|
||||
|
||||
current := s.bridge.GetColorScheme()
|
||||
@ -320,13 +320,13 @@ func (s *Service) ColorSchemeName(ctx context.Context, _ *emptypb.Empty) (*wrapp
|
||||
return wrapperspb.String(current), nil
|
||||
}
|
||||
|
||||
func (s *Service) CurrentEmailClient(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) CurrentEmailClient(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("CurrentEmailClient")
|
||||
|
||||
return wrapperspb.String(s.bridge.GetCurrentUserAgent()), nil
|
||||
}
|
||||
|
||||
func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emptypb.Empty, error) {
|
||||
func (s *Service) ReportBug(_ context.Context, report *ReportBugRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithFields(logrus.Fields{
|
||||
"osType": report.OsType,
|
||||
"osVersion": report.OsVersion,
|
||||
@ -382,7 +382,7 @@ func (s *Service) ExportTLSCertificates(_ context.Context, folderPath *wrappersp
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) ForceLauncher(_ context.Context, launcher *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("launcher", launcher.Value).Debug("ForceLauncher")
|
||||
|
||||
s.restarter.Override(launcher.Value)
|
||||
@ -390,7 +390,7 @@ func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.String
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) SetMainExecutable(_ context.Context, exe *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("executable", exe.Value).Debug("SetMainExecutable")
|
||||
|
||||
s.restarter.AddFlags("--wait", exe.Value)
|
||||
@ -398,7 +398,7 @@ func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringV
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
func (s *Service) Login(_ context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", login.Username).Debug("Login")
|
||||
|
||||
go func() {
|
||||
@ -454,7 +454,7 @@ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empt
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
func (s *Service) Login2FA(_ context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", login.Username).Debug("Login2FA")
|
||||
|
||||
go func() {
|
||||
@ -499,7 +499,7 @@ func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.E
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
func (s *Service) Login2Passwords(_ context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", login.Username).Debug("Login2Passwords")
|
||||
|
||||
go func() {
|
||||
@ -521,7 +521,7 @@ func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*em
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest) (*emptypb.Empty, error) {
|
||||
func (s *Service) LoginAbort(_ context.Context, loginAbort *LoginAbortRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", loginAbort.Username).Debug("LoginAbort")
|
||||
|
||||
go func() {
|
||||
@ -565,7 +565,7 @@ func (s *Service) CheckUpdate(context.Context, *emptypb.Empty) (*emptypb.Empty,
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
func (s *Service) InstallUpdate(_ context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("InstallUpdate")
|
||||
|
||||
go func() {
|
||||
@ -579,7 +579,7 @@ func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) SetIsAutomaticUpdateOn(_ context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isOn", isOn.Value).Debug("SetIsAutomaticUpdateOn")
|
||||
|
||||
if currentlyOn := s.bridge.GetAutoUpdate(); currentlyOn == isOn.Value {
|
||||
@ -594,19 +594,19 @@ func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.B
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) IsAutomaticUpdateOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
func (s *Service) IsAutomaticUpdateOn(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsAutomaticUpdateOn")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetAutoUpdate()), nil
|
||||
}
|
||||
|
||||
func (s *Service) DiskCachePath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) DiskCachePath(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("DiskCachePath")
|
||||
|
||||
return wrapperspb.String(s.bridge.GetGluonCacheDir()), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetDiskCachePath(ctx context.Context, newPath *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) SetDiskCachePath(_ context.Context, newPath *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("path", newPath.Value).Debug("setDiskCachePath")
|
||||
|
||||
go func() {
|
||||
@ -637,7 +637,7 @@ func (s *Service) SetDiskCachePath(ctx context.Context, newPath *wrapperspb.Stri
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIsDoHEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) SetIsDoHEnabled(_ context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isEnabled", isEnabled.Value).Debug("SetIsDohEnabled")
|
||||
|
||||
if err := s.bridge.SetProxyAllowed(isEnabled.Value); err != nil {
|
||||
@ -648,7 +648,7 @@ func (s *Service) SetIsDoHEnabled(ctx context.Context, isEnabled *wrapperspb.Boo
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) IsDoHEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
func (s *Service) IsDoHEnabled(_ context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsDohEnabled")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetProxyAllowed()), nil
|
||||
@ -668,7 +668,7 @@ func (s *Service) MailServerSettings(_ context.Context, _ *emptypb.Empty) (*Imap
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) SetMailServerSettings(_ context.Context, settings *ImapSmtpSettings) (*emptypb.Empty, error) {
|
||||
func (s *Service) SetMailServerSettings(ctx context.Context, settings *ImapSmtpSettings) (*emptypb.Empty, error) {
|
||||
s.log.
|
||||
WithField("ImapPort", settings.ImapPort).
|
||||
WithField("SmtpPort", settings.SmtpPort).
|
||||
@ -682,28 +682,28 @@ func (s *Service) SetMailServerSettings(_ context.Context, settings *ImapSmtpSet
|
||||
defer func() { _ = s.SendEvent(NewChangeMailServerSettingsFinishedEvent()) }()
|
||||
|
||||
if s.bridge.GetIMAPSSL() != settings.UseSSLForImap {
|
||||
if err := s.bridge.SetIMAPSSL(settings.UseSSLForImap); err != nil {
|
||||
if err := s.bridge.SetIMAPSSL(ctx, settings.UseSSLForImap); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set IMAP SSL")
|
||||
_ = s.SendEvent(NewMailServerSettingsErrorEvent(MailServerSettingsErrorType_IMAP_CONNECTION_MODE_CHANGE_ERROR))
|
||||
}
|
||||
}
|
||||
|
||||
if s.bridge.GetSMTPSSL() != settings.UseSSLForSmtp {
|
||||
if err := s.bridge.SetSMTPSSL(settings.UseSSLForSmtp); err != nil {
|
||||
if err := s.bridge.SetSMTPSSL(ctx, settings.UseSSLForSmtp); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set SMTP SSL")
|
||||
_ = s.SendEvent(NewMailServerSettingsErrorEvent(MailServerSettingsErrorType_SMTP_CONNECTION_MODE_CHANGE_ERROR))
|
||||
}
|
||||
}
|
||||
|
||||
if s.bridge.GetIMAPPort() != int(settings.ImapPort) {
|
||||
if err := s.bridge.SetIMAPPort(int(settings.ImapPort)); err != nil {
|
||||
if err := s.bridge.SetIMAPPort(ctx, int(settings.ImapPort)); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set IMAP port")
|
||||
_ = s.SendEvent(NewMailServerSettingsErrorEvent(MailServerSettingsErrorType_IMAP_PORT_CHANGE_ERROR))
|
||||
}
|
||||
}
|
||||
|
||||
if s.bridge.GetSMTPPort() != int(settings.SmtpPort) {
|
||||
if err := s.bridge.SetSMTPPort(int(settings.SmtpPort)); err != nil {
|
||||
if err := s.bridge.SetSMTPPort(ctx, int(settings.SmtpPort)); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set SMTP port")
|
||||
_ = s.SendEvent(NewMailServerSettingsErrorEvent(MailServerSettingsErrorType_SMTP_PORT_CHANGE_ERROR))
|
||||
}
|
||||
@ -715,19 +715,19 @@ func (s *Service) SetMailServerSettings(_ context.Context, settings *ImapSmtpSet
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Hostname(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) Hostname(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("Hostname")
|
||||
|
||||
return wrapperspb.String(constants.Host), nil
|
||||
}
|
||||
|
||||
func (s *Service) IsPortFree(ctx context.Context, port *wrapperspb.Int32Value) (*wrapperspb.BoolValue, error) {
|
||||
func (s *Service) IsPortFree(_ context.Context, port *wrapperspb.Int32Value) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsPortFree")
|
||||
|
||||
return wrapperspb.Bool(ports.IsPortFree(int(port.Value))), nil
|
||||
}
|
||||
|
||||
func (s *Service) AvailableKeychains(ctx context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) {
|
||||
func (s *Service) AvailableKeychains(_ context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) {
|
||||
s.log.Debug("AvailableKeychains")
|
||||
|
||||
return &AvailableKeychainsResponse{Keychains: maps.Keys(keychain.Helpers)}, nil
|
||||
@ -757,7 +757,7 @@ func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.S
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) CurrentKeychain(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
func (s *Service) CurrentKeychain(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("CurrentKeychain")
|
||||
|
||||
helper, err := s.bridge.GetKeychainApp()
|
||||
|
||||
@ -28,7 +28,7 @@ import (
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
func (s *Service) GetUserList(ctx context.Context, _ *emptypb.Empty) (*UserListResponse, error) {
|
||||
func (s *Service) GetUserList(_ context.Context, _ *emptypb.Empty) (*UserListResponse, error) {
|
||||
s.log.Debug("GetUserList")
|
||||
|
||||
userIDs := s.bridge.GetUserIDs()
|
||||
@ -51,7 +51,7 @@ func (s *Service) GetUserList(ctx context.Context, _ *emptypb.Empty) (*UserListR
|
||||
return &UserListResponse{Users: userList}, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetUser(ctx context.Context, userID *wrapperspb.StringValue) (*User, error) {
|
||||
func (s *Service) GetUser(_ context.Context, userID *wrapperspb.StringValue) (*User, error) {
|
||||
s.log.WithField("userID", userID).Debug("GetUser")
|
||||
|
||||
user, err := s.bridge.GetUserInfo(userID.Value)
|
||||
@ -62,7 +62,7 @@ func (s *Service) GetUser(ctx context.Context, userID *wrapperspb.StringValue) (
|
||||
return grpcUserFromInfo(user), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitModeRequest) (*emptypb.Empty, error) {
|
||||
func (s *Service) SetUserSplitMode(_ context.Context, splitMode *UserSplitModeRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("UserID", splitMode.UserID).WithField("Active", splitMode.Active).Debug("SetUserSplitMode")
|
||||
|
||||
user, err := s.bridge.GetUserInfo(splitMode.UserID)
|
||||
@ -96,7 +96,7 @@ func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitMode
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) SendBadEventUserFeedback(ctx context.Context, feedback *UserBadEventFeedbackRequest) (*emptypb.Empty, error) {
|
||||
func (s *Service) SendBadEventUserFeedback(_ context.Context, feedback *UserBadEventFeedbackRequest) (*emptypb.Empty, error) {
|
||||
l := s.log.WithField("UserID", feedback.UserID).WithField("doResync", feedback.DoResync)
|
||||
l.Debug("SendBadEventUserFeedback")
|
||||
|
||||
@ -114,7 +114,7 @@ func (s *Service) SendBadEventUserFeedback(ctx context.Context, feedback *UserBa
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) LogoutUser(ctx context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) LogoutUser(_ context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("UserID", userID.Value).Debug("LogoutUser")
|
||||
|
||||
if _, err := s.bridge.GetUserInfo(userID.Value); err != nil {
|
||||
@ -132,7 +132,7 @@ func (s *Service) LogoutUser(ctx context.Context, userID *wrapperspb.StringValue
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) RemoveUser(ctx context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
func (s *Service) RemoveUser(_ context.Context, userID *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("UserID", userID.Value).Debug("RemoveUser")
|
||||
|
||||
go func() {
|
||||
@ -152,7 +152,7 @@ func (s *Service) ConfigureUserAppleMail(ctx context.Context, request *Configure
|
||||
|
||||
sslWasEnabled := s.bridge.GetSMTPSSL()
|
||||
|
||||
if err := s.bridge.ConfigureAppleMail(request.UserID, request.Address); err != nil {
|
||||
if err := s.bridge.ConfigureAppleMail(ctx, request.UserID, request.Address); err != nil {
|
||||
s.log.WithField("userID", request.UserID).Error("Cannot configure AppleMail for user")
|
||||
return nil, status.Error(codes.Internal, "Apple Mail config failed")
|
||||
}
|
||||
|
||||
@ -113,7 +113,7 @@ func Init(logsPath, level string) error {
|
||||
// Debug or Trace.
|
||||
func setLevel(level string) error {
|
||||
if level == "" {
|
||||
return nil
|
||||
level = "debug"
|
||||
}
|
||||
|
||||
logLevel, err := logrus.ParseLevel(level)
|
||||
|
||||
@ -96,6 +96,7 @@ func GetTimeZone() string {
|
||||
|
||||
// NewReporter creates new sentry reporter with appName and appVersion to report.
|
||||
func NewReporter(appName string, identifier Identifier) *Reporter {
|
||||
logrus.WithField("id", GetProtectedHostname()).Info("New sentry reporter")
|
||||
return &Reporter{
|
||||
appName: appName,
|
||||
appVersion: constants.Revision,
|
||||
@ -203,7 +204,7 @@ func SkipDuringUnwind() {
|
||||
}
|
||||
|
||||
// EnhanceSentryEvent swaps type with value and removes panic handlers from the stacktrace.
|
||||
func EnhanceSentryEvent(event *sentry.Event, hint *sentry.EventHint) *sentry.Event {
|
||||
func EnhanceSentryEvent(event *sentry.Event, _ *sentry.EventHint) *sentry.Event {
|
||||
for idx, exception := range event.Exception {
|
||||
exception.Type, exception.Value = exception.Value, exception.Type
|
||||
if exception.Stacktrace != nil {
|
||||
|
||||
@ -62,6 +62,6 @@ func (i *InstallerDarwin) InstallUpdate(_ *semver.Version, r io.Reader) error {
|
||||
return syncFolders(oldBundle, newBundle)
|
||||
}
|
||||
|
||||
func (i *InstallerDarwin) IsAlreadyInstalled(version *semver.Version) bool {
|
||||
func (i *InstallerDarwin) IsAlreadyInstalled(_ *semver.Version) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -217,7 +217,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
|
||||
// If the address is enabled, we need to hook it up to the update channels.
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := getAddrIdx(user.apiAddrs, 0)
|
||||
primAddr, err := getPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
@ -276,7 +276,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
|
||||
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := getAddrIdx(user.apiAddrs, 0)
|
||||
primAddr, err := getPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
@ -394,7 +394,7 @@ func (user *User) handleLabelEvents(ctx context.Context, labelEvents []proton.La
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleCreateLabelEvent(ctx context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
|
||||
func (user *User) handleCreateLabelEvent(_ context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
|
||||
return safe.LockRetErr(func() ([]imap.Update, error) {
|
||||
var updates []imap.Update
|
||||
|
||||
@ -480,7 +480,7 @@ func (user *User) handleUpdateLabelEvent(ctx context.Context, event proton.Label
|
||||
}, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteLabelEvent(ctx context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
|
||||
func (user *User) handleDeleteLabelEvent(_ context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
|
||||
return safe.LockRetErr(func() ([]imap.Update, error) {
|
||||
var updates []imap.Update
|
||||
|
||||
@ -628,7 +628,14 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
|
||||
}
|
||||
|
||||
update = imap.NewMessagesCreated(false, res.update)
|
||||
user.updateCh[full.AddressID].Enqueue(update)
|
||||
didPublish, err := safePublishMessageUpdate(user, full.AddressID, update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !didPublish {
|
||||
update = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
@ -643,7 +650,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateMessageEvent(ctx context.Context, message proton.MessageMetadata) ([]imap.Update, error) { //nolint:unparam
|
||||
func (user *User) handleUpdateMessageEvent(_ context.Context, message proton.MessageMetadata) ([]imap.Update, error) { //nolint:unparam
|
||||
return safe.RLockRetErr(func() ([]imap.Update, error) {
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"messageID": message.ID,
|
||||
@ -674,13 +681,20 @@ func (user *User) handleUpdateMessageEvent(ctx context.Context, message proton.M
|
||||
flags,
|
||||
)
|
||||
|
||||
user.updateCh[message.AddressID].Enqueue(update)
|
||||
didPublish, err := safePublishMessageUpdate(user, message.AddressID, update)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !didPublish {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []imap.Update{update}, nil
|
||||
}, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteMessageEvent(ctx context.Context, event proton.MessageEvent) ([]imap.Update, error) { //nolint:unparam
|
||||
func (user *User) handleDeleteMessageEvent(_ context.Context, event proton.MessageEvent) ([]imap.Update, error) {
|
||||
return safe.RLockRetErr(func() ([]imap.Update, error) {
|
||||
user.log.WithField("messageID", event.ID).Info("Handling message deleted event")
|
||||
|
||||
@ -696,7 +710,7 @@ func (user *User) handleDeleteMessageEvent(ctx context.Context, event proton.Mes
|
||||
}, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.MessageEvent) ([]imap.Update, error) { //nolint:unparam
|
||||
func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.MessageEvent) ([]imap.Update, error) {
|
||||
return safe.RLockRetErr(func() ([]imap.Update, error) {
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"messageID": event.ID,
|
||||
@ -743,13 +757,24 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
|
||||
true, // Is the message doesn't exist, silently create it.
|
||||
)
|
||||
|
||||
user.updateCh[full.AddressID].Enqueue(update)
|
||||
didPublish, err := safePublishMessageUpdate(user, full.AddressID, update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !didPublish {
|
||||
update = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []imap.Update{update}, nil
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
@ -816,3 +841,37 @@ func (user *User) reportErrorNoContextCancel(title string, err error, reportCont
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// safePublishMessageUpdate handles the rare case where the address' update channel may have been deleted in the same
|
||||
// event. This rare case can take place if in the same event fetch request there is an update for delete address and
|
||||
// create/update message.
|
||||
// If the user is in combined mode, we simply push the update to the primary address. If the user is in split mode
|
||||
// we do not publish the update as the address no longer exists.
|
||||
func safePublishMessageUpdate(user *User, addressID string, update imap.Update) (bool, error) {
|
||||
v, ok := user.updateCh[addressID]
|
||||
if !ok {
|
||||
if user.GetAddressMode() == vault.CombinedMode {
|
||||
primAddr, err := getPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
primaryCh, ok := user.updateCh[primAddr.ID]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("primary address channel is not available")
|
||||
}
|
||||
|
||||
primaryCh.Enqueue(update)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
logrus.Warnf("Update channel not found for address %v, it may have been already deleted", addressID)
|
||||
_ = user.reporter.ReportMessage("Message Update channel does not exist")
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
v.Enqueue(update)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@ -270,7 +270,7 @@ func (conn *imapConnector) CreateMessage(
|
||||
mailboxID imap.MailboxID,
|
||||
literal []byte,
|
||||
flags imap.FlagSet,
|
||||
date time.Time,
|
||||
_ time.Time,
|
||||
) (imap.Message, []byte, error) {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
@ -459,11 +459,11 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
|
||||
var result bool
|
||||
|
||||
if v, ok := conn.apiLabels[string(labelFromID)]; ok && v.Type == proton.LabelTypeLabel {
|
||||
result = result || true
|
||||
result = true
|
||||
}
|
||||
|
||||
if v, ok := conn.apiLabels[string(labelToID)]; ok && (v.Type == proton.LabelTypeFolder || v.Type == proton.LabelTypeSystem) {
|
||||
result = result || true
|
||||
result = true
|
||||
}
|
||||
|
||||
return result
|
||||
@ -529,7 +529,7 @@ func (conn *imapConnector) GetMailboxVisibility(_ context.Context, mailboxID ima
|
||||
}
|
||||
|
||||
// Close the connector will no longer be used and all resources should be closed/released.
|
||||
func (conn *imapConnector) Close(ctx context.Context) error {
|
||||
func (conn *imapConnector) Close(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -544,7 +544,7 @@ func (conn *imapConnector) importMessage(
|
||||
|
||||
if err := safe.RLockRet(func() error {
|
||||
return withAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
messageID := ""
|
||||
var messageID string
|
||||
|
||||
if slices.Contains(labelIDs, proton.DraftsLabel) {
|
||||
msg, err := conn.createDraft(ctx, literal, addrKR, conn.apiAddrs[conn.addrID])
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
@ -34,22 +35,30 @@ const sendEntryExpiry = 30 * time.Minute
|
||||
type sendRecorder struct {
|
||||
expiry time.Duration
|
||||
|
||||
entries map[string]*sendEntry
|
||||
entries map[string][]*sendEntry
|
||||
entriesLock sync.Mutex
|
||||
}
|
||||
|
||||
func newSendRecorder(expiry time.Duration) *sendRecorder {
|
||||
return &sendRecorder{
|
||||
expiry: expiry,
|
||||
entries: make(map[string]*sendEntry),
|
||||
entries: make(map[string][]*sendEntry),
|
||||
}
|
||||
}
|
||||
|
||||
type sendEntry struct {
|
||||
msgID string
|
||||
toList []string
|
||||
exp time.Time
|
||||
waitCh chan struct{}
|
||||
msgID string
|
||||
toList []string
|
||||
exp time.Time
|
||||
waitCh chan struct{}
|
||||
waitChClosed bool
|
||||
}
|
||||
|
||||
func (s *sendEntry) closeWaitChannel() {
|
||||
if !s.waitChClosed {
|
||||
close(s.waitCh)
|
||||
s.waitChClosed = true
|
||||
}
|
||||
}
|
||||
|
||||
// tryInsertWait tries to insert the given message into the send recorder.
|
||||
@ -102,25 +111,40 @@ func (h *sendRecorder) hasEntryWait(ctx context.Context, hash string, deadline t
|
||||
return h.hasEntryWait(ctx, hash, deadline)
|
||||
}
|
||||
|
||||
func (h *sendRecorder) removeExpiredUnsafe() {
|
||||
for hash, entry := range h.entries {
|
||||
remaining := xslices.Filter(entry, func(t *sendEntry) bool {
|
||||
return !t.exp.Before(time.Now())
|
||||
})
|
||||
|
||||
if len(remaining) == 0 {
|
||||
delete(h.entries, hash)
|
||||
} else {
|
||||
h.entries[hash] = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sendRecorder) tryInsert(hash string, toList []string) bool {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
for hash, entry := range h.entries {
|
||||
if entry.exp.Before(time.Now()) {
|
||||
delete(h.entries, hash)
|
||||
h.removeExpiredUnsafe()
|
||||
|
||||
entries, ok := h.entries[hash]
|
||||
if ok {
|
||||
for _, entry := range entries {
|
||||
if matchToList(entry.toList, toList) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := h.entries[hash]; ok && matchToList(h.entries[hash].toList, toList) {
|
||||
return false
|
||||
}
|
||||
|
||||
h.entries[hash] = &sendEntry{
|
||||
h.entries[hash] = append(entries, &sendEntry{
|
||||
exp: time.Now().Add(h.expiry),
|
||||
toList: toList,
|
||||
waitCh: make(chan struct{}),
|
||||
}
|
||||
})
|
||||
|
||||
return true
|
||||
}
|
||||
@ -129,11 +153,7 @@ func (h *sendRecorder) hasEntry(hash string) bool {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
for hash, entry := range h.entries {
|
||||
if entry.exp.Before(time.Now()) {
|
||||
delete(h.entries, hash)
|
||||
}
|
||||
}
|
||||
h.removeExpiredUnsafe()
|
||||
|
||||
if _, ok := h.entries[hash]; ok {
|
||||
return true
|
||||
@ -142,32 +162,46 @@ func (h *sendRecorder) hasEntry(hash string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sendRecorder) addMessageID(hash, msgID string) {
|
||||
// signalMessageSent should be called after a message has been successfully sent.
|
||||
func (h *sendRecorder) signalMessageSent(hash, msgID string, toList []string) {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
entry, ok := h.entries[hash]
|
||||
entries, ok := h.entries[hash]
|
||||
if ok {
|
||||
entry.msgID = msgID
|
||||
} else {
|
||||
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
|
||||
for _, entry := range entries {
|
||||
if matchToList(entry.toList, toList) {
|
||||
entry.msgID = msgID
|
||||
entry.closeWaitChannel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(entry.waitCh)
|
||||
logrus.Warn("Cannot add message ID to send hash entry, it may have expired")
|
||||
}
|
||||
|
||||
func (h *sendRecorder) removeOnFail(hash string) {
|
||||
func (h *sendRecorder) removeOnFail(hash string, toList []string) {
|
||||
h.entriesLock.Lock()
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
entry, ok := h.entries[hash]
|
||||
if !ok || entry.msgID != "" {
|
||||
entries, ok := h.entries[hash]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
close(entry.waitCh)
|
||||
for idx, entry := range entries {
|
||||
if entry.msgID == "" && matchToList(entry.toList, toList) {
|
||||
entry.closeWaitChannel()
|
||||
|
||||
delete(h.entries, hash)
|
||||
remaining := xslices.Remove(entries, idx, 1)
|
||||
if len(remaining) != 0 {
|
||||
h.entries[hash] = remaining
|
||||
} else {
|
||||
delete(h.entries, hash)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time) (string, bool, error) {
|
||||
@ -191,7 +225,7 @@ func (h *sendRecorder) wait(ctx context.Context, hash string, deadline time.Time
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
if entry, ok := h.entries[hash]; ok {
|
||||
return entry.msgID, true, nil
|
||||
return entry[0].msgID, true, nil
|
||||
}
|
||||
|
||||
return "", false, nil
|
||||
@ -202,7 +236,7 @@ func (h *sendRecorder) getWaitCh(hash string) (<-chan struct{}, bool) {
|
||||
defer h.entriesLock.Unlock()
|
||||
|
||||
if entry, ok := h.entries[hash]; ok {
|
||||
return entry.waitCh, true
|
||||
return entry[0].waitCh, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
|
||||
@ -35,7 +35,7 @@ func TestSendHasher_Insert(t *testing.T) {
|
||||
require.NotEmpty(t, hash1)
|
||||
|
||||
// Simulate successfully sending the message.
|
||||
h.addMessageID(hash1, "abc")
|
||||
h.signalMessageSent(hash1, "abc", nil)
|
||||
|
||||
// Inserting a message with the same hash should return false.
|
||||
_, ok, err = testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
@ -59,7 +59,7 @@ func TestSendHasher_Insert_Expired(t *testing.T) {
|
||||
require.NotEmpty(t, hash1)
|
||||
|
||||
// Simulate successfully sending the message.
|
||||
h.addMessageID(hash1, "abc")
|
||||
h.signalMessageSent(hash1, "abc", nil)
|
||||
|
||||
// Wait for the entry to expire.
|
||||
time.Sleep(time.Second)
|
||||
@ -106,7 +106,7 @@ func TestSendHasher_Wait_SendSuccess(t *testing.T) {
|
||||
// Simulate successfully sending the message after half a second.
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
h.addMessageID(hash, "abc")
|
||||
h.signalMessageSent(hash, "abc", nil)
|
||||
}()
|
||||
|
||||
// Inserting a message with the same hash should fail.
|
||||
@ -127,7 +127,7 @@ func TestSendHasher_Wait_SendFail(t *testing.T) {
|
||||
// Simulate failing to send the message after half a second.
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
h.removeOnFail(hash)
|
||||
h.removeOnFail(hash, nil)
|
||||
}()
|
||||
|
||||
// Inserting a message with the same hash should succeed because the first message failed to send.
|
||||
@ -163,7 +163,7 @@ func TestSendHasher_HasEntry(t *testing.T) {
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Simulate successfully sending the message.
|
||||
h.addMessageID(hash, "abc")
|
||||
h.signalMessageSent(hash, "abc", nil)
|
||||
|
||||
// The message was already sent; we should find it in the hasher.
|
||||
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||
@ -184,7 +184,7 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
|
||||
// Simulate successfully sending the message after half a second.
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
h.addMessageID(hash, "abc")
|
||||
h.signalMessageSent(hash, "abc", nil)
|
||||
}()
|
||||
|
||||
// The message was already sent; we should find it in the hasher.
|
||||
@ -194,6 +194,47 @@ func TestSendHasher_HasEntry_SendSuccess(t *testing.T) {
|
||||
require.Equal(t, "abc", messageID)
|
||||
}
|
||||
|
||||
func TestSendHasher_DualAddDoesNotCauseCrash(t *testing.T) {
|
||||
// There may be a rare case where one 2 smtp connections attempt to send the same message, but if the first message
|
||||
// is stuck long enough for it to expire, the second connection will remove it from the list and cause it to be
|
||||
// inserted as a new entry. The two clients end up sending the message twice and calling the `signalMessageSent` x2,
|
||||
// resulting in a crash.
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Simulate successfully sending the message. We call this method twice as it possible for multiple SMTP connections
|
||||
// to attempt to send the same message.
|
||||
h.signalMessageSent(hash, "abc", nil)
|
||||
h.signalMessageSent(hash, "abc", nil)
|
||||
|
||||
// The message was already sent; we should find it in the hasher.
|
||||
messageID, ok, err := testHasEntry(h, literal1, time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "abc", messageID)
|
||||
}
|
||||
|
||||
func TestSendHashed_MessageWithSameHasButDifferentRecipientsIsInserted(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
// Insert a message into the hasher.
|
||||
hash, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver <receiver@pm.me>")
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
hash2, ok, err := testTryInsert(h, literal1, time.Now().Add(time.Second), "Receiver <receiver@pm.me>", "Receiver2 <receiver2@pm.me>")
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hash2)
|
||||
require.Equal(t, hash, hash2)
|
||||
}
|
||||
|
||||
func TestSendHasher_HasEntry_SendFail(t *testing.T) {
|
||||
h := newSendRecorder(sendEntryExpiry)
|
||||
|
||||
@ -206,7 +247,7 @@ func TestSendHasher_HasEntry_SendFail(t *testing.T) {
|
||||
// Simulate failing to send the message after half a second.
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
h.removeOnFail(hash)
|
||||
h.removeOnFail(hash, nil)
|
||||
}()
|
||||
|
||||
// The message failed to send; we should not find it in the hasher.
|
||||
@ -240,7 +281,7 @@ func TestSendHasher_HasEntry_Expired(t *testing.T) {
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Simulate successfully sending the message.
|
||||
h.addMessageID(hash, "abc")
|
||||
h.signalMessageSent(hash, "abc", nil)
|
||||
|
||||
// Wait for the entry to expire.
|
||||
time.Sleep(time.Second)
|
||||
@ -264,7 +305,6 @@ Content-Disposition: attachment; filename="attname.txt"
|
||||
attachment
|
||||
--longrandomstring--
|
||||
`
|
||||
|
||||
const literal2 = `From: Sender <sender@pm.me>
|
||||
To: Receiver <receiver@pm.me>
|
||||
Content-Type: multipart/mixed; boundary=longrandomstring
|
||||
|
||||
@ -89,7 +89,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
||||
}
|
||||
|
||||
// If we fail to send this message, we should remove the hash from the send recorder.
|
||||
defer user.sendHash.removeOnFail(hash)
|
||||
defer user.sendHash.removeOnFail(hash, to)
|
||||
|
||||
// Create a new message parser from the reader.
|
||||
parser, err := parser.New(bytes.NewReader(b))
|
||||
@ -162,7 +162,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
||||
}
|
||||
|
||||
// If the message was successfully sent, we can update the message ID in the record.
|
||||
user.sendHash.addMessageID(hash, sent.ID)
|
||||
user.sendHash.signalMessageSent(hash, sent.ID, to)
|
||||
|
||||
return nil
|
||||
})
|
||||
@ -438,6 +438,10 @@ func (user *User) createAttachments(
|
||||
}
|
||||
}
|
||||
|
||||
// Exclude name from params since this is already provided using Filename.
|
||||
delete(att.MIMEParams, "name")
|
||||
delete(att.MIMEParams, "filename")
|
||||
|
||||
attachment, err := client.UploadAttachment(ctx, addrKR, proton.CreateAttachmentReq{
|
||||
Filename: att.Name,
|
||||
MessageID: draftID,
|
||||
|
||||
@ -19,6 +19,6 @@
|
||||
|
||||
package user
|
||||
|
||||
func debugDumpToDisk(b []byte) error {
|
||||
func debugDumpToDisk(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -513,7 +513,16 @@ func (user *User) syncMessages(
|
||||
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
|
||||
defer async.HandlePanic(user.panicHandler)
|
||||
|
||||
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil
|
||||
kr, ok := addrKRs[msg.AddressID]
|
||||
if !ok {
|
||||
return &buildRes{
|
||||
messageID: msg.ID,
|
||||
addressID: msg.AddressID,
|
||||
err: fmt.Errorf("address does not have an unlocked keyring"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return buildRFC822(apiLabels, msg, kr, new(bytes.Buffer)), nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
@ -572,10 +581,10 @@ func (user *User) syncMessages(
|
||||
|
||||
// We could sync a placeholder message here, but for now we skip it entirely.
|
||||
continue
|
||||
} else {
|
||||
if err := vault.RemFailedMessageID(res.messageID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove failed message ID")
|
||||
}
|
||||
}
|
||||
|
||||
if err := vault.RemFailedMessageID(res.messageID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove failed message ID")
|
||||
}
|
||||
|
||||
targetInfo := addressToIndex[res.addressID]
|
||||
|
||||
@ -83,6 +83,18 @@ func getAddrIdx(apiAddrs map[string]proton.Address, idx int) (proton.Address, er
|
||||
return sorted[idx], nil
|
||||
}
|
||||
|
||||
func getPrimaryAddr(apiAddrs map[string]proton.Address) (proton.Address, error) {
|
||||
sorted := sortSlice(maps.Values(apiAddrs), func(a, b proton.Address) bool {
|
||||
return a.Order < b.Order
|
||||
})
|
||||
|
||||
if len(sorted) == 0 {
|
||||
return proton.Address{}, fmt.Errorf("no addresses available")
|
||||
}
|
||||
|
||||
return sorted[0], nil
|
||||
}
|
||||
|
||||
// sortSlice returns the given slice sorted by the given comparator.
|
||||
func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
|
||||
sorted := make([]Item, len(items))
|
||||
|
||||
@ -282,7 +282,7 @@ func (user *User) Match(query string) bool {
|
||||
func (user *User) Emails() []string {
|
||||
return safe.RLockRet(func() []string {
|
||||
addresses := xslices.Filter(maps.Values(user.apiAddrs), func(addr proton.Address) bool {
|
||||
return addr.Status == proton.AddressStatusEnabled
|
||||
return addr.Status == proton.AddressStatusEnabled && addr.Type != proton.AddressTypeExternal
|
||||
})
|
||||
|
||||
slices.SortFunc(addresses, func(a, b proton.Address) bool {
|
||||
@ -586,6 +586,8 @@ func (user *User) Close() {
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
}
|
||||
|
||||
user.updateCh = make(map[string]*async.QueuedChannel[imap.Update])
|
||||
}, user.updateChLock)
|
||||
|
||||
// Close the user's notify channel.
|
||||
@ -690,87 +692,89 @@ func (user *User) doEventPoll(ctx context.Context) error {
|
||||
user.eventLock.Lock()
|
||||
defer user.eventLock.Unlock()
|
||||
|
||||
event, more, err := user.client.GetEvent(ctx, user.vault.EventID())
|
||||
gpaEvents, more, err := user.client.GetEvent(ctx, user.vault.EventID())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get event (caused by %T): %w", internal.ErrCause(err), err)
|
||||
}
|
||||
|
||||
// If the event ID hasn't changed, there are no new events.
|
||||
if event.EventID == user.vault.EventID() {
|
||||
if gpaEvents[len(gpaEvents)-1].EventID == user.vault.EventID() {
|
||||
user.log.Debug("No new API events")
|
||||
return nil
|
||||
}
|
||||
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"old": user.vault.EventID(),
|
||||
"new": event,
|
||||
}).Info("Received new API event")
|
||||
for _, event := range gpaEvents {
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"old": user.vault.EventID(),
|
||||
"new": event,
|
||||
}).Info("Received new API event")
|
||||
|
||||
// Handle the event.
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
// If the error is a context cancellation, return error to retry later.
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return fmt.Errorf("failed to handle event due to context cancellation: %w", err)
|
||||
// Handle the event.
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
// If the error is a context cancellation, return error to retry later.
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return fmt.Errorf("failed to handle event due to context cancellation: %w", err)
|
||||
}
|
||||
|
||||
// If the error is a network error, return error to retry later.
|
||||
if netErr := new(proton.NetError); errors.As(err, &netErr) {
|
||||
return fmt.Errorf("failed to handle event due to network issue: %w", err)
|
||||
}
|
||||
|
||||
// Catch all for uncategorized net errors that may slip through.
|
||||
if netErr := new(net.OpError); errors.As(err, &netErr) {
|
||||
return fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err)
|
||||
}
|
||||
|
||||
// In case a json decode error slips through.
|
||||
if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) {
|
||||
user.eventCh.Enqueue(events.UncategorizedEventError{
|
||||
UserID: user.ID(),
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to handle event due to JSON issue: %w", err)
|
||||
}
|
||||
|
||||
// If the error is an unexpected EOF, return error to retry later.
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return fmt.Errorf("failed to handle event due to EOF: %w", err)
|
||||
}
|
||||
|
||||
// If the error is a server-side issue, return error to retry later.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
|
||||
return fmt.Errorf("failed to handle event due to server error: %w", err)
|
||||
}
|
||||
|
||||
// Otherwise, the error is a client-side issue; notify bridge to handle it.
|
||||
user.log.WithField("event", event).Warn("Failed to handle API event")
|
||||
|
||||
user.eventCh.Enqueue(events.UserBadEvent{
|
||||
UserID: user.ID(),
|
||||
OldEventID: user.vault.EventID(),
|
||||
NewEventID: event.EventID,
|
||||
EventInfo: event.String(),
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to handle event due to client error: %w", err)
|
||||
}
|
||||
|
||||
// If the error is a network error, return error to retry later.
|
||||
if netErr := new(proton.NetError); errors.As(err, &netErr) {
|
||||
return fmt.Errorf("failed to handle event due to network issue: %w", err)
|
||||
}
|
||||
user.log.WithField("event", event).Debug("Handled API event")
|
||||
|
||||
// Catch all for uncategorized net errors that may slip through.
|
||||
if netErr := new(net.OpError); errors.As(err, &netErr) {
|
||||
return fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err)
|
||||
}
|
||||
|
||||
// In case a json decode error slips through.
|
||||
if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) {
|
||||
user.eventCh.Enqueue(events.UncategorizedEventError{
|
||||
// Update the event ID in the vault. If this fails, notify bridge to handle it.
|
||||
if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
user.eventCh.Enqueue(events.UserBadEvent{
|
||||
UserID: user.ID(),
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to handle event due to JSON issue: %w", err)
|
||||
return fmt.Errorf("failed to update event ID: %w", err)
|
||||
}
|
||||
|
||||
// If the error is an unexpected EOF, return error to retry later.
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return fmt.Errorf("failed to handle event due to EOF: %w", err)
|
||||
}
|
||||
|
||||
// If the error is a server-side issue, return error to retry later.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
|
||||
return fmt.Errorf("failed to handle event due to server error: %w", err)
|
||||
}
|
||||
|
||||
// Otherwise, the error is a client-side issue; notify bridge to handle it.
|
||||
user.log.WithField("event", event).Warn("Failed to handle API event")
|
||||
|
||||
user.eventCh.Enqueue(events.UserBadEvent{
|
||||
UserID: user.ID(),
|
||||
OldEventID: user.vault.EventID(),
|
||||
NewEventID: event.EventID,
|
||||
EventInfo: event.String(),
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to handle event due to client error: %w", err)
|
||||
user.log.WithField("eventID", event.EventID).Debug("Updated event ID in vault")
|
||||
}
|
||||
|
||||
user.log.WithField("event", event).Debug("Handled API event")
|
||||
|
||||
// Update the event ID in the vault. If this fails, notify bridge to handle it.
|
||||
if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
user.eventCh.Enqueue(events.UserBadEvent{
|
||||
UserID: user.ID(),
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to update event ID: %w", err)
|
||||
}
|
||||
|
||||
user.log.WithField("eventID", event.EventID).Debug("Updated event ID in vault")
|
||||
|
||||
if more {
|
||||
user.goPollAPIEvents(false)
|
||||
}
|
||||
|
||||
@ -30,7 +30,12 @@ import (
|
||||
// If CertPEMPath is set, it will attempt to read the certificate from the file.
|
||||
// Otherwise, or on read/validation failure, it will return the certificate from the vault.
|
||||
func (vault *Vault) GetBridgeTLSCert() ([]byte, []byte) {
|
||||
if certPath, keyPath := vault.get().Certs.CustomCertPath, vault.get().Certs.CustomKeyPath; certPath != "" && keyPath != "" {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
certs := vault.getUnsafe().Certs
|
||||
|
||||
if certPath, keyPath := certs.CustomCertPath, certs.CustomKeyPath; certPath != "" && keyPath != "" {
|
||||
if certPEM, keyPEM, err := readPEMCert(certPath, keyPath); err == nil {
|
||||
return certPEM, keyPEM
|
||||
}
|
||||
@ -38,7 +43,7 @@ func (vault *Vault) GetBridgeTLSCert() ([]byte, []byte) {
|
||||
logrus.Error("Failed to read certificate from file, using default")
|
||||
}
|
||||
|
||||
return vault.get().Certs.Bridge.Cert, vault.get().Certs.Bridge.Key
|
||||
return certs.Bridge.Cert, certs.Bridge.Key
|
||||
}
|
||||
|
||||
// SetBridgeTLSCertPath sets the path to PEM-encoded certificates for the bridge.
|
||||
@ -47,7 +52,7 @@ func (vault *Vault) SetBridgeTLSCertPath(certPath, keyPath string) error {
|
||||
return fmt.Errorf("invalid certificate: %w", err)
|
||||
}
|
||||
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Certs.CustomCertPath = certPath
|
||||
data.Certs.CustomKeyPath = keyPath
|
||||
})
|
||||
@ -55,18 +60,18 @@ func (vault *Vault) SetBridgeTLSCertPath(certPath, keyPath string) error {
|
||||
|
||||
// SetBridgeTLSCertKey sets the path to PEM-encoded certificates for the bridge.
|
||||
func (vault *Vault) SetBridgeTLSCertKey(cert, key []byte) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Certs.Bridge.Cert = cert
|
||||
data.Certs.Bridge.Key = key
|
||||
})
|
||||
}
|
||||
|
||||
func (vault *Vault) GetCertsInstalled() bool {
|
||||
return vault.get().Certs.Installed
|
||||
return vault.getSafe().Certs.Installed
|
||||
}
|
||||
|
||||
func (vault *Vault) SetCertsInstalled(installed bool) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Certs.Installed = installed
|
||||
})
|
||||
}
|
||||
|
||||
@ -18,11 +18,11 @@
|
||||
package vault
|
||||
|
||||
func (vault *Vault) GetCookies() ([]byte, error) {
|
||||
return vault.get().Cookies, nil
|
||||
return vault.getSafe().Cookies, nil
|
||||
}
|
||||
|
||||
func (vault *Vault) SetCookies(cookies []byte) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Cookies = cookies
|
||||
})
|
||||
}
|
||||
|
||||
46
internal/vault/password_archive.go
Normal file
46
internal/vault/password_archive.go
Normal file
@ -0,0 +1,46 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package vault
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// set archives the password for an email address, overwriting any existing archived value.
|
||||
func (p *PasswordArchive) set(emailAddress string, password []byte) {
|
||||
if p.Archive == nil {
|
||||
p.Archive = make(map[string][]byte)
|
||||
}
|
||||
|
||||
p.Archive[emailHashString(emailAddress)] = password
|
||||
}
|
||||
|
||||
// get retrieves the archived password for an email address, or nil if not found.
|
||||
func (p *PasswordArchive) get(emailAddress string) []byte {
|
||||
if p.Archive == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.Archive[emailHashString(emailAddress)]
|
||||
}
|
||||
|
||||
// emailHashString returns a hash string for an email address as a hexadecimal string.
|
||||
func emailHashString(emailAddress string) string {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(emailAddress)))
|
||||
}
|
||||
@ -33,72 +33,72 @@ const (
|
||||
|
||||
// GetIMAPPort sets the port that the IMAP server should listen on.
|
||||
func (vault *Vault) GetIMAPPort() int {
|
||||
return vault.get().Settings.IMAPPort
|
||||
return vault.getSafe().Settings.IMAPPort
|
||||
}
|
||||
|
||||
// SetIMAPPort sets the port that the IMAP server should listen on.
|
||||
func (vault *Vault) SetIMAPPort(port int) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.IMAPPort = port
|
||||
})
|
||||
}
|
||||
|
||||
// GetSMTPPort sets the port that the SMTP server should listen on.
|
||||
func (vault *Vault) GetSMTPPort() int {
|
||||
return vault.get().Settings.SMTPPort
|
||||
return vault.getSafe().Settings.SMTPPort
|
||||
}
|
||||
|
||||
// SetSMTPPort sets the port that the SMTP server should listen on.
|
||||
func (vault *Vault) SetSMTPPort(port int) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.SMTPPort = port
|
||||
})
|
||||
}
|
||||
|
||||
// GetIMAPSSL sets whether the IMAP server should use SSL.
|
||||
func (vault *Vault) GetIMAPSSL() bool {
|
||||
return vault.get().Settings.IMAPSSL
|
||||
return vault.getSafe().Settings.IMAPSSL
|
||||
}
|
||||
|
||||
// SetIMAPSSL sets whether the IMAP server should use SSL.
|
||||
func (vault *Vault) SetIMAPSSL(ssl bool) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.IMAPSSL = ssl
|
||||
})
|
||||
}
|
||||
|
||||
// GetSMTPSSL sets whether the SMTP server should use SSL.
|
||||
func (vault *Vault) GetSMTPSSL() bool {
|
||||
return vault.get().Settings.SMTPSSL
|
||||
return vault.getSafe().Settings.SMTPSSL
|
||||
}
|
||||
|
||||
// SetSMTPSSL sets whether the SMTP server should use SSL.
|
||||
func (vault *Vault) SetSMTPSSL(ssl bool) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.SMTPSSL = ssl
|
||||
})
|
||||
}
|
||||
|
||||
// GetGluonCacheDir sets the directory where the gluon should store its data.
|
||||
func (vault *Vault) GetGluonCacheDir() string {
|
||||
return vault.get().Settings.GluonDir
|
||||
return vault.getSafe().Settings.GluonDir
|
||||
}
|
||||
|
||||
// SetGluonDir sets the directory where the gluon should store its data.
|
||||
func (vault *Vault) SetGluonDir(dir string) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.GluonDir = dir
|
||||
})
|
||||
}
|
||||
|
||||
// GetUpdateChannel sets the update channel.
|
||||
func (vault *Vault) GetUpdateChannel() updater.Channel {
|
||||
return vault.get().Settings.UpdateChannel
|
||||
return vault.getSafe().Settings.UpdateChannel
|
||||
}
|
||||
|
||||
// SetUpdateChannel sets the update channel.
|
||||
func (vault *Vault) SetUpdateChannel(channel updater.Channel) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.UpdateChannel = channel
|
||||
})
|
||||
}
|
||||
@ -106,7 +106,7 @@ func (vault *Vault) SetUpdateChannel(channel updater.Channel) error {
|
||||
// GetUpdateRollout sets the update rollout.
|
||||
func (vault *Vault) GetUpdateRollout() float64 {
|
||||
// The rollout value 0.6046602879796196 is forbidden. The RNG was not seeded when it was picked (GODT-2319).
|
||||
rollout := vault.get().Settings.UpdateRollout
|
||||
rollout := vault.getSafe().Settings.UpdateRollout
|
||||
if math.Abs(rollout-ForbiddenRollout) >= 0.00000001 {
|
||||
return rollout
|
||||
}
|
||||
@ -120,110 +120,110 @@ func (vault *Vault) GetUpdateRollout() float64 {
|
||||
|
||||
// SetUpdateRollout sets the update rollout.
|
||||
func (vault *Vault) SetUpdateRollout(rollout float64) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.UpdateRollout = rollout
|
||||
})
|
||||
}
|
||||
|
||||
// GetColorScheme sets the color scheme to be used by the bridge GUI.
|
||||
func (vault *Vault) GetColorScheme() string {
|
||||
return vault.get().Settings.ColorScheme
|
||||
return vault.getSafe().Settings.ColorScheme
|
||||
}
|
||||
|
||||
// SetColorScheme sets the color scheme to be used by the bridge GUI.
|
||||
func (vault *Vault) SetColorScheme(colorScheme string) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.ColorScheme = colorScheme
|
||||
})
|
||||
}
|
||||
|
||||
// GetProxyAllowed sets whether the bridge is allowed to use alternative routing.
|
||||
func (vault *Vault) GetProxyAllowed() bool {
|
||||
return vault.get().Settings.ProxyAllowed
|
||||
return vault.getSafe().Settings.ProxyAllowed
|
||||
}
|
||||
|
||||
// SetProxyAllowed sets whether the bridge is allowed to use alternative routing.
|
||||
func (vault *Vault) SetProxyAllowed(allowed bool) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.ProxyAllowed = allowed
|
||||
})
|
||||
}
|
||||
|
||||
// GetShowAllMail sets whether the bridge should show the All Mail folder.
|
||||
func (vault *Vault) GetShowAllMail() bool {
|
||||
return vault.get().Settings.ShowAllMail
|
||||
return vault.getSafe().Settings.ShowAllMail
|
||||
}
|
||||
|
||||
// SetShowAllMail sets whether the bridge should show the All Mail folder.
|
||||
func (vault *Vault) SetShowAllMail(showAllMail bool) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.ShowAllMail = showAllMail
|
||||
})
|
||||
}
|
||||
|
||||
// GetAutostart sets whether the bridge should autostart.
|
||||
func (vault *Vault) GetAutostart() bool {
|
||||
return vault.get().Settings.Autostart
|
||||
return vault.getSafe().Settings.Autostart
|
||||
}
|
||||
|
||||
// SetAutostart sets whether the bridge should autostart.
|
||||
func (vault *Vault) SetAutostart(autostart bool) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.Autostart = autostart
|
||||
})
|
||||
}
|
||||
|
||||
// GetAutoUpdate sets whether the bridge should automatically update.
|
||||
func (vault *Vault) GetAutoUpdate() bool {
|
||||
return vault.get().Settings.AutoUpdate
|
||||
return vault.getSafe().Settings.AutoUpdate
|
||||
}
|
||||
|
||||
// SetAutoUpdate sets whether the bridge should automatically update.
|
||||
func (vault *Vault) SetAutoUpdate(autoUpdate bool) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.AutoUpdate = autoUpdate
|
||||
})
|
||||
}
|
||||
|
||||
// GetTelemetryDisabled checks whether telemetry is disabled.
|
||||
func (vault *Vault) GetTelemetryDisabled() bool {
|
||||
return vault.get().Settings.TelemetryDisabled
|
||||
return vault.getSafe().Settings.TelemetryDisabled
|
||||
}
|
||||
|
||||
// SetTelemetryDisabled sets whether telemetry is disabled.
|
||||
func (vault *Vault) SetTelemetryDisabled(telemetryDisabled bool) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.TelemetryDisabled = telemetryDisabled
|
||||
})
|
||||
}
|
||||
|
||||
// GetLastVersion returns the last version of the bridge that was run.
|
||||
func (vault *Vault) GetLastVersion() *semver.Version {
|
||||
return semver.MustParse(vault.get().Settings.LastVersion)
|
||||
return semver.MustParse(vault.getSafe().Settings.LastVersion)
|
||||
}
|
||||
|
||||
// SetLastVersion sets the last version of the bridge that was run.
|
||||
func (vault *Vault) SetLastVersion(version *semver.Version) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.LastVersion = version.String()
|
||||
})
|
||||
}
|
||||
|
||||
// GetFirstStart returns whether this is the first time the bridge has been started.
|
||||
func (vault *Vault) GetFirstStart() bool {
|
||||
return vault.get().Settings.FirstStart
|
||||
return vault.getSafe().Settings.FirstStart
|
||||
}
|
||||
|
||||
// SetFirstStart sets whether this is the first time the bridge has been started.
|
||||
func (vault *Vault) SetFirstStart(firstStart bool) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.FirstStart = firstStart
|
||||
})
|
||||
}
|
||||
|
||||
// GetMaxSyncMemory returns the maximum amount of memory the sync process should use.
|
||||
func (vault *Vault) GetMaxSyncMemory() uint64 {
|
||||
v := vault.get().Settings.MaxSyncMemory
|
||||
v := vault.getSafe().Settings.MaxSyncMemory
|
||||
// can be zero if never written to vault before.
|
||||
if v == 0 {
|
||||
return DefaultMaxSyncMemory
|
||||
@ -234,14 +234,14 @@ func (vault *Vault) GetMaxSyncMemory() uint64 {
|
||||
|
||||
// SetMaxSyncMemory sets the maximum amount of memory the sync process should use.
|
||||
func (vault *Vault) SetMaxSyncMemory(maxMemory uint64) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.MaxSyncMemory = maxMemory
|
||||
})
|
||||
}
|
||||
|
||||
// GetLastUserAgent returns the last user agent recorded by bridge.
|
||||
func (vault *Vault) GetLastUserAgent() string {
|
||||
v := vault.get().Settings.LastUserAgent
|
||||
v := vault.getSafe().Settings.LastUserAgent
|
||||
|
||||
// Handle case where there may be no value.
|
||||
if len(v) == 0 {
|
||||
@ -253,19 +253,19 @@ func (vault *Vault) GetLastUserAgent() string {
|
||||
|
||||
// SetLastUserAgent store the last user agent recorded by bridge.
|
||||
func (vault *Vault) SetLastUserAgent(userAgent string) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.LastUserAgent = userAgent
|
||||
})
|
||||
}
|
||||
|
||||
// GetLastHeartbeatSent returns the last time heartbeat was sent.
|
||||
func (vault *Vault) GetLastHeartbeatSent() time.Time {
|
||||
return vault.get().Settings.LastHeartbeatSent
|
||||
return vault.getSafe().Settings.LastHeartbeatSent
|
||||
}
|
||||
|
||||
// SetLastHeartbeatSent store the last time heartbeat was sent.
|
||||
func (vault *Vault) SetLastHeartbeatSent(timestamp time.Time) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modSafe(func(data *Data) {
|
||||
data.Settings.LastHeartbeatSent = timestamp
|
||||
})
|
||||
}
|
||||
|
||||
@ -238,3 +238,30 @@ func TestVault_Settings_LastUserAgent(t *testing.T) {
|
||||
// Check the default first start value.
|
||||
require.Equal(t, vault.DefaultUserAgent, s.GetLastUserAgent())
|
||||
}
|
||||
|
||||
func Test_Settings_PasswordArchive(t *testing.T) {
|
||||
// Create a new test vault.
|
||||
s := newVault(t)
|
||||
|
||||
// The store should have no users.
|
||||
require.Empty(t, s.GetUserIDs())
|
||||
|
||||
// Create a new user.
|
||||
user, err := s.AddUser("userID1", "username1", "username1@pm.me", "authUID1", "authRef1", []byte("keyPass1"))
|
||||
require.NoError(t, err)
|
||||
bridgePass := user.BridgePass()
|
||||
|
||||
// Remove the user.
|
||||
require.NoError(t, user.Close())
|
||||
require.NoError(t, s.DeleteUser("userID1"))
|
||||
|
||||
// Add a different user. Another password is generated.
|
||||
user, err = s.AddUser("userID2", "username2", "username2@pm.me", "authUID2", "authRef2", []byte("keyPass2"))
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, user.BridgePass(), bridgePass)
|
||||
|
||||
// Add the first user again. The password is restored.
|
||||
user, err = s.AddUser("userID1", "username1", "username1@pm.me", "authUID1", "authRef1", []byte("keyPass1"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user.BridgePass(), bridgePass)
|
||||
}
|
||||
|
||||
@ -48,11 +48,7 @@ func unmarshalFile[T any](gcm cipher.AEAD, b []byte, data *T) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := msgpack.Unmarshal(dec, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return msgpack.Unmarshal(dec, data)
|
||||
}
|
||||
|
||||
func marshalFile[T any](gcm cipher.AEAD, t T) ([]byte, error) {
|
||||
|
||||
25
internal/vault/types_password_archive.go
Normal file
25
internal/vault/types_password_archive.go
Normal file
@ -0,0 +1,25 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package vault
|
||||
|
||||
// PasswordArchive maps a list email address hashes to passwords.
|
||||
// The type is not defined as a map alias to prevent having to handle nil default values when vault was created by an older version of the application.
|
||||
type PasswordArchive struct {
|
||||
// we store the SHA-256 sum as string for readability and JSON marshalling of map[[32]byte][]byte will not be allowed, thus breaking vault-editor.
|
||||
Archive map[string][]byte
|
||||
}
|
||||
@ -53,6 +53,8 @@ type Settings struct {
|
||||
|
||||
LastHeartbeatSent time.Time
|
||||
|
||||
PasswordArchive PasswordArchive
|
||||
|
||||
// **WARNING**: These entry can't be removed until they vault has proper migration support.
|
||||
SyncWorkers int
|
||||
SyncAttPool int
|
||||
@ -105,5 +107,7 @@ func newDefaultSettings(gluonDir string) Settings {
|
||||
|
||||
LastUserAgent: DefaultUserAgent,
|
||||
LastHeartbeatSent: time.Time{},
|
||||
|
||||
PasswordArchive: PasswordArchive{},
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ func (status SyncStatus) IsComplete() bool {
|
||||
return status.HasLabels && status.HasMessages
|
||||
}
|
||||
|
||||
func newDefaultUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) UserData {
|
||||
func newDefaultUser(userID, username, primaryEmail, authUID, authRef string, keyPass, bridgePass []byte) UserData {
|
||||
return UserData{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
@ -82,7 +82,7 @@ func newDefaultUser(userID, username, primaryEmail, authUID, authRef string, key
|
||||
GluonKey: newRandomToken(32),
|
||||
GluonIDs: make(map[string]string),
|
||||
UIDValidity: make(map[string]imap.UID),
|
||||
BridgePass: newRandomToken(16),
|
||||
BridgePass: bridgePass,
|
||||
AddressMode: CombinedMode,
|
||||
|
||||
AuthUID: authUID,
|
||||
|
||||
@ -122,6 +122,14 @@ func (user *User) SetAuth(authUID, authRef string) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (user *User) setAuthAndKeyPassUnsafe(authUID, authRef string, keyPass []byte) error {
|
||||
return user.vault.modUserUnsafe(user.userID, func(userData *UserData) {
|
||||
userData.AuthRef = authRef
|
||||
userData.AuthUID = authUID
|
||||
userData.KeyPass = keyPass
|
||||
})
|
||||
}
|
||||
|
||||
// KeyPass returns the user's (salted) key password.
|
||||
func (user *User) KeyPass() []byte {
|
||||
return user.vault.getUser(user.userID).KeyPass
|
||||
|
||||
@ -40,11 +40,11 @@ type Vault struct {
|
||||
path string
|
||||
gcm cipher.AEAD
|
||||
|
||||
enc []byte
|
||||
encLock sync.RWMutex
|
||||
enc []byte
|
||||
|
||||
ref map[string]int
|
||||
refLock sync.Mutex
|
||||
ref map[string]int
|
||||
|
||||
lock sync.RWMutex
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
@ -79,14 +79,46 @@ func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHan
|
||||
|
||||
// GetUserIDs returns the user IDs and usernames of all users in the vault.
|
||||
func (vault *Vault) GetUserIDs() []string {
|
||||
return xslices.Map(vault.get().Users, func(user UserData) string {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
return xslices.Map(vault.getUnsafe().Users, func(user UserData) string {
|
||||
return user.UserID
|
||||
})
|
||||
}
|
||||
|
||||
func (vault *Vault) getUsers() ([]*User, error) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
users := vault.getUnsafe().Users
|
||||
|
||||
result := make([]*User, 0, len(users))
|
||||
|
||||
for _, user := range users {
|
||||
u, err := vault.newUserUnsafe(user.UserID)
|
||||
if err != nil {
|
||||
for _, v := range result {
|
||||
if err := v.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Fait to close user after failed get")
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, u)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// HasUser returns true if the vault contains a user with the given ID.
|
||||
func (vault *Vault) HasUser(userID string) bool {
|
||||
return xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
return xslices.IndexFunc(vault.getUnsafe().Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
}) >= 0
|
||||
}
|
||||
@ -106,46 +138,72 @@ func (vault *Vault) GetUser(userID string, fn func(*User)) error {
|
||||
|
||||
// NewUser returns a new vault user. It must be closed before it can be deleted.
|
||||
func (vault *Vault) NewUser(userID string) (*User, error) {
|
||||
if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.newUserUnsafe(userID)
|
||||
}
|
||||
|
||||
func (vault *Vault) newUserUnsafe(userID string) (*User, error) {
|
||||
if idx := xslices.IndexFunc(vault.getUnsafe().Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
}); idx < 0 {
|
||||
return nil, errors.New("no such user")
|
||||
}
|
||||
|
||||
return vault.attachUser(userID), nil
|
||||
return vault.attachUserUnsafe(userID), nil
|
||||
}
|
||||
|
||||
// ForUser executes a callback for each user in the vault.
|
||||
func (vault *Vault) ForUser(parallelism int, fn func(*User) error) error {
|
||||
userIDs := vault.GetUserIDs()
|
||||
users, err := vault.getUsers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return parallel.DoContext(context.Background(), parallelism, len(userIDs), func(_ context.Context, idx int) error {
|
||||
r := parallel.DoContext(context.Background(), parallelism, len(users), func(_ context.Context, idx int) error {
|
||||
defer async.HandlePanic(vault.panicHandler)
|
||||
|
||||
user, err := vault.NewUser(userIDs[idx])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = user.Close() }()
|
||||
|
||||
user := users[idx]
|
||||
return fn(user)
|
||||
})
|
||||
|
||||
for _, u := range users {
|
||||
if err := u.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close user after ForUser")
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// AddUser creates a new user in the vault with the given ID, username and password.
|
||||
// A bridge password and gluon key are generated using the package's token generator.
|
||||
// A gluon key is generated using the package's token generator. If a password is found in the password archive for this user,
|
||||
// it is restored, otherwise a new bridge password is generated using the package's token generator.
|
||||
func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.addUserUnsafe(userID, username, primaryEmail, authUID, authRef, keyPass)
|
||||
}
|
||||
|
||||
func (vault *Vault) addUserUnsafe(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, error) {
|
||||
logrus.WithField("userID", userID).Info("Adding vault user")
|
||||
|
||||
var exists bool
|
||||
|
||||
if err := vault.mod(func(data *Data) {
|
||||
if err := vault.modUnsafe(func(data *Data) {
|
||||
if idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
}); idx >= 0 {
|
||||
exists = true
|
||||
} else {
|
||||
data.Users = append(data.Users, newDefaultUser(userID, username, primaryEmail, authUID, authRef, keyPass))
|
||||
bridgePass := data.Settings.PasswordArchive.get(primaryEmail)
|
||||
if len(bridgePass) == 0 {
|
||||
bridgePass = newRandomToken(16)
|
||||
}
|
||||
|
||||
data.Users = append(data.Users, newDefaultUser(userID, username, primaryEmail, authUID, authRef, keyPass, bridgePass))
|
||||
}
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
@ -155,13 +213,42 @@ func (vault *Vault) AddUser(userID, username, primaryEmail, authUID, authRef str
|
||||
return nil, errors.New("user already exists")
|
||||
}
|
||||
|
||||
return vault.NewUser(userID)
|
||||
return vault.attachUserUnsafe(userID), nil
|
||||
}
|
||||
|
||||
// GetOrAddUser retrieves an existing user and updates the authRef and keyPass or creates a new user. Returns
|
||||
// the user and whether the user did not exist before.
|
||||
func (vault *Vault) GetOrAddUser(userID, username, primaryEmail, authUID, authRef string, keyPass []byte) (*User, bool, error) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
{
|
||||
users := vault.getUnsafe().Users
|
||||
|
||||
idx := xslices.IndexFunc(users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
})
|
||||
|
||||
if idx >= 0 {
|
||||
user := vault.attachUserUnsafe(userID)
|
||||
|
||||
if err := user.setAuthAndKeyPassUnsafe(authUID, authRef, keyPass); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return user, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
u, err := vault.addUserUnsafe(userID, username, primaryEmail, authUID, authRef, keyPass)
|
||||
|
||||
return u, true, err
|
||||
}
|
||||
|
||||
// DeleteUser removes the given user from the vault.
|
||||
func (vault *Vault) DeleteUser(userID string) error {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).Info("Deleting vault user")
|
||||
|
||||
@ -169,7 +256,7 @@ func (vault *Vault) DeleteUser(userID string) error {
|
||||
return fmt.Errorf("user %s is currently in use", userID)
|
||||
}
|
||||
|
||||
return vault.mod(func(data *Data) {
|
||||
return vault.modUnsafe(func(data *Data) {
|
||||
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
})
|
||||
@ -177,23 +264,32 @@ func (vault *Vault) DeleteUser(userID string) error {
|
||||
if idx < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
data.Settings.PasswordArchive.set(data.Users[idx].PrimaryEmail, data.Users[idx].BridgePass)
|
||||
data.Users = append(data.Users[:idx], data.Users[idx+1:]...)
|
||||
})
|
||||
}
|
||||
|
||||
func (vault *Vault) Migrated() bool {
|
||||
return vault.get().Migrated
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
return vault.getUnsafe().Migrated
|
||||
}
|
||||
|
||||
func (vault *Vault) SetMigrated() error {
|
||||
return vault.mod(func(data *Data) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.modUnsafe(func(data *Data) {
|
||||
data.Migrated = true
|
||||
})
|
||||
}
|
||||
|
||||
func (vault *Vault) Reset(gluonDir string) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.modUnsafe(func(data *Data) {
|
||||
*data = newDefaultData(gluonDir)
|
||||
})
|
||||
}
|
||||
@ -203,8 +299,8 @@ func (vault *Vault) Path() string {
|
||||
}
|
||||
|
||||
func (vault *Vault) Close() error {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
if len(vault.ref) > 0 {
|
||||
return errors.New("vault is still in use")
|
||||
@ -215,10 +311,7 @@ func (vault *Vault) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vault *Vault) attachUser(userID string) *User {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
|
||||
func (vault *Vault) attachUserUnsafe(userID string) *User {
|
||||
logrus.WithField("userID", userID).Trace("Attaching vault user")
|
||||
|
||||
vault.ref[userID]++
|
||||
@ -230,8 +323,8 @@ func (vault *Vault) attachUser(userID string) *User {
|
||||
}
|
||||
|
||||
func (vault *Vault) detachUser(userID string) error {
|
||||
vault.refLock.Lock()
|
||||
defer vault.refLock.Unlock()
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).Trace("Detaching vault user")
|
||||
|
||||
@ -283,10 +376,14 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
||||
}, corrupt, nil
|
||||
}
|
||||
|
||||
func (vault *Vault) get() Data {
|
||||
vault.encLock.RLock()
|
||||
defer vault.encLock.RUnlock()
|
||||
func (vault *Vault) getSafe() Data {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
return vault.getUnsafe()
|
||||
}
|
||||
|
||||
func (vault *Vault) getUnsafe() Data {
|
||||
var data Data
|
||||
|
||||
if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil {
|
||||
@ -296,10 +393,14 @@ func (vault *Vault) get() Data {
|
||||
return data
|
||||
}
|
||||
|
||||
func (vault *Vault) mod(fn func(data *Data)) error {
|
||||
vault.encLock.Lock()
|
||||
defer vault.encLock.Unlock()
|
||||
func (vault *Vault) modSafe(fn func(data *Data)) error {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.modUnsafe(fn)
|
||||
}
|
||||
|
||||
func (vault *Vault) modUnsafe(fn func(data *Data)) error {
|
||||
var data Data
|
||||
|
||||
if err := unmarshalFile(vault.gcm, vault.enc, &data); err != nil {
|
||||
@ -319,13 +420,31 @@ func (vault *Vault) mod(fn func(data *Data)) error {
|
||||
}
|
||||
|
||||
func (vault *Vault) getUser(userID string) UserData {
|
||||
return vault.get().Users[xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||
vault.lock.RLock()
|
||||
defer vault.lock.RUnlock()
|
||||
|
||||
users := vault.getUnsafe().Users
|
||||
|
||||
idx := xslices.IndexFunc(users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
})]
|
||||
})
|
||||
|
||||
if idx < 0 {
|
||||
panic("Unknown user")
|
||||
}
|
||||
|
||||
return users[idx]
|
||||
}
|
||||
|
||||
func (vault *Vault) modUser(userID string, fn func(userData *UserData)) error {
|
||||
return vault.mod(func(data *Data) {
|
||||
vault.lock.Lock()
|
||||
defer vault.lock.Unlock()
|
||||
|
||||
return vault.modUserUnsafe(userID, fn)
|
||||
}
|
||||
|
||||
func (vault *Vault) modUserUnsafe(userID string, fn func(userData *UserData)) error {
|
||||
return vault.modUnsafe(func(data *Data) {
|
||||
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||
return user.UserID == userID
|
||||
})
|
||||
|
||||
@ -24,7 +24,7 @@ import (
|
||||
)
|
||||
|
||||
func (vault *Vault) ImportJSON(dec []byte) {
|
||||
vault.mod(func(data *Data) {
|
||||
vault.modSafe(func(data *Data) {
|
||||
if err := json.Unmarshal(dec, data); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -32,7 +32,7 @@ func (vault *Vault) ImportJSON(dec []byte) {
|
||||
}
|
||||
|
||||
func (vault *Vault) ExportJSON() []byte {
|
||||
enc, err := json.MarshalIndent(vault.get(), "", " ")
|
||||
enc, err := json.MarshalIndent(vault.getSafe(), "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@ func (v *Versioner) RemoveOldVersions() error {
|
||||
}
|
||||
|
||||
// RemoveOtherVersions removes all but the specific provided app version.
|
||||
func (v *Versioner) RemoveOtherVersions(versionToKeep *semver.Version) error {
|
||||
func (v *Versioner) RemoveOtherVersions(_ *semver.Version) error {
|
||||
// darwin does not use the versioner; removal is a noop.
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user