mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 12:46:46 +00:00
Other: Safer user types
This commit is contained in:
2
go.mod
2
go.mod
@ -38,7 +38,7 @@ require (
|
|||||||
github.com/sirupsen/logrus v1.9.0
|
github.com/sirupsen/logrus v1.9.0
|
||||||
github.com/stretchr/testify v1.8.0
|
github.com/stretchr/testify v1.8.0
|
||||||
github.com/urfave/cli/v2 v2.16.3
|
github.com/urfave/cli/v2 v2.16.3
|
||||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012095146-bd94443eeb8e
|
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455
|
||||||
golang.org/x/exp v0.0.0-20220921164117-439092de6870
|
golang.org/x/exp v0.0.0-20220921164117-439092de6870
|
||||||
golang.org/x/net v0.1.0
|
golang.org/x/net v0.1.0
|
||||||
golang.org/x/sys v0.1.0
|
golang.org/x/sys v0.1.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@ -397,8 +397,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr
|
|||||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||||
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
|
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
|
||||||
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
|
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
|
||||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012095146-bd94443eeb8e h1:UBgcmAYZ45ylLlfmc8/0evP40LwVthBHRoMgGqt4YV8=
|
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455 h1:TWNT/rPSUGjYsNTwWx5Fd029LipSv+h1XuBwFSd5cAo=
|
||||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012095146-bd94443eeb8e/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4=
|
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4=
|
||||||
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
||||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/cookies"
|
"github.com/ProtonMail/proton-bridge/v2/internal/cookies"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
|
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
|
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
|
||||||
bridgeCLI "github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli"
|
bridgeCLI "github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc"
|
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc"
|
||||||
@ -23,6 +24,7 @@ import (
|
|||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Visible flags
|
||||||
const (
|
const (
|
||||||
flagCPUProfile = "cpu-prof"
|
flagCPUProfile = "cpu-prof"
|
||||||
flagCPUProfileShort = "p"
|
flagCPUProfileShort = "p"
|
||||||
@ -40,8 +42,10 @@ const (
|
|||||||
|
|
||||||
flagLogIMAP = "log-imap"
|
flagLogIMAP = "log-imap"
|
||||||
flagLogSMTP = "log-smtp"
|
flagLogSMTP = "log-smtp"
|
||||||
|
)
|
||||||
|
|
||||||
// Hidden flags
|
// Hidden flags
|
||||||
|
const (
|
||||||
flagLauncher = "launcher"
|
flagLauncher = "launcher"
|
||||||
flagNoWindow = "no-window"
|
flagNoWindow = "no-window"
|
||||||
)
|
)
|
||||||
@ -137,7 +141,7 @@ func run(c *cli.Context) error {
|
|||||||
// Load the cookies from the vault.
|
// Load the cookies from the vault.
|
||||||
return withCookieJar(vault, func(cookieJar http.CookieJar) error {
|
return withCookieJar(vault, func(cookieJar http.CookieJar) error {
|
||||||
// Create a new bridge instance.
|
// Create a new bridge instance.
|
||||||
return withBridge(c, locations, identifier, reporter, vault, cookieJar, func(b *bridge.Bridge) error {
|
return withBridge(c, locations, identifier, reporter, vault, cookieJar, func(b *bridge.Bridge, eventCh <-chan events.Event) error {
|
||||||
if insecure {
|
if insecure {
|
||||||
logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted")
|
logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted")
|
||||||
b.PushError(bridge.ErrVaultInsecure)
|
b.PushError(bridge.ErrVaultInsecure)
|
||||||
@ -150,13 +154,13 @@ func run(c *cli.Context) error {
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case c.Bool(flagCLI):
|
case c.Bool(flagCLI):
|
||||||
return bridgeCLI.New(b).Loop()
|
return bridgeCLI.New(b, eventCh).Loop()
|
||||||
|
|
||||||
case c.Bool(flagNonInteractive):
|
case c.Bool(flagNonInteractive):
|
||||||
select {}
|
select {}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
service, err := grpc.NewService(crashHandler, restarter, locations, b, !c.Bool(flagNoWindow))
|
service, err := grpc.NewService(crashHandler, restarter, locations, b, eventCh, !c.Bool(flagNoWindow))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not create service: %w", err)
|
return fmt.Errorf("could not create service: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/dialer"
|
"github.com/ProtonMail/proton-bridge/v2/internal/dialer"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
|
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||||
@ -32,7 +33,7 @@ func withBridge(
|
|||||||
reporter *sentry.Reporter,
|
reporter *sentry.Reporter,
|
||||||
vault *vault.Vault,
|
vault *vault.Vault,
|
||||||
cookieJar http.CookieJar,
|
cookieJar http.CookieJar,
|
||||||
fn func(*bridge.Bridge) error,
|
fn func(*bridge.Bridge, <-chan events.Event) error,
|
||||||
) error {
|
) error {
|
||||||
// Get the current bridge version.
|
// Get the current bridge version.
|
||||||
version, err := semver.NewVersion(constants.Version)
|
version, err := semver.NewVersion(constants.Version)
|
||||||
@ -64,7 +65,7 @@ func withBridge(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new bridge.
|
// Create a new bridge.
|
||||||
bridge, err := bridge.New(
|
bridge, eventCh, err := bridge.New(
|
||||||
// The app stuff.
|
// The app stuff.
|
||||||
locations,
|
locations,
|
||||||
vault,
|
vault,
|
||||||
@ -96,7 +97,7 @@ func withBridge(
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return fn(bridge)
|
return fn(bridge, eventCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAutostarter() (*autostart.App, error) {
|
func newAutostarter() (*autostart.App, error) {
|
||||||
|
|||||||
@ -7,7 +7,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Masterminds/semver/v3"
|
"github.com/Masterminds/semver/v3"
|
||||||
@ -16,9 +15,10 @@ import (
|
|||||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
|
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/try"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@ -30,17 +30,15 @@ type Bridge struct {
|
|||||||
vault *vault.Vault
|
vault *vault.Vault
|
||||||
|
|
||||||
// users holds authorized users.
|
// users holds authorized users.
|
||||||
users map[string]*user.User
|
users *safe.Map[string, *user.User]
|
||||||
|
loadCh chan struct{}
|
||||||
|
loadWG try.Group
|
||||||
|
|
||||||
// api manages user API clients.
|
// api manages user API clients.
|
||||||
api *liteapi.Manager
|
api *liteapi.Manager
|
||||||
proxyCtl ProxyController
|
proxyCtl ProxyController
|
||||||
identifier Identifier
|
identifier Identifier
|
||||||
|
|
||||||
// watchers holds all registered event watchers.
|
|
||||||
watchers []*watcher.Watcher[events.Event]
|
|
||||||
watchersLock sync.RWMutex
|
|
||||||
|
|
||||||
// tlsConfig holds the bridge TLS config used by the IMAP and SMTP servers.
|
// tlsConfig holds the bridge TLS config used by the IMAP and SMTP servers.
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
|
|
||||||
@ -66,6 +64,9 @@ type Bridge struct {
|
|||||||
// locator is the bridge's locator.
|
// locator is the bridge's locator.
|
||||||
locator Locator
|
locator Locator
|
||||||
|
|
||||||
|
// watchers holds all registered event watchers.
|
||||||
|
watchers *safe.Slice[*watcher.Watcher[events.Event]]
|
||||||
|
|
||||||
// errors contains errors encountered during startup.
|
// errors contains errors encountered during startup.
|
||||||
errors []error
|
errors []error
|
||||||
|
|
||||||
@ -95,7 +96,7 @@ func New(
|
|||||||
|
|
||||||
logIMAPClient, logIMAPServer bool, // whether to log IMAP client/server activity
|
logIMAPClient, logIMAPServer bool, // whether to log IMAP client/server activity
|
||||||
logSMTP bool, // whether to log SMTP activity
|
logSMTP bool, // whether to log SMTP activity
|
||||||
) (*Bridge, error) {
|
) (*Bridge, <-chan events.Event, error) {
|
||||||
api := liteapi.New(
|
api := liteapi.New(
|
||||||
liteapi.WithHostURL(apiURL),
|
liteapi.WithHostURL(apiURL),
|
||||||
liteapi.WithAppVersion(constants.AppVersion(curVersion.Original())),
|
liteapi.WithAppVersion(constants.AppVersion(curVersion.Original())),
|
||||||
@ -105,54 +106,62 @@ func New(
|
|||||||
|
|
||||||
tlsConfig, err := loadTLSConfig(vault)
|
tlsConfig, err := loadTLSConfig(vault)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load TLS config: %w", err)
|
return nil, nil, fmt.Errorf("failed to load TLS config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
gluonDir, err := getGluonDir(vault)
|
gluonDir, err := getGluonDir(vault)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get Gluon directory: %w", err)
|
return nil, nil, fmt.Errorf("failed to get Gluon directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
smtpBackend, err := newSMTPBackend()
|
smtpBackend, err := newSMTPBackend()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create SMTP backend: %w", err)
|
return nil, nil, fmt.Errorf("failed to create SMTP backend: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
imapServer, err := newIMAPServer(gluonDir, curVersion, tlsConfig, logIMAPClient, logIMAPServer)
|
imapServer, err := newIMAPServer(gluonDir, curVersion, tlsConfig, logIMAPClient, logIMAPServer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create IMAP server: %w", err)
|
return nil, nil, fmt.Errorf("failed to create IMAP server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
focusService, err := focus.NewService()
|
focusService, err := focus.NewService()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create focus service: %w", err)
|
return nil, nil, fmt.Errorf("failed to create focus service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
bridge := newBridge(
|
bridge := newBridge(
|
||||||
|
// App stuff
|
||||||
locator,
|
locator,
|
||||||
vault,
|
vault,
|
||||||
autostarter,
|
autostarter,
|
||||||
updater,
|
updater,
|
||||||
curVersion,
|
curVersion,
|
||||||
|
|
||||||
|
// API stuff
|
||||||
api,
|
api,
|
||||||
identifier,
|
identifier,
|
||||||
proxyCtl,
|
proxyCtl,
|
||||||
|
|
||||||
|
// Service stuff
|
||||||
tlsConfig,
|
tlsConfig,
|
||||||
imapServer,
|
imapServer,
|
||||||
smtpBackend,
|
smtpBackend,
|
||||||
focusService,
|
focusService,
|
||||||
|
|
||||||
|
// Logging stuff
|
||||||
logIMAPClient,
|
logIMAPClient,
|
||||||
logIMAPServer,
|
logIMAPServer,
|
||||||
logSMTP,
|
logSMTP,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Get an event channel for all events (individual events can be subscribed to later).
|
||||||
|
eventCh, _ := bridge.GetEvents()
|
||||||
|
|
||||||
if err := bridge.init(tlsReporter); err != nil {
|
if err := bridge.init(tlsReporter); err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize bridge: %w", err)
|
return nil, nil, fmt.Errorf("failed to initialize bridge: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return bridge, nil
|
return bridge, eventCh, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBridge(
|
func newBridge(
|
||||||
@ -174,7 +183,9 @@ func newBridge(
|
|||||||
) *Bridge {
|
) *Bridge {
|
||||||
return &Bridge{
|
return &Bridge{
|
||||||
vault: vault,
|
vault: vault,
|
||||||
users: make(map[string]*user.User),
|
|
||||||
|
users: safe.NewMap[string, *user.User](nil),
|
||||||
|
loadCh: make(chan struct{}, 1),
|
||||||
|
|
||||||
api: api,
|
api: api,
|
||||||
proxyCtl: proxyCtl,
|
proxyCtl: proxyCtl,
|
||||||
@ -193,6 +204,8 @@ func newBridge(
|
|||||||
autostarter: autostarter,
|
autostarter: autostarter,
|
||||||
locator: locator,
|
locator: locator,
|
||||||
|
|
||||||
|
watchers: safe.NewSlice[*watcher.Watcher[events.Event]](),
|
||||||
|
|
||||||
logIMAPClient: logIMAPClient,
|
logIMAPClient: logIMAPClient,
|
||||||
logIMAPServer: logIMAPServer,
|
logIMAPServer: logIMAPServer,
|
||||||
logSMTP: logSMTP,
|
logSMTP: logSMTP,
|
||||||
@ -227,10 +240,6 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := bridge.loadUsers(); err != nil {
|
|
||||||
return fmt.Errorf("failed to load users: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for range tlsReporter.GetTLSIssueCh() {
|
for range tlsReporter.GetTLSIssueCh() {
|
||||||
bridge.publish(events.TLSIssue{})
|
bridge.publish(events.TLSIssue{})
|
||||||
@ -261,6 +270,8 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
|
|||||||
bridge.PushError(ErrWatchUpdates)
|
bridge.PushError(ErrWatchUpdates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go bridge.loadLoop()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,6 +299,9 @@ func (bridge *Bridge) Close(ctx context.Context) error {
|
|||||||
// Stop ongoing operations such as connectivity checks.
|
// Stop ongoing operations such as connectivity checks.
|
||||||
close(bridge.stopCh)
|
close(bridge.stopCh)
|
||||||
|
|
||||||
|
// Wait for ongoing user load operations to finish.
|
||||||
|
bridge.loadWG.Wait()
|
||||||
|
|
||||||
// Close the IMAP server.
|
// Close the IMAP server.
|
||||||
if err := bridge.closeIMAP(ctx); err != nil {
|
if err := bridge.closeIMAP(ctx); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to close IMAP server")
|
logrus.WithError(err).Error("Failed to close IMAP server")
|
||||||
@ -299,10 +313,10 @@ func (bridge *Bridge) Close(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close all users.
|
// Close all users.
|
||||||
for _, user := range bridge.users {
|
if err := bridge.users.IterValuesErr(func(user *user.User) error {
|
||||||
if err := user.Close(); err != nil {
|
return user.Close()
|
||||||
logrus.WithError(err).Error("Failed to close user")
|
}); err != nil {
|
||||||
}
|
logrus.WithError(err).Error("Failed to close users")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the focus service.
|
// Close the focus service.
|
||||||
@ -317,49 +331,44 @@ func (bridge *Bridge) Close(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) publish(event events.Event) {
|
func (bridge *Bridge) publish(event events.Event) {
|
||||||
bridge.watchersLock.RLock()
|
bridge.watchers.Iter(func(watcher *watcher.Watcher[events.Event]) {
|
||||||
defer bridge.watchersLock.RUnlock()
|
|
||||||
|
|
||||||
for _, watcher := range bridge.watchers {
|
|
||||||
if watcher.IsWatching(event) {
|
if watcher.IsWatching(event) {
|
||||||
if ok := watcher.Send(event); !ok {
|
if ok := watcher.Send(event); !ok {
|
||||||
logrus.WithField("event", event).Warn("Failed to send event to watcher")
|
logrus.WithField("event", event).Warn("Failed to send event to watcher")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events.Event] {
|
func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events.Event] {
|
||||||
bridge.watchersLock.Lock()
|
|
||||||
defer bridge.watchersLock.Unlock()
|
|
||||||
|
|
||||||
newWatcher := watcher.New(ofType...)
|
newWatcher := watcher.New(ofType...)
|
||||||
|
|
||||||
bridge.watchers = append(bridge.watchers, newWatcher)
|
bridge.watchers.Append(newWatcher)
|
||||||
|
|
||||||
return newWatcher
|
return newWatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) {
|
func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) {
|
||||||
bridge.watchersLock.Lock()
|
bridge.watchers.Delete(oldWatcher)
|
||||||
defer bridge.watchersLock.Unlock()
|
|
||||||
|
|
||||||
bridge.watchers = xslices.Filter(bridge.watchers, func(other *watcher.Watcher[events.Event]) bool {
|
|
||||||
return other != oldWatcher
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) onStatusUp() {
|
func (bridge *Bridge) onStatusUp() {
|
||||||
bridge.publish(events.ConnStatusUp{})
|
bridge.publish(events.ConnStatusUp{})
|
||||||
|
|
||||||
if err := bridge.loadUsers(); err != nil {
|
bridge.loadCh <- struct{}{}
|
||||||
logrus.WithError(err).Error("Failed to load users")
|
|
||||||
}
|
bridge.users.IterValues(func(user *user.User) {
|
||||||
|
user.OnStatusUp()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) onStatusDown() {
|
func (bridge *Bridge) onStatusDown() {
|
||||||
bridge.publish(events.ConnStatusDown{})
|
bridge.publish(events.ConnStatusDown{})
|
||||||
|
|
||||||
|
bridge.users.IterValues(func(user *user.User) {
|
||||||
|
user.OnStatusDown()
|
||||||
|
})
|
||||||
|
|
||||||
upCh, done := bridge.GetEvents(events.ConnStatusUp{})
|
upCh, done := bridge.GetEvents(events.ConnStatusUp{})
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
|
|||||||
@ -136,7 +136,7 @@ func TestBridge_UserAgent(t *testing.T) {
|
|||||||
|
|
||||||
func TestBridge_Cookies(t *testing.T) {
|
func TestBridge_Cookies(t *testing.T) {
|
||||||
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
withTLSEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, vaultKey []byte) {
|
||||||
sessionIDs := safe.NewSet[string]()
|
sessionIDs := safe.NewValue([]string{})
|
||||||
|
|
||||||
// Save any session IDs we use.
|
// Save any session IDs we use.
|
||||||
s.AddCallWatcher(func(call server.Call) {
|
s.AddCallWatcher(func(call server.Call) {
|
||||||
@ -145,7 +145,9 @@ func TestBridge_Cookies(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionIDs.Insert(cookie.Value)
|
sessionIDs.Mod(func(sessionIDs *[]string) {
|
||||||
|
*sessionIDs = append(*sessionIDs, cookie.Value)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// Start bridge and add a user so that API assigns us a session ID via cookie.
|
// Start bridge and add a user so that API assigns us a session ID via cookie.
|
||||||
@ -160,8 +162,8 @@ func TestBridge_Cookies(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// We should have used just one session ID.
|
// We should have used just one session ID.
|
||||||
sessionIDs.Values(func(sessionIDs []string) {
|
sessionIDs.Load(func(sessionIDs []string) {
|
||||||
require.Len(t, sessionIDs, 1)
|
require.Len(t, xslices.Unique(sessionIDs), 1)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -405,7 +407,7 @@ func withBridge(
|
|||||||
defer func() { require.NoError(t, cookieJar.PersistCookies()) }()
|
defer func() { require.NoError(t, cookieJar.PersistCookies()) }()
|
||||||
|
|
||||||
// Create a new bridge.
|
// Create a new bridge.
|
||||||
bridge, err := bridge.New(
|
bridge, eventCh, err := bridge.New(
|
||||||
// The app stuff.
|
// The app stuff.
|
||||||
locator,
|
locator,
|
||||||
vault,
|
vault,
|
||||||
@ -428,6 +430,9 @@ func withBridge(
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait for bridge to finish loading users.
|
||||||
|
waitForEvent(t, eventCh, events.AllUsersLoaded{})
|
||||||
|
|
||||||
// Close the bridge when done.
|
// Close the bridge when done.
|
||||||
defer func() { require.NoError(t, bridge.Close(ctx)) }()
|
defer func() { require.NoError(t, bridge.Close(ctx)) }()
|
||||||
|
|
||||||
@ -435,6 +440,17 @@ func withBridge(
|
|||||||
tests(bridge, mocks)
|
tests(bridge, mocks)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func waitForEvent[T any](t *testing.T, eventCh <-chan events.Event, wantEvent T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for event := range eventCh {
|
||||||
|
switch event.(type) {
|
||||||
|
case T:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// must is a helper function that panics on error.
|
// must is a helper function that panics on error.
|
||||||
func must[T any](val T, err error) T {
|
func must[T any](val T, err error) T {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -41,7 +41,12 @@ func (bridge *Bridge) ReportBug(ctx context.Context, osType, osVersion, descript
|
|||||||
if info, err := bridge.QueryUserInfo(username); err == nil {
|
if info, err := bridge.QueryUserInfo(username); err == nil {
|
||||||
account = info.Username
|
account = info.Username
|
||||||
} else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 {
|
} else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 {
|
||||||
account = bridge.users[userIDs[0]].Name()
|
user, err := bridge.vault.GetUser(userIDs[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
account = user.Username()
|
||||||
}
|
}
|
||||||
|
|
||||||
var atts []liteapi.ReportBugAttachment
|
var atts []liteapi.ReportBugAttachment
|
||||||
|
|||||||
@ -1,47 +1,52 @@
|
|||||||
package bridge
|
package bridge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/clientconfig"
|
"github.com/ProtonMail/proton-bridge/v2/internal/clientconfig"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
||||||
user, ok := bridge.users[userID]
|
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error {
|
||||||
if !ok {
|
if address == "" {
|
||||||
return ErrNoSuchUser
|
address = user.Emails()[0]
|
||||||
}
|
|
||||||
|
|
||||||
if address == "" {
|
|
||||||
address = user.Emails()[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
username := address
|
|
||||||
addresses := address
|
|
||||||
|
|
||||||
if user.GetAddressMode() == vault.CombinedMode {
|
|
||||||
username = user.Emails()[0]
|
|
||||||
addresses = strings.Join(user.Emails(), ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
// If configuring apple mail for Catalina or newer, users should use SSL.
|
|
||||||
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
|
|
||||||
if err := bridge.SetSMTPSSL(true); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
username := address
|
||||||
|
addresses := address
|
||||||
|
|
||||||
|
if user.GetAddressMode() == vault.CombinedMode {
|
||||||
|
username = user.Emails()[0]
|
||||||
|
addresses = strings.Join(user.Emails(), ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If configuring apple mail for Catalina or newer, users should use SSL.
|
||||||
|
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
|
||||||
|
if err := bridge.SetSMTPSSL(true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (&clientconfig.AppleMail{}).Configure(
|
||||||
|
constants.Host,
|
||||||
|
bridge.vault.GetIMAPPort(),
|
||||||
|
bridge.vault.GetSMTPPort(),
|
||||||
|
bridge.vault.GetIMAPSSL(),
|
||||||
|
bridge.vault.GetSMTPSSL(),
|
||||||
|
username,
|
||||||
|
addresses,
|
||||||
|
user.BridgePass(),
|
||||||
|
)
|
||||||
|
}); !ok {
|
||||||
|
return ErrNoSuchUser
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("failed to configure apple mail: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return (&clientconfig.AppleMail{}).Configure(
|
return nil
|
||||||
constants.Host,
|
|
||||||
bridge.vault.GetIMAPPort(),
|
|
||||||
bridge.vault.GetSMTPPort(),
|
|
||||||
bridge.vault.GetIMAPSSL(),
|
|
||||||
bridge.vault.GetSMTPSSL(),
|
|
||||||
username,
|
|
||||||
addresses,
|
|
||||||
user.BridgePass(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Masterminds/semver/v3"
|
"github.com/Masterminds/semver/v3"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -119,10 +120,10 @@ func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error
|
|||||||
|
|
||||||
bridge.imapServer = imapServer
|
bridge.imapServer = imapServer
|
||||||
|
|
||||||
for _, user := range bridge.users {
|
if err := bridge.users.IterValuesErr(func(user *user.User) error {
|
||||||
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
return bridge.addIMAPUser(ctx, user)
|
||||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
}); err != nil {
|
||||||
}
|
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bridge.serveIMAP(); err != nil {
|
if err := bridge.serveIMAP(); err != nil {
|
||||||
|
|||||||
@ -1,13 +1,10 @@
|
|||||||
package bridge
|
package bridge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/subtle"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type smtpBackend struct {
|
type smtpBackend struct {
|
||||||
@ -26,13 +23,12 @@ func (backend *smtpBackend) Login(state *smtp.ConnectionState, email, password s
|
|||||||
defer backend.usersLock.RUnlock()
|
defer backend.usersLock.RUnlock()
|
||||||
|
|
||||||
for _, user := range backend.users {
|
for _, user := range backend.users {
|
||||||
if subtle.ConstantTimeCompare(user.BridgePass(), []byte(password)) != 1 {
|
session, err := user.NewSMTPSession(email, []byte(password))
|
||||||
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if email := strings.ToLower(email); slices.Contains(user.Emails(), email) {
|
return session, nil
|
||||||
return user.NewSMTPSession(email)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, ErrNoSuchUser
|
return nil, ErrNoSuchUser
|
||||||
|
|||||||
@ -19,7 +19,7 @@ func TestBridge_Sync(t *testing.T) {
|
|||||||
s := server.New()
|
s := server.New()
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
numMsg := 1 << 10
|
numMsg := 1 << 8
|
||||||
|
|
||||||
withEnv(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
withEnv(t, s, func(ctx context.Context, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password)
|
userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password)
|
||||||
@ -80,51 +80,51 @@ func TestBridge_Sync(t *testing.T) {
|
|||||||
|
|
||||||
// Login the user; its sync should fail.
|
// Login the user; its sync should fail.
|
||||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
syncCh, done := chToType[events.Event, events.SyncFailed](bridge.GetEvents(events.SyncFailed{}))
|
{
|
||||||
defer done()
|
syncCh, done := chToType[events.Event, events.SyncFailed](bridge.GetEvents(events.SyncFailed{}))
|
||||||
|
defer done()
|
||||||
|
|
||||||
userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil)
|
userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, userID, (<-syncCh).UserID)
|
require.Equal(t, userID, (<-syncCh).UserID)
|
||||||
|
|
||||||
info, err := bridge.GetUserInfo(userID)
|
info, err := bridge.GetUserInfo(userID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, info.Connected)
|
require.True(t, info.Connected)
|
||||||
|
|
||||||
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||||
defer func() { _ = client.Logout() }()
|
defer func() { _ = client.Logout() }()
|
||||||
|
|
||||||
status, err := client.Select(`Folders/folder`, false)
|
status, err := client.Select(`Folders/folder`, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Less(t, status.Messages, uint32(numMsg))
|
require.Less(t, status.Messages, uint32(numMsg))
|
||||||
})
|
}
|
||||||
|
|
||||||
// Remove the network limit, allowing the sync to finish.
|
// Remove the network limit, allowing the sync to finish.
|
||||||
netCtl.SetReadLimit(0)
|
netCtl.SetReadLimit(0)
|
||||||
|
|
||||||
// Login the user; its sync should now finish.
|
{
|
||||||
// If we then connect an IMAP client, it should eventually see all the messages.
|
syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{}))
|
||||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
defer done()
|
||||||
syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{}))
|
|
||||||
defer done()
|
|
||||||
|
|
||||||
require.Equal(t, userID, (<-syncCh).UserID)
|
require.Equal(t, userID, (<-syncCh).UserID)
|
||||||
|
|
||||||
info, err := bridge.GetUserInfo(userID)
|
info, err := bridge.GetUserInfo(userID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, info.Connected)
|
require.True(t, info.Connected)
|
||||||
|
|
||||||
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
client, err := client.Dial(fmt.Sprintf(":%v", bridge.GetIMAPPort()))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
||||||
defer func() { _ = client.Logout() }()
|
defer func() { _ = client.Logout() }()
|
||||||
|
|
||||||
status, err := client.Select(`Folders/folder`, false)
|
status, err := client.Select(`Folders/folder`, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, uint32(numMsg), status.Messages)
|
require.Equal(t, uint32(numMsg), status.Messages)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ProtonMail/gluon/imap"
|
"github.com/ProtonMail/gluon/imap"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/try"
|
"github.com/ProtonMail/proton-bridge/v2/internal/try"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
@ -53,23 +54,26 @@ func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
|
|||||||
return UserInfo{}, err
|
return UserInfo{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, ok := bridge.users[userID]
|
if info, ok := safe.MapGetRet(bridge.users, userID, func(user *user.User) UserInfo {
|
||||||
if !ok {
|
return getConnUserInfo(user)
|
||||||
return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil
|
}); ok {
|
||||||
|
return info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return getConnUserInfo(user), nil
|
return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryUserInfo queries the user info by username or address.
|
// QueryUserInfo queries the user info by username or address.
|
||||||
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
|
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
|
||||||
for userID, user := range bridge.users {
|
return safe.MapValuesRetErr(bridge.users, func(users []*user.User) (UserInfo, error) {
|
||||||
if user.Match(query) {
|
for _, user := range users {
|
||||||
return bridge.GetUserInfo(userID)
|
if user.Match(query) {
|
||||||
|
return getConnUserInfo(user), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return UserInfo{}, ErrNoSuchUser
|
return UserInfo{}, ErrNoSuchUser
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginAuth begins the login process. It returns an authorized client that might need 2FA.
|
// LoginAuth begins the login process. It returns an authorized client that might need 2FA.
|
||||||
@ -79,7 +83,7 @@ func (bridge *Bridge) LoginAuth(ctx context.Context, username string, password [
|
|||||||
return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err)
|
return nil, liteapi.Auth{}, fmt.Errorf("failed to create new API client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := bridge.users[auth.UserID]; ok {
|
if bridge.users.Has(auth.UserID) {
|
||||||
if err := client.AuthDelete(ctx); err != nil {
|
if err := client.AuthDelete(ctx); err != nil {
|
||||||
logrus.WithError(err).Warn("Failed to delete auth")
|
logrus.WithError(err).Warn("Failed to delete auth")
|
||||||
}
|
}
|
||||||
@ -187,34 +191,37 @@ func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
|
|||||||
|
|
||||||
// SetAddressMode sets the address mode for the given user.
|
// SetAddressMode sets the address mode for the given user.
|
||||||
func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error {
|
func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error {
|
||||||
user, ok := bridge.users[userID]
|
if ok, err := bridge.users.GetErr(userID, func(user *user.User) error {
|
||||||
if !ok {
|
if user.GetAddressMode() == mode {
|
||||||
return ErrNoSuchUser
|
return fmt.Errorf("address mode is already %q", mode)
|
||||||
}
|
|
||||||
|
|
||||||
if user.GetAddressMode() == mode {
|
|
||||||
return fmt.Errorf("address mode is already %q", mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, gluonID := range user.GetGluonIDs() {
|
|
||||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
|
||||||
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err := user.SetAddressMode(ctx, mode); err != nil {
|
for _, gluonID := range user.GetGluonIDs() {
|
||||||
|
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.SetAddressMode(ctx, mode); err != nil {
|
||||||
|
return fmt.Errorf("failed to set address mode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
||||||
|
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bridge.publish(events.AddressModeChanged{
|
||||||
|
UserID: userID,
|
||||||
|
AddressMode: mode,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); !ok {
|
||||||
|
return ErrNoSuchUser
|
||||||
|
} else if err != nil {
|
||||||
return fmt.Errorf("failed to set address mode: %w", err)
|
return fmt.Errorf("failed to set address mode: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
|
||||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bridge.publish(events.AddressModeChanged{
|
|
||||||
UserID: userID,
|
|
||||||
AddressMode: mode,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,10 +248,30 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut
|
|||||||
return apiUser.ID, nil
|
return apiUser.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadUsers is a loop that, when polled, attempts to load authorized users from the vault.
|
// loadLoop is a loop that, when polled, attempts to load authorized users from the vault.
|
||||||
|
func (bridge *Bridge) loadLoop() {
|
||||||
|
for {
|
||||||
|
bridge.loadWG.GoTry(func(ok bool) {
|
||||||
|
if ok {
|
||||||
|
if err := bridge.loadUsers(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to load users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-bridge.stopCh:
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-bridge.loadCh:
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) loadUsers() error {
|
func (bridge *Bridge) loadUsers() error {
|
||||||
return bridge.vault.ForUser(func(user *vault.User) error {
|
if err := bridge.vault.ForUser(func(user *vault.User) error {
|
||||||
if _, ok := bridge.users[user.UserID()]; ok {
|
if bridge.users.Has(user.UserID()) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,7 +298,13 @@ func (bridge *Bridge) loadUsers() error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to iterate over users: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bridge.publish(events.AllUsersLoaded{})
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadUser loads an existing user from the vault.
|
// loadUser loads an existing user from the vault.
|
||||||
@ -387,7 +420,7 @@ func (bridge *Bridge) addNewUser(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
bridge.users[apiUser.ID] = user
|
bridge.users.Set(apiUser.ID, user)
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
@ -417,7 +450,7 @@ func (bridge *Bridge) addExistingUser(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
bridge.users[apiUser.ID] = user
|
bridge.users.Set(apiUser.ID, user)
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
@ -451,37 +484,38 @@ func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
|
|||||||
|
|
||||||
// logoutUser logs the given user out from bridge.
|
// logoutUser logs the given user out from bridge.
|
||||||
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
|
func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
|
||||||
user, ok := bridge.users[userID]
|
if ok, err := bridge.users.GetDeleteErr(userID, func(user *user.User) error {
|
||||||
if !ok {
|
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||||
return ErrNoSuchUser
|
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
|
||||||
}
|
|
||||||
|
|
||||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, gluonID := range user.GetGluonIDs() {
|
|
||||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to remove IMAP user")
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err := user.Logout(ctx); err != nil {
|
for _, gluonID := range user.GetGluonIDs() {
|
||||||
logrus.WithError(err).Error("Failed to logout user")
|
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
|
||||||
}
|
logrus.WithError(err).Error("Failed to remove IMAP user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := user.Close(); err != nil {
|
if err := user.Logout(ctx); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to close user")
|
logrus.WithError(err).Error("Failed to logout user")
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(bridge.users, userID)
|
if err := user.Close(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to close user")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); !ok {
|
||||||
|
return ErrNoSuchUser
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteUser deletes the given user from bridge.
|
// deleteUser deletes the given user from bridge.
|
||||||
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
|
func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
|
||||||
if user, ok := bridge.users[userID]; ok {
|
if ok := bridge.users.GetDelete(userID, func(user *user.User) {
|
||||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
|
logrus.WithError(err).Error("Failed to remove user from SMTP backend")
|
||||||
}
|
}
|
||||||
@ -499,13 +533,13 @@ func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
|
|||||||
if err := user.Close(); err != nil {
|
if err := user.Close(); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to close user")
|
logrus.WithError(err).Error("Failed to close user")
|
||||||
}
|
}
|
||||||
|
}); !ok {
|
||||||
|
logrus.Debug("The bridge user was not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to delete user from vault")
|
logrus.WithError(err).Error("Failed to delete user from vault")
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(bridge.users, userID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserInfo returns information about a disconnected user.
|
// getUserInfo returns information about a disconnected user.
|
||||||
|
|||||||
@ -43,23 +43,13 @@ func (bridge *Bridge) handleUserAddressCreated(ctx context.Context, user *user.U
|
|||||||
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
|
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
imapConn, err := user.NewIMAPConnector(addrID)
|
if err := bridge.imapServer.LoadUser(ctx, user.NewIMAPConnector(addrID), gluonID, user.GluonKey()); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create IMAP connector: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
|
|
||||||
return fmt.Errorf("failed to add user to IMAP server: %w", err)
|
return fmt.Errorf("failed to add user to IMAP server: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case vault.SplitMode:
|
case vault.SplitMode:
|
||||||
imapConn, err := user.NewIMAPConnector(event.AddressID)
|
gluonID, err := bridge.imapServer.AddUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create IMAP connector: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to add user to IMAP server: %w", err)
|
return fmt.Errorf("failed to add user to IMAP server: %w", err)
|
||||||
}
|
}
|
||||||
@ -93,12 +83,7 @@ func (bridge *Bridge) handleUserAddressDeleted(ctx context.Context, user *user.U
|
|||||||
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
|
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
imapConn, err := user.NewIMAPConnector(addrID)
|
if err := bridge.imapServer.LoadUser(ctx, user.NewIMAPConnector(addrID), gluonID, user.GluonKey()); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create IMAP connector: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
|
|
||||||
return fmt.Errorf("failed to add user to IMAP server: %w", err)
|
return fmt.Errorf("failed to add user to IMAP server: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,6 +2,10 @@ package events
|
|||||||
|
|
||||||
import "github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
import "github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
|
|
||||||
|
type AllUsersLoaded struct {
|
||||||
|
eventBase
|
||||||
|
}
|
||||||
|
|
||||||
type UserLoaded struct {
|
type UserLoaded struct {
|
||||||
eventBase
|
eventBase
|
||||||
|
|
||||||
|
|||||||
@ -38,7 +38,7 @@ type frontendCLI struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new CLI frontend configured with the given options.
|
// New returns a new CLI frontend configured with the given options.
|
||||||
func New(bridge *bridge.Bridge) *frontendCLI {
|
func New(bridge *bridge.Bridge, eventCh <-chan events.Event) *frontendCLI {
|
||||||
fe := &frontendCLI{
|
fe := &frontendCLI{
|
||||||
Shell: ishell.New(),
|
Shell: ishell.New(),
|
||||||
bridge: bridge,
|
bridge: bridge,
|
||||||
@ -253,15 +253,12 @@ func New(bridge *bridge.Bridge) *frontendCLI {
|
|||||||
Completer: fe.completeUsernames,
|
Completer: fe.completeUsernames,
|
||||||
})
|
})
|
||||||
|
|
||||||
go fe.watchEvents()
|
go fe.watchEvents(eventCh)
|
||||||
|
|
||||||
return fe
|
return fe
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *frontendCLI) watchEvents() {
|
func (f *frontendCLI) watchEvents(eventCh <-chan events.Event) {
|
||||||
eventCh, done := f.bridge.GetEvents()
|
|
||||||
defer done()
|
|
||||||
|
|
||||||
// TODO: Better error events.
|
// TODO: Better error events.
|
||||||
for _, err := range f.bridge.GetErrors() {
|
for _, err := range f.bridge.GetErrors() {
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@ -64,6 +64,7 @@ type Service struct { // nolint:structcheck
|
|||||||
panicHandler *crash.Handler
|
panicHandler *crash.Handler
|
||||||
restarter *restarter.Restarter
|
restarter *restarter.Restarter
|
||||||
bridge *bridge.Bridge
|
bridge *bridge.Bridge
|
||||||
|
eventCh <-chan events.Event
|
||||||
newVersionInfo updater.VersionInfo
|
newVersionInfo updater.VersionInfo
|
||||||
|
|
||||||
authClient *liteapi.Client
|
authClient *liteapi.Client
|
||||||
@ -84,6 +85,7 @@ func NewService(
|
|||||||
restarter *restarter.Restarter,
|
restarter *restarter.Restarter,
|
||||||
locations *locations.Locations,
|
locations *locations.Locations,
|
||||||
bridge *bridge.Bridge,
|
bridge *bridge.Bridge,
|
||||||
|
eventCh <-chan events.Event,
|
||||||
showOnStartup bool,
|
showOnStartup bool,
|
||||||
) (*Service, error) {
|
) (*Service, error) {
|
||||||
tlsConfig, certPEM, err := newTLSConfig()
|
tlsConfig, certPEM, err := newTLSConfig()
|
||||||
@ -115,6 +117,7 @@ func NewService(
|
|||||||
panicHandler: panicHandler,
|
panicHandler: panicHandler,
|
||||||
restarter: restarter,
|
restarter: restarter,
|
||||||
bridge: bridge,
|
bridge: bridge,
|
||||||
|
eventCh: eventCh,
|
||||||
|
|
||||||
log: logrus.WithField("pkg", "grpc"),
|
log: logrus.WithField("pkg", "grpc"),
|
||||||
initializing: sync.WaitGroup{},
|
initializing: sync.WaitGroup{},
|
||||||
@ -200,9 +203,6 @@ func (s *Service) WaitUntilFrontendIsReady() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) watchEvents() {
|
func (s *Service) watchEvents() {
|
||||||
eventCh, done := s.bridge.GetEvents()
|
|
||||||
defer done()
|
|
||||||
|
|
||||||
// TODO: Better error events.
|
// TODO: Better error events.
|
||||||
for _, err := range s.bridge.GetErrors() {
|
for _, err := range s.bridge.GetErrors() {
|
||||||
switch {
|
switch {
|
||||||
@ -220,7 +220,7 @@ func (s *Service) watchEvents() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for event := range eventCh {
|
for event := range s.eventCh {
|
||||||
switch event := event.(type) {
|
switch event := event.(type) {
|
||||||
case events.ConnStatusUp:
|
case events.ConnStatusUp:
|
||||||
_ = s.SendEvent(NewInternetStatusEvent(true))
|
_ = s.SendEvent(NewInternetStatusEvent(true))
|
||||||
@ -243,6 +243,9 @@ func (s *Service) watchEvents() {
|
|||||||
case events.UserChanged:
|
case events.UserChanged:
|
||||||
_ = s.SendEvent(NewUserChangedEvent(event.UserID))
|
_ = s.SendEvent(NewUserChangedEvent(event.UserID))
|
||||||
|
|
||||||
|
case events.UserLoaded:
|
||||||
|
_ = s.SendEvent(NewUserChangedEvent(event.UserID))
|
||||||
|
|
||||||
case events.UserLoggedIn:
|
case events.UserLoggedIn:
|
||||||
_ = s.SendEvent(NewUserChangedEvent(event.UserID))
|
_ = s.SendEvent(NewUserChangedEvent(event.UserID))
|
||||||
|
|
||||||
|
|||||||
@ -3,18 +3,26 @@ package safe
|
|||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Map[Key comparable, Val any] struct {
|
type Map[Key comparable, Val any] struct {
|
||||||
data map[Key]Val
|
data map[Key]Val
|
||||||
lock sync.RWMutex
|
order []Key
|
||||||
|
sort func(a, b Key, data map[Key]Val) bool
|
||||||
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMap[Key comparable, Val any](from map[Key]Val) *Map[Key, Val] {
|
func NewMap[Key comparable, Val any](sort func(a, b Key, data map[Key]Val) bool) *Map[Key, Val] {
|
||||||
m := &Map[Key, Val]{
|
return &Map[Key, Val]{
|
||||||
data: make(map[Key]Val),
|
data: make(map[Key]Val),
|
||||||
|
sort: sort,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMapFrom[Key comparable, Val any](from map[Key]Val, sort func(a, b Key, data map[Key]Val) bool) *Map[Key, Val] {
|
||||||
|
m := NewMap(sort)
|
||||||
|
|
||||||
for key, val := range from {
|
for key, val := range from {
|
||||||
m.Set(key, val)
|
m.Set(key, val)
|
||||||
@ -23,12 +31,36 @@ func NewMap[Key comparable, Val any](from map[Key]Val) *Map[Key, Val] {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) Has(key Key) bool {
|
func (m *Map[Key, Val]) Index(idx int, fn func(Key, Val)) bool {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
_, ok := m.data[key]
|
if idx < 0 || idx >= len(m.order) {
|
||||||
return ok
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn(m.order[idx], m.data[m.order[idx]])
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[Key, Val]) Has(key Key) bool {
|
||||||
|
return m.HasFunc(func(k Key, v Val) bool {
|
||||||
|
return k == key
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[Key, Val]) HasFunc(fn func(key Key, val Val) bool) bool {
|
||||||
|
m.lock.RLock()
|
||||||
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
|
for key, val := range m.data {
|
||||||
|
if fn(key, val) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) Get(key Key, fn func(Val)) bool {
|
func (m *Map[Key, Val]) Get(key Key, fn func(Val)) bool {
|
||||||
@ -46,15 +78,45 @@ func (m *Map[Key, Val]) Get(key Key, fn func(Val)) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) GetErr(key Key, fn func(Val) error) (bool, error) {
|
func (m *Map[Key, Val]) GetErr(key Key, fn func(Val) error) (bool, error) {
|
||||||
m.lock.RLock()
|
var err error
|
||||||
defer m.lock.RUnlock()
|
|
||||||
|
ok := m.Get(key, func(val Val) {
|
||||||
|
err = fn(val)
|
||||||
|
})
|
||||||
|
|
||||||
|
return ok, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[Key, Val]) GetDelete(key Key, fn func(Val)) bool {
|
||||||
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
|
||||||
val, ok := m.data[key]
|
val, ok := m.data[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false, nil
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, fn(val)
|
fn(val)
|
||||||
|
|
||||||
|
delete(m.data, key)
|
||||||
|
|
||||||
|
if idx := xslices.Index(m.order, key); idx >= 0 {
|
||||||
|
m.order = append(m.order[:idx], m.order[idx+1:]...)
|
||||||
|
} else {
|
||||||
|
panic("order and data out of sync")
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[Key, Val]) GetDeleteErr(key Key, fn func(Val) error) (bool, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
ok := m.GetDelete(key, func(val Val) {
|
||||||
|
err = fn(val)
|
||||||
|
})
|
||||||
|
|
||||||
|
return ok, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) Set(key Key, val Val) {
|
func (m *Map[Key, Val]) Set(key Key, val Val) {
|
||||||
@ -62,84 +124,140 @@ func (m *Map[Key, Val]) Set(key Key, val Val) {
|
|||||||
defer m.lock.Unlock()
|
defer m.lock.Unlock()
|
||||||
|
|
||||||
m.data[key] = val
|
m.data[key] = val
|
||||||
|
|
||||||
|
m.order = append(m.order, key)
|
||||||
|
|
||||||
|
if m.sort != nil {
|
||||||
|
slices.SortFunc(m.order, func(a, b Key) bool {
|
||||||
|
return m.sort(a, b, m.data)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) Delete(key Key) {
|
func (m *Map[Key, Val]) SetFrom(key Key, other Key) {
|
||||||
m.lock.Lock()
|
m.lock.Lock()
|
||||||
defer m.lock.Unlock()
|
defer m.lock.Unlock()
|
||||||
|
|
||||||
delete(m.data, key)
|
m.data[key] = m.data[other]
|
||||||
|
|
||||||
|
m.order = append(m.order, key)
|
||||||
|
|
||||||
|
if m.sort != nil {
|
||||||
|
slices.SortFunc(m.order, func(a, b Key) bool {
|
||||||
|
return m.sort(a, b, m.data)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) Iter(fn func(key Key, val Val)) {
|
func (m *Map[Key, Val]) Iter(fn func(key Key, val Val)) {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
for key, val := range m.data {
|
for _, key := range m.order {
|
||||||
fn(key, val)
|
fn(key, m.data[key])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) Keys(fn func(keys []Key)) {
|
func (m *Map[Key, Val]) IterKeys(fn func(Key)) {
|
||||||
m.lock.RLock()
|
m.Iter(func(key Key, _ Val) {
|
||||||
defer m.lock.RUnlock()
|
fn(key)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn(maps.Keys(m.data))
|
func (m *Map[Key, Val]) IterKeysErr(fn func(Key) error) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
m.IterKeys(func(key Key) {
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fn(key)
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[Key, Val]) IterValues(fn func(Val)) {
|
||||||
|
m.Iter(func(_ Key, val Val) {
|
||||||
|
fn(val)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[Key, Val]) IterValuesErr(fn func(Val) error) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
m.IterValues(func(val Val) {
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fn(val)
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[Key, Val]) Values(fn func(vals []Val)) {
|
func (m *Map[Key, Val]) Values(fn func(vals []Val)) {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
fn(maps.Values(m.data))
|
vals := make([]Val, len(m.order))
|
||||||
|
|
||||||
|
for i, key := range m.order {
|
||||||
|
vals[i] = m.data[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn(vals)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) Ret, fallback func() Ret) Ret {
|
func (m *Map[Key, Val]) ValuesErr(fn func(vals []Val) error) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
m.Values(func(vals []Val) {
|
||||||
|
err = fn(vals)
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[Key, Val]) MapErr(fn func(map[Key]Val) error) error {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
|
|
||||||
val, ok := m.data[key]
|
return fn(m.data)
|
||||||
if !ok {
|
|
||||||
return fallback()
|
|
||||||
}
|
|
||||||
|
|
||||||
return fn(val)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) {
|
func MapGetRet[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) Ret) (Ret, bool) {
|
||||||
m.lock.RLock()
|
var ret Ret
|
||||||
defer m.lock.RUnlock()
|
|
||||||
|
|
||||||
val, ok := m.data[key]
|
ok := m.Get(key, func(val Val) {
|
||||||
if !ok {
|
ret = fn(val)
|
||||||
return fallback()
|
})
|
||||||
}
|
|
||||||
|
|
||||||
return fn(val)
|
return ret, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func FindMap[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) Ret, fallback func() Ret) Ret {
|
func MapValuesRet[Key comparable, Val, Ret any](m *Map[Key, Val], fn func([]Val) Ret) Ret {
|
||||||
m.lock.RLock()
|
var ret Ret
|
||||||
defer m.lock.RUnlock()
|
|
||||||
|
|
||||||
for _, val := range m.data {
|
m.Values(func(vals []Val) {
|
||||||
if cmp(val) {
|
ret = fn(vals)
|
||||||
return fn(val)
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fallback()
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func FindMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], cmp func(Val) bool, fn func(Val) (Ret, error), fallback func() (Ret, error)) (Ret, error) {
|
func MapValuesRetErr[Key comparable, Val, Ret any](m *Map[Key, Val], fn func([]Val) (Ret, error)) (Ret, error) {
|
||||||
m.lock.RLock()
|
var ret Ret
|
||||||
defer m.lock.RUnlock()
|
|
||||||
|
|
||||||
for _, val := range m.data {
|
err := m.ValuesErr(func(vals []Val) error {
|
||||||
if cmp(val) {
|
var err error
|
||||||
return fn(val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fallback()
|
ret, err = fn(vals)
|
||||||
|
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
return ret, err
|
||||||
}
|
}
|
||||||
|
|||||||
97
internal/safe/map_test.go
Normal file
97
internal/safe/map_test.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
package safe
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSafe_Map(t *testing.T) {
|
||||||
|
m := NewMap(func(a, b string, data map[string]string) bool {
|
||||||
|
return a < b
|
||||||
|
})
|
||||||
|
|
||||||
|
m.Set("a", "b")
|
||||||
|
|
||||||
|
if !m.Has("a") {
|
||||||
|
t.Fatal("expected to have key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Has("b") {
|
||||||
|
t.Fatal("expected not to have key")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Set("b", "c")
|
||||||
|
|
||||||
|
if !m.Has("b") {
|
||||||
|
t.Fatal("expected to have key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.HasFunc(func(key string, val string) bool {
|
||||||
|
return key == "b"
|
||||||
|
}) {
|
||||||
|
t.Fatal("expected to have key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.Get("b", func(val string) {
|
||||||
|
if val != "c" {
|
||||||
|
t.Fatal("expected to have value")
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
t.Fatal("expected to have key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.Index(0, func(key string, val string) {
|
||||||
|
if key != "a" || val != "b" {
|
||||||
|
t.Fatal("expected to have key and value")
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
t.Fatal("expected to have index")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.Index(1, func(key string, val string) {
|
||||||
|
if key != "b" || val != "c" {
|
||||||
|
t.Fatal("expected to have key and value")
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
t.Fatal("expected to have index")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Index(2, func(key string, val string) {
|
||||||
|
t.Fatal("expected not to have index")
|
||||||
|
}) {
|
||||||
|
t.Fatal("expected not to have index")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.GetDelete("b", func(val string) {
|
||||||
|
if val != "c" {
|
||||||
|
t.Fatal("expected to have value")
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
t.Fatal("expected to have key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Has("b") {
|
||||||
|
t.Fatal("expected not to have key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.GetDelete("b", func(val string) {
|
||||||
|
t.Fatal("expected not to have value")
|
||||||
|
}) {
|
||||||
|
t.Fatal("expected not to have key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.Index(0, func(key string, val string) {
|
||||||
|
if key != "a" || val != "b" {
|
||||||
|
t.Fatal("expected to have key and value")
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
t.Fatal("expected to have index")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Values(func(val []string) {
|
||||||
|
if len(val) != 1 {
|
||||||
|
t.Fatal("expected to have values")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val[0] != "b" {
|
||||||
|
t.Fatal("expected to have value")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -1,46 +0,0 @@
|
|||||||
package safe
|
|
||||||
|
|
||||||
import "golang.org/x/exp/maps"
|
|
||||||
|
|
||||||
type Set[Val comparable] Map[Val, struct{}]
|
|
||||||
|
|
||||||
func NewSet[Val comparable](vals ...Val) *Set[Val] {
|
|
||||||
set := (*Set[Val])(NewMap[Val, struct{}](nil))
|
|
||||||
|
|
||||||
for _, val := range vals {
|
|
||||||
set.Insert(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
return set
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Set[Val]) Has(key Val) bool {
|
|
||||||
m.lock.RLock()
|
|
||||||
defer m.lock.RUnlock()
|
|
||||||
|
|
||||||
_, ok := m.data[key]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Set[Val]) Insert(key Val) {
|
|
||||||
m.lock.Lock()
|
|
||||||
defer m.lock.Unlock()
|
|
||||||
|
|
||||||
m.data[key] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Set[Val]) Iter(fn func(key Val)) {
|
|
||||||
m.lock.RLock()
|
|
||||||
defer m.lock.RUnlock()
|
|
||||||
|
|
||||||
for key := range m.data {
|
|
||||||
fn(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Set[Val]) Values(fn func(vals []Val)) {
|
|
||||||
m.lock.RLock()
|
|
||||||
defer m.lock.RUnlock()
|
|
||||||
|
|
||||||
fn(maps.Keys(m.data))
|
|
||||||
}
|
|
||||||
@ -1,13 +1,17 @@
|
|||||||
package safe
|
package safe
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
type Slice[Val any] struct {
|
"github.com/bradenaw/juniper/xslices"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Slice[Val comparable] struct {
|
||||||
data []Val
|
data []Val
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSlice[Val any](from []Val) *Slice[Val] {
|
func NewSlice[Val comparable](from ...Val) *Slice[Val] {
|
||||||
s := &Slice[Val]{
|
s := &Slice[Val]{
|
||||||
data: make([]Val, len(from)),
|
data: make([]Val, len(from)),
|
||||||
}
|
}
|
||||||
@ -17,37 +21,27 @@ func NewSlice[Val any](from []Val) *Slice[Val] {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Slice[Val]) Get(fn func(data []Val)) {
|
func (s *Slice[Val]) Iter(fn func(val Val)) {
|
||||||
s.lock.RLock()
|
s.lock.RLock()
|
||||||
defer s.lock.RUnlock()
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
fn(s.data)
|
for _, val := range s.data {
|
||||||
|
fn(val)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Slice[Val]) GetErr(fn func(data []Val) error) error {
|
func (s *Slice[Val]) Append(val Val) {
|
||||||
s.lock.RLock()
|
|
||||||
defer s.lock.RUnlock()
|
|
||||||
|
|
||||||
return fn(s.data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Slice[Val]) Set(data []Val) {
|
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
s.data = data
|
s.data = append(s.data, val)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSlice[Val, Ret any](s *Slice[Val], fn func(data []Val) Ret) Ret {
|
func (s *Slice[Val]) Delete(val Val) {
|
||||||
s.lock.RLock()
|
s.lock.Lock()
|
||||||
defer s.lock.RUnlock()
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
return fn(s.data)
|
s.data = xslices.Filter(s.data, func(v Val) bool {
|
||||||
}
|
return v != val
|
||||||
|
})
|
||||||
func GetSliceErr[Val, Ret any](s *Slice[Val], fn func(data []Val) (Ret, error)) (Ret, error) {
|
|
||||||
s.lock.RLock()
|
|
||||||
defer s.lock.RUnlock()
|
|
||||||
|
|
||||||
return fn(s.data)
|
|
||||||
}
|
}
|
||||||
|
|||||||
34
internal/safe/slice_test.go
Normal file
34
internal/safe/slice_test.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package safe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSlice(t *testing.T) {
|
||||||
|
s := NewSlice(1, 2, 3, 4, 5)
|
||||||
|
|
||||||
|
{
|
||||||
|
var have []int
|
||||||
|
|
||||||
|
s.Iter(func(val int) {
|
||||||
|
have = append(have, val)
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, []int{1, 2, 3, 4, 5}, have)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Append(6)
|
||||||
|
s.Delete(3)
|
||||||
|
|
||||||
|
{
|
||||||
|
var have []int
|
||||||
|
|
||||||
|
s.Iter(func(val int) {
|
||||||
|
have = append(have, val)
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, []int{1, 2, 4, 5, 6}, have)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -13,37 +13,57 @@ func NewValue[T any](data T) *Value[T] {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Value[T]) Get(fn func(data T)) {
|
func (s *Value[T]) Load(fn func(data T)) {
|
||||||
s.lock.RLock()
|
s.lock.RLock()
|
||||||
defer s.lock.RUnlock()
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
fn(s.data)
|
fn(s.data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Value[T]) GetErr(fn func(data T) error) error {
|
func (s *Value[T]) LoadErr(fn func(data T) error) error {
|
||||||
s.lock.RLock()
|
var err error
|
||||||
defer s.lock.RUnlock()
|
|
||||||
|
|
||||||
return fn(s.data)
|
s.Load(func(data T) {
|
||||||
|
err = fn(data)
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Value[T]) Set(data T) {
|
func (s *Value[T]) Save(data T) {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
s.data = data
|
s.data = data
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetType[T, Ret any](s *Value[T], fn func(data T) Ret) Ret {
|
func (s *Value[T]) Mod(fn func(data *T)) {
|
||||||
s.lock.RLock()
|
s.lock.Lock()
|
||||||
defer s.lock.RUnlock()
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
return fn(s.data)
|
fn(&s.data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTypeErr[T, Ret any](s *Value[T], fn func(data T) (Ret, error)) (Ret, error) {
|
func LoadRet[T, Ret any](s *Value[T], fn func(data T) Ret) Ret {
|
||||||
s.lock.RLock()
|
var ret Ret
|
||||||
defer s.lock.RUnlock()
|
|
||||||
|
|
||||||
return fn(s.data)
|
s.Load(func(data T) {
|
||||||
|
ret = fn(data)
|
||||||
|
})
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadRetErr[T, Ret any](s *Value[T], fn func(data T) (Ret, error)) (Ret, error) {
|
||||||
|
var ret Ret
|
||||||
|
|
||||||
|
err := s.LoadErr(func(data T) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
ret, err = fn(data)
|
||||||
|
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
return ret, err
|
||||||
}
|
}
|
||||||
|
|||||||
37
internal/safe/value_test.go
Normal file
37
internal/safe/value_test.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package safe
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestValue(t *testing.T) {
|
||||||
|
v := NewValue("foo")
|
||||||
|
|
||||||
|
v.Load(func(data string) {
|
||||||
|
if data != "foo" {
|
||||||
|
t.Error("expected foo")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
v.Save("bar")
|
||||||
|
|
||||||
|
v.Load(func(data string) {
|
||||||
|
if data != "bar" {
|
||||||
|
t.Error("expected bar")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
v.Mod(func(data *string) {
|
||||||
|
*data = "baz"
|
||||||
|
})
|
||||||
|
|
||||||
|
v.Load(func(data string) {
|
||||||
|
if data != "baz" {
|
||||||
|
t.Error("expected baz")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if LoadRet(v, func(data string) string {
|
||||||
|
return data
|
||||||
|
}) != "baz" {
|
||||||
|
t.Error("expected baz")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -2,6 +2,7 @@ package try
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@ -47,3 +48,31 @@ func catch(handlers ...func() error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Group struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wg *Group) GoTry(fn func(bool)) {
|
||||||
|
if wg.mu.TryLock() {
|
||||||
|
go func() {
|
||||||
|
defer wg.mu.Unlock()
|
||||||
|
fn(true)
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
go fn(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wg *Group) Lock() {
|
||||||
|
wg.mu.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wg *Group) Unlock() {
|
||||||
|
wg.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wg *Group) Wait() {
|
||||||
|
wg.mu.Lock()
|
||||||
|
defer wg.mu.Unlock()
|
||||||
|
}
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/ProtonMail/gluon/queue"
|
"github.com/ProtonMail/gluon/queue"
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
"gitlab.protontech.ch/go/liteapi"
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
@ -28,12 +27,6 @@ func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.MailSettings != nil {
|
|
||||||
if err := user.handleMailSettingsEvent(ctx, *event.MailSettings); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(event.Labels) > 0 {
|
if len(event.Labels) > 0 {
|
||||||
if err := user.handleLabelEvents(ctx, event.Labels); err != nil {
|
if err := user.handleLabelEvents(ctx, event.Labels); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -51,14 +44,7 @@ func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error
|
|||||||
|
|
||||||
// handleUserEvent handles the given user event.
|
// handleUserEvent handles the given user event.
|
||||||
func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) error {
|
func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) error {
|
||||||
userKR, err := userEvent.Keys.Unlock(user.vault.KeyPass(), nil)
|
user.apiUser.Save(userEvent)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
user.apiUser.Set(userEvent)
|
|
||||||
|
|
||||||
user.userKR.Set(userKR)
|
|
||||||
|
|
||||||
user.eventCh.Enqueue(events.UserChanged{
|
user.eventCh.Enqueue(events.UserChanged{
|
||||||
UserID: user.ID(),
|
UserID: user.ID(),
|
||||||
@ -93,22 +79,18 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||||
addrKR, err := safe.GetTypeErr(user.userKR, func(userKR *crypto.KeyRing) (*crypto.KeyRing, error) {
|
user.apiAddrs.Set(event.Address.ID, event.Address)
|
||||||
return event.Address.Keys.Unlock(user.vault.KeyPass(), userKR)
|
|
||||||
})
|
switch user.vault.AddressMode() {
|
||||||
if err != nil {
|
case vault.CombinedMode:
|
||||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) {
|
||||||
|
user.updateCh.SetFrom(event.Address.ID, addrID)
|
||||||
|
})
|
||||||
|
|
||||||
|
case vault.SplitMode:
|
||||||
|
user.updateCh.Set(event.Address.ID, queue.NewQueuedChannel[imap.Update](0, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get addresses: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user.apiAddrs.Set(apiAddrs)
|
|
||||||
|
|
||||||
user.addrKRs.Set(event.Address.ID, addrKR)
|
|
||||||
|
|
||||||
user.eventCh.Enqueue(events.UserAddressCreated{
|
user.eventCh.Enqueue(events.UserAddressCreated{
|
||||||
UserID: user.ID(),
|
UserID: user.ID(),
|
||||||
AddressID: event.Address.ID,
|
AddressID: event.Address.ID,
|
||||||
@ -116,9 +98,11 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad
|
|||||||
})
|
})
|
||||||
|
|
||||||
if user.vault.AddressMode() == vault.SplitMode {
|
if user.vault.AddressMode() == vault.SplitMode {
|
||||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
if ok, err := user.updateCh.GetErr(event.Address.ID, func(updateCh *queue.QueuedChannel[imap.Update]) error {
|
||||||
|
return syncLabels(ctx, user.client, updateCh)
|
||||||
if err := syncLabels(ctx, user.client, user.updateCh[event.Address.ID]); err != nil {
|
}); !ok {
|
||||||
|
return fmt.Errorf("no such address %q", event.Address.ID)
|
||||||
|
} else if err != nil {
|
||||||
return fmt.Errorf("failed to sync labels to new address: %w", err)
|
return fmt.Errorf("failed to sync labels to new address: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -127,21 +111,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||||
addrKR, err := safe.GetTypeErr(user.userKR, func(userKR *crypto.KeyRing) (*crypto.KeyRing, error) {
|
user.apiAddrs.Set(event.Address.ID, event.Address)
|
||||||
return event.Address.Keys.Unlock(user.vault.KeyPass(), userKR)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get addresses: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user.apiAddrs.Set(apiAddrs)
|
|
||||||
|
|
||||||
user.addrKRs.Set(event.Address.ID, addrKR)
|
|
||||||
|
|
||||||
user.eventCh.Enqueue(events.UserAddressUpdated{
|
user.eventCh.Enqueue(events.UserAddressUpdated{
|
||||||
UserID: user.ID(),
|
UserID: user.ID(),
|
||||||
@ -153,25 +123,20 @@ func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.Ad
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||||
email, err := safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
var email string
|
||||||
return getAddrEmail(apiAddrs, event.ID)
|
|
||||||
})
|
if ok := user.apiAddrs.GetDelete(event.ID, func(apiAddr liteapi.Address) {
|
||||||
if err != nil {
|
email = apiAddr.Email
|
||||||
return fmt.Errorf("failed to get address email: %w", err)
|
}); !ok {
|
||||||
|
return fmt.Errorf("no such address %q", event.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
if ok := user.updateCh.GetDelete(event.ID, func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||||
if err != nil {
|
if user.vault.AddressMode() == vault.SplitMode {
|
||||||
return fmt.Errorf("failed to get addresses: %w", err)
|
updateCh.Close()
|
||||||
}
|
}
|
||||||
|
}); !ok {
|
||||||
user.apiAddrs.Set(apiAddrs)
|
return fmt.Errorf("no such address %q", event.ID)
|
||||||
|
|
||||||
user.addrKRs.Delete(event.ID)
|
|
||||||
|
|
||||||
if len(user.updateCh) > 1 {
|
|
||||||
user.updateCh[event.ID].Close()
|
|
||||||
delete(user.updateCh, event.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
user.eventCh.Enqueue(events.UserAddressDeleted{
|
user.eventCh.Enqueue(events.UserAddressDeleted{
|
||||||
@ -183,13 +148,6 @@ func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.Ad
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleMailSettingsEvent handles the given mail settings event.
|
|
||||||
func (user *User) handleMailSettingsEvent(ctx context.Context, mailSettingsEvent liteapi.MailSettings) error {
|
|
||||||
user.settings.Set(mailSettingsEvent)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleLabelEvents handles the given label events.
|
// handleLabelEvents handles the given label events.
|
||||||
func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.LabelEvent) error {
|
func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.LabelEvent) error {
|
||||||
for _, event := range labelEvents {
|
for _, event := range labelEvents {
|
||||||
@ -215,25 +173,25 @@ func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.L
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) handleCreateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
func (user *User) handleCreateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
||||||
for _, updateCh := range user.updateCh {
|
user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||||
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label)))
|
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label)))
|
||||||
}
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) handleUpdateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
func (user *User) handleUpdateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
||||||
for _, updateCh := range user.updateCh {
|
user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||||
updateCh.Enqueue(imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label)))
|
updateCh.Enqueue(imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label)))
|
||||||
}
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) handleDeleteLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
func (user *User) handleDeleteLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
||||||
for _, updateCh := range user.updateCh {
|
user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||||
updateCh.Enqueue(imap.NewMailboxDeleted(imap.LabelID(event.ID)))
|
updateCh.Enqueue(imap.NewMailboxDeleted(imap.LabelID(event.ID)))
|
||||||
}
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -269,29 +227,18 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me
|
|||||||
return fmt.Errorf("failed to get full message: %w", err)
|
return fmt.Errorf("failed to get full message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
buildRes, err := safe.GetMapErr(
|
return user.withAddrKR(event.Message.AddressID, func(addrKR *crypto.KeyRing) error {
|
||||||
user.addrKRs,
|
buildRes, err := buildRFC822(ctx, full, addrKR)
|
||||||
full.AddressID,
|
if err != nil {
|
||||||
func(addrKR *crypto.KeyRing) (*buildRes, error) {
|
return fmt.Errorf("failed to build RFC822 message: %w", err)
|
||||||
return buildRFC822(ctx, full, addrKR)
|
}
|
||||||
},
|
|
||||||
func() (*buildRes, error) {
|
|
||||||
return nil, fmt.Errorf("address keyring not found")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to build RFC822: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(user.updateCh) > 1 {
|
user.updateCh.Get(full.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||||
user.updateCh[buildRes.addressID].Enqueue(imap.NewMessagesCreated(buildRes.update))
|
updateCh.Enqueue(imap.NewMessagesCreated(buildRes.update))
|
||||||
} else {
|
|
||||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
|
||||||
user.updateCh[apiAddrs[0].ID].Enqueue(imap.NewMessagesCreated(buildRes.update))
|
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
|
func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
|
||||||
@ -302,13 +249,9 @@ func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.Me
|
|||||||
event.Message.Starred(),
|
event.Message.Starred(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(user.updateCh) > 1 {
|
user.updateCh.Get(event.Message.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||||
user.updateCh[event.Message.AddressID].Enqueue(update)
|
updateCh.Enqueue(update)
|
||||||
} else {
|
})
|
||||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
|
||||||
user.updateCh[apiAddrs[0].ID].Enqueue(update)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,13 +2,13 @@ package user
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/subtle"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/imap"
|
"github.com/ProtonMail/gluon/imap"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/ProtonMail/gluon/queue"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"gitlab.protontech.ch/go/liteapi"
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
@ -25,27 +25,18 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type imapConnector struct {
|
type imapConnector struct {
|
||||||
client *liteapi.Client
|
*User
|
||||||
updateCh <-chan imap.Update
|
|
||||||
|
|
||||||
emails []string
|
addrID string
|
||||||
password []byte
|
|
||||||
|
|
||||||
flags, permFlags, attrs imap.FlagSet
|
flags, permFlags, attrs imap.FlagSet
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIMAPConnector(
|
func newIMAPConnector(user *User, addrID string) *imapConnector {
|
||||||
client *liteapi.Client,
|
|
||||||
updateCh <-chan imap.Update,
|
|
||||||
password []byte,
|
|
||||||
emails ...string,
|
|
||||||
) *imapConnector {
|
|
||||||
return &imapConnector{
|
return &imapConnector{
|
||||||
client: client,
|
User: user,
|
||||||
updateCh: updateCh,
|
|
||||||
|
|
||||||
emails: emails,
|
addrID: addrID,
|
||||||
password: password,
|
|
||||||
|
|
||||||
flags: defaultFlags,
|
flags: defaultFlags,
|
||||||
permFlags: defaultPermanentFlags,
|
permFlags: defaultPermanentFlags,
|
||||||
@ -55,13 +46,16 @@ func newIMAPConnector(
|
|||||||
|
|
||||||
// Authorize returns whether the given username/password combination are valid for this connector.
|
// Authorize returns whether the given username/password combination are valid for this connector.
|
||||||
func (conn *imapConnector) Authorize(username string, password []byte) bool {
|
func (conn *imapConnector) Authorize(username string, password []byte) bool {
|
||||||
if subtle.ConstantTimeCompare(conn.password, password) != 1 {
|
addrID, err := conn.checkAuth(username, password)
|
||||||
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return xslices.IndexFunc(conn.emails, func(address string) bool {
|
if conn.vault.AddressMode() == vault.SplitMode && addrID != conn.addrID {
|
||||||
return strings.EqualFold(address, username)
|
return false
|
||||||
}) >= 0
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLabel returns information about the label with the given ID.
|
// GetLabel returns information about the label with the given ID.
|
||||||
@ -246,7 +240,14 @@ func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs [
|
|||||||
// GetUpdates returns a stream of updates that the gluon server should apply.
|
// GetUpdates returns a stream of updates that the gluon server should apply.
|
||||||
// It is recommended that the returned channel is buffered with at least constants.ChannelBufferCount.
|
// It is recommended that the returned channel is buffered with at least constants.ChannelBufferCount.
|
||||||
func (conn *imapConnector) GetUpdates() <-chan imap.Update {
|
func (conn *imapConnector) GetUpdates() <-chan imap.Update {
|
||||||
return conn.updateCh
|
updateCh, ok := safe.MapGetRet(conn.updateCh, conn.addrID, func(updateCh *queue.QueuedChannel[imap.Update]) <-chan imap.Update {
|
||||||
|
return updateCh.GetChannel()
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("update channel for %q not found", conn.addrID))
|
||||||
|
}
|
||||||
|
|
||||||
|
return updateCh
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUIDValidity returns the default UID validity for this user.
|
// GetUIDValidity returns the default UID validity for this user.
|
||||||
|
|||||||
60
internal/user/keys.go
Normal file
60
internal/user/keys.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (user *User) withUserKR(fn func(*crypto.KeyRing) error) error {
|
||||||
|
return user.apiUser.LoadErr(func(apiUser liteapi.User) error {
|
||||||
|
userKR, err := apiUser.Keys.Unlock(user.vault.KeyPass(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unlock user keys: %w", err)
|
||||||
|
}
|
||||||
|
defer userKR.ClearPrivateParams()
|
||||||
|
|
||||||
|
return fn(userKR)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing) error) error {
|
||||||
|
return user.withUserKR(func(userKR *crypto.KeyRing) error {
|
||||||
|
if ok, err := user.apiAddrs.GetErr(addrID, func(apiAddr liteapi.Address) error {
|
||||||
|
addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||||
|
}
|
||||||
|
defer userKR.ClearPrivateParams()
|
||||||
|
|
||||||
|
return fn(addrKR)
|
||||||
|
}); !ok {
|
||||||
|
return fmt.Errorf("no such address %q", addrID)
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *User) withAddrKRs(fn func(map[string]*crypto.KeyRing) error) error {
|
||||||
|
return user.withUserKR(func(userKR *crypto.KeyRing) error {
|
||||||
|
return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error {
|
||||||
|
addrKRs := make(map[string]*crypto.KeyRing)
|
||||||
|
|
||||||
|
for _, apiAddr := range apiAddrs {
|
||||||
|
addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||||
|
}
|
||||||
|
defer userKR.ClearPrivateParams()
|
||||||
|
|
||||||
|
addrKRs[apiAddr.ID] = addrKR
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn(addrKRs)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -34,22 +34,25 @@ type smtpSession struct {
|
|||||||
// from is the current sending address (taken from the return path).
|
// from is the current sending address (taken from the return path).
|
||||||
from string
|
from string
|
||||||
|
|
||||||
|
// fromAddrID is the ID of the curent sending address (taken from the return path).
|
||||||
|
fromAddrID string
|
||||||
|
|
||||||
// to holds all to for the current message.
|
// to holds all to for the current message.
|
||||||
to []string
|
to []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSMTPSession(user *User, email string) (*smtpSession, error) {
|
func newSMTPSession(user *User, email string) (*smtpSession, error) {
|
||||||
authID, err := safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
return safe.MapValuesRetErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (*smtpSession, error) {
|
||||||
return getAddrID(apiAddrs, email)
|
authID, err := getAddrID(apiAddrs, email)
|
||||||
})
|
if err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("failed to get address ID: %w", err)
|
||||||
return nil, fmt.Errorf("failed to get address ID: %w", err)
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return &smtpSession{
|
return &smtpSession{
|
||||||
User: user,
|
User: user,
|
||||||
authID: authID,
|
authID: authID,
|
||||||
}, nil
|
}, nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Discard currently processed message.
|
// Discard currently processed message.
|
||||||
@ -58,6 +61,7 @@ func (session *smtpSession) Reset() {
|
|||||||
|
|
||||||
// Clear the from and to fields.
|
// Clear the from and to fields.
|
||||||
session.from = ""
|
session.from = ""
|
||||||
|
session.fromAddrID = ""
|
||||||
session.to = nil
|
session.to = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,7 +78,7 @@ func (session *smtpSession) Logout() error {
|
|||||||
func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
||||||
logrus.Info("SMTP session mail")
|
logrus.Info("SMTP session mail")
|
||||||
|
|
||||||
return session.apiAddrs.GetErr(func(apiAddrs []liteapi.Address) error {
|
return session.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error {
|
||||||
switch {
|
switch {
|
||||||
case opts.RequireTLS:
|
case opts.RequireTLS:
|
||||||
return ErrNotImplemented
|
return ErrNotImplemented
|
||||||
@ -93,12 +97,15 @@ func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := getAddrID(apiAddrs, sanitizeEmail(from)); err != nil {
|
addrID, err := getAddrID(apiAddrs, sanitizeEmail(from))
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("invalid return path: %w", err)
|
return fmt.Errorf("invalid return path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
session.from = from
|
session.from = from
|
||||||
|
|
||||||
|
session.fromAddrID = addrID
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -138,18 +145,13 @@ func (session *smtpSession) Data(r io.Reader) error {
|
|||||||
return fmt.Errorf("failed to create parser: %w", err)
|
return fmt.Errorf("failed to create parser: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
message, err := safe.GetSliceErr(session.apiAddrs, func(apiAddrs []liteapi.Address) (liteapi.Message, error) {
|
return session.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error {
|
||||||
addrID, err := getAddrID(apiAddrs, session.from)
|
return session.withAddrKR(session.fromAddrID, func(addrKR *crypto.KeyRing) error {
|
||||||
if err != nil {
|
return session.withUserKR(func(userKR *crypto.KeyRing) error {
|
||||||
return liteapi.Message{}, fmt.Errorf("invalid return path: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return safe.GetMapErr(session.addrKRs, addrID, func(addrKR *crypto.KeyRing) (liteapi.Message, error) {
|
|
||||||
return safe.GetTypeErr(session.settings, func(settings liteapi.MailSettings) (liteapi.Message, error) {
|
|
||||||
// Use the first key for encrypting the message.
|
// Use the first key for encrypting the message.
|
||||||
addrKR, err := addrKR.FirstKey()
|
addrKR, err := addrKR.FirstKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return liteapi.Message{}, fmt.Errorf("failed to get first key: %w", err)
|
return fmt.Errorf("failed to get first key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the message contains a sender, use it instead of the one from the return path.
|
// If the message contains a sender, use it instead of the one from the return path.
|
||||||
@ -157,51 +159,61 @@ func (session *smtpSession) Data(r io.Reader) error {
|
|||||||
session.from = sender
|
session.from = sender
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load the user's mail settings.
|
||||||
|
settings, err := session.client.GetMailSettings(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get mail settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// If we have to attach the public key, do it now.
|
// If we have to attach the public key, do it now.
|
||||||
if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled {
|
if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled {
|
||||||
key, err := addrKR.GetKey(0)
|
key, err := addrKR.GetKey(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return liteapi.Message{}, fmt.Errorf("failed to get sending key: %w", err)
|
return fmt.Errorf("failed to get sending key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey, err := key.GetArmoredPublicKey()
|
pubKey, err := key.GetArmoredPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return liteapi.Message{}, fmt.Errorf("failed to get public key: %w", err)
|
return fmt.Errorf("failed to get public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
|
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse the message we want to send (after we have attached the public key).
|
||||||
message, err := message.ParseWithParser(parser)
|
message, err := message.ParseWithParser(parser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return liteapi.Message{}, fmt.Errorf("failed to parse message: %w", err)
|
return fmt.Errorf("failed to parse message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sendWithKey(
|
// Collect all the user's emails so we can match them to the outgoing message.
|
||||||
|
emails := xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||||
|
return addr.Email
|
||||||
|
})
|
||||||
|
|
||||||
|
sent, err := sendWithKey(
|
||||||
ctx,
|
ctx,
|
||||||
session.client,
|
session.client,
|
||||||
session.authID,
|
session.authID,
|
||||||
session.vault.AddressMode(),
|
session.vault.AddressMode(),
|
||||||
apiAddrs,
|
|
||||||
settings,
|
settings,
|
||||||
session.userKR,
|
userKR,
|
||||||
addrKR,
|
addrKR,
|
||||||
|
emails,
|
||||||
session.from,
|
session.from,
|
||||||
session.to,
|
session.to,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.WithField("messageID", sent.ID).Info("Message sent")
|
||||||
|
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
}, func() (liteapi.Message, error) {
|
|
||||||
return liteapi.Message{}, ErrMissingAddrKey
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to send message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logrus.WithField("messageID", message.ID).Info("Message sent")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendWithKey sends the message with the given address key.
|
// sendWithKey sends the message with the given address key.
|
||||||
@ -210,10 +222,10 @@ func sendWithKey(
|
|||||||
client *liteapi.Client,
|
client *liteapi.Client,
|
||||||
authAddrID string,
|
authAddrID string,
|
||||||
addrMode vault.AddressMode,
|
addrMode vault.AddressMode,
|
||||||
apiAddrs []liteapi.Address,
|
|
||||||
settings liteapi.MailSettings,
|
settings liteapi.MailSettings,
|
||||||
userKR *safe.Value[*crypto.KeyRing],
|
userKR *crypto.KeyRing,
|
||||||
addrKR *crypto.KeyRing,
|
addrKR *crypto.KeyRing,
|
||||||
|
emails []string,
|
||||||
from string,
|
from string,
|
||||||
to []string,
|
to []string,
|
||||||
message message.Message,
|
message message.Message,
|
||||||
@ -243,7 +255,7 @@ func sendWithKey(
|
|||||||
return liteapi.Message{}, fmt.Errorf("failed to get armored message body: %w", err)
|
return liteapi.Message{}, fmt.Errorf("failed to get armored message body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
draft, err := createDraft(ctx, client, apiAddrs, from, to, parentID, liteapi.DraftTemplate{
|
draft, err := createDraft(ctx, client, emails, from, to, parentID, liteapi.DraftTemplate{
|
||||||
Subject: message.Subject,
|
Subject: message.Subject,
|
||||||
Body: armBody,
|
Body: armBody,
|
||||||
MIMEType: message.MIMEType,
|
MIMEType: message.MIMEType,
|
||||||
@ -264,9 +276,7 @@ func sendWithKey(
|
|||||||
return liteapi.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
return liteapi.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
recipients, err := safe.GetTypeErr(userKR, func(userKR *crypto.KeyRing) (recipients, error) {
|
recipients, err := getRecipients(ctx, client, userKR, settings, draft)
|
||||||
return getRecipients(ctx, client, userKR, settings, draft)
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return liteapi.Message{}, fmt.Errorf("failed to get recipients: %w", err)
|
return liteapi.Message{}, fmt.Errorf("failed to get recipients: %w", err)
|
||||||
}
|
}
|
||||||
@ -357,7 +367,7 @@ func getParentID(
|
|||||||
func createDraft(
|
func createDraft(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *liteapi.Client,
|
client *liteapi.Client,
|
||||||
apiAddrs []liteapi.Address,
|
emails []string,
|
||||||
from string,
|
from string,
|
||||||
to []string,
|
to []string,
|
||||||
parentID string,
|
parentID string,
|
||||||
@ -371,12 +381,12 @@ func createDraft(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check that the sending address is owned by the user, and if so, sanitize it.
|
// Check that the sending address is owned by the user, and if so, sanitize it.
|
||||||
if idx := xslices.IndexFunc(apiAddrs, func(addr liteapi.Address) bool {
|
if idx := xslices.IndexFunc(emails, func(email string) bool {
|
||||||
return strings.EqualFold(addr.Email, sanitizeEmail(template.Sender.Address))
|
return strings.EqualFold(email, sanitizeEmail(template.Sender.Address))
|
||||||
}); idx < 0 {
|
}); idx < 0 {
|
||||||
return liteapi.Message{}, fmt.Errorf("address %q is not owned by user", template.Sender.Address)
|
return liteapi.Message{}, fmt.Errorf("address %q is not owned by user", template.Sender.Address)
|
||||||
} else {
|
} else {
|
||||||
template.Sender.Address = constructEmail(template.Sender.Address, apiAddrs[idx].Email)
|
template.Sender.Address = constructEmail(template.Sender.Address, emails[idx])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check ToList: ensure that ToList only contains addresses we actually plan to send to.
|
// Check ToList: ensure that ToList only contains addresses we actually plan to send to.
|
||||||
|
|||||||
@ -10,12 +10,13 @@ import (
|
|||||||
"github.com/ProtonMail/gluon/imap"
|
"github.com/ProtonMail/gluon/imap"
|
||||||
"github.com/ProtonMail/gluon/queue"
|
"github.com/ProtonMail/gluon/queue"
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"github.com/bradenaw/juniper/stream"
|
"github.com/bradenaw/juniper/stream"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"gitlab.protontech.ch/go/liteapi"
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -24,27 +25,43 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (user *User) sync(ctx context.Context) error {
|
func (user *User) sync(ctx context.Context) error {
|
||||||
if !user.vault.SyncStatus().HasLabels {
|
return user.withAddrKRs(func(addrKRs map[string]*crypto.KeyRing) error {
|
||||||
if err := syncLabels(ctx, user.client, maps.Values(user.updateCh)...); err != nil {
|
logrus.Info("Beginning sync")
|
||||||
return fmt.Errorf("failed to sync labels: %w", err)
|
|
||||||
|
if !user.vault.SyncStatus().HasLabels {
|
||||||
|
logrus.Info("Syncing labels")
|
||||||
|
|
||||||
|
if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error {
|
||||||
|
return syncLabels(ctx, user.client, xslices.Unique(updateCh)...)
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to sync labels: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.vault.SetHasLabels(true); err != nil {
|
||||||
|
return fmt.Errorf("failed to set has labels: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logrus.Info("Labels are already synced, skipping")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.vault.SetHasLabels(true); err != nil {
|
if !user.vault.SyncStatus().HasMessages {
|
||||||
return fmt.Errorf("failed to set has labels: %w", err)
|
logrus.Info("Syncing labels")
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.vault.SyncStatus().HasMessages {
|
if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error {
|
||||||
if err := user.syncMessages(ctx); err != nil {
|
return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh)
|
||||||
return fmt.Errorf("failed to sync messages: %w", err)
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to sync messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.vault.SetHasMessages(true); err != nil {
|
||||||
|
return fmt.Errorf("failed to set has messages: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logrus.Info("Messages are already synced, skipping")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.vault.SetHasMessages(true); err != nil {
|
return nil
|
||||||
return fmt.Errorf("failed to set has messages: %w", err)
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error {
|
func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error {
|
||||||
@ -102,48 +119,44 @@ func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) syncMessages(ctx context.Context) error {
|
func syncMessages(
|
||||||
|
ctx context.Context,
|
||||||
|
userID string,
|
||||||
|
client *liteapi.Client,
|
||||||
|
vault *vault.User,
|
||||||
|
addrKRs map[string]*crypto.KeyRing,
|
||||||
|
updateCh map[string]*queue.QueuedChannel[imap.Update],
|
||||||
|
eventCh *queue.QueuedChannel[events.Event],
|
||||||
|
) error {
|
||||||
// Determine which messages to sync.
|
// Determine which messages to sync.
|
||||||
allMetadata, err := user.client.GetAllMessageMetadata(ctx, nil)
|
metadata, err := client.GetAllMessageMetadata(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get all message metadata: %w", err)
|
return fmt.Errorf("get all message metadata: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata := allMetadata
|
// Get the message IDs to sync.
|
||||||
|
messageIDs := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string {
|
||||||
|
return metadata.ID
|
||||||
|
})
|
||||||
|
|
||||||
// If possible, begin syncing from one beyond the last synced message.
|
// If possible, begin syncing from one beyond the last synced message.
|
||||||
if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" {
|
if idx := xslices.Index(messageIDs, vault.SyncStatus().LastMessageID); idx >= 0 {
|
||||||
if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool {
|
messageIDs = messageIDs[idx+1:]
|
||||||
return metadata.ID == beginID
|
|
||||||
}); idx >= 0 {
|
|
||||||
metadata = metadata[idx+1:]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the metadata, building the messages.
|
// Fetch and build each message.
|
||||||
buildCh := stream.Chunk(stream.Map(
|
buildCh := stream.Map(
|
||||||
user.client.GetFullMessages(ctx, xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string {
|
client.GetFullMessages(ctx, messageIDs...),
|
||||||
return metadata.ID
|
|
||||||
})...),
|
|
||||||
func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) {
|
func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) {
|
||||||
return safe.GetMapErr(
|
return buildRFC822(ctx, full, addrKRs[full.AddressID])
|
||||||
user.addrKRs,
|
|
||||||
full.AddressID,
|
|
||||||
func(addrKR *crypto.KeyRing) (*buildRes, error) {
|
|
||||||
return buildRFC822(ctx, full, addrKR)
|
|
||||||
},
|
|
||||||
func() (*buildRes, error) {
|
|
||||||
return nil, fmt.Errorf("address keyring not found")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
},
|
},
|
||||||
), maxBatchSize)
|
)
|
||||||
defer buildCh.Close()
|
defer buildCh.Close()
|
||||||
|
|
||||||
// Create the flushers, one per update channel.
|
// Create the flushers, one per update channel.
|
||||||
flushers := make(map[string]*flusher)
|
flushers := make(map[string]*flusher)
|
||||||
|
|
||||||
for addrID, updateCh := range user.updateCh {
|
for addrID, updateCh := range updateCh {
|
||||||
flusher := newFlusher(updateCh, maxUpdateSize)
|
flusher := newFlusher(updateCh, maxUpdateSize)
|
||||||
defer flusher.flush(ctx, true)
|
defer flusher.flush(ctx, true)
|
||||||
|
|
||||||
@ -151,42 +164,27 @@ func (user *User) syncMessages(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a reporter to report sync progress updates.
|
// Create a reporter to report sync progress updates.
|
||||||
reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second)
|
reporter := newReporter(userID, eventCh, len(messageIDs), time.Second)
|
||||||
defer reporter.done()
|
defer reporter.done()
|
||||||
|
|
||||||
var count int
|
|
||||||
|
|
||||||
// Send each update to the appropriate flusher.
|
// Send each update to the appropriate flusher.
|
||||||
for {
|
return forEach(ctx, stream.Chunk(buildCh, maxBatchSize), func(batch []*buildRes) error {
|
||||||
batch, err := buildCh.Next(ctx)
|
for _, res := range batch {
|
||||||
if errors.Is(err, stream.End) {
|
flushers[res.addressID].push(ctx, res.update)
|
||||||
return nil
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("failed to get next sync batch: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
|
||||||
for _, res := range batch {
|
|
||||||
if len(flushers) > 1 {
|
|
||||||
flushers[res.addressID].push(ctx, res.update)
|
|
||||||
} else {
|
|
||||||
flushers[apiAddrs[0].ID].push(ctx, res.update)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, flusher := range flushers {
|
for _, flusher := range flushers {
|
||||||
flusher.flush(ctx, true)
|
flusher.flush(ctx, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil {
|
if err := vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil {
|
||||||
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter.add(len(batch))
|
reporter.add(len(batch))
|
||||||
|
|
||||||
count += len(batch)
|
return nil
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated {
|
func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated {
|
||||||
@ -232,3 +230,18 @@ func wantLabelID(labelID string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func forEach[T any](ctx context.Context, streamer stream.Stream[T], fn func(T) error) error {
|
||||||
|
for {
|
||||||
|
res, err := streamer.Next(ctx)
|
||||||
|
if errors.Is(err, stream.End) {
|
||||||
|
return nil
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("failed to get next stream item: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fn(res); err != nil {
|
||||||
|
return fmt.Errorf("failed to process stream item: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -1,10 +1,17 @@
|
|||||||
package user
|
package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// mapTo converts the slice to the given type.
|
||||||
|
// This is not runtime safe, so make sure the slice is of the correct type!
|
||||||
|
// (This is a workaround for the fact that slices cannot be converted to other types generically).
|
||||||
func mapTo[From, To any](from []From) []To {
|
func mapTo[From, To any](from []From) []To {
|
||||||
to := make([]To, 0, len(from))
|
to := make([]To, 0, len(from))
|
||||||
|
|
||||||
@ -19,3 +26,79 @@ func mapTo[From, To any](from []From) []To {
|
|||||||
|
|
||||||
return to
|
return to
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// groupBy returns a map of the given slice grouped by the given key.
|
||||||
|
// Duplicate keys are overwritten.
|
||||||
|
func groupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[Key]Value {
|
||||||
|
groups := make(map[Key]Value)
|
||||||
|
|
||||||
|
for _, item := range items {
|
||||||
|
groups[key(item)] = item
|
||||||
|
}
|
||||||
|
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortAddr returns whether the first address should be sorted before the second.
|
||||||
|
func sortAddr(addrIDA, addrIDB string, apiAddrs map[string]liteapi.Address) bool {
|
||||||
|
return apiAddrs[addrIDA].Order < apiAddrs[addrIDB].Order
|
||||||
|
}
|
||||||
|
|
||||||
|
// hexEncode returns the hexadecimal encoding of the given byte slice.
|
||||||
|
func hexEncode(b []byte) []byte {
|
||||||
|
enc := make([]byte, hex.EncodedLen(len(b)))
|
||||||
|
|
||||||
|
hex.Encode(enc, b)
|
||||||
|
|
||||||
|
return enc
|
||||||
|
}
|
||||||
|
|
||||||
|
// hexDecode returns the bytes represented by the hexadecimal encoding of the given byte slice.
|
||||||
|
func hexDecode(b []byte) ([]byte, error) {
|
||||||
|
dec := make([]byte, hex.DecodedLen(len(b)))
|
||||||
|
|
||||||
|
if _, err := hex.Decode(dec, b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAddrID returns the address ID for the given email address.
|
||||||
|
func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
|
||||||
|
for _, addr := range apiAddrs {
|
||||||
|
if addr.Email == email {
|
||||||
|
return addr.ID, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("address %s not found", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAddrEmail returns the email address of the given address ID.
|
||||||
|
func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) {
|
||||||
|
for _, addr := range apiAddrs {
|
||||||
|
if addr.ID == addrID {
|
||||||
|
return addr.Email, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("address %s not found", addrID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
|
||||||
|
func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-stopCh:
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ctx, cancel
|
||||||
|
}
|
||||||
|
|||||||
@ -1,19 +1,18 @@
|
|||||||
package user
|
package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
"crypto/subtle"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/connector"
|
"github.com/ProtonMail/gluon/connector"
|
||||||
"github.com/ProtonMail/gluon/imap"
|
"github.com/ProtonMail/gluon/imap"
|
||||||
"github.com/ProtonMail/gluon/queue"
|
"github.com/ProtonMail/gluon/queue"
|
||||||
"github.com/ProtonMail/gluon/wait"
|
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/try"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
@ -32,15 +31,11 @@ type User struct {
|
|||||||
eventCh *queue.QueuedChannel[events.Event]
|
eventCh *queue.QueuedChannel[events.Event]
|
||||||
|
|
||||||
apiUser *safe.Value[liteapi.User]
|
apiUser *safe.Value[liteapi.User]
|
||||||
apiAddrs *safe.Slice[liteapi.Address]
|
apiAddrs *safe.Map[string, liteapi.Address]
|
||||||
settings *safe.Value[liteapi.MailSettings]
|
updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]]
|
||||||
|
|
||||||
userKR *safe.Value[*crypto.KeyRing]
|
|
||||||
addrKRs *safe.Map[string, *crypto.KeyRing]
|
|
||||||
|
|
||||||
updateCh map[string]*queue.QueuedChannel[imap.Update]
|
|
||||||
syncStopCh chan struct{}
|
syncStopCh chan struct{}
|
||||||
syncWG wait.Group
|
syncLock try.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User) (*User, error) {
|
func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User) (*User, error) {
|
||||||
@ -50,9 +45,8 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
|||||||
return nil, fmt.Errorf("failed to get addresses: %w", err)
|
return nil, fmt.Errorf("failed to get addresses: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unlock the user's keyrings.
|
// Check we can unlock the keyrings.
|
||||||
userKR, addrKRs, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass())
|
if _, _, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass()); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to unlock user: %w", err)
|
return nil, fmt.Errorf("failed to unlock user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,20 +62,21 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the user's mail settings.
|
// Create update channels for each of the user's addresses.
|
||||||
settings, err := client.GetMailSettings(ctx)
|
// In combined mode, the addresses all share the same update channel.
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get mail settings: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create update channels for each of the user's addresses (if in combined mode, just the primary).
|
|
||||||
updateCh := make(map[string]*queue.QueuedChannel[imap.Update])
|
updateCh := make(map[string]*queue.QueuedChannel[imap.Update])
|
||||||
|
|
||||||
for _, addr := range apiAddrs {
|
switch encVault.AddressMode() {
|
||||||
updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
case vault.CombinedMode:
|
||||||
|
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||||
|
|
||||||
if encVault.AddressMode() == vault.CombinedMode {
|
for _, addr := range apiAddrs {
|
||||||
break
|
updateCh[addr.ID] = primaryUpdateCh
|
||||||
|
}
|
||||||
|
|
||||||
|
case vault.SplitMode:
|
||||||
|
for _, addr := range apiAddrs {
|
||||||
|
updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,19 +86,15 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
|||||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||||
|
|
||||||
apiUser: safe.NewValue(apiUser),
|
apiUser: safe.NewValue(apiUser),
|
||||||
apiAddrs: safe.NewSlice(apiAddrs),
|
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
|
||||||
settings: safe.NewValue(settings),
|
updateCh: safe.NewMapFrom(updateCh, nil),
|
||||||
|
|
||||||
userKR: safe.NewValue(userKR),
|
|
||||||
addrKRs: safe.NewMap(addrKRs),
|
|
||||||
|
|
||||||
updateCh: updateCh,
|
|
||||||
syncStopCh: make(chan struct{}),
|
syncStopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
// When we receive an auth object, we update it in the vault.
|
// When we receive an auth object, we update it in the vault.
|
||||||
// This will be used to authorize the user on the next run.
|
// This will be used to authorize the user on the next run.
|
||||||
client.AddAuthHandler(func(auth liteapi.Auth) {
|
user.client.AddAuthHandler(func(auth liteapi.Auth) {
|
||||||
if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil {
|
if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to update auth in vault")
|
logrus.WithError(err).Error("Failed to update auth in vault")
|
||||||
}
|
}
|
||||||
@ -111,23 +102,24 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
|||||||
|
|
||||||
// When we are deauthorized, we send a deauth event to the event channel.
|
// When we are deauthorized, we send a deauth event to the event channel.
|
||||||
// Bridge will react to this event by logging out the user.
|
// Bridge will react to this event by logging out the user.
|
||||||
client.AddDeauthHandler(func() {
|
user.client.AddDeauthHandler(func() {
|
||||||
user.eventCh.Enqueue(events.UserDeauth{
|
user.eventCh.Enqueue(events.UserDeauth{
|
||||||
UserID: user.ID(),
|
UserID: user.ID(),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// TODO: Don't start the event loop until the initial sync has finished!
|
||||||
|
eventCh := user.client.NewEventStream(EventPeriod, EventJitter, user.vault.EventID())
|
||||||
|
|
||||||
// If we haven't synced yet, do it first.
|
// If we haven't synced yet, do it first.
|
||||||
// If it fails, we don't start the event loop.
|
// If it fails, we don't start the event loop.
|
||||||
// Otherwise, begin processing API events, logging any errors that occur.
|
// Otherwise, begin processing API events, logging any errors that occur.
|
||||||
go func() {
|
go func() {
|
||||||
if status := user.vault.SyncStatus(); !status.HasMessages {
|
if err := <-user.startSync(); err != nil {
|
||||||
if err := <-user.startSync(); err != nil {
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for err := range user.streamEvents() {
|
for err := range user.streamEvents(eventCh) {
|
||||||
logrus.WithError(err).Error("Error while streaming events")
|
logrus.WithError(err).Error("Error while streaming events")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -137,40 +129,34 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
|||||||
|
|
||||||
// ID returns the user's ID.
|
// ID returns the user's ID.
|
||||||
func (user *User) ID() string {
|
func (user *User) ID() string {
|
||||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) string {
|
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string {
|
||||||
return apiUser.ID
|
return apiUser.ID
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name returns the user's username.
|
// Name returns the user's username.
|
||||||
func (user *User) Name() string {
|
func (user *User) Name() string {
|
||||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) string {
|
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string {
|
||||||
return apiUser.Name
|
return apiUser.Name
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Match matches the given query against the user's username and email addresses.
|
// Match matches the given query against the user's username and email addresses.
|
||||||
func (user *User) Match(query string) bool {
|
func (user *User) Match(query string) bool {
|
||||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) bool {
|
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) bool {
|
||||||
return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) bool {
|
if query == apiUser.Name {
|
||||||
if query == apiUser.Name {
|
return true
|
||||||
return true
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for _, addr := range apiAddrs {
|
return user.apiAddrs.HasFunc(func(_ string, addr liteapi.Address) bool {
|
||||||
if addr.Email == query {
|
return addr.Email == query
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emails returns all the user's email addresses.
|
// Emails returns all the user's email addresses via the callback.
|
||||||
func (user *User) Emails() []string {
|
func (user *User) Emails() []string {
|
||||||
return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) []string {
|
return safe.MapValuesRet(user.apiAddrs, func(apiAddrs []liteapi.Address) []string {
|
||||||
return xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
return xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||||
return addr.Email
|
return addr.Email
|
||||||
})
|
})
|
||||||
@ -184,28 +170,38 @@ func (user *User) GetAddressMode() vault.AddressMode {
|
|||||||
|
|
||||||
// SetAddressMode sets the user's address mode.
|
// SetAddressMode sets the user's address mode.
|
||||||
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
|
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
|
||||||
for _, updateCh := range user.updateCh {
|
user.stopSync()
|
||||||
updateCh.Close()
|
user.lockSync()
|
||||||
}
|
defer user.unlockSync()
|
||||||
|
|
||||||
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
|
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||||
|
for _, updateCh := range xslices.Unique(updateCh) {
|
||||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
updateCh.Close()
|
||||||
for _, addr := range apiAddrs {
|
|
||||||
user.updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
|
||||||
|
|
||||||
if mode == vault.CombinedMode {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
updateCh := make(map[string]*queue.QueuedChannel[imap.Update])
|
||||||
|
|
||||||
|
switch mode {
|
||||||
|
case vault.CombinedMode:
|
||||||
|
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||||
|
|
||||||
|
user.apiAddrs.IterKeys(func(addrID string) {
|
||||||
|
updateCh[addrID] = primaryUpdateCh
|
||||||
|
})
|
||||||
|
|
||||||
|
case vault.SplitMode:
|
||||||
|
user.apiAddrs.IterKeys(func(addrID string) {
|
||||||
|
updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
user.updateCh = safe.NewMapFrom(updateCh, nil)
|
||||||
|
|
||||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||||
return fmt.Errorf("failed to set address mode: %w", err)
|
return fmt.Errorf("failed to set address mode: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user.stopSync()
|
|
||||||
|
|
||||||
if err := user.vault.ClearSyncStatus(); err != nil {
|
if err := user.vault.ClearSyncStatus(); err != nil {
|
||||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||||
}
|
}
|
||||||
@ -246,25 +242,19 @@ func (user *User) GluonKey() []byte {
|
|||||||
|
|
||||||
// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP.
|
// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP.
|
||||||
func (user *User) BridgePass() []byte {
|
func (user *User) BridgePass() []byte {
|
||||||
buf := new(bytes.Buffer)
|
return hexEncode(user.vault.BridgePass())
|
||||||
|
|
||||||
if _, err := hex.NewEncoder(buf).Write(user.vault.BridgePass()); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsedSpace returns the total space used by the user on the API.
|
// UsedSpace returns the total space used by the user on the API.
|
||||||
func (user *User) UsedSpace() int {
|
func (user *User) UsedSpace() int {
|
||||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) int {
|
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int {
|
||||||
return apiUser.UsedSpace
|
return apiUser.UsedSpace
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaxSpace returns the amount of space the user can use on the API.
|
// MaxSpace returns the amount of space the user can use on the API.
|
||||||
func (user *User) MaxSpace() int {
|
func (user *User) MaxSpace() int {
|
||||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) int {
|
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int {
|
||||||
return apiUser.MaxSpace
|
return apiUser.MaxSpace
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -275,37 +265,9 @@ func (user *User) GetEventCh() <-chan events.Event {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewIMAPConnector returns an IMAP connector for the given address.
|
// NewIMAPConnector returns an IMAP connector for the given address.
|
||||||
// If not in split mode, this function returns an error.
|
// If not in split mode, this must be the primary address.
|
||||||
func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
|
func (user *User) NewIMAPConnector(addrID string) connector.Connector {
|
||||||
return safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (connector.Connector, error) {
|
return newIMAPConnector(user, addrID)
|
||||||
var emails []string
|
|
||||||
|
|
||||||
switch user.vault.AddressMode() {
|
|
||||||
case vault.CombinedMode:
|
|
||||||
if addrID != apiAddrs[0].ID {
|
|
||||||
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
emails = xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
|
||||||
return addr.Email
|
|
||||||
})
|
|
||||||
|
|
||||||
case vault.SplitMode:
|
|
||||||
email, err := getAddrEmail(apiAddrs, addrID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
emails = []string{email}
|
|
||||||
}
|
|
||||||
|
|
||||||
return newIMAPConnector(
|
|
||||||
user.client,
|
|
||||||
user.updateCh[addrID].GetChannel(),
|
|
||||||
user.BridgePass(),
|
|
||||||
emails...,
|
|
||||||
), nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
|
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
|
||||||
@ -314,23 +276,48 @@ func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
|
|||||||
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
||||||
imapConn := make(map[string]connector.Connector)
|
imapConn := make(map[string]connector.Connector)
|
||||||
|
|
||||||
for addrID := range user.updateCh {
|
switch user.vault.AddressMode() {
|
||||||
conn, err := user.NewIMAPConnector(addrID)
|
case vault.CombinedMode:
|
||||||
if err != nil {
|
user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) {
|
||||||
return nil, fmt.Errorf("failed to create IMAP connector: %w", err)
|
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||||
}
|
})
|
||||||
|
|
||||||
imapConn[addrID] = conn
|
case vault.SplitMode:
|
||||||
|
user.apiAddrs.IterKeys(func(addrID string) {
|
||||||
|
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return imapConn, nil
|
return imapConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSMTPSession returns an SMTP session for the user.
|
// NewSMTPSession returns an SMTP session for the user.
|
||||||
func (user *User) NewSMTPSession(email string) (smtp.Session, error) {
|
func (user *User) NewSMTPSession(email string, password []byte) (smtp.Session, error) {
|
||||||
|
if _, err := user.checkAuth(email, password); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return newSMTPSession(user, email)
|
return newSMTPSession(user, email)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnStatusUp is called when the connection goes up.
|
||||||
|
func (user *User) OnStatusUp() {
|
||||||
|
go func() {
|
||||||
|
logrus.Info("Connection up, checking if sync is needed")
|
||||||
|
|
||||||
|
if err := <-user.startSync(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to sync on status up")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnStatusDown is called when the connection goes down.
|
||||||
|
func (user *User) OnStatusDown() {
|
||||||
|
logrus.Info("Connection down, aborting any ongoing syncs")
|
||||||
|
|
||||||
|
user.stopSync()
|
||||||
|
}
|
||||||
|
|
||||||
// Logout logs the user out from the API.
|
// Logout logs the user out from the API.
|
||||||
// If withVault is true, the user's vault is also cleared.
|
// If withVault is true, the user's vault is also cleared.
|
||||||
func (user *User) Logout(ctx context.Context) error {
|
func (user *User) Logout(ctx context.Context) error {
|
||||||
@ -350,13 +337,18 @@ func (user *User) Close() error {
|
|||||||
// Cancel ongoing syncs.
|
// Cancel ongoing syncs.
|
||||||
user.stopSync()
|
user.stopSync()
|
||||||
|
|
||||||
|
// Wait for ongoing syncs to stop.
|
||||||
|
user.waitSync()
|
||||||
|
|
||||||
// Close the user's API client.
|
// Close the user's API client.
|
||||||
user.client.Close()
|
user.client.Close()
|
||||||
|
|
||||||
// Close the user's update channels.
|
// Close the user's update channels.
|
||||||
for _, updateCh := range user.updateCh {
|
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||||
updateCh.Close()
|
for _, updateCh := range xslices.Unique(updateCh) {
|
||||||
}
|
updateCh.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Close the user's notify channel.
|
// Close the user's notify channel.
|
||||||
user.eventCh.Close()
|
user.eventCh.Close()
|
||||||
@ -364,16 +356,37 @@ func (user *User) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *User) checkAuth(email string, password []byte) (string, error) {
|
||||||
|
dec, err := hexDecode(password)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to decode password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if subtle.ConstantTimeCompare(user.vault.BridgePass(), dec) != 1 {
|
||||||
|
return "", fmt.Errorf("invalid password")
|
||||||
|
}
|
||||||
|
|
||||||
|
return safe.MapValuesRetErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
||||||
|
for _, addr := range apiAddrs {
|
||||||
|
if addr.Email == strings.ToLower(email) {
|
||||||
|
return addr.ID, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("invalid email")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// streamEvents begins streaming API events for the user.
|
// streamEvents begins streaming API events for the user.
|
||||||
// When we receive an API event, we attempt to handle it.
|
// When we receive an API event, we attempt to handle it.
|
||||||
// If successful, we update the event ID in the vault.
|
// If successful, we update the event ID in the vault.
|
||||||
func (user *User) streamEvents() <-chan error {
|
func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error {
|
||||||
errCh := make(chan error)
|
errCh := make(chan error)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(errCh)
|
defer close(errCh)
|
||||||
|
|
||||||
for event := range user.client.NewEventStreamer(EventPeriod, EventJitter, user.vault.EventID()).Subscribe() {
|
for event := range eventCh {
|
||||||
if err := user.handleAPIEvent(context.Background(), event); err != nil {
|
if err := user.handleAPIEvent(context.Background(), event); err != nil {
|
||||||
errCh <- fmt.Errorf("failed to handle API event: %w", err)
|
errCh <- fmt.Errorf("failed to handle API event: %w", err)
|
||||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||||
@ -387,11 +400,21 @@ func (user *User) streamEvents() <-chan error {
|
|||||||
|
|
||||||
// startSync begins a startSync for the user.
|
// startSync begins a startSync for the user.
|
||||||
func (user *User) startSync() <-chan error {
|
func (user *User) startSync() <-chan error {
|
||||||
|
if user.vault.SyncStatus().IsComplete() {
|
||||||
|
logrus.Debug("Already synced, skipping")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
errCh := make(chan error)
|
errCh := make(chan error)
|
||||||
|
|
||||||
user.syncWG.Go(func() {
|
user.syncLock.GoTry(func(ok bool) {
|
||||||
defer close(errCh)
|
defer close(errCh)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
logrus.Debug("Sync already in progress, skipping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := contextWithStopCh(context.Background(), user.syncStopCh)
|
ctx, cancel := contextWithStopCh(context.Background(), user.syncStopCh)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@ -421,46 +444,24 @@ func (user *User) startSync() <-chan error {
|
|||||||
func (user *User) stopSync() {
|
func (user *User) stopSync() {
|
||||||
select {
|
select {
|
||||||
case user.syncStopCh <- struct{}{}:
|
case user.syncStopCh <- struct{}{}:
|
||||||
user.syncWG.Wait()
|
logrus.Debug("Sent sync abort signal")
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// ...
|
logrus.Debug("No sync to abort")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
|
// lockSync prevents a new sync from starting.
|
||||||
for _, addr := range apiAddrs {
|
func (user *User) lockSync() {
|
||||||
if addr.Email == email {
|
user.syncLock.Lock()
|
||||||
return addr.ID, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", fmt.Errorf("address %s not found", email)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) {
|
// unlockSync allows a new sync to start.
|
||||||
for _, addr := range apiAddrs {
|
func (user *User) unlockSync() {
|
||||||
if addr.ID == addrID {
|
user.syncLock.Unlock()
|
||||||
return addr.Email, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", fmt.Errorf("address %s not found", addrID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
|
// waitSync waits for any ongoing sync to finish.
|
||||||
func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) {
|
func (user *User) waitSync() {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
user.syncLock.Wait()
|
||||||
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-stopCh:
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
case <-ctx.Done():
|
|
||||||
// ...
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return ctx, cancel
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -103,6 +103,10 @@ type SyncStatus struct {
|
|||||||
LastMessageID string
|
LastMessageID string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (status SyncStatus) IsComplete() bool {
|
||||||
|
return status.HasLabels && status.HasMessages
|
||||||
|
}
|
||||||
|
|
||||||
func newDefaultUser(userID, username, authUID, authRef string, keyPass []byte) UserData {
|
func newDefaultUser(userID, username, authUID, authRef string, keyPass []byte) UserData {
|
||||||
return UserData{
|
return UserData{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
|
|||||||
@ -10,15 +10,18 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
|
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
"github.com/bradenaw/juniper/xslices"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Vault is an encrypted data vault that stores bridge and user data.
|
||||||
type Vault struct {
|
type Vault struct {
|
||||||
path string
|
path string
|
||||||
enc []byte
|
enc []byte
|
||||||
gcm cipher.AEAD
|
gcm cipher.AEAD
|
||||||
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// New constructs a new encrypted data vault at the given filepath using the given encryption key.
|
// New constructs a new encrypted data vault at the given filepath using the given encryption key.
|
||||||
@ -150,6 +153,9 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) get() Data {
|
func (vault *Vault) get() Data {
|
||||||
|
vault.lock.RLock()
|
||||||
|
defer vault.lock.RUnlock()
|
||||||
|
|
||||||
dec, err := decrypt(vault.gcm, vault.enc)
|
dec, err := decrypt(vault.gcm, vault.enc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@ -165,20 +171,28 @@ func (vault *Vault) get() Data {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) mod(fn func(data *Data)) error {
|
func (vault *Vault) mod(fn func(data *Data)) error {
|
||||||
data := vault.get()
|
vault.lock.Lock()
|
||||||
|
defer vault.lock.Unlock()
|
||||||
|
|
||||||
fn(&data)
|
dec, err := decrypt(vault.gcm, vault.enc)
|
||||||
|
|
||||||
return vault.set(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vault *Vault) set(data Data) error {
|
|
||||||
dec, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
enc, err := encrypt(vault.gcm, dec)
|
var data Data
|
||||||
|
|
||||||
|
if err := json.Unmarshal(dec, &data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fn(&data)
|
||||||
|
|
||||||
|
mod, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
enc, err := encrypt(vault.gcm, mod)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,7 +21,6 @@
|
|||||||
package versioner
|
package versioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -33,10 +32,9 @@ import (
|
|||||||
// RemoveOldVersions is a noop on darwin; we don't test it there.
|
// RemoveOldVersions is a noop on darwin; we don't test it there.
|
||||||
|
|
||||||
func TestRemoveOldVersions(t *testing.T) {
|
func TestRemoveOldVersions(t *testing.T) {
|
||||||
updates, err := os.MkdirTemp(t.TempDir(), "updates")
|
tempDir := t.TempDir()
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
v := newTestVersioner(t, "myCoolApp", updates, "2.3.4-beta", "2.3.4", "2.3.5", "2.4.0")
|
v := newTestVersioner(t, "myCoolApp", tempDir, "2.3.4-beta", "2.3.4", "2.3.5", "2.4.0")
|
||||||
|
|
||||||
allVersions, err := v.ListVersions()
|
allVersions, err := v.ListVersions()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -49,5 +47,5 @@ func TestRemoveOldVersions(t *testing.T) {
|
|||||||
assert.Len(t, cleanedVersions, 1)
|
assert.Len(t, cleanedVersions, 1)
|
||||||
|
|
||||||
assert.Equal(t, semver.MustParse("2.4.0"), cleanedVersions[0].version)
|
assert.Equal(t, semver.MustParse("2.4.0"), cleanedVersions[0].version)
|
||||||
assert.Equal(t, filepath.Join(updates, "2.4.0"), cleanedVersions[0].path)
|
assert.Equal(t, filepath.Join(tempDir, "2.4.0"), cleanedVersions[0].path)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -51,7 +51,7 @@ func (t *testCtx) startBridge() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create the bridge.
|
// Create the bridge.
|
||||||
bridge, err := bridge.New(
|
bridge, eventCh, err := bridge.New(
|
||||||
t.locator,
|
t.locator,
|
||||||
vault,
|
vault,
|
||||||
t.mocks.Autostarter,
|
t.mocks.Autostarter,
|
||||||
@ -73,6 +73,9 @@ func (t *testCtx) startBridge() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait for the users to be loaded.
|
||||||
|
waitForEvent(eventCh, events.AllUsersLoaded{})
|
||||||
|
|
||||||
// Save the bridge t.
|
// Save the bridge t.
|
||||||
t.bridge = bridge
|
t.bridge = bridge
|
||||||
|
|
||||||
@ -101,3 +104,12 @@ func (t *testCtx) stopBridge() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func waitForEvent[T any](eventCh <-chan events.Event, wantEvent T) {
|
||||||
|
for event := range eventCh {
|
||||||
|
switch event.(type) {
|
||||||
|
case T:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user