GODT-1815: Combined/Split mode

This commit is contained in:
James Houlahan
2022-09-28 11:29:33 +02:00
parent 9670e29d9f
commit e9672e6bba
55 changed files with 1909 additions and 705 deletions

View File

@ -232,6 +232,13 @@ func (bridge *Bridge) GetErrors() []error {
}
func (bridge *Bridge) Close(ctx context.Context) error {
// Abort any ongoing syncs.
for _, user := range bridge.users {
if err := user.AbortSync(ctx); err != nil {
return fmt.Errorf("failed to abort sync: %w", err)
}
}
// Close the IMAP server.
if err := bridge.closeIMAP(ctx); err != nil {
logrus.WithError(err).Error("Failed to close IMAP server")

View File

@ -4,19 +4,24 @@ import (
"context"
"os"
"testing"
"time"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/ProtonMail/proton-bridge/v2/tests"
"github.com/bradenaw/juniper/xslices"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi/server"
"gitlab.protontech.ch/go/liteapi/server/account"
)
const (
@ -29,6 +34,13 @@ var (
v2_4_0 = semver.MustParse("2.4.0")
)
func init() {
user.DefaultEventPeriod = 100 * time.Millisecond
user.DefaultEventJitter = 0
account.GenerateKey = tests.FastGenerateKey
certs.GenerateCert = tests.FastGenerateCert
}
func TestBridge_ConnStatus(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, vaultKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
@ -156,19 +168,27 @@ func TestBridge_CheckUpdate(t *testing.T) {
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{})
// Get a stream of update not available events.
noUpdateCh, done := bridge.GetEvents(events.UpdateNotAvailable{})
defer done()
// We are currently on the latest version.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateNotAvailable{}, <-updateCh)
// we should receive an event indicating that no update is available.
require.Equal(t, events.UpdateNotAvailable{}, <-noUpdateCh)
// Simulate a new version being available.
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
// Get a stream of update available events.
updateCh, done := bridge.GetEvents(events.UpdateAvailable{})
defer done()
// Check for updates.
bridge.CheckForUpdates()
// We should receive an event indicating that an update is available.
require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{
Version: v2_4_0,
@ -188,7 +208,7 @@ func TestBridge_AutoUpdate(t *testing.T) {
require.NoError(t, bridge.SetAutoUpdate(true))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateInstalled{})
updateCh, done := bridge.GetEvents(events.UpdateInstalled{})
defer done()
// Simulate a new version being available.
@ -196,6 +216,8 @@ func TestBridge_AutoUpdate(t *testing.T) {
// Check for updates.
bridge.CheckForUpdates()
// We should receive an event indicating that the update was installed.
require.Equal(t, events.UpdateInstalled{
Version: updater.VersionInfo{
Version: v2_4_0,
@ -213,8 +235,8 @@ func TestBridge_ManualUpdate(t *testing.T) {
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{})
// Get a stream of update available events.
updateCh, done := bridge.GetEvents(events.UpdateAvailable{})
defer done()
// Simulate a new version being available, but it's too new for us.
@ -222,6 +244,8 @@ func TestBridge_ManualUpdate(t *testing.T) {
// Check for updates.
bridge.CheckForUpdates()
// We should receive an event indicating an update is available, but we can't install it.
require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{
Version: v2_4_0,

View File

@ -14,8 +14,9 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
return ErrNoSuchUser
}
// TODO: Handle split mode!
if address == "" {
address = user.Addresses()[0]
address = user.Emails()[0]
}
// If configuring apple mail for Catalina or newer, users should use SSL.
@ -32,7 +33,7 @@ func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
bridge.vault.GetIMAPSSL(),
bridge.vault.GetSMTPSSL(),
address,
strings.Join(user.Addresses(), ","),
strings.Join(user.Emails(), ","),
user.BridgePass(),
)
}

View File

@ -2,6 +2,7 @@ package bridge
import (
"context"
"fmt"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
@ -96,40 +97,39 @@ func (bridge *Bridge) GetGluonDir() string {
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
if newGluonDir == bridge.GetGluonDir() {
return nil
return fmt.Errorf("new gluon dir is the same as the old one")
}
if err := bridge.closeIMAP(context.Background()); err != nil {
return err
return fmt.Errorf("failed to close IMAP: %w", err)
}
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
return err
return fmt.Errorf("failed to move gluon dir: %w", err)
}
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
return err
return fmt.Errorf("failed to set new gluon dir: %w", err)
}
imapServer, err := newIMAPServer(bridge.vault.GetGluonDir(), bridge.curVersion, bridge.tlsConfig)
if err != nil {
return err
}
for _, user := range bridge.users {
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return err
}
if err := imapServer.LoadUser(context.Background(), imapConn, user.GluonID(), user.GluonKey()); err != nil {
return err
}
return fmt.Errorf("failed to create new IMAP server: %w", err)
}
bridge.imapServer = imapServer
return bridge.serveIMAP()
for _, user := range bridge.users {
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
}
if err := bridge.serveIMAP(); err != nil {
return fmt.Errorf("failed to serve IMAP: %w", err)
}
return nil
}
func (bridge *Bridge) GetProxyAllowed() bool {

View File

@ -23,8 +23,8 @@ func (backend *smtpBackend) Login(state *smtp.ConnectionState, username string,
defer backend.usersLock.RUnlock()
for _, user := range backend.users {
if slices.Contains(user.Addresses(), username) && user.BridgePass() == password {
return user.NewSMTPSession(username)
if slices.Contains(user.Emails(), username) && user.BridgePass() == password {
return user.NewSMTPSession(username), nil
}
}

View File

@ -29,7 +29,7 @@ type UserInfo struct {
Addresses []string
// AddressMode is the user's address mode.
AddressMode AddressMode
AddressMode vault.AddressMode
// BridgePass is the user's bridge password.
BridgePass string
@ -41,13 +41,6 @@ type UserInfo struct {
MaxSpace int
}
type AddressMode int
const (
SplitMode AddressMode = iota
CombinedMode
)
// GetUserIDs returns the IDs of all known users (authorized or not).
func (bridge *Bridge) GetUserIDs() []string {
return bridge.vault.GetUserIDs()
@ -62,7 +55,7 @@ func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
user, ok := bridge.users[userID]
if !ok {
return getUserInfo(vaultUser.UserID(), vaultUser.Username()), nil
return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil
}
return getConnUserInfo(user), nil
@ -153,12 +146,43 @@ func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
return nil
}
func (bridge *Bridge) GetAddressMode(userID string) (AddressMode, error) {
panic("TODO")
}
// SetAddressMode sets the address mode for the given user.
func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode vault.AddressMode) error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
func (bridge *Bridge) SetAddressMode(userID string, mode AddressMode) error {
panic("TODO")
if user.GetAddressMode() == mode {
return fmt.Errorf("address mode is already %q", mode)
}
if err := user.AbortSync(ctx); err != nil {
return fmt.Errorf("failed to abort sync: %w", err)
}
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
}
if err := user.SetAddressMode(ctx, mode); err != nil {
return fmt.Errorf("failed to set address mode: %w", err)
}
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
bridge.publish(events.AddressModeChanged{
UserID: userID,
AddressMode: mode,
})
user.DoSync(ctx)
return nil
}
// loadUsers loads authorized users from the vault.
@ -177,7 +201,7 @@ func (bridge *Bridge) loadUsers(ctx context.Context) error {
logrus.WithError(err).Error("Failed to load connected user")
if _, ok := err.(*resty.ResponseError); ok {
if err := user.Clear(); err != nil {
if err := bridge.vault.ClearUser(userID); err != nil {
logrus.WithError(err).Error("Failed to clear user")
}
}
@ -231,33 +255,41 @@ func (bridge *Bridge) addUser(
if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) {
existingUser, err := bridge.addExistingUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
if err != nil {
return err
return fmt.Errorf("failed to add existing user: %w", err)
}
user = existingUser
} else {
newUser, err := bridge.addNewUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
if err != nil {
return err
return fmt.Errorf("failed to add new user: %w", err)
}
user = newUser
}
go func() {
for event := range user.GetNotifyCh() {
switch event := event.(type) {
case events.UserDeauth:
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
}
// Connects the user's address(es) to gluon.
if err := bridge.addIMAPUser(ctx, user); err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
bridge.publish(event)
// Handle events coming from the user before forwarding them to the bridge.
// For example, if the user's addresses change, we need to update them in gluon.
go func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for event := range user.GetEventCh() {
if err := bridge.handleUserEvent(ctx, user, event); err != nil {
logrus.WithError(err).Error("Failed to handle user event")
} else {
bridge.publish(event)
}
}
}()
// Gluon will set the IMAP ID in the context, if known, before making requests on behalf of this user.
// As such, if we find this ID in the context, we should use it to update our user agent.
client.AddPreRequestHook(func(ctx context.Context, req *resty.Request) error {
if imapID, ok := imap.GetIMAPIDFromContext(ctx); ok {
bridge.identifier.SetClient(imapID.Name, imapID.Version)
@ -266,6 +298,11 @@ func (bridge *Bridge) addUser(
return nil
})
// TODO: Replace this with proper sync manager.
if !user.HasSync() {
user.DoSync(ctx)
}
bridge.publish(events.UserLoggedIn{
UserID: user.ID(),
})
@ -293,25 +330,6 @@ func (bridge *Bridge) addNewUser(
return nil, err
}
gluonKey, err := crypto.RandomToken(32)
if err != nil {
return nil, err
}
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return nil, err
}
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, gluonKey)
if err != nil {
return nil, err
}
if err := vaultUser.SetGluonAuth(gluonID, gluonKey); err != nil {
return nil, err
}
if err := bridge.smtpBackend.addUser(user); err != nil {
return nil, err
}
@ -349,15 +367,6 @@ func (bridge *Bridge) addExistingUser(
return nil, err
}
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return nil, err
}
if err := bridge.imapServer.LoadUser(ctx, imapConn, user.GluonID(), user.GluonKey()); err != nil {
return nil, err
}
if err := bridge.smtpBackend.addUser(user); err != nil {
return nil, err
}
@ -376,31 +385,39 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, wi
return ErrNoSuchUser
}
vaultUser, err := bridge.vault.GetUser(userID)
if err != nil {
return err
}
if err := bridge.imapServer.RemoveUser(ctx, vaultUser.GluonID(), withFiles); err != nil {
return err
// TODO: The sync should be canceled by the sync manager.
if err := user.AbortSync(ctx); err != nil {
return fmt.Errorf("failed to abort user sync: %w", err)
}
if err := bridge.smtpBackend.removeUser(user); err != nil {
return err
return fmt.Errorf("failed to remove SMTP user: %w", err)
}
for _, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, withFiles); err != nil {
return fmt.Errorf("failed to remove IMAP user: %w", err)
}
}
if withAPI {
if err := user.Logout(ctx); err != nil {
return err
return fmt.Errorf("failed to logout user: %w", err)
}
}
if err := user.Close(ctx); err != nil {
return err
return fmt.Errorf("failed to close user: %w", err)
}
if err := vaultUser.Clear(); err != nil {
return err
if err := bridge.vault.ClearUser(userID); err != nil {
return fmt.Errorf("failed to clear user: %w", err)
}
if withFiles {
if err := bridge.vault.DeleteUser(userID); err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
}
delete(bridge.users, userID)
@ -412,12 +429,39 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, wi
return nil
}
// addIMAPUser connects the given user to gluon.
func (bridge *Bridge) addIMAPUser(ctx context.Context, user *user.User) error {
imapConn, err := user.NewIMAPConnectors()
if err != nil {
return fmt.Errorf("failed to create IMAP connectors: %w", err)
}
for addrID, imapConn := range imapConn {
if gluonID, ok := user.GetGluonID(addrID); ok {
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to load IMAP user: %w", err)
}
} else {
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add IMAP user: %w", err)
}
if err := user.SetGluonID(addrID, gluonID); err != nil {
return fmt.Errorf("failed to set IMAP user ID: %w", err)
}
}
}
return nil
}
// getUserInfo returns information about a disconnected user.
func getUserInfo(userID, username string) UserInfo {
func getUserInfo(userID, username string, addressMode vault.AddressMode) UserInfo {
return UserInfo{
UserID: userID,
Username: username,
AddressMode: CombinedMode,
AddressMode: addressMode,
}
}
@ -427,8 +471,8 @@ func getConnUserInfo(user *user.User) UserInfo {
Connected: true,
UserID: user.ID(),
Username: user.Name(),
Addresses: user.Addresses(),
AddressMode: CombinedMode,
Addresses: user.Emails(),
AddressMode: user.GetAddressMode(),
BridgePass: user.BridgePass(),
UsedSpace: user.UsedSpace(),
MaxSpace: user.MaxSpace(),

View File

@ -0,0 +1,118 @@
package bridge
import (
"context"
"fmt"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
)
func (bridge *Bridge) handleUserEvent(ctx context.Context, user *user.User, event events.Event) error {
switch event := event.(type) {
case events.UserAddressCreated:
if err := bridge.handleUserAddressCreated(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user address created event: %w", err)
}
case events.UserAddressUpdated:
if err := bridge.handleUserAddressUpdated(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user address updated event: %w", err)
}
case events.UserAddressDeleted:
if err := bridge.handleUserAddressDeleted(ctx, user, event); err != nil {
return fmt.Errorf("failed to handle user address deleted event: %w", err)
}
case events.UserDeauth:
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
return fmt.Errorf("failed to logout user: %w", err)
}
}
return nil
}
func (bridge *Bridge) handleUserAddressCreated(ctx context.Context, user *user.User, event events.UserAddressCreated) error {
switch user.GetAddressMode() {
case vault.CombinedMode:
for addrID, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
imapConn, err := user.NewIMAPConnector(addrID)
if err != nil {
return fmt.Errorf("failed to create IMAP connector: %w", err)
}
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to add user to IMAP server: %w", err)
}
}
case vault.SplitMode:
imapConn, err := user.NewIMAPConnector(event.AddressID)
if err != nil {
return fmt.Errorf("failed to create IMAP connector: %w", err)
}
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, user.GluonKey())
if err != nil {
return fmt.Errorf("failed to add user to IMAP server: %w", err)
}
if err := user.SetGluonID(event.AddressID, gluonID); err != nil {
return fmt.Errorf("failed to set gluon ID: %w", err)
}
}
return nil
}
// TODO: Handle addresses that have been disabled!
func (bridge *Bridge) handleUserAddressUpdated(ctx context.Context, user *user.User, event events.UserAddressUpdated) error {
switch user.GetAddressMode() {
case vault.CombinedMode:
return fmt.Errorf("not implemented")
case vault.SplitMode:
return fmt.Errorf("not implemented")
}
return nil
}
func (bridge *Bridge) handleUserAddressDeleted(ctx context.Context, user *user.User, event events.UserAddressDeleted) error {
switch user.GetAddressMode() {
case vault.CombinedMode:
for addrID, gluonID := range user.GetGluonIDs() {
if err := bridge.imapServer.RemoveUser(ctx, gluonID, false); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
imapConn, err := user.NewIMAPConnector(addrID)
if err != nil {
return fmt.Errorf("failed to create IMAP connector: %w", err)
}
if err := bridge.imapServer.LoadUser(ctx, imapConn, gluonID, user.GluonKey()); err != nil {
return fmt.Errorf("failed to add user to IMAP server: %w", err)
}
}
case vault.SplitMode:
gluonID, ok := user.GetGluonID(event.AddressID)
if !ok {
return fmt.Errorf("gluon ID not found for address %s", event.AddressID)
}
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
}
}
return nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi/server"
)
@ -283,3 +284,30 @@ func TestBridge_BridgePass(t *testing.T) {
})
})
}
func TestBridge_AddressMode(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, dialer *bridge.TestDialer, locator bridge.Locator, storeKey []byte) {
withBridge(t, ctx, s.GetHostURL(), dialer, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
require.NoError(t, err)
// Get the user's info.
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
// The user is in combined mode by default.
require.Equal(t, vault.CombinedMode, info.AddressMode)
// Put the user in split mode.
require.NoError(t, bridge.SetAddressMode(ctx, userID, vault.SplitMode))
// Get the user's info.
info, err = bridge.GetUserInfo(userID)
require.NoError(t, err)
// The user is in split mode.
require.Equal(t, vault.SplitMode, info.AddressMode)
})
})
}

