feat(GODT-2822): Integrate and activate all service

The bridge now runs on the new architecture.
This commit is contained in:
Leander Beernaert
2023-07-28 14:59:12 +02:00
parent a187747c7c
commit 823ca4d207
22 changed files with 558 additions and 3794 deletions

View File

@ -288,7 +288,7 @@ MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscr
mv tmp internal/services/userevents/mocks_test.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/events EventPublisher \
> internal/events/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider \
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \
> internal/services/useridentity/mocks/mocks.go
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog

View File

@ -408,11 +408,6 @@ func (bridge *Bridge) GetErrors() []error {
func (bridge *Bridge) Close(ctx context.Context) {
logrus.Info("Closing bridge")
// Close the servers
if err := bridge.serverManager.CloseServers(ctx); err != nil {
logrus.WithError(err).Error("Failed to close servers")
}
// Close all users.
safe.Lock(func() {
for _, user := range bridge.users {
@ -420,6 +415,11 @@ func (bridge *Bridge) Close(ctx context.Context) {
}
}, bridge.usersLock)
// Close the servers
if err := bridge.serverManager.CloseServers(ctx); err != nil {
logrus.WithError(err).Error("Failed to close servers")
}
// Stop all ongoing tasks.
bridge.tasks.CancelAndWait()

View File

@ -54,6 +54,7 @@ import (
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
var (
@ -621,6 +622,10 @@ func TestBridge_AddressWithoutKeys(t *testing.T) {
defer m.Close()
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Watch for sync finished event.
syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{}))
defer done()
// Create a user which will have an address without keys.
userID, _, err := s.CreateUser("nokeys", []byte("password"))
require.NoError(t, err)
@ -641,10 +646,6 @@ func TestBridge_AddressWithoutKeys(t *testing.T) {
// Remove the address keys.
require.NoError(t, s.RemoveAddressKey(userID, aliasAddrID, aliasAddr.Keys[0].ID))
// Watch for sync finished event.
syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{}))
defer done()
// We should be able to log the user in.
require.NoError(t, getErr(bridge.LoginFull(context.Background(), "nokeys", []byte("password"), nil, nil)))
require.NoError(t, err)
@ -873,6 +874,9 @@ func TestBridge_ChangeAddressOrder(t *testing.T) {
// withEnv creates the full test environment and runs the tests.
func withEnv(t *testing.T, tests func(context.Context, *server.Server, *proton.NetCtl, bridge.Locator, []byte), opts ...server.Option) {
opt := goleak.IgnoreCurrent()
defer goleak.VerifyNone(t, opt)
server := server.New(opts...)
defer server.Close()

View File

@ -79,7 +79,7 @@ func (bridge *Bridge) CheckClientState(ctx context.Context, checkFlags bool, pro
}
log.Debug("Building state")
state, err := meta.BuildMailboxToMessageMap(usr)
state, err := meta.BuildMailboxToMessageMap(ctx, usr)
if err != nil {
log.WithError(err).Error("Failed to build state")
return result, err

View File

@ -36,7 +36,6 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/user"
"github.com/ProtonMail/proton-bridge/v3/internal/useragent"
"github.com/sirupsen/logrus"
)
@ -45,16 +44,6 @@ func (bridge *Bridge) restartIMAP(ctx context.Context) error {
return bridge.serverManager.RestartIMAP(ctx)
}
// addIMAPUser connects the given user to gluon.
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
return bridge.serverManager.AddIMAPUser(ctx, user)
}
// removeIMAPUser disconnects the given user from gluon, optionally also removing its files.
func (bridge *Bridge) removeIMAPUser(ctx context.Context, user *user.User, withData bool) error {
return bridge.serverManager.RemoveIMAPUser(ctx, user, withData)
}
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
switch event := event.(type) {
case imapEvents.UserAdded:

View File

@ -28,8 +28,8 @@ import (
"github.com/ProtonMail/gluon/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/services/imapservice"
bridgesmtp "github.com/ProtonMail/proton-bridge/v3/internal/services/smtp"
"github.com/ProtonMail/proton-bridge/v3/internal/user"
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
"github.com/emersion/go-smtp"
"github.com/sirupsen/logrus"
@ -97,16 +97,18 @@ func (sm *ServerManager) RestartSMTP(ctx context.Context) error {
return err
}
func (sm *ServerManager) AddIMAPUser(ctx context.Context, user *user.User) error {
_, err := sm.requests.Send(ctx, &smRequestAddIMAPUser{user: user})
return err
}
func (sm *ServerManager) RemoveIMAPUser(ctx context.Context, user *user.User, withData bool) error {
_, err := sm.requests.Send(ctx, &smRequestRemoveIMAPUser{
user: user,
withData: withData,
func (sm *ServerManager) AddIMAPUser(
ctx context.Context,
connector connector.Connector,
addrID string,
idProvider imapservice.GluonIDProvider,
syncStateProvider imapservice.SyncStateProvider,
) error {
_, err := sm.requests.Send(ctx, &smRequestAddIMAPUser{
connector: connector,
addrID: addrID,
idProvider: idProvider,
syncStateProvider: syncStateProvider,
})
return err
@ -120,18 +122,11 @@ func (sm *ServerManager) SetGluonDir(ctx context.Context, gluonDir string) error
return err
}
func (sm *ServerManager) AddGluonUser(ctx context.Context, conn connector.Connector, passphrase []byte) (string, error) {
reply, err := cpc.SendTyped[string](ctx, sm.requests, &smRequestAddGluonUser{
conn: conn,
passphrase: passphrase,
})
return reply, err
}
func (sm *ServerManager) RemoveGluonUser(ctx context.Context, gluonID string) error {
_, err := sm.requests.Send(ctx, &smRequestRemoveGluonUser{
userID: gluonID,
func (sm *ServerManager) RemoveIMAPUser(ctx context.Context, deleteData bool, provider imapservice.GluonIDProvider, addrID ...string) error {
_, err := sm.requests.Send(ctx, &smRequestRemoveIMAPUser{
withData: deleteData,
addrID: addrID,
idProvider: provider,
})
return err
@ -195,18 +190,16 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) {
request.Reply(ctx, nil, err)
case *smRequestAddIMAPUser:
err := sm.handleAddIMAPUser(ctx, r.user)
err := sm.handleAddIMAPUser(ctx, r.connector, r.addrID, r.idProvider, r.syncStateProvider)
request.Reply(ctx, nil, err)
if err == nil {
sm.loadedUserCount++
sm.handleLoadedUserCountChange(ctx, bridge)
}
case *smRequestRemoveIMAPUser:
err := sm.handleRemoveIMAPUser(ctx, r.user, r.withData)
err := sm.handleRemoveIMAPUser(ctx, r.withData, r.idProvider, r.addrID...)
request.Reply(ctx, nil, err)
if err == nil {
sm.loadedUserCount--
sm.handleLoadedUserCountChange(ctx, bridge)
}
@ -214,14 +207,6 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) {
err := sm.handleSetGluonDir(ctx, bridge, r.dir)
request.Reply(ctx, nil, err)
case *smRequestAddGluonUser:
id, err := sm.handleAddGluonUser(ctx, r.conn, r.passphrase)
request.Reply(ctx, id, err)
case *smRequestRemoveGluonUser:
err := sm.handleRemoveGluonUser(ctx, r.userID)
request.Reply(ctx, nil, err)
case *smRequestAddSMTPAccount:
logrus.WithField("user", r.account.UserID()).Debug("Adding SMTP Account")
sm.smtpAccounts.AddAccount(r.account)
@ -277,27 +262,42 @@ func (sm *ServerManager) handleClose(ctx context.Context, bridge *Bridge) {
}
}
func (sm *ServerManager) handleAddIMAPUser(ctx context.Context, user *user.User) error {
func (sm *ServerManager) handleAddIMAPUser(ctx context.Context,
connector connector.Connector,
addrID string,
idProvider imapservice.GluonIDProvider,
syncStateProvider imapservice.SyncStateProvider,
) error {
// Due to the many different error exits, performer user count change at this stage rather we split the incrementing
// of users from the logic.
err := sm.handleAddIMAPUserImpl(ctx, connector, addrID, idProvider, syncStateProvider)
if err == nil {
sm.loadedUserCount++
}
return err
}
func (sm *ServerManager) handleAddIMAPUserImpl(ctx context.Context,
connector connector.Connector,
addrID string,
idProvider imapservice.GluonIDProvider,
syncStateProvider imapservice.SyncStateProvider,
) error {
if sm.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
imapConn, err := user.NewIMAPConnectors()
if err != nil {
return fmt.Errorf("failed to create IMAP connectors: %w", err)
}
for addrID, imapConn := range imapConn {
log := logrus.WithFields(logrus.Fields{
"userID": user.ID(),
"addrID": addrID,
})
log.Info("Adding user to imap server")
if gluonID, ok := user.GetGluonID(addrID); ok {
if gluonID, ok := idProvider.GetGluonID(addrID); ok {
log.WithField("gluonID", gluonID).Info("Loading existing IMAP user")
// Load the user, checking whether the DB was newly created.
isNew, err := sm.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey())
isNew, err := sm.imapServer.LoadUser(ctx, connector, gluonID, idProvider.GluonKey())
if err != nil {
return fmt.Errorf("failed to load IMAP user: %w", err)
}
@ -312,32 +312,32 @@ func (sm *ServerManager) handleAddIMAPUser(ctx context.Context, user *user.User)
}
// Clear the sync status -- we need to resync all messages.
if err := user.ClearSyncStatus(); err != nil {
if err := syncStateProvider.ClearSyncStatus(); err != nil {
return fmt.Errorf("failed to clear sync status: %w", err)
}
// Add the user back to the IMAP server.
if isNew, err := sm.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
if isNew, err := sm.imapServer.LoadUser(ctx, connector, gluonID, idProvider.GluonKey()); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
} else if isNew {
panic("IMAP user should already have a database")
}
} else if status := user.GetSyncStatus(); !status.HasLabels {
} else if status := syncStateProvider.GetSyncStatus(); !status.HasLabels {
// Otherwise, the DB already exists -- if the labels are not yet synced, we need to re-create the DB.
if err := sm.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
return fmt.Errorf("failed to remove old IMAP user: %w", err)
}
if err := user.RemoveGluonID(addrID, gluonID); err != nil {
if err := idProvider.RemoveGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to remove old IMAP user ID: %w", err)
}
gluonID, err := sm.imapServer.AddUser(ctx, imapConn, user.GluonKey())
gluonID, err := sm.imapServer.AddUser(ctx, connector, idProvider.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
if err := idProvider.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
@ -346,45 +346,49 @@ func (sm *ServerManager) handleAddIMAPUser(ctx context.Context, user *user.User)
} else {
log.Info("Creating new IMAP user")
gluonID, err := sm.imapServer.AddUser(ctx, imapConn, user.GluonKey())
gluonID, err := sm.imapServer.AddUser(ctx, connector, idProvider.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
if err := idProvider.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
log.WithField("gluonID", gluonID).Info("Created new IMAP user")
}
}
// Trigger a sync for the user, if needed.
user.TriggerSync()
return nil
}
func (sm *ServerManager) handleRemoveIMAPUser(ctx context.Context, user *user.User, withData bool) error {
func (sm *ServerManager) handleRemoveIMAPUser(ctx context.Context, withData bool, idProvider imapservice.GluonIDProvider, addrIDs ...string) error {
if sm.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
logrus.WithFields(logrus.Fields{
"userID": user.ID(),
"withData": withData,
"addresses": addrIDs,
}).Debug("Removing IMAP user")
for addrID, gluonID := range user.GetGluonIDs() {
for _, addrID := range addrIDs {
gluonID, ok := idProvider.GetGluonID(addrID)
if !ok {
logrus.Warnf("Could not find Gluon ID for addrID %v", addrID)
continue
}
if err := sm.imapServer.RemoveUser(ctx, gluonID, withData); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
if withData {
if err := user.RemoveGluonID(addrID, gluonID); err != nil {
if err := idProvider.RemoveGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to remove IMAP user ID: %w", err)
}
}
sm.loadedUserCount--
}
return nil
@ -396,7 +400,7 @@ func createIMAPServer(bridge *Bridge) (*gluon.Server, error) {
return nil, fmt.Errorf("failed to get Gluon Database directory: %w", err)
}
return newIMAPServer(
server, err := newIMAPServer(
bridge.vault.GetGluonCacheDir(),
gluonDataDir,
bridge.curVersion,
@ -409,6 +413,11 @@ func createIMAPServer(bridge *Bridge) (*gluon.Server, error) {
bridge.uidValidityGenerator,
bridge.panicHandler,
)
if err == nil {
bridge.publish(events.IMAPServerCreated{})
}
return server, err
}
func createSMTPServer(bridge *Bridge, accounts *bridgesmtp.Accounts) *smtp.Server {
@ -464,6 +473,8 @@ func (sm *ServerManager) closeIMAPServer(ctx context.Context, bridge *Bridge) er
}
sm.imapServer = nil
bridge.publish(events.IMAPServerClosed{})
}
return nil
@ -632,35 +643,12 @@ func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge,
bridge.heartbeat.SetCacheLocation(newGluonDir)
gluonDataDir, err := bridge.GetGluonDataDir()
if err != nil {
return fmt.Errorf("failed to get Gluon Database directory: %w", err)
}
imapServer, err := newIMAPServer(
bridge.vault.GetGluonCacheDir(),
gluonDataDir,
bridge.curVersion,
bridge.tlsConfig,
bridge.reporter,
bridge.logIMAPClient,
bridge.logIMAPServer,
bridge.imapEventCh,
bridge.tasks,
bridge.uidValidityGenerator,
bridge.panicHandler,
)
imapServer, err := createIMAPServer(bridge)
if err != nil {
return fmt.Errorf("failed to create new IMAP server: %w", err)
}
sm.imapServer = imapServer
for _, bridgeUser := range bridge.users {
if err := sm.handleAddIMAPUser(ctx, bridgeUser); err != nil {
return fmt.Errorf("failed to add users to new IMAP server: %w", err)
}
sm.loadedUserCount++
}
if sm.shouldStartServers() {
if err := sm.serveIMAP(ctx, bridge); err != nil {
@ -672,22 +660,6 @@ func (sm *ServerManager) handleSetGluonDir(ctx context.Context, bridge *Bridge,
}, bridge.usersLock)
}
func (sm *ServerManager) handleAddGluonUser(ctx context.Context, conn connector.Connector, passphrase []byte) (string, error) {
if sm.imapServer == nil {
return "", fmt.Errorf("no imap server instance running")
}
return sm.imapServer.AddUser(ctx, conn, passphrase)
}
func (sm *ServerManager) handleRemoveGluonUser(ctx context.Context, userID string) error {
if sm.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
return sm.imapServer.RemoveUser(ctx, userID, true)
}
func (sm *ServerManager) shouldStartServers() bool {
return sm.loadedUserCount >= 1
}
@ -699,27 +671,22 @@ type smRequestRestartIMAP struct{}
type smRequestRestartSMTP struct{}
type smRequestAddIMAPUser struct {
user *user.User
connector connector.Connector
addrID string
idProvider imapservice.GluonIDProvider
syncStateProvider imapservice.SyncStateProvider
}
type smRequestRemoveIMAPUser struct {
user *user.User
withData bool
addrID []string
idProvider imapservice.GluonIDProvider
}
type smRequestSetGluonDir struct {
dir string
}
type smRequestAddGluonUser struct {
conn connector.Connector
passphrase []byte
}
type smRequestRemoveGluonUser struct {
userID string
}
type smRequestAddSMTPAccount struct {
account *bridgesmtp.Service
}

View File

@ -278,18 +278,10 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
return fmt.Errorf("address mode is already %q", mode)
}
if err := bridge.removeIMAPUser(ctx, user, true); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
if err := user.SetAddressMode(ctx, mode); err != nil {
return fmt.Errorf("failed to set address mode: %w", err)
}
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
bridge.publish(events.AddressModeChanged{
UserID: userID,
AddressMode: mode,
@ -335,13 +327,7 @@ func (bridge *Bridge) SendBadEventUserFeedback(_ context.Context, userID string,
logrus.WithError(rerr).Error("Failed to report feedback failure")
}
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
user.BadEventFeedbackResync(ctx)
return nil
return user.BadEventFeedbackResync(ctx)
}
if rerr := bridge.reporter.ReportMessageWithContext(
@ -535,16 +521,13 @@ func (bridge *Bridge) addUserWithVault(
bridge.vault.GetMaxSyncMemory(),
statsPath,
bridge,
bridge.serverManager,
&bridgeEventSubscription{b: bridge},
)
if err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
// Connect the user's address(es) to gluon.
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := bridge.addSMTPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add SMTP user: %w", err)
}
@ -558,11 +541,8 @@ func (bridge *Bridge) addUserWithVault(
"event": event,
}).Debug("Received user event")
if err := bridge.handleUserEvent(ctx, user, event); err != nil {
logrus.WithError(err).Error("Failed to handle user event")
} else {
bridge.handleUserEvent(ctx, user, event)
bridge.publish(event)
}
})
})
@ -613,10 +593,6 @@ func (bridge *Bridge) logoutUser(ctx context.Context, user *user.User, withAPI,
"withData": withData,
}).Debug("Logging out user")
if err := bridge.removeIMAPUser(ctx, user, withData); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
if err := bridge.removeSMTPUser(ctx, user); err != nil {
logrus.WithError(err).Error("Failed to remove SMTP user")
}

View File

@ -19,44 +19,17 @@ package bridge
import (
"context"
"fmt"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/proton-bridge/v3/internal"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/user"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/sirupsen/logrus"
)
func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, event events.Event) error {
func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, event events.Event) {
switch event := event.(type) {
case events.UserAddressCreated:
if err := bridge.handleUserAddressCreated(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user address created event: %w", err)
}
case events.UserAddressEnabled:
if err := bridge.handleUserAddressEnabled(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user address enabled event: %w", err)
}
case events.UserAddressDisabled:
if err := bridge.handleUserAddressDisabled(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user address disabled event: %w", err)
}
case events.UserAddressDeleted:
if err := bridge.handleUserAddressDeleted(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user address deleted event: %w", err)
}
case events.UserRefreshed:
if err := bridge.handleUserRefreshed(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user refreshed event: %w", err)
}
case events.UserDeauth:
bridge.handleUserDeauth(ctx, user)
@ -66,102 +39,6 @@ func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, even
case events.UncategorizedEventError:
bridge.handleUncategorizedErrorEvent(event)
}
return nil
}
func (bridge *Bridge) handleUserAddressCreated(ctx context.Context, user *user.User, event events.UserAddressCreated) error {
if user.GetAddressMode() == vault.CombinedMode {
return nil
}
gluonID, err := bridge.serverManager.AddGluonUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add user to IMAP server: %w", err)
}
if err := user.SetGluonID(event.AddressID, gluonID); err != nil {
return fmt.Errorf("failed to set gluon ID: %w", err)
}
return nil
}
func (bridge *Bridge) handleUserAddressEnabled(ctx context.Context, user *user.User, event events.UserAddressEnabled) error {
if user.GetAddressMode() == vault.CombinedMode {
return nil
}
gluonID, err := bridge.serverManager.AddGluonUser(ctx, user.NewIMAPConnector(event.AddressID), user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add user to IMAP server: %w", err)
}
if err := user.SetGluonID(event.AddressID, gluonID); err != nil {
return fmt.Errorf("failed to set gluon ID: %w", err)
}
return nil
}
func (bridge *Bridge) handleUserAddressDisabled(ctx context.Context, user *user.User, event events.UserAddressDisabled) error {
if user.GetAddressMode() == vault.CombinedMode {
return nil
}
gluonID, ok := user.GetGluonID(event.AddressID)
if !ok {
return fmt.Errorf("gluon ID not found for address %s", event.AddressID)
}
if err := bridge.serverManager.RemoveGluonUser(ctx, gluonID); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
if err := user.RemoveGluonID(event.AddressID, gluonID); err != nil {
return fmt.Errorf("failed to remove gluon ID for address: %w", err)
}
return nil
}
func (bridge *Bridge) handleUserAddressDeleted(ctx context.Context, user *user.User, event events.UserAddressDeleted) error {
if user.GetAddressMode() == vault.CombinedMode {
return nil
}
gluonID, ok := user.GetGluonID(event.AddressID)
if !ok {
return fmt.Errorf("gluon ID not found for address %s", event.AddressID)
}
if err := bridge.serverManager.handleRemoveGluonUser(ctx, gluonID); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
if err := user.RemoveGluonID(event.AddressID, gluonID); err != nil {
return fmt.Errorf("failed to remove gluon ID for address: %w", err)
}
return nil
}
func (bridge *Bridge) handleUserRefreshed(ctx context.Context, user *user.User, event events.UserRefreshed) error {
return safe.RLockRet(func() error {
if event.CancelEventPool {
user.CancelSyncAndEventPoll()
}
if err := bridge.removeIMAPUser(ctx, user, true); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
return nil
}, bridge.usersLock)
}
func (bridge *Bridge) handleUserDeauth(ctx context.Context, user *user.User) {
@ -171,7 +48,7 @@ func (bridge *Bridge) handleUserDeauth(ctx context.Context, user *user.User) {
}, bridge.usersLock)
}
func (bridge *Bridge) handleUserBadEvent(_ context.Context, user *user.User, event events.UserBadEvent) {
func (bridge *Bridge) handleUserBadEvent(ctx context.Context, user *user.User, event events.UserBadEvent) {
safe.Lock(func() {
if rerr := bridge.reporter.ReportMessageWithContext("Failed to handle event", reporter.Context{
"user_id": user.ID(),
@ -184,12 +61,7 @@ func (bridge *Bridge) handleUserBadEvent(_ context.Context, user *user.User, eve
logrus.WithError(rerr).Error("Failed to report failed event handling")
}
user.CancelSyncAndEventPoll()
// Disable IMAP user
if err := bridge.removeIMAPUser(context.Background(), user, false); err != nil {
logrus.WithError(err).Error("Failed to remove IMAP user")
}
user.OnBadEvent(ctx)
}, bridge.usersLock)
}

View File

@ -20,6 +20,9 @@ package events
import (
"context"
"fmt"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/watcher"
)
type Event interface {
@ -39,3 +42,22 @@ type EventPublisher interface {
type NullEventPublisher struct{}
func (NullEventPublisher) PublishEvent(_ context.Context, _ Event) {}
type Subscription interface {
Add(ofType ...Event) *watcher.Watcher[Event]
Remove(watcher *watcher.Watcher[Event])
}
type NullSubscription struct{}
func (n NullSubscription) Add(ofType ...Event) *watcher.Watcher[Event] {
return watcher.New[Event](&async.NoopPanicHandler{}, ofType...)
}
func (n NullSubscription) Remove(watcher *watcher.Watcher[Event]) {
watcher.Close()
}
func NewNullSubscription() *NullSubscription {
return &NullSubscription{}
}

View File

@ -37,6 +37,22 @@ func (event IMAPServerStopped) String() string {
return "IMAPServerStopped"
}
type IMAPServerClosed struct {
eventBase
}
func (event IMAPServerClosed) String() string {
return "IMAPServerClosed"
}
type IMAPServerCreated struct {
eventBase
}
func (event IMAPServerCreated) String() string {
return "IMAPServerCreated"
}
type IMAPServerError struct {
eventBase

View File

@ -200,7 +200,7 @@ func (s *Service) CancelSync(ctx context.Context) error {
}
func (s *Service) ResumeSync(ctx context.Context) error {
_, err := s.cpc.Send(ctx, &cancelSyncReq{})
_, err := s.cpc.Send(ctx, &resumeSyncReq{})
return err
}

View File

@ -41,10 +41,6 @@ func (s *Service) onAddressEvent(ctx context.Context, events []proton.AddressEve
return nil
}
if s.addressMode != usertypes.AddressModeSplit {
return nil
}
for _, event := range events {
switch event.Action {
case proton.EventCreate:

View File

@ -34,7 +34,7 @@ import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/constants"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
imapservice "github.com/ProtonMail/proton-bridge/v3/internal/services/imapservice"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/bradenaw/juniper/xmaps"
@ -58,12 +58,21 @@ type DiagMailboxMessage struct {
Flags imap.FlagSet
}
func (apm DiagnosticMetadata) BuildMailboxToMessageMap(user *User) (map[string]AccountMailboxMap, error) {
return safe.RLockRetErr(func() (map[string]AccountMailboxMap, error) {
func (apm DiagnosticMetadata) BuildMailboxToMessageMap(ctx context.Context, user *User) (map[string]AccountMailboxMap, error) {
apiAddrs, err := user.identityService.GetAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get addresses: %w", err)
}
apiLabels, err := user.imapService.GetLabels(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get labels: %w", err)
}
result := make(map[string]AccountMailboxMap)
mode := user.GetAddressMode()
primaryAddrID, err := usertypes.GetPrimaryAddr(user.apiAddrs)
primaryAddrID, err := usertypes.GetPrimaryAddr(apiAddrs)
if err != nil {
return nil, fmt.Errorf("failed to get primary addr for user: %w", err)
}
@ -73,7 +82,7 @@ func (apm DiagnosticMetadata) BuildMailboxToMessageMap(user *User) (map[string]A
addrID = primaryAddrID.ID
}
addr := user.apiAddrs[addrID]
addr := apiAddrs[addrID]
if addr.Status != proton.AddressStatusEnabled {
return nil, false
}
@ -89,13 +98,13 @@ func (apm DiagnosticMetadata) BuildMailboxToMessageMap(user *User) (map[string]A
for _, metadata := range apm.Metadata {
for _, label := range metadata.LabelIDs {
details, ok := user.apiLabels[label]
details, ok := apiLabels[label]
if !ok {
logrus.Warnf("User %v has message with unknown label '%v'", user.Name(), label)
continue
}
if !wantLabel(details) {
if !imapservice.WantLabel(details) {
continue
}
@ -108,14 +117,14 @@ func (apm DiagnosticMetadata) BuildMailboxToMessageMap(user *User) (map[string]A
if details.Type == proton.LabelTypeSystem {
mboxName = details.Name
} else {
mboxName = strings.Join(getMailboxName(details), "/")
mboxName = strings.Join(imapservice.GetMailboxName(details), "/")
}
mboxMessage := DiagMailboxMessage{
UserID: user.ID(),
ID: metadata.ID,
AddressID: metadata.AddressID,
Flags: buildFlagSetFromMessageMetadata(metadata),
Flags: imapservice.BuildFlagSetFromMessageMetadata(metadata),
}
if v, ok := account[mboxName]; ok {
@ -126,7 +135,6 @@ func (apm DiagnosticMetadata) BuildMailboxToMessageMap(user *User) (map[string]A
}
}
return result, nil
}, user.apiAddrsLock, user.apiLabelsLock)
}
func (user *User) GetDiagnosticMetadata(ctx context.Context) (DiagnosticMetadata, error) {
@ -161,12 +169,19 @@ func (user *User) DebugDownloadMessages(
msgs map[string]DiagMailboxMessage,
progressCB func(string, int, int),
) error {
var err error
safe.RLock(func() {
err = func() error {
total := len(msgs)
userID := user.ID()
apiUser, err := user.identityService.GetAPIUser(ctx)
if err != nil {
return fmt.Errorf("failed to get api user: %w", err)
}
apiAddrs, err := user.identityService.GetAddresses(ctx)
if err != nil {
return fmt.Errorf("failed to get address: %w", err)
}
counter := 1
for _, msg := range msgs {
if progressCB != nil {
@ -188,7 +203,7 @@ func (user *User) DebugDownloadMessages(
return err
}
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
if err := usertypes.WithAddrKR(apiUser, apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
switch {
case len(message.Attachments) > 0:
return decodeMultipartMessage(msgDir, addrKR, message.Message, message.AttData)
@ -204,9 +219,6 @@ func (user *User) DebugDownloadMessages(
}
}
return nil
}()
}, user.apiAddrsLock, user.apiUserLock)
return err
}
func getBodyName(path string) string {

View File

@ -1,864 +0,0 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package user
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"github.com/ProtonMail/gluon"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)
// handleAPIEvent handles the given proton.Event.
func (user *User) handleAPIEvent(ctx context.Context, event proton.Event) error {
if event.Refresh&proton.RefreshMail != 0 {
return user.handleRefreshEvent(ctx, event.Refresh, event.EventID)
}
if event.User != nil {
user.handleUserEvent(ctx, *event.User)
}
if len(event.Addresses) > 0 {
if err := user.handleAddressEvents(ctx, event.Addresses); err != nil {
return err
}
}
if len(event.Labels) > 0 {
if err := user.handleLabelEvents(ctx, event.Labels); err != nil {
return err
}
}
if len(event.Messages) > 0 {
if err := user.handleMessageEvents(ctx, event.Messages); err != nil {
return err
}
}
if event.UsedSpace != nil {
user.handleUsedSpaceChange(*event.UsedSpace)
}
return nil
}
func (user *User) handleRefreshEvent(ctx context.Context, refresh proton.RefreshFlag, eventID string) error {
l := user.log.WithFields(logrus.Fields{
"eventID": eventID,
"refresh": refresh,
})
l.Info("Handling refresh event")
// Abort the event stream
defer user.pollAbort.Abort()
// Re-sync messages after the user, address and label refresh.
defer user.goSync()
return user.syncUserAddressesLabelsAndClearSync(ctx, false)
}
func (user *User) syncUserAddressesLabelsAndClearSync(ctx context.Context, cancelEventPool bool) error {
return safe.LockRet(func() error {
// Fetch latest user info.
apiUser, err := user.client.GetUser(ctx)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
// Fetch latest address info.
apiAddrs, err := user.client.GetAddresses(ctx)
if err != nil {
return fmt.Errorf("failed to get addresses: %w", err)
}
// Fetch latest label info.
apiLabels, err := user.client.GetLabels(ctx, proton.LabelTypeSystem, proton.LabelTypeFolder, proton.LabelTypeLabel)
if err != nil {
return fmt.Errorf("failed to get labels: %w", err)
}
// Update the API info in the user.
user.apiUser = apiUser
user.apiAddrs = usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID })
user.apiLabels = usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID })
// Clear sync status; we want to sync everything again.
if err := user.clearSyncStatus(); err != nil {
return fmt.Errorf("failed to clear sync status: %w", err)
}
// The user was refreshed.
user.eventCh.Enqueue(events.UserRefreshed{
UserID: user.apiUser.ID,
CancelEventPool: cancelEventPool,
})
return nil
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
}
// handleUserEvent handles the given user event.
func (user *User) handleUserEvent(_ context.Context, userEvent proton.User) {
safe.Lock(func() {
user.log.WithFields(logrus.Fields{
"userID": userEvent.ID,
"username": logging.Sensitive(userEvent.Name),
}).Info("Handling user event")
user.apiUser = userEvent
user.eventCh.Enqueue(events.UserChanged{
UserID: user.apiUser.ID,
})
}, user.apiUserLock)
}
// handleAddressEvents handles the given address events.
// GODT-1945: If split address mode, need to signal back to bridge to update the addresses.
func (user *User) handleAddressEvents(ctx context.Context, addressEvents []proton.AddressEvent) error {
for _, event := range addressEvents {
switch event.Action {
case proton.EventCreate:
if err := user.handleCreateAddressEvent(ctx, event); err != nil {
user.reportError("Failed to apply address create event", err)
return fmt.Errorf("failed to handle create address event: %w", err)
}
case proton.EventUpdate, proton.EventUpdateFlags:
if err := user.handleUpdateAddressEvent(ctx, event); err != nil {
if errors.Is(err, ErrAddressDoesNotExist) {
logrus.Debugf("Address %v does not exist, will try create instead", event.Address.ID)
if createErr := user.handleCreateAddressEvent(ctx, event); createErr != nil {
user.reportError("Failed to apply address update event (with create)", createErr)
return fmt.Errorf("failed to handle update address event (with create): %w", createErr)
}
return nil
}
user.reportError("Failed to apply address update event", err)
return fmt.Errorf("failed to handle update address event: %w", err)
}
case proton.EventDelete:
if err := user.handleDeleteAddressEvent(ctx, event); err != nil {
user.reportError("Failed to apply address delete event", err)
return fmt.Errorf("failed to delete address: %w", err)
}
}
}
return nil
}
func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.AddressEvent) error {
if err := safe.LockRet(func() error {
user.log.WithFields(logrus.Fields{
"addressID": event.ID,
"email": logging.Sensitive(event.Address.Email),
}).Info("Handling address created event")
if _, ok := user.apiAddrs[event.Address.ID]; ok {
user.log.Debugf("Address %q already exists", event.ID)
return nil
}
user.apiAddrs[event.Address.ID] = event.Address
// If the address is disabled.
if event.Address.Status != proton.AddressStatusEnabled {
return nil
}
// If the address is enabled, we need to hook it up to the update channels.
switch user.vault.AddressMode() {
case vault.CombinedMode:
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
if err != nil {
return fmt.Errorf("failed to get primary address: %w", err)
}
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
case vault.SplitMode:
user.updateCh[event.Address.ID] = async.NewQueuedChannel[imap.Update](
0,
0,
user.panicHandler,
fmt.Sprintf("user-update-split-%v", event.Address.ID),
)
}
user.eventCh.Enqueue(events.UserAddressCreated{
UserID: user.apiUser.ID,
AddressID: event.Address.ID,
Email: event.Address.Email,
})
return nil
}, user.apiAddrsLock, user.updateChLock); err != nil {
return fmt.Errorf("failed to handle create address event: %w", err)
}
// Perform the sync in an RLock.
return safe.RLockRet(func() error {
if event.Address.Status != proton.AddressStatusEnabled {
return nil
}
if user.vault.AddressMode() == vault.SplitMode {
if err := syncLabels(ctx, user.apiLabels, user.updateCh[event.Address.ID]); err != nil {
return fmt.Errorf("failed to sync labels to new address: %w", err)
}
}
return nil
}, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
}
var ErrAddressDoesNotExist = errors.New("address does not exist")
func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.AddressEvent) error { //nolint:unparam
return safe.LockRet(func() error {
user.log.WithFields(logrus.Fields{
"addressID": event.ID,
"email": logging.Sensitive(event.Address.Email),
}).Info("Handling address updated event")
oldAddr, ok := user.apiAddrs[event.Address.ID]
if !ok {
return ErrAddressDoesNotExist
}
user.apiAddrs[event.Address.ID] = event.Address
switch {
// If the address was newly enabled:
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
switch user.vault.AddressMode() {
case vault.CombinedMode:
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
if err != nil {
return fmt.Errorf("failed to get primary address: %w", err)
}
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
case vault.SplitMode:
user.updateCh[event.Address.ID] = async.NewQueuedChannel[imap.Update](
0,
0,
user.panicHandler,
fmt.Sprintf("user-update-split-%v", event.Address.ID),
)
}
user.eventCh.Enqueue(events.UserAddressEnabled{
UserID: user.apiUser.ID,
AddressID: event.Address.ID,
Email: event.Address.Email,
})
// If the address was newly disabled:
case oldAddr.Status == proton.AddressStatusEnabled && event.Address.Status != proton.AddressStatusEnabled:
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.ID].CloseAndDiscardQueued()
}
delete(user.updateCh, event.ID)
user.eventCh.Enqueue(events.UserAddressDisabled{
UserID: user.apiUser.ID,
AddressID: event.Address.ID,
Email: event.Address.Email,
})
// Otherwise it's just an update:
default:
user.eventCh.Enqueue(events.UserAddressUpdated{
UserID: user.apiUser.ID,
AddressID: event.Address.ID,
Email: event.Address.Email,
})
}
return nil
}, user.apiAddrsLock, user.updateChLock)
}
func (user *User) handleDeleteAddressEvent(_ context.Context, event proton.AddressEvent) error {
return safe.LockRet(func() error {
user.log.WithField("addressID", event.ID).Info("Handling address deleted event")
addr, ok := user.apiAddrs[event.ID]
if !ok {
user.log.Debugf("Address %q does not exist", event.ID)
return nil
}
delete(user.apiAddrs, event.ID)
// If the address was disabled to begin with, we don't need to do anything.
if addr.Status != proton.AddressStatusEnabled {
return nil
}
// Otherwise, in split mode, drop the update queue.
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.ID].CloseAndDiscardQueued()
}
// And in either mode, remove the address from the update channel map.
delete(user.updateCh, event.ID)
user.eventCh.Enqueue(events.UserAddressDeleted{
UserID: user.apiUser.ID,
AddressID: event.ID,
Email: addr.Email,
})
return nil
}, user.apiAddrsLock, user.updateChLock)
}
// handleLabelEvents handles the given label events.
func (user *User) handleLabelEvents(ctx context.Context, labelEvents []proton.LabelEvent) error {
for _, event := range labelEvents {
switch event.Action {
case proton.EventCreate:
updates, err := user.handleCreateLabelEvent(ctx, event)
if err != nil {
return fmt.Errorf("failed to handle create label event: %w", err)
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return err
}
case proton.EventUpdate, proton.EventUpdateFlags:
updates, err := user.handleUpdateLabelEvent(ctx, event)
if err != nil {
return fmt.Errorf("failed to handle update label event: %w", err)
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return err
}
case proton.EventDelete:
updates, err := user.handleDeleteLabelEvent(ctx, event)
if err != nil {
return fmt.Errorf("failed to handle delete label event: %w", err)
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return fmt.Errorf("failed to handle delete label event in gluon: %w", err)
}
}
}
return nil
}
func (user *User) handleCreateLabelEvent(_ context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
return safe.LockRetErr(func() ([]imap.Update, error) {
var updates []imap.Update
user.log.WithFields(logrus.Fields{
"labelID": event.ID,
"name": logging.Sensitive(event.Label.Name),
}).Info("Handling label created event")
user.apiLabels[event.Label.ID] = event.Label
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
update := newMailboxCreatedUpdate(imap.MailboxID(event.ID), getMailboxName(event.Label))
updateCh.Enqueue(update)
updates = append(updates, update)
}
user.eventCh.Enqueue(events.UserLabelCreated{
UserID: user.apiUser.ID,
LabelID: event.Label.ID,
Name: event.Label.Name,
})
return updates, nil
}, user.apiLabelsLock, user.updateChLock)
}
func (user *User) handleUpdateLabelEvent(ctx context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
return safe.LockRetErr(func() ([]imap.Update, error) {
var updates []imap.Update
user.log.WithFields(logrus.Fields{
"labelID": event.ID,
"name": logging.Sensitive(event.Label.Name),
}).Info("Handling label updated event")
stack := []proton.Label{event.Label}
for len(stack) > 0 {
label := stack[0]
stack = stack[1:]
// Only update the label if it exists; we don't want to create it as a client may have just deleted it.
if _, ok := user.apiLabels[label.ID]; ok {
user.apiLabels[label.ID] = event.Label
}
// API doesn't notify us that the path has changed. We need to fetch it again.
apiLabel, err := user.client.GetLabel(ctx, label.ID, label.Type)
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
user.log.WithError(apiErr).Warn("Failed to get label: label does not exist")
continue
} else if err != nil {
return nil, fmt.Errorf("failed to get label %q: %w", label.ID, err)
}
// Update the label in the map.
user.apiLabels[apiLabel.ID] = apiLabel
// Notify the IMAP clients.
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
update := imap.NewMailboxUpdated(
imap.MailboxID(apiLabel.ID),
getMailboxName(apiLabel),
)
updateCh.Enqueue(update)
updates = append(updates, update)
}
user.eventCh.Enqueue(events.UserLabelUpdated{
UserID: user.apiUser.ID,
LabelID: apiLabel.ID,
Name: apiLabel.Name,
})
children := xslices.Filter(maps.Values(user.apiLabels), func(other proton.Label) bool {
return other.ParentID == label.ID
})
stack = append(stack, children...)
}
return updates, nil
}, user.apiLabelsLock, user.updateChLock)
}
func (user *User) handleDeleteLabelEvent(_ context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
return safe.LockRetErr(func() ([]imap.Update, error) {
var updates []imap.Update
user.log.WithField("labelID", event.ID).Info("Handling label deleted event")
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
update := imap.NewMailboxDeleted(imap.MailboxID(event.ID))
updateCh.Enqueue(update)
updates = append(updates, update)
}
delete(user.apiLabels, event.ID)
user.eventCh.Enqueue(events.UserLabelDeleted{
UserID: user.apiUser.ID,
LabelID: event.ID,
})
return updates, nil
}, user.apiLabelsLock, user.updateChLock)
}
// handleMessageEvents handles the given message events.
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []proton.MessageEvent) error {
for _, event := range messageEvents {
ctx = logging.WithLogrusField(ctx, "messageID", event.ID)
switch event.Action {
case proton.EventCreate:
updates, err := user.handleCreateMessageEvent(logging.WithLogrusField(ctx, "action", "create message"), event.Message)
if err != nil {
user.reportError("Failed to apply create message event", err)
return fmt.Errorf("failed to handle create message event: %w", err)
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return err
}
case proton.EventUpdate, proton.EventUpdateFlags:
// Draft update means to completely remove old message and upload the new data again, but we should
// only do this if the event is of type EventUpdate otherwise label switch operations will not work.
if (event.Message.IsDraft() || (event.Message.Flags&proton.MessageFlagSent != 0)) && event.Action == proton.EventUpdate {
updates, err := user.handleUpdateDraftOrSentMessage(
logging.WithLogrusField(ctx, "action", "update draft or sent message"),
event,
)
if err != nil {
user.reportError("Failed to apply update draft message event", err)
return fmt.Errorf("failed to handle update draft event: %w", err)
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return err
}
continue
}
// GODT-2028 - Use better events here. It should be possible to have 3 separate events that refrain to
// whether the flags, labels or read only data (header+body) has been changed. This requires fixing proton
// first so that it correctly reports those cases.
// Issue regular update to handle mailboxes and flag changes.
updates, err := user.handleUpdateMessageEvent(
logging.WithLogrusField(ctx, "action", "update message"),
event.Message,
)
if err != nil {
user.reportError("Failed to apply update message event", err)
return fmt.Errorf("failed to handle update message event: %w", err)
}
// If the update fails on the gluon side because it doesn't exist, we try to create the message instead.
if err := waitOnIMAPUpdates(ctx, updates); gluon.IsNoSuchMessage(err) {
user.log.WithError(err).Error("Failed to handle update message event in gluon, will try creating it")
updates, err := user.handleCreateMessageEvent(ctx, event.Message)
if err != nil {
return fmt.Errorf("failed to handle update message event as create: %w", err)
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return err
}
} else if err != nil {
return err
}
case proton.EventDelete:
updates, err := user.handleDeleteMessageEvent(
logging.WithLogrusField(ctx, "action", "delete message"),
event,
)
if err != nil {
user.reportError("Failed to apply delete message event", err)
return fmt.Errorf("failed to handle delete message event: %w", err)
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return fmt.Errorf("failed to handle delete message event in gluon: %w", err)
}
}
}
return nil
}
func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.MessageMetadata) ([]imap.Update, error) {
user.log.WithFields(logrus.Fields{
"messageID": message.ID,
"subject": logging.Sensitive(message.Subject),
}).Info("Handling message created event")
full, err := user.client.GetFullMessage(ctx, message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
// If the message is not found, it means that it has been deleted before we could fetch it.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
user.log.WithField("messageID", message.ID).Warn("Cannot create new message: full message is missing on API")
return nil, nil
}
return nil, fmt.Errorf("failed to get full message: %w", err)
}
return safe.RLockRetErr(func() ([]imap.Update, error) {
var update imap.Update
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
if res.err != nil {
user.log.WithError(err).Error("Failed to build RFC822 message")
if err := user.vault.AddFailedMessageID(message.ID); err != nil {
user.log.WithError(err).Error("Failed to add failed message ID to vault")
}
user.reportErrorAndMessageID("Failed to build message (event create)", res.err, res.messageID)
return nil
}
if err := user.vault.RemFailedMessageID(message.ID); err != nil {
user.log.WithError(err).Error("Failed to remove failed message ID from vault")
}
update = imap.NewMessagesCreated(false, res.update)
didPublish, err := safePublishMessageUpdate(user, full.AddressID, update)
if err != nil {
return err
}
if !didPublish {
update = nil
}
return nil
}); err != nil {
return nil, err
}
if update == nil {
return nil, nil
}
return []imap.Update{update}, nil
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
}
func (user *User) handleUpdateMessageEvent(_ context.Context, message proton.MessageMetadata) ([]imap.Update, error) { //nolint:unparam
return safe.RLockRetErr(func() ([]imap.Update, error) {
user.log.WithFields(logrus.Fields{
"messageID": message.ID,
"subject": logging.Sensitive(message.Subject),
}).Info("Handling message updated event")
flags := buildFlagSetFromMessageMetadata(message)
update := imap.NewMessageMailboxesUpdated(
imap.MessageID(message.ID),
usertypes.MapTo[string, imap.MailboxID](wantLabels(user.apiLabels, message.LabelIDs)),
flags,
)
didPublish, err := safePublishMessageUpdate(user, message.AddressID, update)
if err != nil {
return nil, err
}
if !didPublish {
return nil, nil
}
return []imap.Update{update}, nil
}, user.apiLabelsLock, user.updateChLock)
}
func (user *User) handleDeleteMessageEvent(_ context.Context, event proton.MessageEvent) ([]imap.Update, error) {
return safe.RLockRetErr(func() ([]imap.Update, error) {
user.log.WithField("messageID", event.ID).Info("Handling message deleted event")
var updates []imap.Update
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
update := imap.NewMessagesDeleted(imap.MessageID(event.ID))
updateCh.Enqueue(update)
updates = append(updates, update)
}
return updates, nil
}, user.updateChLock)
}
func (user *User) handleUpdateDraftOrSentMessage(ctx context.Context, event proton.MessageEvent) ([]imap.Update, error) {
return safe.RLockRetErr(func() ([]imap.Update, error) {
user.log.WithFields(logrus.Fields{
"messageID": event.ID,
"subject": logging.Sensitive(event.Message.Subject),
"isDraft": event.Message.IsDraft(),
}).Info("Handling draft or sent updated event")
full, err := user.client.GetFullMessage(ctx, event.Message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
// If the message is not found, it means that it has been deleted before we could fetch it.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
user.log.WithField("messageID", event.Message.ID).Warn("Cannot update message: full message is missing on API")
return nil, nil
}
return nil, fmt.Errorf("failed to get full draft: %w", err)
}
var update imap.Update
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
if res.err != nil {
logrus.WithError(err).Error("Failed to build RFC822 message")
if err := user.vault.AddFailedMessageID(event.ID); err != nil {
user.log.WithError(err).Error("Failed to add failed message ID to vault")
}
user.reportErrorAndMessageID("Failed to build draft message (event update)", res.err, res.messageID)
return nil
}
if err := user.vault.RemFailedMessageID(event.ID); err != nil {
user.log.WithError(err).Error("Failed to remove failed message ID from vault")
}
update = imap.NewMessageUpdated(
res.update.Message,
res.update.Literal,
res.update.MailboxIDs,
res.update.ParsedMessage,
true, // Is the message doesn't exist, silently create it.
)
didPublish, err := safePublishMessageUpdate(user, full.AddressID, update)
if err != nil {
return err
}
if !didPublish {
update = nil
}
return nil
}); err != nil {
return nil, err
}
if update == nil {
return nil, nil
}
return []imap.Update{update}, nil
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
}
func (user *User) handleUsedSpaceChange(usedSpace int) {
safe.Lock(func() {
if user.apiUser.UsedSpace == usedSpace {
return
}
user.apiUser.UsedSpace = usedSpace
user.eventCh.Enqueue(events.UsedSpaceChanged{
UserID: user.apiUser.ID,
UsedSpace: usedSpace,
})
}, user.apiUserLock)
}
func getMailboxName(label proton.Label) []string {
var name []string
switch label.Type {
case proton.LabelTypeFolder:
name = append([]string{folderPrefix}, label.Path...)
case proton.LabelTypeLabel:
name = append([]string{labelPrefix}, label.Path...)
case proton.LabelTypeContactGroup:
fallthrough
case proton.LabelTypeSystem:
fallthrough
default:
name = label.Path
}
return name
}
func waitOnIMAPUpdates(ctx context.Context, updates []imap.Update) error {
for _, update := range updates {
if err, ok := update.WaitContext(ctx); ok && err != nil {
return fmt.Errorf("failed to apply gluon update %v: %w", update.String(), err)
}
}
return nil
}
func (user *User) reportError(title string, err error) {
user.reportErrorNoContextCancel(title, err, reporter.Context{})
}
func (user *User) reportErrorAndMessageID(title string, err error, messgeID string) {
user.reportErrorNoContextCancel(title, err, reporter.Context{"messageID": messgeID})
}
func (user *User) reportErrorNoContextCancel(title string, err error, reportContext reporter.Context) {
if !errors.Is(err, context.Canceled) {
reportContext["error"] = err
reportContext["error_type"] = internal.ErrCauseType(err)
if rerr := user.reporter.ReportMessageWithContext(title, reportContext); rerr != nil {
user.log.WithError(err).WithField("title", title).Error("Failed to report message")
}
}
}
// safePublishMessageUpdate handles the rare case where the address' update channel may have been deleted in the same
// event. This rare case can take place if in the same event fetch request there is an update for delete address and
// create/update message.
// If the user is in combined mode, we simply push the update to the primary address. If the user is in split mode
// we do not publish the update as the address no longer exists.
func safePublishMessageUpdate(user *User, addressID string, update imap.Update) (bool, error) {
v, ok := user.updateCh[addressID]
if !ok {
if user.GetAddressMode() == vault.CombinedMode {
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
if err != nil {
return false, fmt.Errorf("failed to get primary address: %w", err)
}
primaryCh, ok := user.updateCh[primAddr.ID]
if !ok {
return false, fmt.Errorf("primary address channel is not available")
}
primaryCh.Enqueue(update)
return true, nil
}
logrus.Warnf("Update channel not found for address %v, it may have been already deleted", addressID)
_ = user.reporter.ReportMessage("Message Update channel does not exist")
return false, nil
}
v.Enqueue(update)
return true, nil
}

View File

@ -1,750 +0,0 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package user
import (
"bytes"
"context"
"errors"
"fmt"
"net/mail"
"sync/atomic"
"time"
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/rfc5322"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
"github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
)
// Verify that *imapConnector implements connector.Connector.
var _ connector.Connector = (*imapConnector)(nil)
var (
defaultFlags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged, imap.FlagDeleted) // nolint:gochecknoglobals
defaultPermanentFlags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged, imap.FlagDeleted) // nolint:gochecknoglobals
defaultAttributes = imap.NewFlagSet() // nolint:gochecknoglobals
)
const (
folderPrefix = "Folders"
labelPrefix = "Labels"
)
type imapConnector struct {
*User
addrID string
flags, permFlags, attrs imap.FlagSet
}
func newIMAPConnector(user *User, addrID string) *imapConnector {
return &imapConnector{
User: user,
addrID: addrID,
flags: defaultFlags,
permFlags: defaultPermanentFlags,
attrs: defaultAttributes,
}
}
// Authorize returns whether the given username/password combination are valid for this connector.
func (conn *imapConnector) Authorize(ctx context.Context, username string, password []byte) bool {
addrID, err := conn.CheckAuth(username, password)
if err != nil {
return false
}
if conn.vault.AddressMode() == vault.SplitMode && addrID != conn.addrID {
return false
}
conn.User.SendConfigStatusSuccess(ctx)
return true
}
// CreateMailbox creates a label with the given name.
func (conn *imapConnector) CreateMailbox(ctx context.Context, name []string) (imap.Mailbox, error) {
defer conn.goPollAPIEvents(false)
if len(name) < 2 {
return imap.Mailbox{}, fmt.Errorf("invalid mailbox name %q: %w", name, connector.ErrOperationNotAllowed)
}
switch name[0] {
case folderPrefix:
return conn.createFolder(ctx, name[1:])
case labelPrefix:
return conn.createLabel(ctx, name[1:])
default:
return imap.Mailbox{}, fmt.Errorf("invalid mailbox name %q: %w", name, connector.ErrOperationNotAllowed)
}
}
func (conn *imapConnector) createLabel(ctx context.Context, name []string) (imap.Mailbox, error) {
if len(name) != 1 {
return imap.Mailbox{}, fmt.Errorf("a label cannot have children: %w", connector.ErrOperationNotAllowed)
}
return safe.LockRetErr(func() (imap.Mailbox, error) {
label, err := conn.client.CreateLabel(ctx, proton.CreateLabelReq{
Name: name[0],
Color: "#f66",
Type: proton.LabelTypeLabel,
})
if err != nil {
return imap.Mailbox{}, err
}
conn.apiLabels[label.ID] = label
return toIMAPMailbox(label, conn.flags, conn.permFlags, conn.attrs), nil
}, conn.apiLabelsLock)
}
func (conn *imapConnector) createFolder(ctx context.Context, name []string) (imap.Mailbox, error) {
return safe.LockRetErr(func() (imap.Mailbox, error) {
var parentID string
if len(name) > 1 {
for _, label := range conn.apiLabels {
if !slices.Equal(label.Path, name[:len(name)-1]) {
continue
}
parentID = label.ID
break
}
if parentID == "" {
return imap.Mailbox{}, fmt.Errorf("parent folder %q does not exist: %w", name[:len(name)-1], connector.ErrOperationNotAllowed)
}
}
label, err := conn.client.CreateLabel(ctx, proton.CreateLabelReq{
Name: name[len(name)-1],
Color: "#f66",
Type: proton.LabelTypeFolder,
ParentID: parentID,
})
if err != nil {
return imap.Mailbox{}, err
}
// Add label to list so subsequent sub folder create requests work correct.
conn.apiLabels[label.ID] = label
return toIMAPMailbox(label, conn.flags, conn.permFlags, conn.attrs), nil
}, conn.apiLabelsLock)
}
// UpdateMailboxName sets the name of the label with the given ID.
func (conn *imapConnector) UpdateMailboxName(ctx context.Context, labelID imap.MailboxID, name []string) error {
return safe.LockRet(func() error {
defer conn.goPollAPIEvents(false)
if len(name) < 2 {
return fmt.Errorf("invalid mailbox name %q: %w", name, connector.ErrOperationNotAllowed)
}
switch name[0] {
case folderPrefix:
return conn.updateFolder(ctx, labelID, name[1:])
case labelPrefix:
return conn.updateLabel(ctx, labelID, name[1:])
default:
return fmt.Errorf("invalid mailbox name %q: %w", name, connector.ErrOperationNotAllowed)
}
}, conn.apiLabelsLock)
}
func (conn *imapConnector) updateLabel(ctx context.Context, labelID imap.MailboxID, name []string) error {
if len(name) != 1 {
return fmt.Errorf("a label cannot have children: %w", connector.ErrOperationNotAllowed)
}
label, err := conn.client.GetLabel(ctx, string(labelID), proton.LabelTypeLabel)
if err != nil {
return err
}
update, err := conn.client.UpdateLabel(ctx, label.ID, proton.UpdateLabelReq{
Name: name[0],
Color: label.Color,
})
if err != nil {
return err
}
conn.apiLabels[label.ID] = update
return nil
}
func (conn *imapConnector) updateFolder(ctx context.Context, labelID imap.MailboxID, name []string) error {
var parentID string
if len(name) > 1 {
for _, label := range conn.apiLabels {
if !slices.Equal(label.Path, name[:len(name)-1]) {
continue
}
parentID = label.ID
break
}
if parentID == "" {
return fmt.Errorf("parent folder %q does not exist: %w", name[:len(name)-1], connector.ErrOperationNotAllowed)
}
}
label, err := conn.client.GetLabel(ctx, string(labelID), proton.LabelTypeFolder)
if err != nil {
return err
}
update, err := conn.client.UpdateLabel(ctx, string(labelID), proton.UpdateLabelReq{
Name: name[len(name)-1],
Color: label.Color,
ParentID: parentID,
})
if err != nil {
return err
}
conn.apiLabels[label.ID] = update
return nil
}
// DeleteMailbox deletes the label with the given ID.
func (conn *imapConnector) DeleteMailbox(ctx context.Context, labelID imap.MailboxID) error {
return safe.LockRet(func() error {
defer conn.goPollAPIEvents(false)
if err := conn.client.DeleteLabel(ctx, string(labelID)); err != nil {
return err
}
delete(conn.apiLabels, string(labelID))
return nil
}, conn.apiLabelsLock)
}
// CreateMessage creates a new message on the remote.
func (conn *imapConnector) CreateMessage(
ctx context.Context,
mailboxID imap.MailboxID,
literal []byte,
flags imap.FlagSet,
_ time.Time,
) (imap.Message, []byte, error) {
defer conn.goPollAPIEvents(false)
if mailboxID == proton.AllMailLabel {
return imap.Message{}, nil, connector.ErrOperationNotAllowed
}
toList, err := getLiteralToList(literal)
if err != nil {
return imap.Message{}, nil, fmt.Errorf("failed to retrieve addresses from literal:%w", err)
}
// Compute the hash of the message (to match it against SMTP messages).
hash, err := sendrecorder.GetMessageHash(literal)
if err != nil {
return imap.Message{}, nil, err
}
// Check if we already tried to send this message recently.
if messageID, ok, err := conn.sendHash.HasEntryWait(ctx, hash, time.Now().Add(90*time.Second), toList); err != nil {
return imap.Message{}, nil, fmt.Errorf("failed to check send hash: %w", err)
} else if ok {
conn.log.WithField("messageID", messageID).Warn("Message already sent")
// Query the server-side message.
full, err := conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
}
// Build the message as it is on the server.
if err := safe.RLockRet(func() error {
return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[full.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
var err error
if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil {
return err
}
return nil
})
}, conn.apiUserLock, conn.apiAddrsLock); err != nil {
return imap.Message{}, nil, fmt.Errorf("failed to build message: %w", err)
}
return toIMAPMessage(full.MessageMetadata), literal, nil
}
wantLabelIDs := []string{string(mailboxID)}
if flags.Contains(imap.FlagFlagged) {
wantLabelIDs = append(wantLabelIDs, proton.StarredLabel)
}
var wantFlags proton.MessageFlag
unread := !flags.Contains(imap.FlagSeen)
if mailboxID != proton.DraftsLabel {
header, err := rfc822.Parse(literal).ParseHeader()
if err != nil {
return imap.Message{}, nil, err
}
switch {
case mailboxID == proton.InboxLabel:
wantFlags = wantFlags.Add(proton.MessageFlagReceived)
case mailboxID == proton.SentLabel:
wantFlags = wantFlags.Add(proton.MessageFlagSent)
case header.Has("Received"):
wantFlags = wantFlags.Add(proton.MessageFlagReceived)
default:
wantFlags = wantFlags.Add(proton.MessageFlagSent)
}
} else {
unread = false
}
if flags.Contains(imap.FlagAnswered) {
wantFlags = wantFlags.Add(proton.MessageFlagReplied)
}
msg, literal, err := conn.importMessage(ctx, literal, wantLabelIDs, wantFlags, unread)
if err != nil {
if errors.Is(err, proton.ErrImportSizeExceeded) {
// Remap error so that Gluon does not put this message in the recovery mailbox.
err = fmt.Errorf("%v: %w", err, connector.ErrMessageSizeExceedsLimits)
}
if apiErr := new(proton.APIError); errors.As(err, &apiErr) {
logrus.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("Failed to import message")
} else {
logrus.WithError(err).Error("Failed to import message")
}
}
return msg, literal, err
}
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
msg, err := conn.client.GetFullMessage(ctx, string(id), usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
return nil, err
}
return safe.RLockRetErr(func() ([]byte, error) {
var literal []byte
err := usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[msg.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
l, buildErr := message.BuildRFC822(addrKR, msg.Message, msg.AttData, defaultJobOpts())
if buildErr != nil {
return buildErr
}
literal = l
return nil
})
return literal, err
}, conn.apiUserLock, conn.apiAddrsLock)
}
// AddMessagesToMailbox labels the given messages with the given label ID.
func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs []imap.MessageID, mailboxID imap.MailboxID) error {
defer conn.goPollAPIEvents(false)
if isAllMailOrScheduled(mailboxID) {
return connector.ErrOperationNotAllowed
}
return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(mailboxID))
}
// RemoveMessagesFromMailbox unlabels the given messages with the given label ID.
func (conn *imapConnector) RemoveMessagesFromMailbox(ctx context.Context, messageIDs []imap.MessageID, mailboxID imap.MailboxID) error {
defer conn.goPollAPIEvents(false)
if isAllMailOrScheduled(mailboxID) {
return connector.ErrOperationNotAllowed
}
msgIDs := usertypes.MapTo[imap.MessageID, string](messageIDs)
if err := conn.client.UnlabelMessages(ctx, msgIDs, string(mailboxID)); err != nil {
return err
}
if mailboxID == proton.TrashLabel || mailboxID == proton.DraftsLabel {
if err := conn.client.DeleteMessage(ctx, msgIDs...); err != nil {
return err
}
}
return nil
}
// MoveMessages removes the given messages from one label and adds them to the other label.
func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.MessageID, labelFromID imap.MailboxID, labelToID imap.MailboxID) (bool, error) {
defer conn.goPollAPIEvents(false)
if (labelFromID == proton.InboxLabel && labelToID == proton.SentLabel) ||
(labelFromID == proton.SentLabel && labelToID == proton.InboxLabel) ||
isAllMailOrScheduled(labelFromID) ||
isAllMailOrScheduled(labelToID) {
return false, connector.ErrOperationNotAllowed
}
shouldExpungeOldLocation := func() bool {
conn.apiLabelsLock.RLock()
defer conn.apiLabelsLock.RUnlock()
var result bool
if v, ok := conn.apiLabels[string(labelFromID)]; ok && v.Type == proton.LabelTypeLabel {
result = true
}
if v, ok := conn.apiLabels[string(labelToID)]; ok && (v.Type == proton.LabelTypeFolder || v.Type == proton.LabelTypeSystem) {
result = true
}
return result
}()
if err := conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
return false, fmt.Errorf("labeling messages: %w", err)
}
if shouldExpungeOldLocation {
if err := conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
return false, fmt.Errorf("unlabeling messages: %w", err)
}
}
return shouldExpungeOldLocation, nil
}
// MarkMessagesSeen sets the seen value of the given messages.
func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.MessageID, seen bool) error {
defer conn.goPollAPIEvents(false)
if seen {
return conn.client.MarkMessagesRead(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...)
}
return conn.client.MarkMessagesUnread(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...)
}
// MarkMessagesFlagged sets the flagged value of the given messages.
func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs []imap.MessageID, flagged bool) error {
defer conn.goPollAPIEvents(false)
if flagged {
return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
}
return conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
}
// GetUpdates returns a stream of updates that the gluon server should apply.
// It is recommended that the returned channel is buffered with at least constants.ChannelBufferCount.
func (conn *imapConnector) GetUpdates() <-chan imap.Update {
return safe.RLockRet(func() <-chan imap.Update {
return conn.updateCh[conn.addrID].GetChannel()
}, conn.updateChLock)
}
// GetMailboxVisibility returns the visibility of a mailbox over IMAP.
func (conn *imapConnector) GetMailboxVisibility(_ context.Context, mailboxID imap.MailboxID) imap.MailboxVisibility {
switch mailboxID {
case proton.AllMailLabel:
if atomic.LoadUint32(&conn.showAllMail) != 0 {
return imap.Visible
}
return imap.Hidden
case proton.AllScheduledLabel:
return imap.HiddenIfEmpty
default:
return imap.Visible
}
}
// Close the connector will no longer be used and all resources should be closed/released.
func (conn *imapConnector) Close(_ context.Context) error {
return nil
}
func (conn *imapConnector) importMessage(
ctx context.Context,
literal []byte,
labelIDs []string,
flags proton.MessageFlag,
unread bool,
) (imap.Message, []byte, error) {
var full proton.FullMessage
if err := safe.RLockRet(func() error {
return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
var messageID string
if slices.Contains(labelIDs, proton.DraftsLabel) {
msg, err := conn.createDraft(ctx, literal, addrKR, conn.apiAddrs[conn.addrID])
if err != nil {
return fmt.Errorf("failed to create draft: %w", err)
}
// apply labels
messageID = msg.ID
} else {
str, err := conn.client.ImportMessages(ctx, addrKR, 1, 1, []proton.ImportReq{{
Metadata: proton.ImportMetadata{
AddressID: conn.addrID,
LabelIDs: labelIDs,
Unread: proton.Bool(unread),
Flags: flags,
},
Message: literal,
}}...)
if err != nil {
return fmt.Errorf("failed to prepare message for import: %w", err)
}
res, err := stream.Collect(ctx, str)
if err != nil {
return fmt.Errorf("failed to import message: %w", err)
}
messageID = res[0].MessageID
}
var err error
if full, err = conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
return fmt.Errorf("failed to fetch message: %w", err)
}
if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil {
return fmt.Errorf("failed to build message: %w", err)
}
return nil
})
}, conn.apiUserLock, conn.apiAddrsLock); err != nil {
return imap.Message{}, nil, err
}
return toIMAPMessage(full.MessageMetadata), literal, nil
}
func toIMAPMessage(message proton.MessageMetadata) imap.Message {
flags := buildFlagSetFromMessageMetadata(message)
var date time.Time
if message.Time > 0 {
date = time.Unix(message.Time, 0)
} else {
date = time.Now()
}
return imap.Message{
ID: imap.MessageID(message.ID),
Flags: flags,
Date: date,
}
}
func (conn *imapConnector) createDraft(ctx context.Context, literal []byte, addrKR *crypto.KeyRing, sender proton.Address) (proton.Message, error) {
// Create a new message parser from the reader.
parser, err := parser.New(bytes.NewReader(literal))
if err != nil {
return proton.Message{}, fmt.Errorf("failed to create parser: %w", err)
}
message, err := message.ParseWithParser(parser, true)
if err != nil {
return proton.Message{}, fmt.Errorf("failed to parse message: %w", err)
}
decBody := string(message.PlainBody)
if message.RichBody != "" {
decBody = string(message.RichBody)
}
draft, err := conn.client.CreateDraft(ctx, addrKR, proton.CreateDraftReq{
Message: proton.DraftTemplate{
Subject: message.Subject,
Body: decBody,
MIMEType: message.MIMEType,
Sender: &mail.Address{Name: sender.DisplayName, Address: sender.Email},
ToList: message.ToList,
CCList: message.CCList,
BCCList: message.BCCList,
ExternalID: message.ExternalID,
},
})
if err != nil {
return proton.Message{}, fmt.Errorf("failed to create draft: %w", err)
}
for _, att := range message.Attachments {
disposition := proton.AttachmentDisposition
if att.Disposition == "inline" && att.ContentID != "" {
disposition = proton.InlineDisposition
}
if _, err := conn.client.UploadAttachment(ctx, addrKR, proton.CreateAttachmentReq{
MessageID: draft.ID,
Filename: att.Name,
MIMEType: rfc822.MIMEType(att.MIMEType),
Disposition: disposition,
ContentID: att.ContentID,
Body: att.Data,
}); err != nil {
return proton.Message{}, fmt.Errorf("failed to add attachment to draft: %w", err)
}
}
return draft, nil
}
func toIMAPMailbox(label proton.Label, flags, permFlags, attrs imap.FlagSet) imap.Mailbox {
if label.Type == proton.LabelTypeLabel {
label.Path = append([]string{labelPrefix}, label.Path...)
} else if label.Type == proton.LabelTypeFolder {
label.Path = append([]string{folderPrefix}, label.Path...)
}
return imap.Mailbox{
ID: imap.MailboxID(label.ID),
Name: label.Path,
Flags: flags,
PermanentFlags: permFlags,
Attributes: attrs,
}
}
func isAllMailOrScheduled(mailboxID imap.MailboxID) bool {
return (mailboxID == proton.AllMailLabel) || (mailboxID == proton.AllScheduledLabel)
}
func buildFlagSetFromMessageMetadata(message proton.MessageMetadata) imap.FlagSet {
flags := imap.NewFlagSet()
if message.Seen() {
flags.AddToSelf(imap.FlagSeen)
}
if message.Starred() {
flags.AddToSelf(imap.FlagFlagged)
}
if message.IsDraft() {
flags.AddToSelf(imap.FlagDraft)
}
if message.IsRepliedAll == true || message.IsReplied == true { //nolint: gosimple
flags.AddToSelf(imap.FlagAnswered)
}
return flags
}
func getLiteralToList(literal []byte) ([]string, error) {
headerLiteral, _ := rfc822.Split(literal)
header, err := rfc822.NewHeader(headerLiteral)
if err != nil {
return nil, err
}
var result []string
parseAddress := func(field string) error {
if fieldAddr, ok := header.GetChecked(field); ok {
addr, err := rfc5322.ParseAddressList(fieldAddr)
if err != nil {
return fmt.Errorf("failed to parse addresses for '%v': %w", field, err)
}
result = append(result, xslices.Map(addr, func(addr *mail.Address) string {
return addr.Address
})...)
return nil
}
return nil
}
if err := parseAddress("To"); err != nil {
return nil, err
}
if err := parseAddress("Cc"); err != nil {
return nil, err
}
if err := parseAddress("Bcc"); err != nil {
return nil, err
}
return result, nil
}

View File

@ -36,8 +36,14 @@ func BenchmarkAddrKeyRing(b *testing.B) {
withUser(b, ctx, s, m, "username", "password", func(user *User) {
b.StartTimer()
apiUser, err := user.identityService.GetAPIUser(ctx)
require.NoError(b, err)
apiAddrs, err := user.identityService.GetAddresses(ctx)
require.NoError(b, err)
for i := 0; i < b.N; i++ {
require.NoError(b, usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
require.NoError(b, usertypes.WithAddrKRs(apiUser, apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
return nil
}))
}

View File

@ -1,918 +0,0 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package user
import (
"bytes"
"context"
"fmt"
"os"
"runtime"
"strings"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/logging"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/xslices"
"github.com/pbnjay/memory"
"github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
// syncSystemLabels ensures that system labels are all known to gluon.
func (user *User) syncSystemLabels(ctx context.Context) error {
return safe.RLockRet(func() error {
var updates []imap.Update
for _, label := range xslices.Filter(maps.Values(user.apiLabels), func(label proton.Label) bool { return label.Type == proton.LabelTypeSystem }) {
if !wantLabel(label) {
continue
}
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
update := newSystemMailboxCreatedUpdate(imap.MailboxID(label.ID), label.Name)
updateCh.Enqueue(update)
updates = append(updates, update)
}
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return fmt.Errorf("could not sync system labels: %w", err)
}
return nil
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
}
// doSync begins syncing the user's data.
// It first ensures the latest event ID is known; if not, it fetches it.
// It sends a SyncStarted event and then either SyncFinished or SyncFailed
// depending on whether the sync was successful.
func (user *User) doSync(ctx context.Context) error {
if user.vault.EventID() == "" {
eventID, err := user.client.GetLatestEventID(ctx)
if err != nil {
return fmt.Errorf("failed to get latest event ID: %w", err)
}
if err := user.vault.SetEventID(eventID); err != nil {
return fmt.Errorf("failed to set latest event ID: %w", err)
}
}
start := time.Now()
user.log.WithField("start", start).Info("Beginning user sync")
user.eventCh.Enqueue(events.SyncStarted{
UserID: user.ID(),
})
if err := user.sync(ctx); err != nil {
user.log.WithError(err).Warn("Failed to sync user")
user.eventCh.Enqueue(events.SyncFailed{
UserID: user.ID(),
Error: err,
})
return fmt.Errorf("failed to sync: %w", err)
}
user.log.WithField("duration", time.Since(start)).Info("Finished user sync")
user.eventCh.Enqueue(events.SyncFinished{
UserID: user.ID(),
})
return nil
}
func (user *User) sync(ctx context.Context) error {
return safe.RLockRet(func() error {
return usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
if !user.vault.SyncStatus().HasLabels {
user.log.Info("Syncing labels")
if err := syncLabels(ctx, user.apiLabels, xslices.Unique(maps.Values(user.updateCh))...); err != nil {
return fmt.Errorf("failed to sync labels: %w", err)
}
if err := user.vault.SetHasLabels(true); err != nil {
return fmt.Errorf("failed to set has labels: %w", err)
}
user.log.Info("Synced labels")
} else {
user.log.Info("Labels are already synced, skipping")
}
if !user.vault.SyncStatus().HasMessages {
user.log.Info("Syncing messages")
// Determine which messages to sync.
messageIDs, err := user.client.GetMessageIDs(ctx, "")
if err != nil {
return fmt.Errorf("failed to get message IDs to sync: %w", err)
}
logrus.Debugf("User has the following failed synced message ids: %v", user.vault.SyncStatus().FailedMessageIDs)
// Remove any messages that have already failed to sync.
messageIDs = xslices.Filter(messageIDs, func(messageID string) bool {
return !slices.Contains(user.vault.SyncStatus().FailedMessageIDs, messageID)
})
// Reverse the order of the message IDs so that the newest messages are synced first.
xslices.Reverse(messageIDs)
// If we have a message ID that we've already synced, then we can skip all messages before it.
if idx := xslices.Index(messageIDs, user.vault.SyncStatus().LastMessageID); idx >= 0 {
messageIDs = messageIDs[idx+1:]
}
// Sync the messages.
if err := user.syncMessages(
ctx,
user.ID(),
messageIDs,
user.client,
user.reporter,
user.vault,
user.apiLabels,
addrKRs,
user.updateCh,
user.eventCh,
user.maxSyncMemory,
); err != nil {
return fmt.Errorf("failed to sync messages: %w", err)
}
if err := user.vault.SetHasMessages(true); err != nil {
return fmt.Errorf("failed to set has messages: %w", err)
}
user.log.Info("Synced messages")
} else {
user.log.Info("Messages are already synced, skipping")
}
return nil
})
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
}
// nolint:exhaustive
func syncLabels(ctx context.Context, apiLabels map[string]proton.Label, updateCh ...*async.QueuedChannel[imap.Update]) error {
var updates []imap.Update
// Create placeholder Folders/Labels mailboxes with the \Noselect attribute.
for _, prefix := range []string{folderPrefix, labelPrefix} {
for _, updateCh := range updateCh {
update := newPlaceHolderMailboxCreatedUpdate(prefix)
updateCh.Enqueue(update)
updates = append(updates, update)
}
}
// Sync the user's labels.
for labelID, label := range apiLabels {
if !wantLabel(label) {
continue
}
switch label.Type {
case proton.LabelTypeSystem:
for _, updateCh := range updateCh {
update := newSystemMailboxCreatedUpdate(imap.MailboxID(label.ID), label.Name)
updateCh.Enqueue(update)
updates = append(updates, update)
}
case proton.LabelTypeFolder, proton.LabelTypeLabel:
for _, updateCh := range updateCh {
update := newMailboxCreatedUpdate(imap.MailboxID(labelID), getMailboxName(label))
updateCh.Enqueue(update)
updates = append(updates, update)
}
default:
return fmt.Errorf("unknown label type: %d", label.Type)
}
}
// Wait for all label updates to be applied.
for _, update := range updates {
err, ok := update.WaitContext(ctx)
if ok && err != nil {
return fmt.Errorf("failed to apply label create update in gluon %v: %w", update.String(), err)
}
}
return nil
}
const Kilobyte = uint64(1024)
const Megabyte = 1024 * Kilobyte
const Gigabyte = 1024 * Megabyte
func toMB(v uint64) float64 {
return float64(v) / float64(Megabyte)
}
type syncLimits struct {
MaxDownloadRequestMem uint64
MinDownloadRequestMem uint64
MaxMessageBuildingMem uint64
MinMessageBuildingMem uint64
MaxSyncMemory uint64
MaxParallelDownloads int
}
func newSyncLimits(maxSyncMemory uint64) syncLimits {
limits := syncLimits{
// There's no point in using more than 128MB of download data per stage, after that we reach a point of diminishing
// returns as we can't keep the pipeline fed fast enough.
MaxDownloadRequestMem: 128 * Megabyte,
// Any lower than this and we may fail to download messages.
MinDownloadRequestMem: 40 * Megabyte,
// This value can be increased to your hearts content. The more system memory the user has, the more messages
// we can build in parallel.
MaxMessageBuildingMem: 128 * Megabyte,
MinMessageBuildingMem: 64 * Megabyte,
// Maximum recommend value for parallel downloads by the API team.
MaxParallelDownloads: 20,
MaxSyncMemory: maxSyncMemory,
}
if _, ok := os.LookupEnv("BRIDGE_SYNC_FORCE_MINIMUM_SPEC"); ok {
logrus.Warn("Sync specs forced to minimum")
limits.MaxDownloadRequestMem = 50 * Megabyte
limits.MaxMessageBuildingMem = 80 * Megabyte
limits.MaxParallelDownloads = 2
limits.MaxSyncMemory = 800 * Megabyte
}
return limits
}
// nolint:gocyclo
func (user *User) syncMessages(
ctx context.Context,
userID string,
messageIDs []string,
client *proton.Client,
sentry reporter.Reporter,
vault *vault.User,
apiLabels map[string]proton.Label,
addrKRs map[string]*crypto.KeyRing,
updateCh map[string]*async.QueuedChannel[imap.Update],
eventCh *async.QueuedChannel[events.Event],
cfgMaxSyncMemory uint64,
) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Track the amount of time to process all the messages.
syncStartTime := time.Now()
defer func() { logrus.WithField("duration", time.Since(syncStartTime)).Info("Message sync completed") }()
user.log.WithFields(logrus.Fields{
"messages": len(messageIDs),
"numCPU": runtime.NumCPU(),
}).Info("Starting message sync")
// Create the flushers, one per update channel.
// Create a reporter to report sync progress updates.
syncReporter := newSyncReporter(userID, eventCh, len(messageIDs), time.Second)
defer syncReporter.done()
// Expected mem usage for this whole process should be the sum of MaxMessageBuildingMem and MaxDownloadRequestMem
// times x due to pipeline and all additional memory used by network requests and compression+io.
totalMemory := memory.TotalMemory()
syncLimits := newSyncLimits(cfgMaxSyncMemory)
if syncLimits.MaxSyncMemory >= totalMemory/2 {
logrus.Warnf("Requested max sync memory of %v MB is greater than half of system memory (%v MB), forcing to half of system memory",
toMB(syncLimits.MaxSyncMemory), toMB(totalMemory/2))
syncLimits.MaxSyncMemory = totalMemory / 2
}
if syncLimits.MaxSyncMemory < 800*Megabyte {
logrus.Warnf("Requested max sync memory of %v MB, but minimum recommended is 800 MB, forcing max syncMemory to 800MB", toMB(syncLimits.MaxSyncMemory))
syncLimits.MaxSyncMemory = 800 * Megabyte
}
logrus.Debugf("Total System Memory: %v", toMB(totalMemory))
// Linter says it's not used. This is a lie.
//nolint: staticcheck
syncMaxDownloadRequestMem := syncLimits.MaxDownloadRequestMem
// Linter says it's not used. This is a lie.
//nolint: staticcheck
syncMaxMessageBuildingMem := syncLimits.MaxMessageBuildingMem
// If less than 2GB available try and limit max memory to 512 MB
switch {
case syncLimits.MaxSyncMemory < 2*Gigabyte:
if syncLimits.MaxSyncMemory < 800*Megabyte {
logrus.Warnf("System has less than 800MB of memory, you may experience issues sycing large mailboxes")
}
syncMaxDownloadRequestMem = syncLimits.MinDownloadRequestMem
syncMaxMessageBuildingMem = syncLimits.MinMessageBuildingMem
case syncLimits.MaxSyncMemory == 2*Gigabyte:
// Increasing the max download capacity has very little effect on sync speed. We could increase the download
// memory but the user would see less sync notifications. A smaller value here leads to more frequent
// updates. Additionally, most of sync time is spent in the message building.
syncMaxDownloadRequestMem = syncLimits.MaxDownloadRequestMem
// Currently limited so that if a user has multiple accounts active it also doesn't cause excessive memory usage.
syncMaxMessageBuildingMem = syncLimits.MaxMessageBuildingMem
default:
// Divide by 8 as download stage and build stage will use aprox. 4x the specified memory.
remainingMemory := (syncLimits.MaxSyncMemory - 2*Gigabyte) / 8
syncMaxDownloadRequestMem = syncLimits.MaxDownloadRequestMem + remainingMemory
syncMaxMessageBuildingMem = syncLimits.MaxMessageBuildingMem + remainingMemory
}
logrus.Debugf("Max memory usage for sync Download=%vMB Building=%vMB Predicted Max Total=%vMB",
toMB(syncMaxDownloadRequestMem),
toMB(syncMaxMessageBuildingMem),
toMB((syncMaxMessageBuildingMem*4)+(syncMaxDownloadRequestMem*4)),
)
type flushUpdate struct {
messageID string
err error
batchLen int
}
type downloadRequest struct {
ids []string
expectedSize uint64
err error
}
type downloadedMessageBatch struct {
batch []proton.FullMessage
}
type builtMessageBatch struct {
batch []*buildRes
}
downloadCh := make(chan downloadRequest)
buildCh := make(chan downloadedMessageBatch)
// The higher this value, the longer we can continue our download iteration before being blocked on channel writes
// to the update flushing goroutine.
flushCh := make(chan builtMessageBatch)
flushUpdateCh := make(chan flushUpdate)
errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
// Go routine in charge of downloading message metadata
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(downloadCh)
const MetadataDataPageSize = 150
var downloadReq downloadRequest
downloadReq.ids = make([]string, 0, MetadataDataPageSize)
metadataChunks := xslices.Chunk(messageIDs, MetadataDataPageSize)
for i, metadataChunk := range metadataChunks {
logrus.Debugf("Metadata Request (%v of %v), previous: %v", i, len(metadataChunks), len(downloadReq.ids))
metadata, err := client.GetMessageMetadataPage(ctx, 0, len(metadataChunk), proton.MessageFilter{ID: metadataChunk})
if err != nil {
logrus.WithError(err).Errorf("Failed to download message metadata for chunk %v", i)
downloadReq.err = err
select {
case downloadCh <- downloadReq:
case <-ctx.Done():
return
}
return
}
if ctx.Err() != nil {
return
}
// Build look up table so that messages are processed in the same order.
metadataMap := make(map[string]int, len(metadata))
for i, v := range metadata {
metadataMap[v.ID] = i
}
for i, id := range metadataChunk {
m := &metadata[metadataMap[id]]
nextSize := downloadReq.expectedSize + uint64(m.Size)
if nextSize >= syncMaxDownloadRequestMem || len(downloadReq.ids) >= 256 {
logrus.Debugf("Download Request Sent at %v of %v", i, len(metadata))
select {
case downloadCh <- downloadReq:
case <-ctx.Done():
return
}
downloadReq.expectedSize = 0
downloadReq.ids = make([]string, 0, MetadataDataPageSize)
nextSize = uint64(m.Size)
}
downloadReq.ids = append(downloadReq.ids, id)
downloadReq.expectedSize = nextSize
}
}
if len(downloadReq.ids) != 0 {
logrus.Debugf("Sending remaining download request")
select {
case downloadCh <- downloadReq:
case <-ctx.Done():
return
}
}
}, logging.Labels{"sync-stage": "meta-data"})
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(buildCh)
defer close(errorCh)
defer func() {
logrus.Debugf("sync downloader exit")
}()
attachmentDownloader := user.newAttachmentDownloader(ctx, client, syncLimits.MaxParallelDownloads)
defer attachmentDownloader.close()
for request := range downloadCh {
logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize))
if request.err != nil {
errorCh <- request.err
return
}
if ctx.Err() != nil {
errorCh <- ctx.Err()
return
}
result, err := parallel.MapContext(ctx, syncLimits.MaxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
defer async.HandlePanic(user.panicHandler)
var result proton.FullMessage
msg, err := client.GetMessage(ctx, id)
if err != nil {
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
return proton.FullMessage{}, err
}
attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
if err != nil {
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments")
return proton.FullMessage{}, err
}
result.Message = msg
result.AttData = attachments
return result, nil
})
if err != nil {
errorCh <- err
return
}
select {
case buildCh <- downloadedMessageBatch{
batch: result,
}:
case <-ctx.Done():
return
}
}
}, logging.Labels{"sync-stage": "download"})
// Goroutine which builds messages after they have been downloaded
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(flushCh)
defer func() {
logrus.Debugf("sync builder exit")
}()
maxMessagesInParallel := runtime.NumCPU()
for buildBatch := range buildCh {
if ctx.Err() != nil {
return
}
chunks := chunkSyncBuilderBatch(buildBatch.batch, syncMaxMessageBuildingMem)
for index, chunk := range chunks {
logrus.Debugf("Build request: %v of %v count=%v", index, len(chunks), len(chunk))
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
defer async.HandlePanic(user.panicHandler)
kr, ok := addrKRs[msg.AddressID]
if !ok {
logrus.Errorf("Address '%v' on message '%v' does not have an unlocked kerying", msg.AddressID, msg.ID)
return &buildRes{
messageID: msg.ID,
addressID: msg.AddressID,
err: fmt.Errorf("address does not have an unlocked keyring"),
}, nil
}
res := buildRFC822(apiLabels, msg, kr, new(bytes.Buffer))
if res.err != nil {
logrus.WithError(res.err).WithField("msgID", msg.ID).Error("Failed to build message (syn)")
}
return res, nil
})
if err != nil {
return
}
select {
case flushCh <- builtMessageBatch{result}:
case <-ctx.Done():
return
}
}
}
}, logging.Labels{"sync-stage": "builder"})
// Goroutine which converts the messages into updates and builds a waitable structure for progress tracking.
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(flushUpdateCh)
defer func() {
logrus.Debugf("sync flush exit")
}()
type updateTargetInfo struct {
queueIndex int
ch *async.QueuedChannel[imap.Update]
}
pendingUpdates := make([][]*imap.MessageCreated, len(updateCh))
addressToIndex := make(map[string]updateTargetInfo)
{
i := 0
for addrID, updateCh := range updateCh {
addressToIndex[addrID] = updateTargetInfo{
ch: updateCh,
queueIndex: i,
}
i++
}
}
for downloadBatch := range flushCh {
logrus.Debugf("Flush batch: %v", len(downloadBatch.batch))
for _, res := range downloadBatch.batch {
if res.err != nil {
if err := vault.AddFailedMessageID(res.messageID); err != nil {
logrus.WithError(err).Error("Failed to add failed message ID")
}
if err := sentry.ReportMessageWithContext("Failed to build message (sync)", reporter.Context{
"messageID": res.messageID,
"error": res.err,
}); err != nil {
logrus.WithError(err).Error("Failed to report message build error")
}
// We could sync a placeholder message here, but for now we skip it entirely.
continue
}
if err := vault.RemFailedMessageID(res.messageID); err != nil {
logrus.WithError(err).Error("Failed to remove failed message ID")
}
targetInfo := addressToIndex[res.addressID]
pendingUpdates[targetInfo.queueIndex] = append(pendingUpdates[targetInfo.queueIndex], res.update)
}
for _, info := range addressToIndex {
up := imap.NewMessagesCreated(true, pendingUpdates[info.queueIndex]...)
info.ch.Enqueue(up)
err, ok := up.WaitContext(ctx)
if ok && err != nil {
flushUpdateCh <- flushUpdate{
err: fmt.Errorf("failed to apply sync update to gluon %v: %w", up.String(), err),
}
return
}
pendingUpdates[info.queueIndex] = pendingUpdates[info.queueIndex][:0]
}
select {
case flushUpdateCh <- flushUpdate{
messageID: downloadBatch.batch[0].messageID,
err: nil,
batchLen: len(downloadBatch.batch),
}:
case <-ctx.Done():
return
}
}
}, logging.Labels{"sync-stage": "flush"})
for flushUpdate := range flushUpdateCh {
if flushUpdate.err != nil {
return flushUpdate.err
}
if err := vault.SetLastMessageID(flushUpdate.messageID); err != nil {
return fmt.Errorf("failed to set last synced message ID: %w", err)
}
syncReporter.add(flushUpdate.batchLen)
}
return <-errorCh
}
func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *imap.MailboxCreated {
if strings.EqualFold(labelName, imap.Inbox) {
labelName = imap.Inbox
}
attrs := imap.NewFlagSet(imap.AttrNoInferiors)
permanentFlags := defaultPermanentFlags
flags := defaultFlags
switch labelID {
case proton.TrashLabel:
attrs = attrs.Add(imap.AttrTrash)
case proton.SpamLabel:
attrs = attrs.Add(imap.AttrJunk)
case proton.AllMailLabel:
attrs = attrs.Add(imap.AttrAll)
flags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged)
permanentFlags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged)
case proton.ArchiveLabel:
attrs = attrs.Add(imap.AttrArchive)
case proton.SentLabel:
attrs = attrs.Add(imap.AttrSent)
case proton.DraftsLabel:
attrs = attrs.Add(imap.AttrDrafts)
case proton.StarredLabel:
attrs = attrs.Add(imap.AttrFlagged)
case proton.AllScheduledLabel:
labelName = "Scheduled" // API actual name is "All Scheduled"
}
return imap.NewMailboxCreated(imap.Mailbox{
ID: labelID,
Name: []string{labelName},
Flags: flags,
PermanentFlags: permanentFlags,
Attributes: attrs,
})
}
func newPlaceHolderMailboxCreatedUpdate(labelName string) *imap.MailboxCreated {
return imap.NewMailboxCreated(imap.Mailbox{
ID: imap.MailboxID(labelName),
Name: []string{labelName},
Flags: defaultFlags,
PermanentFlags: defaultPermanentFlags,
Attributes: imap.NewFlagSet(imap.AttrNoSelect),
})
}
func newMailboxCreatedUpdate(labelID imap.MailboxID, labelName []string) *imap.MailboxCreated {
return imap.NewMailboxCreated(imap.Mailbox{
ID: labelID,
Name: labelName,
Flags: defaultFlags,
PermanentFlags: defaultPermanentFlags,
Attributes: imap.NewFlagSet(),
})
}
func wantLabel(label proton.Label) bool {
if label.Type != proton.LabelTypeSystem {
return true
}
// nolint:exhaustive
switch label.ID {
case proton.InboxLabel:
return true
case proton.TrashLabel:
return true
case proton.SpamLabel:
return true
case proton.AllMailLabel:
return true
case proton.ArchiveLabel:
return true
case proton.SentLabel:
return true
case proton.DraftsLabel:
return true
case proton.StarredLabel:
return true
case proton.AllScheduledLabel:
return true
default:
return false
}
}
func wantLabels(apiLabels map[string]proton.Label, labelIDs []string) []string {
return xslices.Filter(labelIDs, func(labelID string) bool {
apiLabel, ok := apiLabels[labelID]
if !ok {
return false
}
return wantLabel(apiLabel)
})
}
type attachmentResult struct {
attachment []byte
err error
}
type attachmentJob struct {
id string
size int64
result chan attachmentResult
}
type attachmentDownloader struct {
workerCh chan attachmentJob
cancel context.CancelFunc
}
func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan attachmentJob) {
for {
select {
case <-ctx.Done():
return
case job, ok := <-work:
if !ok {
return
}
var b bytes.Buffer
b.Grow(int(job.size))
err := client.GetAttachmentInto(ctx, job.id, &b)
select {
case <-ctx.Done():
close(job.result)
return
case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
close(job.result)
}
}
}
}
func (user *User) newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
ctx, cancel := context.WithCancel(ctx)
for i := 0; i < workerCount; i++ {
workerCh = make(chan attachmentJob)
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
"sync": fmt.Sprintf("att-downloader %v", i),
})
}
return &attachmentDownloader{
workerCh: workerCh,
cancel: cancel,
}
}
func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) {
resultChs := make([]chan attachmentResult, len(attachments))
for i, id := range attachments {
resultChs[i] = make(chan attachmentResult, 1)
select {
case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}:
case <-ctx.Done():
return nil, ctx.Err()
}
}
result := make([][]byte, len(attachments))
var err error
for i := 0; i < len(attachments); i++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
case r := <-resultChs[i]:
if r.err != nil {
err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err)
}
result[i] = r.attachment
}
}
return result, err
}
func (a *attachmentDownloader) close() {
a.cancel()
}
func chunkSyncBuilderBatch(batch []proton.FullMessage, maxMemory uint64) [][]proton.FullMessage {
var expectedMemUsage uint64
var chunks [][]proton.FullMessage
var lastIndex int
var index int
for _, v := range batch {
var dataSize uint64
for _, a := range v.Attachments {
dataSize += uint64(a.Size)
}
// 2x increase for attachment due to extra memory needed for decrypting and writing
// in memory buffer.
dataSize *= 2
dataSize += uint64(len(v.Body))
nextMemSize := expectedMemUsage + dataSize
if nextMemSize >= maxMemory {
chunks = append(chunks, batch[lastIndex:index])
lastIndex = index
expectedMemUsage = dataSize
} else {
expectedMemUsage = nextMemSize
}
index++
}
if lastIndex < len(batch) {
chunks = append(chunks, batch[lastIndex:])
}
return chunks
}