View File

@ -1,5 +1,7 @@
package events
import "github.com/ProtonMail/proton-bridge/v2/internal/vault"
type UserLoggedIn struct {
eventBase
@ -33,20 +35,30 @@ type UserChanged struct {
type UserAddressCreated struct {
eventBase
UserID string
Address string
UserID string
AddressID string
Email string
}
type UserAddressChanged struct {
type UserAddressUpdated struct {
eventBase
UserID string
Address string
UserID string
AddressID string
Email string
}
type UserAddressDeleted struct {
eventBase
UserID string
Address string
UserID string
AddressID string
Email string
}
type AddressModeChanged struct {
eventBase
UserID string
AddressMode vault.AddressMode
}

View File

@ -23,6 +23,7 @@ import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/abiosoft/ishell"
)
@ -39,7 +40,7 @@ func (f *frontendCLI) listAccounts(c *ishell.Context) {
connected = "connected"
}
mode := "split"
if user.AddressMode == bridge.CombinedMode {
if user.AddressMode == vault.CombinedMode {
mode = "combined"
}
f.Printf(spacing, idx, user.Username, connected, mode)
@ -58,7 +59,7 @@ func (f *frontendCLI) showAccountInfo(c *ishell.Context) {
return
}
if user.AddressMode == bridge.CombinedMode {
if user.AddressMode == vault.CombinedMode {
f.showAccountAddressInfo(user, user.Addresses[0])
} else {
for _, address := range user.Addresses {
@ -225,19 +226,19 @@ func (f *frontendCLI) changeMode(c *ishell.Context) {
return
}
var targetMode bridge.AddressMode
var targetMode vault.AddressMode
if user.AddressMode == bridge.CombinedMode {
targetMode = bridge.SplitMode
if user.AddressMode == vault.CombinedMode {
targetMode = vault.SplitMode
} else {
targetMode = bridge.CombinedMode
targetMode = vault.CombinedMode
}
if !f.yesNoQuestion("Are you sure you want to change the mode for account " + bold(user.Username) + " to " + bold(targetMode)) {
return
}
if err := f.bridge.SetAddressMode(user.UserID, targetMode); err != nil {
if err := f.bridge.SetAddressMode(context.Background(), user.UserID, targetMode); err != nil {
f.printAndLogError("Cannot switch address mode:", err)
}

View File

@ -296,7 +296,7 @@ func (f *frontendCLI) watchEvents() {
f.notifyLogout(user.Username)
case events.UserAddressChanged:
case events.UserAddressUpdated:
user, err := f.bridge.GetUserInfo(event.UserID)
if err != nil {
return
@ -305,7 +305,7 @@ func (f *frontendCLI) watchEvents() {
f.Printf("Address changed for %s. You may need to reconfigure your email client.\n", user.Username)
case events.UserAddressDeleted:
f.notifyLogout(event.Address)
f.notifyLogout(event.Email)
case events.SyncStarted:
user, err := f.bridge.GetUserInfo(event.UserID)

View File

@ -228,13 +228,13 @@ func (s *Service) watchEvents() {
_ = s.SendEvent(NewShowMainWindowEvent())
case events.UserAddressCreated:
_ = s.SendEvent(NewMailAddressChangeEvent(event.Address))
_ = s.SendEvent(NewMailAddressChangeEvent(event.Email))
case events.UserAddressChanged:
_ = s.SendEvent(NewMailAddressChangeEvent(event.Address))
case events.UserAddressUpdated:
_ = s.SendEvent(NewMailAddressChangeEvent(event.Email))
case events.UserAddressDeleted:
_ = s.SendEvent(NewMailAddressChangeLogoutEvent(event.Address))
_ = s.SendEvent(NewMailAddressChangeLogoutEvent(event.Email))
case events.UserChanged:
_ = s.SendEvent(NewUserChangedEvent(event.UserID))

View File

@ -20,7 +20,7 @@ package grpc
import (
"context"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@ -74,15 +74,15 @@ func (s *Service) SetUserSplitMode(ctx context.Context, splitMode *UserSplitMode
defer s.panicHandler.HandlePanic()
defer func() { _ = s.SendEvent(NewUserToggleSplitModeFinishedEvent(splitMode.UserID)) }()
var targetMode bridge.AddressMode
var targetMode vault.AddressMode
if splitMode.Active && user.AddressMode == bridge.CombinedMode {
targetMode = bridge.SplitMode
} else if !splitMode.Active && user.AddressMode == bridge.SplitMode {
targetMode = bridge.CombinedMode
if splitMode.Active && user.AddressMode == vault.CombinedMode {
targetMode = vault.SplitMode
} else if !splitMode.Active && user.AddressMode == vault.SplitMode {
targetMode = vault.CombinedMode
}
if err := s.bridge.SetAddressMode(user.UserID, targetMode); err != nil {
if err := s.bridge.SetAddressMode(context.Background(), user.UserID, targetMode); err != nil {
logrus.WithError(err).Error("Failed to set address mode")
}
}()

View File

@ -22,6 +22,7 @@ import (
"strings"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/sirupsen/logrus"
)
@ -64,7 +65,7 @@ func grpcUserFromInfo(user bridge.UserInfo) *User {
Username: user.Username,
AvatarText: getInitials(user.Username),
LoggedIn: user.Connected,
SplitMode: user.AddressMode == bridge.SplitMode,
SplitMode: user.AddressMode == vault.SplitMode,
SetupGuideSeen: true, // users listed have already seen the setup guide.
UsedBytes: int64(user.UsedSpace),
TotalBytes: int64(user.MaxSpace),

41
internal/pool/job.go Normal file
View File

@ -0,0 +1,41 @@
package pool
import "context"
type job[In, Out any] struct {
ctx context.Context
req In
res chan Out
err chan error
done chan struct{}
}
func newJob[In, Out any](ctx context.Context, req In) *job[In, Out] {
return &job[In, Out]{
ctx: ctx,
req: req,
res: make(chan Out),
err: make(chan error),
done: make(chan struct{}),
}
}
func (job *job[In, Out]) result() (Out, error) {
return <-job.res, <-job.err
}
func (job *job[In, Out]) postSuccess(res Out) {
close(job.err)
job.res <- res
}
func (job *job[In, Out]) postFailure(err error) {
close(job.res)
job.err <- err
}
func (job *job[In, Out]) waitDone() {
<-job.done
}

View File

@ -13,16 +13,16 @@ var ErrJobCancelled = errors.New("Job cancelled by surrounding context")
// Pool is a worker pool that handles input of type In and returns results of type Out.
type Pool[In comparable, Out any] struct {
queue *queue.QueuedChannel[*Job[In, Out]]
queue *queue.QueuedChannel[*job[In, Out]]
size int
}
// DoneFunc must be called to free up pool resources.
type DoneFunc func()
// doneFunc must be called to free up pool resources.
type doneFunc func()
// New returns a new pool.
func New[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] {
queue := queue.NewQueuedChannel[*Job[In, Out]](0, 0)
queue := queue.NewQueuedChannel[*job[In, Out]](0, 0)
for i := 0; i < size; i++ {
go func() {
@ -51,17 +51,6 @@ func New[In comparable, Out any](size int, work func(context.Context, In) (Out,
}
}
// NewJob submits a job to the pool. It returns a job handle and a DoneFunc.
// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done,
// which frees up the worker in the pool for reuse.
func (pool *Pool[In, Out]) NewJob(ctx context.Context, req In) (*Job[In, Out], DoneFunc) {
job := newJob[In, Out](ctx, req)
pool.queue.Enqueue(job)
return job, func() { close(job.done) }
}
// Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred.
func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, Out, error) error) error {
ctx, cancel := context.WithCancel(ctx)
@ -81,10 +70,10 @@ func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(In, O
go func() {
defer wg.Done()
job, done := pool.NewJob(ctx, req)
job, done := pool.newJob(ctx, req)
defer done()
res, err := job.Result()
res, err := job.result()
if err := fn(req, res, err); err != nil {
lock.Lock()
@ -134,44 +123,25 @@ func (pool *Pool[In, Out]) ProcessAll(ctx context.Context, reqs []In) (map[In]Ou
return data, nil
}
// ProcessOne submits one job to the pool and returns the result.
func (pool *Pool[In, Out]) ProcessOne(ctx context.Context, req In) (Out, error) {
job, done := pool.newJob(ctx, req)
defer done()
return job.result()
}
func (pool *Pool[In, Out]) Done() {
pool.queue.Close()
}
type Job[In, Out any] struct {
ctx context.Context
req In
// newJob submits a job to the pool. It returns a job handle and a DoneFunc.
// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done,
// which frees up the worker in the pool for reuse.
func (pool *Pool[In, Out]) newJob(ctx context.Context, req In) (*job[In, Out], doneFunc) {
job := newJob[In, Out](ctx, req)
res chan Out
err chan error
pool.queue.Enqueue(job)
done chan struct{}
}
func newJob[In, Out any](ctx context.Context, req In) *Job[In, Out] {
return &Job[In, Out]{
ctx: ctx,
req: req,
res: make(chan Out),
err: make(chan error),
done: make(chan struct{}),
}
}
func (job *Job[In, Out]) Result() (Out, error) {
return <-job.res, <-job.err
}
func (job *Job[In, Out]) postSuccess(res Out) {
close(job.err)
job.res <- res
}
func (job *Job[In, Out]) postFailure(err error) {
close(job.res)
job.err <- err
}
func (job *Job[In, Out]) waitDone() {
<-job.done
return job, func() { close(job.done) }
}

View File

@ -15,16 +15,16 @@ import (
func TestPool_NewJob(t *testing.T) {
doubler := newDoubler(runtime.NumCPU())
job1, done1 := doubler.NewJob(context.Background(), 1)
job1, done1 := doubler.newJob(context.Background(), 1)
defer done1()
job2, done2 := doubler.NewJob(context.Background(), 2)
job2, done2 := doubler.newJob(context.Background(), 2)
defer done2()
res2, err := job2.Result()
res2, err := job2.result()
require.NoError(t, err)
res1, err := job1.Result()
res1, err := job1.result()
require.NoError(t, err)
assert.Equal(t, 2, res1)
@ -36,31 +36,31 @@ func TestPool_NewJob_Done(t *testing.T) {
doubler := newDoubler(2)
// Start two jobs. Don't mark the jobs as done yet.
job1, done1 := doubler.NewJob(context.Background(), 1)
job2, done2 := doubler.NewJob(context.Background(), 2)
job1, done1 := doubler.newJob(context.Background(), 1)
job2, done2 := doubler.newJob(context.Background(), 2)
// Get the first result.
res1, _ := job1.Result()
res1, _ := job1.result()
assert.Equal(t, 2, res1)
// Get the first result.
res2, _ := job2.Result()
res2, _ := job2.result()
assert.Equal(t, 4, res2)
// Additional jobs will wait.
job3, _ := doubler.NewJob(context.Background(), 3)
job4, _ := doubler.NewJob(context.Background(), 4)
job3, _ := doubler.newJob(context.Background(), 3)
job4, _ := doubler.newJob(context.Background(), 4)
// Channel to collect results from jobs 3 and 4.
resCh := make(chan int, 2)
go func() {
res, _ := job3.Result()
res, _ := job3.result()
resCh <- res
}()
go func() {
res, _ := job4.Result()
res, _ := job4.result()
resCh <- res
}()

View File

@ -0,0 +1,46 @@
package user
import "gitlab.protontech.ch/go/liteapi"
type addrList struct {
apiAddrs ordMap[string, string, liteapi.Address]
}
func newAddrList(apiAddrs []liteapi.Address) *addrList {
return &addrList{
apiAddrs: newOrdMap(
func(addr liteapi.Address) string { return addr.ID },
func(addr liteapi.Address) string { return addr.Email },
func(a, b liteapi.Address) bool { return a.Order < b.Order },
apiAddrs...,
),
}
}
func (list *addrList) insert(address liteapi.Address) {
list.apiAddrs.insert(address)
}
func (list *addrList) delete(addrID string) string {
return list.apiAddrs.delete(addrID)
}
func (list *addrList) primary() string {
return list.apiAddrs.keys()[0]
}
func (list *addrList) addrIDs() []string {
return list.apiAddrs.keys()
}
func (list *addrList) emails() []string {
return list.apiAddrs.values()
}
func (list *addrList) email(addrID string) string {
return list.apiAddrs.get(addrID)
}
func (list *addrList) addrMap() map[string]string {
return list.apiAddrs.toMap()
}

View File

@ -2,16 +2,20 @@ package user
import (
"context"
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
"github.com/bradenaw/juniper/xslices"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/slices"
)
type request struct {
messageID string
addressID string
addrKR *crypto.KeyRing
}
@ -54,8 +58,38 @@ func newBuilder(f fetcher, msgWorkers, attWorkers int) *pool.Pool[request, *imap
return nil, err
}
return getMessageCreatedUpdate(msg, literal)
return newMessageCreatedUpdate(msg, literal)
})
return msgPool
}
func newMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) {
parsedMessage, err := imap.NewParsedMessage(literal)
if err != nil {
return nil, err
}
flags := imap.NewFlagSet()
if !message.Unread {
flags = flags.Add(imap.FlagSeen)
}
if slices.Contains(message.LabelIDs, liteapi.StarredLabel) {
flags = flags.Add(imap.FlagFlagged)
}
imapMessage := imap.Message{
ID: imap.MessageID(message.ID),
Flags: flags,
Date: time.Unix(message.Time, 0),
}
return &imap.MessageCreated{
Message: imapMessage,
Literal: literal,
LabelIDs: mapTo[string, imap.LabelID](xslices.Filter(message.LabelIDs, wantLabelID)),
ParsedMessage: parsedMessage,
}, nil
}

View File

@ -8,5 +8,5 @@ var (
ErrNotSupported = errors.New("not supported")
ErrInvalidReturnPath = errors.New("invalid return path")
ErrInvalidRecipient = errors.New("invalid recipient")
ErrMissingAddressKey = errors.New("missing address key")
ErrMissingAddrKey = errors.New("missing address key")
)

View File

@ -2,43 +2,44 @@ package user
import (
"context"
"fmt"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
// handleAPIEvent handles the given liteapi.Event.
func (user *User) handleAPIEvent(event liteapi.Event) error {
func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error {
if event.User != nil {
if err := user.handleUserEvent(*event.User); err != nil {
if err := user.handleUserEvent(ctx, *event.User); err != nil {
return err
}
}
if len(event.Addresses) > 0 {
if err := user.handleAddressEvents(event.Addresses); err != nil {
if err := user.handleAddressEvents(ctx, event.Addresses); err != nil {
return err
}
}
if event.MailSettings != nil {
if err := user.handleMailSettingsEvent(*event.MailSettings); err != nil {
if err := user.handleMailSettingsEvent(ctx, *event.MailSettings); err != nil {
return err
}
}
if len(event.Labels) > 0 {
if err := user.handleLabelEvents(event.Labels); err != nil {
if err := user.handleLabelEvents(ctx, event.Labels); err != nil {
return err
}
}
if len(event.Messages) > 0 {
if err := user.handleMessageEvents(event.Messages); err != nil {
if err := user.handleMessageEvents(ctx, event.Messages); err != nil {
return err
}
}
@ -47,7 +48,7 @@ func (user *User) handleAPIEvent(event liteapi.Event) error {
}
// handleUserEvent handles the given user event.
func (user *User) handleUserEvent(userEvent liteapi.User) error {
func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) error {
userKR, err := userEvent.Keys.Unlock(user.vault.KeyPass(), nil)
if err != nil {
return err
@ -57,49 +58,31 @@ func (user *User) handleUserEvent(userEvent liteapi.User) error {
user.userKR = userKR
user.notifyCh <- events.UserChanged{
user.eventCh.Enqueue(events.UserChanged{
UserID: user.ID(),
}
})
return nil
}
// handleAddressEvents handles the given address events.
// TODO: If split address mode, need to signal back to bridge to update the addresses!
func (user *User) handleAddressEvents(addressEvents []liteapi.AddressEvent) error {
func (user *User) handleAddressEvents(ctx context.Context, addressEvents []liteapi.AddressEvent) error {
for _, event := range addressEvents {
switch event.Action {
case liteapi.EventDelete:
address, err := user.deleteAddress(event.ID)
if err != nil {
return err
}
// TODO: This is not the same as addressChangedLogout event!
// That was only relevant in split mode. This is used differently now.
user.notifyCh <- events.UserAddressDeleted{
UserID: user.ID(),
Address: address.Email,
}
case liteapi.EventCreate:
if err := user.createAddress(event.Address); err != nil {
return err
}
user.notifyCh <- events.UserAddressCreated{
UserID: user.ID(),
Address: event.Address.Email,
if err := user.handleCreateAddressEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle create address event: %w", err)
}
case liteapi.EventUpdate:
if err := user.updateAddress(event.Address); err != nil {
return err
if err := user.handleUpdateAddressEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle update address event: %w", err)
}
user.notifyCh <- events.UserAddressChanged{
UserID: user.ID(),
Address: event.Address.Email,
case liteapi.EventDelete:
if err := user.handleDeleteAddressEvent(ctx, event); err != nil {
return fmt.Errorf("failed to delete address: %w", err)
}
}
}
@ -107,111 +90,189 @@ func (user *User) handleAddressEvents(addressEvents []liteapi.AddressEvent) erro
return nil
}
// createAddress creates the given address.
func (user *User) createAddress(address liteapi.Address) error {
addrKR, err := address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
if err != nil {
return err
return fmt.Errorf("failed to unlock address keys: %w", err)
}
if user.imapConn != nil {
user.imapConn.addAddress(address.Email)
user.apiAddrs.insert(event.Address)
user.addrKRs[event.Address.ID] = addrKR
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
if err := user.syncLabels(ctx, event.Address.ID); err != nil {
return fmt.Errorf("failed to sync labels to new address: %w", err)
}
}
user.addresses = append(user.addresses, address)
user.addrKRs[address.ID] = addrKR
user.eventCh.Enqueue(events.UserAddressCreated{
UserID: user.ID(),
AddressID: event.Address.ID,
Email: event.Address.Email,
})
return nil
}
// updateAddress updates the given address.
func (user *User) updateAddress(address liteapi.Address) error {
if _, err := user.deleteAddress(address.ID); err != nil {
return err
func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
if err != nil {
return fmt.Errorf("failed to unlock address keys: %w", err)
}
return user.createAddress(address)
}
user.apiAddrs.insert(event.Address)
// deleteAddress deletes the given address.
func (user *User) deleteAddress(addressID string) (liteapi.Address, error) {
idx := xslices.IndexFunc(user.addresses, func(address liteapi.Address) bool {
return address.ID == addressID
user.addrKRs[event.Address.ID] = addrKR
user.eventCh.Enqueue(events.UserAddressUpdated{
UserID: user.ID(),
AddressID: event.Address.ID,
Email: event.Address.Email,
})
if idx < 0 {
return liteapi.Address{}, ErrNoSuchAddress
return nil
}
func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
email := user.apiAddrs.delete(event.ID)
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.ID].Close()
delete(user.updateCh, event.ID)
}
if user.imapConn != nil {
user.imapConn.remAddress(user.addresses[idx].Email)
}
user.eventCh.Enqueue(events.UserAddressDeleted{
UserID: user.ID(),
AddressID: event.ID,
Email: email,
})
var address liteapi.Address
address, user.addresses = user.addresses[idx], append(user.addresses[:idx], user.addresses[idx+1:]...)
delete(user.addrKRs, addressID)
return address, nil
return nil
}
// handleMailSettingsEvent handles the given mail settings event.
func (user *User) handleMailSettingsEvent(mailSettingsEvent liteapi.MailSettings) error {
func (user *User) handleMailSettingsEvent(ctx context.Context, mailSettingsEvent liteapi.MailSettings) error {
user.settings = mailSettingsEvent
return nil
}
// handleLabelEvents handles the given label events.
func (user *User) handleLabelEvents(labelEvents []liteapi.LabelEvent) error {
func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.LabelEvent) error {
for _, event := range labelEvents {
switch event.Action {
case liteapi.EventDelete:
user.updateCh <- imap.NewMailboxDeleted(imap.LabelID(event.ID))
case liteapi.EventCreate:
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label))
if err := user.handleCreateLabelEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle create label event: %w", err)
}
case liteapi.EventUpdate, liteapi.EventUpdateFlags:
user.updateCh <- imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label))
if err := user.handleUpdateLabelEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle update label event: %w", err)
}
case liteapi.EventDelete:
if err := user.handleDeleteLabelEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle delete label event: %w", err)
}
}
}
return nil
}
func (user *User) handleCreateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
for _, updateCh := range user.updateCh {
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label)))
}
return nil
}
func (user *User) handleUpdateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
for _, updateCh := range user.updateCh {
updateCh.Enqueue(imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label)))
}
return nil
}
func (user *User) handleDeleteLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
for _, updateCh := range user.updateCh {
updateCh.Enqueue(imap.NewMailboxDeleted(imap.LabelID(event.ID)))
}
return nil
}
// handleMessageEvents handles the given message events.
func (user *User) handleMessageEvents(messageEvents []liteapi.MessageEvent) error {
ctx, cancel := context.WithCancel(context.Background())
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []liteapi.MessageEvent) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for _, event := range messageEvents {
switch event.Action {
case liteapi.EventDelete:
return ErrNotImplemented
case liteapi.EventCreate:
messages, err := user.builder.ProcessAll(ctx, []request{{event.ID, user.addrKRs[event.Message.AddressID]}})
if err != nil {
return err
if err := user.handleCreateMessageEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle create message event: %w", err)
}
user.updateCh <- imap.NewMessagesCreated(maps.Values(messages)...)
case liteapi.EventUpdate, liteapi.EventUpdateFlags:
user.updateCh <- imap.NewMessageLabelsUpdated(
imap.MessageID(event.ID),
imapLabelIDs(filterLabelIDs(event.Message.LabelIDs)),
bool(!event.Message.Unread),
slices.Contains(event.Message.LabelIDs, liteapi.StarredLabel),
)
if err := user.handleUpdateMessageEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle update message event: %w", err)
}
case liteapi.EventDelete:
return ErrNotImplemented
}
}
return nil
}
func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
var addressID string
if user.GetAddressMode() == vault.CombinedMode {
addressID = user.apiAddrs.primary()
} else {
addressID = event.Message.AddressID
}
message, err := user.builder.ProcessOne(ctx, request{
messageID: event.ID,
addressID: addressID,
addrKR: user.addrKRs[event.Message.AddressID],
})
if err != nil {
return err
}
user.updateCh[addressID].Enqueue(imap.NewMessagesCreated(message))
return nil
}
func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
update := imap.NewMessageLabelsUpdated(
imap.MessageID(event.ID),
mapTo[string, imap.LabelID](xslices.Filter(event.Message.LabelIDs, wantLabelID)),
event.Message.Seen(),
event.Message.Starred(),
)
if user.GetAddressMode() == vault.CombinedMode {
user.updateCh[user.apiAddrs.primary()].Enqueue(update)
} else {
user.updateCh[event.Message.AddressID].Enqueue(update)
}
return nil
}
func getMailboxName(label liteapi.Label) []string {
var name []string

76
internal/user/flusher.go Normal file
View File

@ -0,0 +1,76 @@
package user
import (
"sync"
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
)
type flusher struct {
userID string
updateCh *queue.QueuedChannel[imap.Update]
eventCh *queue.QueuedChannel[events.Event]
updates []*imap.MessageCreated
maxChunkSize int
curChunkSize int
count int
total int
start time.Time
pushLock sync.Mutex
}
func newFlusher(
userID string,
updateCh *queue.QueuedChannel[imap.Update],
eventCh *queue.QueuedChannel[events.Event],
total, maxChunkSize int,
) *flusher {
return &flusher{
userID: userID,
updateCh: updateCh,
eventCh: eventCh,
maxChunkSize: maxChunkSize,
total: total,
start: time.Now(),
}
}
func (f *flusher) push(update *imap.MessageCreated) {
f.pushLock.Lock()
defer f.pushLock.Unlock()
f.updates = append(f.updates, update)
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
f.flush()
}
}
func (f *flusher) flush() {
if len(f.updates) == 0 {
return
}
f.count += len(f.updates)
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
f.eventCh.Enqueue(newSyncProgress(f.userID, f.count, f.total, f.start))
f.updates = nil
f.curChunkSize = 0
}
func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress {
return events.SyncProgress{
UserID: userID,
Progress: float64(count) / float64(total),
Elapsed: time.Since(start),
Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count),
}
}

View File

@ -25,11 +25,12 @@ const (
)
type imapConnector struct {
addrID string
client *liteapi.Client
updateCh <-chan imap.Update
addresses []string
password string
emails []string
password string
flags, permFlags, attrs imap.FlagSet
}
@ -37,15 +38,15 @@ type imapConnector struct {
func newIMAPConnector(
client *liteapi.Client,
updateCh <-chan imap.Update,
addresses []string,
password string,
emails ...string,
) *imapConnector {
return &imapConnector{
client: client,
updateCh: updateCh,
addresses: addresses,
password: password,
emails: emails,
password: password,
flags: defaultFlags,
permFlags: defaultPermanentFlags,
@ -59,7 +60,7 @@ func (conn *imapConnector) Authorize(username string, password string) bool {
return false
}
return xslices.IndexFunc(conn.addresses, func(address string) bool {
return xslices.IndexFunc(conn.emails, func(address string) bool {
return strings.EqualFold(address, username)
}) >= 0
}
@ -187,7 +188,7 @@ func (conn *imapConnector) GetMessage(ctx context.Context, messageID imap.Messag
ID: imap.MessageID(message.ID),
Flags: flags,
Date: time.Unix(message.Time, 0),
}, imapLabelIDs(message.LabelIDs), nil
}, mapTo[string, imap.LabelID](message.LabelIDs), nil
}
// CreateMessage creates a new message on the remote.
@ -204,21 +205,21 @@ func (conn *imapConnector) CreateMessage(
// LabelMessages labels the given messages with the given label ID.
func (conn *imapConnector) LabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error {
return conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), string(labelID))
return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelID))
}
// UnlabelMessages unlabels the given messages with the given label ID.
func (conn *imapConnector) UnlabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error {
return conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), string(labelID))
return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelID))
}
// MoveMessages removes the given messages from one label and adds them to the other label.
func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.MessageID, labelFromID imap.LabelID, labelToID imap.LabelID) error {
if err := conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), string(labelToID)); err != nil {
if err := conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
return fmt.Errorf("labeling messages: %w", err)
}
if err := conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), string(labelFromID)); err != nil {
if err := conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
return fmt.Errorf("unlabeling messages: %w", err)
}
@ -228,18 +229,18 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
// MarkMessagesSeen sets the seen value of the given messages.
func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.MessageID, seen bool) error {
if seen {
return conn.client.MarkMessagesRead(ctx, strMessageIDs(messageIDs)...)
return conn.client.MarkMessagesRead(ctx, mapTo[imap.MessageID, string](messageIDs)...)
} else {
return conn.client.MarkMessagesUnread(ctx, strMessageIDs(messageIDs)...)
return conn.client.MarkMessagesUnread(ctx, mapTo[imap.MessageID, string](messageIDs)...)
}
}
// MarkMessagesFlagged sets the flagged value of the given messages.
func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs []imap.MessageID, flagged bool) error {
if flagged {
return conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), liteapi.StarredLabel)
return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel)
} else {
return conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), liteapi.StarredLabel)
return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel)
}
}
@ -249,45 +250,17 @@ func (conn *imapConnector) GetUpdates() <-chan imap.Update {
return conn.updateCh
}
// Close the connector when it will no longer be used and all resources should be closed/released.
func (conn *imapConnector) Close(ctx context.Context) error {
// GetUIDValidity returns the default UID validity for this user.
func (conn *imapConnector) GetUIDValidity() imap.UID {
return imap.UID(1)
}
// SetUIDValidity sets the default UID validity for this user.
func (conn *imapConnector) SetUIDValidity(uidValidity imap.UID) error {
return nil
}
func (conn *imapConnector) addAddress(address string) {
conn.addresses = append(conn.addresses, address)
}
func (conn *imapConnector) remAddress(address string) {
idx := slices.Index(conn.addresses, address)
if idx < 0 {
return
}
conn.addresses = append(conn.addresses[:idx], conn.addresses[idx+1:]...)
}
func strLabelIDs(imapLabelIDs []imap.LabelID) []string {
return xslices.Map(imapLabelIDs, func(labelID imap.LabelID) string {
return string(labelID)
})
}
func imapLabelIDs(labelIDs []string) []imap.LabelID {
return xslices.Map(labelIDs, func(labelID string) imap.LabelID {
return imap.LabelID(labelID)
})
}
func strMessageIDs(imapMessageIDs []imap.MessageID) []string {
return xslices.Map(imapMessageIDs, func(messageID imap.MessageID) string {
return string(messageID)
})
}
func imapMessageIDs(messageIDs []string) []imap.MessageID {
return xslices.Map(messageIDs, func(messageID string) imap.MessageID {
return imap.MessageID(messageID)
})
// Close the connector will no longer be used and all resources should be closed/released.
func (conn *imapConnector) Close(ctx context.Context) error {
return nil
}

89
internal/user/map.go Normal file
View File

@ -0,0 +1,89 @@
package user
import (
"github.com/bradenaw/juniper/xslices"
"golang.org/x/exp/slices"
)
type ordMap[Key comparable, Val, Data any] struct {
data map[Key]Data
order []Key
toKey func(Data) Key
toVal func(Data) Val
isLess func(Data, Data) bool
}
func newOrdMap[Key comparable, Val, Data any](
key func(Data) Key,
value func(Data) Val,
less func(Data, Data) bool,
data ...Data,
) ordMap[Key, Val, Data] {
m := ordMap[Key, Val, Data]{
data: make(map[Key]Data),
toKey: key,
toVal: value,
isLess: less,
}
for _, d := range data {
m.insert(d)
}
return m
}
func (set *ordMap[Key, Val, Data]) insert(data Data) {
if _, ok := set.data[set.toKey(data)]; ok {
set.delete(set.toKey(data))
}
set.data[set.toKey(data)] = data
set.order = append(set.order, set.toKey(data))
slices.SortFunc(set.order, func(a, b Key) bool {
return set.isLess(set.data[a], set.data[b])
})
}
func (set *ordMap[Key, Val, Data]) delete(key Key) Val {
data, ok := set.data[key]
if !ok {
return *new(Val)
}
delete(set.data, key)
set.order = xslices.Filter(set.order, func(otherKey Key) bool {
return otherKey != key
})
return set.toVal(data)
}
func (set *ordMap[Key, Val, Data]) get(key Key) Val {
return set.toVal(set.data[key])
}
func (set *ordMap[Key, Val, Data]) keys() []Key {
return set.order
}
func (set *ordMap[Key, Val, Data]) values() []Val {
return xslices.Map(set.order, func(key Key) Val {
return set.toVal(set.data[key])
})
}
func (set *ordMap[Key, Val, Data]) toMap() map[Key]Val {
m := make(map[Key]Val)
for _, key := range set.order {
m[key] = set.toVal(set.data[key])
}
return m
}

48
internal/user/map_test.go Normal file
View File

@ -0,0 +1,48 @@
package user
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMap(t *testing.T) {
type Key int
type Value string
type Data struct {
key Key
value Value
}
m := newOrdMap(
func(d Data) Key { return d.key },
func(d Data) Value { return d.value },
func(a, b Data) bool { return a.key < b.key },
Data{key: 1, value: "a"},
Data{key: 2, value: "b"},
Data{key: 3, value: "c"},
)
// Insert some new data.
m.insert(Data{key: 4, value: "d"})
m.insert(Data{key: 5, value: "e"})
// Delete some data.
require.Equal(t, Value("c"), m.delete(3))
require.Equal(t, Value("a"), m.delete(1))
require.Equal(t, Value("e"), m.delete(5))
// Check the remaining keys and values are correct.
require.Equal(t, []Key{2, 4}, m.keys())
require.Equal(t, []Value{"b", "d"}, m.values())
// Overwrite some data.
m.insert(Data{key: 2, value: "two"})
m.insert(Data{key: 4, value: "four"})
// Check the remaining keys and values are correct.
require.Equal(t, []Key{2, 4}, m.keys())
require.Equal(t, []Value{"two", "four"}, m.values())
}

View File

@ -20,12 +20,14 @@ import (
)
type smtpSession struct {
client *liteapi.Client
username string
addresses []liteapi.Address
userKR *crypto.KeyRing
addrKRs map[string]*crypto.KeyRing
settings liteapi.MailSettings
client *liteapi.Client
username string
emails map[string]string
settings liteapi.MailSettings
userKR *crypto.KeyRing
addrKRs map[string]*crypto.KeyRing
from string
to map[string]struct{}
@ -34,18 +36,20 @@ type smtpSession struct {
func newSMTPSession(
client *liteapi.Client,
username string,
addresses []liteapi.Address,
addresses map[string]string,
settings liteapi.MailSettings,
userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing,
settings liteapi.MailSettings,
) *smtpSession {
return &smtpSession{
client: client,
username: username,
addresses: addresses,
userKR: userKR,
addrKRs: addrKRs,
settings: settings,
client: client,
username: username,
emails: addresses,
settings: settings,
userKR: userKR,
addrKRs: addrKRs,
from: "",
to: make(map[string]struct{}),
@ -86,15 +90,15 @@ func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error {
return ErrNotImplemented
}
idx := xslices.IndexFunc(session.addresses, func(address liteapi.Address) bool {
return strings.EqualFold(address.Email, from)
})
if idx < 0 {
return ErrInvalidReturnPath
for addrID, email := range session.emails {
if strings.EqualFold(from, email) {
session.from = addrID
}
}
session.from = session.addresses[idx].ID
if session.from == "" {
return ErrInvalidReturnPath
}
return nil
}
@ -129,10 +133,10 @@ func (session *smtpSession) Data(r io.Reader) error {
addrKR, ok := session.addrKRs[session.from]
if !ok {
return ErrMissingAddressKey
return ErrMissingAddrKey
}
addrKR, err := addrKR.FirstKey()
addrKey, err := addrKR.FirstKey()
if err != nil {
return fmt.Errorf("failed to get first key: %w", err)
}
@ -143,7 +147,7 @@ func (session *smtpSession) Data(r io.Reader) error {
}
if session.settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled {
key, err := addrKR.GetKey(0)
key, err := addrKey.GetKey(0)
if err != nil {
return fmt.Errorf("failed to get user public key: %w", err)
}
@ -153,7 +157,7 @@ func (session *smtpSession) Data(r io.Reader) error {
return fmt.Errorf("failed to get user public key: %w", err)
}
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKey.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
}
message, err := message.ParseWithParser(parser)
@ -161,7 +165,7 @@ func (session *smtpSession) Data(r io.Reader) error {
return fmt.Errorf("failed to parse message: %w", err)
}
draft, attKeys, err := session.createDraft(ctx, addrKR, message)
draft, attKeys, err := session.createDraft(ctx, addrKey, message)
if err != nil {
return fmt.Errorf("failed to create draft: %w", err)
}
@ -171,7 +175,7 @@ func (session *smtpSession) Data(r io.Reader) error {
return fmt.Errorf("failed to get recipients: %w", err)
}
req, err := createSendReq(addrKR, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys)
req, err := createSendReq(addrKey, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys)
if err != nil {
return fmt.Errorf("failed to create packages: %w", err)
}

View File

@ -4,57 +4,34 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/slices"
)
const chunkSize = 1 << 20
func (user *User) sync(ctx context.Context) error {
user.notifyCh <- events.SyncStarted{
UserID: user.ID(),
}
if err := user.syncLabels(ctx); err != nil {
return fmt.Errorf("failed to sync labels: %w", err)
}
if err := user.syncMessages(ctx); err != nil {
return fmt.Errorf("failed to sync messages: %w", err)
}
user.notifyCh <- events.SyncFinished{
UserID: user.ID(),
}
if err := user.vault.SetSync(true); err != nil {
return fmt.Errorf("failed to update sync status: %w", err)
}
return nil
}
func (user *User) syncLabels(ctx context.Context) error {
func (user *User) syncLabels(ctx context.Context, addrIDs ...string) error {
// Sync the system folders.
system, err := user.client.GetLabels(ctx, liteapi.LabelTypeSystem)
if err != nil {
return err
}
for _, label := range system {
user.updateCh <- newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name)
for _, label := range xslices.Filter(system, func(label liteapi.Label) bool { return wantLabelID(label.ID) }) {
for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name))
}
}
// Create Folders/Labels mailboxes with a random ID and with the \Noselect attribute.
for _, prefix := range []string{folderPrefix, labelPrefix} {
user.updateCh <- newPlaceHolderMailboxCreatedUpdate(prefix)
for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newPlaceHolderMailboxCreatedUpdate(prefix))
}
}
// Sync the API folders.
@ -64,7 +41,9 @@ func (user *User) syncLabels(ctx context.Context) error {
}
for _, folder := range folders {
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path})
for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path}))
}
}
// Sync the API labels.
@ -74,7 +53,9 @@ func (user *User) syncLabels(ctx context.Context) error {
}
for _, label := range labels {
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path})
for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path}))
}
}
return nil
@ -84,27 +65,53 @@ func (user *User) syncMessages(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Determine which messages to sync.
// TODO: This needs to be done better using the new API route to retrieve just the message IDs.
metadata, err := user.client.GetAllMessageMetadata(ctx)
if err != nil {
return err
}
// If in split mode, we need to send each message to a different IMAP connector.
isSplitMode := user.vault.AddressMode() == vault.SplitMode
// Collect the build requests -- we need:
// - the message ID to build,
// - the keyring to decrypt the message,
// - and the address to send the message to (for split mode).
requests := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) request {
var addressID string
if isSplitMode {
addressID = metadata.AddressID
} else {
addressID = user.apiAddrs.primary()
}
return request{
messageID: metadata.ID,
addressID: addressID,
addrKR: user.addrKRs[metadata.AddressID],
}
})
flusher := newFlusher(user.ID(), user.updateCh, user.notifyCh, len(metadata), chunkSize)
defer flusher.flush()
// Create the flushers, one per update channel.
flushers := make(map[string]*flusher)
for addrID, updateCh := range user.updateCh {
flusher := newFlusher(user.ID(), updateCh, user.eventCh, len(requests), chunkSize)
defer flusher.flush()
flushers[addrID] = flusher
}
// Build the messages and send them to the correct flusher.
if err := user.builder.Process(ctx, requests, func(req request, res *imap.MessageCreated, err error) error {
if err != nil {
return fmt.Errorf("failed to build message %s: %w", req.messageID, err)
}
flusher.push(res)
flushers[req.addressID].push(res)
return nil
}); err != nil {
@ -114,95 +121,15 @@ func (user *User) syncMessages(ctx context.Context) error {
return nil
}
type flusher struct {
userID string
func (user *User) syncWait() {
for _, updateCh := range user.updateCh {
waiter := imap.NewNoop()
defer waiter.Wait()
updates []*imap.MessageCreated
updateCh chan<- imap.Update
notifyCh chan<- events.Event
maxChunkSize int
curChunkSize int
count int
total int
start time.Time
pushLock sync.Mutex
}
func newFlusher(userID string, updateCh chan<- imap.Update, notifyCh chan<- events.Event, total, maxChunkSize int) *flusher {
return &flusher{
userID: userID,
updateCh: updateCh,
notifyCh: notifyCh,
maxChunkSize: maxChunkSize,
total: total,
start: time.Now(),
updateCh.Enqueue(waiter)
}
}
func (f *flusher) push(update *imap.MessageCreated) {
f.pushLock.Lock()
defer f.pushLock.Unlock()
f.updates = append(f.updates, update)
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
f.flush()
}
}
func (f *flusher) flush() {
if len(f.updates) == 0 {
return
}
f.count += len(f.updates)
f.updateCh <- imap.NewMessagesCreated(f.updates...)
f.notifyCh <- newSyncProgress(f.userID, f.count, f.total, f.start)
f.updates = nil
f.curChunkSize = 0
}
func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress {
return events.SyncProgress{
UserID: userID,
Progress: float64(count) / float64(total),
Elapsed: time.Since(start),
Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count),
}
}
func getMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) {
parsedMessage, err := imap.NewParsedMessage(literal)
if err != nil {
return nil, err
}
flags := imap.NewFlagSet()
if !message.Unread {
flags = flags.Add(imap.FlagSeen)
}
if slices.Contains(message.LabelIDs, liteapi.StarredLabel) {
flags = flags.Add(imap.FlagFlagged)
}
imapMessage := imap.Message{
ID: imap.MessageID(message.ID),
Flags: flags,
Date: time.Unix(message.Time, 0),
}
return &imap.MessageCreated{
Message: imapMessage,
Literal: literal,
LabelIDs: imapLabelIDs(filterLabelIDs(message.LabelIDs)),
ParsedMessage: parsedMessage,
}, nil
}
func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated {
if strings.EqualFold(labelName, imap.Inbox) {
labelName = imap.Inbox
@ -237,18 +164,12 @@ func newMailboxCreatedUpdate(labelID imap.LabelID, labelName []string) *imap.Mai
})
}
func filterLabelIDs(labelIDs []string) []string {
var filteredLabelIDs []string
func wantLabelID(labelID string) bool {
switch labelID {
case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel:
return false
for _, labelID := range labelIDs {
switch labelID {
case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel:
// ... skip ...
default:
filteredLabelIDs = append(filteredLabelIDs, labelID)
}
default:
return true
}
return filteredLabelIDs
}

13
internal/user/types.go Normal file
View File

@ -0,0 +1,13 @@
package user
import "reflect"
func mapTo[From, To any](from []From) []To {
to := make([]To, 0, len(from))
for _, from := range from {
to = append(to, reflect.ValueOf(from).Convert(reflect.TypeOf(to).Elem()).Interface().(To))
}
return to
}

View File

@ -0,0 +1,20 @@
package user
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestToType(t *testing.T) {
type myString string
// Slices of different types are not equal.
require.NotEqual(t, []myString{"a", "b", "c"}, []string{"a", "b", "c"})
// But converting them to the same type makes them equal.
require.Equal(t, []myString{"a", "b", "c"}, mapTo[string, myString]([]string{"a", "b", "c"}))
// The conversion can happen in the other direction too.
require.Equal(t, []string{"a", "b", "c"}, mapTo[myString, string]([]myString{"a", "b", "c"}))
}

View File

@ -2,19 +2,22 @@ package user
import (
"context"
"fmt"
"runtime"
"time"
"github.com/ProtonMail/gluon"
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-smtp"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
@ -23,40 +26,38 @@ var (
DefaultEventJitter = 20 * time.Second
)
// TODO: Is it bad to store the key pass in the user? Any worse than storing private keys?
type User struct {
vault *vault.User
client *liteapi.Client
builder *pool.Pool[request, *imap.MessageCreated]
eventCh *queue.QueuedChannel[events.Event]
apiUser liteapi.User
addresses []liteapi.Address
settings liteapi.MailSettings
notifyCh chan events.Event
updateCh chan imap.Update
apiUser liteapi.User
apiAddrs *addrList
userKR *crypto.KeyRing
addrKRs map[string]*crypto.KeyRing
imapConn *imapConnector
settings liteapi.MailSettings
updateCh map[string]*queue.QueuedChannel[imap.Update]
syncWG gluon.WaitGroup
}
func New(
ctx context.Context,
vault *vault.User,
encVault *vault.User,
client *liteapi.Client,
apiUser liteapi.User,
apiAddrs []liteapi.Address,
userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing,
) (*User, error) {
if vault.EventID() == "" {
if encVault.EventID() == "" {
eventID, err := client.GetLatestEventID(ctx)
if err != nil {
return nil, err
}
if err := vault.SetEventID(eventID); err != nil {
if err := encVault.SetEventID(eventID); err != nil {
return nil, err
}
}
@ -67,19 +68,29 @@ func New(
}
user := &User{
apiUser: apiUser,
addresses: apiAddrs,
settings: settings,
vault: vault,
vault: encVault,
client: client,
builder: newBuilder(client, runtime.NumCPU()*runtime.NumCPU(), runtime.NumCPU()*runtime.NumCPU()),
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
notifyCh: make(chan events.Event),
updateCh: make(chan imap.Update),
apiUser: apiUser,
apiAddrs: newAddrList(apiAddrs),
userKR: userKR,
addrKRs: addrKRs,
userKR: userKR,
addrKRs: addrKRs,
settings: settings,
updateCh: make(map[string]*queue.QueuedChannel[imap.Update]),
}
// Initialize update channels for each of the user's addresses.
for _, addrID := range user.apiAddrs.addrIDs() {
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
// If in combined mode, we only need one update channel.
if encVault.AddressMode() == vault.CombinedMode {
break
}
}
// When we receive an auth object, we update it in the store.
@ -93,111 +104,234 @@ func New(
// When we are deauthorized, we send a deauth event to the notify channel.
// Bridge will catch this and log the user out.
client.AddDeauthHandler(func() {
user.notifyCh <- events.UserDeauth{
user.eventCh.Enqueue(events.UserDeauth{
UserID: user.ID(),
}
})
})
// When we receive an API event, we attempt to handle it. If successful, we send the event to the event channel.
// When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault.
go func() {
for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, vault.EventID()).Subscribe() {
if err := user.handleAPIEvent(event); err != nil {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, encVault.EventID()).Subscribe() {
if err := user.handleAPIEvent(ctx, event); err != nil {
logrus.WithError(err).Error("Failed to handle event")
} else {
if err := user.vault.SetEventID(event.EventID); err != nil {
logrus.WithError(err).Error("Failed to update event ID")
}
} else if err := user.vault.SetEventID(event.EventID); err != nil {
logrus.WithError(err).Error("Failed to update event ID")
}
}
}()
// TODO: Use a proper sync manager! (if partial sync, pickup from where we last stopped)
if !vault.HasSync() {
go user.sync(context.Background())
}
return user, nil
}
// ID returns the user's ID.
func (user *User) ID() string {
return user.apiUser.ID
}
// Name returns the user's username.
func (user *User) Name() string {
return user.apiUser.Name
}
// Match matches the given query against the user's username and email addresses.
func (user *User) Match(query string) bool {
if query == user.Name() {
if query == user.apiUser.Name {
return true
}
return slices.Contains(user.Addresses(), query)
return slices.Contains(user.apiAddrs.emails(), query)
}
func (user *User) Addresses() []string {
return xslices.Map(
sort(user.addresses, func(a, b liteapi.Address) bool {
return a.Order < b.Order
}),
func(address liteapi.Address) string {
return address.Email
},
)
// Emails returns all the user's email addresses.
func (user *User) Emails() []string {
return user.apiAddrs.emails()
}
func (user *User) GluonID() string {
return user.vault.GluonID()
// GetAddressMode returns the user's current address mode.
func (user *User) GetAddressMode() vault.AddressMode {
return user.vault.AddressMode()
}
// SetAddressMode sets the user's address mode.
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
for _, updateCh := range user.updateCh {
updateCh.Close()
}
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
for _, addrID := range user.apiAddrs.addrIDs() {
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
if mode == vault.CombinedMode {
break
}
}
if err := user.vault.SetAddressMode(mode); err != nil {
return fmt.Errorf("failed to set address mode: %w", err)
}
return nil
}
// GetGluonIDs returns the users gluon IDs.
func (user *User) GetGluonIDs() map[string]string {
return user.vault.GetGluonIDs()
}
// GetGluonID returns the gluon ID for the given address, if present.
func (user *User) GetGluonID(addrID string) (string, bool) {
gluonID, ok := user.vault.GetGluonIDs()[addrID]
if !ok {
return "", false
}
return gluonID, true
}
// SetGluonID sets the gluon ID for the given address.
func (user *User) SetGluonID(addrID, gluonID string) error {
return user.vault.SetGluonID(addrID, gluonID)
}
// GluonKey returns the user's gluon key from the vault.
func (user *User) GluonKey() []byte {
return user.vault.GluonKey()
}
// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP.
func (user *User) BridgePass() string {
return user.vault.BridgePass()
}
// UsedSpace returns the total space used by the user on the API.
func (user *User) UsedSpace() int {
return user.apiUser.UsedSpace
}
// MaxSpace returns the amount of space the user can use on the API.
func (user *User) MaxSpace() int {
return user.apiUser.MaxSpace
}
// GetNotifyCh returns a channel which notifies of events happening to the user (such as deauth, address change)
func (user *User) GetNotifyCh() <-chan events.Event {
return user.notifyCh
// HasSync returns whether the user has finished syncing.
func (user *User) HasSync() bool {
return user.vault.HasSync()
}
func (user *User) NewGluonConnector(ctx context.Context) (connector.Connector, error) {
if user.imapConn != nil {
if err := user.imapConn.Close(ctx); err != nil {
return nil, err
// AbortSync aborts any ongoing sync.
// TODO: This should abort the sync rather than just waiting.
// Should probably be done automatically when one of the user's IMAP connectors is closed.
func (user *User) AbortSync(ctx context.Context) error {
user.syncWG.Wait()
return nil
}
// DoSync performs a sync for the user.
func (user *User) DoSync(ctx context.Context) <-chan error {
errCh := queue.NewQueuedChannel[error](0, 0)
user.syncWG.Go(func() {
defer errCh.Close()
user.eventCh.Enqueue(events.SyncStarted{
UserID: user.ID(),
})
errCh.Enqueue(func() error {
if err := user.syncLabels(ctx, maps.Keys(user.updateCh)...); err != nil {
return fmt.Errorf("failed to sync labels: %w", err)
}
if err := user.syncMessages(ctx); err != nil {
return fmt.Errorf("failed to sync messages: %w", err)
}
user.syncWait()
if err := user.vault.SetSync(true); err != nil {
return fmt.Errorf("failed to set sync status: %w", err)
}
return nil
}())
user.eventCh.Enqueue(events.SyncFinished{
UserID: user.ID(),
})
})
return errCh.GetChannel()
}
// GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change)
func (user *User) GetEventCh() <-chan events.Event {
return user.eventCh.GetChannel()
}
// NewIMAPConnector returns an IMAP connector for the given address.
// If not in split mode, this function returns an error.
func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
var emails []string
switch user.vault.AddressMode() {
case vault.CombinedMode:
if addrID != user.apiAddrs.primary() {
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
}
emails = user.apiAddrs.emails()
case vault.SplitMode:
emails = []string{user.apiAddrs.email(addrID)}
}
user.imapConn = newIMAPConnector(user.client, user.updateCh, user.Addresses(), user.vault.BridgePass())
return user.imapConn, nil
return newIMAPConnector(
user.client,
user.updateCh[addrID].GetChannel(),
user.vault.BridgePass(),
emails...,
), nil
}
func (user *User) NewSMTPSession(username string) (smtp.Session, error) {
return newSMTPSession(user.client, username, user.addresses, user.userKR, user.addrKRs, user.settings), nil
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
// In combined mode, this is just the user's primary address.
// In split mode, this is all the user's addresses.
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
imapConn := make(map[string]connector.Connector)
for addrID := range user.updateCh {
conn, err := user.NewIMAPConnector(addrID)
if err != nil {
return nil, fmt.Errorf("failed to create IMAP connector: %w", err)
}
imapConn[addrID] = conn
}
return imapConn, nil
}
// NewSMTPSession returns an SMTP session for the user.
func (user *User) NewSMTPSession(username string) smtp.Session {
return newSMTPSession(user.client, username, user.apiAddrs.addrMap(), user.settings, user.userKR, user.addrKRs)
}
// Logout logs the user out from the API.
func (user *User) Logout(ctx context.Context) error {
return user.client.AuthDelete(ctx)
}
// Close closes ongoing connections and cleans up resources.
func (user *User) Close(ctx context.Context) error {
// Close the user's IMAP connectors.
if user.imapConn != nil {
if err := user.imapConn.Close(ctx); err != nil {
return err
}
}
// Wait for ongoing syncs to finish.
user.syncWG.Wait()
// Close the user's message builder.
user.builder.Done()
@ -205,15 +339,13 @@ func (user *User) Close(ctx context.Context) error {
// Close the user's API client.
user.client.Close()
// Close the user's update channels.
for _, updateCh := range user.updateCh {
updateCh.Close()
}
// Close the user's notify channel.
close(user.notifyCh)
user.eventCh.Close()
return nil
}
// sort returns the slice, sorted by the given callback.
func sort[T any](slice []T, less func(a, b T) bool) []T {
slices.SortFunc(slice, less)
return slice
}

162
internal/user/user_test.go Normal file
View File

@ -0,0 +1,162 @@
package user_test
import (
"context"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/ProtonMail/proton-bridge/v2/tests"
"github.com/bradenaw/juniper/iterator"
"github.com/emersion/go-imap"
"github.com/emersion/go-imap/client"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
"gitlab.protontech.ch/go/liteapi/server/account"
)
func init() {
user.DefaultEventPeriod = 100 * time.Millisecond
user.DefaultEventJitter = 0
account.GenerateKey = tests.FastGenerateKey
certs.GenerateCert = tests.FastGenerateCert
}
func TestUser_Data(t *testing.T) {
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
// User's ID should be correct.
require.Equal(t, userID, user.ID())
// User's name should be correct.
require.Equal(t, "username", user.Name())
// User's email should be correct.
require.ElementsMatch(t, []string{"email@pm.me", "alias@pm.me"}, user.Emails())
// By default, user should be in combined mode.
require.Equal(t, vault.CombinedMode, user.GetAddressMode())
// By default, user should have a non-empty bridge password.
require.NotEmpty(t, user.BridgePass())
})
})
}
func TestUser_Sync(t *testing.T) {
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
// Get the user's IMAP connectors.
imapConn, err := user.NewIMAPConnectors()
require.NoError(t, err)
// Pretend to be gluon applying all the updates.
go func() {
for _, imapConn := range imapConn {
for update := range imapConn.GetUpdates() {
update.Done()
}
}
}()
// Trigger a user sync.
errCh := user.DoSync(ctx)
// User starts a sync at startup.
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
// User finishes a sync at startup.
require.IsType(t, events.SyncFinished{}, <-user.GetEventCh())
// The sync completes without error.
require.NoError(t, <-errCh)
})
})
}
func TestUser_Deauth(t *testing.T) {
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
eventCh := user.GetEventCh()
// Revoke the user's auth token.
require.NoError(t, s.RevokeUser(userID))
// The user should eventually be logged out.
require.Eventually(t, func() bool { _, ok := (<-eventCh).(events.UserDeauth); return ok }, 5*time.Second, 100*time.Millisecond)
})
})
}
func withAPI(t *testing.T, ctx context.Context, username, password string, emails []string, fn func(context.Context, *server.Server, string, []string)) {
server := server.New()
defer server.Close()
var addrIDs []string
userID, addrID, err := server.AddUser(username, password, emails[0])
require.NoError(t, err)
addrIDs = append(addrIDs, addrID)
for _, email := range emails[1:] {
addrID, err := server.AddAddress(userID, email, password)
require.NoError(t, err)
addrIDs = append(addrIDs, addrID)
}
fn(ctx, server, userID, addrIDs)
}
func withUser(t *testing.T, ctx context.Context, apiURL, username, password string, fn func(*user.User)) {
c, apiAuth, err := liteapi.New(liteapi.WithHostURL(apiURL)).NewClientWithLogin(ctx, username, password)
require.NoError(t, err)
defer func() { require.NoError(t, c.Close()) }()
apiUser, apiAddrs, userKR, addrKRs, passphrase, err := c.Unlock(ctx, []byte(password))
require.NoError(t, err)
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"))
require.NoError(t, err)
require.False(t, corrupt)
vaultUser, err := vault.AddUser(apiUser.ID, username, apiAuth.UID, apiAuth.RefreshToken, passphrase)
require.NoError(t, err)
user, err := user.New(ctx, vaultUser, c, apiUser, apiAddrs, userKR, addrKRs)
require.NoError(t, err)
defer func() { require.NoError(t, user.Close(ctx)) }()
fn(user)
}
func withIMAPClient(t *testing.T, addr string, fn func(*client.Client)) {
c, err := client.Dial(addr)
require.NoError(t, err)
defer c.Close()
fn(c)
}
func fetch(t *testing.T, c *client.Client, seqset string, items ...imap.FetchItem) []*imap.Message {
msgCh := make(chan *imap.Message)
go func() {
require.NoError(t, c.Fetch(must(imap.ParseSeqSet(seqset)), items, msgCh))
}()
return iterator.Collect(iterator.Chan(msgCh))
}
func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}