View File

@ -1,174 +0,0 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package user
import (
"bytes"
"html/template"
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
"github.com/bradenaw/juniper/xslices"
)
type buildRes struct {
messageID string
addressID string
update *imap.MessageCreated
err error
}
func defaultJobOpts() message.JobOptions {
return message.JobOptions{
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
AddMessageIDReference: true, // Whether to include the MessageID in References.
}
}
func buildRFC822(apiLabels map[string]proton.Label, full proton.FullMessage, addrKR *crypto.KeyRing, buffer *bytes.Buffer) *buildRes {
var (
update *imap.MessageCreated
err error
)
buffer.Grow(full.Size)
if buildErr := message.BuildRFC822Into(addrKR, full.Message, full.AttData, defaultJobOpts(), buffer); buildErr != nil {
update = newMessageCreatedFailedUpdate(apiLabels, full.MessageMetadata, buildErr)
err = buildErr
} else if created, parseErr := newMessageCreatedUpdate(apiLabels, full.MessageMetadata, buffer.Bytes()); parseErr != nil {
update = newMessageCreatedFailedUpdate(apiLabels, full.MessageMetadata, parseErr)
err = parseErr
} else {
update = created
}
return &buildRes{
messageID: full.ID,
addressID: full.AddressID,
update: update,
err: err,
}
}
func newMessageCreatedUpdate(
apiLabels map[string]proton.Label,
message proton.MessageMetadata,
literal []byte,
) (*imap.MessageCreated, error) {
parsedMessage, err := imap.NewParsedMessage(literal)
if err != nil {
return nil, err
}
return &imap.MessageCreated{
Message: toIMAPMessage(message),
Literal: literal,
MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
ParsedMessage: parsedMessage,
}, nil
}
func newMessageCreatedFailedUpdate(
apiLabels map[string]proton.Label,
message proton.MessageMetadata,
err error,
) *imap.MessageCreated {
literal := newFailedMessageLiteral(message.ID, time.Unix(message.Time, 0), message.Subject, err)
parsedMessage, err := imap.NewParsedMessage(literal)
if err != nil {
panic(err)
}
return &imap.MessageCreated{
Message: toIMAPMessage(message),
MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
Literal: literal,
ParsedMessage: parsedMessage,
}
}
func newFailedMessageLiteral(
messageID string,
date time.Time,
subject string,
syncErr error,
) []byte {
var buf bytes.Buffer
if tmpl, err := template.New("header").Parse(failedMessageHeaderTemplate); err != nil {
panic(err)
} else if b, err := tmplExec(tmpl, map[string]any{
"Date": date.In(time.UTC).Format(time.RFC822),
}); err != nil {
panic(err)
} else if _, err := buf.Write(b); err != nil {
panic(err)
}
if tmpl, err := template.New("body").Parse(failedMessageBodyTemplate); err != nil {
panic(err)
} else if b, err := tmplExec(tmpl, map[string]any{
"MessageID": messageID,
"Subject": subject,
"Error": syncErr.Error(),
}); err != nil {
panic(err)
} else if _, err := buf.Write(lineWrap(algo.B64Encode(b))); err != nil {
panic(err)
}
return buf.Bytes()
}
func tmplExec(template *template.Template, data any) ([]byte, error) {
var buf bytes.Buffer
if err := template.Execute(&buf, data); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func lineWrap(b []byte) []byte {
return bytes.Join(xslices.Chunk(b, 76), []byte{'\r', '\n'})
}
const failedMessageHeaderTemplate = `Date: {{.Date}}
Subject: Message failed to build
Content-Type: text/plain
Content-Transfer-Encoding: base64
`
const failedMessageBodyTemplate = `Failed to build message:
Subject: {{.Subject}}
Error: {{.Error}}
MessageID: {{.MessageID}}
`

View File

@ -1,80 +0,0 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package user
import (
"errors"
"testing"
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/bradenaw/juniper/xslices"
"github.com/stretchr/testify/require"
)
func TestNewFailedMessageLiteral(t *testing.T) {
literal := newFailedMessageLiteral("abcd-efgh", time.Unix(123456789, 0), "subject", errors.New("oops"))
header, err := rfc822.Parse(literal).ParseHeader()
require.NoError(t, err)
require.Equal(t, "Message failed to build", header.Get("Subject"))
require.Equal(t, "29 Nov 73 21:33 UTC", header.Get("Date"))
require.Equal(t, "text/plain", header.Get("Content-Type"))
require.Equal(t, "base64", header.Get("Content-Transfer-Encoding"))
b, err := rfc822.Parse(literal).DecodedBody()
require.NoError(t, err)
require.Equal(t, string(b), "Failed to build message: \nSubject: subject\nError: oops\nMessageID: abcd-efgh\n")
parsed, err := imap.NewParsedMessage(literal)
require.NoError(t, err)
require.Equal(t, `("29 Nov 73 21:33 UTC" "Message failed to build" NIL NIL NIL NIL NIL NIL NIL NIL)`, parsed.Envelope)
require.Equal(t, `("text" "plain" () NIL NIL "base64" 114 2)`, parsed.Body)
require.Equal(t, `("text" "plain" () NIL NIL "base64" 114 2 NIL NIL NIL NIL)`, parsed.Structure)
}
func TestSyncChunkSyncBuilderBatch(t *testing.T) {
// GODT-2424 - Some messages were not fully built due to a bug in the chunking if the total memory used by the
// message would be higher than the maximum we allowed.
const totalMessageCount = 100
msg := proton.FullMessage{
Message: proton.Message{
Attachments: []proton.Attachment{
{
Size: int64(8 * Megabyte),
},
},
},
AttData: nil,
}
messages := xslices.Repeat(msg, totalMessageCount)
chunks := chunkSyncBuilderBatch(messages, 16*Megabyte)
var totalMessagesInChunks int
for _, v := range chunks {
totalMessagesInChunks += len(v)
}
require.Equal(t, totalMessagesInChunks, totalMessageCount)
}

View File

@ -1,72 +0,0 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package user
import (
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
)
type syncReporter struct {
userID string
eventCh *async.QueuedChannel[events.Event]
start time.Time
total int
count int
last time.Time
freq time.Duration
}
func newSyncReporter(userID string, eventCh *async.QueuedChannel[events.Event], total int, freq time.Duration) *syncReporter {
return &syncReporter{
userID: userID,
eventCh: eventCh,
start: time.Now(),
total: total,
freq: freq,
}
}
func (rep *syncReporter) add(delta int) {
rep.count += delta
if time.Since(rep.last) > rep.freq {
rep.eventCh.Enqueue(events.SyncProgress{
UserID: rep.userID,
Progress: float64(rep.count) / float64(rep.total),
Elapsed: time.Since(rep.start),
Remaining: time.Since(rep.start) * time.Duration(rep.total-(rep.count+1)) / time.Duration(rep.count+1),
})
rep.last = time.Now()
}
}
func (rep *syncReporter) done() {
rep.eventCh.Enqueue(events.SyncProgress{
UserID: rep.userID,
Progress: 1,
Elapsed: time.Since(rep.start),
Remaining: 0,
})
}

View File

@ -19,27 +19,19 @@ package user
import (
"context"
"crypto/subtle"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"path/filepath"
"strings"
"sync/atomic"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal"
"github.com/ProtonMail/proton-bridge/v3/internal/configstatus"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/services/imapservice"
"github.com/ProtonMail/proton-bridge/v3/internal/services/orderedtasks"
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
"github.com/ProtonMail/proton-bridge/v3/internal/services/smtp"
"github.com/ProtonMail/proton-bridge/v3/internal/services/userevents"
@ -65,6 +57,7 @@ const (
)
type User struct {
id string
log *logrus.Entry
vault *vault.User
@ -75,27 +68,7 @@ type User struct {
eventCh *async.QueuedChannel[events.Event]
eventLock safe.RWMutex
apiUser proton.User
apiUserLock safe.RWMutex
apiAddrs map[string]proton.Address
apiAddrsLock safe.RWMutex
apiLabels map[string]proton.Label
apiLabelsLock safe.RWMutex
updateCh map[string]*async.QueuedChannel[imap.Update]
updateChLock safe.RWMutex
tasks *async.Group
syncAbort async.Abortable
pollAbort async.Abortable
goSync func()
pollAPIEventsCh chan chan struct{}
goPollAPIEvents func(wait bool)
showAllMail uint32
maxSyncMemory uint64
@ -108,9 +81,11 @@ type User struct {
eventService *userevents.Service
identityService *useridentity.Service
smtpService *smtp.Service
imapService *imapservice.Service
serviceGroup *orderedtasks.OrderedCancelGroup
}
// New returns a new user.
func New(
ctx context.Context,
encVault *vault.User,
@ -122,6 +97,49 @@ func New(
maxSyncMemory uint64,
statsDir string,
telemetryManager telemetry.Availability,
serverManager imapservice.IMAPServerManager,
eventSubscription events.Subscription,
) (*User, error) {
user, err := newImpl(
ctx,
encVault,
client,
reporter,
apiUser,
crashHandler,
showAllMail,
maxSyncMemory,
statsDir,
telemetryManager,
serverManager,
eventSubscription,
)
if err != nil {
// Cleanup any pending resources on error
if user != nil {
user.Close()
}
return nil, err
}
return user, nil
}
// New returns a new user.
func newImpl(
ctx context.Context,
encVault *vault.User,
client *proton.Client,
reporter reporter.Reporter,
apiUser proton.User,
crashHandler async.PanicHandler,
showAllMail bool,
maxSyncMemory uint64,
statsDir string,
telemetryManager telemetry.Availability,
serverManager imapservice.IMAPServerManager,
eventSubscription events.Subscription,
) (*User, error) {
logrus.WithField("userID", apiUser.ID).Info("Creating new user")
@ -137,7 +155,7 @@ func New(
return nil, fmt.Errorf("failed to get labels: %w", err)
}
identityState := useridentity.NewState(apiUser, slices.Clone(apiAddrs), client)
identityState := useridentity.NewState(apiUser, apiAddrs, client)
logrus.WithFields(logrus.Fields{
"userID": apiUser.ID,
@ -151,31 +169,12 @@ func New(
return nil, fmt.Errorf("failed to init configuration status file: %w", err)
}
// Use null publisher for now to avoid conflicts with original event loop.
eventPublisher := &events.NullEventPublisher{}
// Use in memory store to avoid conflicts with original event loop.
idStore := userevents.NewInMemoryEventIDStore()
_ = idStore.Store(context.Background(), encVault.EventID())
eventService := userevents.NewService(
apiUser.ID,
client,
// Use in memory store to avoid conflicts with the original event loop.
idStore,
eventPublisher,
EventPeriod,
5*time.Minute,
crashHandler,
)
sendRecorder := sendrecorder.NewSendRecorder(sendrecorder.SendEntryExpiry)
identityService := useridentity.NewService(eventService, eventPublisher, identityState)
// Create the user object.
user := &User{
log: logrus.WithField("userID", apiUser.ID),
id: apiUser.ID,
vault: encVault,
client: client,
@ -185,22 +184,7 @@ func New(
eventCh: async.NewQueuedChannel[events.Event](0, 0, crashHandler, fmt.Sprintf("bridge-user-%v", apiUser.ID)),
eventLock: safe.NewRWMutex(),
apiUser: apiUser,
apiUserLock: safe.NewRWMutex(),
apiAddrs: usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }),
apiAddrsLock: safe.NewRWMutex(),
apiLabels: usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID }),
apiLabelsLock: safe.NewRWMutex(),
updateCh: make(map[string]*async.QueuedChannel[imap.Update]),
updateChLock: safe.NewRWMutex(),
tasks: async.NewGroup(context.Background(), crashHandler),
pollAPIEventsCh: make(chan chan struct{}),
showAllMail: b32(showAllMail),
maxSyncMemory: maxSyncMemory,
@ -209,11 +193,24 @@ func New(
configStatus: configStatus,
telemetryManager: telemetryManager,
identityService: identityService,
serviceGroup: orderedtasks.NewOrderedCancelGroup(crashHandler),
smtpService: nil,
eventService: eventService,
}
user.eventService = userevents.NewService(
apiUser.ID,
client,
userevents.NewVaultEventIDStore(encVault),
user,
EventPeriod,
5*time.Minute,
crashHandler,
)
addressMode := usertypes.VaultToAddressMode(encVault.AddressMode())
user.identityService = useridentity.NewService(user.eventService, user, identityState, encVault, user)
user.smtpService = smtp.NewService(
apiUser.ID,
client,
@ -223,20 +220,37 @@ func New(
encVault,
encVault,
user,
eventService,
usertypes.VaultToAddressMode(encVault.AddressMode()),
user.eventService,
addressMode,
identityState.Clone(),
)
user.imapService = imapservice.NewService(
client,
identityState.Clone(),
user,
encVault,
user.eventService,
serverManager,
user,
encVault,
encVault,
crashHandler,
sendRecorder,
user,
reporter,
addressMode,
eventSubscription,
user.maxSyncMemory,
showAllMail,
)
// Check for status_progress when triggered.
user.goStatusProgress = user.tasks.PeriodicOrTrigger(configstatus.ProgressCheckInterval, 0, func(ctx context.Context) {
user.SendConfigStatusProgress(ctx)
})
defer user.goStatusProgress()
// Initialize the user's update channels for its current address mode.
user.initUpdateCh(encVault.AddressMode())
// When we receive an auth object, we update it in the vault.
// This will be used to authorize the user on the next run.
user.client.AddAuthHandler(func(auth proton.Auth) {
@ -259,119 +273,83 @@ func New(
return nil
})
// When triggered, poll the API for events, optionally blocking until the poll is complete.
user.goPollAPIEvents = func(wait bool) {
doneCh := make(chan struct{})
go func() {
defer async.HandlePanic(user.panicHandler)
user.pollAPIEventsCh <- doneCh
}()
if wait {
<-doneCh
}
}
// When triggered, sync the user and then begin streaming API events.
user.goSync = user.tasks.Trigger(func(ctx context.Context) {
user.log.Info("Sync triggered")
// Sync the user.
user.syncAbort.Do(ctx, func(ctx context.Context) {
if user.vault.SyncStatus().IsComplete() {
user.log.Info("Sync already complete, only system label will be updated")
if err := user.syncSystemLabels(ctx); err != nil {
user.log.WithError(err).Error("Failed to update system labels")
return
}
user.log.Info("System label update complete, starting API event stream")
return
}
for {
if err := ctx.Err(); err != nil {
user.log.WithError(err).Error("Sync aborted")
return
} else if err := user.doSync(ctx); err != nil {
user.log.WithError(err).Error("Failed to sync, will retry later")
sleepCtx(ctx, SyncRetryCooldown)
} else {
user.log.Info("Sync complete, starting API event stream")
return
}
}
})
// Once we know the sync has completed, we can start polling for API events.
if user.vault.SyncStatus().IsComplete() {
user.pollAbort.Do(ctx, func(ctx context.Context) {
user.startEvents(ctx)
})
}
})
// Start Event Service
if err := user.eventService.Start(ctx, user.tasks); err != nil {
return nil, fmt.Errorf("failed to start event service: %w", err)
if err := user.eventService.Start(ctx, user.serviceGroup); err != nil {
return user, fmt.Errorf("failed to start event service: %w", err)
}
// Start Identity Service
user.identityService.Start(user.tasks)
user.identityService.Start(ctx, user.serviceGroup)
// Start SMTP Service
user.smtpService.Start(user.tasks)
user.smtpService.Start(ctx, user.serviceGroup)
if err := user.eventService.Resume(ctx); err != nil {
return nil, fmt.Errorf("failed to resume event service")
// Start IMAP Service
if err := user.imapService.Start(ctx, user.serviceGroup); err != nil {
return user, fmt.Errorf("failed to start imap service: %w", err)
}
return user, nil
}
func (user *User) TriggerSync() {
user.goSync()
}
// ID returns the user's ID.
func (user *User) ID() string {
return safe.RLockRet(func() string {
return user.apiUser.ID
}, user.apiUserLock)
return user.id
}
// Name returns the user's username.
func (user *User) Name() string {
return safe.RLockRet(func() string {
return user.apiUser.Name
}, user.apiUserLock)
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
defer cancel()
apiUser, err := user.identityService.GetAPIUser(ctx)
if err != nil {
return ""
}
return apiUser.Name
}
// Match matches the given query against the user's username and email addresses.
func (user *User) Match(query string) bool {
return safe.RLockRet(func() bool {
if query == user.apiUser.Name {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
defer cancel()
apiUser, err := user.identityService.GetAPIUser(ctx)
if err != nil {
return false
}
if query == apiUser.Name {
return true
}
for _, addr := range user.apiAddrs {
apiAddrs, err := user.identityService.GetAddresses(ctx)
if err != nil {
return false
}
for _, addr := range apiAddrs {
if query == addr.Email {
return true
}
}
return false
}, user.apiUserLock, user.apiAddrsLock)
}
// Emails returns all the user's active email addresses.
// It returns them in sorted order; the user's primary address is first.
func (user *User) Emails() []string {
return safe.RLockRet(func() []string {
addresses := xslices.Filter(maps.Values(user.apiAddrs), func(addr proton.Address) bool {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
defer cancel()
apiAddrs, err := user.identityService.GetAddresses(ctx)
if err != nil {
return nil
}
addresses := xslices.Filter(maps.Values(apiAddrs), func(addr proton.Address) bool {
return addr.Status == proton.AddressStatusEnabled && addr.Type != proton.AddressTypeExternal
})
@ -382,7 +360,6 @@ func (user *User) Emails() []string {
return xslices.Map(addresses, func(addr proton.Address) string {
return addr.Email
})
}, user.apiAddrsLock)
}
// GetAddressMode returns the user's current address mode.
@ -394,10 +371,6 @@ func (user *User) GetAddressMode() vault.AddressMode {
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
user.log.WithField("mode", mode).Info("Setting address mode")
user.syncAbort.Abort()
user.pollAbort.Abort()
return safe.LockRet(func() error {
if err := user.vault.SetAddressMode(mode); err != nil {
return fmt.Errorf("failed to set address mode: %w", err)
}
@ -406,36 +379,50 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er
return fmt.Errorf("failed to set smtp address mode: %w", err)
}
if err := user.clearSyncStatus(); err != nil {
return fmt.Errorf("failed to clear sync status: %w", err)
if err := user.imapService.SetAddressMode(ctx, usertypes.VaultToAddressMode(mode)); err != nil {
return fmt.Errorf("failed to imap address mode: %w", err)
}
return nil
}, user.eventLock, user.apiAddrsLock, user.updateChLock)
}
// CancelSyncAndEventPoll stops the sync or event poll go-routine.
func (user *User) CancelSyncAndEventPoll() {
user.syncAbort.Abort()
user.pollAbort.Abort()
}
// BadEventFeedbackResync sends user feedback whether should do message re-sync.
func (user *User) BadEventFeedbackResync(ctx context.Context) {
user.CancelSyncAndEventPoll()
func (user *User) BadEventFeedbackResync(ctx context.Context) error {
if err := user.imapService.OnBadEventResync(ctx); err != nil {
return fmt.Errorf("failed to execute bad event request on imap service: %w", err)
}
// We need to cancel the event poll later again as it is not guaranteed, due to timing, that we have a
// task to cancel.
if err := user.syncUserAddressesLabelsAndClearSync(ctx, true); err != nil {
user.log.WithError(err).Error("Bad event resync failed")
if err := user.identityService.Resync(ctx); err != nil {
return fmt.Errorf("failed to resync identity service: %w", err)
}
if err := user.smtpService.Resync(ctx); err != nil {
return fmt.Errorf("failed to resync smtp service: %w", err)
}
if err := user.imapService.Resync(ctx); err != nil {
return fmt.Errorf("failed to resync imap service: %w", err)
}
return nil
}
func (user *User) OnBadEvent(ctx context.Context) {
if err := user.imapService.OnBadEvent(ctx); err != nil {
user.log.WithError(err).Error("Failed to notify imap service of bad event")
}
}
// SetShowAllMail sets whether to show the All Mail mailbox.
func (user *User) SetShowAllMail(show bool) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
defer cancel()
user.log.WithField("show", show).Info("Setting show all mail")
atomic.StoreUint32(&user.showAllMail, b32(show))
if err := user.imapService.ShowAllMail(ctx, show); err != nil {
user.log.WithError(err).Error("Failed to set show all mail")
}
}
// GetGluonIDs returns the users gluon IDs.
@ -498,16 +485,28 @@ func (user *User) BridgePass() []byte {
// UsedSpace returns the total space used by the user on the API.
func (user *User) UsedSpace() int {
return safe.RLockRet(func() int {
return user.apiUser.UsedSpace
}, user.apiUserLock)
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
defer cancel()
apiUser, err := user.identityService.GetAPIUser(ctx)
if err != nil {
return 0
}
return apiUser.UsedSpace
}
// MaxSpace returns the amount of space the user can use on the API.
func (user *User) MaxSpace() int {
return safe.RLockRet(func() int {
return user.apiUser.MaxSpace
}, user.apiUserLock)
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
defer cancel()
apiUser, err := user.identityService.GetAPIUser(ctx)
if err != nil {
return 0
}
return apiUser.MaxSpace
}
// GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change).
@ -515,118 +514,32 @@ func (user *User) GetEventCh() <-chan events.Event {
return user.eventCh.GetChannel()
}
// NewIMAPConnector returns an IMAP connector for the given address.
// If not in split mode, this must be the primary address.
func (user *User) NewIMAPConnector(addrID string) connector.Connector {
return newIMAPConnector(user, addrID)
}
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
// In combined mode, this is just the user's primary address.
// In split mode, this is all the user's addresses.
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
return safe.RLockRetErr(func() (map[string]connector.Connector, error) {
imapConn := make(map[string]connector.Connector)
switch user.vault.AddressMode() {
case vault.CombinedMode:
primAddr, err := usertypes.GetAddrIdx(user.apiAddrs, 0)
if err != nil {
return nil, fmt.Errorf("failed to get primary address: %w", err)
}
imapConn[primAddr.ID] = newIMAPConnector(user, primAddr.ID)
case vault.SplitMode:
for addrID := range user.apiAddrs {
imapConn[addrID] = newIMAPConnector(user, addrID)
}
}
return imapConn, nil
}, user.apiAddrsLock)
}
// CheckAuth returns whether the given email and password can be used to authenticate over IMAP or SMTP with this user.
// It returns the address ID of the authenticated address.
func (user *User) CheckAuth(email string, password []byte) (string, error) {
user.log.WithField("email", logging.Sensitive(email)).Debug("Checking authentication")
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
defer cancel()
if email == "crash@bandicoot" {
panic("your wish is my command.. I crash")
}
dec, err := algo.B64RawDecode(password)
if err != nil {
return "", fmt.Errorf("failed to decode password: %w", err)
}
if subtle.ConstantTimeCompare(user.vault.BridgePass(), dec) != 1 {
err := fmt.Errorf("invalid password")
user.ReportConfigStatusFailure(err.Error())
return "", err
}
return safe.RLockRetErr(func() (string, error) {
for _, addr := range user.apiAddrs {
if addr.Status != proton.AddressStatusEnabled {
continue
}
if strings.EqualFold(addr.Email, email) {
return addr.ID, nil
}
}
return "", fmt.Errorf("invalid email")
}, user.apiAddrsLock)
return user.identityService.CheckAuth(ctx, email, password)
}
// OnStatusUp is called when the connection goes up.
func (user *User) OnStatusUp(context.Context) {
func (user *User) OnStatusUp(ctx context.Context) {
user.log.Info("Connection is up")
user.goSync()
if err := user.imapService.ResumeSync(ctx); err != nil {
user.log.WithError(err).Error("Failed to resume sync")
}
}
// OnStatusDown is called when the connection goes down.
func (user *User) OnStatusDown(context.Context) {
func (user *User) OnStatusDown(ctx context.Context) {
user.log.Info("Connection is down")
user.syncAbort.Abort()
user.pollAbort.Abort()
user.eventService.Pause()
if err := user.imapService.CancelSync(ctx); err != nil {
user.log.WithError(err).Error("Failed to cancel sync")
}
// GetSyncStatus returns the sync status of the user.
func (user *User) GetSyncStatus() vault.SyncStatus {
return user.vault.GetSyncStatus()
}
// ClearSyncStatus clears the sync status of the user.
// This also drops any updates in the update channel(s).
// Warning: the gluon user must be removed and re-added if this happens!
func (user *User) ClearSyncStatus() error {
user.log.Info("Clearing sync status")
return safe.LockRet(func() error {
return user.clearSyncStatus()
}, user.eventLock, user.apiAddrsLock, user.updateChLock)
}
// clearSyncStatus clears the sync status of the user.
// This also drops any updates in the update channel(s).
// Warning: the gluon user must be removed and re-added if this happens!
// It is assumed that the eventLock, apiAddrsLock and updateChLock are already locked.
func (user *User) clearSyncStatus() error {
user.log.Info("Clearing sync status")
user.initUpdateCh(user.vault.AddressMode())
if err := user.vault.ClearSyncStatus(); err != nil {
return fmt.Errorf("failed to clear sync status: %w", err)
}
return nil
}
// Logout logs the user out from the API.
@ -635,6 +548,19 @@ func (user *User) Logout(ctx context.Context, withAPI bool) error {
user.log.Debug("Canceling ongoing tasks")
if err := user.imapService.OnLogout(ctx); err != nil {
return fmt.Errorf("failed to remove user from server: %w", err)
}
// Stop Services
user.serviceGroup.CancelAndWait()
// Cleanup Event Service.
user.eventService.Close()
// Close imap service.
user.imapService.Close()
user.tasks.CancelAndWait()
if withAPI {
@ -658,27 +584,24 @@ func (user *User) Logout(ctx context.Context, withAPI bool) error {
func (user *User) Close() {
user.log.Info("Closing user")
// Stop Services
user.serviceGroup.CancelAndWait()
// Cleanup Event Service.
user.eventService.Close()
// Close imap service.
user.imapService.Close()
// Stop any ongoing background tasks.
user.tasks.CancelAndWait()
// Close the user's API client.
user.client.Close()
// Close the user's update channels.
safe.Lock(func() {
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
updateCh.CloseAndDiscardQueued()
}
user.updateCh = make(map[string]*async.QueuedChannel[imap.Update])
}, user.updateChLock)
// Close the user's notify channel.
user.eventCh.CloseAndDiscardQueued()
// Cleanup Event Service.
user.eventService.Close()
// Close the user's vault.
if err := user.vault.Close(); err != nil {
user.log.WithError(err).Error("Failed to close vault")
@ -715,12 +638,6 @@ func (user *User) SendTelemetry(ctx context.Context, data []byte) error {
return nil
}
func (user *User) WithSMTPData(ctx context.Context, op func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error {
return safe.RLockRet(func() error {
return op(ctx, user.apiAddrs, user.apiUser, user.vault)
}, user.apiUserLock, user.apiAddrsLock, user.eventLock)
}
func (user *User) ReportSMTPAuthFailed(username string) {
emails := user.Emails()
for _, mail := range emails {
@ -738,182 +655,6 @@ func (user *User) GetSMTPService() *smtp.Service {
return user.smtpService
}
// initUpdateCh initializes the user's update channels in the given address mode.
// It is assumed that user.apiAddrs and user.updateCh are already locked.
func (user *User) initUpdateCh(mode vault.AddressMode) {
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
updateCh.CloseAndDiscardQueued()
}
user.updateCh = make(map[string]*async.QueuedChannel[imap.Update])
switch mode {
case vault.CombinedMode:
primaryUpdateCh := async.NewQueuedChannel[imap.Update](
0,
0,
user.panicHandler,
"user-update-combined",
)
for addrID := range user.apiAddrs {
user.updateCh[addrID] = primaryUpdateCh
}
case vault.SplitMode:
for addrID := range user.apiAddrs {
user.updateCh[addrID] = async.NewQueuedChannel[imap.Update](
0,
0,
user.panicHandler,
fmt.Sprintf("user-update-split-%v", addrID),
)
}
}
}
// startEvents streams events from the API, logging any errors that occur.
// This does nothing until the sync has been marked as complete.
// When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault.
func (user *User) startEvents(ctx context.Context) {
ticker := proton.NewTicker(EventPeriod, EventJitter, user.panicHandler)
defer ticker.Stop()
for {
var doneCh chan struct{}
select {
case <-ctx.Done():
return
case doneCh = <-user.pollAPIEventsCh:
// ...
case <-ticker.C:
// ...
}
user.log.Debug("Event poll triggered")
if err := user.doEventPoll(ctx); err != nil {
user.log.WithError(err).Error("Failed to poll events")
}
if doneCh != nil {
close(doneCh)
}
}
}
// doEventPoll is called whenever API events should be polled.
func (user *User) doEventPoll(ctx context.Context) error {
user.eventLock.Lock()
defer user.eventLock.Unlock()
gpaEvents, more, err := user.client.GetEvent(ctx, user.vault.EventID())
if err != nil {
return fmt.Errorf("failed to get event (caused by %T): %w", internal.ErrCause(err), err)
}
// If the event ID hasn't changed, there are no new events.
if gpaEvents[len(gpaEvents)-1].EventID == user.vault.EventID() {
user.log.Debug("No new API events")
return nil
}
for _, event := range gpaEvents {
user.log.WithFields(logrus.Fields{
"old": user.vault.EventID(),
"new": event,
}).Info("Received new API event")
// Handle the event.
if err := user.handleAPIEvent(ctx, event); err != nil {
// If the error is a context cancellation, return error to retry later.
if errors.Is(err, context.Canceled) {
return fmt.Errorf("failed to handle event due to context cancellation: %w", err)
}
// If the error is a network error, return error to retry later.
if netErr := new(proton.NetError); errors.As(err, &netErr) {
return fmt.Errorf("failed to handle event due to network issue: %w", err)
}
// Catch all for uncategorized net errors that may slip through.
if netErr := new(net.OpError); errors.As(err, &netErr) {
return fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err)
}
// In case a json decode error slips through.
if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) {
user.eventCh.Enqueue(events.UncategorizedEventError{
UserID: user.ID(),
Error: err,
})
return fmt.Errorf("failed to handle event due to JSON issue: %w", err)
}
// If the error is an unexpected EOF, return error to retry later.
if errors.Is(err, io.ErrUnexpectedEOF) {
return fmt.Errorf("failed to handle event due to EOF: %w", err)
}
// If the error is a server-side issue, return error to retry later.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
return fmt.Errorf("failed to handle event due to server error: %w", err)
}
// Otherwise, the error is a client-side issue; notify bridge to handle it.
user.log.WithField("event", event).Warn("Failed to handle API event")
user.eventCh.Enqueue(events.UserBadEvent{
UserID: user.ID(),
OldEventID: user.vault.EventID(),
NewEventID: event.EventID,
EventInfo: event.String(),
Error: err,
})
return fmt.Errorf("failed to handle event due to client error: %w", err)
}
user.log.WithField("event", event).Debug("Handled API event")
// Update the event ID in the vault. If this fails, notify bridge to handle it.
if err := user.vault.SetEventID(event.EventID); err != nil {
user.eventCh.Enqueue(events.UserBadEvent{
UserID: user.ID(),
Error: err,
})
return fmt.Errorf("failed to update event ID: %w", err)
}
user.log.WithField("eventID", event.EventID).Debug("Updated event ID in vault")
}
if more {
user.goPollAPIEvents(false)
}
return nil
}
// b32 returns a uint32 0 or 1 representing b.
func b32(b bool) uint32 {
if b {
return 1
}
return 0
}
// sleepCtx sleeps for the given duration, or until the context is canceled.
func sleepCtx(ctx context.Context, d time.Duration) {
select {
case <-ctx.Done():
case <-time.After(d):
}
func (user *User) PublishEvent(_ context.Context, event events.Event) {
user.eventCh.Enqueue(event)
}

View File

@ -26,6 +26,8 @@ import (
"github.com/ProtonMail/go-proton-api/server"
"github.com/ProtonMail/go-proton-api/server/backend"
"github.com/ProtonMail/proton-bridge/v3/internal/certs"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/services/imapservice"
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry/mocks"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/tests"
@ -147,9 +149,28 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
ctl := gomock.NewController(tb)
defer ctl.Finish()
manager := mocks.NewMockHeartbeatManager(ctl)
manager.EXPECT().IsTelemetryAvailable(context.Background()).AnyTimes()
user, err := New(ctx, vaultUser, client, nil, apiUser, nil, true, vault.DefaultMaxSyncMemory, tb.TempDir(), manager)
nullEventSubscription := events.NewNullSubscription()
nullServerManager := imapservice.NewNullIMAPServerManager()
user, err := New(
ctx,
vaultUser,
client,
nil,
apiUser,
nil,
true,
vault.DefaultMaxSyncMemory,
tb.TempDir(),
manager,
nullServerManager,
nullEventSubscription,
)
require.NoError(tb, err)
defer user.Close()