View File

@ -5,9 +5,14 @@ import (
)
// RandomToken is a function that returns a random token.
var RandomToken func(size int) ([]byte, error)
// By default, we use crypto.RandomToken to generate tokens.
func init() {
RandomToken = crypto.RandomToken
var RandomToken = crypto.RandomToken
func newRandomToken(size int) []byte {
token, err := RandomToken(size)
if err != nil {
panic(err)
}
return token
}

View File

@ -4,6 +4,7 @@ import (
"math/rand"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
)
@ -45,15 +46,24 @@ type Settings struct {
FirstStartGUI bool
}
type AddressMode int
const (
CombinedMode AddressMode = iota
SplitMode
)
// UserData holds information about a single bridge user.
// The user may or may not be logged in.
type UserData struct {
UserID string
Username string
GluonID string
GluonKey []byte
BridgePass string
GluonKey []byte
GluonIDs map[string]string
UIDValidity map[string]imap.UID
BridgePass []byte
AddressMode AddressMode
AuthUID string
AuthRef string

View File

@ -1,5 +1,11 @@
package vault
import (
"encoding/hex"
"github.com/ProtonMail/gluon/imap"
)
type User struct {
vault *Vault
userID string
@ -13,16 +19,41 @@ func (user *User) Username() string {
return user.vault.getUser(user.userID).Username
}
func (user *User) GluonID() string {
return user.vault.getUser(user.userID).GluonID
func (user *User) GetGluonIDs() map[string]string {
return user.vault.getUser(user.userID).GluonIDs
}
func (user *User) SetGluonID(addrID, gluonID string) error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.GluonIDs[addrID] = gluonID
})
}
func (user *User) GetUIDValidity(addrID string) (imap.UID, bool) {
validity, ok := user.vault.getUser(user.userID).UIDValidity[addrID]
if !ok {
return imap.UID(0), false
}
return validity, true
}
func (user *User) SetUIDValidity(addrID string, validity imap.UID) error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.UIDValidity[addrID] = validity
})
}
func (user *User) GluonKey() []byte {
return user.vault.getUser(user.userID).GluonKey
}
func (user *User) AddressMode() AddressMode {
return user.vault.getUser(user.userID).AddressMode
}
func (user *User) BridgePass() string {
return user.vault.getUser(user.userID).BridgePass
return hex.EncodeToString(user.vault.getUser(user.userID).BridgePass)
}
func (user *User) AuthUID() string {
@ -51,7 +82,7 @@ func (user *User) SetKeyPass(keyPass []byte) error {
})
}
// SetAuth updates the auth secrets for the given user.
// SetAuth sets the auth secrets for the given user.
func (user *User) SetAuth(authUID, authRef string) error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.AuthUID = authUID
@ -59,33 +90,23 @@ func (user *User) SetAuth(authUID, authRef string) error {
})
}
// SetGluonAuth updates the gluon ID and key for the given user.
func (user *User) SetGluonAuth(gluonID string, gluonKey []byte) error {
// SetAddressMode sets the address mode for the given user.
func (user *User) SetAddressMode(mode AddressMode) error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.GluonID = gluonID
data.GluonKey = gluonKey
data.AddressMode = mode
})
}
// SetEventID updates the event ID for the given user.
// SetEventID sets the event ID for the given user.
func (user *User) SetEventID(eventID string) error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.EventID = eventID
})
}
// SetSync updates the sync state for the given user.
// SetSync sets the sync state for the given user.
func (user *User) SetSync(hasSync bool) error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.HasSync = hasSync
})
}
// Clear clears the secrets for the given user.
func (user *User) Clear() error {
return user.vault.modUser(user.userID, func(data *UserData) {
data.AuthUID = ""
data.AuthRef = ""
data.KeyPass = nil
})
}

View File

@ -4,6 +4,7 @@ import (
"encoding/hex"
"testing"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/stretchr/testify/require"
)
@ -32,30 +33,48 @@ func TestUser(t *testing.T) {
require.NoError(t, user2.SetSync(false))
// Set gluon data for user 1 and 2.
require.NoError(t, user1.SetGluonAuth("gluonID1", []byte("gluonKey1")))
require.NoError(t, user2.SetGluonAuth("gluonID2", []byte("gluonKey2")))
require.NoError(t, user1.SetGluonID("addrID1", "gluonID1"))
require.NoError(t, user2.SetGluonID("addrID2", "gluonID2"))
require.NoError(t, user1.SetUIDValidity("addrID1", imap.UID(1)))
require.NoError(t, user2.SetUIDValidity("addrID2", imap.UID(2)))
// List available users.
require.ElementsMatch(t, []string{"userID1", "userID2"}, s.GetUserIDs())
// Check gluon information for user 1.
gluonID1, ok := user1.GetGluonIDs()["addrID1"]
require.True(t, ok)
require.Equal(t, "gluonID1", gluonID1)
uidValidity1, ok := user1.GetUIDValidity("addrID1")
require.True(t, ok)
require.Equal(t, imap.UID(1), uidValidity1)
require.NotEmpty(t, user1.GluonKey())
// Get auth information for user 1.
require.Equal(t, "userID1", user1.UserID())
require.Equal(t, "user1", user1.Username())
require.Equal(t, "gluonID1", user1.GluonID())
require.Equal(t, []byte("gluonKey1"), user1.GluonKey())
require.Equal(t, hex.EncodeToString([]byte("token")), user1.BridgePass())
require.Equal(t, vault.CombinedMode, user1.AddressMode())
require.Equal(t, "authUID1", user1.AuthUID())
require.Equal(t, "authRef1", user1.AuthRef())
require.Equal(t, []byte("keyPass1"), user1.KeyPass())
require.Equal(t, "eventID1", user1.EventID())
require.Equal(t, true, user1.HasSync())
// Check gluon information for user 1.
gluonID2, ok := user2.GetGluonIDs()["addrID2"]
require.True(t, ok)
require.Equal(t, "gluonID2", gluonID2)
uidValidity2, ok := user2.GetUIDValidity("addrID2")
require.True(t, ok)
require.Equal(t, imap.UID(2), uidValidity2)
require.NotEmpty(t, user2.GluonKey())
// Get auth information for user 2.
require.Equal(t, "userID2", user2.UserID())
require.Equal(t, "user2", user2.Username())
require.Equal(t, "gluonID2", user2.GluonID())
require.Equal(t, []byte("gluonKey2"), user2.GluonKey())
require.Equal(t, hex.EncodeToString([]byte("token")), user2.BridgePass())
require.Equal(t, vault.CombinedMode, user2.AddressMode())
require.Equal(t, "authUID2", user2.AuthUID())
require.Equal(t, "authRef2", user2.AuthRef())
require.Equal(t, []byte("keyPass2"), user2.KeyPass())
@ -63,8 +82,8 @@ func TestUser(t *testing.T) {
require.Equal(t, false, user2.HasSync())
// Clear the users.
require.NoError(t, user1.Clear())
require.NoError(t, user2.Clear())
require.NoError(t, s.ClearUser("userID1"))
require.NoError(t, s.ClearUser("userID2"))
// Their secrets should now be cleared.
require.Equal(t, "", user1.AuthUID())

View File

@ -4,7 +4,6 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"io/fs"
@ -12,6 +11,7 @@ import (
"os"
"path/filepath"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/bradenaw/juniper/xslices"
)
@ -99,16 +99,16 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
return nil, errors.New("user already exists")
}
tok, err := RandomToken(16)
if err != nil {
return nil, err
}
if err := vault.mod(func(data *Data) {
data.Users = append(data.Users, UserData{
UserID: userID,
Username: username,
BridgePass: hex.EncodeToString(tok),
UserID: userID,
Username: username,
GluonKey: newRandomToken(32),
GluonIDs: make(map[string]string),
UIDValidity: make(map[string]imap.UID),
BridgePass: newRandomToken(16),
AddressMode: CombinedMode,
AuthUID: authUID,
AuthRef: authRef,
@ -121,6 +121,14 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
return vault.GetUser(userID)
}
func (vault *Vault) ClearUser(userID string) error {
return vault.modUser(userID, func(data *UserData) {
data.AuthUID = ""
data.AuthRef = ""
data.KeyPass = nil
})
}
// DeleteUser removes the given user from the vault.
func (vault *Vault) DeleteUser(userID string) error {
return vault.mod(func(data *Data) {